Skip to content

Commit f79f997

Browse files
feat: support update_initializeprob!
1 parent 191b9a1 commit f79f997

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
331331
analytic = nothing,
332332
split_idxs = nothing,
333333
initializeprob = nothing,
334+
update_initializeprob! = nothing,
334335
initializeprobmap = nothing,
335336
initializeprobpmap = nothing,
336337
kwargs...) where {iip, specialize}
@@ -434,6 +435,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
434435
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
435436
analytic = analytic,
436437
initializeprob = initializeprob,
438+
update_initializeprob! = update_initializeprob!,
437439
initializeprobmap = initializeprobmap,
438440
initializeprobpmap = initializeprobpmap)
439441
end
@@ -778,6 +780,17 @@ function (f::GetUpdatedMTKParameters)(prob, initializesol)
778780
mtkp
779781
end
780782

783+
struct UpdateInitializeprob{G, S}
784+
# `getu` functor which gets all values from prob
785+
getvals::G
786+
# `setu` functor which updates initializeprob with values
787+
setvals::S
788+
end
789+
790+
function (f::UpdateInitializeprob)(initializeprob, prob)
791+
f.setvals(initializeprob, f.getvals(prob))
792+
end
793+
781794
function get_temporary_value(p)
782795
stype = symtype(unwrap(p))
783796
return if stype == Real
@@ -865,6 +878,10 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
865878
getpunknowns = getu(initializeprob, punknowns)
866879
setpunknowns = setp(sys, punknowns)
867880
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
881+
reqd_syms = vcat(
882+
variable_symbols(initializeprob), parameter_symbols(initializeprob))
883+
update_initializeprob! = UpdateInitializeprob(
884+
getu(sys, reqd_syms), setu(initializeprob, reqd_syms))
868885

869886
zerovars = Dict(setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0)
870887
if parammap isa SciMLBase.NullParameters
@@ -880,6 +897,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
880897
(trueinit = SVector{length(trueinit)}(trueinit))
881898
else
882899
initializeprob = nothing
900+
update_initializeprob! = nothing
883901
initializeprobmap = nothing
884902
initializeprobpmap = nothing
885903
trueinit = u0map
@@ -929,6 +947,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
929947
sparse = sparse, eval_expression = eval_expression,
930948
eval_module = eval_module,
931949
initializeprob = initializeprob,
950+
update_initializeprob! = update_initializeprob!,
932951
initializeprobmap = initializeprobmap,
933952
initializeprobpmap = initializeprobpmap,
934953
kwargs...)

src/systems/nonlinear/initializesystem.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,5 +201,8 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
201201
getpunknowns = getu(initprob, punknowns)
202202
setpunknowns = setp(sys, punknowns)
203203
initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
204-
return initprob, initprobmap, initprobpmap
204+
reqd_syms = vcat(variable_symbols(initprob), parameter_symbols(initprob))
205+
update_initializeprob! = UpdateInitializeprob(
206+
getu(sys, reqd_syms), setu(initprob, reqd_syms))
207+
return initprob, update_initializeprob!, initprobmap, initprobpmap
205208
end

test/initializationsystem.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,3 +690,22 @@ end
690690
prob5 = remake(prob)
691691
@test init(prob, Tsit5()).ps[p] 2.0
692692
end
693+
694+
@testset "Update initializeprob parameters" begin
695+
@variables x(t) y(t)
696+
@parameters p q
697+
@mtkbuild sys = ODESystem(
698+
[D(x) ~ x, p ~ x + y], t; guesses = [x => 0.0, p => 0.0])
699+
prob = ODEProblem(sys, [y => 1.0], (0.0, 1.0), [p => 3.0])
700+
@test prob.f.initializeprob.ps[p] 3.0
701+
@test init(prob, Tsit5())[x] 2.0
702+
prob.ps[p] = 2.0
703+
@test prob.f.initializeprob.ps[p] 3.0
704+
@test init(prob, Tsit5())[x] 1.0
705+
ModelingToolkit.defaults(prob.f.sys)[p] = missing
706+
prob2 = remake(prob; u0 = [y => 1.0], p = [p => 3x])
707+
@test !is_variable(prob2.f.initializeprob, p) &&
708+
!is_parameter(prob2.f.initializeprob, p)
709+
@test init(prob2, Tsit5())[x] 0.5
710+
@test_nowarn solve(prob2, Tsit5())
711+
end

0 commit comments

Comments
 (0)