Skip to content

Commit dc211d4

Browse files
committed
Default expected and out_name in ph.assert_dtype()
1 parent d6665c0 commit dc211d4

File tree

3 files changed

+19
-23
lines changed

3 files changed

+19
-23
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from inspect import getfullargspec
2-
from typing import Tuple
2+
from typing import Optional, Tuple
33

44
from . import dtype_helpers as dh
55
from . import function_stubs
@@ -43,12 +43,15 @@ def nargs(func_name):
4343
def assert_dtype(
4444
func_name: str,
4545
in_dtypes: Tuple[DataType, ...],
46-
out_name: str,
4746
out_dtype: DataType,
48-
expected: DataType
47+
expected: Optional[DataType] = None,
48+
*,
49+
out_name: str = "out.dtype",
4950
):
5051
f_in_dtypes = dh.fmt_types(in_dtypes)
5152
f_out_dtype = dh.dtype_to_name[out_dtype]
53+
if expected is None:
54+
expected = dh.result_type(*in_dtypes)
5255
f_expected = dh.dtype_to_name[expected]
5356
msg = (
5457
f"{out_name}={f_out_dtype}, but should be {f_expected} "

array_api_tests/test_linalg.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -280,13 +280,7 @@ def test_matmul(x1, x2):
280280
else:
281281
res = _array_module.matmul(x1, x2)
282282

283-
ph.assert_dtype(
284-
"matmul",
285-
(x1.dtype, x2.dtype),
286-
"out.dtype",
287-
res.dtype,
288-
dh.promotion_table[x1.dtype, x2.dtype],
289-
)
283+
ph.assert_dtype("matmul", (x1.dtype, x2.dtype), res.dtype)
290284

291285
if len(x1.shape) == len(x2.shape) == 1:
292286
assert res.shape == ()

array_api_tests/test_type_promotion.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
@given(hh.mutually_promotable_dtypes(None))
2121
def test_result_type(dtypes):
2222
out = xp.result_type(*dtypes)
23-
ph.assert_dtype('result_type', dtypes, 'out', out, dh.result_type(*dtypes))
23+
ph.assert_dtype('result_type', dtypes, out, out_name='out')
2424

2525

2626
@given(
@@ -34,9 +34,8 @@ def test_meshgrid(dtypes, data):
3434
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
3535
arrays.append(x)
3636
out = xp.meshgrid(*arrays)
37-
expected = dh.result_type(*dtypes)
3837
for i, x in enumerate(out):
39-
ph.assert_dtype('meshgrid', dtypes, f'out[{i}].dtype', x.dtype, expected)
38+
ph.assert_dtype('meshgrid', dtypes, x.dtype, out_name=f'out[{i}].dtype')
4039

4140

4241
@given(
@@ -50,7 +49,7 @@ def test_concat(shape, dtypes, data):
5049
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
5150
arrays.append(x)
5251
out = xp.concat(arrays)
53-
ph.assert_dtype('concat', dtypes, 'out.dtype', out.dtype, dh.result_type(*dtypes))
52+
ph.assert_dtype('concat', dtypes, out.dtype)
5453

5554

5655
@given(
@@ -64,7 +63,7 @@ def test_stack(shape, dtypes, data):
6463
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
6564
arrays.append(x)
6665
out = xp.stack(arrays)
67-
ph.assert_dtype('stack', dtypes, 'out.dtype', out.dtype, dh.result_type(*dtypes))
66+
ph.assert_dtype('stack', dtypes, out.dtype)
6867

6968

7069
bitwise_shift_funcs = [
@@ -150,7 +149,7 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
150149
out = func(*arrays)
151150
except OverflowError:
152151
reject()
153-
ph.assert_dtype(func_name, in_dtypes, 'out.dtype', out.dtype, out_dtype)
152+
ph.assert_dtype(func_name, in_dtypes, out.dtype, out_dtype)
154153

155154

156155
promotion_params: List[Param[Tuple[DataType, DataType], DataType]] = []
@@ -170,7 +169,7 @@ def test_where(in_dtypes, out_dtype, shapes, data):
170169
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
171170
cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[2]), label='condition')
172171
out = xp.where(cond, x1, x2)
173-
ph.assert_dtype('where', in_dtypes, 'out.dtype', out.dtype, out_dtype)
172+
ph.assert_dtype('where', in_dtypes, out.dtype, out_dtype)
174173

175174

176175
numeric_promotion_params = promotion_params[1:]
@@ -182,7 +181,7 @@ def test_tensordot(in_dtypes, out_dtype, shapes, data):
182181
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
183182
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
184183
out = xp.tensordot(x1, x2)
185-
ph.assert_dtype('tensordot', in_dtypes, 'out.dtype', out.dtype, out_dtype)
184+
ph.assert_dtype('tensordot', in_dtypes, out.dtype, out_dtype)
186185

187186

188187
@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_params)
@@ -191,7 +190,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
191190
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
192191
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
193192
out = xp.vecdot(x1, x2)
194-
ph.assert_dtype('vecdot', in_dtypes, 'out.dtype', out.dtype, out_dtype)
193+
ph.assert_dtype('vecdot', in_dtypes, out.dtype, out_dtype)
195194

196195

197196
op_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = []
@@ -259,7 +258,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
259258
out = eval(expr, locals_)
260259
except OverflowError:
261260
reject()
262-
ph.assert_dtype(op, in_dtypes, 'out.dtype', out.dtype, out_dtype)
261+
ph.assert_dtype(op, in_dtypes, out.dtype, out_dtype)
263262

264263

265264
inplace_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = []
@@ -300,7 +299,7 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
300299
except OverflowError:
301300
reject()
302301
x1 = locals_['x1']
303-
ph.assert_dtype(op, in_dtypes, 'x1.dtype', x1.dtype, out_dtype)
302+
ph.assert_dtype(op, in_dtypes, x1.dtype, out_dtype, out_name='x1.dtype')
304303

305304

306305
op_scalar_params: List[Param[str, str, DataType, ScalarType, DataType]] = []
@@ -334,7 +333,7 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
334333
out = eval(expr, {'x': x, 's': s})
335334
except OverflowError:
336335
reject()
337-
ph.assert_dtype(op, (in_dtype, in_stype), 'out.dtype', out.dtype, out_dtype)
336+
ph.assert_dtype(op, (in_dtype, in_stype), out.dtype, out_dtype)
338337

339338

340339
inplace_scalar_params: List[Param[str, str, DataType, ScalarType]] = []
@@ -369,7 +368,7 @@ def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data):
369368
reject()
370369
x = locals_['x']
371370
assert x.dtype == dtype, f'{x.dtype=!s}, but should be {dtype}'
372-
ph.assert_dtype(op, (dtype, in_stype), 'x.dtype', x.dtype, dtype)
371+
ph.assert_dtype(op, (dtype, in_stype), x.dtype, dtype, out_name='x.dtype')
373372

374373

375374
if __name__ == '__main__':

0 commit comments

Comments
 (0)