Skip to content

Commit c4f99c6

Browse files
committed
Merge rewrite for sum/prod of div with that of mul
1 parent 35d4b60 commit c4f99c6

File tree

3 files changed

+188
-270
lines changed

3 files changed

+188
-270
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 47 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,52 +1189,69 @@ def local_neg_to_mul(fgraph, node):
11891189

11901190
@register_specialize
11911191
@node_rewriter([Sum, Prod])
1192-
def local_sum_prod_of_mul(fgraph, node):
1192+
def local_sum_prod_of_mul_or_div(fgraph, node):
11931193
"""
11941194
sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions
11951195
11961196
or
11971197
11981198
prod(a * X) -> (a ** size(X)) * prod(X)
11991199
1200+
It also applies to reduction of X / a,
1201+
but not a / X, as that would still require inverting every value in X before the reduction
1202+
12001203
TODO: In the case where not all axis overlap with broadcast dimensions,
12011204
consider introducing an outer reduction after factoring out the compatible reduced dimensions
12021205
E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1)
12031206
"""
1204-
# TODO: if the the thing inside the Sum is a division,
1205-
# we should get at the numerator....
12061207

12071208
[node_inps] = node.inputs
1208-
if not (node_inps.owner and node_inps.owner.op == mul):
1209+
if not node_inps.owner:
1210+
return None
1211+
1212+
inner_op = node_inps.owner.op
1213+
if not (inner_op == mul or inner_op == true_div):
12091214
return None
12101215

12111216
reduced_axes = node.op.axis
12121217
if reduced_axes is None:
12131218
reduced_axes = tuple(range(node_inps.type.ndim))
12141219

12151220
# Separate terms that can be moved out of the Sum/Prod and those that cannot
1216-
outer_terms = []
1217-
inner_terms = []
1218-
for term in node_inps.owner.inputs:
1219-
term_bcast = term.type.broadcastable
1220-
if all(term_bcast[i] for i in reduced_axes):
1221-
outer_terms.append(term.squeeze(reduced_axes))
1222-
else:
1223-
inner_terms.append(term)
1221+
if inner_op == mul:
1222+
# Mul accepts arbitrary inputs, so we need to separate into two groups
1223+
outer_terms = []
1224+
inner_terms = []
1225+
for term in node_inps.owner.inputs:
1226+
term_bcast = term.type.broadcastable
1227+
if all(term_bcast[i] for i in reduced_axes):
1228+
outer_terms.append(term.squeeze(reduced_axes))
1229+
else:
1230+
inner_terms.append(term)
12241231

1225-
if not outer_terms:
1226-
return None
1227-
elif len(outer_terms) == 1:
1228-
[outer_term] = outer_terms
1229-
else:
1230-
outer_term = mul(*outer_terms)
1232+
if not outer_terms:
1233+
return None
1234+
elif len(outer_terms) == 1:
1235+
[outer_term] = outer_terms
1236+
else:
1237+
outer_term = mul(*outer_terms)
12311238

1232-
if not inner_terms:
1233-
inner_term = None
1234-
elif len(inner_terms) == 1:
1235-
[inner_term] = inner_terms
1236-
else:
1237-
inner_term = mul(*inner_terms)
1239+
if not inner_terms:
1240+
inner_term = None
1241+
elif len(inner_terms) == 1:
1242+
[inner_term] = inner_terms
1243+
else:
1244+
inner_term = mul(*inner_terms)
1245+
1246+
else: # true_div
1247+
# We only care about removing the denominator out of the reduction
1248+
numerator, denominator = node_inps.owner.inputs
1249+
denominator_bcast = denominator.type.broadcastable
1250+
if all(denominator_bcast[i] for i in reduced_axes):
1251+
outer_term = denominator.squeeze(reduced_axes)
1252+
inner_term = numerator
1253+
else:
1254+
return None
12381255

12391256
# If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements
12401257
# that were contracted in the input
@@ -1245,12 +1262,16 @@ def local_sum_prod_of_mul(fgraph, node):
12451262
)
12461263
outer_term = outer_term**n_reduced_elements
12471264

1248-
# Sum/Prod is useless, just return the outer_term
12491265
if not inner_term:
1266+
# Sum/Prod is useless, just return the outer_term
1267+
# (This can only happen for mul, not division)
12501268
new_out = outer_term
12511269
else:
12521270
reduced_inner_term = node.op(inner_term)
1253-
new_out = outer_term * reduced_inner_term
1271+
if inner_op == mul:
1272+
new_out = outer_term * reduced_inner_term
1273+
else:
1274+
new_out = reduced_inner_term / outer_term
12541275
copy_stack_trace(node.outputs, [inner_term, reduced_inner_term, outer_term])
12551276

