Skip to content

Commit c4b816f

Browse files
committed
Extend local_sum_prod_of_mul rewrite to non-scalar terms
Also: * Separates the sum of negation rewrite * Fixes bug in partial prod reduction
1 parent 849c556 commit c4b816f

File tree

2 files changed

+216
-67
lines changed

2 files changed

+216
-67
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 63 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,86 +1188,83 @@ def local_neg_to_mul(fgraph, node):
11881188

11891189
@register_specialize
11901190
@node_rewriter([Sum, Prod])
1191-
def local_sum_prod_mul_by_scalar(fgraph, node):
1191+
def local_sum_prod_of_mul(fgraph, node):
11921192
"""
1193-
sum(scalar * smth) -> scalar * sum(smth)
1194-
sum(-smth) -> -sum(smth)
1193+
sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions
11951194
11961195
or
11971196
1198-
prod(scalar * smth) -> scalar ** size(smth) * prod(smth)
1199-
prod(-smth) -> -1 ** size(smth) * prod(smth)
1197+
prod(a * X) -> (a ** size(X)) * prod(X)
12001198
1199+
Both a and X can be the output of other multiplications
12011200
"""
12021201
# TODO: if the the thing inside the Sum is a division,
12031202
# we should get at the numerator....
1204-
if isinstance(node.op, (Sum, Prod)):
1205-
(node_inps,) = node.inputs
1206-
if node_inps.owner and node_inps.owner.op == mul:
1207-
terms = node_inps.owner.inputs
1208-
scalars = [t.dimshuffle() for t in terms if all(t.type.broadcastable)]
12091203

1210-
if len(scalars) == 0:
1211-
return
1204+
[node_inps] = node.inputs
1205+
if not (node_inps.owner and node_inps.owner.op == mul):
1206+
return None
12121207

1213-
non_scalars = [t for t in terms if not all(t.broadcastable)]
1208+
reduced_axes = node.op.axis
1209+
if reduced_axes is None:
1210+
reduced_axes = tuple(range(node_inps.type.ndim))
1211+
1212+
# Separate terms that can be moved out of the Sum/Prod and those that cannot
1213+
terms_bcasted_along_reduced_axes = []
1214+
other_terms = []
1215+
for term in node_inps.owner.inputs:
1216+
term_bcast = term.type.broadcastable
1217+
if all(term_bcast[i] for i in reduced_axes):
1218+
terms_bcasted_along_reduced_axes.append(term.squeeze(reduced_axes))
1219+
else:
1220+
other_terms.append(term)
12141221

1215-
# Perform the op only on the non-scalar inputs, if applicable
1216-
if len(non_scalars) == 0:
1217-
new_op_input_nb_elements = 1
1218-
new_op_output = 1
1219-
elif len(non_scalars) == 1:
1220-
new_op_input_nb_elements = non_scalars[0].size
1221-
new_op_output = node.op(non_scalars[0])
1222-
else:
1223-
new_op_input = mul(*non_scalars)
1224-
# We assume that errors always come from the prod/mul op in the
1225-
# original computational graph, and therefore need to only
1226-
# copy over its output stacktrace.
1227-
copy_stack_trace(node.outputs, new_op_input)
1228-
1229-
new_op_input_nb_elements = new_op_input.size
1230-
new_op_output = node.op(new_op_input)
1231-
1232-
if len(non_scalars) != 0:
1233-
# Copy over stacktrace from previous output to new mul op,
1234-
# for same reason as above.
1235-
copy_stack_trace(node.outputs, new_op_output)
1236-
1237-
# If `node.op` is a `Prod`, then the scalars need to be raised to
1238-
# the power of the number of elements in the input to the `Prod`
1239-
if isinstance(node.op, Prod) and new_op_input_nb_elements != 1:
1240-
scalars = [s**new_op_input_nb_elements for s in scalars]
1241-
1242-
# Scale the output of the op by the scalars and return as
1243-
# replacement for the original output
1244-
mul_inputs = scalars
1245-
if new_op_input_nb_elements != 1:
1246-
mul_inputs.append(new_op_output)
1247-
1248-
if len(mul_inputs) == 1:
1249-
# Copy over stacktrace from previous output to new mul op,
1250-
# for same reason as above.
1251-
copy_stack_trace(node.outputs, mul_inputs)
1252-
1253-
return mul_inputs
1254-
else:
1255-
ret = mul(*mul_inputs)
1256-
# Copy over stacktrace from previous output to new mul op,
1257-
# for same reason as above.
1258-
copy_stack_trace(node.outputs, [ret] + mul_inputs)
1222+
if not terms_bcasted_along_reduced_axes:
1223+
return None
1224+
elif len(terms_bcasted_along_reduced_axes) == 1:
1225+
[outside_term] = terms_bcasted_along_reduced_axes
1226+
else:
1227+
outside_term = mul(*terms_bcasted_along_reduced_axes)
12591228

