diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index bd33c2e6a1..1c26c5590d 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -42,13 +42,8 @@ from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.math import ( - All, - Any, Dot, - FixedOpCAReduce, - NonZeroDimsCAReduce, Prod, - ProdWithoutZeros, Sum, _conj, add, @@ -96,6 +91,7 @@ register_uncanonicalize, register_useless, ) +from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.shape import Shape, Shape_i from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( @@ -105,7 +101,11 @@ values_eq_approx_remove_inf_nan, values_eq_approx_remove_nan, ) -from pytensor.tensor.variable import TensorConstant, get_unique_constant_value +from pytensor.tensor.variable import ( + TensorConstant, + TensorVariable, + get_unique_constant_value, +) def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): @@ -1580,130 +1580,110 @@ def local_sum_prod_all_to_none(fgraph, node): @register_canonicalize -@node_rewriter([Sum, Prod]) -def local_op_of_op(fgraph, node): +@node_rewriter([CAReduce]) +def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None: """ - Prod(Prod()) -> single Prod() - or Sum(Sum()) -> single Sum() + or any CAReduce(Careduce(x)) of the same type """ - op_type = Sum if isinstance(node.op, Sum) else Prod - (node_inps,) = node.inputs - out_dtype = node.op.dtype - # This is done to make sure the rewrite doesn't affect other - # computations. - if len(fgraph.clients[node_inps]) == 1: - if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)): - # check to see either the inner or outer prod is doing a - # product over all axis, in which case we can remove it - if node_inps.owner.op.axis is None or node.op.axis is None: - return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])] - - # figure out which axes were in the original sum - newaxis = list(node_inps.owner.op.axis) - for i in node.op.axis: - new_i = i - for ii in node_inps.owner.op.axis: - if new_i >= ii: - new_i += 1 - assert new_i not in newaxis - newaxis.append(new_i) - - assert len(newaxis) == len( - list(node_inps.owner.op.axis) + list(node.op.axis) - ) + [inner_reduce] = node.inputs + if not (inner_reduce.owner and isinstance(inner_reduce.owner.op, CAReduce)): + return None - combined = op_type(newaxis, dtype=out_dtype) - return [combined(node_inps.owner.inputs[0])] + # Don't apply rewrite if inner_reduce is used elsewhere + if len(fgraph.clients[inner_reduce]) > 1: + return None + # Check if CAReduces have the same scalar op + outer_op: CAReduce = node.op + inner_op = inner_reduce.owner.op -ALL_REDUCE = [ - CAReduce, - All, - Any, - Sum, - Prod, - ProdWithoutZeros, - *CAReduce.__subclasses__(), - *FixedOpCAReduce.__subclasses__(), - *NonZeroDimsCAReduce.__subclasses__(), -] + if outer_op.scalar_op != inner_op.scalar_op: + return None + + outer_axis = outer_op.axis + inner_axis = inner_op.axis + [x] = inner_reduce.owner.inputs + # check to see either the inner or outer prod is doing a + # product over all axis, in which case we can remove it + if outer_axis is None or inner_axis is None: + return [outer_op.clone(axis=None)(x)] + + # Merge axis + newaxis = list(inner_axis) + for i in outer_axis: + new_i = i + for ii in inner_axis: + if new_i >= ii: + new_i += 1 + assert new_i not in newaxis + newaxis.append(new_i) + + assert len(newaxis) == len(inner_axis) + len(outer_axis) + return [outer_op.clone(axis=sorted(newaxis))(x)] @register_canonicalize @register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce -@node_rewriter(ALL_REDUCE) +@node_rewriter([CAReduce]) def local_reduce_join(fgraph, node): """ - CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b) + CAReduce{scalar.op}(Join(axis=x, a, b), axis=x) -> Elemwise{scalar.op}(a, b) - Notes - ----- - Supported scalar.op are Maximum, Minimum in some cases and Add and Mul in - all cases. - - Currently we must reduce on axis 0. It is probably extensible to the case - where we join and reduce on the same set of axis. + When a, b have a dim length of 1 along the join axis """ - if node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join): - join_node = node.inputs[0].owner - if extract_constant(join_node.inputs[0], only_process_constants=True) != 0: - return + if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join)): + return None - if isinstance(node.op.scalar_op, ps.ScalarMaximum | ps.ScalarMinimum): - # Support only 2 inputs for now - if len(join_node.inputs) != 3: - return - elif not isinstance(node.op.scalar_op, ps.Add | ps.Mul): - return - elif len(join_node.inputs) <= 2: - # This is a useless join that should get removed by another rewrite? - return + [joined_out] = node.inputs + joined_node = joined_out.owner + join_axis_tensor, *joined_inputs = joined_node.inputs - new_inp = [] - for inp in join_node.inputs[1:]: - inp = inp.owner - if not inp: - return - if not isinstance(inp.op, DimShuffle) or inp.op.new_order != ( - "x", - *range(inp.inputs[0].ndim), - ): - return - new_inp.append(inp.inputs[0]) - ret = Elemwise(node.op.scalar_op)(*new_inp) + n_joined_inputs = len(joined_inputs) + if n_joined_inputs < 2: + # Let some other rewrite get rid of this useless Join + return None + if n_joined_inputs > 2 and not isinstance(node.op.scalar_op, ps.Add | ps.Mul): + # We don't rewrite if a single Elemwise cannot take all inputs at once + return None - if ret.dtype != node.outputs[0].dtype: - # The reduction do something about the dtype. - return + if not isinstance(join_axis_tensor, Constant): + return None + join_axis = join_axis_tensor.data - reduce_axis = node.op.axis - if reduce_axis is None: - reduce_axis = tuple(range(node.inputs[0].ndim)) + # Check whether reduction happens on joined axis + reduce_op = node.op + reduce_axis = reduce_op.axis + if reduce_axis is None: + if joined_out.type.ndim > 1: + return None + elif reduce_axis != (join_axis,): + return None - if len(reduce_axis) != 1 or 0 not in reduce_axis: - return + # Check all inputs are broadcastable along the join axis and squeeze those dims away + new_inputs = [] + for inp in joined_inputs: + if not inp.type.broadcastable[join_axis]: + return None + # Most times inputs to join have an expand_dims, we eagerly clean up those here + new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis)) + new_inputs.append(new_input) - # We add the new check late to don't add extra warning. - try: - join_axis = get_underlying_scalar_constant_value( - join_node.inputs[0], only_process_constants=True - ) + ret = Elemwise(node.op.scalar_op)(*new_inputs) - if join_axis != reduce_axis[0]: - return - except NotScalarConstantError: - return + if ret.dtype != node.outputs[0].dtype: + # The reduction do something about the dtype. + return None - return [ret] + return [ret] @register_infer_shape @register_canonicalize("fast_compile", "local_cut_useless_reduce") @register_useless("local_cut_useless_reduce") -@node_rewriter(ALL_REDUCE) +@node_rewriter([CAReduce]) def local_useless_reduce(fgraph, node): """Sum(a, axis=[]) -> a""" (summed,) = node.inputs @@ -1715,7 +1695,7 @@ def local_useless_reduce(fgraph, node): @register_canonicalize @register_uncanonicalize @register_specialize -@node_rewriter(ALL_REDUCE) +@node_rewriter([CAReduce]) def local_reduce_broadcastable(fgraph, node): """Remove reduction over broadcastable dimensions.""" (reduced,) = node.inputs diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 174858da30..ed4b03d9f3 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -101,6 +101,7 @@ local_grad_log_erfc_neg, local_greedy_distributor, local_mul_canonizer, + local_reduce_chain, local_sum_prod_of_mul_or_div, mul_canonizer, parse_mul_tree, @@ -2497,6 +2498,168 @@ def test_elemwise(self): assert debugprint(g, file="str").count("Switch") == 1 +class TestReduceChain: + def setup_method(self): + self.mode = get_default_mode().including("canonicalize", "specialize") + + def test_local_sum_prod_all_to_none(self): + a = tensor3() + input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5) + # test sum + f = function([a], a.sum(), mode=self.mode) + assert len(f.maker.fgraph.apply_nodes) == 1 + utt.assert_allclose(f(input), input.sum()) + # test prod + f = function([a], a.prod(), mode=self.mode) + assert len(f.maker.fgraph.apply_nodes) == 1 + utt.assert_allclose(f(input), input.prod()) + # test sum + f = function([a], a.sum([0, 1, 2]), mode=self.mode) + assert len(f.maker.fgraph.apply_nodes) == 1 + utt.assert_allclose(f(input), input.sum()) + # test prod + f = function([a], a.prod([0, 1, 2]), mode=self.mode) + assert len(f.maker.fgraph.apply_nodes) == 1 + utt.assert_allclose(f(input), input.prod()) + + f = function([a], a.sum(0).sum(0).sum(0), mode=self.mode) + assert len(f.maker.fgraph.apply_nodes) == 1 + utt.assert_allclose(f(input), input.sum()) + + def test_local_sum_sum_prod_prod(self): + a = tensor3() + input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5) + dims = [ + (0, 0), + (1, 0), + (2, 0), + (0, 1), + (1, 1), + (2, 1), + ((0, 1), 0), + ((1, 2), 0), + (0, (0, 1)), + (1, (0, 1)), + (2, (0, 1)), + ] + + def my_prod(data, d, dd): + # This prod when d or dd is a tuple of 2 dimensions. + if not isinstance(d, tuple) and not isinstance(dd, tuple): + return data.prod(d).prod(dd) + if isinstance(d, tuple): + d = sorted(d) + return data.prod(d[1]).prod(d[0]).prod(dd) + else: + dd = sorted(dd) + return data.prod(d).prod(dd[1]).prod(dd[0]) + + def my_sum(data, d, dd): + # This sum when d or dd is a tuple of 2 dimensions. + if not isinstance(d, tuple) and not isinstance(dd, tuple): + return data.sum(d).sum(dd) + if isinstance(d, tuple): + d = sorted(d) + return data.sum(d[1]).sum(d[0]).sum(dd) + else: + dd = sorted(dd) + return data.sum(d).sum(dd[1]).sum(dd[0]) + + def my_sum_prod(data, d, dd): + # This sum when d or dd is a tuple of 2 dimensions. + if not isinstance(d, tuple) and not isinstance(dd, tuple): + return data.sum(d).prod(dd) + if isinstance(d, tuple): + d = sorted(d) + return data.sum(d[1]).sum(d[0]).prod(dd) + else: + dd = sorted(dd) + return data.sum(d).prod(dd[1]).prod(dd[0]) + + for d, dd in dims: + expected = my_sum(input, d, dd) + f = function([a], a.sum(d).sum(dd), mode=self.mode) + utt.assert_allclose(f(input), expected) + assert len(f.maker.fgraph.apply_nodes) == 1 + for d, dd in dims[:6]: + f = function([a], a.sum(d).sum(dd).sum(0), mode=self.mode) + utt.assert_allclose(f(input), input.sum(d).sum(dd).sum(0)) + assert len(f.maker.fgraph.apply_nodes) == 1 + for d in [0, 1, 2]: + f = function([a], a.sum(d).sum(None), mode=self.mode) + utt.assert_allclose(f(input), input.sum(d).sum()) + assert len(f.maker.fgraph.apply_nodes) == 1 + f = function([a], a.sum(None).sum(), mode=self.mode) + utt.assert_allclose(f(input), input.sum()) + assert len(f.maker.fgraph.apply_nodes) == 1 + + # test prod + for d, dd in dims: + expected = my_prod(input, d, dd) + f = function([a], a.prod(d).prod(dd), mode=self.mode) + utt.assert_allclose(f(input), expected) + assert len(f.maker.fgraph.apply_nodes) == 1 + for d, dd in dims[:6]: + f = function([a], a.prod(d).prod(dd).prod(0), mode=self.mode) + utt.assert_allclose(f(input), input.prod(d).prod(dd).prod(0)) + assert len(f.maker.fgraph.apply_nodes) == 1 + for d in [0, 1, 2]: + f = function([a], a.prod(d).prod(None), mode=self.mode) + utt.assert_allclose(f(input), input.prod(d).prod()) + assert len(f.maker.fgraph.apply_nodes) == 1 + f = function([a], a.prod(None).prod(), mode=self.mode) + utt.assert_allclose(f(input), input.prod()) + assert len(f.maker.fgraph.apply_nodes) == 1 + + # Test that sum prod didn't get rewritten. + for d, dd in dims: + expected = my_sum_prod(input, d, dd) + f = function([a], a.sum(d).prod(dd), mode=self.mode) + utt.assert_allclose(f(input), expected) + assert len(f.maker.fgraph.apply_nodes) == 2 + for d, dd in dims[:6]: + f = function([a], a.sum(d).prod(dd).prod(0), mode=self.mode) + utt.assert_allclose(f(input), input.sum(d).prod(dd).prod(0)) + assert len(f.maker.fgraph.apply_nodes) == 2 + for d in [0, 1, 2]: + f = function([a], a.sum(d).prod(None), mode=self.mode) + utt.assert_allclose(f(input), input.sum(d).prod()) + assert len(f.maker.fgraph.apply_nodes) == 2 + f = function([a], a.sum(None).prod(), mode=self.mode) + utt.assert_allclose(f(input), input.sum()) + assert len(f.maker.fgraph.apply_nodes) == 1 + + def test_local_sum_sum_int8(self): + """Test that `local_sum_sum` works when combining two sums on an int8 array. + + This is a regression test for ticket gh-356. + """ + + x = tensor3(dtype="int8") + y = x.sum(axis=0).sum(axis=1) + + with config.change_flags(on_opt_error="raise"): + # This compilation would fail prior to fix. + function([x], y) + + def test_local_sum_sum_dtype(self): + """Test that `local_sum_sum` works when specifying dtypes manually.""" + + x = tensor3(dtype="int8") + y = x.sum(axis=0, dtype="int32").sum(axis=1, dtype="int64") + + with config.change_flags(on_opt_error="raise"): + # This compilation would fail prior to fix. + function([x], y) + + def test_all(self): + x = tensor3(dtype=bool) + out = x.all(axis=-1).all(axis=0) + fg = FunctionGraph([x], [out], clone=False) + [new_out] = local_reduce_chain.transform(fg, out.owner) + assert equal_computations([new_out], [x.all(axis=(0, 2))]) + + class TestLocalSumProd: """Test sum/prod rewrites.""" @@ -2813,133 +2976,6 @@ def test_prod_of_non_scalar_mul(self): rewritten_out_fn(*test_vals), ) - def test_local_sum_prod_all_to_none(self): - a = tensor3() - input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5) - # test sum - f = function([a], a.sum(), mode=self.mode) - assert len(f.maker.fgraph.apply_nodes) == 1 - utt.assert_allclose(f(input), input.sum()) - # test prod - f = function([a], a.prod(), mode=self.mode) - assert len(f.maker.fgraph.apply_nodes) == 1 - utt.assert_allclose(f(input), input.prod()) - # test sum - f = function([a], a.sum([0, 1, 2]), mode=self.mode) - assert len(f.maker.fgraph.apply_nodes) == 1 - utt.assert_allclose(f(input), input.sum()) - # test prod - f = function([a], a.prod([0, 1, 2]), mode=self.mode) - assert len(f.maker.fgraph.apply_nodes) == 1 - utt.assert_allclose(f(input), input.prod()) - - f = function([a], a.sum(0).sum(0).sum(0), mode=self.mode) - assert len(f.maker.fgraph.apply_nodes) == 1 - utt.assert_allclose(f(input), input.sum()) - - def test_local_sum_sum_prod_prod(self): - a = tensor3() - input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5) - dims = [ - (0, 0), - (1, 0), - (2, 0), - (0, 1), - (1, 1), - (2, 1), - ((0, 1), 0), - ((1, 2), 0), - (0, (0, 1)), - (1, (0, 1)), - (2, (0, 1)), - ] - - def my_prod(data, d, dd): - # This prod when d or dd is a tuple of 2 dimensions. - if not isinstance(d, tuple) and not isinstance(dd, tuple): - return data.prod(d).prod(dd) - if isinstance(d, tuple): - d = sorted(d) - return data.prod(d[1]).prod(d[0]).prod(dd) - else: - dd = sorted(dd) - return data.prod(d).prod(dd[1]).prod(dd[0]) - - def my_sum(data, d, dd): - # This sum when d or dd is a tuple of 2 dimensions. - if not isinstance(d, tuple) and not isinstance(dd, tuple): - return data.sum(d).sum(dd) - if isinstance(d, tuple): - d = sorted(d) - return data.sum(d[1]).sum(d[0]).sum(dd) - else: - dd = sorted(dd) - return data.sum(d).sum(dd[1]).sum(dd[0]) - - def my_sum_prod(data, d, dd): - # This sum when d or dd is a tuple of 2 dimensions. - if not isinstance(d, tuple) and not isinstance(dd, tuple): - return data.sum(d).prod(dd) - if isinstance(d, tuple): - d = sorted(d) - return data.sum(d[1]).sum(d[0]).prod(dd) - else: - dd = sorted(dd) - return data.sum(d).prod(dd[1]).prod(dd[0]) - - for d, dd in dims: - expected = my_sum(input, d, dd) - f = function([a], a.sum(d).sum(dd), mode=self.mode) - utt.assert_allclose(f(input), expected) - assert len(f.maker.fgraph.apply_nodes) == 1 - for d, dd in dims[:6]: - f = function([a], a.sum(d).sum(dd).sum(0), mode=self.mode) - utt.assert_allclose(f(input), input.sum(d).sum(dd).sum(0)) - assert len(f.maker.fgraph.apply_nodes) == 1 - for d in [0, 1, 2]: - f = function([a], a.sum(d).sum(None), mode=self.mode) - utt.assert_allclose(f(input), input.sum(d).sum()) - assert len(f.maker.fgraph.apply_nodes) == 1 - f = function([a], a.sum(None).sum(), mode=self.mode) - utt.assert_allclose(f(input), input.sum()) - assert len(f.maker.fgraph.apply_nodes) == 1 - - # test prod - for d, dd in dims: - expected = my_prod(input, d, dd) - f = function([a], a.prod(d).prod(dd), mode=self.mode) - utt.assert_allclose(f(input), expected) - assert len(f.maker.fgraph.apply_nodes) == 1 - for d, dd in dims[:6]: - f = function([a], a.prod(d).prod(dd).prod(0), mode=self.mode) - utt.assert_allclose(f(input), input.prod(d).prod(dd).prod(0)) - assert len(f.maker.fgraph.apply_nodes) == 1 - for d in [0, 1, 2]: - f = function([a], a.prod(d).prod(None), mode=self.mode) - utt.assert_allclose(f(input), input.prod(d).prod()) - assert len(f.maker.fgraph.apply_nodes) == 1 - f = function([a], a.prod(None).prod(), mode=self.mode) - utt.assert_allclose(f(input), input.prod()) - assert len(f.maker.fgraph.apply_nodes) == 1 - - # Test that sum prod didn't get rewritten. - for d, dd in dims: - expected = my_sum_prod(input, d, dd) - f = function([a], a.sum(d).prod(dd), mode=self.mode) - utt.assert_allclose(f(input), expected) - assert len(f.maker.fgraph.apply_nodes) == 2 - for d, dd in dims[:6]: - f = function([a], a.sum(d).prod(dd).prod(0), mode=self.mode) - utt.assert_allclose(f(input), input.sum(d).prod(dd).prod(0)) - assert len(f.maker.fgraph.apply_nodes) == 2 - for d in [0, 1, 2]: - f = function([a], a.sum(d).prod(None), mode=self.mode) - utt.assert_allclose(f(input), input.sum(d).prod()) - assert len(f.maker.fgraph.apply_nodes) == 2 - f = function([a], a.sum(None).prod(), mode=self.mode) - utt.assert_allclose(f(input), input.sum()) - assert len(f.maker.fgraph.apply_nodes) == 1 - def test_local_sum_prod_alloc(self): a = dtensor3() input = np.asarray(np.arange(2 * 3 * 4).reshape(2, 3, 4), dtype="float64") @@ -3005,29 +3041,6 @@ def test_local_sum_prod_alloc(self): assert topo[-1].op == pt.alloc assert not any(isinstance(node.op, Sum) for node in topo) - def test_local_sum_sum_int8(self): - """Test that `local_sum_sum` works when combining two sums on an int8 array. - - This is a regression test for ticket gh-356. - """ - - x = tensor3(dtype="int8") - y = x.sum(axis=0).sum(axis=1) - - with config.change_flags(on_opt_error="raise"): - # This compilation would fail prior to fix. - function([x], y) - - def test_local_sum_sum_dtype(self): - """Test that `local_sum_sum` works when specifying dtypes manually.""" - - x = tensor3(dtype="int8") - y = x.sum(axis=0, dtype="int32").sum(axis=1, dtype="int64") - - with config.change_flags(on_opt_error="raise"): - # This compilation would fail prior to fix. - function([x], y) - def test_local_sum_prod_mul_by_scalar_stack_trace(self): """Test that stack trace is copied over correctly for `local_sum_prod_mul_by_scalar`.""" m0 = ( @@ -3218,7 +3231,7 @@ def test_local_prod_of_div(self): class TestLocalReduce: def setup_method(self): self.mode = get_default_mode().including( - "canonicalize", "specialize", "uncanonicalize", "local_max_and_argmax" + "canonicalize", "specialize", "uncanonicalize" ) def test_local_reduce_broadcast_all_0(self): @@ -3291,62 +3304,94 @@ def test_local_reduce_broadcast_some_1(self): isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort() ) - def test_local_reduce_join(self): + +class TestReduceJoin: + def setup_method(self): + self.mode = get_default_mode().including( + "canonicalize", "specialize", "uncanonicalize" + ) + + @pytest.mark.parametrize( + "op, nin", [(pt_sum, 3), (pt_max, 2), (pt_min, 2), (prod, 3)] + ) + def test_local_reduce_join(self, op, nin): vx = matrix() vy = matrix() vz = matrix() x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX) y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX) z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX) - # Test different reduction scalar operation - for out, res in [ - (pt_max((vx, vy), 0), np.max((x, y), 0)), - (pt_min((vx, vy), 0), np.min((x, y), 0)), - (pt_sum((vx, vy, vz), 0), np.sum((x, y, z), 0)), - (prod((vx, vy, vz), 0), np.prod((x, y, z), 0)), - (prod((vx, vy.T, vz), 0), np.prod((x, y.T, z), 0)), - ]: - f = function([vx, vy, vz], out, on_unused_input="ignore", mode=self.mode) - assert (f(x, y, z) == res).all(), out - topo = f.maker.fgraph.toposort() - assert len(topo) <= 2, out - assert isinstance(topo[-1].op, Elemwise), out + inputs = (vx, vy, vz)[:nin] + test_values = (x, y, z)[:nin] + + out = op(inputs, axis=0) + f = function(inputs, out, mode=self.mode) + np.testing.assert_allclose( + f(*test_values), getattr(np, op.__name__)(test_values, axis=0) + ) + topo = f.maker.fgraph.toposort() + assert len(topo) <= 2 + assert isinstance(topo[-1].op, Elemwise) + + def test_type(self): # Test different axis for the join and the reduction # We must force the dtype, of otherwise, this tests will fail # on 32 bit systems A = shared(np.array([1, 2, 3, 4, 5], dtype="int64")) f = function([], pt_sum(pt.stack([A, A]), axis=0), mode=self.mode) - utt.assert_allclose(f(), [2, 4, 6, 8, 10]) + np.testing.assert_allclose(f(), [2, 4, 6, 8, 10]) topo = f.maker.fgraph.toposort() assert isinstance(topo[-1].op, Elemwise) # Test a case that was bugged in a old PyTensor bug f = function([], pt_sum(pt.stack([A, A]), axis=1), mode=self.mode) - utt.assert_allclose(f(), [15, 15]) + np.testing.assert_allclose(f(), [15, 15]) topo = f.maker.fgraph.toposort() assert not isinstance(topo[-1].op, Elemwise) # This case could be rewritten A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode) - utt.assert_allclose(f(), [2, 4, 6, 8, 10]) + np.testing.assert_allclose(f(), [2, 4, 6, 8, 10]) topo = f.maker.fgraph.toposort() assert not isinstance(topo[-1].op, Elemwise) A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode) - utt.assert_allclose(f(), [15, 15]) + np.testing.assert_allclose(f(), [15, 15]) topo = f.maker.fgraph.toposort() assert not isinstance(topo[-1].op, Elemwise) + def test_not_supported_axis_none(self): # Test that the rewrite does not crash in one case where it # is not applied. Reported at # https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion + vx = matrix() + vy = matrix() + vz = matrix() + x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX) + y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX) + z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX) + out = pt_sum([vx, vy, vz], axis=None) - f = function([vx, vy, vz], out) + f = function([vx, vy, vz], out, mode=self.mode) + np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z])) + + def test_not_supported_unequal_shapes(self): + # Not the same shape along the join axis + vx = matrix(shape=(1, 3)) + vy = matrix(shape=(2, 3)) + x = np.asarray([[1, 0, 1]], dtype=config.floatX) + y = np.asarray([[4, 0, 1], [2, 1, 1]], dtype=config.floatX) + out = pt_sum(join(0, vx, vy), axis=0) + + f = function([vx, vy], out, mode=self.mode) + np.testing.assert_allclose( + f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0) + ) def test_local_useless_adds():