Skip to content

Commit 8047fbb

Browse files
Setup guesses(sys) and passing override dictionaries
1 parent 250dded commit 8047fbb

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

src/systems/abstractsystem.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ for prop in [:eqs
376376
:var_to_name
377377
:ctrls
378378
:defaults
379+
:guesses
379380
:observed
380381
:tgrad
381382
:jac
@@ -758,6 +759,10 @@ function full_parameters(sys::AbstractSystem)
758759
vcat(parameters(sys), dependent_parameters(sys))
759760
end
760761

762+
function guesses(sys::AbstractSystem)
763+
get_guesses(sys)
764+
end
765+
761766
# required in `src/connectors.jl:437`
762767
parameters(_) = []
763768

src/systems/diffeqs/odesystem.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ struct ODESystem <: AbstractODESystem
8888
"""
8989
defaults::Dict
9090
"""
91+
The guesses to use as the initial conditions for the
92+
initialization system.
93+
"""
94+
guesses::Dict
95+
"""
9196
Tearing result specifying how to solve the system.
9297
"""
9398
torn_matching::Union{Matching, Nothing}
@@ -157,7 +162,7 @@ struct ODESystem <: AbstractODESystem
157162
parent::Any
158163

159164
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
160-
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
165+
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses,
161166
torn_matching, connector_type, preface, cevents,
162167
devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
163168
tearing_state = nothing,
@@ -175,7 +180,7 @@ struct ODESystem <: AbstractODESystem
175180
check_units(u, deqs)
176181
end
177182
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
178-
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
183+
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses, torn_matching,
179184
connector_type, preface, cevents, devents, parameter_dependencies, metadata,
180185
gui_metadata, tearing_state, substitutions, complete, index_cache,
181186
discrete_subsystems, solved_unknowns, split_idxs, parent)
@@ -191,6 +196,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
191196
default_u0 = Dict(),
192197
default_p = Dict(),
193198
defaults = _merge(Dict(default_u0), Dict(default_p)),
199+
guesses = Dict(),
194200
connector_type = nothing,
195201
preface = nothing,
196202
continuous_events = nothing,
@@ -217,6 +223,13 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
217223
var_to_name = Dict()
218224
process_variables!(var_to_name, defaults, dvs′)
219225
process_variables!(var_to_name, defaults, ps′)
226+
227+
sysguesses = [ModelingToolkit.getguess(st) for st in dvs′]
228+
hasaguess = findall(!isnothing, sysguesses)
229+
var_guesses = dvs′[hasaguess] .=> sysguesses[hasaguess]
230+
sysguesses = isempty(var_guesses) ? Dict() : todict(var_guesses)
231+
guesses = merge(sysguesses, todict(guesses))
232+
220233
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
221234

222235
tgrad = RefValue(EMPTY_TGRAD)
@@ -234,11 +247,12 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
234247
parameter_dependencies, ps′)
235248
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
236249
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
237-
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing,
250+
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses, nothing,
238251
connector_type, preface, cont_callbacks, disc_callbacks, parameter_dependencies,
239252
metadata, gui_metadata, checks = checks)
240253
end
241254

255+
242256
function ODESystem(eqs, iv; kwargs...)
243257
eqs = collect(eqs)
244258
# NOTE: this assumes that the order of algebraic equations doesn't matter

src/systems/nonlinear/initializesystem.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,14 @@ function initializesystem(sys::ODESystem; name = nameof(sys), guesses = Dict(),
1212
# Start the equations list with algebraic equations
1313
eqs_ics = eqs[idxs_alge]
1414
u0 = Vector{Pair}(undef, 0)
15-
defs = ModelingToolkit.defaults(sys)
15+
defs = defaults(sys)
1616

1717
full_states = [sts; getfield.((observed(sys)), :lhs)]
1818

1919
# Refactor to ODESystem construction
2020
# should be ModelingToolkit.guesses(sys)
21-
sysguesses = [ModelingToolkit.getguess(st) for st in full_states]
22-
hasaguess = findall(!isnothing, sysguesses)
23-
sysguesses = todict(full_states[hasaguess] .=> sysguesses[hasaguess])
24-
guesses = merge(sysguesses, todict(guesses))
21+
22+
guesses = merge(get_guesses(sys), todict(guesses))
2523

2624
for st in full_states
2725
if st keys(defs)

0 commit comments

Comments
 (0)