Skip to content

Commit 1e443d4

Browse files
committed
Remove useless ALL_REDUCE list
1 parent 31bf682 commit 1e443d4

File tree

1 file changed

+3
-21
lines changed

1 file changed

+3
-21
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,8 @@
4242
from pytensor.tensor.exceptions import NotScalarConstantError
4343
from pytensor.tensor.extra_ops import broadcast_arrays
4444
from pytensor.tensor.math import (
45-
All,
46-
Any,
4745
Dot,
48-
FixedOpCAReduce,
49-
NonZeroDimsCAReduce,
5046
Prod,
51-
ProdWithoutZeros,
5247
Sum,
5348
_conj,
5449
add,
@@ -1621,22 +1616,9 @@ def local_op_of_op(fgraph, node):
16211616
return [combined(node_inps.owner.inputs[0])]
16221617

16231618

1624-
ALL_REDUCE = [
1625-
CAReduce,
1626-
All,
1627-
Any,
1628-
Sum,
1629-
Prod,
1630-
ProdWithoutZeros,
1631-
*CAReduce.__subclasses__(),
1632-
*FixedOpCAReduce.__subclasses__(),
1633-
*NonZeroDimsCAReduce.__subclasses__(),
1634-
]
1635-
1636-
16371619
@register_canonicalize
16381620
@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce
1639-
@node_rewriter(ALL_REDUCE)
1621+
@node_rewriter([CAReduce])
16401622
def local_reduce_join(fgraph, node):
16411623
"""
16421624
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
@@ -1706,7 +1688,7 @@ def local_reduce_join(fgraph, node):
17061688
@register_infer_shape
17071689
@register_canonicalize("fast_compile", "local_cut_useless_reduce")
17081690
@register_useless("local_cut_useless_reduce")
1709-
@node_rewriter(ALL_REDUCE)
1691+
@node_rewriter([CAReduce])
17101692
def local_useless_reduce(fgraph, node):
17111693
"""Sum(a, axis=[]) -> a"""
17121694
(summed,) = node.inputs
@@ -1718,7 +1700,7 @@ def local_useless_reduce(fgraph, node):
17181700
@register_canonicalize
17191701
@register_uncanonicalize
17201702
@register_specialize
1721-
@node_rewriter(ALL_REDUCE)
1703+
@node_rewriter([CAReduce])
17221704
def local_reduce_broadcastable(fgraph, node):
17231705
"""Remove reduction over broadcastable dimensions."""
17241706
(reduced,) = node.inputs

0 commit comments

Comments
 (0)