Skip to content

Commit 98d04a4

Browse files
committed
Test mean()
1 parent 2bc4c2f commit 98d04a4

File tree

1 file changed

+45
-7
lines changed

1 file changed

+45
-7
lines changed

array_api_tests/test_statistical_functions.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from . import pytest_helpers as ph
1111
from . import xps
1212

13+
RTOL = 0.05
14+
1315

1416
@given(
1517
x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)),
@@ -37,7 +39,7 @@ def test_min(x, data):
3739
if keepdims:
3840
idx = tuple(1 for _ in x.shape)
3941
msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]"
40-
assert out.shape == idx
42+
assert out.shape == idx, msg
4143
else:
4244
ph.assert_shape("min", out.shape, (), **kw)
4345

@@ -84,7 +86,7 @@ def test_max(x, data):
8486
if keepdims:
8587
idx = tuple(1 for _ in x.shape)
8688
msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]"
87-
assert out.shape == idx
89+
assert out.shape == idx, msg
8890
else:
8991
ph.assert_shape("max", out.shape, (), **kw)
9092

@@ -105,11 +107,47 @@ def test_max(x, data):
105107
assert max_ == expected, msg
106108

107109

108-
# TODO: generate kwargs
109-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)))
110-
def test_mean(x):
111-
xp.mean(x)
112-
# TODO
110+
@given(
111+
x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)),
112+
data=st.data(),
113+
)
114+
def test_mean(x, data):
115+
axis_strats = [st.none()]
116+
if x.shape != ():
117+
axis_strats.append(
118+
st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim)
119+
)
120+
kw = data.draw(
121+
hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw"
122+
)
123+
124+
out = xp.mean(x, **kw)
125+
126+
ph.assert_dtype("mean", x.dtype, out.dtype)
127+
128+
f_func = f"mean({ph.fmt_kw(kw)})"
129+
130+
# TODO: support axis
131+
if kw.get("axis") is None:
132+
keepdims = kw.get("keepdims", False)
133+
if keepdims:
134+
idx = tuple(1 for _ in x.shape)
135+
msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]"
136+
assert out.shape == idx, msg
137+
else:
138+
ph.assert_shape("max", out.shape, (), **kw)
139+
140+
# TODO: figure out NaN behaviour
141+
if not xp.any(xp.isnan(x)):
142+
_out = xp.reshape(out, ()) if keepdims else out
143+
elements = []
144+
for idx in ah.ndindex(x.shape):
145+
s = float(x[idx])
146+
elements.append(s)
147+
mean = float(_out)
148+
expected = sum(elements) / len(elements)
149+
msg = f"out={mean}, should be roughly {expected} [{f_func}]"
150+
assert math.isclose(mean, expected, rel_tol=RTOL), msg
113151

114152

115153
# TODO: generate kwargs

0 commit comments

Comments
 (0)