Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 52 additions & 24 deletions test/nopre/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,32 @@ _ff = (
)
_ff(copy(A), copy(b1))

Enzyme.autodiff(
Reverse,
(
x,
y,
) -> f(
x,
y;
alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
),
Duplicated(copy(A), dA),
Duplicated(copy(b1), db1)
)
# Enzyme >= 0.13.155 crashes compiling any primal containing
# RecursiveFactorization's explicit-SIMD kernels (vector GEP MethodError in
# abs_typeof, https://github.com/EnzymeAD/Enzyme.jl/issues/3164). With
# RecursiveFactorization loaded, DefaultLinearSolver compiles the RFLU branch,
# so this site is affected. The @test_broken flips to an unexpected-pass error
# when a fixed Enzyme release lands — unbreak this and the other #3164 sites.
@test_broken begin
Enzyme.autodiff(
Reverse,
(
x,
y,
) -> f(
x,
y;
alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
),
Duplicated(copy(A), dA),
Duplicated(copy(b1), db1)
)

dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))
dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))

@test dA ≈ dA2
@test db1 ≈ db12
dA ≈ dA2 && db1 ≈ db12
end

A = rand(n, n);
dA = zeros(n, n);
Expand Down Expand Up @@ -167,14 +174,16 @@ f2(A, b1, b2)
dA = zeros(n, n);
db1 = zeros(n);
db2 = zeros(n);
Enzyme.autodiff(
Reverse, f2, Duplicated(copy(A), dA),
Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)
)
# Broken by the Enzyme >= 0.13.155 RecursiveFactorization regression
# (https://github.com/EnzymeAD/Enzyme.jl/issues/3164)
@test_broken begin
Enzyme.autodiff(
Reverse, f2, Duplicated(copy(A), dA),
Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)
)

@test dA ≈ dA2
@test db1 ≈ db12
@test db2 ≈ db22
dA ≈ dA2 && db1 ≈ db12 && db2 ≈ db22
end

function f3(A, b1, b2; alg = KrylovJL_GMRES())
prob = LinearProblem(A, b1)
Expand Down Expand Up @@ -245,6 +254,25 @@ end
LUFactorization(),
RFLUFactorization(), # KrylovJL_GMRES(), fails
)
# Forward mode over RFLU also hits the Enzyme >= 0.13.155 regression
# (https://github.com/EnzymeAD/Enzyme.jl/issues/3164)
if alg isa RFLUFactorization
@test_broken begin
en_jac = map(onehot(b1)) do db1
return only(
Enzyme.autodiff(
set_runtime_activity(Forward), fnice,
Const(A), Duplicated(b1, db1), Const(alg)
)
)
end |> collect
fd_jac = FiniteDiff.finite_difference_jacobian(b -> fnice(A, b, alg), b1) |>
vec
isapprox(en_jac, fd_jac, rtol = 1.0e-4)
end
continue
end

fb_closure = b -> fnice(A, b, alg)

fd_jac = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec
Expand Down
Loading