Skip to content

Commit 6bcc4a9

Browse files
committed
update
1 parent 327a1e2 commit 6bcc4a9

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

array_api_compat/common/_helpers.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,16 @@ def _check_device(xp, device):
159159
if device not in ["cpu", None]:
160160
raise ValueError(f"Unsupported device for NumPy: {device!r}")
161161

162-
# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
162+
# Placeholder object to represent the dask device
163+
# when the array backend is not the CPU.
164+
# (since it is not easy to tell which device a dask array is on)
165+
class _dask_device:
166+
def __repr__(self):
167+
return "DASK_DEVICE"
168+
169+
DASK_DEVICE = _dask_device()
170+
171+
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
163172
# or cupy.ndarray. They are not included in array objects of this library
164173
# because this library just reuses the respective ndarray classes without
165174
# wrapping or subclassing them. These helper functions can be used instead of
@@ -179,11 +188,19 @@ def device(x: Array, /) -> Device:
179188
out: device
180189
a ``device`` object (see the "Device Support" section of the array API specification).
181190
"""
182-
if is_numpy_array(x) or is_dask_array(x):
183-
# TODO: dask technically can support GPU arrays
184-
# Detecting the array backend isn't easy for dask, though, so just return CPU for now
191+
if is_numpy_array(x):
185192
return "cpu"
186-
if is_jax_array(x):
193+
elif is_dask_array(x):
194+
# Peek at the metadata of the jax array to determine type
195+
try:
196+
import numpy as np
197+
if isinstance(x._meta, np.ndarray):
198+
# Must be on CPU since backed by numpy
199+
return "cpu"
200+
except ImportError:
201+
pass
202+
return DASK_DEVICE
203+
elif is_jax_array(x):
187204
# JAX has .device() as a method, but it is being deprecated so that it
188205
# can become a property, in accordance with the standard. In order for
189206
# this function to not break when JAX makes the flip, we check for

0 commit comments

Comments
 (0)