Skip to content

Commit f9837e4

Browse files
fixup! feat: allow parameters to be unknowns in the initialization system
1 parent 9a3cd2e commit f9837e4

File tree

1 file changed

+54
-19
lines changed

1 file changed

+54
-19
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -102,36 +102,59 @@ function generate_initializesystem(sys::ODESystem;
102102
if pmap isa SciMLBase.NullParameters
103103
pmap = Dict()
104104
end
105+
pmap = todict(pmap)
105106
for p in parameters(sys)
106-
# If either of them are `missing` the parameter is an unknown
107-
# But if the parameter is passed a value, use that as an additional
108-
# equation in the system
109-
if (_val1 = get(pmap, p, nothing)) === missing || get(defs, p, nothing) === missing
107+
if is_parameter_solvable(p, pmap, defs, guesses)
108+
# If either of them are `missing` the parameter is an unknown
109+
# But if the parameter is passed a value, use that as an additional
110+
# equation in the system
111+
_val1 = get(pmap, p, nothing)
112+
_val2 = get(defs, p, nothing)
113+
_val3 = get(guesses, p, nothing)
110114
varp = tovar(p)
111115
paramsubs[p] = varp
112-
if _val1 !== nothing && _val1 !== missing
113-
push!(eqs_ics, varp ~ _val1)
114-
end
115-
if !haskey(guesses, p)
116-
error("Invalid setup: parameter $(p) has no default value or initial guess")
116+
# Has a default of `missing`, and (either an equation using the value passed to `ODEProblem` or a guess)
117+
if _val2 === missing
118+
if _val1 !== nothing && _val1 !== missing
119+
push!(eqs_ics, varp ~ _val1)
120+
push!(u0, varp => _val1)
121+
elseif _val3 !== nothing
122+
# assuming an equation exists (either via algebraic equations or initialization_eqs)
123+
push!(u0, varp => _val1)
124+
elseif check_defguess
125+
error("Invalid setup: parameter $(p) has no default value, initial value, or guess")
126+
end
127+
# `missing` passed to `ODEProblem`, and (either an equation using default or a guess)
128+
elseif _val1 === missing
129+
if _val2 !== nothing && _val2 !== missing
130+
push!(eqs_ics, varp ~ _val2)
131+
push!(u0, varp => _val2)
132+
elseif _val3 !== nothing
133+
push!(u0, varp => _val1)
134+
elseif check_defguess
135+
error("Invalid setup: parameter $(p) has no default value, initial value, or guess")
136+
end
137+
# No value passed to `ODEProblem`, but a default and a guess are present
138+
# _val2 !== missing is implied by it falling this far in the elseif chain
139+
elseif _val1 === nothing && _val2 !== nothing && _val3 !== nothing
140+
push!(eqs_ics, varp ~ _val2)
141+
push!(u0, varp => _val3)
142+
else
143+
# _val1 !== missing and _val1 !== nothing, so a value was provided to ODEProblem
144+
# This would mean `is_parameter_solvable` returned `false`, so we never end up
145+
# here
146+
error("This should never be reached")
117147
end
118-
push!(u0, varp => guesses[p])
119148
end
120149
end
121150
pars = vcat(
122151
[get_iv(sys)],
123152
[p for p in parameters(sys) if !haskey(paramsubs, p)]
124153
)
125-
pdeps = parameter_dependencies(sys)
126-
if !isempty(pdeps)
127-
pdep_eqs = [k ~ v for (k, v) in pdeps]
128-
else
129-
pdep_eqs = Equation[]
130-
end
131154
nleqs = if algebraic_only
132-
[eqs_ics; observed(sys); pdep_eqs]
155+
[eqs_ics; observed(sys)]
133156
else
134-
[eqs_ics; get_initialization_eqs(sys); initialization_eqs; observed(sys); pdep_eqs]
157+
[eqs_ics; get_initialization_eqs(sys); initialization_eqs; observed(sys)]
135158
end
136159
nleqs = Symbolics.substitute.(nleqs, (paramsubs,))
137160
unks = [full_states; collect(values(paramsubs))]
@@ -151,6 +174,15 @@ function generate_initializesystem(sys::ODESystem;
151174
return sys_nl
152175
end
153176

177+
function is_parameter_solvable(p, pmap, defs, guesses)
178+
_val1 = pmap isa AbstractDict ? get(pmap, p, nothing) : nothing
179+
_val2 = get(defs, p, nothing)
180+
_val3 = get(guesses, p, nothing)
181+
# either (missing is a default or was passed to the ODEProblem) or (nothing was passed to
182+
# the ODEProblem and it has a default and a guess)
183+
return (_val1 === missing || _val2 === missing) || (_val1 === nothing && _val2 !== nothing && _val3 !== nothing)
184+
end
185+
154186
function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
155187
if (u0 === missing || !(eltype(u0) <: Pair) || isempty(u0)) &&
156188
(p === missing || !(eltype(p) <: Pair) || isempty(p))
@@ -162,13 +194,16 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
162194
if p === missing
163195
p = Dict()
164196
end
197+
if t0 === nothing
198+
t0 = 0.0
199+
end
165200
u0 = todict(u0)
166201
p = todict(p)
167202
initprob = InitializationProblem(sys, t0, u0, p)
168203
initprobmap = getu(initprob, unknowns(sys))
169204
punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)]
170205
getpunknowns = getu(initprob, punknowns)
171-
setpunknowns = setp_oop(sys, punknowns)
206+
setpunknowns = setp(sys, punknowns)
172207
initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
173208
return initprob, initprobmap, initprobpmap
174209
end

0 commit comments

Comments
 (0)