diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index 73ddadc2a0..ca362e4531 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -1,7 +1,30 @@ import jax from pytensor.link.jax.dispatch.basic import jax_funcify -from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, Solve, SolveTriangular +from pytensor.tensor.slinalg import ( + BlockDiagonal, + Cholesky, + Eigvalsh, + Solve, + SolveTriangular, +) + + +@jax_funcify.register(Eigvalsh) +def jax_funcify_Eigvalsh(op, **kwargs): + if op.lower: + UPLO = "L" + else: + UPLO = "U" + + def eigvalsh(a, b): + if b is not None: + raise NotImplementedError( + "jax.numpy.linalg.eigvalsh does not support generalized eigenvector problems (b != None)" + ) + return jax.numpy.linalg.eigvalsh(a, UPLO=UPLO) + + return eigvalsh @jax_funcify.register(Cholesky) diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 53e154facc..827666d37f 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -163,3 +163,34 @@ def test_jax_block_diag_blockwise(): np.random.normal(size=(5, 3, 3)).astype(config.floatX), ], ) + + +@pytest.mark.parametrize("lower", [False, True]) +def test_jax_eigvalsh(lower): + A = matrix("A") + B = matrix("B") + + out = pt_slinalg.eigvalsh(A, B, lower=lower) + out_fg = FunctionGraph([A, B], [out]) + + with pytest.raises(NotImplementedError): + compare_jax_and_py( + out_fg, + [ + np.array( + [[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]] + ).astype(config.floatX), + np.array( + [[10, 0, 1, 3], [0, 12, 7, 8], [1, 7, 14, 2], [3, 8, 2, 16]] + ).astype(config.floatX), + ], + ) + compare_jax_and_py( + out_fg, + [ + np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype( + config.floatX + ), + None, + ], + )