Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,22 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A, PWᴴ, alg)
return PWᴴ, right_polar_pullback
end

function ChainRulesCore.rrule(::typeof(project_hermitian), A, alg)
Aₕ = project_hermitian(A, alg)
function project_hermitian_pullback(ΔAₕ)
ΔA = project_hermitian(unthunk(ΔAₕ))
return NoTangent(), ΔA, NoTangent()
end
return Aₕ, project_hermitian_pullback
end

function ChainRulesCore.rrule(::typeof(project_antihermitian), A, alg)
Aₐ = project_antihermitian(A, alg)
function project_antihermitian_pullback(ΔAₐ)
ΔA = project_antihermitian(unthunk(ΔAₐ))
return NoTangent(), ΔA, NoTangent()
end
return Aₐ, project_antihermitian_pullback
end

end
58 changes: 51 additions & 7 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ for f in (:eig, :eigh)
_warn_pullback_truncerror(dϵ)

# compute pullbacks
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required

# restore state
Expand Down Expand Up @@ -351,8 +351,8 @@ for f in (:eig, :eigh)
dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc)))
function $f_adjoint!(::NoRData)
# compute pullbacks
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
zero!.(dDV)

# restore state
copy!(A, Ac)
Expand Down Expand Up @@ -425,7 +425,7 @@ for (f!, f) in (
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
USVᴴc = copy.(USVᴴ)
output = $f!(A, Mooncake.primal(alg_dalg))
output = $f!(A, USVᴴ, Mooncake.primal(alg_dalg))
function svd_adjoint(::NoRData)
copy!(A, Ac)
if $(f! == svd_compact!)
Expand Down Expand Up @@ -590,7 +590,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
_warn_pullback_truncerror(dϵ)

# compute pullbacks
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind)
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
zero!.(dUSVᴴ)

Expand Down Expand Up @@ -717,8 +717,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))
function svd_trunc_adjoint(::NoRData)
# compute pullbacks
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind)
zero!.(dUSVᴴ)

# restore state
Expand Down Expand Up @@ -779,4 +778,49 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
end

# single-output projections: project_hermitian!, project_antihermitian!
for (f!, f, adj) in (
(:project_hermitian!, :project_hermitian, :project_hermitian_adjoint),
(:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint),
)
@eval begin
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
A, dA = arrayify(A_dA)
arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg)
argc = copy(arg)
arg = $f!(A, arg, Mooncake.primal(alg_dalg))

function $adj(::NoRData)
$f!(darg)
if dA !== darg
dA .+= darg
zero!(darg)
end
copy!(arg, argc)
return ntuple(Returns(NoRData()), 4)
end

return arg_darg, $adj
end

@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
A, dA = arrayify(A_dA)
output = $f(A, Mooncake.primal(alg_dalg))
output_doutput = Mooncake.zero_fcodual(output)

doutput = last(arrayify(output_doutput))
function $adj(::NoRData)
# TODO: need accumulating projection to avoid intermediate here
dA .+= $f(doutput)
zero!(doutput)
return ntuple(Returns(NoRData()), 3)
end

return output_doutput, $adj
end
end
end

end
4 changes: 2 additions & 2 deletions src/pullbacks/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
!iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1)
!iszerotangent(ΔP) && mul!(M, ΔP, P, -1, 1)
C = _sylvester(P, P, M' - M)
C .+= ΔP
!iszerotangent(ΔP) && (C .+= ΔP)
ΔA = mul!(ΔA, W, C, 1, 1)
if !iszerotangent(ΔW)
ΔWP = ΔW / P
Expand Down Expand Up @@ -47,7 +47,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
!iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1)
!iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1)
C = _sylvester(P, P, M' - M)
C .+= ΔP
!iszerotangent(ΔP) && (C .+= ΔP)
ΔA = mul!(ΔA, C, Wᴴ, 1, 1)
if !iszerotangent(ΔWᴴ)
PΔWᴴ = P \ ΔWᴴ
Expand Down
22 changes: 15 additions & 7 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ))
qr_rank(R; rank_atol = default_pullback_rank_atol(R)) =
@something findlast(>=(rank_atol) ∘ abs, diagview(R)) 0

function check_qr_cotangents(
Q, R, ΔQ, ΔR, p::Int;
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
minmn = min(size(Q, 1), size(R, 2))
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
Expand All @@ -7,11 +14,13 @@ function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Rea
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
Δgauge_Q = norm(ΔQ2, Inf)
Δgauge = max(Δgauge, Δgauge_Q)
end
if !iszerotangent(ΔR)
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
Δgauge = max(Δgauge, norm(ΔR22, Inf))
Δgauge_R = norm(ΔR22, Inf)
Δgauge = max(Δgauge, Δgauge_R)
end
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
Expand All @@ -29,7 +38,7 @@ function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_
# Q2' * ΔQ2 as a gauge dependent quantity.
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
@warn "`qr` full cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

Expand Down Expand Up @@ -60,9 +69,8 @@ function qr_pullback!(
Q, R = QR
m = size(Q, 1)
n = size(R, 2)
minmn = min(m, n)
Rd = diagview(R)
p = @something findlast(>=(rank_atol) ∘ abs, Rd) 0
p = qr_rank(R)

ΔQ, ΔR = ΔQR

Expand All @@ -72,7 +80,7 @@ function qr_pullback!(
ΔA1 = view(ΔA, :, 1:p)
ΔA2 = view(ΔA, :, (p + 1):n)

check_qr_cotangents(Q, R, ΔQ, ΔR, minmn, p; gauge_atol)
check_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol)

ΔQ̃ = zero!(similar(Q, (m, p)))
if !iszerotangent(ΔQ)
Expand Down
29 changes: 0 additions & 29 deletions test/mooncake.jl

This file was deleted.

19 changes: 19 additions & 0 deletions test/mooncake/eig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/eigh.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/lq.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/orthnull.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/polar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/projections.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_projections(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/qr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ if filter_tests!(testsuite, args)
is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true"
if is_apple_ci
delete!(testsuite, "enzyme")
delete!(testsuite, "mooncake")
filter!(p -> !startswith(first(p), "mooncake/"), testsuite)
delete!(testsuite, "chainrules")
end
Sys.iswindows() && delete!(testsuite, "enzyme")
Expand Down
Loading
Loading