Skip to content

feat: add adjoint for remake_buffer #3042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 86 additions & 1 deletion ext/MTKChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,99 @@ module MTKChainRulesCoreExt

import ModelingToolkit as MTK
import ChainRulesCore
import ChainRulesCore: NoTangent
import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk

function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
function mtp_pullback(dt)
dt = unthunk(dt)
(NoTangent(), dt.tunable[1:length(tunables)],
ntuple(_ -> NoTangent(), length(args))...)
end
MTK.MTKParameters(tunables, args...), mtp_pullback
end

function subset_idxs(idxs, portion, template)
ntuple(Val(length(template))) do subi
[Base.tail(idx.idx) for idx in idxs if idx.portion == portion && idx.idx[1] == subi]
end
end

selected_tangents(::NoTangent, _) = ()
selected_tangents(::ZeroTangent, _) = ZeroTangent()
function selected_tangents(
tangents::AbstractArray{T}, idxs::Vector{Tuple{Int}}) where {T <: Number}
selected_tangents(tangents, map(only, idxs))
end
function selected_tangents(tangents::AbstractArray{T}, idxs...) where {T <: Number}
newtangents = copy(tangents)
view(newtangents, idxs...) .= zero(T)
newtangents
end
function selected_tangents(
tangents::AbstractVector{T}, idxs) where {S <: Number, T <: AbstractArray{S}}
newtangents = copy(tangents)
for i in idxs
j, k... = i
if k == ()
newtangents[j] = zero(newtangents[j])
else
newtangents[j] = selected_tangents(newtangents[j], k...)
end
end
newtangents
end
function selected_tangents(tangents::AbstractVector{T}, idxs) where {T <: AbstractArray}
newtangents = similar(tangents, Union{T, NoTangent})
copyto!(newtangents, tangents)
for i in idxs
j, k... = i
if k == ()
newtangents[j] = NoTangent()
else
newtangents[j] = selected_tangents(newtangents[j], k...)
end
end
newtangents
end
function selected_tangents(
tangents::Union{Tangent{<:Tuple}, Tangent{T, <:Tuple}}, idxs) where {T}
ntuple(Val(length(tangents))) do i
selected_tangents(tangents[i], idxs[i])
end
end

function ChainRulesCore.rrule(
::typeof(MTK.remake_buffer), indp, oldbuf::MTK.MTKParameters, idxs, vals)
if idxs isa AbstractSet
idxs = collect(idxs)
end
idxs = map(idxs) do i
i isa MTK.ParameterIndex ? i : MTK.parameter_index(indp, i)
end
newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals)
tunable_idxs = reduce(
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable))
disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete)
const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant)
nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric)

pullback = let idxs = idxs
function remake_buffer_pullback(buf′)
buf′ = unthunk(buf′)
f′ = NoTangent()
indp′ = NoTangent()

tunable = selected_tangents(buf′.tunable, tunable_idxs)
discrete = selected_tangents(buf′.discrete, disc_idxs)
constant = selected_tangents(buf′.constant, const_idxs)
nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs)
oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, discrete, constant, nonnumeric)
idxs′ = NoTangent()
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
return f′, indp′, oldbuf′, idxs′, vals′
end
end
newbuf, pullback
end

end
48 changes: 26 additions & 22 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,30 +288,25 @@ function IndexCache(sys::AbstractSystem)
end

function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym)
if sym isa Symbol
sym = get(ic.symbol_to_variable, sym, nothing)
sym === nothing && return false
end
return check_index_map(ic.unknown_idx, sym) !== nothing
variable_index(ic, sym) !== nothing
end

function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym)
if sym isa Symbol
sym = get(ic.symbol_to_variable, sym, nothing)
sym === nothing && return nothing
end
return check_index_map(ic.unknown_idx, sym)
idx = check_index_map(ic.unknown_idx, sym)
idx === nothing || return idx
iscall(sym) && operation(sym) == getindex || return nothing
args = arguments(sym)
idx = variable_index(ic, args[1])
idx === nothing && return nothing
return idx[args[2:end]...]
end

