-
Notifications
You must be signed in to change notification settings - Fork 0
Fixed-Step Simulation Hot-Path Improvements #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,44 +41,85 @@ function propagate_variable(x::T, dt, x_dot::T) where {T} | |
| return (x + dt * x_dot)::T # Just to be clear, this shouldn't change the type. | ||
| end | ||
|
|
||
| function propagate_set(x::T1, dt, x_dot::T2) where {T1, T2} | ||
| @generated function propagate_set(x::T1, dt, x_dot::T2) where {T1, T2} | ||
| names = fieldnames(T1) | ||
| values = [ | ||
| hasfield(T2, name) ? | ||
| :(propagate_variable(getfield(x, $(QuoteNode(name))), dt, getfield(x_dot, $(QuoteNode(name))))) : | ||
| :(getfield(x, $(QuoteNode(name)))) | ||
| for name in names | ||
| ] | ||
| return :(NamedTuple{$names}(($(values...),))) | ||
| end | ||
|
|
||
| is_empty_rates_output(rates_output::RatesOutput) = | ||
| isempty(fieldnames(typeof(rates_output.rates))) && | ||
| isempty(fieldnames(typeof(rates_output.models))) | ||
|
|
||
| @generated function propagate_models( | ||
| submodels::NamedTuple{names}, dt, rates_output::TRO, | ||
| ) where {names, TRO} | ||
| values = [ | ||
| hasfield(TRO, name) ? | ||
| :(propagate(getfield(submodels, $(QuoteNode(name))), dt, getfield(rates_output, $(QuoteNode(name))))) : | ||
| :(getfield(submodels, $(QuoteNode(name)))) | ||
| for name in names | ||
| ] | ||
| return :(NamedTuple{$names}(($(values...),))) | ||
| end | ||
|
|
||
| function propagate(msd::ModelStateDescription, dt, rates_output::RatesOutput) | ||
| is_empty_rates_output(rates_output) && return msd | ||
| return copy_model_state_description_except( | ||
| msd; | ||
| continuous_states = propagate_set(msd.continuous_states, dt, rates_output.rates), | ||
| models = propagate_models(msd.models, dt, rates_output.models), | ||
| ) | ||
| end | ||
|
|
||
| function propagate_variable_rk4(x::T, dt, k1::T, k2::T, k3::T, k4::T) where {T} | ||
| return (x + dt/6 * (k1 + 2*k2 + 2*k3 + k4))::T | ||
| end | ||
|
|
||
| function propagate_set_rk4(x::T1, dt, r1, r2, r3, r4) where {T1} | ||
| return NamedTuple{fieldnames(T1)}( | ||
| map(fieldnames(T1)) do f | ||
| if hasfield(typeof(x_dot), f) | ||
| propagate_variable(x[f], dt, x_dot[f]) | ||
| if hasfield(typeof(r1), f) | ||
| propagate_variable_rk4(x[f], dt, r1[f], r2[f], r3[f], r4[f]) | ||
| else | ||
| x[f] | ||
| end | ||
| end | ||
| ) | ||
| end | ||
|
|
||
| function propagate_models(submodels::NamedTuple, dt, rates_output::NamedTuple) | ||
|
|
||
| # A user's RatesOutput's model entry could contain the models in any order. Here, we | ||
| # build a named tuple that matches the order of the original set of submodels. Plus, if | ||
| # an entry is missing, we fill it in with a blank RatesOutput(). This lets us simply | ||
| # `map` below. | ||
| complete_rates_output = NamedTuple{fieldnames(typeof(submodels))}( | ||
| map(fieldnames(typeof(submodels))) do f | ||
| if hasfield(typeof(rates_output), f) | ||
| rates_output[f] | ||
| else | ||
| RatesOutput() | ||
| end | ||
| end | ||
| function propagate_models_rk4(submodels::NamedTuple, dt, m1::NamedTuple, m2::NamedTuple, m3::NamedTuple, m4::NamedTuple) | ||
| names = fieldnames(typeof(submodels)) | ||
| complete_m1 = NamedTuple{names}(map(names) do f | ||
| hasfield(typeof(m1), f) ? getfield(m1, f) : RatesOutput() | ||
| end) | ||
| complete_m2 = NamedTuple{names}(map(names) do f | ||
| hasfield(typeof(m2), f) ? getfield(m2, f) : RatesOutput() | ||
| end) | ||
| complete_m3 = NamedTuple{names}(map(names) do f | ||
| hasfield(typeof(m3), f) ? getfield(m3, f) : RatesOutput() | ||
| end) | ||
| complete_m4 = NamedTuple{names}(map(names) do f | ||
| hasfield(typeof(m4), f) ? getfield(m4, f) : RatesOutput() | ||
| end) | ||
| return map( | ||
| (sm, ro1, ro2, ro3, ro4) -> propagate_rk4(sm, dt, ro1, ro2, ro3, ro4), | ||
| submodels, complete_m1, complete_m2, complete_m3, complete_m4, | ||
|
Comment on lines
+110
to
+112
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm surprised about this part! The multi-input |
||
| ) | ||
|
|
||
| # Now this is a simple map and doesn't allocate. | ||
| return map((sm, ro) -> propagate(sm, dt, ro), submodels, complete_rates_output) | ||
|
|
||
| end | ||
|
|
||
| function propagate(msd::ModelStateDescription, dt, rates_output::RatesOutput) | ||
| function propagate_rk4(msd::ModelStateDescription, dt, k1::RatesOutput, k2::RatesOutput, k3::RatesOutput, k4::RatesOutput) | ||
| is_empty_rates_output(k1) && is_empty_rates_output(k2) && | ||
| is_empty_rates_output(k3) && is_empty_rates_output(k4) && return msd | ||
|
Comment on lines
+117
to
+118
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If any of these is empty when the others aren't, that would be an error, so it's sufficient to check only one. |
||
| return copy_model_state_description_except( | ||
| msd; | ||
| continuous_states = propagate_set(msd.continuous_states, dt, rates_output.rates), | ||
| models = propagate_models(msd.models, dt, rates_output.models), | ||
| continuous_states = propagate_set_rk4(msd.continuous_states, dt, k1.rates, k2.rates, k3.rates, k4.rates), | ||
| models = propagate_models_rk4(msd.models, dt, k1.models, k2.models, k3.models, k4.models), | ||
| ) | ||
| end | ||
|
|
||
|
|
@@ -89,38 +130,41 @@ function propagate_variable(x::T, gains, x_dot::NTuple{N, T}) where {T, N} | |
| return (x + sum(gains .* x_dot))::T # Just to be clear, this shouldn't change the type. | ||
| end | ||
|
|
||
| function propagate_set(x::T1, gains, x_dot::Tuple) where {T1} | ||
| return NamedTuple{fieldnames(T1)}( | ||
| map(fieldnames(T1)) do f | ||
| if hasfield(typeof(first(x_dot)), f) # TODO: Check this for efficiency. | ||
| propagate_variable(x[f], gains, getfield.(x_dot, f)) | ||
| else | ||
| x[f] # Allow fields to not be updated (empty rates output). | ||
| end | ||
| end | ||
| ) | ||
| @generated function propagate_set(x::T1, gains, x_dot::TX) where {T1, TX <: Tuple} | ||
| names = fieldnames(T1) | ||
| n = fieldcount(TX) | ||
| first_rates_type = fieldtype(TX, 1) | ||
| values = [ | ||
| hasfield(first_rates_type, name) ? | ||
| :(propagate_variable( | ||
| getfield(x, $(QuoteNode(name))), gains, | ||
| ntuple(i -> getfield(getfield(x_dot, i), $(QuoteNode(name))), Val($n)), | ||
| )) : | ||
| :(getfield(x, $(QuoteNode(name)))) | ||
| for name in names | ||
| ] | ||
| return :(NamedTuple{$names}(($(values...),))) | ||
| end | ||
|
|
||
| # `submodels` is a named tuple of ModelStateDescriptions. | ||
| # `gains` is a tuple of gains. | ||
| # `rates_output` is a tuple (one for each gain) of named tuples holding the RatesOutput | ||
| # of each of the submodels (for submodels that have such an output). | ||
| function propagate_models(submodels::NamedTuple, gains::Tuple, rates_outputs::Tuple) | ||
| complete_rates_outputs = map(rates_outputs) do ro | ||
| NamedTuple{fieldnames(typeof(submodels))}( | ||
| map(fieldnames(typeof(submodels))) do f | ||
| if hasfield(typeof(ro), f) # If we have derivatives for this state... | ||
| getfield(ro, f) # Get it for all of them. | ||
| else | ||
| RatesOutput() | ||
| end | ||
| end | ||
| ) | ||
| end | ||
| return map( | ||
| (sm, ro...) -> propagate(sm, gains, ro), | ||
| submodels, complete_rates_outputs... | ||
| ) | ||
| @generated function propagate_models( | ||
| submodels::NamedTuple{names}, gains::Tuple, rates_outputs::TRO, | ||
| ) where {names, TRO <: Tuple} | ||
| n = fieldcount(TRO) | ||
| first_models_type = fieldtype(TRO, 1) | ||
| values = [ | ||
| hasfield(first_models_type, name) ? | ||
| :(propagate( | ||
| getfield(submodels, $(QuoteNode(name))), gains, | ||
| ntuple(i -> getfield(getfield(rates_outputs, i), $(QuoteNode(name))), Val($n)), | ||
| )) : | ||
| :(getfield(submodels, $(QuoteNode(name)))) | ||
| for name in names | ||
| ] | ||
| return :(NamedTuple{$names}(($(values...),))) | ||
| end | ||
|
|
||
| function propagate(msd::ModelStateDescription{T}, gains::Tuple, rates_outputs::Tuple) where {T} | ||
|
|
@@ -148,56 +192,54 @@ end | |
| create_solver(options::RungeKutta4Options, msd::ModelStateDescription) = RungeKutta4(options) | ||
|
|
||
| get_initial_time_step(solver::RungeKutta4) = solver.options.dt | ||
| handles_internal_substepping(::AbstractSolver) = false | ||
| handles_internal_substepping(::RungeKutta4) = true | ||
|
|
||
| # TODO: It seems like there's a lot about `solve` that could be abstracted and simplified. | ||
| function solve(ommd, solver::RungeKutta4, t_last, t_next, msd_km1, rates_fcn, t_end) | ||
| dt_solver = solver.options.dt | ||
| function rk4_substep(msd_start, t_start, t_stop, msd_with_draws, k1) | ||
| # If there's no actual work to do here, skip the calculations. | ||
| if t_start == t_stop | ||
| return msd_with_draws | ||
| else | ||
| t_start_f = float(t_start) | ||
| t_stop_f = float(t_stop) | ||
| dt = t_stop_f - t_start_f | ||
| msd2 = propagate(msd_with_draws, dt/2, k1) | ||
| k2 = rates_fcn(t_start_f + dt/2, model(msd2)) | ||
| msd3 = propagate(msd_with_draws, dt/2, k2) | ||
| k3 = rates_fcn(t_start_f + dt/2, model(msd3)) | ||
| msd4 = propagate(msd_with_draws, dt, k3) | ||
| k4 = rates_fcn(t_start_f + dt, model(msd4)) | ||
| msd_stop = msd_with_draws | ||
| msd_stop = propagate(msd_stop, dt/6, k1) | ||
| msd_stop = propagate(msd_stop, dt/3, k2) | ||
| msd_stop = propagate(msd_stop, dt/3, k3) | ||
| msd_stop = propagate(msd_stop, dt/6, k4) | ||
| return msd_stop | ||
| end | ||
| end | ||
|
|
||
| t_last_f = float(t_last) | ||
| t_next_f = float(t_next) | ||
|
|
||
| # Make the draws for the continuous-time function. | ||
| msd_km1_with_draws = draw_wc(t_last_f, t_next_f, ommd, msd_km1) | ||
|
|
||
| # The first derivative is different because it's an output. The rest are ephemeral. | ||
| msd1 = msd_km1_with_draws | ||
| k1 = rates_fcn(t_last_f, model(msd1)) | ||
|
|
||
| # If there's no actual work to do here, skip the calculations. | ||
| if t_last == t_next | ||
|
|
||
| msd_k = msd_km1_with_draws | ||
|
|
||
| else | ||
|
|
||
| dt = t_next_f - t_last_f | ||
| msd2 = propagate(msd1, dt/2, k1) | ||
| k2 = rates_fcn(t_last_f + dt/2, model(msd2)) | ||
| msd3 = propagate(msd1, dt/2, k2) | ||
| k3 = rates_fcn(t_last_f + dt/2, model(msd3)) | ||
| msd4 = propagate(msd1, dt, k3) | ||
| k4 = rates_fcn(t_last_f + dt, model(msd4)) | ||
|
|
||
| # This seems more efficient: | ||
| # propagate( | ||
| # msd_km1_with_draws, | ||
| # (dt/6, dt/3, dt/3, dt/6), | ||
| # (k1, k2, k3, k4), | ||
| # ) | ||
|
|
||
| # But this doesn't allocate and is actually slightly faster. | ||
| msd_k = msd_km1_with_draws | ||
| msd_k = propagate(msd_k, dt/6, k1) | ||
| msd_k = propagate(msd_k, dt/3, k2) | ||
| msd_k = propagate(msd_k, dt/3, k3) | ||
| msd_k = propagate(msd_k, dt/6, k4) | ||
|
|
||
| t_sub_last = t_last | ||
| t_sub_next = min(t_next, t_sub_last + dt_solver) | ||
| first_msd = draw_wc(float(t_sub_last), float(t_sub_next), ommd, msd_km1) | ||
| first_rates = rates_fcn(float(t_sub_last), model(first_msd)) | ||
| msd_k = rk4_substep(msd_km1, t_sub_last, t_sub_next, first_msd, first_rates) | ||
|
|
||
| while t_sub_next != t_next | ||
| t_sub_last = t_sub_next | ||
| t_sub_next = min(t_next, t_sub_last + dt_solver) | ||
| msd_with_draws = draw_wc(float(t_sub_last), float(t_sub_next), ommd, msd_k) | ||
| k1 = rates_fcn(float(t_sub_last), model(msd_with_draws)) | ||
| msd_k = rk4_substep(msd_k, t_sub_last, t_sub_next, msd_with_draws, k1) | ||
| end | ||
|
|
||
| return SolverOutputs(; | ||
| t_completed = t_next, # This should already be a rational. | ||
| msd_km1 = msd_km1_with_draws, | ||
| msd_km1 = first_msd, | ||
| msd_k, | ||
| rates = k1, | ||
| rates = first_rates, | ||
| stop = UnknownStopReason(), | ||
| t_next_suggested = t_next + solver.options.dt, # Already rational | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was pretty happy that I had no allocations for RK4 without generated functions! I wonder if this is actually an improvement.