From fd39ebbc15567ab1174c0e8fdfb3c99611658371 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 1 Aug 2024 11:35:30 -0700 Subject: [PATCH 1/4] Future-proof JAX array API import --- array_api_compat/common/_helpers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 93a50d87..bdfedcaa 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -377,9 +377,13 @@ def your_function(x, y): elif use_compat is False: import jax.numpy as jnp else: - # jax.experimental.array_api is already an array namespace. We do - # not have a wrapper submodule for it. - import jax.experimental.array_api as jnp + # JAX v0.4.32 and newer implements the array API directly. + # For older JAX versions, it is available via jax.experimental.array_api. + import jax.numpy + if hasattr(jax.numpy, "__array_api_version__"): + jnp = jax.numpy + else: + import jax.experimental.array_api as jnp namespaces.add(jnp) elif is_pydata_sparse_array(x): if use_compat is True: From 1419f08bb443477232b103acd7f5a1e4ec1e7cbb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 1 Aug 2024 13:10:43 -0700 Subject: [PATCH 2/4] Account for other locations that reference jax.experimental.array_api --- array_api_compat/common/_helpers.py | 9 +++++---- docs/index.md | 3 ++- tests/_helpers.py | 6 +++++- tests/test_array_namespace.py | 16 ++++++++++++---- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index bdfedcaa..269b4a6f 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -377,7 +377,7 @@ def your_function(x, y): elif use_compat is False: import jax.numpy as jnp else: - # JAX v0.4.32 and newer implements the array API directly. + # JAX v0.4.32 and newer implements the array API directly in jax.numpy. # For older JAX versions, it is available via jax.experimental.array_api. import jax.numpy if hasattr(jax.numpy, "__array_api_version__"): @@ -617,9 +617,10 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] return x raise ValueError(f"Unsupported device {device!r}") elif is_jax_array(x): - # This import adds to_device to x - import jax.experimental.array_api # noqa: F401 - return x.to_device(device, stream=stream) + if not hasattr(x, "__array_namespace__"): + # In JAX v0.4.31 and older, this import adds to_device method to x. + import jax.experimental.array_api # noqa: F401 + return x.to_device(device, stream=stream) elif is_pydata_sparse_array(x) and device == _device(x): # Perform trivial check to return the same array if # device is same instead of err-ing. diff --git a/docs/index.md b/docs/index.md index a78e9314..b268e61a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -63,7 +63,8 @@ import array_api_compat.dask as da ```{note} There are no `array_api_compat` submodules for JAX, sparse, or ndonnx. These support for these libraries is contained in the libraries themselves (JAX -support is in the `jax.experimental.array_api` module). The +support is in the `jax.numpy` module in JAX v0.4.32 or newer, and in the +`jax.experimental.array_api` module for older JAX versions). The array-api-compat support for these libraries consists of supporting them in the [helper functions](helper-functions). ``` diff --git a/tests/_helpers.py b/tests/_helpers.py index e52c205d..5f8cd74f 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -18,7 +18,11 @@ def import_(library, wrapper=False): pytest.importorskip(library) if wrapper: if 'jax' in library: - library = 'jax.experimental.array_api' + # JAX v0.4.32 implements the array API directly in jax.numpy + # Older jax versions use jax.experimental.array_api + jax_numpy = import_module("jax.numpy") + if not hasattr(jax_numpy, "__array_api_version__"): + library = 'jax.experimental.array_api' elif library.startswith('sparse'): library = 'sparse' else: diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 8707b05a..a9ae8020 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -26,8 +26,13 @@ def test_array_namespace(library, api_version, use_compat): if use_compat is False or use_compat is None and library not in wrapped_libraries: if library == "jax.numpy" and use_compat is None: - import jax.experimental.array_api - assert namespace == jax.experimental.array_api + if hasattr(jax.numpy, "__array_api_version__"): + # JAX v0.4.32 or later uses jax.numpy directly + assert namespace == jax.numpy + else: + # JAX v0.4.31 or earlier uses jax.experimental.array_api + import jax.experimental.array_api + assert namespace == jax.experimental.array_api else: assert namespace == xp else: @@ -58,8 +63,11 @@ def test_array_namespace(library, api_version, use_compat): assert 'jax.experimental.array_api' not in sys.modules namespace = array_api_compat.array_namespace(array, api_version={api_version!r}) -import jax.experimental.array_api -assert namespace == jax.experimental.array_api +if hasattr(jax.numpy, '__array_api_version__'): + assert namespace == jax.numpy +else: + import jax.experimental.array_api + assert namespace == jax.experimental.array_api """ subprocess.run([sys.executable, "-c", code], check=True) From 096cec7203fa20b9596954e183a6f209f0bb4045 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 1 Aug 2024 13:12:08 -0700 Subject: [PATCH 3/4] fix typo --- array_api_compat/common/_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 269b4a6f..97abe076 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -620,7 +620,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] if not hasattr(x, "__array_namespace__"): # In JAX v0.4.31 and older, this import adds to_device method to x. import jax.experimental.array_api # noqa: F401 - return x.to_device(device, stream=stream) + return x.to_device(device, stream=stream) elif is_pydata_sparse_array(x) and device == _device(x): # Perform trivial check to return the same array if # device is same instead of err-ing. From d570ebc7bb9c4174d97f631994528bf3a0eef787 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 1 Aug 2024 13:15:42 -0700 Subject: [PATCH 4/4] fix ruff --- tests/test_array_namespace.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index a9ae8020..e35e31e1 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -26,6 +26,7 @@ def test_array_namespace(library, api_version, use_compat): if use_compat is False or use_compat is None and library not in wrapped_libraries: if library == "jax.numpy" and use_compat is None: + import jax.numpy if hasattr(jax.numpy, "__array_api_version__"): # JAX v0.4.32 or later uses jax.numpy directly assert namespace == jax.numpy