Skip to content

Commit 049d557

Browse files
committed
Allow to_device(x, "cpu") for JAX arrays
1 parent 701a5ef commit 049d557

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

array_api_compat/common/_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
272272
elif is_jax_array(x):
273273
# This import adds to_device to x
274274
import jax.experimental.array_api
275+
if device == 'cpu':
276+
device = jax.devices('cpu')[0]
275277
return x.to_device(device, stream=stream)
276278
return x.to_device(device, stream=stream)
277279

tests/test_common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55
import numpy as np
66
from numpy.testing import assert_allclose
77

8-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
8+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
99
def test_to_device_host(library):
1010
# different libraries have different semantics
1111
# for DtoH transfers; ensure that we support a portable
1212
# shim for common array libs
1313
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
14-
xp = import_('array_api_compat.' + library)
14+
if library == "jax.numpy":
15+
xp = import_('jax.experimental.array_api')
16+
else:
17+
xp = import_('array_api_compat.' + library)
18+
1519
expected = np.array([1, 2, 3])
1620
x = xp.asarray([1, 2, 3])
1721
x = to_device(x, "cpu")

0 commit comments

Comments
 (0)