Skip to content

Commit 85e5130

Browse files
Merge pull request #2403 from SciML/initializesystem
Fix up initializesystem for hierarchical models
2 parents 7c85930 + c78321f commit 85e5130

19 files changed

+727
-169
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ Libdl = "1"
8989
LinearAlgebra = "1"
9090
MLStyle = "0.4.17"
9191
NaNMath = "0.3, 1"
92-
OrdinaryDiffEq = "6"
92+
OrdinaryDiffEq = "6.72.0"
9393
PrecompileTools = "1"
9494
RecursiveArrayTools = "2.3, 3"
9595
Reexport = "0.2, 1"
9696
RuntimeGeneratedFunctions = "0.5.9"
97-
SciMLBase = "2.0.1"
97+
SciMLBase = "2.28.0"
9898
SciMLStructures = "1.0"
9999
Serialization = "1"
100100
Setfield = "0.7, 0.8, 1"

src/ModelingToolkit.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,18 @@ include("systems/model_parsing.jl")
137137
include("systems/connectors.jl")
138138
include("systems/callbacks.jl")
139139

140+
include("systems/nonlinear/nonlinearsystem.jl")
140141
include("systems/diffeqs/odesystem.jl")
141142
include("systems/diffeqs/sdesystem.jl")
142143
include("systems/diffeqs/abstractodesystem.jl")
144+
include("systems/nonlinear/modelingtoolkitize.jl")
145+
include("systems/nonlinear/initializesystem.jl")
143146
include("systems/diffeqs/first_order_transform.jl")
144147
include("systems/diffeqs/modelingtoolkitize.jl")
145148
include("systems/diffeqs/basic_transformations.jl")
146149

147150
include("systems/jumps/jumpsystem.jl")
148151

149-
include("systems/nonlinear/nonlinearsystem.jl")
150-
include("systems/nonlinear/modelingtoolkitize.jl")
151-
include("systems/nonlinear/initializesystem.jl")
152-
153152
include("systems/optimization/constraints_system.jl")
154153
include("systems/optimization/optimizationsystem.jl")
155154
include("systems/optimization/modelingtoolkitize.jl")
@@ -253,7 +252,7 @@ export toexpr, get_variables
253252
export simplify, substitute
254253
export build_function
255254
export modelingtoolkitize
256-
export initializesystem
255+
export initializesystem, generate_initializesystem
257256

258257
export @variables, @parameters, @constants, @brownian
259258
export @named, @nonamespace, @namespace, extend, compose, complete

