Skip to content

Commit 3ff7aba

Browse files
committed
Factor out assert_dtype and fmt_types, add typing.py
1 parent 67a2d6b commit 3ff7aba

File tree

4 files changed

+88
-58
lines changed

4 files changed

+88
-58
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from typing import NamedTuple
1+
from functools import lru_cache
2+
from typing import NamedTuple, Tuple, Union
23

34
from . import _array_module as xp
5+
from .typing import DataType, ScalarType
46

57

68
__all__ = [
@@ -24,6 +26,7 @@
2426
'binary_op_to_symbol',
2527
'unary_op_to_symbol',
2628
'inplace_op_to_symbol',
29+
'fmt_types',
2730
]
2831

2932

@@ -351,3 +354,15 @@ def result_type(*dtypes):
351354
inplace_op_to_symbol[iop] = f'{symbol}='
352355
func_in_dtypes[iop] = func_in_dtypes[op]
353356
func_returns_bool[iop] = func_returns_bool[op]
357+
358+
359+
@lru_cache
360+
def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str:
361+
f_types = []
362+
for type_ in types:
363+
try:
364+
f_types.append(dtype_to_name[type_])
365+
except KeyError:
366+
# i.e. dtype is bool, int, or float
367+
f_types.append(type_.__name__)
368+
return ', '.join(f_types)

array_api_tests/pytest_helpers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from inspect import getfullargspec
2+
3+
from . import dtype_helpers as dh
24
from . import function_stubs
5+
from .typing import DataType
6+
37

48
def raises(exceptions, function, message=''):
59
"""
@@ -33,3 +37,12 @@ def doesnt_raise(function, message=''):
3337

3438
def nargs(func_name):
3539
return len(getfullargspec(getattr(function_stubs, func_name)).args)
40+
41+
def assert_dtype(test_case: str, result_name: str, dtype: DataType, expected: DataType):
42+
msg = (
43+
f'{result_name}={dh.dtype_to_name[dtype]}, '
44+
f'but should be {dh.dtype_to_name[expected]} [{test_case}]'
45+
)
46+
assert dtype == expected, msg
47+
48+

array_api_tests/test_type_promotion.py

Lines changed: 48 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html
33
"""
44
from collections import defaultdict
5-
from functools import lru_cache
6-
from typing import Tuple, Type, Union, List
5+
from typing import Tuple, Union, List
76

87
import pytest
98
from hypothesis import assume, given, reject
@@ -12,41 +11,17 @@
1211
from . import _array_module as xp
1312
from . import dtype_helpers as dh
1413
from . import hypothesis_helpers as hh
14+
from . import pytest_helpers as ph
1515
from . import xps
16+
from .typing import DataType, ScalarType, Param
1617
from .function_stubs import elementwise_functions
17-
from .pytest_helpers import nargs
18-
19-
20-
DT = Type
21-
ScalarType = Union[Type[bool], Type[int], Type[float]]
22-
Param = Tuple
23-
24-
25-
@lru_cache
26-
def fmt_types(types: Tuple[Union[DT, ScalarType], ...]) -> str:
27-
f_types = []
28-
for type_ in types:
29-
try:
30-
f_types.append(dh.dtype_to_name[type_])
31-
except KeyError:
32-
# i.e. dtype is bool, int, or float
33-
f_types.append(type_.__name__)
34-
return ', '.join(f_types)
35-
36-
37-
def assert_dtype(test_case: str, result_name: str, dtype: DT, expected: DT):
38-
msg = (
39-
f'{result_name}={dh.dtype_to_name[dtype]}, '
40-
f'but should be {dh.dtype_to_name[expected]} [{test_case}]'
41-
)
42-
assert dtype == expected, msg
4318

4419

