Custom vmap implementation#25
Conversation
TODO: needs description of what is going on
| result = batch_rule(self.inner, *args) | ||
| return result |
There was a problem hiding this comment.
Unlike custom_vjp, custom_vmap does not call custom_vmap on the inner dispatcher!
There was a problem hiding this comment.
how come. It doesn't seem to me like Batched(Batched()) wouldn't apply. Perhaps you are saying, it is impossible for a batching rule to recursively refer to itself?
| loss88 = d.sum(result, name='loss88') | ||
|
|
||
| grad_x, = d.grad(loss88, [x]) | ||
| assert torch.allclose(grad_x, torch.full_like(x, 2)) |
There was a problem hiding this comment.
I don't think I agree with this?
I would expect that Autograd(Batched(Torch(), length=2)) would give 2 * x as the gradient while Batched(Autograd(Torch()), length=2) would give 2. No?
There was a problem hiding this comment.
Assuming someone doesn't use this to completely change the values of the output (like we're doing here), then the output values should be the same. The difference is how the backward pass is being executed.
As written in this PR right now Autograd(Batched(Torch(), length=2)) executes the backward pass of the batching rule, not the backward pass of the original function. To check, your claim is that Autograd(Batched(Torch(), length=2)) should execute the backward pass of the original function, not the backward pass of the batching rule, right?
There was a problem hiding this comment.
Yes.
I would expect here to be able to see the different between batch of gradients and element-wise gradients.
There was a problem hiding this comment.
@albanD for Batched(Autograd(Torch()) and a custom_vjp(f_fwd, f_bwd, *args) call, what would you expect to happen?
Option 1: The backward pass differentiates through the batching rule for f_fwd
Option 2: The backward pass runs vmap(f_bwd)
There was a problem hiding this comment.
Isn't that question beyond what we're discussing here?
The question is still there without considering any custom_vjp right?
There was a problem hiding this comment.
Oh I'm just trying to understand the difference between this and custom_vjp. I think the current semantics are analogous to what custom_vjp is doing but reasoning through it is confusing
There was a problem hiding this comment.
Disregarding custom_vjp for a moment, if we're going off of the argument from the meeting today that this should match the behavior of a normal batch rule, isn't this code right?
Let's say we're calling d2.unsqueeze(x, dim) with d2 = Batched(Autograd(Torch())). First it hits the Batched dispatcher, so we get d2.inner.unsqueeze(x, dim + 1) so the Autograd dispatcher sees unsqueeze(x, dim + 1) (which is the batch rule) and does autograd on that function. Using the same dispatcher stack, we similarly expect autograd runs on the "batch rule"
As a related a note, I always get confused that Batched(Autograd(Torch())) is the same as grad(vmap()) so it might be worth to add the transform implementations if that's not too much work? I think this makes a lot more sense if we're able to say grad(vmap()) gets the derivative of the custom batch rule rather than remembering to invert the interpreter stack.
There was a problem hiding this comment.
I agree that the current Dispatcher for the regular code does that.
But now if we call the custom_vmap with this unsqueeze function does it do that as well?
There was a problem hiding this comment.
But now if we call the custom_vmap with this unsqueeze function does it do that as well?
Yes?
More explanation:
I would expect that Autograd(Batched(Torch(), length=2)) would give 2 * x as the gradient
Given that the original function is d.mul(x, x), this would mean that autograd is running on the unbatched function, not the batch rule. If we agree that given this set of Dispatchers should end up with Autograd running on the batched rule, then 2 is the expected gradient because d.add(x, x) is the batched rule and the derivative of that is 2
| def wrapped(d, *args): | ||
| saved = self.inner | ||
| try: | ||
| self.inner = d |
There was a problem hiding this comment.
I was going to say that you can do this without mutating the dispatcher stack by simply creating a new Autograd dispatcher on the fly, whose inner is d, but then the tapes would not be shared. This seems... dubious.
There was a problem hiding this comment.
I guess if we represent the tape with an extra indirection this isn't too hard to do. Probably better and then makes this rule nicely symmetric for how custom_vjp is implemented in Batched.
| ) | ||
| return r, saved | ||
|
|
||
| def custom_vmap(self, fn, batch_rule, *args): |
There was a problem hiding this comment.
This appears to diverge substantially from JAX's custom_vmap, at https://github.com/google/jax/blob/main/jax/_src/custom_batching.py
There was a problem hiding this comment.
@ezyang is your comment that the API is different, the implementation is different, or both?
There was a problem hiding this comment.
implementation. But later I worked out that this is exactly analogous to how we did batching, so... idk, maybe it's still fine
TODO: needs description of what is going on