Skip to content

Commit 65853ff

Browse files
feat: validate parameter type and allow dependent initial values in param init
1 parent 31f260d commit 65853ff

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,16 @@ struct InitializationSystemMetadata
170170
end
171171

172172
function is_parameter_solvable(p, pmap, defs, guesses)
173+
p = unwrap(p)
174+
is_variable_floatingpoint(p) || return false
173175
_val1 = pmap isa AbstractDict ? get(pmap, p, nothing) : nothing
174176
_val2 = get(defs, p, nothing)
175177
_val3 = get(guesses, p, nothing)
176178
# either (missing is a default or was passed to the ODEProblem) or (nothing was passed to
177179
# the ODEProblem and it has a default and a guess)
178180
return ((_val1 === missing || _val2 === missing) ||
179-
(_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
181+
(symbolic_type(_val1) != NotSymbolic() ||
182+
_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
180183
end
181184

182185
function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,3 +885,9 @@ end
885885

886886
diff2term_with_unit(x, t) = _with_unit(diff2term, x, t)
887887
lower_varname_with_unit(var, iv, order) = _with_unit(lower_varname, var, iv, iv, order)
888+
889+
function is_variable_floatingpoint(sym)
890+
sym = unwrap(sym)
891+
T = symtype(sym)
892+
return T == Real || T <: AbstractFloat || T <: AbstractArray{Real} || T <: AbstractArray{<:AbstractFloat}
893+
end

test/initializationsystem.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -646,21 +646,39 @@ end
646646
prob2.ps[p] = 0.0
647647
test_parameter(prob2, p, 2.0)
648648

649-
# Should not be solved for:
649+
# Default overridden by ODEProblem, guess provided
650+
@mtkbuild sys = ODESystem(
651+
[D(x) ~ q * x, D(y) ~ y * p], t; defaults = [p => 2q], guesses = [p => 1.0])
652+
_pmap = merge(pmap, Dict(p => q))
653+
prob = ODEProblem(sys, u0map, (0.0, 1.0), _pmap)
654+
test_parameter(prob, p, pmap[q])
655+
test_initializesystem(sys, u0map, pmap, p, 0 ~ q - p)
650656

651-
# ODEProblem value with guess, no `missing`
657+
# ODEProblem dependent value with guess, no `missing`
652658
@mtkbuild sys = ODESystem([D(x) ~ x * q, D(y) ~ y * p], t; guesses = [p => 0.0])
653659
_pmap = merge(pmap, Dict(p => 3q))
654660
prob = ODEProblem(sys, u0map, (0.0, 1.0), _pmap)
655-
@test prob.ps[p] 3.0
656-
@test prob.f.initializeprob === nothing
657-
# Default overridden by ODEProblem, guess provided
661+
test_parameter(prob, p, 3pmap[q])
662+
663+
# Should not be solved for:
664+
665+
# Override dependent default with direct value
658666
@mtkbuild sys = ODESystem(
659667
[D(x) ~ q * x, D(y) ~ y * p], t; defaults = [p => 2q], guesses = [p => 1.0])
668+
_pmap = merge(pmap, Dict(p => 1.0))
660669
prob = ODEProblem(sys, u0map, (0.0, 1.0), _pmap)
661-
@test prob.ps[p] 3.0
670+
@test prob.ps[p] 1.0
662671
@test prob.f.initializeprob === nothing
663672

673+
# Non-floating point
674+
@parameters r::Int s::Int
675+
@mtkbuild sys = ODESystem(
676+
[D(x) ~ s * x, D(y) ~ y * r], t; defaults = [s => 2r], guesses = [s => 1.0])
677+
prob = ODEProblem(sys, u0map, (0.0, 1.0), [r => 1])
678+
@test prob.ps[r] == 1
679+
@test prob.ps[s] == 2
680+
@test prob.f.initializeprob === nothing
681+
664682
@mtkbuild sys = ODESystem([D(x) ~ x, p ~ x + y], t; guesses = [p => 0.0])
665683
@test_throws ModelingToolkit.MissingParametersError ODEProblem(
666684
sys, [x => 1.0, y => 1.0], (0.0, 1.0))

0 commit comments

Comments
 (0)