Skip to content

Commit fa758f7

Browse files
committed
Rename import_or_skip_cupy to import_ and move jax logic into it
1 parent 919ec41 commit fa758f7

File tree

4 files changed

+18
-20
lines changed

4 files changed

+18
-20
lines changed

tests/_helpers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33
import pytest
44

55

6-
def import_or_skip_cupy(library):
7-
if "cupy" in library:
6+
def import_(library, wrapper=False):
7+
if library == 'cupy':
88
return pytest.importorskip(library)
9+
10+
if wrapper:
11+
if 'jax' in library:
12+
library = 'jax.experimental.array_api'
13+
else:
14+
library = 'array_api_compat.' + library
15+
916
return import_module(library)

tests/test_array_namespace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import array_api_compat
66
from array_api_compat import array_namespace
77

8-
from ._helpers import import_or_skip_cupy
8+
from ._helpers import import_
99

1010

1111
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
1212
@pytest.mark.parametrize("api_version", [None, "2021.12"])
1313
def test_array_namespace(library, api_version):
14-
xp = import_or_skip_cupy(library)
14+
xp = import_(library)
1515

1616
array = xp.asarray([1.0, 2.0, 3.0])
1717
namespace = array_api_compat.array_namespace(array, api_version=api_version)

tests/test_helpers.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
is_dask_array, is_jax_array, is_array_api_obj,
33
device, to_device)
44

5-
from ._helpers import import_or_skip_cupy
5+
from ._helpers import import_
66

77
import pytest
88
import numpy as np
@@ -19,7 +19,7 @@
1919
@pytest.mark.parametrize('library', is_functions.keys())
2020
@pytest.mark.parametrize('func', is_functions.values())
2121
def test_is_xp_array(library, func):
22-
lib = import_or_skip_cupy(library)
22+
lib = import_(library)
2323
is_func = globals()[func]
2424

2525
x = lib.asarray([1, 2, 3])
@@ -33,10 +33,7 @@ def test_device(library):
3333
if library == "dask.array":
3434
pytest.xfail("device() needs to be fixed for dask")
3535

36-
if library == "jax.numpy":
37-
xp = import_or_skip_cupy('jax.experimental.array_api')
38-
else:
39-
xp = import_or_skip_cupy('array_api_compat.' + library)
36+
xp = import_(library, wrapper=True)
4037

4138
# We can't test much for device() and to_device() other than that
4239
# x.to_device(x.device) works.
@@ -54,7 +51,7 @@ def test_to_device_host(library):
5451
# for DtoH transfers; ensure that we support a portable
5552
# shim for common array libs
5653
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
57-
xp = import_or_skip_cupy('array_api_compat.' + library)
54+
xp = import_(library, wrapper=True)
5855

5956
expected = np.array([1, 2, 3])
6057
x = xp.asarray([1, 2, 3])

tests/test_isdtype.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77

8-
from ._helpers import import_or_skip_cupy
8+
from ._helpers import import_
99

1010
# Check the known dtypes by their string names
1111

@@ -66,10 +66,7 @@ def isdtype_(dtype_, kind):
6666

6767
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
6868
def test_isdtype_spec_dtypes(library):
69-
if library == "jax.numpy":
70-
xp = import_or_skip_cupy('jax.experimental.array_api')
71-
else:
72-
xp = import_or_skip_cupy('array_api_compat.' + library)
69+
xp = import_(library, wrapper=True)
7370

7471
isdtype = xp.isdtype
7572

@@ -104,10 +101,7 @@ def test_isdtype_spec_dtypes(library):
104101
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
105102
@pytest.mark.parametrize("dtype_", additional_dtypes)
106103
def test_isdtype_additional_dtypes(library, dtype_):
107-
if library == "jax.numpy":
108-
xp = import_or_skip_cupy('jax.experimental.array_api')
109-
else:
110-
xp = import_or_skip_cupy('array_api_compat.' + library)
104+
xp = import_(library, wrapper=True)
111105

112106
isdtype = xp.isdtype
113107

0 commit comments

Comments
 (0)