We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9c8bed6 commit 6d59ae8Copy full SHA for 6d59ae8
array_api_compat/common/_helpers.py
@@ -125,9 +125,8 @@ def your_function(x, y):
125
raise TypeError("_use_compat cannot be False if input array is a dask array!")
126
elif is_jax_array(x):
127
_check_api_version(api_version)
128
- # jax.numpy is already an array namespace, but requires this
129
- # side-effecting import for __array_namespace__ and some other
130
- # things to be defined.
+ # jax.experimental.array_api is already an array namespace. We do
+ # not have a wrapper submodule for it.
131
import jax.experimental.array_api as jnp
132
namespaces.add(jnp)
133
elif hasattr(x, '__array_namespace__'):
0 commit comments