Skip to content

fix: fix discrete system simplification #3703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/problems/discreteproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ end
check_compatibility && check_compatible_system(DiscreteProblem, sys)

dvs = unknowns(sys)
u0map = to_varmap(op, dvs)
add_toterms!(u0map; replace = true)
op = to_varmap(op, dvs)
add_toterms!(op; replace = true)
f, u0, p = process_SciMLProblem(DiscreteFunction{iip, spec}, sys, op;
t = tspan !== nothing ? tspan[1] : tspan, check_compatibility, expression,
build_initializeprob = false, kwargs...)
kwargs...)

if expression == Val{true}
u0 = :(f($u0, p, tspan[1]))
Expand Down
3 changes: 2 additions & 1 deletion src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Diffe
ExtraEquationsSystemException,
ExtraVariablesSystemException,
vars!, invalidate_cache!,
vars!, invalidate_cache!, Shift,
IncrementalCycleTracker, add_edge_checked!, topological_sort,
filter_kwargs, lower_varname_with_unit,
lower_shift_varname_with_unit, setio, SparseMatrixCLIL,
Expand All @@ -39,7 +40,7 @@ using ModelingToolkit: algeqs, EquationsView,
dervars_range, diffvars_range, algvars_range,
DiffGraph, complete!,
get_fullvars, system_subset
using SymbolicIndexingInterface: symbolic_type, ArraySymbolic
using SymbolicIndexingInterface: symbolic_type, ArraySymbolic, NotSymbolic

