Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/SciMLOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ export
export update_coefficients!,
update_coefficients, isconstant,
iscached,
cache_operator, issquare,
cache_operator, cache_operator_hinted,
update_cache,
issquare,
islinear,
concretize,
isconvertible, has_adjoint,
Expand Down
69 changes: 65 additions & 4 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -648,9 +648,56 @@ has_adjoint(L::AddedOperator) = all(has_adjoint, L.ops)
@generated function cache_internals(L::AddedOperator, v::AbstractVecOrMat)
ops_types = L.parameters[2].parameters
N = length(ops_types)

# Within each wrapper-type group, the op whose eltype equals promote_type(group eltypes)
# becomes the donor — all others in the group reuse its cache instead of allocating their own.
# If promote_type yields a new type that no op in the group has (e.g. Float64 + ComplexF32
# → ComplexF64), no op qualifies as donor and every op in that group caches independently.
# NOTE: the resulting cache aliasing (ops[i].cache === ops[j].cache) is safe only because
# AddedOperator's mul! evaluates sub-ops strictly serially. Any parallelism would require
# independent caches per sub-op.
wrappers = ntuple(i -> Base.typename(ops_types[i]).wrapper, N)

donor = ntuple(N) do i
idx_eltype = foldl(enumerate(wrappers); init = (i, wrappers[i])) do a, b
Ta = eltype(ops_types[a[1]])
Tb = eltype(ops_types[b[1]])
T = promote_type(Ta, Tb)
eltype_condition_a = Ta === T
eltype_condition_b = Tb === T

if b[2] !== wrappers[i] || !eltype_condition_b
return a
else
if eltype_condition_a && !eltype_condition_b
return a
elseif eltype_condition_b && !eltype_condition_a
return b
else
return b
end
end
end

return idx_eltype[1]
end

# Unique variable names for each cached sub-operator
syms = ntuple(i -> Symbol(:op_, i), N)

# Emit donors first so their symbols are defined before any asker references them
donor_stmts = [
:($(syms[i]) = cache_operator(L.ops[$i], v)) for i in 1:N if donor[i] == i
]
asker_stmts = [
:($(syms[i]) = cache_operator_hinted(L.ops[$i], getcache($(syms[donor[i]])), v))
for i in 1:N if donor[i] != i
]

return quote
ops = Base.@ntuple $N i -> cache_operator(L.ops[i], v)
return AddedOperator(ops)
$(donor_stmts...)
$(asker_stmts...)
return AddedOperator(($(syms...),))
end
end

Expand Down Expand Up @@ -874,6 +921,7 @@ function update_coefficients(L::ComposedOperator, u, p, t; kwargs...)
end

getops(L::ComposedOperator) = L.ops
getcache(op::ComposedOperator) = op.cache

# Copy method to avoid aliasing
function Base.copy(L::ComposedOperator)
Expand Down Expand Up @@ -939,6 +987,17 @@ end
end
end

function _get_cache_shapes(L::ComposedOperator, v::AbstractVecOrMat)
N = length(L.ops)
K = size(v, 2)

if v isa AbstractMatrix
return ntuple(i -> i < N ? (size(L.ops[i + 1], 1), K) : (size(v, 1), K), Val(N))
else
return ntuple(i -> i < N ? (size(L.ops[i + 1], 1),) : (size(v, 1),), Val(N))
end
end

@generated function cache_self(L::ComposedOperator, v::AbstractVecOrMat)
N = length(L.parameters[2].parameters) # Number of operators

Expand Down Expand Up @@ -1199,6 +1258,7 @@ function update_coefficients(L::InvertedOperator, u, p, t; kwargs...)
end

getops(L::InvertedOperator) = (L.L,)
getcache(op::InvertedOperator) = op.cache
islinear(L::InvertedOperator) = islinear(L.L)
isconvertible(::InvertedOperator) = false

Expand Down Expand Up @@ -1229,9 +1289,10 @@ function Base.copy(L::InvertedOperator)
)
end

_get_cache_shapes(::InvertedOperator, v::AbstractVecOrMat) = size(v)

