Skip to content

Commit 01b388a

Browse files
Merge pull request #3697 from AayushSabharwal/as/array-initprobpmap
fix: remove CSE hack, fix unscalarized variables in initializeprobpmap
2 parents a5adf3d + aff913f commit 01b388a

File tree

5 files changed

+68
-168
lines changed

5 files changed

+68
-168
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 11 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,7 @@ Update the system equations, unknowns, and observables after simplification.
929929
"""
930930
function update_simplified_system!(
931931
state::TearingState, neweqs, solved_eqs, dummy_sub, var_sccs, extra_unknowns;
932-
cse_hack = true, array_hack = true, D = nothing, iv = nothing)
932+
array_hack = true, D = nothing, iv = nothing)
933933
@unpack fullvars, structure = state
934934
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
935935
diff_to_var = invview(var_to_diff)
@@ -978,8 +978,7 @@ function update_simplified_system!(
978978
end
979979
@set! sys.unknowns = unknowns
980980

981-
obs = cse_and_array_hacks(
982-
sys, obs, unknowns, neweqs; cse = cse_hack, array = array_hack)
981+
obs = tearing_hacks(sys, obs, unknowns, neweqs; array = array_hack)
983982

984983
@set! sys.eqs = neweqs
985984
@set! sys.observed = obs
@@ -1035,7 +1034,7 @@ differential variables.
10351034
according to `full_var_eq_matching`.
10361035
"""
10371036
function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
1038-
full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; simplify = false, mm, cse_hack = true,
1037+
full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; simplify = false, mm,
10391038
array_hack = true, fully_determined = true)
10401039
extra_eqs_vars = get_extra_eqs_vars(
10411040
state, var_eq_matching, full_var_eq_matching, fully_determined)
@@ -1074,7 +1073,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
10741073
# var_eq_matching and full_var_eq_matching are now invalidated
10751074

10761075
sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs,
1077-
extra_unknowns; cse_hack, array_hack, iv, D)
1076+
extra_unknowns; array_hack, iv, D)
10781077

10791078
@set! state.sys = sys
10801079
@set! sys.tearing_state = state
@@ -1223,60 +1222,22 @@ function get_extra_eqs_vars(
12231222
end
12241223

12251224
"""
1226-
# HACK 1
1227-
1228-
Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
1229-
gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
1230-
_very_ expensive. this hack performs a limited form of CSE specifically for this case to
1231-
avoid the unnecessary cost. This and the below hack are implemented simultaneously
1232-
1233-
# HACK 2
1225+
# HACK
12341226
12351227
Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
12361228
equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
12371229
if all `p[i]` are present and the unscalarized form is used in any equation (observed or
12381230
not) we first count the number of times the scalarized form of each observed variable
12391231
occurs in observed equations (and unknowns if it's split).
12401232
"""
1241-
function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = true)
1242-
# HACK 1
1243-
# mapping of rhs to temporary CSE variable
1244-
# `f(...) => tmpvar` in above example
1245-
rhs_to_tempvar = Dict()
1246-
1247-
# HACK 2
1233+
function tearing_hacks(sys, obs, unknowns, neweqs; array = true)
12481234
# map of array observed variable (unscalarized) to number of its
12491235
# scalarized terms that appear in observed equations
12501236
arr_obs_occurrences = Dict()
12511237
for (i, eq) in enumerate(obs)
12521238
lhs = eq.lhs
12531239
rhs = eq.rhs
12541240

