diff --git a/pytensor/link/jax/dispatch/nlinalg.py b/pytensor/link/jax/dispatch/nlinalg.py index 68efb3e10e..26af73b137 100644 --- a/pytensor/link/jax/dispatch/nlinalg.py +++ b/pytensor/link/jax/dispatch/nlinalg.py @@ -3,7 +3,7 @@ from pytensor.link.jax.dispatch import jax_funcify from pytensor.tensor.blas import BatchedDot from pytensor.tensor.math import Dot, MaxAndArgmax -from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull +from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull, SLogDet @jax_funcify.register(SVD) @@ -25,6 +25,14 @@ def det(x): return det +@jax_funcify.register(SLogDet) +def jax_funcify_SLogDet(op, **kwargs): + def slogdet(x): + return jnp.linalg.slogdet(x) + + return slogdet + + @jax_funcify.register(Eig) def jax_funcify_Eig(op, **kwargs): def eig(x): diff --git a/pytensor/link/numba/dispatch/nlinalg.py b/pytensor/link/numba/dispatch/nlinalg.py index 21fa34e1bb..174c684d06 100644 --- a/pytensor/link/numba/dispatch/nlinalg.py +++ b/pytensor/link/numba/dispatch/nlinalg.py @@ -18,6 +18,7 @@ MatrixInverse, MatrixPinv, QRFull, + SLogDet, ) @@ -58,6 +59,25 @@ def det(x): return det +@numba_funcify.register(SLogDet) +def numba_funcify_SLogDet(op, node, **kwargs): + + out_dtype_1 = node.outputs[0].type.numpy_dtype + out_dtype_2 = node.outputs[1].type.numpy_dtype + + inputs_cast = int_to_float_fn(node.inputs, out_dtype_1) + + @numba_basic.numba_njit + def slogdet(x): + sign, det = np.linalg.slogdet(inputs_cast(x)) + return ( + numba_basic.direct_cast(sign, out_dtype_1), + numba_basic.direct_cast(det, out_dtype_2), + ) + + return slogdet + + @numba_funcify.register(Eig) def numba_funcify_Eig(op, node, **kwargs): diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 799f0fcf54..7da8bb044d 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -231,6 +231,39 @@ def __str__(self): det = Det() +class SLogDet(Op): + """ + Compute the log determinant and its sign of the matrix. Input should be a square matrix. + """ + + __props__ = () + + def make_node(self, x): + x = as_tensor_variable(x) + assert x.ndim == 2 + sign = scalar(dtype=x.dtype) + det = scalar(dtype=x.dtype) + return Apply(self, [x], [sign, det]) + + def perform(self, node, inputs, outputs): + (x,) = inputs + (sign, det) = outputs + try: + sign[0], det[0] = (z.astype(x.dtype) for z in np.linalg.slogdet(x)) + except Exception: + print("Failed to compute determinant", x) + raise + + def infer_shape(self, fgraph, node, shapes): + return [(), ()] + + def __str__(self): + return "SLogDet" + + +slogdet = SLogDet() + + class Eig(Op): """ Compute the eigenvalues and right eigenvectors of a square array. diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 50d1cf378e..146766f400 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -85,6 +85,10 @@ def assert_fn(x, y): out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + outs = at_nlinalg.slogdet(x) + out_fg = FunctionGraph([x], outs) + compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 3018b9a97a..7bc60d1313 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -179,6 +179,41 @@ def test_Det(x, exc): ) +@pytest.mark.parametrize( + "x, exc", + [ + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + None, + ), + ( + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), + ), + None, + ), + ], +) +def test_SLogDet(x, exc): + g = nlinalg.SLogDet()(x) + g_fg = FunctionGraph(outputs=g) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + # We were seeing some weird results in CI where the following two almost # sign-swapped results were being return from Numba and Python, respectively. # The issue might be related to https://github.com/numba/numba/issues/4519. diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index b39ac0ba3c..15dbbff083 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -24,6 +24,7 @@ norm, pinv, qr, + slogdet, svd, tensorinv, tensorsolve, @@ -280,6 +281,18 @@ def test_det_shape(): assert tuple(det_shape.data) == () +def test_slogdet(): + rng = np.random.default_rng(utt.fetch_seed()) + + r = rng.standard_normal((5, 5)).astype(config.floatX) + x = matrix() + f = pytensor.function([x], slogdet(x)) + f_sign, f_det = f(r) + sign, det = np.linalg.slogdet(r) + assert np.equal(sign, f_sign) + assert np.allclose(det, f_det) + + def test_trace(): rng = np.random.default_rng(utt.fetch_seed()) x = matrix()