diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 67bcbc88c0..72f3132889 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -733,11 +733,11 @@ function add_initialization_parameters(sys::AbstractSystem) defs = copy(get_defaults(sys)) for ivar in initials if symbolic_type(ivar) == ScalarSymbolic() - defs[ivar] = zero_var(ivar) + defs[ivar] = false else defs[ivar] = collect(ivar) for scal_ivar in defs[ivar] - defs[scal_ivar] = zero_var(scal_ivar) + defs[scal_ivar] = false end end end diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 196c692c67..62ddd12a08 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1541,32 +1541,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem, filter_missing_values!(parammap) u0map = merge(ModelingToolkit.guesses(sys), todict(guesses), u0map) - fullmap = merge(u0map, parammap) - u0T = Union{} - for sym in unknowns(isys) - val = fixpoint_sub(sym, fullmap) - symbolic_type(val) == NotSymbolic() || continue - u0T = promote_type(u0T, typeof(val)) - end - for eq in observed(isys) - # ignore HACK-ed observed equations - symbolic_type(eq.lhs) == ArraySymbolic() && continue - val = fixpoint_sub(eq.lhs, fullmap) - symbolic_type(val) == NotSymbolic() || continue - u0T = promote_type(u0T, typeof(val)) - end - if u0T != Union{} - u0T = eltype(u0T) - u0map = Dict(k => if v === nothing - nothing - elseif symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) - v isa AbstractArray ? u0T.(v) : u0T(v) - else - v - end - for (k, v) in u0map) - end - TProb = if neqs == nunknown && isempty(unassigned_vars) if use_scc && neqs > 0 if is_split(isys) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 9c85b9b324..ec25b9b660 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -8,7 +8,7 @@ function generate_initializesystem(sys::AbstractTimeDependentSystem; pmap = Dict(), initialization_eqs = [], guesses = Dict(), - default_dd_guess = 0.0, + default_dd_guess = Bool(0), algebraic_only = false, check_units = true, check_defguess = false, name = nameof(sys), extra_metadata = (;), kwargs...) @@ -646,10 +646,12 @@ function SciMLBase.remake_initialization_data( op, missing_unknowns, missing_pars = build_operating_point!(sys, u0map, pmap, defs, cmap, dvs, ps) + floatT = float_type_from_varmap(op) kws = maybe_build_initialization_problem( sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; - use_scc, initialization_eqs, allow_incomplete = true) - return get(kws, :initialization_data, nothing) + use_scc, initialization_eqs, floatT, allow_incomplete = true) + + return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp) end function SciMLBase.late_binding_update_u0_p( diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 36105c0de8..5fb830290f 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -1,6 +1,6 @@ symconvert(::Type{Symbolics.Struct{T}}, x) where {T} = convert(T, x) -symconvert(::Type{T}, x) where {T} = convert(T, x) -symconvert(::Type{Real}, x::Integer) = convert(Float64, x) +symconvert(::Type{T}, x::V) where {T, V} = convert(promote_type(T, V), x) +symconvert(::Type{Real}, x::Integer) = convert(Float16, x) symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x)) struct MTKParameters{T, I, D, C, N, H} @@ -28,7 +28,7 @@ the default behavior). """ function MTKParameters( sys::AbstractSystem, p, u0 = Dict(); tofloat = false, - t0 = nothing, substitution_limit = 1000) + t0 = nothing, substitution_limit = 1000, floatT = nothing) ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing get_index_cache(sys) else @@ -56,6 +56,10 @@ function MTKParameters( op[get_iv(sys)] = t0 end + if floatT === nothing + floatT = float(float_type_from_varmap(op)) + end + isempty(missing_pars) || throw(MissingParametersError(collect(missing_pars))) evaluate_varmap!(op, ps; limit = substitution_limit) @@ -111,6 +115,9 @@ function MTKParameters( if ctype <: FnType ctype = fntype_to_function_type(ctype) end + if ctype == Real && floatT !== nothing + ctype = floatT + end val = symconvert(ctype, val) done = set_value(sym, val) if !done && Symbolics.isarraysymbolic(sym) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 1db153781a..0750585905 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -247,25 +247,6 @@ function recursive_unwrap(x::AbstractDict) return anydict(unwrap(k) => recursive_unwrap(v) for (k, v) in x) end -""" - $(TYPEDSIGNATURES) - -Return the appropriate zero value for a symbolic variable representing a number or array of -numbers. Sized array symbolics return a zero-filled array of matching size. Unsized array -symbolics return an empty array of the appropriate `eltype`. -""" -function zero_var(x::Symbolic{T}) where {V <: Number, T <: Union{V, AbstractArray{V}}} - if Symbolics.isarraysymbolic(x) - if is_sized_array_symbolic(x) - return zeros(eltype(T), size(x)) - else - return T[] - end - else - return zero(T) - end -end - """ $(TYPEDSIGNATURES) @@ -362,7 +343,7 @@ Keyword arguments: - `is_initializeprob, guesses`: Used to determine whether the system is missing guesses. """ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector; - tofloat = true, container_type = Array, + tofloat = true, container_type = Array, floatT = Nothing, toterm = default_toterm, promotetoconcrete = nothing, check = true, allow_symbolic = false, is_initializeprob = false) isempty(vars) && return nothing @@ -385,6 +366,9 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector; is_initializeprob ? throw(MissingGuessError(missingsyms, missingvals)) : throw(UnexpectedSymbolicValueInVarmap(missingsyms[1], missingvals[1])) end + if tofloat && !(floatT == Nothing) + vals = floatT.(vals) + end end if container_type <: Union{AbstractDict, Tuple, Nothing, SciMLBase.NullParameters} @@ -533,12 +517,12 @@ function (f::UpdateInitializeprob)(initializeprob, prob) f.setvals(initializeprob, f.getvals(prob)) end -function get_temporary_value(p) +function get_temporary_value(p, floatT = Float64) stype = symtype(unwrap(p)) return if stype == Real - zero(Float64) + zero(floatT) elseif stype <: AbstractArray{Real} - zeros(Float64, size(p)) + zeros(floatT, size(p)) elseif stype <: Real zero(stype) elseif stype <: AbstractArray @@ -648,15 +632,32 @@ All other keyword arguments are forwarded to `InitializationProblem`. """ function maybe_build_initialization_problem( sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs, - guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, kwargs...) + guesses, missing_unknowns; implicit_dae = false, + u0_constructor = identity, floatT = Float64, kwargs...) guesses = merge(ModelingToolkit.guesses(sys), todict(guesses)) if t === nothing && is_time_dependent(sys) - t = 0.0 + t = zero(floatT) end initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}( sys, t, u0map, pmap; guesses, kwargs...) + if state_values(initializeprob) !== nothing + initializeprob = remake(initializeprob; u0 = floatT.(state_values(initializeprob))) + end + initp = parameter_values(initializeprob) + if is_split(sys) + buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), initp) + initp = repack(floatT.(buffer)) + buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Initials(), initp) + initp = repack(floatT.(buffer)) + elseif initp isa AbstractArray + initp′ = similar(initp, floatT) + copyto!(initp′, initp) + initp = initp′ + end + initializeprob = remake(initializeprob; p = initp) + meta = get_metadata(initializeprob.f.sys) if is_time_dependent(sys) @@ -692,7 +693,7 @@ function maybe_build_initialization_problem( get(op, p, missing) === missing || continue p = unwrap(p) stype = symtype(p) - op[p] = get_temporary_value(p) + op[p] = get_temporary_value(p, floatT) if iscall(p) && operation(p) === getindex arrp = arguments(p)[1] op[arrp] = collect(arrp) @@ -701,7 +702,7 @@ function maybe_build_initialization_problem( if is_time_dependent(sys) for v in missing_unknowns - op[v] = zero_var(v) + op[v] = get_temporary_value(v, floatT) end empty!(missing_unknowns) end @@ -712,6 +713,26 @@ function maybe_build_initialization_problem( initializeprobpmap)) end +""" + $(TYPEDSIGNATURES) + +Calculate the floating point type to use from the given `varmap` by looking at variables +with a constant value. +""" +function float_type_from_varmap(varmap, floatT = Bool) + for (k, v) in varmap + symbolic_type(v) == NotSymbolic() || continue + is_array_of_symbolics(v) && continue + + if v isa AbstractArray + floatT = promote_type(floatT, eltype(v)) + elseif v isa Real + floatT = promote_type(floatT, typeof(v)) + end + end + return float(floatT) +end + """ $(TYPEDSIGNATURES) @@ -815,12 +836,19 @@ function process_SciMLProblem( op, missing_unknowns, missing_pars = build_operating_point!(sys, u0map, pmap, defs, cmap, dvs, ps) + floatT = Bool + if u0Type <: AbstractArray && eltype(u0Type) <: Real + floatT = float(eltype(u0Type)) + else + floatT = float_type_from_varmap(op, floatT) + end + if !is_time_dependent(sys) || is_initializesystem(sys) add_observed_equations!(u0map, obs) end if u0_constructor === identity && u0Type <: StaticArray u0_constructor = vals -> SymbolicUtils.Code.create_array( - u0Type, eltype(vals), Val(1), Val(length(vals)), vals...) + u0Type, floatT, Val(1), Val(length(vals)), vals...) end if build_initializeprob kws = maybe_build_initialization_problem( @@ -830,7 +858,7 @@ function process_SciMLProblem( warn_cyclic_dependency, check_units = check_initialization_units, circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc, force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete, - u0_constructor) + u0_constructor, floatT) kwargs = merge(kwargs, kws) end @@ -858,7 +886,7 @@ function process_SciMLProblem( evaluate_varmap!(op, dvs; limit = substitution_limit) u0 = better_varmap_to_vars( - op, dvs; tofloat, + op, dvs; tofloat, floatT, container_type = u0Type, allow_symbolic = symbolic_u0, is_initializeprob) if u0 !== nothing @@ -882,7 +910,7 @@ function process_SciMLProblem( end evaluate_varmap!(op, ps; limit = substitution_limit) if is_split(sys) - p = MTKParameters(sys, op) + p = MTKParameters(sys, op; floatT = floatT) else p = better_varmap_to_vars(op, ps; tofloat, container_type = pType) end @@ -898,6 +926,16 @@ function process_SciMLProblem( du0 = nothing end + if build_initializeprob + t0 = t + if is_time_dependent(sys) && t0 === nothing + t0 = zero(floatT) + end + initialization_data = SciMLBase.remake_initialization_data( + kwargs.initialization_data, kwargs, u0, t0, p, u0, p) + kwargs = merge(kwargs,) + end + f = constructor(sys, dvs, ps, u0; p = p, eval_expression = eval_expression, eval_module = eval_module, diff --git a/test/initial_values.jl b/test/initial_values.jl index f63333cb2f..01a053a9bf 100644 --- a/test/initial_values.jl +++ b/test/initial_values.jl @@ -252,3 +252,16 @@ end ps = [p => [4.0, 5.0]] @test_nowarn NonlinearProblem(nlsys, u0, ps) end + +@testset "Issue#3553: Retain `Float32` initial values" begin + @parameters p d + @variables X(t) + eqs = [D(X) ~ p - d * X] + @mtkbuild osys = ODESystem(eqs, t) + u0 = [X => 1.0f0] + ps = [p => 1.0f0, d => 2.0f0] + oprob = ODEProblem(osys, u0, (0.0f0, 1.0f0), ps) + sol = solve(oprob) + @test eltype(oprob.u0) == Float32 + @test eltype(eltype(sol.u)) == Float32 +end