Skip to content

Commit fb56b22

Browse files
committed
Refactor fill assertions
1 parent 7e613fb commit fb56b22

File tree

2 files changed

+27
-29
lines changed

2 files changed

+27
-29
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from . import dtype_helpers as dh
1111
from . import pytest_helpers as ph
1212
from . import xps
13-
from .typing import Shape, DataType
13+
from .typing import Shape, DataType, Array
1414

1515

1616
def assert_default_float(func_name: str, dtype: DataType):
@@ -33,11 +33,7 @@ def assert_default_int(func_name: str, dtype: DataType):
3333
assert dtype == dh.default_int, msg
3434

3535

36-
def assert_kw_dtype(
37-
func_name: str,
38-
kw_dtype: DataType,
39-
out_dtype: DataType,
40-
):
36+
def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
4137
f_kw_dtype = dh.dtype_to_name[kw_dtype]
4238
f_out_dtype = dh.dtype_to_name[out_dtype]
4339
msg = (
@@ -47,12 +43,7 @@ def assert_kw_dtype(
4743
assert out_dtype == kw_dtype, msg
4844

4945

50-
def assert_shape(
51-
func_name: str,
52-
out_shape: Shape,
53-
expected: Union[int, Shape],
54-
**kw,
55-
):
46+
def assert_shape(func_name: str, out_shape: Shape, expected: Union[int, Shape], **kw):
5647
f_kw = ", ".join(f"{k}={v}" for k, v in kw.items())
5748
msg = f"out.shape={out_shape}, but should be {expected} [{func_name}({f_kw})]"
5849
if isinstance(expected, int):
@@ -61,6 +52,18 @@ def assert_shape(
6152

6253

6354

55+
def assert_fill(func_name: str, fill: float, dtype: DataType, out: Array, **kw):
56+
f_kw = ", ".join(f"{k}={v}" for k, v in kw.items())
57+
msg = (
58+
f"out not filled with {fill} [{func_name}({f_kw})]\n"
59+
f"{out=}"
60+
)
61+
if math.isnan(fill):
62+
assert ah.all(ah.isnan(out)), msg
63+
else:
64+
assert ah.all(ah.equal(out, ah.asarray(fill, dtype=dtype))), msg
65+
66+
6467
# Testing xp.arange() requires bounding the start/stop/step arguments to only
6568
# test argument combinations compliant with the Array API, as well as to not
6669
# produce arrays with sizes not supproted by an array module.
@@ -234,8 +237,9 @@ def test_eye(n_rows, n_cols, kw):
234237
)
235238

236239

240+
237241
@st.composite
238-
def full_fill_values(draw):
242+
def full_fill_values(draw) -> st.SearchStrategy[float]:
239243
kw = draw(st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw"))
240244
dtype = kw.get("dtype", None) or draw(default_safe_dtypes)
241245
return draw(xps.from_dtype(dtype))
@@ -266,10 +270,7 @@ def test_full(shape, fill_value, kw):
266270
else:
267271
assert_kw_dtype("full", kw["dtype"], out.dtype)
268272
assert_shape("full", out.shape, shape, shape=shape)
269-
if dh.is_float_dtype(out.dtype) and math.isnan(fill_value):
270-
assert ah.all(ah.isnan(out)), "full() array did not equal the fill value"
271-
else:
272-
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), "full() array did not equal the fill value"
273+
assert_fill("full", fill_value, dtype, out, fill_value=fill_value)
273274

274275

275276
@st.composite
@@ -291,13 +292,8 @@ def test_full_like(x, fill_value, kw):
291292
ph.assert_dtype("full_like", (x.dtype,), out.dtype)
292293
else:
293294
assert_kw_dtype("full_like", kw["dtype"], out.dtype)
294-
295295
assert_shape("full_like", out.shape, x.shape)
296-
if dh.is_float_dtype(dtype) and math.isnan(fill_value):
297-
assert ah.all(ah.isnan(out)), "full_like() array did not equal the fill value"
298-
else:
299-
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), "full_like() array did not equal the fill value"
300-
296+
assert_fill("full_like", fill_value, dtype, out, fill_value=fill_value)
301297

302298
finite_kw = {"allow_nan": False, "allow_infinity": False}
303299

@@ -364,7 +360,7 @@ def test_linspace(num, dtype, endpoint, data):
364360
# TODO: array assertions ala test_arange
365361

366362

367-
def make_one(dtype):
363+
def make_one(dtype: DataType) -> Union[bool, float]:
368364
if dtype is None or dh.is_float_dtype(dtype):
369365
return 1.0
370366
elif dh.is_int_dtype(dtype):
@@ -382,7 +378,7 @@ def test_ones(shape, kw):
382378
assert_kw_dtype("ones", kw["dtype"], out.dtype)
383379
assert_shape("ones", out.shape, shape, shape=shape)
384380
dtype = kw.get("dtype", None) or dh.default_float
385-
assert ah.all(ah.equal(out, ah.asarray(make_one(dtype), dtype=dtype))), "ones() array did not equal 1"
381+
assert_fill("ones", make_one(dtype), dtype, out)
386382

387383

388384
@given(
@@ -397,10 +393,10 @@ def test_ones_like(x, kw):
397393
assert_kw_dtype("ones_like", kw["dtype"], out.dtype)
398394
assert_shape("ones_like", out.shape, x.shape)
399395
dtype = kw.get("dtype", None) or x.dtype
400-
assert ah.all(ah.equal(out, ah.asarray(make_one(dtype), dtype=dtype))), "ones_like() array elements did not equal 1"
396+
assert_fill("ones_like", make_one(dtype), dtype, out)
401397

402398

403-
def make_zero(dtype):
399+
def make_zero(dtype: DataType) -> Union[bool, float]:
404400
if dtype is None or dh.is_float_dtype(dtype):
405401
return 0.0
406402
elif dh.is_int_dtype(dtype):
@@ -418,7 +414,7 @@ def test_zeros(shape, kw):
418414
assert_kw_dtype("zeros", kw["dtype"], out.dtype)
419415
assert_shape("zeros", out.shape, shape, shape=shape)
420416
dtype = kw.get("dtype", None) or dh.default_float
421-
assert ah.all(ah.equal(out, ah.asarray(make_zero(dtype), dtype=dtype))), "zeros() array did not equal 0"
417+
assert_fill("zeros", make_zero(dtype), dtype, out)
422418

423419

424420
@given(
@@ -433,4 +429,4 @@ def test_zeros_like(x, kw):
433429
assert_kw_dtype("zeros_like", kw["dtype"], out.dtype)
434430
assert_shape("zeros_like", out.shape, x.shape)
435431
dtype = kw.get("dtype", None) or x.dtype
436-
assert ah.all(ah.equal(out, ah.asarray(make_zero(dtype), dtype=out.dtype))), "xp.zeros_like() array elements did not ah.all xp.equal 0"
432+
assert_fill("zeros_like", make_zero(dtype), dtype, out)

array_api_tests/typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
__all__ = [
44
"DataType",
55
"ScalarType",
6+
"Array",
67
"Shape",
78
"Param",
89
]
910

1011
DataType = Type[Any]
1112
ScalarType = Union[Type[bool], Type[int], Type[float]]
13+
Array = Any
1214
Shape = Tuple[int, ...]
1315
Param = Tuple

0 commit comments

Comments
 (0)