Skip to content

Commit e851b60

Browse files
committed
Define/assign xp in __init__.py
Prevents circular imports, generally makes more sense
1 parent bbbcb90 commit e851b60

13 files changed

+50
-62
lines changed

README.md

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,7 @@ You need to specify the array library to test. It can be specified via the
3636
$ export ARRAY_API_TESTS_MODULE=numpy.array_api
3737
```
3838

39-
Alternately, change the `array_module` variable in `array_api_tests/_array_module.py`
40-
line, e.g.
41-
42-
```diff
43-
- array_module = None
44-
+ import numpy.array_api as array_module
45-
```
39+
Alternately, import/define the `xp` variable in `array_api_tests/__init__.py`.
4640

4741
### Run the suite
4842

array_api_tests/__init__.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,36 @@
1+
import os
12
from functools import wraps
2-
from os import getenv
3+
from importlib import import_module
34

45
from hypothesis import strategies as st
56
from hypothesis.extra import array_api
67

78
from . import _version
8-
from ._array_module import mod as _xp
99

10-
__all__ = ["api_version", "xps"]
10+
__all__ = ["xp", "api_version", "xps"]
11+
12+
13+
# You can comment the following out and instead import the specific array module
14+
# you want to test, e.g. `import numpy.array_api as xp`.
15+
if "ARRAY_API_TESTS_MODULE" in os.environ:
16+
xp_name = os.environ["ARRAY_API_TESTS_MODULE"]
17+
_module, _sub = xp_name, None
18+
if "." in xp_name:
19+
_module, _sub = xp_name.split(".", 1)
20+
xp = import_module(_module)
21+
if _sub:
22+
try:
23+
xp = getattr(xp, _sub)
24+
except AttributeError:
25+
# _sub may be a submodule that needs to be imported. WE can't
26+
# do this in every case because some array modules are not
27+
# submodules that can be imported (like mxnet.nd).
28+
xp = import_module(xp_name)
29+
else:
30+
raise RuntimeError(
31+
"No array module specified - either edit __init__.py or set the "
32+
"ARRAY_API_TESTS_MODULE environment variable."
33+
)
1134

1235

1336
# We monkey patch floats() to always disable subnormals as they are out-of-scope
@@ -43,9 +66,9 @@ def _from_dtype(*a, **kw):
4366
pass
4467

4568

46-
api_version = getenv(
47-
"ARRAY_API_TESTS_VERSION", getattr(_xp, "__array_api_version__", "2021.12")
69+
api_version = os.getenv(
70+
"ARRAY_API_TESTS_VERSION", getattr(xp, "__array_api_version__", "2021.12")
4871
)
49-
xps = array_api.make_strategies_namespace(_xp, api_version=api_version)
72+
xps = array_api.make_strategies_namespace(xp, api_version=api_version)
5073

5174
__version__ = _version.get_versions()["version"]

array_api_tests/_array_module.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,5 @@
1-
import os
2-
from importlib import import_module
1+
from . import stubs, xp
32

4-
from . import stubs
5-
6-
# Replace this with a specific array module to test it, for example,
7-
#
8-
# import numpy as array_module
9-
array_module = None
10-
11-
if array_module is None:
12-
if 'ARRAY_API_TESTS_MODULE' in os.environ:
13-
mod_name = os.environ['ARRAY_API_TESTS_MODULE']
14-
_module, _sub = mod_name, None
15-
if '.' in mod_name:
16-
_module, _sub = mod_name.split('.', 1)
17-
mod = import_module(_module)
18-
if _sub:
19-
try:
20-
mod = getattr(mod, _sub)
21-
except AttributeError:
22-
# _sub may be a submodule that needs to be imported. WE can't
23-
# do this in every case because some array modules are not
24-
# submodules that can be imported (like mxnet.nd).
25-
mod = import_module(mod_name)
26-
else:
27-
raise RuntimeError("No array module specified. Either edit _array_module.py or set the ARRAY_API_TESTS_MODULE environment variable")
28-
else:
29-
mod = array_module
30-
mod_name = mod.__name__
31-
# Names from the spec. This is what should actually be imported from this
32-
# file.
333

344
class _UndefinedStub:
355
"""
@@ -45,7 +15,7 @@ def __init__(self, name):
4515
self.name = name
4616

4717
def _raise(self, *args, **kwargs):
48-
raise AssertionError(f"{self.name} is not defined in {mod_name}")
18+
raise AssertionError(f"{self.name} is not defined in {xp.__name__}")
4919

5020
def __repr__(self):
5121
return f"<undefined stub for {self.name!r}>"
@@ -67,6 +37,6 @@ def __repr__(self):
6737

6838
for attr in _top_level_attrs:
6939
try:
70-
globals()[attr] = getattr(mod, attr)
40+
globals()[attr] = getattr(xp, attr)
7141
except AttributeError:
7242
globals()[attr] = _UndefinedStub(attr)

array_api_tests/dtype_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from warnings import warn
77

88
from . import api_version
9-
from ._array_module import mod as xp
9+
from . import xp
1010
from .stubs import name_to_func
1111
from .typing import DataType, ScalarType
1212

array_api_tests/stubs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from types import FunctionType, ModuleType
77
from typing import Dict, List
88

9+
from . import api_version
10+
911
__all__ = [
1012
"name_to_func",
1113
"array_methods",
@@ -15,10 +17,9 @@
1517
"extension_to_funcs",
1618
]
1719

18-
spec_version = "2022.12"
19-
spec_module = "_" + spec_version.replace('.', '_')
20+
spec_module = "_" + api_version.replace('.', '_')
2021

21-
spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / spec_version / "API_specification"
22+
spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / api_version / "API_specification"
2223
assert spec_dir.exists(), f"{spec_dir} not found - try `git submodule update --init`"
2324
sigs_dir = Path(__file__).parent.parent / "array-api" / "src" / "array_api_stubs" / spec_module
2425
assert sigs_dir.exists()

array_api_tests/test_array_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from . import pytest_helpers as ph
1414
from . import shape_helpers as sh
1515
from . import xps
16-
from ._array_module import mod as _xp
16+
from . import xp as _xp
1717
from .typing import DataType, Index, Param, Scalar, ScalarType, Shape
1818

1919
pytestmark = pytest.mark.ci

array_api_tests/test_constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
from . import dtype_helpers as dh
7-
from ._array_module import mod as xp
7+
from . import xp
88
from .typing import Array
99

1010
pytestmark = pytest.mark.ci

array_api_tests/test_data_type_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from . import pytest_helpers as ph
1212
from . import shape_helpers as sh
1313
from . import xps
14-
from ._array_module import mod as _xp
14+
from . import xp as _xp
1515
from .typing import DataType
1616

1717
pytestmark = pytest.mark.ci

array_api_tests/test_fft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from . import pytest_helpers as ph
1515
from . import shape_helpers as sh
1616
from . import xps
17-
from ._array_module import mod as xp
17+
from . import xp
1818

1919
pytestmark = [
2020
pytest.mark.ci,

array_api_tests/test_has_names.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77

8-
from ._array_module import mod as xp, mod_name
8+
from . import xp
99
from .stubs import (array_attributes, array_methods, category_to_funcs,
1010
extension_to_funcs, EXTENSIONS)
1111

@@ -27,13 +27,13 @@
2727
def test_has_names(category, name):
2828
if category in EXTENSIONS:
2929
ext_mod = getattr(xp, category)
30-
assert hasattr(ext_mod, name), f"{mod_name} is missing the {category} extension function {name}()"
30+
assert hasattr(ext_mod, name), f"{xp.__name__} is missing the {category} extension function {name}()"
3131
elif category.startswith('array_'):
3232
# TODO: This would fail if ones() is missing.
3333
arr = xp.ones((1, 1))
3434
if category == 'array_attribute':
35-
assert hasattr(arr, name), f"The {mod_name} array object is missing the attribute {name}"
35+
assert hasattr(arr, name), f"The {xp.__name__} array object is missing the attribute {name}"
3636
else:
37-
assert hasattr(arr, name), f"The {mod_name} array object is missing the method {name}()"
37+
assert hasattr(arr, name), f"The {xp.__name__} array object is missing the method {name}()"
3838
else:
39-
assert hasattr(xp, name), f"{mod_name} is missing the {category} function {name}()"
39+
assert hasattr(xp, name), f"{xp.__name__} is missing the {category} function {name}()"

array_api_tests/test_signatures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def squeeze(x, /, axis):
3030
import pytest
3131

3232
from . import dtype_helpers as dh
33-
from ._array_module import mod as xp
33+
from . import xp
3434
from .stubs import array_methods, category_to_funcs, extension_to_funcs, name_to_func
3535

3636
pytestmark = pytest.mark.ci

array_api_tests/test_special_cases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from . import pytest_helpers as ph
3434
from . import shape_helpers as sh
3535
from . import xps
36-
from ._array_module import mod as xp
36+
from . import xp
3737
from .stubs import category_to_funcs
3838

3939
pytestmark = pytest.mark.ci

reporting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def pytest_metadata(metadata):
4949
"""
5050
Additional global metadata for --json-report.
5151
"""
52-
metadata['array_api_tests_module'] = xp.mod_name
52+
metadata['array_api_tests_module'] = xp.__name__
5353
metadata['array_api_tests_version'] = __version__
5454

5555
@fixture(autouse=True)

0 commit comments

Comments
 (0)