diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index e6adc948..727bf5a2 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -154,6 +154,11 @@ def _cupy_to_device(x, device, /, stream=None): if device == x.device: return x + elif device == "cpu": + # allowing us to use `to_device(x, "cpu")` + # is useful for portable test swapping between + # host and device backends + return x.get() elif not isinstance(device, _Device): raise ValueError(f"Unsupported device {device!r}") else: diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 00000000..86886b7f --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,23 @@ +from ._helpers import import_ +from array_api_compat import to_device, device + +import pytest +import numpy as np +from numpy.testing import assert_allclose + +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +def test_to_device_host(library): + # different libraries have different semantics + # for DtoH transfers; ensure that we support a portable + # shim for common array libs + # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 + xp = import_('array_api_compat.' + library) + expected = np.array([1, 2, 3]) + x = xp.asarray([1, 2, 3]) + x = to_device(x, "cpu") + # torch will return a genuine Device object, but + # the other libs will do something different with + # a `device(x)` query; however, what's really important + # here is that we can test portably after calling + # to_device(x, "cpu") to return to host + assert_allclose(x, expected)