Skip to content

Commit 4b6a444

Browse files
authored
Fix bug in nlinalg.slogdet and expose it in linalg module (#807)
1 parent 086323f commit 4b6a444

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def perform(self, node, inputs, outputs):
246246
(x,) = inputs
247247
(sign, det) = outputs
248248
try:
249-
sign[0], det[0] = (z.astype(x.dtype) for z in np.linalg.slogdet(x))
249+
sign[0], det[0] = (np.array(z, dtype=x.dtype) for z in np.linalg.slogdet(x))
250250
except Exception:
251251
print("Failed to compute determinant", x)
252252
raise
@@ -1186,6 +1186,7 @@ def kron(a, b):
11861186
"lstsq",
11871187
"matrix_power",
11881188
"norm",
1189+
"slogdet",
11891190
"tensorinv",
11901191
"tensorsolve",
11911192
"kron",

tests/tensor/test_nlinalg.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,11 @@ def test_slogdet():
387387
sign, det = np.linalg.slogdet(r)
388388
assert np.equal(sign, f_sign)
389389
assert np.allclose(det, f_det)
390+
# check numpy array types is returned
391+
# see https://github.com/pymc-devs/pytensor/issues/799
392+
sign, logdet = slogdet(x)
393+
det = sign * pytensor.tensor.exp(logdet)
394+
assert_array_almost_equal(det.eval({x: [[1]]}), np.array(1.0))
390395

391396

392397
def test_trace():

0 commit comments

Comments
 (0)