Skip to content

Commit 35d4b60

Browse files
committed
Extend local_sum_prod_of_mul rewrite to non-scalar terms
Also: * Separates the sum of negation rewrite * Fixes bug in partial prod reduction
1 parent 865d944 commit 35d4b60

File tree

2 files changed

+221
-67
lines changed

2 files changed

+221
-67
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 68 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,86 +1189,88 @@ def local_neg_to_mul(fgraph, node):
11891189

11901190
@register_specialize
11911191
@node_rewriter([Sum, Prod])
1192-
def local_sum_prod_mul_by_scalar(fgraph, node):
1192+
def local_sum_prod_of_mul(fgraph, node):
11931193
"""
1194-
sum(scalar * smth) -> scalar * sum(smth)
1195-
sum(-smth) -> -sum(smth)
1194+
sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions
11961195
11971196
or
11981197
1199-
prod(scalar * smth) -> scalar ** size(smth) * prod(smth)
1200-
prod(-smth) -> -1 ** size(smth) * prod(smth)
1198+
prod(a * X) -> (a ** size(X)) * prod(X)
12011199
1200+
TODO: In the case where not all axis overlap with broadcast dimensions,
1201+
consider introducing an outer reduction after factoring out the compatible reduced dimensions
1202+
E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1)
12021203
"""
12031204
# TODO: if the the thing inside the Sum is a division,
12041205
# we should get at the numerator....
1205-
if isinstance(node.op, (Sum, Prod)):
1206-
(node_inps,) = node.inputs
1207-
if node_inps.owner and node_inps.owner.op == mul:
1208-
terms = node_inps.owner.inputs
1209-
scalars = [t.dimshuffle() for t in terms if all(t.type.broadcastable)]
12101206

1211-
if len(scalars) == 0:
1212-
return
1207+
[node_inps] = node.inputs
1208+
if not (node_inps.owner and node_inps.owner.op == mul):
1209+
return None
12131210

1214-
non_scalars = [t for t in terms if not all(t.broadcastable)]
1211+
reduced_axes = node.op.axis
1212+
if reduced_axes is None:
1213+
reduced_axes = tuple(range(node_inps.type.ndim))
1214+
1215+
# Separate terms that can be moved out of the Sum/Prod and those that cannot
1216+
outer_terms = []
1217+
inner_terms = []
1218+
for term in node_inps.owner.inputs:
1219+
term_bcast = term.type.broadcastable
1220+
if all(term_bcast[i] for i in reduced_axes):
1221+
outer_terms.append(term.squeeze(reduced_axes))
1222+
else:
1223+
inner_terms.append(term)
12151224

1216-
# Perform the op only on the non-scalar inputs, if applicable
1217-
if len(non_scalars) == 0:
1218-
new_op_input_nb_elements = 1
1219-
new_op_output = 1
1220-
elif len(non_scalars) == 1:
1221-
new_op_input_nb_elements = non_scalars[0].size
1222-
new_op_output = node.op(non_scalars[0])
1223-
else:
1224-
new_op_input = mul(*non_scalars)
1225-
# We assume that errors always come from the prod/mul op in the
1226-
# original computational graph, and therefore need to only
1227-
# copy over its output stacktrace.
1228-
copy_stack_trace(node.outputs, new_op_input)
1229-
1230-
new_op_input_nb_elements = new_op_input.size
1231-
new_op_output = node.op(new_op_input)
1232-
1233-
if len(non_scalars) != 0:
1234-
# Copy over stacktrace from previous output to new mul op,
1235-
# for same reason as above.
1236-
copy_stack_trace(node.outputs, new_op_output)
1237-
1238-
# If `node.op` is a `Prod`, then the scalars need to be raised to
1239-
# the power of the number of elements in the input to the `Prod`
1240-
if isinstance(node.op, Prod) and new_op_input_nb_elements != 1:
1241-
scalars = [s**new_op_input_nb_elements for s in scalars]
1242-
1243-
# Scale the output of the op by the scalars and return as
1244-
# replacement for the original output
1245-
mul_inputs = scalars
1246-
if new_op_input_nb_elements != 1:
1247-
mul_inputs.append(new_op_output)
1248-
1249-
if len(mul_inputs) == 1:
1250-
# Copy over stacktrace from previous output to new mul op,
1251-
# for same reason as above.
1252-
copy_stack_trace(node.outputs, mul_inputs)
1253-
1254-
return mul_inputs
1255-
else:
1256-
ret = mul(*mul_inputs)
1257-
# Copy over stacktrace from previous output to new mul op,
1258-
# for same reason as above.
1259-
copy_stack_trace(node.outputs, [ret] + mul_inputs)
1225+
if not outer_terms:
1226+
return None
1227+
elif len(outer_terms) == 1:
1228+
[outer_term] = outer_terms
1229+
else:
1230+
outer_term = mul(*outer_terms)
12601231

