diff --git a/src/SciMLOperators.jl b/src/SciMLOperators.jl index 2a756e71..50e859ab 100644 --- a/src/SciMLOperators.jl +++ b/src/SciMLOperators.jl @@ -234,7 +234,9 @@ export export update_coefficients!, update_coefficients, isconstant, iscached, - cache_operator, issquare, + cache_operator, cache_operator_hinted, + update_cache, + issquare, islinear, concretize, isconvertible, has_adjoint, diff --git a/src/basic.jl b/src/basic.jl index c7a34060..7b70c521 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -648,9 +648,56 @@ has_adjoint(L::AddedOperator) = all(has_adjoint, L.ops) @generated function cache_internals(L::AddedOperator, v::AbstractVecOrMat) ops_types = L.parameters[2].parameters N = length(ops_types) + + # Within each wrapper-type group, the op whose eltype equals promote_type(group eltypes) + # becomes the donor — all others in the group reuse its cache instead of allocating their own. + # If promote_type yields a new type that no op in the group has (e.g. Float64 + ComplexF32 + # → ComplexF64), no op qualifies as donor and every op in that group caches independently. + # NOTE: the resulting cache aliasing (ops[i].cache === ops[j].cache) is safe only because + # AddedOperator's mul! evaluates sub-ops strictly serially. Any parallelism would require + # independent caches per sub-op. + wrappers = ntuple(i -> Base.typename(ops_types[i]).wrapper, N) + + donor = ntuple(N) do i + idx_eltype = foldl(enumerate(wrappers); init = (i, wrappers[i])) do a, b + Ta = eltype(ops_types[a[1]]) + Tb = eltype(ops_types[b[1]]) + T = promote_type(Ta, Tb) + eltype_condition_a = Ta === T + eltype_condition_b = Tb === T + + if b[2] !== wrappers[i] || !eltype_condition_b + return a + else + if eltype_condition_a && !eltype_condition_b + return a + elseif eltype_condition_b && !eltype_condition_a + return b + else + return b + end + end + end + + return idx_eltype[1] + end + + # Unique variable names for each cached sub-operator + syms = ntuple(i -> Symbol(:op_, i), N) + + # Emit donors first so their symbols are defined before any asker references them + donor_stmts = [ + :($(syms[i]) = cache_operator(L.ops[$i], v)) for i in 1:N if donor[i] == i + ] + asker_stmts = [ + :($(syms[i]) = cache_operator_hinted(L.ops[$i], getcache($(syms[donor[i]])), v)) + for i in 1:N if donor[i] != i + ] + return quote - ops = Base.@ntuple $N i -> cache_operator(L.ops[i], v) - return AddedOperator(ops) + $(donor_stmts...) + $(asker_stmts...) + return AddedOperator(($(syms...),)) end end @@ -874,6 +921,7 @@ function update_coefficients(L::ComposedOperator, u, p, t; kwargs...) end getops(L::ComposedOperator) = L.ops +getcache(op::ComposedOperator) = op.cache # Copy method to avoid aliasing function Base.copy(L::ComposedOperator) @@ -939,6 +987,17 @@ end end end +function _get_cache_shapes(L::ComposedOperator, v::AbstractVecOrMat) + N = length(L.ops) + K = size(v, 2) + + if v isa AbstractMatrix + return ntuple(i -> i < N ? (size(L.ops[i + 1], 1), K) : (size(v, 1), K), Val(N)) + else + return ntuple(i -> i < N ? (size(L.ops[i + 1], 1),) : (size(v, 1),), Val(N)) + end +end + @generated function cache_self(L::ComposedOperator, v::AbstractVecOrMat) N = length(L.parameters[2].parameters) # Number of operators @@ -1199,6 +1258,7 @@ function update_coefficients(L::InvertedOperator, u, p, t; kwargs...) end getops(L::InvertedOperator) = (L.L,) +getcache(op::InvertedOperator) = op.cache islinear(L::InvertedOperator) = islinear(L.L) isconvertible(::InvertedOperator) = false @@ -1229,9 +1289,10 @@ function Base.copy(L::InvertedOperator) ) end +_get_cache_shapes(::InvertedOperator, v::AbstractVecOrMat) = size(v) + function cache_self(L::InvertedOperator, u::AbstractVecOrMat) - cache = zero(u) - @reset L.cache = cache + @reset L.cache = zero(u) return L end diff --git a/src/func.jl b/src/func.jl index 19cd14bd..312f49dd 100644 --- a/src/func.jl +++ b/src/func.jl @@ -590,6 +590,22 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray) return L end +function _get_cache_shapes(L::FunctionOperator, v::AbstractVecOrMat) + if L.traits.batch + M = size(L, 1) + if v isa AbstractMatrix + return (size(v), (M, size(v, 2))) + else + return (size(v), (M,)) + end + end + + return (L.traits.sizes[1], L.traits.sizes[2]) +end + +getcache(op::FunctionOperator) = op.cache +update_cache(L::FunctionOperator, new_cache) = set_cache(L, new_cache) + # fix method amg bw AbstractArray, AbstractVecOrMat cache_self(L::FunctionOperator, v::AbstractArray) = _cache_self(L, v) cache_self(L::FunctionOperator, v::AbstractVecOrMat) = _cache_self(L, v) diff --git a/src/interface.jl b/src/interface.jl index 2801dc70..c4a7e1f1 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -139,6 +139,15 @@ getops(L) = () """ $SIGNATURES +Return the current cache held by `op`, or `nothing` if it holds none. +New operator types get the safe `nothing` default automatically; override +for types that store a shareable `.cache` field. +""" +getcache(::AbstractSciMLOperator) = nothing + +""" +$SIGNATURES + Checks whether `L` has preallocated caches for inplace evaluations. """ function iscached(L::AbstractSciMLOperator) @@ -179,12 +188,64 @@ end cache_self(L::AbstractSciMLOperator, ::AbstractVecOrMat) = L cache_internals(L::AbstractSciMLOperator, ::AbstractVecOrMat) = L +""" +$SIGNATURES + +Return the expected cache shape specification for `L` given input `v`. +Returns `nothing` if `L` requires no cache. +The return value can be a single `NTuple{N,Int}` (single-array cache), +a `Tuple` of such shapes (multi-array cache), or `nothing` for absent slots. +""" +_get_cache_shapes(::AbstractSciMLOperator, ::AbstractVecOrMat) = nothing + +""" +$SIGNATURES + +Check whether `hint` is shape-compatible with `shapes` (as returned by `_get_cache_shapes`). +Uses `zip` to avoid integer-indexed Tuple access. Reads only array metadata — safe on GPU. +""" +_cache_compatible(hint, ::Nothing, v::AbstractArray) = false +_cache_compatible(::Nothing, shapes, v::AbstractArray) = false +_cache_compatible(::Nothing, ::Nothing, v::AbstractArray) = false +function _cache_compatible(hint::AbstractArray, shape::Tuple{Vararg{Int}}, v::AbstractArray) + # Check array device compatibility (CPU, GPU, etc.) + Base.typename(typeof(hint)).wrapper === Base.typename(typeof(v)).wrapper || return false + promote_type(eltype(v), eltype(hint)) === eltype(hint) || return false + return size(hint) == shape || return false +end +function _cache_compatible(hint::Tuple, shapes::Tuple, v::AbstractArray) + length(hint) != length(shapes) && return false + return all(((h, s),) -> _cache_compatible(h, s, v), zip(hint, shapes)) +end + +""" +$SIGNATURES + +Inject `new_cache` into `op` as its cache. Default uses `@reset op.cache = new_cache`. +Override for operators that don't use the `.cache` field convention. +""" +update_cache(op::AbstractSciMLOperator, new_cache) = @reset op.cache = new_cache + +""" +$SIGNATURES + +Like `cache_operator`, but tries to reuse `hint` (an existing cache from a compatible operator) +instead of allocating new buffers. Falls back to `cache_operator` when `hint` is not compatible. +""" +function cache_operator_hinted(op::AbstractSciMLOperator, hint, v::AbstractVecOrMat) + if _cache_compatible(hint, _get_cache_shapes(op, v), v) + op = update_cache(op, hint) + return cache_internals(op, v) + end + return cache_operator(op, v) +end + ### # operator traits ### Base.size(A::AbstractSciMLOperator, d::Integer) = d <= 2 ? size(A)[d] : 1 -Base.eltype(::Type{AbstractSciMLOperator{T}}) where {T} = T +Base.eltype(::Type{<:AbstractSciMLOperator{T}}) where {T} = T Base.eltype(::AbstractSciMLOperator{T}) where {T} = T Base.promote_eltype(::AbstractSciMLOperator{<:T1}, ::AbstractSciMLOperator{<:T2}) where {T1, T2} = Base.promote_type(T1, T2) diff --git a/src/tensor.jl b/src/tensor.jl index 34d8be87..871c8972 100644 --- a/src/tensor.jl +++ b/src/tensor.jl @@ -175,6 +175,7 @@ function update_coefficients(L::TensorProductOperator, u, p, t; kwargs...) end getops(L::TensorProductOperator) = L.ops +getcache(op::TensorProductOperator) = op.cache # Copy method to avoid aliasing function Base.copy(L::TensorProductOperator) @@ -362,30 +363,51 @@ function Base.:\(L::TensorProductOperator, v::AbstractVecOrMat) return v isa AbstractMatrix ? reshape(V, (n, k)) : reshape(V, (n,)) end -function cache_self(L::TensorProductOperator, v::AbstractVecOrMat) +function _get_cache_shapes(L::TensorProductOperator, v::AbstractVecOrMat) outer, inner = L.ops + outer isa IdentityOperator && return nothing mi, ni = size(inner) mo, no = size(outer) k = size(v, 2) - is_outer_identity = outer isa IdentityOperator + s1 = (mi, no * k) + s2 = (no, mi, k) + s3 = (mo, mi * k) + s4 = (mo * mi, k) + + if reduce(&, issquare.(L.ops)) + return (s1, s2, s3, s4, s1, s2, s3) + else + s5 = (ni, mo * k) + s6 = (mo, ni, k) + s7 = (no, ni * k) + return (s1, s2, s3, s4, s5, s6, s7) + end +end + +function cache_self(L::TensorProductOperator, v::AbstractVecOrMat) + shapes = _get_cache_shapes(L, v) + + # outer is IdentityOperator — no buffers needed + if isnothing(shapes) + @reset L.cache = (nothing, nothing, nothing, nothing, nothing, nothing, nothing) + return L + end - # 3 arg mul! - c1 = is_outer_identity ? nothing : lmul!(false, similar(v, (mi, no * k))) # c1 = inner * v - c2 = is_outer_identity ? nothing : lmul!(false, similar(v, (no, mi, k))) # permute (2, 1, 3) - c3 = is_outer_identity ? nothing : lmul!(false, similar(v, (mo, mi * k))) # c3 = outer * c2 + s1, s2, s3, s4, s5, s6, s7 = shapes - # 5 arg mul! - c4 = is_outer_identity ? nothing : lmul!(false, similar(v, (mo * mi, k))) # cache v in 5 arg mul! + c1 = lmul!(false, similar(v, s1)) # inner * v (3-arg mul!) + c2 = lmul!(false, similar(v, s2)) # permute (2,1,3) + c3 = lmul!(false, similar(v, s3)) # outer * c2 + c4 = lmul!(false, similar(v, s4)) # copy of w for 5-arg mul! - # 3 arg ldiv! if mapreduce(issquare, &, L.ops) - c5, c6, c7 = c1, c2, c3 + c5, c6, c7 = c1, c2, c3 # square case: ldiv! reuses mul! buffers else - c5 = lmul!(false, similar(v, (ni, mo * k))) # c5 = inner \ v - c6 = lmul!(false, similar(v, (mo, ni, k))) # permute (2, 1, 3) - c7 = lmul!(false, similar(v, (no, ni * k))) # c7 = outer \ c6 + c5 = lmul!(false, similar(v, s5)) # inner \ v (3-arg ldiv!) + c6 = lmul!(false, similar(v, s6)) # permute (2,1,3) + c7 = lmul!(false, similar(v, s7)) # outer \ c6 end @reset L.cache = (c1, c2, c3, c4, c5, c6, c7) diff --git a/test/basic.jl b/test/basic.jl index aaab3567..8f2df554 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -408,6 +408,77 @@ end end end +@testset "AddedOperator cache sharing (Composed, Tensor, Composed, Tensor, Tensor)" begin + using SciMLOperators: cache_operator_hinted, _get_cache_shapes + + m1, m2 = 2, 4 # m1 * m2 == N + + # C1 and C2: same wrapper (ComposedOperator), different inner type params + # C1 = A1*B1 → ops::Tuple{MatrixOperator, MatrixOperator} + # C2 = A2*B2' → ops::Tuple{MatrixOperator, AdjointOperator{…}} + A1 = MatrixOperator(rand(N, N)); B1 = MatrixOperator(rand(N, N)) + A2 = MatrixOperator(rand(N, N)); B2 = MatrixOperator(rand(N, N)) + C1 = A1 * B1 + C2 = A2 * B2' + + # T1, T2, T3: same wrapper (TensorProductOperator), different inner type params + # T1 = Ao ⊗ Ai, T2 = Ao' ⊗ Ai, T3 = Ao ⊗ Ai' + Ao = MatrixOperator(rand(m1, m1)); Ai = MatrixOperator(rand(m2, m2)) + T1 = TensorProductOperator(Ao, Ai) + T2 = TensorProductOperator(Ao', Ai) + T3 = TensorProductOperator(Ao, Ai') + + L = C1 + T1 + C2 + T2 + T3 + A1 + A2 + @test L isa AddedOperator + @test length(L.ops) == 7 + + for input in (rand(N, K), rand(N)) + L = cache_operator(L, input) + expected = C1 * input + T1 * input + C2 * input + T2 * input + T3 * input + + A1 * input + A2 * input + + # Correctness: out-of-place (*) and in-place (mul!) paths + @test L * input ≈ expected + w = similar(input) + mul!(w, L, input) + @test w ≈ expected + + # Cache sharing: same-wrapper ops with compatible sizes share physical buffers + @test L.ops[3].cache === L.ops[1].cache # C2 (A2*B2') reuses C1's cache + @test L.ops[4].cache === L.ops[2].cache # T2 (Ao'⊗Ai) reuses T1's cache + @test L.ops[5].cache === L.ops[2].cache # T3 (Ao⊗Ai') reuses T1's cache + end + + # --- Mixed-eltype: Float64 + ComplexF64 ComposedOperators --- + # promote_type(Float64, ComplexF64) = ComplexF64 = eltype(Ac) → Ac is donor, Ar reuses its cache + Ar = MatrixOperator(rand(Float64, N, N)) * MatrixOperator(rand(Float64, N, N)) + Ac = MatrixOperator(rand(ComplexF64, N, N)) * MatrixOperator(rand(ComplexF64, N, N)) + v_real = rand(Float64, N) + Lcs = cache_operator(Ar + Ac, v_real) + @test Lcs.ops[1].cache === Lcs.ops[2].cache # Ar reuses Ac's ComplexF64 cache + w_c = similar(v_real, ComplexF64) + mul!(w_c, Lcs, v_real) + @test w_c ≈ Ar * v_real + Ac * v_real + + # --- No sharing: Float64 + ComplexF32 → promote_type = ComplexF64 (neither op's eltype) --- + Af32 = MatrixOperator(rand(ComplexF32, N, N)) * MatrixOperator(rand(ComplexF32, N, N)) + Lf32s = cache_operator(Ar + Af32, rand(Float64, N)) + @test Lf32s.ops[1].cache !== Lf32s.ops[2].cache # independent caches + + # --- Non-square ComposedOperator: _get_cache_shapes must match cache_self's allocation --- + M1, M2, M3 = 5, 3, 4 + P = MatrixOperator(rand(M1, M2)); Q = MatrixOperator(rand(M2, M3)) + PQc = cache_operator(P * Q, rand(M3)) + + @test _get_cache_shapes(PQc, rand(M3)) == ((M2,), (M3,)) + @test map(size, PQc.cache) == ((M2,), (M3,)) + v_ns = rand(M3) + w_ns = zeros(M1) + mul!(w_ns, PQc, v_ns) + @test w_ns ≈ P * (Q * v_ns) +end + + @testset "ComposedOperator" begin A = rand(N, N) B = rand(N, N)