Skip to content

add option to set state priority as argument to structural_simplify #2048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions src/structural_transformation/bipartite_tearing/modia_tearing.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,31 @@
# This code is derived from the Modia project and is licensed as follows:
# https://github.com/ModiaSim/Modia.jl/blob/b61daad643ef7edd0c1ccce6bf462c6acfb4ad1a/LICENSE

struct OrderedBitSet <: AbstractSet{Int}
bitset::BitSet
order::Vector{Int}
end
OrderedBitSet() = OrderedBitSet(BitSet(), Int[])
Base.iterate(o::OrderedBitSet) = Base.iterate(o.order)
Base.iterate(o::OrderedBitSet, state) = Base.iterate(o.order, state)
Base.in(a::Int, o::OrderedBitSet) = a in o.bitset
function Base.push!(o::OrderedBitSet, a::Int)
if !(a in o.bitset)
push!(o.bitset, a)
push!(o.order, a)
end
o
end
function Base.empty!(o::OrderedBitSet)
empty!(o.bitset)
empty!(o.order)
o
end
function Base.sort!(o::OrderedBitSet; kw...)
sort!(o.order; kw...)
o
end

function try_assign_eq!(ict::IncrementalCycleTracker, vj::Integer, eq::Integer)
G = ict.graph
add_edge_checked!(ict, Iterators.filter(!=(vj), 𝑠neighbors(G.graph, eq)), vj) do G
Expand All @@ -20,7 +45,7 @@ function try_assign_eq!(ict::IncrementalCycleTracker, vars, v_active, eq::Intege
end

function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int},
v_active::BitSet, isder′::F) where {F}
v_active::OrderedBitSet, isder′::F) where {F}
check_der = isder′ !== nothing
if check_der
has_der = Ref(false)
Expand Down Expand Up @@ -54,18 +79,22 @@ function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}
end

function tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, eqs, vars,
isder::F) where {F}
isder::F, solved_eq) where {F}
tearEquations!(ict, solvable_graph.fadjlist, eqs, vars, isder)
for var in vars
var_eq_matching[var] = ict.graph.matching[var]
eq = var_eq_matching[var] = ict.graph.matching[var]
if eq isa Int
push!(solved_eq, eq)
end
end
return nothing
end

function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
::Type{U} = Unassigned;
varfilter::F2 = v -> true,
eqfilter::F3 = eq -> true) where {F, U, F2, F3}
eqfilter::F3 = eq -> true,
state_priority::F4 = nothing) where {F, U, F2, F3, F4}
# It would be possible here to simply iterate over all variables and attempt to
# use tearEquations! to produce a matching that greedily selects the minimal
# number of torn variables. However, we can do this process faster if we first
Expand All @@ -88,20 +117,29 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
ict = IncrementalCycleTracker(vargraph; dir = :in)

ieqs = Int[]
filtered_vars = BitSet()
filtered_vars = OrderedBitSet()
solved_eq = BitSet()
if state_priority !== nothing
sort!(var_sccs, by = Base.Fix1(sum, state_priority))
end
for vars in var_sccs
for var in vars
if varfilter(var)
push!(filtered_vars, var)
if var_eq_matching[var] !== unassigned
push!(ieqs, var_eq_matching[var])
for eq in 𝑑neighbors(graph, var)
if !(eq in solved_eq)
push!(ieqs, eq)
end
end
end
var_eq_matching[var] = unassigned
end
if state_priority !== nothing
sort!(filtered_vars, by = state_priority)
end
tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, ieqs,
filtered_vars,
isder)
isder, solved_eq)

# clear cache
vargraph.ne = 0
Expand Down
2 changes: 1 addition & 1 deletion src/structural_transformation/partial_state_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja

var_eq_matching, full_var_eq_matching = tear_graph_modia(structure, isdiffed,
Union{Unassigned, SelectedState};
varfilter = can_eliminate)
varfilter = can_eliminate, state_priority)
for v in eachindex(var_eq_matching)
is_not_present(v) && continue
dv = var_to_diff[v]
Expand Down
11 changes: 11 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,17 @@ function markio!(state, orig_inputs, inputs, outputs; check = true)
state, orig_inputs
end

function set_priorities!(state, priorities)
fullvars = state.fullvars
prio_dict = Dict(priorities)
for (i, v) in enumerate(fullvars)
p = get(prio_dict, v, nothing)
p === nothing && continue
v = setmetadata(v, VariableStatePriority, p)
fullvars[i] = v
end
end

"""
(; A, B, C, D), simplified_sys = linearize(sys, inputs, outputs; t=0.0, op = Dict(), allow_input_derivatives = false, zero_dummy_der=false, kwargs...)
(; A, B, C, D) = linearize(simplified_sys, lin_fun; t=0.0, op = Dict(), allow_input_derivatives = false, zero_dummy_der=false)
Expand Down
3 changes: 2 additions & 1 deletion src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -599,13 +599,14 @@ end

function _structural_simplify!(state::TearingState, io; simplify = false,
check_consistency = true, fully_determined = true,
kwargs...)
priorities = Dict(), kwargs...)
check_consistency &= fully_determined
has_io = io !== nothing
orig_inputs = Set()
if has_io
ModelingToolkit.markio!(state, orig_inputs, io...)
end
isempty(priorities) || ModelingToolkit.set_priorities!(state, priorities)
state, input_idxs = ModelingToolkit.inputs_to_parameters!(state, io)
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
if check_consistency
Expand Down
49 changes: 49 additions & 0 deletions test/structural_transformation/tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,52 @@ sys = structural_simplify(ms_model)
prob_complex = ODAEProblem(sys, u0, (0, 1.0))
sol = solve(prob_complex, Tsit5())
@test all(sol[mass.v] .== 1)

## Test priorities
@parameters t
D = Differential(t)

function Cart(; init_pos, init_vel, mass, name = :cart)
@variables pos(t)=init_pos vel(t)=init_vel
@variables f(t)
@parameters mass = mass

eqs = [D(pos) ~ vel
D(vel) ~ f / mass]

return ODESystem(eqs; name)
end

function PDController(; kp = 0, kd = 0, name = :controller)
@variables x(t) v(t) f(t)
@parameters kp=kp kd=kd

eqs = [
f ~ -kp * x - kd * v,
]

return ODESystem(eqs; name)
end

function ControlledCart(; cart, cont, name = :sys)
eqs = [cart.pos ~ cont.x
cart.vel ~ cont.v
cart.f ~ cont.f]
return ODESystem(eqs; name, systems = [cart, cont])
end

cart = Cart(init_pos = 0.0, init_vel = 1.0, mass = 0.5)
cont = PDController(kp = 1.0, kd = 0.5)
controlled_cart = ControlledCart(; cart, cont)
@variables z(t) k(t)
eqs = [z ~ k;
sqrt(z) ~ abs(cont.x)]
@named alge = ODESystem(eqs, t);
controlled_cart = extend(controlled_cart, alge)
# Test that our state priorities are respected
s1 = states(structural_simplify(controlled_cart,
priorities = [k => 2]))
s2 = states(structural_simplify(controlled_cart,
priorities = [z => 2]))
@test Set(s1) == Set([cart.pos, cart.vel, k])
@test Set(s2) == Set([cart.pos, cart.vel, z])