Skip to content

Fix adjoint for NonlinearSolution constructor #998

Open
jClugstor wants to merge 4 commits into
SciML:masterfrom
jClugstor:nonlinearsolution_adjoint
Open

Fix adjoint for NonlinearSolution constructor #998
jClugstor wants to merge 4 commits into
SciML:masterfrom
jClugstor:nonlinearsolution_adjoint

Conversation

@jClugstor

Copy link
Copy Markdown
Member

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

Should hopefully fix SciML/NonlinearSolve.jl#581, in conjuction with #997

@ChrisRackauckas

Copy link
Copy Markdown
Member

Tests fail.

ChrisRackauckas added a commit to SciML/DiffEqBase.jl that referenced this pull request Apr 25, 2025
SciML/SciMLSensitivity.jl#1189 highlights that the downstream tests that are failing for SciMLSensitivity.jl are actually Enzyme issues with v1.11. These are now being tracked here EnzymeAD/Enzyme.jl#2318 for the Enzyme developers to work on. But there are certain things we've been weary about, like:

* SciML/SciMLBase.jl#997
* SciML/SciMLBase.jl#998

Because of failing downstream tests. But that's counter productive: we're not improving our autodiff interfaces because we see failures, but those aren't failures of our autodiff interfaces, those are Enzyme failures.

So for now the solution seems to be to go to v1.10 in these downstream tests and increase the coverage of SciMLSensitivity, and focus on our parts. We can re-enable "1" when Enzyme is ready for it, but for now it's just noise.
ChrisRackauckas added a commit that referenced this pull request Apr 25, 2025
SciML/SciMLSensitivity.jl#1189 highlights that the downstream tests that are failing for SciMLSensitivity.jl are actually Enzyme issues with v1.11. These are now being tracked here EnzymeAD/Enzyme.jl#2318 for the Enzyme developers to work on. But there are certain things we've been weary about, like:

* #997
* #998

Because of failing downstream tests. But that's counter productive: we're not improving our autodiff interfaces because we see failures, but those aren't failures of our autodiff interfaces, those are Enzyme failures.

So for now the solution seems to be to go to v1.10 in these downstream tests and increase the coverage of SciMLSensitivity, and focus on our parts. We can re-enable "1" when Enzyme is ready for it, but for now it's just noise.
Comment thread ext/SciMLBaseChainRulesCoreExt.jl Outdated
T, N, uType, R, P, A, O, uType2, S, Tr}}, u,
args...) where {T, N, uType, R, P, A, O, uType2, S, Tr}
function NonlinearSolutionAdjoint(ȳ)
(NoTangent(), ȳ.u, ntuple(_ -> NoTangent(), length(args))...)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably track .prob as well (for gradients against parameters).

Also what is the type of ? What types could we encounter here? If it's another NonlinearSolution, we should investigate what produces it and potentially try to return a tangent type

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an mwe is

@parameters g
@variables x(t) y(t) [state_priority = 10] λ(t)
eqs = [D(D(x)) ~ λ * x
       D(D(y)) ~ λ * y - g
       x^2 + y^2 ~ 1]
@mtkbuild pend = ODESystem(eqs, t)

prob = ODEProblem(pend, [x => 1, y => 0], (0.0, 1.5), [g => 1.5], guesses ==> 1])

sol = solve(prob, Rodas5P())

get_vars = getsym(prob, [pend.x+pend.y]);

Zygote.gradient(sol) do sol
    u = get_vars(sol)
    # u = sol[pend.x+pend.y]
    sum(reduce(vcat, u))
end

which points to this branch

VA = recursivecopy(VA)
and
VA = recursivecopy(VA)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was always a NonlinearSolution when I was testing.

@jClugstor jClugstor May 13, 2025

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, just now getting back to this.
To track .prob, would it be like this:

function ChainRulesCore.rrule(
        ::Type{<:SciMLBase.NonlinearSolution{
            T, N, uType, R, P, A, O, uType2, S, Tr}}, u, resid, prob,
        args...) where {T, N, uType, R, P, A, O, uType2, S, Tr}
    function NonlinearSolutionAdjoint(ȳ)
        (NoTangent(), ȳ.u, NoTangent(), ŷ.prob, ntuple(_ -> NoTangent(), length(args))...)
    end
    SciMLBase.NonlinearSolution{T, N, uType, R, P, A, O, uType2, S, Tr}(u, resid, prob, args...),
    NonlinearSolutionAdjoint
end

?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this looks reasonable

Comment thread ext/SciMLBaseChainRulesCoreExt.jl Outdated
T, N, uType, R, P, A, O, uType2, S, Tr}}, u, resid, prob,
args...) where {T, N, uType, R, P, A, O, uType2, S, Tr}
function NonlinearSolutionAdjoint(ȳ)
(NoTangent(), ȳ.u, NoTangent(), ŷ.prob, ntuple(_ -> NoTangent(), length(args))...)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a follow up, is the type of a solution type still with the latest SciMLSensitivity

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is a Tangent type with the latest SciMLSensitivity:

typeof(ȳ) = Tangent{Any}(u = 1.0, resid = ChainRulesCore.ZeroTangent(), prob = ChainRulesCore.ZeroTangent(), alg = ChainRulesCore.ZeroTangent(), retcode = ChainRulesCore.ZeroTangent(), original = ChainRulesCore.ZeroTangent(), left = ChainRulesCore.ZeroTangent(), right = ChainRulesCore.ZeroTangent(), stats = ChainRulesCore.ZeroTangent(), trace = ChainRulesCore.ZeroTangent())

@jClugstor jClugstor force-pushed the nonlinearsolution_adjoint branch from 031520d to d948f07 Compare May 14, 2025 15:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

IntervalNonlinearProblem fails with Zygote

3 participants