Skip to content

Commit 39b6221

Browse files
Merge pull request #3042 from AayushSabharwal/as/remake-buffer-adjoint
feat: add adjoint for `remake_buffer`
2 parents 9611e18 + 55e3452 commit 39b6221

File tree

5 files changed

+185
-31
lines changed

5 files changed

+185
-31
lines changed

ext/MTKChainRulesCoreExt.jl

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,99 @@ module MTKChainRulesCoreExt
22

33
import ModelingToolkit as MTK
44
import ChainRulesCore
5-
import ChainRulesCore: NoTangent
5+
import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
66

77
function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
88
function mtp_pullback(dt)
9+
dt = unthunk(dt)
910
(NoTangent(), dt.tunable[1:length(tunables)],
1011
ntuple(_ -> NoTangent(), length(args))...)
1112
end
1213
MTK.MTKParameters(tunables, args...), mtp_pullback
1314
end
1415

16+
function subset_idxs(idxs, portion, template)
17+
ntuple(Val(length(template))) do subi
18+
[Base.tail(idx.idx) for idx in idxs if idx.portion == portion && idx.idx[1] == subi]
19+
end
20+
end
21+
22+
selected_tangents(::NoTangent, _) = ()
23+
selected_tangents(::ZeroTangent, _) = ZeroTangent()
24+
function selected_tangents(
25+
tangents::AbstractArray{T}, idxs::Vector{Tuple{Int}}) where {T <: Number}
26+
selected_tangents(tangents, map(only, idxs))
27+
end
28+
function selected_tangents(tangents::AbstractArray{T}, idxs...) where {T <: Number}
29+
newtangents = copy(tangents)
30+
view(newtangents, idxs...) .= zero(T)
31+
newtangents
32+
end
33+
function selected_tangents(
34+
tangents::AbstractVector{T}, idxs) where {S <: Number, T <: AbstractArray{S}}
35+
newtangents = copy(tangents)
36+
for i in idxs
37+
j, k... = i
38+
if k == ()
39+
newtangents[j] = zero(newtangents[j])
40+
else
41+
newtangents[j] = selected_tangents(newtangents[j], k...)
42+
end
43+
end
44+
newtangents
45+
end
46+
function selected_tangents(tangents::AbstractVector{T}, idxs) where {T <: AbstractArray}
47+
newtangents = similar(tangents, Union{T, NoTangent})
48+
copyto!(newtangents, tangents)
49+
for i in idxs
50+
j, k... = i
51+
if k == ()
52+
newtangents[j] = NoTangent()
53+
else
54+
newtangents[j] = selected_tangents(newtangents[j], k...)
55+
end
56+
end
57+
newtangents
58+
end
59+
function selected_tangents(
60+
tangents::Union{Tangent{<:Tuple}, Tangent{T, <:Tuple}}, idxs) where {T}
61+
ntuple(Val(length(tangents))) do i
62+
selected_tangents(tangents[i], idxs[i])
63+
end
64+
end
65+
66+
function ChainRulesCore.rrule(
67+
::typeof(MTK.remake_buffer), indp, oldbuf::MTK.MTKParameters, idxs, vals)
68+
if idxs isa AbstractSet
69+
idxs = collect(idxs)
70+
end
71+
idxs = map(idxs) do i
72+
i isa MTK.ParameterIndex ? i : MTK.parameter_index(indp, i)
73+
end
74+
newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals)
75+
tunable_idxs = reduce(
76+
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable))
77+
disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete)
78+
const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant)
79+
nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric)
80+
81+
pullback = let idxs = idxs
82+
function remake_buffer_pullback(buf′)
83+
buf′ = unthunk(buf′)
84+
f′ = NoTangent()
85+
indp′ = NoTangent()
86+
87+
tunable = selected_tangents(buf′.tunable, tunable_idxs)
88+
discrete = selected_tangents(buf′.discrete, disc_idxs)
89+
constant = selected_tangents(buf′.constant, const_idxs)
90+
nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs)
91+
oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, discrete, constant, nonnumeric)
92+
idxs′ = NoTangent()
93+
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
94+
return f′, indp′, oldbuf′, idxs′, vals′
95+
end
96+
end
97+
newbuf, pullback
98+
end
99+
15100
end

src/systems/index_cache.jl

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -288,30 +288,25 @@ function IndexCache(sys::AbstractSystem)
288288
end
289289

290290
function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym)
291-
if sym isa Symbol
292-
sym = get(ic.symbol_to_variable, sym, nothing)
293-
sym === nothing && return false
294-
end
295-
return check_index_map(ic.unknown_idx, sym) !== nothing
291+
variable_index(ic, sym) !== nothing
296292
end
297293

