Skip to content

Commit aafbbaa

Browse files
committed
Remove "cpu" device from jax to_device()
1 parent db667ea commit aafbbaa

File tree

2 files changed

+2
-10
lines changed

2 files changed

+2
-10
lines changed

array_api_compat/common/_helpers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,6 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
272272
elif is_jax_array(x):
273273
# This import adds to_device to x
274274
import jax.experimental.array_api
275-
if device == 'cpu':
276-
device = jax.devices('cpu')[0]
277275
return x.to_device(device, stream=stream)
278276
return x.to_device(device, stream=stream)
279277

tests/test_helpers.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,13 @@ def test_device(library):
4848
assert device(x) == device(x2)
4949

5050

51-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
51+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
5252
def test_to_device_host(library):
53-
# Test that "cpu" device works. Note: this isn't actually supported by the
54-
# standard yet. See https://github.com/data-apis/array-api/issues/626.
55-
5653
# different libraries have different semantics
5754
# for DtoH transfers; ensure that we support a portable
5855
# shim for common array libs
5956
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
60-
if library == "jax.numpy":
61-
xp = import_('jax.experimental.array_api')
62-
else:
63-
xp = import_('array_api_compat.' + library)
57+
xp = import_('array_api_compat.' + library)
6458

6559
expected = np.array([1, 2, 3])
6660
x = xp.asarray([1, 2, 3])

0 commit comments

Comments
 (0)