@@ -499,40 +499,76 @@ end
499
499
500
500
# For type-inference when using `SII.setp_oop`
501
501
@generated function _remake_buffer (
502
- indp, oldbuf:: MTKParameters{T, I, D, C, N, H} , idxs:: Tuple{Vararg{ParameterIndex}} ,
503
- vals:: Union{AbstractArray, Tuple} ; validate = true ) where {T, I, D, C, N, H}
504
- valtype (i) = vals <: AbstractArray ? eltype (vals) : fieldtype (vals, i)
502
+ indp, oldbuf:: MTKParameters{T, I, D, C, N, H} ,
503
+ idxs:: Union{Tuple{Vararg{ParameterIndex}}, AbstractArray{<:ParameterIndex{P}}} ,
504
+ vals:: Union{AbstractArray, Tuple} ; validate = true ) where {T, I, D, C, N, H, P}
505
+
506
+ # fallback to non-generated method if values aren't type-stable
507
+ if vals <: AbstractArray && ! isconcretetype (eltype (vals))
508
+ return quote
509
+ $ _remake_buffer (indp, oldbuf, collect (idxs), vals; validate)
510
+ end
511
+ end
512
+
513
+ # given an index in idxs/vals and the current `eltype` of the buffer,
514
+ # return the promoted eltype of the buffer
515
+ function promote_valtype (i, valT)
516
+ # tuples have distinct types, arrays have a common eltype
517
+ valT′ = vals <: AbstractArray ? eltype (vals) : fieldtype (vals, i)
518
+ # if the buffer is a scalarized buffer but the variable is an array
519
+ # e.g. an array tunable, take the eltype
520
+ if valT′ <: AbstractArray && ! (valT <: AbstractArray )
521
+ valT′ = eltype (valT′)
522
+ end
523
+ return promote_type (valT, valT′)
524
+ end
525
+
526
+ # types of the idxs
527
+ idxtypes = if idxs <: AbstractArray
528
+ # if both are arrays, there is only one possible type to check
529
+ if vals <: AbstractArray
530
+ (eltype (idxs),)
531
+ else
532
+ # if `vals` is a tuple, we repeat `eltype(idxs)` to check against
533
+ # every possible type of the buffer
534
+ ntuple (Returns (eltype (idxs)), Val (fieldcount (vals)))
535
+ end
536
+ else
537
+ # `idxs` is a tuple, so we check against all buffers
538
+ fieldtypes (idxs)
539
+ end
540
+ # promote types
505
541
tunablesT = eltype (T)
506
- for (i, idxT) in enumerate (fieldtypes (idxs) )
542
+ for (i, idxT) in enumerate (idxtypes )
507
543
idxT <: ParameterIndex{SciMLStructures.Tunable} || continue
508
- tunablesT = promote_type (tunablesT, valtype (i) )
544
+ tunablesT = promote_valtype (i, tunablesT )
509
545
end
510
546
initialsT = eltype (I)
511
- for (i, idxT) in enumerate (fieldtypes (idxs) )
547
+ for (i, idxT) in enumerate (idxtypes )
512
548
idxT <: ParameterIndex{SciMLStructures.Initials} || continue
513
- initialsT = promote_type (initialsT, valtype (i) )
549
+ initialsT = promote_valtype (i, initialsT )
514
550
end
515
551
discretesT = ntuple (Val (fieldcount (D))) do i
516
552
bufT = eltype (fieldtype (D, i))
517
- for (j, idxT) in enumerate (fieldtypes (idxs) )
553
+ for (j, idxT) in enumerate (idxtypes )
518
554
idxT <: ParameterIndex{SciMLStructures.Discrete, i} || continue
519
- bufT = promote_type (bufT, valtype (i) )
555
+ bufT = promote_valtype (i, bufT )
520
556
end
521
557
bufT
522
558
end
523
559
constantsT = ntuple (Val (fieldcount (C))) do i
524
560
bufT = eltype (fieldtype (C, i))
525
- for (j, idxT) in enumerate (fieldtypes (idxs) )
561
+ for (j, idxT) in enumerate (idxtypes )
526
562
idxT <: ParameterIndex{SciMLStructures.Constants, i} || continue
527
- bufT = promote_type (bufT, valtype (i) )
563
+ bufT = promote_valtype (i, bufT )
528
564
end
529
565
bufT
530
566
end
531
567
nonnumericT = ntuple (Val (fieldcount (N))) do i
532
568
bufT = eltype (fieldtype (N, i))
533
- for (j, idxT) in enumerate (fieldtypes (idxs) )
569
+ for (j, idxT) in enumerate (idxtypes )
534
570
idxT <: ParameterIndex{Nonnumeric, i} || continue
535
- bufT = promote_type (bufT, valtype (i) )
571
+ bufT = promote_valtype (i, bufT )
536
572
end
537
573
bufT
538
574
end
554
590
newbuf = MTKParameters (
555
591
tunables, initials, discretes, constants, nonnumerics, copy .(oldbuf. caches))
556
592
end
557
- for i in 1 : fieldcount (idxs)
558
- push! (expr. args, :($ setindex! (newbuf, vals[$ i], idxs[$ i])))
593
+ if idxs <: AbstractArray
594
+ push! (expr. args, :(for (idx, val) in zip (idxs, vals)
595
+ $ setindex! (newbuf, val, idx)
596
+ end ))
597
+ else
598
+ for i in 1 : fieldcount (idxs)
599
+ push! (expr. args, :($ setindex! (newbuf, vals[$ i], idxs[$ i])))
600
+ end
559
601
end
560
602
push! (expr. args, :(return newbuf))
561
603
0 commit comments