From 8e707595129bd1217debc11f09305701d8443ec9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Gonz=C3=A1lez?= Date: Tue, 3 Mar 2026 16:38:13 +0100 Subject: [PATCH 1/9] wip --- src/callbacks.jl | 184 ++++++++++++++++++++++++++++------------------- test/gen.jl | 77 ++++++++++++++++++++ test/lv.jl | 69 ++++++++++++++++++ 3 files changed, 257 insertions(+), 73 deletions(-) create mode 100644 test/gen.jl create mode 100644 test/lv.jl diff --git a/src/callbacks.jl b/src/callbacks.jl index cb6285c9e..f0ac3d889 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -9,18 +9,18 @@ function initialize!(cb::CallbackSet, u, t, integrator::DEIntegrator) cb.discrete_callbacks... ) end -initialize!(cb::CallbackSet{Tuple{}, Tuple{}}, u, t, integrator::DEIntegrator) = false +initialize!(cb::CallbackSet{Tuple{},Tuple{}}, u, t, integrator::DEIntegrator) = false function initialize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback, cs::DECallback... - ) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback, cs::DECallback... +) c.initialize(c, u, t, integrator) return initialize!(u, t, integrator, any_modified || integrator.u_modified, cs...) end function initialize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback - ) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback +) c.initialize(c, u, t, integrator) return any_modified || integrator.u_modified end @@ -33,18 +33,18 @@ Recursively apply `finalize!` and return whether any modified u function finalize!(cb::CallbackSet, u, t, integrator::DEIntegrator) return finalize!(u, t, integrator, false, cb.continuous_callbacks..., cb.discrete_callbacks...) end -finalize!(cb::CallbackSet{Tuple{}, Tuple{}}, u, t, integrator::DEIntegrator) = false +finalize!(cb::CallbackSet{Tuple{},Tuple{}}, u, t, integrator::DEIntegrator) = false function finalize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback, cs::DECallback... - ) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback, cs::DECallback... +) c.finalize(c, u, t, integrator) return finalize!(u, t, integrator, any_modified || integrator.u_modified, cs...) end function finalize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback - ) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback +) c.finalize(c, u, t, integrator) return any_modified || integrator.u_modified end @@ -104,13 +104,13 @@ function get_condition(integrator::DEIntegrator, callback, abst) if callback.idxs === nothing integrator(tmp, abst, Val{0}) else - integrator(tmp, abst, Val{0}, idxs = callback.idxs) + integrator(tmp, abst, Val{0}, idxs=callback.idxs) end else if callback.idxs === nothing tmp = integrator(abst, Val{0}) else - tmp = integrator(abst, Val{0}, idxs = callback.idxs) + tmp = integrator(abst, Val{0}, idxs=callback.idxs) end end # ismutable && !(callback.idxs isa Number) ? integrator(tmp,abst,Val{0},idxs=callback.idxs) : @@ -130,24 +130,24 @@ end # Use a generated function for type stability even when many callbacks are given @inline function find_first_continuous_callback( - integrator, - callbacks::Vararg{ - AbstractContinuousCallback, - N, - } - ) where {N} + integrator, + callbacks::Vararg{ + AbstractContinuousCallback, + N, + } +) where {N} return find_first_continuous_callback(integrator, tuple(callbacks...)) end @generated function find_first_continuous_callback( - integrator, - callbacks::NTuple{ - N, - AbstractContinuousCallback, - } - ) where {N} + integrator, + callbacks::NTuple{ + N, + AbstractContinuousCallback, + } +) where {N} ex = quote tmin, upcrossing, - event_occurred, event_idx, residual = find_callback_time( + event_occurred, event_idx, residual = find_callback_time( integrator, callbacks[1], 1 ) @@ -157,7 +157,7 @@ end ex = quote $ex tmin2, upcrossing2, - event_occurred2, event_idx2, residual2 = find_callback_time( + event_occurred2, event_idx2, residual2 = find_callback_time( integrator, callbacks[$i], $i @@ -183,9 +183,9 @@ end end @inline function find_callback_time( - integrator, callback::VectorContinuousCallback, - callback_idx - ) + integrator, callback::VectorContinuousCallback, + callback_idx +) if callback.interp_points != 0 addsteps!(integrator) end @@ -239,7 +239,7 @@ end min_event_idx = -1 for idx in 1:length(event_idx) if ArrayInterface.allowed_getindex(event_idx, idx) != 0 - function zero_func(abst, p = nothing) + function zero_func(abst, p=nothing) return ArrayInterface.allowed_getindex( get_condition( integrator, @@ -271,13 +271,13 @@ end end return callback_t, ArrayInterface.allowed_getindex(bottom_sign, min_event_idx), - event_occurred::Bool, min_event_idx::Int, residual + event_occurred::Bool, min_event_idx::Int, residual end @inline function find_callback_time( - integrator, callback::ContinuousCallback, - callback_idx - ) + integrator, callback::ContinuousCallback, + callback_idx +) if callback.interp_points != 0 addsteps!(integrator) end @@ -310,7 +310,7 @@ end residual = zero(bottom_condition) else # Find callback time - zero_func(abst, p = nothing) = get_condition(integrator, callback, abst) + zero_func(abst, p=nothing) = get_condition(integrator, callback, abst) callback_t = find_root(zero_func, (bottom_t, top_t), callback.rootfind) residual = zero_func(callback_t) end @@ -344,9 +344,9 @@ function check_event_occurence(integrator, callback, bottom_sign) check_event_occurence_upto(integrator, callback, bottom_sign, top_t) if callback.interp_points != 0 && !isdiscrete(integrator.alg) && - any(iszero, event_idx) + any(iszero, event_idx) # Use the interpolants for safety checking - ts = range(integrator.tprev, stop = integrator.t, length = callback.interp_points) + ts = range(integrator.tprev, stop=integrator.t, length=callback.interp_points) for i in 2:length(ts) top_t = ts[i] event_occurred, event_idx, top_sign = @@ -384,6 +384,15 @@ function check_event_occurence_upto(integrator, callback::VectorContinuousCallba return event_occurred, event_idx, top_sign end +_shift(τ, i) = + if iszero(i) + τ + elseif i > 0 + _shift(nextfloat(τ), i - 1) + else + _shift(prevfloat(τ), i + 1) + end + """ Find either exact or floating point precision root of `f`. If the exact root cannot be represented, return closest floating point number depending on `rootfind` @@ -396,8 +405,15 @@ Assumes that: function find_root(f, tup, rootfind::SciMLBase.RootfindOpt) sol = solve( IntervalNonlinearProblem{false}(f, tup), - ModAB(), abstol = 0.0, reltol = 0.0 + ModAB(), abstol=0.0, reltol=0.0 ) + + + if is_inverted_root_pair(sol, f, tup) + # "Inverted" root pair (#1290) + + + if rootfind == SciMLBase.LeftRootFind return sol.left else @@ -405,6 +421,28 @@ function find_root(f, tup, rootfind::SciMLBase.RootfindOpt) end end +""" +Determine if the root pair is "inverted" — i.e. if the final root bracket, when evaluated on +the condition function, has the opposite signs as the initial bracket. This can occur due to +numerical floating point noise. +""" +function is_inverted_root_pair(sol, f, tup) + + # Fast path (f(t) == 0). Alternative implementation: check if + # sol.retcode ∈ (ReturnCode.ExactSolutionLeft, ReturnCode.ExactSolutionRight) + iszero(sol.resid) && return false + + # Under current implementation of ModAB, sol.resid = f(max(sol.left, sol.right)) + # Therefore, the residual should have the same sign as the condition function evaluated + # at maximum(tup); otherwise, the root pair is inverted. + sign(sol.resid) != sign(f(maximum(tup))) +end + + + + + + """ findall_events!(next_sign,affect!,affect_neg!,prev_sign) @@ -413,20 +451,20 @@ in the interval between prev_sign and next_sign. Return `true` if any event occured. """ function findall_events!( - next_sign::Union{Array, SubArray}, affect!::F1, affect_neg!::F2, - prev_sign::Union{Array, SubArray} - ) where {F1, F2} + next_sign::Union{Array,SubArray}, affect!::F1, affect_neg!::F2, + prev_sign::Union{Array,SubArray} +) where {F1,F2} @inbounds for i in 1:length(prev_sign) next_sign[i] = ( (prev_sign[i] < 0 && affect! !== nothing) || - (prev_sign[i] > 0 && affect_neg! !== nothing) + (prev_sign[i] > 0 && affect_neg! !== nothing) ) && - prev_sign[i] * next_sign[i] <= 0 + prev_sign[i] * next_sign[i] <= 0 end return any(isone, next_sign) end -function findall_events!(next_sign, affect!::F1, affect_neg!::F2, prev_sign) where {F1, F2} +function findall_events!(next_sign, affect!::F1, affect_neg!::F2, prev_sign) where {F1,F2} hasaffect::Bool = affect! !== nothing hasaffectneg::Bool = affect_neg! !== nothing f = (n, p) -> ((p < 0 && hasaffect) || (p > 0 && hasaffectneg)) && p * n <= 0 @@ -437,18 +475,18 @@ end """ Return `true` if an event occured. """ -function is_event_occurence(prev_sign::Number, next_sign::Number, affect!::F1, affect_neg!::F2) where {F1, F2} +function is_event_occurence(prev_sign::Number, next_sign::Number, affect!::F1, affect_neg!::F2) where {F1,F2} return ( (prev_sign < 0 && affect! !== nothing) || - (prev_sign > 0 && affect_neg! !== nothing) + (prev_sign > 0 && affect_neg! !== nothing) ) && prev_sign * next_sign <= 0 end function apply_callback!( - integrator, - callback::Union{ContinuousCallback, VectorContinuousCallback}, - cb_time, prev_sign, event_idx - ) + integrator, + callback::Union{ContinuousCallback,VectorContinuousCallback}, + cb_time, prev_sign, event_idx +) if isadaptive(integrator) set_proposed_dt!( integrator, @@ -479,20 +517,20 @@ function apply_callback!( integrator.u_modified = false else callback isa VectorContinuousCallback ? - callback.affect!(integrator, event_idx) : callback.affect!(integrator) + callback.affect!(integrator, event_idx) : callback.affect!(integrator) end elseif prev_sign > 0 if callback.affect_neg! === nothing integrator.u_modified = false else callback isa VectorContinuousCallback ? - callback.affect_neg!(integrator, event_idx) : callback.affect_neg!(integrator) + callback.affect_neg!(integrator, event_idx) : callback.affect_neg!(integrator) end end if integrator.u_modified reeval_internals_due_to_modification!( - integrator, callback_initializealg = callback.initializealg + integrator, callback_initializealg=callback.initializealg ) @inbounds if callback.save_positions[2] @@ -526,7 +564,7 @@ end callback.affect!(integrator) if integrator.u_modified reeval_internals_due_to_modification!( - integrator, false, callback_initializealg = callback.initializealg + integrator, false, callback_initializealg=callback.initializealg ) end @inbounds if callback.save_positions[2] @@ -550,12 +588,12 @@ end end @inline function apply_discrete_callback!( - integrator, discrete_modified::Bool, - saved_in_cb::Bool, callback::DiscreteCallback, - args... - ) + integrator, discrete_modified::Bool, + saved_in_cb::Bool, callback::DiscreteCallback, + args... +) bool, - saved_in_cb2 = apply_discrete_callback!( + saved_in_cb2 = apply_discrete_callback!( integrator, apply_discrete_callback!( integrator, @@ -567,9 +605,9 @@ end end @inline function apply_discrete_callback!( - integrator, discrete_modified::Bool, - saved_in_cb::Bool, callback::DiscreteCallback - ) + integrator, discrete_modified::Bool, + saved_in_cb::Bool, callback::DiscreteCallback +) bool, saved_in_cb2 = apply_discrete_callback!(integrator, callback) return discrete_modified || bool, saved_in_cb || saved_in_cb2 end @@ -605,7 +643,7 @@ end """ $(TYPEDEF) """ -mutable struct CallbackCache{conditionType, signType} +mutable struct CallbackCache{conditionType,signType} tmp_condition::conditionType next_condition::conditionType next_sign::signType @@ -613,9 +651,9 @@ mutable struct CallbackCache{conditionType, signType} end function CallbackCache( - u, max_len, ::Type{conditionType}, - ::Type{signType} - ) where {conditionType, signType} + u, max_len, ::Type{conditionType}, + ::Type{signType} +) where {conditionType,signType} tmp_condition = similar(u, conditionType, max_len) next_condition = similar(u, conditionType, max_len) next_sign = similar(u, signType, max_len) @@ -624,9 +662,9 @@ function CallbackCache( end function CallbackCache( - max_len, ::Type{conditionType}, - ::Type{signType} - ) where {conditionType, signType} + max_len, ::Type{conditionType}, + ::Type{signType} +) where {conditionType,signType} tmp_condition = zeros(conditionType, max_len) next_condition = zeros(conditionType, max_len) next_sign = zeros(signType, max_len) diff --git a/test/gen.jl b/test/gen.jl new file mode 100644 index 000000000..8496a171a --- /dev/null +++ b/test/gen.jl @@ -0,0 +1,77 @@ +using OrdinaryDiffEqTsit5, OrdinaryDiffEqCore, SciMLBase, LinearAlgebra + +# Lotka-Volterra equations +function lotka_volterra!(du, u, p, t) + α, β, δ, γ = 1.5, 1.0, 3.0, 1.0 + x, y = u + du[1] = α * x - β * x * y + du[2] = δ * x * y - γ * y + return nothing +end + +function run_once(; seed=nothing) + if seed !== nothing + Random.seed!(seed) + end + + # Random coefficients for two linear conditions: c' * u + coeffs1 = randn(2) + + u0 = [1.0, 1.0] + tspan = (0.0, 20.0) + + # Record initial signs + initial_signs = [sign(dot(coeffs1, u0))] + + # VCC condition: two linear functions of state + function vcc_condition!(out, u, t, integrator) + out[1] = dot(coeffs1, u) + return nothing + end + + function vcc_affect!(integrator, event_index) + u = integrator.u + vals = [dot(coeffs1, u)] + v = vals[event_index] + if !iszero(v) && sign(v) == initial_signs[event_index] + @show coeffs1 u event_index v initial_signs + error("VCC fired but value has same sign as initial — RightRootFind bug?") + else + # termine simulation + terminate!(integrator) + end + return nothing + end + + cb = VectorContinuousCallback( + vcc_condition!, + vcc_affect!, + 1; + rootfind=SciMLBase.RightRootFind, + ) + + prob = ODEProblem(lotka_volterra!, u0, tspan) + sol = solve(prob, Tsit5(); callback=cb, abstol=1e-10, reltol=1e-10) + return sol +end + +# Main loop — fish for the bug +i = 0 +while true + global i += 1 + if i % 1000 == 0 + println("Iteration $i ...") + end + try + run_once() + catch e + if e isa ErrorException && contains(e.msg, "RightRootFind") + println("\n*** Bug found at iteration $i ***") + rethrow() + else + rethrow() + end + end +end + +println("No bug found after 100_000 iterations.") diff --git a/test/lv.jl b/test/lv.jl new file mode 100644 index 000000000..6e20591d4 --- /dev/null +++ b/test/lv.jl @@ -0,0 +1,69 @@ +using OrdinaryDiffEqTsit5, OrdinaryDiffEqCore, SciMLBase, LinearAlgebra + +# Lotka-Volterra equations +function lotka_volterra!(du, u, p, t) + α, β, δ, γ = 1.5, 1.0, 3.0, 1.0 + x, y = u + du[1] = α * x - β * x * y + du[2] = δ * x * y - γ * y + return nothing +end + +# coeffs1 = [0.6825223495861318, -0.4295322984152052] +# coeffs2 = [1.7358772252665537, -1.0070061675696311] + +coeffs1 = [2.922772251297381, -2.8028553839288595] +u0 = [1.0, 1.0] +tspan = (0.0, 20.0) +tspan = (0.0, 0.03) + +# Record initial signs +initial_conditions = [dot(coeffs1, u0)] +initial_signs = sign.(initial_conditions) + +# VCC condition: two linear functions of state +function vcc_condition!(out, u, t, integrator) + out[1] = dot(coeffs1, u) + return nothing +end + +function vcc_affect!(integrator, event_index) + @show event_index, integrator.t + u = integrator.u + @show integrator.t + if event_index == 1 + @show integrator.u + println("Condition value at crossing: ", [dot(coeffs1, u)]) + terminate!(integrator) + end + return nothing +end + +cb = VectorContinuousCallback( + vcc_condition!, + vcc_affect!, + 1; + rootfind=SciMLBase.RightRootFind, +) + +println("Initial conditions: ", initial_conditions) + +prob = ODEProblem(lotka_volterra!, u0, tspan) +sol = solve(prob, Tsit5(); callback=cb, abstol=1e-10, reltol=1e-10, dense=true) +sol_u = solve(prob, Tsit5(); abstol=1e-10, reltol=1e-10, dense=true) +# sol = solve(prob, Tsit5(); abstol=1e-10, reltol=1e-10) +sol +nothing + +shift(τ, i) = + if iszero(i) + τ + elseif i > 0 + shift(nextfloat(τ), i - 1) + else + shift(prevfloat(τ), i + 1) + end + +# 0.23620973794890948 + +cond2(u) = dot(coeffs2, u) \ No newline at end of file From 6e914fc63961ab2bca2e3e7704c5d25ca8e44768 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Gonz=C3=A1lez?= Date: Tue, 3 Mar 2026 16:38:19 +0100 Subject: [PATCH 2/9] wip2 --- src/callbacks.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index f0ac3d889..82305b7f0 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -411,7 +411,8 @@ function find_root(f, tup, rootfind::SciMLBase.RootfindOpt) if is_inverted_root_pair(sol, f, tup) # "Inverted" root pair (#1290) - + if sol.resid > 0 && rootfind == SciMLBase.LeftRootFind + return find_root() if rootfind == SciMLBase.LeftRootFind From efd41f996fa7cf78baac326f3d5d8b4b63a5018f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Gonz=C3=A1lez?= Date: Tue, 3 Mar 2026 18:37:09 +0100 Subject: [PATCH 3/9] Implement inverted root pair management logic --- src/callbacks.jl | 183 +++++++++++++++++++++++------------------------ 1 file changed, 91 insertions(+), 92 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index 82305b7f0..af733b5e6 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -9,18 +9,18 @@ function initialize!(cb::CallbackSet, u, t, integrator::DEIntegrator) cb.discrete_callbacks... ) end -initialize!(cb::CallbackSet{Tuple{},Tuple{}}, u, t, integrator::DEIntegrator) = false +initialize!(cb::CallbackSet{Tuple{}, Tuple{}}, u, t, integrator::DEIntegrator) = false function initialize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback, cs::DECallback... -) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback, cs::DECallback... + ) c.initialize(c, u, t, integrator) return initialize!(u, t, integrator, any_modified || integrator.u_modified, cs...) end function initialize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback -) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback + ) c.initialize(c, u, t, integrator) return any_modified || integrator.u_modified end @@ -33,18 +33,18 @@ Recursively apply `finalize!` and return whether any modified u function finalize!(cb::CallbackSet, u, t, integrator::DEIntegrator) return finalize!(u, t, integrator, false, cb.continuous_callbacks..., cb.discrete_callbacks...) end -finalize!(cb::CallbackSet{Tuple{},Tuple{}}, u, t, integrator::DEIntegrator) = false +finalize!(cb::CallbackSet{Tuple{}, Tuple{}}, u, t, integrator::DEIntegrator) = false function finalize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback, cs::DECallback... -) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback, cs::DECallback... + ) c.finalize(c, u, t, integrator) return finalize!(u, t, integrator, any_modified || integrator.u_modified, cs...) end function finalize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback -) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback + ) c.finalize(c, u, t, integrator) return any_modified || integrator.u_modified end @@ -104,13 +104,13 @@ function get_condition(integrator::DEIntegrator, callback, abst) if callback.idxs === nothing integrator(tmp, abst, Val{0}) else - integrator(tmp, abst, Val{0}, idxs=callback.idxs) + integrator(tmp, abst, Val{0}, idxs = callback.idxs) end else if callback.idxs === nothing tmp = integrator(abst, Val{0}) else - tmp = integrator(abst, Val{0}, idxs=callback.idxs) + tmp = integrator(abst, Val{0}, idxs = callback.idxs) end end # ismutable && !(callback.idxs isa Number) ? integrator(tmp,abst,Val{0},idxs=callback.idxs) : @@ -130,24 +130,24 @@ end # Use a generated function for type stability even when many callbacks are given @inline function find_first_continuous_callback( - integrator, - callbacks::Vararg{ - AbstractContinuousCallback, - N, - } -) where {N} + integrator, + callbacks::Vararg{ + AbstractContinuousCallback, + N, + } + ) where {N} return find_first_continuous_callback(integrator, tuple(callbacks...)) end @generated function find_first_continuous_callback( - integrator, - callbacks::NTuple{ - N, - AbstractContinuousCallback, - } -) where {N} + integrator, + callbacks::NTuple{ + N, + AbstractContinuousCallback, + } + ) where {N} ex = quote tmin, upcrossing, - event_occurred, event_idx, residual = find_callback_time( + event_occurred, event_idx, residual = find_callback_time( integrator, callbacks[1], 1 ) @@ -157,7 +157,7 @@ end ex = quote $ex tmin2, upcrossing2, - event_occurred2, event_idx2, residual2 = find_callback_time( + event_occurred2, event_idx2, residual2 = find_callback_time( integrator, callbacks[$i], $i @@ -183,9 +183,9 @@ end end @inline function find_callback_time( - integrator, callback::VectorContinuousCallback, - callback_idx -) + integrator, callback::VectorContinuousCallback, + callback_idx + ) if callback.interp_points != 0 addsteps!(integrator) end @@ -239,7 +239,7 @@ end min_event_idx = -1 for idx in 1:length(event_idx) if ArrayInterface.allowed_getindex(event_idx, idx) != 0 - function zero_func(abst, p=nothing) + function zero_func(abst, p = nothing) return ArrayInterface.allowed_getindex( get_condition( integrator, @@ -271,13 +271,13 @@ end end return callback_t, ArrayInterface.allowed_getindex(bottom_sign, min_event_idx), - event_occurred::Bool, min_event_idx::Int, residual + event_occurred::Bool, min_event_idx::Int, residual end @inline function find_callback_time( - integrator, callback::ContinuousCallback, - callback_idx -) + integrator, callback::ContinuousCallback, + callback_idx + ) if callback.interp_points != 0 addsteps!(integrator) end @@ -310,7 +310,7 @@ end residual = zero(bottom_condition) else # Find callback time - zero_func(abst, p=nothing) = get_condition(integrator, callback, abst) + zero_func(abst, p = nothing) = get_condition(integrator, callback, abst) callback_t = find_root(zero_func, (bottom_t, top_t), callback.rootfind) residual = zero_func(callback_t) end @@ -344,9 +344,9 @@ function check_event_occurence(integrator, callback, bottom_sign) check_event_occurence_upto(integrator, callback, bottom_sign, top_t) if callback.interp_points != 0 && !isdiscrete(integrator.alg) && - any(iszero, event_idx) + any(iszero, event_idx) # Use the interpolants for safety checking - ts = range(integrator.tprev, stop=integrator.t, length=callback.interp_points) + ts = range(integrator.tprev, stop = integrator.t, length = callback.interp_points) for i in 2:length(ts) top_t = ts[i] event_occurred, event_idx, top_sign = @@ -385,13 +385,13 @@ function check_event_occurence_upto(integrator, callback::VectorContinuousCallba end _shift(τ, i) = - if iszero(i) - τ - elseif i > 0 - _shift(nextfloat(τ), i - 1) - else - _shift(prevfloat(τ), i + 1) - end +if iszero(i) + τ +elseif i > 0 + _shift(nextfloat(τ), i - 1) +else + _shift(prevfloat(τ), i + 1) +end """ Find either exact or floating point precision root of `f`. @@ -405,15 +405,18 @@ Assumes that: function find_root(f, tup, rootfind::SciMLBase.RootfindOpt) sol = solve( IntervalNonlinearProblem{false}(f, tup), - ModAB(), abstol=0.0, reltol=0.0 + ModAB(), abstol = 0.0, reltol = 0.0 ) - if is_inverted_root_pair(sol, f, tup) - # "Inverted" root pair (#1290) - if sol.resid > 0 && rootfind == SciMLBase.LeftRootFind - return find_root() - + # "Inverted" root pair (#1290); direction of integration flips the bracket side + return if (sol.resid > 0) ⊻ (tup[1] > tup[2]) + find_root(f, (tup[1], sol.left), rootfind) + else + find_root(f, (sol.right, tup[2]), rootfind) + end + end + if rootfind == SciMLBase.LeftRootFind return sol.left @@ -428,21 +431,17 @@ the condition function, has the opposite signs as the initial bracket. This can numerical floating point noise. """ function is_inverted_root_pair(sol, f, tup) - + # Fast path (f(t) == 0). Alternative implementation: check if # sol.retcode ∈ (ReturnCode.ExactSolutionLeft, ReturnCode.ExactSolutionRight) - iszero(sol.resid) && return false + iszero(sol.resid) && return false # Under current implementation of ModAB, sol.resid = f(max(sol.left, sol.right)) # Therefore, the residual should have the same sign as the condition function evaluated # at maximum(tup); otherwise, the root pair is inverted. - sign(sol.resid) != sign(f(maximum(tup))) + return sign(sol.resid) != sign(f(maximum(tup))) end - - - - """ findall_events!(next_sign,affect!,affect_neg!,prev_sign) @@ -452,20 +451,20 @@ in the interval between prev_sign and next_sign. Return `true` if any event occured. """ function findall_events!( - next_sign::Union{Array,SubArray}, affect!::F1, affect_neg!::F2, - prev_sign::Union{Array,SubArray} -) where {F1,F2} + next_sign::Union{Array, SubArray}, affect!::F1, affect_neg!::F2, + prev_sign::Union{Array, SubArray} + ) where {F1, F2} @inbounds for i in 1:length(prev_sign) next_sign[i] = ( (prev_sign[i] < 0 && affect! !== nothing) || - (prev_sign[i] > 0 && affect_neg! !== nothing) + (prev_sign[i] > 0 && affect_neg! !== nothing) ) && - prev_sign[i] * next_sign[i] <= 0 + prev_sign[i] * next_sign[i] <= 0 end return any(isone, next_sign) end -function findall_events!(next_sign, affect!::F1, affect_neg!::F2, prev_sign) where {F1,F2} +function findall_events!(next_sign, affect!::F1, affect_neg!::F2, prev_sign) where {F1, F2} hasaffect::Bool = affect! !== nothing hasaffectneg::Bool = affect_neg! !== nothing f = (n, p) -> ((p < 0 && hasaffect) || (p > 0 && hasaffectneg)) && p * n <= 0 @@ -476,18 +475,18 @@ end """ Return `true` if an event occured. """ -function is_event_occurence(prev_sign::Number, next_sign::Number, affect!::F1, affect_neg!::F2) where {F1,F2} +function is_event_occurence(prev_sign::Number, next_sign::Number, affect!::F1, affect_neg!::F2) where {F1, F2} return ( (prev_sign < 0 && affect! !== nothing) || - (prev_sign > 0 && affect_neg! !== nothing) + (prev_sign > 0 && affect_neg! !== nothing) ) && prev_sign * next_sign <= 0 end function apply_callback!( - integrator, - callback::Union{ContinuousCallback,VectorContinuousCallback}, - cb_time, prev_sign, event_idx -) + integrator, + callback::Union{ContinuousCallback, VectorContinuousCallback}, + cb_time, prev_sign, event_idx + ) if isadaptive(integrator) set_proposed_dt!( integrator, @@ -518,20 +517,20 @@ function apply_callback!( integrator.u_modified = false else callback isa VectorContinuousCallback ? - callback.affect!(integrator, event_idx) : callback.affect!(integrator) + callback.affect!(integrator, event_idx) : callback.affect!(integrator) end elseif prev_sign > 0 if callback.affect_neg! === nothing integrator.u_modified = false else callback isa VectorContinuousCallback ? - callback.affect_neg!(integrator, event_idx) : callback.affect_neg!(integrator) + callback.affect_neg!(integrator, event_idx) : callback.affect_neg!(integrator) end end if integrator.u_modified reeval_internals_due_to_modification!( - integrator, callback_initializealg=callback.initializealg + integrator, callback_initializealg = callback.initializealg ) @inbounds if callback.save_positions[2] @@ -565,7 +564,7 @@ end callback.affect!(integrator) if integrator.u_modified reeval_internals_due_to_modification!( - integrator, false, callback_initializealg=callback.initializealg + integrator, false, callback_initializealg = callback.initializealg ) end @inbounds if callback.save_positions[2] @@ -589,12 +588,12 @@ end end @inline function apply_discrete_callback!( - integrator, discrete_modified::Bool, - saved_in_cb::Bool, callback::DiscreteCallback, - args... -) + integrator, discrete_modified::Bool, + saved_in_cb::Bool, callback::DiscreteCallback, + args... + ) bool, - saved_in_cb2 = apply_discrete_callback!( + saved_in_cb2 = apply_discrete_callback!( integrator, apply_discrete_callback!( integrator, @@ -606,9 +605,9 @@ end end @inline function apply_discrete_callback!( - integrator, discrete_modified::Bool, - saved_in_cb::Bool, callback::DiscreteCallback -) + integrator, discrete_modified::Bool, + saved_in_cb::Bool, callback::DiscreteCallback + ) bool, saved_in_cb2 = apply_discrete_callback!(integrator, callback) return discrete_modified || bool, saved_in_cb || saved_in_cb2 end @@ -644,7 +643,7 @@ end """ $(TYPEDEF) """ -mutable struct CallbackCache{conditionType,signType} +mutable struct CallbackCache{conditionType, signType} tmp_condition::conditionType next_condition::conditionType next_sign::signType @@ -652,9 +651,9 @@ mutable struct CallbackCache{conditionType,signType} end function CallbackCache( - u, max_len, ::Type{conditionType}, - ::Type{signType} -) where {conditionType,signType} + u, max_len, ::Type{conditionType}, + ::Type{signType} + ) where {conditionType, signType} tmp_condition = similar(u, conditionType, max_len) next_condition = similar(u, conditionType, max_len) next_sign = similar(u, signType, max_len) @@ -663,9 +662,9 @@ function CallbackCache( end function CallbackCache( - max_len, ::Type{conditionType}, - ::Type{signType} -) where {conditionType,signType} + max_len, ::Type{conditionType}, + ::Type{signType} + ) where {conditionType, signType} tmp_condition = zeros(conditionType, max_len) next_condition = zeros(conditionType, max_len) next_sign = zeros(signType, max_len) From 4156a5bf1b09676d4f8c0c15965fd649258a5db0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Gonz=C3=A1lez?= Date: Tue, 3 Mar 2026 19:57:37 +0100 Subject: [PATCH 4/9] Clean-up edge cases --- src/callbacks.jl | 264 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 182 insertions(+), 82 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index af733b5e6..a66b9fd79 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -9,18 +9,18 @@ function initialize!(cb::CallbackSet, u, t, integrator::DEIntegrator) cb.discrete_callbacks... ) end -initialize!(cb::CallbackSet{Tuple{}, Tuple{}}, u, t, integrator::DEIntegrator) = false +initialize!(cb::CallbackSet{Tuple{},Tuple{}}, u, t, integrator::DEIntegrator) = false function initialize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback, cs::DECallback... - ) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback, cs::DECallback... +) c.initialize(c, u, t, integrator) return initialize!(u, t, integrator, any_modified || integrator.u_modified, cs...) end function initialize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback - ) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback +) c.initialize(c, u, t, integrator) return any_modified || integrator.u_modified end @@ -33,18 +33,18 @@ Recursively apply `finalize!` and return whether any modified u function finalize!(cb::CallbackSet, u, t, integrator::DEIntegrator) return finalize!(u, t, integrator, false, cb.continuous_callbacks..., cb.discrete_callbacks...) end -finalize!(cb::CallbackSet{Tuple{}, Tuple{}}, u, t, integrator::DEIntegrator) = false +finalize!(cb::CallbackSet{Tuple{},Tuple{}}, u, t, integrator::DEIntegrator) = false function finalize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback, cs::DECallback... - ) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback, cs::DECallback... +) c.finalize(c, u, t, integrator) return finalize!(u, t, integrator, any_modified || integrator.u_modified, cs...) end function finalize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback - ) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback +) c.finalize(c, u, t, integrator) return any_modified || integrator.u_modified end @@ -104,13 +104,13 @@ function get_condition(integrator::DEIntegrator, callback, abst) if callback.idxs === nothing integrator(tmp, abst, Val{0}) else - integrator(tmp, abst, Val{0}, idxs = callback.idxs) + integrator(tmp, abst, Val{0}, idxs=callback.idxs) end else if callback.idxs === nothing tmp = integrator(abst, Val{0}) else - tmp = integrator(abst, Val{0}, idxs = callback.idxs) + tmp = integrator(abst, Val{0}, idxs=callback.idxs) end end # ismutable && !(callback.idxs isa Number) ? integrator(tmp,abst,Val{0},idxs=callback.idxs) : @@ -130,24 +130,24 @@ end # Use a generated function for type stability even when many callbacks are given @inline function find_first_continuous_callback( - integrator, - callbacks::Vararg{ - AbstractContinuousCallback, - N, - } - ) where {N} + integrator, + callbacks::Vararg{ + AbstractContinuousCallback, + N, + } +) where {N} return find_first_continuous_callback(integrator, tuple(callbacks...)) end @generated function find_first_continuous_callback( - integrator, - callbacks::NTuple{ - N, - AbstractContinuousCallback, - } - ) where {N} + integrator, + callbacks::NTuple{ + N, + AbstractContinuousCallback, + } +) where {N} ex = quote tmin, upcrossing, - event_occurred, event_idx, residual = find_callback_time( + event_occurred, event_idx, residual = find_callback_time( integrator, callbacks[1], 1 ) @@ -157,7 +157,7 @@ end ex = quote $ex tmin2, upcrossing2, - event_occurred2, event_idx2, residual2 = find_callback_time( + event_occurred2, event_idx2, residual2 = find_callback_time( integrator, callbacks[$i], $i @@ -183,9 +183,9 @@ end end @inline function find_callback_time( - integrator, callback::VectorContinuousCallback, - callback_idx - ) + integrator, callback::VectorContinuousCallback, + callback_idx +) if callback.interp_points != 0 addsteps!(integrator) end @@ -239,7 +239,7 @@ end min_event_idx = -1 for idx in 1:length(event_idx) if ArrayInterface.allowed_getindex(event_idx, idx) != 0 - function zero_func(abst, p = nothing) + function zero_func(abst, p=nothing) return ArrayInterface.allowed_getindex( get_condition( integrator, @@ -271,13 +271,13 @@ end end return callback_t, ArrayInterface.allowed_getindex(bottom_sign, min_event_idx), - event_occurred::Bool, min_event_idx::Int, residual + event_occurred::Bool, min_event_idx::Int, residual end @inline function find_callback_time( - integrator, callback::ContinuousCallback, - callback_idx - ) + integrator, callback::ContinuousCallback, + callback_idx +) if callback.interp_points != 0 addsteps!(integrator) end @@ -310,7 +310,7 @@ end residual = zero(bottom_condition) else # Find callback time - zero_func(abst, p = nothing) = get_condition(integrator, callback, abst) + zero_func(abst, p=nothing) = get_condition(integrator, callback, abst) callback_t = find_root(zero_func, (bottom_t, top_t), callback.rootfind) residual = zero_func(callback_t) end @@ -344,9 +344,9 @@ function check_event_occurence(integrator, callback, bottom_sign) check_event_occurence_upto(integrator, callback, bottom_sign, top_t) if callback.interp_points != 0 && !isdiscrete(integrator.alg) && - any(iszero, event_idx) + any(iszero, event_idx) # Use the interpolants for safety checking - ts = range(integrator.tprev, stop = integrator.t, length = callback.interp_points) + ts = range(integrator.tprev, stop=integrator.t, length=callback.interp_points) for i in 2:length(ts) top_t = ts[i] event_occurred, event_idx, top_sign = @@ -385,12 +385,109 @@ function check_event_occurence_upto(integrator, callback::VectorContinuousCallba end _shift(τ, i) = -if iszero(i) - τ -elseif i > 0 - _shift(nextfloat(τ), i - 1) -else - _shift(prevfloat(τ), i + 1) + if iszero(i) + τ + elseif i > 0 + _shift(nextfloat(τ), i - 1) + else + _shift(prevfloat(τ), i + 1) + end + +using BracketingNonlinearSolve: AbstractBracketingAlgorithm, NonlinearVerbosity, NonlinearSolveBase, build_bracketing_solution +import SciMLBase: @SciMLMessage +struct ModAB2 <: AbstractBracketingAlgorithm +end + +function SciMLBase.__solve( + prob::IntervalNonlinearProblem, alg::ModAB2, args...; + maxiters=1000, abstol=nothing, verbose::NonlinearVerbosity=NonlinearVerbosity(), kwargs... +) + @assert !SciMLBase.isinplace(prob) "`ModAB` only supports out-of-place problems." + + f = Base.Fix2(prob.f, prob.p) + x1, x2 = minmax(promote(prob.tspan...)...) + y1, y2 = f(x1), f(x2) + + abstol = NonlinearSolveBase.get_tolerance( + x1, abstol, promote_type(eltype(x1), eltype(x2)) + ) + + if iszero(y1) + return build_exact_solution(prob, alg, x1, y1, ReturnCode.ExactSolutionLeft) + end + + if iszero(y2) + return build_exact_solution(prob, alg, x2, y2, ReturnCode.ExactSolutionRight) + end + + if sign(y1) == sign(y2) + @SciMLMessage( + "The interval is not an enclosing interval, opposite signs at the \ + boundaries are required.", + verbose, :non_enclosing_interval + ) + return build_bracketing_solution(prob, alg, x1, y1, x1, x2, ReturnCode.InitialFailure) + end + + bisecting = true + side = 0 # tracks the side that has moved at the previous iteration + ϵ = abstol + i = 1 + threshold = x2 - x1 # Threshold to fall back to bisection if AB fails to shrink the interval enough + C = 16 # safety factor for threshold corresponding to 4 iterations = 2^4 + while i < maxiters + local x3, y3 + if bisecting # Bisection method is used + x3 = (x1 + x2) / 2 + y3 = f(x3) # Function value at midpoint + ym = (y1 + y2) / 2 # Ordinate of chord at midpoint + # calculate k on each bisection step with account for local function properties and symmetry + r = 1 - abs(ym / (y2 - y1)) # Symmetry factor + k = r * r # Deviation factor + # Check if the function is close enough to linear + if abs(ym - y3) < k * (abs(ym) + abs(y3)) + bisecting = false + threshold = (x2 - x1) * C + end + else # Anderson-Bjork method is used + # x3 = clamp((x1 * y2 - y1 * x2) / (y2 - y1), x1, x2) + x3 = (x1 * y2 - y1 * x2) / (y2 - y1) + y3 = f(x3) + threshold /= 2 + end + if iszero(y3) + return build_exact_solution(prob, alg, x3, y3, ReturnCode.Success) + elseif (x2 - x1) < 2ϵ + return build_bracketing_solution(prob, alg, x3, y3, x1, x2, ReturnCode.Success) + end + x0 = x3 + if sign(y1) == sign(y3) + if side == 1 # Apply Anderson-Bjork correction on the right side + m = 1 - y3 / y1 + y2 *= m <= 0 ? inv(2 * one(y1)) : m + elseif !bisecting + side = 1 + end + x1, y1 = x3, y3 + else + if side == -1 # Apply Anderson-Bjork correction on the left side + m = 1 - y3 / y2 + y1 *= m <= 0 ? inv(2 * one(y1)) : m + elseif !bisecting + side = -1 + end + x2, y2 = x3, y3 + end + if nextfloat(x1) == x2 + return build_bracketing_solution(prob, alg, x2, f(x2), x1, x2, ReturnCode.FloatingPointLimit) + end + i += 1 + if x2 - x1 > threshold # If AB fails to shrink the interval enough + bisecting = true # reset to bisection + side = 0 + end + end + return build_bracketing_solution(prob, alg, x1, y1, x1, x2, ReturnCode.MaxIters) end """ @@ -405,9 +502,8 @@ Assumes that: function find_root(f, tup, rootfind::SciMLBase.RootfindOpt) sol = solve( IntervalNonlinearProblem{false}(f, tup), - ModAB(), abstol = 0.0, reltol = 0.0 + ModAB2(), abstol=0.0, reltol=0.0 ) - if is_inverted_root_pair(sol, f, tup) # "Inverted" root pair (#1290); direction of integration flips the bracket side return if (sol.resid > 0) ⊻ (tup[1] > tup[2]) @@ -417,7 +513,6 @@ function find_root(f, tup, rootfind::SciMLBase.RootfindOpt) end end - if rootfind == SciMLBase.LeftRootFind return sol.left else @@ -436,10 +531,15 @@ function is_inverted_root_pair(sol, f, tup) # sol.retcode ∈ (ReturnCode.ExactSolutionLeft, ReturnCode.ExactSolutionRight) iszero(sol.resid) && return false + # Should be equal to sol.resid under current implementation of ModAB, but this is + # more robust against implementation changes and it also works around ModAB#860 + most_positive_residual = f(max(sol.left, sol.right)) + # Under current implementation of ModAB, sol.resid = f(max(sol.left, sol.right)) # Therefore, the residual should have the same sign as the condition function evaluated # at maximum(tup); otherwise, the root pair is inverted. - return sign(sol.resid) != sign(f(maximum(tup))) + # @show sol.resid, f(maximum(tup)) + return sign(most_positive_residual) != sign(f(maximum(tup))) end @@ -451,20 +551,20 @@ in the interval between prev_sign and next_sign. Return `true` if any event occured. """ function findall_events!( - next_sign::Union{Array, SubArray}, affect!::F1, affect_neg!::F2, - prev_sign::Union{Array, SubArray} - ) where {F1, F2} + next_sign::Union{Array,SubArray}, affect!::F1, affect_neg!::F2, + prev_sign::Union{Array,SubArray} +) where {F1,F2} @inbounds for i in 1:length(prev_sign) next_sign[i] = ( (prev_sign[i] < 0 && affect! !== nothing) || - (prev_sign[i] > 0 && affect_neg! !== nothing) + (prev_sign[i] > 0 && affect_neg! !== nothing) ) && - prev_sign[i] * next_sign[i] <= 0 + prev_sign[i] * next_sign[i] <= 0 end return any(isone, next_sign) end -function findall_events!(next_sign, affect!::F1, affect_neg!::F2, prev_sign) where {F1, F2} +function findall_events!(next_sign, affect!::F1, affect_neg!::F2, prev_sign) where {F1,F2} hasaffect::Bool = affect! !== nothing hasaffectneg::Bool = affect_neg! !== nothing f = (n, p) -> ((p < 0 && hasaffect) || (p > 0 && hasaffectneg)) && p * n <= 0 @@ -475,18 +575,18 @@ end """ Return `true` if an event occured. """ -function is_event_occurence(prev_sign::Number, next_sign::Number, affect!::F1, affect_neg!::F2) where {F1, F2} +function is_event_occurence(prev_sign::Number, next_sign::Number, affect!::F1, affect_neg!::F2) where {F1,F2} return ( (prev_sign < 0 && affect! !== nothing) || - (prev_sign > 0 && affect_neg! !== nothing) + (prev_sign > 0 && affect_neg! !== nothing) ) && prev_sign * next_sign <= 0 end function apply_callback!( - integrator, - callback::Union{ContinuousCallback, VectorContinuousCallback}, - cb_time, prev_sign, event_idx - ) + integrator, + callback::Union{ContinuousCallback,VectorContinuousCallback}, + cb_time, prev_sign, event_idx +) if isadaptive(integrator) set_proposed_dt!( integrator, @@ -517,20 +617,20 @@ function apply_callback!( integrator.u_modified = false else callback isa VectorContinuousCallback ? - callback.affect!(integrator, event_idx) : callback.affect!(integrator) + callback.affect!(integrator, event_idx) : callback.affect!(integrator) end elseif prev_sign > 0 if callback.affect_neg! === nothing integrator.u_modified = false else callback isa VectorContinuousCallback ? - callback.affect_neg!(integrator, event_idx) : callback.affect_neg!(integrator) + callback.affect_neg!(integrator, event_idx) : callback.affect_neg!(integrator) end end if integrator.u_modified reeval_internals_due_to_modification!( - integrator, callback_initializealg = callback.initializealg + integrator, callback_initializealg=callback.initializealg ) @inbounds if callback.save_positions[2] @@ -564,7 +664,7 @@ end callback.affect!(integrator) if integrator.u_modified reeval_internals_due_to_modification!( - integrator, false, callback_initializealg = callback.initializealg + integrator, false, callback_initializealg=callback.initializealg ) end @inbounds if callback.save_positions[2] @@ -588,12 +688,12 @@ end end @inline function apply_discrete_callback!( - integrator, discrete_modified::Bool, - saved_in_cb::Bool, callback::DiscreteCallback, - args... - ) + integrator, discrete_modified::Bool, + saved_in_cb::Bool, callback::DiscreteCallback, + args... +) bool, - saved_in_cb2 = apply_discrete_callback!( + saved_in_cb2 = apply_discrete_callback!( integrator, apply_discrete_callback!( integrator, @@ -605,9 +705,9 @@ end end @inline function apply_discrete_callback!( - integrator, discrete_modified::Bool, - saved_in_cb::Bool, callback::DiscreteCallback - ) + integrator, discrete_modified::Bool, + saved_in_cb::Bool, callback::DiscreteCallback +) bool, saved_in_cb2 = apply_discrete_callback!(integrator, callback) return discrete_modified || bool, saved_in_cb || saved_in_cb2 end @@ -643,7 +743,7 @@ end """ $(TYPEDEF) """ -mutable struct CallbackCache{conditionType, signType} +mutable struct CallbackCache{conditionType,signType} tmp_condition::conditionType next_condition::conditionType next_sign::signType @@ -651,9 +751,9 @@ mutable struct CallbackCache{conditionType, signType} end function CallbackCache( - u, max_len, ::Type{conditionType}, - ::Type{signType} - ) where {conditionType, signType} + u, max_len, ::Type{conditionType}, + ::Type{signType} +) where {conditionType,signType} tmp_condition = similar(u, conditionType, max_len) next_condition = similar(u, conditionType, max_len) next_sign = similar(u, signType, max_len) @@ -662,9 +762,9 @@ function CallbackCache( end function CallbackCache( - max_len, ::Type{conditionType}, - ::Type{signType} - ) where {conditionType, signType} + max_len, ::Type{conditionType}, + ::Type{signType} +) where {conditionType,signType} tmp_condition = zeros(conditionType, max_len) next_condition = zeros(conditionType, max_len) next_sign = zeros(signType, max_len) From 8884593b93702bf8ceaaf3b80535dcb9875f799a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Gonz=C3=A1lez?= Date: Tue, 3 Mar 2026 20:01:21 +0100 Subject: [PATCH 5/9] Remove debug code --- src/callbacks.jl | 97 ------------------------------------------------ 1 file changed, 97 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index a66b9fd79..d04170344 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -393,103 +393,6 @@ _shift(τ, i) = _shift(prevfloat(τ), i + 1) end -using BracketingNonlinearSolve: AbstractBracketingAlgorithm, NonlinearVerbosity, NonlinearSolveBase, build_bracketing_solution -import SciMLBase: @SciMLMessage -struct ModAB2 <: AbstractBracketingAlgorithm -end - -function SciMLBase.__solve( - prob::IntervalNonlinearProblem, alg::ModAB2, args...; - maxiters=1000, abstol=nothing, verbose::NonlinearVerbosity=NonlinearVerbosity(), kwargs... -) - @assert !SciMLBase.isinplace(prob) "`ModAB` only supports out-of-place problems." - - f = Base.Fix2(prob.f, prob.p) - x1, x2 = minmax(promote(prob.tspan...)...) - y1, y2 = f(x1), f(x2) - - abstol = NonlinearSolveBase.get_tolerance( - x1, abstol, promote_type(eltype(x1), eltype(x2)) - ) - - if iszero(y1) - return build_exact_solution(prob, alg, x1, y1, ReturnCode.ExactSolutionLeft) - end - - if iszero(y2) - return build_exact_solution(prob, alg, x2, y2, ReturnCode.ExactSolutionRight) - end - - if sign(y1) == sign(y2) - @SciMLMessage( - "The interval is not an enclosing interval, opposite signs at the \ - boundaries are required.", - verbose, :non_enclosing_interval - ) - return build_bracketing_solution(prob, alg, x1, y1, x1, x2, ReturnCode.InitialFailure) - end - - bisecting = true - side = 0 # tracks the side that has moved at the previous iteration - ϵ = abstol - i = 1 - threshold = x2 - x1 # Threshold to fall back to bisection if AB fails to shrink the interval enough - C = 16 # safety factor for threshold corresponding to 4 iterations = 2^4 - while i < maxiters - local x3, y3 - if bisecting # Bisection method is used - x3 = (x1 + x2) / 2 - y3 = f(x3) # Function value at midpoint - ym = (y1 + y2) / 2 # Ordinate of chord at midpoint - # calculate k on each bisection step with account for local function properties and symmetry - r = 1 - abs(ym / (y2 - y1)) # Symmetry factor - k = r * r # Deviation factor - # Check if the function is close enough to linear - if abs(ym - y3) < k * (abs(ym) + abs(y3)) - bisecting = false - threshold = (x2 - x1) * C - end - else # Anderson-Bjork method is used - # x3 = clamp((x1 * y2 - y1 * x2) / (y2 - y1), x1, x2) - x3 = (x1 * y2 - y1 * x2) / (y2 - y1) - y3 = f(x3) - threshold /= 2 - end - if iszero(y3) - return build_exact_solution(prob, alg, x3, y3, ReturnCode.Success) - elseif (x2 - x1) < 2ϵ - return build_bracketing_solution(prob, alg, x3, y3, x1, x2, ReturnCode.Success) - end - x0 = x3 - if sign(y1) == sign(y3) - if side == 1 # Apply Anderson-Bjork correction on the right side - m = 1 - y3 / y1 - y2 *= m <= 0 ? inv(2 * one(y1)) : m - elseif !bisecting - side = 1 - end - x1, y1 = x3, y3 - else - if side == -1 # Apply Anderson-Bjork correction on the left side - m = 1 - y3 / y2 - y1 *= m <= 0 ? inv(2 * one(y1)) : m - elseif !bisecting - side = -1 - end - x2, y2 = x3, y3 - end - if nextfloat(x1) == x2 - return build_bracketing_solution(prob, alg, x2, f(x2), x1, x2, ReturnCode.FloatingPointLimit) - end - i += 1 - if x2 - x1 > threshold # If AB fails to shrink the interval enough - bisecting = true # reset to bisection - side = 0 - end - end - return build_bracketing_solution(prob, alg, x1, y1, x1, x2, ReturnCode.MaxIters) -end - """ Find either exact or floating point precision root of `f`. If the exact root cannot be represented, return closest floating point number depending on `rootfind` From 9716b651755e64683ff9716969214baf6084a054 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Gonz=C3=A1lez?= Date: Tue, 3 Mar 2026 20:05:02 +0100 Subject: [PATCH 6/9] Reformat with Runic --- src/callbacks.jl | 160 +++++++++++++++++++++++------------------------ 1 file changed, 80 insertions(+), 80 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index d04170344..6817fb9a6 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -9,18 +9,18 @@ function initialize!(cb::CallbackSet, u, t, integrator::DEIntegrator) cb.discrete_callbacks... ) end -initialize!(cb::CallbackSet{Tuple{},Tuple{}}, u, t, integrator::DEIntegrator) = false +initialize!(cb::CallbackSet{Tuple{}, Tuple{}}, u, t, integrator::DEIntegrator) = false function initialize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback, cs::DECallback... -) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback, cs::DECallback... + ) c.initialize(c, u, t, integrator) return initialize!(u, t, integrator, any_modified || integrator.u_modified, cs...) end function initialize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback -) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback + ) c.initialize(c, u, t, integrator) return any_modified || integrator.u_modified end @@ -33,18 +33,18 @@ Recursively apply `finalize!` and return whether any modified u function finalize!(cb::CallbackSet, u, t, integrator::DEIntegrator) return finalize!(u, t, integrator, false, cb.continuous_callbacks..., cb.discrete_callbacks...) end -finalize!(cb::CallbackSet{Tuple{},Tuple{}}, u, t, integrator::DEIntegrator) = false +finalize!(cb::CallbackSet{Tuple{}, Tuple{}}, u, t, integrator::DEIntegrator) = false function finalize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback, cs::DECallback... -) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback, cs::DECallback... + ) c.finalize(c, u, t, integrator) return finalize!(u, t, integrator, any_modified || integrator.u_modified, cs...) end function finalize!( - u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback -) + u, t, integrator::DEIntegrator, any_modified::Bool, + c::DECallback + ) c.finalize(c, u, t, integrator) return any_modified || integrator.u_modified end @@ -104,13 +104,13 @@ function get_condition(integrator::DEIntegrator, callback, abst) if callback.idxs === nothing integrator(tmp, abst, Val{0}) else - integrator(tmp, abst, Val{0}, idxs=callback.idxs) + integrator(tmp, abst, Val{0}, idxs = callback.idxs) end else if callback.idxs === nothing tmp = integrator(abst, Val{0}) else - tmp = integrator(abst, Val{0}, idxs=callback.idxs) + tmp = integrator(abst, Val{0}, idxs = callback.idxs) end end # ismutable && !(callback.idxs isa Number) ? integrator(tmp,abst,Val{0},idxs=callback.idxs) : @@ -130,24 +130,24 @@ end # Use a generated function for type stability even when many callbacks are given @inline function find_first_continuous_callback( - integrator, - callbacks::Vararg{ - AbstractContinuousCallback, - N, - } -) where {N} + integrator, + callbacks::Vararg{ + AbstractContinuousCallback, + N, + } + ) where {N} return find_first_continuous_callback(integrator, tuple(callbacks...)) end @generated function find_first_continuous_callback( - integrator, - callbacks::NTuple{ - N, - AbstractContinuousCallback, - } -) where {N} + integrator, + callbacks::NTuple{ + N, + AbstractContinuousCallback, + } + ) where {N} ex = quote tmin, upcrossing, - event_occurred, event_idx, residual = find_callback_time( + event_occurred, event_idx, residual = find_callback_time( integrator, callbacks[1], 1 ) @@ -157,7 +157,7 @@ end ex = quote $ex tmin2, upcrossing2, - event_occurred2, event_idx2, residual2 = find_callback_time( + event_occurred2, event_idx2, residual2 = find_callback_time( integrator, callbacks[$i], $i @@ -183,9 +183,9 @@ end end @inline function find_callback_time( - integrator, callback::VectorContinuousCallback, - callback_idx -) + integrator, callback::VectorContinuousCallback, + callback_idx + ) if callback.interp_points != 0 addsteps!(integrator) end @@ -239,7 +239,7 @@ end min_event_idx = -1 for idx in 1:length(event_idx) if ArrayInterface.allowed_getindex(event_idx, idx) != 0 - function zero_func(abst, p=nothing) + function zero_func(abst, p = nothing) return ArrayInterface.allowed_getindex( get_condition( integrator, @@ -271,13 +271,13 @@ end end return callback_t, ArrayInterface.allowed_getindex(bottom_sign, min_event_idx), - event_occurred::Bool, min_event_idx::Int, residual + event_occurred::Bool, min_event_idx::Int, residual end @inline function find_callback_time( - integrator, callback::ContinuousCallback, - callback_idx -) + integrator, callback::ContinuousCallback, + callback_idx + ) if callback.interp_points != 0 addsteps!(integrator) end @@ -310,7 +310,7 @@ end residual = zero(bottom_condition) else # Find callback time - zero_func(abst, p=nothing) = get_condition(integrator, callback, abst) + zero_func(abst, p = nothing) = get_condition(integrator, callback, abst) callback_t = find_root(zero_func, (bottom_t, top_t), callback.rootfind) residual = zero_func(callback_t) end @@ -344,9 +344,9 @@ function check_event_occurence(integrator, callback, bottom_sign) check_event_occurence_upto(integrator, callback, bottom_sign, top_t) if callback.interp_points != 0 && !isdiscrete(integrator.alg) && - any(iszero, event_idx) + any(iszero, event_idx) # Use the interpolants for safety checking - ts = range(integrator.tprev, stop=integrator.t, length=callback.interp_points) + ts = range(integrator.tprev, stop = integrator.t, length = callback.interp_points) for i in 2:length(ts) top_t = ts[i] event_occurred, event_idx, top_sign = @@ -385,13 +385,13 @@ function check_event_occurence_upto(integrator, callback::VectorContinuousCallba end _shift(τ, i) = - if iszero(i) - τ - elseif i > 0 - _shift(nextfloat(τ), i - 1) - else - _shift(prevfloat(τ), i + 1) - end +if iszero(i) + τ +elseif i > 0 + _shift(nextfloat(τ), i - 1) +else + _shift(prevfloat(τ), i + 1) +end """ Find either exact or floating point precision root of `f`. @@ -405,7 +405,7 @@ Assumes that: function find_root(f, tup, rootfind::SciMLBase.RootfindOpt) sol = solve( IntervalNonlinearProblem{false}(f, tup), - ModAB2(), abstol=0.0, reltol=0.0 + ModAB2(), abstol = 0.0, reltol = 0.0 ) if is_inverted_root_pair(sol, f, tup) # "Inverted" root pair (#1290); direction of integration flips the bracket side @@ -454,20 +454,20 @@ in the interval between prev_sign and next_sign. Return `true` if any event occured. """ function findall_events!( - next_sign::Union{Array,SubArray}, affect!::F1, affect_neg!::F2, - prev_sign::Union{Array,SubArray} -) where {F1,F2} + next_sign::Union{Array, SubArray}, affect!::F1, affect_neg!::F2, + prev_sign::Union{Array, SubArray} + ) where {F1, F2} @inbounds for i in 1:length(prev_sign) next_sign[i] = ( (prev_sign[i] < 0 && affect! !== nothing) || - (prev_sign[i] > 0 && affect_neg! !== nothing) + (prev_sign[i] > 0 && affect_neg! !== nothing) ) && - prev_sign[i] * next_sign[i] <= 0 + prev_sign[i] * next_sign[i] <= 0 end return any(isone, next_sign) end -function findall_events!(next_sign, affect!::F1, affect_neg!::F2, prev_sign) where {F1,F2} +function findall_events!(next_sign, affect!::F1, affect_neg!::F2, prev_sign) where {F1, F2} hasaffect::Bool = affect! !== nothing hasaffectneg::Bool = affect_neg! !== nothing f = (n, p) -> ((p < 0 && hasaffect) || (p > 0 && hasaffectneg)) && p * n <= 0 @@ -478,18 +478,18 @@ end """ Return `true` if an event occured. """ -function is_event_occurence(prev_sign::Number, next_sign::Number, affect!::F1, affect_neg!::F2) where {F1,F2} +function is_event_occurence(prev_sign::Number, next_sign::Number, affect!::F1, affect_neg!::F2) where {F1, F2} return ( (prev_sign < 0 && affect! !== nothing) || - (prev_sign > 0 && affect_neg! !== nothing) + (prev_sign > 0 && affect_neg! !== nothing) ) && prev_sign * next_sign <= 0 end function apply_callback!( - integrator, - callback::Union{ContinuousCallback,VectorContinuousCallback}, - cb_time, prev_sign, event_idx -) + integrator, + callback::Union{ContinuousCallback, VectorContinuousCallback}, + cb_time, prev_sign, event_idx + ) if isadaptive(integrator) set_proposed_dt!( integrator, @@ -520,20 +520,20 @@ function apply_callback!( integrator.u_modified = false else callback isa VectorContinuousCallback ? - callback.affect!(integrator, event_idx) : callback.affect!(integrator) + callback.affect!(integrator, event_idx) : callback.affect!(integrator) end elseif prev_sign > 0 if callback.affect_neg! === nothing integrator.u_modified = false else callback isa VectorContinuousCallback ? - callback.affect_neg!(integrator, event_idx) : callback.affect_neg!(integrator) + callback.affect_neg!(integrator, event_idx) : callback.affect_neg!(integrator) end end if integrator.u_modified reeval_internals_due_to_modification!( - integrator, callback_initializealg=callback.initializealg + integrator, callback_initializealg = callback.initializealg ) @inbounds if callback.save_positions[2] @@ -567,7 +567,7 @@ end callback.affect!(integrator) if integrator.u_modified reeval_internals_due_to_modification!( - integrator, false, callback_initializealg=callback.initializealg + integrator, false, callback_initializealg = callback.initializealg ) end @inbounds if callback.save_positions[2] @@ -591,12 +591,12 @@ end end @inline function apply_discrete_callback!( - integrator, discrete_modified::Bool, - saved_in_cb::Bool, callback::DiscreteCallback, - args... -) + integrator, discrete_modified::Bool, + saved_in_cb::Bool, callback::DiscreteCallback, + args... + ) bool, - saved_in_cb2 = apply_discrete_callback!( + saved_in_cb2 = apply_discrete_callback!( integrator, apply_discrete_callback!( integrator, @@ -608,9 +608,9 @@ end end @inline function apply_discrete_callback!( - integrator, discrete_modified::Bool, - saved_in_cb::Bool, callback::DiscreteCallback -) + integrator, discrete_modified::Bool, + saved_in_cb::Bool, callback::DiscreteCallback + ) bool, saved_in_cb2 = apply_discrete_callback!(integrator, callback) return discrete_modified || bool, saved_in_cb || saved_in_cb2 end @@ -646,7 +646,7 @@ end """ $(TYPEDEF) """ -mutable struct CallbackCache{conditionType,signType} +mutable struct CallbackCache{conditionType, signType} tmp_condition::conditionType next_condition::conditionType next_sign::signType @@ -654,9 +654,9 @@ mutable struct CallbackCache{conditionType,signType} end function CallbackCache( - u, max_len, ::Type{conditionType}, - ::Type{signType} -) where {conditionType,signType} + u, max_len, ::Type{conditionType}, + ::Type{signType} + ) where {conditionType, signType} tmp_condition = similar(u, conditionType, max_len) next_condition = similar(u, conditionType, max_len) next_sign = similar(u, signType, max_len) @@ -665,9 +665,9 @@ function CallbackCache( end function CallbackCache( - max_len, ::Type{conditionType}, - ::Type{signType} -) where {conditionType,signType} + max_len, ::Type{conditionType}, + ::Type{signType} + ) where {conditionType, signType} tmp_condition = zeros(conditionType, max_len) next_condition = zeros(conditionType, max_len) next_sign = zeros(signType, max_len) From e088b3a3df412c352ddac36675051f6d5aef9ef4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Gonz=C3=A1lez?= Date: Tue, 3 Mar 2026 20:06:12 +0100 Subject: [PATCH 7/9] Remove WIP artifacts --- src/callbacks.jl | 11 +------ test/gen.jl | 77 ------------------------------------------------ test/lv.jl | 69 ------------------------------------------- 3 files changed, 1 insertion(+), 156 deletions(-) delete mode 100644 test/gen.jl delete mode 100644 test/lv.jl diff --git a/src/callbacks.jl b/src/callbacks.jl index 6817fb9a6..47776873e 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -384,15 +384,6 @@ function check_event_occurence_upto(integrator, callback::VectorContinuousCallba return event_occurred, event_idx, top_sign end -_shift(τ, i) = -if iszero(i) - τ -elseif i > 0 - _shift(nextfloat(τ), i - 1) -else - _shift(prevfloat(τ), i + 1) -end - """ Find either exact or floating point precision root of `f`. If the exact root cannot be represented, return closest floating point number depending on `rootfind` @@ -405,7 +396,7 @@ Assumes that: function find_root(f, tup, rootfind::SciMLBase.RootfindOpt) sol = solve( IntervalNonlinearProblem{false}(f, tup), - ModAB2(), abstol = 0.0, reltol = 0.0 + ModAB(), abstol = 0.0, reltol = 0.0 ) if is_inverted_root_pair(sol, f, tup) # "Inverted" root pair (#1290); direction of integration flips the bracket side diff --git a/test/gen.jl b/test/gen.jl deleted file mode 100644 index 8496a171a..000000000 --- a/test/gen.jl +++ /dev/null @@ -1,77 +0,0 @@ -using OrdinaryDiffEqTsit5, OrdinaryDiffEqCore, SciMLBase, LinearAlgebra - -# Lotka-Volterra equations -function lotka_volterra!(du, u, p, t) - α, β, δ, γ = 1.5, 1.0, 3.0, 1.0 - x, y = u - du[1] = α * x - β * x * y - du[2] = δ * x * y - γ * y - return nothing -end - -function run_once(; seed=nothing) - if seed !== nothing - Random.seed!(seed) - end - - # Random coefficients for two linear conditions: c' * u - coeffs1 = randn(2) - - u0 = [1.0, 1.0] - tspan = (0.0, 20.0) - - # Record initial signs - initial_signs = [sign(dot(coeffs1, u0))] - - # VCC condition: two linear functions of state - function vcc_condition!(out, u, t, integrator) - out[1] = dot(coeffs1, u) - return nothing - end - - function vcc_affect!(integrator, event_index) - u = integrator.u - vals = [dot(coeffs1, u)] - v = vals[event_index] - if !iszero(v) && sign(v) == initial_signs[event_index] - @show coeffs1 u event_index v initial_signs - error("VCC fired but value has same sign as initial — RightRootFind bug?") - else - # termine simulation - terminate!(integrator) - end - return nothing - end - - cb = VectorContinuousCallback( - vcc_condition!, - vcc_affect!, - 1; - rootfind=SciMLBase.RightRootFind, - ) - - prob = ODEProblem(lotka_volterra!, u0, tspan) - sol = solve(prob, Tsit5(); callback=cb, abstol=1e-10, reltol=1e-10) - return sol -end - -# Main loop — fish for the bug -i = 0 -while true - global i += 1 - if i % 1000 == 0 - println("Iteration $i ...") - end - try - run_once() - catch e - if e isa ErrorException && contains(e.msg, "RightRootFind") - println("\n*** Bug found at iteration $i ***") - rethrow() - else - rethrow() - end - end -end - -println("No bug found after 100_000 iterations.") diff --git a/test/lv.jl b/test/lv.jl deleted file mode 100644 index 6e20591d4..000000000 --- a/test/lv.jl +++ /dev/null @@ -1,69 +0,0 @@ -using OrdinaryDiffEqTsit5, OrdinaryDiffEqCore, SciMLBase, LinearAlgebra - -# Lotka-Volterra equations -function lotka_volterra!(du, u, p, t) - α, β, δ, γ = 1.5, 1.0, 3.0, 1.0 - x, y = u - du[1] = α * x - β * x * y - du[2] = δ * x * y - γ * y - return nothing -end - -# coeffs1 = [0.6825223495861318, -0.4295322984152052] -# coeffs2 = [1.7358772252665537, -1.0070061675696311] - -coeffs1 = [2.922772251297381, -2.8028553839288595] -u0 = [1.0, 1.0] -tspan = (0.0, 20.0) -tspan = (0.0, 0.03) - -# Record initial signs -initial_conditions = [dot(coeffs1, u0)] -initial_signs = sign.(initial_conditions) - -# VCC condition: two linear functions of state -function vcc_condition!(out, u, t, integrator) - out[1] = dot(coeffs1, u) - return nothing -end - -function vcc_affect!(integrator, event_index) - @show event_index, integrator.t - u = integrator.u - @show integrator.t - if event_index == 1 - @show integrator.u - println("Condition value at crossing: ", [dot(coeffs1, u)]) - terminate!(integrator) - end - return nothing -end - -cb = VectorContinuousCallback( - vcc_condition!, - vcc_affect!, - 1; - rootfind=SciMLBase.RightRootFind, -) - -println("Initial conditions: ", initial_conditions) - -prob = ODEProblem(lotka_volterra!, u0, tspan) -sol = solve(prob, Tsit5(); callback=cb, abstol=1e-10, reltol=1e-10, dense=true) -sol_u = solve(prob, Tsit5(); abstol=1e-10, reltol=1e-10, dense=true) -# sol = solve(prob, Tsit5(); abstol=1e-10, reltol=1e-10) -sol -nothing - -shift(τ, i) = - if iszero(i) - τ - elseif i > 0 - shift(nextfloat(τ), i - 1) - else - shift(prevfloat(τ), i + 1) - end - -# 0.23620973794890948 - -cond2(u) = dot(coeffs2, u) \ No newline at end of file From 6cf22832eb006c8b74aacc8d4d4f16c98925c343 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Gonz=C3=A1lez?= Date: Thu, 5 Mar 2026 11:50:20 +0100 Subject: [PATCH 8/9] Add test --- test/callbacks.jl | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/test/callbacks.jl b/test/callbacks.jl index c409c7362..d15ba9360 100644 --- a/test/callbacks.jl +++ b/test/callbacks.jl @@ -1,4 +1,5 @@ -using DiffEqBase, Test +using DiffEqBase, SciMLBase, Test +using BracketingNonlinearSolve: ModAB condition = function (u, t, integrator) # Event when event_f(u,t,k) == 0 return t - 2.95 @@ -134,3 +135,35 @@ test_find_first_callback(callbacks, find_first_integrator); @test irrational_f(after) < 0.0 @test nextfloat(after) == before end + +# https://github.com/SciML/DiffEqBase.jl/issues/1290 +@testset "Inverted root pair detection and correction" begin + # Discovered via random search + coeffs = [ + -0.7270388932299022, -0.6929210470992349, -0.7343652899957108, + 0.8310017775620168, 0.6030921975763498, 0.46703506019208685, + -2.3581581735824186, 2.0556608750360628, -0.8183123724103458, + -2.5113469878793513, 0.10406374497692948, 0.0701494558467343, + ] + a1, b1, c1, a2, b2, c2, a3, b3, c3, a4, b4, c4 = coeffs + f(t) = + (a1*t^2 + b1*t + c1) * (a3*t^2 + b3*t + c3) + + (a2*t^2 + b2*t + c2) * (a4*t^2 + b4*t + c4) + f(t, _) = f(t) + tspan = (0.4294759977027207, 0.5371755582641773) + + # Verify that ModAB directly produces an inverted root pair for this input + raw_sol = solve( + IntervalNonlinearProblem{false}(f, tspan), + ModAB(), abstol = 0.0, reltol = 0.0 + ) + @test DiffEqBase.is_inverted_root_pair(raw_sol, f, tspan) + + # find_root must detect and correct the inversion: the returned roots must bracket + # a root with matching condition signs + left = DiffEqBase.find_root(f, tspan, SciMLBase.LeftRootFind) + right = DiffEqBase.find_root(f, tspan, SciMLBase.RightRootFind) + @test f(left) > 0.0 + @test f(right) < 0.0 + @test nextfloat(left) == right +end From cfc7d254550feb6070eee4b02ce901afb32b1dbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Gonz=C3=A1lez?= Date: Thu, 5 Mar 2026 11:53:35 +0100 Subject: [PATCH 9/9] Formatting --- test/callbacks.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/callbacks.jl b/test/callbacks.jl index d15ba9360..3b7aefecc 100644 --- a/test/callbacks.jl +++ b/test/callbacks.jl @@ -140,15 +140,15 @@ end @testset "Inverted root pair detection and correction" begin # Discovered via random search coeffs = [ - -0.7270388932299022, -0.6929210470992349, -0.7343652899957108, - 0.8310017775620168, 0.6030921975763498, 0.46703506019208685, - -2.3581581735824186, 2.0556608750360628, -0.8183123724103458, - -2.5113469878793513, 0.10406374497692948, 0.0701494558467343, + -0.7270388932299022, -0.6929210470992349, -0.7343652899957108, + 0.8310017775620168, 0.6030921975763498, 0.46703506019208685, + -2.3581581735824186, 2.0556608750360628, -0.8183123724103458, + -2.5113469878793513, 0.10406374497692948, 0.0701494558467343, ] a1, b1, c1, a2, b2, c2, a3, b3, c3, a4, b4, c4 = coeffs f(t) = - (a1*t^2 + b1*t + c1) * (a3*t^2 + b3*t + c3) + - (a2*t^2 + b2*t + c2) * (a4*t^2 + b4*t + c4) + (a1 * t^2 + b1 * t + c1) * (a3 * t^2 + b3 * t + c3) + + (a2 * t^2 + b2 * t + c2) * (a4 * t^2 + b4 * t + c4) f(t, _) = f(t) tspan = (0.4294759977027207, 0.5371755582641773) @@ -161,9 +161,9 @@ end # find_root must detect and correct the inversion: the returned roots must bracket # a root with matching condition signs - left = DiffEqBase.find_root(f, tspan, SciMLBase.LeftRootFind) + left = DiffEqBase.find_root(f, tspan, SciMLBase.LeftRootFind) right = DiffEqBase.find_root(f, tspan, SciMLBase.RightRootFind) - @test f(left) > 0.0 + @test f(left) > 0.0 @test f(right) < 0.0 @test nextfloat(left) == right end