@@ -30,29 +30,41 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
30
30
(dval. b for dval in dres)
31
31
end
32
32
33
- return EnzymeCore. EnzymeRules. AugmentedReturn (res, dres, (d_A, d_b))
33
+
34
+ prob_d_A = if EnzymeRules. width (config) == 1
35
+ prob. dval. A
36
+ else
37
+ (dval. A for dval in prob. dval)
38
+ end
39
+ prob_d_b = if EnzymeRules. width (config) == 1
40
+ prob. dval. b
41
+ else
42
+ (dval. b for dval in prob. dval)
43
+ end
44
+
45
+ return EnzymeCore. EnzymeRules. AugmentedReturn (res, dres, (d_A, d_b, prob_d_A, prob_d_b))
34
46
end
35
47
36
48
function EnzymeCore. EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.init)} , :: Type{RT} , cache, prob:: EnzymeCore.Annotation{LP} , alg:: Const ; kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
37
- d_A, d_b = cache
49
+ d_A, d_b, prob_d_A, prob_d_b = cache
38
50
39
51
if EnzymeRules. width (config) == 1
40
- if d_A != = prob . dval . A
41
- prob . dval . A .+ = d_A
52
+ if d_A != = prob_d_A
53
+ prob_d_A .+ = d_A
42
54
d_A .= 0
43
55
end
44
- if d_b != = prob . dval . b
45
- prob . dval . b .+ = d_b
56
+ if d_b != = prob_d_b
57
+ prob_d_b .+ = d_b
46
58
d_b .= 0
47
59
end
48
60
else
49
61
for i in 1 : EnzymeRules. width (config)
50
- if d_A != = prob . dval . A
51
- prob . dval . A [i] .+ = d_A[i]
62
+ if d_A != = prob_d_A[i]
63
+ prob_d_A [i] .+ = d_A[i]
52
64
d_A[i] .= 0
53
65
end
54
- if d_b != = prob . dval . b
55
- prob . dval . b [i] .+ = d_b[i]
66
+ if d_b != = prob_d_b[i]
67
+ prob_d_b [i] .+ = d_b[i]
56
68
d_b[i] .= 0
57
69
end
58
70
end
@@ -87,22 +99,33 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
87
99
resvals = if EnzymeRules. width (config) == 1
88
100
dres. u
89
101
else
90
- (dr. u for dr in dres)
102
+ ntuple (Val (EnzymeRules. width (config))) do i
103
+ Base. @_inline_meta
104
+ dres[i]. u
105
+ end
91
106
end
92
107
93
108
dAs = if EnzymeRules. width (config) == 1
94
109
(linsolve. dval. A,)
95
110
else
96
- (dval. A for dval in linsolve. dval)
111
+ ntuple (Val (EnzymeRules. width (config))) do i
112
+ Base. @_inline_meta
113
+ linsolve. dval[i]. A
114
+ end
97
115
end
98
116
99
117
dbs = if EnzymeRules. width (config) == 1
100
118
(linsolve. dval. b,)
101
119
else
102
- (dval. b for dval in linsolve. dval)
120
+ ntuple (Val (EnzymeRules. width (config))) do i
121
+ Base. @_inline_meta
122
+ linsolve. dval[i]. b
123
+ end
103
124
end
104
125
105
- cache = (res, resvals, deepcopy (linsolve. val), dAs, dbs)
126
+ cachesolve = deepcopy (linsolve. val)
127
+
128
+ cache = (copy (res. u), resvals, cachesolve, dAs, dbs)
106
129
return EnzymeCore. EnzymeRules. AugmentedReturn (res, dres, cache)
107
130
end
108
131
0 commit comments