@@ -331,6 +331,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
331
331
analytic = nothing ,
332
332
split_idxs = nothing ,
333
333
initializeprob = nothing ,
334
+ update_initializeprob! = nothing ,
334
335
initializeprobmap = nothing ,
335
336
initializeprobpmap = nothing ,
336
337
kwargs... ) where {iip, specialize}
@@ -434,6 +435,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
434
435
sparsity = sparsity ? jacobian_sparsity (sys) : nothing ,
435
436
analytic = analytic,
436
437
initializeprob = initializeprob,
438
+ update_initializeprob! = update_initializeprob!,
437
439
initializeprobmap = initializeprobmap,
438
440
initializeprobpmap = initializeprobpmap)
439
441
end
@@ -778,6 +780,17 @@ function (f::GetUpdatedMTKParameters)(prob, initializesol)
778
780
mtkp
779
781
end
780
782
783
+ struct UpdateInitializeprob{G, S}
784
+ # `getu` functor which gets all values from prob
785
+ getvals:: G
786
+ # `setu` functor which updates initializeprob with values
787
+ setvals:: S
788
+ end
789
+
790
+ function (f:: UpdateInitializeprob )(initializeprob, prob)
791
+ f. setvals (initializeprob, f. getvals (prob))
792
+ end
793
+
781
794
function get_temporary_value (p)
782
795
stype = symtype (unwrap (p))
783
796
return if stype == Real
@@ -865,6 +878,10 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
865
878
getpunknowns = getu (initializeprob, punknowns)
866
879
setpunknowns = setp (sys, punknowns)
867
880
initializeprobpmap = GetUpdatedMTKParameters (getpunknowns, setpunknowns)
881
+ reqd_syms = vcat (
882
+ variable_symbols (initializeprob), parameter_symbols (initializeprob))
883
+ update_initializeprob! = UpdateInitializeprob (
884
+ getu (sys, reqd_syms), setu (initializeprob, reqd_syms))
868
885
869
886
zerovars = Dict (setdiff (unknowns (sys), keys (defaults (sys))) .=> 0.0 )
870
887
if parammap isa SciMLBase. NullParameters
@@ -880,6 +897,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
880
897
(trueinit = SVector {length(trueinit)} (trueinit))
881
898
else
882
899
initializeprob = nothing
900
+ update_initializeprob! = nothing
883
901
initializeprobmap = nothing
884
902
initializeprobpmap = nothing
885
903
trueinit = u0map
@@ -929,6 +947,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
929
947
sparse = sparse, eval_expression = eval_expression,
930
948
eval_module = eval_module,
931
949
initializeprob = initializeprob,
950
+ update_initializeprob! = update_initializeprob!,
932
951
initializeprobmap = initializeprobmap,
933
952
initializeprobpmap = initializeprobpmap,
934
953
kwargs... )
0 commit comments