From 33e25dc287ab618698da65b2ee72dead7409ea6e Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Sun, 22 Feb 2026 04:47:11 -0500 Subject: [PATCH 1/2] Thread ForwardDiff chunk size through FunctionWrapper compilation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use the new `SciMLBase.forwarddiff_chunksize` trait to compile FunctionWrapper variants with the algorithm's actual chunk size instead of hardcoding chunk=1. Changes: - Add `_forwarddiff_chunksize` helper in solve.jl - Thread `Val(CS)` through `get_concrete_problem` → `promote_f` → `wrapfun_iip` - Add parameterized `dualgen(T, Val(CS))` in ForwardDiff extension - Add 3-arg `wrapfun_iip(ff, inputs, Val(CS))` in ForwardDiff extension - Add 3-arg fallback `wrapfun_iip(ff, inputs, ::Val)` in norecompile.jl This fixes NoFunctionWrapperFoundError when algorithms specify a non-default chunk size (e.g. AutoForwardDiff(chunksize=3)) on mass matrix problems. Requires SciMLBase >= 2.143.0. Co-Authored-By: Chris Rackauckas --- Project.toml | 4 ++-- ext/DiffEqBaseForwardDiffExt.jl | 36 +++++++++++++++++++++++++++++++ src/norecompile.jl | 3 +++ src/solve.jl | 38 ++++++++++++++++++++++++++------- 4 files changed, 71 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index b5c35be18..53c1d8986 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiffEqBase" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" authors = ["Chris Rackauckas "] -version = "6.208.0" +version = "6.209.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -98,7 +98,7 @@ Printf = "1.9" RecursiveArrayTools = "3.1" Reexport = "1.0" ReverseDiff = "1" -SciMLBase = "2.142.0" +SciMLBase = "2.143.0" SciMLOperators = "1" SciMLStructures = "1.5" Setfield = "1" diff --git a/ext/DiffEqBaseForwardDiffExt.jl b/ext/DiffEqBaseForwardDiffExt.jl index 9937d1c04..25264bf06 100644 --- a/ext/DiffEqBaseForwardDiffExt.jl +++ b/ext/DiffEqBaseForwardDiffExt.jl @@ -11,6 +11,7 @@ import SciMLBase: isdualtype, DualEltypeChecker, sse, __sum const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1} dualgen(::Type{T}) where {T} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, 1} +dualgen(::Type{T}, ::Val{CS}) where {T, CS} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, CS} const NORECOMPILE_IIP_SUPPORTED_ARGS = ( Tuple{ @@ -85,6 +86,41 @@ function wrapfun_iip( return FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt) end +# 3-arg version: compile FunctionWrapper variants with the specified chunk size. +# Uses chunk=CS for u-related duals (Jacobian computation) and chunk=1 for +# t-related duals (time derivative is always scalar, so chunk=1). +function wrapfun_iip( + ff, + inputs::Tuple{T1, T2, T3, T4}, + ::Val{CS} + ) where {T1, T2, T3, T4, CS} + T = eltype(T2) + + # Jacobian (u-derivative) uses chunk=CS + dualT_jac = dualgen(T, Val(CS)) + dualT1_jac = ArrayInterface.promote_eltype(T1, dualT_jac) + dualT2_jac = ArrayInterface.promote_eltype(T2, dualT_jac) + + # Time derivative uses chunk=1 (scalar differentiation w.r.t. t) + dualT_time = dualgen(T) + dualT1_time = ArrayInterface.promote_eltype(T1, dualT_time) + dualT4_time = dualgen(promote_type(T, T4)) + + iip_arglists = ( + Tuple{T1, T2, T3, T4}, # plain + Tuple{dualT1_jac, dualT2_jac, T3, T4}, # Jacobian (u dual, chunk=CS) + Tuple{dualT1_time, T2, T3, dualT4_time}, # time derivative (chunk=1) + Tuple{dualT1_jac, dualT2_jac, T3, dualT4_time}, # both + ) + + iip_returnlists = ntuple(x -> Nothing, 4) + + fwt = map(iip_arglists, iip_returnlists) do A, R + FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff)) + end + return FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt) +end + const iip_arglists_default = ( Tuple{ Vector{Float64}, Vector{Float64}, Vector{Float64}, diff --git a/src/norecompile.jl b/src/norecompile.jl index 615bf3dc8..ccfc86f35 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -28,6 +28,9 @@ function wrapfun_iip(ff, inputs) ) end +# 3-arg fallback: when ForwardDiff extension is not loaded, ignore chunk size +wrapfun_iip(ff, inputs, ::Val) = wrapfun_iip(ff, inputs) + function wrapfun_oop(ff, inputs) return FunctionWrappersWrappers.FunctionWrappersWrapper( ff, (typeof(inputs),), (typeof(inputs[1]),) diff --git a/src/solve.jl b/src/solve.jl index fd120bde4..0c0d0d382 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -675,7 +675,8 @@ function get_concrete_problem(prob, isadapt; alg = nothing, kwargs...) tspan_promote = promote_tspan(u0_promote, p, tspan, prob, kwargs) f_promote = promote_f( prob.f, Val(SciMLBase.specialization(prob.f)), u0_promote, p, - tspan_promote[1], Val(_uses_forwarddiff(alg)) + tspan_promote[1], Val(_uses_forwarddiff(alg)), + Val(_forwarddiff_chunksize(alg)) ) if isconcreteu0(prob, tspan[1], kwargs) && prob.u0 === u0 && typeof(u0_promote) === typeof(prob.u0) && @@ -704,7 +705,8 @@ function get_concrete_problem(prob::DAEProblem, isadapt; alg = nothing, kwargs.. f_promote = promote_f( prob.f, Val(SciMLBase.specialization(prob.f)), u0_promote, p, - tspan_promote[1], Val(_uses_forwarddiff(alg)) + tspan_promote[1], Val(_uses_forwarddiff(alg)), + Val(_forwarddiff_chunksize(alg)) ) if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(prob.u0) && isconcretedu0(prob, tspan[1], kwargs) && typeof(du0_promote) === typeof(prob.du0) && @@ -752,6 +754,14 @@ function _promote_tspan(tspan, kwargs) end end +# Helper to get the effective ForwardDiff chunk size from the algorithm. +# Defaults to 1 when algorithm is not known or doesn't specify. +_forwarddiff_chunksize(::Nothing) = 1 +function _forwarddiff_chunksize(alg) + cs = SciMLBase.forwarddiff_chunksize(alg) + return (cs === nothing || cs <= 0) ? 1 : cs +end + # Helper to determine if we need ForwardDiff-aware function wrapping. # Default to true (full wrapping) when algorithm is not known. _uses_forwarddiff(::Nothing) = true @@ -772,7 +782,10 @@ end # Full path for algorithms that use ForwardDiff internally (e.g. Rosenbrock). # These algorithms precompile AFTER the ForwardDiff extension loads, so # backedges to hasdualpromote/wrapfun_iip don't cause invalidation issues. -function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{true}) where {F, specialize} +function promote_f( + f::F, ::Val{specialize}, u0, p, t, ::Val{true}, + ::Val{CS} = Val(1) + ) where {F, specialize, CS} uElType = u0 === nothing ? Float64 : eltype(u0) if isdefined(f, :jac_prototype) && f.jac_prototype isa AbstractArray f = @set f.jac_prototype = similar(f.jac_prototype, uElType) @@ -802,10 +815,10 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{true}) where {F, spe if f.jac !== nothing && !(f.jac isa FunctionWrappersWrappers.FunctionWrappersWrapper) n = length(u0) J_proto = f.jac_prototype !== nothing ? similar(f.jac_prototype, uElType) : - zeros(uElType, n, n) + zeros(uElType, n, n) f = @set f.jac = wrapfun_jac_iip(f.jac, (J_proto, u0, p, t)) end - return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t))) + return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t), Val(CS))) else return f end @@ -815,7 +828,10 @@ end # Avoids calling hasdualpromote/wrapfun_iip which have extension overrides in # DiffEqBaseForwardDiffExt that would create invalidating method table backedges. # Uses a simple single-signature FunctionWrapper instead. -function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{false}) where {F, specialize} +function promote_f( + f::F, ::Val{specialize}, u0, p, t, ::Val{false}, + ::Val{CS} = Val(1) + ) where {F, specialize, CS} uElType = u0 === nothing ? Float64 : eltype(u0) if isdefined(f, :jac_prototype) && f.jac_prototype isa AbstractArray f = @set f.jac_prototype = similar(f.jac_prototype, uElType) @@ -849,7 +865,10 @@ end hasdualpromote(u0, t) = true -function promote_f(f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{true}) where {specialize} +function promote_f( + f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{true}, + ::Val{CS} = Val(1) + ) where {specialize, CS} return if isnothing(f._func_cache) f else @@ -857,7 +876,10 @@ function promote_f(f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{true}) w remake(f, _func_cache = copy(f._func_cache)) end end -function promote_f(f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{false}) where {specialize} +function promote_f( + f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{false}, + ::Val{CS} = Val(1) + ) where {specialize, CS} return if isnothing(f._func_cache) f else From 50cb7642ac3fe354df3fcd130c845ef5b6a2b0c3 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Sun, 22 Feb 2026 08:03:00 -0500 Subject: [PATCH 2/2] Dispatch on Val from forwarddiff_chunksize trait instead of runtime branching Now that SciMLBase.forwarddiff_chunksize returns Val{N}(), dispatch on it directly via _resolve_chunksize(::Val{0}) = Val(1) instead of runtime Int comparison. Removes Val() wrapping at call sites since _forwarddiff_chunksize already returns Val. Co-Authored-By: Chris Rackauckas --- src/solve.jl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 0c0d0d382..6e3512040 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -676,7 +676,7 @@ function get_concrete_problem(prob, isadapt; alg = nothing, kwargs...) f_promote = promote_f( prob.f, Val(SciMLBase.specialization(prob.f)), u0_promote, p, tspan_promote[1], Val(_uses_forwarddiff(alg)), - Val(_forwarddiff_chunksize(alg)) + _forwarddiff_chunksize(alg) ) if isconcreteu0(prob, tspan[1], kwargs) && prob.u0 === u0 && typeof(u0_promote) === typeof(prob.u0) && @@ -706,7 +706,7 @@ function get_concrete_problem(prob::DAEProblem, isadapt; alg = nothing, kwargs.. f_promote = promote_f( prob.f, Val(SciMLBase.specialization(prob.f)), u0_promote, p, tspan_promote[1], Val(_uses_forwarddiff(alg)), - Val(_forwarddiff_chunksize(alg)) + _forwarddiff_chunksize(alg) ) if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(prob.u0) && isconcretedu0(prob, tspan[1], kwargs) && typeof(du0_promote) === typeof(prob.du0) && @@ -755,12 +755,11 @@ function _promote_tspan(tspan, kwargs) end # Helper to get the effective ForwardDiff chunk size from the algorithm. -# Defaults to 1 when algorithm is not known or doesn't specify. -_forwarddiff_chunksize(::Nothing) = 1 -function _forwarddiff_chunksize(alg) - cs = SciMLBase.forwarddiff_chunksize(alg) - return (cs === nothing || cs <= 0) ? 1 : cs -end +# Returns Val{CS}(). Defaults to Val(1) when algorithm is not known or unspecified. +_forwarddiff_chunksize(::Nothing) = Val(1) +_forwarddiff_chunksize(alg) = _resolve_chunksize(SciMLBase.forwarddiff_chunksize(alg)) +_resolve_chunksize(::Val{0}) = Val(1) +_resolve_chunksize(v::Val) = v # Helper to determine if we need ForwardDiff-aware function wrapping. # Default to true (full wrapping) when algorithm is not known.