Skip to content

Commit f299353

Browse files
Merge pull request #2605 from AayushSabharwal/as/fix-array-observed
fix: fix callback codegen, observed eqs with non-scalarized symbolic arrays
2 parents 2a0938c + d927e04 commit f299353

File tree

9 files changed

+113
-34
lines changed

9 files changed

+113
-34
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,43 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
551551
end
552552

553553
sys = state.sys
554+
555+
obs_sub = dummy_sub
556+
for eq in neweqs
557+
isdiffeq(eq) || continue
558+
obs_sub[eq.lhs] = eq.rhs
559+
end
560+
# TODO: compute the dependency correctly so that we don't have to do this
561+
obs = [fast_substitute(observed(sys), obs_sub); subeqs]
562+
563+
# HACK: Substitute non-scalarized symbolic arrays of observed variables
564+
# E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations
565+
# ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled
566+
# by the topological sorting and dependency identification pieces
567+
obs_arr_subs = Dict()
568+
569+
for eq in obs
570+
lhs = eq.lhs
571+
istree(lhs) || continue
572+
operation(lhs) === getindex || continue
573+
Symbolics.shape(lhs) !== Symbolics.Unknown() || continue
574+
arg1 = arguments(lhs)[1]
575+
haskey(obs_arr_subs, arg1) && continue
576+
obs_arr_subs[arg1] = [arg1[i] for i in eachindex(arg1)]
577+
end
578+
for i in eachindex(neweqs)
579+
neweqs[i] = fast_substitute(neweqs[i], obs_arr_subs; operator = Symbolics.Operator)
580+
end
581+
for i in eachindex(obs)
582+
obs[i] = fast_substitute(obs[i], obs_arr_subs; operator = Symbolics.Operator)
583+
end
584+
for i in eachindex(subeqs)
585+
subeqs[i] = fast_substitute(subeqs[i], obs_arr_subs; operator = Symbolics.Operator)
586+
end
587+
554588
@set! sys.eqs = neweqs
589+
@set! sys.observed = obs
590+
555591
unknowns = Any[v
556592
for (i, v) in enumerate(fullvars)
557593
if diff_to_var[i] === nothing && ispresent(i)]
@@ -563,15 +599,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
563599
@set! sys.unknowns = unknowns
564600
@set! sys.substitutions = Substitutions(subeqs, deps)
565601

566-
obs_sub = dummy_sub
567-
for eq in equations(sys)
568-
isdiffeq(eq) || continue
569-
obs_sub[eq.lhs] = eq.rhs
570-
end
571-
# TODO: compute the dependency correctly so that we don't have to do this
572-
obs = [fast_substitute(observed(sys), obs_sub); subeqs]
573-
@set! sys.observed = obs
574-
575602
# Only makes sense for time-dependent
576603
# TODO: generalize to SDE
577604
if sys isa ODESystem

