Skip to content

Commit d69af77

Browse files
committed
More caching
1 parent e4f0785 commit d69af77

File tree

1 file changed

+37
-14
lines changed

1 file changed

+37
-14
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,29 +30,41 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
3030
(dval.b for dval in dres)
3131
end
3232

33-
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b))
33+
34+
prob_d_A = if EnzymeRules.width(config) == 1
35+
prob.dval.A
36+
else
37+
(dval.A for dval in prob.dval)
38+
end
39+
prob_d_b = if EnzymeRules.width(config) == 1
40+
prob.dval.b
41+
else
42+
(dval.b for dval in prob.dval)
43+
end
44+
45+
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
3446
end
3547

3648
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
37-
d_A, d_b = cache
49+
d_A, d_b, prob_d_A, prob_d_b = cache
3850

3951
if EnzymeRules.width(config) == 1
40-
if d_A !== prob.dval.A
41-
prob.dval.A .+= d_A
52+
if d_A !== prob_d_A
53+
prob_d_A .+= d_A
4254
d_A .= 0
4355
end
44-
if d_b !== prob.dval.b
45-
prob.dval.b .+= d_b
56+
if d_b !== prob_d_b
57+
prob_d_b .+= d_b
4658
d_b .= 0
4759
end
4860
else
4961
for i in 1:EnzymeRules.width(config)
50-
if d_A !== prob.dval.A
51-
prob.dval.A[i] .+= d_A[i]
62+
if d_A !== prob_d_A[i]
63+
prob_d_A[i] .+= d_A[i]
5264
d_A[i] .= 0
5365
end
54-
if d_b !== prob.dval.b
55-
prob.dval.b[i] .+= d_b[i]
66+
if d_b !== prob_d_b[i]
67+
prob_d_b[i] .+= d_b[i]
5668
d_b[i] .= 0
5769
end
5870
end
@@ -87,22 +99,33 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
8799
resvals = if EnzymeRules.width(config) == 1
88100
dres.u
89101
else
90-
(dr.u for dr in dres)
102+
ntuple(Val(EnzymeRules.width(config))) do i
103+
Base.@_inline_meta
104+
dres[i].u
105+
end
91106
end
92107

93108
dAs = if EnzymeRules.width(config) == 1
94109
(linsolve.dval.A,)
95110
else
96-
(dval.A for dval in linsolve.dval)
111+
ntuple(Val(EnzymeRules.width(config))) do i
112+
Base.@_inline_meta
113+
linsolve.dval[i].A
114+
end
97115
end
98116

99117
dbs = if EnzymeRules.width(config) == 1
100118
(linsolve.dval.b,)
101119
else
102-
(dval.b for dval in linsolve.dval)
120+
ntuple(Val(EnzymeRules.width(config))) do i
121+
Base.@_inline_meta
122+
linsolve.dval[i].b
123+
end
103124
end
104125

105-
cache = (res, resvals, deepcopy(linsolve.val), dAs, dbs)
126+
cachesolve = deepcopy(linsolve.val)
127+
128+
cache = (copy(res.u), resvals, cachesolve, dAs, dbs)
106129
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
107130
end
108131

0 commit comments

Comments
 (0)