Skip to content

Commit a3505c5

Browse files
committed
Use keyword-only arguments for the pytest helpers functions
This makes the inscrutable helper function calls more readable, and fixes at least one instance where the arguments were called in the wrong order. This also changes **kw to kw={} for the function keyword arguments, so there is no ambiguity between function keyword arguments and arguments to the helper function.
1 parent a533680 commit a3505c5

14 files changed

+369
-305
lines changed

array_api_tests/meta/test_pytest_helpers.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55

66

77
def test_assert_dtype():
8-
ph.assert_dtype("promoted_func", [xp.uint8, xp.int8], xp.int16)
8+
ph.assert_dtype("promoted_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.int16)
99
with raises(AssertionError):
10-
ph.assert_dtype("bad_func", [xp.uint8, xp.int8], xp.float32)
11-
ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool)
12-
ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8)
13-
ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool)
10+
ph.assert_dtype("bad_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.float32)
11+
ph.assert_dtype("bool_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.bool, expected=xp.bool)
12+
ph.assert_dtype("single_promoted_func", in_dtype=[xp.uint8], out_dtype=xp.uint8)
13+
ph.assert_dtype("single_bool_func", in_dtype=[xp.uint8], out_dtype=xp.bool, expected=xp.bool)
1414

1515

1616
def test_assert_array_elements():
17-
ph.assert_array_elements("int zeros", xp.asarray(0), xp.asarray(0))
18-
ph.assert_array_elements("pos zeros", xp.asarray(0.0), xp.asarray(0.0))
17+
ph.assert_array_elements("int zeros", out=xp.asarray(0), expected=xp.asarray(0))
18+
ph.assert_array_elements("pos zeros", out=xp.asarray(0.0), expected=xp.asarray(0.0))
1919
with raises(AssertionError):
20-
ph.assert_array_elements("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0))
20+
ph.assert_array_elements("mixed sign zeros", out=xp.asarray(0.0), expected=xp.asarray(-0.0))
2121
with raises(AssertionError):
22-
ph.assert_array_elements("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0))
22+
ph.assert_array_elements("mixed sign zeros", out=xp.asarray(-0.0), expected=xp.asarray(0.0))

array_api_tests/pytest_helpers.py

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ def is_neg_zero(n: float) -> bool:
8282

8383
def assert_dtype(
8484
func_name: str,
85+
*,
8586
in_dtype: Union[DataType, Sequence[DataType]],
8687
out_dtype: DataType,
8788
expected: Optional[DataType] = None,
88-
*,
8989
repr_name: str = "out.dtype",
9090
):
9191
"""
@@ -96,7 +96,7 @@ def assert_dtype(
9696
9797
>>> x = xp.arange(5, dtype=xp.uint8)
9898
>>> out = xp.abs(x)
99-
>>> assert_dtype('abs', x.dtype, out.dtype)
99+
>>> assert_dtype('abs', in_dtype=x.dtype, out_dtype=out.dtype)
100100
101101
is equivalent to
102102
@@ -108,7 +108,7 @@ def assert_dtype(
108108
>>> x1 = xp.arange(5, dtype=xp.uint8)
109109
>>> x2 = xp.arange(5, dtype=xp.uint16)
110110
>>> out = xp.add(x1, x2)
111-
>>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype)
111+
>>> assert_dtype('add', in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
112112
113113
is equivalent to
114114
@@ -119,7 +119,7 @@ def assert_dtype(
119119
>>> x = xp.arange(5, dtype=xp.int8)
120120
>>> out = xp.sum(x)
121121
>>> default_int = xp.asarray(0).dtype
122-
>>> assert_dtype('sum', x, out.dtype, default_int)
122+
>>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int)
123123
124124
"""
125125
in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype]
@@ -135,13 +135,18 @@ def assert_dtype(
135135
assert out_dtype == expected, msg
136136

137137

138-
def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
138+
def assert_kw_dtype(
139+
func_name: str,
140+
*,
141+
kw_dtype: DataType,
142+
out_dtype: DataType,
143+
):
139144
"""
140145
Assert the output dtype is the passed keyword dtype, e.g.
141146
142147
>>> kw = {'dtype': xp.uint8}
143-
>>> out = xp.ones(5, **kw)
144-
>>> assert_kw_dtype('ones', kw['dtype'], out.dtype)
148+
>>> out = xp.ones(5, kw=kw)
149+
>>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype)
145150
146151
"""
147152
f_kw_dtype = dh.dtype_to_name[kw_dtype]
@@ -222,17 +227,17 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty
222227

223228
def assert_shape(
224229
func_name: str,
230+
*,
225231
out_shape: Union[int, Shape],
226232
expected: Union[int, Shape],
227-
/,
228233
repr_name="out.shape",
229-
**kw,
234+
kw: dict = {},
230235
):
231236
"""
232237
Assert the output shape is as expected, e.g.
233238
234239
>>> out = xp.ones((3, 3, 3))
235-
>>> assert_shape('ones', out.shape, (3, 3, 3))
240+
>>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3))
236241
237242
"""
238243
if isinstance(out_shape, int):
@@ -249,11 +254,10 @@ def assert_result_shape(
249254
func_name: str,
250255
in_shapes: Sequence[Shape],
251256
out_shape: Shape,
252-
/,
253257
expected: Optional[Shape] = None,
254258
*,
255259
repr_name="out.shape",
256-
**kw,
260+
kw: dict = {},
257261
):
258262
"""
259263
Assert the output shape is as expected.
@@ -262,7 +266,7 @@ def assert_result_shape(
262266
in_shapes, to test against out_shape, e.g.
263267
264268
>>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3)))
265-
>>> assert_shape('add', [(3, 1), (1, 3)], out.shape)
269+
>>> assert_result_shape('add', in_shape=[(3, 1), (1, 3)], out_shape=out.shape)
266270
267271
is equivalent to
268272
@@ -281,21 +285,21 @@ def assert_result_shape(
281285

282286
def assert_keepdimable_shape(
283287
func_name: str,
288+
*,
284289
in_shape: Shape,
285290
out_shape: Shape,
286291
axes: Tuple[int, ...],
287292
keepdims: bool,
288-
/,
289-
**kw,
293+
kw: dict = {},
290294
):
291295
"""
292296
Assert the output shape from a keepdimable function is as expected, e.g.
293297
294298
>>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
295299
>>> out1 = xp.max(x, keepdims=False)
296300
>>> out2 = xp.max(x, keepdims=True)
297-
>>> assert_keepdimable_shape('max', x.shape, out1.shape, (0, 1), False)
298-
>>> assert_keepdimable_shape('max', x.shape, out2.shape, (0, 1), True)
301+
>>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out1.shape, axes=(0, 1), keepdims=False)
302+
>>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out2.shape, axes=(0, 1), keepdims=True)
299303
300304
is equivalent to
301305
@@ -307,19 +311,26 @@ def assert_keepdimable_shape(
307311
shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape))
308312
else:
309313
shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes)
310-
assert_shape(func_name, out_shape, shape, **kw)
314+
assert_shape(func_name, out_shape=out_shape, expected=shape, kw=kw)
311315

