Skip to content

Commit 2bc4c2f

Browse files
committed
Improve min()/max() tests
1 parent 9a0ccf8 commit 2bc4c2f

File tree

1 file changed

+96
-10
lines changed

1 file changed

+96
-10
lines changed

array_api_tests/test_statistical_functions.py

Lines changed: 96 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,108 @@
1+
import math
2+
13
from hypothesis import given
4+
from hypothesis import strategies as st
25

36
from . import _array_module as xp
7+
from . import array_helpers as ah
8+
from . import dtype_helpers as dh
49
from . import hypothesis_helpers as hh
10+
from . import pytest_helpers as ph
511
from . import xps
612

713

8-
# TODO: generate kwargs
9-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)))
10-
def test_min(x):
11-
xp.min(x)
12-
# TODO
14+
@given(
15+
x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)),
16+
data=st.data(),
17+
)
18+
def test_min(x, data):
19+
axis_strats = [st.none()]
20+
if x.shape != ():
21+
axis_strats.append(
22+
st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim)
23+
)
24+
kw = data.draw(
25+
hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw"
26+
)
1327

28+
out = xp.min(x, **kw)
1429

15-
# TODO: generate kwargs
16-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)))
17-
def test_max(x):
18-
xp.max(x)
19-
# TODO
30+
ph.assert_dtype("min", x.dtype, out.dtype)
31+
32+
f_func = f"min({ph.fmt_kw(kw)})"
33+
34+
# TODO: support axis
35+
if kw.get("axis") is None:
36+
keepdims = kw.get("keepdims", False)
37+
if keepdims:
38+
idx = tuple(1 for _ in x.shape)
39+
msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]"
40+
assert out.shape == idx
41+
else:
42+
ph.assert_shape("min", out.shape, (), **kw)
43+
44+
# TODO: figure out NaN behaviour
45+
if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)):
46+
_out = xp.reshape(out, ()) if keepdims else out
47+
scalar_type = dh.get_scalar_type(out.dtype)
48+
elements = []
49+
for idx in ah.ndindex(x.shape):
50+
s = scalar_type(x[idx])
51+
elements.append(s)
52+
min_ = scalar_type(_out)
53+
expected = min(elements)
54+
msg = f"out={min_}, should be {expected} [{f_func}]"
55+
if math.isnan(min_):
56+
assert math.isnan(expected), msg
57+
else:
58+
assert min_ == expected, msg
59+
60+
61+
@given(
62+
x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)),
63+
data=st.data(),
64+
)
65+
def test_max(x, data):
66+
axis_strats = [st.none()]
67+
if x.shape != ():
68+
axis_strats.append(
69+
st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim)
70+
)
71+
kw = data.draw(
72+
hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw"
73+
)
74+
75+
out = xp.max(x, **kw)
76+
77+
ph.assert_dtype("max", x.dtype, out.dtype)
78+
79+
f_func = f"max({ph.fmt_kw(kw)})"
80+
81+
# TODO: support axis
82+
if kw.get("axis") is None:
83+
keepdims = kw.get("keepdims", False)
84+
if keepdims:
85+
idx = tuple(1 for _ in x.shape)
86+
msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]"
87+
assert out.shape == idx
88+
else:
89+
ph.assert_shape("max", out.shape, (), **kw)
90+
91+
# TODO: figure out NaN behaviour
92+
if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)):
93+
_out = xp.reshape(out, ()) if keepdims else out
94+
scalar_type = dh.get_scalar_type(out.dtype)
95+
elements = []
96+
for idx in ah.ndindex(x.shape):
97+
s = scalar_type(x[idx])
98+
elements.append(s)
99+
max_ = scalar_type(_out)
100+
expected = max(elements)
101+
msg = f"out={max_}, should be {expected} [{f_func}]"
102+
if math.isnan(max_):
103+
assert math.isnan(expected), msg
104+
else:
105+
assert max_ == expected, msg
20106

21107

22108
# TODO: generate kwargs

0 commit comments

Comments
 (0)