Skip to content

Commit b5c023e

Browse files
committed
Test different arg/kwarg arrangements for uninspectable normal args
1 parent 422401e commit b5c023e

File tree

1 file changed

+96
-40
lines changed

1 file changed

+96
-40
lines changed

array_api_tests/test_signatures.py

Lines changed: 96 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@ def squeeze(x, /, axis):
1919
2020
"""
2121
from collections import defaultdict
22+
from copy import copy
2223
from inspect import Parameter, Signature, signature
2324
from itertools import chain
2425
from types import FunctionType
25-
from typing import Callable, DefaultDict, Dict, List
26+
from typing import Any, Callable, DefaultDict, Dict, List, Literal, Sequence, get_args
2627

2728
import pytest
28-
from hypothesis import given
29+
from hypothesis import given, note
2930
from hypothesis import strategies as st
3031

3132
from . import dtype_helpers as dh
@@ -38,17 +39,23 @@ def squeeze(x, /, axis):
3839

3940
pytestmark = pytest.mark.ci
4041

41-
42-
kind_to_str: Dict[Parameter, str] = {
42+
ParameterKind = Literal[
43+
Parameter.POSITIONAL_ONLY,
44+
Parameter.VAR_POSITIONAL,
45+
Parameter.POSITIONAL_OR_KEYWORD,
46+
Parameter.KEYWORD_ONLY,
47+
Parameter.VAR_KEYWORD,
48+
]
49+
ALL_KINDS = get_args(ParameterKind)
50+
VAR_KINDS = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
51+
kind_to_str: Dict[ParameterKind, str] = {
4352
Parameter.POSITIONAL_OR_KEYWORD: "normal argument",
4453
Parameter.POSITIONAL_ONLY: "pos-only argument",
4554
Parameter.KEYWORD_ONLY: "keyword-only argument",
4655
Parameter.VAR_POSITIONAL: "star-args (i.e. *args) argument",
4756
Parameter.VAR_KEYWORD: "star-kwargs (i.e. **kwargs) argument",
4857
}
4958

50-
VAR_KINDS = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
51-
5259

5360
def _test_inspectable_func(sig: Signature, stub_sig: Signature):
5461
params = list(sig.parameters.values())
@@ -89,11 +96,12 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
8996
], (
9097
f"{param.name} is a {kind_to_str[param.kind]}, "
9198
f"but should be a {f_stub_kind} "
92-
f"(or at least a {kind_to_str[Parameter.POSITIONAL_OR_KEYWORD]})"
99+
f"(or at least a {kind_to_str[ParameterKind.POSITIONAL_OR_KEYWORD]})"
93100
)
94101
else:
95102
pass
96103

104+
97105
def shapes(**kw) -> st.SearchStrategy[Shape]:
98106
if "min_side" not in kw.keys():
99107
kw["min_side"] = 1
@@ -111,7 +119,7 @@ def shapes(**kw) -> st.SearchStrategy[Shape]:
111119
**{k: st.just(()) for k in ["__bool__", "__int__", "__index__", "__float__"]},
112120
"sort": shapes(min_dims=1), # for axis=-1,
113121
**{k: shapes(min_dims=2) for k in matrixy_funcs},
114-
# Override for some matrixy functions
122+
# Overwrite min_dims=2 shapes for some matrixy functions
115123
"cross": shapes(min_side=3, max_side=3, min_dims=3, max_dims=3),
116124
"outer": shapes(min_dims=1, max_dims=1),
117125
},
@@ -128,50 +136,98 @@ def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]:
128136
return xps.scalar_dtypes()
129137

130138

139+
func_to_example_values: Dict[str, Dict[ParameterKind, Dict[str, Any]]] = {
140+
"broadcast_to": {
141+
Parameter.POSITIONAL_ONLY: {"x": xp.asarray([0, 1])},
142+
Parameter.POSITIONAL_OR_KEYWORD: {"shape": (1, 2)},
143+
},
144+
"cholesky": {
145+
Parameter.POSITIONAL_ONLY: {"x": xp.asarray([[1.0, 0.0], [0.0, 1.0]])}
146+
},
147+
"inv": {Parameter.POSITIONAL_ONLY: {"x": xp.asarray([[1.0, 0.0], [0.0, 1.0]])}},
148+
}
149+
150+
151+
def make_pretty_func(func_name: str, args: Sequence[Any], kwargs: Dict[str, Any]):
152+
f_sig = f"{func_name}("
153+
f_sig += ", ".join(str(a) for a in args)
154+
if len(kwargs) != 0:
155+
if len(args) != 0:
156+
f_sig += ", "
157+
f_sig += ", ".join(f"{k}={v}" for k, v in kwargs.items())
158+
f_sig += ")"
159+
return f_sig
160+
161+
131162
@given(data=st.data())
132163
def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature, data):
133-
if func_name in ["cholesky", "inv"]:
134-
func(xp.asarray([[1.0, 0.0], [0.0, 1.0]]))
135-
return
136-
elif func_name == "solve":
137-
func(xp.asarray([[1.0, 2.0], [3.0, 5.0]]), xp.asarray([1.0, 2.0]))
138-
return
139-
140-
pos_argname_to_example_value = {}
141-
normal_argname_to_example_value = {}
142-
kw_argname_to_example_value = {}
143-
for stub_param in stub_sig.parameters.values():
144-
if stub_param.name in ["x", "x1"]:
164+
example_values: Dict[ParameterKind, Dict[str, Any]] = func_to_example_values.get(
165+
func_name, {}
166+
)
167+
for kind in ALL_KINDS:
168+
example_values.setdefault(kind, {})
169+
170+
for param in stub_sig.parameters.values():
171+
for name_to_value in example_values.values():
172+
if param.name in name_to_value.keys():
173+
continue
174+
175+
if param.default != Parameter.empty:
176+
example_value = param.default
177+
elif param.name in ["x", "x1"]:
145178
dtypes = get_dtypes_strategy(func_name)
146179
shapes = func_to_shapes[func_name]
147180
example_value = data.draw(
148-
xps.arrays(dtype=dtypes, shape=shapes), label=stub_param.name
181+
xps.arrays(dtype=dtypes, shape=shapes), label=param.name
149182
)
150-
elif stub_param.name == "x2":
151-
assert "x1" in pos_argname_to_example_value.keys() # sanity check
152-
x1 = pos_argname_to_example_value["x1"]
183+
elif param.name == "x2":
184+
# sanity check
185+
assert "x1" in example_values[Parameter.POSITIONAL_ONLY].keys()
186+
x1 = example_values[Parameter.POSITIONAL_ONLY]["x1"]
153187
example_value = data.draw(
154188
xps.arrays(dtype=x1.dtype, shape=x1.shape), label="x2"
155189
)
190+
elif param.name == "axes":
191+
example_value = ()
192+
elif param.name == "shape":
193+
example_value = ()
156194
else:
157-
if stub_param.default != Parameter.empty:
158-
example_value = stub_param.default
159-
else:
160-
pytest.skip(f"No example value for argument '{stub_param.name}'")
161-
162-
if stub_param.kind == Parameter.POSITIONAL_ONLY:
163-
pos_argname_to_example_value[stub_param.name] = example_value
164-
elif stub_param.kind == Parameter.POSITIONAL_OR_KEYWORD:
165-
normal_argname_to_example_value[stub_param.name] = example_value
166-
elif stub_param.kind == Parameter.KEYWORD_ONLY:
167-
kw_argname_to_example_value[stub_param.name] = example_value
168-
else:
169-
pytest.skip()
195+
pytest.skip(f"No example value for argument '{param.name}'")
170196

171-
if len(normal_argname_to_example_value) == 0:
172-
func(*pos_argname_to_example_value.values(), **kw_argname_to_example_value)
197+
if param.kind in VAR_KINDS:
198+
pytest.skip("TODO")
199+
example_values[param.kind][param.name] = example_value
200+
201+
if len(example_values[Parameter.POSITIONAL_OR_KEYWORD]) == 0:
202+
f_func = make_pretty_func(
203+
func_name,
204+
example_values[Parameter.POSITIONAL_ONLY].values(),
205+
example_values[Parameter.KEYWORD_ONLY],
206+
)
207+
note(f"trying {f_func}")
208+
func(
209+
*example_values[Parameter.POSITIONAL_ONLY].values(),
210+
**example_values[Parameter.KEYWORD_ONLY],
211+
)
173212
else:
174-
pass # TODO
213+
either_argname_value_pairs = list(
214+
example_values[Parameter.POSITIONAL_OR_KEYWORD].items()
215+
)
216+
n_either_args = len(either_argname_value_pairs)
217+
for n_extra_args in reversed(range(n_either_args + 1)):
218+
extra_args = [v for _, v in either_argname_value_pairs[:n_extra_args]]
219+
if n_extra_args < n_either_args:
220+
extra_kwargs = dict(either_argname_value_pairs[n_extra_args:])
221+
else:
222+
extra_kwargs = {}
223+
args = list(example_values[Parameter.POSITIONAL_ONLY].values())
224+
args += extra_args
225+
kwargs = copy(example_values[Parameter.KEYWORD_ONLY])
226+
if len(extra_kwargs) != 0:
227+
kwargs.update(extra_kwargs)
228+
f_func = make_pretty_func(func_name, args, kwargs)
229+
note(f"trying {f_func}")
230+
func(*args, **kwargs)
175231

176232

177233
def _test_func_signature(

0 commit comments

Comments
 (0)