From faa78bbb5522086b713f2f681d9532465e03bef2 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sun, 3 Mar 2024 13:02:25 +0530 Subject: [PATCH 1/2] Add JAX support for SortOp --- pytensor/link/jax/dispatch/tensor_basic.py | 9 +++++++++ tests/link/jax/test_tensor_basic.py | 8 ++++++++ 2 files changed, 17 insertions(+) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index bf1a93ce5b..1379badb75 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -22,6 +22,7 @@ ) from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.shape import Shape_i +from pytensor.tensor.sort import SortOp ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange` @@ -205,3 +206,11 @@ def tri(*args): return jnp.tri(*args, dtype=op.dtype) return tri + + +@jax_funcify.register(SortOp) +def jax_funcify_Sort(op, **kwargs): + def sort(arr, *args): + return jnp.sort(arr, *args) + + return sort diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index 94755ddf2c..e9bee312cd 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -218,6 +218,14 @@ def test_tri(): compare_jax_and_py(fgraph, []) +def test_sort(): + x = matrix("x") + out = pytensor.tensor.sort(x) + fgraph = FunctionGraph([x], [out]) + arr = np.array([[1.0, 4.0], [5.0, 2.0]]) + compare_jax_and_py(fgraph, [arr]) + + def test_tri_nonconcrete(): """JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values.""" From 1ef52384a67cc04a9d090bfc8661dfc134c24091 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Mon, 4 Mar 2024 22:16:25 +0530 Subject: [PATCH 2/2] Increase readability --- pytensor/link/jax/dispatch/tensor_basic.py | 4 ++-- tests/link/jax/test_tensor_basic.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index 1379badb75..0807006d06 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -210,7 +210,7 @@ def tri(*args): @jax_funcify.register(SortOp) def jax_funcify_Sort(op, **kwargs): - def sort(arr, *args): - return jnp.sort(arr, *args) + def sort(arr, axis): + return jnp.sort(arr, axis=axis) return sort diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index e9bee312cd..2ab6a0e11b 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -218,9 +218,10 @@ def test_tri(): compare_jax_and_py(fgraph, []) -def test_sort(): - x = matrix("x") - out = pytensor.tensor.sort(x) +@pytest.mark.parametrize("axis", [None, -1]) +def test_sort(axis): + x = matrix("x", shape=(2, 2), dtype="float64") + out = pytensor.tensor.sort(x, axis=axis) fgraph = FunctionGraph([x], [out]) arr = np.array([[1.0, 4.0], [5.0, 2.0]]) compare_jax_and_py(fgraph, [arr])