Skip to content

fix: fix values being promoted to Float64 in problem construction #3561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
50f932f
fix: use `Float16` zero default for `Initial` parameters
AayushSabharwal Apr 14, 2025
0305c9d
test: test `Float32` values retained in problem construction
AayushSabharwal Apr 14, 2025
bcaa869
Update src/systems/abstractsystem.jl
ChrisRackauckas Apr 14, 2025
65a326d
Update src/systems/abstractsystem.jl
ChrisRackauckas Apr 14, 2025
aa7f38b
fix: fix type promotion in `MTKParameters` constructor
AayushSabharwal Apr 15, 2025
227a0d8
fix: use `Bool(0)` for `default_dd_guess`
AayushSabharwal Apr 15, 2025
b228fd9
fix: promote integer valued reals to `Float16`
AayushSabharwal Apr 15, 2025
c814779
feat: allow specifying `eltype` of numeric buffer in `better_varmap_t…
AayushSabharwal Apr 15, 2025
a8d59d6
feat: allow specifying floating point type in `get_temporary_value`
AayushSabharwal Apr 15, 2025
742146e
refactor: remove `zero_var`
AayushSabharwal Apr 15, 2025
fb1fb93
refactor: determine floating point type earlier and propagate
AayushSabharwal Apr 15, 2025
59c6b31
refactor: remove type-promotion hack in `InitializationProblem`
AayushSabharwal Apr 15, 2025
27e7efb
fix: use `remake` to promote intialization problem
AayushSabharwal Apr 15, 2025
8d2482c
fix: fix `symconvert`
AayushSabharwal Apr 15, 2025
38724b2
fix: allow specifying floating point type in `MTKParameters`
AayushSabharwal Apr 15, 2025
57d7da1
Revert "fix: use `remake` to promote intialization problem"
AayushSabharwal Apr 15, 2025
9e6e894
fix: use `remake_initialization_data` to promote the initialization p…
AayushSabharwal Apr 15, 2025
662217b
fix: promote initializeprob in `remake_initialization_data`
AayushSabharwal Apr 15, 2025
f9cdd16
fix: respect `floatT` in `maybe_build_initialization_problem`
AayushSabharwal Apr 15, 2025
9fe7586
fix: respect `tofloat` when using `floatT`, default to not promoting
AayushSabharwal Apr 15, 2025
a8e2495
fix: do not require `isconcretetype` when discovering `floatT`
AayushSabharwal Apr 15, 2025
6023efb
refactor: add `float_type_from_varmap`
AayushSabharwal Apr 16, 2025
ead8682
fix: infer `floatT` in `MTKParameters` constructor if not provided
AayushSabharwal Apr 16, 2025
cd19c7e
fix: infer `floatT` in `remake_initialization_data`
AayushSabharwal Apr 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -733,11 +733,11 @@ function add_initialization_parameters(sys::AbstractSystem)
defs = copy(get_defaults(sys))
for ivar in initials
if symbolic_type(ivar) == ScalarSymbolic()
defs[ivar] = zero_var(ivar)
defs[ivar] = false
else
defs[ivar] = collect(ivar)
for scal_ivar in defs[ivar]
defs[scal_ivar] = zero_var(scal_ivar)
defs[scal_ivar] = false
end
end
end
Expand Down
26 changes: 0 additions & 26 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1541,32 +1541,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
filter_missing_values!(parammap)
u0map = merge(ModelingToolkit.guesses(sys), todict(guesses), u0map)

fullmap = merge(u0map, parammap)
u0T = Union{}
for sym in unknowns(isys)
val = fixpoint_sub(sym, fullmap)
symbolic_type(val) == NotSymbolic() || continue
u0T = promote_type(u0T, typeof(val))
end
for eq in observed(isys)
# ignore HACK-ed observed equations
symbolic_type(eq.lhs) == ArraySymbolic() && continue
val = fixpoint_sub(eq.lhs, fullmap)
symbolic_type(val) == NotSymbolic() || continue
u0T = promote_type(u0T, typeof(val))
end
if u0T != Union{}
u0T = eltype(u0T)
u0map = Dict(k => if v === nothing
nothing
elseif symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)
v isa AbstractArray ? u0T.(v) : u0T(v)
else
v
end
for (k, v) in u0map)
end

TProb = if neqs == nunknown && isempty(unassigned_vars)
if use_scc && neqs > 0
if is_split(isys)
Expand Down
8 changes: 5 additions & 3 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function generate_initializesystem(sys::AbstractTimeDependentSystem;
pmap = Dict(),
initialization_eqs = [],
guesses = Dict(),
default_dd_guess = 0.0,
default_dd_guess = Bool(0),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just false?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

false makes the kwarg seem like a flag, when it is actually a value. So I opted for Bool(0) just to make it clearer.

algebraic_only = false,
check_units = true, check_defguess = false,
name = nameof(sys), extra_metadata = (;), kwargs...)
Expand Down Expand Up @@ -646,10 +646,12 @@ function SciMLBase.remake_initialization_data(

op, missing_unknowns, missing_pars = build_operating_point!(sys,
u0map, pmap, defs, cmap, dvs, ps)
floatT = float_type_from_varmap(op)
kws = maybe_build_initialization_problem(
sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns;
use_scc, initialization_eqs, allow_incomplete = true)
return get(kws, :initialization_data, nothing)
use_scc, initialization_eqs, floatT, allow_incomplete = true)

return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp)
end

function SciMLBase.late_binding_update_u0_p(
Expand Down
13 changes: 10 additions & 3 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
symconvert(::Type{Symbolics.Struct{T}}, x) where {T} = convert(T, x)
symconvert(::Type{T}, x) where {T} = convert(T, x)
symconvert(::Type{Real}, x::Integer) = convert(Float64, x)
symconvert(::Type{T}, x::V) where {T, V} = convert(promote_type(T, V), x)
symconvert(::Type{Real}, x::Integer) = convert(Float16, x)
symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x))

struct MTKParameters{T, I, D, C, N, H}
Expand Down Expand Up @@ -28,7 +28,7 @@ the default behavior).
"""
function MTKParameters(
sys::AbstractSystem, p, u0 = Dict(); tofloat = false,
t0 = nothing, substitution_limit = 1000)
t0 = nothing, substitution_limit = 1000, floatT = nothing)
ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing
get_index_cache(sys)
else
Expand Down Expand Up @@ -56,6 +56,10 @@ function MTKParameters(
op[get_iv(sys)] = t0
end

if floatT === nothing
floatT = float(float_type_from_varmap(op))
end

isempty(missing_pars) || throw(MissingParametersError(collect(missing_pars)))
evaluate_varmap!(op, ps; limit = substitution_limit)

Expand Down Expand Up @@ -111,6 +115,9 @@ function MTKParameters(
if ctype <: FnType
ctype = fntype_to_function_type(ctype)
end
if ctype == Real && floatT !== nothing
ctype = floatT
end
val = symconvert(ctype, val)
done = set_value(sym, val)
if !done && Symbolics.isarraysymbolic(sym)
Expand Down
100 changes: 69 additions & 31 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,25 +247,6 @@ function recursive_unwrap(x::AbstractDict)
return anydict(unwrap(k) => recursive_unwrap(v) for (k, v) in x)
end

"""
$(TYPEDSIGNATURES)

Return the appropriate zero value for a symbolic variable representing a number or array of
numbers. Sized array symbolics return a zero-filled array of matching size. Unsized array
symbolics return an empty array of the appropriate `eltype`.
"""
function zero_var(x::Symbolic{T}) where {V <: Number, T <: Union{V, AbstractArray{V}}}
if Symbolics.isarraysymbolic(x)
if is_sized_array_symbolic(x)
return zeros(eltype(T), size(x))
else
return T[]
end
else
return zero(T)
end
end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -362,7 +343,7 @@ Keyword arguments:
- `is_initializeprob, guesses`: Used to determine whether the system is missing guesses.
"""
function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
tofloat = true, container_type = Array,
tofloat = true, container_type = Array, floatT = Nothing,
toterm = default_toterm, promotetoconcrete = nothing, check = true,
allow_symbolic = false, is_initializeprob = false)
isempty(vars) && return nothing
Expand All @@ -385,6 +366,9 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
is_initializeprob ? throw(MissingGuessError(missingsyms, missingvals)) :
throw(UnexpectedSymbolicValueInVarmap(missingsyms[1], missingvals[1]))
end
if tofloat && !(floatT == Nothing)
vals = floatT.(vals)
end
end

if container_type <: Union{AbstractDict, Tuple, Nothing, SciMLBase.NullParameters}
Expand Down Expand Up @@ -533,12 +517,12 @@ function (f::UpdateInitializeprob)(initializeprob, prob)
f.setvals(initializeprob, f.getvals(prob))
end

function get_temporary_value(p)
function get_temporary_value(p, floatT = Float64)
stype = symtype(unwrap(p))
return if stype == Real
zero(Float64)
zero(floatT)
elseif stype <: AbstractArray{Real}
zeros(Float64, size(p))
zeros(floatT, size(p))
elseif stype <: Real
zero(stype)
elseif stype <: AbstractArray
Expand Down Expand Up @@ -648,15 +632,32 @@ All other keyword arguments are forwarded to `InitializationProblem`.
"""
function maybe_build_initialization_problem(
sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs,
guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, kwargs...)
guesses, missing_unknowns; implicit_dae = false,
u0_constructor = identity, floatT = Float64, kwargs...)
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))

if t === nothing && is_time_dependent(sys)
t = 0.0
t = zero(floatT)
end

initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}(
sys, t, u0map, pmap; guesses, kwargs...)
if state_values(initializeprob) !== nothing
initializeprob = remake(initializeprob; u0 = floatT.(state_values(initializeprob)))
end
initp = parameter_values(initializeprob)
if is_split(sys)
buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), initp)
initp = repack(floatT.(buffer))
buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Initials(), initp)
initp = repack(floatT.(buffer))
elseif initp isa AbstractArray
initp′ = similar(initp, floatT)
copyto!(initp′, initp)
initp = initp′
end
initializeprob = remake(initializeprob; p = initp)

meta = get_metadata(initializeprob.f.sys)

if is_time_dependent(sys)
Expand Down Expand Up @@ -692,7 +693,7 @@ function maybe_build_initialization_problem(
get(op, p, missing) === missing || continue
p = unwrap(p)
stype = symtype(p)
op[p] = get_temporary_value(p)
op[p] = get_temporary_value(p, floatT)
if iscall(p) && operation(p) === getindex
arrp = arguments(p)[1]
op[arrp] = collect(arrp)
Expand All @@ -701,7 +702,7 @@ function maybe_build_initialization_problem(

if is_time_dependent(sys)
for v in missing_unknowns
op[v] = zero_var(v)
op[v] = get_temporary_value(v, floatT)
end
empty!(missing_unknowns)
end
Expand All @@ -712,6 +713,26 @@ function maybe_build_initialization_problem(
initializeprobpmap))
end

"""
$(TYPEDSIGNATURES)

Calculate the floating point type to use from the given `varmap` by looking at variables
with a constant value.
"""
function float_type_from_varmap(varmap, floatT = Bool)
for (k, v) in varmap
symbolic_type(v) == NotSymbolic() || continue
is_array_of_symbolics(v) && continue

if v isa AbstractArray
floatT = promote_type(floatT, eltype(v))
elseif v isa Real
floatT = promote_type(floatT, typeof(v))
end
end
return float(floatT)
end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -815,12 +836,19 @@ function process_SciMLProblem(
op, missing_unknowns, missing_pars = build_operating_point!(sys,
u0map, pmap, defs, cmap, dvs, ps)

floatT = Bool
if u0Type <: AbstractArray && eltype(u0Type) <: Real
floatT = float(eltype(u0Type))
else
floatT = float_type_from_varmap(op, floatT)
end

if !is_time_dependent(sys) || is_initializesystem(sys)
add_observed_equations!(u0map, obs)
end
if u0_constructor === identity && u0Type <: StaticArray
u0_constructor = vals -> SymbolicUtils.Code.create_array(
u0Type, eltype(vals), Val(1), Val(length(vals)), vals...)
u0Type, floatT, Val(1), Val(length(vals)), vals...)
end
if build_initializeprob
kws = maybe_build_initialization_problem(
Expand All @@ -830,7 +858,7 @@ function process_SciMLProblem(
warn_cyclic_dependency, check_units = check_initialization_units,
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc,
force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete,
u0_constructor)
u0_constructor, floatT)

kwargs = merge(kwargs, kws)
end
Expand Down Expand Up @@ -858,7 +886,7 @@ function process_SciMLProblem(
evaluate_varmap!(op, dvs; limit = substitution_limit)

u0 = better_varmap_to_vars(
op, dvs; tofloat,
op, dvs; tofloat, floatT,
container_type = u0Type, allow_symbolic = symbolic_u0, is_initializeprob)

if u0 !== nothing
Expand All @@ -882,7 +910,7 @@ function process_SciMLProblem(
end
evaluate_varmap!(op, ps; limit = substitution_limit)
if is_split(sys)
p = MTKParameters(sys, op)
p = MTKParameters(sys, op; floatT = floatT)
else
p = better_varmap_to_vars(op, ps; tofloat, container_type = pType)
end
Expand All @@ -898,6 +926,16 @@ function process_SciMLProblem(
du0 = nothing
end

if build_initializeprob
t0 = t
if is_time_dependent(sys) && t0 === nothing
t0 = zero(floatT)
end
initialization_data = SciMLBase.remake_initialization_data(
kwargs.initialization_data, kwargs, u0, t0, p, u0, p)
kwargs = merge(kwargs,)
end

f = constructor(sys, dvs, ps, u0; p = p,
eval_expression = eval_expression,
eval_module = eval_module,
Expand Down
13 changes: 13 additions & 0 deletions test/initial_values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,16 @@ end
ps = [p => [4.0, 5.0]]
@test_nowarn NonlinearProblem(nlsys, u0, ps)
end

@testset "Issue#3553: Retain `Float32` initial values" begin
@parameters p d
@variables X(t)
eqs = [D(X) ~ p - d * X]
@mtkbuild osys = ODESystem(eqs, t)
u0 = [X => 1.0f0]
ps = [p => 1.0f0, d => 2.0f0]
oprob = ODEProblem(osys, u0, (0.0f0, 1.0f0), ps)
sol = solve(oprob)
@test eltype(oprob.u0) == Float32
@test eltype(eltype(sol.u)) == Float32
end
Loading