We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d57c671 commit fd39ebbCopy full SHA for fd39ebb
array_api_compat/common/_helpers.py
@@ -377,9 +377,13 @@ def your_function(x, y):
377
elif use_compat is False:
378
import jax.numpy as jnp
379
else:
380
- # jax.experimental.array_api is already an array namespace. We do
381
- # not have a wrapper submodule for it.
382
- import jax.experimental.array_api as jnp
+ # JAX v0.4.32 and newer implements the array API directly.
+ # For older JAX versions, it is available via jax.experimental.array_api.
+ import jax.numpy
383
+ if hasattr(jax.numpy, "__array_api_version__"):
384
+ jnp = jax.numpy
385
+ else:
386
+ import jax.experimental.array_api as jnp
387
namespaces.add(jnp)
388
elif is_pydata_sparse_array(x):
389
if use_compat is True:
0 commit comments