Skip to content

Commit 2ec609d

Browse files
authored
Merge pull request #40 from tylerjereddy/treddy_to_device_cpu_cupy
ENH: support CuPy to_device "cpu"
2 parents 9ef7f72 + b9ceea9 commit 2ec609d

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

array_api_compat/common/_helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ def _cupy_to_device(x, device, /, stream=None):
154154

155155
if device == x.device:
156156
return x
157+
elif device == "cpu":
158+
# allowing us to use `to_device(x, "cpu")`
159+
# is useful for portable test swapping between
160+
# host and device backends
161+
return x.get()
157162
elif not isinstance(device, _Device):
158163
raise ValueError(f"Unsupported device {device!r}")
159164
else:

tests/test_common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from ._helpers import import_
2+
from array_api_compat import to_device, device
3+
4+
import pytest
5+
import numpy as np
6+
from numpy.testing import assert_allclose
7+
8+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
9+
def test_to_device_host(library):
10+
# different libraries have different semantics
11+
# for DtoH transfers; ensure that we support a portable
12+
# shim for common array libs
13+
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
14+
xp = import_('array_api_compat.' + library)
15+
expected = np.array([1, 2, 3])
16+
x = xp.asarray([1, 2, 3])
17+
x = to_device(x, "cpu")
18+
# torch will return a genuine Device object, but
19+
# the other libs will do something different with
20+
# a `device(x)` query; however, what's really important
21+
# here is that we can test portably after calling
22+
# to_device(x, "cpu") to return to host
23+
assert_allclose(x, expected)

0 commit comments

Comments
 (0)