Skip to content

Commit 9611e18

Browse files
Merge pull request #3040 from AayushSabharwal/as/fix-build-obs
fix: fix missing `unwrap` in `build_explicit_observed_function`
2 parents 58ac094 + 986b6bc commit 9611e18

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,14 +487,14 @@ function build_explicit_observed_function(sys, ts;
487487
end
488488
_ps = ps
489489
if ps isa Tuple
490-
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
490+
ps = DestructuredArgs.(unwrap.(ps), inbounds = !checkbounds)
491491
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
492-
ps = DestructuredArgs.(reorder_parameters(get_index_cache(sys), ps))
492+
ps = DestructuredArgs.(reorder_parameters(get_index_cache(sys), unwrap.(ps)))
493493
if isempty(ps) && inputs !== nothing
494494
ps = (:EMPTY,)
495495
end
496496
else
497-
ps = (DestructuredArgs(ps, inbounds = !checkbounds),)
497+
ps = (DestructuredArgs(unwrap.(ps), inbounds = !checkbounds),)
498498
end
499499
dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds)
500500
if inputs === nothing

test/odesystem.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,3 +1367,23 @@ end
13671367
@test length(ModelingToolkit.guesses(sys2)) == 3
13681368
@test ModelingToolkit.guesses(sys2)[p5] == 10.0
13691369
end
1370+
1371+
@testset "Observed with inputs" begin
1372+
@variables u(t)[1:2] x(t)[1:2] o(t)[1:2]
1373+
@parameters p[1:4]
1374+
1375+
eqs = [D(u[1]) ~ p[1] * u[1] - p[2] * u[1] * u[2] + x[1] + 0.1
1376+
D(u[2]) ~ p[4] * u[1] * u[2] - p[3] * u[2] - x[2]
1377+
o[1] ~ sum(p) * sum(u)
1378+
o[2] ~ sum(p) * sum(x)]
1379+
1380+
@named sys = ODESystem(eqs, t, [u..., x..., o], [p...])
1381+
sys1, = structural_simplify(sys, ([x...], [o...]), split = false)
1382+
1383+
@test_nowarn ModelingToolkit.build_explicit_observed_function(sys1, u; inputs = [x...])
1384+
1385+
obsfn = ModelingToolkit.build_explicit_observed_function(
1386+
sys1, u + x + p[1:2]; inputs = [x...])
1387+
1388+
@test obsfn(ones(2), 2ones(2), 3ones(4), 4.0) == 6ones(2)
1389+
end

0 commit comments

Comments
 (0)