function cache_self(L::InvertedOperator, u::AbstractVecOrMat)
cache = zero(u)
@reset L.cache = cache
@reset L.cache = zero(u)
return L
end

Expand Down
16 changes: 16 additions & 0 deletions src/func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,22 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray)
return L
end

function _get_cache_shapes(L::FunctionOperator, v::AbstractVecOrMat)
if L.traits.batch
M = size(L, 1)
if v isa AbstractMatrix
return (size(v), (M, size(v, 2)))
else
return (size(v), (M,))
end
end

return (L.traits.sizes[1], L.traits.sizes[2])
end

getcache(op::FunctionOperator) = op.cache
update_cache(L::FunctionOperator, new_cache) = set_cache(L, new_cache)

# fix method amg bw AbstractArray, AbstractVecOrMat
cache_self(L::FunctionOperator, v::AbstractArray) = _cache_self(L, v)
cache_self(L::FunctionOperator, v::AbstractVecOrMat) = _cache_self(L, v)
Expand Down
63 changes: 62 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@ getops(L) = ()
"""
$SIGNATURES

Return the current cache held by `op`, or `nothing` if it holds none.
New operator types get the safe `nothing` default automatically; override
for types that store a shareable `.cache` field.
"""
getcache(::AbstractSciMLOperator) = nothing

"""
$SIGNATURES

Checks whether `L` has preallocated caches for inplace evaluations.
"""
function iscached(L::AbstractSciMLOperator)
Expand Down Expand Up @@ -179,12 +188,64 @@ end
cache_self(L::AbstractSciMLOperator, ::AbstractVecOrMat) = L
cache_internals(L::AbstractSciMLOperator, ::AbstractVecOrMat) = L

"""
$SIGNATURES

Return the expected cache shape specification for `L` given input `v`.
Returns `nothing` if `L` requires no cache.
The return value can be a single `NTuple{N,Int}` (single-array cache),
a `Tuple` of such shapes (multi-array cache), or `nothing` for absent slots.
"""
_get_cache_shapes(::AbstractSciMLOperator, ::AbstractVecOrMat) = nothing

"""
$SIGNATURES

Check whether `hint` is shape-compatible with `shapes` (as returned by `_get_cache_shapes`).
Uses `zip` to avoid integer-indexed Tuple access. Reads only array metadata — safe on GPU.
"""
_cache_compatible(hint, ::Nothing, v::AbstractArray) = false
_cache_compatible(::Nothing, shapes, v::AbstractArray) = false
_cache_compatible(::Nothing, ::Nothing, v::AbstractArray) = false
function _cache_compatible(hint::AbstractArray, shape::Tuple{Vararg{Int}}, v::AbstractArray)
# Check array device compatibility (CPU, GPU, etc.)
Base.typename(typeof(hint)).wrapper === Base.typename(typeof(v)).wrapper || return false
promote_type(eltype(v), eltype(hint)) === eltype(hint) || return false
return size(hint) == shape || return false
end
function _cache_compatible(hint::Tuple, shapes::Tuple, v::AbstractArray)
length(hint) != length(shapes) && return false
return all(((h, s),) -> _cache_compatible(h, s, v), zip(hint, shapes))
end

"""
$SIGNATURES

Inject `new_cache` into `op` as its cache. Default uses `@reset op.cache = new_cache`.
Override for operators that don't use the `.cache` field convention.
"""
update_cache(op::AbstractSciMLOperator, new_cache) = @reset op.cache = new_cache

"""
$SIGNATURES

Like `cache_operator`, but tries to reuse `hint` (an existing cache from a compatible operator)
instead of allocating new buffers. Falls back to `cache_operator` when `hint` is not compatible.
"""
function cache_operator_hinted(op::AbstractSciMLOperator, hint, v::AbstractVecOrMat)
if _cache_compatible(hint, _get_cache_shapes(op, v), v)
op = update_cache(op, hint)
return cache_internals(op, v)
end
return cache_operator(op, v)
end

###
# operator traits
###

