Skip to content

Commit b237706

Browse files
committed
Clean up generate_initializesystem()
1 parent 85d8d10 commit b237706

File tree

1 file changed

+58
-75
lines changed

1 file changed

+58
-75
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 58 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -5,109 +5,92 @@ Generate `NonlinearSystem` which initializes an ODE problem from specified initi
55
"""
66
function generate_initializesystem(sys::ODESystem;
77
u0map = Dict(),
8-
name = nameof(sys),
9-
guesses = Dict(), check_defguess = false,
10-
default_dd_value = 0.0,
11-
algebraic_only = false,
128
initialization_eqs = [],
13-
check_units = true,
14-
kwargs...)
15-
sts, eqs = unknowns(sys), equations(sys)
9+
guesses = Dict(),
10+
default_dd_guess = 0.0,
11+
algebraic_only = false,
12+
check_units = true, check_defguess = false,
13+
name = nameof(sys), kwargs...)
14+
vars = unique([unknowns(sys); getfield.((observed(sys)), :lhs)])
15+
vars_set = Set(vars) # for efficient in-lookup
16+
17+
eqs = equations(sys)
1618
idxs_diff = isdiffeq.(eqs)
1719
idxs_alge = .!idxs_diff
18-
num_alge = sum(idxs_alge)
19-
20-
# Start the equations list with algebraic equations
21-
eqs_ics = eqs[idxs_alge]
22-
u0 = Vector{Pair}(undef, 0)
2320

21+
# prepare map for dummy derivative substitution
2422
eqs_diff = eqs[idxs_diff]
25-
diffmap = Dict(getfield.(eqs_diff, :lhs) .=> getfield.(eqs_diff, :rhs))
26-
observed_diffmap = Dict(Differential(get_iv(sys)).(getfield.((observed(sys)), :lhs)) .=>
27-
Differential(get_iv(sys)).(getfield.((observed(sys)), :rhs)))
28-
full_diffmap = merge(diffmap, observed_diffmap)
23+
D = Differential(get_iv(sys))
24+
diffmap = merge(
25+
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
26+
Dict(D(eq.lhs) => D(eq.rhs) for eq in observed(sys))
27+
)
2928

30-
full_states = unique([sts; getfield.((observed(sys)), :lhs)])
31-
set_full_states = Set(full_states)
29+
# 1) process dummy derivatives and u0map into initialization system
30+
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
31+
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
3232
guesses = merge(get_guesses(sys), todict(guesses))
3333
schedule = getfield(sys, :schedule)
34-
35-
if schedule !== nothing
36-
guessmap = [x[1] => get(guesses, x[1], default_dd_value)
37-
for x in schedule.dummy_sub]
38-
dd_guess = Dict(filter(x -> !isnothing(x[1]), guessmap))
39-
if u0map === nothing || isempty(u0map)
40-
filtered_u0 = u0map
41-
else
42-
filtered_u0 = Pair[]
43-
for x in u0map
44-
y = get(schedule.dummy_sub, x[1], x[1])
45-
y = ModelingToolkit.fixpoint_sub(y, full_diffmap)
46-
47-
if y set_full_states
48-
# defer initialization until defaults are merged below
49-
push!(filtered_u0, y => x[2])
34+
if !isnothing(schedule)
35+
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
36+
# set dummy derivatives to default_dd_guess unless specified
37+
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
38+
end
39+
if !isnothing(u0map)
40+
for (y, x) in u0map
41+
y = get(schedule.dummy_sub, y, y)
42+
y = fixpoint_sub(y, diffmap)
43+
if y vars_set
44+
# variables specified in u0 overrides defaults
45+
push!(defs, y => x)
5046
elseif y isa Symbolics.Arr
51-
# scalarize array # TODO: don't scalarize arrays
52-
_y = collect(y)
53-
for i in eachindex(_y)
54-
push!(filtered_u0, _y[i] => x[2][i])
55-
end
47+
# TODO: don't scalarize arrays
48+
push!(defs, collect(y) .=> x)
5649
elseif y isa Symbolics.BasicSymbolic
57-
# y is a derivative expression expanded
58-
# add to the initialization equations
59-
push!(eqs_ics, y ~ x[2])
50+
# y is a derivative expression expanded; add it to the initialization equations
51+
push!(eqs_ics, y ~ x)
6052
else
6153
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
6254
end
6355
end
64-
filtered_u0 = todict(filtered_u0)
6556
end
66-
else
67-
dd_guess = Dict()
68-
filtered_u0 = todict(u0map)
6957
end
7058

71-
defs = merge(defaults(sys), filtered_u0)
72-
73-
for st in full_states
74-
if st keys(defs)
75-
def = defs[st]
76-
59+
# 2) process other variables
60+
for var in vars
61+
if var keys(defs)
62+
def = defs[var]
7763
if def isa Equation
78-
st keys(guesses) && check_defguess &&
79-
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
64+
# TODO: this behavior is not tested!
65+
var keys(guesses) && check_defguess &&
66+
error("Invalid setup: variable $(var) has an initial condition equation with no guess.")
8067
push!(eqs_ics, def)
81-
push!(u0, st => guesses[st])
68+
push!(defs, var => guesses[var])
8269
else
83-
push!(eqs_ics, st ~ def)
84-
push!(u0, st => def)
70+
push!(eqs_ics, var ~ def)
8571
end
86-
elseif st keys(guesses)
87-
push!(u0, st => guesses[st])
72+
elseif var keys(guesses)
73+
push!(defs, var => guesses[var])
8874
elseif check_defguess
89-
error("Invalid setup: unknown $(st) has no default value or initial guess")
75+
error("Invalid setup: variable $(var) has no default value or initial guess")
9076
end
9177
end
9278

79+
# 3) process explicitly provided initialization equations
9380
if !algebraic_only
94-
for eq in [get_initialization_eqs(sys); initialization_eqs]
95-
_eq = ModelingToolkit.fixpoint_sub(eq, full_diffmap)
96-
push!(eqs_ics, _eq)
81+
initialization_eqs = [get_initialization_eqs(sys); initialization_eqs]
82+
for eq in initialization_eqs
83+
eq = fixpoint_sub(eq, diffmap) # expand dummy derivatives
84+
push!(eqs_ics, eq)
9785
end
9886
end
9987

100-
pars = [parameters(sys); get_iv(sys)]
101-
nleqs = [eqs_ics; observed(sys)]
102-
103-
sys_nl = NonlinearSystem(nleqs,
104-
full_states,
105-
pars;
106-
defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess),
107-
parameter_dependencies = parameter_dependencies(sys),
88+
pars = [parameters(sys); get_iv(sys)] # include independent variable as pseudo-parameter
89+
eqs_ics = [eqs_ics; observed(sys)]
90+
return NonlinearSystem(
91+
eqs_ics, vars, pars;
92+
defaults = defs, parameter_dependencies = parameter_dependencies(sys),
10893
checks = check_units,
109-
name,
110-
kwargs...)
111-
112-
return sys_nl
94+
name, kwargs...
95+
)
11396
end

0 commit comments

Comments
 (0)