4520
@given(hh.mutually_promotable_dtypes(None))
4621
def test_result_type(dtypes):
4722
out = xp.result_type(*dtypes)
48-
assert_dtype(
49-
f'result_type({fmt_types(dtypes)})', 'out', out, dh.result_type(*dtypes)
23+
ph.assert_dtype(
24+
f'result_type({dh.fmt_types(dtypes)})', 'out', out, dh.result_type(*dtypes)
5025
)
5126

5227

@@ -62,9 +37,9 @@ def test_meshgrid(dtypes, data):
6237
arrays.append(x)
6338
out = xp.meshgrid(*arrays)
6439
expected = dh.result_type(*dtypes)
65-
test_case = f'meshgrid({fmt_types(dtypes)})'
40+
test_case = f'meshgrid({dh.fmt_types(dtypes)})'
6641
for i, x in enumerate(out):
67-
assert_dtype(test_case, f'out[{i}].dtype', x.dtype, expected)
42+
ph.assert_dtype(test_case, f'out[{i}].dtype', x.dtype, expected)
6843

6944

7045
@given(
@@ -78,8 +53,11 @@ def test_concat(shape, dtypes, data):
7853
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
7954
arrays.append(x)
8055
out = xp.concat(arrays)
81-
assert_dtype(
82-
f'concat({fmt_types(dtypes)})', 'out.dtype', out.dtype, dh.result_type(*dtypes)
56+
ph.assert_dtype(
57+
f'concat({dh.fmt_types(dtypes)})',
58+
'out.dtype',
59+
out.dtype,
60+
dh.result_type(*dtypes),
8361
)
8462

8563

@@ -94,8 +72,11 @@ def test_stack(shape, dtypes, data):
9472
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
9573
arrays.append(x)
9674
out = xp.stack(arrays)
97-
assert_dtype(
98-
f'stack({fmt_types(dtypes)})', 'out.dtype', out.dtype, dh.result_type(*dtypes)
75+
ph.assert_dtype(
76+
f'stack({dh.fmt_types(dtypes)})',
77+
'out.dtype',
78+
out.dtype,
79+
dh.result_type(*dtypes),
9980
)
10081

10182

@@ -117,17 +98,19 @@ def test_stack(shape, dtypes, data):
11798

11899

119100
def make_id(
120-
func_name: str, in_dtypes: Tuple[Union[DT, ScalarType], ...], out_dtype: DT
101+
func_name: str,
102+
in_dtypes: Tuple[Union[DataType, ScalarType], ...],
103+
out_dtype: DataType,
121104
) -> str:
122-
f_args = fmt_types(in_dtypes)
105+
f_args = dh.fmt_types(in_dtypes)
123106
f_out_dtype = dh.dtype_to_name[out_dtype]
124107
return f'{func_name}({f_args}) -> {f_out_dtype}'
125108

126109

127-
func_params: List[Param[str, Tuple[DT, ...], DT]] = []
110+
func_params: List[Param[str, Tuple[DataType, ...], DataType]] = []
128111
for func_name in elementwise_functions.__all__:
129112
valid_in_dtypes = dh.func_in_dtypes[func_name]
130-
ndtypes = nargs(func_name)
113+
ndtypes = ph.nargs(func_name)
131114
if ndtypes == 1:
132115
for in_dtype in valid_in_dtypes:
133116
out_dtype = xp.bool if dh.func_returns_bool[func_name] else in_dtype
@@ -180,12 +163,12 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
180163
out = func(*arrays)
181164
except OverflowError:
182165
reject()
183-
assert_dtype(
184-
f'{func_name}({fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
166+
ph.assert_dtype(
167+
f'{func_name}({dh.fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
185168
)
186169

187170

188-
promotion_params: List[Param[Tuple[DT, DT], DT]] = []
171+
promotion_params: List[Param[Tuple[DataType, DataType], DataType]] = []
189172
for (dtype1, dtype2), promoted_dtype in dh.promotion_table.items():
190173
p = pytest.param(
191174
(dtype1, dtype2),
@@ -202,7 +185,9 @@ def test_where(in_dtypes, out_dtype, shapes, data):
202185
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
203186
cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[2]), label='condition')
204187
out = xp.where(cond, x1, x2)
205-
assert_dtype(f'where({fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype)
188+
ph.assert_dtype(
189+
f'where({dh.fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
190+
)
206191

207192

208193
numeric_promotion_params = promotion_params[1:]
@@ -214,8 +199,8 @@ def test_tensordot(in_dtypes, out_dtype, shapes, data):
214199
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
215200
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
216201
out = xp.tensordot(x1, x2)
217-
assert_dtype(
218-
f'tensordot({fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
202+
ph.assert_dtype(
203+
f'tensordot({dh.fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
219204
)
220205

221206

@@ -225,16 +210,18 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
225210
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
226211
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
227212
out = xp.vecdot(x1, x2)
228-
assert_dtype(f'vecdot({fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype)
213+
ph.assert_dtype(
214+
f'vecdot({dh.fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
215+
)
229216

230217

231-
op_params: List[Param[str, str, Tuple[DT, ...], DT]] = []
218+
op_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = []
232219
op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol}
233220
for op, symbol in op_to_symbol.items():
234221
if op == '__matmul__':
235222
continue
236223
valid_in_dtypes = dh.func_in_dtypes[op]
237-
ndtypes = nargs(op)
224+
ndtypes = ph.nargs(op)
238225
if ndtypes == 1:
239226
for in_dtype in valid_in_dtypes:
240227
out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype
@@ -293,10 +280,12 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
293280
out = eval(expr, locals_)
294281
except OverflowError:
295282
reject()
296-
assert_dtype(f'{op}({fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype)
283+
ph.assert_dtype(
284+
f'{op}({dh.fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
285+
)
297286

298287

299-
inplace_params: List[Param[str, str, Tuple[DT, ...], DT]] = []
288+
inplace_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = []
300289
for op, symbol in dh.inplace_op_to_symbol.items():
301290
if op == '__imatmul__':
302291
continue
@@ -334,10 +323,10 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
334323
except OverflowError:
335324
reject()
336325
x1 = locals_['x1']
337-
assert_dtype(f'{op}({fmt_types(in_dtypes)})', 'x1.dtype', x1.dtype, out_dtype)
326+
ph.assert_dtype(f'{op}({dh.fmt_types(in_dtypes)})', 'x1.dtype', x1.dtype, out_dtype)
338327

339328

340-
op_scalar_params: List[Param[str, str, DT, ScalarType, DT]] = []
329+
op_scalar_params: List[Param[str, str, DataType, ScalarType, DataType]] = []
341330
for op, symbol in dh.binary_op_to_symbol.items():
342331
if op == '__matmul__':
343332
continue
@@ -368,12 +357,12 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
368357
out = eval(expr, {'x': x, 's': s})
369358
except OverflowError:
370359
reject()
371-
assert_dtype(
372-
f'{op}({fmt_types((in_dtype, in_stype))})', 'out.dtype', out.dtype, out_dtype
360+
ph.assert_dtype(
361+
f'{op}({dh.fmt_types((in_dtype, in_stype))})', 'out.dtype', out.dtype, out_dtype
373362
)
374363

375364

376-
inplace_scalar_params: List[Param[str, str, DT, ScalarType]] = []
365+
inplace_scalar_params: List[Param[str, str, DataType, ScalarType]] = []
377366
for op, symbol in dh.inplace_op_to_symbol.items():
378367
if op == '__imatmul__':
379368
continue
@@ -405,7 +394,9 @@ def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data):
405394
reject()
406395
x = locals_['x']
407396
assert x.dtype == dtype, f'{x.dtype=!s}, but should be {dtype}'
408-
assert_dtype(f'{op}({fmt_types((dtype, in_stype))})', 'x.dtype', x.dtype, dtype)
397+
ph.assert_dtype(
398+
f'{op}({dh.fmt_types((dtype, in_stype))})', 'x.dtype', x.dtype, dtype
399+
)
409400

410401

411402
if __name__ == '__main__':

array_api_tests/typing.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from typing import Tuple, Type, Union, Any
2+
3+
__all__ = [
4+
"DataType",
5+
"ScalarType",
6+
"Param",
7+
]
8+
9+
DataType = Type[Any]
10+
ScalarType = Union[Type[bool], Type[int], Type[float]]
11+
Param = Tuple

0 commit comments

Comments
 (0)