From a753731486f1bd672ecfeeca9a345e6e76297318 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Wed, 6 May 2026 13:29:30 +0200 Subject: [PATCH 1/6] Share same cache among similar operators in `AddedOperator` Co-authored-by: Copilot --- src/SciMLOperators.jl | 4 +++- src/basic.jl | 38 +++++++++++++++++++++++++---- src/func.jl | 7 ++++++ src/interface.jl | 56 +++++++++++++++++++++++++++++++++++++++++++ src/tensor.jl | 24 +++++++++++++++++++ test/basic.jl | 50 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 174 insertions(+), 5 deletions(-) diff --git a/src/SciMLOperators.jl b/src/SciMLOperators.jl index 2a756e71..50e859ab 100644 --- a/src/SciMLOperators.jl +++ b/src/SciMLOperators.jl @@ -234,7 +234,9 @@ export export update_coefficients!, update_coefficients, isconstant, iscached, - cache_operator, issquare, + cache_operator, cache_operator_hinted, + update_cache, + issquare, islinear, concretize, isconvertible, has_adjoint, diff --git a/src/basic.jl b/src/basic.jl index c7a34060..9c500a4f 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -644,13 +644,30 @@ end islinear(L::AddedOperator) = all(islinear, getops(L)) Base.iszero(L::AddedOperator) = all(iszero, getops(L)) has_adjoint(L::AddedOperator) = all(has_adjoint, L.ops) +LinearAlgebra.ishermitian(L::AddedOperator) = all(ishermitian, L.ops) @generated function cache_internals(L::AddedOperator, v::AbstractVecOrMat) ops_types = L.parameters[2].parameters N = length(ops_types) + + # If multiple sub-operators share the same outermost type constructor (wrapper), we can cache one of them and reuse the cache for the others. This is because operators with the same wrapper will have the same caching structure, so we can avoid redundant caching work. The `donor` tuple identifies which operator's cache to use for each sub-operator. + + donor = ntuple(i -> findfirst(j -> ops_types[j].name.wrapper === ops_types[i].name.wrapper, 1:N), N) + + # Unique variable names for each cached sub-operator + syms = ntuple(i -> Symbol(:op_, i), N) + + # Emit cache_operator for donors, cache_operator_hinted for the rest + stmts = ntuple(N) do i + d = donor[i] + d == i ? + :($(syms[i]) = cache_operator(L.ops[$i], v)) : + :($(syms[i]) = cache_operator_hinted(L.ops[$i], getcache($(syms[d])), v)) + end + return quote - ops = Base.@ntuple $N i -> cache_operator(L.ops[i], v) - return AddedOperator(ops) + $(stmts...) + return AddedOperator(($(syms...),)) end end @@ -874,6 +891,7 @@ function update_coefficients(L::ComposedOperator, u, p, t; kwargs...) end getops(L::ComposedOperator) = L.ops +getcache(op::ComposedOperator) = op.cache # Copy method to avoid aliasing function Base.copy(L::ComposedOperator) @@ -939,6 +957,16 @@ end end end +function _get_cache_shapes(L::ComposedOperator, v::AbstractVecOrMat) + N = length(L.ops) + res = if v isa AbstractMatrix + ntuple(i -> (size(L.ops[i], 1), size(v, 2)), Val(N)) + else + ntuple(i -> (size(L.ops[i], 1),), Val(N)) + end + return res +end + @generated function cache_self(L::ComposedOperator, v::AbstractVecOrMat) N = length(L.parameters[2].parameters) # Number of operators @@ -1199,6 +1227,7 @@ function update_coefficients(L::InvertedOperator, u, p, t; kwargs...) end getops(L::InvertedOperator) = (L.L,) +getcache(op::InvertedOperator) = op.cache islinear(L::InvertedOperator) = islinear(L.L) isconvertible(::InvertedOperator) = false @@ -1229,9 +1258,10 @@ function Base.copy(L::InvertedOperator) ) end +_get_cache_shapes(::InvertedOperator, v::AbstractVecOrMat) = size(v) + function cache_self(L::InvertedOperator, u::AbstractVecOrMat) - cache = zero(u) - @reset L.cache = cache + @reset L.cache = zero(u) return L end diff --git a/src/func.jl b/src/func.jl index 19cd14bd..eeaa28c9 100644 --- a/src/func.jl +++ b/src/func.jl @@ -590,6 +590,13 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray) return L end +function _get_cache_shapes(L::FunctionOperator, v::AbstractVecOrMat) + return (L.traits.sizes[1], L.traits.sizes[2]) +end + +getcache(op::FunctionOperator) = op.cache +update_cache(L::FunctionOperator, new_cache) = set_cache(L, new_cache) + # fix method amg bw AbstractArray, AbstractVecOrMat cache_self(L::FunctionOperator, v::AbstractArray) = _cache_self(L, v) cache_self(L::FunctionOperator, v::AbstractVecOrMat) = _cache_self(L, v) diff --git a/src/interface.jl b/src/interface.jl index 2801dc70..dc201749 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -139,6 +139,15 @@ getops(L) = () """ $SIGNATURES +Return the current cache held by `op`, or `nothing` if it holds none. +New operator types get the safe `nothing` default automatically; override +for types that store a shareable `.cache` field. +""" +getcache(::AbstractSciMLOperator) = nothing + +""" +$SIGNATURES + Checks whether `L` has preallocated caches for inplace evaluations. """ function iscached(L::AbstractSciMLOperator) @@ -179,6 +188,53 @@ end cache_self(L::AbstractSciMLOperator, ::AbstractVecOrMat) = L cache_internals(L::AbstractSciMLOperator, ::AbstractVecOrMat) = L +""" +$SIGNATURES + +Return the expected cache shape specification for `L` given input `v`. +Returns `nothing` if `L` requires no cache. +The return value can be a single `NTuple{N,Int}` (single-array cache), +a `Tuple` of such shapes (multi-array cache), or `nothing` for absent slots. +""" +_get_cache_shapes(::AbstractSciMLOperator, ::AbstractVecOrMat) = nothing + +""" +$SIGNATURES + +Check whether `hint` is shape-compatible with `shapes` (as returned by `_get_cache_shapes`). +Uses `zip` to avoid integer-indexed Tuple access. Reads only array metadata — safe on GPU. +""" +_cache_compatible(hint, ::Nothing) = false +_cache_compatible(::Nothing, shapes) = false +_cache_compatible(::Nothing, ::Nothing) = false +_cache_compatible(hint::AbstractArray, shape::Tuple{Vararg{Int}}) = size(hint) == shape +function _cache_compatible(hint::Tuple, shapes::Tuple) + length(hint) != length(shapes) && return false + return all(((h, s),) -> _cache_compatible(h, s), zip(hint, shapes)) +end + +""" +$SIGNATURES + +Inject `new_cache` into `op` as its cache. Default uses `@reset op.cache = new_cache`. +Override for operators that don't use the `.cache` field convention. +""" +update_cache(op::AbstractSciMLOperator, new_cache) = @reset op.cache = new_cache + +""" +$SIGNATURES + +Like `cache_operator`, but tries to reuse `hint` (an existing cache from a compatible operator) +instead of allocating new buffers. Falls back to `cache_operator` when `hint` is not compatible. +""" +function cache_operator_hinted(op::AbstractSciMLOperator, hint, v::AbstractVecOrMat) + if _cache_compatible(hint, _get_cache_shapes(op, v)) + op = update_cache(op, hint) + return cache_internals(op, v) + end + return cache_operator(op, v) +end + ### # operator traits ### diff --git a/src/tensor.jl b/src/tensor.jl index 34d8be87..aba70a81 100644 --- a/src/tensor.jl +++ b/src/tensor.jl @@ -175,6 +175,7 @@ function update_coefficients(L::TensorProductOperator, u, p, t; kwargs...) end getops(L::TensorProductOperator) = L.ops +getcache(op::TensorProductOperator) = op.cache # Copy method to avoid aliasing function Base.copy(L::TensorProductOperator) @@ -362,6 +363,29 @@ function Base.:\(L::TensorProductOperator, v::AbstractVecOrMat) return v isa AbstractMatrix ? reshape(V, (n, k)) : reshape(V, (n,)) end +function _get_cache_shapes(L::TensorProductOperator, v::AbstractVecOrMat) + outer, inner = L.ops + outer isa IdentityOperator && return nothing + + mi, ni = size(inner) + mo, no = size(outer) + k = size(v, 2) + + s1 = (mi, no * k) + s2 = (no, mi, k) + s3 = (mo, mi * k) + s4 = (mo * mi, k) + + if reduce(&, issquare.(L.ops)) + return (s1, s2, s3, s4, s1, s2, s3) + else + s5 = (ni, mo * k) + s6 = (mo, ni, k) + s7 = (no, ni * k) + return (s1, s2, s3, s4, s5, s6, s7) + end +end + function cache_self(L::TensorProductOperator, v::AbstractVecOrMat) outer, inner = L.ops diff --git a/test/basic.jl b/test/basic.jl index aaab3567..9a08d7da 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -408,6 +408,56 @@ end end end +@testset "AddedOperator cache sharing (Composed, Tensor, Composed, Tensor, Tensor)" begin + using SciMLOperators: cache_operator_hinted + + m1, m2 = 2, 4 # m1 * m2 == N + + # C1 and C2: same wrapper (ComposedOperator), different inner type params + # C1 = A1*B1 → ops::Tuple{MatrixOperator, MatrixOperator} + # C2 = A2*B2' → ops::Tuple{MatrixOperator, AdjointOperator{…}} + A1 = MatrixOperator(rand(N, N)); B1 = MatrixOperator(rand(N, N)) + A2 = MatrixOperator(rand(N, N)); B2 = MatrixOperator(rand(N, N)) + C1 = A1 * B1 + C2 = A2 * B2' + + # T1, T2, T3: same wrapper (TensorProductOperator), different inner type params + # T1 = Ao ⊗ Ai, T2 = Ao' ⊗ Ai, T3 = Ao ⊗ Ai' + Ao = MatrixOperator(rand(m1, m1)); Ai = MatrixOperator(rand(m2, m2)) + T1 = TensorProductOperator(Ao, Ai) + T2 = TensorProductOperator(Ao', Ai) + T3 = TensorProductOperator(Ao, Ai') + + L = C1 + T1 + C2 + T2 + T3 + A1 + A2 + @test L isa AddedOperator + @test length(L.ops) == 7 + + # Matrix input + u = rand(N, K) + L = cache_operator(L, u) + + # Correctness: the cached operator gives the right result + expected = C1 * u + T1 * u + C2 * u + T2 * u + T3 * u + A1 * u + A2 * u + @test L * u ≈ expected + + # Cache sharing: same-wrapper sub-operators with compatible sizes share physical buffers + @test L.ops[3].cache === L.ops[1].cache # C2 (A2*B2') reuses C1's cache (same wrapper) + @test L.ops[4].cache === L.ops[2].cache # T2 (Ao'⊗Ai) reuses T1's cache (same wrapper) + @test L.ops[5].cache === L.ops[2].cache # T3 (Ao⊗Ai') reuses T1's cache (same wrapper) + + # Vector input + v = rand(N) + L = cache_operator(L, v) + + expected = C1 * v + T1 * v + C2 * v + T2 * v + T3 * v + A1 * v + A2 * v + @test L * v ≈ expected + + @test L.ops[3].cache === L.ops[1].cache + @test L.ops[4].cache === L.ops[2].cache + @test L.ops[5].cache === L.ops[2].cache +end + + @testset "ComposedOperator" begin A = rand(N, N) B = rand(N, N) From 3c8a12415881de76c47082c59ed8c2c314406d13 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Wed, 6 May 2026 14:24:57 +0200 Subject: [PATCH 2/6] Remove ishermitian method --- src/basic.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/basic.jl b/src/basic.jl index 9c500a4f..54b55ab6 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -644,7 +644,6 @@ end islinear(L::AddedOperator) = all(islinear, getops(L)) Base.iszero(L::AddedOperator) = all(iszero, getops(L)) has_adjoint(L::AddedOperator) = all(has_adjoint, L.ops) -LinearAlgebra.ishermitian(L::AddedOperator) = all(ishermitian, L.ops) @generated function cache_internals(L::AddedOperator, v::AbstractVecOrMat) ops_types = L.parameters[2].parameters From a86a6b33203650fc86e329d5b8f2dd0bce2fce83 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Wed, 6 May 2026 14:45:16 +0200 Subject: [PATCH 3/6] Reuse `_get_cache_shapes` in `cache_self` Co-authored-by: Copilot --- src/tensor.jl | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/src/tensor.jl b/src/tensor.jl index aba70a81..871c8972 100644 --- a/src/tensor.jl +++ b/src/tensor.jl @@ -387,29 +387,27 @@ function _get_cache_shapes(L::TensorProductOperator, v::AbstractVecOrMat) end function cache_self(L::TensorProductOperator, v::AbstractVecOrMat) - outer, inner = L.ops - - mi, ni = size(inner) - mo, no = size(outer) - k = size(v, 2) + shapes = _get_cache_shapes(L, v) - is_outer_identity = outer isa IdentityOperator + # outer is IdentityOperator — no buffers needed + if isnothing(shapes) + @reset L.cache = (nothing, nothing, nothing, nothing, nothing, nothing, nothing) + return L + end - # 3 arg mul! - c1 = is_outer_identity ? nothing : lmul!(false, similar(v, (mi, no * k))) # c1 = inner * v - c2 = is_outer_identity ? nothing : lmul!(false, similar(v, (no, mi, k))) # permute (2, 1, 3) - c3 = is_outer_identity ? nothing : lmul!(false, similar(v, (mo, mi * k))) # c3 = outer * c2 + s1, s2, s3, s4, s5, s6, s7 = shapes - # 5 arg mul! - c4 = is_outer_identity ? nothing : lmul!(false, similar(v, (mo * mi, k))) # cache v in 5 arg mul! + c1 = lmul!(false, similar(v, s1)) # inner * v (3-arg mul!) + c2 = lmul!(false, similar(v, s2)) # permute (2,1,3) + c3 = lmul!(false, similar(v, s3)) # outer * c2 + c4 = lmul!(false, similar(v, s4)) # copy of w for 5-arg mul! - # 3 arg ldiv! if mapreduce(issquare, &, L.ops) - c5, c6, c7 = c1, c2, c3 + c5, c6, c7 = c1, c2, c3 # square case: ldiv! reuses mul! buffers else - c5 = lmul!(false, similar(v, (ni, mo * k))) # c5 = inner \ v - c6 = lmul!(false, similar(v, (mo, ni, k))) # permute (2, 1, 3) - c7 = lmul!(false, similar(v, (no, ni * k))) # c7 = outer \ c6 + c5 = lmul!(false, similar(v, s5)) # inner \ v (3-arg ldiv!) + c6 = lmul!(false, similar(v, s6)) # permute (2,1,3) + c7 = lmul!(false, similar(v, s7)) # outer \ c6 end @reset L.cache = (c1, c2, c3, c4, c5, c6, c7) From d803c2fbac893d3f258f82a1de3a8ae2aafdfd26 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Wed, 6 May 2026 15:27:15 +0200 Subject: [PATCH 4/6] Format code --- test/basic.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/basic.jl b/test/basic.jl index 9a08d7da..74f27d9d 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -424,9 +424,9 @@ end # T1, T2, T3: same wrapper (TensorProductOperator), different inner type params # T1 = Ao ⊗ Ai, T2 = Ao' ⊗ Ai, T3 = Ao ⊗ Ai' Ao = MatrixOperator(rand(m1, m1)); Ai = MatrixOperator(rand(m2, m2)) - T1 = TensorProductOperator(Ao, Ai) + T1 = TensorProductOperator(Ao, Ai) T2 = TensorProductOperator(Ao', Ai) - T3 = TensorProductOperator(Ao, Ai') + T3 = TensorProductOperator(Ao, Ai') L = C1 + T1 + C2 + T2 + T3 + A1 + A2 @test L isa AddedOperator @@ -437,7 +437,7 @@ end L = cache_operator(L, u) # Correctness: the cached operator gives the right result - expected = C1 * u + T1 * u + C2 * u + T2 * u + T3 * u + A1 * u + A2 * u + expected = C1 * u + T1 * u + C2 * u + T2 * u + T3 * u + A1 * u + A2 * u @test L * u ≈ expected # Cache sharing: same-wrapper sub-operators with compatible sizes share physical buffers From 33017d56db013a8b218ca51f6c15ed9137acaee9 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Tue, 19 May 2026 04:11:07 +0200 Subject: [PATCH 5/6] Fix issues --- src/basic.jl | 60 +++++++++++++++++++++++++++++++++---------- src/func.jl | 9 +++++++ src/interface.jl | 21 +++++++++------ test/basic.jl | 67 +++++++++++++++++++++++++++++++----------------- 4 files changed, 112 insertions(+), 45 deletions(-) diff --git a/src/basic.jl b/src/basic.jl index 54b55ab6..82a6eae3 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -649,23 +649,54 @@ has_adjoint(L::AddedOperator) = all(has_adjoint, L.ops) ops_types = L.parameters[2].parameters N = length(ops_types) - # If multiple sub-operators share the same outermost type constructor (wrapper), we can cache one of them and reuse the cache for the others. This is because operators with the same wrapper will have the same caching structure, so we can avoid redundant caching work. The `donor` tuple identifies which operator's cache to use for each sub-operator. + # Within each wrapper-type group, the op whose eltype equals promote_type(group eltypes) + # becomes the donor — all others in the group reuse its cache instead of allocating their own. + # If promote_type yields a new type that no op in the group has (e.g. Float64 + ComplexF32 + # → ComplexF64), no op qualifies as donor and every op in that group caches independently. + # NOTE: the resulting cache aliasing (ops[i].cache === ops[j].cache) is safe only because + # AddedOperator's mul! evaluates sub-ops strictly serially. Any parallelism would require + # independent caches per sub-op. + wrappers = ntuple(i -> Base.typename(ops_types[i]).wrapper, N) + + donor = ntuple(N) do i + idx_eltype = foldl(enumerate(wrappers); init = (i, wrappers[i])) do a, b + Ta = eltype(ops_types[a[1]]) + Tb = eltype(ops_types[b[1]]) + T = promote_type(Ta, Tb) + eltype_condition_a = Ta === T + eltype_condition_b = Tb === T + + if b[2] !== wrappers[i] || !eltype_condition_b + return a + else + if eltype_condition_a && !eltype_condition_b + return a + elseif eltype_condition_b && !eltype_condition_a + return b + else + return b + end + end + end - donor = ntuple(i -> findfirst(j -> ops_types[j].name.wrapper === ops_types[i].name.wrapper, 1:N), N) + return idx_eltype[1] + end # Unique variable names for each cached sub-operator syms = ntuple(i -> Symbol(:op_, i), N) - # Emit cache_operator for donors, cache_operator_hinted for the rest - stmts = ntuple(N) do i - d = donor[i] - d == i ? - :($(syms[i]) = cache_operator(L.ops[$i], v)) : - :($(syms[i]) = cache_operator_hinted(L.ops[$i], getcache($(syms[d])), v)) - end + # Emit donors first so their symbols are defined before any asker references them + donor_stmts = [ + :($(syms[i]) = cache_operator(L.ops[$i], v)) for i in 1:N if donor[i] == i + ] + asker_stmts = [ + :($(syms[i]) = cache_operator_hinted(L.ops[$i], getcache($(syms[donor[i]])), v)) + for i in 1:N if donor[i] != i + ] return quote - $(stmts...) + $(donor_stmts...) + $(asker_stmts...) return AddedOperator(($(syms...),)) end end @@ -958,12 +989,13 @@ end function _get_cache_shapes(L::ComposedOperator, v::AbstractVecOrMat) N = length(L.ops) - res = if v isa AbstractMatrix - ntuple(i -> (size(L.ops[i], 1), size(v, 2)), Val(N)) + K = size(v, 2) + + if v isa AbstractMatrix + return ntuple(i -> i < N ? (size(L.ops[i + 1], 1), K) : (size(v, 1), K), Val(N)) else - ntuple(i -> (size(L.ops[i], 1),), Val(N)) + return ntuple(i -> i < N ? (size(L.ops[i + 1], 1),) : (size(v, 1),), Val(N)) end - return res end @generated function cache_self(L::ComposedOperator, v::AbstractVecOrMat) diff --git a/src/func.jl b/src/func.jl index eeaa28c9..312f49dd 100644 --- a/src/func.jl +++ b/src/func.jl @@ -591,6 +591,15 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray) end function _get_cache_shapes(L::FunctionOperator, v::AbstractVecOrMat) + if L.traits.batch + M = size(L, 1) + if v isa AbstractMatrix + return (size(v), (M, size(v, 2))) + else + return (size(v), (M,)) + end + end + return (L.traits.sizes[1], L.traits.sizes[2]) end diff --git a/src/interface.jl b/src/interface.jl index dc201749..209d83e4 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -204,13 +204,18 @@ $SIGNATURES Check whether `hint` is shape-compatible with `shapes` (as returned by `_get_cache_shapes`). Uses `zip` to avoid integer-indexed Tuple access. Reads only array metadata — safe on GPU. """ -_cache_compatible(hint, ::Nothing) = false -_cache_compatible(::Nothing, shapes) = false -_cache_compatible(::Nothing, ::Nothing) = false -_cache_compatible(hint::AbstractArray, shape::Tuple{Vararg{Int}}) = size(hint) == shape -function _cache_compatible(hint::Tuple, shapes::Tuple) +_cache_compatible(hint, ::Nothing, v::AbstractArray) = false +_cache_compatible(::Nothing, shapes, v::AbstractArray) = false +_cache_compatible(::Nothing, ::Nothing, v::AbstractArray) = false +function _cache_compatible(hint::AbstractArray, shape::Tuple{Vararg{Int}}, v::AbstractArray) + # Check array device compatibility (CPU, GPU, etc.) + Base.typename(typeof(hint)).wrapper === Base.typename(typeof(v)).wrapper || return false + promote_type(eltype(v), eltype(hint)) === eltype(hint) || return false + size(hint) == shape || return false +end +function _cache_compatible(hint::Tuple, shapes::Tuple, v::AbstractArray) length(hint) != length(shapes) && return false - return all(((h, s),) -> _cache_compatible(h, s), zip(hint, shapes)) + return all(((h, s),) -> _cache_compatible(h, s, v), zip(hint, shapes)) end """ @@ -228,7 +233,7 @@ Like `cache_operator`, but tries to reuse `hint` (an existing cache from a compa instead of allocating new buffers. Falls back to `cache_operator` when `hint` is not compatible. """ function cache_operator_hinted(op::AbstractSciMLOperator, hint, v::AbstractVecOrMat) - if _cache_compatible(hint, _get_cache_shapes(op, v)) + if _cache_compatible(hint, _get_cache_shapes(op, v), v) op = update_cache(op, hint) return cache_internals(op, v) end @@ -240,7 +245,7 @@ end ### Base.size(A::AbstractSciMLOperator, d::Integer) = d <= 2 ? size(A)[d] : 1 -Base.eltype(::Type{AbstractSciMLOperator{T}}) where {T} = T +Base.eltype(::Type{<:AbstractSciMLOperator{T}}) where {T} = T Base.eltype(::AbstractSciMLOperator{T}) where {T} = T Base.promote_eltype(::AbstractSciMLOperator{<:T1}, ::AbstractSciMLOperator{<:T2}) where {T1, T2} = Base.promote_type(T1, T2) diff --git a/test/basic.jl b/test/basic.jl index 74f27d9d..69d4e1bd 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -409,7 +409,7 @@ end end @testset "AddedOperator cache sharing (Composed, Tensor, Composed, Tensor, Tensor)" begin - using SciMLOperators: cache_operator_hinted + using SciMLOperators: cache_operator_hinted, _get_cache_shapes m1, m2 = 2, 4 # m1 * m2 == N @@ -432,29 +432,50 @@ end @test L isa AddedOperator @test length(L.ops) == 7 - # Matrix input - u = rand(N, K) - L = cache_operator(L, u) - - # Correctness: the cached operator gives the right result - expected = C1 * u + T1 * u + C2 * u + T2 * u + T3 * u + A1 * u + A2 * u - @test L * u ≈ expected - - # Cache sharing: same-wrapper sub-operators with compatible sizes share physical buffers - @test L.ops[3].cache === L.ops[1].cache # C2 (A2*B2') reuses C1's cache (same wrapper) - @test L.ops[4].cache === L.ops[2].cache # T2 (Ao'⊗Ai) reuses T1's cache (same wrapper) - @test L.ops[5].cache === L.ops[2].cache # T3 (Ao⊗Ai') reuses T1's cache (same wrapper) - - # Vector input - v = rand(N) - L = cache_operator(L, v) - - expected = C1 * v + T1 * v + C2 * v + T2 * v + T3 * v + A1 * v + A2 * v - @test L * v ≈ expected + for input in (rand(N, K), rand(N)) + L = cache_operator(L, input) + expected = C1 * input + T1 * input + C2 * input + T2 * input + T3 * input + + A1 * input + A2 * input + + # Correctness: out-of-place (*) and in-place (mul!) paths + @test L * input ≈ expected + w = similar(input) + mul!(w, L, input) + @test w ≈ expected + + # Cache sharing: same-wrapper ops with compatible sizes share physical buffers + @test L.ops[3].cache === L.ops[1].cache # C2 (A2*B2') reuses C1's cache + @test L.ops[4].cache === L.ops[2].cache # T2 (Ao'⊗Ai) reuses T1's cache + @test L.ops[5].cache === L.ops[2].cache # T3 (Ao⊗Ai') reuses T1's cache + end - @test L.ops[3].cache === L.ops[1].cache - @test L.ops[4].cache === L.ops[2].cache - @test L.ops[5].cache === L.ops[2].cache + # --- Mixed-eltype: Float64 + ComplexF64 ComposedOperators --- + # promote_type(Float64, ComplexF64) = ComplexF64 = eltype(Ac) → Ac is donor, Ar reuses its cache + Ar = MatrixOperator(rand(Float64, N, N)) * MatrixOperator(rand(Float64, N, N)) + Ac = MatrixOperator(rand(ComplexF64, N, N)) * MatrixOperator(rand(ComplexF64, N, N)) + v_real = rand(Float64, N) + Lcs = cache_operator(Ar + Ac, v_real) + @test Lcs.ops[1].cache === Lcs.ops[2].cache # Ar reuses Ac's ComplexF64 cache + w_c = similar(v_real, ComplexF64) + mul!(w_c, Lcs, v_real) + @test w_c ≈ Ar * v_real + Ac * v_real + + # --- No sharing: Float64 + ComplexF32 → promote_type = ComplexF64 (neither op's eltype) --- + Af32 = MatrixOperator(rand(ComplexF32, N, N)) * MatrixOperator(rand(ComplexF32, N, N)) + Lf32s = cache_operator(Ar + Af32, rand(Float64, N)) + @test Lf32s.ops[1].cache !== Lf32s.ops[2].cache # independent caches + + # --- Non-square ComposedOperator: _get_cache_shapes must match cache_self's allocation --- + M1, M2, M3 = 5, 3, 4 + P = MatrixOperator(rand(M1, M2)); Q = MatrixOperator(rand(M2, M3)) + PQc = cache_operator(P * Q, rand(M3)) + + @test _get_cache_shapes(PQc, rand(M3)) == ((M2,), (M3,)) + @test map(size, PQc.cache) == ((M2,), (M3,)) + v_ns = rand(M3) + w_ns = zeros(M1) + mul!(w_ns, PQc, v_ns) + @test w_ns ≈ P * (Q * v_ns) end From 39a6b9a347cd51cde83ab37dcd7a65fd5ace7568 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Tue, 19 May 2026 04:11:28 +0200 Subject: [PATCH 6/6] Format code --- src/basic.jl | 2 +- src/interface.jl | 2 +- test/basic.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/basic.jl b/src/basic.jl index 82a6eae3..7b70c521 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -691,7 +691,7 @@ has_adjoint(L::AddedOperator) = all(has_adjoint, L.ops) ] asker_stmts = [ :($(syms[i]) = cache_operator_hinted(L.ops[$i], getcache($(syms[donor[i]])), v)) - for i in 1:N if donor[i] != i + for i in 1:N if donor[i] != i ] return quote diff --git a/src/interface.jl b/src/interface.jl index 209d83e4..c4a7e1f1 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -211,7 +211,7 @@ function _cache_compatible(hint::AbstractArray, shape::Tuple{Vararg{Int}}, v::Ab # Check array device compatibility (CPU, GPU, etc.) Base.typename(typeof(hint)).wrapper === Base.typename(typeof(v)).wrapper || return false promote_type(eltype(v), eltype(hint)) === eltype(hint) || return false - size(hint) == shape || return false + return size(hint) == shape || return false end function _cache_compatible(hint::Tuple, shapes::Tuple, v::AbstractArray) length(hint) != length(shapes) && return false diff --git a/test/basic.jl b/test/basic.jl index 69d4e1bd..8f2df554 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -435,7 +435,7 @@ end for input in (rand(N, K), rand(N)) L = cache_operator(L, input) expected = C1 * input + T1 * input + C2 * input + T2 * input + T3 * input + - A1 * input + A2 * input + A1 * input + A2 * input # Correctness: out-of-place (*) and in-place (mul!) paths @test L * input ≈ expected