diff --git a/src/callbacks.jl b/src/callbacks.jl index cb6285c9e..47776873e 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -398,6 +398,15 @@ function find_root(f, tup, rootfind::SciMLBase.RootfindOpt) IntervalNonlinearProblem{false}(f, tup), ModAB(), abstol = 0.0, reltol = 0.0 ) + if is_inverted_root_pair(sol, f, tup) + # "Inverted" root pair (#1290); direction of integration flips the bracket side + return if (sol.resid > 0) ⊻ (tup[1] > tup[2]) + find_root(f, (tup[1], sol.left), rootfind) + else + find_root(f, (sol.right, tup[2]), rootfind) + end + end + if rootfind == SciMLBase.LeftRootFind return sol.left else @@ -405,6 +414,29 @@ function find_root(f, tup, rootfind::SciMLBase.RootfindOpt) end end +""" +Determine if the root pair is "inverted" — i.e. if the final root bracket, when evaluated on +the condition function, has the opposite signs as the initial bracket. This can occur due to +numerical floating point noise. +""" +function is_inverted_root_pair(sol, f, tup) + + # Fast path (f(t) == 0). Alternative implementation: check if + # sol.retcode ∈ (ReturnCode.ExactSolutionLeft, ReturnCode.ExactSolutionRight) + iszero(sol.resid) && return false + + # Should be equal to sol.resid under current implementation of ModAB, but this is + # more robust against implementation changes and it also works around ModAB#860 + most_positive_residual = f(max(sol.left, sol.right)) + + # Under current implementation of ModAB, sol.resid = f(max(sol.left, sol.right)) + # Therefore, the residual should have the same sign as the condition function evaluated + # at maximum(tup); otherwise, the root pair is inverted. + # @show sol.resid, f(maximum(tup)) + return sign(most_positive_residual) != sign(f(maximum(tup))) +end + + """ findall_events!(next_sign,affect!,affect_neg!,prev_sign) diff --git a/test/callbacks.jl b/test/callbacks.jl index c409c7362..3b7aefecc 100644 --- a/test/callbacks.jl +++ b/test/callbacks.jl @@ -1,4 +1,5 @@ -using DiffEqBase, Test +using DiffEqBase, SciMLBase, Test +using BracketingNonlinearSolve: ModAB condition = function (u, t, integrator) # Event when event_f(u,t,k) == 0 return t - 2.95 @@ -134,3 +135,35 @@ test_find_first_callback(callbacks, find_first_integrator); @test irrational_f(after) < 0.0 @test nextfloat(after) == before end + +# https://github.com/SciML/DiffEqBase.jl/issues/1290 +@testset "Inverted root pair detection and correction" begin + # Discovered via random search + coeffs = [ + -0.7270388932299022, -0.6929210470992349, -0.7343652899957108, + 0.8310017775620168, 0.6030921975763498, 0.46703506019208685, + -2.3581581735824186, 2.0556608750360628, -0.8183123724103458, + -2.5113469878793513, 0.10406374497692948, 0.0701494558467343, + ] + a1, b1, c1, a2, b2, c2, a3, b3, c3, a4, b4, c4 = coeffs + f(t) = + (a1 * t^2 + b1 * t + c1) * (a3 * t^2 + b3 * t + c3) + + (a2 * t^2 + b2 * t + c2) * (a4 * t^2 + b4 * t + c4) + f(t, _) = f(t) + tspan = (0.4294759977027207, 0.5371755582641773) + + # Verify that ModAB directly produces an inverted root pair for this input + raw_sol = solve( + IntervalNonlinearProblem{false}(f, tspan), + ModAB(), abstol = 0.0, reltol = 0.0 + ) + @test DiffEqBase.is_inverted_root_pair(raw_sol, f, tspan) + + # find_root must detect and correct the inversion: the returned roots must bracket + # a root with matching condition signs + left = DiffEqBase.find_root(f, tspan, SciMLBase.LeftRootFind) + right = DiffEqBase.find_root(f, tspan, SciMLBase.RightRootFind) + @test f(left) > 0.0 + @test f(right) < 0.0 + @test nextfloat(left) == right +end