diff --git a/Project.toml b/Project.toml index b5c35be18..53c1d8986 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiffEqBase" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" authors = ["Chris Rackauckas "] -version = "6.208.0" +version = "6.209.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -98,7 +98,7 @@ Printf = "1.9" RecursiveArrayTools = "3.1" Reexport = "1.0" ReverseDiff = "1" -SciMLBase = "2.142.0" +SciMLBase = "2.143.0" SciMLOperators = "1" SciMLStructures = "1.5" Setfield = "1" diff --git a/ext/DiffEqBaseForwardDiffExt.jl b/ext/DiffEqBaseForwardDiffExt.jl index 9937d1c04..25264bf06 100644 --- a/ext/DiffEqBaseForwardDiffExt.jl +++ b/ext/DiffEqBaseForwardDiffExt.jl @@ -11,6 +11,7 @@ import SciMLBase: isdualtype, DualEltypeChecker, sse, __sum const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1} dualgen(::Type{T}) where {T} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, 1} +dualgen(::Type{T}, ::Val{CS}) where {T, CS} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, CS} const NORECOMPILE_IIP_SUPPORTED_ARGS = ( Tuple{ @@ -85,6 +86,41 @@ function wrapfun_iip( return FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt) end +# 3-arg version: compile FunctionWrapper variants with the specified chunk size. +# Uses chunk=CS for u-related duals (Jacobian computation) and chunk=1 for +# t-related duals (time derivative is always scalar, so chunk=1). +function wrapfun_iip( + ff, + inputs::Tuple{T1, T2, T3, T4}, + ::Val{CS} + ) where {T1, T2, T3, T4, CS} + T = eltype(T2) + + # Jacobian (u-derivative) uses chunk=CS + dualT_jac = dualgen(T, Val(CS)) + dualT1_jac = ArrayInterface.promote_eltype(T1, dualT_jac) + dualT2_jac = ArrayInterface.promote_eltype(T2, dualT_jac) + + # Time derivative uses chunk=1 (scalar differentiation w.r.t. t) + dualT_time = dualgen(T) + dualT1_time = ArrayInterface.promote_eltype(T1, dualT_time) + dualT4_time = dualgen(promote_type(T, T4)) + + iip_arglists = ( + Tuple{T1, T2, T3, T4}, # plain + Tuple{dualT1_jac, dualT2_jac, T3, T4}, # Jacobian (u dual, chunk=CS) + Tuple{dualT1_time, T2, T3, dualT4_time}, # time derivative (chunk=1) + Tuple{dualT1_jac, dualT2_jac, T3, dualT4_time}, # both + ) + + iip_returnlists = ntuple(x -> Nothing, 4) + + fwt = map(iip_arglists, iip_returnlists) do A, R + FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff)) + end + return FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt) +end + const iip_arglists_default = ( Tuple{ Vector{Float64}, Vector{Float64}, Vector{Float64}, diff --git a/src/norecompile.jl b/src/norecompile.jl index 615bf3dc8..ccfc86f35 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -28,6 +28,9 @@ function wrapfun_iip(ff, inputs) ) end +# 3-arg fallback: when ForwardDiff extension is not loaded, ignore chunk size +wrapfun_iip(ff, inputs, ::Val) = wrapfun_iip(ff, inputs) + function wrapfun_oop(ff, inputs) return FunctionWrappersWrappers.FunctionWrappersWrapper( ff, (typeof(inputs),), (typeof(inputs[1]),) diff --git a/src/solve.jl b/src/solve.jl index fd120bde4..6e3512040 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -675,7 +675,8 @@ function get_concrete_problem(prob, isadapt; alg = nothing, kwargs...) tspan_promote = promote_tspan(u0_promote, p, tspan, prob, kwargs) f_promote = promote_f( prob.f, Val(SciMLBase.specialization(prob.f)), u0_promote, p, - tspan_promote[1], Val(_uses_forwarddiff(alg)) + tspan_promote[1], Val(_uses_forwarddiff(alg)), + _forwarddiff_chunksize(alg) ) if isconcreteu0(prob, tspan[1], kwargs) && prob.u0 === u0 && typeof(u0_promote) === typeof(prob.u0) && @@ -704,7 +705,8 @@ function get_concrete_problem(prob::DAEProblem, isadapt; alg = nothing, kwargs.. f_promote = promote_f( prob.f, Val(SciMLBase.specialization(prob.f)), u0_promote, p, - tspan_promote[1], Val(_uses_forwarddiff(alg)) + tspan_promote[1], Val(_uses_forwarddiff(alg)), + _forwarddiff_chunksize(alg) ) if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(prob.u0) && isconcretedu0(prob, tspan[1], kwargs) && typeof(du0_promote) === typeof(prob.du0) && @@ -752,6 +754,13 @@ function _promote_tspan(tspan, kwargs) end end +# Helper to get the effective ForwardDiff chunk size from the algorithm. +# Returns Val{CS}(). Defaults to Val(1) when algorithm is not known or unspecified. +_forwarddiff_chunksize(::Nothing) = Val(1) +_forwarddiff_chunksize(alg) = _resolve_chunksize(SciMLBase.forwarddiff_chunksize(alg)) +_resolve_chunksize(::Val{0}) = Val(1) +_resolve_chunksize(v::Val) = v + # Helper to determine if we need ForwardDiff-aware function wrapping. # Default to true (full wrapping) when algorithm is not known. _uses_forwarddiff(::Nothing) = true @@ -772,7 +781,10 @@ end # Full path for algorithms that use ForwardDiff internally (e.g. Rosenbrock). # These algorithms precompile AFTER the ForwardDiff extension loads, so # backedges to hasdualpromote/wrapfun_iip don't cause invalidation issues. -function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{true}) where {F, specialize} +function promote_f( + f::F, ::Val{specialize}, u0, p, t, ::Val{true}, + ::Val{CS} = Val(1) + ) where {F, specialize, CS} uElType = u0 === nothing ? Float64 : eltype(u0) if isdefined(f, :jac_prototype) && f.jac_prototype isa AbstractArray f = @set f.jac_prototype = similar(f.jac_prototype, uElType) @@ -802,10 +814,10 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{true}) where {F, spe if f.jac !== nothing && !(f.jac isa FunctionWrappersWrappers.FunctionWrappersWrapper) n = length(u0) J_proto = f.jac_prototype !== nothing ? similar(f.jac_prototype, uElType) : - zeros(uElType, n, n) + zeros(uElType, n, n) f = @set f.jac = wrapfun_jac_iip(f.jac, (J_proto, u0, p, t)) end - return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t))) + return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t), Val(CS))) else return f end @@ -815,7 +827,10 @@ end # Avoids calling hasdualpromote/wrapfun_iip which have extension overrides in # DiffEqBaseForwardDiffExt that would create invalidating method table backedges. # Uses a simple single-signature FunctionWrapper instead. -function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{false}) where {F, specialize} +function promote_f( + f::F, ::Val{specialize}, u0, p, t, ::Val{false}, + ::Val{CS} = Val(1) + ) where {F, specialize, CS} uElType = u0 === nothing ? Float64 : eltype(u0) if isdefined(f, :jac_prototype) && f.jac_prototype isa AbstractArray f = @set f.jac_prototype = similar(f.jac_prototype, uElType) @@ -849,7 +864,10 @@ end hasdualpromote(u0, t) = true -function promote_f(f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{true}) where {specialize} +function promote_f( + f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{true}, + ::Val{CS} = Val(1) + ) where {specialize, CS} return if isnothing(f._func_cache) f else @@ -857,7 +875,10 @@ function promote_f(f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{true}) w remake(f, _func_cache = copy(f._func_cache)) end end -function promote_f(f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{false}) where {specialize} +function promote_f( + f::SplitFunction, ::Val{specialize}, u0, p, t, ::Val{false}, + ::Val{CS} = Val(1) + ) where {specialize, CS} return if isnothing(f._func_cache) f else