Skip to content

Commit 02cbd76

Browse files
Merge pull request #3202 from AayushSabharwal/as/tuple-observed
feat: support directly generating observed functions for tuples
2 parents 9cf859f + 8edc6b0 commit 02cbd76

File tree

4 files changed

+36
-3
lines changed

4 files changed

+36
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ SimpleNonlinearSolve = "0.1.0, 1, 2"
131131
SparseArrays = "1"
132132
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
133133
StaticArrays = "0.10, 0.11, 0.12, 1.0"
134-
SymbolicIndexingInterface = "0.3.31"
134+
SymbolicIndexingInterface = "0.3.35"
135135
SymbolicUtils = "3.7"
136136
Symbolics = "6.15.4"
137137
URIs = "1"

src/systems/abstractsystem.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,8 @@ function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
818818
!is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
819819
end
820820

821+
SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true
822+
821823
function SymbolicIndexingInterface.observed(
822824
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
823825
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
@@ -827,7 +829,8 @@ function SymbolicIndexingInterface.observed(
827829
throw(ArgumentError("Symbol $sym does not exist in the system"))
828830
end
829831
sym = _sym
830-
elseif sym isa AbstractArray && symbolic_type(sym) isa NotSymbolic &&
832+
elseif (sym isa Tuple ||
833+
(sym isa AbstractArray && symbolic_type(sym) isa NotSymbolic)) &&
831834
any(x -> x isa Symbol, sym)
832835
sym = map(sym) do s
833836
if s isa Symbol

src/systems/diffeqs/odesystem.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,10 @@ function build_explicit_observed_function(sys, ts;
427427
param_only = false,
428428
op = Operator,
429429
throw = true)
430+
is_tuple = ts isa Tuple
431+
if is_tuple
432+
ts = collect(ts)
433+
end
430434
if (isscalar = symbolic_type(ts) !== NotSymbolic())
431435
ts = [ts]
432436
end
@@ -573,9 +577,16 @@ function build_explicit_observed_function(sys, ts;
573577

574578
# Need to keep old method of building the function since it uses `output_type`,
575579
# which can't be provided to `build_function`
580+
return_value = if isscalar
581+
ts[1]
582+
elseif is_tuple
583+
MakeTuple(Tuple(ts))
584+
else
585+
MakeArray(ts, output_type)
586+
end
576587
oop_fn = Func(args, [],
577588
pre(Let(obsexprs,
578-
isscalar ? ts[1] : MakeArray(ts, output_type),
589+
return_value,
579590
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
580591
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)
581592

test/symbolic_indexing_interface.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using SciMLStructures: Tunable
88
eqs = [D(x) ~ a * y + t, D(y) ~ b * t]
99
@named odesys = ODESystem(eqs, t, [x, y], [a, b]; observed = [xy ~ x + y])
1010
odesys = complete(odesys)
11+
@test SymbolicIndexingInterface.supports_tuple_observed(odesys)
1112
@test all(is_variable.((odesys,), [x, y, 1, 2, :x, :y]))
1213
@test all(.!is_variable.((odesys,), [a, b, t, 3, 0, :a, :b]))
1314
@test variable_index.((odesys,), [x, y, a, b, t, 1, 2, :x, :y, :a, :b]) ==
@@ -33,6 +34,14 @@ using SciMLStructures: Tunable
3334
@test default_values(odesys)[y] == 2.0
3435
@test isequal(default_values(odesys)[xy], x + y)
3536

37+
prob = ODEProblem(odesys, [], (0.0, 1.0), [a => 1.0, b => 2.0])
38+
getter = getu(odesys, (x + 1, x + 2))
39+
@test getter(prob) isa Tuple
40+
@test_nowarn @inferred getter(prob)
41+
getter = getp(odesys, (a + 1, a + 2))
42+
@test getter(prob) isa Tuple
43+
@test_nowarn @inferred getter(prob)
44+
3645
@named odesys = ODESystem(
3746
eqs, t, [x, y], [a, b]; defaults = [xy => 3.0], observed = [xy ~ x + y])
3847
odesys = complete(odesys)
@@ -99,6 +108,7 @@ end
99108
0 ~ x * y - β * z]
100109
@named ns = NonlinearSystem(eqs, [x, y, z], [σ, ρ, β])
101110
ns = complete(ns)
111+
@test SymbolicIndexingInterface.supports_tuple_observed(ns)
102112
@test !is_time_dependent(ns)
103113
ps = ModelingToolkit.MTKParameters(ns, [σ => 1.0, ρ => 2.0, β => 3.0])
104114
pobs = parameter_observed(ns, σ + ρ)
@@ -107,6 +117,15 @@ end
107117
pobs = parameter_observed(ns, [σ + ρ, ρ + β])
108118
@test isempty(get_all_timeseries_indexes(ns, [σ + ρ, ρ + β]))
109119
@test pobs(ps) == [3.0, 5.0]
120+
121+
prob = NonlinearProblem(
122+
ns, [x => 1.0, y => 2.0, z => 3.0], [σ => 1.0, ρ => 2.0, β => 3.0])
123+
getter = getu(ns, (x + 1, x + 2))
124+
@test getter(prob) isa Tuple
125+
@test_nowarn @inferred getter(prob)
126+
getter = getp(ns, (σ + 1, σ + 2))
127+
@test getter(prob) isa Tuple
128+
@test_nowarn @inferred getter(prob)
110129
end
111130

112131
@testset "PDESystem" begin

0 commit comments

Comments
 (0)