custom_vjp bug#31
Open
zou3519 wants to merge 1 commit into
Open
Conversation
zou3519
commented
May 2, 2022
Comment on lines
+1323
to
+1347
| """ | ||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
| x = jnp.array(0.123) | ||
|
|
||
| @jax.custom_vjp | ||
| def jax_square(x): | ||
| return None | ||
|
|
||
| def f_fwd(x): | ||
| two_x = 2 * x | ||
| return x ** 2, (two_x,) | ||
|
|
||
| def f_bwd(saved, grad_output): | ||
| two_x, = saved | ||
| return (grad_output * two_x,) | ||
|
|
||
| jax_square.defvjp(f_fwd, f_bwd) | ||
|
|
||
| # This is 2.0 which is correct | ||
| ddx = jax.grad(jax.grad(jax_square))(x) | ||
| print(ddx) | ||
| """ |
Collaborator
Author
There was a problem hiding this comment.
This is the "reference". Note that JAX's f_fwd returns an intermediate, two_x and uses it in the gradient computation. Computing the second order gradient is correct.
zou3519
commented
May 2, 2022
Comment on lines
+1316
to
+1321
| # Bug 0: | ||
| # This is wrong: The result is None, but it should be 2. | ||
| # Somehow the gradients aren't getting recorded. | ||
| ddx = run_gradgrad(outer, inner, x) | ||
| import pdb; pdb.set_trace() |
Collaborator
Author
There was a problem hiding this comment.
Bug 0: ddx should equal two, but it is actually None here. This behavior is in-line with with autograd.Function does today, actually. autograd.Function requires the user to specify a gradient formula for the intermediate (two_x).
However, we're trying to explore what it would take to not require the user to specify the gradient formula, so let's not end this discussion at "this is expected".
zou3519
commented
May 2, 2022
Comment on lines
+1306
to
+1307
| # print(outer.gradient_tape) | ||
| # Bug 1: outer.grad still extend outer.gradient_tape, even though create_graph is False. |
Collaborator
Author
There was a problem hiding this comment.
Bug 1, which is probably related
There's a bug in simple functorch's custom_vjp. In particular, we would like to learn what it would take to get simple functorch's custom_vjp to work like JAX's custom_vjp w.r.t. to the behavior towards intermediate Tensors. TODO: we should try to fix the bugs mentioned.
zou3519
commented
May 2, 2022
Comment on lines
+1309
to
+1314
| # Bug 2: inner.gradient_tape | ||
| # The second TapeEntry doesn't use x at all! In fact, no tapes | ||
| # capture the 2 * x behavior. | ||
| # [TapeEntry(inputs=['x'], outputs=['v199'], propagate=<function Autograd.custom_vjp.<locals> | ||
| # .propagate at 0x7f8310658a60>), TapeEntry(inputs=['v200', 'v198'], outputs=['v201'], propagate | ||
| # =<function Autograd.mul.<locals>.propagate at 0x7f83106a3dc0>)] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
There's a bug in simple functorch's custom_vjp. In particular, we would
like to learn what it would take to get simple functorch's custom_vjp to
work like JAX's custom_vjp w.r.t. to the behavior towards intermediate
Tensors.
TODO: we should try to fix the bugs mentioned.