Skip to content

feat: improve type-inference of remake_buffer in certain cases #3546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
StochasticDelayDiffEq = "1.8.1"
StochasticDiffEq = "6.72.1"
SymbolicIndexingInterface = "0.3.37"
SymbolicIndexingInterface = "0.3.39"
SymbolicUtils = "3.25.1"
Symbolics = "6.37"
URIs = "1"
Expand Down
112 changes: 112 additions & 0 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,119 @@ end
function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, idxs, vals)
_remake_buffer(indp, oldbuf, idxs, vals)
end

# For type-inference when using `SII.setp_oop`
@generated function _remake_buffer(
indp, oldbuf::MTKParameters{T, I, D, C, N, H},
idxs::Union{Tuple{Vararg{ParameterIndex}}, AbstractArray{<:ParameterIndex{P}}},
vals::Union{AbstractArray, Tuple}; validate = true) where {T, I, D, C, N, H, P}

# fallback to non-generated method if values aren't type-stable
if vals <: AbstractArray && !isconcretetype(eltype(vals))
return quote
$__remake_buffer(indp, oldbuf, collect(idxs), vals; validate)
end
end

# given an index in idxs/vals and the current `eltype` of the buffer,
# return the promoted eltype of the buffer
function promote_valtype(i, valT)
# tuples have distinct types, arrays have a common eltype
valT′ = vals <: AbstractArray ? eltype(vals) : fieldtype(vals, i)
# if the buffer is a scalarized buffer but the variable is an array
# e.g. an array tunable, take the eltype
if valT′ <: AbstractArray && !(valT <: AbstractArray)
valT′ = eltype(valT′)
end
return promote_type(valT, valT′)
end

# types of the idxs
idxtypes = if idxs <: AbstractArray
# if both are arrays, there is only one possible type to check
if vals <: AbstractArray
(eltype(idxs),)
else
# if `vals` is a tuple, we repeat `eltype(idxs)` to check against
# every possible type of the buffer
ntuple(Returns(eltype(idxs)), Val(fieldcount(vals)))
end
else
# `idxs` is a tuple, so we check against all buffers
fieldtypes(idxs)
end
# promote types
tunablesT = eltype(T)
for (i, idxT) in enumerate(idxtypes)
idxT <: ParameterIndex{SciMLStructures.Tunable} || continue
tunablesT = promote_valtype(i, tunablesT)
end
initialsT = eltype(I)
for (i, idxT) in enumerate(idxtypes)
idxT <: ParameterIndex{SciMLStructures.Initials} || continue
initialsT = promote_valtype(i, initialsT)
end
discretesT = ntuple(Val(fieldcount(D))) do i
bufT = eltype(fieldtype(D, i))
for (j, idxT) in enumerate(idxtypes)
idxT <: ParameterIndex{SciMLStructures.Discrete, i} || continue
bufT = promote_valtype(i, bufT)
end
bufT
end
constantsT = ntuple(Val(fieldcount(C))) do i
bufT = eltype(fieldtype(C, i))
for (j, idxT) in enumerate(idxtypes)
idxT <: ParameterIndex{SciMLStructures.Constants, i} || continue
bufT = promote_valtype(i, bufT)
end
bufT
end
nonnumericT = ntuple(Val(fieldcount(N))) do i
bufT = eltype(fieldtype(N, i))
for (j, idxT) in enumerate(idxtypes)
idxT <: ParameterIndex{Nonnumeric, i} || continue
bufT = promote_valtype(i, bufT)
end
bufT
end

expr = quote
tunables = $similar(oldbuf.tunable, $tunablesT)
copyto!(tunables, oldbuf.tunable)
initials = $similar(oldbuf.initials, $initialsT)
copyto!(initials, oldbuf.initials)
discretes = $(Expr(:tuple,
(:($similar(oldbuf.discrete[$i], $(discretesT[i]))) for i in 1:length(discretesT))...))
$((:($copyto!(discretes[$i], oldbuf.discrete[$i])) for i in 1:length(discretesT))...)
constants = $(Expr(:tuple,
(:($similar(oldbuf.constant[$i], $(constantsT[i]))) for i in 1:length(constantsT))...))
$((:($copyto!(constants[$i], oldbuf.constant[$i])) for i in 1:length(constantsT))...)
nonnumerics = $(Expr(:tuple,
(:($similar(oldbuf.nonnumeric[$i], $(nonnumericT[i]))) for i in 1:length(nonnumericT))...))
$((:($copyto!(nonnumerics[$i], oldbuf.nonnumeric[$i])) for i in 1:length(nonnumericT))...)
newbuf = MTKParameters(
tunables, initials, discretes, constants, nonnumerics, copy.(oldbuf.caches))
end
if idxs <: AbstractArray
push!(expr.args, :(for (idx, val) in zip(idxs, vals)
$setindex!(newbuf, val, idx)
end))
else
for i in 1:fieldcount(idxs)
push!(expr.args, :($setindex!(newbuf, vals[$i], idxs[$i])))
end
end
push!(expr.args, :(return newbuf))

return expr
end

function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true)
return __remake_buffer(indp, oldbuf, idxs, vals; validate)
end

function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true)
newbuf = @set oldbuf.tunable = similar(oldbuf.tunable, Any)
@set! newbuf.initials = similar(oldbuf.initials, Any)
@set! newbuf.discrete = Tuple(similar(buf, Any) for buf in newbuf.discrete)
Expand Down
21 changes: 21 additions & 0 deletions test/mtkparameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ end
@test getp(sys, f)(newps) isa Matrix{UInt}
@test getp(sys, g)(newps) isa Vector{Float32}

@testset "Type-stability of `remake_buffer`" begin
prob = ODEProblem(sys, [], (0.0, 1.0), ivs)

idxs = (a, c, d, e, f, g, h)
vals = (1.0, 2.0, 3, ones(3), ones(Int, 3, 3), ones(2), "a")

setter = setsym_oop(prob, idxs)
@test_nowarn @inferred setter(prob, vals)
@test_throws ErrorException @inferred setter(prob, collect(vals))

idxs = (a, c, e...)
vals = Float16[1.0, 2.0, 3.0, 4.0, 5.0]
setter = setsym_oop(prob, idxs)
@test_nowarn @inferred setter(prob, vals)

idxs = [a, e]
vals = (Float16(1.0), ForwardDiff.Dual{Nothing, Float16, 0}[1.0, 2.0, 3.0])
setter = setsym_oop(prob, idxs)
@test_nowarn @inferred setter(prob, vals)
end

ps = MTKParameters(sys, ivs)
function loss(value, sys, ps)
@test value isa ForwardDiff.Dual
Expand Down
Loading