From 81e0ea8b4018e0a4ba6002a268938daa4c8db5cc Mon Sep 17 00:00:00 2001 From: theorashid Date: Thu, 6 Jun 2024 11:41:59 +0100 Subject: [PATCH 1/2] Refactor np.linalg.slogdet to use np.array instead of z.astype for dtype consistency --- pytensor/tensor/nlinalg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index d46f576a71..30a3c8d5c8 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -246,7 +246,7 @@ def perform(self, node, inputs, outputs): (x,) = inputs (sign, det) = outputs try: - sign[0], det[0] = (z.astype(x.dtype) for z in np.linalg.slogdet(x)) + sign[0], det[0] = (np.array(z, dtype=x.dtype) for z in np.linalg.slogdet(x)) except Exception: print("Failed to compute determinant", x) raise @@ -1186,6 +1186,7 @@ def kron(a, b): "lstsq", "matrix_power", "norm", + "slogdet", "tensorinv", "tensorsolve", "kron", From c45485f33c44c2e29926680f2711d7043dfeb60c Mon Sep 17 00:00:00 2001 From: theorashid Date: Thu, 6 Jun 2024 12:30:35 +0100 Subject: [PATCH 2/2] add a test on the breaking case for slogdet types --- tests/tensor/test_nlinalg.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index bf4bb1a904..8a0c6c6596 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -387,6 +387,11 @@ def test_slogdet(): sign, det = np.linalg.slogdet(r) assert np.equal(sign, f_sign) assert np.allclose(det, f_det) + # check numpy array types is returned + # see https://github.com/pymc-devs/pytensor/issues/799 + sign, logdet = slogdet(x) + det = sign * pytensor.tensor.exp(logdet) + assert_array_almost_equal(det.eval({x: [[1]]}), np.array(1.0)) def test_trace():