diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index fc8f21cea8..86846c3249 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -142,15 +142,10 @@ def get_params(self, node): def make_node(self, x): x = as_tensor_variable(x) - # We keep the original broadcastable flags for dimensions on which - # we do not perform the max / argmax. + # Keep the original shapes for axes on which we do not perform the max/argmax. all_axes = set(self.axis) inputs = [x] - out_shape = tuple( - 1 if s == 1 else None - for i, s in enumerate(x.type.shape) - if i not in all_axes - ) + out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes) outputs = [ tensor(dtype=x.type.dtype, shape=out_shape, name="max"), tensor(dtype="int64", shape=out_shape, name="argmax"), @@ -1521,7 +1516,6 @@ def perform(self, node, inp, out): output[0] = np.asarray(np.mean(input, dtype="float64", axis=axis)) def c_code(self, node, name, inames, onames, sub): - ret = super().c_code(node, name, inames, onames, sub) if self.axis is not None: @@ -1940,7 +1934,6 @@ def perform(self, node, inp, out): z[0] = np.asarray(np.dot(x, y)) def grad(self, inp, grads): - x, y = inp (gz,) = grads xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim @@ -2631,7 +2624,6 @@ def L_op(self, inp, out, grads): # this handles inputs with zeros, but only certain input shapes return [grad_case_without_zeros] else: - where_zeros = eq(prod_in, 0.0) sum_where_zeros = sum(where_zeros, axis=self.axis) groups_with_single_zero = eq(sum_where_zeros, 1).dimshuffle(new_dims) @@ -2924,7 +2916,6 @@ def _get_output_shape(cls, x1, x2, shapes, validate=False): ) return x2_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:] else: - if validate: from pytensor.tensor.random.basic import broadcast_shapes diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 414ea86ab4..3fc8756506 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -771,10 +771,9 @@ def test_basic_1(self): v = eval_outputs(max_and_argmax(n)[0].shape) assert len(v) == 0 - def test_basic_2(self): - data = random(2, 3) - n = as_tensor_variable(data) - for (axis, np_axis) in [ + @pytest.mark.parametrize( + "axis,np_axis", + [ (-1, -1), (0, 0), (1, 1), @@ -783,19 +782,28 @@ def test_basic_2(self): ([1, 0], None), (NoneConst.clone(), None), (constant(0), 0), - ]: - v, i = eval_outputs(max_and_argmax(n, axis)) - assert i.dtype == "int64" - assert np.all(v == np.max(data, np_axis)) - assert np.all(i == np.argmax(data, np_axis)) - v_shape = eval_outputs(max_and_argmax(n, axis)[0].shape) - assert tuple(v_shape) == np.max(data, np_axis).shape - - def test_basic_2_float16(self): - # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16 - data = (random(20, 30).astype("float16") - 0.5) * 20 - n = shared(data) - for (axis, np_axis) in [ + ], + ) + def test_basic_2(self, axis, np_axis): + data = random(2, 3) + n = as_tensor_variable(data) + # Test shape propagates (static & eval) + vt, it = max_and_argmax(n, axis) + np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis) + assert vt.type.shape == np_max.shape + assert it.type.shape == np_argm.shape + v_shape, i_shape = eval_outputs([vt.shape, it.shape]) + assert tuple(v_shape) == vt.type.shape + assert tuple(i_shape) == it.type.shape + # Test values + v, i = eval_outputs([vt, it]) + assert i.dtype == "int64" + assert np.all(v == np_max) + assert np.all(i == np_argm) + + @pytest.mark.parametrize( + "axis,np_axis", + [ (-1, -1), (0, 0), (1, 1), @@ -804,13 +812,25 @@ def test_basic_2_float16(self): ([1, 0], None), (NoneConst.clone(), None), (constant(0), 0), - ]: - v, i = eval_outputs(max_and_argmax(n, axis), (MaxAndArgmax,)) - assert i.dtype == "int64" - assert np.all(v == np.max(data, np_axis)) - assert np.all(i == np.argmax(data, np_axis)) - v_shape = eval_outputs(max_and_argmax(n, axis)[0].shape) - assert tuple(v_shape) == np.max(data, np_axis).shape + ], + ) + def test_basic_2_float16(self, axis, np_axis): + # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16 + data = (random(20, 30).astype("float16") - 0.5) * 20 + n = as_tensor_variable(data) + # Test shape propagates (static & eval) + vt, it = max_and_argmax(n, axis) + np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis) + assert vt.type.shape == np_max.shape + assert it.type.shape == np_argm.shape + v_shape, i_shape = eval_outputs([vt.shape, it.shape]) + assert tuple(v_shape) == vt.type.shape + assert tuple(i_shape) == it.type.shape + # Test values + v, i = eval_outputs([vt, it]) + assert i.dtype == "int64" + assert np.all(v == np_max) + assert np.all(i == np_argm) def test_basic_2_invalid(self): n = as_tensor_variable(random(2, 3)) @@ -840,23 +860,33 @@ def test_basic_2_valid_neg(self): v = eval_outputs(max_and_argmax(n, -2)[0].shape) assert v == (3) - def test_basic_3(self): - data = random(2, 3, 4) - n = as_tensor_variable(data) - for (axis, np_axis) in [ + @pytest.mark.parametrize( + "axis,np_axis", + [ (-1, -1), (0, 0), (1, 1), (None, None), ([0, 1, 2], None), ([1, 2, 0], None), - ]: - v, i = eval_outputs(max_and_argmax(n, axis)) - assert i.dtype == "int64" - assert np.all(v == np.max(data, np_axis)) - assert np.all(i == np.argmax(data, np_axis)) - v = eval_outputs(max_and_argmax(n, axis)[0].shape) - assert tuple(v) == np.max(data, np_axis).shape + ], + ) + def test_basic_3(self, axis, np_axis): + data = random(2, 3, 4) + n = as_tensor_variable(data) + # Test shape propagates (static & eval) + vt, it = max_and_argmax(n, axis) + np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis) + assert vt.type.shape == np_max.shape + assert it.type.shape == np_argm.shape + v_shape, i_shape = eval_outputs([vt.shape, it.shape]) + assert tuple(v_shape) == vt.type.shape + assert tuple(i_shape) == it.type.shape + # Test values + v, i = eval_outputs([vt, it]) + assert i.dtype == "int64" + assert np.all(v == np_max) + assert np.all(i == np_argm) def test_arg_grad(self): # The test checks that the gradient of argmax(x).sum() is 0 @@ -948,17 +978,19 @@ def test_preserve_broadcastable(self): # Ensure the original broadcastable flags are preserved by Max/Argmax. x = matrix().dimshuffle("x", 0, "x", 1, "x") y = x.max(axis=1) + assert y.type.shape == (1, 1, None, 1) assert y.type.broadcastable == (True, True, False, True) def test_multiple_axes(self): data = np.arange(24).reshape(3, 2, 4) x = as_tensor_variable(data) - v, i = eval_outputs(max_and_argmax(x, [1, -1])) + vt, it = max_and_argmax(x, [1, -1]) + assert vt.type.shape == it.type.shape == (3,) + v, i = eval_outputs([vt, it]) assert np.all(v == np.array([7, 15, 23])) assert np.all(i == np.array([7, 7, 7])) - - v = eval_outputs(max_and_argmax(x, [1, -1])[0].shape) - assert tuple(v) == np.max(data, (1, -1)).shape + v = eval_outputs(vt.shape) + assert tuple(v) == vt.type.shape def test_zero_shape(self): x = matrix() @@ -972,8 +1004,8 @@ def test_zero_shape(self): def test_numpy_input(self): ar = np.array([1, 2, 3]) max_at, argmax_at = max_and_argmax(ar, axis=None) - assert max_at.eval(), 3 - assert argmax_at.eval(), 2 + assert max_at.eval() == 3 + assert argmax_at.eval() == 2 class TestArgminArgmax: