@@ -13,11 +13,6 @@ function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
13
13
MTK. MTKParameters (tunables, args... ), mtp_pullback
14
14
end
15
15
16
- notangent_or_else (:: NoTangent , _, x) = x
17
- notangent_or_else (_, x, _) = x
18
- notangent_fallback (x, y) = notangent_or_else (x, x, y)
19
- reduce_to_notangent (x, y) = notangent_or_else (x, y, x)
20
-
21
16
function subset_idxs (idxs, portion, template)
22
17
ntuple (Val (length (template))) do subi
23
18
[Base. tail (idx. idx) for idx in idxs if idx. portion == portion && idx. idx[1 ] == subi]
@@ -83,21 +78,23 @@ function ChainRulesCore.rrule(
83
78
const_idxs = subset_idxs (idxs, MTK. SciMLStructures. Constants (), oldbuf. constant)
84
79
nn_idxs = subset_idxs (idxs, MTK. NONNUMERIC_PORTION, oldbuf. nonnumeric)
85
80
86
- function remake_buffer_pullback (buf′)
87
- buf′ = unthunk (buf′)
88
- f′ = NoTangent ()
89
- indp′ = NoTangent ()
81
+ pullback = let idxs = idxs
82
+ function remake_buffer_pullback (buf′)
83
+ buf′ = unthunk (buf′)
84
+ f′ = NoTangent ()
85
+ indp′ = NoTangent ()
90
86
91
- tunable = selected_tangents (buf′. tunable, tunable_idxs)
92
- discrete = selected_tangents (buf′. discrete, disc_idxs)
93
- constant = selected_tangents (buf′. constant, const_idxs)
94
- nonnumeric = selected_tangents (buf′. nonnumeric, nn_idxs)
95
- oldbuf′ = Tangent {typeof(oldbuf)} (; tunable, discrete, constant, nonnumeric)
96
- idxs′ = NoTangent ()
97
- vals′ = map (i -> MTK. _ducktyped_parameter_values (buf′, i), idxs)
98
- return f′, indp′, oldbuf′, idxs′, vals′
87
+ tunable = selected_tangents (buf′. tunable, tunable_idxs)
88
+ discrete = selected_tangents (buf′. discrete, disc_idxs)
89
+ constant = selected_tangents (buf′. constant, const_idxs)
90
+ nonnumeric = selected_tangents (buf′. nonnumeric, nn_idxs)
91
+ oldbuf′ = Tangent {typeof(oldbuf)} (; tunable, discrete, constant, nonnumeric)
92
+ idxs′ = NoTangent ()
93
+ vals′ = map (i -> MTK. _ducktyped_parameter_values (buf′, i), idxs)
94
+ return f′, indp′, oldbuf′, idxs′, vals′
95
+ end
99
96
end
100
- newbuf, remake_buffer_pullback
97
+ newbuf, pullback
101
98
end
102
99
103
100
end
0 commit comments