Skip to content

Commit 8686654

Browse files
committed
Test axes results
1 parent d44c689 commit 8686654

File tree

2 files changed

+166
-127
lines changed

2 files changed

+166
-127
lines changed

array_api_tests/meta/test_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import pytest
22

3-
from ..test_signatures import extension_module
43
from ..test_creation_functions import frange
4+
from ..test_signatures import extension_module
5+
from ..test_statistical_functions import axes_ndindex
56

67

78
def test_extension_module_is_extension():
@@ -24,3 +25,21 @@ def test_extension_func_is_not_extension():
2425
def test_frange(r, size, elements):
2526
assert len(r) == size
2627
assert list(r) == elements
28+
29+
30+
@pytest.mark.parametrize(
31+
"shape, axes, expected",
32+
[
33+
((), (), [((),)]),
34+
(
35+
(2, 2),
36+
(0,),
37+
[
38+
((0, 0), (1, 0)),
39+
((0, 1), (1, 1)),
40+
],
41+
),
42+
],
43+
)
44+
def test_axes_ndindex(shape, axes, expected):
45+
assert list(axes_ndindex(shape, axes)) == expected

array_api_tests/test_statistical_functions.py

Lines changed: 146 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
2-
from typing import Optional, Union
2+
from itertools import product
3+
from typing import Iterator, Optional, Tuple, Union
34

45
from hypothesis import assume, given
56
from hypothesis import strategies as st
@@ -21,23 +22,82 @@ def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]:
2122
return st.one_of(axes_strats)
2223

2324

25+
def normalise_axis(
26+
axis: Optional[Union[int, Tuple[int, ...]]], ndim: int
27+
) -> Tuple[int, ...]:
28+
if axis is None:
29+
return tuple(range(ndim))
30+
axes = axis if isinstance(axis, tuple) else (axis,)
31+
axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes)
32+
return axes
33+
34+
35+
def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, ...]]:
36+
base_iterables = []
37+
axes_iterables = []
38+
for axis, side in enumerate(shape):
39+
if axis in axes:
40+
base_iterables.append((None,))
41+
axes_iterables.append(range(side))
42+
else:
43+
base_iterables.append(range(side))
44+
axes_iterables.append((None,))
45+
for base_idx in product(*base_iterables):
46+
indices = []
47+
for idx in product(*axes_iterables):
48+
idx = list(idx)
49+
for axis, side in enumerate(idx):
50+
if axis not in axes:
51+
idx[axis] = base_idx[axis]
52+
idx = tuple(idx)
53+
indices.append(idx)
54+
yield tuple(indices)
55+
56+
57+
def assert_keepdimable_shape(
58+
func_name: str,
59+
in_shape: Shape,
60+
axes: Tuple[int, ...],
61+
keepdims: bool,
62+
out_shape: Shape,
63+
/,
64+
**kw,
65+
):
66+
if keepdims:
67+
shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape))
68+
else:
69+
shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes)
70+
ph.assert_shape(func_name, out_shape, shape, **kw)
71+
72+
2473
def assert_equals(
25-
func_name: str, type_: ScalarType, out: Scalar, expected: Scalar, /, **kw
74+
func_name: str,
75+
type_: ScalarType,
76+
idx: Shape,
77+
out: Scalar,
78+
expected: Scalar,
79+
/,
80+
**kw,
2681
):
82+
out_repr = "out" if idx == () else f"out[{idx}]"
2783
f_func = f"{func_name}({ph.fmt_kw(kw)})"
2884
if type_ is bool or type_ is int:
29-
msg = f"{out=}, should be {expected} [{f_func}]"
85+
msg = f"{out_repr}={out}, should be {expected} [{f_func}]"
3086
assert out == expected, msg
3187
elif math.isnan(expected):
32-
msg = f"{out=}, should be {expected} [{f_func}]"
88+
msg = f"{out_repr}={out}, should be {expected} [{f_func}]"
3389
assert math.isnan(out), msg
3490
else:
35-
msg = f"{out=}, should be roughly {expected} [{f_func}]"
91+
msg = f"{out_repr}={out}, should be roughly {expected} [{f_func}]"
3692
assert math.isclose(out, expected, rel_tol=0.05), msg
3793

3894

3995
@given(
40-
x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)),
96+
x=xps.arrays(
97+
dtype=xps.numeric_dtypes(),
98+
shape=hh.shapes(min_side=1),
99+
elements={"allow_nan": False},
100+
),
41101
data=st.data(),
42102
)
43103
def test_min(x, data):
@@ -46,34 +106,27 @@ def test_min(x, data):
46106
out = xp.min(x, **kw)
47107

