Skip to content

Commit 4f161dd

Browse files
committed
Fix some issues with the signatures tests
1 parent ceeec55 commit 4f161dd

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

array_api_tests/test_signatures.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def example_argument(arg, func_name, dtype):
3737
# (it can have the same behavior as the default, just not literally the
3838
# same value).
3939
known_args = dict(
40+
api_version='2021.1',
4041
arrays=(ones((1, 3, 3), dtype=dtype), ones((1, 3, 3), dtype=dtype)),
4142
# These cannot be the same as each other, which is why all our test
4243
# arrays have to have at least 3 dimensions.
@@ -97,6 +98,13 @@ def example_argument(arg, func_name, dtype):
9798
# finfo requires a float dtype and iinfo requires an int dtype
9899
elif func_name == 'iinfo' and arg == 'type':
99100
return int64
101+
# tensordot args must be contractible with each other
102+
elif func_name == 'tensordot' and arg == 'x2':
103+
return ones((3, 3, 1), dtype=dtype)
104+
# tensordot "axes" is either a number representing the number of
105+
# contractible axes or a 2-tuple or axes
106+
elif func_name == 'tensordot' and arg == 'axes':
107+
return 1
100108
return known_args[arg]
101109
else:
102110
raise RuntimeError(f"Don't know how to test argument {arg}. Please update test_signatures.py")
@@ -151,15 +159,15 @@ def test_function_positional_args(name):
151159
pytest.skip(f"{name} is not a function, skipping.")
152160
mod_func = getattr(_mod, name)
153161
argspec = inspect.getfullargspec(stub_func)
154-
args = argspec.args
162+
func_args = argspec.args
155163
if name.startswith('__'):
156-
args = args[1:]
157-
nargs = [len(args)]
164+
func_args = func_args[1:]
165+
nargs = [len(func_args)]
158166
if argspec.defaults:
159167
# The actual default values are checked in the specific tests
160-
nargs.extend([len(args) - i for i in range(1, len(argspec.defaults) + 1)])
168+
nargs.extend([len(func_args) - i for i in range(1, len(argspec.defaults) + 1)])
161169

162-
args = [example_argument(arg, name, dtype) for arg in args]
170+
args = [example_argument(arg, name, dtype) for arg in func_args]
163171
if not args:
164172
args = [example_argument('x', name, dtype)]
165173
else:

0 commit comments

Comments
 (0)