@@ -646,30 +646,40 @@ struct ReconstructInitializeprob{GP, GU}
646
646
ugetter:: GU
647
647
end
648
648
649
+ """
650
+ $(TYPEDEF)
651
+
652
+ A wrapper over an observed function which allows calling it on a problem-like object.
653
+ `TD` determines whether the getter function is `(u, p, t)` (if `true`) or `(u, p)` (if
654
+ `false`).
655
+ """
656
+ struct ObservedWrapper{TD, F}
657
+ f:: F
658
+ end
659
+
660
+ ObservedWrapper {TD} (f:: F ) where {TD, F} = ObservedWrapper {TD, F} (f)
661
+
662
+ function (ow:: ObservedWrapper{true} )(prob)
663
+ ow. f (state_values (prob), parameter_values (prob), current_time (prob))
664
+ end
665
+
666
+ function (ow:: ObservedWrapper{false} )(prob)
667
+ ow. f (state_values (prob), parameter_values (prob))
668
+ end
669
+
649
670
"""
650
671
$(TYPEDSIGNATURES)
651
672
652
673
Given an index provider `indp` and a vector of symbols `syms` return a type-stable getter
653
- function by splitting `syms` into contiguous buffers where the getter of each buffer
654
- is type-stable and constructing a function that calls and concatenates the results.
655
- """
656
- function concrete_getu (indp, syms:: AbstractVector )
657
- # a list of contiguous buffer
658
- split_syms = [Any[syms[1 ]]]
659
- # the type of the getter of the last buffer
660
- current = typeof (getu (indp, syms[1 ]))
661
- for sym in syms[2 : end ]
662
- getter = getu (indp, sym)
663
- if typeof (getter) != current
664
- # if types don't match, build a new buffer
665
- push! (split_syms, [])
666
- current = typeof (getter)
667
- end
668
- push! (split_syms[end ], sym)
669
- end
670
- split_syms = Tuple (split_syms)
671
- # the getter is now type-stable, and we can vcat it to get the full buffer
672
- return Base. Fix1 (reduce, vcat) ∘ getu (indp, split_syms)
674
+ function.
675
+
676
+ Note that the getter ONLY works for problem-like objects, since it generates an observed
677
+ function. It does NOT work for solutions.
678
+ """
679
+ Base. @nospecializeinfer function concrete_getu (indp, syms:: AbstractVector )
680
+ @nospecialize
681
+ obsfn = build_explicit_observed_function (indp, syms; wrap_delays = false )
682
+ return ObservedWrapper {is_time_dependent(indp)} (obsfn)
673
683
end
674
684
675
685
"""
0 commit comments