Skip to content

Commit 816bf41

Browse files
committed
Test var() and sum()
1 parent 37bd6df commit 816bf41

File tree

1 file changed

+93
-16
lines changed

1 file changed

+93
-16
lines changed

array_api_tests/test_statistical_functions.py

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, .
5656

5757
def assert_keepdimable_shape(
5858
func_name: str,
59+
out_shape: Shape,
5960
in_shape: Shape,
6061
axes: Tuple[int, ...],
6162
keepdims: bool,
62-
out_shape: Shape,
6363
/,
6464
**kw,
6565
):
@@ -108,7 +108,7 @@ def test_min(x, data):
108108
ph.assert_dtype("min", x.dtype, out.dtype)
109109
_axes = normalise_axis(kw.get("axis", None), x.ndim)
110110
assert_keepdimable_shape(
111-
"min", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw
111+
"min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
112112
)
113113
scalar_type = dh.get_scalar_type(out.dtype)
114114
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
@@ -137,7 +137,7 @@ def test_max(x, data):
137137
ph.assert_dtype("max", x.dtype, out.dtype)
138138
_axes = normalise_axis(kw.get("axis", None), x.ndim)
139139
assert_keepdimable_shape(
140-
"max", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw
140+
"max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
141141
)
142142
scalar_type = dh.get_scalar_type(out.dtype)
143143
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
@@ -166,7 +166,7 @@ def test_mean(x, data):
166166
ph.assert_dtype("mean", x.dtype, out.dtype)
167167
_axes = normalise_axis(kw.get("axis", None), x.ndim)
168168
assert_keepdimable_shape(
169-
"mean", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw
169+
"mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
170170
)
171171
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
172172
mean = float(out[out_idx])
@@ -217,7 +217,7 @@ def test_prod(x, data):
217217
ph.assert_dtype("prod", x.dtype, out.dtype, _dtype)
218218
_axes = normalise_axis(kw.get("axis", None), x.ndim)
219219
assert_keepdimable_shape(
220-
"prod", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw
220+
"prod", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
221221
)
222222
scalar_type = dh.get_scalar_type(out.dtype)
223223
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
@@ -264,20 +264,97 @@ def test_std(x, data):
264264

265265
ph.assert_dtype("std", x.dtype, out.dtype)
266266
assert_keepdimable_shape(
267-
"std", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw
267+
"std", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
268268
)
269269
# We can't easily test the result(s) as standard deviation methods vary a lot
270270

271271

272-
# TODO: generate kwargs
273-
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)))
274-
def test_sum(x):
275-
xp.sum(x)
276-
# TODO
272+
@given(
273+
x=xps.arrays(
274+
dtype=xps.floating_dtypes(),
275+
shape=hh.shapes(min_side=1),
276+
elements={"allow_nan": False},
277+
).filter(lambda x: x.size >= 2),
278+
data=st.data(),
279+
)
280+
def test_var(x, data):
281+
axis = data.draw(axes(x.ndim), label="axis")
282+
_axes = normalise_axis(axis, x.ndim)
283+
N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes)
284+
correction = data.draw(
285+
st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N),
286+
label="correction",
287+
)
288+
keepdims = data.draw(st.booleans(), label="keepdims")
289+
kw = data.draw(
290+
hh.specified_kwargs(
291+
("axis", axis, None),
292+
("correction", correction, 0.0),
293+
("keepdims", keepdims, False),
294+
),
295+
label="kw",
296+
)
297+
298+
out = xp.var(x, **kw)
299+
300+
ph.assert_dtype("var", x.dtype, out.dtype)
301+
assert_keepdimable_shape(
302+
"var", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
303+
)
304+
# We can't easily test the result(s) as variance methods vary a lot
305+
306+
307+
@given(
308+
x=xps.arrays(
309+
dtype=xps.numeric_dtypes(),
310+
shape=hh.shapes(min_side=1),
311+
elements={"allow_nan": False},
312+
),
313+
data=st.data(),
314+
)
315+
def test_sum(x, data):
316+
kw = data.draw(
317+
hh.kwargs(
318+
axis=axes(x.ndim),
319+
dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes
320+
keepdims=st.booleans(),
321+
),
322+
label="kw",
323+
)
277324

325+
out = xp.sum(x, **kw)
278326

279-
# TODO: generate kwargs
280-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)))
281-
def test_var(x):
282-
xp.var(x)
283-
# TODO
327+
dtype = kw.get("dtype", None)
328+
if dtype is None:
329+
if dh.is_int_dtype(x.dtype):
330+
m, M = dh.dtype_ranges[x.dtype]
331+
d_m, d_M = dh.dtype_ranges[dh.default_int]
332+
if m < d_m or M > d_M:
333+
_dtype = x.dtype
334+
else:
335+
_dtype = dh.default_int
336+
else:
337+
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
338+
_dtype = x.dtype
339+
else:
340+
_dtype = dh.default_float
341+
else:
342+
_dtype = dtype
343+
ph.assert_dtype("sum", x.dtype, out.dtype, _dtype)
344+
_axes = normalise_axis(kw.get("axis", None), x.ndim)
345+
assert_keepdimable_shape(
346+
"sum", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
347+
)
348+
scalar_type = dh.get_scalar_type(out.dtype)
349+
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
350+
sum_ = scalar_type(out[out_idx])
351+
assume(not math.isinf(sum_))
352+
elements = []
353+
for idx in indices:
354+
s = scalar_type(x[idx])
355+
elements.append(s)
356+
expected = sum(elements)
357+
if dh.is_int_dtype(out.dtype):
358+
m, M = dh.dtype_ranges[out.dtype]
359+
assume(m <= expected <= M)
360+
assert_equals("sum", dh.get_scalar_type(out.dtype), out_idx, sum_, expected)

0 commit comments

Comments
 (0)