Skip to content

Commit 19f8164

Browse files
Merge pull request #3546 from AayushSabharwal/as/generated-remake
feat: improve type-inference of `remake_buffer` in certain cases
2 parents c2ff784 + 06a63df commit 19f8164

File tree

3 files changed

+134
-1
lines changed

3 files changed

+134
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
149149
StaticArrays = "0.10, 0.11, 0.12, 1.0"
150150
StochasticDelayDiffEq = "1.8.1"
151151
StochasticDiffEq = "6.72.1"
152-
SymbolicIndexingInterface = "0.3.37"
152+
SymbolicIndexingInterface = "0.3.39"
153153
SymbolicUtils = "3.25.1"
154154
Symbolics = "6.37"
155155
URIs = "1"

src/systems/parameter_buffer.jl

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,119 @@ end
496496
function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, idxs, vals)
497497
_remake_buffer(indp, oldbuf, idxs, vals)
498498
end
499+
500+
# For type-inference when using `SII.setp_oop`
501+
@generated function _remake_buffer(
502+
indp, oldbuf::MTKParameters{T, I, D, C, N, H},
503+
idxs::Union{Tuple{Vararg{ParameterIndex}}, AbstractArray{<:ParameterIndex{P}}},
504+
vals::Union{AbstractArray, Tuple}; validate = true) where {T, I, D, C, N, H, P}
505+
506+
# fallback to non-generated method if values aren't type-stable
507+
if vals <: AbstractArray && !isconcretetype(eltype(vals))
508+
return quote
509+
$__remake_buffer(indp, oldbuf, collect(idxs), vals; validate)
510+
end
511+
end
512+
513+
# given an index in idxs/vals and the current `eltype` of the buffer,
514+
# return the promoted eltype of the buffer
515+
function promote_valtype(i, valT)
516+
# tuples have distinct types, arrays have a common eltype
517+
valT′ = vals <: AbstractArray ? eltype(vals) : fieldtype(vals, i)
518+
# if the buffer is a scalarized buffer but the variable is an array
519+
# e.g. an array tunable, take the eltype
520+
if valT′ <: AbstractArray && !(valT <: AbstractArray)
521+
valT′ = eltype(valT′)
522+
end
523+
return promote_type(valT, valT′)
524+
end
525+
526+
# types of the idxs
527+
idxtypes = if idxs <: AbstractArray
528+
# if both are arrays, there is only one possible type to check
529+
if vals <: AbstractArray
530+
(eltype(idxs),)
531+
else
532+
# if `vals` is a tuple, we repeat `eltype(idxs)` to check against
533+
# every possible type of the buffer
534+
ntuple(Returns(eltype(idxs)), Val(fieldcount(vals)))
535+
end
536+
else
537+
# `idxs` is a tuple, so we check against all buffers
538+
fieldtypes(idxs)
539+
end
540+
# promote types
541+
tunablesT = eltype(T)
542+
for (i, idxT) in enumerate(idxtypes)
543+
idxT <: ParameterIndex{SciMLStructures.Tunable} || continue
544+
tunablesT = promote_valtype(i, tunablesT)
545+
end
546+
initialsT = eltype(I)
547+
for (i, idxT) in enumerate(idxtypes)
548+
idxT <: ParameterIndex{SciMLStructures.Initials} || continue
549+
initialsT = promote_valtype(i, initialsT)
550+
end
551+
discretesT = ntuple(Val(fieldcount(D))) do i
552+
bufT = eltype(fieldtype(D, i))
553+
for (j, idxT) in enumerate(idxtypes)
554+
idxT <: ParameterIndex{SciMLStructures.Discrete, i} || continue
555+
bufT = promote_valtype(i, bufT)
556+
end
557+
bufT
558+
end
559+
constantsT = ntuple(Val(fieldcount(C))) do i
560+
bufT = eltype(fieldtype(C, i))
561+
for (j, idxT) in enumerate(idxtypes)
562+
idxT <: ParameterIndex{SciMLStructures.Constants, i} || continue
563+
bufT = promote_valtype(i, bufT)
564+
end
565+
bufT
566+
end
567+
nonnumericT = ntuple(Val(fieldcount(N))) do i
568+
bufT = eltype(fieldtype(N, i))
569+
for (j, idxT) in enumerate(idxtypes)
570+
idxT <: ParameterIndex{Nonnumeric, i} || continue
571+
bufT = promote_valtype(i, bufT)
572+
end
573+
bufT
574+
end
575+
576+
expr = quote
577+
tunables = $similar(oldbuf.tunable, $tunablesT)
578+
copyto!(tunables, oldbuf.tunable)
579+
initials = $similar(oldbuf.initials, $initialsT)
580+
copyto!(initials, oldbuf.initials)
581+
discretes = $(Expr(:tuple,
582+
(:($similar(oldbuf.discrete[$i], $(discretesT[i]))) for i in 1:length(discretesT))...))
583+
$((:($copyto!(discretes[$i], oldbuf.discrete[$i])) for i in 1:length(discretesT))...)
584+
constants = $(Expr(:tuple,
585+
(:($similar(oldbuf.constant[$i], $(constantsT[i]))) for i in 1:length(constantsT))...))
586+
$((:($copyto!(constants[$i], oldbuf.constant[$i])) for i in 1:length(constantsT))...)
587+
nonnumerics = $(Expr(:tuple,
588+
(:($similar(oldbuf.nonnumeric[$i], $(nonnumericT[i]))) for i in 1:length(nonnumericT))...))
589+
$((:($copyto!(nonnumerics[$i], oldbuf.nonnumeric[$i])) for i in 1:length(nonnumericT))...)
590+
newbuf = MTKParameters(
591+
tunables, initials, discretes, constants, nonnumerics, copy.(oldbuf.caches))
592+
end
593+
if idxs <: AbstractArray
594+
push!(expr.args, :(for (idx, val) in zip(idxs, vals)
595+
$setindex!(newbuf, val, idx)
596+
end))
597+
else
598+
for i in 1:fieldcount(idxs)
599+
push!(expr.args, :($setindex!(newbuf, vals[$i], idxs[$i])))
600+
end
601+
end
602+
push!(expr.args, :(return newbuf))
603+
604+
return expr
605+
end
606+
499607
function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true)
608+
return __remake_buffer(indp, oldbuf, idxs, vals; validate)
609+
end
610+
611+
function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true)
500612
newbuf = @set oldbuf.tunable = similar(oldbuf.tunable, Any)
501613
@set! newbuf.initials = similar(oldbuf.initials, Any)
502614
@set! newbuf.discrete = Tuple(similar(buf, Any) for buf in newbuf.discrete)

