@@ -185,17 +185,17 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
185
185
assert out .dtype == out_dtype , f'{ out .dtype = !s} , but should be { out_dtype } '
186
186
187
187
188
- promotion_table_params : List [Tuple [Tuple [DT , DT ], DT ]] = []
188
+ promotion_params : List [Tuple [Tuple [DT , DT ], DT ]] = []
189
189
for (dtype1 , dtype2 ), promoted_dtype in dh .promotion_table .items ():
190
190
p = pytest .param (
191
191
(dtype1 , dtype2 ),
192
192
promoted_dtype ,
193
193
id = make_id ('' , (dtype1 , dtype2 ), promoted_dtype ),
194
194
)
195
- promotion_table_params .append (p )
195
+ promotion_params .append (p )
196
196
197
197
198
- @pytest .mark .parametrize ('in_dtypes, out_dtype' , promotion_table_params )
198
+ @pytest .mark .parametrize ('in_dtypes, out_dtype' , promotion_params )
199
199
@given (shapes = hh .mutually_broadcastable_shapes (3 ), data = st .data ())
200
200
def test_where (in_dtypes , out_dtype , shapes , data ):
201
201
x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
@@ -205,10 +205,10 @@ def test_where(in_dtypes, out_dtype, shapes, data):
205
205
assert out .dtype == out_dtype , f'{ out .dtype = !s} , but should be { out_dtype } '
206
206
207
207
208
- numeric_promotion_table_params = promotion_table_params [1 :]
208
+ numeric_promotion_params = promotion_params [1 :]
209
209
210
210
211
- @pytest .mark .parametrize ('in_dtypes, out_dtype' , numeric_promotion_table_params )
211
+ @pytest .mark .parametrize ('in_dtypes, out_dtype' , numeric_promotion_params )
212
212
@given (shapes = hh .mutually_broadcastable_shapes (2 , min_dims = 1 ), data = st .data ())
213
213
def test_matmul (in_dtypes , out_dtype , shapes , data ):
214
214
x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
@@ -217,6 +217,18 @@ def test_matmul(in_dtypes, out_dtype, shapes, data):
217
217
assert out .dtype == out_dtype , f'{ out .dtype = !s} , but should be { out_dtype } '
218
218
219
219
220
+ @pytest .mark .parametrize ('in_dtypes, out_dtype' , numeric_promotion_params )
221
+ @given (shapes = hh .mutually_broadcastable_shapes (2 ), data = st .data ())
222
+ def test_tensordot (in_dtypes , out_dtype , shapes , data ):
223
+ pass # TODO: figure out acceptable shape behaviour
224
+
225
+
226
+ @pytest .mark .parametrize ('in_dtypes, out_dtype' , numeric_promotion_params )
227
+ @given (shapes = hh .mutually_broadcastable_shapes (2 ), data = st .data ())
228
+ def test_vecdot (in_dtypes , out_dtype , shapes , data ):
229
+ pass # TODO: figure out acceptable shape behaviour
230
+
231
+
220
232
op_params : List [Tuple [str , str , Tuple [DT , ...], DT ]] = []
221
233
op_to_symbol = {** dh .unary_op_to_symbol , ** dh .binary_op_to_symbol }
222
234
for op , symbol in op_to_symbol .items ():
0 commit comments