Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions lib/DiffEqBase/ext/DiffEqBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,49 @@ function _make_fww(
}(fwt, cs)
end

# Replace the eltype of a Vector with `D`, *bypassing* ForwardDiff's
# tag-precedence `promote_type`. That precedence is a `@generated tagcount`
# that bakes a first-compile literal, so `promote_type(Dual{NL}, Dual{OrdEq,
# Dual{NL}, CS})` can invert the nesting (producing `Dual{NL, Dual{OrdEq,
# …}, 2}` rather than `Dual{OrdEq, Dual{NL}, CS}`) depending on which package
# precompiled which tag first. For the Jacobian-case FunctionWrapper slots we
# need a deterministic nested-dual type that matches the seeded `D`, so we
# force it here.
function _dualify_eltype(::Type{T}, ::Type{D}) where {T, D}
T <: AbstractArray || return T
eltype(T) <: ForwardDiff.Dual && T <: Vector && return Vector{D}
return ArrayInterface.promote_eltype(T, D)
end

# When `p` carries a Dual (e.g. NonlinearSolve drives a Jacobian through an
# ODE solve), the Jacobian-case FunctionWrapper signatures need to accept a
# `p` already promoted to the inner nested Dual so that the downstream widen
# step in `OrdinaryDiffEqDifferentiation.jacobian!` has a slot to dispatch to.
# The compiled body then multiplies `u*p` within a single tag hierarchy.
# Non-array `p` (NullParameters, scalars, tuples) stays as-is.
function _promote_p_sig(::Type{T3}, ::Type{DualT}) where {T3, DualT}
T3 <: AbstractArray || return T3
SciMLBase.anyeltypedual(T3) <: ForwardDiff.Dual || return T3
return _dualify_eltype(T3, DualT)
end

function wrapfun_iip(
ff,
inputs::Tuple{T1, T2, T3, T4}
) where {T1, T2, T3, T4}
T = eltype(T2)
dualT = dualgen(T)
dualT1 = ArrayInterface.promote_eltype(T1, dualT)
dualT2 = ArrayInterface.promote_eltype(T2, dualT)
dualT1 = _dualify_eltype(T1, dualT)
dualT2 = _dualify_eltype(T2, dualT)
dualT3 = _promote_p_sig(T3, dualT)
dualT4 = dualgen(promote_type(T, T4))

return _make_fww(
Void(ff),
Tuple{T1, T2, T3, T4},
Tuple{dualT1, dualT2, T3, T4},
Tuple{dualT1, dualT2, dualT3, T4},
Tuple{dualT1, T2, T3, dualT4},
Tuple{dualT1, dualT2, T3, dualT4}
Tuple{dualT1, dualT2, dualT3, dualT4}
)
end