test/mtkparameters.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,27 @@ end
9191
@test getp(sys, f)(newps) isa Matrix{UInt}
9292
@test getp(sys, g)(newps) isa Vector{Float32}
9393

94+
@testset "Type-stability of `remake_buffer`" begin
95+
prob = ODEProblem(sys, [], (0.0, 1.0), ivs)
96+
97+
idxs = (a, c, d, e, f, g, h)
98+
vals = (1.0, 2.0, 3, ones(3), ones(Int, 3, 3), ones(2), "a")
99+
100+
setter = setsym_oop(prob, idxs)
101+
@test_nowarn @inferred setter(prob, vals)
102+
@test_throws ErrorException @inferred setter(prob, collect(vals))
103+
104+
idxs = (a, c, e...)
105+
vals = Float16[1.0, 2.0, 3.0, 4.0, 5.0]
106+
setter = setsym_oop(prob, idxs)
107+
@test_nowarn @inferred setter(prob, vals)
108+
109+
idxs = [a, e]
110+
vals = (Float16(1.0), ForwardDiff.Dual{Nothing, Float16, 0}[1.0, 2.0, 3.0])
111+
setter = setsym_oop(prob, idxs)
112+
@test_nowarn @inferred setter(prob, vals)
113+
end
114+
94115
ps = MTKParameters(sys, ivs)
95116
function loss(value, sys, ps)
96117
@test value isa ForwardDiff.Dual

0 commit comments

Comments
 (0)