Skip to content

refactor: BLT sort equations and variables in tearing #3681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions src/problems/sccnonlinearproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,32 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
end

ts = get_tearing_state(sys)
var_eq_matching, var_sccs = StructuralTransformations.algebraic_variables_scc(ts)
sched = get_schedule(sys)
if sched === nothing
@warn "System is simplified but does not have a schedule. This should not happen."
var_eq_matching, var_sccs = StructuralTransformations.algebraic_variables_scc(ts)
condensed_graph = MatchedCondensationGraph(
DiCMOBiGraph{true}(complete(ts.structure.graph),
complete(var_eq_matching)),
var_sccs)
toporder = topological_sort_by_dfs(condensed_graph)
var_sccs = var_sccs[toporder]
eq_sccs = map(Base.Fix1(getindex, var_eq_matching), var_sccs)
else
var_sccs = sched.var_sccs
# Equations are already in the order of SCCs
eq_sccs = length.(var_sccs)
cumsum!(eq_sccs, eq_sccs)
eq_sccs = map(enumerate(eq_sccs)) do (i, lasti)
i == 1 ? (1:lasti) : ((eq_sccs[i - 1] + 1):lasti)
end
end

if length(var_sccs) == 1
return NonlinearProblem{iip}(
sys, op; eval_expression, eval_module, kwargs...)
end

condensed_graph = MatchedCondensationGraph(
DiCMOBiGraph{true}(complete(ts.structure.graph),
complete(var_eq_matching)),
var_sccs)
toporder = topological_sort_by_dfs(condensed_graph)
var_sccs = var_sccs[toporder]
eq_sccs = map(Base.Fix1(getindex, var_eq_matching), var_sccs)

dvs = unknowns(sys)
ps = parameters(sys)
eqs = equations(sys)
Expand Down
4 changes: 2 additions & 2 deletions src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Diffe
filter_kwargs, lower_varname_with_unit,
lower_shift_varname_with_unit, setio, SparseMatrixCLIL,
get_fullvars, has_equations, observed,
Schedule, schedule
Schedule, schedule, iscomplete, get_schedule

using ModelingToolkit.BipartiteGraphs
import .BipartiteGraphs: invview, complete
Expand Down Expand Up @@ -55,7 +55,7 @@ using SimpleNonlinearSolve

using DocStringExtensions

export tearing, partial_state_selection, dae_index_lowering, check_consistency
export tearing, dae_index_lowering, check_consistency
export dummy_derivative
export sorted_incidence_matrix,
pantelides!, pantelides_reassemble, tearing_reassemble, find_solvables!,
Expand Down
175 changes: 1 addition & 174 deletions src/structural_transformation/partial_state_selection.jl
Original file line number Diff line number Diff line change
@@ -1,173 +1,4 @@
function partial_state_selection_graph!(state::TransformationState)
find_solvables!(state; allow_symbolic = true)
var_eq_matching = complete(pantelides!(state))
complete!(state.structure)
partial_state_selection_graph!(state.structure, var_eq_matching)
end

function ascend_dg(xs, dg, level)
while level > 0
xs = Int[dg[x] for x in xs]
level -= 1
end
return xs
end

function ascend_dg_all(xs, dg, level, maxlevel)
r = Int[]
while true
if level <= 0
append!(r, xs)
end
maxlevel <= 0 && break
xs = Int[dg[x] for x in xs if dg[x] !== nothing]
level -= 1
maxlevel -= 1
end
return r
end

function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, varlevel,
inv_varlevel, inv_eqlevel)
@unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure

# var_eq_matching is a maximal matching on the top-differentiated variables.
# Find Strongly connected components. Note that after pantelides, we expect
# a balanced system, so a maximal matching should be possible.
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, maximal_top_matching)
var_eq_matching = Matching{Union{Unassigned, SelectedState}}(ndsts(graph))
for vars in var_sccs
# TODO: We should have a way to not have the scc code look at unassigned vars.
if length(vars) == 1 && maximal_top_matching[vars[1]] === unassigned
continue
end

# Now proceed level by level from lowest to highest and tear the graph.
eqs = [maximal_top_matching[var]
for var in vars if maximal_top_matching[var] !== unassigned]
isempty(eqs) && continue
maxeqlevel = maximum(map(x -> inv_eqlevel[x], eqs))
maxvarlevel = level = maximum(map(x -> inv_varlevel[x], vars))
old_level_vars = ()
ict = IncrementalCycleTracker(
DiCMOBiGraph{true}(graph,
complete(Matching(ndsts(graph)), nsrcs(graph))),
dir = :in)

