Skip to content

Commit fd39ebb

Browse files
committed
Future-proof JAX array API import
1 parent d57c671 commit fd39ebb

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

array_api_compat/common/_helpers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,13 @@ def your_function(x, y):
377377
elif use_compat is False:
378378
import jax.numpy as jnp
379379
else:
380-
# jax.experimental.array_api is already an array namespace. We do
381-
# not have a wrapper submodule for it.
382-
import jax.experimental.array_api as jnp
380+
# JAX v0.4.32 and newer implements the array API directly.
381+
# For older JAX versions, it is available via jax.experimental.array_api.
382+
import jax.numpy
383+
if hasattr(jax.numpy, "__array_api_version__"):
384+
jnp = jax.numpy
385+
else:
386+
import jax.experimental.array_api as jnp
383387
namespaces.add(jnp)
384388
elif is_pydata_sparse_array(x):
385389
if use_compat is True:

0 commit comments

Comments
 (0)