Skip to content

Commit 55e3452

Browse files
feat: improve type-stability of adjoint
1 parent a246d55 commit 55e3452

File tree

2 files changed

+22
-19
lines changed

2 files changed

+22
-19
lines changed

ext/MTKChainRulesCoreExt.jl

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,6 @@ function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
1313
MTK.MTKParameters(tunables, args...), mtp_pullback
1414
end
1515

16-
notangent_or_else(::NoTangent, _, x) = x
17-
notangent_or_else(_, x, _) = x
18-
notangent_fallback(x, y) = notangent_or_else(x, x, y)
19-
reduce_to_notangent(x, y) = notangent_or_else(x, y, x)
20-
2116
function subset_idxs(idxs, portion, template)
2217
ntuple(Val(length(template))) do subi
2318
[Base.tail(idx.idx) for idx in idxs if idx.portion == portion && idx.idx[1] == subi]
@@ -83,21 +78,23 @@ function ChainRulesCore.rrule(
8378
const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant)
8479
nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric)
8580

86-
function remake_buffer_pullback(buf′)
87-
buf′ = unthunk(buf′)
88-
f′ = NoTangent()
89-
indp′ = NoTangent()
81+
pullback = let idxs = idxs
82+
function remake_buffer_pullback(buf′)
83+
buf′ = unthunk(buf′)
84+
f′ = NoTangent()
85+
indp′ = NoTangent()
9086

91-
tunable = selected_tangents(buf′.tunable, tunable_idxs)
92-
discrete = selected_tangents(buf′.discrete, disc_idxs)
93-
constant = selected_tangents(buf′.constant, const_idxs)
94-
nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs)
95-
oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, discrete, constant, nonnumeric)
96-
idxs′ = NoTangent()
97-
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
98-
return f′, indp′, oldbuf′, idxs′, vals′
87+
tunable = selected_tangents(buf′.tunable, tunable_idxs)
88+
discrete = selected_tangents(buf′.discrete, disc_idxs)
89+
constant = selected_tangents(buf′.constant, const_idxs)
90+
nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs)
91+
oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, discrete, constant, nonnumeric)
92+
idxs′ = NoTangent()
93+
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
94+
return f′, indp′, oldbuf′, idxs′, vals′
95+
end
9996
end
100-
newbuf, remake_buffer_pullback
97+
newbuf, pullback
10198
end
10299

103100
end

test/extensions/ad.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using SciMLSensitivity
88
using ForwardDiff
99
using ChainRulesCore
1010
using ChainRulesCore: NoTangent
11-
using ChainRulesTestUtils: test_rrule
11+
using ChainRulesTestUtils: test_rrule, rand_tangent
1212

1313
@variables x(t)[1:3] y(t)
1414
@parameters p[1:3, 1:3] q
@@ -97,3 +97,9 @@ for (_idxs, vals) in [
9797
end
9898
end
9999
end
100+
101+
idxs = (parameter_index(sys, a), parameter_index(sys, b))
102+
vals = (1.0f0, 3ones(Float32, 3))
103+
tangent = rand_tangent(ps)
104+
fwd, back = ChainRulesCore.rrule(remake_buffer, sys, ps, idxs, vals)
105+
@inferred back(tangent)

0 commit comments

Comments
 (0)