Skip to content

Commit c0b90b4

Browse files
committed
Remove useless ALL_REDUCE list
1 parent c49c0ef commit c0b90b4

File tree

1 file changed

+3
-22
lines changed

1 file changed

+3
-22
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,8 @@
4242
from pytensor.tensor.exceptions import NotScalarConstantError
4343
from pytensor.tensor.extra_ops import broadcast_arrays
4444
from pytensor.tensor.math import (
45-
All,
46-
Any,
4745
Dot,
48-
FixedOpCAReduce,
49-
NonZeroDimsCAReduce,
5046
Prod,
51-
ProdWithoutZeros,
5247
Sum,
5348
_conj,
5449
add,
@@ -1621,22 +1616,8 @@ def local_op_of_op(fgraph, node):
16211616
return [combined(node_inps.owner.inputs[0])]
16221617

16231618

1624-
ALL_REDUCE = [
1625-
CAReduce,
1626-
All,
1627-
Any,
1628-
Sum,
1629-
Prod,
1630-
ProdWithoutZeros,
1631-
*CAReduce.__subclasses__(),
1632-
*FixedOpCAReduce.__subclasses__(),
1633-
*NonZeroDimsCAReduce.__subclasses__(),
1634-
]
1635-
1636-
16371619
@register_canonicalize
1638-
@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce
1639-
@node_rewriter(ALL_REDUCE)
1620+
@node_rewriter([CAReduce])
16401621
def local_reduce_join(fgraph, node):
16411622
"""
16421623
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
@@ -1706,7 +1687,7 @@ def local_reduce_join(fgraph, node):
17061687
@register_infer_shape
17071688
@register_canonicalize("fast_compile", "local_cut_useless_reduce")
17081689
@register_useless("local_cut_useless_reduce")
1709-
@node_rewriter(ALL_REDUCE)
1690+
@node_rewriter([CAReduce])
17101691
def local_useless_reduce(fgraph, node):
17111692
"""Sum(a, axis=[]) -> a"""
17121693
(summed,) = node.inputs
@@ -1718,7 +1699,7 @@ def local_useless_reduce(fgraph, node):
17181699
@register_canonicalize
17191700
@register_uncanonicalize
17201701
@register_specialize
1721-
@node_rewriter(ALL_REDUCE)
1702+
@node_rewriter([CAReduce])
17221703
def local_reduce_broadcastable(fgraph, node):
17231704
"""Remove reduction over broadcastable dimensions."""
17241705
(reduced,) = node.inputs

0 commit comments

Comments
 (0)