Skip to content

Commit 2c32ea1

Browse files
committed
Rudimentary unified signature tests
Just top-level functions for now
1 parent 88d8236 commit 2c32ea1

File tree

2 files changed

+79
-281
lines changed

2 files changed

+79
-281
lines changed

array_api_tests/stubs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from types import FunctionType, ModuleType
66
from typing import Dict, List
77

8-
__all__ = ["category_to_funcs", "array", "extension_to_funcs"]
8+
__all__ = ["category_to_funcs", "array", "EXTENSIONS", "extension_to_funcs"]
99

1010

1111
spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / "API_specification"
@@ -31,11 +31,10 @@
3131
assert all(isinstance(o, FunctionType) for o in objects)
3232
category_to_funcs[category] = objects
3333

34-
3534
array = name_to_mod["array_object"].array
3635

3736

38-
EXTENSIONS = ["linalg"]
37+
EXTENSIONS: str = ["linalg"]
3938
extension_to_funcs: Dict[str, List[FunctionType]] = {}
4039
for ext in EXTENSIONS:
4140
mod = name_to_mod[ext]

array_api_tests/test_signatures.py

Lines changed: 77 additions & 278 deletions
Original file line numberDiff line numberDiff line change
@@ -1,289 +1,88 @@
1-
import inspect
1+
from inspect import Parameter, signature
2+
from types import FunctionType
3+
from typing import Dict
24

35
import pytest
46

5-
from ._array_module import mod, mod_name, ones, eye, float64, bool, int64, _UndefinedStub
6-
from .pytest_helpers import raises, doesnt_raise
7-
from . import dtype_helpers as dh
7+
from ._array_module import mod as xp
8+
from .stubs import category_to_funcs
89

9-
from . import function_stubs
10+
kind_to_str: Dict[Parameter, str] = {
11+
Parameter.POSITIONAL_OR_KEYWORD: "normal argument",
12+
Parameter.POSITIONAL_ONLY: "pos-only argument",
13+
Parameter.KEYWORD_ONLY: "keyword-only argument",
14+
Parameter.VAR_POSITIONAL: "star-args (i.e. *args) argument",
15+
Parameter.VAR_KEYWORD: "star-kwargs (i.e. **kwargs) argument",
16+
}
1017

1118