48108
ph.assert_dtype("min", x.dtype, out.dtype)
49-
50-
f_func = f"min({ph.fmt_kw(kw)})"
51-
52-
# TODO: support axis
53-
if kw.get("axis", None) is None:
54-
keepdims = kw.get("keepdims", False)
55-
if keepdims:
56-
shape = tuple(1 for _ in x.shape)
57-
msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]"
58-
assert out.shape == shape, msg
59-
else:
60-
ph.assert_shape("min", out.shape, (), **kw)
61-
62-
# TODO: figure out NaN behaviour
63-
if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)):
64-
_out = xp.reshape(out, ()) if keepdims else out
65-
scalar_type = dh.get_scalar_type(out.dtype)
66-
elements = []
67-
for idx in ah.ndindex(x.shape):
68-
s = scalar_type(x[idx])
69-
elements.append(s)
70-
min_ = scalar_type(_out)
71-
expected = min(elements)
72-
assert_equals("min", dh.get_scalar_type(out.dtype), min_, expected)
109+
_axes = normalise_axis(kw.get("axis", None), x.ndim)
110+
assert_keepdimable_shape(
111+
"min", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw
112+
)
113+
scalar_type = dh.get_scalar_type(out.dtype)
114+
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
115+
min_ = scalar_type(out[out_idx])
116+
elements = []
117+
for idx in indices:
118+
s = scalar_type(x[idx])
119+
elements.append(s)
120+
expected = min(elements)
121+
assert_equals("min", dh.get_scalar_type(out.dtype), out_idx, min_, expected)
73122

74123

75124
@given(
76-
x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)),
125+
x=xps.arrays(
126+
dtype=xps.numeric_dtypes(),
127+
shape=hh.shapes(min_side=1),
128+
elements={"allow_nan": False},
129+
),
77130
data=st.data(),
78131
)
79132
def test_max(x, data):
@@ -82,34 +135,27 @@ def test_max(x, data):
82135
out = xp.max(x, **kw)
83136

84137
ph.assert_dtype("max", x.dtype, out.dtype)
85-
86-
f_func = f"max({ph.fmt_kw(kw)})"
87-
88-
# TODO: support axis
89-
if kw.get("axis", None) is None:
90-
keepdims = kw.get("keepdims", False)
91-
if keepdims:
92-
shape = tuple(1 for _ in x.shape)
93-
msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]"
94-
assert out.shape == shape, msg
95-
else:
96-
ph.assert_shape("max", out.shape, (), **kw)
97-
98-
# TODO: figure out NaN behaviour
99-
if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)):
100-
_out = xp.reshape(out, ()) if keepdims else out
101-
scalar_type = dh.get_scalar_type(out.dtype)
102-
elements = []
103-
for idx in ah.ndindex(x.shape):
104-
s = scalar_type(x[idx])
105-
elements.append(s)
106-
max_ = scalar_type(_out)
107-
expected = max(elements)
108-
assert_equals("mean", dh.get_scalar_type(out.dtype), max_, expected)
138+
_axes = normalise_axis(kw.get("axis", None), x.ndim)
139+
assert_keepdimable_shape(
140+
"max", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw
141+
)
142+
scalar_type = dh.get_scalar_type(out.dtype)
143+
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
144+
max_ = scalar_type(out[out_idx])
145+
elements = []
146+
for idx in indices:
147+
s = scalar_type(x[idx])
148+
elements.append(s)
149+
expected = max(elements)
150+
assert_equals("max", dh.get_scalar_type(out.dtype), out_idx, max_, expected)
109151

110152

111153
@given(
112-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)),
154+
x=xps.arrays(
155+
dtype=xps.floating_dtypes(),
156+
shape=hh.shapes(min_side=1),
157+
elements={"allow_nan": False},
158+
),
113159
data=st.data(),
114160
)
115161
def test_mean(x, data):
@@ -118,33 +164,26 @@ def test_mean(x, data):
118164
out = xp.mean(x, **kw)
119165

120166
ph.assert_dtype("mean", x.dtype, out.dtype)
121-
122-
f_func = f"mean({ph.fmt_kw(kw)})"
123-
124-
# TODO: support axis
125-
if kw.get("axis", None) is None:
126-
keepdims = kw.get("keepdims", False)
127-
if keepdims:
128-
shape = tuple(1 for _ in x.shape)
129-
msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]"
130-
assert out.shape == shape, msg
131-
else:
132-
ph.assert_shape("max", out.shape, (), **kw)
133-
134-
# TODO: figure out NaN behaviour
135-
if not xp.any(xp.isnan(x)):
136-
_out = xp.reshape(out, ()) if keepdims else out
137-
elements = []
138-
for idx in ah.ndindex(x.shape):
139-
s = float(x[idx])
140-
elements.append(s)
141-
mean = float(_out)
142-
expected = sum(elements) / len(elements)
143-
assert_equals("mean", float, mean, expected)
167+
_axes = normalise_axis(kw.get("axis", None), x.ndim)
168+
assert_keepdimable_shape(
169+
"mean", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw
170+
)
171+
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
172+
mean = float(out[out_idx])
173+
elements = []
174+
for idx in indices:
175+
s = float(x[idx])
176+
elements.append(s)
177+
expected = sum(elements) / len(elements)
178+
assert_equals("mean", dh.get_scalar_type(out.dtype), out_idx, mean, expected)
144179

