Skip to content

Commit 1b01e11

Browse files
committed
Refactor axes strategies for stat functions
1 parent 6236e0a commit 1b01e11

File tree

1 file changed

+22
-35
lines changed

1 file changed

+22
-35
lines changed

array_api_tests/test_statistical_functions.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from typing import Optional, Union
23

34
from hypothesis import assume, given
45
from hypothesis import strategies as st
@@ -9,7 +10,15 @@
910
from . import hypothesis_helpers as hh
1011
from . import pytest_helpers as ph
1112
from . import xps
12-
from .typing import Scalar, ScalarType
13+
from .typing import Scalar, ScalarType, Shape
14+
15+
16+
def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]:
17+
axes_strats = [st.none()]
18+
if ndim != 0:
19+
axes_strats.append(st.integers(-ndim, ndim - 1))
20+
axes_strats.append(xps.valid_tuple_axes(ndim))
21+
return st.one_of(axes_strats)
1322

1423

1524
def assert_equals(
@@ -32,14 +41,7 @@ def assert_equals(
3241
data=st.data(),
3342
)
3443
def test_min(x, data):
35-
axis_strats = [st.none()]
36-
if x.shape != ():
37-
axis_strats.append(
38-
st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim)
39-
)
40-
kw = data.draw(
41-
hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw"
42-
)
44+
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
4345

4446
out = xp.min(x, **kw)
4547

@@ -75,14 +77,7 @@ def test_min(x, data):
7577
data=st.data(),
7678
)
7779
def test_max(x, data):
78-
axis_strats = [st.none()]
79-
if x.shape != ():
80-
axis_strats.append(
81-
st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim)
82-
)
83-
kw = data.draw(
84-
hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw"
85-
)
80+
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
8681

8782
out = xp.max(x, **kw)
8883

@@ -118,14 +113,7 @@ def test_max(x, data):
118113
data=st.data(),
119114
)
120115
def test_mean(x, data):
121-
axis_strats = [st.none()]
122-
if x.shape != ():
123-
axis_strats.append(
124-
st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim)
125-
)
126-
kw = data.draw(
127-
hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw"
128-
)
116+
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
129117

130118
out = xp.mean(x, **kw)
131119

@@ -160,14 +148,9 @@ def test_mean(x, data):
160148
data=st.data(),
161149
)
162150
def test_prod(x, data):
163-
axis_strats = [st.none()]
164-
if x.shape != ():
165-
axis_strats.append(
166-
st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim)
167-
)
168151
kw = data.draw(
169152
hh.kwargs(
170-
axis=st.one_of(axis_strats),
153+
axis=axes(x.ndim),
171154
dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes
172155
keepdims=st.booleans(),
173156
),
@@ -222,10 +205,14 @@ def test_prod(x, data):
222205
assert_equals("prod", dh.get_scalar_type(out.dtype), prod, expected)
223206

224207

225-
# TODO: generate kwargs
226-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)))
227-
def test_std(x):
228-
xp.std(x)
208+
@given(
209+
x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)),
210+
data=st.data(),
211+
)
212+
def test_std(x, data):
213+
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
214+
215+
xp.std(x, **kw)
229216
# TODO
230217

231218

0 commit comments

Comments
 (0)