12-
submodules = [m for m in dir(function_stubs) if
13-
inspect.ismodule(getattr(function_stubs, m)) and not
14-
m.startswith('_')]
15-
16-
def stub_module(name):
17-
for m in submodules:
18-
if name in getattr(function_stubs, m).__all__:
19-
return m
20-
21-
def extension_module(name):
22-
return name in submodules and name in function_stubs.__all__
23-
24-
extension_module_names = []
25-
for n in function_stubs.__all__:
26-
if extension_module(n):
27-
extension_module_names.extend([f'{n}.{i}' for i in getattr(function_stubs, n).__all__])
28-
29-
30-
params = []
31-
for name in function_stubs.__all__:
32-
marks = []
33-
if extension_module(name):
34-
marks.append(pytest.mark.xp_extension(name))
35-
params.append(pytest.param(name, marks=marks))
36-
for name in extension_module_names:
37-
ext = name.split('.')[0]
38-
mark = pytest.mark.xp_extension(ext)
39-
params.append(pytest.param(name, marks=[mark]))
40-
41-
42-
def array_method(name):
43-
return stub_module(name) == 'array_object'
44-
45-
def function_category(name):
46-
return stub_module(name).rsplit('_', 1)[0].replace('_', ' ')
47-
48-
def example_argument(arg, func_name, dtype):
49-
"""
50-
Get an example argument for the argument arg for the function func_name
51-
52-
The full tests for function behavior is in other files. We just need to
53-
have an example input for each argument name that should work so that we
54-
can check if the argument is implemented at all.
55-
19+
@pytest.mark.parametrize(
20+
"stub",
21+
[s for stubs in category_to_funcs.values() for s in stubs],
22+
ids=lambda f: f.__name__,
23+
)
24+
def test_signature(stub: FunctionType):
5625
"""
57-
# Note: for keyword arguments that have a default, this should be
58-
# different from the default, as the default argument is tested separately
59-
# (it can have the same behavior as the default, just not literally the
60-
# same value).
61-
known_args = dict(
62-
api_version='2021.1',
63-
arrays=(ones((1, 3, 3), dtype=dtype), ones((1, 3, 3), dtype=dtype)),
64-
# These cannot be the same as each other, which is why all our test
65-
# arrays have to have at least 3 dimensions.
66-
axis1=2,
67-
axis2=2,
68-
axis=1,
69-
axes=(2, 1, 0),
70-
copy=True,
71-
correction=1.0,
72-
descending=True,
73-
# TODO: This will only work on the NumPy implementation. The exact
74-
# value of the device keyword will vary across implementations, so we
75-
# need some way to infer it or for libraries to specify a list of
76-
# valid devices.
77-
device='cpu',
78-
dtype=float64,
79-
endpoint=False,
80-
fill_value=1.0,
81-
from_=int64,
82-
full_matrices=False,
83-
k=1,
84-
keepdims=True,
85-
key=(0, 0),
86-
indexing='ij',
87-
mode='complete',
88-
n=2,
89-
n_cols=1,
90-
n_rows=1,
91-
num=2,
92-
offset=1,
93-
ord=1,
94-
obj = [[[1, 1, 1], [1, 1, 1], [1, 1, 1]]],
95-
other=ones((3, 3), dtype=dtype),
96-
return_counts=True,
97-
return_index=True,
98-
return_inverse=True,
99-
rtol=1e-10,
100-
self=ones((3, 3), dtype=dtype),
101-
shape=(1, 3, 3),
102-
shift=1,
103-
sorted=False,
104-
stable=False,
105-
start=0,
106-
step=2,
107-
stop=1,
108-
# TODO: Update this to be non-default. See the comment on "device" above.
109-
stream=None,
110-
to=float64,
111-
type=float64,
112-
upper=True,
113-
value=0,
114-
x1=ones((1, 3, 3), dtype=dtype),
115-
x2=ones((1, 3, 3), dtype=dtype),
116-
x=ones((1, 3, 3), dtype=dtype),
117-
)
118-
if not isinstance(bool, _UndefinedStub):
119-
known_args['condition'] = ones((1, 3, 3), dtype=bool),
26+
Signature of function is correct enough to not affect interoperability
12027
121-
if arg in known_args:
122-
# Special cases:
123-
124-
# squeeze() requires an axis of size 1, but other functions such as
125-
# cross() require axes of size >1
126-
if func_name == 'squeeze' and arg == 'axis':
127-
return 0
128-
# ones() is not invertible
129-
# finfo requires a float dtype and iinfo requires an int dtype
130-
elif func_name == 'iinfo' and arg == 'type':
131-
return int64
132-
# tensordot args must be contractible with each other
133-
elif func_name == 'tensordot' and arg == 'x2':
134-
return ones((3, 3, 1), dtype=dtype)
135-
# tensordot "axes" is either a number representing the number of
136-
# contractible axes or a 2-tuple or axes
137-
elif func_name == 'tensordot' and arg == 'axes':
138-
return 1
139-
# The inputs to outer() must be 1-dimensional
140-
elif func_name == 'outer' and arg in ['x1', 'x2']:
141-
return ones((3,), dtype=dtype)
142-
# Linear algebra functions tend to error if the input isn't "nice" as
143-
# a matrix
144-
elif arg.startswith('x') and func_name in function_stubs.linalg.__all__:
145-
return eye(3)
146-
return known_args[arg]
147-
else:
148-
raise RuntimeError(f"Don't know how to test argument {arg}. Please update test_signatures.py")
149-
150-
@pytest.mark.parametrize('name', params)
151-
def test_has_names(name):
152-
if extension_module(name):
153-
assert hasattr(mod, name), f'{mod_name} is missing the {name} extension'
154-
elif '.' in name:
155-
extension_mod, name = name.split('.')
156-
assert hasattr(getattr(mod, extension_mod), name), f"{mod_name} is missing the {function_category(name)} extension function {name}()"
157-
elif array_method(name):
158-
arr = ones((1, 1))
159-
if getattr(function_stubs.array_object, name) is None:
160-
assert hasattr(arr, name), f"The array object is missing the attribute {name}"
161-
else:
162-
assert hasattr(arr, name), f"The array object is missing the method {name}()"
163-
else:
164-
assert hasattr(mod, name), f"{mod_name} is missing the {function_category(name)} function {name}()"
28+
We're not interested in being 100% strict - instead we focus on areas which
29+
could affect interop, e.g. with
16530
166-
@pytest.mark.parametrize('name', params)
167-
def test_function_positional_args(name):
168-
# Note: We can't actually test that positional arguments are
169-
# positional-only, as that would require knowing the argument name and
170-
# checking that it can't be used as a keyword argument. But argument name
171-
# inspection does not work for most array library functions that are not
172-
# written in pure Python (e.g., it won't work for numpy ufuncs).
31+
def add(x1, x2, /):
32+
...
17333
174-
if extension_module(name):
175-
return
176-
177-
dtype = None
178-
if (name.startswith('__i') and name not in ['__int__', '__invert__', '__index__']
179-
or name.startswith('__r') and name != '__rshift__'):
180-
n = f'__{name[3:]}'
181-
else:
182-
n = name
183-
in_dtypes = dh.func_in_dtypes.get(n, dh.float_dtypes)
184-
if bool in in_dtypes:
185-
dtype = bool
186-
elif all(d in in_dtypes for d in dh.all_int_dtypes):
187-
dtype = int64
188-
189-
if array_method(name):
190-
if name == '__bool__':
191-
_mod = ones((), dtype=bool)
192-
elif name in ['__int__', '__index__']:
193-
_mod = ones((), dtype=int64)
194-
elif name == '__float__':
195-
_mod = ones((), dtype=float64)
196-
else:
197-
_mod = example_argument('self', name, dtype)
198-
stub_func = getattr(function_stubs, name)
199-
elif '.' in name:
200-
extension_module_name, name = name.split('.')
201-
_mod = getattr(mod, extension_module_name)
202-
stub_func = getattr(getattr(function_stubs, extension_module_name), name)
203-
else:
204-
_mod = mod
205-
stub_func = getattr(function_stubs, name)
206-
207-
if not hasattr(_mod, name):
208-
pytest.skip(f"{mod_name} does not have {name}(), skipping.")
209-
if stub_func is None:
210-
# TODO: Can we make this skip the parameterization entirely?
211-
pytest.skip(f"{name} is not a function, skipping.")
212-
mod_func = getattr(_mod, name)
213-
argspec = inspect.getfullargspec(stub_func)
214-
func_args = argspec.args
215-
if func_args[:1] == ['self']:
216-
func_args = func_args[1:]
217-
nargs = [len(func_args)]
218-
if argspec.defaults:
219-
# The actual default values are checked in the specific tests
220-
nargs.extend([len(func_args) - i for i in range(1, len(argspec.defaults) + 1)])
221-
222-
args = [example_argument(arg, name, dtype) for arg in func_args]
223-
if not args:
224-
args = [example_argument('x', name, dtype)]
225-
else:
226-
# Duplicate the last positional argument for the n+1 test.
227-
args = args + [args[-1]]
228-
229-
kwonlydefaults = argspec.kwonlydefaults or {}
230-
required_kwargs = {arg: example_argument(arg, name, dtype) for arg in argspec.kwonlyargs if arg not in kwonlydefaults}
231-
232-
for n in range(nargs[0]+2):
233-
if name == 'result_type' and n == 0:
234-
# This case is not encoded in the signature, but isn't allowed.
235-
continue
236-
if n in nargs:
237-
doesnt_raise(lambda: mod_func(*args[:n], **required_kwargs))
238-
elif argspec.varargs:
239-
pass
34+
x1 and x2 don't need to be pos-only for the purposes of interoperability.
35+
"""
36+
assert hasattr(xp, stub.__name__), f"{stub.__name__} not found in array module"
37+
func = getattr(xp, stub.__name__)
38+
39+
try:
40+
sig = signature(func)
41+
except ValueError:
42+
pytest.skip(msg=f"type({stub.__name__})={type(func)} not supported by inspect")
43+
stub_sig = signature(stub)
44+
params = list(sig.parameters.values())
45+
stub_params = list(stub_sig.parameters.values())
46+
# We're not interested if the array module has additional arguments, so we
47+
# only iterate through the arguments listed in the spec.
48+
for i, stub_param in enumerate(stub_params):
49+
assert (
50+
len(params) >= i + 1
51+
), f"Argument '{stub_param.name}' missing from signature"
52+
param = params[i]
53+
54+
# We're not interested in the name if it isn't actually used
55+
if stub_param.kind not in [
56+
Parameter.POSITIONAL_ONLY,
57+
Parameter.VAR_POSITIONAL,
58+
Parameter.VAR_KEYWORD,
59+
]:
60+
assert (
61+
param.name == stub_param.name
62+
), f"Expected argument '{param.name}' to be named '{stub_param.name}'"
63+
64+
if (
65+
stub_param.name in ["x", "x1", "x2"]
66+
and stub_param.kind != Parameter.POSITIONAL_ONLY
67+
):
68+
pytest.skip(
69+
f"faulty spec - {stub_param.name} should be a "
70+
f"{kind_to_str[Parameter.POSITIONAL_OR_KEYWORD]}"
71+
)
72+
f_kind = kind_to_str[param.kind]
73+
f_stub_kind = kind_to_str[stub_param.kind]
74+
if stub_param.kind in [
75+
Parameter.POSITIONAL_OR_KEYWORD,
76+
Parameter.VAR_POSITIONAL,
77+
Parameter.VAR_KEYWORD,
78+
]:
79+
assert param.kind == stub_param.kind, (
80+
f"{param.name} is a {f_kind}, " f"but should be a {f_stub_kind}"
81+
)
24082
else:
241-
# NumPy ufuncs raise ValueError instead of TypeError
242-
raises((TypeError, ValueError), lambda: mod_func(*args[:n]), f"{name}() should not accept {n} positional arguments")
243-
244-
@pytest.mark.parametrize('name', params)
245-
def test_function_keyword_only_args(name):
246-
if extension_module(name):
247-
return
248-
249-
if array_method(name):
250-
_mod = ones((1, 1))
251-
stub_func = getattr(function_stubs, name)
252-
elif '.' in name:
253-
extension_module_name, name = name.split('.')
254-
_mod = getattr(mod, extension_module_name)
255-
stub_func = getattr(getattr(function_stubs, extension_module_name), name)
256-
else:
257-
_mod = mod
258-
stub_func = getattr(function_stubs, name)
259-
260-
if not hasattr(_mod, name):
261-
pytest.skip(f"{mod_name} does not have {name}(), skipping.")
262-
if stub_func is None:
263-
# TODO: Can we make this skip the parameterization entirely?
264-
pytest.skip(f"{name} is not a function, skipping.")
265-
mod_func = getattr(_mod, name)
266-
argspec = inspect.getfullargspec(stub_func)
267-
args = argspec.args
268-
if args[:1] == ['self']:
269-
args = args[1:]
270-
kwonlyargs = argspec.kwonlyargs
271-
kwonlydefaults = argspec.kwonlydefaults or {}
272-
dtype = None
273-
274-
args = [example_argument(arg, name, dtype) for arg in args]
275-
276-
for arg in kwonlyargs:
277-
value = example_argument(arg, name, dtype)
278-
# The "only" part of keyword-only is tested by the positional test above.
279-
doesnt_raise(lambda: mod_func(*args, **{arg: value}),
280-
f"{name}() should accept the keyword-only argument {arg!r}")
281-
282-
# Make sure the default is accepted. These tests are not granular
283-
# enough to test that the default is actually the default, i.e., gives
284-
# the same value if the keyword isn't passed. That is tested in the
285-
# specific function tests.
286-
if arg in kwonlydefaults:
287-
default_value = kwonlydefaults[arg]
288-
doesnt_raise(lambda: mod_func(*args, **{arg: default_value}),
289-
f"{name}() should accept the default value {default_value!r} for the keyword-only argument {arg!r}")
83+
# TODO: allow for kw-only args to be out-of-order
84+
assert param.kind in [stub_param.kind, Parameter.POSITIONAL_OR_KEYWORD], (
85+
f"{param.name} is a {f_kind}, "
86+
f"but should be a {f_stub_kind} "
87+
f"(or at least a {kind_to_str[Parameter.POSITIONAL_OR_KEYWORD]})"
88+
)

0 commit comments

Comments
 (0)