@@ -29,6 +29,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
29
29
else
30
30
(dval. b for dval in dres)
31
31
end
32
+
32
33
return EnzymeCore. EnzymeRules. AugmentedReturn (res, dres, (d_A, d_b))
33
34
end
34
35
@@ -89,20 +90,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
89
90
(dr. u for dr in dres)
90
91
end
91
92
92
- cache = (res, resvals, deepcopy (linsolve. val))
93
- return EnzymeCore. EnzymeRules. AugmentedReturn (res, dres, cache)
94
- end
95
-
96
- function EnzymeCore. EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.solve!)} , :: Type{RT} , cache, linsolve:: EnzymeCore.Annotation{LP} ; kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
97
- y, dys, _linsolve = cache
98
-
99
- @assert ! (typeof (linsolve) <: Const )
100
- @assert ! (typeof (linsolve) <: Active )
101
-
102
- if EnzymeRules. width (config) == 1
103
- dys = (dys,)
104
- end
105
-
106
93
dAs = if EnzymeRules. width (config) == 1
107
94
(linsolve. dval. A,)
108
95
else
@@ -115,6 +102,20 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
115
102
(dval. b for dval in linsolve. dval)
116
103
end
117
104
105
+ cache = (res, resvals, deepcopy (linsolve. val), dAs, dbs)
106
+ return EnzymeCore. EnzymeRules. AugmentedReturn (res, dres, cache)
107
+ end
108
+
109
+ function EnzymeCore. EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.solve!)} , :: Type{RT} , cache, linsolve:: EnzymeCore.Annotation{LP} ; kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
110
+ y, dys, _linsolve, dAs, dbs = cache
111
+
112
+ @assert ! (typeof (linsolve) <: Const )
113
+ @assert ! (typeof (linsolve) <: Active )
114
+
115
+ if EnzymeRules. width (config) == 1
116
+ dys = (dys,)
117
+ end
118
+
118
119
for (dA, db, dy) in zip (dAs, dbs, dys)
119
120
z = if _linsolve. cacheval isa Factorization
120
121
_linsolve. cacheval' \ dy
0 commit comments