@@ -496,7 +496,119 @@ end
496
496
function SymbolicIndexingInterface. remake_buffer (indp, oldbuf:: MTKParameters , idxs, vals)
497
497
_remake_buffer (indp, oldbuf, idxs, vals)
498
498
end
499
+
500
+ # For type-inference when using `SII.setp_oop`
501
+ @generated function _remake_buffer (
502
+ indp, oldbuf:: MTKParameters{T, I, D, C, N, H} ,
503
+ idxs:: Union{Tuple{Vararg{ParameterIndex}}, AbstractArray{<:ParameterIndex{P}}} ,
504
+ vals:: Union{AbstractArray, Tuple} ; validate = true ) where {T, I, D, C, N, H, P}
505
+
506
+ # fallback to non-generated method if values aren't type-stable
507
+ if vals <: AbstractArray && ! isconcretetype (eltype (vals))
508
+ return quote
509
+ $ __remake_buffer (indp, oldbuf, collect (idxs), vals; validate)
510
+ end
511
+ end
512
+
513
+ # given an index in idxs/vals and the current `eltype` of the buffer,
514
+ # return the promoted eltype of the buffer
515
+ function promote_valtype (i, valT)
516
+ # tuples have distinct types, arrays have a common eltype
517
+ valT′ = vals <: AbstractArray ? eltype (vals) : fieldtype (vals, i)
518
+ # if the buffer is a scalarized buffer but the variable is an array
519
+ # e.g. an array tunable, take the eltype
520
+ if valT′ <: AbstractArray && ! (valT <: AbstractArray )
521
+ valT′ = eltype (valT′)
522
+ end
523
+ return promote_type (valT, valT′)
524
+ end
525
+
526
+ # types of the idxs
527
+ idxtypes = if idxs <: AbstractArray
528
+ # if both are arrays, there is only one possible type to check
529
+ if vals <: AbstractArray
530
+ (eltype (idxs),)
531
+ else
532
+ # if `vals` is a tuple, we repeat `eltype(idxs)` to check against
533
+ # every possible type of the buffer
534
+ ntuple (Returns (eltype (idxs)), Val (fieldcount (vals)))
535
+ end
536
+ else
537
+ # `idxs` is a tuple, so we check against all buffers
538
+ fieldtypes (idxs)
539
+ end
540
+ # promote types
541
+ tunablesT = eltype (T)
542
+ for (i, idxT) in enumerate (idxtypes)
543
+ idxT <: ParameterIndex{SciMLStructures.Tunable} || continue
544
+ tunablesT = promote_valtype (i, tunablesT)
545
+ end
546
+ initialsT = eltype (I)
547
+ for (i, idxT) in enumerate (idxtypes)
548
+ idxT <: ParameterIndex{SciMLStructures.Initials} || continue
549
+ initialsT = promote_valtype (i, initialsT)
550
+ end
551
+ discretesT = ntuple (Val (fieldcount (D))) do i
552
+ bufT = eltype (fieldtype (D, i))
553
+ for (j, idxT) in enumerate (idxtypes)
554
+ idxT <: ParameterIndex{SciMLStructures.Discrete, i} || continue
555
+ bufT = promote_valtype (i, bufT)
556
+ end
557
+ bufT
558
+ end
559
+ constantsT = ntuple (Val (fieldcount (C))) do i
560
+ bufT = eltype (fieldtype (C, i))
561
+ for (j, idxT) in enumerate (idxtypes)
562
+ idxT <: ParameterIndex{SciMLStructures.Constants, i} || continue
563
+ bufT = promote_valtype (i, bufT)
564
+ end
565
+ bufT
566
+ end
567
+ nonnumericT = ntuple (Val (fieldcount (N))) do i
568
+ bufT = eltype (fieldtype (N, i))
569
+ for (j, idxT) in enumerate (idxtypes)
570
+ idxT <: ParameterIndex{Nonnumeric, i} || continue
571
+ bufT = promote_valtype (i, bufT)
572
+ end
573
+ bufT
574
+ end
575
+
576
+ expr = quote
577
+ tunables = $ similar (oldbuf. tunable, $ tunablesT)
578
+ copyto! (tunables, oldbuf. tunable)
579
+ initials = $ similar (oldbuf. initials, $ initialsT)
580
+ copyto! (initials, oldbuf. initials)
581
+ discretes = $ (Expr (:tuple ,
582
+ (:($ similar (oldbuf. discrete[$ i], $ (discretesT[i]))) for i in 1 : length (discretesT)). .. ))
583
+ $ ((:($ copyto! (discretes[$ i], oldbuf. discrete[$ i])) for i in 1 : length (discretesT)). .. )
584
+ constants = $ (Expr (:tuple ,
585
+ (:($ similar (oldbuf. constant[$ i], $ (constantsT[i]))) for i in 1 : length (constantsT)). .. ))
586
+ $ ((:($ copyto! (constants[$ i], oldbuf. constant[$ i])) for i in 1 : length (constantsT)). .. )
587
+ nonnumerics = $ (Expr (:tuple ,
588
+ (:($ similar (oldbuf. nonnumeric[$ i], $ (nonnumericT[i]))) for i in 1 : length (nonnumericT)). .. ))
589
+ $ ((:($ copyto! (nonnumerics[$ i], oldbuf. nonnumeric[$ i])) for i in 1 : length (nonnumericT)). .. )
590
+ newbuf = MTKParameters (
591
+ tunables, initials, discretes, constants, nonnumerics, copy .(oldbuf. caches))
592
+ end
593
+ if idxs <: AbstractArray
594
+ push! (expr. args, :(for (idx, val) in zip (idxs, vals)
595
+ $ setindex! (newbuf, val, idx)
596
+ end ))
597
+ else
598
+ for i in 1 : fieldcount (idxs)
599
+ push! (expr. args, :($ setindex! (newbuf, vals[$ i], idxs[$ i])))
600
+ end
601
+ end
602
+ push! (expr. args, :(return newbuf))
603
+
604
+ return expr
605
+ end
606
+
499
607
function _remake_buffer (indp, oldbuf:: MTKParameters , idxs, vals; validate = true )
608
+ return __remake_buffer (indp, oldbuf, idxs, vals; validate)
609
+ end
610
+
611
+ function __remake_buffer (indp, oldbuf:: MTKParameters , idxs, vals; validate = true )
500
612
newbuf = @set oldbuf. tunable = similar (oldbuf. tunable, Any)
501
613
@set! newbuf. initials = similar (oldbuf. initials, Any)
502
614
@set! newbuf. discrete = Tuple (similar (buf, Any) for buf in newbuf. discrete)
0 commit comments