Skip to content

Commit db667ea

Browse files
committed
Move tests in test_common.py to test_helpers.py
1 parent 049d557 commit db667ea

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

tests/test_common.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

tests/test_helpers.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from ._helpers import import_
66

77
import pytest
8+
import numpy as np
9+
from numpy.testing import assert_allclose
810

911
is_functions = {
1012
'numpy': 'is_numpy_array',
@@ -44,3 +46,28 @@ def test_device(library):
4446

4547
x2 = to_device(x, dev)
4648
assert device(x) == device(x2)
49+
50+
51+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
52+
def test_to_device_host(library):
53+
# Test that "cpu" device works. Note: this isn't actually supported by the
54+
# standard yet. See https://github.com/data-apis/array-api/issues/626.
55+
56+
# different libraries have different semantics
57+
# for DtoH transfers; ensure that we support a portable
58+
# shim for common array libs
59+
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
60+
if library == "jax.numpy":
61+
xp = import_('jax.experimental.array_api')
62+
else:
63+
xp = import_('array_api_compat.' + library)
64+
65+
expected = np.array([1, 2, 3])
66+
x = xp.asarray([1, 2, 3])
67+
x = to_device(x, "cpu")
68+
# torch will return a genuine Device object, but
69+
# the other libs will do something different with
70+
# a `device(x)` query; however, what's really important
71+
# here is that we can test portably after calling
72+
# to_device(x, "cpu") to return to host
73+
assert_allclose(x, expected)

0 commit comments

Comments
 (0)