298294
function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym)
299295
if sym isa Symbol
300296
sym = get(ic.symbol_to_variable, sym, nothing)
301297
sym === nothing && return nothing
302298
end
303-
return check_index_map(ic.unknown_idx, sym)
299+
idx = check_index_map(ic.unknown_idx, sym)
300+
idx === nothing || return idx
301+
iscall(sym) && operation(sym) == getindex || return nothing
302+
args = arguments(sym)
303+
idx = variable_index(ic, args[1])
304+
idx === nothing && return nothing
305+
return idx[args[2:end]...]
304306
end
305307

306308
function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym)
307-
if sym isa Symbol
308-
sym = get(ic.symbol_to_variable, sym, nothing)
309-
sym === nothing && return false
310-
end
311-
return check_index_map(ic.tunable_idx, sym) !== nothing ||
312-
check_index_map(ic.discrete_idx, sym) !== nothing ||
313-
check_index_map(ic.constant_idx, sym) !== nothing ||
314-
check_index_map(ic.nonnumeric_idx, sym) !== nothing
309+
parameter_index(ic, sym) !== nothing
315310
end
316311

317312
function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
@@ -331,17 +326,21 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
331326
ParameterIndex(SciMLStructures.Constants(), idx, validate_size)
332327
elseif (idx = check_index_map(ic.nonnumeric_idx, sym)) !== nothing
333328
ParameterIndex(NONNUMERIC_PORTION, idx, validate_size)
334-
else
335-
nothing
329+
elseif iscall(sym) && operation(sym) == getindex
330+
args = arguments(sym)
331+
pidx = parameter_index(ic, args[1])
332+
pidx === nothing && return nothing
333+
if pidx.portion == SciMLStructures.Tunable()
334+
ParameterIndex(pidx.portion, reshape(pidx.idx, size(args[1]))[args[2:end]...],
335+
pidx.validate_size)
336+
else
337+
ParameterIndex(pidx.portion, (pidx.idx..., args[2:end]...), pidx.validate_size)
338+
end
336339
end
337340
end
338341

339342
function SymbolicIndexingInterface.is_timeseries_parameter(ic::IndexCache, sym)
340-
if sym isa Symbol
341-
sym = get(ic.symbol_to_variable, sym, nothing)
342-
sym === nothing && return false
343-
end
344-
return check_index_map(ic.discrete_idx, sym) !== nothing
343+
timeseries_parameter_index(ic, sym) !== nothing
345344
end
346345

347346
function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sym)
@@ -350,8 +349,13 @@ function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sy
350349
sym === nothing && return nothing
351350
end
352351
idx = check_index_map(ic.discrete_idx, sym)
352+
idx === nothing ||
353+
return ParameterTimeseriesIndex(idx.clock_idx, (idx.buffer_idx, idx.idx_in_clock))
354+
iscall(sym) && operation(sym) == getindex || return nothing
355+
args = arguments(sym)
356+
idx = timeseries_parameter_index(ic, args[1])
353357
idx === nothing && return nothing
354-
return ParameterTimeseriesIndex(idx.clock_idx, (idx.buffer_idx, idx.idx_in_clock))
358+
ParameterIndex(idx.portion, (idx.idx..., args[2:end]...), idx.validate_size)
355359
end
356360

357361
function check_index_map(idxmap, sym)

src/systems/parameter_buffer.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,15 @@ function Base.copy(p::MTKParameters)
336336
end
337337

338338
function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::ParameterIndex)
339+
_ducktyped_parameter_values(p, pind)
340+
end
341+
function _ducktyped_parameter_values(p, pind::ParameterIndex)
339342
@unpack portion, idx = pind
340343
if portion isa SciMLStructures.Tunable
341344
return idx isa Int ? p.tunable[idx] : view(p.tunable, idx)
342345
end
343346
i, j, k... = idx
344-
if portion isa SciMLStructures.Tunable
345-
return isempty(k) ? p.tunable[i][j] : p.tunable[i][j][k...]
346-
elseif portion isa SciMLStructures.Discrete
347+
if portion isa SciMLStructures.Discrete
347348
return isempty(k) ? p.discrete[i][j] : p.discrete[i][j][k...]
348349
elseif portion isa SciMLStructures.Constants
349350
return isempty(k) ? p.constant[i][j] : p.constant[i][j][k...]
@@ -435,20 +436,26 @@ function validate_parameter_type(ic::IndexCache, p, idx::ParameterIndex, val)
435436
end
436437

437438
function validate_parameter_type(ic::IndexCache, idx::ParameterIndex, val)
439+
stype = get_buffer_template(ic, idx).type
440+
if idx.portion == SciMLStructures.Tunable() && !(idx.idx isa Int)
441+
stype = AbstractArray{<:stype}
442+
end
438443
validate_parameter_type(
439-
ic, get_buffer_template(ic, idx).type, Symbolics.Unknown(), nothing, idx, val)
444+
ic, stype, Symbolics.Unknown(), nothing, idx, val)
440445
end
441446

