Skip to content

Custom vmap implementation#25

Open
zou3519 wants to merge 1 commit into
mainfrom
custom_vmap
Open

Custom vmap implementation#25
zou3519 wants to merge 1 commit into
mainfrom
custom_vmap

Conversation

@zou3519

@zou3519 zou3519 commented Apr 18, 2022

Copy link
Copy Markdown
Collaborator

TODO: needs description of what is going on

TODO: needs description of what is going on
Comment thread simple_functorch.py
Comment on lines +768 to +769
result = batch_rule(self.inner, *args)
return result

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike custom_vjp, custom_vmap does not call custom_vmap on the inner dispatcher!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come. It doesn't seem to me like Batched(Batched()) wouldn't apply. Perhaps you are saying, it is impossible for a batching rule to recursively refer to itself?

Comment thread simple_functorch.py
loss88 = d.sum(result, name='loss88')

grad_x, = d.grad(loss88, [x])
assert torch.allclose(grad_x, torch.full_like(x, 2))

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I agree with this?
I would expect that Autograd(Batched(Torch(), length=2)) would give 2 * x as the gradient while Batched(Autograd(Torch()), length=2) would give 2. No?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming someone doesn't use this to completely change the values of the output (like we're doing here), then the output values should be the same. The difference is how the backward pass is being executed.

As written in this PR right now Autograd(Batched(Torch(), length=2)) executes the backward pass of the batching rule, not the backward pass of the original function. To check, your claim is that Autograd(Batched(Torch(), length=2)) should execute the backward pass of the original function, not the backward pass of the batching rule, right?

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.
I would expect here to be able to see the different between batch of gradients and element-wise gradients.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD for Batched(Autograd(Torch()) and a custom_vjp(f_fwd, f_bwd, *args) call, what would you expect to happen?

Option 1: The backward pass differentiates through the batching rule for f_fwd

Option 2: The backward pass runs vmap(f_bwd)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that question beyond what we're discussing here?
The question is still there without considering any custom_vjp right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I'm just trying to understand the difference between this and custom_vjp. I think the current semantics are analogous to what custom_vjp is doing but reasoning through it is confusing

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ezyang @Chillee -- this is what we were discussing during the Composability hangout, if you folks have opinions

@samdow samdow Apr 19, 2022

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disregarding custom_vjp for a moment, if we're going off of the argument from the meeting today that this should match the behavior of a normal batch rule, isn't this code right?

Let's say we're calling d2.unsqueeze(x, dim) with d2 = Batched(Autograd(Torch())). First it hits the Batched dispatcher, so we get d2.inner.unsqueeze(x, dim + 1) so the Autograd dispatcher sees unsqueeze(x, dim + 1) (which is the batch rule) and does autograd on that function. Using the same dispatcher stack, we similarly expect autograd runs on the "batch rule"

As a related a note, I always get confused that Batched(Autograd(Torch())) is the same as grad(vmap()) so it might be worth to add the transform implementations if that's not too much work? I think this makes a lot more sense if we're able to say grad(vmap()) gets the derivative of the custom batch rule rather than remembering to invert the interpreter stack.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the current Dispatcher for the regular code does that.
But now if we call the custom_vmap with this unsqueeze function does it do that as well?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But now if we call the custom_vmap with this unsqueeze function does it do that as well?

Yes?

More explanation:

I would expect that Autograd(Batched(Torch(), length=2)) would give 2 * x as the gradient

Given that the original function is d.mul(x, x), this would mean that autograd is running on the unbatched function, not the batch rule. If we agree that given this set of Dispatchers should end up with Autograd running on the batched rule, then 2 is the expected gradient because d.add(x, x) is the batched rule and the derivative of that is 2

Comment thread simple_functorch.py
def wrapped(d, *args):
saved = self.inner
try:
self.inner = d

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

woooow so spicy

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to say that you can do this without mutating the dispatcher stack by simply creating a new Autograd dispatcher on the fly, whose inner is d, but then the tapes would not be shared. This seems... dubious.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if we represent the tape with an extra indirection this isn't too hard to do. Probably better and then makes this rule nicely symmetric for how custom_vjp is implemented in Batched.

Comment thread simple_functorch.py
)
return r, saved

def custom_vmap(self, fn, batch_rule, *args):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to diverge substantially from JAX's custom_vmap, at https://github.com/google/jax/blob/main/jax/_src/custom_batching.py

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is your comment that the API is different, the implementation is different, or both?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implementation. But later I worked out that this is exactly analogous to how we did batching, so... idk, maybe it's still fine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants