Skip to content

Commit 23323bc

Browse files
Merge pull request #2399 from SciML/staticarrays
Preserve staticarrays in the problem construction
2 parents 7e0917f + a0adfe9 commit 23323bc

File tree

5 files changed

+47
-3
lines changed

5 files changed

+47
-3
lines changed

src/parameters.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ function split_parameters_by_type(ps)
100100
end
101101
tighten_types = x -> identity.(x)
102102
split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs))
103+
104+
if ps isa StaticArray
105+
parrs = map(x-> SArray{Tuple{size(x)...}}(x), split_ps)
106+
split_ps = SArray{Tuple{size(parrs)...}}(parrs)
107+
end
103108
if length(split_ps) == 1 #Tuple not needed, only 1 type
104109
return split_ps[1], split_idxs
105110
else

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,10 @@ function DiffEqBase.ODEProblem(sys::AbstractODESystem, args...; kwargs...)
912912
ODEProblem{true}(sys, args...; kwargs...)
913913
end
914914

915+
function DiffEqBase.ODEProblem(sys::AbstractODESystem, u0map::StaticArray, args...; kwargs...)
916+
ODEProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
917+
end
918+
915919
function DiffEqBase.ODEProblem{true}(sys::AbstractODESystem, args...; kwargs...)
916920
ODEProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
917921
end

src/variables.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,16 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
7575
end
7676
end
7777

78-
# T = typeof(varmap)
79-
# We respect the input type (feature removed, not needed with Tuple support)
78+
# We respect the input type if it's a static array
79+
# otherwise canonicalize to a normal array
8080
# container_type = T <: Union{Dict,Tuple} ? Array : T
81-
container_type = Array
81+
if varmap isa StaticArray
82+
container_type = typeof(varmap)
83+
else
84+
container_type = Array
85+
end
86+
87+
@show container_type
8288

8389
vals = if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs
8490
varmap = todict(varmap)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ end
3535
@safetestset "Reduction Test" include("reduction.jl")
3636
@safetestset "Split Parameters Test" include("split_parameters.jl")
3737
@safetestset "ODAEProblem Test" include("odaeproblem.jl")
38+
@safetestset "StaticArrays Test" include("static_arrays.jl")
3839
@safetestset "Components Test" include("components.jl")
3940
@safetestset "Model Parsing Test" include("model_parsing.jl")
4041
@safetestset "print_tree" include("print_tree.jl")

test/static_arrays.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using ModelingToolkit, SciMLBase, StaticArrays, Test
2+
3+
@parameters σ ρ β
4+
@variables t x(t) y(t) z(t)
5+
D = Differential(t)
6+
7+
eqs = [D(D(x)) ~ σ * (y - x),
8+
D(y) ~ x *- z) - y,
9+
D(z) ~ x * y - β * z]
10+
11+
@named sys = ODESystem(eqs)
12+
sys = structural_simplify(sys)
13+
14+
u0 = @SVector [D(x) => 2.0,
15+
x => 1.0,
16+
y => 0.0,
17+
z => 0.0]
18+
19+
p = @SVector=> 28.0,
20+
ρ => 10.0,
21+
β => 8 / 3]
22+
23+
tspan = (0.0, 100.0)
24+
prob_mtk = ODEProblem(sys, u0, tspan, p)
25+
26+
@test !SciMLBase.isinplace(prob_mtk)
27+
@test prob_mtk.u0 isa SArray
28+
@test prob_mtk.p isa SArray

0 commit comments

Comments
 (0)