@@ -37,6 +37,7 @@ def example_argument(arg, func_name, dtype):
37
37
# (it can have the same behavior as the default, just not literally the
38
38
# same value).
39
39
known_args = dict (
40
+ api_version = '2021.1' ,
40
41
arrays = (ones ((1 , 3 , 3 ), dtype = dtype ), ones ((1 , 3 , 3 ), dtype = dtype )),
41
42
# These cannot be the same as each other, which is why all our test
42
43
# arrays have to have at least 3 dimensions.
@@ -97,6 +98,13 @@ def example_argument(arg, func_name, dtype):
97
98
# finfo requires a float dtype and iinfo requires an int dtype
98
99
elif func_name == 'iinfo' and arg == 'type' :
99
100
return int64
101
+ # tensordot args must be contractible with each other
102
+ elif func_name == 'tensordot' and arg == 'x2' :
103
+ return ones ((3 , 3 , 1 ), dtype = dtype )
104
+ # tensordot "axes" is either a number representing the number of
105
+ # contractible axes or a 2-tuple or axes
106
+ elif func_name == 'tensordot' and arg == 'axes' :
107
+ return 1
100
108
return known_args [arg ]
101
109
else :
102
110
raise RuntimeError (f"Don't know how to test argument { arg } . Please update test_signatures.py" )
@@ -151,15 +159,15 @@ def test_function_positional_args(name):
151
159
pytest .skip (f"{ name } is not a function, skipping." )
152
160
mod_func = getattr (_mod , name )
153
161
argspec = inspect .getfullargspec (stub_func )
154
- args = argspec .args
162
+ func_args = argspec .args
155
163
if name .startswith ('__' ):
156
- args = args [1 :]
157
- nargs = [len (args )]
164
+ func_args = func_args [1 :]
165
+ nargs = [len (func_args )]
158
166
if argspec .defaults :
159
167
# The actual default values are checked in the specific tests
160
- nargs .extend ([len (args ) - i for i in range (1 , len (argspec .defaults ) + 1 )])
168
+ nargs .extend ([len (func_args ) - i for i in range (1 , len (argspec .defaults ) + 1 )])
161
169
162
- args = [example_argument (arg , name , dtype ) for arg in args ]
170
+ args = [example_argument (arg , name , dtype ) for arg in func_args ]
163
171
if not args :
164
172
args = [example_argument ('x' , name , dtype )]
165
173
else :
0 commit comments