Skip to content

Commit b55646e

Browse files
Merge pull request #3714 from AayushSabharwal/as/fix-initsys
fix: ensure `=> nothing` overrides defaults
2 parents 29f3c6e + 3224c33 commit b55646e

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ function generate_initializesystem_timevarying(sys::AbstractSystem;
6161
isempty(trueobs) || filter_delay_equations_variables!(sys, trueobs)
6262
vars = unique([unknowns(sys); getfield.(trueobs, :lhs)])
6363
vars_set = Set(vars) # for efficient in-lookup
64+
arrvars = Set()
65+
for var in vars
66+
if iscall(var) && operation(var) === getindex
67+
push!(arrvars, first(arguments(var)))
68+
end
69+
end
6470

6571
eqs_ics = Equation[]
6672
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
@@ -71,9 +77,13 @@ function generate_initializesystem_timevarying(sys::AbstractSystem;
7177

7278
# PREPROCESSING
7379
op = anydict(op)
80+
if isempty(op)
81+
op = copy(defs)
82+
end
83+
scalarize_vars_in_varmap!(op, arrvars)
7484
u0map = anydict()
7585
pmap = anydict()
76-
build_operating_point!(sys, op, u0map, pmap, defs, unknowns(sys),
86+
build_operating_point!(sys, op, u0map, pmap, Dict(), unknowns(sys),
7787
parameters(sys; initial_parameters = true))
7888
for (k, v) in op
7989
if has_parameter_dependency_with_lhs(sys, k) && is_variable_floatingpoint(k)
@@ -144,7 +154,7 @@ function generate_initializesystem_timevarying(sys::AbstractSystem;
144154

145155
# 3) process other variables
146156
for var in vars
147-
if var keys(defs)
157+
if var keys(op)
148158
push!(eqs_ics, var ~ defs[var])
149159
elseif var keys(guesses)
150160
push!(defs, var => guesses[var])
@@ -238,7 +248,7 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem;
238248
op = anydict(op)
239249
u0map = anydict()
240250
pmap = anydict()
241-
build_operating_point!(sys, op, u0map, pmap, defs, unknowns(sys),
251+
build_operating_point!(sys, op, u0map, pmap, Dict(), unknowns(sys),
242252
parameters(sys; initial_parameters = true))
243253
for (k, v) in op
244254
if has_parameter_dependency_with_lhs(sys, k) && is_variable_floatingpoint(k)

src/systems/problem_utils.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,24 @@ function scalarize_varmap!(varmap::AbstractDict)
523523
return varmap
524524
end
525525

526+
"""
527+
$(TYPEDSIGNATURES)
528+
529+
For each array variable in `vars`, scalarize the corresponding entry in `varmap`.
530+
If a scalarized entry already exists, it is not overridden.
531+
"""
532+
function scalarize_vars_in_varmap!(varmap::AbstractDict, vars)
533+
for var in vars
534+
symbolic_type(var) == ArraySymbolic() || continue
535+
is_sized_array_symbolic(var) || continue
536+
haskey(varmap, var) || continue
537+
for i in eachindex(var)
538+
haskey(varmap, var[i]) && continue
539+
varmap[var[i]] = varmap[var][i]
540+
end
541+
end
542+
end
543+
526544
function get_temporary_value(p, floatT = Float64)
527545
stype = symtype(unwrap(p))
528546
return if stype == Real

test/initializationsystem.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,3 +1664,10 @@ end
16641664
sol = solve(prob, Tsit5())
16651665
@test SciMLBase.successful_retcode(sol)
16661666
end
1667+
1668+
@testset "Defaults removed with ` => nothing` aren't retained" begin
1669+
@variables x(t)[1:2]
1670+
@mtkbuild sys = System([D(x[1]) ~ -x[1], x[1] + x[2] ~ 3], t; defaults = [x[1] => 1])
1671+
prob = ODEProblem(sys, [x[1] => nothing, x[2] => 1], (0.0, 1.0))
1672+
@test SciMLBase.initialization_status(prob) == SciMLBase.FULLY_DETERMINED
1673+
end

0 commit comments

Comments
 (0)