Skip to content

Commit fb1fb93

Browse files
refactor: determine floating point type earlier and propagate
1 parent 742146e commit fb1fb93

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

src/systems/problem_utils.jl

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -631,11 +631,12 @@ All other keyword arguments are forwarded to `InitializationProblem`.
631631
"""
632632
function maybe_build_initialization_problem(
633633
sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs,
634-
guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, kwargs...)
634+
guesses, missing_unknowns; implicit_dae = false,
635+
u0_constructor = identity, floatT = Float64, kwargs...)
635636
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))
636637

637638
if t === nothing && is_time_dependent(sys)
638-
t = 0.0
639+
t = zero(floatT)
639640
end
640641

641642
initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}(
@@ -675,7 +676,7 @@ function maybe_build_initialization_problem(
675676
get(op, p, missing) === missing || continue
676677
p = unwrap(p)
677678
stype = symtype(p)
678-
op[p] = get_temporary_value(p)
679+
op[p] = get_temporary_value(p, floatT)
679680
if iscall(p) && operation(p) === getindex
680681
arrp = arguments(p)[1]
681682
op[arrp] = collect(arrp)
@@ -684,7 +685,7 @@ function maybe_build_initialization_problem(
684685

685686
if is_time_dependent(sys)
686687
for v in missing_unknowns
687-
op[v] = zero_var(v)
688+
op[v] = get_temporary_value(v, floatT)
688689
end
689690
empty!(missing_unknowns)
690691
end
@@ -798,12 +799,26 @@ function process_SciMLProblem(
798799
op, missing_unknowns, missing_pars = build_operating_point!(sys,
799800
u0map, pmap, defs, cmap, dvs, ps)
800801

802+
floatT = Bool
803+
for (k, v) in op
804+
symbolic_type(v) == NotSymbolic() || continue
805+
is_array_of_symbolics(v) && continue
806+
807+
if v isa AbstractArray
808+
isconcretetype(eltype(v)) || continue
809+
floatT = promote_type(floatT, eltype(v))
810+
elseif v isa Real && isconcretetype(v)
811+
floatT = promote_type(floatT, typeof(v))
812+
end
813+
end
814+
floatT = float(floatT)
815+
801816
if !is_time_dependent(sys) || is_initializesystem(sys)
802817
add_observed_equations!(u0map, obs)
803818
end
804819
if u0_constructor === identity && u0Type <: StaticArray
805820
u0_constructor = vals -> SymbolicUtils.Code.create_array(
806-
u0Type, eltype(vals), Val(1), Val(length(vals)), vals...)
821+
u0Type, floatT, Val(1), Val(length(vals)), vals...)
807822
end
808823
if build_initializeprob
809824
kws = maybe_build_initialization_problem(
@@ -813,7 +828,7 @@ function process_SciMLProblem(
813828
warn_cyclic_dependency, check_units = check_initialization_units,
814829
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc,
815830
force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete,
816-
u0_constructor)
831+
u0_constructor, floatT)
817832

818833
kwargs = merge(kwargs, kws)
819834
end
@@ -841,7 +856,7 @@ function process_SciMLProblem(
841856
evaluate_varmap!(op, dvs; limit = substitution_limit)
842857

843858
u0 = better_varmap_to_vars(
844-
op, dvs; tofloat,
859+
op, dvs; tofloat, floatT,
845860
container_type = u0Type, allow_symbolic = symbolic_u0, is_initializeprob)
846861

847862
if u0 !== nothing
@@ -865,7 +880,7 @@ function process_SciMLProblem(
865880
end
866881
evaluate_varmap!(op, ps; limit = substitution_limit)
867882
if is_split(sys)
868-
p = MTKParameters(sys, op)
883+
p = MTKParameters(sys, op; floatT = floatT)
869884
else
870885
p = better_varmap_to_vars(op, ps; tofloat, container_type = pType)
871886
end

0 commit comments

Comments
 (0)