function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym)
if sym isa Symbol
sym = get(ic.symbol_to_variable, sym, nothing)
sym === nothing && return false
end
return check_index_map(ic.tunable_idx, sym) !== nothing ||
check_index_map(ic.discrete_idx, sym) !== nothing ||
check_index_map(ic.constant_idx, sym) !== nothing ||
check_index_map(ic.nonnumeric_idx, sym) !== nothing
parameter_index(ic, sym) !== nothing
end

function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
Expand All @@ -331,17 +326,21 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
ParameterIndex(SciMLStructures.Constants(), idx, validate_size)
elseif (idx = check_index_map(ic.nonnumeric_idx, sym)) !== nothing
ParameterIndex(NONNUMERIC_PORTION, idx, validate_size)
else
nothing
elseif iscall(sym) && operation(sym) == getindex
args = arguments(sym)
pidx = parameter_index(ic, args[1])
pidx === nothing && return nothing
if pidx.portion == SciMLStructures.Tunable()
ParameterIndex(pidx.portion, reshape(pidx.idx, size(args[1]))[args[2:end]...],
pidx.validate_size)
else
ParameterIndex(pidx.portion, (pidx.idx..., args[2:end]...), pidx.validate_size)
end
end
end

function SymbolicIndexingInterface.is_timeseries_parameter(ic::IndexCache, sym)
if sym isa Symbol
sym = get(ic.symbol_to_variable, sym, nothing)
sym === nothing && return false
end
return check_index_map(ic.discrete_idx, sym) !== nothing
timeseries_parameter_index(ic, sym) !== nothing
end

function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sym)
Expand All @@ -350,8 +349,13 @@ function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sy
sym === nothing && return nothing
end
idx = check_index_map(ic.discrete_idx, sym)
idx === nothing ||
return ParameterTimeseriesIndex(idx.clock_idx, (idx.buffer_idx, idx.idx_in_clock))
iscall(sym) && operation(sym) == getindex || return nothing
args = arguments(sym)
idx = timeseries_parameter_index(ic, args[1])
idx === nothing && return nothing
return ParameterTimeseriesIndex(idx.clock_idx, (idx.buffer_idx, idx.idx_in_clock))
ParameterIndex(idx.portion, (idx.idx..., args[2:end]...), idx.validate_size)
end

function check_index_map(idxmap, sym)
Expand Down
27 changes: 19 additions & 8 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,15 @@ function Base.copy(p::MTKParameters)
end

function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::ParameterIndex)
_ducktyped_parameter_values(p, pind)
end
function _ducktyped_parameter_values(p, pind::ParameterIndex)
@unpack portion, idx = pind
if portion isa SciMLStructures.Tunable
return idx isa Int ? p.tunable[idx] : view(p.tunable, idx)
end
i, j, k... = idx
if portion isa SciMLStructures.Tunable
return isempty(k) ? p.tunable[i][j] : p.tunable[i][j][k...]
elseif portion isa SciMLStructures.Discrete
if portion isa SciMLStructures.Discrete
return isempty(k) ? p.discrete[i][j] : p.discrete[i][j][k...]
elseif portion isa SciMLStructures.Constants
return isempty(k) ? p.constant[i][j] : p.constant[i][j][k...]
Expand Down Expand Up @@ -435,20 +436,26 @@ function validate_parameter_type(ic::IndexCache, p, idx::ParameterIndex, val)
end

function validate_parameter_type(ic::IndexCache, idx::ParameterIndex, val)
stype = get_buffer_template(ic, idx).type
if idx.portion == SciMLStructures.Tunable() && !(idx.idx isa Int)
stype = AbstractArray{<:stype}
end
validate_parameter_type(
ic, get_buffer_template(ic, idx).type, Symbolics.Unknown(), nothing, idx, val)
ic, stype, Symbolics.Unknown(), nothing, idx, val)
end

