From fdaaafbf5927ed2f88d6759f9c6fa670931fe723 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Wed, 4 Jan 2023 20:37:44 +0100 Subject: [PATCH 1/3] Add slogdet for Numba --- pytensor/link/numba/dispatch/nlinalg.py | 20 ++++++++++++++ pytensor/tensor/nlinalg.py | 33 +++++++++++++++++++++++ tests/link/numba/test_nlinalg.py | 35 +++++++++++++++++++++++++ tests/tensor/test_nlinalg.py | 21 +++++++++++++++ 4 files changed, 109 insertions(+) diff --git a/pytensor/link/numba/dispatch/nlinalg.py b/pytensor/link/numba/dispatch/nlinalg.py index 21fa34e1bb..174c684d06 100644 --- a/pytensor/link/numba/dispatch/nlinalg.py +++ b/pytensor/link/numba/dispatch/nlinalg.py @@ -18,6 +18,7 @@ MatrixInverse, MatrixPinv, QRFull, + SLogDet, ) @@ -58,6 +59,25 @@ def det(x): return det +@numba_funcify.register(SLogDet) +def numba_funcify_SLogDet(op, node, **kwargs): + + out_dtype_1 = node.outputs[0].type.numpy_dtype + out_dtype_2 = node.outputs[1].type.numpy_dtype + + inputs_cast = int_to_float_fn(node.inputs, out_dtype_1) + + @numba_basic.numba_njit + def slogdet(x): + sign, det = np.linalg.slogdet(inputs_cast(x)) + return ( + numba_basic.direct_cast(sign, out_dtype_1), + numba_basic.direct_cast(det, out_dtype_2), + ) + + return slogdet + + @numba_funcify.register(Eig) def numba_funcify_Eig(op, node, **kwargs): diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 799f0fcf54..7da8bb044d 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -231,6 +231,39 @@ def __str__(self): det = Det() +class SLogDet(Op): + """ + Compute the log determinant and its sign of the matrix. Input should be a square matrix. + """ + + __props__ = () + + def make_node(self, x): + x = as_tensor_variable(x) + assert x.ndim == 2 + sign = scalar(dtype=x.dtype) + det = scalar(dtype=x.dtype) + return Apply(self, [x], [sign, det]) + + def perform(self, node, inputs, outputs): + (x,) = inputs + (sign, det) = outputs + try: + sign[0], det[0] = (z.astype(x.dtype) for z in np.linalg.slogdet(x)) + except Exception: + print("Failed to compute determinant", x) + raise + + def infer_shape(self, fgraph, node, shapes): + return [(), ()] + + def __str__(self): + return "SLogDet" + + +slogdet = SLogDet() + + class Eig(Op): """ Compute the eigenvalues and right eigenvectors of a square array. diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 3018b9a97a..7bc60d1313 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -179,6 +179,41 @@ def test_Det(x, exc): ) +@pytest.mark.parametrize( + "x, exc", + [ + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + None, + ), + ( + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), + ), + None, + ), + ], +) +def test_SLogDet(x, exc): + g = nlinalg.SLogDet()(x) + g_fg = FunctionGraph(outputs=g) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + # We were seeing some weird results in CI where the following two almost # sign-swapped results were being return from Numba and Python, respectively. # The issue might be related to https://github.com/numba/numba/issues/4519. diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index b39ac0ba3c..7ba4f875e3 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -24,6 +24,7 @@ norm, pinv, qr, + slogdet, svd, tensorinv, tensorsolve, @@ -280,6 +281,26 @@ def test_det_shape(): assert tuple(det_shape.data) == () +def test_slogdet(): + rng = np.random.default_rng(utt.fetch_seed()) + + r = rng.standard_normal((5, 5)).astype(config.floatX) + x = matrix() + f = pytensor.function([x], slogdet(x)) + f_sign, f_det = f(r) + sign, det = np.linalg.slogdet(r) + assert np.equal(sign, f_sign) + assert np.allclose(det, f_det) + + +def test_slogdet_shape(): + x = matrix() + sign, det = slogdet(x) + for shape in [sign.shape, det.shape]: + assert isinstance(shape, Constant) + assert tuple(shape.data) == () + + def test_trace(): rng = np.random.default_rng(utt.fetch_seed()) x = matrix() From c8a6f7faf23d766b8f7b90c7e7cef6e1a4f45d96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Sat, 7 Jan 2023 12:15:27 +0100 Subject: [PATCH 2/3] Add slogdet for JAX --- pytensor/link/jax/dispatch/nlinalg.py | 10 +++++++++- tests/link/jax/test_nlinalg.py | 4 ++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/dispatch/nlinalg.py b/pytensor/link/jax/dispatch/nlinalg.py index 68efb3e10e..26af73b137 100644 --- a/pytensor/link/jax/dispatch/nlinalg.py +++ b/pytensor/link/jax/dispatch/nlinalg.py @@ -3,7 +3,7 @@ from pytensor.link.jax.dispatch import jax_funcify from pytensor.tensor.blas import BatchedDot from pytensor.tensor.math import Dot, MaxAndArgmax -from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull +from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull, SLogDet @jax_funcify.register(SVD) @@ -25,6 +25,14 @@ def det(x): return det +@jax_funcify.register(SLogDet) +def jax_funcify_SLogDet(op, **kwargs): + def slogdet(x): + return jnp.linalg.slogdet(x) + + return slogdet + + @jax_funcify.register(Eig) def jax_funcify_Eig(op, **kwargs): def eig(x): diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 50d1cf378e..146766f400 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -85,6 +85,10 @@ def assert_fn(x, y): out_fg = FunctionGraph([x], outs) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + outs = at_nlinalg.slogdet(x) + out_fg = FunctionGraph([x], outs) + compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), From 79917e06f81962502783db0a9e86364993ac815c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Sat, 7 Jan 2023 17:24:44 +0100 Subject: [PATCH 3/3] Remove test_slogdet_shape test --- tests/tensor/test_nlinalg.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 7ba4f875e3..15dbbff083 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -293,14 +293,6 @@ def test_slogdet(): assert np.allclose(det, f_det) -def test_slogdet_shape(): - x = matrix() - sign, det = slogdet(x) - for shape in [sign.shape, det.shape]: - assert isinstance(shape, Constant) - assert tuple(shape.data) == () - - def test_trace(): rng = np.random.default_rng(utt.fetch_seed()) x = matrix()