diff --git a/src/EnsembleKalmanInversion.jl b/src/EnsembleKalmanInversion.jl index 89c2b3f0f..ac8ba9aa2 100644 --- a/src/EnsembleKalmanInversion.jl +++ b/src/EnsembleKalmanInversion.jl @@ -19,14 +19,13 @@ Provides a failsafe update that - updates the successful ensemble according to the EKI update, - updates the failed ensemble by sampling from the updated successful ensemble. """ -function FailureHandler(process::Inversion, method::SampleSuccGauss) +function FailureHandler(::Inversion, ::SampleSuccGauss) function failsafe_update(ekp, u, g, y, obs_noise_cov, failed_ens) - successful_ens = filter(x -> !(x in failed_ens), collect(1:size(g, 2))) - n_failed = length(failed_ens) - u[:, successful_ens] = - eki_update(ekp, u[:, successful_ens], g[:, successful_ens], y[:, successful_ens], obs_noise_cov) + (failed_ens, sample_transform, sample_dim) = get_correlations(ekp.level_scheduler, failed_ens) + u = eki_update(ekp, u, g, y, obs_noise_cov; ignored_indices = failed_ens) if !isempty(failed_ens) - u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed) + new_samples = sample_empirical_gaussian(ekp.rng, ekp, u, sample_dim; ignored_indices = failed_ens) + u[:, failed_ens] = new_samples[:, sample_transform] end return u end @@ -39,7 +38,8 @@ end u::AbstractMatrix{FT}, g::AbstractMatrix{FT}, y::AbstractMatrix{FT}, - obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}}, + obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}}; + ignored_indices = [], ) where {FT <: Real, IT, CT <: Real} Returns the updated parameter vectors given their current values and @@ -53,14 +53,17 @@ function eki_update( u::AbstractMatrix{FT}, g::AbstractMatrix{FT}, y::AbstractMatrix{FT}, - obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}}, + obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}}; + ignored_indices = [], ) where {FT <: Real, IT, CT <: Real} - cov_est = cov([u; g], dims = 2, corrected = false) # [(N_par + N_obs)×(N_par + N_obs)] + cov_est = compute_cov(ekp, [u; g]; corrected = false, ignored_indices) # [(N_par + N_obs)×(N_par + N_obs)] # Localization cov_localized = ekp.localizer.localize(cov_est) + cov_uu, cov_ug, cov_gg = get_cov_blocks(cov_localized, size(u, 1)) + cov_gg = posdef(cov_gg) # N_obs × N_obs \ [N_obs × N_ens] # --> tmp is [N_obs × N_ens] @@ -108,9 +111,10 @@ function update_ensemble!( # g: N_obs × N_ens u = get_u_final(ekp) N_obs = size(g, 1) - cov_init = cov(u, dims = 2) if ekp.verbose + cov_init = compute_cov(ekp, u; corrected = true) + if get_N_iterations(ekp) == 0 @info "Iteration 0 (prior)" @info "Covariance trace: $(tr(cov_init))" @@ -123,7 +127,9 @@ function update_ensemble!( # Scale noise using Δt scaled_obs_noise_cov = ekp.obs_noise_cov / ekp.Δt[end] - noise = sqrt(scaled_obs_noise_cov) * rand(ekp.rng, MvNormal(zeros(N_obs), I), ekp.N_ens) + independent_noise_dim = get_N_indep(ekp.level_scheduler) + noise = scaled_obs_noise_cov * rand(ekp.rng, MvNormal(zeros(N_obs), I), independent_noise_dim) + noise = transform_noise(ekp.level_scheduler, noise) # Add obs_mean (N_obs) to each column of noise (N_obs × N_ens) if # G is deterministic @@ -143,10 +149,10 @@ function update_ensemble!( # Store error compute_error!(ekp) - # Diagnostics - cov_new = cov(u, dims = 2) - if ekp.verbose + # Diagnostics + cov_new = compute_cov(ekp, u; corrected = true) + @info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))" end diff --git a/src/EnsembleKalmanProcess.jl b/src/EnsembleKalmanProcess.jl index 753e7acb1..175822114 100644 --- a/src/EnsembleKalmanProcess.jl +++ b/src/EnsembleKalmanProcess.jl @@ -32,12 +32,14 @@ abstract type FailureHandlingMethod end # Accelerators abstract type Accelerator end +# Level schedulers +abstract type LevelScheduler end "Failure handling method that ignores forward model failures" struct IgnoreFailures <: FailureHandlingMethod end -"""" +""" SampleSuccGauss <: FailureHandlingMethod Failure handling method that substitutes failed ensemble members by new samples from @@ -130,6 +132,8 @@ struct EnsembleKalmanProcess{ scheduler::LRS "accelerator object that informs EK update steps, stores additional state variables as needed" accelerator::ACC + "" + level_scheduler::LevelScheduler "stored vector of timesteps used in each EK iteration" Δt::Vector{FT} "the particular EK process (`Inversion` or `Sampler` or `Unscented` or `TransformInversion` or `SparseInversion`)" @@ -151,6 +155,7 @@ function EnsembleKalmanProcess( process::P; scheduler::Union{Nothing, LRS} = nothing, accelerator::Union{Nothing, ACC} = nothing, + level_scheduler::Union{Nothing, LS} = nothing, Δt = nothing, rng::AbstractRNG = Random.GLOBAL_RNG, failure_handler_method::FM = IgnoreFailures(), @@ -160,6 +165,7 @@ function EnsembleKalmanProcess( FT <: AbstractFloat, LRS <: LearningRateScheduler, ACC <: Accelerator, + LS <: LevelScheduler, P <: Process, FM <: FailureHandlingMethod, LM <: LocalizationMethod, @@ -221,6 +227,17 @@ function EnsembleKalmanProcess( end end + # set up level scheduler + ls = if isnothing(level_scheduler) + SingleLevelScheduler(N_ens, LevelInfinity()) + else + if !(typeof(process) <: Inversion) + throw(ArgumentError("Only `Inversion` (EKI) can currently be used with multilevel Monte Carlo.")) + end + + level_scheduler + end + # failure handler fh = FailureHandler(process, failure_handler_method) # localizer @@ -239,6 +256,7 @@ function EnsembleKalmanProcess( err, lrs, acc, + ls, Δt, process, rng, @@ -503,20 +521,24 @@ get_error(ekp::EnsembleKalmanProcess) = ekp.err """ sample_empirical_gaussian( rng::AbstractRNG, + ekp::EnsembleKalmanProcess, u::AbstractMatrix{FT}, n::IT; inflation::Union{FT, Nothing} = nothing, + ignored_indices = [], ) where {FT <: Real, IT <: Int} Returns `n` samples from an empirical Gaussian based on point estimates `u`, adding inflation if the covariance is singular. """ function sample_empirical_gaussian( rng::AbstractRNG, + ekp::EnsembleKalmanProcess, u::AbstractMatrix{FT}, n::IT; inflation::Union{FT, Nothing} = nothing, + ignored_indices = [], ) where {FT <: Real, IT <: Int} - cov_u_new = Symmetric(cov(u, dims = 2)) + cov_u_new = Symmetric(posdef(compute_cov(ekp, u; corrected = true, ignored_indices))) if !isposdef(cov_u_new) @warn string("Sample covariance matrix over ensemble is singular.", "\n Applying variance inflation.") if isnothing(inflation) @@ -525,16 +547,18 @@ function sample_empirical_gaussian( end cov_u_new = cov_u_new + inflation * I end - mean_u_new = mean(u, dims = 2) + mean_u_new = compute_mean(ekp, u; ignored_indices) return mean_u_new .+ sqrt(cov_u_new) * rand(rng, MvNormal(zeros(length(mean_u_new[:])), I), n) end function sample_empirical_gaussian( + ekp::EnsembleKalmanProcess, u::AbstractMatrix{FT}, n::IT; inflation::Union{FT, Nothing} = nothing, + ignored_indices = [], ) where {FT <: Real, IT <: Int} - return sample_empirical_gaussian(Random.GLOBAL_RNG, u, n, inflation = inflation) + return sample_empirical_gaussian(Random.GLOBAL_RNG, ekp, u, n; inflation, ignored_indices) end @@ -691,6 +715,8 @@ function update_ensemble!( end +include("SampleStatistics.jl") + ## include the different types of Processes and their exports: # struct Inversion @@ -719,3 +745,6 @@ include("UnscentedKalmanInversion.jl") # struct Accelerator include("Accelerators.jl") + +# Level schedulers +include("Multilevel.jl") diff --git a/src/EnsembleTransformKalmanInversion.jl b/src/EnsembleTransformKalmanInversion.jl index 589f33481..aff13093d 100644 --- a/src/EnsembleTransformKalmanInversion.jl +++ b/src/EnsembleTransformKalmanInversion.jl @@ -32,7 +32,7 @@ function FailureHandler(process::TransformInversion, method::SampleSuccGauss) n_failed = length(failed_ens) u[:, successful_ens] = etki_update(ekp, u[:, successful_ens], g[:, successful_ens], y, obs_noise_cov) if !isempty(failed_ens) - u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed) + u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, ekp, u, n_failed; ignored_indices = failed_ens) end return u end diff --git a/src/Multilevel.jl b/src/Multilevel.jl new file mode 100644 index 000000000..9a4e93baf --- /dev/null +++ b/src/Multilevel.jl @@ -0,0 +1,96 @@ +export SingleLevelScheduler, MultilevelScheduler, get_N_ens, get_N_indep, levels, transform_noise + +struct LevelInfinity end + +const SingleLevelType{IT} = Union{IT, LevelInfinity} + +struct MultilevelScheduler{IT <: Integer} <: LevelScheduler + Js::Dict{IT, IT} + N_indep::IT + N_ens::IT + + function MultilevelScheduler(Js::Dict{IT, IT}) where {IT <: Integer} + N_indep = sum(values(Js)) + N_ens = sum(J * (level == 0 ? 1 : 2) for (level, J) in Js) + + new{IT}(Js, N_indep, N_ens) + end +end + +struct SingleLevelScheduler{IT <: Integer} <: LevelScheduler + N_ens::IT + level::SingleLevelType{IT} + + function SingleLevelScheduler(N_ens::IT, level::SingleLevelType{IT} = LevelInfinity()) where {IT <: Integer} + new{IT}(N_ens, level) + end +end + + +get_N_ens(ms::MultilevelScheduler) = ms.N_ens + +get_N_indep(ms::MultilevelScheduler) = ms.N_indep + +levels(ms::MultilevelScheduler) = begin + vcat( + fill(0, ms.Js[0]), + (fill(l, ms.Js[l]) for l in sort(collect(keys(ms.Js))) if l != 0)..., + (fill(l - 1, ms.Js[l]) for l in sort(collect(keys(ms.Js))) if l != 0)..., + ) +end + +statistic_groups(ms::MultilevelScheduler) = begin + groups = [] + + offset = ms.N_indep - ms.Js[0] + + index = 0 + for level in sort(collect(keys(ms.Js))) + J = ms.Js[level] + push!(groups, (index+1:index+J, 1)) + if level > 0 + push!(groups, (index+offset+1:index+offset+J, -1)) + end + + index += J + end + + groups +end + +transform_noise(ms::MultilevelScheduler, noise::AbstractMatrix{FT}) where {FT <: Real} = begin + @assert size(noise, 2) == ms.N_indep + + noise[:, vcat(1:ms.N_indep, ms.Js[0]+1:ms.N_indep)] +end + +get_correlations(ms::MultilevelScheduler, indices::AbstractVector{IT}) where {IT <: Integer} = begin + num_uncorrelated = 0 + new_indices = map(indices) do i + if i <= ms.Js[0] + num_uncorrelated += 1 + i # There is no correlated index + elseif i <= ms.N_indep + i + (ms.N_indep - ms.Js[0]) + else + i - (ms.N_indep - ms.Js[0]) + end + end + all_indices = sort!(unique!(vcat(indices, new_indices))) + num_correlated = (length(all_indices) - num_uncorrelated) ÷ 2 + noise_dim = num_correlated + num_uncorrelated + all_indices, hcat(1:num_uncorrelated, num_uncorrelated+1:noise_dim, num_uncorrelated+1:noise_dim), noise_dim +end + + +get_N_ens(sls::SingleLevelScheduler) = sls.N_ens + +get_N_indep(sls::SingleLevelScheduler) = sls.N_ens + +levels(sls::SingleLevelScheduler) = fill(sls.level, sls.N_ens) + +statistic_groups(sls::SingleLevelScheduler) = [(1:sls.N_ens, 1)] + +transform_noise(::SingleLevelScheduler, noise::AbstractMatrix{FT}) where {FT <: Real} = noise + +get_correlations(::SingleLevelScheduler, indices::AbstractVector{IT}) where {IT <: Integer} = (indices, 1:length(indices), length(indices)) diff --git a/src/SampleStatistics.jl b/src/SampleStatistics.jl new file mode 100644 index 000000000..4ca1b4f39 --- /dev/null +++ b/src/SampleStatistics.jl @@ -0,0 +1,24 @@ +# included in EnsembleKalmanProcess.jl + +export compute_mean, compute_cov + +function posdef(mat) + S, V = eigen(mat) + V = V[:, (S .> 0)] + S = S[S .> 0] + V * diagm(S) * V' +end + +function compute_mean(ekp::EnsembleKalmanProcess, x; ignored_indices = []) + reduce(statistic_groups(ekp.level_scheduler); init = 0) do acc, (indices, multiplier) + indices = setdiff(indices, ignored_indices) + multiplier * mean(x[:, indices]; dims = 2) .+ acc + end +end + +function compute_cov(ekp::EnsembleKalmanProcess, x; corrected, ignored_indices = []) + reduce(statistic_groups(ekp.level_scheduler); init = 0) do acc, (indices, multiplier) + indices = setdiff(indices, ignored_indices) + multiplier * cov(x[:, indices]; corrected, dims = 2) .+ acc + end +end diff --git a/src/SparseEnsembleKalmanInversion.jl b/src/SparseEnsembleKalmanInversion.jl index ce9262158..ebe9f1bac 100644 --- a/src/SparseEnsembleKalmanInversion.jl +++ b/src/SparseEnsembleKalmanInversion.jl @@ -56,7 +56,7 @@ function FailureHandler(process::SparseInversion, method::SampleSuccGauss) u[:, successful_ens] = sparse_eki_update(ekp, u[:, successful_ens], g[:, successful_ens], y[:, successful_ens], obs_noise_cov) if !isempty(failed_ens) - u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed) + u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, ekp, u, n_failed; ignored_indices = failed_ens) end return u end diff --git a/test/EnsembleKalmanProcess/runtests.jl b/test/EnsembleKalmanProcess/runtests.jl index 9e9ab4963..ebaabbe34 100644 --- a/test/EnsembleKalmanProcess/runtests.jl +++ b/test/EnsembleKalmanProcess/runtests.jl @@ -960,7 +960,9 @@ end rng = Random.MersenneTwister(rng_seed) u = rand(10, 4) + ekp = EKP.EnsembleKalmanProcess(u, [1.;], [1.;;], Inversion()) @test_logs (:warn, r"Sample covariance matrix over ensemble is singular.") match_mode = :any sample_empirical_gaussian( + ekp, u, 2, ) @@ -968,8 +970,8 @@ end u2 = rand(rng, 5, 20) @test all( isapprox.( - sample_empirical_gaussian(copy(rng), u2, 2), - sample_empirical_gaussian(copy(rng), u2, 2, inflation = 0.0); + sample_empirical_gaussian(copy(rng), ekp, u2, 2), + sample_empirical_gaussian(copy(rng), ekp, u2, 2, inflation = 0.0); atol = 1e-8, ), ) diff --git a/test/Multilevel/runtests.jl b/test/Multilevel/runtests.jl new file mode 100644 index 000000000..a2ac456f8 --- /dev/null +++ b/test/Multilevel/runtests.jl @@ -0,0 +1,94 @@ +using Distributions +using LinearAlgebra +using Random +using Test + +using EnsembleKalmanProcesses +using EnsembleKalmanProcesses.ParameterDistributions +const EKP = EnsembleKalmanProcesses + +Random.seed!(123) +artificial_noise = randn(2) + +forward_model(u; level) = begin + p(x) = u[2]*x + exp(-u[1])*(-x^2/2 + x/2) + exact_solution = [p(.25); p(.75)] + exact_solution + u.^2/norm(u.^2) .* artificial_noise / (10 * 2^(level+1)) +end + +@testset "Multilevel" begin + # Seed for pseudo-random number generator + rng_seed = 42 + rng = Random.MersenneTwister(rng_seed) + + priors = [ + ParameterDistribution(Parameterized(Normal(-3, 1)), no_constraint(), "u1"), + ParameterDistribution(Parameterized(Normal(105, 5)), no_constraint(), "u2"), + ] + prior = combine_distributions(priors) + + y = [27.5; 79.7] + Γ = 0.01 * I + N_iter = 10 + lrs = DefaultScheduler(1) + + # Approximate mean-field limit + println("Approximating mean field") + N_ens = 200_000 + level = 30 + initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens) + eki = EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, scheduler=lrs) + for i in 1:N_iter + u = get_u_final(eki) + g_ens = hcat((forward_model(u[:,j]; level) for j in 1:N_ens)...) + EKP.update_ensemble!(eki, g_ens) + end + mean_field_limit_approx_mean = compute_mean(eki, get_u_final(eki)) + + # Single-level approximation + println("Approximating single-level") + num_avg = 5 + single_level_cost = 2^20 + single_level_errors = map(4:10) do level + N_ens = floor(Int, single_level_cost / 2^level) + mean(1:num_avg) do _ + initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens) + eki = EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, scheduler=lrs) + for i in 1:N_iter + u = get_u_final(eki) + g_ens = hcat((forward_model(u[:,j]; level) for j in 1:N_ens)...) + EKP.update_ensemble!(eki, g_ens) + end + norm(mean_field_limit_approx_mean - compute_mean(eki, get_u_final(eki))) + end + end + + # Multilevel approximation + println("Approximating multilevel") + max_level = 9 + Js = Dict(level => floor(Int, 20 * 2^((max_level - level) * 4/3)) for level in 0:max_level) + level_scheduler = MultilevelScheduler(Js) + N_ens = get_N_ens(level_scheduler) + num_avg = 5 + multilevel_error = mean(1:num_avg) do _ + initial_ensemble = EKP.construct_initial_ensemble(rng, prior, get_N_indep(level_scheduler)) + initial_ensemble = transform_noise(level_scheduler, initial_ensemble) + eki = EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng, level_scheduler, scheduler=lrs) + for i in 1:N_iter + u = get_u_final(eki) + g_ens = hcat((forward_model(u[:,j]; level) for (j, level) in zip(1:N_ens, levels(level_scheduler)))...) + EKP.update_ensemble!(eki, g_ens) + end + norm(mean_field_limit_approx_mean - compute_mean(eki, get_u_final(eki))) + end + multilevel_cost = reduce(Js; init = 0) do acc, (level, J) + acc + J * 2^level + (level == 0 ? 0 : J * 2^(level - 1)) + end + + println(multilevel_cost, " ", multilevel_error) + println(single_level_cost, " ", single_level_errors) + @test multilevel_cost < 0.5 * single_level_cost + for single_level_error in single_level_errors + @test multilevel_error < single_level_error + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 71ce1c39b..d1da0ef2d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,7 @@ end "TOMLInterface", "SparseInversion", "Inflation", + "Multilevel", ] if all_tests || has_submodule(submodule) || "EnsembleKalmanProcesses" in ARGS include_test(submodule)