1255-
# HACK 1
1256-
if cse && is_getindexed_array(rhs)
1257-
rhs_arr = arguments(rhs)[1]
1258-
iscall(rhs_arr) && operation(rhs_arr) isa Symbolics.Operator && continue
1259-
if !haskey(rhs_to_tempvar, rhs_arr)
1260-
tempvar = gensym(Symbol(lhs))
1261-
N = length(rhs_arr)
1262-
tempvar = unwrap(Symbolics.variable(
1263-
tempvar; T = Symbolics.symtype(rhs_arr)))
1264-
tempvar = setmetadata(
1265-
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
1266-
tempeq = tempvar ~ rhs_arr
1267-
rhs_to_tempvar[rhs_arr] = tempvar
1268-
push!(obs, tempeq)
1269-
end
1270-
1271-
# getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
1272-
# so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
1273-
# which fails the topological sort
1274-
neweq = lhs ~ getindex_wrapper(
1275-
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
1276-
obs[i] = neweq
1277-
end
1278-
# end HACK 1
1279-
12801241
array || continue
12811242
iscall(lhs) || continue
12821243
operation(lhs) === getindex || continue
@@ -1287,31 +1248,6 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
12871248
continue
12881249
end
12891250

1290-
# Also do CSE for `equations(sys)`
1291-
if cse
1292-
for (i, eq) in enumerate(neweqs)
1293-
(; lhs, rhs) = eq
1294-
is_getindexed_array(rhs) || continue
1295-
rhs_arr = arguments(rhs)[1]
1296-
if !haskey(rhs_to_tempvar, rhs_arr)
1297-
tempvar = gensym(Symbol(lhs))
1298-
N = length(rhs_arr)
1299-
tempvar = unwrap(Symbolics.variable(
1300-
tempvar; T = Symbolics.symtype(rhs_arr)))
1301-
tempvar = setmetadata(
1302-
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
1303-
tempeq = tempvar ~ rhs_arr
1304-
rhs_to_tempvar[rhs_arr] = tempvar
1305-
push!(obs, tempeq)
1306-
end
1307-
# don't need getindex_wrapper, but do it anyway to know that this
1308-
# hack took place
1309-
neweq = lhs ~ getindex_wrapper(
1310-
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
1311-
neweqs[i] = neweq
1312-
end
1313-
end
1314-
13151251
# count variables in unknowns if they are scalarized forms of variables
13161252
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
13171253
# is an observed equation.
@@ -1346,18 +1282,7 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
13461282
return obs
13471283
end
13481284

1349-
function is_getindexed_array(rhs)
1350-
(!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
1351-
iscall(rhs) && operation(rhs) === getindex &&
1352-
Symbolics.shape(rhs) != Symbolics.Unknown()
1353-
end
1354-
1355-
# PART OF HACK 1
1356-
getindex_wrapper(x, i) = x[i...]
1357-
1358-
@register_symbolic getindex_wrapper(x::AbstractArray, i::Tuple{Vararg{Int}})
1359-
1360-
# PART OF HACK 2
1285+
# PART OF HACK
13611286
function change_origin(origin, arr)
13621287
if all(isone, Tuple(origin))
13631288
return arr
@@ -1385,11 +1310,11 @@ new residual equations after tearing. End users are encouraged to call [`mtkcomp
13851310
instead, which calls this function internally.
13861311
"""
13871312
function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing,
1388-
simplify = false, cse_hack = true, array_hack = true, fully_determined = true, kwargs...)
1313+
simplify = false, array_hack = true, fully_determined = true, kwargs...)
13891314
var_eq_matching, full_var_eq_matching, var_sccs, can_eliminate = tearing(state)
13901315
invalidate_cache!(tearing_reassemble(
13911316
state, var_eq_matching, full_var_eq_matching, var_sccs; mm,
1392-
simplify, cse_hack, array_hack, fully_determined))
1317+
simplify, array_hack, fully_determined))
13931318
end
13941319

13951320
"""
@@ -1399,7 +1324,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
13991324
the system is balanced.
14001325
"""
14011326
function dummy_derivative(sys, state = TearingState(sys); simplify = false,
1402-
mm = nothing, cse_hack = true, array_hack = true, fully_determined = true, kwargs...)
1327+
mm = nothing, array_hack = true, fully_determined = true, kwargs...)
14031328
jac = let state = state
14041329
(eqs, vars) -> begin
14051330
symeqs = EquationsView(state)[eqs]
@@ -1425,5 +1350,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
14251350
state, jac; state_priority,
14261351
kwargs...)
14271352
tearing_reassemble(state, var_eq_matching, full_var_eq_matching, var_sccs;
1428-
simplify, mm, cse_hack, array_hack, fully_determined)
1353+
simplify, mm, array_hack, fully_determined)
14291354
end

src/systems/nonlinear/initializesystem.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -780,20 +780,6 @@ function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation})
780780
push!(rm_idxs, i)
781781
continue
782782
end
783-
if operation(eq.rhs) == StructuralTransformations.getindex_wrapper
784-
var, idxs = arguments(eq.rhs)
785-
subs[eq.rhs] = var[idxs...]
786-
push!(tempvars, var)
787-
end
788-
end
789-
790-
for (i, eq) in enumerate(eqs)
791-
iscall(eq.rhs) || continue
792-
if operation(eq.rhs) == StructuralTransformations.getindex_wrapper
793-
var, idxs = arguments(eq.rhs)
794-
subs[eq.rhs] = var[idxs...]
795-
push!(tempvars, var)
796-
end
797783
end
798784

799785
for (i, eq) in enumerate(obseqs)

test/code_generation.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,34 @@ end
7979
@test SciMLBase.successful_retcode(sol)
8080
end
8181
end
82+
83+
@testset "scalarized array observed calling same function multiple times" begin
84+
@variables x(t) y(t)[1:2]
85+
@parameters foo(::Real)[1:2]
86+
val = Ref(0)
87+
function _tmp_fn2(x)
88+
val[] += 1
89+
return [x, 2x]
90+
end
91+
@mtkcompile sys = System([D(x) ~ y[1] + y[2], y ~ foo(x)], t)
92+
@test length(equations(sys)) == 1
93+
@test length(ModelingToolkit.observed(sys)) == 3
94+
prob = ODEProblem(sys, [x => 1.0, foo => _tmp_fn2], (0.0, 1.0))
95+
val[] = 0
96+
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
97+
@test val[] == 1
98+
99+
@testset "CSE in equations(sys)" begin
100+
val[] = 0
101+
@variables z(t)[1:2]
102+
@mtkcompile sys = System(
103+
[D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t)
104+
@test length(equations(sys)) == 5
105+
@test length(ModelingToolkit.observed(sys)) == 0
106+
prob = ODEProblem(
107+
sys, [y => ones(2), z => 2ones(2), x => 3.0, foo => _tmp_fn2], (0.0, 1.0))
108+
val[] = 0
109+
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
110+
@test val[] == 2
111+
end
112+
end

test/initializationsystem.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,3 +1643,24 @@ end
16431643
@test !SciMLBase.isinplace(prob)
16441644
@test !SciMLBase.isinplace(prob.f.initialization_data.initializeprob)
16451645
end
1646+
1647+
@testset "Array unknowns occurring unscalarized in initializeprobpmap" begin
1648+
@variables begin
1649+
u(t)[1:2] = 0.9ones(2)
1650+
x(t)[1:2], [guess = 0.01ones(2)]
1651+
o(t)[1:2]
1652+
end
1653+
@parameters p[1:4] = [2.0, 1.875, 2.0, 1.875]
1654+
1655+
eqs = [D(u[1]) ~ p[1] * u[1] - p[2] * u[1] * u[2] + x[1] + 0.1
1656+
D(u[2]) ~ p[4] * u[1] * u[2] - p[3] * u[2] - x[2]
1657+
o[1] ~ sum(p) * sum(u)
1658+
o[2] ~ sum(p) * sum(x)
1659+
x[1] ~ 0.01exp(-1)
1660+
x[2] ~ 0.01cos(t)]
1661+
1662+
@mtkbuild sys = ODESystem(eqs, t)
1663+
prob = ODEProblem(sys, [], (0.0, 1.0))
1664+
sol = solve(prob, Tsit5())
1665+
@test SciMLBase.successful_retcode(sol)
1666+
end

test/structural_transformation/utils.jl

Lines changed: 5 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ end
5252
@mtkcompile sys = System(
5353
[D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
5454
@test length(equations(sys)) == 1
55-
@test length(observed(sys)) == 7
55+
@test length(observed(sys)) == 6
5656
@test any(obs -> isequal(obs, y), observables(sys))
5757
@test any(obs -> isequal(obs, z), observables(sys))
5858
prob = ODEProblem(sys, [x => 1.0, foo => _tmp_fn], (0.0, 1.0))
@@ -62,76 +62,20 @@ end
6262
@test length(unknowns(isys)) == 5
6363
@test length(equations(isys)) == 4
6464
@test !any(equations(isys)) do eq
65-
iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
66-
StructuralTransformations.change_origin]
65+
iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.change_origin]
6766
end
6867
end
6968

70-
@testset "scalarized array observed calling same function multiple times" begin
71-
@variables x(t) y(t)[1:2]
72-
@parameters foo(::Real)[1:2]
73-
val = Ref(0)
74-
function _tmp_fn2(x)
75-
val[] += 1
76-
return [x, 2x]
77-
end
78-
@mtkcompile sys = System([D(x) ~ y[1] + y[2], y ~ foo(x)], t)
79-
@test length(equations(sys)) == 1
80-
@test length(observed(sys)) == 4
81-
prob = ODEProblem(sys, [x => 1.0, foo => _tmp_fn2], (0.0, 1.0))
82-
val[] = 0
83-
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
84-
@test val[] == 1
85-
86-
isys = ModelingToolkit.generate_initializesystem(sys)
87-
@test length(unknowns(isys)) == 3
88-
@test length(equations(isys)) == 2
89-
@test !any(equations(isys)) do eq
90-
iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
91-
StructuralTransformations.change_origin]
92-
end
93-
94-
@testset "CSE hack in equations(sys)" begin
95-
val[] = 0
96-
@variables z(t)[1:2]
97-
@mtkcompile sys = System(
98-
[D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t)
99-
@test length(equations(sys)) == 5
100-
@test length(observed(sys)) == 2
101-
prob = ODEProblem(
102-
sys, [y => ones(2), z => 2ones(2), x => 3.0, foo => _tmp_fn2], (0.0, 1.0))
103-
val[] = 0
104-
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
105-
@test val[] == 2
106-
107-
isys = ModelingToolkit.generate_initializesystem(sys)
108-
@test length(unknowns(isys)) == 5
109-
@test length(equations(isys)) == 2
110-
@test !any(equations(isys)) do eq
111-
iscall(eq.rhs) &&
112-
operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
113-
StructuralTransformations.change_origin]
114-
end
115-
end
116-
end
117-
118-
@testset "array and cse hacks can be disabled" begin
69+
@testset "array hack can be disabled" begin
11970
@testset "fully_determined = true" begin
12071
@variables x(t) y(t)[1:2] z(t)[1:2]
12172
@parameters foo(::AbstractVector)[1:2]
12273
_tmp_fn(x) = 2x
12374
@named sys = System(
12475
[D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
12576

126-
sys1 = mtkcompile(sys; cse_hack = false)
127-
@test length(observed(sys1)) == 6
128-
@test !any(observed(sys1)) do eq
129-
iscall(eq.rhs) &&
130-
operation(eq.rhs) == StructuralTransformations.getindex_wrapper
131-
end
132-
13377
sys2 = mtkcompile(sys; array_hack = false)
134-
@test length(observed(sys2)) == 5
78+
@test length(observed(sys2)) == 4
13579
@test !any(observed(sys2)) do eq
13680
iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin
13781
end
@@ -144,15 +88,8 @@ end
14488
@named sys = System(
14589
[D(x) ~ z[1] + z[2] + foo(z)[1] + w, y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
14690

147-
sys1 = mtkcompile(sys; cse_hack = false, fully_determined = false)
148-
@test length(observed(sys1)) == 6
149-
@test !any(observed(sys1)) do eq
150-
iscall(eq.rhs) &&
151-
operation(eq.rhs) == StructuralTransformations.getindex_wrapper
152-
end
153-
15491
sys2 = mtkcompile(sys; array_hack = false, fully_determined = false)
155-
@test length(observed(sys2)) == 5
92+
@test length(observed(sys2)) == 4
15693
@test !any(observed(sys2)) do eq
15794
iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin
15895
end

0 commit comments

Comments
 (0)