src/bipartite_graph.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ return `false` may not be matched.
423423
"""
424424
function maximal_matching(g::BipartiteGraph, srcfilter = vsrc -> true,
425425
dstfilter = vdst -> true, ::Type{U} = Unassigned) where {U}
426-
matching = Matching{U}(ndsts(g))
426+
matching = Matching{U}(max(nsrcs(g), ndsts(g)))
427427
foreach(Iterators.filter(srcfilter, 𝑠vertices(g))) do vsrc
428428
construct_augmenting_path!(matching, g, vsrc, dstfilter)
429429
end

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2424
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
2525
filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
26-
fast_substitute, get_fullvars, has_equations, observed
26+
fast_substitute, get_fullvars, has_equations, observed,
27+
Schedule
2728

2829
using ModelingToolkit.BipartiteGraphs
2930
import .BipartiteGraphs: invview, complete

src/structural_transformation/symbolics_tearing.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
555555
# TODO: compute the dependency correctly so that we don't have to do this
556556
obs = [fast_substitute(observed(sys), obs_sub); subeqs]
557557
@set! sys.observed = obs
558+
559+
# Only makes sense for time-dependent
560+
# TODO: generalize to SDE
561+
if sys isa ODESystem
562+
@set! sys.schedule = Schedule(var_eq_matching, dummy_sub)
563+
end
558564
@set! state.sys = sys
559565
@set! sys.tearing_state = state
560566
return invalidate_cache!(sys)

src/systems/abstractsystem.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,9 @@ function complete(sys::AbstractSystem; split = true)
537537
if split && has_index_cache(sys)
538538
@set! sys.index_cache = IndexCache(sys)
539539
end
540+
if isdefined(sys, :initializesystem) && get_initializesystem(sys) !== nothing
541+
@set! sys.initializesystem = complete(get_initializesystem(sys); split)
542+
end
540543
isdefined(sys, :complete) ? (@set! sys.complete = true) : sys
541544
end
542545

@@ -551,6 +554,7 @@ for prop in [:eqs
551554
:var_to_name
552555
:ctrls
553556
:defaults
557+
:guesses
554558
:observed
555559
:tgrad
556560
:jac
@@ -571,6 +575,8 @@ for prop in [:eqs
571575
:connections
572576
:preface
573577
:torn_matching
578+
:initializesystem
579+
:schedule
574580
:tearing_state
575581
:substitutions
576582
:metadata
@@ -933,6 +939,10 @@ function full_parameters(sys::AbstractSystem)
933939
vcat(parameters(sys), dependent_parameters(sys))
934940
end
935941

942+
function guesses(sys::AbstractSystem)
943+
get_guesses(sys)
944+
end
945+
936946
# required in `src/connectors.jl:437`
937947
parameters(_) = []
938948

@@ -2259,14 +2269,15 @@ function UnPack.unpack(sys::ModelingToolkit.AbstractSystem, ::Val{p}) where {p}
22592269
end
22602270

22612271
"""
2262-
missing_variable_defaults(sys::AbstractSystem, default = 0.0)
2272+
missing_variable_defaults(sys::AbstractSystem, default = 0.0; subset = unknowns(sys))
22632273
22642274
returns a `Vector{Pair}` of variables set to `default` which are missing from `get_defaults(sys)`. The `default` argument can be a single value or vector to set the missing defaults respectively.
22652275
"""
2266-
function missing_variable_defaults(sys::AbstractSystem, default = 0.0)
2276+
function missing_variable_defaults(
2277+
sys::AbstractSystem, default = 0.0; subset = unknowns(sys))
22672278
varmap = get_defaults(sys)
22682279
varmap = Dict(Symbolics.diff2term(value(k)) => value(varmap[k]) for k in keys(varmap))
2269-
missingvars = setdiff(unknowns(sys), keys(varmap))
2280+
missingvars = setdiff(subset, keys(varmap))
22702281
ds = Pair[]
22712282

22722283
n = length(missingvars)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 134 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
struct Schedule
2+
var_eq_matching::Any
3+
dummy_sub::Any
4+
end
5+
16
function filter_kwargs(kwargs)
27
kwargs = Dict(kwargs)
38
for key in keys(kwargs)
@@ -316,6 +321,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
316321
sparsity = false,
317322
analytic = nothing,
318323
split_idxs = nothing,
324+
initializeprob = nothing,
325+
initializeprobmap = nothing,
319326
kwargs...) where {iip, specialize}
320327
if !iscomplete(sys)
321328
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
@@ -487,6 +494,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
487494
end
488495

489496
@set! sys.split_idxs = split_idxs
497+
490498
ODEFunction{iip, specialize}(f;
491499
sys = sys,
492500
jac = _jac === nothing ? nothing : _jac,
@@ -495,7 +503,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
495503
jac_prototype = jac_prototype,
496504
observed = observedfun,
497505
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
498-
analytic = analytic)
506+
analytic = analytic,
507+
initializeprob = initializeprob,
508+
initializeprobmap = initializeprobmap)
499509
end
500510

501511
"""
@@ -525,6 +535,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
525535
sparse = false, simplify = false,
526536
eval_module = @__MODULE__,
527537
checkbounds = false,
538+
initializeprob = nothing,
539+
initializeprobmap = nothing,
528540
kwargs...) where {iip}
529541
if !iscomplete(sys)
530542
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
@@ -596,7 +608,9 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
596608
sys = sys,
597609
jac = _jac === nothing ? nothing : _jac,
598610
jac_prototype = jac_prototype,
599-
observed = observedfun)
611+
observed = observedfun,
612+
initializeprob = initializeprob,
613+
initializeprobmap = initializeprobmap)
600614
end
601615

