Skip to content

Commit 6d63145

Browse files
feat: add adjoint for remake_buffer
1 parent 0e6e842 commit 6d63145

File tree

4 files changed

+145
-2
lines changed

4 files changed

+145
-2
lines changed

ext/MTKChainRulesCoreExt.jl

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,102 @@ module MTKChainRulesCoreExt
22

33
import ModelingToolkit as MTK
44
import ChainRulesCore
5-
import ChainRulesCore: NoTangent
5+
import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
66

77
function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
88
function mtp_pullback(dt)
9+
dt = unthunk(dt)
910
(NoTangent(), dt.tunable[1:length(tunables)],
1011
ntuple(_ -> NoTangent(), length(args))...)
1112
end
1213
MTK.MTKParameters(tunables, args...), mtp_pullback
1314
end
1415

16+
notangent_or_else(::NoTangent, _, x) = x
17+
notangent_or_else(_, x, _) = x
18+
notangent_fallback(x, y) = notangent_or_else(x, x, y)
19+
reduce_to_notangent(x, y) = notangent_or_else(x, y, x)
20+
21+
function subset_idxs(idxs, portion, template)
22+
ntuple(Val(length(template))) do subi
23+
[Base.tail(idx.idx) for idx in idxs if idx.portion == portion && idx.idx[1] == subi]
24+
end
25+
end
26+
27+
selected_tangents(::NoTangent, _) = ()
28+
selected_tangents(::ZeroTangent, _) = ZeroTangent()
29+
function selected_tangents(
30+
tangents::AbstractArray{T}, idxs::Vector{Tuple{Int}}) where {T <: Number}
31+
selected_tangents(tangents, map(only, idxs))
32+
end
33+
function selected_tangents(tangents::AbstractArray{T}, idxs...) where {T <: Number}
34+
newtangents = copy(tangents)
35+
view(newtangents, idxs...) .= zero(T)
36+
newtangents
37+
end
38+
function selected_tangents(
39+
tangents::AbstractVector{T}, idxs) where {S <: Number, T <: AbstractArray{S}}
40+
newtangents = copy(tangents)
41+
for i in idxs
42+
j, k... = i
43+
if k == ()
44+
newtangents[j] = zero(newtangents[j])
45+
else
46+
newtangents[j] = selected_tangents(newtangents[j], k...)
47+
end
48+
end
49+
newtangents
50+
end
51+
function selected_tangents(tangents::AbstractVector{T}, idxs) where {T <: AbstractArray}
52+
newtangents = similar(tangents, Union{T, NoTangent})
53+
copyto!(newtangents, tangents)
54+
for i in idxs
55+
j, k... = i
56+
if k == ()
57+
newtangents[j] = NoTangent()
58+
else
59+
newtangents[j] = selected_tangents(newtangents[j], k...)
60+
end
61+
end
62+
newtangents
63+
end
64+
function selected_tangents(
65+
tangents::Union{Tangent{<:Tuple}, Tangent{T, <:Tuple}}, idxs) where {T}
66+
ntuple(Val(length(tangents))) do i
67+
selected_tangents(tangents[i], idxs[i])
68+
end
69+
end
70+
71+
function ChainRulesCore.rrule(
72+
::typeof(MTK.remake_buffer), indp, oldbuf::MTK.MTKParameters, idxs, vals)
73+
if idxs isa AbstractSet
74+
idxs = collect(idxs)
75+
end
76+
idxs = map(idxs) do i
77+
i isa MTK.ParameterIndex ? i : MTK.parameter_index(indp, i)
78+
end
79+
newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals)
80+
tunable_idxs = reduce(
81+
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable))
82+
disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete)
83+
const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant)
84+
nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric)
85+
86+
function remake_buffer_pullback(buf′)
87+
buf′ = unthunk(buf′)
88+
f′ = NoTangent()
89+
indp′ = NoTangent()
90+
91+
tunable = selected_tangents(buf′.tunable, tunable_idxs)
92+
discrete = selected_tangents(buf′.discrete, disc_idxs)
93+
constant = selected_tangents(buf′.constant, const_idxs)
94+
nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs)
95+
oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, discrete, constant, nonnumeric)
96+
idxs′ = NoTangent()
97+
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
98+
return f′, indp′, oldbuf′, idxs′, vals′
99+
end
100+
newbuf, remake_buffer_pullback
101+
end
102+
15103
end

