Skip to content

Commit b685ab5

Browse files
feat: add @generated method for _remake_buffer for type-inference
1 parent c2ff784 commit b685ab5

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

src/systems/parameter_buffer.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,72 @@ end
496496
function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, idxs, vals)
497497
_remake_buffer(indp, oldbuf, idxs, vals)
498498
end
499+
500+
# For type-inference when using `SII.setp_oop`
501+
@generated function _remake_buffer(
502+
indp, oldbuf::MTKParameters{T, I, D, C, N, H}, idxs::Tuple{Vararg{ParameterIndex}},
503+
vals::Union{AbstractArray, Tuple}; validate = true) where {T, I, D, C, N, H}
504+
valtype(i) = vals <: AbstractArray ? eltype(vals) : fieldtype(vals, i)
505+
tunablesT = eltype(T)
506+
for (i, idxT) in enumerate(fieldtypes(idxs))
507+
idxT <: ParameterIndex{SciMLStructures.Tunable} || continue
508+
tunablesT = promote_type(tunablesT, valtype(i))
509+
end
510+
initialsT = eltype(I)
511+
for (i, idxT) in enumerate(fieldtypes(idxs))
512+
idxT <: ParameterIndex{SciMLStructures.Initials} || continue
513+
initialsT = promote_type(initialsT, valtype(i))
514+
end
515+
discretesT = ntuple(Val(fieldcount(D))) do i
516+
bufT = eltype(fieldtype(D, i))
517+
for (j, idxT) in enumerate(fieldtypes(idxs))
518+
idxT <: ParameterIndex{SciMLStructures.Discrete, i} || continue
519+
bufT = promote_type(bufT, valtype(i))
520+
end
521+
bufT
522+
end
523+
constantsT = ntuple(Val(fieldcount(C))) do i
524+
bufT = eltype(fieldtype(C, i))
525+
for (j, idxT) in enumerate(fieldtypes(idxs))
526+
idxT <: ParameterIndex{SciMLStructures.Constants, i} || continue
527+
bufT = promote_type(bufT, valtype(i))
528+
end
529+
bufT
530+
end
531+
nonnumericT = ntuple(Val(fieldcount(N))) do i
532+
bufT = eltype(fieldtype(N, i))
533+
for (j, idxT) in enumerate(fieldtypes(idxs))
534+
idxT <: ParameterIndex{Nonnumeric, i} || continue
535+
bufT = promote_type(bufT, valtype(i))
536+
end
537+
bufT
538+
end
539+
540+
expr = quote
541+
tunables = $similar(oldbuf.tunable, $tunablesT)
542+
copyto!(tunables, oldbuf.tunable)
543+
initials = $similar(oldbuf.initials, $initialsT)
544+
copyto!(initials, oldbuf.initials)
545+
discretes = $(Expr(:tuple,
546+
(:($similar(oldbuf.discrete[$i], $(discretesT[i]))) for i in 1:length(discretesT))...))
547+
$((:($copyto!(discretes[$i], oldbuf.discrete[$i])) for i in 1:length(discretesT))...)
548+
constants = $(Expr(:tuple,
549+
(:($similar(oldbuf.constant[$i], $(constantsT[i]))) for i in 1:length(constantsT))...))
550+
$((:($copyto!(constants[$i], oldbuf.constant[$i])) for i in 1:length(constantsT))...)
551+
nonnumerics = $(Expr(:tuple,
552+
(:($similar(oldbuf.nonnumeric[$i], $(nonnumericT[i]))) for i in 1:length(nonnumericT))...))
553+
$((:($copyto!(nonnumerics[$i], oldbuf.nonnumeric[$i])) for i in 1:length(nonnumericT))...)
554+
newbuf = MTKParameters(
555+
tunables, initials, discretes, constants, nonnumerics, copy.(oldbuf.caches))
556+
end
557+
for i in 1:fieldcount(idxs)
558+
push!(expr.args, :($setindex!(newbuf, vals[$i], idxs[$i])))
559+
end
560+
push!(expr.args, :(return newbuf))
561+
562+
return expr
563+
end
564+
499565
function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true)
500566
newbuf = @set oldbuf.tunable = similar(oldbuf.tunable, Any)
501567
@set! newbuf.initials = similar(oldbuf.initials, Any)

0 commit comments

Comments
 (0)