602616
function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
@@ -839,18 +853,46 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
839853
tofloat = true,
840854
symbolic_u0 = false,
841855
u0_constructor = identity,
856+
guesses = Dict(),
857+
t = nothing,
858+
warn_initialize_determined = true,
842859
kwargs...)
843860
eqs = equations(sys)
844861
dvs = unknowns(sys)
845862
ps = full_parameters(sys)
846863
iv = get_iv(sys)
847864

865+
# Append zeros to the variables which are determined by the initialization system
866+
# This essentially bypasses the check for if initial conditions are defined for DAEs
867+
# since they will be checked in the initialization problem's construction
868+
# TODO: make check for if a DAE cheaper than calculating the mass matrix a second time!
869+
ci = infer_clocks!(ClockInference(TearingState(sys)))
870+
# TODO: make it work with clocks
871+
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
872+
if (implicit_dae || calculate_massmatrix(sys) !== I) &&
873+
all(isequal(Continuous()), ci.var_domain) &&
874+
ModelingToolkit.get_tearing_state(sys) !== nothing
875+
if eltype(u0map) <: Number
876+
u0map = unknowns(sys) .=> u0map
877+
end
878+
initializeprob = ModelingToolkit.InitializationProblem(
879+
sys, t, u0map, parammap; guesses, warn_initialize_determined)
880+
initializeprobmap = getu(initializeprob, unknowns(sys))
881+
882+
zerovars = setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0
883+
trueinit = identity.([zerovars; u0map])
884+
else
885+
initializeprob = nothing
886+
initializeprobmap = nothing
887+
trueinit = u0map
888+
end
889+
848890
if has_index_cache(sys) && get_index_cache(sys) !== nothing
849-
u0, defs = get_u0(sys, u0map, parammap; symbolic_u0)
891+
u0, defs = get_u0(sys, trueinit, parammap; symbolic_u0)
850892
p = MTKParameters(sys, parammap)
851893
else
852894
u0, p, defs = get_u0_p(sys,
853-
u0map,
895+
trueinit,
854896
parammap;
855897
tofloat,
856898
use_union,
@@ -881,6 +923,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
881923
checkbounds = checkbounds, p = p,
882924
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
883925
sparse = sparse, eval_expression = eval_expression,
926+
initializeprob = initializeprob,
927+
initializeprobmap = initializeprobmap,
884928
kwargs...)
885929
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
886930
end
@@ -984,13 +1028,14 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
9841028
parammap = DiffEqBase.NullParameters();
9851029
callback = nothing,
9861030
check_length = true,
1031+
warn_initialize_determined = true,
9871032
kwargs...) where {iip, specialize}
9881033
if !iscomplete(sys)
9891034
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
9901035
end
9911036
f, u0, p = process_DEProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
9921037
t = tspan !== nothing ? tspan[1] : tspan,
993-
check_length, kwargs...)
1038+
check_length, warn_initialize_determined, kwargs...)
9941039
cbs = process_events(sys; callback, kwargs...)
9951040
inits = []
9961041
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
@@ -1055,13 +1100,15 @@ end
10551100

