diff --git a/src/structural_transformation/bipartite_tearing/modia_tearing.jl b/src/structural_transformation/bipartite_tearing/modia_tearing.jl index e8304752e8..90c1d1f367 100644 --- a/src/structural_transformation/bipartite_tearing/modia_tearing.jl +++ b/src/structural_transformation/bipartite_tearing/modia_tearing.jl @@ -1,6 +1,31 @@ # This code is derived from the Modia project and is licensed as follows: # https://github.com/ModiaSim/Modia.jl/blob/b61daad643ef7edd0c1ccce6bf462c6acfb4ad1a/LICENSE +struct OrderedBitSet <: AbstractSet{Int} + bitset::BitSet + order::Vector{Int} +end +OrderedBitSet() = OrderedBitSet(BitSet(), Int[]) +Base.iterate(o::OrderedBitSet) = Base.iterate(o.order) +Base.iterate(o::OrderedBitSet, state) = Base.iterate(o.order, state) +Base.in(a::Int, o::OrderedBitSet) = a in o.bitset +function Base.push!(o::OrderedBitSet, a::Int) + if !(a in o.bitset) + push!(o.bitset, a) + push!(o.order, a) + end + o +end +function Base.empty!(o::OrderedBitSet) + empty!(o.bitset) + empty!(o.order) + o +end +function Base.sort!(o::OrderedBitSet; kw...) + sort!(o.order; kw...) + o +end + function try_assign_eq!(ict::IncrementalCycleTracker, vj::Integer, eq::Integer) G = ict.graph add_edge_checked!(ict, Iterators.filter(!=(vj), 𝑠neighbors(G.graph, eq)), vj) do G @@ -20,7 +45,7 @@ function try_assign_eq!(ict::IncrementalCycleTracker, vars, v_active, eq::Intege end function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}, - v_active::BitSet, isder′::F) where {F} + v_active::OrderedBitSet, isder′::F) where {F} check_der = isder′ !== nothing if check_der has_der = Ref(false) @@ -54,10 +79,13 @@ function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int} end function tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, eqs, vars, - isder::F) where {F} + isder::F, solved_eq) where {F} tearEquations!(ict, solvable_graph.fadjlist, eqs, vars, isder) for var in vars - var_eq_matching[var] = ict.graph.matching[var] + eq = var_eq_matching[var] = ict.graph.matching[var] + if eq isa Int + push!(solved_eq, eq) + end end return nothing end @@ -65,7 +93,8 @@ end function tear_graph_modia(structure::SystemStructure, isder::F = nothing, ::Type{U} = Unassigned; varfilter::F2 = v -> true, - eqfilter::F3 = eq -> true) where {F, U, F2, F3} + eqfilter::F3 = eq -> true, + state_priority::F4 = nothing) where {F, U, F2, F3, F4} # It would be possible here to simply iterate over all variables and attempt to # use tearEquations! to produce a matching that greedily selects the minimal # number of torn variables. However, we can do this process faster if we first @@ -88,20 +117,29 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing, ict = IncrementalCycleTracker(vargraph; dir = :in) ieqs = Int[] - filtered_vars = BitSet() + filtered_vars = OrderedBitSet() + solved_eq = BitSet() + if state_priority !== nothing + sort!(var_sccs, by = Base.Fix1(sum, state_priority)) + end for vars in var_sccs for var in vars if varfilter(var) push!(filtered_vars, var) - if var_eq_matching[var] !== unassigned - push!(ieqs, var_eq_matching[var]) + for eq in 𝑑neighbors(graph, var) + if !(eq in solved_eq) + push!(ieqs, eq) + end end end var_eq_matching[var] = unassigned end + if state_priority !== nothing + sort!(filtered_vars, by = state_priority) + end tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, ieqs, filtered_vars, - isder) + isder, solved_eq) # clear cache vargraph.ne = 0 diff --git a/src/structural_transformation/partial_state_selection.jl b/src/structural_transformation/partial_state_selection.jl index f2fec0269f..6b2ed3ac71 100644 --- a/src/structural_transformation/partial_state_selection.jl +++ b/src/structural_transformation/partial_state_selection.jl @@ -340,7 +340,7 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja var_eq_matching, full_var_eq_matching = tear_graph_modia(structure, isdiffed, Union{Unassigned, SelectedState}; - varfilter = can_eliminate) + varfilter = can_eliminate, state_priority) for v in eachindex(var_eq_matching) is_not_present(v) && continue dv = var_to_diff[v] diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 83a28a259c..533d93e400 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1452,6 +1452,17 @@ function markio!(state, orig_inputs, inputs, outputs; check = true) state, orig_inputs end +function set_priorities!(state, priorities) + fullvars = state.fullvars + prio_dict = Dict(priorities) + for (i, v) in enumerate(fullvars) + p = get(prio_dict, v, nothing) + p === nothing && continue + v = setmetadata(v, VariableStatePriority, p) + fullvars[i] = v + end +end + """ (; A, B, C, D), simplified_sys = linearize(sys, inputs, outputs; t=0.0, op = Dict(), allow_input_derivatives = false, zero_dummy_der=false, kwargs...) (; A, B, C, D) = linearize(simplified_sys, lin_fun; t=0.0, op = Dict(), allow_input_derivatives = false, zero_dummy_der=false) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 4e330619d5..3e1467eb6b 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -599,13 +599,14 @@ end function _structural_simplify!(state::TearingState, io; simplify = false, check_consistency = true, fully_determined = true, - kwargs...) + priorities = Dict(), kwargs...) check_consistency &= fully_determined has_io = io !== nothing orig_inputs = Set() if has_io ModelingToolkit.markio!(state, orig_inputs, io...) end + isempty(priorities) || ModelingToolkit.set_priorities!(state, priorities) state, input_idxs = ModelingToolkit.inputs_to_parameters!(state, io) sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...) if check_consistency diff --git a/test/structural_transformation/tearing.jl b/test/structural_transformation/tearing.jl index a30238a645..eed5f91aec 100644 --- a/test/structural_transformation/tearing.jl +++ b/test/structural_transformation/tearing.jl @@ -219,3 +219,52 @@ sys = structural_simplify(ms_model) prob_complex = ODAEProblem(sys, u0, (0, 1.0)) sol = solve(prob_complex, Tsit5()) @test all(sol[mass.v] .== 1) + +## Test priorities +@parameters t +D = Differential(t) + +function Cart(; init_pos, init_vel, mass, name = :cart) + @variables pos(t)=init_pos vel(t)=init_vel + @variables f(t) + @parameters mass = mass + + eqs = [D(pos) ~ vel + D(vel) ~ f / mass] + + return ODESystem(eqs; name) +end + +function PDController(; kp = 0, kd = 0, name = :controller) + @variables x(t) v(t) f(t) + @parameters kp=kp kd=kd + + eqs = [ + f ~ -kp * x - kd * v, + ] + + return ODESystem(eqs; name) +end + +function ControlledCart(; cart, cont, name = :sys) + eqs = [cart.pos ~ cont.x + cart.vel ~ cont.v + cart.f ~ cont.f] + return ODESystem(eqs; name, systems = [cart, cont]) +end + +cart = Cart(init_pos = 0.0, init_vel = 1.0, mass = 0.5) +cont = PDController(kp = 1.0, kd = 0.5) +controlled_cart = ControlledCart(; cart, cont) +@variables z(t) k(t) +eqs = [z ~ k; + sqrt(z) ~ abs(cont.x)] +@named alge = ODESystem(eqs, t); +controlled_cart = extend(controlled_cart, alge) +# Test that our state priorities are respected +s1 = states(structural_simplify(controlled_cart, + priorities = [k => 2])) +s2 = states(structural_simplify(controlled_cart, + priorities = [z => 2])) +@test Set(s1) == Set([cart.pos, cart.vel, k]) +@test Set(s2) == Set([cart.pos, cart.vel, z])