From e2c9835b7b95135a6a490527f5ba2194ca9d5f7a Mon Sep 17 00:00:00 2001 From: Arne Bouillon Date: Fri, 3 Nov 2023 14:52:34 +0100 Subject: [PATCH 1/3] Remove redundant symbol --- src/EnsembleKalmanProcess.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/EnsembleKalmanProcess.jl b/src/EnsembleKalmanProcess.jl index 753e7acb1..54413b730 100644 --- a/src/EnsembleKalmanProcess.jl +++ b/src/EnsembleKalmanProcess.jl @@ -37,7 +37,7 @@ abstract type Accelerator end "Failure handling method that ignores forward model failures" struct IgnoreFailures <: FailureHandlingMethod end -"""" +""" SampleSuccGauss <: FailureHandlingMethod Failure handling method that substitutes failed ensemble members by new samples from From 15908c9bbdbf5c6db85cdf603083ef04ea2c11cb Mon Sep 17 00:00:00 2001 From: Arne Bouillon Date: Thu, 9 Nov 2023 11:52:16 +0100 Subject: [PATCH 2/3] Implement proof of concept for MLMC This somehow works? --- src/EnsembleKalmanInversion.jl | 17 +++--- src/EnsembleKalmanProcess.jl | 19 +++++++ src/Multilevel.jl | 76 +++++++++++++++++++++++++++ src/SampleStatistics.jl | 22 ++++++++ test/Multilevel/runtests.jl | 94 ++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 6 files changed, 223 insertions(+), 6 deletions(-) create mode 100644 src/Multilevel.jl create mode 100644 src/SampleStatistics.jl create mode 100644 test/Multilevel/runtests.jl diff --git a/src/EnsembleKalmanInversion.jl b/src/EnsembleKalmanInversion.jl index 89c2b3f0f..7c3b20c28 100644 --- a/src/EnsembleKalmanInversion.jl +++ b/src/EnsembleKalmanInversion.jl @@ -56,11 +56,13 @@ function eki_update( obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}}, ) where {FT <: Real, IT, CT <: Real} - cov_est = cov([u; g], dims = 2, corrected = false) # [(N_par + N_obs)×(N_par + N_obs)] + cov_est = compute_cov(ekp, u, g; corrected = false) # [(N_par + N_obs)×(N_par + N_obs)] # Localization cov_localized = ekp.localizer.localize(cov_est) + cov_uu, cov_ug, cov_gg = get_cov_blocks(cov_localized, size(u, 1)) + cov_gg = posdef(cov_gg) # N_obs × N_obs \ [N_obs × N_ens] # --> tmp is [N_obs × N_ens] @@ -108,9 +110,10 @@ function update_ensemble!( # g: N_obs × N_ens u = get_u_final(ekp) N_obs = size(g, 1) - cov_init = cov(u, dims = 2) if ekp.verbose + cov_init = get_cov_blocks(compute_cov(ekp, u, g; corrected = true))[1] + if get_N_iterations(ekp) == 0 @info "Iteration 0 (prior)" @info "Covariance trace: $(tr(cov_init))" @@ -123,7 +126,9 @@ function update_ensemble!( # Scale noise using Δt scaled_obs_noise_cov = ekp.obs_noise_cov / ekp.Δt[end] - noise = sqrt(scaled_obs_noise_cov) * rand(ekp.rng, MvNormal(zeros(N_obs), I), ekp.N_ens) + independent_noise_dim = get_N_indep(ekp.level_scheduler) + noise = scaled_obs_noise_cov * rand(ekp.rng, MvNormal(zeros(N_obs), I), independent_noise_dim) + noise = transform_noise(ekp.level_scheduler, noise) # Add obs_mean (N_obs) to each column of noise (N_obs × N_ens) if # G is deterministic @@ -143,10 +148,10 @@ function update_ensemble!( # Store error compute_error!(ekp) - # Diagnostics - cov_new = cov(u, dims = 2) - if ekp.verbose + # Diagnostics + cov_new = get_cov_blocks(compute_cov(ekp, u, g; corrected = true))[1] + @info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))" end diff --git a/src/EnsembleKalmanProcess.jl b/src/EnsembleKalmanProcess.jl index 54413b730..c3c95e221 100644 --- a/src/EnsembleKalmanProcess.jl +++ b/src/EnsembleKalmanProcess.jl @@ -32,6 +32,8 @@ abstract type FailureHandlingMethod end # Accelerators abstract type Accelerator end +# Level schedulers +abstract type LevelScheduler end "Failure handling method that ignores forward model failures" @@ -130,6 +132,8 @@ struct EnsembleKalmanProcess{ scheduler::LRS "accelerator object that informs EK update steps, stores additional state variables as needed" accelerator::ACC + "" + level_scheduler::LevelScheduler "stored vector of timesteps used in each EK iteration" Δt::Vector{FT} "the particular EK process (`Inversion` or `Sampler` or `Unscented` or `TransformInversion` or `SparseInversion`)" @@ -151,6 +155,7 @@ function EnsembleKalmanProcess( process::P; scheduler::Union{Nothing, LRS} = nothing, accelerator::Union{Nothing, ACC} = nothing, + level_scheduler::Union{Nothing, LS} = nothing, Δt = nothing, rng::AbstractRNG = Random.GLOBAL_RNG, failure_handler_method::FM = IgnoreFailures(), @@ -160,6 +165,7 @@ function EnsembleKalmanProcess( FT <: AbstractFloat, LRS <: LearningRateScheduler, ACC <: Accelerator, + LS <: LevelScheduler, P <: Process, FM <: FailureHandlingMethod, LM <: LocalizationMethod, @@ -221,6 +227,13 @@ function EnsembleKalmanProcess( end end + # set up level scheduler + ls = if isnothing(level_scheduler) + SingleLevelScheduler(N_ens, LevelInfinity()) + else + level_scheduler + end + # failure handler fh = FailureHandler(process, failure_handler_method) # localizer @@ -239,6 +252,7 @@ function EnsembleKalmanProcess( err, lrs, acc, + ls, Δt, process, rng, @@ -691,6 +705,8 @@ function update_ensemble!( end +include("SampleStatistics.jl") + ## include the different types of Processes and their exports: # struct Inversion @@ -719,3 +735,6 @@ include("UnscentedKalmanInversion.jl") # struct Accelerator include("Accelerators.jl") + +# Level schedulers +include("Multilevel.jl") diff --git a/src/Multilevel.jl b/src/Multilevel.jl new file mode 100644 index 000000000..34230c859 --- /dev/null +++ b/src/Multilevel.jl @@ -0,0 +1,76 @@ +export SingleLevelScheduler, MultilevelScheduler, get_N_ens, get_N_indep, levels, transform_noise + +struct LevelInfinity end + +const SingleLevelType{IT} = Union{IT, LevelInfinity} + +struct MultilevelScheduler{IT <: Integer} <: LevelScheduler + Js::Dict{IT, IT} + N_indep::IT + N_ens::IT + + function MultilevelScheduler(Js::Dict{IT, IT}) where {IT <: Integer} + N_indep = sum(values(Js)) + N_ens = sum(J * (level == 0 ? 1 : 2) for (level, J) in Js) + + new{IT}(Js, N_indep, N_ens) + end +end + +struct SingleLevelScheduler{IT <: Integer} <: LevelScheduler + N_ens::IT + level::SingleLevelType{IT} + + function SingleLevelScheduler(N_ens::IT, level::SingleLevelType{IT} = LevelInfinity()) where {IT <: Integer} + new{IT}(N_ens, level) + end +end + + +get_N_ens(ms::MultilevelScheduler) = ms.N_ens + +get_N_indep(ms::MultilevelScheduler) = ms.N_indep + +levels(ms::MultilevelScheduler) = begin + vcat( + fill(0, ms.Js[0]), + (fill(l, ms.Js[l]) for l in sort(collect(keys(ms.Js))) if l != 0)..., + (fill(l - 1, ms.Js[l]) for l in sort(collect(keys(ms.Js))) if l != 0)..., + ) +end + +transform_noise(ms::MultilevelScheduler, noise::AbstractMatrix{FT}) where {FT <: Real} = begin + @assert size(noise, 2) == ms.N_indep + + noise[:, vcat(1:ms.N_indep, ms.Js[0]+1:ms.N_indep)] +end + +statistic_groups(ms::MultilevelScheduler) = begin + groups = [] + + offset = ms.N_indep - ms.Js[0] + + index = 0; + for level in sort(collect(keys(ms.Js))) + J = ms.Js[level] + push!(groups, (index+1:index+J, 1)) + if level > 0 + push!(groups, (index+offset+1:index+offset+J, -1)) + end + + index += J + end + + groups +end + + +get_N_ens(sls::SingleLevelScheduler) = sls.N_ens + +get_N_indep(sls::SingleLevelScheduler) = sls.N_ens + +levels(sls::SingleLevelScheduler) = fill(sls.level, sls.N_ens) + +transform_noise(sls::SingleLevelScheduler, noise::AbstractMatrix{FT}) where {FT <: Real} = noise + +statistic_groups(sls::SingleLevelScheduler) = [(1:sls.N_ens, 1)] diff --git a/src/SampleStatistics.jl b/src/SampleStatistics.jl new file mode 100644 index 000000000..c8bc5e3c2 --- /dev/null +++ b/src/SampleStatistics.jl @@ -0,0 +1,22 @@ +# included in EnsembleKalmanProcess.jl + +export compute_mean, compute_cov + +function posdef(mat) + S, V = eigen(mat) + V = V[:, (S .> 0)] + S = S[S .> 0] + V * diagm(S) * V' +end + +function compute_mean(ekp::EnsembleKalmanProcess, u) + foldl(statistic_groups(ekp.level_scheduler); init = 0) do acc, (indices, multiplier) + multiplier * mean(u[:, indices]; dims = 2) .+ acc + end +end + +function compute_cov(ekp::EnsembleKalmanProcess, u, g; corrected) + foldl(statistic_groups(ekp.level_scheduler); init = 0) do acc, (indices, multiplier) + multiplier * cov([u; g][:, indices]; corrected = false, dims = 2) .+ acc + end +end diff --git a/test/Multilevel/runtests.jl b/test/Multilevel/runtests.jl new file mode 100644 index 000000000..e337bc1fa --- /dev/null +++ b/test/Multilevel/runtests.jl @@ -0,0 +1,94 @@ +using Distributions +using LinearAlgebra +using Random +using Test + +using EnsembleKalmanProcesses +using EnsembleKalmanProcesses.ParameterDistributions +const EKP = EnsembleKalmanProcesses + +Random.seed!(123) +artificial_noise = randn(2) + +forward_model(u; level) = begin + p(x) = u[2]*x + exp(-u[1])*(-x^2/2 + x/2) + exact_solution = [p(.25); p(.75)] + exact_solution + u.^2/norm(u.^2) .* artificial_noise / (10 * 2^(level+1)) +end + +@testset "Multilevel" begin + # Seed for pseudo-random number generator + rng_seed = 42 + rng = Random.MersenneTwister(rng_seed) + + priors = [ + ParameterDistribution(Parameterized(Normal(-3, 1)), no_constraint(), "u1"), + ParameterDistribution(Parameterized(Normal(105, 5)), no_constraint(), "u2"), + ] + prior = combine_distributions(priors) + + y = [27.5; 79.7] + Γ = 0.01 * I + N_iter = 10 + lrs = DefaultScheduler(1) + + # Approximate mean-field limit + println("Approximating mean field") + N_ens = 200_000 + level = 30 + initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens) + eki = EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, scheduler=lrs) + for i in 1:N_iter + u = get_u_final(eki) + g_ens = hcat((forward_model(u[:,j]; level) for j in 1:N_ens)...) + EKP.update_ensemble!(eki, g_ens) + end + mean_field_limit_approx_mean = compute_mean(eki, get_u_final(eki)) + + # Single-level approximation + println("Approximating single-level") + num_avg = 5 + single_level_cost = 2^20 + single_level_errors = map(4:10) do level + N_ens = floor(Int, single_level_cost / 2^level) + mean(1:num_avg) do _ + initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens) + eki = EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, scheduler=lrs) + for i in 1:N_iter + u = get_u_final(eki) + g_ens = hcat((forward_model(u[:,j]; level) for j in 1:N_ens)...) + EKP.update_ensemble!(eki, g_ens) + end + norm(mean_field_limit_approx_mean - compute_mean(eki, get_u_final(eki))) + end + end + + # Multilevel approximation + println("Approximating multilevel") + max_level = 9 + Js = Dict(level => floor(Int, 20 * 2^((max_level - level) * 4/3)) for level in 0:max_level) + level_scheduler = MultilevelScheduler(Js) + N_ens = get_N_ens(level_scheduler) + num_avg = 5 + multilevel_error = mean(1:num_avg) do _ + initial_ensemble = EKP.construct_initial_ensemble(rng, prior, get_N_indep(level_scheduler)) + initial_ensemble = transform_noise(level_scheduler, initial_ensemble) + eki = EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng, level_scheduler, scheduler=lrs) + for i in 1:N_iter + u = get_u_final(eki) + g_ens = hcat((forward_model(u[:,j]; level) for (j, level) in zip(1:N_ens, levels(level_scheduler)))...) + EKP.update_ensemble!(eki, g_ens) + end + norm(mean_field_limit_approx_mean - compute_mean(eki, get_u_final(eki))) + end + multilevel_cost = reduce(Js; init = 0) do acc, (level, J) + acc + J * 2^level + (level == 0 ? 0 : J * 2^(level - 1)) + end + + println(multilevel_cost, " ", multilevel_error) + println(single_level_cost, " ", single_level_errors) + @assert multilevel_cost < 0.5 * single_level_cost + for single_level_error in single_level_errors + @assert multilevel_error < 0.5 * single_level_error + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 71ce1c39b..d1da0ef2d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,7 @@ end "TOMLInterface", "SparseInversion", "Inflation", + "Multilevel", ] if all_tests || has_submodule(submodule) || "EnsembleKalmanProcesses" in ARGS include_test(submodule) From 4033e5241b4b47e346635f90dedf9b65f7fb80ac Mon Sep 17 00:00:00 2001 From: Arne Bouillon Date: Thu, 9 Nov 2023 15:59:03 +0100 Subject: [PATCH 3/3] Currently limit Multilevel to EKI --- src/EnsembleKalmanInversion.jl | 23 ++++++++------- src/EnsembleKalmanProcess.jl | 16 +++++++++-- src/EnsembleTransformKalmanInversion.jl | 2 +- src/Multilevel.jl | 38 +++++++++++++++++++------ src/SampleStatistics.jl | 14 +++++---- src/SparseEnsembleKalmanInversion.jl | 2 +- test/EnsembleKalmanProcess/runtests.jl | 6 ++-- test/Multilevel/runtests.jl | 4 +-- 8 files changed, 70 insertions(+), 35 deletions(-) diff --git a/src/EnsembleKalmanInversion.jl b/src/EnsembleKalmanInversion.jl index 7c3b20c28..ac8ba9aa2 100644 --- a/src/EnsembleKalmanInversion.jl +++ b/src/EnsembleKalmanInversion.jl @@ -19,14 +19,13 @@ Provides a failsafe update that - updates the successful ensemble according to the EKI update, - updates the failed ensemble by sampling from the updated successful ensemble. """ -function FailureHandler(process::Inversion, method::SampleSuccGauss) +function FailureHandler(::Inversion, ::SampleSuccGauss) function failsafe_update(ekp, u, g, y, obs_noise_cov, failed_ens) - successful_ens = filter(x -> !(x in failed_ens), collect(1:size(g, 2))) - n_failed = length(failed_ens) - u[:, successful_ens] = - eki_update(ekp, u[:, successful_ens], g[:, successful_ens], y[:, successful_ens], obs_noise_cov) + (failed_ens, sample_transform, sample_dim) = get_correlations(ekp.level_scheduler, failed_ens) + u = eki_update(ekp, u, g, y, obs_noise_cov; ignored_indices = failed_ens) if !isempty(failed_ens) - u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed) + new_samples = sample_empirical_gaussian(ekp.rng, ekp, u, sample_dim; ignored_indices = failed_ens) + u[:, failed_ens] = new_samples[:, sample_transform] end return u end @@ -39,7 +38,8 @@ end u::AbstractMatrix{FT}, g::AbstractMatrix{FT}, y::AbstractMatrix{FT}, - obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}}, + obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}}; + ignored_indices = [], ) where {FT <: Real, IT, CT <: Real} Returns the updated parameter vectors given their current values and @@ -53,10 +53,11 @@ function eki_update( u::AbstractMatrix{FT}, g::AbstractMatrix{FT}, y::AbstractMatrix{FT}, - obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}}, + obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}}; + ignored_indices = [], ) where {FT <: Real, IT, CT <: Real} - cov_est = compute_cov(ekp, u, g; corrected = false) # [(N_par + N_obs)×(N_par + N_obs)] + cov_est = compute_cov(ekp, [u; g]; corrected = false, ignored_indices) # [(N_par + N_obs)×(N_par + N_obs)] # Localization cov_localized = ekp.localizer.localize(cov_est) @@ -112,7 +113,7 @@ function update_ensemble!( N_obs = size(g, 1) if ekp.verbose - cov_init = get_cov_blocks(compute_cov(ekp, u, g; corrected = true))[1] + cov_init = compute_cov(ekp, u; corrected = true) if get_N_iterations(ekp) == 0 @info "Iteration 0 (prior)" @@ -150,7 +151,7 @@ function update_ensemble!( if ekp.verbose # Diagnostics - cov_new = get_cov_blocks(compute_cov(ekp, u, g; corrected = true))[1] + cov_new = compute_cov(ekp, u; corrected = true) @info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))" end diff --git a/src/EnsembleKalmanProcess.jl b/src/EnsembleKalmanProcess.jl index c3c95e221..175822114 100644 --- a/src/EnsembleKalmanProcess.jl +++ b/src/EnsembleKalmanProcess.jl @@ -231,6 +231,10 @@ function EnsembleKalmanProcess( ls = if isnothing(level_scheduler) SingleLevelScheduler(N_ens, LevelInfinity()) else + if !(typeof(process) <: Inversion) + throw(ArgumentError("Only `Inversion` (EKI) can currently be used with multilevel Monte Carlo.")) + end + level_scheduler end @@ -517,20 +521,24 @@ get_error(ekp::EnsembleKalmanProcess) = ekp.err """ sample_empirical_gaussian( rng::AbstractRNG, + ekp::EnsembleKalmanProcess, u::AbstractMatrix{FT}, n::IT; inflation::Union{FT, Nothing} = nothing, + ignored_indices = [], ) where {FT <: Real, IT <: Int} Returns `n` samples from an empirical Gaussian based on point estimates `u`, adding inflation if the covariance is singular. """ function sample_empirical_gaussian( rng::AbstractRNG, + ekp::EnsembleKalmanProcess, u::AbstractMatrix{FT}, n::IT; inflation::Union{FT, Nothing} = nothing, + ignored_indices = [], ) where {FT <: Real, IT <: Int} - cov_u_new = Symmetric(cov(u, dims = 2)) + cov_u_new = Symmetric(posdef(compute_cov(ekp, u; corrected = true, ignored_indices))) if !isposdef(cov_u_new) @warn string("Sample covariance matrix over ensemble is singular.", "\n Applying variance inflation.") if isnothing(inflation) @@ -539,16 +547,18 @@ function sample_empirical_gaussian( end cov_u_new = cov_u_new + inflation * I end - mean_u_new = mean(u, dims = 2) + mean_u_new = compute_mean(ekp, u; ignored_indices) return mean_u_new .+ sqrt(cov_u_new) * rand(rng, MvNormal(zeros(length(mean_u_new[:])), I), n) end function sample_empirical_gaussian( + ekp::EnsembleKalmanProcess, u::AbstractMatrix{FT}, n::IT; inflation::Union{FT, Nothing} = nothing, + ignored_indices = [], ) where {FT <: Real, IT <: Int} - return sample_empirical_gaussian(Random.GLOBAL_RNG, u, n, inflation = inflation) + return sample_empirical_gaussian(Random.GLOBAL_RNG, ekp, u, n; inflation, ignored_indices) end diff --git a/src/EnsembleTransformKalmanInversion.jl b/src/EnsembleTransformKalmanInversion.jl index 589f33481..aff13093d 100644 --- a/src/EnsembleTransformKalmanInversion.jl +++ b/src/EnsembleTransformKalmanInversion.jl @@ -32,7 +32,7 @@ function FailureHandler(process::TransformInversion, method::SampleSuccGauss) n_failed = length(failed_ens) u[:, successful_ens] = etki_update(ekp, u[:, successful_ens], g[:, successful_ens], y, obs_noise_cov) if !isempty(failed_ens) - u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed) + u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, ekp, u, n_failed; ignored_indices = failed_ens) end return u end diff --git a/src/Multilevel.jl b/src/Multilevel.jl index 34230c859..9a4e93baf 100644 --- a/src/Multilevel.jl +++ b/src/Multilevel.jl @@ -39,18 +39,12 @@ levels(ms::MultilevelScheduler) = begin ) end -transform_noise(ms::MultilevelScheduler, noise::AbstractMatrix{FT}) where {FT <: Real} = begin - @assert size(noise, 2) == ms.N_indep - - noise[:, vcat(1:ms.N_indep, ms.Js[0]+1:ms.N_indep)] -end - statistic_groups(ms::MultilevelScheduler) = begin groups = [] offset = ms.N_indep - ms.Js[0] - index = 0; + index = 0 for level in sort(collect(keys(ms.Js))) J = ms.Js[level] push!(groups, (index+1:index+J, 1)) @@ -64,6 +58,30 @@ statistic_groups(ms::MultilevelScheduler) = begin groups end +transform_noise(ms::MultilevelScheduler, noise::AbstractMatrix{FT}) where {FT <: Real} = begin + @assert size(noise, 2) == ms.N_indep + + noise[:, vcat(1:ms.N_indep, ms.Js[0]+1:ms.N_indep)] +end + +get_correlations(ms::MultilevelScheduler, indices::AbstractVector{IT}) where {IT <: Integer} = begin + num_uncorrelated = 0 + new_indices = map(indices) do i + if i <= ms.Js[0] + num_uncorrelated += 1 + i # There is no correlated index + elseif i <= ms.N_indep + i + (ms.N_indep - ms.Js[0]) + else + i - (ms.N_indep - ms.Js[0]) + end + end + all_indices = sort!(unique!(vcat(indices, new_indices))) + num_correlated = (length(all_indices) - num_uncorrelated) ÷ 2 + noise_dim = num_correlated + num_uncorrelated + all_indices, hcat(1:num_uncorrelated, num_uncorrelated+1:noise_dim, num_uncorrelated+1:noise_dim), noise_dim +end + get_N_ens(sls::SingleLevelScheduler) = sls.N_ens @@ -71,6 +89,8 @@ get_N_indep(sls::SingleLevelScheduler) = sls.N_ens levels(sls::SingleLevelScheduler) = fill(sls.level, sls.N_ens) -transform_noise(sls::SingleLevelScheduler, noise::AbstractMatrix{FT}) where {FT <: Real} = noise - statistic_groups(sls::SingleLevelScheduler) = [(1:sls.N_ens, 1)] + +transform_noise(::SingleLevelScheduler, noise::AbstractMatrix{FT}) where {FT <: Real} = noise + +get_correlations(::SingleLevelScheduler, indices::AbstractVector{IT}) where {IT <: Integer} = (indices, 1:length(indices), length(indices)) diff --git a/src/SampleStatistics.jl b/src/SampleStatistics.jl index c8bc5e3c2..4ca1b4f39 100644 --- a/src/SampleStatistics.jl +++ b/src/SampleStatistics.jl @@ -9,14 +9,16 @@ function posdef(mat) V * diagm(S) * V' end -function compute_mean(ekp::EnsembleKalmanProcess, u) - foldl(statistic_groups(ekp.level_scheduler); init = 0) do acc, (indices, multiplier) - multiplier * mean(u[:, indices]; dims = 2) .+ acc +function compute_mean(ekp::EnsembleKalmanProcess, x; ignored_indices = []) + reduce(statistic_groups(ekp.level_scheduler); init = 0) do acc, (indices, multiplier) + indices = setdiff(indices, ignored_indices) + multiplier * mean(x[:, indices]; dims = 2) .+ acc end end -function compute_cov(ekp::EnsembleKalmanProcess, u, g; corrected) - foldl(statistic_groups(ekp.level_scheduler); init = 0) do acc, (indices, multiplier) - multiplier * cov([u; g][:, indices]; corrected = false, dims = 2) .+ acc +function compute_cov(ekp::EnsembleKalmanProcess, x; corrected, ignored_indices = []) + reduce(statistic_groups(ekp.level_scheduler); init = 0) do acc, (indices, multiplier) + indices = setdiff(indices, ignored_indices) + multiplier * cov(x[:, indices]; corrected, dims = 2) .+ acc end end diff --git a/src/SparseEnsembleKalmanInversion.jl b/src/SparseEnsembleKalmanInversion.jl index ce9262158..ebe9f1bac 100644 --- a/src/SparseEnsembleKalmanInversion.jl +++ b/src/SparseEnsembleKalmanInversion.jl @@ -56,7 +56,7 @@ function FailureHandler(process::SparseInversion, method::SampleSuccGauss) u[:, successful_ens] = sparse_eki_update(ekp, u[:, successful_ens], g[:, successful_ens], y[:, successful_ens], obs_noise_cov) if !isempty(failed_ens) - u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed) + u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, ekp, u, n_failed; ignored_indices = failed_ens) end return u end diff --git a/test/EnsembleKalmanProcess/runtests.jl b/test/EnsembleKalmanProcess/runtests.jl index 9e9ab4963..ebaabbe34 100644 --- a/test/EnsembleKalmanProcess/runtests.jl +++ b/test/EnsembleKalmanProcess/runtests.jl @@ -960,7 +960,9 @@ end rng = Random.MersenneTwister(rng_seed) u = rand(10, 4) + ekp = EKP.EnsembleKalmanProcess(u, [1.;], [1.;;], Inversion()) @test_logs (:warn, r"Sample covariance matrix over ensemble is singular.") match_mode = :any sample_empirical_gaussian( + ekp, u, 2, ) @@ -968,8 +970,8 @@ end u2 = rand(rng, 5, 20) @test all( isapprox.( - sample_empirical_gaussian(copy(rng), u2, 2), - sample_empirical_gaussian(copy(rng), u2, 2, inflation = 0.0); + sample_empirical_gaussian(copy(rng), ekp, u2, 2), + sample_empirical_gaussian(copy(rng), ekp, u2, 2, inflation = 0.0); atol = 1e-8, ), ) diff --git a/test/Multilevel/runtests.jl b/test/Multilevel/runtests.jl index e337bc1fa..a2ac456f8 100644 --- a/test/Multilevel/runtests.jl +++ b/test/Multilevel/runtests.jl @@ -87,8 +87,8 @@ end println(multilevel_cost, " ", multilevel_error) println(single_level_cost, " ", single_level_errors) - @assert multilevel_cost < 0.5 * single_level_cost + @test multilevel_cost < 0.5 * single_level_cost for single_level_error in single_level_errors - @assert multilevel_error < 0.5 * single_level_error + @test multilevel_error < single_level_error end end