Skip to content

Commit a3e5a01

Browse files
committed
Move creation assert helpers to pytest_helpers
1 parent a4a0a35 commit a3e5a01

File tree

2 files changed

+116
-95
lines changed

2 files changed

+116
-95
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
1+
import math
12
from inspect import getfullargspec
2-
from typing import Optional, Tuple
3+
from typing import Any, Dict, Optional, Tuple, Union
34

5+
from . import array_helpers as ah
46
from . import dtype_helpers as dh
57
from . import function_stubs
6-
from .typing import DataType
8+
from .typing import Array, DataType, Scalar, Shape
79

10+
__all__ = [
11+
"raises",
12+
"doesnt_raise",
13+
"nargs",
14+
"assert_dtype",
15+
"assert_kw_dtype",
16+
"assert_default_float",
17+
"assert_default_int",
18+
"assert_shape",
19+
"assert_fill",
20+
]
821

9-
def raises(exceptions, function, message=''):
22+
23+
def raises(exceptions, function, message=""):
1024
"""
1125
Like pytest.raises() except it allows custom error messages
1226
"""
@@ -16,11 +30,14 @@ def raises(exceptions, function, message=''):
1630
return
1731
except Exception as e:
1832
if message:
19-
raise AssertionError(f"Unexpected exception {e!r} (expected {exceptions}): {message}")
33+
raise AssertionError(
34+
f"Unexpected exception {e!r} (expected {exceptions}): {message}"
35+
)
2036
raise AssertionError(f"Unexpected exception {e!r} (expected {exceptions})")
2137
raise AssertionError(message)
2238

