Skip to content

Commit a3f26ec

Browse files
fixup! feat: allow parameters to be unknowns in the initialization system
1 parent a427126 commit a3f26ec

File tree

1 file changed

+54
-19
lines changed

1 file changed

+54
-19
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -96,36 +96,59 @@ function generate_initializesystem(sys::ODESystem;
9696
if pmap isa SciMLBase.NullParameters
9797
pmap = Dict()
9898
end
99+
pmap = todict(pmap)
99100
for p in parameters(sys)
100-
# If either of them are `missing` the parameter is an unknown
101-
# But if the parameter is passed a value, use that as an additional
102-
# equation in the system
103-
if (_val1 = get(pmap, p, nothing)) === missing || get(defs, p, nothing) === missing
101+
if is_parameter_solvable(p, pmap, defs, guesses)
102+
# If either of them are `missing` the parameter is an unknown
103+
# But if the parameter is passed a value, use that as an additional
104+
# equation in the system
105+
_val1 = get(pmap, p, nothing)
106+
_val2 = get(defs, p, nothing)
107+
_val3 = get(guesses, p, nothing)
104108
varp = tovar(p)
105109
paramsubs[p] = varp
106-
if _val1 !== nothing && _val1 !== missing
107-
push!(eqs_ics, varp ~ _val1)
108-
end
109-
if !haskey(guesses, p)
110-
error("Invalid setup: parameter $(p) has no default value or initial guess")
110+
# Has a default of `missing`, and (either an equation using the value passed to `ODEProblem` or a guess)
111+
if _val2 === missing
112+
if _val1 !== nothing && _val1 !== missing
113+
push!(eqs_ics, varp ~ _val1)
114+
push!(u0, varp => _val1)
115+
elseif _val3 !== nothing
116+
# assuming an equation exists (either via algebraic equations or initialization_eqs)
117+
push!(u0, varp => _val1)
118+
elseif check_defguess
119+
error("Invalid setup: parameter $(p) has no default value, initial value, or guess")
120+
end
121+
# `missing` passed to `ODEProblem`, and (either an equation using default or a guess)
122+
elseif _val1 === missing
123+
if _val2 !== nothing && _val2 !== missing
124+
push!(eqs_ics, varp ~ _val2)
125+
push!(u0, varp => _val2)
126+
elseif _val3 !== nothing
127+
push!(u0, varp => _val1)
128+
elseif check_defguess
129+
error("Invalid setup: parameter $(p) has no default value, initial value, or guess")
130+
end
131+
# No value passed to `ODEProblem`, but a default and a guess are present
132+
# _val2 !== missing is implied by it falling this far in the elseif chain
133+
elseif _val1 === nothing && _val2 !== nothing && _val3 !== nothing
134+
push!(eqs_ics, varp ~ _val2)
135+
push!(u0, varp => _val3)
136+
else
137+
# _val1 !== missing and _val1 !== nothing, so a value was provided to ODEProblem
138+
# This would mean `is_parameter_solvable` returned `false`, so we never end up
139+
# here
140+
error("This should never be reached")
111141
end
112-
push!(u0, varp => guesses[p])
113142
end
114143
end
115144
pars = vcat(
116145
[get_iv(sys)],
117146
[p for p in parameters(sys) if !haskey(paramsubs, p)]
118147
)
119-
pdeps = parameter_dependencies(sys)
120-
if !isempty(pdeps)
121-
pdep_eqs = [k ~ v for (k, v) in pdeps]
122-
else
123-
pdep_eqs = Equation[]
124-
end
125148
nleqs = if algebraic_only
126-
[eqs_ics; observed(sys); pdep_eqs]
149+
[eqs_ics; observed(sys)]
127150
else
128-
[eqs_ics; get_initialization_eqs(sys); initialization_eqs; observed(sys); pdep_eqs]
151+
[eqs_ics; get_initialization_eqs(sys); initialization_eqs; observed(sys)]
129152
end
130153
nleqs = Symbolics.substitute.(nleqs, (paramsubs,))
131154
unks = [full_states; collect(values(paramsubs))]
@@ -142,6 +165,15 @@ function generate_initializesystem(sys::ODESystem;
142165
return sys_nl
143166
end
144167

168+
function is_parameter_solvable(p, pmap, defs, guesses)
169+
_val1 = pmap isa AbstractDict ? get(pmap, p, nothing) : nothing
170+
_val2 = get(defs, p, nothing)
171+
_val3 = get(guesses, p, nothing)
172+
# either (missing is a default or was passed to the ODEProblem) or (nothing was passed to
173+
# the ODEProblem and it has a default and a guess)
174+
return (_val1 === missing || _val2 === missing) || (_val1 === nothing && _val2 !== nothing && _val3 !== nothing)
175+
end
176+
145177
function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
146178
if (u0 === missing || !(eltype(u0) <: Pair) || isempty(u0)) &&
147179
(p === missing || !(eltype(p) <: Pair) || isempty(p))
@@ -153,13 +185,16 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
153185
if p === missing
154186
p = Dict()
155187
end
188+
if t0 === nothing
189+
t0 = 0.0
190+
end
156191
u0 = todict(u0)
157192
p = todict(p)
158193
initprob = InitializationProblem(sys, t0, u0, p)
159194
initprobmap = getu(initprob, unknowns(sys))
160195
punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)]
161196
getpunknowns = getu(initprob, punknowns)
162-
setpunknowns = setp_oop(sys, punknowns)
197+
setpunknowns = setp(sys, punknowns)
163198
initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
164199
return initprob, initprobmap, initprobpmap
165200
end

0 commit comments

Comments
 (0)