Skip to content

Commit ddb313e

Browse files
committed
Use the native jax to_device() method
1 parent ce07cd9 commit ddb313e

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

array_api_compat/common/_helpers.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def _check_device(xp, device):
153153
if device not in ["cpu", None]:
154154
raise ValueError(f"Unsupported device for NumPy: {device!r}")
155155

156-
# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray
156+
# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
157157
# or cupy.ndarray. They are not included in array objects of this library
158158
# because this library just reuses the respective ndarray classes without
159159
# wrapping or subclassing them. These helper functions can be used instead of
@@ -230,12 +230,6 @@ def _torch_to_device(x, device, /, stream=None):
230230
raise NotImplementedError
231231
return x.to(device)
232232

233-
def _jax_to_device(x, device, /, stream=None):
234-
import jax
235-
if stream is not None:
236-
raise NotImplementedError
237-
return jax.device_put(x, device)
238-
239233
def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array":
240234
"""
241235
Copy the array from the device on which it currently resides to the specified ``device``.
@@ -276,7 +270,9 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
276270
return x
277271
raise ValueError(f"Unsupported device {device!r}")
278272
elif is_jax_array(x):
279-
return _jax_to_device(x, device, stream=stream)
273+
# This import adds to_device to x
274+
import jax.experimental.array_api
275+
return x.to_device(device, stream=stream)
280276
return x.to_device(device, stream=stream)
281277

282278
def size(x):

0 commit comments

Comments
 (0)