Skip to content

Commit a5adf3d

Browse files
Merge pull request #3703 from AayushSabharwal/as/discrete-simplify-2
fix: fix discrete system simplification
2 parents 5d06442 + adfc91c commit a5adf3d

File tree

12 files changed

+265
-125
lines changed

12 files changed

+265
-125
lines changed

src/problems/discreteproblem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ end
4343
check_compatibility && check_compatible_system(DiscreteProblem, sys)
4444

4545
dvs = unknowns(sys)
46-
u0map = to_varmap(op, dvs)
47-
add_toterms!(u0map; replace = true)
46+
op = to_varmap(op, dvs)
47+
add_toterms!(op; replace = true)
4848
f, u0, p = process_SciMLProblem(DiscreteFunction{iip, spec}, sys, op;
4949
t = tspan !== nothing ? tspan[1] : tspan, check_compatibility, expression,
50-
build_initializeprob = false, kwargs...)
50+
kwargs...)
5151

5252
if expression == Val{true}
5353
u0 = :(f($u0, p, tspan[1]))

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Diffe
2222
ExtraEquationsSystemException,
2323
ExtraVariablesSystemException,
2424
vars!, invalidate_cache!,
25+
vars!, invalidate_cache!, Shift,
2526
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2627
filter_kwargs, lower_varname_with_unit,
2728
lower_shift_varname_with_unit, setio, SparseMatrixCLIL,
@@ -39,7 +40,7 @@ using ModelingToolkit: algeqs, EquationsView,
3940
dervars_range, diffvars_range, algvars_range,
4041
DiffGraph, complete!,
4142
get_fullvars, system_subset
42-
using SymbolicIndexingInterface: symbolic_type, ArraySymbolic
43+
using SymbolicIndexingInterface: symbolic_type, ArraySymbolic, NotSymbolic
4344

4445
using ModelingToolkit.DiffEqBase
4546
using ModelingToolkit.StaticArrays

src/structural_transformation/symbolics_tearing.jl

