@@ -159,7 +159,16 @@ def _check_device(xp, device):
159
159
if device not in ["cpu" , None ]:
160
160
raise ValueError (f"Unsupported device for NumPy: { device !r} " )
161
161
162
- # device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
162
+ # Placeholder object to represent the dask device
163
+ # when the array backend is not the CPU.
164
+ # (since it is not easy to tell which device a dask array is on)
165
+ class _dask_device :
166
+ def __repr__ (self ):
167
+ return "DASK_DEVICE"
168
+
169
+ DASK_DEVICE = _dask_device ()
170
+
171
+ # device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
163
172
# or cupy.ndarray. They are not included in array objects of this library
164
173
# because this library just reuses the respective ndarray classes without
165
174
# wrapping or subclassing them. These helper functions can be used instead of
@@ -179,11 +188,19 @@ def device(x: Array, /) -> Device:
179
188
out: device
180
189
a ``device`` object (see the "Device Support" section of the array API specification).
181
190
"""
182
- if is_numpy_array (x ) or is_dask_array (x ):
183
- # TODO: dask technically can support GPU arrays
184
- # Detecting the array backend isn't easy for dask, though, so just return CPU for now
191
+ if is_numpy_array (x ):
185
192
return "cpu"
186
- if is_jax_array (x ):
193
+ elif is_dask_array (x ):
194
+ # Peek at the metadata of the jax array to determine type
195
+ try :
196
+ import numpy as np
197
+ if isinstance (x ._meta , np .ndarray ):
198
+ # Must be on CPU since backed by numpy
199
+ return "cpu"
200
+ except ImportError :
201
+ pass
202
+ return DASK_DEVICE
203
+ elif is_jax_array (x ):
187
204
# JAX has .device() as a method, but it is being deprecated so that it
188
205
# can become a property, in accordance with the standard. In order for
189
206
# this function to not break when JAX makes the flip, we check for
0 commit comments