From 287c5fed66d156c3ac5eeca63d50777c16669d3e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 4 Feb 2026 11:49:00 -0500 Subject: [PATCH 1/5] add specializations `svd_trunc(!)` for `TruncatedAlgorithm` --- .../MatrixAlgebraKitMooncakeExt.jl | 132 ++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index e4ec256f..c323f953 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -10,6 +10,7 @@ using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback! using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback! using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback! +using MatrixAlgebraKit: TruncatedAlgorithm using LinearAlgebra @@ -437,6 +438,48 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS end return output_codual, svd_trunc_adjoint end +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + USVᴴ_dUSVᴴ_arr = arrayify.(Mooncake.primal(USVᴴ_dUSVᴴ), Mooncake.tangent(USVᴴ_dUSVᴴ)) + USVᴴ, dUSVᴴ = first.(USVᴴ_dUSVᴴ_arr), last.(USVᴴ_dUSVᴴ_arr) + alg = Mooncake.primal(alg_dalg) + + # store state prior to primal call + Ac = copy(A) + USVᴴc = copy.(USVᴴ) + + # compute primal - capture full USVᴴ and ind + USVᴴ = svd_compact!(A, USVᴴ, alg.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind) + + # pack output - note that we allocate new dUSVᴴtrunc because these aren't actually + # overwritten in the input! + USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ)) + + # define pullback + local svd_trunc_adjoint + let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))) + function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) + abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) || + @warn "Pullback for `svd_trunc` ignores non-zero tangents for truncation error" + + # compute pullbacks + svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind) + zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required + zero!.(dUSVᴴ) + + # restore state + copy!(A, Ac) + copy!.(USVᴴ, USVᴴc) + + return ntuple(Returns(NoRData()), 4) + end + end + + return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint +end @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual) @@ -464,6 +507,33 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C end return output_codual, svd_trunc_adjoint end +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + + # compute primal - capture full USVᴴ and ind + USVᴴ = svd_compact(A, alg.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind) + + # pack output + USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ)) + + # define pullback + local svd_trunc_adjoint + let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))) + function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) + abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) || + @warn "Pullback for `svd_trunc` ignores non-zero tangents for truncation error" + svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) + zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required + return ntuple(Returns(NoRData()), 3) + end + end + + return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint +end @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) @@ -504,6 +574,44 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U end return output_codual, svd_trunc_adjoint end +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + USVᴴ_dUSVᴴ_arr = arrayify.(Mooncake.primal(USVᴴ_dUSVᴴ), Mooncake.tangent(USVᴴ_dUSVᴴ)) + USVᴴ, dUSVᴴ = first.(USVᴴ_dUSVᴴ_arr), last.(USVᴴ_dUSVᴴ_arr) + alg = Mooncake.primal(alg_dalg) + + # store state prior to primal call + Ac = copy(A) + USVᴴc = copy.(USVᴴ) + + # compute primal - capture full USVᴴ and ind + USVᴴ = svd_compact!(A, USVᴴ, alg.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + + # pack output - note that we allocate new dUSVᴴtrunc because these aren't actually + # overwritten in the input! + USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual(USVᴴtrunc) + + # define pullback + local svd_trunc_adjoint + let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))) + function svd_trunc_adjoint(::NoRData) + # compute pullbacks + svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind) + zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required + zero!.(dUSVᴴ) + + # restore state + copy!(A, Ac) + copy!.(USVᴴ, USVᴴc) + + return ntuple(Returns(NoRData()), 4) + end + end + + return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint +end @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual) @@ -530,5 +638,29 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al end return output_codual, svd_trunc_adjoint end +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + + # compute primal - capture full USVᴴ and ind + USVᴴ = svd_compact(A, alg.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + + # pack output + USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual(USVᴴtrunc) + + # define pullback + local svd_trunc_adjoint + let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))) + function svd_trunc_adjoint(::NoRData) + svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) + zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required + return ntuple(Returns(NoRData()), 3) + end + end + + return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint +end end From a23735c8c90e2ad6884f2397987aab91b08f3bb3 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 4 Feb 2026 11:53:01 -0500 Subject: [PATCH 2/5] update changelog --- docs/src/changelog.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/changelog.md b/docs/src/changelog.md index 9147c4da..3283f9de 100644 --- a/docs/src/changelog.md +++ b/docs/src/changelog.md @@ -24,6 +24,8 @@ When releasing a new version, move the "Unreleased" changes to a new version sec ### Changed +- The Mooncake rules for truncated decompositions with `TruncatedAlgorithm` now use the pullbacks that make use of the full decomposition. ([#171](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/171)) + ### Deprecated ### Removed From bf151a080c5eedb8192ff55aca39c87989201fa5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 4 Feb 2026 13:27:23 -0500 Subject: [PATCH 3/5] add specializations `eig_trunc(!)` for `TruncatedAlgorithm` --- .../MatrixAlgebraKitMooncakeExt.jl | 192 +++++++++++++++--- 1 file changed, 163 insertions(+), 29 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index c323f953..a96cebd6 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -168,14 +168,21 @@ for (f!, f, f_full, pb, adj) in ( end end -for (f!, f, f_ne!, f_ne, pb, adj) in ( - (:eig_trunc!, :eig_trunc, :eig_trunc_no_error!, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint), - (:eigh_trunc!, :eigh_trunc, :eigh_trunc_no_error!, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint), - ) +for f in (:eig, :eigh) + f_trunc = Symbol(f, :_trunc) + f_trunc! = Symbol(f_trunc, :!) + f_full = Symbol(f, :_full) + f_full! = Symbol(f_full, :!) + f_pullback! = Symbol(f, :_pullback!) + f_trunc_pullback! = Symbol(f_trunc, :_pullback!) + f_adjoint! = Symbol(f, :_adjoint!) + f_trunc_no_error = Symbol(f_trunc, :_no_error) + f_trunc_no_error! = Symbol(f_trunc_no_error, :!) + @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} - function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f_trunc!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) DV = Mooncake.primal(DV_dDV) @@ -183,54 +190,121 @@ for (f!, f, f_ne!, f_ne, pb, adj) in ( Ac = copy(A) DVc = copy.(DV) alg = Mooncake.primal(alg_dalg) - output = $f!(A, DV, alg) + output = $f_trunc!(A, DV, alg) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal # of ComplexF32) into the correct **forwards** data type (since we are now in the forward # pass). For many types this is done automatically when the forward step returns, but # not for nested structs with various fields (like Diagonal{Complex}) - output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) - function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real} + output_codual = Mooncake.zero_fcodual(output) + function $f_adjoint!(dy::Tuple{NoRData, NoRData, <:Real}) copy!(A, Ac) Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error" D′, dD′ = arrayify(Dtrunc, dDtrunc_) V′, dV′ = arrayify(Vtrunc, dVtrunc_) - $pb(dA, A, (D′, V′), (dD′, dV′)) + $f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′)) copy!(DV[1], DVc[1]) copy!(DV[2], DVc[2]) zero!(dD′) zero!(dV′) return NoRData(), NoRData(), NoRData(), NoRData() end - return output_codual, $adj + return output_codual, $f_adjoint! end - function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) + function Mooncake.rrule!!(::CoDual{typeof($f_trunc!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + DV_dDV_arr = arrayify.(Mooncake.primal(DV_dDV), Mooncake.tangent(DV_dDV)) + DV, dDV = first.(DV_dDV_arr), last.(DV_dDV_arr) + alg = Mooncake.primal(alg_dalg) + + # store state prior to primal call + Ac = copy(A) + DVc = copy.(DV) + + # compute primal - capture full DV and ind + DV = $f_full!(A, DV, alg.alg) + DVtrunc, ind = MatrixAlgebraKit.truncate($f_trunc!, DV, alg.trunc) + ϵ = MatrixAlgebraKit.truncation_error(diagview(DV[1]), ind) + + # pack output - note that we allocate new dDVtrunc because these aren't overwritten in the input + DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ)) + + # define pullback + local $f_adjoint! + let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc)))) + function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real}) + abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) || + @warn "Pullback for `$f!` ignores non-zero tangents for truncation error" + + # compute pullbacks + $f_pullback!(dA, Ac, DVc, dDVtrunc, ind) + zero!.(dDVtrunc) # since this is allocated in this function this is probably not required + + # restore state + copy!(A, Ac) + copy!.(DV, DVc) + + return ntuple(Returns(NoRData()), 4) + end + end + + return DVtrunc_dDVtrunc, $f_adjoint! + end + function Mooncake.rrule!!(::CoDual{typeof($f_trunc)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) alg = Mooncake.primal(alg_dalg) - output = $f(A, alg) + output = $f_trunc(A, alg) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal # of ComplexF32) into the correct **forwards** data type (since we are now in the forward # pass). For many types this is done automatically when the forward step returns, but # not for nested structs with various fields (like Diagonal{Complex}) output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) - function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real} + function $f_adjoint!(dy::Tuple{NoRData, NoRData, T}) where {T <: Real} Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error" D, dD = arrayify(Dtrunc, dDtrunc_) V, dV = arrayify(Vtrunc, dVtrunc_) - $pb(dA, A, (D, V), (dD, dV)) + $f_trunc_pullback!(dA, A, (D, V), (dD, dV)) zero!(dD) zero!(dV) return NoRData(), NoRData(), NoRData() end - return output_codual, $adj + return output_codual, $f_adjoint! end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm} - function Mooncake.rrule!!(::CoDual{typeof($f_ne!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) + function Mooncake.rrule!!(::CoDual{typeof($f_trunc)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + + # compute primal - capture full DV and ind + DV = $f_full(A, alg.alg) + DVtrunc, ind = MatrixAlgebraKit.truncate($f_trunc!, DV, alg.trunc) + ϵ = MatrixAlgebraKit.truncation_error(diagview(DV[1]), ind) + + # pack output + DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ)) + + # define pullback + local $f_adjoint! + let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc)))) + function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real}) + abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) || + @warn "Pullback for `$f_trunc` ignores non-zero tangents for truncation error" + $f_pullback!(dA, A, DV, dDVtrunc, ind) + zero!.(dDVtrunc) # since this is allocated in this function this is probably not required + return ntuple(Returns(NoRData()), 3) + end + end + + return DVtrunc_dDVtrunc, $f_adjoint! + end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) alg = Mooncake.primal(alg_dalg) @@ -238,48 +312,108 @@ for (f!, f, f_ne!, f_ne, pb, adj) in ( dDV = Mooncake.tangent(DV_dDV) Ac = copy(A) DVc = copy.(DV) - output = $f_ne!(A, DV, alg) + output = $f_trunc_no_error!(A, DV, alg) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal # of ComplexF32) into the correct **forwards** data type (since we are now in the forward # pass). For many types this is done automatically when the forward step returns, but # not for nested structs with various fields (like Diagonal{Complex}) output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) - function $adj(::NoRData) + function $f_adjoint!(::NoRData) copy!(A, Ac) Dtrunc, Vtrunc = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual) D′, dD′ = arrayify(Dtrunc, dDtrunc_) V′, dV′ = arrayify(Vtrunc, dVtrunc_) - $pb(dA, A, (D′, V′), (dD′, dV′)) + $f_pullback!(dA, A, (D′, V′), (dD′, dV′)) copy!(DV[1], DVc[1]) copy!(DV[2], DVc[2]) zero!(dD′) zero!(dV′) return NoRData(), NoRData(), NoRData(), NoRData() end - return output_codual, $adj + return output_codual, $f_adjoint! + end + function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + DV_dDV_arr = arrayify.(Mooncake.primal(DV_dDV), Mooncake.tangent(DV_dDV)) + DV, dDV = first.(DV_dDV_arr), last.(DV_dDV_arr) + alg = Mooncake.primal(alg_dalg) + + # store state prior to primal call + Ac = copy(A) + DVc = copy.(DV) + + # compute primal - capture full DV and ind + DV = $f_full!(A, DV, alg.alg) + DVtrunc, ind = MatrixAlgebraKit.truncate($f_trunc!, DV, alg.trunc) + + # pack output - note that we allocate new dDVtrunc because these aren't overwritten in the input + DVtrunc_dDVtrunc = Mooncake.zero_fcodual(DVtrunc) + + # define pullback + local $f_adjoint! + let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc))) + function $f_adjoint!(::NoRData) + # compute pullbacks + $f_pullback!(dA, Ac, DVc, dDVtrunc, ind) + zero!.(dDVtrunc) # since this is allocated in this function this is probably not required + + # restore state + copy!(A, Ac) + copy!.(DV, DVc) + + return ntuple(Returns(NoRData()), 4) + end + end + + return DVtrunc_dDVtrunc, $f_adjoint! end - function Mooncake.rrule!!(::CoDual{typeof($f_ne)}, A_dA::CoDual, alg_dalg::CoDual) + function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) alg = Mooncake.primal(alg_dalg) - output = $f_ne(A, alg) + output = $f_trunc_no_error(A, alg) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal # of ComplexF32) into the correct **forwards** data type (since we are now in the forward # pass). For many types this is done automatically when the forward step returns, but # not for nested structs with various fields (like Diagonal{Complex}) output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) - function $adj(::NoRData) + function $f_adjoint!(::NoRData) Dtrunc, Vtrunc = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual) D, dD = arrayify(Dtrunc, dDtrunc_) V, dV = arrayify(Vtrunc, dVtrunc_) - $pb(dA, A, (D, V), (dD, dV)) + $f_trunc_pullback!(dA, A, (D, V), (dD, dV)) zero!(dD) zero!(dV) return NoRData(), NoRData(), NoRData() end - return output_codual, $adj + return output_codual, $f_adjoint! + end + function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + + # compute primal - capture full DV and ind + DV = $f_full(A, alg.alg) + DVtrunc, ind = MatrixAlgebraKit.truncate($f_trunc!, DV, alg.trunc) + + # pack output + DVtrunc_dDVtrunc = Mooncake.zero_fcodual(DVtrunc) + + # define pullback + local $f_adjoint! + let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc))) + function $f_adjoint!(::NoRData) + $f_pullback!(dA, A, DV, dDVtrunc, ind) + zero!.(dDVtrunc) # since this is allocated in this function this is probably not required + return ntuple(Returns(NoRData()), 3) + end + end + + return DVtrunc_dDVtrunc, $f_adjoint! end end end From 6b352e9786a589ff2635f909b5183439860f76d1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 5 Feb 2026 09:12:58 -0500 Subject: [PATCH 4/5] pullback truncation error warning --- .../MatrixAlgebraKitMooncakeExt.jl | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index a96cebd6..ea284aeb 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -168,6 +168,9 @@ for (f!, f, f_full, pb, adj) in ( end end +_warn_pullback_truncerror(dϵ::Real; tol = MatrixAlgebraKit.defaulttol(dϵ)) = + abs(dϵ) ≤ tol || @warn "Pullback ignores non-zero tangents for truncation error" + for f in (:eig, :eigh) f_trunc = Symbol(f, :_trunc) f_trunc! = Symbol(f_trunc, :!) @@ -200,7 +203,7 @@ for f in (:eig, :eigh) copy!(A, Ac) Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error" + _warn_pullback_truncerror(dy[3]) D′, dD′ = arrayify(Dtrunc, dDtrunc_) V′, dV′ = arrayify(Vtrunc, dVtrunc_) $f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′)) @@ -235,8 +238,7 @@ for f in (:eig, :eigh) local $f_adjoint! let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc)))) function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real}) - abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) || - @warn "Pullback for `$f!` ignores non-zero tangents for truncation error" + _warn_pullback_truncerror(dϵ) # compute pullbacks $f_pullback!(dA, Ac, DVc, dDVtrunc, ind) @@ -265,7 +267,7 @@ for f in (:eig, :eigh) function $f_adjoint!(dy::Tuple{NoRData, NoRData, T}) where {T <: Real} Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error" + _warn_pullback_truncerror(dy[3]) D, dD = arrayify(Dtrunc, dDtrunc_) V, dV = arrayify(Vtrunc, dVtrunc_) $f_trunc_pullback!(dA, A, (D, V), (dD, dV)) @@ -292,8 +294,7 @@ for f in (:eig, :eigh) local $f_adjoint! let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc)))) function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real}) - abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) || - @warn "Pullback for `$f_trunc` ignores non-zero tangents for truncation error" + _warn_pullback_truncerror(dϵ) $f_pullback!(dA, A, DV, dDVtrunc, ind) zero!.(dDVtrunc) # since this is allocated in this function this is probably not required return ntuple(Returns(NoRData()), 3) @@ -554,7 +555,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS copy!(A, Ac) Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error" + _warn_pullback_truncerror(dy[4]) U′, dU′ = arrayify(Utrunc, dUtrunc_) S′, dS′ = arrayify(Strunc, dStrunc_) Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_) @@ -596,8 +597,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS local svd_trunc_adjoint let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))) function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) - abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) || - @warn "Pullback for `svd_trunc` ignores non-zero tangents for truncation error" + _warn_pullback_truncerror(dϵ) # compute pullbacks svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind) @@ -629,7 +629,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error" + _warn_pullback_truncerror(dy[4]) U, dU = arrayify(Utrunc, dUtrunc_) S, dS = arrayify(Strunc, dStrunc_) Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) @@ -658,8 +658,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C local svd_trunc_adjoint let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))) function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) - abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) || - @warn "Pullback for `svd_trunc` ignores non-zero tangents for truncation error" + _warn_pullback_truncerror(dϵ) svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required return ntuple(Returns(NoRData()), 3) From 34643f8f05c450dbc7feed4377b448257e071862 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 5 Feb 2026 09:18:15 -0500 Subject: [PATCH 5/5] remove letblocks --- .../MatrixAlgebraKitMooncakeExt.jl | 149 ++++++++---------- 1 file changed, 66 insertions(+), 83 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index ea284aeb..3a113c20 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -13,7 +13,6 @@ using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback! using MatrixAlgebraKit: TruncatedAlgorithm using LinearAlgebra - Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any} @@ -235,21 +234,19 @@ for f in (:eig, :eigh) DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ)) # define pullback - local $f_adjoint! - let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc)))) - function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real}) - _warn_pullback_truncerror(dϵ) + dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc)))) + function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real}) + _warn_pullback_truncerror(dϵ) - # compute pullbacks - $f_pullback!(dA, Ac, DVc, dDVtrunc, ind) - zero!.(dDVtrunc) # since this is allocated in this function this is probably not required + # compute pullbacks + $f_pullback!(dA, Ac, DVc, dDVtrunc, ind) + zero!.(dDVtrunc) # since this is allocated in this function this is probably not required - # restore state - copy!(A, Ac) - copy!.(DV, DVc) + # restore state + copy!(A, Ac) + copy!.(DV, DVc) - return ntuple(Returns(NoRData()), 4) - end + return ntuple(Returns(NoRData()), 4) end return DVtrunc_dDVtrunc, $f_adjoint! @@ -291,14 +288,12 @@ for f in (:eig, :eigh) DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ)) # define pullback - local $f_adjoint! - let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc)))) - function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real}) - _warn_pullback_truncerror(dϵ) - $f_pullback!(dA, A, DV, dDVtrunc, ind) - zero!.(dDVtrunc) # since this is allocated in this function this is probably not required - return ntuple(Returns(NoRData()), 3) - end + dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc)))) + function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real}) + _warn_pullback_truncerror(dϵ) + $f_pullback!(dA, A, DV, dDVtrunc, ind) + zero!.(dDVtrunc) # since this is allocated in this function this is probably not required + return ntuple(Returns(NoRData()), 3) end return DVtrunc_dDVtrunc, $f_adjoint! @@ -353,19 +348,17 @@ for f in (:eig, :eigh) DVtrunc_dDVtrunc = Mooncake.zero_fcodual(DVtrunc) # define pullback - local $f_adjoint! - let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc))) - function $f_adjoint!(::NoRData) - # compute pullbacks - $f_pullback!(dA, Ac, DVc, dDVtrunc, ind) - zero!.(dDVtrunc) # since this is allocated in this function this is probably not required - - # restore state - copy!(A, Ac) - copy!.(DV, DVc) - - return ntuple(Returns(NoRData()), 4) - end + dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc))) + function $f_adjoint!(::NoRData) + # compute pullbacks + $f_pullback!(dA, Ac, DVc, dDVtrunc, ind) + zero!.(dDVtrunc) # since this is allocated in this function this is probably not required + + # restore state + copy!(A, Ac) + copy!.(DV, DVc) + + return ntuple(Returns(NoRData()), 4) end return DVtrunc_dDVtrunc, $f_adjoint! @@ -405,13 +398,11 @@ for f in (:eig, :eigh) DVtrunc_dDVtrunc = Mooncake.zero_fcodual(DVtrunc) # define pullback - local $f_adjoint! - let ind = ind, dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc))) - function $f_adjoint!(::NoRData) - $f_pullback!(dA, A, DV, dDVtrunc, ind) - zero!.(dDVtrunc) # since this is allocated in this function this is probably not required - return ntuple(Returns(NoRData()), 3) - end + dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc))) + function $f_adjoint!(::NoRData) + $f_pullback!(dA, A, DV, dDVtrunc, ind) + zero!.(dDVtrunc) # since this is allocated in this function this is probably not required + return ntuple(Returns(NoRData()), 3) end return DVtrunc_dDVtrunc, $f_adjoint! @@ -594,22 +585,20 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ)) # define pullback - local svd_trunc_adjoint - let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))) - function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) - _warn_pullback_truncerror(dϵ) + dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))) + function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) + _warn_pullback_truncerror(dϵ) - # compute pullbacks - svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind) - zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required - zero!.(dUSVᴴ) + # compute pullbacks + svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind) + zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required + zero!.(dUSVᴴ) - # restore state - copy!(A, Ac) - copy!.(USVᴴ, USVᴴc) + # restore state + copy!(A, Ac) + copy!.(USVᴴ, USVᴴc) - return ntuple(Returns(NoRData()), 4) - end + return ntuple(Returns(NoRData()), 4) end return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint @@ -655,14 +644,12 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ)) # define pullback - local svd_trunc_adjoint - let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))) - function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) - _warn_pullback_truncerror(dϵ) - svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) - zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required - return ntuple(Returns(NoRData()), 3) - end + dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))) + function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) + _warn_pullback_truncerror(dϵ) + svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) + zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required + return ntuple(Returns(NoRData()), 3) end return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint @@ -727,20 +714,18 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual(USVᴴtrunc) # define pullback - local svd_trunc_adjoint - let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))) - function svd_trunc_adjoint(::NoRData) - # compute pullbacks - svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind) - zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required - zero!.(dUSVᴴ) - - # restore state - copy!(A, Ac) - copy!.(USVᴴ, USVᴴc) - - return ntuple(Returns(NoRData()), 4) - end + dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))) + function svd_trunc_adjoint(::NoRData) + # compute pullbacks + svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind) + zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required + zero!.(dUSVᴴ) + + # restore state + copy!(A, Ac) + copy!.(USVᴴ, USVᴴc) + + return ntuple(Returns(NoRData()), 4) end return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint @@ -784,13 +769,11 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual(USVᴴtrunc) # define pullback - local svd_trunc_adjoint - let ind = ind, dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))) - function svd_trunc_adjoint(::NoRData) - svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) - zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required - return ntuple(Returns(NoRData()), 3) - end + dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))) + function svd_trunc_adjoint(::NoRData) + svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) + zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required + return ntuple(Returns(NoRData()), 3) end return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint