Skip to content

Commit e4f0785

Browse files
committed
fix mutated db
1 parent 54f0722 commit e4f0785

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
2929
else
3030
(dval.b for dval in dres)
3131
end
32+
3233
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b))
3334
end
3435

@@ -89,20 +90,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
8990
(dr.u for dr in dres)
9091
end
9192

92-
cache = (res, resvals, deepcopy(linsolve.val))
93-
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
94-
end
95-
96-
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
97-
y, dys, _linsolve = cache
98-
99-
@assert !(typeof(linsolve) <: Const)
100-
@assert !(typeof(linsolve) <: Active)
101-
102-
if EnzymeRules.width(config) == 1
103-
dys = (dys,)
104-
end
105-
10693
dAs = if EnzymeRules.width(config) == 1
10794
(linsolve.dval.A,)
10895
else
@@ -115,6 +102,20 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
115102
(dval.b for dval in linsolve.dval)
116103
end
117104

105+
cache = (res, resvals, deepcopy(linsolve.val), dAs, dbs)
106+
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
107+
end
108+
109+
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
110+
y, dys, _linsolve, dAs, dbs = cache
111+
112+
@assert !(typeof(linsolve) <: Const)
113+
@assert !(typeof(linsolve) <: Active)
114+
115+
if EnzymeRules.width(config) == 1
116+
dys = (dys,)
117+
end
118+
118119
for (dA, db, dy) in zip(dAs, dbs, dys)
119120
z = if _linsolve.cacheval isa Factorization
120121
_linsolve.cacheval' \ dy

0 commit comments

Comments
 (0)