1261-
return [ret]
1232+
if not inner_terms:
1233+
inner_term = None
1234+
elif len(inner_terms) == 1:
1235+
[inner_term] = inner_terms
1236+
else:
1237+
inner_term = mul(*inner_terms)
1238+
1239+
# If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements
1240+
# that were contracted in the input
1241+
if isinstance(node.op, Prod) and inner_term:
1242+
dtype = inner_term.dtype
1243+
n_reduced_elements = prod(
1244+
[inner_term.shape[i].astype(dtype) for i in reduced_axes]
1245+
)
1246+
outer_term = outer_term**n_reduced_elements
12621247

1263-
if isinstance(node.op, Sum) and node_inps.owner and node_inps.owner.op == neg:
1264-
s = node.op(node_inps.owner.inputs[0])
1265-
ret = neg(s)
1266-
# There are never errors in the negative op, thus
1267-
# we need only to copy over stacktrace from previous output node to
1268-
# the two new ops.
1269-
copy_stack_trace(node.outputs, [s, ret])
1248+
# Sum/Prod is useless, just return the outer_term
1249+
if not inner_term:
1250+
new_out = outer_term
1251+
else:
1252+
reduced_inner_term = node.op(inner_term)
1253+
new_out = outer_term * reduced_inner_term
1254+
copy_stack_trace(node.outputs, [inner_term, reduced_inner_term, outer_term])
12701255

1271-
return [ret]
1256+
copy_stack_trace(node.outputs, new_out)
1257+
return [new_out]
1258+
1259+
1260+
@register_specialize
1261+
@node_rewriter([Sum])
1262+
def local_sum_of_neg_to_neg_of_sum(fgraph, node):
1263+
"""Rewrite sum(-X) -> -sum(X)."""
1264+
[node_inps] = node.inputs
1265+
if node_inps.owner and node_inps.owner.op == neg:
1266+
s = node.op(node_inps.owner.inputs[0])
1267+
ret = neg(s)
1268+
# There are never errors in the negative op, thus
1269+
# we need only to copy over stacktrace from previous output node to
1270+
# the two new ops.
1271+
copy_stack_trace(node.outputs, [s, ret])
1272+
1273+
return [ret]
12721274

12731275

12741276
@register_specialize

tests/tensor/rewriting/test_math.py

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
local_grad_log_erfc_neg,
9393
local_greedy_distributor,
9494
local_mul_canonizer,
95+
local_sum_prod_of_mul,
9596
mul_canonizer,
9697
parse_mul_tree,
9798
perform_sigm_times_exp,
@@ -2503,7 +2504,7 @@ class TestLocalSumProd:
25032504
def setup_method(self):
25042505
self.mode = get_default_mode().including("canonicalize", "specialize")
25052506

2506-
def test_local_sum_prod_mul_by_scalar(self):
2507+
def test_local_sum_prod_of_scalar_mul(self):
25072508
# Test the rewrite `local_sum_prod_mul_by_scalar` for both Sum and
25082509
# Prod ops in six cases each :
25092510
# 1-the inputs to the mul contain a scalar and no non-scalar
@@ -2653,6 +2654,157 @@ def test_reduction_rewrite(
26532654
axis=(0,),
26542655
)
26552656

