Skip to content

Commit f9b679f

Browse files
committed
Clearer size/shape assertions
1 parent f6ea2c1 commit f9b679f

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,14 @@ def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
5959

6060

6161
def assert_shape(
62-
func_name: str, out_shape: Shape, expected: Union[int, Shape], /, **kw
62+
func_name: str, out_shape: Union[int, Shape], expected: Union[int, Shape], /, **kw
6363
):
64-
f_kw = ", ".join(f"{k}={v}" for k, v in kw.items())
65-
msg = f"out.shape={out_shape}, but should be {expected} [{func_name}({f_kw})]"
64+
if isinstance(out_shape, int):
65+
out_shape = (out_shape,)
6666
if isinstance(expected, int):
6767
expected = (expected,)
68+
f_kw = ", ".join(f"{k}={v}" for k, v in kw.items())
69+
msg = f"out.shape={out_shape}, but should be {expected} [{func_name}({f_kw})]"
6870
assert out_shape == expected, msg
6971

7072

@@ -183,7 +185,7 @@ def test_arange(dtype, data):
183185
else:
184186
_dtype = dtype
185187

186-
# sanity check
188+
# sanity checks
187189
if dh.is_int_dtype(_dtype):
188190
m, M = dh.dtype_ranges[_dtype]
189191
assert m <= _start <= M
@@ -213,9 +215,10 @@ def test_arange(dtype, data):
213215
assert_default_float("arange", out.dtype)
214216
else:
215217
assert out.dtype == dtype
216-
assert out.ndim == 1, f"{out.ndim=}, should be 1 [linspace()]"
218+
assert out.ndim == 1, f"{out.ndim=}, but should be 1 [linspace()]"
219+
f_func = f"[linspace({start=}, {stop=}, {step=})]"
217220
if dh.is_int_dtype(_dtype):
218-
assert out.size == size
221+
assert out.size == size, f"{out.size=}, but should be {size} {f_func}"
219222
else:
220223
# We check size is roughly as expected to avoid edge cases e.g.
221224
#
@@ -224,7 +227,11 @@ def test_arange(dtype, data):
224227
# >>> xp.arange(2, step=0.3333333333333333)
225228
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
226229
#
227-
assert math.floor(math.sqrt(size)) <= out.size <= math.ceil(size ** 2)
230+
min_size = math.floor(size * 0.9)
231+
max_size = math.ceil(size * 1.1)
232+
assert (
233+
min_size <= out.size <= max_size
234+
), f"{out.size=}, but should be roughly {size} {f_func}"
228235
assume(out.size == size)
229236
if dh.is_int_dtype(_dtype):
230237
ah.assert_exactly_equal(out, ah.asarray(list(r), dtype=_dtype))
@@ -407,24 +414,22 @@ def test_linspace(num, dtype, endpoint, data):
407414
out = xp.linspace(start, stop, num, **kw)
408415

409416
assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num)
417+
f_func = f"[linspace({start=}, {stop=}, {num=})]"
410418
if num > 0:
411419
assert ah.equal(
412420
out[0], ah.asarray(start, dtype=out.dtype)
413-
), f"out[0]={out[0]}, but should be {start=} [linspace({stop=}, {num=})]"
421+
), f"out[0]={out[0]}, but should be {start} {f_func}"
414422
if endpoint:
415423
if num > 1:
416424
assert ah.equal(
417425
out[-1], ah.asarray(stop, dtype=out.dtype)
418-
), f"out[-1]={out[-1]}, but should be {stop=} [linspace({start=}, {num=})]"
426+
), f"out[-1]={out[-1]}, but should be {stop} {f_func}"
419427
else:
420428
# linspace(..., num, endpoint=True) should return an array equivalent to
421429
# the first num elements when endpoint=False
422430
expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True)
423431
expected = expected[:-1]
424432
ah.assert_exactly_equal(out, expected)
425-
assert (
426-
out.size == num
427-
), f"{out.size=}, but should be {num=} [linspace({start=}, {stop=})]"
428433

429434

430435
def make_one(dtype: DataType) -> Scalar:

0 commit comments

Comments
 (0)