Skip to content

Commit 96694de

Browse files
feat: handle more type-stable cases in @generated _remake_buffer
1 parent 0ba6120 commit 96694de

File tree

1 file changed

+57
-15
lines changed

1 file changed

+57
-15
lines changed

src/systems/parameter_buffer.jl

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -499,40 +499,76 @@ end
499499

500500
# For type-inference when using `SII.setp_oop`
501501
@generated function _remake_buffer(
502-
indp, oldbuf::MTKParameters{T, I, D, C, N, H}, idxs::Tuple{Vararg{ParameterIndex}},
503-
vals::Union{AbstractArray, Tuple}; validate = true) where {T, I, D, C, N, H}
504-
valtype(i) = vals <: AbstractArray ? eltype(vals) : fieldtype(vals, i)
502+
indp, oldbuf::MTKParameters{T, I, D, C, N, H},
503+
idxs::Union{Tuple{Vararg{ParameterIndex}}, AbstractArray{<:ParameterIndex{P}}},
504+
vals::Union{AbstractArray, Tuple}; validate = true) where {T, I, D, C, N, H, P}
505+
506+
# fallback to non-generated method if values aren't type-stable
507+
if vals <: AbstractArray && !isconcretetype(eltype(vals))
508+
return quote
509+
$_remake_buffer(indp, oldbuf, collect(idxs), vals; validate)
510+
end
511+
end
512+
513+
# given an index in idxs/vals and the current `eltype` of the buffer,
514+
# return the promoted eltype of the buffer
515+
function promote_valtype(i, valT)
516+
# tuples have distinct types, arrays have a common eltype
517+
valT′ = vals <: AbstractArray ? eltype(vals) : fieldtype(vals, i)
518+
# if the buffer is a scalarized buffer but the variable is an array
519+
# e.g. an array tunable, take the eltype
520+
if valT′ <: AbstractArray && !(valT <: AbstractArray)
521+
valT′ = eltype(valT′)
522+
end
523+
return promote_type(valT, valT′)
524+
end
525+
526+
# types of the idxs
527+
idxtypes = if idxs <: AbstractArray
528+
# if both are arrays, there is only one possible type to check
529+
if vals <: AbstractArray
530+
(eltype(idxs),)
531+
else
532+
# if `vals` is a tuple, we repeat `eltype(idxs)` to check against
533+
# every possible type of the buffer
534+
ntuple(Returns(eltype(idxs)), Val(fieldcount(vals)))
535+
end
536+
else
537+
# `idxs` is a tuple, so we check against all buffers
538+
fieldtypes(idxs)
539+
end
540+
# promote types
505541
tunablesT = eltype(T)
506-
for (i, idxT) in enumerate(fieldtypes(idxs))
542+
for (i, idxT) in enumerate(idxtypes)
507543
idxT <: ParameterIndex{SciMLStructures.Tunable} || continue
508-
tunablesT = promote_type(tunablesT, valtype(i))
544+
tunablesT = promote_valtype(i, tunablesT)
509545
end
510546
initialsT = eltype(I)
511-
for (i, idxT) in enumerate(fieldtypes(idxs))
547+
for (i, idxT) in enumerate(idxtypes)
512548
idxT <: ParameterIndex{SciMLStructures.Initials} || continue
513-
initialsT = promote_type(initialsT, valtype(i))
549+
initialsT = promote_valtype(i, initialsT)
514550
end
515551
discretesT = ntuple(Val(fieldcount(D))) do i
516552
bufT = eltype(fieldtype(D, i))
517-
for (j, idxT) in enumerate(fieldtypes(idxs))
553+
for (j, idxT) in enumerate(idxtypes)
518554
idxT <: ParameterIndex{SciMLStructures.Discrete, i} || continue
519-
bufT = promote_type(bufT, valtype(i))
555+
bufT = promote_valtype(i, bufT)
520556
end
521557
bufT
522558
end
523559
constantsT = ntuple(Val(fieldcount(C))) do i
524560
bufT = eltype(fieldtype(C, i))
525-
for (j, idxT) in enumerate(fieldtypes(idxs))
561+
for (j, idxT) in enumerate(idxtypes)
526562
idxT <: ParameterIndex{SciMLStructures.Constants, i} || continue
527-
bufT = promote_type(bufT, valtype(i))
563+
bufT = promote_valtype(i, bufT)
528564
end
529565
bufT
530566
end
531567
nonnumericT = ntuple(Val(fieldcount(N))) do i
532568
bufT = eltype(fieldtype(N, i))
533-
for (j, idxT) in enumerate(fieldtypes(idxs))
569+
for (j, idxT) in enumerate(idxtypes)
534570
idxT <: ParameterIndex{Nonnumeric, i} || continue
535-
bufT = promote_type(bufT, valtype(i))
571+
bufT = promote_valtype(i, bufT)
536572
end
537573
bufT
538574
end
@@ -554,8 +590,14 @@ end
554590
newbuf = MTKParameters(
555591
tunables, initials, discretes, constants, nonnumerics, copy.(oldbuf.caches))
556592
end
557-
for i in 1:fieldcount(idxs)
558-
push!(expr.args, :($setindex!(newbuf, vals[$i], idxs[$i])))
593+
if idxs <: AbstractArray
594+
push!(expr.args, :(for (idx, val) in zip(idxs, vals)
595+
$setindex!(newbuf, val, idx)
596+
end))
597+
else
598+
for i in 1:fieldcount(idxs)
599+
push!(expr.args, :($setindex!(newbuf, vals[$i], idxs[$i])))
600+
end
559601
end
560602
push!(expr.args, :(return newbuf))
561603

0 commit comments

Comments
 (0)