diff --git a/ext/MTKChainRulesCoreExt.jl b/ext/MTKChainRulesCoreExt.jl index e7019e25df..f84690e23f 100644 --- a/ext/MTKChainRulesCoreExt.jl +++ b/ext/MTKChainRulesCoreExt.jl @@ -2,14 +2,99 @@ module MTKChainRulesCoreExt import ModelingToolkit as MTK import ChainRulesCore -import ChainRulesCore: NoTangent +import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...) function mtp_pullback(dt) + dt = unthunk(dt) (NoTangent(), dt.tunable[1:length(tunables)], ntuple(_ -> NoTangent(), length(args))...) end MTK.MTKParameters(tunables, args...), mtp_pullback end +function subset_idxs(idxs, portion, template) + ntuple(Val(length(template))) do subi + [Base.tail(idx.idx) for idx in idxs if idx.portion == portion && idx.idx[1] == subi] + end +end + +selected_tangents(::NoTangent, _) = () +selected_tangents(::ZeroTangent, _) = ZeroTangent() +function selected_tangents( + tangents::AbstractArray{T}, idxs::Vector{Tuple{Int}}) where {T <: Number} + selected_tangents(tangents, map(only, idxs)) +end +function selected_tangents(tangents::AbstractArray{T}, idxs...) where {T <: Number} + newtangents = copy(tangents) + view(newtangents, idxs...) .= zero(T) + newtangents +end +function selected_tangents( + tangents::AbstractVector{T}, idxs) where {S <: Number, T <: AbstractArray{S}} + newtangents = copy(tangents) + for i in idxs + j, k... = i + if k == () + newtangents[j] = zero(newtangents[j]) + else + newtangents[j] = selected_tangents(newtangents[j], k...) + end + end + newtangents +end +function selected_tangents(tangents::AbstractVector{T}, idxs) where {T <: AbstractArray} + newtangents = similar(tangents, Union{T, NoTangent}) + copyto!(newtangents, tangents) + for i in idxs + j, k... = i + if k == () + newtangents[j] = NoTangent() + else + newtangents[j] = selected_tangents(newtangents[j], k...) + end + end + newtangents +end +function selected_tangents( + tangents::Union{Tangent{<:Tuple}, Tangent{T, <:Tuple}}, idxs) where {T} + ntuple(Val(length(tangents))) do i + selected_tangents(tangents[i], idxs[i]) + end +end + +function ChainRulesCore.rrule( + ::typeof(MTK.remake_buffer), indp, oldbuf::MTK.MTKParameters, idxs, vals) + if idxs isa AbstractSet + idxs = collect(idxs) + end + idxs = map(idxs) do i + i isa MTK.ParameterIndex ? i : MTK.parameter_index(indp, i) + end + newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals) + tunable_idxs = reduce( + vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable)) + disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete) + const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant) + nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric) + + pullback = let idxs = idxs + function remake_buffer_pullback(buf′) + buf′ = unthunk(buf′) + f′ = NoTangent() + indp′ = NoTangent() + + tunable = selected_tangents(buf′.tunable, tunable_idxs) + discrete = selected_tangents(buf′.discrete, disc_idxs) + constant = selected_tangents(buf′.constant, const_idxs) + nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs) + oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, discrete, constant, nonnumeric) + idxs′ = NoTangent() + vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs) + return f′, indp′, oldbuf′, idxs′, vals′ + end + end + newbuf, pullback +end + end diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 3f2b4ddebe..64e9c134f6 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -288,11 +288,7 @@ function IndexCache(sys::AbstractSystem) end function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym) - if sym isa Symbol - sym = get(ic.symbol_to_variable, sym, nothing) - sym === nothing && return false - end - return check_index_map(ic.unknown_idx, sym) !== nothing + variable_index(ic, sym) !== nothing end function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym) @@ -300,18 +296,17 @@ function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym) sym = get(ic.symbol_to_variable, sym, nothing) sym === nothing && return nothing end - return check_index_map(ic.unknown_idx, sym) + idx = check_index_map(ic.unknown_idx, sym) + idx === nothing || return idx + iscall(sym) && operation(sym) == getindex || return nothing + args = arguments(sym) + idx = variable_index(ic, args[1]) + idx === nothing && return nothing + return idx[args[2:end]...] end function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym) - if sym isa Symbol - sym = get(ic.symbol_to_variable, sym, nothing) - sym === nothing && return false - end - return check_index_map(ic.tunable_idx, sym) !== nothing || - check_index_map(ic.discrete_idx, sym) !== nothing || - check_index_map(ic.constant_idx, sym) !== nothing || - check_index_map(ic.nonnumeric_idx, sym) !== nothing + parameter_index(ic, sym) !== nothing end function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym) @@ -331,17 +326,21 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym) ParameterIndex(SciMLStructures.Constants(), idx, validate_size) elseif (idx = check_index_map(ic.nonnumeric_idx, sym)) !== nothing ParameterIndex(NONNUMERIC_PORTION, idx, validate_size) - else - nothing + elseif iscall(sym) && operation(sym) == getindex + args = arguments(sym) + pidx = parameter_index(ic, args[1]) + pidx === nothing && return nothing + if pidx.portion == SciMLStructures.Tunable() + ParameterIndex(pidx.portion, reshape(pidx.idx, size(args[1]))[args[2:end]...], + pidx.validate_size) + else + ParameterIndex(pidx.portion, (pidx.idx..., args[2:end]...), pidx.validate_size) + end end end function SymbolicIndexingInterface.is_timeseries_parameter(ic::IndexCache, sym) - if sym isa Symbol - sym = get(ic.symbol_to_variable, sym, nothing) - sym === nothing && return false - end - return check_index_map(ic.discrete_idx, sym) !== nothing + timeseries_parameter_index(ic, sym) !== nothing end function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sym) @@ -350,8 +349,13 @@ function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sy sym === nothing && return nothing end idx = check_index_map(ic.discrete_idx, sym) + idx === nothing || + return ParameterTimeseriesIndex(idx.clock_idx, (idx.buffer_idx, idx.idx_in_clock)) + iscall(sym) && operation(sym) == getindex || return nothing + args = arguments(sym) + idx = timeseries_parameter_index(ic, args[1]) idx === nothing && return nothing - return ParameterTimeseriesIndex(idx.clock_idx, (idx.buffer_idx, idx.idx_in_clock)) + ParameterIndex(idx.portion, (idx.idx..., args[2:end]...), idx.validate_size) end function check_index_map(idxmap, sym) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 95ca13adad..ad001b32dc 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -336,14 +336,15 @@ function Base.copy(p::MTKParameters) end function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::ParameterIndex) + _ducktyped_parameter_values(p, pind) +end +function _ducktyped_parameter_values(p, pind::ParameterIndex) @unpack portion, idx = pind if portion isa SciMLStructures.Tunable return idx isa Int ? p.tunable[idx] : view(p.tunable, idx) end i, j, k... = idx - if portion isa SciMLStructures.Tunable - return isempty(k) ? p.tunable[i][j] : p.tunable[i][j][k...] - elseif portion isa SciMLStructures.Discrete + if portion isa SciMLStructures.Discrete return isempty(k) ? p.discrete[i][j] : p.discrete[i][j][k...] elseif portion isa SciMLStructures.Constants return isempty(k) ? p.constant[i][j] : p.constant[i][j][k...] @@ -435,8 +436,12 @@ function validate_parameter_type(ic::IndexCache, p, idx::ParameterIndex, val) end function validate_parameter_type(ic::IndexCache, idx::ParameterIndex, val) + stype = get_buffer_template(ic, idx).type + if idx.portion == SciMLStructures.Tunable() && !(idx.idx isa Int) + stype = AbstractArray{<:stype} + end validate_parameter_type( - ic, get_buffer_template(ic, idx).type, Symbolics.Unknown(), nothing, idx, val) + ic, stype, Symbolics.Unknown(), nothing, idx, val) end function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val) @@ -444,11 +449,13 @@ function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val) # Nonnumeric parameters have to match the type if portion === NONNUMERIC_PORTION val isa stype && return nothing - throw(ParameterTypeException(:validate_parameter_type, sym, stype, val)) + throw(ParameterTypeException( + :validate_parameter_type, sym === nothing ? index : sym, stype, val)) end # Array parameters need array values... if stype <: AbstractArray && !isa(val, AbstractArray) - throw(ParameterTypeException(:validate_parameter_type, sym, stype, val)) + throw(ParameterTypeException( + :validate_parameter_type, sym === nothing ? index : sym, stype, val)) end # ... and must match sizes if stype <: AbstractArray && sz != Symbolics.Unknown() && size(val) != sz @@ -465,7 +472,7 @@ function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val) # This is for duals and other complicated number types etype = SciMLBase.parameterless_type(etype) eltype(val) <: etype || throw(ParameterTypeException( - :validate_parameter_type, sym, AbstractArray{etype}, val)) + :validate_parameter_type, sym === nothing ? index : sym, AbstractArray{etype}, val)) else # Real check if stype <: Real @@ -473,7 +480,8 @@ function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val) end stype = SciMLBase.parameterless_type(stype) val isa stype || - throw(ParameterTypeException(:validate_parameter_type, sym, stype, val)) + throw(ParameterTypeException( + :validate_parameter_type, sym === nothing ? index : sym, stype, val)) end end @@ -485,6 +493,9 @@ function indp_to_system(indp) end function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, idxs, vals) + _remake_buffer(indp, oldbuf, idxs, vals) +end +function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true) newbuf = @set oldbuf.tunable = similar(oldbuf.tunable, Any) @set! newbuf.discrete = Tuple(similar(buf, Any) for buf in newbuf.discrete) @set! newbuf.constant = Tuple(similar(buf, Any) for buf in newbuf.constant) diff --git a/test/extensions/Project.toml b/test/extensions/Project.toml index d8d3d64605..81097d98d2 100644 --- a/test/extensions/Project.toml +++ b/test/extensions/Project.toml @@ -1,5 +1,7 @@ [deps] BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" diff --git a/test/extensions/ad.jl b/test/extensions/ad.jl index 1946db5277..954b868a1e 100644 --- a/test/extensions/ad.jl +++ b/test/extensions/ad.jl @@ -6,6 +6,9 @@ using SciMLStructures using OrdinaryDiffEq using SciMLSensitivity using ForwardDiff +using ChainRulesCore +using ChainRulesCore: NoTangent +using ChainRulesTestUtils: test_rrule, rand_tangent @variables x(t)[1:3] y(t) @parameters p[1:3, 1:3] q @@ -51,3 +54,52 @@ end @test ForwardDiff.gradient(x_at_0, [0.3, 0.7]) == zeros(2) end + +@parameters a b[1:3] c(t) d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String +@named sys = ODESystem( + Equation[], t, [], [a, b, c, d, e, f, g, h], + continuous_events = [[a ~ 0] => [c ~ 0]]) +sys = complete(sys) + +ivs = Dict(c => 3a, b => ones(3), a => 1.0, d => 4, e => [5.0, 6.0, 7.0], + f => ones(Int, 3, 3), g => [0.1, 0.2, 0.3], h => "foo") + +ps = MTKParameters(sys, ivs) + +varmap = Dict(a => 1.0f0, b => 3ones(Float32, 3), c => 2.0, + e => Float32[0.4, 0.5, 0.6], g => ones(Float32, 4)) +get_values = getp(sys, [a, b..., c, e...]) +get_g = getp(sys, g) +for (_idxs, vals) in [ + # all portions + (collect(keys(varmap)), collect(values(varmap))), + # non-arrays + (keys(varmap), values(varmap)), + # tunable only + ([a], [varmap[a]]), + ([a, b], (varmap[a], varmap[b])), + ([a, b[2]], (varmap[a], varmap[b][2])) +] + for idxs in [_idxs, map(i -> parameter_index(sys, i), collect(_idxs))] + loss = function (p) + newps = remake_buffer(sys, ps, idxs, p) + return sum(get_values(newps)) + sum(get_g(newps)) + end + + grad = Zygote.gradient(loss, vals)[1] + for (val, g) in zip(vals, grad) + @test eltype(val) == eltype(g) + if val isa Number + @test isone(g) + else + @test all(isone, g) + end + end + end +end + +idxs = (parameter_index(sys, a), parameter_index(sys, b)) +vals = (1.0f0, 3ones(Float32, 3)) +tangent = rand_tangent(ps) +fwd, back = ChainRulesCore.rrule(remake_buffer, sys, ps, idxs, vals) +@inferred back(tangent)