using ModelingToolkit.DiffEqBase
using ModelingToolkit.StaticArrays
Expand Down
200 changes: 188 additions & 12 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,23 @@ function generate_derivative_variables!(
for (i, idxs) in idxs_to_remove
deleteat!(var_sccs[i], idxs)
end
new_sccs = insert_sccs(var_sccs, sccs_to_insert)

if mm !== nothing
@set! mm.ncols = ndsts(graph)
end

return new_sccs
end

"""
$(TYPEDSIGNATURES)

Given a list of SCCs and a list of SCCs to insert at specific indices, insert them and
return the new SCC vector.
"""
function insert_sccs(
var_sccs::Vector{Vector{Int}}, sccs_to_insert::Vector{Tuple{Int, Vector{Int}}})
# insert the new SCCs, accounting for the fact that we might have multiple entries
# in `sccs_to_insert` to be inserted at the same index.
old_idx = 1
Expand All @@ -441,10 +458,6 @@ function generate_derivative_variables!(
end

filter!(!isempty, new_sccs)
if mm !== nothing
@set! mm.ncols = ndsts(graph)
end

return new_sccs
end

Expand Down Expand Up @@ -742,7 +755,17 @@ function codegen_equation!(eg::EquationGenerator,
@unpack fullvars, sys, structure = state
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
diff_to_var = invview(var_to_diff)
if is_solvable(eg, ieq, iv) && is_dervar(eg, iv)

issolvable = is_solvable(eg, ieq, iv)
isdervar = issolvable && is_dervar(eg, iv)
isdisc = is_only_discrete(structure)
# The variable is derivative variable and the "most differentiated"
# This is only used for discrete systems, and basically refers to
# `Shift(t, 1)(x(k))` in `Shift(t, 1)(x(k)) ~ x(k) + x(k-1)`. As illustrated in
# the docstring for `add_additional_history!`, this is an exception and needs to be
# treated like a solved equation rather than a differential equation.
is_highest_diff = iv isa Int && isdervar && var_to_diff[iv] === nothing
if issolvable && isdervar && (!isdisc || !is_highest_diff)
var = fullvars[iv]
isnothing(D) && throw(UnexpectedDifferentialError(equations(sys)[ieq]))
order, lv = var_order(iv, diff_to_var)
Expand All @@ -762,15 +785,25 @@ function codegen_equation!(eg::EquationGenerator,
push!(neweqs′, neweq)
push!(eq_ordering, ieq)
push!(var_ordering, diff_to_var[iv])
elseif is_solvable(eg, ieq, iv)
elseif issolvable
var = fullvars[iv]
neweq = make_solved_equation(var, eq, total_sub; simplify)
if neweq !== nothing
# backshift solved equations to calculate the value of the variable at the
# current time. This works because we added one additional history element
# in `add_additional_history!`.
if isdisc
neweq = backshift_expr(neweq, idep)
end
push!(solved_eqs, neweq)
push!(solved_vars, iv)
end
else
neweq = make_algebraic_equation(eq, total_sub)
# For the same reason as solved equations (they are effectively the same)
if isdisc
neweq = backshift_expr(neweq, idep)
end
push!(neweqs′, neweq)
push!(eq_ordering, ieq)
# we push a dummy to `var_ordering` here because `iv` is `unassigned`
Expand Down Expand Up @@ -896,9 +929,24 @@ Update the system equations, unknowns, and observables after simplification.
"""
function update_simplified_system!(
state::TearingState, neweqs, solved_eqs, dummy_sub, var_sccs, extra_unknowns;
cse_hack = true, array_hack = true)
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
cse_hack = true, array_hack = true, D = nothing, iv = nothing)
@unpack fullvars, structure = state
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
diff_to_var = invview(var_to_diff)
# Since we solved the highest order derivative varible in discrete systems,
# we make a list of the solved variables and avoid including them in the
# unknowns.
solved_vars = Set()
if is_only_discrete(structure)
for eq in solved_eqs
var = eq.lhs
if isequal(eq.lhs, eq.rhs)
var = lower_shift_varname_with_unit(D(eq.lhs), iv)
end
push!(solved_vars, var)
end
filter!(eq -> !isequal(eq.lhs, eq.rhs), solved_eqs)
end

ispresent = let var_to_diff = var_to_diff, graph = graph
i -> (!isempty(𝑑neighbors(graph, i)) ||
Expand All @@ -915,9 +963,19 @@ function update_simplified_system!(
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs]

unknown_idxs = filter(
i -> diff_to_var[i] === nothing && ispresent(i), eachindex(state.fullvars))
i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars), eachindex(state.fullvars))
unknowns = state.fullvars[unknown_idxs]
unknowns = [unknowns; extra_unknowns]
if is_only_discrete(structure)
# Algebraic variables are shifted forward by one, so we backshift them.
unknowns = map(enumerate(unknowns)) do (i, var)
if iscall(var) && operation(var) isa Shift && operation(var).steps == 1
backshift_expr(var, iv)
else
var
end
end
end
@set! sys.unknowns = unknowns

obs = cse_and_array_hacks(
Expand Down Expand Up @@ -979,7 +1037,8 @@ differential variables.
function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; simplify = false, mm, cse_hack = true,
array_hack = true, fully_determined = true)
extra_eqs_vars = get_extra_eqs_vars(state, full_var_eq_matching, fully_determined)
extra_eqs_vars = get_extra_eqs_vars(
state, var_eq_matching, full_var_eq_matching, fully_determined)
neweqs = collect(equations(state))
dummy_sub = Dict()

Expand All @@ -995,6 +1054,11 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
end

extra_unknowns = state.fullvars[extra_eqs_vars[2]]
if is_only_discrete(state.structure)
var_sccs = add_additional_history!(
state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs; iv, D)
end

# Structural simplification
substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub; iv, D)

Expand All @@ -1010,13 +1074,121 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
# var_eq_matching and full_var_eq_matching are now invalidated

sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs,
extra_unknowns; cse_hack, array_hack)
extra_unknowns; cse_hack, array_hack, iv, D)

@set! state.sys = sys
@set! sys.tearing_state = state
return invalidate_cache!(sys)
end

"""
$(TYPEDSIGNATURES)

Add one more history equation for discrete systems. For example, if we have

```julia
Shift(t, 1)(x(k-1)) ~ x(k)
Shift(t, 1)(x(k)) ~ x(k) + x(k-1)
```

This turns it into

```julia
Shift(t, 1)(x(k-2)) ~ x(k-1)
Shift(t, 1)(x(k-1)) ~ x(k)
Shift(t, 1)(x(k)) ~ x(k) + x(k-1)
```

Thus adding an additional unknown as well. Later, the highest derivative equation will
be backshifted by one and turned into an observed equation, resulting in:

```julia
Shift(t, 1)(x(k-2)) ~ x(k-1)
Shift(t, 1)(x(k-1)) ~ x(k)

x(k) ~ x(k-1) + x(k-2)
```

Where the last equation is the observed equation.
"""
function add_additional_history!(
state::TearingState, neweqs::Vector, var_eq_matching::Matching,
full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; iv, D)
@unpack fullvars, sys, structure = state
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
eq_var_matching = invview(var_eq_matching)
diff_to_var = invview(var_to_diff)
is_discrete = is_only_discrete(structure)
digraph = DiCMOBiGraph{false}(graph, var_eq_matching)

# We need the inverse mapping of `var_sccs` to update it efficiently later.
v_to_scc = Vector{NTuple{2, Int}}(undef, ndsts(graph))
for (i, scc) in enumerate(var_sccs), (j, v) in enumerate(scc)
v_to_scc[v] = (i, j)
end

vars_to_backshift = BitSet()
eqs_to_backshift = BitSet()
# add history for differential variables
for ivar in 1:length(fullvars)
ieq = var_eq_matching[ivar]
# the variable to backshift is a state variable which is not the
# derivative of any other one.
ieq isa SelectedState || continue
diff_to_var[ivar] === nothing || continue
push!(vars_to_backshift, ivar)
end

inserts = Tuple{Int, Vector{Int}}[]

for var in vars_to_backshift
add_backshifted_var!(state, var, iv)
# all backshifted vars are differential vars, hence SelectedState
push!(var_eq_matching, SelectedState())
push!(full_var_eq_matching, unassigned)
# add to the SCCs right before the variable that was backshifted
push!(inserts, (v_to_scc[var][1], [length(fullvars)]))
end

sort!(inserts, by = first)
new_sccs = insert_sccs(var_sccs, inserts)
return new_sccs
end

"""
$(TYPEDSIGNATURES)

Add the backshifted version of variable `ivar` to the system.
"""
function add_backshifted_var!(state::TearingState, ivar::Int, iv)
@unpack fullvars, structure = state
@unpack var_to_diff, graph, solvable_graph = structure

var = fullvars[ivar]
newvar = simplify_shifts(Shift(iv, -1)(var))
push!(fullvars, newvar)
inewvar = add_vertex!(var_to_diff)
add_edge!(var_to_diff, inewvar, ivar)
add_vertex!(graph, DST)
add_vertex!(solvable_graph, DST)
return inewvar
end

"""
$(TYPEDSIGNATURES)

Backshift the given expression `ex`.
"""
function backshift_expr(ex, iv)
ex isa Symbolic || return ex
return descend_lower_shift_varname_with_unit(
simplify_shifts(distribute_shift(Shift(iv, -1)(ex))), iv)
end

function backshift_expr(ex::Equation, iv)
return backshift_expr(ex.lhs, iv) ~ backshift_expr(ex.rhs, iv)
end

"""
$(TYPEDSIGNATURES)

Expand All @@ -1025,7 +1197,7 @@ respectively. For fully-determined systems, both of these are empty. Overdetermi
have extra equations, and underdetermined systems have extra variables.
"""
function get_extra_eqs_vars(
state::TearingState, full_var_eq_matching::Matching, fully_determined::Bool)
state::TearingState, var_eq_matching::Matching, full_var_eq_matching::Matching, fully_determined::Bool)
fully_determined && return Int[], Int[]

extra_eqs = Int[]
Expand All @@ -1035,6 +1207,10 @@ function get_extra_eqs_vars(
for v in 𝑑vertices(state.structure.graph)
eq = full_var_eq_matching[v]
eq isa Int && continue
# Only if the variable is also unmatched in `var_eq_matching`.
# Otherwise, `SelectedState` differential variables from order lowering
# are also considered "extra"
var_eq_matching[v] === unassigned || continue
push!(extra_vars, v)
end
for eq in 𝑠vertices(state.structure.graph)
Expand Down
22 changes: 20 additions & 2 deletions src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,29 @@ function lower_shift_varname(var, iv)
end
end

function descend_lower_shift_varname_with_unit(var, iv)
symbolic_type(var) == NotSymbolic() && return var
ModelingToolkit._with_unit(descend_lower_shift_varname, var, iv, iv)
end
function descend_lower_shift_varname(var, iv)
iscall(var) || return var
op = operation(var)
if op isa Shift
return shift2term(var)
else
args = arguments(var)
args = map(Base.Fix2(descend_lower_shift_varname, iv), args)
return maketerm(typeof(var), op, args, Symbolics.metadata(var))
end
end

"""
Rename a Shift variable with negative shift, Shift(t, k)(x(t)) to xₜ₋ₖ(t).
"""
function shift2term(var)
iscall(var) || return var
op = operation(var)
op isa Shift || return var
iv = op.t
arg = only(arguments(var))
if operation(arg) === getindex
Expand Down Expand Up @@ -575,10 +593,10 @@ end
function _distribute_shift(expr, shift)
if iscall(expr)
op = operation(expr)
(op isa Pre || op isa Initial) && return expr
(op isa Union{Pre, Initial, Sample, Hold}) && return expr
args = arguments(expr)

if ModelingToolkit.isvariable(expr)
if ModelingToolkit.isvariable(expr) && operation(expr) !== getindex
(length(args) == 1 && isequal(shift.t, only(args))) ? (return shift(expr)) :
(return expr)
elseif op isa Shift
Expand Down
3 changes: 2 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ function (f::Initial)(x)
return x
end
# differential variables are default-toterm-ed
if iscall(x) && operation(x) isa Differential
if iscall(x) && operation(x) isa Union{Differential, Shift}
x = default_toterm(x)
end
# don't double wrap
Expand Down Expand Up @@ -760,6 +760,7 @@ for prop in [:eqs
:metadata
:gui_metadata
:is_initializesystem
:is_discrete
:parameter_dependencies
:assertions
:ignored_connections
Expand Down
2 changes: 1 addition & 1 deletion src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],

@named affectsys = System(
vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)),
collect(union(pre_params, sys_params)))
collect(union(pre_params, sys_params)); is_discrete = true)
affectsys = mtkcompile(affectsys; fully_determined = nothing)
# get accessed parameters p from Pre(p) in the callback parameters
accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params))))
Expand Down
Loading
Loading