10561101
function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
10571102
parammap = DiffEqBase.NullParameters();
1103+
warn_initialize_determined = true,
10581104
check_length = true, kwargs...) where {iip}
10591105
if !iscomplete(sys)
10601106
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblem`")
10611107
end
10621108
f, du0, u0, p = process_DEProblem(DAEFunction{iip}, sys, u0map, parammap;
10631109
implicit_dae = true, du0map = du0map, check_length,
1064-
kwargs...)
1110+
t = tspan !== nothing ? tspan[1] : tspan,
1111+
warn_initialize_determined, kwargs...)
10651112
diffvars = collect_differential_variables(sys)
10661113
sts = unknowns(sys)
10671114
differential_vars = map(Base.Fix2(in, diffvars), sts)
@@ -1237,6 +1284,7 @@ function ODEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,
12371284
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `ODEProblemExpr`")
12381285
end
12391286
f, u0, p = process_DEProblem(ODEFunctionExpr{iip}, sys, u0map, parammap; check_length,
1287+
t = tspan !== nothing ? tspan[1] : tspan,
12401288
kwargs...)
12411289
linenumbers = get(kwargs, :linenumbers, true)
12421290
kwargs = filter_kwargs(kwargs)
@@ -1282,6 +1330,7 @@ function DAEProblemExpr{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
12821330
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblemExpr`")
12831331
end
12841332
f, du0, u0, p = process_DEProblem(DAEFunctionExpr{iip}, sys, u0map, parammap;
1333+
t = tspan !== nothing ? tspan[1] : tspan,
12851334
implicit_dae = true, du0map = du0map, check_length,
12861335
kwargs...)
12871336
linenumbers = get(kwargs, :linenumbers, true)
@@ -1442,3 +1491,82 @@ function flatten_equations(eqs)
14421491
end
14431492
end
14441493
end
1494+
1495+
struct InitializationProblem{iip, specialization} end
1496+
1497+
"""
1498+
```julia
1499+
InitializationProblem{iip}(sys::AbstractODESystem, u0map, tspan,
1500+
parammap = DiffEqBase.NullParameters();
1501+
version = nothing, tgrad = false,
1502+
jac = false,
1503+
checkbounds = false, sparse = false,
1504+
simplify = false,
1505+
linenumbers = true, parallel = SerialForm(),
1506+
kwargs...) where {iip}
1507+
```
1508+
1509+
Generates a NonlinearProblem or NonlinearLeastSquaresProblem from an ODESystem
1510+
which represents the initialization, i.e. the calculation of the consistent
1511+
initial conditions for the given DAE.
1512+
"""
1513+
function InitializationProblem(sys::AbstractODESystem, args...; kwargs...)
1514+
InitializationProblem{true}(sys, args...; kwargs...)
1515+
end
1516+
1517+
function InitializationProblem(sys::AbstractODESystem, t,
1518+
u0map::StaticArray,
1519+
args...;
1520+
kwargs...)
1521+
InitializationProblem{false, SciMLBase.FullSpecialize}(
1522+
sys, t, u0map, args...; kwargs...)
1523+
end
1524+
1525+
function InitializationProblem{true}(sys::AbstractODESystem, args...; kwargs...)
1526+
InitializationProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
1527+
end
1528+
1529+
function InitializationProblem{false}(sys::AbstractODESystem, args...; kwargs...)
1530+
InitializationProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
1531+
end
1532+
1533+
function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
1534+
t::Number, u0map = [],
1535+
parammap = DiffEqBase.NullParameters();
1536+
guesses = [],
1537+
check_length = true,
1538+
warn_initialize_determined = true,
1539+
kwargs...) where {iip, specialize}
1540+
if !iscomplete(sys)
1541+
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
1542+
end
1543+
1544+
if isempty(u0map) && get_initializesystem(sys) !== nothing
1545+
isys = get_initializesystem(sys)
1546+
elseif isempty(u0map) && get_initializesystem(sys) === nothing
1547+
isys = structural_simplify(generate_initializesystem(sys); fully_determined = false)
1548+
else
1549+
isys = structural_simplify(
1550+
generate_initializesystem(sys; u0map); fully_determined = false)
1551+
end
1552+
1553+
neqs = length(equations(isys))
1554+
nunknown = length(unknowns(isys))
1555+
1556+
if warn_initialize_determined && neqs > nunknown
1557+
@warn "Initialization system is overdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
1558+
end
1559+
if warn_initialize_determined && neqs < nunknown
1560+
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
1561+
end
1562+
1563+
parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
1564+
[get_iv(sys) => t] :
1565+
merge(todict(parammap), Dict(get_iv(sys) => t))
1566+
1567+
if neqs == nunknown
1568+
NonlinearProblem(isys, guesses, parammap)
1569+
else
1570+
NonlinearLeastSquaresProblem(isys, guesses, parammap)
1571+
end
1572+
end

0 commit comments

Comments
 (0)