From a2509d463b589e094760502765288e12c1d104d4 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Thu, 27 Jun 2024 17:41:02 +0530 Subject: [PATCH 1/4] added check for square matrix in make_node to Det --- pytensor/tensor/nlinalg.py | 10 +++++++++- tests/tensor/test_nlinalg.py | 5 +++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 30a3c8d5c8..969c324260 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -198,7 +198,15 @@ class Det(Op): def make_node(self, x): x = as_tensor_variable(x) - assert x.ndim == 2 + if x.ndim != 2: + raise ValueError() + # Check for known shapes and square matrix + if all(shape is not None for shape in x.type.shape) and ( + x.type.shape[0] != x.type.shape[1] + ): + raise ValueError( + f"Det not defined for non-square matrix inputs. Shape received is {x.type.shape}" + ) o = scalar(dtype=x.dtype) return Apply(self, [x], [o]) diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 8a0c6c6596..29744505fa 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -365,6 +365,11 @@ def test_det(): assert np.allclose(np.linalg.det(r), f(r)) +def test_det_non_square(): + with pytest.raises(ValueError, match="Det not defined"): + det(tensor("x", shape=(5, 7))) + + def test_det_grad(): rng = np.random.default_rng(utt.fetch_seed()) From 9bcf15b58a573f8e946f82b3751ebbebeb6652fc Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Thu, 27 Jun 2024 21:42:08 +0530 Subject: [PATCH 2/4] cleaned code --- pytensor/tensor/nlinalg.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 969c324260..57a98fa301 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -199,11 +199,11 @@ class Det(Op): def make_node(self, x): x = as_tensor_variable(x) if x.ndim != 2: - raise ValueError() + raise ValueError( + f"Input passed is not a valid 2D matrix. Current ndim {x.ndim} != 2" + ) # Check for known shapes and square matrix - if all(shape is not None for shape in x.type.shape) and ( - x.type.shape[0] != x.type.shape[1] - ): + if None not in x.type.shape and (x.type.shape[0] != x.type.shape[1]): raise ValueError( f"Det not defined for non-square matrix inputs. Shape received is {x.type.shape}" ) From 6e1454462df117504566158e1f2d9fc19ac2cec7 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Fri, 28 Jun 2024 21:45:03 +0530 Subject: [PATCH 3/4] small fixes --- pytensor/tensor/nlinalg.py | 2 +- tests/tensor/test_nlinalg.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 57a98fa301..6db6ae2638 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -205,7 +205,7 @@ def make_node(self, x): # Check for known shapes and square matrix if None not in x.type.shape and (x.type.shape[0] != x.type.shape[1]): raise ValueError( - f"Det not defined for non-square matrix inputs. Shape received is {x.type.shape}" + f"Determinant not defined for non-square matrix inputs. Shape received is {x.type.shape}" ) o = scalar(dtype=x.dtype) return Apply(self, [x], [o]) diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 29744505fa..7a1480aaa8 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -365,7 +365,7 @@ def test_det(): assert np.allclose(np.linalg.det(r), f(r)) -def test_det_non_square(): +def test_det_non_square_raises(): with pytest.raises(ValueError, match="Det not defined"): det(tensor("x", shape=(5, 7))) From 38fb1130771b6a63a0fcce10e52a61716f084a7d Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Fri, 28 Jun 2024 21:56:00 +0530 Subject: [PATCH 4/4] updated regex in test --- tests/tensor/test_nlinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 7a1480aaa8..1a13992011 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -366,7 +366,7 @@ def test_det(): def test_det_non_square_raises(): - with pytest.raises(ValueError, match="Det not defined"): + with pytest.raises(ValueError, match="Determinant not defined"): det(tensor("x", shape=(5, 7)))