@@ -3,8 +3,9 @@ module LinearSolveEnzymeExt
3
3
using LinearSolve
4
4
using LinearSolve. LinearAlgebra
5
5
using EnzymeCore
6
+ using EnzymeCore: EnzymeRules
6
7
7
- function EnzymeCore . EnzymeRules. forward (config:: EnzymeCore. EnzymeRules.FwdConfigWidth{1} ,
8
+ function EnzymeRules. forward (config:: EnzymeRules.FwdConfigWidth{1} ,
8
9
func:: Const{typeof(LinearSolve.init)} , :: Type{RT} , prob:: EnzymeCore.Annotation{LP} ,
9
10
alg:: Const ; kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
10
11
@assert ! (prob isa Const)
@@ -19,26 +20,20 @@ function EnzymeCore.EnzymeRules.forward(config::EnzymeCore.EnzymeRules.FwdConfig
19
20
dres = func. val (prob. dval, alg. val; kwargs... )
20
21
dres. b .= res. b == dres. b ? zero (dres. b) : dres. b
21
22
dres. A .= res. A == dres. A ? zero (dres. A) : dres. A
22
- if RT <: DuplicatedNoNeed
23
- return dres
24
- elseif RT <: Duplicated
25
- return Duplicated (res, dres)
26
- end
27
- error (" Unsupported return type $RT " )
28
23
29
24
if EnzymeRules. needs_primal (config) && EnzymeRules. needs_shadow (config)
30
- Duplicated (res, dres)
25
+ return Duplicated (res, dres)
31
26
elseif EnzymeRules. needs_shadow (config)
32
- dres
27
+ return dres
33
28
elseif EnzymeRules. needs_primal (config)
34
- res
29
+ return res
35
30
else
36
- nothing
31
+ return nothing
37
32
end
38
33
end
39
34
40
- function EnzymeCore . EnzymeRules. forward (
41
- config:: EnzymeCore. EnzymeRules.FwdConfigWidth{1} , func:: Const{typeof(LinearSolve.solve!)} ,
35
+ function EnzymeRules. forward (
36
+ config:: EnzymeRules.FwdConfigWidth{1} , func:: Const{typeof(LinearSolve.solve!)} ,
42
37
:: Type{RT} , linsolve:: EnzymeCore.Annotation{LP} ;
43
38
kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
44
39
@assert ! (linsolve isa Const)
@@ -66,17 +61,17 @@ function EnzymeCore.EnzymeRules.forward(
66
61
linsolve. val. b = b
67
62
68
63
if EnzymeRules. needs_primal (config) && EnzymeRules. needs_shadow (config)
69
- Duplicated (res, dres)
64
+ return Duplicated (res, dres)
70
65
elseif EnzymeRules. needs_shadow (config)
71
- dres
66
+ return dres
72
67
elseif EnzymeRules. needs_primal (config)
73
- res
68
+ return res
74
69
else
75
- nothing
70
+ return nothing
76
71
end
77
72
end
78
73
79
- function EnzymeCore . EnzymeRules. augmented_primal (
74
+ function EnzymeRules. augmented_primal (
80
75
config, func:: Const{typeof(LinearSolve.init)} ,
81
76
:: Type{RT} , prob:: EnzymeCore.Annotation{LP} , alg:: Const ;
82
77
kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
@@ -111,10 +106,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
111
106
(dval. b for dval in prob. dval)
112
107
end
113
108
114
- return EnzymeCore . EnzymeRules. AugmentedReturn (res, dres, (d_A, d_b, prob_d_A, prob_d_b))
109
+ return EnzymeRules. AugmentedReturn (res, dres, (d_A, d_b, prob_d_A, prob_d_b))
115
110
end
116
111
117
- function EnzymeCore . EnzymeRules. reverse (
112
+ function EnzymeRules. reverse (
118
113
config, func:: Const{typeof(LinearSolve.init)} , :: Type{RT} ,
119
114
cache, prob:: EnzymeCore.Annotation{LP} , alg:: Const ;
120
115
kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
148
143
# y=inv(A) B
149
144
# dA −= z y^T
150
145
# dB += z, where z = inv(A^T) dy
151
- function EnzymeCore . EnzymeRules. augmented_primal (
146
+ function EnzymeRules. augmented_primal (
152
147
config, func:: Const{typeof(LinearSolve.solve!)} ,
153
148
:: Type{RT} , linsolve:: EnzymeCore.Annotation{LP} ;
154
149
kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
@@ -201,10 +196,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
201
196
cachesolve = deepcopy (linsolve. val)
202
197
203
198
cache = (copy (res. u), resvals, cachesolve, dAs, dbs)
204
- return EnzymeCore . EnzymeRules. AugmentedReturn (res, dres, cache)
199
+ return EnzymeRules. AugmentedReturn (res, dres, cache)
205
200
end
206
201
207
- function EnzymeCore . EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.solve!)} ,
202
+ function EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.solve!)} ,
208
203
:: Type{RT} , cache, linsolve:: EnzymeCore.Annotation{LP} ;
209
204
kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
210
205
y, dys, _linsolve, dAs, dbs = cache
0 commit comments