442447
function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val)
443448
(; portion) = index
444449
# Nonnumeric parameters have to match the type
445450
if portion === NONNUMERIC_PORTION
446451
val isa stype && return nothing
447-
throw(ParameterTypeException(:validate_parameter_type, sym, stype, val))
452+
throw(ParameterTypeException(
453+
:validate_parameter_type, sym === nothing ? index : sym, stype, val))
448454
end
449455
# Array parameters need array values...
450456
if stype <: AbstractArray && !isa(val, AbstractArray)
451-
throw(ParameterTypeException(:validate_parameter_type, sym, stype, val))
457+
throw(ParameterTypeException(
458+
:validate_parameter_type, sym === nothing ? index : sym, stype, val))
452459
end
453460
# ... and must match sizes
454461
if stype <: AbstractArray && sz != Symbolics.Unknown() && size(val) != sz
@@ -465,15 +472,16 @@ function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val)
465472
# This is for duals and other complicated number types
466473
etype = SciMLBase.parameterless_type(etype)
467474
eltype(val) <: etype || throw(ParameterTypeException(
468-
:validate_parameter_type, sym, AbstractArray{etype}, val))
475+
:validate_parameter_type, sym === nothing ? index : sym, AbstractArray{etype}, val))
469476
else
470477
# Real check
471478
if stype <: Real
472479
stype = Real
473480
end
474481
stype = SciMLBase.parameterless_type(stype)
475482
val isa stype ||
476-
throw(ParameterTypeException(:validate_parameter_type, sym, stype, val))
483+
throw(ParameterTypeException(
484+
:validate_parameter_type, sym === nothing ? index : sym, stype, val))
477485
end
478486
end
479487

@@ -485,6 +493,9 @@ function indp_to_system(indp)
485493
end
486494

487495
function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, idxs, vals)
496+
_remake_buffer(indp, oldbuf, idxs, vals)
497+
end
498+
function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true)
488499
newbuf = @set oldbuf.tunable = similar(oldbuf.tunable, Any)
489500
@set! newbuf.discrete = Tuple(similar(buf, Any) for buf in newbuf.discrete)
490501
@set! newbuf.constant = Tuple(similar(buf, Any) for buf in newbuf.constant)

test/extensions/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
[deps]
22
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
3+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
35
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
46
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
57
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

test/extensions/ad.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ using SciMLStructures
66
using OrdinaryDiffEq
77
using SciMLSensitivity
88
using ForwardDiff
9+
using ChainRulesCore
10+
using ChainRulesCore: NoTangent
11+
using ChainRulesTestUtils: test_rrule, rand_tangent
912

1013
@variables x(t)[1:3] y(t)
1114
@parameters p[1:3, 1:3] q
@@ -51,3 +54,52 @@ end
5154

5255
@test ForwardDiff.gradient(x_at_0, [0.3, 0.7]) == zeros(2)
5356
end
57+
58+
@parameters a b[1:3] c(t) d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
59+
@named sys = ODESystem(
60+
Equation[], t, [], [a, b, c, d, e, f, g, h],
61+
continuous_events = [[a ~ 0] => [c ~ 0]])
62+
sys = complete(sys)
63+
64+
ivs = Dict(c => 3a, b => ones(3), a => 1.0, d => 4, e => [5.0, 6.0, 7.0],
65+
f => ones(Int, 3, 3), g => [0.1, 0.2, 0.3], h => "foo")
66+
67+
ps = MTKParameters(sys, ivs)
68+
69+
varmap = Dict(a => 1.0f0, b => 3ones(Float32, 3), c => 2.0,
70+
e => Float32[0.4, 0.5, 0.6], g => ones(Float32, 4))
71+
get_values = getp(sys, [a, b..., c, e...])
72+
get_g = getp(sys, g)
73+
for (_idxs, vals) in [
74+
# all portions
75+
(collect(keys(varmap)), collect(values(varmap))),
76+
# non-arrays
77+
(keys(varmap), values(varmap)),
78+
# tunable only
79+
([a], [varmap[a]]),
80+
([a, b], (varmap[a], varmap[b])),
81+
([a, b[2]], (varmap[a], varmap[b][2]))
82+
]
83+
for idxs in [_idxs, map(i -> parameter_index(sys, i), collect(_idxs))]
84+
loss = function (p)
85+
newps = remake_buffer(sys, ps, idxs, p)
86+
return sum(get_values(newps)) + sum(get_g(newps))
87+
end
88+
89+
grad = Zygote.gradient(loss, vals)[1]
90+
for (val, g) in zip(vals, grad)
91+
@test eltype(val) == eltype(g)
92+
if val isa Number
93+
@test isone(g)
94+
else
95+
@test all(isone, g)
96+
end
97+
end
98+
end
99+
end
100+
101+
idxs = (parameter_index(sys, a), parameter_index(sys, b))
102+
vals = (1.0f0, 3ones(Float32, 3))
103+
tangent = rand_tangent(ps)
104+
fwd, back = ChainRulesCore.rrule(remake_buffer, sys, ps, idxs, vals)
105+
@inferred back(tangent)

0 commit comments

Comments
 (0)