diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index bf1a93ce5b..0807006d06 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -22,6 +22,7 @@ ) from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.shape import Shape_i +from pytensor.tensor.sort import SortOp ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange` @@ -205,3 +206,11 @@ def tri(*args): return jnp.tri(*args, dtype=op.dtype) return tri + + +@jax_funcify.register(SortOp) +def jax_funcify_Sort(op, **kwargs): + def sort(arr, axis): + return jnp.sort(arr, axis=axis) + + return sort diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index 94755ddf2c..2ab6a0e11b 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -218,6 +218,15 @@ def test_tri(): compare_jax_and_py(fgraph, []) +@pytest.mark.parametrize("axis", [None, -1]) +def test_sort(axis): + x = matrix("x", shape=(2, 2), dtype="float64") + out = pytensor.tensor.sort(x, axis=axis) + fgraph = FunctionGraph([x], [out]) + arr = np.array([[1.0, 4.0], [5.0, 2.0]]) + compare_jax_and_py(fgraph, [arr]) + + def test_tri_nonconcrete(): """JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values."""