Base.size(A::AbstractSciMLOperator, d::Integer) = d <= 2 ? size(A)[d] : 1
Base.eltype(::Type{AbstractSciMLOperator{T}}) where {T} = T
Base.eltype(::Type{<:AbstractSciMLOperator{T}}) where {T} = T
Base.eltype(::AbstractSciMLOperator{T}) where {T} = T
Base.promote_eltype(::AbstractSciMLOperator{<:T1}, ::AbstractSciMLOperator{<:T2}) where {T1, T2} = Base.promote_type(T1, T2)

Expand Down
48 changes: 35 additions & 13 deletions src/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ function update_coefficients(L::TensorProductOperator, u, p, t; kwargs...)
end

getops(L::TensorProductOperator) = L.ops
getcache(op::TensorProductOperator) = op.cache

# Copy method to avoid aliasing
function Base.copy(L::TensorProductOperator)
Expand Down Expand Up @@ -362,30 +363,51 @@ function Base.:\(L::TensorProductOperator, v::AbstractVecOrMat)
return v isa AbstractMatrix ? reshape(V, (n, k)) : reshape(V, (n,))
end

function cache_self(L::TensorProductOperator, v::AbstractVecOrMat)
function _get_cache_shapes(L::TensorProductOperator, v::AbstractVecOrMat)
outer, inner = L.ops
outer isa IdentityOperator && return nothing

mi, ni = size(inner)
mo, no = size(outer)
k = size(v, 2)

is_outer_identity = outer isa IdentityOperator
s1 = (mi, no * k)
s2 = (no, mi, k)
s3 = (mo, mi * k)
s4 = (mo * mi, k)

if reduce(&, issquare.(L.ops))
return (s1, s2, s3, s4, s1, s2, s3)
else
s5 = (ni, mo * k)
s6 = (mo, ni, k)
s7 = (no, ni * k)
return (s1, s2, s3, s4, s5, s6, s7)
end
end

function cache_self(L::TensorProductOperator, v::AbstractVecOrMat)
shapes = _get_cache_shapes(L, v)

# outer is IdentityOperator — no buffers needed
if isnothing(shapes)
@reset L.cache = (nothing, nothing, nothing, nothing, nothing, nothing, nothing)
return L
end

# 3 arg mul!
c1 = is_outer_identity ? nothing : lmul!(false, similar(v, (mi, no * k))) # c1 = inner * v
c2 = is_outer_identity ? nothing : lmul!(false, similar(v, (no, mi, k))) # permute (2, 1, 3)
c3 = is_outer_identity ? nothing : lmul!(false, similar(v, (mo, mi * k))) # c3 = outer * c2
s1, s2, s3, s4, s5, s6, s7 = shapes

# 5 arg mul!
c4 = is_outer_identity ? nothing : lmul!(false, similar(v, (mo * mi, k))) # cache v in 5 arg mul!
c1 = lmul!(false, similar(v, s1)) # inner * v (3-arg mul!)
c2 = lmul!(false, similar(v, s2)) # permute (2,1,3)
c3 = lmul!(false, similar(v, s3)) # outer * c2
c4 = lmul!(false, similar(v, s4)) # copy of w for 5-arg mul!

# 3 arg ldiv!
if mapreduce(issquare, &, L.ops)
c5, c6, c7 = c1, c2, c3
c5, c6, c7 = c1, c2, c3 # square case: ldiv! reuses mul! buffers
else
c5 = lmul!(false, similar(v, (ni, mo * k))) # c5 = inner \ v
c6 = lmul!(false, similar(v, (mo, ni, k))) # permute (2, 1, 3)
c7 = lmul!(false, similar(v, (no, ni * k))) # c7 = outer \ c6
c5 = lmul!(false, similar(v, s5)) # inner \ v (3-arg ldiv!)
c6 = lmul!(false, similar(v, s6)) # permute (2,1,3)
c7 = lmul!(false, similar(v, s7)) # outer \ c6
end

@reset L.cache = (c1, c2, c3, c4, c5, c6, c7)
Expand Down
71 changes: 71 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,77 @@ end
end
end

