Skip to content

Commit 3088d1f

Browse files
Merge pull request #3051 from hersle/cleanup_initialization
Clean up generate_initializesystem()
2 parents b408d0d + 01a7cf9 commit 3088d1f

File tree

5 files changed

+91
-97
lines changed

5 files changed

+91
-97
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 56 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -5,109 +5,81 @@ Generate `NonlinearSystem` which initializes an ODE problem from specified initi
55
"""
66
function generate_initializesystem(sys::ODESystem;
77
u0map = Dict(),
8-
name = nameof(sys),
9-
guesses = Dict(), check_defguess = false,
10-
default_dd_value = 0.0,
11-
algebraic_only = false,
128
initialization_eqs = [],
13-
check_units = true,
14-
kwargs...)
15-
sts, eqs = unknowns(sys), equations(sys)
9+
guesses = Dict(),
10+
default_dd_guess = 0.0,
11+
algebraic_only = false,
12+
check_units = true, check_defguess = false,
13+
name = nameof(sys), kwargs...)
14+
vars = unique([unknowns(sys); getfield.((observed(sys)), :lhs)])
15+
vars_set = Set(vars) # for efficient in-lookup
16+
17+
eqs = equations(sys)
1618
idxs_diff = isdiffeq.(eqs)
1719
idxs_alge = .!idxs_diff
18-
num_alge = sum(idxs_alge)
19-
20-
# Start the equations list with algebraic equations
21-
eqs_ics = eqs[idxs_alge]
22-
u0 = Vector{Pair}(undef, 0)
2320

21+
# prepare map for dummy derivative substitution
2422
eqs_diff = eqs[idxs_diff]
25-
diffmap = Dict(getfield.(eqs_diff, :lhs) .=> getfield.(eqs_diff, :rhs))
26-
observed_diffmap = Dict(Differential(get_iv(sys)).(getfield.((observed(sys)), :lhs)) .=>
27-
Differential(get_iv(sys)).(getfield.((observed(sys)), :rhs)))
28-
full_diffmap = merge(diffmap, observed_diffmap)
23+
D = Differential(get_iv(sys))
24+
diffmap = merge(
25+
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
26+
Dict(D(eq.lhs) => D(eq.rhs) for eq in observed(sys))
27+
)
2928

30-
full_states = unique([sts; getfield.((observed(sys)), :lhs)])
31-
set_full_states = Set(full_states)
29+
# 1) process dummy derivatives and u0map into initialization system
30+
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
31+
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
3232
guesses = merge(get_guesses(sys), todict(guesses))
3333
schedule = getfield(sys, :schedule)
34-
35-
if schedule !== nothing
36-
guessmap = [x[1] => get(guesses, x[1], default_dd_value)
37-
for x in schedule.dummy_sub]
38-
dd_guess = Dict(filter(x -> !isnothing(x[1]), guessmap))
39-
if u0map === nothing || isempty(u0map)
40-
filtered_u0 = u0map
41-
else
42-
filtered_u0 = Pair[]
43-
for x in u0map
44-
y = get(schedule.dummy_sub, x[1], x[1])
45-
y = ModelingToolkit.fixpoint_sub(y, full_diffmap)
46-
47-
if y set_full_states
48-
# defer initialization until defaults are merged below
49-
push!(filtered_u0, y => x[2])
50-
elseif y isa Symbolics.Arr
51-
# scalarize array # TODO: don't scalarize arrays
52-
_y = collect(y)
53-
for i in eachindex(_y)
54-
push!(filtered_u0, _y[i] => x[2][i])
55-
end
56-
elseif y isa Symbolics.BasicSymbolic
57-
# y is a derivative expression expanded
58-
# add to the initialization equations
59-
push!(eqs_ics, y ~ x[2])
60-
else
61-
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
62-
end
34+
if !isnothing(schedule)
35+
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
36+
# set dummy derivatives to default_dd_guess unless specified
37+
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
38+
end
39+
for (y, x) in u0map
40+
y = get(schedule.dummy_sub, y, y)
41+
y = fixpoint_sub(y, diffmap)
42+
if y vars_set
43+
# variables specified in u0 overrides defaults
44+
push!(defs, y => x)
45+
elseif y isa Symbolics.Arr
46+
# TODO: don't scalarize arrays
47+
merge!(defs, Dict(scalarize(y .=> x)))
48+
elseif y isa Symbolics.BasicSymbolic
49+
# y is a derivative expression expanded; add it to the initialization equations
50+
push!(eqs_ics, y ~ x)
51+
else
52+
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
6353
end
64-
filtered_u0 = todict(filtered_u0)
6554
end
66-
else
67-
dd_guess = Dict()
68-
filtered_u0 = todict(u0map)
6955
end
7056

71-
defs = merge(defaults(sys), filtered_u0)
72-
73-
for st in full_states
74-
if st keys(defs)
75-
def = defs[st]
76-
77-
if def isa Equation
78-
st keys(guesses) && check_defguess &&
79-
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
80-
push!(eqs_ics, def)
81-
push!(u0, st => guesses[st])
82-
else
83-
push!(eqs_ics, st ~ def)
84-
push!(u0, st => def)
85-
end
86-
elseif st keys(guesses)
87-
push!(u0, st => guesses[st])
57+
# 2) process other variables
58+
for var in vars
59+
if var keys(defs)
60+
push!(eqs_ics, var ~ defs[var])
61+
elseif var keys(guesses)
62+
push!(defs, var => guesses[var])
8863
elseif check_defguess
89-
error("Invalid setup: unknown $(st) has no default value or initial guess")
64+
error("Invalid setup: variable $(var) has no default value or initial guess")
9065
end
9166
end
9267

68+
# 3) process explicitly provided initialization equations
9369
if !algebraic_only
94-
for eq in [get_initialization_eqs(sys); initialization_eqs]
95-
_eq = ModelingToolkit.fixpoint_sub(eq, full_diffmap)
96-
push!(eqs_ics, _eq)
70+
initialization_eqs = [get_initialization_eqs(sys); initialization_eqs]
71+
for eq in initialization_eqs
72+
eq = fixpoint_sub(eq, diffmap) # expand dummy derivatives
73+
push!(eqs_ics, eq)
9774
end
9875
end
9976

100-
pars = [parameters(sys); get_iv(sys)]
101-
nleqs = [eqs_ics; observed(sys)]
102-
103-
sys_nl = NonlinearSystem(nleqs,
104-
full_states,
105-
pars;
106-
defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess),
107-
parameter_dependencies = parameter_dependencies(sys),
77+
pars = [parameters(sys); get_iv(sys)] # include independent variable as pseudo-parameter
78+
eqs_ics = [eqs_ics; observed(sys)]
79+
return NonlinearSystem(
80+
eqs_ics, vars, pars;
81+
defaults = defs, parameter_dependencies = parameter_dependencies(sys),
10882
checks = check_units,
109-
name,
110-
kwargs...)
111-
112-
return sys_nl
83+
name, kwargs...
84+
)
11385
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -126,25 +126,23 @@ function NonlinearSystem(eqs, unknowns, ps;
126126
throw(ArgumentError("NonlinearSystem does not accept `continuous_events`, you provided $continuous_events"))
127127
discrete_events === nothing || isempty(discrete_events) ||
128128
throw(ArgumentError("NonlinearSystem does not accept `discrete_events`, you provided $discrete_events"))
129-
130129
name === nothing &&
131130
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
132-
# Move things over, but do not touch array expressions
133-
#
134-
# # we cannot scalarize in the loop because `eqs` itself might require
135-
# scalarization
136-
eqs = [x.lhs isa Union{Symbolic, Number} ? 0 ~ x.rhs - x.lhs : x
137-
for x in scalarize(eqs)]
138-
139-
if !(isempty(default_u0) && isempty(default_p))
131+
length(unique(nameof.(systems))) == length(systems) ||
132+
throw(ArgumentError("System names must be unique."))
133+
(isempty(default_u0) && isempty(default_p)) ||
140134
Base.depwarn(
141135
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
142136
:NonlinearSystem, force = true)
137+
138+
# Accept a single (scalar/vector) equation, but make array for consistent internal handling
139+
if !(eqs isa AbstractArray)
140+
eqs = [eqs]
143141
end
144-
sysnames = nameof.(systems)
145-
if length(unique(sysnames)) != length(sysnames)
146-
throw(ArgumentError("System names must be unique."))
147-
end
142+
143+
# Copy equations to canonical form, but do not touch array expressions
144+
eqs = [wrap(eq.lhs) isa Symbolics.Arr ? eq : 0 ~ eq.rhs - eq.lhs for eq in eqs]
145+
148146
jac = RefValue{Any}(EMPTY_JAC)
149147
defaults = todict(defaults)
150148
defaults = Dict{Any, Any}(value(k) => value(v)

test/initializationsystem.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,3 +567,12 @@ oprob_2nd_order_2 = ODEProblem(sys_2nd_order, u0_2nd_order_2, tspan, ps)
567567
sol = solve(oprob_2nd_order_2, Rosenbrock23()) # retcode: Success
568568
@test sol[Y][1] == 2.0
569569
@test sol[D(Y)][1] == 0.5
570+
571+
@testset "Vector in initial conditions" begin
572+
@variables x(t)[1:5] y(t)[1:5]
573+
@named sys = ODESystem([D(x) ~ x, D(y) ~ y], t; initialization_eqs = [y ~ -x])
574+
sys = structural_simplify(sys)
575+
prob = ODEProblem(sys, [sys.x => ones(5)], (0.0, 1.0), [])
576+
sol = solve(prob, Tsit5(), reltol = 1e-4)
577+
@test all(sol(1.0, idxs = sys.x) .≈ +exp(1)) && all(sol(1.0, idxs = sys.y) .≈ -exp(1))
578+
end

test/nonlinearsystem.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,20 @@ end
327327
@test_nowarn solve(prob)
328328
end
329329

330+
@testset "System of linear equations with vector variable" begin
331+
# 1st example in https://en.wikipedia.org/w/index.php?title=System_of_linear_equations&oldid=1247697953
332+
@variables x[1:3]
333+
A = [3 2 -1
334+
2 -2 4
335+
-1 1/2 -1]
336+
b = [1, -2, 0]
337+
@named sys = NonlinearSystem(A * x ~ b, [x], [])
338+
sys = structural_simplify(sys)
339+
prob = NonlinearProblem(sys, unknowns(sys) .=> 0.0)
340+
sol = solve(prob)
341+
@test all(sol[x] .≈ A \ b)
342+
end
343+
330344
@testset "resid_prototype when system has no unknowns and an equation" begin
331345
@variables x
332346
@parameters p

test/reduction.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ A = reshape(1:(N^2), N, N)
178178
eqs = xs ~ A * xs
179179
@named sys′ = NonlinearSystem(eqs, [xs], [])
180180
sys = structural_simplify(sys′)
181+
@test length(equations(sys)) == 3 && length(observed(sys)) == 2
181182

182183
# issue 958
183184
@parameters k₁ k₂ k₋₁ E₀

0 commit comments

Comments
 (0)