Skip to content

Commit 6f3f2cd

Browse files
Merge pull request #405 from SciML/enzyme_batch
Fix enzyme batch mode
2 parents 5a8aa51 + e56227a commit 6f3f2cd

File tree

2 files changed

+42
-20
lines changed

2 files changed

+42
-20
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.i
5858
d_b .= 0
5959
end
6060
else
61-
for i in 1:EnzymeRules.width(config)
62-
if d_A !== prob_d_A[i]
63-
prob_d_A[i] .+= d_A[i]
64-
d_A[i] .= 0
61+
for (_prob_d_A,_d_A,_prob_d_b, _d_b) in zip(prob_d_A, d_A, prob_d_b, d_b)
62+
if _d_A !== _prob_d_A
63+
_prob_d_A .+= _d_A
64+
_d_A .= 0
6565
end
66-
if d_b !== prob_d_b[i]
67-
prob_d_b[i] .+= d_b[i]
68-
d_b[i] .= 0
66+
if _d_b !== _prob_d_b
67+
_prob_d_b .+= _d_b
68+
_d_b .= 0
6969
end
7070
end
7171
end

test/enzyme.jl

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,25 @@ b1 = rand(n);
3333
db1 = zeros(n);
3434
db12 = zeros(n);
3535

36-
#=
37-
# Batch test fails
38-
# Captured in MWE: https://github.com/EnzymeAD/Enzyme.jl/issues/1075
36+
# Batch test
37+
n = 4
38+
A = rand(n, n);
39+
dA = zeros(n, n);
40+
dA2 = zeros(n, n);
41+
b1 = rand(n);
42+
db1 = zeros(n);
43+
db12 = zeros(n);
3944

40-
function fbatch(y, A, b1; alg = LUFactorization())
45+
function f(A, b1; alg = LUFactorization())
4146
prob = LinearProblem(A, b1)
42-
4347
sol1 = solve(prob, alg)
48+
s1 = sol1.u
49+
norm(s1)
50+
end
4451

52+
function fbatch(y, A, b1; alg = LUFactorization())
53+
prob = LinearProblem(A, b1)
54+
sol1 = solve(prob, alg)
4555
s1 = sol1.u
4656
y[1] = norm(s1)
4757
nothing
@@ -50,16 +60,28 @@ end
5060
y = [0.0]
5161
dy1 = [1.0]
5262
dy2 = [1.0]
63+
Enzyme.autodiff(Reverse, fbatch, Duplicated(y, dy1), Duplicated(copy(A), dA), Duplicated(copy(b1), db1))
64+
65+
@test y[1] f(copy(A),b1)
66+
dA_2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
67+
db1_2 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
68+
69+
@test dA dA_2
70+
@test db1 db1_2
71+
72+
y .= 0
73+
dy1 .= 1
74+
dy2 .= 1
75+
dA .= 0
76+
dA2 .= 0
77+
db1 .= 0
78+
db12 .= 0
5379
Enzyme.autodiff(Reverse, fbatch, BatchDuplicated(y, (dy1, dy2)), BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))
5480

55-
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
56-
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
57-
58-
@test_broken dA ≈ dA_2
59-
@test_broken dA2 ≈ dA_2
60-
@test_broken db1 ≈ db1_2
61-
@test_broken db12 ≈ db1_2
62-
=#
81+
@test dA dA_2
82+
@test db1 db1_2
83+
@test dA2 dA_2
84+
@test db12 db1_2
6385

6486
function f(A, b1, b2; alg = LUFactorization())
6587
prob = LinearProblem(A, b1)

0 commit comments

Comments
 (0)