Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 133 additions & 91 deletions src/Solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,44 +41,85 @@ function propagate_variable(x::T, dt, x_dot::T) where {T}
return (x + dt * x_dot)::T # Just to be clear, this shouldn't change the type.
end

function propagate_set(x::T1, dt, x_dot::T2) where {T1, T2}
@generated function propagate_set(x::T1, dt, x_dot::T2) where {T1, T2}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was pretty happy that I had no allocations for RK4 without generated functions! I wonder if this is actually an improvement.

names = fieldnames(T1)
values = [
hasfield(T2, name) ?
:(propagate_variable(getfield(x, $(QuoteNode(name))), dt, getfield(x_dot, $(QuoteNode(name))))) :
:(getfield(x, $(QuoteNode(name))))
for name in names
]
return :(NamedTuple{$names}(($(values...),)))
end

is_empty_rates_output(rates_output::RatesOutput) =
isempty(fieldnames(typeof(rates_output.rates))) &&
isempty(fieldnames(typeof(rates_output.models)))

@generated function propagate_models(
submodels::NamedTuple{names}, dt, rates_output::TRO,
) where {names, TRO}
values = [
hasfield(TRO, name) ?
:(propagate(getfield(submodels, $(QuoteNode(name))), dt, getfield(rates_output, $(QuoteNode(name))))) :
:(getfield(submodels, $(QuoteNode(name))))
for name in names
]
return :(NamedTuple{$names}(($(values...),)))
end

function propagate(msd::ModelStateDescription, dt, rates_output::RatesOutput)
is_empty_rates_output(rates_output) && return msd
return copy_model_state_description_except(
msd;
continuous_states = propagate_set(msd.continuous_states, dt, rates_output.rates),
models = propagate_models(msd.models, dt, rates_output.models),
)
end

function propagate_variable_rk4(x::T, dt, k1::T, k2::T, k3::T, k4::T) where {T}
return (x + dt/6 * (k1 + 2*k2 + 2*k3 + k4))::T
end

function propagate_set_rk4(x::T1, dt, r1, r2, r3, r4) where {T1}
return NamedTuple{fieldnames(T1)}(
map(fieldnames(T1)) do f
if hasfield(typeof(x_dot), f)
propagate_variable(x[f], dt, x_dot[f])
if hasfield(typeof(r1), f)
propagate_variable_rk4(x[f], dt, r1[f], r2[f], r3[f], r4[f])
else
x[f]
end
end
)
end

function propagate_models(submodels::NamedTuple, dt, rates_output::NamedTuple)

