@@ -1190,52 +1190,69 @@ def local_neg_to_mul(fgraph, node):
1190
1190
1191
1191
@register_specialize
1192
1192
@node_rewriter ([Sum , Prod ])
1193
- def local_sum_prod_of_mul (fgraph , node ):
1193
+ def local_sum_prod_of_mul_or_div (fgraph , node ):
1194
1194
"""
1195
1195
sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions
1196
1196
1197
1197
or
1198
1198
1199
1199
prod(a * X) -> (a ** size(X)) * prod(X)
1200
1200
1201
+ It also applies to reduction of X / a,
1202
+ but not a / X, as that would still require inverting every value in X before the reduction
1203
+
1201
1204
TODO: In the case where not all axis overlap with broadcast dimensions,
1202
1205
consider introducing an outer reduction after factoring out the compatible reduced dimensions
1203
1206
E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1)
1204
1207
"""
1205
- # TODO: if the the thing inside the Sum is a division,
1206
- # we should get at the numerator....
1207
1208
1208
1209
[node_inps ] = node .inputs
1209
- if not (node_inps .owner and node_inps .owner .op == mul ):
1210
+ if not node_inps .owner :
1211
+ return None
1212
+
1213
+ inner_op = node_inps .owner .op
1214
+ if not (inner_op == mul or inner_op == true_div ):
1210
1215
return None
1211
1216
1212
1217
reduced_axes = node .op .axis
1213
1218
if reduced_axes is None :
1214
1219
reduced_axes = tuple (range (node_inps .type .ndim ))
1215
1220
1216
1221
# Separate terms that can be moved out of the Sum/Prod and those that cannot
1217
- outer_terms = []
1218
- inner_terms = []
1219
- for term in node_inps .owner .inputs :
1220
- term_bcast = term .type .broadcastable
1221
- if all (term_bcast [i ] for i in reduced_axes ):
1222
- outer_terms .append (term .squeeze (reduced_axes ))
1223
- else :
1224
- inner_terms .append (term )
1222
+ if inner_op == mul :
1223
+ # Mul accepts arbitrary inputs, so we need to separate into two groups
1224
+ outer_terms = []
1225
+ inner_terms = []
1226
+ for term in node_inps .owner .inputs :
1227
+ term_bcast = term .type .broadcastable
1228
+ if all (term_bcast [i ] for i in reduced_axes ):
1229
+ outer_terms .append (term .squeeze (reduced_axes ))
1230
+ else :
1231
+ inner_terms .append (term )
1225
1232
1226
- if not outer_terms :
1227
- return None
1228
- elif len (outer_terms ) == 1 :
1229
- [outer_term ] = outer_terms
1230
- else :
1231
- outer_term = mul (* outer_terms )
1233
+ if not outer_terms :
1234
+ return None
1235
+ elif len (outer_terms ) == 1 :
1236
+ [outer_term ] = outer_terms
1237
+ else :
1238
+ outer_term = mul (* outer_terms )
1232
1239
1233
- if not inner_terms :
1234
- inner_term = None
1235
- elif len (inner_terms ) == 1 :
1236
- [inner_term ] = inner_terms
1237
- else :
1238
- inner_term = mul (* inner_terms )
1240
+ if not inner_terms :
1241
+ inner_term = None
1242
+ elif len (inner_terms ) == 1 :
1243
+ [inner_term ] = inner_terms
1244
+ else :
1245
+ inner_term = mul (* inner_terms )
1246
+
1247
+ else : # true_div
1248
+ # We only care about removing the denominator out of the reduction
1249
+ numerator , denominator = node_inps .owner .inputs
1250
+ denominator_bcast = denominator .type .broadcastable
1251
+ if all (denominator_bcast [i ] for i in reduced_axes ):
1252
+ outer_term = denominator .squeeze (reduced_axes )
1253
+ inner_term = numerator
1254
+ else :
1255
+ return None
1239
1256
1240
1257
# If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements
1241
1258
# that were contracted in the input
@@ -1246,12 +1263,16 @@ def local_sum_prod_of_mul(fgraph, node):
1246
1263
)
1247
1264
outer_term = outer_term ** n_reduced_elements
1248
1265
1249
- # Sum/Prod is useless, just return the outer_term
1250
1266
if not inner_term :
1267
+ # Sum/Prod is useless, just return the outer_term
1268
+ # (This can only happen for mul, not division)
1251
1269
new_out = outer_term
1252
1270
else :
1253
1271
reduced_inner_term = node .op (inner_term )
1254
- new_out = outer_term * reduced_inner_term
1272
+ if inner_op == mul :
1273
+ new_out = outer_term * reduced_inner_term
1274
+ else :
1275
+ new_out = reduced_inner_term / outer_term
1255
1276
copy_stack_trace (node .outputs , [inner_term , reduced_inner_term , outer_term ])
1256
1277
1257
1278
copy_stack_trace (node .outputs , new_out )
@@ -1510,99 +1531,6 @@ def investigate(node):
1510
1531
return
1511
1532
1512
1533
1513
- @register_canonicalize
1514
- @register_specialize
1515
- @node_rewriter ([Sum , Prod ])
1516
- def local_sum_prod_div_dimshuffle (fgraph , node ):
1517
- """
1518
- sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
1519
- if dimension l of the DimShuffle is 'x'
1520
-
1521
- or
1522
-
1523
- prod(a / dimshuffle{...}(b), axis=l) ->
1524
- prod(a, axis={...}) / b ** a.shape[l],
1525
- if dimension l of the DimShuffle is 'x'
1526
- """
1527
-
1528
- # It does not make much sense now to extend it to the case where the
1529
- # dimshuffle is in the numerator, since elemwise inversion of the
1530
- # denominator would still be needed before the summation or production.
1531
-
1532
- if isinstance (node .op , (Sum , Prod )):
1533
- axis = node .op .axis
1534
- if axis is None :
1535
- axis = list (range (node .inputs [0 ].ndim ))
1536
- node_input = node .inputs [0 ]
1537
- if node_input .owner and node_input .owner .op == true_div :
1538
- numerator , denominator = node_input .owner .inputs
1539
-
1540
- if denominator .owner and isinstance (denominator .owner .op , DimShuffle ):
1541
- dimshuffle_input = denominator .owner .inputs [0 ]
1542
- dimshuffle_order = denominator .owner .op .new_order
1543
-
1544
- compatible_dims = []
1545
- incompatible_dims = []
1546
- for ax in axis :
1547
- if ax < len (dimshuffle_order ) and dimshuffle_order [ax ] == "x" :
1548
- compatible_dims .append (ax )
1549
- else :
1550
- incompatible_dims .append (ax )
1551
- reordered_incompatible_dims = []
1552
- for ic_ax in incompatible_dims :
1553
- reordered_incompatible_dims .append (
1554
- ic_ax - sum (1 for c_ax in compatible_dims if c_ax < ic_ax )
1555
- )
1556
-
1557
- if len (compatible_dims ) > 0 :
1558
- optimized_dimshuffle_order = [
1559
- ax
1560
- for i , ax in enumerate (dimshuffle_order )
1561
- if (i not in axis ) or (ax != "x" )
1562
- ]
1563
-
1564
- # Removing leading 'x' (since it will be done automatically)
1565
- while (
1566
- len (optimized_dimshuffle_order ) > 0
1567
- and optimized_dimshuffle_order [0 ] == "x"
1568
- ):
1569
- del optimized_dimshuffle_order [0 ]
1570
-
1571
- # if optimized_dimshuffle_order is sorted with
1572
- # not 'x', then dimshuffle is useless.
1573
- if all (i == e for i , e in enumerate (optimized_dimshuffle_order )):
1574
- optimized_dimshuffle = dimshuffle_input
1575
- else :
1576
- optimized_dimshuffle = DimShuffle (
1577
- dimshuffle_input .type .broadcastable ,
1578
- optimized_dimshuffle_order ,
1579
- )(dimshuffle_input )
1580
-
1581
- if isinstance (node .op , Sum ):
1582
- op_on_compatible_dims = at_sum (numerator , axis = compatible_dims )
1583
- rval = true_div (op_on_compatible_dims , optimized_dimshuffle )
1584
- if len (reordered_incompatible_dims ) > 0 :
1585
- rval = at_sum (rval , axis = reordered_incompatible_dims )
1586
- elif isinstance (node .op , Prod ):
1587
- op_on_compatible_dims = prod (numerator , axis = compatible_dims )
1588
- dtype = numerator .dtype
1589
- rval = true_div (
1590
- op_on_compatible_dims ,
1591
- (
1592
- optimized_dimshuffle
1593
- ** prod (
1594
- [
1595
- numerator .shape [ax ].astype (dtype )
1596
- for ax in compatible_dims
1597
- ]
1598
- )
1599
- ),
1600
- )
1601
- if len (reordered_incompatible_dims ) > 0 :
1602
- rval = prod (rval , axis = reordered_incompatible_dims )
1603
- return [rval ]
1604
-
1605
-
1606
1534
@register_canonicalize
1607
1535
@node_rewriter ([Sum , Prod ])
1608
1536
def local_sum_prod_all_to_none (fgraph , node ):
0 commit comments