Skip to content

Commit ce07cd9

Browse files
committed
Use jax.Array instead of jax.numpy.ndarray in is_jax_array
1 parent 6d59ae8 commit ce07cd9

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

array_api_compat/common/_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def is_jax_array(x):
5555
if 'jax' not in sys.modules:
5656
return False
5757

58-
import jax.numpy
58+
import jax
5959

60-
return isinstance(x, jax.numpy.ndarray)
60+
return isinstance(x, jax.Array)
6161

6262
def is_array_api_obj(x):
6363
"""

0 commit comments

Comments
 (0)