diff --git a/Project.toml b/Project.toml index a868f85a77c..2f69aada669 100644 --- a/Project.toml +++ b/Project.toml @@ -175,6 +175,7 @@ SimpleNonlinearSolve = "2.7" Static = "1.2" StaticArrayInterface = "1.8" StaticArrays = "1.9.14" +StableRNGs = "1" StructArrays = "0.6, 0.7" Symbolics = "6, 7" TruncatedStacktraces = "1.4" @@ -197,6 +198,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -206,4 +208,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["ComponentArrays", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DifferentiationInterface", "DiffEqDevTools", "ExplicitImports", "ODEProblemLibrary", "ElasticArrays", "JLArrays", "Random", "SafeTestsets", "StructArrays", "Test", "Unitful", "Pkg", "NLsolve", "RecursiveFactorization", "SparseConnectivityTracer", "SparseMatrixColorings", "Statistics"] +test = ["ComponentArrays", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DifferentiationInterface", "DiffEqDevTools", "ExplicitImports", "ODEProblemLibrary", "ElasticArrays", "JLArrays", "Random", "SafeTestsets", "StableRNGs", "StructArrays", "Test", "Unitful", "Pkg", "NLsolve", "RecursiveFactorization", "SparseConnectivityTracer", "SparseMatrixColorings", "Statistics"] diff --git a/lib/OrdinaryDiffEqCore/Project.toml b/lib/OrdinaryDiffEqCore/Project.toml index 3ca2a7adf3c..7512219c0fa 100644 --- a/lib/OrdinaryDiffEqCore/Project.toml +++ b/lib/OrdinaryDiffEqCore/Project.toml @@ -4,6 +4,7 @@ authors = ["ParamThakkar123 "] version = "3.6.0" [deps] +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" @@ -43,7 +44,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" @@ -101,7 +101,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [targets] -test = ["DiffEqDevTools", "Random", "SafeTestsets", "SparseArrays", "Test", "Pkg"] +test = ["DiffEqDevTools", "SafeTestsets", "SparseArrays", "Test", "Pkg"] [extensions] OrdinaryDiffEqCoreMooncakeExt = "Mooncake" diff --git a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl index 259f75aebc1..fbfd8a743b7 100644 --- a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl +++ b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl @@ -42,6 +42,8 @@ import SciMLOperators: AbstractSciMLOperator, AbstractSciMLScalarOperator, using DiffEqBase: DEIntegrator +import Random + import RecursiveArrayTools: chain, recursivecopy!, recursivecopy, recursive_bottom_eltype, recursive_unitless_bottom_eltype, recursive_unitless_eltype, copyat_or_push!, DiffEqArray, recursivefill! import RecursiveArrayTools diff --git a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl index 44fad59d06d..e60693b9eea 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl @@ -360,6 +360,23 @@ end const EMPTY_ARRAY_OF_PAIRS = Pair[] +SciMLBase.has_rng(::ODEIntegrator) = true +SciMLBase.get_rng(integrator::ODEIntegrator) = integrator.rng +function SciMLBase.set_rng!(integrator::ODEIntegrator, rng) + R = typeof(integrator.rng) + if !isa(rng, R) + throw( + ArgumentError( + "Cannot set RNG of type $(typeof(rng)) on an integrator " * + "whose RNG type parameter is $R. " * + "Construct a new integrator via `init(prob, alg; rng = your_rng)` instead." + ) + ) + end + integrator.rng = rng + return nothing +end + SciMLBase.has_reinit(integrator::ODEIntegrator) = true function SciMLBase.reinit!( integrator::ODEIntegrator, u0 = integrator.sol.prob.u0; @@ -374,7 +391,8 @@ function SciMLBase.reinit!( reinit_dae = true, reinit_callbacks = true, initialize_save = true, reinit_cache = true, - reinit_retcode = true + reinit_retcode = true, + rng = nothing ) if reinit_dae && SciMLBase.has_initializeprob(integrator.sol.prob.f) # This is `remake` infrastructure. `reinit!` is somewhat like `remake` for @@ -461,6 +479,10 @@ function SciMLBase.reinit!( integrator.erracc = typeof(integrator.erracc)(1) integrator.dtacc = typeof(integrator.dtacc)(1) + if rng !== nothing + SciMLBase.set_rng!(integrator, rng) + end + if reset_dt auto_dt_reset!(integrator) end diff --git a/lib/OrdinaryDiffEqCore/src/integrators/type.jl b/lib/OrdinaryDiffEqCore/src/integrators/type.jl index 827193e8b62..d97636952bd 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/type.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/type.jl @@ -89,7 +89,7 @@ mutable struct ODEIntegrator{ algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}, IIP, uType, duType, tType, pType, eigenType, EEstT, QT, tdirType, ksEltype, SolType, F, CacheType, O, FSALType, EventErrorType, - CallbackCacheType, IA, DV, CC, + CallbackCacheType, IA, DV, CC, RNGType, } <: SciMLBase.AbstractODEIntegrator{algType, IIP, uType, tType} sol::SolType @@ -144,4 +144,5 @@ mutable struct ODEIntegrator{ differential_vars::DV fsalfirst::FSALType fsallast::FSALType + rng::RNGType end diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index 09e72d5f07b..976402705f8 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -79,6 +79,7 @@ function SciMLBase.__init( initialize_integrator = true, alias = ODEAliasSpecifier(), initializealg = DefaultInit(), + rng = nothing, kwargs... ) if prob isa SciMLBase.AbstractDAEProblem && alg isa OrdinaryDiffEqAlgorithm @@ -650,6 +651,8 @@ function SciMLBase.__init( saveiter_dense = 0 fsalfirst, fsallast = get_fsalfirstlast(cache, rate_prototype) + _rng = rng === nothing ? Random.default_rng() : rng + integrator = ODEIntegrator{ typeof(_alg), isinplace(prob), uType, typeof(du), tType, typeof(p), @@ -659,7 +662,7 @@ function SciMLBase.__init( typeof(opts), typeof(fsalfirst), typeof(last_event_error), typeof(callback_cache), typeof(initializealg), typeof(differential_vars), - typeof(controller_cache), + typeof(controller_cache), typeof(_rng), }( sol, u, du, k, t, tType(_dt), f, p, uprev, uprev2, duprev, tprev, @@ -682,7 +685,7 @@ function SciMLBase.__init( isout, reeval_fsal, u_modified, reinitiailize, isdae, opts, stats, initializealg, differential_vars, - fsalfirst, fsallast + fsalfirst, fsallast, _rng ) if initialize_integrator diff --git a/test/integrators/integrator_rng_tests.jl b/test/integrators/integrator_rng_tests.jl new file mode 100644 index 00000000000..2d482327a7b --- /dev/null +++ b/test/integrators/integrator_rng_tests.jl @@ -0,0 +1,261 @@ +using OrdinaryDiffEq, Test, Random, StableRNGs + +# Simple ODE for testing: du/dt = 2u +f_oop(u, p, t) = 2u +f_iip(du, u, p, t) = (du .= 2 .* u) + +@testset "Integrator RNG Interface" begin + @testset "Default RNG (rng not provided)" begin + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + integrator = init(prob, Tsit5()) + + @test SciMLBase.has_rng(integrator) + @test SciMLBase.get_rng(integrator) === Random.default_rng() + end + + @testset "Custom RNG via init (out-of-place)" begin + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + rng = Random.Xoshiro(42) + integrator = init(prob, Tsit5(); rng) + + @test SciMLBase.has_rng(integrator) + @test SciMLBase.get_rng(integrator) === rng + end + + @testset "Custom RNG via init (in-place)" begin + prob = ODEProblem(f_iip, [0.5], (0.0, 1.0)) + rng = Random.Xoshiro(123) + integrator = init(prob, Tsit5(); rng) + + @test SciMLBase.has_rng(integrator) + @test SciMLBase.get_rng(integrator) === rng + end + + @testset "Custom RNG via solve propagates to integrator" begin + rng_from_callback = Ref{Any}(nothing) + cb = DiscreteCallback( + (u, t, integrator) -> true, + integrator -> begin + rng_from_callback[] = SciMLBase.get_rng(integrator) + return nothing + end + ) + + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + rng = Random.Xoshiro(99) + sol = solve(prob, Tsit5(); rng, callback = cb) + + @test sol.retcode == ReturnCode.Success + @test rng_from_callback[] === rng + end + + @testset "set_rng! replaces the RNG" begin + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + rng1 = Random.Xoshiro(1) + integrator = init(prob, Tsit5(); rng = rng1) + + @test SciMLBase.get_rng(integrator) === rng1 + + rng2 = Random.Xoshiro(2) + SciMLBase.set_rng!(integrator, rng2) + @test SciMLBase.get_rng(integrator) === rng2 + end + + @testset "reinit! with rng kwarg" begin + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + rng1 = Random.Xoshiro(10) + integrator = init(prob, Tsit5(); rng = rng1) + @test SciMLBase.get_rng(integrator) === rng1 + + # reinit! without rng should keep existing RNG + reinit!(integrator) + @test SciMLBase.get_rng(integrator) === rng1 + + # reinit! with rng should replace it + rng2 = Random.Xoshiro(20) + reinit!(integrator; rng = rng2) + @test SciMLBase.get_rng(integrator) === rng2 + end + + @testset "reinit! with rng sets RNG before callback initialization" begin + rng_seen_in_init = Ref{Any}(nothing) + cb = DiscreteCallback( + (u, t, integrator) -> true, + integrator -> nothing; + initialize = (cb, u, t, integrator) -> begin + rng_seen_in_init[] = SciMLBase.get_rng(integrator) + return nothing + end + ) + + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + rng1 = Random.Xoshiro(10) + integrator = init(prob, Tsit5(); rng = rng1, callback = cb) + + rng2 = Random.Xoshiro(20) + reinit!(integrator; rng = rng2, reinit_callbacks = true) + + # The callback's initialize hook should see rng2, not rng1 + @test rng_seen_in_init[] === rng2 + end + + @testset "RNG preserved across solve! cycle" begin + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + rng = Random.Xoshiro(42) + integrator = init(prob, Tsit5(); rng) + solve!(integrator) + + @test SciMLBase.get_rng(integrator) === rng + @test integrator.sol.retcode == ReturnCode.Success + end + + @testset "RNG type parameter is concrete" begin + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + rng = Random.Xoshiro(42) + integrator = init(prob, Tsit5(); rng) + + # The RNG type should be a concrete type parameter, not Any + @test typeof(integrator).parameters[end] === typeof(rng) + end + + @testset "Callback can access RNG via get_rng" begin + rng_from_callback = Ref{Any}(nothing) + cb = DiscreteCallback( + (u, t, integrator) -> true, + integrator -> begin + rng_from_callback[] = SciMLBase.get_rng(integrator) + return nothing + end + ) + + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + rng = Random.Xoshiro(42) + sol = solve(prob, Tsit5(); rng, callback = cb) + + @test rng_from_callback[] === rng + end + + @testset "set_rng! with incompatible type throws ArgumentError" begin + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + rng = Random.Xoshiro(42) + integrator = init(prob, Tsit5(); rng) + + @test_throws ArgumentError SciMLBase.set_rng!(integrator, Random.MersenneTwister(1)) + end + + @testset "Different solvers support rng kwarg" begin + prob_oop = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + prob_iip = ODEProblem(f_iip, [0.5], (0.0, 1.0)) + rng = Random.Xoshiro(42) + + for (alg, prob) in [ + (Tsit5(), prob_oop), + (Vern7(), prob_oop), + (RK4(), prob_iip), + (Rosenbrock23(), prob_iip), + ] + integrator = init(prob, alg; rng) + @test SciMLBase.has_rng(integrator) + @test SciMLBase.get_rng(integrator) === rng + solve!(integrator) + @test integrator.sol.retcode == ReturnCode.Success + end + end + + @testset "Callback rand draws are reproducible with StableRNG" begin + # A callback that draws from the integrator RNG at every step + function make_collecting_callback() + draws = Float64[] + cb = DiscreteCallback( + (u, t, integrator) -> true, + integrator -> begin + push!(draws, rand(SciMLBase.get_rng(integrator))) + return nothing + end + ) + return cb, draws + end + + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + + cb1, draws1 = make_collecting_callback() + solve(prob, Tsit5(); rng = StableRNG(42), callback = cb1) + + cb2, draws2 = make_collecting_callback() + solve(prob, Tsit5(); rng = StableRNG(42), callback = cb2) + + @test !isempty(draws1) + @test draws1 == draws2 + end + + @testset "Callback rand draws differ with different StableRNG seeds" begin + function make_collecting_callback() + draws = Float64[] + cb = DiscreteCallback( + (u, t, integrator) -> true, + integrator -> begin + push!(draws, rand(SciMLBase.get_rng(integrator))) + return nothing + end + ) + return cb, draws + end + + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + + cb1, draws1 = make_collecting_callback() + solve(prob, Tsit5(); rng = StableRNG(42), callback = cb1) + + cb2, draws2 = make_collecting_callback() + solve(prob, Tsit5(); rng = StableRNG(99), callback = cb2) + + @test !isempty(draws1) + @test draws1 != draws2 + end + + @testset "reinit! with new StableRNG resets rand sequence" begin + draws = Float64[] + cb = DiscreteCallback( + (u, t, integrator) -> true, + integrator -> begin + push!(draws, rand(SciMLBase.get_rng(integrator))) + return nothing + end + ) + + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + integrator = init(prob, Tsit5(); rng = StableRNG(42), callback = cb) + solve!(integrator) + draws_run1 = copy(draws) + + empty!(draws) + reinit!(integrator; rng = StableRNG(42)) + solve!(integrator) + draws_run2 = draws + + @test !isempty(draws_run1) + @test draws_run1 == draws_run2 + end + + @testset "StableRNG type parameter is concrete" begin + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + rng = StableRNG(42) + integrator = init(prob, Tsit5(); rng) + + @test typeof(integrator).parameters[end] === StableRNG + @test SciMLBase.get_rng(integrator) === rng + end + + @testset "set_rng! works with same StableRNG type" begin + prob = ODEProblem(f_oop, 0.5, (0.0, 1.0)) + rng1 = StableRNG(1) + integrator = init(prob, Tsit5(); rng = rng1) + + rng2 = StableRNG(2) + SciMLBase.set_rng!(integrator, rng2) + @test SciMLBase.get_rng(integrator) === rng2 + + # Cross-type should fail + @test_throws ArgumentError SciMLBase.set_rng!(integrator, Random.Xoshiro(1)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 3ddd477f5b5..6f810d27715 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -137,6 +137,7 @@ end if !is_APPVEYOR && (GROUP == "All" || GROUP == "Integrators_II" || GROUP == "Integrators") + @time @safetestset "Integrator RNG Tests" include("integrators/integrator_rng_tests.jl") @time @safetestset "Reverse Directioned Event Tests" include("integrators/rev_events_tests.jl") @time @safetestset "Differentiation Direction Tests" include("integrators/diffdir_tests.jl") @time @safetestset "Resize Tests" include("integrators/resize_tests.jl")