diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 30a3c8d5c8..6db6ae2638 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -198,7 +198,15 @@ class Det(Op): def make_node(self, x): x = as_tensor_variable(x) - assert x.ndim == 2 + if x.ndim != 2: + raise ValueError( + f"Input passed is not a valid 2D matrix. Current ndim {x.ndim} != 2" + ) + # Check for known shapes and square matrix + if None not in x.type.shape and (x.type.shape[0] != x.type.shape[1]): + raise ValueError( + f"Determinant not defined for non-square matrix inputs. Shape received is {x.type.shape}" + ) o = scalar(dtype=x.dtype) return Apply(self, [x], [o]) diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 8a0c6c6596..1a13992011 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -365,6 +365,11 @@ def test_det(): assert np.allclose(np.linalg.det(r), f(r)) +def test_det_non_square_raises(): + with pytest.raises(ValueError, match="Determinant not defined"): + det(tensor("x", shape=(5, 7))) + + def test_det_grad(): rng = np.random.default_rng(utt.fetch_seed())