diff --git a/ext/DiffEqBaseForwardDiffExt.jl b/ext/DiffEqBaseForwardDiffExt.jl index 25264bf06..4cdb0d9cd 100644 --- a/ext/DiffEqBaseForwardDiffExt.jl +++ b/ext/DiffEqBaseForwardDiffExt.jl @@ -225,4 +225,134 @@ if !hasmethod(nextfloat, Tuple{ForwardDiff.Dual}) end end +import PrecompileTools +PrecompileTools.@compile_workload begin + # Scalar operations on Dual numbers (arithmetic, math functions, comparisons) + d1 = dualT(1.0, ForwardDiff.Partials((0.5,))) + d2 = dualT(2.0, ForwardDiff.Partials((1.0,))) + s = 3.14 + + # Arithmetic: Dual-Dual and Dual-scalar + d1 + d2 + d1 - d2 + d1 * d2 + d1 / d2 + d1 + s + s + d1 + d1 - s + s - d1 + d1 * s + s * d1 + d1 / s + s / d1 + -d1 + abs(d1) + + # Powers and roots + d1^2 + d1^3 + d2^0.5 + sqrt(d2) + cbrt(d2) + + # Transcendental functions + exp(d1) + log(d2) + sin(d1) + cos(d1) + tan(d1) + asin(dualT(0.5, ForwardDiff.Partials((1.0,)))) + acos(dualT(0.5, ForwardDiff.Partials((1.0,)))) + atan(d1) + atan(d1, d2) + sinh(d1) + cosh(d1) + tanh(d1) + + # Comparisons (used in step size control, event detection) + d1 < d2 + d1 > d2 + d1 <= d2 + d1 >= d2 + d1 == d2 + isnan(d1) + isinf(d1) + isfinite(d1) + + # min/max (used in limiters and error control) + min(d1, d2) + max(d1, d2) + min(d1, s) + max(d1, s) + + # Conversion and promotion + zero(dualT) + one(dualT) + float(d1) + ForwardDiff.value(d1) + ForwardDiff.partials(d1) + + # Array operations on Vector{dualT} + v1 = [d1, d2, dualT(0.0, ForwardDiff.Partials((0.0,)))] + v2 = [d2, d1, dualT(1.0, ForwardDiff.Partials((0.1,)))] + + # Basic array ops + v1 + v2 + v1 - v2 + v1 .* v2 + v1 ./ v2 + s .* v1 + v1 .+ s + v1 .- s + v1 .^ 2 + v1 .^ 0.5 + + # In-place array operations + out = similar(v1) + out .= v1 .+ v2 + out .= v1 .- v2 + out .= v1 .* v2 + out .= s .* v1 + out .= v1 .* s .+ v2 + out .= v1 .* s .- v2 .* s + + # Reductions (used in norm calculations, error estimation) + sum(v1) + sum(abs2, v1) + maximum(abs, v1) + + # LinearAlgebra operations + using LinearAlgebra + dot(v1, v2) + norm(v1) + norm(v1, Inf) + norm(v1, 1) + + # copy / fill + copy(v1) + fill!(out, zero(dualT)) + + # SubArray primitive broadcast operations for Float64 and Dual types. + # These are generic building blocks used by any ODE function with views. + # Note: fused multi-operand broadcast expressions (e.g. `dy .= k .* y1 .+ k .* y2 .* y3`) + # create unique nested Broadcasted types per expression and cannot be generically precompiled. + for T in (Float64, dualT) + x = zeros(T, 4) + dx = zeros(T, 4) + sv1 = @view x[1:2] + sv2 = @view x[3:4] + dsv1 = @view dx[1:2] + k = 0.04 + + # Primitive SubArray broadcast operations + dsv1 .= sv1 + dsv1 .= k .* sv1 + dsv1 .= sv1 .* sv2 + dsv1 .= sv1 .+ sv2 + dsv1 .= sv1 .- sv2 + dsv1 .= sv1 .^ 2 + dsv1 .= .-sv1 + end +end + end diff --git a/test/downstream/community_callback_tests.jl b/test/downstream/community_callback_tests.jl index 1985cf059..a8aeb0186 100644 --- a/test/downstream/community_callback_tests.jl +++ b/test/downstream/community_callback_tests.jl @@ -225,7 +225,7 @@ cb = VectorContinuousCallback(cond!, terminate_affect!, nothing, 1) u0 = [0.0, 0.0, 1.0] prob = ODEProblem(f!, u0, (0.0, 10.0); callback = cb) soln = solve(prob, Tsit5()) -@test soln.t[end] ≈ 4.712347213360699 +@test soln.t[end] ≈ 4.712347213360699 atol = 1e-4 odefun = ODEFunction((u, p, t) -> [u[2], u[2] - p]; mass_matrix = [1 0; 0 0]) callback = PresetTimeCallback(0.5, integ -> (integ.p = -integ.p))