12561277
copy_stack_trace(node.outputs, new_out)
@@ -1509,99 +1530,6 @@ def investigate(node):
15091530
return
15101531

15111532

1512-
@register_canonicalize
1513-
@register_specialize
1514-
@node_rewriter([Sum, Prod])
1515-
def local_sum_prod_div_dimshuffle(fgraph, node):
1516-
"""
1517-
sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
1518-
if dimension l of the DimShuffle is 'x'
1519-
1520-
or
1521-
1522-
prod(a / dimshuffle{...}(b), axis=l) ->
1523-
prod(a, axis={...}) / b ** a.shape[l],
1524-
if dimension l of the DimShuffle is 'x'
1525-
"""
1526-
1527-
# It does not make much sense now to extend it to the case where the
1528-
# dimshuffle is in the numerator, since elemwise inversion of the
1529-
# denominator would still be needed before the summation or production.
1530-
1531-
if isinstance(node.op, (Sum, Prod)):
1532-
axis = node.op.axis
1533-
if axis is None:
1534-
axis = list(range(node.inputs[0].ndim))
1535-
node_input = node.inputs[0]
1536-
if node_input.owner and node_input.owner.op == true_div:
1537-
numerator, denominator = node_input.owner.inputs
1538-
1539-
if denominator.owner and isinstance(denominator.owner.op, DimShuffle):
1540-
dimshuffle_input = denominator.owner.inputs[0]
1541-
dimshuffle_order = denominator.owner.op.new_order
1542-
1543-
compatible_dims = []
1544-
incompatible_dims = []
1545-
for ax in axis:
1546-
if ax < len(dimshuffle_order) and dimshuffle_order[ax] == "x":
1547-
compatible_dims.append(ax)
1548-
else:
1549-
incompatible_dims.append(ax)
1550-
reordered_incompatible_dims = []
1551-
for ic_ax in incompatible_dims:
1552-
reordered_incompatible_dims.append(
1553-
ic_ax - sum(1 for c_ax in compatible_dims if c_ax < ic_ax)
1554-
)
1555-
1556-
if len(compatible_dims) > 0:
1557-
optimized_dimshuffle_order = [
1558-
ax
1559-
for i, ax in enumerate(dimshuffle_order)
1560-
if (i not in axis) or (ax != "x")
1561-
]
1562-
1563-
# Removing leading 'x' (since it will be done automatically)
1564-
while (
1565-
len(optimized_dimshuffle_order) > 0
1566-
and optimized_dimshuffle_order[0] == "x"
1567-
):
1568-
del optimized_dimshuffle_order[0]
1569-
1570-
# if optimized_dimshuffle_order is sorted with
1571-
# not 'x', then dimshuffle is useless.
1572-
if all(i == e for i, e in enumerate(optimized_dimshuffle_order)):
1573-
optimized_dimshuffle = dimshuffle_input
1574-
else:
1575-
optimized_dimshuffle = DimShuffle(
1576-
dimshuffle_input.type.broadcastable,
1577-
optimized_dimshuffle_order,
1578-
)(dimshuffle_input)
1579-
1580-
if isinstance(node.op, Sum):
1581-
op_on_compatible_dims = at_sum(numerator, axis=compatible_dims)
1582-
rval = true_div(op_on_compatible_dims, optimized_dimshuffle)
1583-
if len(reordered_incompatible_dims) > 0:
1584-
rval = at_sum(rval, axis=reordered_incompatible_dims)
1585-
elif isinstance(node.op, Prod):
1586-
op_on_compatible_dims = prod(numerator, axis=compatible_dims)
1587-
dtype = numerator.dtype
1588-
rval = true_div(
1589-
op_on_compatible_dims,
1590-
(
1591-
optimized_dimshuffle
1592-
** prod(
1593-
[
1594-
numerator.shape[ax].astype(dtype)
1595-
for ax in compatible_dims
1596-
]
1597-
)
1598-
),
1599-
)
1600-
if len(reordered_incompatible_dims) > 0:
1601-
rval = prod(rval, axis=reordered_incompatible_dims)
1602-
return [rval]
1603-
1604-
16051533
@register_canonicalize
16061534
@node_rewriter([Sum, Prod])
16071535
def local_sum_prod_all_to_none(fgraph, node):

tests/tensor/rewriting/test_elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,7 @@ def large_fuseable_graph(self, n):
899899
),
900900
(fx, fy),
901901
(fxv, fyv),
902-
3,
902+
2,
903903
(
904904
np.sum(-((fxv - fyv) ** 2) / 2),
905905
-(fxv - fyv),

0 commit comments

Comments
 (0)