Skip to content

Commit 52ef9ee

Browse files
committed
Use importorskip
1 parent b2f9557 commit 52ef9ee

File tree

3 files changed

+17
-18
lines changed

3 files changed

+17
-18
lines changed

tests/test_array_namespace.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
1+
import pytest
2+
13
import array_api_compat
24
from array_api_compat import array_namespace
35

4-
from ._helpers import import_
5-
6-
import pytest
76

87
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
9-
@pytest.mark.parametrize("api_version", [None, '2021.12'])
8+
@pytest.mark.parametrize("api_version", [None, "2021.12"])
109
def test_array_namespace(library, api_version):
11-
lib = import_(library)
10+
lib = pytest.importorskip(library)
1211

1312
array = lib.asarray([1.0, 2.0, 3.0])
1413
namespace = array_api_compat.array_namespace(array, api_version=api_version)
1514

16-
if 'array_api' in library:
15+
if "array_api" in library:
1716
assert namespace == lib
1817
else:
1918
if library == "dask.array":
@@ -23,21 +22,22 @@ def test_array_namespace(library, api_version):
2322

2423

2524
def test_array_namespace_errors():
25+
np = pytest.importorskip("numpy")
26+
2627
pytest.raises(TypeError, lambda: array_namespace([1]))
2728
pytest.raises(TypeError, lambda: array_namespace())
2829

29-
import numpy as np
3030
x = np.asarray([1, 2])
31-
3231
pytest.raises(TypeError, lambda: array_namespace((x, x)))
3332
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
3433

35-
import torch
36-
y = torch.asarray([1, 2])
3734

38-
pytest.raises(TypeError, lambda: array_namespace(x, y))
35+
def test_array_namespace_errors_torch():
36+
torch = pytest.importorskip("torch")
3937

40-
pytest.raises(ValueError, lambda: array_namespace(x, api_version='2022.12'))
38+
y = torch.asarray([1, 2])
39+
pytest.raises(TypeError, lambda: array_namespace(x, y))
40+
pytest.raises(ValueError, lambda: array_namespace(x, api_version="2022.12"))
4141

4242

4343
def test_get_namespace():

tests/test_common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from ._helpers import import_
21
from array_api_compat import to_device
32

43
import pytest
@@ -11,7 +10,7 @@ def test_to_device_host(library):
1110
# for DtoH transfers; ensure that we support a portable
1211
# shim for common array libs
1312
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
14-
xp = import_('array_api_compat.' + library)
13+
xp = pytest.importorskip('array_api_compat.' + library)
1514
expected = np.array([1, 2, 3])
1615
x = xp.asarray([1, 2, 3])
1716
x = to_device(x, "cpu")

tests/test_isdtype.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
non-spec dtypes
44
"""
55

6-
from ._helpers import import_
7-
86
import pytest
97

8+
from ._helpers import import_
9+
1010
# Check the known dtypes by their string names
1111

1212
def _spec_dtypes(library):
@@ -66,7 +66,7 @@ def isdtype_(dtype_, kind):
6666

6767
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
6868
def test_isdtype_spec_dtypes(library):
69-
xp = import_('array_api_compat.' + library)
69+
xp = pytest.importorskip('array_api_compat.' + library)
7070

7171
isdtype = xp.isdtype
7272

@@ -101,7 +101,7 @@ def test_isdtype_spec_dtypes(library):
101101
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
102102
@pytest.mark.parametrize("dtype_", additional_dtypes)
103103
def test_isdtype_additional_dtypes(library, dtype_):
104-
xp = import_('array_api_compat.' + library)
104+
xp = pytest.importorskip('array_api_compat.' + library)
105105

106106
isdtype = xp.isdtype
107107

0 commit comments

Comments
 (0)