# A user's RatesOutput's model entry could contain the models in any order. Here, we
# build a named tuple that matches the order of the original set of submodels. Plus, if
# an entry is missing, we fill it in with a blank RatesOutput(). This lets us simply
# `map` below.
complete_rates_output = NamedTuple{fieldnames(typeof(submodels))}(
map(fieldnames(typeof(submodels))) do f
if hasfield(typeof(rates_output), f)
rates_output[f]
else
RatesOutput()
end
end
function propagate_models_rk4(submodels::NamedTuple, dt, m1::NamedTuple, m2::NamedTuple, m3::NamedTuple, m4::NamedTuple)
names = fieldnames(typeof(submodels))
complete_m1 = NamedTuple{names}(map(names) do f
hasfield(typeof(m1), f) ? getfield(m1, f) : RatesOutput()
end)
complete_m2 = NamedTuple{names}(map(names) do f
hasfield(typeof(m2), f) ? getfield(m2, f) : RatesOutput()
end)
complete_m3 = NamedTuple{names}(map(names) do f
hasfield(typeof(m3), f) ? getfield(m3, f) : RatesOutput()
end)
complete_m4 = NamedTuple{names}(map(names) do f
hasfield(typeof(m4), f) ? getfield(m4, f) : RatesOutput()
end)
return map(
(sm, ro1, ro2, ro3, ro4) -> propagate_rk4(sm, dt, ro1, ro2, ro3, ro4),
submodels, complete_m1, complete_m2, complete_m3, complete_m4,
Comment on lines +110 to +112
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised about this part! The multi-input map often seems to optimize more poorly.

)

# Now this is a simple map and doesn't allocate.
return map((sm, ro) -> propagate(sm, dt, ro), submodels, complete_rates_output)

end

function propagate(msd::ModelStateDescription, dt, rates_output::RatesOutput)
function propagate_rk4(msd::ModelStateDescription, dt, k1::RatesOutput, k2::RatesOutput, k3::RatesOutput, k4::RatesOutput)
is_empty_rates_output(k1) && is_empty_rates_output(k2) &&
is_empty_rates_output(k3) && is_empty_rates_output(k4) && return msd
Comment on lines +117 to +118
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If any of these is empty when the others aren't, that would be an error, so it's sufficient to check only one.

return copy_model_state_description_except(
msd;
continuous_states = propagate_set(msd.continuous_states, dt, rates_output.rates),
models = propagate_models(msd.models, dt, rates_output.models),
continuous_states = propagate_set_rk4(msd.continuous_states, dt, k1.rates, k2.rates, k3.rates, k4.rates),
models = propagate_models_rk4(msd.models, dt, k1.models, k2.models, k3.models, k4.models),
)
end

Expand All @@ -89,38 +130,41 @@ function propagate_variable(x::T, gains, x_dot::NTuple{N, T}) where {T, N}
return (x + sum(gains .* x_dot))::T # Just to be clear, this shouldn't change the type.
end

function propagate_set(x::T1, gains, x_dot::Tuple) where {T1}
return NamedTuple{fieldnames(T1)}(
map(fieldnames(T1)) do f
if hasfield(typeof(first(x_dot)), f) # TODO: Check this for efficiency.
propagate_variable(x[f], gains, getfield.(x_dot, f))
else
x[f] # Allow fields to not be updated (empty rates output).
end
end
)
@generated function propagate_set(x::T1, gains, x_dot::TX) where {T1, TX <: Tuple}
names = fieldnames(T1)
n = fieldcount(TX)
first_rates_type = fieldtype(TX, 1)
values = [
hasfield(first_rates_type, name) ?
:(propagate_variable(
getfield(x, $(QuoteNode(name))), gains,
ntuple(i -> getfield(getfield(x_dot, i), $(QuoteNode(name))), Val($n)),
)) :
:(getfield(x, $(QuoteNode(name))))
for name in names
]
return :(NamedTuple{$names}(($(values...),)))
end

# `submodels` is a named tuple of ModelStateDescriptions.
# `gains` is a tuple of gains.
# `rates_output` is a tuple (one for each gain) of named tuples holding the RatesOutput
# of each of the submodels (for submodels that have such an output).
function propagate_models(submodels::NamedTuple, gains::Tuple, rates_outputs::Tuple)
complete_rates_outputs = map(rates_outputs) do ro
NamedTuple{fieldnames(typeof(submodels))}(
map(fieldnames(typeof(submodels))) do f
if hasfield(typeof(ro), f) # If we have derivatives for this state...
getfield(ro, f) # Get it for all of them.
else
RatesOutput()
end
end
)
end
return map(
(sm, ro...) -> propagate(sm, gains, ro),
submodels, complete_rates_outputs...
)
@generated function propagate_models(
submodels::NamedTuple{names}, gains::Tuple, rates_outputs::TRO,
) where {names, TRO <: Tuple}
n = fieldcount(TRO)
first_models_type = fieldtype(TRO, 1)
values = [
hasfield(first_models_type, name) ?
:(propagate(
getfield(submodels, $(QuoteNode(name))), gains,
ntuple(i -> getfield(getfield(rates_outputs, i), $(QuoteNode(name))), Val($n)),
)) :
:(getfield(submodels, $(QuoteNode(name))))
for name in names
]
return :(NamedTuple{$names}(($(values...),)))
end

function propagate(msd::ModelStateDescription{T}, gains::Tuple, rates_outputs::Tuple) where {T}
Expand Down Expand Up @@ -148,56 +192,54 @@ end
create_solver(options::RungeKutta4Options, msd::ModelStateDescription) = RungeKutta4(options)

get_initial_time_step(solver::RungeKutta4) = solver.options.dt
handles_internal_substepping(::AbstractSolver) = false
handles_internal_substepping(::RungeKutta4) = true

# TODO: It seems like there's a lot about `solve` that could be abstracted and simplified.
function solve(ommd, solver::RungeKutta4, t_last, t_next, msd_km1, rates_fcn, t_end)
dt_solver = solver.options.dt
function rk4_substep(msd_start, t_start, t_stop, msd_with_draws, k1)
# If there's no actual work to do here, skip the calculations.
if t_start == t_stop
return msd_with_draws
else
t_start_f = float(t_start)
t_stop_f = float(t_stop)
dt = t_stop_f - t_start_f
msd2 = propagate(msd_with_draws, dt/2, k1)
k2 = rates_fcn(t_start_f + dt/2, model(msd2))
msd3 = propagate(msd_with_draws, dt/2, k2)
k3 = rates_fcn(t_start_f + dt/2, model(msd3))
msd4 = propagate(msd_with_draws, dt, k3)
k4 = rates_fcn(t_start_f + dt, model(msd4))
msd_stop = msd_with_draws
msd_stop = propagate(msd_stop, dt/6, k1)
msd_stop = propagate(msd_stop, dt/3, k2)
msd_stop = propagate(msd_stop, dt/3, k3)
msd_stop = propagate(msd_stop, dt/6, k4)
return msd_stop
end
end

t_last_f = float(t_last)
t_next_f = float(t_next)

# Make the draws for the continuous-time function.
msd_km1_with_draws = draw_wc(t_last_f, t_next_f, ommd, msd_km1)

# The first derivative is different because it's an output. The rest are ephemeral.
msd1 = msd_km1_with_draws
k1 = rates_fcn(t_last_f, model(msd1))

# If there's no actual work to do here, skip the calculations.
if t_last == t_next

msd_k = msd_km1_with_draws

else

dt = t_next_f - t_last_f
msd2 = propagate(msd1, dt/2, k1)
k2 = rates_fcn(t_last_f + dt/2, model(msd2))
msd3 = propagate(msd1, dt/2, k2)
k3 = rates_fcn(t_last_f + dt/2, model(msd3))
msd4 = propagate(msd1, dt, k3)
k4 = rates_fcn(t_last_f + dt, model(msd4))

# This seems more efficient:
# propagate(
# msd_km1_with_draws,
# (dt/6, dt/3, dt/3, dt/6),
# (k1, k2, k3, k4),
# )

# But this doesn't allocate and is actually slightly faster.
msd_k = msd_km1_with_draws
msd_k = propagate(msd_k, dt/6, k1)
msd_k = propagate(msd_k, dt/3, k2)
msd_k = propagate(msd_k, dt/3, k3)
msd_k = propagate(msd_k, dt/6, k4)

t_sub_last = t_last
t_sub_next = min(t_next, t_sub_last + dt_solver)
first_msd = draw_wc(float(t_sub_last), float(t_sub_next), ommd, msd_km1)
first_rates = rates_fcn(float(t_sub_last), model(first_msd))
msd_k = rk4_substep(msd_km1, t_sub_last, t_sub_next, first_msd, first_rates)

while t_sub_next != t_next
t_sub_last = t_sub_next
t_sub_next = min(t_next, t_sub_last + dt_solver)
msd_with_draws = draw_wc(float(t_sub_last), float(t_sub_next), ommd, msd_k)
k1 = rates_fcn(float(t_sub_last), model(msd_with_draws))
msd_k = rk4_substep(msd_k, t_sub_last, t_sub_next, msd_with_draws, k1)
end

return SolverOutputs(;
t_completed = t_next, # This should already be a rational.
msd_km1 = msd_km1_with_draws,
msd_km1 = first_msd,
msd_k,
rates = k1,
rates = first_rates,
stop = UnknownStopReason(),
t_next_suggested = t_next + solver.options.dt, # Already rational
)
Expand Down
Loading
Loading