Expand All @@ -113,20 +140,21 @@ function wrapfun_iip(

# Jacobian (u-derivative) uses chunk=CS
dualT_jac = dualgen(T, Val(CS))
dualT1_jac = ArrayInterface.promote_eltype(T1, dualT_jac)
dualT2_jac = ArrayInterface.promote_eltype(T2, dualT_jac)
dualT1_jac = _dualify_eltype(T1, dualT_jac)
dualT2_jac = _dualify_eltype(T2, dualT_jac)
dualT3_jac = _promote_p_sig(T3, dualT_jac)

# Time derivative uses chunk=1 (scalar differentiation w.r.t. t)
dualT_time = dualgen(T)
dualT1_time = ArrayInterface.promote_eltype(T1, dualT_time)
dualT1_time = _dualify_eltype(T1, dualT_time)
dualT4_time = dualgen(promote_type(T, T4))

return _make_fww(
Void(ff),
Tuple{T1, T2, T3, T4},
Tuple{dualT1_jac, dualT2_jac, T3, T4},
Tuple{dualT1_jac, dualT2_jac, dualT3_jac, T4},
Tuple{dualT1_time, T2, T3, dualT4_time},
Tuple{dualT1_jac, dualT2_jac, T3, dualT4_time}
Tuple{dualT1_jac, dualT2_jac, dualT3_jac, dualT4_time}
)
end

Expand Down
55 changes: 53 additions & 2 deletions lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,55 @@ function jacobian(f::F, x, integrator) where {F}
return jac
end

# Inner-Dual eltype that the prepared ForwardDiff JacobianConfig allocates for
# its xdual buffer. Returns `nothing` for non-ForwardDiff preps (e.g.
# AutoFiniteDiff), so the widening fast-path below is a single type-stable
# no-op dispatch in that case.
function _jac_prep_inner_dual_eltype(prep)
hasfield(typeof(prep), :config) || return nothing
cfg = getfield(prep, :config)
cfg isa ForwardDiff.JacobianConfig || return nothing
duals = cfg.duals
duals isa Tuple && length(duals) >= 2 || return nothing
return eltype(duals[2])
end

# When the integrator's stored `p` (held in `f::UJacobianWrapper`) is a
# `Vector{<:Dual}` because we are *inside* an outer ForwardDiff Jacobian /
# gradient, the inner Rosenbrock Jacobian widens `u` into a deeper nested-Dual
# type via the prepared `JacobianConfig`. If the user-facing function is a
# plain Julia function (the `FullSpecialize`/`NoSpecialize` case), then the
# subsequent `p[i] * u[i]` inside the user body multiplies values at two
# different Dual nesting levels, which dispatches through ForwardDiff's
# `tagcount`-based tag precedence. That precedence is a `@generated` function
# whose literal value is baked at first compile and depends on which package
# precompiled which tag type first, so the result type can come out wrong and
# crash inside `setindex!(du, ...)` with `Float64(::nested_dual)`.
#
# Lift `p` into the inner nested-Dual type once, *before* delegating to
# `DI.jacobian!`, so the user body never multiplies across tag levels. The
# widened `p` carries zero inner partials (correct — `p` does not depend on
# `u`). The per-step `convert.(inner_T, p)` allocation is amortized across
# every chunk evaluation DI performs inside one `jacobian!` call.
#
# This handles both `FullSpecialize` (direct user function) and
# `AutoSpecialize` (FunctionWrappersWrapper). For the latter, DiffEqBase's
# `wrapfun_iip` compiles Jacobian-case signatures with the promoted `p` type,
# so FWW dispatches to the nested-`p` slot whose compiled body multiplies
# `u*p` within a single tag hierarchy.
_widen_uf_p_for_jac(f, prep) = f
function _widen_uf_p_for_jac(f::UJacobianWrapper, prep)
inner_T = _jac_prep_inner_dual_eltype(prep)
inner_T === nothing && return f
p = f.p
p isa AbstractArray || return f
Tp = eltype(p)
Tp <: ForwardDiff.Dual || return f
Tp === inner_T && return f
inner_T <: ForwardDiff.Dual || return f
return @set f.p = convert.(inner_T, p)
end

function jacobian!(
J::AbstractMatrix{<:Number}, f::F, x::AbstractArray{<:Number},
fx::AbstractArray{<:Number}, integrator::SciMLBase.DEIntegrator,
Expand Down Expand Up @@ -240,14 +289,16 @@ function jacobian!(
config = jac_config[1]
end

f_eff = _widen_uf_p_for_jac(f, config)

if integrator.iter == 1
try
DI.jacobian!(f, fx, J, config, gpu_safe_autodiff(alg_autodiff(alg), x), x)
DI.jacobian!(f_eff, fx, J, config, gpu_safe_autodiff(alg_autodiff(alg), x), x)
catch e
throw(FirstAutodiffJacError(e))
end
else
DI.jacobian!(f, fx, J, config, gpu_safe_autodiff(alg_autodiff(alg), x), x)
DI.jacobian!(f_eff, fx, J, config, gpu_safe_autodiff(alg_autodiff(alg), x), x)
end

return nothing
Expand Down
109 changes: 109 additions & 0 deletions lib/OrdinaryDiffEqDifferentiation/test/nested_forwarddiff_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
using Test
using OrdinaryDiffEqRosenbrock
using OrdinaryDiffEqDifferentiation
using SciMLBase
using ADTypes
using ForwardDiff

# Regression test for nested ForwardDiff over an ODE solve
# (https://github.com/SciML/OrdinaryDiffEq.jl/issues/3381).
#
# When a Rosenbrock solver is invoked with a `Vector{<:Dual}` `p` (i.e. we are
# inside an *outer* ForwardDiff layer), the inner Rosenbrock Jacobian widens
# `u` into a deeper nested-Dual type via its `JacobianConfig`. The user body
# `f(du, u, p, t)` then multiplies `p[i] * u[i]` across two different Dual
# nesting levels. ForwardDiff's cross-tag multiplication goes through a
# `@generated tagcount` precedence whose literal value is baked at first
# compile and depends on precompile ordering; that ordering can invert the
# nesting and produce a triple-nested `Dual` that crashes the eventual
# `setindex!(du, ...)` with `Float64(::nested_dual)`.
#
# The fix (`_widen_uf_p_for_jac` in derivative_wrappers.jl) lifts `uf.p` into
# the inner nested-Dual type ahead of `DI.jacobian!`, so the user body never
# multiplies across tag levels.
@testset "Nested ForwardDiff through Rosenbrock (issue #3381)" begin
function ode!(du, u, p, t)
du[1] = -p[1] * u[1]
du[2] = -u[1] - p[2] * u[2]
return nothing
end
ode_f = ODEFunction{true, SciMLBase.FullSpecialize}(ode!)

outer_f = function (p)
prob = ODEProblem(ode_f, [1.0, 1.0], (0.0, 1.0), p)
sol = solve(
prob, Rosenbrock23(autodiff = AutoForwardDiff(chunksize = 2));
reltol = 1.0e-8, abstol = 1.0e-8
)
return sol.u[end]
end

J = ForwardDiff.jacobian(outer_f, [1.5, 2.0])
@test size(J) == (2, 2)
@test all(isfinite, J)

# Nested case with a hand-rolled outer tag — mimics NonlinearSolve's
# Tag{NonlinearSolveTag, Float64} wrapping `p` while Rosenbrock seeds `u`
# under Tag{OrdinaryDiffEqTag, …}.
T = ForwardDiff.Tag{:NestedForwardDiffOuter, Float64}
p_dual = ForwardDiff.Dual{T, Float64, 2}[
ForwardDiff.Dual{T}(1.5, ForwardDiff.Partials{2, Float64}((1.0, 0.0))),
ForwardDiff.Dual{T}(2.0, ForwardDiff.Partials{2, Float64}((0.0, 1.0))),
]
prob = ODEProblem(ode_f, [1.0, 1.0], (0.0, 1.0), p_dual)
sol = solve(
prob, Rosenbrock23(autodiff = AutoForwardDiff(chunksize = 2));
reltol = 1.0e-8, abstol = 1.0e-8
)
@test SciMLBase.successful_retcode(sol)
@test all(u -> all(isfinite, u), sol.u)

# AutoSpecialize wraps `ode!` in a FunctionWrappersWrapper. DiffEqBase's
# `wrapfun_iip` compiles a Jacobian-case slot with the promoted `p` type;
# the widen step in `jacobian!` then dispatches into that slot whose body
# multiplies `u*p` within one tag hierarchy.
ode_f_auto = ODEFunction{true, SciMLBase.AutoSpecialize}(ode!)
prob_auto = ODEProblem(ode_f_auto, [1.0, 1.0], (0.0, 1.0), p_dual)
sol_auto = solve(
prob_auto, Rosenbrock23(autodiff = AutoForwardDiff(chunksize = 2));
reltol = 1.0e-8, abstol = 1.0e-8
)
@test SciMLBase.successful_retcode(sol_auto)
@test all(u -> all(isfinite, u), sol_auto.u)
end

# Regression test for the cache-time pre-widen optimization. The Rosenbrock
# alg_cache builders call `_widen_uf_p_for_jac(uf, jac_config[1])` once after
# `build_jac_config`, so the integrator's cached `uf` already has `p` lifted
# to the inner nested-Dual eltype. Subsequent calls to `_widen_uf_p_for_jac`
# from `jacobian!` then hit the `Tp === inner_T && return f` fast-path and
# allocate zero bytes per Jacobian eval. If a future change drops the
# pre-widen line from a cache builder, the eltype check fails; if the
# fast-path itself is broken, the @allocated check fails.
@testset "Pre-widened UJacobianWrapper.p in Rosenbrock cache (issue #3381)" begin
function ode!(du, u, p, t)
du[1] = -p[1] * u[1]
du[2] = -u[1] - p[2] * u[2]
return nothing
end
ode_f = ODEFunction{true, SciMLBase.FullSpecialize}(ode!)

T = ForwardDiff.Tag{:NestedForwardDiffOuter, Float64}
p_dual = ForwardDiff.Dual{T, Float64, 2}[
ForwardDiff.Dual{T}(1.5, ForwardDiff.Partials{2, Float64}((1.0, 0.0))),
ForwardDiff.Dual{T}(2.0, ForwardDiff.Partials{2, Float64}((0.0, 1.0))),
]
prob = ODEProblem(ode_f, [1.0, 1.0], (0.0, 1.0), p_dual)
integ = init(prob, Rosenbrock23(autodiff = AutoForwardDiff(chunksize = 2)))

uf = integ.cache.uf
cfg = integ.cache.jac_config[1]
inner_T = eltype(cfg.config.duals[2])

# Cache must hold a uf whose p is the inner nested-Dual eltype.
@test eltype(uf.p) === inner_T

# Per-call widen on the cached uf must be a zero-allocation no-op.
OrdinaryDiffEqDifferentiation._widen_uf_p_for_jac(uf, cfg) # warmup
@test (@allocated OrdinaryDiffEqDifferentiation._widen_uf_p_for_jac(uf, cfg)) == 0
end
1 change: 1 addition & 0 deletions lib/OrdinaryDiffEqDifferentiation/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ if TEST_GROUP ∉ ("QA", "Sparse", "ModelingToolkit")
@time @safetestset "Differentiation Trait Tests" include("differentiation_traits_tests.jl")
@time @safetestset "Autodiff Error Tests" include("autodiff_error_tests.jl")
@time @safetestset "No Jac Tests" include("nojac_tests.jl")
@time @safetestset "Nested ForwardDiff" include("nested_forwarddiff_tests.jl")
end

# Run sparse tests (separate environment due to ComponentArrays dep conflicts)
Expand Down
3 changes: 2 additions & 1 deletion lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ using OrdinaryDiffEqDifferentiation: TimeDerivativeWrapper, TimeGradientWrapper,
build_jac_config, issuccess_W, jacobian2W!,
resize_jac_config!, resize_grad_config!,
calc_W, calc_rosenbrock_differentiation!, build_J_W,
UJacobianWrapper, dolinsolve, WOperator, resize_J_W!
UJacobianWrapper, dolinsolve, WOperator, resize_J_W!,
_widen_uf_p_for_jac

using OrdinaryDiffEqDifferentiation: calc_rosenbrock_differentiation

Expand Down
4 changes: 4 additions & 0 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ function alg_cache(

grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
uf = _widen_uf_p_for_jac(uf, jac_config[1])

J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))

Expand Down Expand Up @@ -292,6 +293,7 @@ function alg_cache(

grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
uf = _widen_uf_p_for_jac(uf, jac_config[1])

J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))

Expand Down Expand Up @@ -535,6 +537,7 @@ function alg_cache(

grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
uf = _widen_uf_p_for_jac(uf, jac_config[1])

J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))

Expand Down Expand Up @@ -680,6 +683,7 @@ function alg_cache(

grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
uf = _widen_uf_p_for_jac(uf, jac_config[1])
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))

linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing, u, p, t); u0 = _vec(tmp))
Expand Down
Loading