Skip to content

Commit 152a551

Browse files
committed
Merge rewrite for sum/prod of div with that of mul
1 parent efefa70 commit 152a551

File tree

3 files changed

+202
-275
lines changed

3 files changed

+202
-275
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 47 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,52 +1190,69 @@ def local_neg_to_mul(fgraph, node):
11901190

11911191
@register_specialize
11921192
@node_rewriter([Sum, Prod])
1193-
def local_sum_prod_of_mul(fgraph, node):
1193+
def local_sum_prod_of_mul_or_div(fgraph, node):
11941194
"""
11951195
sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions
11961196
11971197
or
11981198
11991199
prod(a * X) -> (a ** size(X)) * prod(X)
12001200
1201+
It also applies to reduction of X / a,
1202+
but not a / X, as that would still require inverting every value in X before the reduction
1203+
12011204
TODO: In the case where not all axis overlap with broadcast dimensions,
12021205
consider introducing an outer reduction after factoring out the compatible reduced dimensions
12031206
E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1)
12041207
"""
1205-
# TODO: if the the thing inside the Sum is a division,
1206-
# we should get at the numerator....
12071208

12081209
[node_inps] = node.inputs
1209-
if not (node_inps.owner and node_inps.owner.op == mul):
1210+
if not node_inps.owner:
1211+
return None
1212+
1213+
inner_op = node_inps.owner.op
1214+
if not (inner_op == mul or inner_op == true_div):
12101215
return None
12111216

12121217
reduced_axes = node.op.axis
12131218
if reduced_axes is None:
12141219
reduced_axes = tuple(range(node_inps.type.ndim))
12151220

12161221
# Separate terms that can be moved out of the Sum/Prod and those that cannot
1217-
outer_terms = []
1218-
inner_terms = []
1219-
for term in node_inps.owner.inputs:
1220-
term_bcast = term.type.broadcastable
1221-
if all(term_bcast[i] for i in reduced_axes):
1222-
outer_terms.append(term.squeeze(reduced_axes))
1223-
else:
1224-
inner_terms.append(term)
1222+
if inner_op == mul:
1223+
# Mul accepts arbitrary inputs, so we need to separate into two groups
1224+
outer_terms = []
1225+
inner_terms = []
1226+
for term in node_inps.owner.inputs:
1227+
term_bcast = term.type.broadcastable
1228+
if all(term_bcast[i] for i in reduced_axes):
1229+
outer_terms.append(term.squeeze(reduced_axes))
1230+
else:
1231+
inner_terms.append(term)
12251232

1226-
if not outer_terms:
1227-
return None
1228-
elif len(outer_terms) == 1:
1229-
[outer_term] = outer_terms
1230-
else:
1231-
outer_term = mul(*outer_terms)
1233+
if not outer_terms:
1234+
return None
1235+
elif len(outer_terms) == 1:
1236+
[outer_term] = outer_terms
1237+
else:
1238+
outer_term = mul(*outer_terms)
12321239

1233-
if not inner_terms:
1234-
inner_term = None
1235-
elif len(inner_terms) == 1:
1236-
[inner_term] = inner_terms
1237-
else:
1238-
inner_term = mul(*inner_terms)
1240+
if not inner_terms:
1241+
inner_term = None
1242+
elif len(inner_terms) == 1:
1243+
[inner_term] = inner_terms
1244+
else:
1245+
inner_term = mul(*inner_terms)
1246+
1247+
else: # true_div
1248+
# We only care about removing the denominator out of the reduction
1249+
numerator, denominator = node_inps.owner.inputs
1250+
denominator_bcast = denominator.type.broadcastable
1251+
if all(denominator_bcast[i] for i in reduced_axes):
1252+
outer_term = denominator.squeeze(reduced_axes)
1253+
inner_term = numerator
1254+
else:
1255+
return None
12391256

12401257
# If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements
12411258
# that were contracted in the input
@@ -1246,12 +1263,16 @@ def local_sum_prod_of_mul(fgraph, node):
12461263
)
12471264
outer_term = outer_term**n_reduced_elements
12481265

1249-
# Sum/Prod is useless, just return the outer_term
12501266
if not inner_term:
1267+
# Sum/Prod is useless, just return the outer_term
1268+
# (This can only happen for mul, not division)
12511269
new_out = outer_term
12521270
else:
12531271
reduced_inner_term = node.op(inner_term)
1254-
new_out = outer_term * reduced_inner_term
1272+
if inner_op == mul:
1273+
new_out = outer_term * reduced_inner_term
1274+
else:
1275+
new_out = reduced_inner_term / outer_term
12551276
copy_stack_trace(node.outputs, [inner_term, reduced_inner_term, outer_term])
12561277