src/systems/parameter_buffer.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,12 @@ function validate_parameter_type(ic::IndexCache, p, idx::ParameterIndex, val)
436436
end
437437

438438
function validate_parameter_type(ic::IndexCache, idx::ParameterIndex, val)
439+
stype = get_buffer_template(ic, idx).type
440+
if idx.portion == SciMLStructures.Tunable() && !(idx.idx isa Int)
441+
stype = AbstractArray{<:stype}
442+
end
439443
validate_parameter_type(
440-
ic, get_buffer_template(ic, idx).type, Symbolics.Unknown(), nothing, idx, val)
444+
ic, stype, Symbolics.Unknown(), nothing, idx, val)
441445
end
442446

443447
function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val)
@@ -489,6 +493,9 @@ function indp_to_system(indp)
489493
end
490494

491495
function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, idxs, vals)
496+
_remake_buffer(indp, oldbuf, idxs, vals)
497+
end
498+
function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true)
492499
newbuf = @set oldbuf.tunable = similar(oldbuf.tunable, Any)
493500
@set! newbuf.discrete = Tuple(similar(buf, Any) for buf in newbuf.discrete)
494501
@set! newbuf.constant = Tuple(similar(buf, Any) for buf in newbuf.constant)

test/extensions/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
[deps]
22
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
3+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
35
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
46
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
57
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

test/extensions/ad.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ using SciMLStructures
66
using OrdinaryDiffEq
77
using SciMLSensitivity
88
using ForwardDiff
9+
using ChainRulesCore
10+
using ChainRulesCore: NoTangent
11+
using ChainRulesTestUtils: test_rrule
912

1013
@variables x(t)[1:3] y(t)
1114
@parameters p[1:3, 1:3] q
@@ -51,3 +54,46 @@ end
5154

5255
@test ForwardDiff.gradient(x_at_0, [0.3, 0.7]) == zeros(2)
5356
end
57+
58+
@parameters a b[1:3] c(t) d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
59+
@named sys = ODESystem(
60+
Equation[], t, [], [a, b, c, d, e, f, g, h],
61+
continuous_events = [[a ~ 0] => [c ~ 0]])
62+
sys = complete(sys)
63+
64+
ivs = Dict(c => 3a, b => ones(3), a => 1.0, d => 4, e => [5.0, 6.0, 7.0],
65+
f => ones(Int, 3, 3), g => [0.1, 0.2, 0.3], h => "foo")
66+
67+
ps = MTKParameters(sys, ivs)
68+
69+
varmap = Dict(a => 1.0f0, b => 3ones(Float32, 3), c => 2.0,
70+
e => Float32[0.4, 0.5, 0.6], g => ones(Float32, 4))
71+
get_values = getp(sys, [a, b..., c, e...])
72+
get_g = getp(sys, g)
73+
for (_idxs, vals) in [
74+
# all portions
75+
(collect(keys(varmap)), collect(values(varmap))),
76+
# non-arrays
77+
(keys(varmap), values(varmap)),
78+
# tunable only
79+
([a], [varmap[a]]),
80+
([a, b], (varmap[a], varmap[b])),
81+
([a, b[2]], (varmap[a], varmap[b][2]))
82+
]
83+
for idxs in [_idxs, map(i -> parameter_index(sys, i), collect(_idxs))]
84+
loss = function (p)
85+
newps = remake_buffer(sys, ps, idxs, p)
86+
return sum(get_values(newps)) + sum(get_g(newps))
87+
end
88+
89+
grad = Zygote.gradient(loss, vals)[1]
90+
for (val, g) in zip(vals, grad)
91+
@test eltype(val) == eltype(g)
92+
if val isa Number
93+
@test isone(g)
94+
else
95+
@test all(isone, g)
96+
end
97+
end
98+
end
99+
end

0 commit comments

Comments
 (0)