Skip to content

Commit 567686a

Browse files
committed
Sort statistical tests by spec order
1 parent 3966c71 commit 567686a

File tree

1 file changed

+62
-62
lines changed

1 file changed

+62
-62
lines changed

array_api_tests/test_statistical_functions.py

Lines changed: 62 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -105,83 +105,83 @@ def assert_equals(
105105
),
106106
data=st.data(),
107107
)
108-
def test_min(x, data):
108+
def test_max(x, data):
109109
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
110110

111-
out = xp.min(x, **kw)
111+
out = xp.max(x, **kw)
112112

113-
ph.assert_dtype("min", x.dtype, out.dtype)
113+
ph.assert_dtype("max", x.dtype, out.dtype)
114114
_axes = normalise_axis(kw.get("axis", None), x.ndim)
115115
assert_keepdimable_shape(
116-
"min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
116+
"max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
117117
)
118118
scalar_type = dh.get_scalar_type(out.dtype)
119119
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
120-
min_ = scalar_type(out[out_idx])
120+
max_ = scalar_type(out[out_idx])
121121
elements = []
122122
for idx in indices:
123123
s = scalar_type(x[idx])
124124
elements.append(s)
125-
expected = min(elements)
126-
assert_equals("min", dh.get_scalar_type(out.dtype), out_idx, min_, expected)
125+
expected = max(elements)
126+
assert_equals("max", dh.get_scalar_type(out.dtype), out_idx, max_, expected)
127127

128128

129129
@given(
130130
x=xps.arrays(
131-
dtype=xps.numeric_dtypes(),
131+
dtype=xps.floating_dtypes(),
132132
shape=hh.shapes(min_side=1),
133133
elements={"allow_nan": False},
134134
),
135135
data=st.data(),
136136
)
137-
def test_max(x, data):
137+
def test_mean(x, data):
138138
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
139139

140-
out = xp.max(x, **kw)
140+
out = xp.mean(x, **kw)
141141

142-
ph.assert_dtype("max", x.dtype, out.dtype)
142+
ph.assert_dtype("mean", x.dtype, out.dtype)
143143
_axes = normalise_axis(kw.get("axis", None), x.ndim)
144144
assert_keepdimable_shape(
145-
"max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
145+
"mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
146146
)
147-
scalar_type = dh.get_scalar_type(out.dtype)
148147
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
149-
max_ = scalar_type(out[out_idx])
148+
mean = float(out[out_idx])
149+
assume(not math.isinf(mean)) # mean may become inf due to internal overflows
150150
elements = []
151151
for idx in indices:
152-
s = scalar_type(x[idx])
152+
s = float(x[idx])
153153
elements.append(s)
154-
expected = max(elements)
155-
assert_equals("max", dh.get_scalar_type(out.dtype), out_idx, max_, expected)
154+
expected = sum(elements) / len(elements)
155+
assert_equals("mean", dh.get_scalar_type(out.dtype), out_idx, mean, expected)
156156

157157

158158
@given(
159159
x=xps.arrays(
160-
dtype=xps.floating_dtypes(),
160+
dtype=xps.numeric_dtypes(),
161161
shape=hh.shapes(min_side=1),
162162
elements={"allow_nan": False},
163163
),
164164
data=st.data(),
165165
)
166-
def test_mean(x, data):
166+
def test_min(x, data):
167167
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
168168

169-
out = xp.mean(x, **kw)
169+
out = xp.min(x, **kw)
170170

171-
ph.assert_dtype("mean", x.dtype, out.dtype)
171+
ph.assert_dtype("min", x.dtype, out.dtype)
172172
_axes = normalise_axis(kw.get("axis", None), x.ndim)
173173
assert_keepdimable_shape(
174-
"mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
174+
"min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
175175
)
176+
scalar_type = dh.get_scalar_type(out.dtype)
176177
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
177-
mean = float(out[out_idx])
178-
assume(not math.isinf(mean)) # mean may become inf due to internal overflows
178+
min_ = scalar_type(out[out_idx])
179179
elements = []
180180
for idx in indices:
181-
s = float(x[idx])
181+
s = scalar_type(x[idx])
182182
elements.append(s)
183-
expected = sum(elements) / len(elements)
184-
assert_equals("mean", dh.get_scalar_type(out.dtype), out_idx, mean, expected)
183+
expected = min(elements)
184+
assert_equals("min", dh.get_scalar_type(out.dtype), out_idx, min_, expected)
185185

186186

187187
@given(
@@ -279,41 +279,6 @@ def test_std(x, data):
279279
# We can't easily test the result(s) as standard deviation methods vary a lot
280280

281281

282-
@given(
283-
x=xps.arrays(
284-
dtype=xps.floating_dtypes(),
285-
shape=hh.shapes(min_side=1),
286-
elements={"allow_nan": False},
287-
).filter(lambda x: x.size >= 2),
288-
data=st.data(),
289-
)
290-
def test_var(x, data):
291-
axis = data.draw(axes(x.ndim), label="axis")
292-
_axes = normalise_axis(axis, x.ndim)
293-
N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes)
294-
correction = data.draw(
295-
st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N),
296-
label="correction",
297-
)
298-
keepdims = data.draw(st.booleans(), label="keepdims")
299-
kw = data.draw(
300-
hh.specified_kwargs(
301-
("axis", axis, None),
302-
("correction", correction, 0.0),
303-
("keepdims", keepdims, False),
304-
),
305-
label="kw",
306-
)
307-
308-
out = xp.var(x, **kw)
309-
310-
ph.assert_dtype("var", x.dtype, out.dtype)
311-
assert_keepdimable_shape(
312-
"var", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
313-
)
314-
# We can't easily test the result(s) as variance methods vary a lot
315-
316-
317282
@given(
318283
x=xps.arrays(
319284
dtype=xps.numeric_dtypes(),
@@ -372,3 +337,38 @@ def test_sum(x, data):
372337
m, M = dh.dtype_ranges[out.dtype]
373338
assume(m <= expected <= M)
374339
assert_equals("sum", dh.get_scalar_type(out.dtype), out_idx, sum_, expected)
340+
341+
342+
@given(
343+
x=xps.arrays(
344+
dtype=xps.floating_dtypes(),
345+
shape=hh.shapes(min_side=1),
346+
elements={"allow_nan": False},
347+
).filter(lambda x: x.size >= 2),
348+
data=st.data(),
349+
)
350+
def test_var(x, data):
351+
axis = data.draw(axes(x.ndim), label="axis")
352+
_axes = normalise_axis(axis, x.ndim)
353+
N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes)
354+
correction = data.draw(
355+
st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N),
356+
label="correction",
357+
)
358+
keepdims = data.draw(st.booleans(), label="keepdims")
359+
kw = data.draw(
360+
hh.specified_kwargs(
361+
("axis", axis, None),
362+
("correction", correction, 0.0),
363+
("keepdims", keepdims, False),
364+
),
365+
label="kw",
366+
)
367+
368+
out = xp.var(x, **kw)
369+
370+
ph.assert_dtype("var", x.dtype, out.dtype)
371+
assert_keepdimable_shape(
372+
"var", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
373+
)
374+
# We can't easily test the result(s) as variance methods vary a lot

0 commit comments

Comments
 (0)