Skip to content

Commit b0a323d

Browse files
committed
Rename import_ to import_or_skip_cupy
1 parent ff51015 commit b0a323d

File tree

5 files changed

+27
-19
lines changed

5 files changed

+27
-19
lines changed

tests/_helpers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import pytest
44

5-
def import_(library):
6-
if 'cupy' in library:
5+
6+
def import_or_skip_cupy(library):
7+
if "cupy" in library:
78
return pytest.importorskip(library)
89
return import_module(library)

tests/test_array_namespace.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1+
import numpy as np
12
import pytest
3+
import torch
24

35
import array_api_compat
46
from array_api_compat import array_namespace
57

8+
from ._helpers import import_or_skip_cupy
9+
610

711
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
812
@pytest.mark.parametrize("api_version", [None, "2021.12"])
913
def test_array_namespace(library, api_version):
10-
lib = pytest.importorskip(library)
14+
xp = import_or_skip_cupy(library)
1115

12-
array = lib.asarray([1.0, 2.0, 3.0])
16+
array = xp.asarray([1.0, 2.0, 3.0])
1317
namespace = array_api_compat.array_namespace(array, api_version=api_version)
1418

1519
if "array_api" in library:
16-
assert namespace == lib
20+
assert namespace == xp
1721
else:
1822
if library == "dask.array":
1923
assert namespace == array_api_compat.dask.array
@@ -22,8 +26,6 @@ def test_array_namespace(library, api_version):
2226

2327

2428
def test_array_namespace_errors():
25-
np = pytest.importorskip("numpy")
26-
2729
pytest.raises(TypeError, lambda: array_namespace([1]))
2830
pytest.raises(TypeError, lambda: array_namespace())
2931

@@ -33,9 +35,6 @@ def test_array_namespace_errors():
3335

3436

3537
def test_array_namespace_errors_torch():
36-
torch = pytest.importorskip("torch")
37-
np = pytest.importorskip("numpy")
38-
3938
y = torch.asarray([1, 2])
4039
x = np.asarray([1, 2])
4140
pytest.raises(TypeError, lambda: array_namespace(x, y))

tests/test_common.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1-
from array_api_compat import to_device
2-
3-
import pytest
41
import numpy as np
2+
import pytest
53
from numpy.testing import assert_allclose
64

5+
from array_api_compat import to_device
6+
7+
from ._helpers import import_or_skip_cupy
8+
9+
710
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
811
def test_to_device_host(library):
912
# different libraries have different semantics
1013
# for DtoH transfers; ensure that we support a portable
1114
# shim for common array libs
1215
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
13-
xp = pytest.importorskip('array_api_compat.' + library)
16+
xp = import_or_skip_cupy("array_api_compat." + library)
17+
1418
expected = np.array([1, 2, 3])
1519
x = xp.asarray([1, 2, 3])
1620
x = to_device(x, "cpu")

tests/test_isdtype.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import pytest
77

8+
from ._helpers import import_or_skip_cupy
9+
810
# Check the known dtypes by their string names
911

1012
def _spec_dtypes(library):
@@ -59,12 +61,12 @@ def isdtype_(dtype_, kind):
5961
res = dtype_categories[kind](dtype_)
6062
else:
6163
res = dtype_ == kind
62-
assert isinstance(res, bool)
64+
assert type(res) is bool # noqa: E721
6365
return res
6466

6567
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
6668
def test_isdtype_spec_dtypes(library):
67-
xp = pytest.importorskip('array_api_compat.' + library)
69+
xp = import_or_skip_cupy('array_api_compat.' + library)
6870

6971
isdtype = xp.isdtype
7072

@@ -99,7 +101,7 @@ def test_isdtype_spec_dtypes(library):
99101
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
100102
@pytest.mark.parametrize("dtype_", additional_dtypes)
101103
def test_isdtype_additional_dtypes(library, dtype_):
102-
xp = pytest.importorskip('array_api_compat.' + library)
104+
xp = import_or_skip_cupy('array_api_compat.' + library)
103105

104106
isdtype = xp.isdtype
105107

tests/test_vendoring.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,21 @@
33

44
def test_vendoring_numpy():
55
from vendor_test import uses_numpy
6+
67
uses_numpy._test_numpy()
78

89

910
def test_vendoring_cupy():
1011
pytest.importorskip("cupy")
1112

1213
from vendor_test import uses_cupy
14+
1315
uses_cupy._test_cupy()
1416

17+
1518
def test_vendoring_torch():
16-
pytest.importorskip("torch")
17-
1819
from vendor_test import uses_torch
20+
1921
uses_torch._test_torch()
2022

2123
def test_vendoring_dask():

0 commit comments

Comments
 (0)