42
42
from pytensor .tensor .exceptions import NotScalarConstantError
43
43
from pytensor .tensor .extra_ops import broadcast_arrays
44
44
from pytensor .tensor .math import (
45
- All ,
46
- Any ,
47
45
Dot ,
48
- FixedOpCAReduce ,
49
- NonZeroDimsCAReduce ,
50
46
Prod ,
51
- ProdWithoutZeros ,
52
47
Sum ,
53
48
_conj ,
54
49
add ,
@@ -1618,22 +1613,9 @@ def local_op_of_op(fgraph, node):
1618
1613
return [combined (node_inps .owner .inputs [0 ])]
1619
1614
1620
1615
1621
- ALL_REDUCE = [
1622
- CAReduce ,
1623
- All ,
1624
- Any ,
1625
- Sum ,
1626
- Prod ,
1627
- ProdWithoutZeros ,
1628
- * CAReduce .__subclasses__ (),
1629
- * FixedOpCAReduce .__subclasses__ (),
1630
- * NonZeroDimsCAReduce .__subclasses__ (),
1631
- ]
1632
-
1633
-
1634
1616
@register_canonicalize
1635
1617
@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce
1636
- @node_rewriter (ALL_REDUCE )
1618
+ @node_rewriter ([ CAReduce ] )
1637
1619
def local_reduce_join (fgraph , node ):
1638
1620
"""
1639
1621
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
@@ -1703,7 +1685,7 @@ def local_reduce_join(fgraph, node):
1703
1685
@register_infer_shape
1704
1686
@register_canonicalize ("fast_compile" , "local_cut_useless_reduce" )
1705
1687
@register_useless ("local_cut_useless_reduce" )
1706
- @node_rewriter (ALL_REDUCE )
1688
+ @node_rewriter ([ CAReduce ] )
1707
1689
def local_useless_reduce (fgraph , node ):
1708
1690
"""Sum(a, axis=[]) -> a"""
1709
1691
(summed ,) = node .inputs
@@ -1715,7 +1697,7 @@ def local_useless_reduce(fgraph, node):
1715
1697
@register_canonicalize
1716
1698
@register_uncanonicalize
1717
1699
@register_specialize
1718
- @node_rewriter (ALL_REDUCE )
1700
+ @node_rewriter ([ CAReduce ] )
1719
1701
def local_reduce_broadcastable (fgraph , node ):
1720
1702
"""Remove reduction over broadcastable dimensions."""
1721
1703
(reduced ,) = node .inputs
0 commit comments