1260-
return [ret]
1229+
if not other_terms:
1230+
inner_term = None
1231+
elif len(other_terms) == 1:
1232+
[inner_term] = other_terms
1233+
else:
1234+
inner_term = mul(*other_terms)
12611235

1262-
if isinstance(node.op, Sum) and node_inps.owner and node_inps.owner.op == neg:
1263-
s = node.op(node_inps.owner.inputs[0])
1264-
ret = neg(s)
1265-
# There are never errors in the negative op, thus
1266-
# we need only to copy over stacktrace from previous output node to
1267-
# the two new ops.
1268-
copy_stack_trace(node.outputs, [s, ret])
1236+
# If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements
1237+
# that were contracted in the input
1238+
if isinstance(node.op, Prod) and inner_term:
1239+
n_reduced_elements = prod([inner_term.shape[i] for i in reduced_axes])
1240+
outside_term = outside_term**n_reduced_elements
12691241

1270-
return [ret]
1242+
# Sum/Prod is useless, just return the outside_term
1243+
if not inner_term:
1244+
new_out = outside_term
1245+
else:
1246+
reduce_inner_term = node.op(inner_term)
1247+
new_out = outside_term * reduce_inner_term
1248+
copy_stack_trace(node.outputs, [inner_term, reduce_inner_term, outside_term])
1249+
1250+
copy_stack_trace(node.outputs, new_out)
1251+
return [new_out]
1252+
1253+
1254+
@register_specialize
1255+
@node_rewriter([Sum])
1256+
def local_sum_of_neg_to_neg_of_sum(fgraph, node):
1257+
"""Rewrite sum(-X) -> -sum(X)."""
1258+
[node_inps] = node.inputs
1259+
if node_inps.owner and node_inps.owner.op == neg:
1260+
s = node.op(node_inps.owner.inputs[0])
1261+
ret = neg(s)
1262+
# There are never errors in the negative op, thus
1263+
# we need only to copy over stacktrace from previous output node to
1264+
# the two new ops.
1265+
copy_stack_trace(node.outputs, [s, ret])
1266+
1267+
return [ret]
12711268

12721269

12731270
@register_specialize

tests/tensor/rewriting/test_math.py

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
local_grad_log_erfc_neg,
9393
local_greedy_distributor,
9494
local_mul_canonizer,
95+
local_sum_prod_of_mul,
9596
mul_canonizer,
9697
parse_mul_tree,
9798
perform_sigm_times_exp,
@@ -2503,7 +2504,7 @@ class TestLocalSumProd:
25032504
def setup_method(self):
25042505
self.mode = get_default_mode().including("canonicalize", "specialize")
25052506

2506-
def test_local_sum_prod_mul_by_scalar(self):
2507+
def test_local_sum_prod_of_scalar_mul(self):
25072508
# Test the rewrite `local_sum_prod_mul_by_scalar` for both Sum and
25082509
# Prod ops in six cases each :
25092510
# 1-the inputs to the mul contain a scalar and no non-scalar
@@ -2653,6 +2654,157 @@ def test_reduction_rewrite(
26532654
axis=(0,),
26542655
)
26552656

