Fix adjoint for NonlinearSolution constructor #998
Conversation
|
Tests fail. |
SciML/SciMLSensitivity.jl#1189 highlights that the downstream tests that are failing for SciMLSensitivity.jl are actually Enzyme issues with v1.11. These are now being tracked here EnzymeAD/Enzyme.jl#2318 for the Enzyme developers to work on. But there are certain things we've been weary about, like: * SciML/SciMLBase.jl#997 * SciML/SciMLBase.jl#998 Because of failing downstream tests. But that's counter productive: we're not improving our autodiff interfaces because we see failures, but those aren't failures of our autodiff interfaces, those are Enzyme failures. So for now the solution seems to be to go to v1.10 in these downstream tests and increase the coverage of SciMLSensitivity, and focus on our parts. We can re-enable "1" when Enzyme is ready for it, but for now it's just noise.
SciML/SciMLSensitivity.jl#1189 highlights that the downstream tests that are failing for SciMLSensitivity.jl are actually Enzyme issues with v1.11. These are now being tracked here EnzymeAD/Enzyme.jl#2318 for the Enzyme developers to work on. But there are certain things we've been weary about, like: * #997 * #998 Because of failing downstream tests. But that's counter productive: we're not improving our autodiff interfaces because we see failures, but those aren't failures of our autodiff interfaces, those are Enzyme failures. So for now the solution seems to be to go to v1.10 in these downstream tests and increase the coverage of SciMLSensitivity, and focus on our parts. We can re-enable "1" when Enzyme is ready for it, but for now it's just noise.
| T, N, uType, R, P, A, O, uType2, S, Tr}}, u, | ||
| args...) where {T, N, uType, R, P, A, O, uType2, S, Tr} | ||
| function NonlinearSolutionAdjoint(ȳ) | ||
| (NoTangent(), ȳ.u, ntuple(_ -> NoTangent(), length(args))...) |
There was a problem hiding this comment.
We should probably track .prob as well (for gradients against parameters).
Also what is the type of ȳ ? What types could we encounter here? If it's another NonlinearSolution, we should investigate what produces it and potentially try to return a tangent type
There was a problem hiding this comment.
an mwe is
@parameters g
@variables x(t) y(t) [state_priority = 10] λ(t)
eqs = [D(D(x)) ~ λ * x
D(D(y)) ~ λ * y - g
x^2 + y^2 ~ 1]
@mtkbuild pend = ODESystem(eqs, t)
prob = ODEProblem(pend, [x => 1, y => 0], (0.0, 1.5), [g => 1.5], guesses = [λ => 1])
sol = solve(prob, Rodas5P())
get_vars = getsym(prob, [pend.x+pend.y]);
Zygote.gradient(sol) do sol
u = get_vars(sol)
# u = sol[pend.x+pend.y]
sum(reduce(vcat, u))
endwhich points to this branch
SciMLBase.jl/ext/SciMLBaseZygoteExt.jl
Line 106 in 42cdf6a
SciMLBase.jl/ext/SciMLBaseZygoteExt.jl
Line 188 in 42cdf6a
There was a problem hiding this comment.
ȳ was always a NonlinearSolution when I was testing.
There was a problem hiding this comment.
Sorry, just now getting back to this.
To track .prob, would it be like this:
function ChainRulesCore.rrule(
::Type{<:SciMLBase.NonlinearSolution{
T, N, uType, R, P, A, O, uType2, S, Tr}}, u, resid, prob,
args...) where {T, N, uType, R, P, A, O, uType2, S, Tr}
function NonlinearSolutionAdjoint(ȳ)
(NoTangent(), ȳ.u, NoTangent(), ŷ.prob, ntuple(_ -> NoTangent(), length(args))...)
end
SciMLBase.NonlinearSolution{T, N, uType, R, P, A, O, uType2, S, Tr}(u, resid, prob, args...),
NonlinearSolutionAdjoint
end?
There was a problem hiding this comment.
Yes this looks reasonable
| T, N, uType, R, P, A, O, uType2, S, Tr}}, u, resid, prob, | ||
| args...) where {T, N, uType, R, P, A, O, uType2, S, Tr} | ||
| function NonlinearSolutionAdjoint(ȳ) | ||
| (NoTangent(), ȳ.u, NoTangent(), ŷ.prob, ntuple(_ -> NoTangent(), length(args))...) |
There was a problem hiding this comment.
As a follow up, is the type of ȳ a solution type still with the latest SciMLSensitivity
There was a problem hiding this comment.
ȳ is a Tangent type with the latest SciMLSensitivity:
typeof(ȳ) = Tangent{Any}(u = 1.0, resid = ChainRulesCore.ZeroTangent(), prob = ChainRulesCore.ZeroTangent(), alg = ChainRulesCore.ZeroTangent(), retcode = ChainRulesCore.ZeroTangent(), original = ChainRulesCore.ZeroTangent(), left = ChainRulesCore.ZeroTangent(), right = ChainRulesCore.ZeroTangent(), stats = ChainRulesCore.ZeroTangent(), trace = ChainRulesCore.ZeroTangent())031520d to
d948f07
Compare
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
Should hopefully fix SciML/NonlinearSolve.jl#581, in conjuction with #997