145180

146181
@given(
147-
x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)),
182+
x=xps.arrays(
183+
dtype=xps.numeric_dtypes(),
184+
shape=hh.shapes(min_side=1),
185+
elements={"allow_nan": False},
186+
),
148187
data=st.data(),
149188
)
150189
def test_prod(x, data):
@@ -176,52 +215,37 @@ def test_prod(x, data):
176215
else:
177216
_dtype = dtype
178217
ph.assert_dtype("prod", x.dtype, out.dtype, _dtype)
179-
180-
f_func = f"prod({ph.fmt_kw(kw)})"
181-
182-
# TODO: support axis
183-
if kw.get("axis", None) is None:
184-
keepdims = kw.get("keepdims", False)
185-
if keepdims:
186-
shape = tuple(1 for _ in x.shape)
187-
msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]"
188-
assert out.shape == shape, msg
189-
else:
190-
ph.assert_shape("prod", out.shape, (), **kw)
191-
192-
# TODO: figure out NaN behaviour
193-
if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)):
194-
_out = xp.reshape(out, ()) if keepdims else out
195-
scalar_type = dh.get_scalar_type(out.dtype)
196-
elements = []
197-
for idx in ah.ndindex(x.shape):
198-
s = scalar_type(x[idx])
199-
elements.append(s)
200-
prod = scalar_type(_out)
201-
expected = math.prod(elements)
202-
if dh.is_int_dtype(out.dtype):
203-
m, M = dh.dtype_ranges[out.dtype]
204-
assume(m <= expected <= M)
205-
assert_equals("prod", dh.get_scalar_type(out.dtype), prod, expected)
218+
_axes = normalise_axis(kw.get("axis", None), x.ndim)
219+
assert_keepdimable_shape(
220+
"prod", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw
221+
)
222+
scalar_type = dh.get_scalar_type(out.dtype)
223+
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
224+
prod = scalar_type(out[out_idx])
225+
assume(not math.isinf(prod))
226+
elements = []
227+
for idx in indices:
228+
s = scalar_type(x[idx])
229+
elements.append(s)
230+
expected = math.prod(elements)
231+
if dh.is_int_dtype(out.dtype):
232+
m, M = dh.dtype_ranges[out.dtype]
233+
assume(m <= expected <= M)
234+
assert_equals("prod", dh.get_scalar_type(out.dtype), out_idx, prod, expected)
206235

207236

208237
@given(
209-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)).filter(
210-
lambda x: x.size >= 2
211-
),
238+
x=xps.arrays(
239+
dtype=xps.floating_dtypes(),
240+
shape=hh.shapes(min_side=1),
241+
elements={"allow_nan": False},
242+
).filter(lambda x: x.size >= 2),
212243
data=st.data(),
213244
)
214245
def test_std(x, data):
215246
axis = data.draw(axes(x.ndim), label="axis")
216-
if axis is None:
217-
N = x.size
218-
_axes = tuple(range(x.ndim))
219-
else:
220-
_axes = axis if isinstance(axis, tuple) else (axis,)
221-
_axes = tuple(
222-
axis if axis >= 0 else x.ndim + axis for axis in _axes
223-
) # normalise
224-
N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes)
247+
_axes = normalise_axis(axis, x.ndim)
248+
N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes)
225249
correction = data.draw(
226250
st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N),
227251
label="correction",
@@ -239,13 +263,9 @@ def test_std(x, data):
239263
out = xp.std(x, **kw)
240264

241265
ph.assert_dtype("std", x.dtype, out.dtype)
242-
243-
if keepdims:
244-
shape = tuple(1 if axis in _axes else side for axis, side in enumerate(x.shape))
245-
else:
246-
shape = tuple(side for axis, side in enumerate(x.shape) if axis not in _axes)
247-
ph.assert_shape("std", out.shape, shape, **kw)
248-
266+
assert_keepdimable_shape(
267+
"std", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw
268+
)
249269
# We can't easily test the result(s) as standard deviation methods vary a lot
250270

251271

0 commit comments

Comments
 (0)