Skip to content
This repository was archived by the owner on May 12, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqBase"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "6.208.0"
version = "6.209.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -98,7 +98,7 @@ Printf = "1.9"
RecursiveArrayTools = "3.1"
Reexport = "1.0"
ReverseDiff = "1"
SciMLBase = "2.142.0"
SciMLBase = "2.143.0"
SciMLOperators = "1"
SciMLStructures = "1.5"
Setfield = "1"
Expand Down
36 changes: 36 additions & 0 deletions ext/DiffEqBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import SciMLBase: isdualtype, DualEltypeChecker, sse, __sum

const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1}
dualgen(::Type{T}) where {T} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, 1}
dualgen(::Type{T}, ::Val{CS}) where {T, CS} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, CS}

const NORECOMPILE_IIP_SUPPORTED_ARGS = (
Tuple{
Expand Down Expand Up @@ -85,6 +86,41 @@ function wrapfun_iip(
return FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
end

# 3-arg version: compile FunctionWrapper variants with the specified chunk size.
# Uses chunk=CS for u-related duals (Jacobian computation) and chunk=1 for
# t-related duals (time derivative is always scalar, so chunk=1).
function wrapfun_iip(
ff,
inputs::Tuple{T1, T2, T3, T4},
::Val{CS}
) where {T1, T2, T3, T4, CS}
T = eltype(T2)

# Jacobian (u-derivative) uses chunk=CS
dualT_jac = dualgen(T, Val(CS))
dualT1_jac = ArrayInterface.promote_eltype(T1, dualT_jac)
dualT2_jac = ArrayInterface.promote_eltype(T2, dualT_jac)

# Time derivative uses chunk=1 (scalar differentiation w.r.t. t)
dualT_time = dualgen(T)
dualT1_time = ArrayInterface.promote_eltype(T1, dualT_time)
dualT4_time = dualgen(promote_type(T, T4))

iip_arglists = (
Tuple{T1, T2, T3, T4}, # plain
Tuple{dualT1_jac, dualT2_jac, T3, T4}, # Jacobian (u dual, chunk=CS)
Tuple{dualT1_time, T2, T3, dualT4_time}, # time derivative (chunk=1)
Tuple{dualT1_jac, dualT2_jac, T3, dualT4_time}, # both
)

iip_returnlists = ntuple(x -> Nothing, 4)

fwt = map(iip_arglists, iip_returnlists) do A, R
FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff))
end
return FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
end

const iip_arglists_default = (
Tuple{
Vector{Float64}, Vector{Float64}, Vector{Float64},
Expand Down
3 changes: 3 additions & 0 deletions src/norecompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ function wrapfun_iip(ff, inputs)
)
end

# 3-arg fallback: when ForwardDiff extension is not loaded, ignore chunk size
wrapfun_iip(ff, inputs, ::Val) = wrapfun_iip(ff, inputs)

function wrapfun_oop(ff, inputs)
return FunctionWrappersWrappers.FunctionWrappersWrapper(
ff, (typeof(inputs),), (typeof(inputs[1]),)
Expand Down
37 changes: 29 additions & 8 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,8 @@ function get_concrete_problem(prob, isadapt; alg = nothing, kwargs...)
tspan_promote = promote_tspan(u0_promote, p, tspan, prob, kwargs)
f_promote = promote_f(
prob.f, Val(SciMLBase.specialization(prob.f)), u0_promote, p,
tspan_promote[1], Val(_uses_forwarddiff(alg))
tspan_promote[1], Val(_uses_forwarddiff(alg)),
_forwarddiff_chunksize(alg)
)
if isconcreteu0(prob, tspan[1], kwargs) && prob.u0 === u0 &&
typeof(u0_promote) === typeof(prob.u0) &&
Expand Down Expand Up @@ -704,7 +705,8 @@ function get_concrete_problem(prob::DAEProblem, isadapt; alg = nothing, kwargs..

f_promote = promote_f(
prob.f, Val(SciMLBase.specialization(prob.f)), u0_promote, p,
tspan_promote[1], Val(_uses_forwarddiff(alg))
tspan_promote[1], Val(_uses_forwarddiff(alg)),
_forwarddiff_chunksize(alg)
)
if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(prob.u0) &&
isconcretedu0(prob, tspan[1], kwargs) && typeof(du0_promote) === typeof(prob.du0) &&
Expand Down Expand Up @@ -752,6 +754,13 @@ function _promote_tspan(tspan, kwargs)
end
end

# Helper to get the effective ForwardDiff chunk size from the algorithm.
# Returns Val{CS}(). Defaults to Val(1) when algorithm is not known or unspecified.
_forwarddiff_chunksize(::Nothing) = Val(1)
_forwarddiff_chunksize(alg) = _resolve_chunksize(SciMLBase.forwarddiff_chunksize(alg))
_resolve_chunksize(::Val{0}) = Val(1)
_resolve_chunksize(v::Val) = v

# Helper to determine if we need ForwardDiff-aware function wrapping.
# Default to true (full wrapping) when algorithm is not known.
_uses_forwarddiff(::Nothing) = true
Expand All @@ -772,7 +781,10 @@ end
# Full path for algorithms that use ForwardDiff internally (e.g. Rosenbrock).
# These algorithms precompile AFTER the ForwardDiff extension loads, so
# backedges to hasdualpromote/wrapfun_iip don't cause invalidation issues.
function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{true}) where {F, specialize}
function promote_f(
f::F, ::Val{specialize}, u0, p, t, ::Val{true},
::Val{CS} = Val(1)
) where {F, specialize, CS}
uElType = u0 === nothing ? Float64 : eltype(u0)
if isdefined(f, :jac_prototype) && f.jac_prototype isa AbstractArray
f = @set f.jac_prototype = similar(f.jac_prototype, uElType)
Expand Down Expand Up @@ -802,10 +814,10 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{true}) where {F, spe
if f.jac !== nothing && !(f.jac isa FunctionWrappersWrappers.FunctionWrappersWrapper)
n = length(u0)
J_proto = f.jac_prototype !== nothing ? similar(f.jac_prototype, uElType) :
zeros(uElType, n, n)
zeros(uElType, n, n)
f = @set f.jac = wrapfun_jac_iip(f.jac, (J_proto, u0, p, t))
end
return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t)))
return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t), Val(CS)))
else
return f
end
Expand All @@ -815,7 +827,10 @@ end
# Avoids calling hasdualpromote/wrapfun_iip which have extension overrides in
# DiffEqBaseForwardDiffExt that would create invalidating method table backedges.
# Uses a simple single-signature FunctionWrapper instead.
function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{false}) where {F, specialize}
function promote_f(
f::F, ::Val{specialize}, u0, p, t, ::Val{false},
::Val{CS} = Val(1)
) where {F, specialize, CS}
uElType = u0 === nothing ? Float64 : eltype(u0)
if isdefined(f, :jac_prototype) && f.jac_prototype isa AbstractArray
f = @set f.jac_prototype = similar(f.jac_prototype, uElType)
Expand Down Expand Up @@ -849,15 +864,21 @@ end

hasdualpromote(u0, t) = true

function promote_f(f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{true}) where {specialize}
function promote_f(
f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{true},
::Val{CS} = Val(1)
) where {specialize, CS}
return if isnothing(f._func_cache)
f
else
# Copy the cache to ensure it's properly initialized
remake(f, _func_cache = copy(f._func_cache))
end
end
function promote_f(f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{false}) where {specialize}
function promote_f(
f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{false},
::Val{CS} = Val(1)
) where {specialize, CS}
return if isnothing(f._func_cache)
f
else
Expand Down
Loading