Skip to content

Commit 422401e

Browse files
committed
Rudimentary support for uninspectable signatures
1 parent a1d92cb commit 422401e

File tree

2 files changed

+166
-74
lines changed

2 files changed

+166
-74
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ def result_type(*dtypes: DataType):
199199
return result
200200

201201

202-
r_in_dtypes = re.compile("x1?: array\n.+Should have an? (.+) data type.")
202+
r_alias = re.compile("[aA]lias")
203+
r_in_dtypes = re.compile("x1?: array\n.+have an? (.+) data type.")
203204
r_int_note = re.compile(
204205
"If one or both of the input arrays have integer data types, "
205206
"the result is implementation-dependent"
@@ -285,6 +286,8 @@ def result_type(*dtypes: DataType):
285286
"trunc": False,
286287
# searching
287288
"where": False,
289+
# linalg
290+
"matmul": False,
288291
}
289292

290293

@@ -328,7 +331,7 @@ def result_type(*dtypes: DataType):
328331
"__gt__": "greater",
329332
"__le__": "less_equal",
330333
"__lt__": "less",
331-
# '__matmul__': 'matmul', # TODO: support matmul
334+
"__matmul__": "matmul",
332335
"__mod__": "remainder",
333336
"__mul__": "multiply",
334337
"__ne__": "not_equal",
@@ -364,6 +367,7 @@ def result_type(*dtypes: DataType):
364367
func_in_dtypes["__int__"] = all_int_dtypes
365368
func_in_dtypes["__index__"] = all_int_dtypes
366369
func_in_dtypes["__float__"] = float_dtypes
370+
func_in_dtypes["from_dlpack"] = numeric_dtypes
367371
func_in_dtypes["__dlpack__"] = numeric_dtypes
368372

369373

array_api_tests/test_signatures.py

Lines changed: 160 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,44 @@
1+
"""
2+
We're not interested in being 100% strict - instead we focus on areas which
3+
could affect interop, e.g. with
4+
5+
def add(x1, x2, /):
6+
...
7+
8+
x1 and x2 don't need to be pos-only for the purposes of interoperability, but with
9+
10+
def squeeze(x, /, axis):
11+
...
12+
13+
axis has to be pos-or-keyword to support both styles
14+
15+
>>> squeeze(x, 0)
16+
...
17+
>>> squeeze(x, axis=0)
18+
...
19+
20+
"""
21+
from collections import defaultdict
122
from inspect import Parameter, Signature, signature
23+
from itertools import chain
224
from types import FunctionType
3-
from typing import Callable, Dict
25+
from typing import Callable, DefaultDict, Dict, List
426

527
import pytest
628
from hypothesis import given
29+
from hypothesis import strategies as st
730

31+
from . import dtype_helpers as dh
832
from . import hypothesis_helpers as hh
933
from . import xps
34+
from ._array_module import _UndefinedStub
1035
from ._array_module import mod as xp
1136
from .stubs import array_methods, category_to_funcs, extension_to_funcs
37+
from .typing import DataType, Shape
1238

1339
pytestmark = pytest.mark.ci
1440

41+
1542
kind_to_str: Dict[Parameter, str] = {
1643
Parameter.POSITIONAL_OR_KEYWORD: "normal argument",
1744
Parameter.POSITIONAL_ONLY: "pos-only argument",
@@ -20,91 +47,149 @@
2047
Parameter.VAR_KEYWORD: "star-kwargs (i.e. **kwargs) argument",
2148
}
2249

50+
VAR_KINDS = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
2351

24-
def _test_signature(
25-
func: Callable, stub: FunctionType, ignore_first_stub_param: bool = False
26-
):
27-
"""
28-
Signature of function is correct enough to not affect interoperability
29-
30-
We're not interested in being 100% strict - instead we focus on areas which
31-
could affect interop, e.g. with
32-
33-
def add(x1, x2, /):
34-
...
3552

36-
x1 and x2 don't need to be pos-only for the purposes of interoperability, but with
37-
38-
def squeeze(x, /, axis):
39-
...
40-
41-
axis has to be pos-or-keyword to support both styles
42-
43-
>>> squeeze(x, 0)
44-
...
45-
>>> squeeze(x, axis=0)
46-
...
47-
48-
"""
49-
try:
50-
sig = signature(func)
51-
except ValueError:
52-
pytest.skip(
53-
msg=f"type({stub.__name__})={type(func)} not supported by inspect.signature()"
54-
)
53+
def _test_inspectable_func(sig: Signature, stub_sig: Signature):
5554
params = list(sig.parameters.values())
56-
57-
stub_sig = signature(stub)
5855
stub_params = list(stub_sig.parameters.values())
59-
if ignore_first_stub_param:
60-
stub_params = stub_params[1:]
61-
stub = Signature(
62-
parameters=stub_params, return_annotation=stub_sig.return_annotation
63-
)
64-
6556
# We're not interested if the array module has additional arguments, so we
6657
# only iterate through the arguments listed in the spec.
6758
for i, stub_param in enumerate(stub_params):
68-
assert (
69-
len(params) >= i + 1
70-
), f"Argument '{stub_param.name}' missing from signature"
71-
param = params[i]
59+
if sig is not None:
60+
assert (
61+
len(params) >= i + 1
62+
), f"Argument '{stub_param.name}' missing from signature"
63+
param = params[i]
7264

