@@ -1578,42 +1578,48 @@ def local_sum_prod_all_to_none(fgraph, node):
1578
1578
1579
1579
1580
1580
@register_canonicalize
1581
- @node_rewriter ([Sum , Prod ])
1582
- def local_op_of_op (fgraph , node ):
1581
+ @node_rewriter ([CAReduce ])
1582
+ def local_reduce_chain (fgraph , node ):
1583
1583
"""
1584
- Prod(Prod()) -> single Prod()
1585
- or
1586
1584
Sum(Sum()) -> single Sum()
1585
+ or any CAReduce(Careduce(x)) of the same type
1587
1586
1588
1587
"""
1589
- op_type = Sum if isinstance (node .op , Sum ) else Prod
1590
- (node_inps ,) = node .inputs
1591
- out_dtype = node .op .dtype
1592
- # This is done to make sure the rewrite doesn't affect other
1593
- # computations.
1594
- if len (fgraph .clients [node_inps ]) == 1 :
1595
- if node_inps .owner and (isinstance (node_inps .owner .op , node .op .__class__ )):
1596
- # check to see either the inner or outer prod is doing a
1597
- # product over all axis, in which case we can remove it
1598
- if node_inps .owner .op .axis is None or node .op .axis is None :
1599
- return [op_type (None , dtype = out_dtype )(node_inps .owner .inputs [0 ])]
1600
-
1601
- # figure out which axes were in the original sum
1602
- newaxis = list (node_inps .owner .op .axis )
1603
- for i in node .op .axis :
1604
- new_i = i
1605
- for ii in node_inps .owner .op .axis :
1606
- if new_i >= ii :
1607
- new_i += 1
1608
- assert new_i not in newaxis
1609
- newaxis .append (new_i )
1610
-
1611
- assert len (newaxis ) == len (
1612
- list (node_inps .owner .op .axis ) + list (node .op .axis )
1613
- )
1588
+ [inner_reduce ] = node .inputs
1589
+ if not (inner_reduce .owner and isinstance (inner_reduce .owner .op , CAReduce )):
1590
+ return None
1591
+
1592
+ # Don't apply rewrite if inner_reduce is used elsewhere
1593
+ if len (fgraph .clients [inner_reduce ]) > 1 :
1594
+ return None
1595
+
1596
+ # Check if CAReduces have the same scalar op
1597
+ outer_op : CAReduce = node .op
1598
+ inner_op = inner_reduce .owner .op
1599
+
1600
+ if outer_op .scalar_op != inner_op .scalar_op :
1601
+ return None
1614
1602
1615
- combined = op_type (newaxis , dtype = out_dtype )
1616
- return [combined (node_inps .owner .inputs [0 ])]
1603
+ outer_axis = outer_op .axis
1604
+ inner_axis = inner_op .axis
1605
+ [x ] = inner_reduce .owner .inputs
1606
+ # check to see either the inner or outer prod is doing a
1607
+ # product over all axis, in which case we can remove it
1608
+ if outer_axis is None or inner_axis is None :
1609
+ return [outer_op .clone (axis = None )(x )]
1610
+
1611
+ # Merge axis
1612
+ newaxis = list (inner_axis )
1613
+ for i in outer_axis :
1614
+ new_i = i
1615
+ for ii in inner_axis :
1616
+ if new_i >= ii :
1617
+ new_i += 1
1618
+ assert new_i not in newaxis
1619
+ newaxis .append (new_i )
1620
+
1621
+ assert len (newaxis ) == len (inner_axis ) + len (outer_axis )
1622
+ return [outer_op .clone (axis = sorted (newaxis ))(x )]
1617
1623
1618
1624
1619
1625
@register_canonicalize
0 commit comments