Skip to content

Commit ed85e69

Browse files
committed
Generalize and rename local_reduce_chain
1 parent b0cb903 commit ed85e69

File tree

2 files changed

+200
-181
lines changed

2 files changed

+200
-181
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,42 +1578,48 @@ def local_sum_prod_all_to_none(fgraph, node):
15781578

15791579

15801580
@register_canonicalize
1581-
@node_rewriter([Sum, Prod])
1582-
def local_op_of_op(fgraph, node):
1581+
@node_rewriter([CAReduce])
1582+
def local_reduce_chain(fgraph, node):
15831583
"""
1584-
Prod(Prod()) -> single Prod()
1585-
or
15861584
Sum(Sum()) -> single Sum()
1585+
or any CAReduce(Careduce(x)) of the same type
15871586
15881587
"""
1589-
op_type = Sum if isinstance(node.op, Sum) else Prod
1590-
(node_inps,) = node.inputs
1591-
out_dtype = node.op.dtype
1592-
# This is done to make sure the rewrite doesn't affect other
1593-
# computations.
1594-
if len(fgraph.clients[node_inps]) == 1:
1595-
if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)):
1596-
# check to see either the inner or outer prod is doing a
1597-
# product over all axis, in which case we can remove it
1598-
if node_inps.owner.op.axis is None or node.op.axis is None:
1599-
return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])]
1600-
1601-
# figure out which axes were in the original sum
1602-
newaxis = list(node_inps.owner.op.axis)
1603-
for i in node.op.axis:
1604-
new_i = i
1605-
for ii in node_inps.owner.op.axis:
1606-
if new_i >= ii:
1607-
new_i += 1
1608-
assert new_i not in newaxis
1609-
newaxis.append(new_i)
1610-
1611-
assert len(newaxis) == len(
1612-
list(node_inps.owner.op.axis) + list(node.op.axis)
1613-
)
1588+
[inner_reduce] = node.inputs
1589+
if not (inner_reduce.owner and isinstance(inner_reduce.owner.op, CAReduce)):
1590+
return None
1591+
1592+
# Don't apply rewrite if inner_reduce is used elsewhere
1593+
if len(fgraph.clients[inner_reduce]) > 1:
1594+
return None
1595+
1596+
# Check if CAReduces have the same scalar op
1597+
outer_op: CAReduce = node.op
1598+
inner_op = inner_reduce.owner.op
1599+
1600+
if outer_op.scalar_op != inner_op.scalar_op:
1601+
return None
16141602

1615-
combined = op_type(newaxis, dtype=out_dtype)
1616-
return [combined(node_inps.owner.inputs[0])]
1603+
outer_axis = outer_op.axis
1604+
inner_axis = inner_op.axis
1605+
[x] = inner_reduce.owner.inputs
1606+
# check to see either the inner or outer prod is doing a
1607+
# product over all axis, in which case we can remove it
1608+
if outer_axis is None or inner_axis is None:
1609+
return [outer_op.clone(axis=None)(x)]
1610+
1611+
# Merge axis
1612+
newaxis = list(inner_axis)
1613+
for i in outer_axis:
1614+
new_i = i
1615+
for ii in inner_axis:
1616+
if new_i >= ii:
1617+
new_i += 1
1618+
assert new_i not in newaxis
1619+
newaxis.append(new_i)
1620+
1621+
assert len(newaxis) == len(inner_axis) + len(outer_axis)
1622+
return [outer_op.clone(axis=sorted(newaxis))(x)]
16171623

16181624

16191625
@register_canonicalize

0 commit comments

Comments
 (0)