From 50f932fd0589da2f0ab0e1821abdc08c73c04f34 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 14 Apr 2025 17:43:56 +0530 Subject: [PATCH 01/24] fix: use `Float16` zero default for `Initial` parameters --- src/systems/abstractsystem.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 67bcbc88c0..b28b2d5279 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -733,11 +733,11 @@ function add_initialization_parameters(sys::AbstractSystem) defs = copy(get_defaults(sys)) for ivar in initials if symbolic_type(ivar) == ScalarSymbolic() - defs[ivar] = zero_var(ivar) + defs[ivar] = zero(Float16) else defs[ivar] = collect(ivar) for scal_ivar in defs[ivar] - defs[scal_ivar] = zero_var(scal_ivar) + defs[scal_ivar] = zero(Float16) end end end From 0305c9ddba2cf12bb4683cb08202dc2041b2958c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 14 Apr 2025 17:44:16 +0530 Subject: [PATCH 02/24] test: test `Float32` values retained in problem construction --- test/initial_values.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/initial_values.jl b/test/initial_values.jl index f63333cb2f..01a053a9bf 100644 --- a/test/initial_values.jl +++ b/test/initial_values.jl @@ -252,3 +252,16 @@ end ps = [p => [4.0, 5.0]] @test_nowarn NonlinearProblem(nlsys, u0, ps) end + +@testset "Issue#3553: Retain `Float32` initial values" begin + @parameters p d + @variables X(t) + eqs = [D(X) ~ p - d * X] + @mtkbuild osys = ODESystem(eqs, t) + u0 = [X => 1.0f0] + ps = [p => 1.0f0, d => 2.0f0] + oprob = ODEProblem(osys, u0, (0.0f0, 1.0f0), ps) + sol = solve(oprob) + @test eltype(oprob.u0) == Float32 + @test eltype(eltype(sol.u)) == Float32 +end From bcaa86941e2b7c0bb2b1e36acc0bbce26a0c82c2 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 14 Apr 2025 09:08:48 -0400 Subject: [PATCH 03/24] Update src/systems/abstractsystem.jl --- src/systems/abstractsystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index b28b2d5279..7da0179676 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -733,7 +733,7 @@ function add_initialization_parameters(sys::AbstractSystem) defs = copy(get_defaults(sys)) for ivar in initials if symbolic_type(ivar) == ScalarSymbolic() - defs[ivar] = zero(Float16) + defs[ivar] = false else defs[ivar] = collect(ivar) for scal_ivar in defs[ivar] From 65a326db7a8cb80b4fb4b50fe958cf4569b5d940 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 14 Apr 2025 09:08:53 -0400 Subject: [PATCH 04/24] Update src/systems/abstractsystem.jl --- src/systems/abstractsystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 7da0179676..72f3132889 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -737,7 +737,7 @@ function add_initialization_parameters(sys::AbstractSystem) else defs[ivar] = collect(ivar) for scal_ivar in defs[ivar] - defs[scal_ivar] = zero(Float16) + defs[scal_ivar] = false end end end From aa7f38b25c7282ae93f6c12c63388293c9baa93f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 11:43:14 +0530 Subject: [PATCH 05/24] fix: fix type promotion in `MTKParameters` constructor --- src/systems/parameter_buffer.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 36105c0de8..2cf87769a0 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -1,6 +1,6 @@ symconvert(::Type{Symbolics.Struct{T}}, x) where {T} = convert(T, x) symconvert(::Type{T}, x) where {T} = convert(T, x) -symconvert(::Type{Real}, x::Integer) = convert(Float64, x) +symconvert(::Type{Real}, x::Integer) = x symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x)) struct MTKParameters{T, I, D, C, N, H} From 227a0d8edf5623a4be727ab0b5b7601d4a5e5c3d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 11:43:21 +0530 Subject: [PATCH 06/24] fix: use `Bool(0)` for `default_dd_guess` --- src/systems/nonlinear/initializesystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 9c85b9b324..3b90bfc66a 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -8,7 +8,7 @@ function generate_initializesystem(sys::AbstractTimeDependentSystem; pmap = Dict(), initialization_eqs = [], guesses = Dict(), - default_dd_guess = 0.0, + default_dd_guess = Bool(0), algebraic_only = false, check_units = true, check_defguess = false, name = nameof(sys), extra_metadata = (;), kwargs...) From b228fd957bb9ad4413c982254df85fd16cdad92b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 13:02:37 +0530 Subject: [PATCH 07/24] fix: promote integer valued reals to `Float16` --- src/systems/parameter_buffer.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 2cf87769a0..6035605a02 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -1,6 +1,6 @@ symconvert(::Type{Symbolics.Struct{T}}, x) where {T} = convert(T, x) symconvert(::Type{T}, x) where {T} = convert(T, x) -symconvert(::Type{Real}, x::Integer) = x +symconvert(::Type{Real}, x::Integer) = convert(Float16, x) symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x)) struct MTKParameters{T, I, D, C, N, H} From c814779c365205629edba03cf653e4840d0310cb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 17:08:21 +0530 Subject: [PATCH 08/24] feat: allow specifying `eltype` of numeric buffer in `better_varmap_to_vars` --- src/systems/problem_utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 1db153781a..236a3b31b8 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -362,7 +362,7 @@ Keyword arguments: - `is_initializeprob, guesses`: Used to determine whether the system is missing guesses. """ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector; - tofloat = true, container_type = Array, + tofloat = true, container_type = Array, floatT = Float64, toterm = default_toterm, promotetoconcrete = nothing, check = true, allow_symbolic = false, is_initializeprob = false) isempty(vars) && return nothing @@ -385,6 +385,8 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector; is_initializeprob ? throw(MissingGuessError(missingsyms, missingvals)) : throw(UnexpectedSymbolicValueInVarmap(missingsyms[1], missingvals[1])) end + + vals = floatT.(vals) end if container_type <: Union{AbstractDict, Tuple, Nothing, SciMLBase.NullParameters} From a8d59d65d985938e0495ec97964bacf48d112f87 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 17:08:38 +0530 Subject: [PATCH 09/24] feat: allow specifying floating point type in `get_temporary_value` --- src/systems/problem_utils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 236a3b31b8..8ba5349011 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -535,12 +535,12 @@ function (f::UpdateInitializeprob)(initializeprob, prob) f.setvals(initializeprob, f.getvals(prob)) end -function get_temporary_value(p) +function get_temporary_value(p, floatT = Float64) stype = symtype(unwrap(p)) return if stype == Real - zero(Float64) + zero(floatT) elseif stype <: AbstractArray{Real} - zeros(Float64, size(p)) + zeros(floatT, size(p)) elseif stype <: Real zero(stype) elseif stype <: AbstractArray From 742146eac0f9fe085e648ff280297520d978b22d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 17:08:48 +0530 Subject: [PATCH 10/24] refactor: remove `zero_var` --- src/systems/problem_utils.jl | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 8ba5349011..78f2473f63 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -247,25 +247,6 @@ function recursive_unwrap(x::AbstractDict) return anydict(unwrap(k) => recursive_unwrap(v) for (k, v) in x) end -""" - $(TYPEDSIGNATURES) - -Return the appropriate zero value for a symbolic variable representing a number or array of -numbers. Sized array symbolics return a zero-filled array of matching size. Unsized array -symbolics return an empty array of the appropriate `eltype`. -""" -function zero_var(x::Symbolic{T}) where {V <: Number, T <: Union{V, AbstractArray{V}}} - if Symbolics.isarraysymbolic(x) - if is_sized_array_symbolic(x) - return zeros(eltype(T), size(x)) - else - return T[] - end - else - return zero(T) - end -end - """ $(TYPEDSIGNATURES) From fb1fb93e464d82af6490c77180180684623e9e7a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 17:09:10 +0530 Subject: [PATCH 11/24] refactor: determine floating point type earlier and propagate --- src/systems/problem_utils.jl | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 78f2473f63..877b844a32 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -631,11 +631,12 @@ All other keyword arguments are forwarded to `InitializationProblem`. """ function maybe_build_initialization_problem( sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs, - guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, kwargs...) + guesses, missing_unknowns; implicit_dae = false, + u0_constructor = identity, floatT = Float64, kwargs...) guesses = merge(ModelingToolkit.guesses(sys), todict(guesses)) if t === nothing && is_time_dependent(sys) - t = 0.0 + t = zero(floatT) end initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}( @@ -675,7 +676,7 @@ function maybe_build_initialization_problem( get(op, p, missing) === missing || continue p = unwrap(p) stype = symtype(p) - op[p] = get_temporary_value(p) + op[p] = get_temporary_value(p, floatT) if iscall(p) && operation(p) === getindex arrp = arguments(p)[1] op[arrp] = collect(arrp) @@ -684,7 +685,7 @@ function maybe_build_initialization_problem( if is_time_dependent(sys) for v in missing_unknowns - op[v] = zero_var(v) + op[v] = get_temporary_value(v, floatT) end empty!(missing_unknowns) end @@ -798,12 +799,26 @@ function process_SciMLProblem( op, missing_unknowns, missing_pars = build_operating_point!(sys, u0map, pmap, defs, cmap, dvs, ps) + floatT = Bool + for (k, v) in op + symbolic_type(v) == NotSymbolic() || continue + is_array_of_symbolics(v) && continue + + if v isa AbstractArray + isconcretetype(eltype(v)) || continue + floatT = promote_type(floatT, eltype(v)) + elseif v isa Real && isconcretetype(v) + floatT = promote_type(floatT, typeof(v)) + end + end + floatT = float(floatT) + if !is_time_dependent(sys) || is_initializesystem(sys) add_observed_equations!(u0map, obs) end if u0_constructor === identity && u0Type <: StaticArray u0_constructor = vals -> SymbolicUtils.Code.create_array( - u0Type, eltype(vals), Val(1), Val(length(vals)), vals...) + u0Type, floatT, Val(1), Val(length(vals)), vals...) end if build_initializeprob kws = maybe_build_initialization_problem( @@ -813,7 +828,7 @@ function process_SciMLProblem( warn_cyclic_dependency, check_units = check_initialization_units, circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc, force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete, - u0_constructor) + u0_constructor, floatT) kwargs = merge(kwargs, kws) end @@ -841,7 +856,7 @@ function process_SciMLProblem( evaluate_varmap!(op, dvs; limit = substitution_limit) u0 = better_varmap_to_vars( - op, dvs; tofloat, + op, dvs; tofloat, floatT, container_type = u0Type, allow_symbolic = symbolic_u0, is_initializeprob) if u0 !== nothing @@ -865,7 +880,7 @@ function process_SciMLProblem( end evaluate_varmap!(op, ps; limit = substitution_limit) if is_split(sys) - p = MTKParameters(sys, op) + p = MTKParameters(sys, op; floatT = floatT) else p = better_varmap_to_vars(op, ps; tofloat, container_type = pType) end From 59c6b31dd44b6d7194b49326270a182f8cf76f2d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 17:09:27 +0530 Subject: [PATCH 12/24] refactor: remove type-promotion hack in `InitializationProblem` --- src/systems/diffeqs/abstractodesystem.jl | 26 ------------------------ 1 file changed, 26 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 196c692c67..62ddd12a08 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1541,32 +1541,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem, filter_missing_values!(parammap) u0map = merge(ModelingToolkit.guesses(sys), todict(guesses), u0map) - fullmap = merge(u0map, parammap) - u0T = Union{} - for sym in unknowns(isys) - val = fixpoint_sub(sym, fullmap) - symbolic_type(val) == NotSymbolic() || continue - u0T = promote_type(u0T, typeof(val)) - end - for eq in observed(isys) - # ignore HACK-ed observed equations - symbolic_type(eq.lhs) == ArraySymbolic() && continue - val = fixpoint_sub(eq.lhs, fullmap) - symbolic_type(val) == NotSymbolic() || continue - u0T = promote_type(u0T, typeof(val)) - end - if u0T != Union{} - u0T = eltype(u0T) - u0map = Dict(k => if v === nothing - nothing - elseif symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) - v isa AbstractArray ? u0T.(v) : u0T(v) - else - v - end - for (k, v) in u0map) - end - TProb = if neqs == nunknown && isempty(unassigned_vars) if use_scc && neqs > 0 if is_split(isys) From 27e7efbf73ccbe3ccb48c67e49c2ea8fa990830e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 17:09:56 +0530 Subject: [PATCH 13/24] fix: use `remake` to promote intialization problem --- src/systems/diffeqs/abstractodesystem.jl | 30 +++++++++++++++++------- src/systems/diffeqs/sdesystem.jl | 9 +++++-- src/systems/nonlinear/nonlinearsystem.jl | 9 +++++-- 3 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 62ddd12a08..c6b2eb8a87 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -817,7 +817,9 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = end # Call `remake` so it runs initialization if it is trivial - return remake(ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)) + # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote + # u0 and p of initializeprob + return remake(ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...); u0, p) end get_callback(prob::ODEProblem) = prob.kwargs[:callback] @@ -1040,9 +1042,14 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan end # Call `remake` so it runs initialization if it is trivial - return remake(DAEProblem{iip}( - f, du0, u0, tspan, p; differential_vars = differential_vars, - kwargs..., kwargs1...)) + # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote + # u0 and p of initializeprob + return remake( + DAEProblem{iip}( + f, du0, u0, tspan, p; differential_vars = differential_vars, + kwargs..., kwargs1...); + u0, + p) end function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...) @@ -1088,7 +1095,9 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], kwargs1 = merge(kwargs1, (callback = cbs,)) end # Call `remake` so it runs initialization if it is trivial - return remake(DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)) + # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote + # u0 and p of initializeprob + return remake(DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...); u0, p) end function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...) @@ -1139,9 +1148,14 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], noise_rate_prototype = zeros(eltype(u0), size(noiseeqs)) end # Call `remake` so it runs initialization if it is trivial - return remake(SDDEProblem{iip}(f, f.g, u0, h, tspan, p; - noise_rate_prototype = - noise_rate_prototype, kwargs1..., kwargs...)) + # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote + # u0 and p of initializeprob + return remake( + SDDEProblem{iip}(f, f.g, u0, h, tspan, p; + noise_rate_prototype = + noise_rate_prototype, kwargs1..., kwargs...); + u0, + p) end """ diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 3fa1302630..f77b245b8a 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -818,8 +818,13 @@ function DiffEqBase.SDEProblem{iip, specialize}( kwargs = filter_kwargs(kwargs) # Call `remake` so it runs initialization if it is trivial - return remake(SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise, - noise_rate_prototype = noise_rate_prototype, kwargs...)) + # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote + # u0 and p of initializeprob + return remake( + SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise, + noise_rate_prototype = noise_rate_prototype, kwargs...); + u0, + p) end function DiffEqBase.SDEProblem(sys::ODESystem, args...; kwargs...) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 856822492b..c87d71abf5 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -557,7 +557,9 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map, check_length, kwargs...) pt = something(get_metadata(sys), StandardNonlinearProblem()) # Call `remake` so it runs initialization if it is trivial - return remake(NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...)) + # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote + # u0 and p of initializeprob + return remake(NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...); u0, p) end function DiffEqBase.NonlinearProblem(sys::AbstractODESystem, args...; kwargs...) @@ -591,7 +593,10 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma check_length, kwargs...) pt = something(get_metadata(sys), StandardNonlinearProblem()) # Call `remake` so it runs initialization if it is trivial - return remake(NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)) + # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote + # u0 and p of initializeprob + return remake( + NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...); u0, p) end const TypeT = Union{DataType, UnionAll} From 8d2482c5b61ef18ee38a9c9da561d01ca6da7c57 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 17:10:05 +0530 Subject: [PATCH 14/24] fix: fix `symconvert` --- src/systems/parameter_buffer.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 6035605a02..e1bb83238e 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -1,5 +1,5 @@ symconvert(::Type{Symbolics.Struct{T}}, x) where {T} = convert(T, x) -symconvert(::Type{T}, x) where {T} = convert(T, x) +symconvert(::Type{T}, x::V) where {T, V} = convert(promote_type(T, V), x) symconvert(::Type{Real}, x::Integer) = convert(Float16, x) symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x)) From 38724b2e38e180a8b8dae14cddf86a789dd376a5 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 17:10:20 +0530 Subject: [PATCH 15/24] fix: allow specifying floating point type in `MTKParameters` --- src/systems/parameter_buffer.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index e1bb83238e..529be4eaf8 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -28,7 +28,7 @@ the default behavior). """ function MTKParameters( sys::AbstractSystem, p, u0 = Dict(); tofloat = false, - t0 = nothing, substitution_limit = 1000) + t0 = nothing, substitution_limit = 1000, floatT = nothing) ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing get_index_cache(sys) else @@ -111,6 +111,9 @@ function MTKParameters( if ctype <: FnType ctype = fntype_to_function_type(ctype) end + if ctype == Real && floatT !== nothing + ctype = floatT + end val = symconvert(ctype, val) done = set_value(sym, val) if !done && Symbolics.isarraysymbolic(sym) From 57d7da10dc4fe8dd09ae690af84d21ba149c8b2e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 18:04:37 +0530 Subject: [PATCH 16/24] Revert "fix: use `remake` to promote intialization problem" This reverts commit 27e7efbf73ccbe3ccb48c67e49c2ea8fa990830e. --- src/systems/diffeqs/abstractodesystem.jl | 30 +++++++----------------- src/systems/diffeqs/sdesystem.jl | 9 ++----- src/systems/nonlinear/nonlinearsystem.jl | 9 ++----- 3 files changed, 12 insertions(+), 36 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index c6b2eb8a87..62ddd12a08 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -817,9 +817,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = end # Call `remake` so it runs initialization if it is trivial - # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote - # u0 and p of initializeprob - return remake(ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...); u0, p) + return remake(ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)) end get_callback(prob::ODEProblem) = prob.kwargs[:callback] @@ -1042,14 +1040,9 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan end # Call `remake` so it runs initialization if it is trivial - # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote - # u0 and p of initializeprob - return remake( - DAEProblem{iip}( - f, du0, u0, tspan, p; differential_vars = differential_vars, - kwargs..., kwargs1...); - u0, - p) + return remake(DAEProblem{iip}( + f, du0, u0, tspan, p; differential_vars = differential_vars, + kwargs..., kwargs1...)) end function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...) @@ -1095,9 +1088,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], kwargs1 = merge(kwargs1, (callback = cbs,)) end # Call `remake` so it runs initialization if it is trivial - # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote - # u0 and p of initializeprob - return remake(DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...); u0, p) + return remake(DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)) end function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...) @@ -1148,14 +1139,9 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], noise_rate_prototype = zeros(eltype(u0), size(noiseeqs)) end # Call `remake` so it runs initialization if it is trivial - # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote - # u0 and p of initializeprob - return remake( - SDDEProblem{iip}(f, f.g, u0, h, tspan, p; - noise_rate_prototype = - noise_rate_prototype, kwargs1..., kwargs...); - u0, - p) + return remake(SDDEProblem{iip}(f, f.g, u0, h, tspan, p; + noise_rate_prototype = + noise_rate_prototype, kwargs1..., kwargs...)) end """ diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index f77b245b8a..3fa1302630 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -818,13 +818,8 @@ function DiffEqBase.SDEProblem{iip, specialize}( kwargs = filter_kwargs(kwargs) # Call `remake` so it runs initialization if it is trivial - # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote - # u0 and p of initializeprob - return remake( - SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise, - noise_rate_prototype = noise_rate_prototype, kwargs...); - u0, - p) + return remake(SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise, + noise_rate_prototype = noise_rate_prototype, kwargs...)) end function DiffEqBase.SDEProblem(sys::ODESystem, args...; kwargs...) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index c87d71abf5..856822492b 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -557,9 +557,7 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map, check_length, kwargs...) pt = something(get_metadata(sys), StandardNonlinearProblem()) # Call `remake` so it runs initialization if it is trivial - # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote - # u0 and p of initializeprob - return remake(NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...); u0, p) + return remake(NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...)) end function DiffEqBase.NonlinearProblem(sys::AbstractODESystem, args...; kwargs...) @@ -593,10 +591,7 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma check_length, kwargs...) pt = something(get_metadata(sys), StandardNonlinearProblem()) # Call `remake` so it runs initialization if it is trivial - # Pass `u0` and `p` to run `ReconstructInitializeprob` which will promote - # u0 and p of initializeprob - return remake( - NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...); u0, p) + return remake(NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)) end const TypeT = Union{DataType, UnionAll} From 9e6e8942ec8ab21548043a05801f2317bfc699af Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 18:17:09 +0530 Subject: [PATCH 17/24] fix: use `remake_initialization_data` to promote the initialization problem --- src/systems/problem_utils.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 877b844a32..8a343baf7b 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -896,6 +896,16 @@ function process_SciMLProblem( du0 = nothing end + if build_initializeprob + t0 = t + if is_time_dependent(sys) && t0 === nothing + t0 = zero(floatT) + end + initialization_data = SciMLBase.remake_initialization_data( + kwargs.initialization_data, kwargs, u0, t0, p, u0, p) + kwargs = merge(kwargs,) + end + f = constructor(sys, dvs, ps, u0; p = p, eval_expression = eval_expression, eval_module = eval_module, From 662217b28fc301b65d7e18a3a9a9f05e56ff22fa Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 20:52:57 +0530 Subject: [PATCH 18/24] fix: promote initializeprob in `remake_initialization_data` --- src/systems/nonlinear/initializesystem.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 3b90bfc66a..4d08042533 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -649,7 +649,8 @@ function SciMLBase.remake_initialization_data( kws = maybe_build_initialization_problem( sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; use_scc, initialization_eqs, allow_incomplete = true) - return get(kws, :initialization_data, nothing) + + return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp) end function SciMLBase.late_binding_update_u0_p( From f9cdd1684e852f0a6cb181dfa9b2b06a56c42289 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 20:53:18 +0530 Subject: [PATCH 19/24] fix: respect `floatT` in `maybe_build_initialization_problem` --- src/systems/problem_utils.jl | 38 +++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 8a343baf7b..4cc03b52ab 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -641,6 +641,22 @@ function maybe_build_initialization_problem( initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}( sys, t, u0map, pmap; guesses, kwargs...) + if state_values(initializeprob) !== nothing + initializeprob = remake(initializeprob; u0 = floatT.(state_values(initializeprob))) + end + initp = parameter_values(initializeprob) + if is_split(sys) + buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), initp) + initp = repack(floatT.(buffer)) + buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Initials(), initp) + initp = repack(floatT.(buffer)) + elseif initp isa AbstractArray + initp′ = similar(initp, floatT) + copyto!(initp′, initp) + initp = initp′ + end + initializeprob = remake(initializeprob; p = initp) + meta = get_metadata(initializeprob.f.sys) if is_time_dependent(sys) @@ -800,15 +816,19 @@ function process_SciMLProblem( u0map, pmap, defs, cmap, dvs, ps) floatT = Bool - for (k, v) in op - symbolic_type(v) == NotSymbolic() || continue - is_array_of_symbolics(v) && continue - - if v isa AbstractArray - isconcretetype(eltype(v)) || continue - floatT = promote_type(floatT, eltype(v)) - elseif v isa Real && isconcretetype(v) - floatT = promote_type(floatT, typeof(v)) + if u0Type <: AbstractArray && isconcretetype(eltype(u0Type)) && eltype(u0Type) <: Real + floatT = eltype(u0Type) + else + for (k, v) in op + symbolic_type(v) == NotSymbolic() || continue + is_array_of_symbolics(v) && continue + + if v isa AbstractArray + isconcretetype(eltype(v)) || continue + floatT = promote_type(floatT, eltype(v)) + elseif v isa Real && isconcretetype(v) + floatT = promote_type(floatT, typeof(v)) + end end end floatT = float(floatT) From 9fe75863d927ba29a556fc1c8bad2ad67cade42d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 21:05:00 +0530 Subject: [PATCH 20/24] fix: respect `tofloat` when using `floatT`, default to not promoting --- src/systems/problem_utils.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 4cc03b52ab..96444bc31f 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -343,7 +343,7 @@ Keyword arguments: - `is_initializeprob, guesses`: Used to determine whether the system is missing guesses. """ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector; - tofloat = true, container_type = Array, floatT = Float64, + tofloat = true, container_type = Array, floatT = Nothing, toterm = default_toterm, promotetoconcrete = nothing, check = true, allow_symbolic = false, is_initializeprob = false) isempty(vars) && return nothing @@ -366,8 +366,9 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector; is_initializeprob ? throw(MissingGuessError(missingsyms, missingvals)) : throw(UnexpectedSymbolicValueInVarmap(missingsyms[1], missingvals[1])) end - - vals = floatT.(vals) + if tofloat && !(floatT == Nothing) + vals = floatT.(vals) + end end if container_type <: Union{AbstractDict, Tuple, Nothing, SciMLBase.NullParameters} From a8e2495964d04cd7818c43b841aea434f7a1d5c4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 15 Apr 2025 21:05:14 +0530 Subject: [PATCH 21/24] fix: do not require `isconcretetype` when discovering `floatT` --- src/systems/problem_utils.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 96444bc31f..51c018bacc 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -817,7 +817,7 @@ function process_SciMLProblem( u0map, pmap, defs, cmap, dvs, ps) floatT = Bool - if u0Type <: AbstractArray && isconcretetype(eltype(u0Type)) && eltype(u0Type) <: Real + if u0Type <: AbstractArray && eltype(u0Type) <: Real floatT = eltype(u0Type) else for (k, v) in op @@ -825,9 +825,8 @@ function process_SciMLProblem( is_array_of_symbolics(v) && continue if v isa AbstractArray - isconcretetype(eltype(v)) || continue floatT = promote_type(floatT, eltype(v)) - elseif v isa Real && isconcretetype(v) + elseif v isa Real floatT = promote_type(floatT, typeof(v)) end end From 6023efbad37a8e0326ab7a5e8990feb3296959dd Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 16 Apr 2025 10:42:05 +0530 Subject: [PATCH 22/24] refactor: add `float_type_from_varmap` --- src/systems/problem_utils.jl | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 51c018bacc..0750585905 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -713,6 +713,26 @@ function maybe_build_initialization_problem( initializeprobpmap)) end +""" + $(TYPEDSIGNATURES) + +Calculate the floating point type to use from the given `varmap` by looking at variables +with a constant value. +""" +function float_type_from_varmap(varmap, floatT = Bool) + for (k, v) in varmap + symbolic_type(v) == NotSymbolic() || continue + is_array_of_symbolics(v) && continue + + if v isa AbstractArray + floatT = promote_type(floatT, eltype(v)) + elseif v isa Real + floatT = promote_type(floatT, typeof(v)) + end + end + return float(floatT) +end + """ $(TYPEDSIGNATURES) @@ -818,20 +838,10 @@ function process_SciMLProblem( floatT = Bool if u0Type <: AbstractArray && eltype(u0Type) <: Real - floatT = eltype(u0Type) + floatT = float(eltype(u0Type)) else - for (k, v) in op - symbolic_type(v) == NotSymbolic() || continue - is_array_of_symbolics(v) && continue - - if v isa AbstractArray - floatT = promote_type(floatT, eltype(v)) - elseif v isa Real - floatT = promote_type(floatT, typeof(v)) - end - end + floatT = float_type_from_varmap(op, floatT) end - floatT = float(floatT) if !is_time_dependent(sys) || is_initializesystem(sys) add_observed_equations!(u0map, obs) From ead8682fd88b02ff019d6e244e1100c90fc72631 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 16 Apr 2025 10:42:16 +0530 Subject: [PATCH 23/24] fix: infer `floatT` in `MTKParameters` constructor if not provided --- src/systems/parameter_buffer.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 529be4eaf8..5fb830290f 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -56,6 +56,10 @@ function MTKParameters( op[get_iv(sys)] = t0 end + if floatT === nothing + floatT = float(float_type_from_varmap(op)) + end + isempty(missing_pars) || throw(MissingParametersError(collect(missing_pars))) evaluate_varmap!(op, ps; limit = substitution_limit) From cd19c7edab02ec27ed1015230c217ac9cff3a04d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 16 Apr 2025 10:42:26 +0530 Subject: [PATCH 24/24] fix: infer `floatT` in `remake_initialization_data` --- src/systems/nonlinear/initializesystem.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 4d08042533..ec25b9b660 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -646,9 +646,10 @@ function SciMLBase.remake_initialization_data( op, missing_unknowns, missing_pars = build_operating_point!(sys, u0map, pmap, defs, cmap, dvs, ps) + floatT = float_type_from_varmap(op) kws = maybe_build_initialization_problem( sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; - use_scc, initialization_eqs, allow_incomplete = true) + use_scc, initialization_eqs, floatT, allow_incomplete = true) return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp) end