7365
# We're not interested in the name if it isn't actually used
74-
if stub_param.kind not in [
66+
if sig is not None and stub_param.kind not in [
7567
Parameter.POSITIONAL_ONLY,
76-
Parameter.VAR_POSITIONAL,
77-
Parameter.VAR_KEYWORD,
68+
*VAR_KINDS,
7869
]:
7970
assert (
8071
param.name == stub_param.name
8172
), f"Expected argument '{param.name}' to be named '{stub_param.name}'"
8273

83-
if (
84-
stub_param.name in ["x", "x1", "x2"]
85-
and stub_param.kind != Parameter.POSITIONAL_ONLY
86-
):
87-
pytest.skip(
88-
f"faulty spec - argument {stub_param.name} should be a "
89-
f"{kind_to_str[Parameter.POSITIONAL_ONLY]}"
90-
)
91-
f_kind = kind_to_str[param.kind]
9274
f_stub_kind = kind_to_str[stub_param.kind]
93-
if stub_param.kind in [
94-
Parameter.POSITIONAL_OR_KEYWORD,
95-
Parameter.VAR_POSITIONAL,
96-
Parameter.VAR_KEYWORD,
97-
]:
98-
assert (
99-
param.kind == stub_param.kind
100-
), f"{param.name} is a {f_kind}, but should be a {f_stub_kind}"
75+
if stub_param.kind in [Parameter.POSITIONAL_OR_KEYWORD, *VAR_KINDS]:
76+
if sig is not None:
77+
assert param.kind == stub_param.kind, (
78+
f"{param.name} is a {kind_to_str[param.kind]}, "
79+
f"but should be a {f_stub_kind}"
80+
)
81+
else:
82+
pass
10183
else:
10284
# TODO: allow for kw-only args to be out-of-order
103-
assert param.kind in [stub_param.kind, Parameter.POSITIONAL_OR_KEYWORD], (
104-
f"{param.name} is a {f_kind}, "
105-
f"but should be a {f_stub_kind} "
106-
f"(or at least a {kind_to_str[Parameter.POSITIONAL_OR_KEYWORD]})"
85+
if sig is not None:
86+
assert param.kind in [
87+
stub_param.kind,
88+
Parameter.POSITIONAL_OR_KEYWORD,
89+
], (
90+
f"{param.name} is a {kind_to_str[param.kind]}, "
91+
f"but should be a {f_stub_kind} "
92+
f"(or at least a {kind_to_str[Parameter.POSITIONAL_OR_KEYWORD]})"
93+
)
94+
else:
95+
pass
96+
97+
def shapes(**kw) -> st.SearchStrategy[Shape]:
98+
if "min_side" not in kw.keys():
99+
kw["min_side"] = 1
100+
return hh.shapes(**kw)
101+
102+
103+
matrixy_funcs: List[str] = [
104+
f.__name__
105+
for f in chain(category_to_funcs["linear_algebra"], extension_to_funcs["linalg"])
106+
]
107+
matrixy_funcs += ["__matmul__", "triu", "tril"]
108+
func_to_shapes: DefaultDict[str, st.SearchStrategy[Shape]] = defaultdict(
109+
shapes,
110+
{
111+
**{k: st.just(()) for k in ["__bool__", "__int__", "__index__", "__float__"]},
112+
"sort": shapes(min_dims=1), # for axis=-1,
113+
**{k: shapes(min_dims=2) for k in matrixy_funcs},
114+
# Override for some matrixy functions
115+
"cross": shapes(min_side=3, max_side=3, min_dims=3, max_dims=3),
116+
"outer": shapes(min_dims=1, max_dims=1),
117+
},
118+
)
119+
120+
121+
def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]:
122+
if func_name in dh.func_in_dtypes.keys():
123+
dtypes = dh.func_in_dtypes[func_name]
124+
if hh.FILTER_UNDEFINED_DTYPES:
125+
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
126+
return st.sampled_from(dtypes)
127+
else:
128+
return xps.scalar_dtypes()
129+
130+
131+
@given(data=st.data())
132+
def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature, data):
133+
if func_name in ["cholesky", "inv"]:
134+
func(xp.asarray([[1.0, 0.0], [0.0, 1.0]]))
135+
return
136+
elif func_name == "solve":
137+
func(xp.asarray([[1.0, 2.0], [3.0, 5.0]]), xp.asarray([1.0, 2.0]))
138+
return
139+
140+
pos_argname_to_example_value = {}
141+
normal_argname_to_example_value = {}
142+
kw_argname_to_example_value = {}
143+
for stub_param in stub_sig.parameters.values():
144+
if stub_param.name in ["x", "x1"]:
145+
dtypes = get_dtypes_strategy(func_name)
146+
shapes = func_to_shapes[func_name]
147+
example_value = data.draw(
148+
xps.arrays(dtype=dtypes, shape=shapes), label=stub_param.name
107149
)
150+
elif stub_param.name == "x2":
151+
assert "x1" in pos_argname_to_example_value.keys() # sanity check
152+
x1 = pos_argname_to_example_value["x1"]
153+
example_value = data.draw(
154+
xps.arrays(dtype=x1.dtype, shape=x1.shape), label="x2"
155+
)
156+
else:
157+
if stub_param.default != Parameter.empty:
158+
example_value = stub_param.default
159+
else:
160+
pytest.skip(f"No example value for argument '{stub_param.name}'")
161+
162+
if stub_param.kind == Parameter.POSITIONAL_ONLY:
163+
pos_argname_to_example_value[stub_param.name] = example_value
164+
elif stub_param.kind == Parameter.POSITIONAL_OR_KEYWORD:
165+
normal_argname_to_example_value[stub_param.name] = example_value
166+
elif stub_param.kind == Parameter.KEYWORD_ONLY:
167+
kw_argname_to_example_value[stub_param.name] = example_value
168+
else:
169+
pytest.skip()
170+
171+
if len(normal_argname_to_example_value) == 0:
172+
func(*pos_argname_to_example_value.values(), **kw_argname_to_example_value)
173+
else:
174+
pass # TODO
175+
176+
177+
def _test_func_signature(
178+
func: Callable, stub: FunctionType, ignore_first_stub_param: bool = False
179+
):
180+
stub_sig = signature(stub)
181+
if ignore_first_stub_param:
182+
stub_params = list(stub_sig.parameters.values())
183+
del stub_params[0]
184+
stub_sig = Signature(
185+
parameters=stub_params, return_annotation=stub_sig.return_annotation
186+
)
187+
188+
try:
189+
sig = signature(func)
190+
_test_inspectable_func(sig, stub_sig)
191+
except ValueError:
192+
_test_uninspectable_func(stub.__name__, func, stub_sig)
108193