function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val)
(; portion) = index
# Nonnumeric parameters have to match the type
if portion === NONNUMERIC_PORTION
val isa stype && return nothing
throw(ParameterTypeException(:validate_parameter_type, sym, stype, val))
throw(ParameterTypeException(
:validate_parameter_type, sym === nothing ? index : sym, stype, val))
end
# Array parameters need array values...
if stype <: AbstractArray && !isa(val, AbstractArray)
throw(ParameterTypeException(:validate_parameter_type, sym, stype, val))
throw(ParameterTypeException(
:validate_parameter_type, sym === nothing ? index : sym, stype, val))
end
# ... and must match sizes
if stype <: AbstractArray && sz != Symbolics.Unknown() && size(val) != sz
Expand All @@ -465,15 +472,16 @@ function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val)
# This is for duals and other complicated number types
etype = SciMLBase.parameterless_type(etype)
eltype(val) <: etype || throw(ParameterTypeException(
:validate_parameter_type, sym, AbstractArray{etype}, val))
:validate_parameter_type, sym === nothing ? index : sym, AbstractArray{etype}, val))
else
# Real check
if stype <: Real
stype = Real
end
stype = SciMLBase.parameterless_type(stype)
val isa stype ||
throw(ParameterTypeException(:validate_parameter_type, sym, stype, val))
throw(ParameterTypeException(
:validate_parameter_type, sym === nothing ? index : sym, stype, val))
end
end

Expand All @@ -485,6 +493,9 @@ function indp_to_system(indp)
end

function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, idxs, vals)
_remake_buffer(indp, oldbuf, idxs, vals)
end
function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true)
newbuf = @set oldbuf.tunable = similar(oldbuf.tunable, Any)
@set! newbuf.discrete = Tuple(similar(buf, Any) for buf in newbuf.discrete)
@set! newbuf.constant = Tuple(similar(buf, Any) for buf in newbuf.constant)
Expand Down
2 changes: 2 additions & 0 deletions test/extensions/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Expand Down
52 changes: 52 additions & 0 deletions test/extensions/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ using SciMLStructures
using OrdinaryDiffEq
using SciMLSensitivity
using ForwardDiff
using ChainRulesCore
using ChainRulesCore: NoTangent
using ChainRulesTestUtils: test_rrule, rand_tangent

@variables x(t)[1:3] y(t)
@parameters p[1:3, 1:3] q
Expand Down Expand Up @@ -51,3 +54,52 @@ end

@test ForwardDiff.gradient(x_at_0, [0.3, 0.7]) == zeros(2)
end

@parameters a b[1:3] c(t) d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
@named sys = ODESystem(
Equation[], t, [], [a, b, c, d, e, f, g, h],
continuous_events = [[a ~ 0] => [c ~ 0]])
sys = complete(sys)

ivs = Dict(c => 3a, b => ones(3), a => 1.0, d => 4, e => [5.0, 6.0, 7.0],
f => ones(Int, 3, 3), g => [0.1, 0.2, 0.3], h => "foo")

ps = MTKParameters(sys, ivs)

varmap = Dict(a => 1.0f0, b => 3ones(Float32, 3), c => 2.0,
e => Float32[0.4, 0.5, 0.6], g => ones(Float32, 4))
get_values = getp(sys, [a, b..., c, e...])
get_g = getp(sys, g)
for (_idxs, vals) in [
# all portions
(collect(keys(varmap)), collect(values(varmap))),
# non-arrays
(keys(varmap), values(varmap)),
# tunable only
([a], [varmap[a]]),
([a, b], (varmap[a], varmap[b])),
([a, b[2]], (varmap[a], varmap[b][2]))
]
for idxs in [_idxs, map(i -> parameter_index(sys, i), collect(_idxs))]
loss = function (p)
newps = remake_buffer(sys, ps, idxs, p)
return sum(get_values(newps)) + sum(get_g(newps))
end

grad = Zygote.gradient(loss, vals)[1]
for (val, g) in zip(vals, grad)
@test eltype(val) == eltype(g)
if val isa Number
@test isone(g)
else
@test all(isone, g)
end
end
end
end

idxs = (parameter_index(sys, a), parameter_index(sys, b))
vals = (1.0f0, 3ones(Float32, 3))
tangent = rand_tangent(ps)
fwd, back = ChainRulesCore.rrule(remake_buffer, sys, ps, idxs, vals)
@inferred back(tangent)
Loading