Skip to content

Commit 7493ff8

Browse files
committed
Test array object method signatures
1 parent afe08c8 commit 7493ff8

File tree

2 files changed

+38
-17
lines changed

2 files changed

+38
-17
lines changed

array_api_tests/stubs.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import sys
2+
import inspect
23
from importlib import import_module
34
from importlib.util import find_spec
45
from pathlib import Path
56
from types import FunctionType, ModuleType
67
from typing import Dict, List
78

8-
__all__ = ["category_to_funcs", "array", "EXTENSIONS", "extension_to_funcs"]
9+
__all__ = ["array_methods", "category_to_funcs", "EXTENSIONS", "extension_to_funcs"]
910

1011

1112
spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / "API_specification"
@@ -22,6 +23,12 @@
2223
name = path.name.replace(".py", "")
2324
name_to_mod[name] = import_module(f"signatures.{name}")
2425

26+
array = name_to_mod["array_object"].array
27+
array_methods = [
28+
f for n, f in inspect.getmembers(array, predicate=inspect.isfunction)
29+
if n != "__init__" # probably exists for Sphinx
30+
]
31+
2532

2633
category_to_funcs: Dict[str, List[FunctionType]] = {}
2734
for name, mod in name_to_mod.items():
@@ -31,9 +38,6 @@
3138
assert all(isinstance(o, FunctionType) for o in objects)
3239
category_to_funcs[category] = objects
3340

34-
array = name_to_mod["array_object"].array
35-
36-
3741
EXTENSIONS: str = ["linalg"]
3842
extension_to_funcs: Dict[str, List[FunctionType]] = {}
3943
for ext in EXTENSIONS:

array_api_tests/test_signatures.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
from inspect import Parameter, signature
1+
from inspect import Parameter, Signature, signature
22
from types import FunctionType
33
from typing import Callable, Dict
44

55
import pytest
6+
from hypothesis import given
67

8+
from . import hypothesis_helpers as hh
9+
from . import xps
710
from ._array_module import mod as xp
8-
from .stubs import category_to_funcs, extension_to_funcs
11+
from .stubs import array_methods, category_to_funcs, extension_to_funcs
912

1013
kind_to_str: Dict[Parameter, str] = {
1114
Parameter.POSITIONAL_OR_KEYWORD: "normal argument",
@@ -16,7 +19,9 @@
1619
}
1720

1821

19-
def _test_signature(func: Callable, stub: FunctionType):
22+
def _test_signature(
23+
func: Callable, stub: FunctionType, ignore_first_stub_param: bool = False
24+
):
2025
"""
2126
Signature of function is correct enough to not affect interoperability
2227
@@ -34,9 +39,16 @@ def add(x1, x2, /):
3439
pytest.skip(
3540
msg=f"type({stub.__name__})={type(func)} not supported by inspect.signature()"
3641
)
37-
stub_sig = signature(stub)
3842
params = list(sig.parameters.values())
43+
44+
stub_sig = signature(stub)
3945
stub_params = list(stub_sig.parameters.values())
46+
if ignore_first_stub_param:
47+
stub_params = stub_params[1:]
48+
stub = Signature(
49+
parameters=stub_params, return_annotation=stub_sig.return_annotation
50+
)
51+
4052
# We're not interested if the array module has additional arguments, so we
4153
# only iterate through the arguments listed in the spec.
4254
for i, stub_param in enumerate(stub_params):
@@ -87,7 +99,7 @@ def add(x1, x2, /):
8799
[s for stubs in category_to_funcs.values() for s in stubs],
88100
ids=lambda f: f.__name__,
89101
)
90-
def test_signature(stub: FunctionType):
102+
def test_func_signature(stub: FunctionType):
91103
assert hasattr(xp, stub.__name__), f"{stub.__name__} not found in array module"
92104
func = getattr(xp, stub.__name__)
93105
_test_signature(func, stub)
@@ -96,21 +108,26 @@ def test_signature(stub: FunctionType):
96108
extension_and_stub_params = []
97109
for ext, stubs in extension_to_funcs.items():
98110
for stub in stubs:
99-
extension_and_stub_params.append(
100-
pytest.param(
101-
ext,
102-
stub,
103-
id=f"{ext}.{stub.__name__}",
104-
marks=pytest.mark.xp_extension(ext),
105-
)
111+
p = pytest.param(
112+
ext, stub, id=f"{ext}.{stub.__name__}", marks=pytest.mark.xp_extension(ext)
106113
)
114+
extension_and_stub_params.append(p)
107115

108116

109117
@pytest.mark.parametrize("extension, stub", extension_and_stub_params)
110-
def test_extension_signature(extension: str, stub: FunctionType):
118+
def test_extension_func_signature(extension: str, stub: FunctionType):
111119
mod = getattr(xp, extension)
112120
assert hasattr(
113121
mod, stub.__name__
114122
), f"{stub.__name__} not found in {extension} extension"
115123
func = getattr(mod, stub.__name__)
116124
_test_signature(func, stub)
125+
126+
127+
@pytest.mark.parametrize("stub", array_methods, ids=lambda f: f.__name__)
128+
@given(x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()))
129+
def test_array_method_signature(stub: FunctionType, x):
130+
assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}"
131+
method = getattr(x, stub.__name__)
132+
# Ignore 'self' arg in stub, which won't be present in instantiated objects.
133+
_test_signature(method, stub, ignore_first_stub_param=True)

0 commit comments

Comments
 (0)