From 0618a68c5ed0f7888bca2f439bc3c04e9c962e48 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 16 Apr 2025 13:36:40 +0530 Subject: [PATCH 1/4] fix: scalarize `Initial` parameters for `split = false` systems --- src/systems/abstractsystem.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 72f3132889..4ad08bb1f4 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -696,7 +696,7 @@ end supports_initialization(sys::AbstractSystem) = true -function add_initialization_parameters(sys::AbstractSystem) +function add_initialization_parameters(sys::AbstractSystem; split = true) @assert !has_systems(sys) || isempty(get_systems(sys)) supports_initialization(sys) || return sys is_initializesystem(sys) && return sys @@ -711,7 +711,7 @@ function add_initialization_parameters(sys::AbstractSystem) obs, eqs = unhack_observed(observed(sys), eqs) for x in Iterators.flatten((unknowns(sys), Iterators.map(eq -> eq.lhs, obs))) x = unwrap(x) - if iscall(x) && operation(x) == getindex + if iscall(x) && operation(x) == getindex && split push!(all_initialvars, arguments(x)[1]) else push!(all_initialvars, x) @@ -788,7 +788,7 @@ function complete( end sys = newsys if add_initial_parameters - sys = add_initialization_parameters(sys) + sys = add_initialization_parameters(sys; split) end end if split && has_index_cache(sys) From 8147aabad454263c2a5279394c3141145f9d2002 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 16 Apr 2025 13:36:54 +0530 Subject: [PATCH 2/4] fix: recognize scalarized `Initial` parameters in `parameters` --- src/systems/abstractsystem.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 4ad08bb1f4..29338d0722 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1465,7 +1465,11 @@ function parameters(sys::AbstractSystem; initial_parameters = false) result = unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))]) if !initial_parameters && !is_initializesystem(sys) - filter!(x -> !iscall(x) || !isa(operation(x), Initial), result) + filter!(result) do sym + return !(isoperator(sym, Initial) || + iscall(sym) && operation(sym) == getindex && + isoperator(arguments(sym)[1], Initial)) + end end return result end From ec4fd42d6c31888f374d4852971683abecd33a54 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 16 Apr 2025 13:37:07 +0530 Subject: [PATCH 3/4] test: test mix of array initials and scalar parameters with `split = false` --- test/initial_values.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/initial_values.jl b/test/initial_values.jl index 01a053a9bf..79b6b8e067 100644 --- a/test/initial_values.jl +++ b/test/initial_values.jl @@ -265,3 +265,19 @@ end @test eltype(oprob.u0) == Float32 @test eltype(eltype(sol.u)) == Float32 end + +@testset "Array initials and scalar parameters with `split = false`" begin + @variables x(t)[1:2] + @parameters p + @mtkbuild sys=ODESystem([D(x[1]) ~ x[1], D(x[2]) ~ x[2] + p], t) split=false + ps = Set(parameters(sys; initial_parameters = true)) + @test length(ps) == 5 + for i in 1:2 + @test Initial(x[i]) in ps + @test Initial(D(x[i])) in ps + end + @test p in ps + prob = ODEProblem(sys, [x => ones(2)], (0.0, 1.0), [p => 1.0]) + @test prob.p isa Vector{Float64} + @test length(prob.p) == 5 +end From 3dea97cd9430bc9ae75c81b96921587d4eec2650 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 16 Apr 2025 16:36:53 +0530 Subject: [PATCH 4/4] test: account for new scalarized initials in tests --- test/odesystem.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/odesystem.jl b/test/odesystem.jl index a45ed99f35..4b76da6e9d 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1305,7 +1305,7 @@ end ps = zeros(8) setp(sys2, x)(ps, 2ones(2)) setp(sys2, p)(ps, 2ones(2, 2)) - @test_nowarn fn2(ones(4), 2ones(6), 4.0) + @test_nowarn fn2(ones(4), 2ones(14), 4.0) end # https://github.com/SciML/ModelingToolkit.jl/issues/2969 @@ -1416,7 +1416,7 @@ end obsfn = ModelingToolkit.build_explicit_observed_function( sys1, u + x + p[1:2]; inputs = [x...]) - @test obsfn(ones(2), 2ones(2), 3ones(4), 4.0) == 6ones(2) + @test obsfn(ones(2), 2ones(2), 3ones(12), 4.0) == 6ones(2) end @testset "Passing `nothing` to `u0`" begin