@@ -153,7 +153,7 @@ def _check_device(xp, device):
153
153
if device not in ["cpu" , None ]:
154
154
raise ValueError (f"Unsupported device for NumPy: { device !r} " )
155
155
156
- # device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray
156
+ # device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
157
157
# or cupy.ndarray. They are not included in array objects of this library
158
158
# because this library just reuses the respective ndarray classes without
159
159
# wrapping or subclassing them. These helper functions can be used instead of
@@ -230,12 +230,6 @@ def _torch_to_device(x, device, /, stream=None):
230
230
raise NotImplementedError
231
231
return x .to (device )
232
232
233
- def _jax_to_device (x , device , / , stream = None ):
234
- import jax
235
- if stream is not None :
236
- raise NotImplementedError
237
- return jax .device_put (x , device )
238
-
239
233
def to_device (x : "Array" , device : "Device" , / , * , stream : "Optional[Union[int, Any]]" = None ) -> "Array" :
240
234
"""
241
235
Copy the array from the device on which it currently resides to the specified ``device``.
@@ -276,7 +270,9 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
276
270
return x
277
271
raise ValueError (f"Unsupported device { device !r} " )
278
272
elif is_jax_array (x ):
279
- return _jax_to_device (x , device , stream = stream )
273
+ # This import adds to_device to x
274
+ import jax .experimental .array_api
275
+ return x .to_device (device , stream = stream )
280
276
return x .to_device (device , stream = stream )
281
277
282
278
def size (x ):
0 commit comments