diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 93a50d87..97abe076 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -377,9 +377,13 @@ def your_function(x, y): elif use_compat is False: import jax.numpy as jnp else: - # jax.experimental.array_api is already an array namespace. We do - # not have a wrapper submodule for it. - import jax.experimental.array_api as jnp + # JAX v0.4.32 and newer implements the array API directly in jax.numpy. + # For older JAX versions, it is available via jax.experimental.array_api. + import jax.numpy + if hasattr(jax.numpy, "__array_api_version__"): + jnp = jax.numpy + else: + import jax.experimental.array_api as jnp namespaces.add(jnp) elif is_pydata_sparse_array(x): if use_compat is True: @@ -613,8 +617,9 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] return x raise ValueError(f"Unsupported device {device!r}") elif is_jax_array(x): - # This import adds to_device to x - import jax.experimental.array_api # noqa: F401 + if not hasattr(x, "__array_namespace__"): + # In JAX v0.4.31 and older, this import adds to_device method to x. + import jax.experimental.array_api # noqa: F401 return x.to_device(device, stream=stream) elif is_pydata_sparse_array(x) and device == _device(x): # Perform trivial check to return the same array if diff --git a/docs/index.md b/docs/index.md index a78e9314..b268e61a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -63,7 +63,8 @@ import array_api_compat.dask as da ```{note} There are no `array_api_compat` submodules for JAX, sparse, or ndonnx. These support for these libraries is contained in the libraries themselves (JAX -support is in the `jax.experimental.array_api` module). The +support is in the `jax.numpy` module in JAX v0.4.32 or newer, and in the +`jax.experimental.array_api` module for older JAX versions). The array-api-compat support for these libraries consists of supporting them in the [helper functions](helper-functions). ``` diff --git a/tests/_helpers.py b/tests/_helpers.py index e52c205d..5f8cd74f 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -18,7 +18,11 @@ def import_(library, wrapper=False): pytest.importorskip(library) if wrapper: if 'jax' in library: - library = 'jax.experimental.array_api' + # JAX v0.4.32 implements the array API directly in jax.numpy + # Older jax versions use jax.experimental.array_api + jax_numpy = import_module("jax.numpy") + if not hasattr(jax_numpy, "__array_api_version__"): + library = 'jax.experimental.array_api' elif library.startswith('sparse'): library = 'sparse' else: diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 8707b05a..e35e31e1 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -26,8 +26,14 @@ def test_array_namespace(library, api_version, use_compat): if use_compat is False or use_compat is None and library not in wrapped_libraries: if library == "jax.numpy" and use_compat is None: - import jax.experimental.array_api - assert namespace == jax.experimental.array_api + import jax.numpy + if hasattr(jax.numpy, "__array_api_version__"): + # JAX v0.4.32 or later uses jax.numpy directly + assert namespace == jax.numpy + else: + # JAX v0.4.31 or earlier uses jax.experimental.array_api + import jax.experimental.array_api + assert namespace == jax.experimental.array_api else: assert namespace == xp else: @@ -58,8 +64,11 @@ def test_array_namespace(library, api_version, use_compat): assert 'jax.experimental.array_api' not in sys.modules namespace = array_api_compat.array_namespace(array, api_version={api_version!r}) -import jax.experimental.array_api -assert namespace == jax.experimental.array_api +if hasattr(jax.numpy, '__array_api_version__'): + assert namespace == jax.numpy +else: + import jax.experimental.array_api + assert namespace == jax.experimental.array_api """ subprocess.run([sys.executable, "-c", code], check=True)