From 2fa9338aac81ca8bca5a54322dfb91903807a51f Mon Sep 17 00:00:00 2001 From: HangenYuu Date: Fri, 28 Jun 2024 15:38:49 +0800 Subject: [PATCH 1/2] Implemented JAX backend for Eigvalsh --- pytensor/link/jax/dispatch/slinalg.py | 25 ++++++++++++++++++++- tests/link/jax/test_slinalg.py | 31 +++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index 73ddadc2a0..c3b0b1200d 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -1,7 +1,30 @@ import jax from pytensor.link.jax.dispatch.basic import jax_funcify -from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, Solve, SolveTriangular +from pytensor.tensor.slinalg import ( + BlockDiagonal, + Cholesky, + Eigvalsh, + Solve, + SolveTriangular, +) + + +@jax_funcify.register(Eigvalsh) +def jax_funcify_Eigvalsh(op, **kwargs): + if op.lower: + UPLO = "L" + else: + UPLO = "U" + + def eigvalsh(a, b): + if b is not None: + raise NotImplementedError( + "jax.numpy.linalg.eigvalsh does not support computing eigenvalues with additional symmetric matrix." + ) + return jax.numpy.linalg.eigvalsh(a, UPLO=UPLO) + + return eigvalsh @jax_funcify.register(Cholesky) diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 53e154facc..827666d37f 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -163,3 +163,34 @@ def test_jax_block_diag_blockwise(): np.random.normal(size=(5, 3, 3)).astype(config.floatX), ], ) + + +@pytest.mark.parametrize("lower", [False, True]) +def test_jax_eigvalsh(lower): + A = matrix("A") + B = matrix("B") + + out = pt_slinalg.eigvalsh(A, B, lower=lower) + out_fg = FunctionGraph([A, B], [out]) + + with pytest.raises(NotImplementedError): + compare_jax_and_py( + out_fg, + [ + np.array( + [[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]] + ).astype(config.floatX), + np.array( + [[10, 0, 1, 3], [0, 12, 7, 8], [1, 7, 14, 2], [3, 8, 2, 16]] + ).astype(config.floatX), + ], + ) + compare_jax_and_py( + out_fg, + [ + np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype( + config.floatX + ), + None, + ], + ) From ba2d5589ea90be520a713afe3270c3d18f4c98c3 Mon Sep 17 00:00:00 2001 From: Pham Nguyen Hung <97870091+HangenYuu@users.noreply.github.com> Date: Fri, 28 Jun 2024 17:27:32 +0800 Subject: [PATCH 2/2] Update pytensor/link/jax/dispatch/slinalg.py Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- pytensor/link/jax/dispatch/slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index c3b0b1200d..ca362e4531 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -20,7 +20,7 @@ def jax_funcify_Eigvalsh(op, **kwargs): def eigvalsh(a, b): if b is not None: raise NotImplementedError( - "jax.numpy.linalg.eigvalsh does not support computing eigenvalues with additional symmetric matrix." + "jax.numpy.linalg.eigvalsh does not support generalized eigenvector problems (b != None)" ) return jax.numpy.linalg.eigvalsh(a, UPLO=UPLO)