Skip to content

Commit 9afe8c7

Browse files
authored
Merge pull request #213 from honno/bump-spec
Bump `array-api` submodule and utilise its all-versions setup
2 parents 9d7777b + 6b870a9 commit 9afe8c7

14 files changed

+92
-73
lines changed

README.md

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,7 @@ You need to specify the array library to test. It can be specified via the
3636
$ export ARRAY_API_TESTS_MODULE=numpy.array_api
3737
```
3838

39-
Alternately, change the `array_module` variable in `array_api_tests/_array_module.py`
40-
line, e.g.
41-
42-
```diff
43-
- array_module = None
44-
+ import numpy.array_api as array_module
45-
```
39+
Alternately, import/define the `xp` variable in `array_api_tests/__init__.py`.
4640

4741
### Run the suite
4842

array-api

Submodule array-api updated 211 files

array_api_tests/__init__.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,36 @@
1+
import os
12
from functools import wraps
2-
from os import getenv
3+
from importlib import import_module
34

45
from hypothesis import strategies as st
56
from hypothesis.extra import array_api
67

78
from . import _version
8-
from ._array_module import mod as _xp
99

10-
__all__ = ["api_version", "xps"]
10+
__all__ = ["xp", "api_version", "xps"]
11+
12+
13+
# You can comment the following out and instead import the specific array module
14+
# you want to test, e.g. `import numpy.array_api as xp`.
15+
if "ARRAY_API_TESTS_MODULE" in os.environ:
16+
xp_name = os.environ["ARRAY_API_TESTS_MODULE"]
17+
_module, _sub = xp_name, None
18+
if "." in xp_name:
19+
_module, _sub = xp_name.split(".", 1)
20+
xp = import_module(_module)
21+
if _sub:
22+
try:
23+
xp = getattr(xp, _sub)
24+
except AttributeError:
25+
# _sub may be a submodule that needs to be imported. WE can't
26+
# do this in every case because some array modules are not
27+
# submodules that can be imported (like mxnet.nd).
28+
xp = import_module(xp_name)
29+
else:
30+
raise RuntimeError(
31+
"No array module specified - either edit __init__.py or set the "
32+
"ARRAY_API_TESTS_MODULE environment variable."
33+
)
1134

1235

1336
# We monkey patch floats() to always disable subnormals as they are out-of-scope
@@ -43,9 +66,9 @@ def _from_dtype(*a, **kw):
4366
pass
4467

4568

46-
api_version = getenv(
47-
"ARRAY_API_TESTS_VERSION", getattr(_xp, "__array_api_version__", "2021.12")
69+
api_version = os.getenv(
70+
"ARRAY_API_TESTS_VERSION", getattr(xp, "__array_api_version__", "2021.12")
4871
)
49-
xps = array_api.make_strategies_namespace(_xp, api_version=api_version)
72+
xps = array_api.make_strategies_namespace(xp, api_version=api_version)
5073

5174
__version__ = _version.get_versions()["version"]

array_api_tests/_array_module.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,5 @@
1-
import os
2-
from importlib import import_module
1+
from . import stubs, xp
32

4-
from . import stubs
5-
6-
# Replace this with a specific array module to test it, for example,
7-
#
8-
# import numpy as array_module
9-
array_module = None
10-
11-
if array_module is None:
12-
if 'ARRAY_API_TESTS_MODULE' in os.environ:
13-
mod_name = os.environ['ARRAY_API_TESTS_MODULE']
14-
_module, _sub = mod_name, None
15-
if '.' in mod_name:
16-
_module, _sub = mod_name.split('.', 1)
17-
mod = import_module(_module)
18-
if _sub:
19-
try:
20-
mod = getattr(mod, _sub)
21-
except AttributeError:
22-
# _sub may be a submodule that needs to be imported. WE can't
23-
# do this in every case because some array modules are not
24-
# submodules that can be imported (like mxnet.nd).
25-
mod = import_module(mod_name)
26-
else:
27-
raise RuntimeError("No array module specified. Either edit _array_module.py or set the ARRAY_API_TESTS_MODULE environment variable")
28-
else:
29-
mod = array_module
30-
mod_name = mod.__name__
31-
# Names from the spec. This is what should actually be imported from this
32-
# file.
333

344
class _UndefinedStub:
355
"""
@@ -45,7 +15,7 @@ def __init__(self, name):
4515
self.name = name
4616

4717
def _raise(self, *args, **kwargs):
48-
raise AssertionError(f"{self.name} is not defined in {mod_name}")
18+
raise AssertionError(f"{self.name} is not defined in {xp.__name__}")
4919

5020
def __repr__(self):
5121
return f"<undefined stub for {self.name!r}>"
@@ -67,6 +37,6 @@ def __repr__(self):
6737

6838
for attr in _top_level_attrs:
6939
try:
70-
globals()[attr] = getattr(mod, attr)
40+
globals()[attr] = getattr(xp, attr)
7141
except AttributeError:
7242
globals()[attr] = _UndefinedStub(attr)

array_api_tests/dtype_helpers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from warnings import warn
77

88
from . import api_version
9-
from ._array_module import mod as xp
9+
from . import xp
1010
from .stubs import name_to_func
1111
from .typing import DataType, ScalarType
1212

@@ -352,6 +352,9 @@ def result_type(*dtypes: DataType):
352352
"boolean": (xp.bool,),
353353
"integer": all_int_dtypes,
354354
"floating-point": real_float_dtypes,
355+
"real-valued": real_float_dtypes,
356+
"real-valued floating-point": real_float_dtypes,
357+
"complex floating-point": complex_dtypes,
355358
"numeric": numeric_dtypes,
356359
"integer or boolean": bool_and_all_int_dtypes,
357360
}
@@ -364,8 +367,6 @@ def result_type(*dtypes: DataType):
364367
dtype_category = "floating-point"
365368
dtypes = category_to_dtypes[dtype_category]
366369
func_in_dtypes[name] = dtypes
367-
# See https://github.com/data-apis/array-api/pull/413
368-
func_in_dtypes["expm1"] = real_float_dtypes
369370

370371

371372
func_returns_bool = {

array_api_tests/stubs.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from types import FunctionType, ModuleType
77
from typing import Dict, List
88

9+
from . import api_version
10+
911
__all__ = [
1012
"name_to_func",
1113
"array_methods",
@@ -15,20 +17,21 @@
1517
"extension_to_funcs",
1618
]
1719

20+
spec_module = "_" + api_version.replace('.', '_')
1821

19-
spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / "API_specification"
22+
spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / api_version / "API_specification"
2023
assert spec_dir.exists(), f"{spec_dir} not found - try `git submodule update --init`"
21-
sigs_dir = spec_dir / "signatures"
24+
sigs_dir = Path(__file__).parent.parent / "array-api" / "src" / "array_api_stubs" / spec_module
2225
assert sigs_dir.exists()
2326

24-
spec_abs_path: str = str(spec_dir.resolve())
25-
sys.path.append(spec_abs_path)
26-
assert find_spec("signatures") is not None
27+
sigs_abs_path: str = str(sigs_dir.parent.parent.resolve())
28+
sys.path.append(sigs_abs_path)
29+
assert find_spec(f"array_api_stubs.{spec_module}") is not None
2730

2831
name_to_mod: Dict[str, ModuleType] = {}
2932
for path in sigs_dir.glob("*.py"):
3033
name = path.name.replace(".py", "")
31-
name_to_mod[name] = import_module(f"signatures.{name}")
34+
name_to_mod[name] = import_module(f"array_api_stubs.{spec_module}.{name}")
3235

3336
array = name_to_mod["array_object"].array
3437
array_methods = [
@@ -70,3 +73,7 @@
7073
for func in funcs:
7174
if func.__name__ not in name_to_func.keys():
7275
name_to_func[func.__name__] = func
76+
77+
# sanity check public attributes are not empty
78+
for attr in __all__:
79+
assert len(locals()[attr]) != 0, f"{attr} is empty"

array_api_tests/test_array_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from . import pytest_helpers as ph
1414
from . import shape_helpers as sh
1515
from . import xps
16-
from ._array_module import mod as _xp
16+
from . import xp as _xp
1717
from .typing import DataType, Index, Param, Scalar, ScalarType, Shape
1818

1919
pytestmark = pytest.mark.ci

array_api_tests/test_constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
from . import dtype_helpers as dh
7-
from ._array_module import mod as xp
7+
from . import xp
88
from .typing import Array
99

1010
pytestmark = pytest.mark.ci

array_api_tests/test_data_type_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from . import pytest_helpers as ph
1212
from . import shape_helpers as sh
1313
from . import xps
14-
from ._array_module import mod as _xp
14+
from . import xp as _xp
1515
from .typing import DataType
1616

1717
pytestmark = pytest.mark.ci

array_api_tests/test_fft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from . import pytest_helpers as ph
1515
from . import shape_helpers as sh
1616
from . import xps
17-
from ._array_module import mod as xp
17+
from . import xp
1818

1919
pytestmark = [
2020
pytest.mark.ci,

array_api_tests/test_has_names.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77

8-
from ._array_module import mod as xp, mod_name
8+
from . import xp
99
from .stubs import (array_attributes, array_methods, category_to_funcs,
1010
extension_to_funcs, EXTENSIONS)
1111

@@ -27,13 +27,13 @@
2727
def test_has_names(category, name):
2828
if category in EXTENSIONS:
2929
ext_mod = getattr(xp, category)
30-
assert hasattr(ext_mod, name), f"{mod_name} is missing the {category} extension function {name}()"
30+
assert hasattr(ext_mod, name), f"{xp.__name__} is missing the {category} extension function {name}()"
3131
elif category.startswith('array_'):
3232
# TODO: This would fail if ones() is missing.
3333
arr = xp.ones((1, 1))
3434
if category == 'array_attribute':
35-
assert hasattr(arr, name), f"The {mod_name} array object is missing the attribute {name}"
35+
assert hasattr(arr, name), f"The {xp.__name__} array object is missing the attribute {name}"
3636
else:
37-
assert hasattr(arr, name), f"The {mod_name} array object is missing the method {name}()"
37+
assert hasattr(arr, name), f"The {xp.__name__} array object is missing the method {name}()"
3838
else:
39-
assert hasattr(xp, name), f"{mod_name} is missing the {category} function {name}()"
39+
assert hasattr(xp, name), f"{xp.__name__} is missing the {category} function {name}()"

array_api_tests/test_signatures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def squeeze(x, /, axis):
3030
import pytest
3131

3232
from . import dtype_helpers as dh
33-
from ._array_module import mod as xp
33+
from . import xp
3434
from .stubs import array_methods, category_to_funcs, extension_to_funcs, name_to_func
3535

3636
pytestmark = pytest.mark.ci

array_api_tests/test_special_cases.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
from . import hypothesis_helpers as hh
3333
from . import pytest_helpers as ph
3434
from . import shape_helpers as sh
35-
from . import xps
36-
from ._array_module import mod as xp
35+
from . import xp, xps
3736
from .stubs import category_to_funcs
3837

3938
pytestmark = pytest.mark.ci
@@ -126,6 +125,8 @@ def abs_cond(i: float) -> bool:
126125
"infinity": float("inf"),
127126
"0": 0.0,
128127
"1": 1.0,
128+
"False": 0.0,
129+
"True": 1.0,
129130
}
130131
r_value = re.compile(r"([+-]?)(.+)")
131132
r_pi = re.compile(r"(\d?)π(?:/(\d))?")
@@ -158,7 +159,10 @@ def parse_value(value_str: str) -> float:
158159
if denominator := pi_m.group(2):
159160
value /= int(denominator)
160161
else:
161-
value = repr_to_value[m.group(2)]
162+
try:
163+
value = repr_to_value[m.group(2)]
164+
except KeyError as e:
165+
raise ParseError(value_str) from e
162166
if sign := m.group(1):
163167
if sign == "-":
164168
value *= -1
@@ -507,7 +511,10 @@ def __repr__(self) -> str:
507511
return f"{self.__class__.__name__}(<{self}>)"
508512

509513

510-
r_case_block = re.compile(r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters")
514+
r_case_block = re.compile(
515+
r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*"
516+
r"(?:.+\n--+)?(?:\.\. versionchanged.*)?"
517+
)
511518
r_case = re.compile(r"\s+-\s*(.*)\.")
512519

513520

@@ -1121,6 +1128,9 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11211128
iop_params = []
11221129
func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()}
11231130
for stub in category_to_funcs["elementwise"]:
1131+
# if stub.__name__ == "abs":
1132+
# import ipdb; ipdb.set_trace()
1133+
11241134
if stub.__doc__ is None:
11251135
warn(f"{stub.__name__}() stub has no docstring")
11261136
continue
@@ -1167,6 +1177,8 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11671177
op = getattr(operator, op_name)
11681178
name_to_func[op_name] = op
11691179
# We collect inplace operator test cases seperately
1180+
if stub.__name__ == "equal":
1181+
break
11701182
iop_name = "__i" + op_name[2:]
11711183
iop = getattr(operator, iop_name)
11721184
for case in cases:
@@ -1197,6 +1209,11 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11971209
# its False - Hypothesis will complain if we reject too many examples, thus
11981210
# indicating we've done something wrong.
11991211

1212+
# sanity checks
1213+
assert len(unary_params) != 0
1214+
assert len(binary_params) != 0
1215+
assert len(iop_params) != 0
1216+
12001217

12011218
@pytest.mark.parametrize("func_name, func, case", unary_params)
12021219
@given(
@@ -1254,7 +1271,12 @@ def test_binary(func_name, func, case, x1, x2, data):
12541271

12551272
res = func(x1, x2)
12561273
# sanity check
1257-
ph.assert_result_shape(func_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape, expected=result_shape)
1274+
ph.assert_result_shape(
1275+
func_name,
1276+
in_shapes=[x1.shape, x2.shape],
1277+
out_shape=res.shape,
1278+
expected=result_shape,
1279+
)
12581280

12591281
good_example = False
12601282
for l_idx, r_idx, o_idx in all_indices:
@@ -1306,7 +1328,9 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data):
13061328
res = xp.asarray(x1, copy=True)
13071329
res = iop(res, x2)
13081330
# sanity check
1309-
ph.assert_result_shape(iop_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape)
1331+
ph.assert_result_shape(
1332+
iop_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape
1333+
)
13101334

13111335
good_example = False
13121336
for l_idx, r_idx, o_idx in all_indices:

reporting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def pytest_metadata(metadata):
4949
"""
5050
Additional global metadata for --json-report.
5151
"""
52-
metadata['array_api_tests_module'] = xp.mod_name
52+
metadata['array_api_tests_module'] = xp.__name__
5353
metadata['array_api_tests_version'] = __version__
5454

5555
@fixture(autouse=True)

0 commit comments

Comments
 (0)