312316

313317
def assert_0d_equals(
314-
func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
318+
func_name: str,
319+
*,
320+
x_repr: str,
321+
x_val: Array,
322+
out_repr: str,
323+
out_val: Array,
324+
kw: dict = {},
315325
):
316326
"""
317327
Assert a 0d array is as expected, e.g.
318328
319329
>>> x = xp.asarray([0, 1, 2])
320-
>>> res = xp.asarray(x, copy=True)
330+
>>> kw = {'copy': True}
331+
>>> res = xp.asarray(x, **kw)
321332
>>> res[0] = 42
322-
>>> assert_0d_equals('asarray', 'x[0]', x[0], 'x[0]', res[0])
333+
>>> assert_0d_equals('asarray', x_repr='x[0]', x_val=x[0], out_repr='x[0]', out_val=res[0], kw=kw)
323334
324335
is equivalent to
325336
@@ -338,20 +349,20 @@ def assert_0d_equals(
338349

339350
def assert_scalar_equals(
340351
func_name: str,
352+
*,
341353
type_: ScalarType,
342354
idx: Shape,
343355
out: Scalar,
344356
expected: Scalar,
345-
/,
346357
repr_name: str = "out",
347-
**kw,
358+
kw: dict = {},
348359
):
349360
"""
350361
Assert a 0d array, convered to a scalar, is as expected, e.g.
351362
352363
>>> x = xp.ones(5, dtype=xp.uint8)
353364
>>> out = xp.sum(x)
354-
>>> assert_scalar_equals('sum', int, (), int(out), 5)
365+
>>> assert_scalar_equals('sum', type_int, out=(), out=int(out), expected=5)
355366
356367
is equivalent to
357368
@@ -372,13 +383,18 @@ def assert_scalar_equals(
372383

373384

374385
def assert_fill(
375-
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
386+
func_name: str,
387+
*,
388+
fill_value: Scalar,
389+
dtype: DataType,
390+
out: Array,
391+
kw: dict = {},
376392
):
377393
"""
378394
Assert all elements of an array is as expected, e.g.
379395
380396
>>> out = xp.full(5, 42, dtype=xp.uint8)
381-
>>> assert_fill('full', 42, xp.uint8, out, 5)
397+
>>> assert_fill('full', fill_value=42, dtype=xp.uint8, out=out, kw=dict(shape=5))
382398
383399
is equivalent to
384400
@@ -408,22 +424,27 @@ def _assert_float_element(at_out: Array, at_expected: Array, msg: str):
408424

409425

410426
def assert_array_elements(
411-
func_name: str, out: Array, expected: Array, /, *, out_repr: str = "out", **kw
427+
func_name: str,
428+
*,
429+
out: Array,
430+
expected: Array,
431+
out_repr: str = "out",
432+
kw: dict = {},
412433
):
413434
"""
414435
Assert array elements are (strictly) as expected, e.g.
415436
416437
>>> x = xp.arange(5)
417438
>>> out = xp.asarray(x)
418-
>>> assert_array_elements('asarray', out, x)
439+
>>> assert_array_elements('asarray', out=out, expected=x)
419440
420441
is equivalent to
421442
422443
>>> assert xp.all(out == x)
423444
424445
"""
425446
dh.result_type(out.dtype, expected.dtype) # sanity check
426-
assert_shape(func_name, out.shape, expected.shape, **kw) # sanity check
447+
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
427448
f_func = f"[{func_name}({fmt_kw(kw)})]"
428449
if out.dtype in dh.float_dtypes:
429450
for idx in sh.ndindex(out.shape):

0 commit comments

Comments
 (0)