From 2858e6935b995f8ac900af3b9254f804666b36f4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 3 Oct 2023 16:40:41 +0100 Subject: [PATCH 1/2] Fix TensorVariable __rmatmul__ --- pytensor/tensor/variable.py | 2 +- tests/tensor/test_variable.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 0f41e1c1d2..57804a204a 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -652,7 +652,7 @@ def __matmul__(left, right): return at.math.matmul(left, right) def __rmatmul__(right, left): - return at.math.matmul(right, left) + return at.math.matmul(left, right) def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None): """See :func:`pytensor.tensor.math.sum`.""" diff --git a/tests/tensor/test_variable.py b/tests/tensor/test_variable.py index 4d0d6b46d6..d1692e5ba2 100644 --- a/tests/tensor/test_variable.py +++ b/tests/tensor/test_variable.py @@ -10,7 +10,7 @@ from pytensor.compile.mode import get_default_mode from pytensor.graph.basic import Constant, equal_computations from pytensor.tensor import get_vector_length -from pytensor.tensor.basic import as_tensor, constant +from pytensor.tensor.basic import constant from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import dot, eq, matmul from pytensor.tensor.shape import Shape @@ -98,10 +98,15 @@ def test_infix_matmul_method(): assert equal_computations([res], [exp_res]) X_val = np.arange(2 * 3).reshape((2, 3)) - res = as_tensor(X_val) @ y + res = X_val @ y exp_res = matmul(X_val, y) assert equal_computations([res], [exp_res]) + y_val = np.arange(3) + res = X @ y_val + exp_res = matmul(X, y_val) + assert equal_computations([res], [exp_res]) + def test_empty_list_indexing(): ynp = np.zeros((2, 2))[:, []] From 686725c709339ca874028c011751fd7eb2453285 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 3 Oct 2023 16:41:15 +0100 Subject: [PATCH 2/2] Add explicit check for failing ndarray.dot(TensorVariable) --- tests/tensor/test_variable.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/tensor/test_variable.py b/tests/tensor/test_variable.py index d1692e5ba2..b43cb2c4e4 100644 --- a/tests/tensor/test_variable.py +++ b/tests/tensor/test_variable.py @@ -75,7 +75,7 @@ def test_numpy_method(fct, value): utt.assert_allclose(np.nan_to_num(f(value)), np.nan_to_num(fct(value))) -def test_infix_dot_method(): +def test_dot_method(): X = dmatrix("X") y = dvector("y") @@ -83,10 +83,12 @@ def test_infix_dot_method(): exp_res = dot(X, y) assert equal_computations([res], [exp_res]) + # This doesn't work. Numpy calls TensorVariable.__rmul__ at some point and everything is messed up X_val = np.arange(2 * 3).reshape((2, 3)) - res = as_tensor(X_val).dot(y) + res = X_val.dot(y) exp_res = dot(X_val, y) - assert equal_computations([res], [exp_res]) + with pytest.raises(AssertionError): + assert equal_computations([res], [exp_res]) def test_infix_matmul_method():