while level >= 0
to_tear_eqs_toplevel = filter(eq -> inv_eqlevel[eq] >= level, eqs)
to_tear_eqs = ascend_dg(to_tear_eqs_toplevel, invview(eq_to_diff), level)

to_tear_vars_toplevel = filter(var -> inv_varlevel[var] >= level, vars)
to_tear_vars = ascend_dg(to_tear_vars_toplevel, invview(var_to_diff), level)

assigned_eqs = Int[]

if old_level_vars !== ()
# Inherit constraints from previous level.
# TODO: Is this actually a good idea or do we want full freedom
# to tear differently on each level? Does it make a difference
# whether we're using heuristic or optimal tearing?
removed_eqs = Int[]
removed_vars = Int[]
for var in old_level_vars
old_assign = var_eq_matching[var]
if isa(old_assign, SelectedState)
push!(removed_vars, var)
continue
elseif !isa(old_assign, Int) ||
ict.graph.matching[var_to_diff[var]] !== unassigned
continue
end
# Make sure the ict knows about this edge, so it doesn't accidentally introduce
# a cycle.
assgned_eq = eq_to_diff[old_assign]
ok = try_assign_eq!(ict, var_to_diff[var], assgned_eq)
@assert ok
var_eq_matching[var_to_diff[var]] = assgned_eq
push!(removed_eqs, eq_to_diff[ict.graph.matching[var]])
push!(removed_vars, var_to_diff[var])
push!(removed_vars, var)
end
to_tear_eqs = setdiff(to_tear_eqs, removed_eqs)
to_tear_vars = setdiff(to_tear_vars, removed_vars)
end
tearEquations!(ict, solvable_graph.fadjlist, to_tear_eqs, BitSet(to_tear_vars),
nothing)

for var in to_tear_vars
@assert var_eq_matching[var] === unassigned
assgned_eq = ict.graph.matching[var]
var_eq_matching[var] = assgned_eq
isa(assgned_eq, Int) && push!(assigned_eqs, assgned_eq)
end

if level != 0
remaining_vars = collect(v
for v in to_tear_vars
if var_eq_matching[v] === unassigned)
if !isempty(remaining_vars)
remaining_eqs = setdiff(to_tear_eqs, assigned_eqs)
nlsolve_matching = maximal_matching(graph,
Base.Fix2(in, remaining_eqs),
Base.Fix2(in, remaining_vars))
for var in remaining_vars
if nlsolve_matching[var] === unassigned &&
var_eq_matching[var] === unassigned
var_eq_matching[var] = SelectedState()
end
end
end
end

old_level_vars = to_tear_vars
level -= 1
end
end
return complete(var_eq_matching, nsrcs(graph))
end

struct SelectedState end
function partial_state_selection_graph!(structure::SystemStructure, var_eq_matching)
@unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure
eq_to_diff = complete(eq_to_diff)

inv_eqlevel = map(1:nsrcs(graph)) do eq
level = 0
while invview(eq_to_diff)[eq] !== nothing
eq = invview(eq_to_diff)[eq]
level += 1
end
level
end

varlevel = map(1:ndsts(graph)) do var
graph_level = level = 0
while var_to_diff[var] !== nothing
var = var_to_diff[var]
level += 1
if !isempty(𝑑neighbors(graph, var))
graph_level = level
end
end
graph_level
end

inv_varlevel = map(1:ndsts(graph)) do var
level = 0
while invview(var_to_diff)[var] !== nothing
var = invview(var_to_diff)[var]
level += 1
end
level
end

var_eq_matching = pss_graph_modia!(structure,
complete(var_eq_matching), varlevel, inv_varlevel,
inv_eqlevel)

var_eq_matching
end

function dummy_derivative_graph!(state::TransformationState, jac = nothing;
state_priority = nothing, log = Val(false), kwargs...)
Expand Down Expand Up @@ -343,11 +174,7 @@ function dummy_derivative_graph!(
end

ret = tearing_with_dummy_derivatives(structure, BitSet(dummy_derivatives))
if log
(ret..., DummyDerivativeSummary(var_dummy_scc, var_state_priority))
else
ret[1]
end
(ret..., DummyDerivativeSummary(var_dummy_scc, var_state_priority))
end

function is_present(structure, v)::Bool
Expand Down
Loading
Loading