Skip to content

Commit b1e09a6

Browse files
Merge pull request #3681 from AayushSabharwal/as/better-sccs
refactor: BLT sort equations and variables in tearing
2 parents 7d330f4 + 26437e2 commit b1e09a6

22 files changed

+513
-405
lines changed

src/problems/sccnonlinearproblem.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,21 +80,32 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
8080
end
8181

8282
ts = get_tearing_state(sys)
83-
var_eq_matching, var_sccs = StructuralTransformations.algebraic_variables_scc(ts)
83+
sched = get_schedule(sys)
84+
if sched === nothing
85+
@warn "System is simplified but does not have a schedule. This should not happen."
86+
var_eq_matching, var_sccs = StructuralTransformations.algebraic_variables_scc(ts)
87+
condensed_graph = MatchedCondensationGraph(
88+
DiCMOBiGraph{true}(complete(ts.structure.graph),
89+
complete(var_eq_matching)),
90+
var_sccs)
91+
toporder = topological_sort_by_dfs(condensed_graph)
92+
var_sccs = var_sccs[toporder]
93+
eq_sccs = map(Base.Fix1(getindex, var_eq_matching), var_sccs)
94+
else
95+
var_sccs = sched.var_sccs
96+
# Equations are already in the order of SCCs
97+
eq_sccs = length.(var_sccs)
98+
cumsum!(eq_sccs, eq_sccs)
99+
eq_sccs = map(enumerate(eq_sccs)) do (i, lasti)
100+
i == 1 ? (1:lasti) : ((eq_sccs[i - 1] + 1):lasti)
101+
end
102+
end
84103

85104
if length(var_sccs) == 1
86105
return NonlinearProblem{iip}(
87106
sys, op; eval_expression, eval_module, kwargs...)
88107
end
89108

90-
condensed_graph = MatchedCondensationGraph(
91-
DiCMOBiGraph{true}(complete(ts.structure.graph),
92-
complete(var_eq_matching)),
93-
var_sccs)
94-
toporder = topological_sort_by_dfs(condensed_graph)
95-
var_sccs = var_sccs[toporder]
96-
eq_sccs = map(Base.Fix1(getindex, var_eq_matching), var_sccs)
97-
98109
dvs = unknowns(sys)
99110
ps = parameters(sys)
100111
eqs = equations(sys)

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Diffe
2626
filter_kwargs, lower_varname_with_unit,
2727
lower_shift_varname_with_unit, setio, SparseMatrixCLIL,
2828
get_fullvars, has_equations, observed,
29-
Schedule, schedule
29+
Schedule, schedule, iscomplete, get_schedule
3030

3131
using ModelingToolkit.BipartiteGraphs
3232
import .BipartiteGraphs: invview, complete
@@ -55,7 +55,7 @@ using SimpleNonlinearSolve
5555

5656
using DocStringExtensions
5757

58-
export tearing, partial_state_selection, dae_index_lowering, check_consistency
58+
export tearing, dae_index_lowering, check_consistency
5959
export dummy_derivative
6060
export sorted_incidence_matrix,
6161
pantelides!, pantelides_reassemble, tearing_reassemble, find_solvables!,

src/structural_transformation/partial_state_selection.jl

