From 3c6f7036312f874f5dabe3cf92a92ae5ae5b2de8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 17:14:32 +0530 Subject: [PATCH 01/13] feat: add `is_discrete` flag to systems --- src/systems/abstractsystem.jl | 1 + src/systems/system.jl | 16 +++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 533206883f..fb69d8b702 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -760,6 +760,7 @@ for prop in [:eqs :metadata :gui_metadata :is_initializesystem + :is_discrete :parameter_dependencies :assertions :ignored_connections diff --git a/src/systems/system.jl b/src/systems/system.jl index 7db44c7c97..cb330b382f 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -235,6 +235,7 @@ struct System <: AbstractSystem Whether the current system is an initialization system. """ is_initializesystem::Bool + is_discrete::Bool """ $INTERNAL_FIELD_WARNING Whether the system has been simplified by `mtkcompile`. @@ -255,8 +256,8 @@ struct System <: AbstractSystem is_dde = false, tstops = [], tearing_state = nothing, namespacing = true, complete = false, index_cache = nothing, ignored_connections = nothing, preface = nothing, parent = nothing, initializesystem = nothing, - is_initializesystem = false, isscheduled = false, schedule = nothing; - checks::Union{Bool, Int} = true) + is_initializesystem = false, is_discrete = false, isscheduled = false, + schedule = nothing; checks::Union{Bool, Int} = true) if is_initializesystem && iv !== nothing throw(ArgumentError(""" Expected initialization system to be time-independent. Found independent @@ -293,7 +294,8 @@ struct System <: AbstractSystem guesses, systems, initialization_eqs, continuous_events, discrete_events, connector_type, assertions, metadata, gui_metadata, is_dde, tstops, tearing_state, namespacing, complete, index_cache, ignored_connections, - preface, parent, initializesystem, is_initializesystem, isscheduled, schedule) + preface, parent, initializesystem, is_initializesystem, is_discrete, + isscheduled, schedule) end end @@ -330,8 +332,8 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; is_dde = nothing, tstops = [], tearing_state = nothing, ignored_connections = nothing, parent = nothing, description = "", name = nothing, discover_from_metadata = true, - initializesystem = nothing, is_initializesystem = false, preface = [], - checks = true) + initializesystem = nothing, is_initializesystem = false, is_discrete = false, + preface = [], checks = true) name === nothing && throw(NoNameError()) if !isempty(parameter_dependencies) @warn """ @@ -411,7 +413,7 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; var_to_name, name, description, defaults, guesses, systems, initialization_eqs, continuous_events, discrete_events, connector_type, assertions, metadata, gui_metadata, is_dde, tstops, tearing_state, true, false, nothing, ignored_connections, preface, parent, - initializesystem, is_initializesystem; checks) + initializesystem, is_initializesystem, is_discrete; checks) end """ @@ -668,7 +670,7 @@ callbacks, so checking if any LHS is shifted is sufficient. If a variable is shi the input equations there _will_ be a `Shift` equation in the simplified system. """ function is_discrete_system(sys::System) - any(eq -> isoperator(eq.lhs, Shift), equations(sys)) + get_is_discrete(sys) || any(eq -> isoperator(eq.lhs, Shift), equations(sys)) end SymbolicIndexingInterface.is_time_dependent(sys::System) = get_iv(sys) !== nothing From 60439b136d844b7f10ca329013cc01521300201b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 17:14:45 +0530 Subject: [PATCH 02/13] fix: explicitly mark callback systems as discrete --- src/systems/callbacks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index a9069e575c..a12dad9d72 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -259,7 +259,7 @@ function make_affect(affect::Vector{Equation}; discrete_parameters = Any[], @named affectsys = System( vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)), - collect(union(pre_params, sys_params))) + collect(union(pre_params, sys_params)); is_discrete = true) affectsys = mtkcompile(affectsys; fully_determined = nothing) # get accessed parameters p from Pre(p) in the callback parameters accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params)))) From d584eab3c823774da464752626419318c08321d1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 17:15:21 +0530 Subject: [PATCH 03/13] fix: check and update `is_discrete` flag in simplification --- src/systems/systemstructure.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 8b0e020cc4..a2e021c66e 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -720,8 +720,13 @@ function mtkcompile!(state::TearingState; simplify = false, """)) end end - if continuous_id == 1 && any(Base.Fix2(isoperator, Shift), state.fullvars) + if get_is_discrete(state.sys) || + continuous_id == 1 && any(Base.Fix2(isoperator, Shift), state.fullvars) state.structure.only_discrete = true + state = shift_discrete_system(state) + sys = state.sys + @set! sys.is_discrete = true + state.sys = sys end sys = _mtkcompile!(state; simplify, check_consistency, From 57f96fc024480f4d9adbe140f835706d19e060af Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 17:15:46 +0530 Subject: [PATCH 04/13] feat: add `descend_lower_shift_varname` and `_with_unit` variant --- src/structural_transformation/utils.jl | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index 50f9aaf887..a5a8febd18 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -482,11 +482,29 @@ function lower_shift_varname(var, iv) end end +function descend_lower_shift_varname_with_unit(var, iv) + symbolic_type(var) == NotSymbolic() && return var + ModelingToolkit._with_unit(descend_lower_shift_varname, var, iv, iv) +end +function descend_lower_shift_varname(var, iv) + iscall(var) || return var + op = operation(var) + if op isa Shift + return shift2term(var) + else + args = arguments(var) + args = map(Base.Fix2(descend_lower_shift_varname, iv), args) + return maketerm(typeof(var), op, args, Symbolics.metadata(var)) + end +end + """ Rename a Shift variable with negative shift, Shift(t, k)(x(t)) to xₜ₋ₖ(t). """ function shift2term(var) + iscall(var) || return var op = operation(var) + op isa Shift || return var iv = op.t arg = only(arguments(var)) if operation(arg) === getindex From 224948df5b7215d8e884399eb9dc9ac3556a0e69 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 17:16:04 +0530 Subject: [PATCH 05/13] feat: properly handle simplification of (implicit) discrete systems --- .../StructuralTransformations.jl | 3 +- .../symbolics_tearing.jl | 200 ++++++++++++++++-- 2 files changed, 190 insertions(+), 13 deletions(-) diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index eeb480e0f7..681025cb81 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -22,6 +22,7 @@ using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Diffe ExtraEquationsSystemException, ExtraVariablesSystemException, vars!, invalidate_cache!, + vars!, invalidate_cache!, Shift, IncrementalCycleTracker, add_edge_checked!, topological_sort, filter_kwargs, lower_varname_with_unit, lower_shift_varname_with_unit, setio, SparseMatrixCLIL, @@ -39,7 +40,7 @@ using ModelingToolkit: algeqs, EquationsView, dervars_range, diffvars_range, algvars_range, DiffGraph, complete!, get_fullvars, system_subset -using SymbolicIndexingInterface: symbolic_type, ArraySymbolic +using SymbolicIndexingInterface: symbolic_type, ArraySymbolic, NotSymbolic using ModelingToolkit.DiffEqBase using ModelingToolkit.StaticArrays diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 0014a947c8..c5c0757295 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -421,6 +421,23 @@ function generate_derivative_variables!( for (i, idxs) in idxs_to_remove deleteat!(var_sccs[i], idxs) end + new_sccs = insert_sccs(var_sccs, sccs_to_insert) + + if mm !== nothing + @set! mm.ncols = ndsts(graph) + end + + return new_sccs +end + +""" + $(TYPEDSIGNATURES) + +Given a list of SCCs and a list of SCCs to insert at specific indices, insert them and +return the new SCC vector. +""" +function insert_sccs( + var_sccs::Vector{Vector{Int}}, sccs_to_insert::Vector{Tuple{Int, Vector{Int}}}) # insert the new SCCs, accounting for the fact that we might have multiple entries # in `sccs_to_insert` to be inserted at the same index. old_idx = 1 @@ -441,10 +458,6 @@ function generate_derivative_variables!( end filter!(!isempty, new_sccs) - if mm !== nothing - @set! mm.ncols = ndsts(graph) - end - return new_sccs end @@ -742,7 +755,17 @@ function codegen_equation!(eg::EquationGenerator, @unpack fullvars, sys, structure = state @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure diff_to_var = invview(var_to_diff) - if is_solvable(eg, ieq, iv) && is_dervar(eg, iv) + + issolvable = is_solvable(eg, ieq, iv) + isdervar = issolvable && is_dervar(eg, iv) + isdisc = is_only_discrete(structure) + # The variable is derivative variable and the "most differentiated" + # This is only used for discrete systems, and basically refers to + # `Shift(t, 1)(x(k))` in `Shift(t, 1)(x(k)) ~ x(k) + x(k-1)`. As illustrated in + # the docstring for `add_additional_history!`, this is an exception and needs to be + # treated like a solved equation rather than a differential equation. + is_highest_diff = iv isa Int && isdervar && var_to_diff[iv] === nothing + if issolvable && isdervar && (!isdisc || !is_highest_diff) var = fullvars[iv] isnothing(D) && throw(UnexpectedDifferentialError(equations(sys)[ieq])) order, lv = var_order(iv, diff_to_var) @@ -762,15 +785,25 @@ function codegen_equation!(eg::EquationGenerator, push!(neweqs′, neweq) push!(eq_ordering, ieq) push!(var_ordering, diff_to_var[iv]) - elseif is_solvable(eg, ieq, iv) + elseif issolvable var = fullvars[iv] neweq = make_solved_equation(var, eq, total_sub; simplify) if neweq !== nothing + # backshift solved equations to calculate the value of the variable at the + # current time. This works because we added one additional history element + # in `add_additional_history!`. + if isdisc + neweq = backshift_expr(neweq, idep) + end push!(solved_eqs, neweq) push!(solved_vars, iv) end else neweq = make_algebraic_equation(eq, total_sub) + # For the same reason as solved equations (they are effectively the same) + if isdisc + neweq = backshift_expr(neweq, idep) + end push!(neweqs′, neweq) push!(eq_ordering, ieq) # we push a dummy to `var_ordering` here because `iv` is `unassigned` @@ -896,9 +929,24 @@ Update the system equations, unknowns, and observables after simplification. """ function update_simplified_system!( state::TearingState, neweqs, solved_eqs, dummy_sub, var_sccs, extra_unknowns; - cse_hack = true, array_hack = true) - @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure + cse_hack = true, array_hack = true, D = nothing, iv = nothing) + @unpack fullvars, structure = state + @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure diff_to_var = invview(var_to_diff) + # Since we solved the highest order derivative varible in discrete systems, + # we make a list of the solved variables and avoid including them in the + # unknowns. + solved_vars = Set() + if is_only_discrete(structure) + for eq in solved_eqs + var = eq.lhs + if isequal(eq.lhs, eq.rhs) + var = lower_shift_varname_with_unit(D(eq.lhs), iv) + end + push!(solved_vars, var) + end + filter!(eq -> !isequal(eq.lhs, eq.rhs), solved_eqs) + end ispresent = let var_to_diff = var_to_diff, graph = graph i -> (!isempty(𝑑neighbors(graph, i)) || @@ -915,9 +963,19 @@ function update_simplified_system!( obs = [fast_substitute(observed(sys), obs_sub); solved_eqs] unknown_idxs = filter( - i -> diff_to_var[i] === nothing && ispresent(i), eachindex(state.fullvars)) + i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars), eachindex(state.fullvars)) unknowns = state.fullvars[unknown_idxs] unknowns = [unknowns; extra_unknowns] + if is_only_discrete(structure) + # Algebraic variables are shifted forward by one, so we backshift them. + unknowns = map(enumerate(unknowns)) do (i, var) + if iscall(var) && operation(var) isa Shift && operation(var).steps == 1 + backshift_expr(var, iv) + else + var + end + end + end @set! sys.unknowns = unknowns obs = cse_and_array_hacks( @@ -979,7 +1037,8 @@ differential variables. function tearing_reassemble(state::TearingState, var_eq_matching::Matching, full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; simplify = false, mm, cse_hack = true, array_hack = true, fully_determined = true) - extra_eqs_vars = get_extra_eqs_vars(state, full_var_eq_matching, fully_determined) + extra_eqs_vars = get_extra_eqs_vars( + state, var_eq_matching, full_var_eq_matching, fully_determined) neweqs = collect(equations(state)) dummy_sub = Dict() @@ -995,6 +1054,11 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching, end extra_unknowns = state.fullvars[extra_eqs_vars[2]] + if is_only_discrete(state.structure) + var_sccs = add_additional_history!( + state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs; iv, D) + end + # Structural simplification substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub; iv, D) @@ -1010,13 +1074,121 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching, # var_eq_matching and full_var_eq_matching are now invalidated sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs, - extra_unknowns; cse_hack, array_hack) + extra_unknowns; cse_hack, array_hack, iv, D) @set! state.sys = sys @set! sys.tearing_state = state return invalidate_cache!(sys) end +""" + $(TYPEDSIGNATURES) + +Add one more history equation for discrete systems. For example, if we have + +```julia +Shift(t, 1)(x(k-1)) ~ x(k) +Shift(t, 1)(x(k)) ~ x(k) + x(k-1) +``` + +This turns it into + +```julia +Shift(t, 1)(x(k-2)) ~ x(k-1) +Shift(t, 1)(x(k-1)) ~ x(k) +Shift(t, 1)(x(k)) ~ x(k) + x(k-1) +``` + +Thus adding an additional unknown as well. Later, the highest derivative equation will +be backshifted by one and turned into an observed equation, resulting in: + +```julia +Shift(t, 1)(x(k-2)) ~ x(k-1) +Shift(t, 1)(x(k-1)) ~ x(k) + +x(k) ~ x(k-1) + x(k-2) +``` + +Where the last equation is the observed equation. +""" +function add_additional_history!( + state::TearingState, neweqs::Vector, var_eq_matching::Matching, + full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; iv, D) + @unpack fullvars, sys, structure = state + @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure + eq_var_matching = invview(var_eq_matching) + diff_to_var = invview(var_to_diff) + is_discrete = is_only_discrete(structure) + digraph = DiCMOBiGraph{false}(graph, var_eq_matching) + + # We need the inverse mapping of `var_sccs` to update it efficiently later. + v_to_scc = Vector{NTuple{2, Int}}(undef, ndsts(graph)) + for (i, scc) in enumerate(var_sccs), (j, v) in enumerate(scc) + v_to_scc[v] = (i, j) + end + + vars_to_backshift = BitSet() + eqs_to_backshift = BitSet() + # add history for differential variables + for ivar in 1:length(fullvars) + ieq = var_eq_matching[ivar] + # the variable to backshift is a state variable which is not the + # derivative of any other one. + ieq isa SelectedState || continue + diff_to_var[ivar] === nothing || continue + push!(vars_to_backshift, ivar) + end + + inserts = Tuple{Int, Vector{Int}}[] + + for var in vars_to_backshift + add_backshifted_var!(state, var, iv) + # all backshifted vars are differential vars, hence SelectedState + push!(var_eq_matching, SelectedState()) + push!(full_var_eq_matching, unassigned) + # add to the SCCs right before the variable that was backshifted + push!(inserts, (v_to_scc[var][1], [length(fullvars)])) + end + + sort!(inserts, by = first) + new_sccs = insert_sccs(var_sccs, inserts) + return new_sccs +end + +""" + $(TYPEDSIGNATURES) + +Add the backshifted version of variable `ivar` to the system. +""" +function add_backshifted_var!(state::TearingState, ivar::Int, iv) + @unpack fullvars, structure = state + @unpack var_to_diff, graph, solvable_graph = structure + + var = fullvars[ivar] + newvar = simplify_shifts(Shift(iv, -1)(var)) + push!(fullvars, newvar) + inewvar = add_vertex!(var_to_diff) + add_edge!(var_to_diff, inewvar, ivar) + add_vertex!(graph, DST) + add_vertex!(solvable_graph, DST) + return inewvar +end + +""" + $(TYPEDSIGNATURES) + +Backshift the given expression `ex`. +""" +function backshift_expr(ex, iv) + ex isa Symbolic || return ex + return descend_lower_shift_varname_with_unit( + simplify_shifts(distribute_shift(Shift(iv, -1)(ex))), iv) +end + +function backshift_expr(ex::Equation, iv) + return backshift_expr(ex.lhs, iv) ~ backshift_expr(ex.rhs, iv) +end + """ $(TYPEDSIGNATURES) @@ -1025,7 +1197,7 @@ respectively. For fully-determined systems, both of these are empty. Overdetermi have extra equations, and underdetermined systems have extra variables. """ function get_extra_eqs_vars( - state::TearingState, full_var_eq_matching::Matching, fully_determined::Bool) + state::TearingState, var_eq_matching::Matching, full_var_eq_matching::Matching, fully_determined::Bool) fully_determined && return Int[], Int[] extra_eqs = Int[] @@ -1035,6 +1207,10 @@ function get_extra_eqs_vars( for v in 𝑑vertices(state.structure.graph) eq = full_var_eq_matching[v] eq isa Int && continue + # Only if the variable is also unmatched in `var_eq_matching`. + # Otherwise, `SelectedState` differential variables from order lowering + # are also considered "extra" + var_eq_matching[v] === unassigned || continue push!(extra_vars, v) end for eq in 𝑠vertices(state.structure.graph) From 190fa108bc47e4a3182c8c85ef8b115122d5f762 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 17:16:20 +0530 Subject: [PATCH 06/13] fix: fix codegen of implicit discrete systems --- src/systems/codegen.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/systems/codegen.jl b/src/systems/codegen.jl index 2a5bc1a728..17ff652c41 100644 --- a/src/systems/codegen.jl +++ b/src/systems/codegen.jl @@ -54,13 +54,16 @@ function generate_rhs(sys::System; implicit_dae = false, # Handle observables in algebraic equations, since they are shifted shifted_obs = Equation[distribute_shift(D(eq)) for eq in obs] obsidxs = observed_equations_used_by(sys, rhss; obs = shifted_obs) - extra_assignments = [Assignment(shifted_obs[i].lhs, shifted_obs[i].rhs) - for i in obsidxs] + ddvs = map(D, dvs) + + append!(extra_assignments, + [Assignment(shifted_obs[i].lhs, shifted_obs[i].rhs) + for i in obsidxs]) else D = Differential(t) + ddvs = map(D, dvs) rhss = [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] end - ddvs = map(D, dvs) else if !override_discrete && !is_discrete_system(sys) check_operator_variables(eqs, Differential) From 44f71a8cc4247d0a04e74eab232740892c3a53e3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 18:30:22 +0530 Subject: [PATCH 07/13] fix: fix toterm handling in `DiscreteProblem` --- src/problems/discreteproblem.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/problems/discreteproblem.jl b/src/problems/discreteproblem.jl index 8e1abc46e4..c4ab069ebe 100644 --- a/src/problems/discreteproblem.jl +++ b/src/problems/discreteproblem.jl @@ -43,11 +43,11 @@ end check_compatibility && check_compatible_system(DiscreteProblem, sys) dvs = unknowns(sys) - u0map = to_varmap(op, dvs) - add_toterms!(u0map; replace = true) + op = to_varmap(op, dvs) + add_toterms!(op; replace = true) f, u0, p = process_SciMLProblem(DiscreteFunction{iip, spec}, sys, op; t = tspan !== nothing ? tspan[1] : tspan, check_compatibility, expression, - build_initializeprob = false, kwargs...) + kwargs...) if expression == Val{true} u0 = :(f($u0, p, tspan[1])) From 9b69e0d02cf711960c9ed18cb42fddd0cc148d3a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 18:31:15 +0530 Subject: [PATCH 08/13] fix: handle scalarized array symbolics in `distribute_shift` --- src/structural_transformation/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index a5a8febd18..c462013923 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -596,7 +596,7 @@ function _distribute_shift(expr, shift) (op isa Pre || op isa Initial) && return expr args = arguments(expr) - if ModelingToolkit.isvariable(expr) + if ModelingToolkit.isvariable(expr) && operation(expr) !== getindex (length(args) == 1 && isequal(shift.t, only(args))) ? (return shift(expr)) : (return expr) elseif op isa Shift From 8747976430354aeb2c93a6ae7dd03feccb9889e2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 18:31:32 +0530 Subject: [PATCH 09/13] fix: perform `toterm` in `Initial` for `Shift`ed variables --- src/systems/abstractsystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index fb69d8b702..941739f127 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -503,7 +503,7 @@ function (f::Initial)(x) return x end # differential variables are default-toterm-ed - if iscall(x) && operation(x) isa Differential + if iscall(x) && operation(x) isa Union{Differential, Shift} x = default_toterm(x) end # don't double wrap From 0facf9b17a7ae1323b98a1f5a85b016ae0c46d9e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 18:31:49 +0530 Subject: [PATCH 10/13] test: update discrete system tests --- test/discrete_system.jl | 104 ++++++------------------------- test/implicit_discrete_system.jl | 6 +- 2 files changed, 22 insertions(+), 88 deletions(-) diff --git a/test/discrete_system.jl b/test/discrete_system.jl index ccebe5d346..d3db4df5f2 100644 --- a/test/discrete_system.jl +++ b/test/discrete_system.jl @@ -40,7 +40,6 @@ u = ModelingToolkit.varmap_to_vars( Dict([S(k - 1) => 1, I(k - 1) => 2, R(k - 1) => 3]), unknowns(syss)) p = MTKParameters(syss, [c, nsteps, δt, β, γ] .=> collect(1:5)) df.f(du, u, p, 0) -@test_broken getu(syss, [S, I, R]) reorderer = getu(syss, [S(k - 1), I(k - 1), R(k - 1)]) @test reorderer(du) ≈ [0.01831563888873422, 0.9816849729159067, 4.999999388195359] @@ -49,16 +48,17 @@ reorderer = getu(syss, [S(k - 1), I(k - 1), R(k - 1)]) [0.01831563888873422, 0.9816849729159067, 4.999999388195359] # Problem -u0 = [S(k - 1) => 990.0, I(k - 1) => 10.0, R(k - 1) => 0.0] +u0 = [S => 990.0, I => 10.0, R => 0.0] p = [β => 0.05, c => 10.0, γ => 0.25, δt => 0.1, nsteps => 400] tspan = (0.0, ModelingToolkit.value(substitute(nsteps, p))) # value function (from Symbolics) is used to convert a Num to Float64 -prob_map = DiscreteProblem(syss, [u0; p], tspan) +prob_map = DiscreteProblem( + syss, [u0; p], tspan; guesses = [S(k - 1) => 1.0, I(k - 1) => 1.0, R(k - 1) => 1.0]) @test prob_map.f.sys === syss # Solution using OrdinaryDiffEq sol_map = solve(prob_map, FunctionMap()); -@test_broken sol_map[S] isa Vector +@test sol_map[S] isa Vector @test sol_map[S(k - 1)] isa Vector # Using defaults constructor @@ -78,16 +78,16 @@ eqs2 = [S ~ S(k - 1) - infection2, eqs2, t, [S, I, R, R2], [c, nsteps, δt, β, γ]) @test ModelingToolkit.defaults(sys) != Dict() -@test_broken DiscreteProblem(sys, [], tspan) -prob_map2 = DiscreteProblem(sys, [S(k - 1) => S, I(k - 1) => I, R(k - 1) => R], tspan) +prob_map2 = DiscreteProblem(sys, [], tspan) +# prob_map2 = DiscreteProblem(sys, [S(k - 1) => S, I(k - 1) => I, R(k - 1) => R], tspan) sol_map2 = solve(prob_map2, FunctionMap()); @test sol_map.u ≈ sol_map2.u for p in parameters(sys) @test sol_map.prob.ps[p] ≈ sol_map2.prob.ps[p] end -@test sol_map2[R2][begin:(end - 1)] == sol_map2[R(k - 1)][(begin + 1):end] -@test_broken sol_map2[R2(k + 1)][begin:(end - 1)] == sol_map2[R][(begin + 1):end] +@test sol_map2[R2][begin:(end - 1)] == sol_map2[R(k - 1)][(begin + 1):end] == + sol_map2[R][begin:(end - 1)] # Direct Implementation function sir_map!(u_diff, u, p, t) @@ -103,14 +103,12 @@ function sir_map!(u_diff, u, p, t) end nothing end; -@test_broken prob_map2[[S, I, R]] -u0 = prob_map2[[S(k - 1), I(k - 1), R(k - 1)]]; +u0 = sol_map2[[S, I, R], 1]; p = [0.05, 10.0, 0.25, 0.1]; prob_map = DiscreteProblem(sir_map!, u0, tspan, p); sol_map2 = solve(prob_map, FunctionMap()); -@test_broken reduce(hcat, sol_map[[S, I, R]]) ≈ Array(sol_map2) -@test reduce(hcat, sol_map[[S(k - 1), I(k - 1), R(k - 1)]]) ≈ Array(sol_map2) +@test reduce(hcat, sol_map[[S, I, R]]) ≈ Array(sol_map2) # Delayed difference equation # @variables x(..) y(..) z(t) @@ -217,7 +215,7 @@ eqs = [u ~ 1 prob = DiscreteProblem(de, [x(k - 1) => 0.0], (0, 10)) sol = solve(prob, FunctionMap()) -@test reduce(vcat, sol.u) == 1:11 +@test sol[x] == 1:11 # Issue#2585 getdata(buffer, t) = buffer[mod1(Int(t), length(buffer))] @@ -251,77 +249,12 @@ end @test_nowarn @mtkcompile sys = System(; buffer = ones(10)) @testset "Passing `nothing` to `u0`" begin - @test_broken begin - @variables x(t) = 1 - k = ShiftIndex() - @mtkcompile sys = System([x(k) ~ x(k - 1) + 1], t) - prob = @test_nowarn DiscreteProblem(sys, nothing, (0.0, 1.0)) - @test_nowarn solve(prob, FunctionMap()) - end -end - -@testset "Initialization" begin - @test_broken begin - # test that default values apply to the entire history - @variables x(t) = 1.0 - @mtkcompile de = System([x ~ x(k - 1) + x(k - 2)], t) - prob = DiscreteProblem(de, [], (0, 10)) - @test prob[x] == 2.0 - @test prob[x(k - 1)] == 1.0 - - # must provide initial conditions for history - @test_throws ErrorException DiscreteProblem(de, [x => 2.0], (0, 10)) - @test_throws ErrorException DiscreteProblem(de, [x(k + 1) => 2.0], (0, 10)) - - # initial values only affect _that timestep_, not the entire history - prob = DiscreteProblem(de, [x(k - 1) => 2.0], (0, 10)) - @test prob[x] == 3.0 - @test prob[x(k - 1)] == 2.0 - @variables xₜ₋₁(t) - @test prob[xₜ₋₁] == 2.0 - - # Test initial assignment with lowered variable - prob = DiscreteProblem(de, [xₜ₋₁(k - 1) => 4.0], (0, 10)) - @test prob[x(k - 1)] == prob[xₜ₋₁] == 1.0 - @test prob[x] == 5.0 - - # Test missing initial throws error - @variables x(t) - @mtkcompile de = System([x ~ x(k - 1) + x(k - 2) * x(k - 3)], t) - @test_throws ErrorException prob=DiscreteProblem(de, [x(k - 3) => 2.0], (0, 10)) - @test_throws ErrorException prob=DiscreteProblem( - de, [x(k - 3) => 2.0, x(k - 1) => 3.0], (0, 10)) - - # Test non-assigned initials are given default value - @variables x(t) = 2.0 - @mtkcompile de = System([x ~ x(k - 1) + x(k - 2) * x(k - 3)], t) - prob = DiscreteProblem(de, [x(k - 3) => 12.0], (0, 10)) - @test prob[x] == 26.0 - @test prob[x(k - 1)] == 2.0 - @test prob[x(k - 2)] == 2.0 - - # Elaborate test - @variables xₜ₋₂(t) zₜ₋₁(t) z(t) - eqs = [x ~ x(k - 1) + z(k - 2), - z ~ x(k - 2) * x(k - 3) - z(k - 1)^2] - @mtkcompile de = System(eqs, t) - u0 = [x(k - 1) => 3, - xₜ₋₂(k - 1) => 4, - x(k - 2) => 1, - z(k - 1) => 5, - zₜ₋₁(k - 1) => 12] - prob = DiscreteProblem(de, u0, (0, 10)) - @test prob[x] == 15 - @test prob[z] == -21 - - import ModelingToolkit: shift2term - # unknowns(de) = xₜ₋₁, x, zₜ₋₁, xₜ₋₂, z - vars = sort(ModelingToolkit.value.(unknowns(de)); by = string) - @test isequal(shift2term(Shift(t, 1)(vars[2])), vars[1]) - @test isequal(shift2term(Shift(t, 1)(vars[3])), vars[2]) - @test isequal(shift2term(Shift(t, -1)(vars[4])), vars[5]) - @test isequal(shift2term(Shift(t, -2)(vars[1])), vars[3]) - end + @variables x(t) = 1 + k = ShiftIndex() + @mtkcompile sys = System([x(k) ~ x(k - 1) + 1], t) + prob = @test_nowarn DiscreteProblem(sys, nothing, (0.0, 1.0)) + sol = solve(prob, FunctionMap()) + @test SciMLBase.successful_retcode(sol) end @testset "Shifted array variables" begin @@ -339,6 +272,5 @@ end (0, 4)) @test all(isone, prob.u0) sol = solve(prob, FunctionMap()) - @test_broken sol[[x..., y...], end] - @test sol[[x(k - 1)..., y(k - 1)...], end] == 8ones(4) + @test sol[[x..., y...], end] == 8ones(4) end diff --git a/test/implicit_discrete_system.jl b/test/implicit_discrete_system.jl index bbfb045b1f..57c67116d1 100644 --- a/test/implicit_discrete_system.jl +++ b/test/implicit_discrete_system.jl @@ -1,4 +1,4 @@ -using ModelingToolkit, Test +using ModelingToolkit, SymbolicIndexingInterface, Test using ModelingToolkit: t_nounits as t using StableRNGs @@ -45,6 +45,8 @@ end 1 - (u_next[1] + u_next[2])^2 - u_next[3]^2] end + reorderer = getu(sys, [x(k - 2), x(k - 1), y]) + for _ in 1:10 u_next = rand(rng, 3) u = rand(rng, 3) @@ -73,6 +75,6 @@ end y(k) ~ x(k - 1) + x(k - 2), z(k) * x(k) ~ 3] @mtkcompile sys = System(eqs, t) - @test occursin("var\"Shift(t, 1)(z(t))\"", + @test occursin("var\"Shift(t, 1)(x(t))\"", string(ImplicitDiscreteFunction(sys; expression = Val{true}))) end From aaf7c579e1a1bbc44532b3910938eef59b14f471 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 7 Jun 2025 01:08:32 +0530 Subject: [PATCH 11/13] fix: do not shift inside `Pre` in `shift_discrete_system` --- src/systems/systemstructure.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index a2e021c66e..7e00f6ef95 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -544,21 +544,21 @@ function shift_discrete_system(ts::TearingState) discvars = OrderedSet() eqs = equations(sys) for eq in eqs - vars!(discvars, eq; op = Union{Sample, Hold}) + vars!(discvars, eq; op = Union{Sample, Hold, Pre}) end iv = get_iv(sys) discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k)) for k in discvars - if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold})) + if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold, Pre})) for i in eachindex(fullvars) fullvars[i] = StructuralTransformations.simplify_shifts(fast_substitute( - fullvars[i], discmap; operator = Union{Sample, Hold})) + fullvars[i], discmap; operator = Union{Sample, Hold, Pre})) end for i in eachindex(eqs) eqs[i] = StructuralTransformations.simplify_shifts(fast_substitute( - eqs[i], discmap; operator = Union{Sample, Hold})) + eqs[i], discmap; operator = Union{Sample, Hold, Pre})) end @set! ts.sys.eqs = eqs @set! ts.fullvars = fullvars From da27d6999bbe9530fb1f11574e754fe60136ddc0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 7 Jun 2025 11:14:18 +0530 Subject: [PATCH 12/13] fix: do not `distribute_shift` inside `Sample` and `Hold` --- src/structural_transformation/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index c462013923..02a9763226 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -593,7 +593,7 @@ end function _distribute_shift(expr, shift) if iscall(expr) op = operation(expr) - (op isa Pre || op isa Initial) && return expr + (op isa Union{Pre, Initial, Sample, Hold}) && return expr args = arguments(expr) if ModelingToolkit.isvariable(expr) && operation(expr) !== getindex From adfc91cf43e4f67ceb87863ea7e5c2d7ce9d7e08 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 7 Jun 2025 11:14:24 +0530 Subject: [PATCH 13/13] test: update clock tests --- test/clock.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/clock.jl b/test/clock.jl index a50026b38f..7afd7572fb 100644 --- a/test/clock.jl +++ b/test/clock.jl @@ -74,8 +74,8 @@ sss = ModelingToolkit._mtkcompile!( d = Clock(dt) k = ShiftIndex(d) @test issetequal(observed(sss), - [yd(k + 1) ~ Sample(dt)(y); r(k + 1) ~ 1.0; - ud(k + 1) ~ kp * (r(k + 1) - yd(k + 1))]) + [yd ~ Sample(dt)(y); r ~ 1.0; + ud ~ kp * (r - yd)]) canonical_eqs = map(eqs) do eq if iscall(eq.lhs) && operation(eq.lhs) isa Differential