2657+
def test_sum_of_non_scalar_mul(self):
2658+
mode = Mode("vm", optimizer="None")
2659+
rewrite = out2in(local_sum_prod_of_mul)
2660+
2661+
row1 = matrix(shape=(1, None), dtype="float64")
2662+
row2 = matrix(shape=(1, None), dtype="float64")
2663+
col1 = matrix(shape=(None, 1), dtype="float64")
2664+
col2 = matrix(shape=(None, 1), dtype="float64")
2665+
mat1 = matrix(shape=(None, None), dtype="float64")
2666+
mat2 = matrix(shape=(None, None), dtype="float64")
2667+
2668+
inputs = [row1, row2, col1, col2, mat1, mat2]
2669+
test_vals = [
2670+
np.random.random((1, 2)),
2671+
np.random.random((1, 2)),
2672+
np.random.random((2, 1)),
2673+
np.random.random((2, 1)),
2674+
np.random.random((2, 2)),
2675+
np.random.random((2, 2)),
2676+
]
2677+
2678+
for out, expected_out in [
2679+
(
2680+
mul(row1, row2, mat1, mat2, col1, col2).sum(axis=None),
2681+
mul(row1, row2, mat1, mat2, col1, col2).sum(axis=None),
2682+
),
2683+
(
2684+
mul(row1, row2, mat1, mat2, col1, col2).sum(axis=0),
2685+
mul(row1.squeeze(), row2.squeeze())
2686+
* mul(mat1, mat2, col1, col2).sum(axis=0),
2687+
),
2688+
(
2689+
mul(row1, mat1, mat2, col1, col2).sum(axis=0),
2690+
row1.squeeze() * mul(mat1, mat2, col1, col2).sum(axis=0),
2691+
),
2692+
(
2693+
mul(row1, row2, mat1, mat2, col1, col2).sum(axis=1),
2694+
mul(col1.squeeze(), col2.squeeze())
2695+
* mul(row1, row2, mat1, mat2).sum(axis=1),
2696+
),
2697+
(
2698+
mul(row1, row2, mat1, mat2, col2).sum(axis=1),
2699+
col2.squeeze() * mul(row1, row2, mat1, mat2).sum(axis=1),
2700+
),
2701+
(
2702+
mul(row1, row2).sum(axis=1),
2703+
mul(row1, row2).sum(axis=1),
2704+
),
2705+
(
2706+
mul(row1, row2).sum(axis=0),
2707+
mul(row1.squeeze(), row2.squeeze()),
2708+
),
2709+
(
2710+
mul(row1, col1).sum(axis=0),
2711+
row1.squeeze() * col1.sum(axis=0),
2712+
),
2713+
]:
2714+
out_fn = pytensor.function(inputs, out, mode=mode, on_unused_input="ignore")
2715+
2716+
rewritten_out = rewrite_graph(out, custom_rewrite=rewrite)
2717+
assert equal_computations([rewritten_out], [expected_out])
2718+
2719+
rewritten_out_fn = pytensor.function(
2720+
inputs, rewritten_out, mode=mode, on_unused_input="ignore"
2721+
)
2722+
np.testing.assert_allclose(
2723+
out_fn(*test_vals),
2724+
rewritten_out_fn(*test_vals),
2725+
)
2726+
2727+
def test_prod_of_non_scalar_mul(self):
2728+
mode = Mode("vm", optimizer="None")
2729+
rewrite = out2in(local_sum_prod_of_mul)
2730+
2731+
scl1 = matrix(shape=(1, 1), dtype="float64")
2732+
row1 = matrix(shape=(1, None), dtype="float64")
2733+
row2 = matrix(shape=(1, None), dtype="float64")
2734+
col1 = matrix(shape=(None, 1), dtype="float64")
2735+
col2 = matrix(shape=(None, 1), dtype="float64")
2736+
mat1 = matrix(shape=(None, None), dtype="float64")
2737+
mat2 = matrix(shape=(None, None), dtype="float64")
2738+
2739+
inputs = [scl1, row1, row2, col1, col2, mat1, mat2]
2740+
test_vals = [
2741+
np.random.random((1, 1)),
2742+
np.random.random((1, 2)),
2743+
np.random.random((1, 2)),
2744+
np.random.random((2, 1)),
2745+
np.random.random((2, 1)),
2746+
np.random.random((2, 2)),
2747+
np.random.random((2, 2)),
2748+
]
2749+
2750+
for out, expected_out in [
2751+
(
2752+
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=None),
2753+
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=None),
2754+
),
2755+
(
2756+
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=0),
2757+
(
2758+
mul(row1.squeeze(), row2.squeeze())
2759+
** prod([mul(mat1, mat2, col1, col2).shape[0]])
2760+
* mul(mat1, mat2, col1, col2).prod(axis=0)
2761+
),
2762+
),
2763+
(
2764+
mul(row1, mat1, mat2, col1, col2).prod(axis=0),
2765+
(
2766+
row1.squeeze() ** prod([mul(mat1, mat2, col1, col2).shape[0]])
2767+
* mul(mat1, mat2, col1, col2).prod(axis=0)
2768+
),
2769+
),
2770+
(
2771+
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=1),
2772+
(
2773+
mul(col1.squeeze(), col2.squeeze())
2774+
** prod([mul(row1, row2, mat1, mat2).shape[1]])
2775+
* mul(row1, row2, mat1, mat2).prod(axis=1)
2776+
),
2777+
),
2778+
(
2779+
mul(row1, row2).prod(axis=0),
2780+
mul(row1.squeeze(), row2.squeeze()),
2781+
),
2782+
(
2783+
mul(row1, col1).prod(axis=0),
2784+
(row1.squeeze() ** prod([col1.shape[0]]) * col1.prod(axis=0)),
2785+
),
2786+
(
2787+
mul(scl1, mat1, row1).prod(axis=None),
2788+
(
2789+
scl1.squeeze()
2790+
** prod([mul(mat1, row1).shape[0], mul(mat1, row1).shape[1]])
2791+
* mul(mat1, row1).prod(axis=None)
2792+
),
2793+
),
2794+
]:
2795+
out_fn = pytensor.function(inputs, out, mode=mode, on_unused_input="ignore")
2796+
2797+
rewritten_out = rewrite_graph(out, custom_rewrite=rewrite)
2798+
assert equal_computations([rewritten_out], [expected_out])
2799+
2800+
rewritten_out_fn = pytensor.function(
2801+
inputs, rewritten_out, mode=mode, on_unused_input="ignore"
2802+
)
2803+
np.testing.assert_allclose(
2804+
out_fn(*test_vals),
2805+
rewritten_out_fn(*test_vals),
2806+
)
2807+
26562808
def test_local_sum_prod_all_to_none(self):
26572809
a = tensor3()
26582810
input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)

0 commit comments

Comments
 (0)