@@ -929,7 +929,7 @@ Update the system equations, unknowns, and observables after simplification.
929
929
"""
930
930
function update_simplified_system! (
931
931
state:: TearingState , neweqs, solved_eqs, dummy_sub, var_sccs, extra_unknowns;
932
- cse_hack = true , array_hack = true , D = nothing , iv = nothing )
932
+ array_hack = true , D = nothing , iv = nothing )
933
933
@unpack fullvars, structure = state
934
934
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
935
935
diff_to_var = invview (var_to_diff)
@@ -978,8 +978,7 @@ function update_simplified_system!(
978
978
end
979
979
@set! sys. unknowns = unknowns
980
980
981
- obs = cse_and_array_hacks (
982
- sys, obs, unknowns, neweqs; cse = cse_hack, array = array_hack)
981
+ obs = tearing_hacks (sys, obs, unknowns, neweqs; array = array_hack)
983
982
984
983
@set! sys. eqs = neweqs
985
984
@set! sys. observed = obs
@@ -1035,7 +1034,7 @@ differential variables.
1035
1034
according to `full_var_eq_matching`.
1036
1035
"""
1037
1036
function tearing_reassemble (state:: TearingState , var_eq_matching:: Matching ,
1038
- full_var_eq_matching:: Matching , var_sccs:: Vector{Vector{Int}} ; simplify = false , mm, cse_hack = true ,
1037
+ full_var_eq_matching:: Matching , var_sccs:: Vector{Vector{Int}} ; simplify = false , mm,
1039
1038
array_hack = true , fully_determined = true )
1040
1039
extra_eqs_vars = get_extra_eqs_vars (
1041
1040
state, var_eq_matching, full_var_eq_matching, fully_determined)
@@ -1074,7 +1073,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
1074
1073
# var_eq_matching and full_var_eq_matching are now invalidated
1075
1074
1076
1075
sys = update_simplified_system! (state, neweqs, solved_eqs, dummy_sub, var_sccs,
1077
- extra_unknowns; cse_hack, array_hack, iv, D)
1076
+ extra_unknowns; array_hack, iv, D)
1078
1077
1079
1078
@set! state. sys = sys
1080
1079
@set! sys. tearing_state = state
@@ -1223,60 +1222,22 @@ function get_extra_eqs_vars(
1223
1222
end
1224
1223
1225
1224
"""
1226
- # HACK 1
1227
-
1228
- Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
1229
- gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
1230
- _very_ expensive. this hack performs a limited form of CSE specifically for this case to
1231
- avoid the unnecessary cost. This and the below hack are implemented simultaneously
1232
-
1233
- # HACK 2
1225
+ # HACK
1234
1226
1235
1227
Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
1236
1228
equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
1237
1229
if all `p[i]` are present and the unscalarized form is used in any equation (observed or
1238
1230
not) we first count the number of times the scalarized form of each observed variable
1239
1231
occurs in observed equations (and unknowns if it's split).
1240
1232
"""
1241
- function cse_and_array_hacks (sys, obs, unknowns, neweqs; cse = true , array = true )
1242
- # HACK 1
1243
- # mapping of rhs to temporary CSE variable
1244
- # `f(...) => tmpvar` in above example
1245
- rhs_to_tempvar = Dict ()
1246
-
1247
- # HACK 2
1233
+ function tearing_hacks (sys, obs, unknowns, neweqs; array = true )
1248
1234
# map of array observed variable (unscalarized) to number of its
1249
1235
# scalarized terms that appear in observed equations
1250
1236
arr_obs_occurrences = Dict ()
1251
1237
for (i, eq) in enumerate (obs)
1252
1238
lhs = eq. lhs
1253
1239
rhs = eq. rhs
1254
1240
1255
- # HACK 1
1256
- if cse && is_getindexed_array (rhs)
1257
- rhs_arr = arguments (rhs)[1 ]
1258
- iscall (rhs_arr) && operation (rhs_arr) isa Symbolics. Operator && continue
1259
- if ! haskey (rhs_to_tempvar, rhs_arr)
1260
- tempvar = gensym (Symbol (lhs))
1261
- N = length (rhs_arr)
1262
- tempvar = unwrap (Symbolics. variable (
1263
- tempvar; T = Symbolics. symtype (rhs_arr)))
1264
- tempvar = setmetadata (
1265
- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
1266
- tempeq = tempvar ~ rhs_arr
1267
- rhs_to_tempvar[rhs_arr] = tempvar
1268
- push! (obs, tempeq)
1269
- end
1270
-
1271
- # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
1272
- # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
1273
- # which fails the topological sort
1274
- neweq = lhs ~ getindex_wrapper (
1275
- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
1276
- obs[i] = neweq
1277
- end
1278
- # end HACK 1
1279
-
1280
1241
array || continue
1281
1242
iscall (lhs) || continue
1282
1243
operation (lhs) === getindex || continue
@@ -1287,31 +1248,6 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
1287
1248
continue
1288
1249
end
1289
1250
1290
- # Also do CSE for `equations(sys)`
1291
- if cse
1292
- for (i, eq) in enumerate (neweqs)
1293
- (; lhs, rhs) = eq
1294
- is_getindexed_array (rhs) || continue
1295
- rhs_arr = arguments (rhs)[1 ]
1296
- if ! haskey (rhs_to_tempvar, rhs_arr)
1297
- tempvar = gensym (Symbol (lhs))
1298
- N = length (rhs_arr)
1299
- tempvar = unwrap (Symbolics. variable (
1300
- tempvar; T = Symbolics. symtype (rhs_arr)))
1301
- tempvar = setmetadata (
1302
- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
1303
- tempeq = tempvar ~ rhs_arr
1304
- rhs_to_tempvar[rhs_arr] = tempvar
1305
- push! (obs, tempeq)
1306
- end
1307
- # don't need getindex_wrapper, but do it anyway to know that this
1308
- # hack took place
1309
- neweq = lhs ~ getindex_wrapper (
1310
- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
1311
- neweqs[i] = neweq
1312
- end
1313
- end
1314
-
1315
1251
# count variables in unknowns if they are scalarized forms of variables
1316
1252
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
1317
1253
# is an observed equation.
@@ -1346,18 +1282,7 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
1346
1282
return obs
1347
1283
end
1348
1284
1349
- function is_getindexed_array (rhs)
1350
- (! ModelingToolkit. isvariable (rhs) || ModelingToolkit. iscalledparameter (rhs)) &&
1351
- iscall (rhs) && operation (rhs) === getindex &&
1352
- Symbolics. shape (rhs) != Symbolics. Unknown ()
1353
- end
1354
-
1355
- # PART OF HACK 1
1356
- getindex_wrapper (x, i) = x[i... ]
1357
-
1358
- @register_symbolic getindex_wrapper (x:: AbstractArray , i:: Tuple{Vararg{Int}} )
1359
-
1360
- # PART OF HACK 2
1285
+ # PART OF HACK
1361
1286
function change_origin (origin, arr)
1362
1287
if all (isone, Tuple (origin))
1363
1288
return arr
@@ -1385,11 +1310,11 @@ new residual equations after tearing. End users are encouraged to call [`mtkcomp
1385
1310
instead, which calls this function internally.
1386
1311
"""
1387
1312
function tearing (sys:: AbstractSystem , state = TearingState (sys); mm = nothing ,
1388
- simplify = false , cse_hack = true , array_hack = true , fully_determined = true , kwargs... )
1313
+ simplify = false , array_hack = true , fully_determined = true , kwargs... )
1389
1314
var_eq_matching, full_var_eq_matching, var_sccs, can_eliminate = tearing (state)
1390
1315
invalidate_cache! (tearing_reassemble (
1391
1316
state, var_eq_matching, full_var_eq_matching, var_sccs; mm,
1392
- simplify, cse_hack, array_hack, fully_determined))
1317
+ simplify, array_hack, fully_determined))
1393
1318
end
1394
1319
1395
1320
"""
@@ -1399,7 +1324,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
1399
1324
the system is balanced.
1400
1325
"""
1401
1326
function dummy_derivative (sys, state = TearingState (sys); simplify = false ,
1402
- mm = nothing , cse_hack = true , array_hack = true , fully_determined = true , kwargs... )
1327
+ mm = nothing , array_hack = true , fully_determined = true , kwargs... )
1403
1328
jac = let state = state
1404
1329
(eqs, vars) -> begin
1405
1330
symeqs = EquationsView (state)[eqs]
@@ -1425,5 +1350,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
1425
1350
state, jac; state_priority,
1426
1351
kwargs... )
1427
1352
tearing_reassemble (state, var_eq_matching, full_var_eq_matching, var_sccs;
1428
- simplify, mm, cse_hack, array_hack, fully_determined)
1353
+ simplify, mm, array_hack, fully_determined)
1429
1354
end
0 commit comments