23-
def doesnt_raise(function, message=''):
39+
40+
def doesnt_raise(function, message=""):
2441
"""
2542
The inverse of raises().
2643
@@ -36,10 +53,15 @@ def doesnt_raise(function, message=''):
3653
raise AssertionError(f"Unexpected exception {e!r}: {message}")
3754
raise AssertionError(f"Unexpected exception {e!r}")
3855

56+
3957
def nargs(func_name):
4058
return len(getfullargspec(getattr(function_stubs, func_name)).args)
4159

4260

61+
def _fmt_kw(kw: Dict[str, Any]) -> str:
62+
return ", ".join(f"{k}={v}" for k, v in kw.items())
63+
64+
4365
def assert_dtype(
4466
func_name: str,
4567
in_dtypes: Tuple[DataType, ...],
@@ -60,3 +82,54 @@ def assert_dtype(
6082
assert out_dtype == expected, msg
6183

6284

85+
def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
86+
f_kw_dtype = dh.dtype_to_name[kw_dtype]
87+
f_out_dtype = dh.dtype_to_name[out_dtype]
88+
msg = (
89+
f"out.dtype={f_out_dtype}, but should be {f_kw_dtype} "
90+
f"[{func_name}(dtype={f_kw_dtype})]"
91+
)
92+
assert out_dtype == kw_dtype, msg
93+
94+
95+
def assert_default_float(func_name: str, dtype: DataType):
96+
f_dtype = dh.dtype_to_name[dtype]
97+
f_default = dh.dtype_to_name[dh.default_float]
98+
msg = (
99+
f"out.dtype={f_dtype}, should be default "
100+
f"floating-point dtype {f_default} [{func_name}()]"
101+
)
102+
assert dtype == dh.default_float, msg
103+
104+
105+
def assert_default_int(func_name: str, dtype: DataType):
106+
f_dtype = dh.dtype_to_name[dtype]
107+
f_default = dh.dtype_to_name[dh.default_int]
108+
msg = (
109+
f"out.dtype={f_dtype}, should be default "
110+
f"integer dtype {f_default} [{func_name}()]"
111+
)
112+
assert dtype == dh.default_int, msg
113+
114+
115+
def assert_shape(
116+
func_name: str, out_shape: Union[int, Shape], expected: Union[int, Shape], /, **kw
117+
):
118+
if isinstance(out_shape, int):
119+
out_shape = (out_shape,)
120+
if isinstance(expected, int):
121+
expected = (expected,)
122+
msg = (
123+
f"out.shape={out_shape}, but should be {expected} [{func_name}({_fmt_kw(kw)})]"
124+
)
125+
assert out_shape == expected, msg
126+
127+
128+
def assert_fill(
129+
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
130+
):
131+
msg = f"out not filled with {fill_value} [{func_name}({_fmt_kw(kw)})]\n{out=}"
132+
if math.isnan(fill_value):
133+
assert ah.all(ah.isnan(out)), msg
134+
else:
135+
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg

array_api_tests/test_creation_functions.py

Lines changed: 38 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import math
2-
from typing import Union, Any, Tuple, NamedTuple, Iterator
32
from itertools import count
3+
from typing import Any, Iterator, NamedTuple, Tuple, Union
44

5-
from hypothesis import assume, given, strategies as st
5+
from hypothesis import assume, given
6+
from hypothesis import strategies as st
67

78
from . import _array_module as xp
89
from . import array_helpers as ah
9-
from . import hypothesis_helpers as hh
1010
from . import dtype_helpers as dh
11+
from . import hypothesis_helpers as hh
1112
from . import pytest_helpers as ph
1213
from . import xps
13-
from .typing import Shape, DataType, Array, Scalar
14+
from .typing import DataType, Scalar
1415

1516

1617
@st.composite
@@ -28,59 +29,6 @@ def specified_kwargs(draw, *keys_values_defaults: Tuple[str, Any, Any]):
2829
return kw
2930

3031

31-
def assert_default_float(func_name: str, dtype: DataType):
32-
f_dtype = dh.dtype_to_name[dtype]
33-
f_default = dh.dtype_to_name[dh.default_float]
34-
msg = (
35-
f"out.dtype={f_dtype}, should be default "
36-
f"floating-point dtype {f_default} [{func_name}()]"
37-
)
38-
assert dtype == dh.default_float, msg
39-
40-
41-
def assert_default_int(func_name: str, dtype: DataType):
42-
f_dtype = dh.dtype_to_name[dtype]
43-
f_default = dh.dtype_to_name[dh.default_int]
44-
msg = (
45-
f"out.dtype={f_dtype}, should be default "
46-
f"integer dtype {f_default} [{func_name}()]"
47-
)
48-
assert dtype == dh.default_int, msg
49-
50-
51-
def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
52-
f_kw_dtype = dh.dtype_to_name[kw_dtype]
53-
f_out_dtype = dh.dtype_to_name[out_dtype]
54-
msg = (
55-
f"out.dtype={f_out_dtype}, but should be {f_kw_dtype} "
56-
f"[{func_name}(dtype={f_kw_dtype})]"
57-
)
58-
assert out_dtype == kw_dtype, msg
59-
60-
61-
def assert_shape(
62-
func_name: str, out_shape: Union[int, Shape], expected: Union[int, Shape], /, **kw
63-
):
64-
if isinstance(out_shape, int):
65-
out_shape = (out_shape,)
66-
if isinstance(expected, int):
67-
expected = (expected,)
68-
f_kw = ", ".join(f"{k}={v}" for k, v in kw.items())
69-
msg = f"out.shape={out_shape}, but should be {expected} [{func_name}({f_kw})]"
70-
assert out_shape == expected, msg
71-
72-
73-
def assert_fill(
74-
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
75-
):
76-
f_kw = ", ".join(f"{k}={v}" for k, v in kw.items())
77-
msg = f"out not filled with {fill_value} [{func_name}({f_kw})]\n" f"{out=}"
78-
if math.isnan(fill_value):
79-
assert ah.all(ah.isnan(out)), msg
80-
else:
81-
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg
82-
83-
8432
class frange(NamedTuple):
8533
start: float
8634
stop: float
@@ -210,9 +158,9 @@ def test_arange(dtype, data):
210158

211159
if dtype is None:
212160
if all_int:
213-
assert_default_int("arange", out.dtype)
161+
ph.assert_default_int("arange", out.dtype)
214162
else:
215-
assert_default_float("arange", out.dtype)
163+
ph.assert_default_float("arange", out.dtype)
216164
else:
217165
assert out.dtype == dtype
218166
assert out.ndim == 1, f"{out.ndim=}, but should be 1 [linspace()]"
@@ -253,10 +201,10 @@ def test_arange(dtype, data):
253201
def test_empty(shape, kw):
254202
out = xp.empty(shape, **kw)
255203
if kw.get("dtype", None) is None:
256-
assert_default_float("empty", out.dtype)
204+
ph.assert_default_float("empty", out.dtype)
257205
else:
258-
assert_kw_dtype("empty", kw["dtype"], out.dtype)
259-
assert_shape("empty", out.shape, shape, shape=shape)
206+
ph.assert_kw_dtype("empty", kw["dtype"], out.dtype)
207+
ph.assert_shape("empty", out.shape, shape, shape=shape)
260208

261209

262210
@given(
@@ -268,8 +216,8 @@ def test_empty_like(x, kw):
268216
if kw.get("dtype", None) is None:
269217
ph.assert_dtype("empty_like", (x.dtype,), out.dtype)
270218
else:
271-
assert_kw_dtype("empty_like", kw["dtype"], out.dtype)
272-
assert_shape("empty_like", out.shape, x.shape)
219+
ph.assert_kw_dtype("empty_like", kw["dtype"], out.dtype)
220+
ph.assert_shape("empty_like", out.shape, x.shape)
273221

274222

275223
@given(
@@ -283,11 +231,11 @@ def test_empty_like(x, kw):
283231
def test_eye(n_rows, n_cols, kw):
284232
out = xp.eye(n_rows, n_cols, **kw)
285233
if kw.get("dtype", None) is None:
286-
assert_default_float("eye", out.dtype)
234+
ph.assert_default_float("eye", out.dtype)
287235
else:
288-
assert_kw_dtype("eye", kw["dtype"], out.dtype)
236+
ph.assert_kw_dtype("eye", kw["dtype"], out.dtype)
289237
_n_cols = n_rows if n_cols is None else n_cols
290-
assert_shape("eye", out.shape, (n_rows, _n_cols), n_rows=n_rows, n_cols=n_cols)
238+
ph.assert_shape("eye", out.shape, (n_rows, _n_cols), n_rows=n_rows, n_cols=n_cols)
291239
f_func = f"[eye({n_rows=}, {n_cols=})]"
292240
for i in range(n_rows):
293241
for j in range(_n_cols):
@@ -336,13 +284,13 @@ def test_full(shape, fill_value, kw):
336284
if isinstance(fill_value, bool):
337285
pass # TODO
338286
elif isinstance(fill_value, int):
339-
assert_default_int("full", out.dtype)
287+
ph.assert_default_int("full", out.dtype)
340288
else:
341-
assert_default_float("full", out.dtype)
289+
ph.assert_default_float("full", out.dtype)
342290
else:
343-
assert_kw_dtype("full", kw["dtype"], out.dtype)
344-
assert_shape("full", out.shape, shape, shape=shape)
345-
assert_fill("full", fill_value, dtype, out, fill_value=fill_value)
291+
ph.assert_kw_dtype("full", kw["dtype"], out.dtype)
292+
ph.assert_shape("full", out.shape, shape, shape=shape)
293+
ph.assert_fill("full", fill_value, dtype, out, fill_value=fill_value)
346294

347295

348296
@st.composite
@@ -365,9 +313,9 @@ def test_full_like(x, fill_value, kw):
365313
if kw.get("dtype", None) is None:
366314
ph.assert_dtype("full_like", (x.dtype,), out.dtype)
367315
else:
368-
assert_kw_dtype("full_like", kw["dtype"], out.dtype)
369-
assert_shape("full_like", out.shape, x.shape)
370-
assert_fill("full_like", fill_value, dtype, out, fill_value=fill_value)
316+
ph.assert_kw_dtype("full_like", kw["dtype"], out.dtype)
317+
ph.assert_shape("full_like", out.shape, x.shape)
318+
ph.assert_fill("full_like", fill_value, dtype, out, fill_value=fill_value)
371319

372320

373321
finite_kw = {"allow_nan": False, "allow_infinity": False}
@@ -420,7 +368,7 @@ def test_linspace(num, dtype, endpoint, data):
420368
)
421369
out = xp.linspace(start, stop, num, **kw)
422370

423-
assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num)
371+
ph.assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num)
424372
f_func = f"[linspace({start=}, {stop=}, {num=})]"
425373
if num > 0:
426374
assert ah.equal(
@@ -452,12 +400,12 @@ def make_one(dtype: DataType) -> Scalar:
452400
def test_ones(shape, kw):
453401
out = xp.ones(shape, **kw)
454402
if kw.get("dtype", None) is None:
455-
assert_default_float("ones", out.dtype)
403+
ph.assert_default_float("ones", out.dtype)
456404
else:
457-
assert_kw_dtype("ones", kw["dtype"], out.dtype)
458-
assert_shape("ones", out.shape, shape, shape=shape)
405+
ph.assert_kw_dtype("ones", kw["dtype"], out.dtype)
406+
ph.assert_shape("ones", out.shape, shape, shape=shape)
459407
dtype = kw.get("dtype", None) or dh.default_float
460-
assert_fill("ones", make_one(dtype), dtype, out)
408+
ph.assert_fill("ones", make_one(dtype), dtype, out)
461409

462410

463411
@given(
@@ -469,10 +417,10 @@ def test_ones_like(x, kw):
469417
if kw.get("dtype", None) is None:
470418
ph.assert_dtype("ones_like", (x.dtype,), out.dtype)
471419
else:
472-
assert_kw_dtype("ones_like", kw["dtype"], out.dtype)
473-
assert_shape("ones_like", out.shape, x.shape)
420+
ph.assert_kw_dtype("ones_like", kw["dtype"], out.dtype)
421+
ph.assert_shape("ones_like", out.shape, x.shape)
474422
dtype = kw.get("dtype", None) or x.dtype
475-
assert_fill("ones_like", make_one(dtype), dtype, out)
423+
ph.assert_fill("ones_like", make_one(dtype), dtype, out)
476424

477425

478426
def make_zero(dtype: DataType) -> Scalar:
@@ -488,12 +436,12 @@ def make_zero(dtype: DataType) -> Scalar:
488436
def test_zeros(shape, kw):
489437
out = xp.zeros(shape, **kw)
490438
if kw.get("dtype", None) is None:
491-
assert_default_float("zeros", out.dtype)
439+
ph.assert_default_float("zeros", out.dtype)
492440
else:
493-
assert_kw_dtype("zeros", kw["dtype"], out.dtype)
494-
assert_shape("zeros", out.shape, shape, shape=shape)
441+
ph.assert_kw_dtype("zeros", kw["dtype"], out.dtype)
442+
ph.assert_shape("zeros", out.shape, shape, shape=shape)
495443
dtype = kw.get("dtype", None) or dh.default_float
496-
assert_fill("zeros", make_zero(dtype), dtype, out)
444+
ph.assert_fill("zeros", make_zero(dtype), dtype, out)
497445

498446

499447
@given(
@@ -505,7 +453,7 @@ def test_zeros_like(x, kw):
505453
if kw.get("dtype", None) is None:
506454
ph.assert_dtype("zeros_like", (x.dtype,), out.dtype)
507455
else:
508-
assert_kw_dtype("zeros_like", kw["dtype"], out.dtype)
509-
assert_shape("zeros_like", out.shape, x.shape)
456+
ph.assert_kw_dtype("zeros_like", kw["dtype"], out.dtype)
457+
ph.assert_shape("zeros_like", out.shape, x.shape)
510458
dtype = kw.get("dtype", None) or x.dtype
511-
assert_fill("zeros_like", make_zero(dtype), dtype, out)
459+
ph.assert_fill("zeros_like", make_zero(dtype), dtype, out)

0 commit comments

Comments
 (0)