42
42
from pytensor .tensor .exceptions import NotScalarConstantError
43
43
from pytensor .tensor .extra_ops import broadcast_arrays
44
44
from pytensor .tensor .math import (
45
- All ,
46
- Any ,
47
45
Dot ,
48
- FixedOpCAReduce ,
49
- NonZeroDimsCAReduce ,
50
46
Prod ,
51
- ProdWithoutZeros ,
52
47
Sum ,
53
48
_conj ,
54
49
add ,
@@ -1621,22 +1616,8 @@ def local_op_of_op(fgraph, node):
1621
1616
return [combined (node_inps .owner .inputs [0 ])]
1622
1617
1623
1618
1624
- ALL_REDUCE = [
1625
- CAReduce ,
1626
- All ,
1627
- Any ,
1628
- Sum ,
1629
- Prod ,
1630
- ProdWithoutZeros ,
1631
- * CAReduce .__subclasses__ (),
1632
- * FixedOpCAReduce .__subclasses__ (),
1633
- * NonZeroDimsCAReduce .__subclasses__ (),
1634
- ]
1635
-
1636
-
1637
1619
@register_canonicalize
1638
- @register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce
1639
- @node_rewriter (ALL_REDUCE )
1620
+ @node_rewriter ([CAReduce ])
1640
1621
def local_reduce_join (fgraph , node ):
1641
1622
"""
1642
1623
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
@@ -1706,7 +1687,7 @@ def local_reduce_join(fgraph, node):
1706
1687
@register_infer_shape
1707
1688
@register_canonicalize ("fast_compile" , "local_cut_useless_reduce" )
1708
1689
@register_useless ("local_cut_useless_reduce" )
1709
- @node_rewriter (ALL_REDUCE )
1690
+ @node_rewriter ([ CAReduce ] )
1710
1691
def local_useless_reduce (fgraph , node ):
1711
1692
"""Sum(a, axis=[]) -> a"""
1712
1693
(summed ,) = node .inputs
@@ -1718,7 +1699,7 @@ def local_useless_reduce(fgraph, node):
1718
1699
@register_canonicalize
1719
1700
@register_uncanonicalize
1720
1701
@register_specialize
1721
- @node_rewriter (ALL_REDUCE )
1702
+ @node_rewriter ([ CAReduce ] )
1722
1703
def local_reduce_broadcastable (fgraph , node ):
1723
1704
"""Remove reduction over broadcastable dimensions."""
1724
1705
(reduced ,) = node .inputs
0 commit comments