Skip to content

Commit 5765ec0

Browse files
committed
Streamline uninspectable testing, just skipping awkward cases
1 parent b5c023e commit 5765ec0

File tree

1 file changed

+81
-123
lines changed

1 file changed

+81
-123
lines changed

array_api_tests/test_signatures.py

Lines changed: 81 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,13 @@ def squeeze(x, /, axis):
1818
...
1919
2020
"""
21-
from collections import defaultdict
2221
from copy import copy
2322
from inspect import Parameter, Signature, signature
2423
from itertools import chain
2524
from types import FunctionType
26-
from typing import Any, Callable, DefaultDict, Dict, List, Literal, Sequence, get_args
25+
from typing import Any, Callable, Dict, List, Literal, Sequence, get_args
2726

2827
import pytest
29-
from hypothesis import given, note
3028
from hypothesis import strategies as st
3129

3230
from . import dtype_helpers as dh
@@ -35,7 +33,7 @@ def squeeze(x, /, axis):
3533
from ._array_module import _UndefinedStub
3634
from ._array_module import mod as xp
3735
from .stubs import array_methods, category_to_funcs, extension_to_funcs
38-
from .typing import DataType, Shape
36+
from .typing import DataType
3937

4038
pytestmark = pytest.mark.ci
4139

@@ -53,7 +51,7 @@ def squeeze(x, /, axis):
5351
Parameter.POSITIONAL_ONLY: "pos-only argument",
5452
Parameter.KEYWORD_ONLY: "keyword-only argument",
5553
Parameter.VAR_POSITIONAL: "star-args (i.e. *args) argument",
56-
Parameter.VAR_KEYWORD: "star-kwargs (i.e. **kwargs) argument",
54+
Parameter.VAR_KEYWORD: "star-kwonly (i.e. **kwonly) argument",
5755
}
5856

5957

@@ -63,14 +61,13 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
6361
# We're not interested if the array module has additional arguments, so we
6462
# only iterate through the arguments listed in the spec.
6563
for i, stub_param in enumerate(stub_params):
66-
if sig is not None:
67-
assert (
68-
len(params) >= i + 1
69-
), f"Argument '{stub_param.name}' missing from signature"
70-
param = params[i]
64+
assert (
65+
len(params) >= i + 1
66+
), f"Argument '{stub_param.name}' missing from signature"
67+
param = params[i]
7168

7269
# We're not interested in the name if it isn't actually used
73-
if sig is not None and stub_param.kind not in [
70+
if stub_param.kind not in [
7471
Parameter.POSITIONAL_ONLY,
7572
*VAR_KINDS,
7673
]:
@@ -80,50 +77,17 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
8077

8178
f_stub_kind = kind_to_str[stub_param.kind]
8279
if stub_param.kind in [Parameter.POSITIONAL_OR_KEYWORD, *VAR_KINDS]:
83-
if sig is not None:
84-
assert param.kind == stub_param.kind, (
85-
f"{param.name} is a {kind_to_str[param.kind]}, "
86-
f"but should be a {f_stub_kind}"
87-
)
88-
else:
89-
pass
80+
assert param.kind == stub_param.kind, (
81+
f"{param.name} is a {kind_to_str[param.kind]}, "
82+
f"but should be a {f_stub_kind}"
83+
)
9084
else:
9185
# TODO: allow for kw-only args to be out-of-order
92-
if sig is not None:
93-
assert param.kind in [
94-
stub_param.kind,
95-
Parameter.POSITIONAL_OR_KEYWORD,
96-
], (
97-
f"{param.name} is a {kind_to_str[param.kind]}, "
98-
f"but should be a {f_stub_kind} "
99-
f"(or at least a {kind_to_str[ParameterKind.POSITIONAL_OR_KEYWORD]})"
100-
)
101-
else:
102-
pass
103-
104-
105-
def shapes(**kw) -> st.SearchStrategy[Shape]:
106-
if "min_side" not in kw.keys():
107-
kw["min_side"] = 1
108-
return hh.shapes(**kw)
109-
110-
111-
matrixy_funcs: List[str] = [
112-
f.__name__
113-
for f in chain(category_to_funcs["linear_algebra"], extension_to_funcs["linalg"])
114-
]
115-
matrixy_funcs += ["__matmul__", "triu", "tril"]
116-
func_to_shapes: DefaultDict[str, st.SearchStrategy[Shape]] = defaultdict(
117-
shapes,
118-
{
119-
**{k: st.just(()) for k in ["__bool__", "__int__", "__index__", "__float__"]},
120-
"sort": shapes(min_dims=1), # for axis=-1,
121-
**{k: shapes(min_dims=2) for k in matrixy_funcs},
122-
# Overwrite min_dims=2 shapes for some matrixy functions
123-
"cross": shapes(min_side=3, max_side=3, min_dims=3, max_dims=3),
124-
"outer": shapes(min_dims=1, max_dims=1),
125-
},
126-
)
86+
assert param.kind in [stub_param.kind, Parameter.POSITIONAL_OR_KEYWORD,], (
87+
f"{param.name} is a {kind_to_str[param.kind]}, "
88+
f"but should be a {f_stub_kind} "
89+
f"(or at least a {kind_to_str[ParameterKind.POSITIONAL_OR_KEYWORD]})"
90+
)
12791

12892

12993
def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]:
@@ -136,97 +100,93 @@ def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]:
136100
return xps.scalar_dtypes()
137101

138102

139-
func_to_example_values: Dict[str, Dict[ParameterKind, Dict[str, Any]]] = {
140-
"broadcast_to": {
141-
Parameter.POSITIONAL_ONLY: {"x": xp.asarray([0, 1])},
142-
Parameter.POSITIONAL_OR_KEYWORD: {"shape": (1, 2)},
143-
},
144-
"cholesky": {
145-
Parameter.POSITIONAL_ONLY: {"x": xp.asarray([[1.0, 0.0], [0.0, 1.0]])}
146-
},
147-
"inv": {Parameter.POSITIONAL_ONLY: {"x": xp.asarray([[1.0, 0.0], [0.0, 1.0]])}},
148-
}
149-
150-
151-
def make_pretty_func(func_name: str, args: Sequence[Any], kwargs: Dict[str, Any]):
103+
def make_pretty_func(func_name: str, args: Sequence[Any], kwonly: Dict[str, Any]):
152104
f_sig = f"{func_name}("
153105
f_sig += ", ".join(str(a) for a in args)
154-
if len(kwargs) != 0:
106+
if len(kwonly) != 0:
155107
if len(args) != 0:
156108
f_sig += ", "
157-
f_sig += ", ".join(f"{k}={v}" for k, v in kwargs.items())
109+
f_sig += ", ".join(f"{k}={v}" for k, v in kwonly.items())
158110
f_sig += ")"
159111
return f_sig
160112

161113

162-
@given(data=st.data())
163-
def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature, data):
164-
example_values: Dict[ParameterKind, Dict[str, Any]] = func_to_example_values.get(
165-
func_name, {}
166-
)
167-
for kind in ALL_KINDS:
168-
example_values.setdefault(kind, {})
114+
matrixy_funcs: List[str] = [
115+
f.__name__
116+
for f in chain(category_to_funcs["linear_algebra"], extension_to_funcs["linalg"])
117+
]
118+
matrixy_funcs += ["__matmul__", "triu", "tril"]
169119

170-
for param in stub_sig.parameters.values():
171-
for name_to_value in example_values.values():
172-
if param.name in name_to_value.keys():
173-
continue
174120

175-
if param.default != Parameter.empty:
176-
example_value = param.default
121+
def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature):
122+
skip_msg = (
123+
f"Signature for {func_name}() is not inspectable "
124+
"and is too troublesome to test for otherwise"
125+
)
126+
if func_name in [
127+
"__bool__",
128+
"__int__",
129+
"__index__",
130+
"__float__",
131+
"pow",
132+
"bitwise_left_shift",
133+
"bitwise_right_shift",
134+
"broadcast_to",
135+
"permute_dims",
136+
"sort",
137+
*matrixy_funcs,
138+
]:
139+
pytest.skip(skip_msg)
140+
141+
param_to_value: Dict[Parameter, Any] = {}
142+
for param in stub_sig.parameters.values():
143+
if param.kind in VAR_KINDS:
144+
pytest.skip(skip_msg)
145+
elif param.default != Parameter.empty:
146+
value = param.default
177147
elif param.name in ["x", "x1"]:
178148
dtypes = get_dtypes_strategy(func_name)
179-
shapes = func_to_shapes[func_name]
180-
example_value = data.draw(
181-
xps.arrays(dtype=dtypes, shape=shapes), label=param.name
182-
)
149+
value = xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)).example()
183150
elif param.name == "x2":
184151
# sanity check
185-
assert "x1" in example_values[Parameter.POSITIONAL_ONLY].keys()
186-
x1 = example_values[Parameter.POSITIONAL_ONLY]["x1"]
187-
example_value = data.draw(
188-
xps.arrays(dtype=x1.dtype, shape=x1.shape), label="x2"
189-
)
190-
elif param.name == "axes":
191-
example_value = ()
192-
elif param.name == "shape":
193-
example_value = ()
152+
assert "x1" in [p.name for p in param_to_value.keys()]
153+
x1 = next(v for p, v in param_to_value.items() if p.name == "x1")
154+
value = xps.arrays(dtype=x1.dtype, shape=x1.shape).example()
194155
else:
195-
pytest.skip(f"No example value for argument '{param.name}'")
196-
197-
if param.kind in VAR_KINDS:
198-
pytest.skip("TODO")
199-
example_values[param.kind][param.name] = example_value
200-
201-
if len(example_values[Parameter.POSITIONAL_OR_KEYWORD]) == 0:
202-
f_func = make_pretty_func(
203-
func_name,
204-
example_values[Parameter.POSITIONAL_ONLY].values(),
205-
example_values[Parameter.KEYWORD_ONLY],
206-
)
207-
note(f"trying {f_func}")
208-
func(
209-
*example_values[Parameter.POSITIONAL_ONLY].values(),
210-
**example_values[Parameter.KEYWORD_ONLY],
211-
)
156+
pytest.skip(skip_msg)
157+
param_to_value[param] = value
158+
159+
posonly: List[Any] = [
160+
v for p, v in param_to_value.items() if p.kind == Parameter.POSITIONAL_ONLY
161+
]
162+
kwonly: Dict[str, Any] = {
163+
p.name: v for p, v in param_to_value.items() if p.kind == Parameter.KEYWORD_ONLY
164+
}
165+
if (
166+
sum(p.kind == Parameter.POSITIONAL_OR_KEYWORD for p in param_to_value.keys())
167+
== 0
168+
):
169+
f_func = make_pretty_func(func_name, posonly, kwonly)
170+
print(f"trying {f_func}")
171+
func(*posonly, **kwonly)
212172
else:
213173
either_argname_value_pairs = list(
214-
example_values[Parameter.POSITIONAL_OR_KEYWORD].items()
174+
(p.name, v)
175+
for p, v in param_to_value.items()
176+
if p.kind == Parameter.POSITIONAL_OR_KEYWORD
215177
)
216178
n_either_args = len(either_argname_value_pairs)
217179
for n_extra_args in reversed(range(n_either_args + 1)):
218-
extra_args = [v for _, v in either_argname_value_pairs[:n_extra_args]]
180+
extra_posargs = [v for _, v in either_argname_value_pairs[:n_extra_args]]
219181
if n_extra_args < n_either_args:
220182
extra_kwargs = dict(either_argname_value_pairs[n_extra_args:])
221183
else:
222184
extra_kwargs = {}
223-
args = list(example_values[Parameter.POSITIONAL_ONLY].values())
224-
args += extra_args
225-
kwargs = copy(example_values[Parameter.KEYWORD_ONLY])
226-
if len(extra_kwargs) != 0:
227-
kwargs.update(extra_kwargs)
185+
args = copy(posonly)
186+
args += extra_posargs
187+
kwargs = {**kwonly, **extra_kwargs}
228188
f_func = make_pretty_func(func_name, args, kwargs)
229-
note(f"trying {f_func}")
189+
print(f"trying {f_func}")
230190
func(*args, **kwargs)
231191

232192

@@ -279,11 +239,9 @@ def test_extension_func_signature(extension: str, stub: FunctionType):
279239

280240

281241
@pytest.mark.parametrize("stub", array_methods, ids=lambda f: f.__name__)
282-
@given(data=st.data())
283-
def test_array_method_signature(stub: FunctionType, data):
242+
def test_array_method_signature(stub: FunctionType):
284243
dtypes = get_dtypes_strategy(stub.__name__)
285-
shapes = func_to_shapes[stub.__name__]
286-
x = data.draw(xps.arrays(dtype=dtypes, shape=shapes), label="x")
244+
x = xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)).example()
287245
assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}"
288246
method = getattr(x, stub.__name__)
289247
# Ignore 'self' arg in stub, which won't be present in instantiated objects.

0 commit comments

Comments
 (0)