Skip to content

fix: scalarize Initial parameters for split = false systems #3569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ end

supports_initialization(sys::AbstractSystem) = true

function add_initialization_parameters(sys::AbstractSystem)
function add_initialization_parameters(sys::AbstractSystem; split = true)
@assert !has_systems(sys) || isempty(get_systems(sys))
supports_initialization(sys) || return sys
is_initializesystem(sys) && return sys
Expand All @@ -711,7 +711,7 @@ function add_initialization_parameters(sys::AbstractSystem)
obs, eqs = unhack_observed(observed(sys), eqs)
for x in Iterators.flatten((unknowns(sys), Iterators.map(eq -> eq.lhs, obs)))
x = unwrap(x)
if iscall(x) && operation(x) == getindex
if iscall(x) && operation(x) == getindex && split
push!(all_initialvars, arguments(x)[1])
else
push!(all_initialvars, x)
Expand Down Expand Up @@ -788,7 +788,7 @@ function complete(
end
sys = newsys
if add_initial_parameters
sys = add_initialization_parameters(sys)
sys = add_initialization_parameters(sys; split)
end
end
if split && has_index_cache(sys)
Expand Down Expand Up @@ -1465,7 +1465,11 @@ function parameters(sys::AbstractSystem; initial_parameters = false)
result = unique(isempty(systems) ? ps :
[ps; reduce(vcat, namespace_parameters.(systems))])
if !initial_parameters && !is_initializesystem(sys)
filter!(x -> !iscall(x) || !isa(operation(x), Initial), result)
filter!(result) do sym
return !(isoperator(sym, Initial) ||
iscall(sym) && operation(sym) == getindex &&
isoperator(arguments(sym)[1], Initial))
end
end
return result
end
Expand Down
16 changes: 16 additions & 0 deletions test/initial_values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,19 @@ end
@test eltype(oprob.u0) == Float32
@test eltype(eltype(sol.u)) == Float32
end

@testset "Array initials and scalar parameters with `split = false`" begin
@variables x(t)[1:2]
@parameters p
@mtkbuild sys=ODESystem([D(x[1]) ~ x[1], D(x[2]) ~ x[2] + p], t) split=false
ps = Set(parameters(sys; initial_parameters = true))
@test length(ps) == 5
for i in 1:2
@test Initial(x[i]) in ps
@test Initial(D(x[i])) in ps
end
@test p in ps
prob = ODEProblem(sys, [x => ones(2)], (0.0, 1.0), [p => 1.0])
@test prob.p isa Vector{Float64}
@test length(prob.p) == 5
end
4 changes: 2 additions & 2 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ end
ps = zeros(8)
setp(sys2, x)(ps, 2ones(2))
setp(sys2, p)(ps, 2ones(2, 2))
@test_nowarn fn2(ones(4), 2ones(6), 4.0)
@test_nowarn fn2(ones(4), 2ones(14), 4.0)
end

# https://github.com/SciML/ModelingToolkit.jl/issues/2969
Expand Down Expand Up @@ -1416,7 +1416,7 @@ end
obsfn = ModelingToolkit.build_explicit_observed_function(
sys1, u + x + p[1:2]; inputs = [x...])

@test obsfn(ones(2), 2ones(2), 3ones(4), 4.0) == 6ones(2)
@test obsfn(ones(2), 2ones(2), 3ones(12), 4.0) == 6ones(2)
end

@testset "Passing `nothing` to `u0`" begin
Expand Down
Loading