|
5 | 5 | from ._helpers import import_
|
6 | 6 |
|
7 | 7 | import pytest
|
| 8 | +import numpy as np |
| 9 | +from numpy.testing import assert_allclose |
8 | 10 |
|
9 | 11 | is_functions = {
|
10 | 12 | 'numpy': 'is_numpy_array',
|
@@ -44,3 +46,28 @@ def test_device(library):
|
44 | 46 |
|
45 | 47 | x2 = to_device(x, dev)
|
46 | 48 | assert device(x) == device(x2)
|
| 49 | + |
| 50 | + |
| 51 | +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) |
| 52 | +def test_to_device_host(library): |
| 53 | + # Test that "cpu" device works. Note: this isn't actually supported by the |
| 54 | + # standard yet. See https://github.com/data-apis/array-api/issues/626. |
| 55 | + |
| 56 | + # different libraries have different semantics |
| 57 | + # for DtoH transfers; ensure that we support a portable |
| 58 | + # shim for common array libs |
| 59 | + # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 |
| 60 | + if library == "jax.numpy": |
| 61 | + xp = import_('jax.experimental.array_api') |
| 62 | + else: |
| 63 | + xp = import_('array_api_compat.' + library) |
| 64 | + |
| 65 | + expected = np.array([1, 2, 3]) |
| 66 | + x = xp.asarray([1, 2, 3]) |
| 67 | + x = to_device(x, "cpu") |
| 68 | + # torch will return a genuine Device object, but |
| 69 | + # the other libs will do something different with |
| 70 | + # a `device(x)` query; however, what's really important |
| 71 | + # here is that we can test portably after calling |
| 72 | + # to_device(x, "cpu") to return to host |
| 73 | + assert_allclose(x, expected) |
0 commit comments