Skip to content

Commit 39b3717

Browse files
WIP: fix KrylovJL_GMRES with Enzyme
1 parent 33911f6 commit 39b3717

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ using Enzyme
88

99
using EnzymeCore
1010

11+
@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.KrylovJL}) = true
12+
@inline EnzymeCore.EnzymeRules.inactive_type(v::Type{LinearSolve.Krylov.GmresSolver}) = true
13+
1114
function EnzymeCore.EnzymeRules.forward(
1215
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
1316
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}

test/enzyme.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA),
158158
@test db1 db12
159159
@test db2 db22
160160

161-
#=
161+
162162
function f3(A, b1, b2; alg = KrylovJL_GMRES())
163163
prob = LinearProblem(A, b1)
164164
cache = init(prob, alg)
@@ -168,12 +168,14 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES())
168168
norm(s1 + s2)
169169
end
170170

171+
dA = zeros(n, n);
172+
db1 = zeros(n);
173+
db2 = zeros(n);
171174
Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
172175

173176
@test dA dA2 atol=5e-5
174177
@test db1 db12
175178
@test db2 db22
176-
=#
177179

178180
A = rand(n, n);
179181
dA = zeros(n, n);

0 commit comments

Comments
 (0)