src/systems/abstractsystem.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,15 @@ generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
153153
154154
Generate a function to evaluate `exprs`. `exprs` is a symbolic expression or
155155
array of symbolic expression involving symbolic variables in `sys`. The symbolic variables
156-
may be subsetted using `dvs` and `ps`. All `kwargs` except `postprocess_fbody` and `states`
157-
are passed to the internal [`build_function`](@ref) call. The returned function can be called
158-
as `f(u, p, t)` or `f(du, u, p, t)` for time-dependent systems and `f(u, p)` or `f(du, u, p)`
159-
for time-independent systems. If `split=true` (the default) was passed to [`complete`](@ref),
156+
may be subsetted using `dvs` and `ps`. All `kwargs` are passed to the internal
157+
[`build_function`](@ref) call. The returned function can be called as `f(u, p, t)` or
158+
`f(du, u, p, t)` for time-dependent systems and `f(u, p)` or `f(du, u, p)` for
159+
time-independent systems. If `split=true` (the default) was passed to [`complete`](@ref),
160160
[`structural_simplify`](@ref) or [`@mtkbuild`](@ref), `p` is expected to be an `MTKParameters`
161161
object.
162162
"""
163163
function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
164-
ps = parameters(sys); wrap_code = nothing, kwargs...)
164+
ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing, kwargs...)
165165
if !iscomplete(sys)
166166
error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
167167
end
@@ -170,16 +170,21 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
170170
if wrap_code === nothing
171171
wrap_code = isscalar ? identity : (identity, identity)
172172
end
173-
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
174-
173+
pre, sol_states = get_substitutions_and_solved_unknowns(sys, isscalar ? [exprs] : exprs)
174+
if postprocess_fbody === nothing
175+
postprocess_fbody = pre
176+
end
177+
if states === nothing
178+
states = sol_states
179+
end
175180
if is_time_dependent(sys)
176181
return build_function(exprs,
177182
dvs,
178183
p...,
179184
get_iv(sys);
180185
kwargs...,
181-
postprocess_fbody = pre,
182-
states = sol_states,
186+
postprocess_fbody,
187+
states,
183188
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
184189
wrap_array_vars(sys, exprs; dvs)
185190
)
@@ -188,8 +193,8 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
188193
dvs,
189194
p...;
190195
kwargs...,
191-
postprocess_fbody = pre,
192-
states = sol_states,
196+
postprocess_fbody,
197+
states,
193198
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
194199
wrap_array_vars(sys, exprs; dvs)
195200
)

src/systems/callbacks.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
388388
return (args...) -> () # We don't do anything in the callback, we're just after the event
389389
end
390390
else
391+
eqs = flatten_equations(eqs)
391392
rhss = map(x -> x.rhs, eqs)
392393
outvar = :u
393394
if outputidxs === nothing
@@ -457,7 +458,7 @@ end
457458

458459
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
459460
ps = full_parameters(sys); kwargs...)
460-
eqs = map(cb -> cb.eqs, cbs)
461+
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
461462
num_eqs = length.(eqs)
462463
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
463464
# fuse equations to create VectorContinuousCallback
@@ -471,12 +472,8 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
471472
rhss = map(x -> x.rhs, eqs)
472473
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
473474

474-
u = map(x -> time_varying_as_func(value(x), sys), dvs)
475-
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
476-
t = get_iv(sys)
477-
pre = get_preprocess_constants(rhss)
478-
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{false},
479-
postprocess_fbody = pre, kwargs...)
475+
rf_oop, rf_ip = generate_custom_function(sys, rhss, dvs, ps; expression = Val{false},
476+
kwargs...)
480477

481478
affect_functions = map(cbs) do cb # Keep affect function separate
482479
eq_aff = affects(cb)
@@ -487,16 +484,16 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
487484
cond = function (u, t, integ)
488485
if DiffEqBase.isinplace(integ.sol.prob)
489486
tmp, = DiffEqBase.get_tmp_cache(integ)
490-
rf_ip(tmp, u, parameter_values(integ)..., t)
487+
rf_ip(tmp, u, parameter_values(integ), t)
491488
tmp[1]
492489
else
493-
rf_oop(u, parameter_values(integ)..., t)
490+
rf_oop(u, parameter_values(integ), t)
494491
end
495492
end
496493
ContinuousCallback(cond, affect_functions[])
497494
else
498495
cond = function (out, u, t, integ)
499-
rf_ip(out, u, parameter_values(integ)..., t)
496+
rf_ip(out, u, parameter_values(integ), t)
500497
end
501498

502499
# since there may be different number of conditions and affects,

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
857857
varmap = u0map === nothing || isempty(u0map) || eltype(u0map) <: Number ?
858858
defaults(sys) :
859859
merge(defaults(sys), todict(u0map))
860+
varmap = canonicalize_varmap(varmap)
860861
varlist = collect(map(unwrap, dvs))
861862
missingvars = setdiff(varlist, collect(keys(varmap)))
862863

src/systems/optimization/constraints_system.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,15 @@ function generate_canonical_form_lhss(sys)
226226
lhss = subs_constants([Symbolics.canonical_form(eq).lhs for eq in constraints(sys)])
227227
end
228228

229-
function get_cmap(sys::ConstraintsSystem)
229+
function get_cmap(sys::ConstraintsSystem, exprs = nothing)
230230
#Inject substitutions for constants => values
231231
cs = collect_constants([get_constraints(sys); get_observed(sys)]) #ctrls? what else?
232232
if !empty_substitutions(sys)
233233
cs = [cs; collect_constants(get_substitutions(sys).subs)]
234234
end
235+
if exprs !== nothing
236+
cs = [cs; collect_constants(exprs)]
237+
end
235238
# Swap constants for their values
236239
cmap = map(x -> x ~ getdefault(x), cs)
237240
return cmap, cs

src/utils.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -564,19 +564,22 @@ function empty_substitutions(sys)
564564
isnothing(subs) || isempty(subs.deps)
565565
end
566566

567-
function get_cmap(sys)
567+
function get_cmap(sys, exprs = nothing)
568568
#Inject substitutions for constants => values
569569
cs = collect_constants([get_eqs(sys); get_observed(sys)]) #ctrls? what else?
570570
if !empty_substitutions(sys)
571571
cs = [cs; collect_constants(get_substitutions(sys).subs)]
572572
end
573+
if exprs !== nothing
574+
cs = [cs; collect_constants(exprs)]
575+
end
573576
# Swap constants for their values
574577
cmap = map(x -> x ~ getdefault(x), cs)
575578
return cmap, cs
576579
end
577580

578-
function get_substitutions_and_solved_unknowns(sys; no_postprocess = false)
579-
cmap, cs = get_cmap(sys)
581+
function get_substitutions_and_solved_unknowns(sys, exprs = nothing; no_postprocess = false)
582+
cmap, cs = get_cmap(sys, exprs)
580583
if empty_substitutions(sys) && isempty(cs)
581584
sol_states = Code.LazyState()
582585
pre = no_postprocess ? (ex -> ex) : get_postprocess_fbody(sys)

src/variables.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ function canonicalize_varmap(varmap; toterm = Symbolics.diff2term)
216216
if Symbolics.isarraysymbolic(k) && Symbolics.shape(k) !== Symbolics.Unknown()
217217
for i in eachindex(k)
218218
new_varmap[k[i]] = v[i]
219+
new_varmap[toterm(k[i])] = v[i]
219220
end
220221
end
221222
end

test/initial_values.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,10 @@ varmap = Dict(p => ones(3), q => 2ones(3))
6767
cvarmap = ModelingToolkit.canonicalize_varmap(varmap)
6868
target_varmap = Dict(p => ones(3), q => 2ones(3), q[1] => 2.0, q[2] => 2.0, q[3] => 2.0)
6969
@test cvarmap == target_varmap
70+
71+
# Initialization of ODEProblem with dummy derivatives of multidimensional arrays
72+
# Issue#1283
73+
@variables z(t)[1:2, 1:2]
74+
eqs = [D(D(z)) ~ ones(2, 2)]
75+
@mtkbuild sys = ODESystem(eqs, t)
76+
@test_nowarn ODEProblem(sys, [z => zeros(2, 2), D(z) => ones(2, 2)], (0.0, 10.0))

test/odesystem.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,3 +995,38 @@ let # Issue https://github.com/SciML/ModelingToolkit.jl/issues/2322
995995
sol = solve(prob, Rodas4())
996996
@test sol(1)[]0.6065307685451087 rtol=1e-4
997997
end
998+
999+
# Issue#2599
1000+
@variables x(t) y(t)
1001+
eqs = [D(x) ~ x * t, y ~ 2x]
1002+
@mtkbuild sys = ODESystem(eqs, t; continuous_events = [[y ~ 3] => [x ~ 2]])
1003+
prob = ODEProblem(sys, [x => 1.0], (0.0, 10.0))
1004+
@test_nowarn solve(prob, Tsit5())
1005+
1006+
# Issue#2383
1007+
@variables x(t)[1:3]
1008+
@parameters p[1:3, 1:3]
1009+
eqs = [
1010+
D(x) ~ p * x
1011+
]
1012+
@mtkbuild sys = ODESystem(eqs, t; continuous_events = [[norm(x) ~ 3.0] => [x ~ ones(3)]])
1013+
# array affect equations used to not work
1014+
prob1 = @test_nowarn ODEProblem(sys, [x => ones(3)], (0.0, 10.0), [p => ones(3, 3)])
1015+
sol1 = @test_nowarn solve(prob1, Tsit5())
1016+
1017+
# array condition equations also used to not work
1018+
@mtkbuild sys = ODESystem(
1019+
eqs, t; continuous_events = [[x ~ sqrt(3) * ones(3)] => [x ~ ones(3)]])
1020+
# array affect equations used to not work
1021+
prob2 = @test_nowarn ODEProblem(sys, [x => ones(3)], (0.0, 10.0), [p => ones(3, 3)])
1022+
sol2 = @test_nowarn solve(prob2, Tsit5())
1023+
1024+
@test sol1 sol2
1025+
1026+
# Requires fix in symbolics for `linear_expansion(p * x, D(y))`
1027+
@test_broken begin
1028+
@variables x(t)[1:3] y(t)
1029+
@parameters p[1:3, 1:3]
1030+
@test_nowarn @mtkbuild sys = ODESystem([D(x) ~ p * x, D(y) ~ x' * p * x], t)
1031+
@test_nowarn ODEProblem(sys, [x => ones(3), y => 2], (0.0, 10.0), [p => ones(3, 3)])
1032+
end

0 commit comments

Comments
 (0)