Skip to content

Commit a1ae929

Browse files
committed
Remove useless ALL_REDUCE list
1 parent 8e0958a commit a1ae929

File tree

1 file changed

+3
-21
lines changed

1 file changed

+3
-21
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,8 @@
4242
from pytensor.tensor.exceptions import NotScalarConstantError
4343
from pytensor.tensor.extra_ops import broadcast_arrays
4444
from pytensor.tensor.math import (
45-
All,
46-
Any,
4745
Dot,
48-
FixedOpCAReduce,
49-
NonZeroDimsCAReduce,
5046
Prod,
51-
ProdWithoutZeros,
5247
Sum,
5348
_conj,
5449
add,
@@ -1618,22 +1613,9 @@ def local_op_of_op(fgraph, node):
16181613
return [combined(node_inps.owner.inputs[0])]
16191614

16201615

1621-
ALL_REDUCE = [
1622-
CAReduce,
1623-
All,
1624-
Any,
1625-
Sum,
1626-
Prod,
1627-
ProdWithoutZeros,
1628-
*CAReduce.__subclasses__(),
1629-
*FixedOpCAReduce.__subclasses__(),
1630-
*NonZeroDimsCAReduce.__subclasses__(),
1631-
]
1632-
1633-
16341616
@register_canonicalize
16351617
@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce
1636-
@node_rewriter(ALL_REDUCE)
1618+
@node_rewriter([CAReduce])
16371619
def local_reduce_join(fgraph, node):
16381620
"""
16391621
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
@@ -1703,7 +1685,7 @@ def local_reduce_join(fgraph, node):
17031685
@register_infer_shape
17041686
@register_canonicalize("fast_compile", "local_cut_useless_reduce")
17051687
@register_useless("local_cut_useless_reduce")
1706-
@node_rewriter(ALL_REDUCE)
1688+
@node_rewriter([CAReduce])
17071689
def local_useless_reduce(fgraph, node):
17081690
"""Sum(a, axis=[]) -> a"""
17091691
(summed,) = node.inputs
@@ -1715,7 +1697,7 @@ def local_useless_reduce(fgraph, node):
17151697
@register_canonicalize
17161698
@register_uncanonicalize
17171699
@register_specialize
1718-
@node_rewriter(ALL_REDUCE)
1700+
@node_rewriter([CAReduce])
17191701
def local_reduce_broadcastable(fgraph, node):
17201702
"""Remove reduction over broadcastable dimensions."""
17211703
(reduced,) = node.inputs

0 commit comments

Comments
 (0)