@testset "AddedOperator cache sharing (Composed, Tensor, Composed, Tensor, Tensor)" begin
using SciMLOperators: cache_operator_hinted, _get_cache_shapes

m1, m2 = 2, 4 # m1 * m2 == N

# C1 and C2: same wrapper (ComposedOperator), different inner type params
# C1 = A1*B1 → ops::Tuple{MatrixOperator, MatrixOperator}
# C2 = A2*B2' → ops::Tuple{MatrixOperator, AdjointOperator{…}}
A1 = MatrixOperator(rand(N, N)); B1 = MatrixOperator(rand(N, N))
A2 = MatrixOperator(rand(N, N)); B2 = MatrixOperator(rand(N, N))
C1 = A1 * B1
C2 = A2 * B2'

# T1, T2, T3: same wrapper (TensorProductOperator), different inner type params
# T1 = Ao ⊗ Ai, T2 = Ao' ⊗ Ai, T3 = Ao ⊗ Ai'
Ao = MatrixOperator(rand(m1, m1)); Ai = MatrixOperator(rand(m2, m2))
T1 = TensorProductOperator(Ao, Ai)
T2 = TensorProductOperator(Ao', Ai)
T3 = TensorProductOperator(Ao, Ai')

L = C1 + T1 + C2 + T2 + T3 + A1 + A2
@test L isa AddedOperator
@test length(L.ops) == 7

for input in (rand(N, K), rand(N))
L = cache_operator(L, input)
expected = C1 * input + T1 * input + C2 * input + T2 * input + T3 * input +
A1 * input + A2 * input

# Correctness: out-of-place (*) and in-place (mul!) paths
@test L * input ≈ expected
w = similar(input)
mul!(w, L, input)
@test w ≈ expected

# Cache sharing: same-wrapper ops with compatible sizes share physical buffers
@test L.ops[3].cache === L.ops[1].cache # C2 (A2*B2') reuses C1's cache
@test L.ops[4].cache === L.ops[2].cache # T2 (Ao'⊗Ai) reuses T1's cache
@test L.ops[5].cache === L.ops[2].cache # T3 (Ao⊗Ai') reuses T1's cache
end

# --- Mixed-eltype: Float64 + ComplexF64 ComposedOperators ---
# promote_type(Float64, ComplexF64) = ComplexF64 = eltype(Ac) → Ac is donor, Ar reuses its cache
Ar = MatrixOperator(rand(Float64, N, N)) * MatrixOperator(rand(Float64, N, N))
Ac = MatrixOperator(rand(ComplexF64, N, N)) * MatrixOperator(rand(ComplexF64, N, N))
v_real = rand(Float64, N)
Lcs = cache_operator(Ar + Ac, v_real)
@test Lcs.ops[1].cache === Lcs.ops[2].cache # Ar reuses Ac's ComplexF64 cache
w_c = similar(v_real, ComplexF64)
mul!(w_c, Lcs, v_real)
@test w_c ≈ Ar * v_real + Ac * v_real

# --- No sharing: Float64 + ComplexF32 → promote_type = ComplexF64 (neither op's eltype) ---
Af32 = MatrixOperator(rand(ComplexF32, N, N)) * MatrixOperator(rand(ComplexF32, N, N))
Lf32s = cache_operator(Ar + Af32, rand(Float64, N))
@test Lf32s.ops[1].cache !== Lf32s.ops[2].cache # independent caches

# --- Non-square ComposedOperator: _get_cache_shapes must match cache_self's allocation ---
M1, M2, M3 = 5, 3, 4
P = MatrixOperator(rand(M1, M2)); Q = MatrixOperator(rand(M2, M3))
PQc = cache_operator(P * Q, rand(M3))

@test _get_cache_shapes(PQc, rand(M3)) == ((M2,), (M3,))
@test map(size, PQc.cache) == ((M2,), (M3,))
v_ns = rand(M3)
w_ns = zeros(M1)
mul!(w_ns, PQc, v_ns)
@test w_ns ≈ P * (Q * v_ns)
end


@testset "ComposedOperator" begin
A = rand(N, N)
B = rand(N, N)
Expand Down
Loading