From e47037b80ad22d3eab5e8eeb93a2452832c188ed Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 7 Nov 2023 18:26:50 +0100 Subject: [PATCH 1/3] Refactor sum_prod_mul rewrite test and add failing case Rewrite from prod of mul was not correct when only some axes were reduced by prod --- tests/tensor/rewriting/test_math.py | 30 +++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index a37a161d62..440975933f 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -2512,6 +2512,7 @@ def test_local_sum_prod_mul_by_scalar(self): # 4-the inputs to the mul contain two scalars and no non-scalar # 5-the inputs to the mul contain two scalars and one non-scalar # 6-the inputs to the mul contain two scalars and two non-scalars + # 7-the reduction happens across only the first of two axes vect = dvector() mat = dmatrix() @@ -2524,10 +2525,15 @@ def test_local_sum_prod_mul_by_scalar(self): s2_val = np.random.random() def test_reduction_rewrite( - inputs, inputs_val, reduction_op, expected_output, nb_expected_sum_nodes + inputs, + inputs_val, + reduction_op, + expected_output, + nb_expected_sum_nodes, + axis=None, ): mul_out = mul(*inputs) - f = function(inputs, reduction_op()(mul_out), mode=self.mode) + f = function(inputs, reduction_op(axis=axis)(mul_out), mode=self.mode) out = f(*inputs_val) utt.assert_allclose(out, expected_output) @@ -2581,6 +2587,16 @@ def test_reduction_rewrite( 1, ) + # Case 7 + test_reduction_rewrite( + [mat, scalar1, scalar2], + [m_val, s1_val, s2_val], + Sum, + (s1_val * s2_val * m_val).sum(0), + 1, + axis=(0,), + ) + # Test prod # Case 1 @@ -2627,6 +2643,16 @@ def test_reduction_rewrite( 2, ) + # Case 7 + test_reduction_rewrite( + [mat, scalar1, scalar2], + [m_val, s1_val, s2_val], + Prod, + (s1_val * s2_val * m_val).prod(0), + 1, + axis=(0,), + ) + def test_local_sum_prod_all_to_none(self): a = tensor3() input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5) From efefa70503e03f67254a9682dd7281747af3a593 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 7 Nov 2023 18:26:56 +0100 Subject: [PATCH 2/3] Extend local_sum_prod_of_mul rewrite to non-scalar terms Also: * Separates the sum of negation rewrite * Fixes bug in partial prod reduction --- pytensor/tensor/rewriting/math.py | 134 ++++++++++++------------ tests/tensor/rewriting/test_math.py | 154 +++++++++++++++++++++++++++- 2 files changed, 221 insertions(+), 67 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index a814ffdf69..0d6a9552df 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1190,86 +1190,88 @@ def local_neg_to_mul(fgraph, node): @register_specialize @node_rewriter([Sum, Prod]) -def local_sum_prod_mul_by_scalar(fgraph, node): +def local_sum_prod_of_mul(fgraph, node): """ - sum(scalar * smth) -> scalar * sum(smth) - sum(-smth) -> -sum(smth) + sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions or - prod(scalar * smth) -> scalar ** size(smth) * prod(smth) - prod(-smth) -> -1 ** size(smth) * prod(smth) + prod(a * X) -> (a ** size(X)) * prod(X) + TODO: In the case where not all axis overlap with broadcast dimensions, + consider introducing an outer reduction after factoring out the compatible reduced dimensions + E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1) """ # TODO: if the the thing inside the Sum is a division, # we should get at the numerator.... - if isinstance(node.op, (Sum, Prod)): - (node_inps,) = node.inputs - if node_inps.owner and node_inps.owner.op == mul: - terms = node_inps.owner.inputs - scalars = [t.dimshuffle() for t in terms if all(t.type.broadcastable)] - if len(scalars) == 0: - return + [node_inps] = node.inputs + if not (node_inps.owner and node_inps.owner.op == mul): + return None - non_scalars = [t for t in terms if not all(t.broadcastable)] + reduced_axes = node.op.axis + if reduced_axes is None: + reduced_axes = tuple(range(node_inps.type.ndim)) + + # Separate terms that can be moved out of the Sum/Prod and those that cannot + outer_terms = [] + inner_terms = [] + for term in node_inps.owner.inputs: + term_bcast = term.type.broadcastable + if all(term_bcast[i] for i in reduced_axes): + outer_terms.append(term.squeeze(reduced_axes)) + else: + inner_terms.append(term) - # Perform the op only on the non-scalar inputs, if applicable - if len(non_scalars) == 0: - new_op_input_nb_elements = 1 - new_op_output = 1 - elif len(non_scalars) == 1: - new_op_input_nb_elements = non_scalars[0].size - new_op_output = node.op(non_scalars[0]) - else: - new_op_input = mul(*non_scalars) - # We assume that errors always come from the prod/mul op in the - # original computational graph, and therefore need to only - # copy over its output stacktrace. - copy_stack_trace(node.outputs, new_op_input) - - new_op_input_nb_elements = new_op_input.size - new_op_output = node.op(new_op_input) - - if len(non_scalars) != 0: - # Copy over stacktrace from previous output to new mul op, - # for same reason as above. - copy_stack_trace(node.outputs, new_op_output) - - # If `node.op` is a `Prod`, then the scalars need to be raised to - # the power of the number of elements in the input to the `Prod` - if isinstance(node.op, Prod) and new_op_input_nb_elements != 1: - scalars = [s**new_op_input_nb_elements for s in scalars] - - # Scale the output of the op by the scalars and return as - # replacement for the original output - mul_inputs = scalars - if new_op_input_nb_elements != 1: - mul_inputs.append(new_op_output) - - if len(mul_inputs) == 1: - # Copy over stacktrace from previous output to new mul op, - # for same reason as above. - copy_stack_trace(node.outputs, mul_inputs) - - return mul_inputs - else: - ret = mul(*mul_inputs) - # Copy over stacktrace from previous output to new mul op, - # for same reason as above. - copy_stack_trace(node.outputs, [ret] + mul_inputs) + if not outer_terms: + return None + elif len(outer_terms) == 1: + [outer_term] = outer_terms + else: + outer_term = mul(*outer_terms) - return [ret] + if not inner_terms: + inner_term = None + elif len(inner_terms) == 1: + [inner_term] = inner_terms + else: + inner_term = mul(*inner_terms) + + # If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements + # that were contracted in the input + if isinstance(node.op, Prod) and inner_term: + dtype = inner_term.dtype + n_reduced_elements = prod( + [inner_term.shape[i].astype(dtype) for i in reduced_axes] + ) + outer_term = outer_term**n_reduced_elements - if isinstance(node.op, Sum) and node_inps.owner and node_inps.owner.op == neg: - s = node.op(node_inps.owner.inputs[0]) - ret = neg(s) - # There are never errors in the negative op, thus - # we need only to copy over stacktrace from previous output node to - # the two new ops. - copy_stack_trace(node.outputs, [s, ret]) + # Sum/Prod is useless, just return the outer_term + if not inner_term: + new_out = outer_term + else: + reduced_inner_term = node.op(inner_term) + new_out = outer_term * reduced_inner_term + copy_stack_trace(node.outputs, [inner_term, reduced_inner_term, outer_term]) - return [ret] + copy_stack_trace(node.outputs, new_out) + return [new_out] + + +@register_specialize +@node_rewriter([Sum]) +def local_sum_of_neg_to_neg_of_sum(fgraph, node): + """Rewrite sum(-X) -> -sum(X).""" + [node_inps] = node.inputs + if node_inps.owner and node_inps.owner.op == neg: + s = node.op(node_inps.owner.inputs[0]) + ret = neg(s) + # There are never errors in the negative op, thus + # we need only to copy over stacktrace from previous output node to + # the two new ops. + copy_stack_trace(node.outputs, [s, ret]) + + return [ret] @register_specialize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 440975933f..adcea2aa68 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -92,6 +92,7 @@ local_grad_log_erfc_neg, local_greedy_distributor, local_mul_canonizer, + local_sum_prod_of_mul, mul_canonizer, parse_mul_tree, perform_sigm_times_exp, @@ -2503,7 +2504,7 @@ class TestLocalSumProd: def setup_method(self): self.mode = get_default_mode().including("canonicalize", "specialize") - def test_local_sum_prod_mul_by_scalar(self): + def test_local_sum_prod_of_scalar_mul(self): # Test the rewrite `local_sum_prod_mul_by_scalar` for both Sum and # Prod ops in six cases each : # 1-the inputs to the mul contain a scalar and no non-scalar @@ -2653,6 +2654,157 @@ def test_reduction_rewrite( axis=(0,), ) + def test_sum_of_non_scalar_mul(self): + mode = Mode("vm", optimizer="None") + rewrite = out2in(local_sum_prod_of_mul) + + row1 = matrix(shape=(1, None), dtype="float64") + row2 = matrix(shape=(1, None), dtype="float64") + col1 = matrix(shape=(None, 1), dtype="float64") + col2 = matrix(shape=(None, 1), dtype="float64") + mat1 = matrix(shape=(None, None), dtype="float64") + mat2 = matrix(shape=(None, None), dtype="float64") + + inputs = [row1, row2, col1, col2, mat1, mat2] + test_vals = [ + np.random.random((1, 2)), + np.random.random((1, 2)), + np.random.random((2, 1)), + np.random.random((2, 1)), + np.random.random((2, 2)), + np.random.random((2, 2)), + ] + + for out, expected_out in [ + ( + mul(row1, row2, mat1, mat2, col1, col2).sum(axis=None), + mul(row1, row2, mat1, mat2, col1, col2).sum(axis=None), + ), + ( + mul(row1, row2, mat1, mat2, col1, col2).sum(axis=0), + mul(row1.squeeze(), row2.squeeze()) + * mul(mat1, mat2, col1, col2).sum(axis=0), + ), + ( + mul(row1, mat1, mat2, col1, col2).sum(axis=0), + row1.squeeze() * mul(mat1, mat2, col1, col2).sum(axis=0), + ), + ( + mul(row1, row2, mat1, mat2, col1, col2).sum(axis=1), + mul(col1.squeeze(), col2.squeeze()) + * mul(row1, row2, mat1, mat2).sum(axis=1), + ), + ( + mul(row1, row2, mat1, mat2, col2).sum(axis=1), + col2.squeeze() * mul(row1, row2, mat1, mat2).sum(axis=1), + ), + ( + mul(row1, row2).sum(axis=1), + mul(row1, row2).sum(axis=1), + ), + ( + mul(row1, row2).sum(axis=0), + mul(row1.squeeze(), row2.squeeze()), + ), + ( + mul(row1, col1).sum(axis=0), + row1.squeeze() * col1.sum(axis=0), + ), + ]: + out_fn = pytensor.function(inputs, out, mode=mode, on_unused_input="ignore") + + rewritten_out = rewrite_graph(out, custom_rewrite=rewrite) + assert equal_computations([rewritten_out], [expected_out]) + + rewritten_out_fn = pytensor.function( + inputs, rewritten_out, mode=mode, on_unused_input="ignore" + ) + np.testing.assert_allclose( + out_fn(*test_vals), + rewritten_out_fn(*test_vals), + ) + + def test_prod_of_non_scalar_mul(self): + mode = Mode("vm", optimizer="None") + rewrite = out2in(local_sum_prod_of_mul) + + scl1 = matrix(shape=(1, 1), dtype="float64") + row1 = matrix(shape=(1, None), dtype="float64") + row2 = matrix(shape=(1, None), dtype="float64") + col1 = matrix(shape=(None, 1), dtype="float64") + col2 = matrix(shape=(None, 1), dtype="float64") + mat1 = matrix(shape=(None, None), dtype="float64") + mat2 = matrix(shape=(None, None), dtype="float64") + + inputs = [scl1, row1, row2, col1, col2, mat1, mat2] + test_vals = [ + np.random.random((1, 1)), + np.random.random((1, 2)), + np.random.random((1, 2)), + np.random.random((2, 1)), + np.random.random((2, 1)), + np.random.random((2, 2)), + np.random.random((2, 2)), + ] + + for out, expected_out in [ + ( + mul(row1, row2, mat1, mat2, col1, col2).prod(axis=None), + mul(row1, row2, mat1, mat2, col1, col2).prod(axis=None), + ), + ( + mul(row1, row2, mat1, mat2, col1, col2).prod(axis=0), + ( + mul(row1.squeeze(), row2.squeeze()) + ** prod([mul(mat1, mat2, col1, col2).shape[0]]) + * mul(mat1, mat2, col1, col2).prod(axis=0) + ), + ), + ( + mul(row1, mat1, mat2, col1, col2).prod(axis=0), + ( + row1.squeeze() ** prod([mul(mat1, mat2, col1, col2).shape[0]]) + * mul(mat1, mat2, col1, col2).prod(axis=0) + ), + ), + ( + mul(row1, row2, mat1, mat2, col1, col2).prod(axis=1), + ( + mul(col1.squeeze(), col2.squeeze()) + ** prod([mul(row1, row2, mat1, mat2).shape[1]]) + * mul(row1, row2, mat1, mat2).prod(axis=1) + ), + ), + ( + mul(row1, row2).prod(axis=0), + mul(row1.squeeze(), row2.squeeze()), + ), + ( + mul(row1, col1).prod(axis=0), + (row1.squeeze() ** prod([col1.shape[0]]) * col1.prod(axis=0)), + ), + ( + mul(scl1, mat1, row1).prod(axis=None), + ( + scl1.squeeze() + ** prod([mul(mat1, row1).shape[0], mul(mat1, row1).shape[1]]) + * mul(mat1, row1).prod(axis=None) + ), + ), + ]: + out_fn = pytensor.function(inputs, out, mode=mode, on_unused_input="ignore") + + rewritten_out = rewrite_graph(out, custom_rewrite=rewrite) + assert equal_computations([rewritten_out], [expected_out]) + + rewritten_out_fn = pytensor.function( + inputs, rewritten_out, mode=mode, on_unused_input="ignore" + ) + np.testing.assert_allclose( + out_fn(*test_vals), + rewritten_out_fn(*test_vals), + ) + def test_local_sum_prod_all_to_none(self): a = tensor3() input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5) From 152a551dfc1272ed39c4e50e66f6a91b77797bf4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 20 Nov 2023 15:20:33 +0100 Subject: [PATCH 3/3] Merge rewrite for sum/prod of div with that of mul --- pytensor/tensor/rewriting/math.py | 166 ++++--------- tests/tensor/rewriting/test_elemwise.py | 2 +- tests/tensor/rewriting/test_math.py | 309 ++++++++++++------------ 3 files changed, 202 insertions(+), 275 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 0d6a9552df..67dc8eedeb 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1190,7 +1190,7 @@ def local_neg_to_mul(fgraph, node): @register_specialize @node_rewriter([Sum, Prod]) -def local_sum_prod_of_mul(fgraph, node): +def local_sum_prod_of_mul_or_div(fgraph, node): """ sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions @@ -1198,15 +1198,20 @@ def local_sum_prod_of_mul(fgraph, node): prod(a * X) -> (a ** size(X)) * prod(X) + It also applies to reduction of X / a, + but not a / X, as that would still require inverting every value in X before the reduction + TODO: In the case where not all axis overlap with broadcast dimensions, consider introducing an outer reduction after factoring out the compatible reduced dimensions E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1) """ - # TODO: if the the thing inside the Sum is a division, - # we should get at the numerator.... [node_inps] = node.inputs - if not (node_inps.owner and node_inps.owner.op == mul): + if not node_inps.owner: + return None + + inner_op = node_inps.owner.op + if not (inner_op == mul or inner_op == true_div): return None reduced_axes = node.op.axis @@ -1214,28 +1219,40 @@ def local_sum_prod_of_mul(fgraph, node): reduced_axes = tuple(range(node_inps.type.ndim)) # Separate terms that can be moved out of the Sum/Prod and those that cannot - outer_terms = [] - inner_terms = [] - for term in node_inps.owner.inputs: - term_bcast = term.type.broadcastable - if all(term_bcast[i] for i in reduced_axes): - outer_terms.append(term.squeeze(reduced_axes)) - else: - inner_terms.append(term) + if inner_op == mul: + # Mul accepts arbitrary inputs, so we need to separate into two groups + outer_terms = [] + inner_terms = [] + for term in node_inps.owner.inputs: + term_bcast = term.type.broadcastable + if all(term_bcast[i] for i in reduced_axes): + outer_terms.append(term.squeeze(reduced_axes)) + else: + inner_terms.append(term) - if not outer_terms: - return None - elif len(outer_terms) == 1: - [outer_term] = outer_terms - else: - outer_term = mul(*outer_terms) + if not outer_terms: + return None + elif len(outer_terms) == 1: + [outer_term] = outer_terms + else: + outer_term = mul(*outer_terms) - if not inner_terms: - inner_term = None - elif len(inner_terms) == 1: - [inner_term] = inner_terms - else: - inner_term = mul(*inner_terms) + if not inner_terms: + inner_term = None + elif len(inner_terms) == 1: + [inner_term] = inner_terms + else: + inner_term = mul(*inner_terms) + + else: # true_div + # We only care about removing the denominator out of the reduction + numerator, denominator = node_inps.owner.inputs + denominator_bcast = denominator.type.broadcastable + if all(denominator_bcast[i] for i in reduced_axes): + outer_term = denominator.squeeze(reduced_axes) + inner_term = numerator + else: + return None # If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements # that were contracted in the input @@ -1246,12 +1263,16 @@ def local_sum_prod_of_mul(fgraph, node): ) outer_term = outer_term**n_reduced_elements - # Sum/Prod is useless, just return the outer_term if not inner_term: + # Sum/Prod is useless, just return the outer_term + # (This can only happen for mul, not division) new_out = outer_term else: reduced_inner_term = node.op(inner_term) - new_out = outer_term * reduced_inner_term + if inner_op == mul: + new_out = outer_term * reduced_inner_term + else: + new_out = reduced_inner_term / outer_term copy_stack_trace(node.outputs, [inner_term, reduced_inner_term, outer_term]) copy_stack_trace(node.outputs, new_out) @@ -1510,99 +1531,6 @@ def investigate(node): return -@register_canonicalize -@register_specialize -@node_rewriter([Sum, Prod]) -def local_sum_prod_div_dimshuffle(fgraph, node): - """ - sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b, - if dimension l of the DimShuffle is 'x' - - or - - prod(a / dimshuffle{...}(b), axis=l) -> - prod(a, axis={...}) / b ** a.shape[l], - if dimension l of the DimShuffle is 'x' - """ - - # It does not make much sense now to extend it to the case where the - # dimshuffle is in the numerator, since elemwise inversion of the - # denominator would still be needed before the summation or production. - - if isinstance(node.op, (Sum, Prod)): - axis = node.op.axis - if axis is None: - axis = list(range(node.inputs[0].ndim)) - node_input = node.inputs[0] - if node_input.owner and node_input.owner.op == true_div: - numerator, denominator = node_input.owner.inputs - - if denominator.owner and isinstance(denominator.owner.op, DimShuffle): - dimshuffle_input = denominator.owner.inputs[0] - dimshuffle_order = denominator.owner.op.new_order - - compatible_dims = [] - incompatible_dims = [] - for ax in axis: - if ax < len(dimshuffle_order) and dimshuffle_order[ax] == "x": - compatible_dims.append(ax) - else: - incompatible_dims.append(ax) - reordered_incompatible_dims = [] - for ic_ax in incompatible_dims: - reordered_incompatible_dims.append( - ic_ax - sum(1 for c_ax in compatible_dims if c_ax < ic_ax) - ) - - if len(compatible_dims) > 0: - optimized_dimshuffle_order = [ - ax - for i, ax in enumerate(dimshuffle_order) - if (i not in axis) or (ax != "x") - ] - - # Removing leading 'x' (since it will be done automatically) - while ( - len(optimized_dimshuffle_order) > 0 - and optimized_dimshuffle_order[0] == "x" - ): - del optimized_dimshuffle_order[0] - - # if optimized_dimshuffle_order is sorted with - # not 'x', then dimshuffle is useless. - if all(i == e for i, e in enumerate(optimized_dimshuffle_order)): - optimized_dimshuffle = dimshuffle_input - else: - optimized_dimshuffle = DimShuffle( - dimshuffle_input.type.broadcastable, - optimized_dimshuffle_order, - )(dimshuffle_input) - - if isinstance(node.op, Sum): - op_on_compatible_dims = at_sum(numerator, axis=compatible_dims) - rval = true_div(op_on_compatible_dims, optimized_dimshuffle) - if len(reordered_incompatible_dims) > 0: - rval = at_sum(rval, axis=reordered_incompatible_dims) - elif isinstance(node.op, Prod): - op_on_compatible_dims = prod(numerator, axis=compatible_dims) - dtype = numerator.dtype - rval = true_div( - op_on_compatible_dims, - ( - optimized_dimshuffle - ** prod( - [ - numerator.shape[ax].astype(dtype) - for ax in compatible_dims - ] - ) - ), - ) - if len(reordered_incompatible_dims) > 0: - rval = prod(rval, axis=reordered_incompatible_dims) - return [rval] - - @register_canonicalize @node_rewriter([Sum, Prod]) def local_sum_prod_all_to_none(fgraph, node): diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index ac4d293f16..8e7c754d5e 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -899,7 +899,7 @@ def large_fuseable_graph(self, n): ), (fx, fy), (fxv, fyv), - 3, + 2, ( np.sum(-((fxv - fyv) ** 2) / 2), -(fxv - fyv), diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index adcea2aa68..4bc7ae3ad3 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -92,7 +92,7 @@ local_grad_log_erfc_neg, local_greedy_distributor, local_mul_canonizer, - local_sum_prod_of_mul, + local_sum_prod_of_mul_or_div, mul_canonizer, parse_mul_tree, perform_sigm_times_exp, @@ -2656,7 +2656,7 @@ def test_reduction_rewrite( def test_sum_of_non_scalar_mul(self): mode = Mode("vm", optimizer="None") - rewrite = out2in(local_sum_prod_of_mul) + rewrite = out2in(local_sum_prod_of_mul_or_div) row1 = matrix(shape=(1, None), dtype="float64") row2 = matrix(shape=(1, None), dtype="float64") @@ -2726,7 +2726,7 @@ def test_sum_of_non_scalar_mul(self): def test_prod_of_non_scalar_mul(self): mode = Mode("vm", optimizer="None") - rewrite = out2in(local_sum_prod_of_mul) + rewrite = out2in(local_sum_prod_of_mul_or_div) scl1 = matrix(shape=(1, 1), dtype="float64") row1 = matrix(shape=(1, None), dtype="float64") @@ -2756,14 +2756,15 @@ def test_prod_of_non_scalar_mul(self): mul(row1, row2, mat1, mat2, col1, col2).prod(axis=0), ( mul(row1.squeeze(), row2.squeeze()) - ** prod([mul(mat1, mat2, col1, col2).shape[0]]) + ** prod([mul(mat1, mat2, col1, col2).shape[0].astype("float64")]) * mul(mat1, mat2, col1, col2).prod(axis=0) ), ), ( mul(row1, mat1, mat2, col1, col2).prod(axis=0), ( - row1.squeeze() ** prod([mul(mat1, mat2, col1, col2).shape[0]]) + row1.squeeze() + ** prod([mul(mat1, mat2, col1, col2).shape[0].astype("float64")]) * mul(mat1, mat2, col1, col2).prod(axis=0) ), ), @@ -2771,7 +2772,7 @@ def test_prod_of_non_scalar_mul(self): mul(row1, row2, mat1, mat2, col1, col2).prod(axis=1), ( mul(col1.squeeze(), col2.squeeze()) - ** prod([mul(row1, row2, mat1, mat2).shape[1]]) + ** prod([mul(row1, row2, mat1, mat2).shape[1].astype("float64")]) * mul(row1, row2, mat1, mat2).prod(axis=1) ), ), @@ -2781,13 +2782,21 @@ def test_prod_of_non_scalar_mul(self): ), ( mul(row1, col1).prod(axis=0), - (row1.squeeze() ** prod([col1.shape[0]]) * col1.prod(axis=0)), + ( + row1.squeeze() ** prod([col1.shape[0].astype("float64")]) + * col1.prod(axis=0) + ), ), ( mul(scl1, mat1, row1).prod(axis=None), ( scl1.squeeze() - ** prod([mul(mat1, row1).shape[0], mul(mat1, row1).shape[1]]) + ** prod( + [ + mul(mat1, row1).shape[0].astype("float64"), + mul(mat1, row1).shape[1].astype("float64"), + ] + ) * mul(mat1, row1).prod(axis=None) ), ), @@ -3050,146 +3059,7 @@ def test_local_sum_prod_mul_by_scalar_stack_trace(self): f = function([mat], at_sum(-mat), mode=m0) assert check_stack_trace(f, ops_to_check=[Sum]) - -class TestLocalReduce: - def setup_method(self): - self.mode = get_default_mode().including( - "canonicalize", "specialize", "uncanonicalize", "local_max_and_argmax" - ) - - def test_local_reduce_broadcast_all_0(self): - for fct in [ - at_sum, - at_all, - at_any, - prod, - at_max, - at_min, - ]: - x = TensorType("int64", shape=(1, 1, 1))() - f = function([x], [fct(x)], mode=self.mode) - assert not any( - isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort() - ) - - def test_local_reduce_broadcast_all_1(self): - for fct in [ - at_sum, - at_all, - at_any, - prod, - at_max, - at_min, - ]: - x = TensorType("int64", shape=(1, 1))() - f = function([x], [fct(x, axis=[0, 1])], mode=self.mode) - assert not any( - isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort() - ) - - def test_local_reduce_broadcast_some_0(self): - for fct in [ - at_sum, - at_all, - at_any, - prod, - at_max, - at_min, - ]: - x = TensorType("int64", shape=(1, None, 1))() - f = function([x], [fct(x, axis=[0, 1])], mode=self.mode) - - order = f.maker.fgraph.toposort() - assert 1 == sum(isinstance(node.op, CAReduce) for node in order) - - node = [node for node in order if isinstance(node.op, CAReduce)][0] - - op = node.op - assert isinstance(op, CAReduce) - # The leading broadcastable dimension has been dropped by the - # `local_reduce_broadcastable` rewrite. Now, summation is over - # the original `x`'s dimension 1. - assert node.inputs[0].ndim == 2, node - assert op.axis == (0,), op.axis - - def test_local_reduce_broadcast_some_1(self): - for fct in [ - at_sum, - at_all, - at_any, - prod, - at_max, - at_min, - ]: - x = TensorType("int64", shape=(1, 1, 1))() - f = function([x], [fct(x, axis=[0, 2])], mode=self.mode) - assert not any( - isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort() - ) - - def test_local_reduce_join(self): - vx = matrix() - vy = matrix() - vz = matrix() - x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX) - y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX) - z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX) - # Test different reduction scalar operation - for out, res in [ - (at_max((vx, vy), 0), np.max((x, y), 0)), - (at_min((vx, vy), 0), np.min((x, y), 0)), - (at_sum((vx, vy, vz), 0), np.sum((x, y, z), 0)), - (prod((vx, vy, vz), 0), np.prod((x, y, z), 0)), - (prod((vx, vy.T, vz), 0), np.prod((x, y.T, z), 0)), - ]: - f = function([vx, vy, vz], out, on_unused_input="ignore", mode=self.mode) - assert (f(x, y, z) == res).all(), out - topo = f.maker.fgraph.toposort() - assert len(topo) <= 2, out - assert isinstance(topo[-1].op, Elemwise), out - - # Test different axis for the join and the reduction - # We must force the dtype, of otherwise, this tests will fail - # on 32 bit systems - A = shared(np.array([1, 2, 3, 4, 5], dtype="int64")) - - f = function([], at_sum(at.stack([A, A]), axis=0), mode=self.mode) - utt.assert_allclose(f(), [2, 4, 6, 8, 10]) - topo = f.maker.fgraph.toposort() - assert isinstance(topo[-1].op, Elemwise) - - # Test a case that was bugged in a old PyTensor bug - f = function([], at_sum(at.stack([A, A]), axis=1), mode=self.mode) - - utt.assert_allclose(f(), [15, 15]) - topo = f.maker.fgraph.toposort() - assert not isinstance(topo[-1].op, Elemwise) - - # This case could be rewritten - A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) - f = function([], at_sum(at.concatenate((A, A), axis=1), axis=1), mode=self.mode) - utt.assert_allclose(f(), [2, 4, 6, 8, 10]) - topo = f.maker.fgraph.toposort() - assert not isinstance(topo[-1].op, Elemwise) - - A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) - f = function([], at_sum(at.concatenate((A, A), axis=1), axis=0), mode=self.mode) - utt.assert_allclose(f(), [15, 15]) - topo = f.maker.fgraph.toposort() - assert not isinstance(topo[-1].op, Elemwise) - - # Test that the rewrite does not crash in one case where it - # is not applied. Reported at - # https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion - out = at_sum([vx, vy, vz], axis=None) - f = function([vx, vy, vz], out) - - -class TestLocalSumProdDimshuffle: - def setup_method(self): - self.mode = get_default_mode().including("canonicalize") - - def test_local_sum_div_dimshuffle(self): + def test_local_sum_of_div(self): a = matrix("a") b = vector("b") c = tensor3("c") @@ -3242,7 +3112,7 @@ def test_local_sum_div_dimshuffle(self): assert isinstance(g[-1].op.scalar_op, aes.basic.TrueDiv) f(a_val, b_val, c_val, d_val) - def test_local_prod_div_dimshuffle(self): + def test_local_prod_of_div(self): a = matrix("a") b = vector("b") c = tensor3("c") @@ -3295,9 +3165,9 @@ def test_local_prod_div_dimshuffle(self): # `FusionOptimizer` is included to make sure that `expected_outer_operator` # remains the same for all rewrite modes. mode_with_rewrite = default_mode.including( - "local_sum_prod_div_dimshuffle", "FusionOptimizer" + "local_sum_prod_of_mul_or_div", "FusionOptimizer" ) - mode_without_rewrite = default_mode.excluding("local_sum_prod_div_dimshuffle") + mode_without_rewrite = default_mode.excluding("local_sum_prod_of_mul_or_div") # Numerical tests: tests whether the numerical values with and without # rewrites are equal or not. @@ -3345,9 +3215,139 @@ def test_local_prod_div_dimshuffle(self): g.maker.fgraph.toposort()[-1].op.scalar_op, expected_outer_operator[i] ) - # TODO: - # test_local_sum_prod_dimshuffle (a * b * c) - # test_local_sum_divprod_dimshuffle ((a * b) / (c * d)) + +class TestLocalReduce: + def setup_method(self): + self.mode = get_default_mode().including( + "canonicalize", "specialize", "uncanonicalize", "local_max_and_argmax" + ) + + def test_local_reduce_broadcast_all_0(self): + for fct in [ + at_sum, + at_all, + at_any, + prod, + at_max, + at_min, + ]: + x = TensorType("int64", shape=(1, 1, 1))() + f = function([x], [fct(x)], mode=self.mode) + assert not any( + isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort() + ) + + def test_local_reduce_broadcast_all_1(self): + for fct in [ + at_sum, + at_all, + at_any, + prod, + at_max, + at_min, + ]: + x = TensorType("int64", shape=(1, 1))() + f = function([x], [fct(x, axis=[0, 1])], mode=self.mode) + assert not any( + isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort() + ) + + def test_local_reduce_broadcast_some_0(self): + for fct in [ + at_sum, + at_all, + at_any, + prod, + at_max, + at_min, + ]: + x = TensorType("int64", shape=(1, None, 1))() + f = function([x], [fct(x, axis=[0, 1])], mode=self.mode) + + order = f.maker.fgraph.toposort() + assert 1 == sum(isinstance(node.op, CAReduce) for node in order) + + node = [node for node in order if isinstance(node.op, CAReduce)][0] + + op = node.op + assert isinstance(op, CAReduce) + # The leading broadcastable dimension has been dropped by the + # `local_reduce_broadcastable` rewrite. Now, summation is over + # the original `x`'s dimension 1. + assert node.inputs[0].ndim == 2, node + assert op.axis == (0,), op.axis + + def test_local_reduce_broadcast_some_1(self): + for fct in [ + at_sum, + at_all, + at_any, + prod, + at_max, + at_min, + ]: + x = TensorType("int64", shape=(1, 1, 1))() + f = function([x], [fct(x, axis=[0, 2])], mode=self.mode) + assert not any( + isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort() + ) + + def test_local_reduce_join(self): + vx = matrix() + vy = matrix() + vz = matrix() + x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX) + y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX) + z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX) + # Test different reduction scalar operation + for out, res in [ + (at_max((vx, vy), 0), np.max((x, y), 0)), + (at_min((vx, vy), 0), np.min((x, y), 0)), + (at_sum((vx, vy, vz), 0), np.sum((x, y, z), 0)), + (prod((vx, vy, vz), 0), np.prod((x, y, z), 0)), + (prod((vx, vy.T, vz), 0), np.prod((x, y.T, z), 0)), + ]: + f = function([vx, vy, vz], out, on_unused_input="ignore", mode=self.mode) + assert (f(x, y, z) == res).all(), out + topo = f.maker.fgraph.toposort() + assert len(topo) <= 2, out + assert isinstance(topo[-1].op, Elemwise), out + + # Test different axis for the join and the reduction + # We must force the dtype, of otherwise, this tests will fail + # on 32 bit systems + A = shared(np.array([1, 2, 3, 4, 5], dtype="int64")) + + f = function([], at_sum(at.stack([A, A]), axis=0), mode=self.mode) + utt.assert_allclose(f(), [2, 4, 6, 8, 10]) + topo = f.maker.fgraph.toposort() + assert isinstance(topo[-1].op, Elemwise) + + # Test a case that was bugged in a old PyTensor bug + f = function([], at_sum(at.stack([A, A]), axis=1), mode=self.mode) + + utt.assert_allclose(f(), [15, 15]) + topo = f.maker.fgraph.toposort() + assert not isinstance(topo[-1].op, Elemwise) + + # This case could be rewritten + A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) + f = function([], at_sum(at.concatenate((A, A), axis=1), axis=1), mode=self.mode) + utt.assert_allclose(f(), [2, 4, 6, 8, 10]) + topo = f.maker.fgraph.toposort() + assert not isinstance(topo[-1].op, Elemwise) + + A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) + f = function([], at_sum(at.concatenate((A, A), axis=1), axis=0), mode=self.mode) + utt.assert_allclose(f(), [15, 15]) + topo = f.maker.fgraph.toposort() + assert not isinstance(topo[-1].op, Elemwise) + + # Test that the rewrite does not crash in one case where it + # is not applied. Reported at + # https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion + out = at_sum([vx, vy, vz], axis=None) + f = function([vx, vy, vz], out) def test_local_useless_adds(): @@ -3534,7 +3534,6 @@ def test_local_mul_exp_to_exp_add(): # e^x * e^y * e^z * e^w = e^(x+y+z+w) op = expx * expy * expz * expw f = function([x, y, z, w], op, mode) - pytensor.dprint(f) utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 + 6)) graph = f.maker.fgraph.toposort() assert all(isinstance(n.op, Elemwise) for n in graph)