2657+
def test_sum_of_non_scalar_mul(self):
2658+
mode = Mode("vm", optimizer="None")
2659+
rewrite = out2in(local_sum_prod_of_mul)
2660+
2661+
row1 = matrix(shape=(1, None), dtype="float64")
2662+
row2 = matrix(shape=(1, None), dtype="float64")
2663+
col1 = matrix(shape=(None, 1), dtype="float64")
2664+
col2 = matrix(shape=(None, 1), dtype="float64")
2665+
mat1 = matrix(shape=(None, None), dtype="float64")
2666+
mat2 = matrix(shape=(None, None), dtype="float64")
2667+
2668+
inputs = [row1, row2, col1, col2, mat1, mat2]
2669+
test_vals = [
2670+
np.random.random((1, 2)),
2671+
np.random.random((1, 2)),
2672+
np.random.random((2, 1)),
2673+
np.random.random((2, 1)),
2674+
np.random.random((2, 2)),
2675+
np.random.random((2, 2)),
2676+
]
2677+
2678+
for out, expected_out in [
2679+
(
2680+
mul(row1, row2, mat1, mat2, col1, col2).sum(axis=None),
2681+
mul(row1, row2, mat1, mat2, col1, col2).sum(axis=None),
2682+
),
2683+
(
2684+
mul(row1, row2, mat1, mat2, col1, col2).sum(axis=0),
2685+
mul(row1.squeeze(), row2.squeeze())
2686+
* mul(mat1, mat2, col1, col2).sum(axis=0),
2687+
),
2688+
(
2689+
mul(row1, mat1, mat2, col1, col2).sum(axis=0),
2690+
row1.squeeze() * mul(mat1, mat2, col1, col2).sum(axis=0),
2691+
),
2692+
(
2693+
mul(row1, row2, mat1, mat2, col1, col2).sum(axis=1),
2694+
mul(col1.squeeze(), col2.squeeze())
2695+
* mul(row1, row2, mat1, mat2).sum(axis=1),
2696+
),
2697+
(
2698+
mul(row1, row2, mat1, mat2, col2).sum(axis=1),
2699+
col2.squeeze() * mul(row1, row2, mat1, mat2).sum(axis=1),
2700+
),
2701+
(
2702+
mul(row1, row2).sum(axis=1),
2703+
mul(row1, row2).sum(axis=1),
2704+
),
2705+
(
2706+
mul(row1, row2).sum(axis=0),
2707+
mul(row1.squeeze(), row2.squeeze()),
2708+
),
2709+
(
2710+
mul(row1, col1).sum(axis=0),
2711+
row1.squeeze() * col1.sum(axis=0),
2712+
),
2713+
]:
2714+
out_fn = pytensor.function(inputs, out, mode=mode, on_unused_input="ignore")
2715+
2716+
rewritten_out = rewrite_graph(out, custom_rewrite=rewrite)
2717+
assert equal_computations([rewritten_out], [expected_out])
2718+
2719+
rewritten_out_fn = pytensor.function(
2720+
inputs, rewritten_out, mode=mode, on_unused_input="ignore"
2721+
)
2722+
np.testing.assert_allclose(
2723+
out_fn(*test_vals),
2724+
rewritten_out_fn(*test_vals),
2725+
)
2726+
2727+
def test_prod_of_non_scalar_mul(self):
2728+
mode = Mode("vm", optimizer="None")
2729+
rewrite = out2in(local_sum_prod_of_mul)
2730+
2731+
scl1 = matrix(shape=(1, 1), dtype="float64")
2732+
row1 = matrix(shape=(1, None), dtype="float64")
2733+
row2 = matrix(shape=(1, None), dtype="float64")
2734+
col1 = matrix(shape=(None, 1), dtype="float64")
2735+
col2 = matrix(shape=(None, 1), dtype="float64")
2736+
mat1 = matrix(shape=(None, None), dtype="float64")
2737+
mat2 = matrix(shape=(None, None), dtype="float64")
2738+
2739+
inputs = [scl1, row1, row2, col1, col2, mat1, mat2]
2740+
test_vals = [
2741+
np.random.random((1, 1)),
2742+
np.random.random((1, 2)),
2743+
np.random.random((1, 2)),
2744+
np.random.random((2, 1)),
2745+
np.random.random((2, 1)),
2746+
np.random.random((2, 2)),
2747+
np.random.random((2, 2)),
2748+
]
2749+
2750+
for out, expected_out in [
2751+
(
2752+
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=None),
2753+
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=None),
2754+
),
2755+
(
2756+
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=0),
2757+
(
2758+
mul(row1.squeeze(), row2.squeeze())
2759+
** prod([mul(mat1, mat2, col1, col2).shape[0]])
2760+
* mul(mat1, mat2, col1, col2).prod(axis=0)
2761+
),
2762+
),
2763+
(
2764+
mul(row1, mat1, mat2, col1, col2).prod(axis=0),
2765+
(
2766+
row1.squeeze() ** prod([mul(mat1, mat2, col1, col2).shape[0]])
2767+
* mul(mat1, mat2, col1, col2).prod(axis=0)
2768+
),
2769+
),
2770+
(
2771+
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=1),
2772+
(
2773+
mul(col1.squeeze(), col2.squeeze())
2774+
** prod([mul(row1, row2, mat1, mat2).shape[1]])
2775+
* mul(row1, row2, mat1, mat2).prod(axis=1)
2776+
),
2777+
),
2778+
(
2779+
mul(row1, row2).prod(axis=0),
2780+
mul(row1.squeeze(), row2.squeeze()),
2781+
),
2782+
(
2783+
mul(row1, col1).prod(axis=0),
2784+
(row1.squeeze() ** prod([col1.shape[0]]) * col1.prod(axis=0)),
2785+
),
2786+
(
2787+
mul(scl1, mat1, row1).prod(axis=None),
2788+
(
2789+
scl1.squeeze()
2790+
** prod([mul(mat1, row1).shape[0], mul(mat1, row1).shape[1]])
2791+
* mul(mat1, row1).prod(axis=None)
2792+
),
2793+
),
2794+
]:
2795+
out_fn = pytensor.function(inputs, out, mode=mode, on_unused_input="ignore")
2796+
2797+
rewritten_out = rewrite_graph(out, custom_rewrite=rewrite)
2798+
assert equal_computations([rewritten_out], [expected_out])
2799+
2800+
rewritten_out_fn = pytensor.function(
2801+
inputs, rewritten_out, mode=mode, on_unused_input="ignore"
2802+
)
2803+
np.testing.assert_allclose(
2804+
out_fn(*test_vals),
2805+
rewritten_out_fn(*test_vals),
2806+
)
2807+
26562808
def test_local_sum_prod_all_to_none(self):
26572809
a = tensor3()
26582810
input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)

0 commit comments

Comments
 (0)