Lines changed: 1 addition & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -1,173 +1,4 @@
1-
function partial_state_selection_graph!(state::TransformationState)
2-
find_solvables!(state; allow_symbolic = true)
3-
var_eq_matching = complete(pantelides!(state))
4-
complete!(state.structure)
5-
partial_state_selection_graph!(state.structure, var_eq_matching)
6-
end
7-
8-
function ascend_dg(xs, dg, level)
9-
while level > 0
10-
xs = Int[dg[x] for x in xs]
11-
level -= 1
12-
end
13-
return xs
14-
end
15-
16-
function ascend_dg_all(xs, dg, level, maxlevel)
17-
r = Int[]
18-
while true
19-
if level <= 0
20-
append!(r, xs)
21-
end
22-
maxlevel <= 0 && break
23-
xs = Int[dg[x] for x in xs if dg[x] !== nothing]
24-
level -= 1
25-
maxlevel -= 1
26-
end
27-
return r
28-
end
29-
30-
function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, varlevel,
31-
inv_varlevel, inv_eqlevel)
32-
@unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure
33-
34-
# var_eq_matching is a maximal matching on the top-differentiated variables.
35-
# Find Strongly connected components. Note that after pantelides, we expect
36-
# a balanced system, so a maximal matching should be possible.
37-
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, maximal_top_matching)
38-
var_eq_matching = Matching{Union{Unassigned, SelectedState}}(ndsts(graph))
39-
for vars in var_sccs
40-
# TODO: We should have a way to not have the scc code look at unassigned vars.
41-
if length(vars) == 1 && maximal_top_matching[vars[1]] === unassigned
42-
continue
43-
end
44-
45-
# Now proceed level by level from lowest to highest and tear the graph.
46-
eqs = [maximal_top_matching[var]
47-
for var in vars if maximal_top_matching[var] !== unassigned]
48-
isempty(eqs) && continue
49-
maxeqlevel = maximum(map(x -> inv_eqlevel[x], eqs))
50-
maxvarlevel = level = maximum(map(x -> inv_varlevel[x], vars))
51-
old_level_vars = ()
52-
ict = IncrementalCycleTracker(
53-
DiCMOBiGraph{true}(graph,
54-
complete(Matching(ndsts(graph)), nsrcs(graph))),
55-
dir = :in)
56-
57-
while level >= 0
58-
to_tear_eqs_toplevel = filter(eq -> inv_eqlevel[eq] >= level, eqs)
59-
to_tear_eqs = ascend_dg(to_tear_eqs_toplevel, invview(eq_to_diff), level)
60-
61-
to_tear_vars_toplevel = filter(var -> inv_varlevel[var] >= level, vars)
62-
to_tear_vars = ascend_dg(to_tear_vars_toplevel, invview(var_to_diff), level)
63-
64-
assigned_eqs = Int[]
65-
66-
if old_level_vars !== ()
67-
# Inherit constraints from previous level.
68-
# TODO: Is this actually a good idea or do we want full freedom
69-
# to tear differently on each level? Does it make a difference
70-
# whether we're using heuristic or optimal tearing?
71-
removed_eqs = Int[]
72-
removed_vars = Int[]
73-
for var in old_level_vars
74-
old_assign = var_eq_matching[var]
75-
if isa(old_assign, SelectedState)
76-
push!(removed_vars, var)
77-
continue
78-
elseif !isa(old_assign, Int) ||
79-
ict.graph.matching[var_to_diff[var]] !== unassigned
80-
continue
81-
end
82-
# Make sure the ict knows about this edge, so it doesn't accidentally introduce
83-
# a cycle.
84-
assgned_eq = eq_to_diff[old_assign]
85-
ok = try_assign_eq!(ict, var_to_diff[var], assgned_eq)
86-
@assert ok
87-
var_eq_matching[var_to_diff[var]] = assgned_eq
88-
push!(removed_eqs, eq_to_diff[ict.graph.matching[var]])
89-
push!(removed_vars, var_to_diff[var])
90-
push!(removed_vars, var)
91-
end
92-
to_tear_eqs = setdiff(to_tear_eqs, removed_eqs)
93-
to_tear_vars = setdiff(to_tear_vars, removed_vars)
94-
end
95-
tearEquations!(ict, solvable_graph.fadjlist, to_tear_eqs, BitSet(to_tear_vars),
96-
nothing)
97-
98-
for var in to_tear_vars
99-
@assert var_eq_matching[var] === unassigned
100-
assgned_eq = ict.graph.matching[var]
101-
var_eq_matching[var] = assgned_eq
102-
isa(assgned_eq, Int) && push!(assigned_eqs, assgned_eq)
103-
end
104-
105-
if level != 0
106-
remaining_vars = collect(v
107-
for v in to_tear_vars
108-
if var_eq_matching[v] === unassigned)
109-
if !isempty(remaining_vars)
110-
remaining_eqs = setdiff(to_tear_eqs, assigned_eqs)
111-
nlsolve_matching = maximal_matching(graph,
112-
Base.Fix2(in, remaining_eqs),
113-
Base.Fix2(in, remaining_vars))
114-
for var in remaining_vars
115-
if nlsolve_matching[var] === unassigned &&
116-
var_eq_matching[var] === unassigned
117-
var_eq_matching[var] = SelectedState()
118-
end
119-
end
120-
end
121-
end
122-
123-
old_level_vars = to_tear_vars
124-
level -= 1
125-
end
126-
end
127-
return complete(var_eq_matching, nsrcs(graph))
128-
end
129-
1301
struct SelectedState end
131-
function partial_state_selection_graph!(structure::SystemStructure, var_eq_matching)
132-
@unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure
133-
eq_to_diff = complete(eq_to_diff)
134-
135-
inv_eqlevel = map(1:nsrcs(graph)) do eq
136-
level = 0
137-
while invview(eq_to_diff)[eq] !== nothing
138-
eq = invview(eq_to_diff)[eq]
139-
level += 1
140-
end
141-
level
142-
end
143-
144-
varlevel = map(1:ndsts(graph)) do var
145-
graph_level = level = 0
146-
while var_to_diff[var] !== nothing
147-
var = var_to_diff[var]
148-
level += 1
149-
if !isempty(𝑑neighbors(graph, var))
150-
graph_level = level
151-
end
152-
end
153-
graph_level
154-
end
155-
156-
inv_varlevel = map(1:ndsts(graph)) do var
157-
level = 0
158-
while invview(var_to_diff)[var] !== nothing
159-
var = invview(var_to_diff)[var]
160-
level += 1
161-
end
162-
level
163-
end
164-
165-
var_eq_matching = pss_graph_modia!(structure,
166-
complete(var_eq_matching), varlevel, inv_varlevel,
167-
inv_eqlevel)
168-
169-
var_eq_matching
170-
end
1712

1723
function dummy_derivative_graph!(state::TransformationState, jac = nothing;
1734
state_priority = nothing, log = Val(false), kwargs...)
@@ -343,11 +174,7 @@ function dummy_derivative_graph!(
343174
end
344175

345176
ret = tearing_with_dummy_derivatives(structure, BitSet(dummy_derivatives))
346-
if log
347-
(ret..., DummyDerivativeSummary(var_dummy_scc, var_state_priority))
348-
else
349-
ret[1]
350-
end
177+
(ret..., DummyDerivativeSummary(var_dummy_scc, var_state_priority))
351178
end
352179

353180
function is_present(structure, v)::Bool

0 commit comments

Comments
 (0)