Skip to content

Commit 051d2b1

Browse files
committed
Construct test case name in ph.assert_dtype()
1 parent 3ff7aba commit 051d2b1

File tree

2 files changed

+27
-43
lines changed

2 files changed

+27
-43
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from inspect import getfullargspec
2+
from typing import Tuple
23

34
from . import dtype_helpers as dh
45
from . import function_stubs
@@ -38,11 +39,21 @@ def doesnt_raise(function, message=''):
3839
def nargs(func_name):
3940
return len(getfullargspec(getattr(function_stubs, func_name)).args)
4041

41-
def assert_dtype(test_case: str, result_name: str, dtype: DataType, expected: DataType):
42+
43+
def assert_dtype(
44+
func_name: str,
45+
in_dtypes: Tuple[DataType, ...],
46+
out_name: str,
47+
out_dtype: DataType,
48+
expected: DataType
49+
):
50+
f_in_dtypes = dh.fmt_types(in_dtypes)
51+
f_out_dtype = dh.dtype_to_name[out_dtype]
52+
f_expected = dh.dtype_to_name[expected]
4253
msg = (
43-
f'{result_name}={dh.dtype_to_name[dtype]}, '
44-
f'but should be {dh.dtype_to_name[expected]} [{test_case}]'
54+
f"{out_name}={f_out_dtype}, but should be {f_expected} "
55+
f"[{func_name}({f_in_dtypes})]"
4556
)
46-
assert dtype == expected, msg
57+
assert out_dtype == expected, msg
4758

4859

array_api_tests/test_type_promotion.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
@given(hh.mutually_promotable_dtypes(None))
2121
def test_result_type(dtypes):
2222
out = xp.result_type(*dtypes)
23-
ph.assert_dtype(
24-
f'result_type({dh.fmt_types(dtypes)})', 'out', out, dh.result_type(*dtypes)
25-
)
23+
ph.assert_dtype('result_type', dtypes, 'out', out, dh.result_type(*dtypes))
2624

2725

2826
@given(
@@ -37,9 +35,8 @@ def test_meshgrid(dtypes, data):
3735
arrays.append(x)
3836
out = xp.meshgrid(*arrays)
3937
expected = dh.result_type(*dtypes)
40-
test_case = f'meshgrid({dh.fmt_types(dtypes)})'
4138
for i, x in enumerate(out):
42-
ph.assert_dtype(test_case, f'out[{i}].dtype', x.dtype, expected)
39+
ph.assert_dtype('meshgrid', dtypes, f'out[{i}].dtype', x.dtype, expected)
4340

4441

4542
@given(
@@ -53,12 +50,7 @@ def test_concat(shape, dtypes, data):
5350
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
5451
arrays.append(x)
5552
out = xp.concat(arrays)
56-
ph.assert_dtype(
57-
f'concat({dh.fmt_types(dtypes)})',
58-
'out.dtype',
59-
out.dtype,
60-
dh.result_type(*dtypes),
61-
)
53+
ph.assert_dtype('concat', dtypes, 'out.dtype', out.dtype, dh.result_type(*dtypes))
6254

6355

6456
@given(
@@ -72,12 +64,7 @@ def test_stack(shape, dtypes, data):
7264
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
7365
arrays.append(x)
7466
out = xp.stack(arrays)
75-
ph.assert_dtype(
76-
f'stack({dh.fmt_types(dtypes)})',
77-
'out.dtype',
78-
out.dtype,
79-
dh.result_type(*dtypes),
80-
)
67+
ph.assert_dtype('stack', dtypes, 'out.dtype', out.dtype, dh.result_type(*dtypes))
8168

8269

8370
bitwise_shift_funcs = [
@@ -163,9 +150,7 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
163150
out = func(*arrays)
164151
except OverflowError:
165152
reject()
166-
ph.assert_dtype(
167-
f'{func_name}({dh.fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
168-
)
153+
ph.assert_dtype(func_name, in_dtypes, 'out.dtype', out.dtype, out_dtype)
169154

170155

171156
promotion_params: List[Param[Tuple[DataType, DataType], DataType]] = []
@@ -185,9 +170,7 @@ def test_where(in_dtypes, out_dtype, shapes, data):
185170
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
186171
cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[2]), label='condition')
187172
out = xp.where(cond, x1, x2)
188-
ph.assert_dtype(
189-
f'where({dh.fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
190-
)
173+
ph.assert_dtype('where', in_dtypes, 'out.dtype', out.dtype, out_dtype)
191174

192175

193176
numeric_promotion_params = promotion_params[1:]
@@ -199,9 +182,7 @@ def test_tensordot(in_dtypes, out_dtype, shapes, data):
199182
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
200183
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
201184
out = xp.tensordot(x1, x2)
202-
ph.assert_dtype(
203-
f'tensordot({dh.fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
204-
)
185+
ph.assert_dtype('tensordot', in_dtypes, 'out.dtype', out.dtype, out_dtype)
205186

206187

207188
@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_params)
@@ -210,9 +191,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
210191
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
211192
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
212193
out = xp.vecdot(x1, x2)
213-
ph.assert_dtype(
214-
f'vecdot({dh.fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
215-
)
194+
ph.assert_dtype('vecdot', in_dtypes, 'out.dtype', out.dtype, out_dtype)
216195

217196

218197
op_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = []
@@ -280,9 +259,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
280259
out = eval(expr, locals_)
281260
except OverflowError:
282261
reject()
283-
ph.assert_dtype(
284-
f'{op}({dh.fmt_types(in_dtypes)})', 'out.dtype', out.dtype, out_dtype
285-
)
262+
ph.assert_dtype(op, in_dtypes, 'out.dtype', out.dtype, out_dtype)
286263

287264

288265
inplace_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = []
@@ -323,7 +300,7 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
323300
except OverflowError:
324301
reject()
325302
x1 = locals_['x1']
326-
ph.assert_dtype(f'{op}({dh.fmt_types(in_dtypes)})', 'x1.dtype', x1.dtype, out_dtype)
303+
ph.assert_dtype(op, in_dtypes, 'x1.dtype', x1.dtype, out_dtype)
327304

328305

329306
op_scalar_params: List[Param[str, str, DataType, ScalarType, DataType]] = []
@@ -357,9 +334,7 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
357334
out = eval(expr, {'x': x, 's': s})
358335
except OverflowError:
359336
reject()
360-
ph.assert_dtype(
361-
f'{op}({dh.fmt_types((in_dtype, in_stype))})', 'out.dtype', out.dtype, out_dtype
362-
)
337+
ph.assert_dtype(op, (in_dtype, in_stype), 'out.dtype', out.dtype, out_dtype)
363338

364339

365340
inplace_scalar_params: List[Param[str, str, DataType, ScalarType]] = []
@@ -394,9 +369,7 @@ def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data):
394369
reject()
395370
x = locals_['x']
396371
assert x.dtype == dtype, f'{x.dtype=!s}, but should be {dtype}'
397-
ph.assert_dtype(
398-
f'{op}({dh.fmt_types((dtype, in_stype))})', 'x.dtype', x.dtype, dtype
399-
)
372+
ph.assert_dtype(op, (dtype, in_stype), 'x.dtype', x.dtype, dtype)
400373

401374

402375
if __name__ == '__main__':

0 commit comments

Comments
 (0)