Skip to content

Commit 2f7933a

Browse files
committed
Alias Param type hint as Tuple
1 parent d7bf0bf commit 2f7933a

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

array_api_tests/test_type_promotion.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
DT = Type
2121
ScalarType = Union[Type[bool], Type[int], Type[float]]
22+
Param = Tuple
2223

2324

2425
def multi_promotable_dtypes(
@@ -126,7 +127,7 @@ def make_id(
126127
return f'{func_name}({f_args}) -> {f_out_dtype}'
127128

128129

129-
func_params: List[Tuple[str, Tuple[DT, ...], DT]] = []
130+
func_params: List[Param[str, Tuple[DT, ...], DT]] = []
130131
for func_name in elementwise_functions.__all__:
131132
valid_in_dtypes = dh.func_in_dtypes[func_name]
132133
ndtypes = nargs(func_name)
@@ -185,7 +186,7 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
185186
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
186187

187188

188-
promotion_params: List[Tuple[Tuple[DT, DT], DT]] = []
189+
promotion_params: List[Param[Tuple[DT, DT], DT]] = []
189190
for (dtype1, dtype2), promoted_dtype in dh.promotion_table.items():
190191
p = pytest.param(
191192
(dtype1, dtype2),
@@ -235,7 +236,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
235236
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
236237

237238

238-
op_params: List[Tuple[str, str, Tuple[DT, ...], DT]] = []
239+
op_params: List[Param[str, str, Tuple[DT, ...], DT]] = []
239240
op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol}
240241
for op, symbol in op_to_symbol.items():
241242
if op == '__matmul__':
@@ -303,7 +304,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
303304
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
304305

305306

306-
inplace_params: List[Tuple[str, str, Tuple[DT, ...], DT]] = []
307+
inplace_params: List[Param[str, str, Tuple[DT, ...], DT]] = []
307308
for op, symbol in dh.inplace_op_to_symbol.items():
308309
if op == '__imatmul__':
309310
continue
@@ -344,7 +345,7 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
344345
assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}'
345346

346347

347-
op_scalar_params: List[Tuple[str, str, DT, ScalarType, DT]] = []
348+
op_scalar_params: List[Param[str, str, DT, ScalarType, DT]] = []
348349
for op, symbol in dh.binary_op_to_symbol.items():
349350
if op == '__matmul__':
350351
continue
@@ -378,7 +379,7 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
378379
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
379380

380381

381-
inplace_scalar_params: List[Tuple[str, str, DT, ScalarType]] = []
382+
inplace_scalar_params: List[Param[str, str, DT, ScalarType]] = []
382383
for op, symbol in dh.inplace_op_to_symbol.items():
383384
if op == '__imatmul__':
384385
continue

0 commit comments

Comments
 (0)