From b9ceea9d452addc041e57a5e01c6473ae0898958 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Fri, 28 Apr 2023 11:57:05 -0600 Subject: [PATCH] ENH: support CuPy to_device "cpu" * allow `to_device(x, "cpu")` for CuPy--this makes it easier to write portable unit tests where the expected value is on the host (a NumPy array) * otherwise, you can end up writing different shims in downstream libraries to move the array-like back to the host (`.get()` for CuPy...); if CuPy conforms to the standard, we shouldn't need this shim long-term --- array_api_compat/common/_helpers.py | 5 +++++ tests/test_common.py | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+) create mode 100644 tests/test_common.py diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index e6adc948..727bf5a2 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -154,6 +154,11 @@ def _cupy_to_device(x, device, /, stream=None): if device == x.device: return x + elif device == "cpu": + # allowing us to use `to_device(x, "cpu")` + # is useful for portable test swapping between + # host and device backends + return x.get() elif not isinstance(device, _Device): raise ValueError(f"Unsupported device {device!r}") else: diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 00000000..86886b7f --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,23 @@ +from ._helpers import import_ +from array_api_compat import to_device, device + +import pytest +import numpy as np +from numpy.testing import assert_allclose + +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +def test_to_device_host(library): + # different libraries have different semantics + # for DtoH transfers; ensure that we support a portable + # shim for common array libs + # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 + xp = import_('array_api_compat.' + library) + expected = np.array([1, 2, 3]) + x = xp.asarray([1, 2, 3]) + x = to_device(x, "cpu") + # torch will return a genuine Device object, but + # the other libs will do something different with + # a `device(x)` query; however, what's really important + # here is that we can test portably after calling + # to_device(x, "cpu") to return to host + assert_allclose(x, expected)