12571278
copy_stack_trace(node.outputs, new_out)
@@ -1510,99 +1531,6 @@ def investigate(node):
15101531
return
15111532

15121533

1513-
@register_canonicalize
1514-
@register_specialize
1515-
@node_rewriter([Sum, Prod])
1516-
def local_sum_prod_div_dimshuffle(fgraph, node):
1517-
"""
1518-
sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
1519-
if dimension l of the DimShuffle is 'x'
1520-
1521-
or
1522-
1523-
prod(a / dimshuffle{...}(b), axis=l) ->
1524-
prod(a, axis={...}) / b ** a.shape[l],
1525-
if dimension l of the DimShuffle is 'x'
1526-
"""
1527-
1528-
# It does not make much sense now to extend it to the case where the
1529-
# dimshuffle is in the numerator, since elemwise inversion of the
1530-
# denominator would still be needed before the summation or production.
1531-
1532-
if isinstance(node.op, (Sum, Prod)):
1533-
axis = node.op.axis
1534-
if axis is None:
1535-
axis = list(range(node.inputs[0].ndim))
1536-
node_input = node.inputs[0]
1537-
if node_input.owner and node_input.owner.op == true_div:
1538-
numerator, denominator = node_input.owner.inputs
1539-
1540-
if denominator.owner and isinstance(denominator.owner.op, DimShuffle):
1541-
dimshuffle_input = denominator.owner.inputs[0]
1542-
dimshuffle_order = denominator.owner.op.new_order
1543-
1544-
compatible_dims = []
1545-
incompatible_dims = []
1546-
for ax in axis:
1547-
if ax < len(dimshuffle_order) and dimshuffle_order[ax] == "x":
1548-
compatible_dims.append(ax)
1549-
else:
1550-
incompatible_dims.append(ax)
1551-
reordered_incompatible_dims = []
1552-
for ic_ax in incompatible_dims:
1553-
reordered_incompatible_dims.append(
1554-
ic_ax - sum(1 for c_ax in compatible_dims if c_ax < ic_ax)
1555-
)
1556-
1557-
if len(compatible_dims) > 0:
1558-
optimized_dimshuffle_order = [
1559-
ax
1560-
for i, ax in enumerate(dimshuffle_order)
1561-
if (i not in axis) or (ax != "x")
1562-
]
1563-
1564-
# Removing leading 'x' (since it will be done automatically)
1565-
while (
1566-
len(optimized_dimshuffle_order) > 0
1567-
and optimized_dimshuffle_order[0] == "x"
1568-
):
1569-
del optimized_dimshuffle_order[0]
1570-
1571-
# if optimized_dimshuffle_order is sorted with
1572-
# not 'x', then dimshuffle is useless.
1573-
if all(i == e for i, e in enumerate(optimized_dimshuffle_order)):
1574-
optimized_dimshuffle = dimshuffle_input
1575-
else:
1576-
optimized_dimshuffle = DimShuffle(
1577-
dimshuffle_input.type.broadcastable,
1578-
optimized_dimshuffle_order,
1579-
)(dimshuffle_input)
1580-
1581-
if isinstance(node.op, Sum):
1582-
op_on_compatible_dims = at_sum(numerator, axis=compatible_dims)
1583-
rval = true_div(op_on_compatible_dims, optimized_dimshuffle)
1584-
if len(reordered_incompatible_dims) > 0:
1585-
rval = at_sum(rval, axis=reordered_incompatible_dims)
1586-
elif isinstance(node.op, Prod):
1587-
op_on_compatible_dims = prod(numerator, axis=compatible_dims)
1588-
dtype = numerator.dtype
1589-
rval = true_div(
1590-
op_on_compatible_dims,
1591-
(
1592-
optimized_dimshuffle
1593-
** prod(
1594-
[
1595-
numerator.shape[ax].astype(dtype)
1596-
for ax in compatible_dims
1597-
]
1598-
)
1599-
),
1600-
)
1601-
if len(reordered_incompatible_dims) > 0:
1602-
rval = prod(rval, axis=reordered_incompatible_dims)
1603-
return [rval]
1604-
1605-
16061534
@register_canonicalize
16071535
@node_rewriter([Sum, Prod])
16081536
def local_sum_prod_all_to_none(fgraph, node):

tests/tensor/rewriting/test_elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,7 @@ def large_fuseable_graph(self, n):
899899
),
900900
(fx, fy),
901901
(fxv, fyv),
902-
3,
902+
2,
903903
(
904904
np.sum(-((fxv - fyv) ** 2) / 2),
905905
-(fxv - fyv),

0 commit comments

Comments
 (0)