Skip to content

Commit afe08c8

Browse files
committed
Test extension signatures
1 parent 2c32ea1 commit afe08c8

File tree

1 file changed

+44
-16
lines changed

1 file changed

+44
-16
lines changed

array_api_tests/test_signatures.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from inspect import Parameter, signature
22
from types import FunctionType
3-
from typing import Dict
3+
from typing import Callable, Dict
44

55
import pytest
66

77
from ._array_module import mod as xp
8-
from .stubs import category_to_funcs
8+
from .stubs import category_to_funcs, extension_to_funcs
99

1010
kind_to_str: Dict[Parameter, str] = {
1111
Parameter.POSITIONAL_OR_KEYWORD: "normal argument",
@@ -16,12 +16,7 @@
1616
}
1717

1818

19-
@pytest.mark.parametrize(
20-
"stub",
21-
[s for stubs in category_to_funcs.values() for s in stubs],
22-
ids=lambda f: f.__name__,
23-
)
24-
def test_signature(stub: FunctionType):
19+
def _test_signature(func: Callable, stub: FunctionType):
2520
"""
2621
Signature of function is correct enough to not affect interoperability
2722
@@ -33,13 +28,12 @@ def add(x1, x2, /):
3328
3429
x1 and x2 don't need to be pos-only for the purposes of interoperability.
3530
"""
36-
assert hasattr(xp, stub.__name__), f"{stub.__name__} not found in array module"
37-
func = getattr(xp, stub.__name__)
38-
3931
try:
4032
sig = signature(func)
4133
except ValueError:
42-
pytest.skip(msg=f"type({stub.__name__})={type(func)} not supported by inspect")
34+
pytest.skip(
35+
msg=f"type({stub.__name__})={type(func)} not supported by inspect.signature()"
36+
)
4337
stub_sig = signature(stub)
4438
params = list(sig.parameters.values())
4539
stub_params = list(stub_sig.parameters.values())
@@ -66,7 +60,7 @@ def add(x1, x2, /):
6660
and stub_param.kind != Parameter.POSITIONAL_ONLY
6761
):
6862
pytest.skip(
69-
f"faulty spec - {stub_param.name} should be a "
63+
f"faulty spec - argument {stub_param.name} should be a "
7064
f"{kind_to_str[Parameter.POSITIONAL_OR_KEYWORD]}"
7165
)
7266
f_kind = kind_to_str[param.kind]
@@ -76,13 +70,47 @@ def add(x1, x2, /):
7670
Parameter.VAR_POSITIONAL,
7771
Parameter.VAR_KEYWORD,
7872
]:
79-
assert param.kind == stub_param.kind, (
80-
f"{param.name} is a {f_kind}, " f"but should be a {f_stub_kind}"
81-
)
73+
assert (
74+
param.kind == stub_param.kind
75+
), f"{param.name} is a {f_kind}, but should be a {f_stub_kind}"
8276
else:
8377
# TODO: allow for kw-only args to be out-of-order
8478
assert param.kind in [stub_param.kind, Parameter.POSITIONAL_OR_KEYWORD], (
8579
f"{param.name} is a {f_kind}, "
8680
f"but should be a {f_stub_kind} "
8781
f"(or at least a {kind_to_str[Parameter.POSITIONAL_OR_KEYWORD]})"
8882
)
83+
84+
85+
@pytest.mark.parametrize(
86+
"stub",
87+
[s for stubs in category_to_funcs.values() for s in stubs],
88+
ids=lambda f: f.__name__,
89+
)
90+
def test_signature(stub: FunctionType):
91+
assert hasattr(xp, stub.__name__), f"{stub.__name__} not found in array module"
92+
func = getattr(xp, stub.__name__)
93+
_test_signature(func, stub)
94+
95+
96+
extension_and_stub_params = []
97+
for ext, stubs in extension_to_funcs.items():
98+
for stub in stubs:
99+
extension_and_stub_params.append(
100+
pytest.param(
101+
ext,
102+
stub,
103+
id=f"{ext}.{stub.__name__}",
104+
marks=pytest.mark.xp_extension(ext),
105+
)
106+
)
107+
108+
109+
@pytest.mark.parametrize("extension, stub", extension_and_stub_params)
110+
def test_extension_signature(extension: str, stub: FunctionType):
111+
mod = getattr(xp, extension)
112+
assert hasattr(
113+
mod, stub.__name__
114+
), f"{stub.__name__} not found in {extension} extension"
115+
func = getattr(mod, stub.__name__)
116+
_test_signature(func, stub)

0 commit comments

Comments
 (0)