diff --git a/Project.toml b/Project.toml index 63aef5b8da..693a81e148 100644 --- a/Project.toml +++ b/Project.toml @@ -149,7 +149,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" StochasticDelayDiffEq = "1.8.1" StochasticDiffEq = "6.72.1" -SymbolicIndexingInterface = "0.3.37" +SymbolicIndexingInterface = "0.3.39" SymbolicUtils = "3.25.1" Symbolics = "6.37" URIs = "1" diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 36105c0de8..672fb58795 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -496,7 +496,119 @@ end function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, idxs, vals) _remake_buffer(indp, oldbuf, idxs, vals) end + +# For type-inference when using `SII.setp_oop` +@generated function _remake_buffer( + indp, oldbuf::MTKParameters{T, I, D, C, N, H}, + idxs::Union{Tuple{Vararg{ParameterIndex}}, AbstractArray{<:ParameterIndex{P}}}, + vals::Union{AbstractArray, Tuple}; validate = true) where {T, I, D, C, N, H, P} + + # fallback to non-generated method if values aren't type-stable + if vals <: AbstractArray && !isconcretetype(eltype(vals)) + return quote + $__remake_buffer(indp, oldbuf, collect(idxs), vals; validate) + end + end + + # given an index in idxs/vals and the current `eltype` of the buffer, + # return the promoted eltype of the buffer + function promote_valtype(i, valT) + # tuples have distinct types, arrays have a common eltype + valT′ = vals <: AbstractArray ? eltype(vals) : fieldtype(vals, i) + # if the buffer is a scalarized buffer but the variable is an array + # e.g. an array tunable, take the eltype + if valT′ <: AbstractArray && !(valT <: AbstractArray) + valT′ = eltype(valT′) + end + return promote_type(valT, valT′) + end + + # types of the idxs + idxtypes = if idxs <: AbstractArray + # if both are arrays, there is only one possible type to check + if vals <: AbstractArray + (eltype(idxs),) + else + # if `vals` is a tuple, we repeat `eltype(idxs)` to check against + # every possible type of the buffer + ntuple(Returns(eltype(idxs)), Val(fieldcount(vals))) + end + else + # `idxs` is a tuple, so we check against all buffers + fieldtypes(idxs) + end + # promote types + tunablesT = eltype(T) + for (i, idxT) in enumerate(idxtypes) + idxT <: ParameterIndex{SciMLStructures.Tunable} || continue + tunablesT = promote_valtype(i, tunablesT) + end + initialsT = eltype(I) + for (i, idxT) in enumerate(idxtypes) + idxT <: ParameterIndex{SciMLStructures.Initials} || continue + initialsT = promote_valtype(i, initialsT) + end + discretesT = ntuple(Val(fieldcount(D))) do i + bufT = eltype(fieldtype(D, i)) + for (j, idxT) in enumerate(idxtypes) + idxT <: ParameterIndex{SciMLStructures.Discrete, i} || continue + bufT = promote_valtype(i, bufT) + end + bufT + end + constantsT = ntuple(Val(fieldcount(C))) do i + bufT = eltype(fieldtype(C, i)) + for (j, idxT) in enumerate(idxtypes) + idxT <: ParameterIndex{SciMLStructures.Constants, i} || continue + bufT = promote_valtype(i, bufT) + end + bufT + end + nonnumericT = ntuple(Val(fieldcount(N))) do i + bufT = eltype(fieldtype(N, i)) + for (j, idxT) in enumerate(idxtypes) + idxT <: ParameterIndex{Nonnumeric, i} || continue + bufT = promote_valtype(i, bufT) + end + bufT + end + + expr = quote + tunables = $similar(oldbuf.tunable, $tunablesT) + copyto!(tunables, oldbuf.tunable) + initials = $similar(oldbuf.initials, $initialsT) + copyto!(initials, oldbuf.initials) + discretes = $(Expr(:tuple, + (:($similar(oldbuf.discrete[$i], $(discretesT[i]))) for i in 1:length(discretesT))...)) + $((:($copyto!(discretes[$i], oldbuf.discrete[$i])) for i in 1:length(discretesT))...) + constants = $(Expr(:tuple, + (:($similar(oldbuf.constant[$i], $(constantsT[i]))) for i in 1:length(constantsT))...)) + $((:($copyto!(constants[$i], oldbuf.constant[$i])) for i in 1:length(constantsT))...) + nonnumerics = $(Expr(:tuple, + (:($similar(oldbuf.nonnumeric[$i], $(nonnumericT[i]))) for i in 1:length(nonnumericT))...)) + $((:($copyto!(nonnumerics[$i], oldbuf.nonnumeric[$i])) for i in 1:length(nonnumericT))...) + newbuf = MTKParameters( + tunables, initials, discretes, constants, nonnumerics, copy.(oldbuf.caches)) + end + if idxs <: AbstractArray + push!(expr.args, :(for (idx, val) in zip(idxs, vals) + $setindex!(newbuf, val, idx) + end)) + else + for i in 1:fieldcount(idxs) + push!(expr.args, :($setindex!(newbuf, vals[$i], idxs[$i]))) + end + end + push!(expr.args, :(return newbuf)) + + return expr +end + function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true) + return __remake_buffer(indp, oldbuf, idxs, vals; validate) +end + +function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true) newbuf = @set oldbuf.tunable = similar(oldbuf.tunable, Any) @set! newbuf.initials = similar(oldbuf.initials, Any) @set! newbuf.discrete = Tuple(similar(buf, Any) for buf in newbuf.discrete) diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index f9d71f00bc..22201b1988 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -91,6 +91,27 @@ end @test getp(sys, f)(newps) isa Matrix{UInt} @test getp(sys, g)(newps) isa Vector{Float32} +@testset "Type-stability of `remake_buffer`" begin + prob = ODEProblem(sys, [], (0.0, 1.0), ivs) + + idxs = (a, c, d, e, f, g, h) + vals = (1.0, 2.0, 3, ones(3), ones(Int, 3, 3), ones(2), "a") + + setter = setsym_oop(prob, idxs) + @test_nowarn @inferred setter(prob, vals) + @test_throws ErrorException @inferred setter(prob, collect(vals)) + + idxs = (a, c, e...) + vals = Float16[1.0, 2.0, 3.0, 4.0, 5.0] + setter = setsym_oop(prob, idxs) + @test_nowarn @inferred setter(prob, vals) + + idxs = [a, e] + vals = (Float16(1.0), ForwardDiff.Dual{Nothing, Float16, 0}[1.0, 2.0, 3.0]) + setter = setsym_oop(prob, idxs) + @test_nowarn @inferred setter(prob, vals) +end + ps = MTKParameters(sys, ivs) function loss(value, sys, ps) @test value isa ForwardDiff.Dual