109194

110195
@pytest.mark.parametrize(
@@ -115,7 +200,7 @@ def squeeze(x, /, axis):
115200
def test_func_signature(stub: FunctionType):
116201
assert hasattr(xp, stub.__name__), f"{stub.__name__} not found in array module"
117202
func = getattr(xp, stub.__name__)
118-
_test_signature(func, stub)
203+
_test_func_signature(func, stub)
119204

120205

121206
extension_and_stub_params = []
@@ -134,13 +219,16 @@ def test_extension_func_signature(extension: str, stub: FunctionType):
134219
mod, stub.__name__
135220
), f"{stub.__name__} not found in {extension} extension"
136221
func = getattr(mod, stub.__name__)
137-
_test_signature(func, stub)
222+
_test_func_signature(func, stub)
138223

139224

140225
@pytest.mark.parametrize("stub", array_methods, ids=lambda f: f.__name__)
141-
@given(x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()))
142-
def test_array_method_signature(stub: FunctionType, x):
226+
@given(data=st.data())
227+
def test_array_method_signature(stub: FunctionType, data):
228+
dtypes = get_dtypes_strategy(stub.__name__)
229+
shapes = func_to_shapes[stub.__name__]
230+
x = data.draw(xps.arrays(dtype=dtypes, shape=shapes), label="x")
143231
assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}"
144232
method = getattr(x, stub.__name__)
145233
# Ignore 'self' arg in stub, which won't be present in instantiated objects.
146-
_test_signature(method, stub, ignore_first_stub_param=True)
234+
_test_func_signature(method, stub, ignore_first_stub_param=True)

0 commit comments

Comments
 (0)