Skip to content

Commit 6004b97

Browse files
committed
Add a basic test for device() and to_device()
1 parent ddb313e commit 6004b97

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

tests/test_helpers.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array,
2-
is_dask_array, is_jax_array, is_array_api_obj)
2+
is_dask_array, is_jax_array, is_array_api_obj,
3+
device, to_device)
34

45
from ._helpers import import_
56

@@ -24,3 +25,19 @@ def test_is_xp_array(library, func):
2425
assert is_func(x) == (func == is_functions[library])
2526

2627
assert is_array_api_obj(x)
28+
29+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
30+
def test_device(library):
31+
if library == "jax.numpy":
32+
xp = import_('jax.experimental.array_api')
33+
else:
34+
xp = import_('array_api_compat.' + library)
35+
36+
# We can't test much for device() and to_device() other than that
37+
# x.to_device(x.device) works.
38+
39+
x = xp.asarray([1, 2, 3])
40+
dev = device(x)
41+
42+
x2 = to_device(x, dev)
43+
assert device(x) == device(x2)

0 commit comments

Comments
 (0)