Lines changed: 188 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,23 @@ function generate_derivative_variables!(
421421
for (i, idxs) in idxs_to_remove
422422
deleteat!(var_sccs[i], idxs)
423423
end
424+
new_sccs = insert_sccs(var_sccs, sccs_to_insert)
425+
426+
if mm !== nothing
427+
@set! mm.ncols = ndsts(graph)
428+
end
429+
430+
return new_sccs
431+
end
432+
433+
"""
434+
$(TYPEDSIGNATURES)
435+
436+
Given a list of SCCs and a list of SCCs to insert at specific indices, insert them and
437+
return the new SCC vector.
438+
"""
439+
function insert_sccs(
440+
var_sccs::Vector{Vector{Int}}, sccs_to_insert::Vector{Tuple{Int, Vector{Int}}})
424441
# insert the new SCCs, accounting for the fact that we might have multiple entries
425442
# in `sccs_to_insert` to be inserted at the same index.
426443
old_idx = 1
@@ -441,10 +458,6 @@ function generate_derivative_variables!(
441458
end
442459

443460
filter!(!isempty, new_sccs)
444-
if mm !== nothing
445-
@set! mm.ncols = ndsts(graph)
446-
end
447-
448461
return new_sccs
449462
end
450463

@@ -742,7 +755,17 @@ function codegen_equation!(eg::EquationGenerator,
742755
@unpack fullvars, sys, structure = state
743756
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
744757
diff_to_var = invview(var_to_diff)
745-
if is_solvable(eg, ieq, iv) && is_dervar(eg, iv)
758+
759+
issolvable = is_solvable(eg, ieq, iv)
760+
isdervar = issolvable && is_dervar(eg, iv)
761+
isdisc = is_only_discrete(structure)
762+
# The variable is derivative variable and the "most differentiated"
763+
# This is only used for discrete systems, and basically refers to
764+
# `Shift(t, 1)(x(k))` in `Shift(t, 1)(x(k)) ~ x(k) + x(k-1)`. As illustrated in
765+
# the docstring for `add_additional_history!`, this is an exception and needs to be
766+
# treated like a solved equation rather than a differential equation.
767+
is_highest_diff = iv isa Int && isdervar && var_to_diff[iv] === nothing
768+
if issolvable && isdervar && (!isdisc || !is_highest_diff)
746769
var = fullvars[iv]
747770
isnothing(D) && throw(UnexpectedDifferentialError(equations(sys)[ieq]))
748771
order, lv = var_order(iv, diff_to_var)
@@ -762,15 +785,25 @@ function codegen_equation!(eg::EquationGenerator,
762785
push!(neweqs′, neweq)
763786
push!(eq_ordering, ieq)
764787
push!(var_ordering, diff_to_var[iv])
765-
elseif is_solvable(eg, ieq, iv)
788+
elseif issolvable
766789
var = fullvars[iv]
767790
neweq = make_solved_equation(var, eq, total_sub; simplify)
768791
if neweq !== nothing
792+
# backshift solved equations to calculate the value of the variable at the
793+
# current time. This works because we added one additional history element
794+
# in `add_additional_history!`.
795+
if isdisc
796+
neweq = backshift_expr(neweq, idep)
797+
end
769798
push!(solved_eqs, neweq)
770799
push!(solved_vars, iv)
771800
end
772801
else
773802
neweq = make_algebraic_equation(eq, total_sub)
803+
# For the same reason as solved equations (they are effectively the same)
804+
if isdisc
805+
neweq = backshift_expr(neweq, idep)
806+
end
774807
push!(neweqs′, neweq)
775808
push!(eq_ordering, ieq)
776809
# we push a dummy to `var_ordering` here because `iv` is `unassigned`
@@ -896,9 +929,24 @@ Update the system equations, unknowns, and observables after simplification.
896929
"""
897930
function update_simplified_system!(
898931
state::TearingState, neweqs, solved_eqs, dummy_sub, var_sccs, extra_unknowns;
899-
cse_hack = true, array_hack = true)
900-
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
932+
cse_hack = true, array_hack = true, D = nothing, iv = nothing)
933+
@unpack fullvars, structure = state
934+
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
901935
diff_to_var = invview(var_to_diff)
936+
# Since we solved the highest order derivative varible in discrete systems,
937+
# we make a list of the solved variables and avoid including them in the
938+
# unknowns.
939+
solved_vars = Set()
940+
if is_only_discrete(structure)
941+
for eq in solved_eqs
942+
var = eq.lhs
943+
if isequal(eq.lhs, eq.rhs)
944+
var = lower_shift_varname_with_unit(D(eq.lhs), iv)
945+
end
946+
push!(solved_vars, var)
947+
end
948+
filter!(eq -> !isequal(eq.lhs, eq.rhs), solved_eqs)
949+
end
902950

903951
ispresent = let var_to_diff = var_to_diff, graph = graph
904952
i -> (!isempty(𝑑neighbors(graph, i)) ||
@@ -915,9 +963,19 @@ function update_simplified_system!(
915963
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs]
916964

917965
unknown_idxs = filter(
918-
i -> diff_to_var[i] === nothing && ispresent(i), eachindex(state.fullvars))
966+
i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars), eachindex(state.fullvars))
919967
unknowns = state.fullvars[unknown_idxs]
920968
unknowns = [unknowns; extra_unknowns]
969+
if is_only_discrete(structure)
970+
# Algebraic variables are shifted forward by one, so we backshift them.
971+
unknowns = map(enumerate(unknowns)) do (i, var)
972+
if iscall(var) && operation(var) isa Shift && operation(var).steps == 1
973+
backshift_expr(var, iv)
974+
else
975+
var
976+
end
977+
end
978+
end
921979
@set! sys.unknowns = unknowns
922980

923981
obs = cse_and_array_hacks(
@@ -979,7 +1037,8 @@ differential variables.
9791037
function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
9801038
full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; simplify = false, mm, cse_hack = true,
9811039
array_hack = true, fully_determined = true)
982-
extra_eqs_vars = get_extra_eqs_vars(state, full_var_eq_matching, fully_determined)
1040+
extra_eqs_vars = get_extra_eqs_vars(
1041+
state, var_eq_matching, full_var_eq_matching, fully_determined)
9831042
neweqs = collect(equations(state))
9841043
dummy_sub = Dict()
9851044

@@ -995,6 +1054,11 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
9951054
end
9961055

9971056
extra_unknowns = state.fullvars[extra_eqs_vars[2]]
1057+
if is_only_discrete(state.structure)
1058+
var_sccs = add_additional_history!(
1059+
state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs; iv, D)
1060+
end
1061+
9981062
# Structural simplification
9991063
substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub; iv, D)
10001064

@@ -1010,13 +1074,121 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
10101074
# var_eq_matching and full_var_eq_matching are now invalidated
10111075

10121076
sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs,
1013-
extra_unknowns; cse_hack, array_hack)
1077+
extra_unknowns; cse_hack, array_hack, iv, D)
10141078

10151079
@set! state.sys = sys
10161080
@set! sys.tearing_state = state
10171081
return invalidate_cache!(sys)
10181082
end
10191083

1084+
"""
1085+
$(TYPEDSIGNATURES)
1086+
1087+
Add one more history equation for discrete systems. For example, if we have
1088+
1089+
```julia
1090+
Shift(t, 1)(x(k-1)) ~ x(k)
1091+
Shift(t, 1)(x(k)) ~ x(k) + x(k-1)
1092+
```
1093+
1094+
This turns it into
1095+
1096+
```julia
1097+
Shift(t, 1)(x(k-2)) ~ x(k-1)
1098+
Shift(t, 1)(x(k-1)) ~ x(k)
1099+
Shift(t, 1)(x(k)) ~ x(k) + x(k-1)
1100+
```
1101+
1102+
Thus adding an additional unknown as well. Later, the highest derivative equation will
1103+
be backshifted by one and turned into an observed equation, resulting in:
1104+
1105+
```julia
1106+
Shift(t, 1)(x(k-2)) ~ x(k-1)
1107+
Shift(t, 1)(x(k-1)) ~ x(k)
1108+
1109+
x(k) ~ x(k-1) + x(k-2)
1110+
```
1111+
1112+
Where the last equation is the observed equation.
1113+
"""
1114+
function add_additional_history!(
1115+
state::TearingState, neweqs::Vector, var_eq_matching::Matching,
1116+
full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; iv, D)
1117+
@unpack fullvars, sys, structure = state
1118+
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
1119+
eq_var_matching = invview(var_eq_matching)
1120+
diff_to_var = invview(var_to_diff)
1121+
is_discrete = is_only_discrete(structure)
1122+
digraph = DiCMOBiGraph{false}(graph, var_eq_matching)
1123+
1124+
# We need the inverse mapping of `var_sccs` to update it efficiently later.
1125+
v_to_scc = Vector{NTuple{2, Int}}(undef, ndsts(graph))
1126+
for (i, scc) in enumerate(var_sccs), (j, v) in enumerate(scc)
1127+
v_to_scc[v] = (i, j)
1128+
end
1129+
1130+
vars_to_backshift = BitSet()
1131+
eqs_to_backshift = BitSet()
1132+
# add history for differential variables
1133+
for ivar in 1:length(fullvars)
1134+
ieq = var_eq_matching[ivar]
1135+
# the variable to backshift is a state variable which is not the
1136+
# derivative of any other one.
1137+
ieq isa SelectedState || continue
1138+
diff_to_var[ivar] === nothing || continue
1139+
push!(vars_to_backshift, ivar)
1140+
end
1141+
1142+
inserts = Tuple{Int, Vector{Int}}[]
1143+
1144+
for var in vars_to_backshift
1145+
add_backshifted_var!(state, var, iv)
1146+
# all backshifted vars are differential vars, hence SelectedState
1147+
push!(var_eq_matching, SelectedState())
1148+
push!(full_var_eq_matching, unassigned)
1149+
# add to the SCCs right before the variable that was backshifted
1150+
push!(inserts, (v_to_scc[var][1], [length(fullvars)]))
1151+
end
1152+
1153+
sort!(inserts, by = first)
1154+
new_sccs = insert_sccs(var_sccs, inserts)
1155+
return new_sccs
1156+
end
1157+
1158+
"""
1159+
$(TYPEDSIGNATURES)
1160+
1161+
Add the backshifted version of variable `ivar` to the system.
1162+
"""
1163+
function add_backshifted_var!(state::TearingState, ivar::Int, iv)
1164+
@unpack fullvars, structure = state
1165+
@unpack var_to_diff, graph, solvable_graph = structure
1166+
1167+
var = fullvars[ivar]
1168+
newvar = simplify_shifts(Shift(iv, -1)(var))
1169+
push!(fullvars, newvar)
1170+
inewvar = add_vertex!(var_to_diff)
1171+
add_edge!(var_to_diff, inewvar, ivar)
1172+
add_vertex!(graph, DST)
1173+
add_vertex!(solvable_graph, DST)
1174+
return inewvar
1175+
end
1176+
1177+
"""
1178+
$(TYPEDSIGNATURES)
1179+
1180+
Backshift the given expression `ex`.
1181+
"""
1182+
function backshift_expr(ex, iv)
1183+
ex isa Symbolic || return ex
1184+
return descend_lower_shift_varname_with_unit(
1185+
simplify_shifts(distribute_shift(Shift(iv, -1)(ex))), iv)
1186+
end
1187+
1188+
function backshift_expr(ex::Equation, iv)
1189+
return backshift_expr(ex.lhs, iv) ~ backshift_expr(ex.rhs, iv)
1190+
end
1191+
10201192
"""
10211193
$(TYPEDSIGNATURES)
10221194
@@ -1025,7 +1197,7 @@ respectively. For fully-determined systems, both of these are empty. Overdetermi
10251197
have extra equations, and underdetermined systems have extra variables.
10261198
"""
10271199
function get_extra_eqs_vars(
1028-
state::TearingState, full_var_eq_matching::Matching, fully_determined::Bool)
1200+
state::TearingState, var_eq_matching::Matching, full_var_eq_matching::Matching, fully_determined::Bool)
10291201
fully_determined && return Int[], Int[]
10301202

10311203
extra_eqs = Int[]
@@ -1035,6 +1207,10 @@ function get_extra_eqs_vars(
10351207
for v in 𝑑vertices(state.structure.graph)
10361208
eq = full_var_eq_matching[v]
10371209
eq isa Int && continue
1210+
# Only if the variable is also unmatched in `var_eq_matching`.
1211+
# Otherwise, `SelectedState` differential variables from order lowering
1212+
# are also considered "extra"
1213+
var_eq_matching[v] === unassigned || continue
10381214
push!(extra_vars, v)
10391215
end
10401216
for eq in 𝑠vertices(state.structure.graph)

src/structural_transformation/utils.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,11 +482,29 @@ function lower_shift_varname(var, iv)
482482
end
483483
end
484484

485+
function descend_lower_shift_varname_with_unit(var, iv)
486+
symbolic_type(var) == NotSymbolic() && return var
487+
ModelingToolkit._with_unit(descend_lower_shift_varname, var, iv, iv)
488+
end
489+
function descend_lower_shift_varname(var, iv)
490+
iscall(var) || return var
491+
op = operation(var)
492+
if op isa Shift
493+
return shift2term(var)
494+
else
495+
args = arguments(var)
496+
args = map(Base.Fix2(descend_lower_shift_varname, iv), args)
497+
return maketerm(typeof(var), op, args, Symbolics.metadata(var))
498+
end
499+
end
500+
485501
"""
486502
Rename a Shift variable with negative shift, Shift(t, k)(x(t)) to xₜ₋ₖ(t).
487503
"""
488504
function shift2term(var)
505+
iscall(var) || return var
489506
op = operation(var)
507+
op isa Shift || return var
490508
iv = op.t
491509
arg = only(arguments(var))
492510
if operation(arg) === getindex
@@ -575,10 +593,10 @@ end
575593
function _distribute_shift(expr, shift)
576594
if iscall(expr)
577595
op = operation(expr)
578-
(op isa Pre || op isa Initial) && return expr
596+
(op isa Union{Pre, Initial, Sample, Hold}) && return expr
579597
args = arguments(expr)
580598

581-
if ModelingToolkit.isvariable(expr)
599+
if ModelingToolkit.isvariable(expr) && operation(expr) !== getindex
582600
(length(args) == 1 && isequal(shift.t, only(args))) ? (return shift(expr)) :
583601
(return expr)
584602
elseif op isa Shift

src/systems/abstractsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ function (f::Initial)(x)
503503
return x
504504
end
505505
# differential variables are default-toterm-ed
506-
if iscall(x) && operation(x) isa Differential
506+
if iscall(x) && operation(x) isa Union{Differential, Shift}
507507
x = default_toterm(x)
508508
end
509509
# don't double wrap
@@ -760,6 +760,7 @@ for prop in [:eqs
760760
:metadata
761761
:gui_metadata
762762
:is_initializesystem
763+
:is_discrete
763764
:parameter_dependencies
764765
:assertions
765766
:ignored_connections

src/systems/callbacks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
259259

260260
@named affectsys = System(
261261
vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)),
262-
collect(union(pre_params, sys_params)))
262+
collect(union(pre_params, sys_params)); is_discrete = true)
263263
affectsys = mtkcompile(affectsys; fully_determined = nothing)
264264
# get accessed parameters p from Pre(p) in the callback parameters
265265
accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params))))

0 commit comments

Comments
 (0)