|
19 | 19 |
|
20 | 20 | DT = Type
|
21 | 21 | ScalarType = Union[Type[bool], Type[int], Type[float]]
|
| 22 | +Param = Tuple |
22 | 23 |
|
23 | 24 |
|
24 | 25 | def multi_promotable_dtypes(
|
@@ -126,7 +127,7 @@ def make_id(
|
126 | 127 | return f'{func_name}({f_args}) -> {f_out_dtype}'
|
127 | 128 |
|
128 | 129 |
|
129 |
| -func_params: List[Tuple[str, Tuple[DT, ...], DT]] = [] |
| 130 | +func_params: List[Param[str, Tuple[DT, ...], DT]] = [] |
130 | 131 | for func_name in elementwise_functions.__all__:
|
131 | 132 | valid_in_dtypes = dh.func_in_dtypes[func_name]
|
132 | 133 | ndtypes = nargs(func_name)
|
@@ -185,7 +186,7 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
|
185 | 186 | assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
|
186 | 187 |
|
187 | 188 |
|
188 |
| -promotion_params: List[Tuple[Tuple[DT, DT], DT]] = [] |
| 189 | +promotion_params: List[Param[Tuple[DT, DT], DT]] = [] |
189 | 190 | for (dtype1, dtype2), promoted_dtype in dh.promotion_table.items():
|
190 | 191 | p = pytest.param(
|
191 | 192 | (dtype1, dtype2),
|
@@ -235,7 +236,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
|
235 | 236 | assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
|
236 | 237 |
|
237 | 238 |
|
238 |
| -op_params: List[Tuple[str, str, Tuple[DT, ...], DT]] = [] |
| 239 | +op_params: List[Param[str, str, Tuple[DT, ...], DT]] = [] |
239 | 240 | op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol}
|
240 | 241 | for op, symbol in op_to_symbol.items():
|
241 | 242 | if op == '__matmul__':
|
@@ -303,7 +304,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
|
303 | 304 | assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
|
304 | 305 |
|
305 | 306 |
|
306 |
| -inplace_params: List[Tuple[str, str, Tuple[DT, ...], DT]] = [] |
| 307 | +inplace_params: List[Param[str, str, Tuple[DT, ...], DT]] = [] |
307 | 308 | for op, symbol in dh.inplace_op_to_symbol.items():
|
308 | 309 | if op == '__imatmul__':
|
309 | 310 | continue
|
@@ -344,7 +345,7 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
|
344 | 345 | assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}'
|
345 | 346 |
|
346 | 347 |
|
347 |
| -op_scalar_params: List[Tuple[str, str, DT, ScalarType, DT]] = [] |
| 348 | +op_scalar_params: List[Param[str, str, DT, ScalarType, DT]] = [] |
348 | 349 | for op, symbol in dh.binary_op_to_symbol.items():
|
349 | 350 | if op == '__matmul__':
|
350 | 351 | continue
|
@@ -378,7 +379,7 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
|
378 | 379 | assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
|
379 | 380 |
|
380 | 381 |
|
381 |
| -inplace_scalar_params: List[Tuple[str, str, DT, ScalarType]] = [] |
| 382 | +inplace_scalar_params: List[Param[str, str, DT, ScalarType]] = [] |
382 | 383 | for op, symbol in dh.inplace_op_to_symbol.items():
|
383 | 384 | if op == '__imatmul__':
|
384 | 385 | continue
|
|
0 commit comments