10
10
from . import pytest_helpers as ph
11
11
from . import xps
12
12
13
+ RTOL = 0.05
14
+
13
15
14
16
@given (
15
17
x = xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )),
@@ -37,7 +39,7 @@ def test_min(x, data):
37
39
if keepdims :
38
40
idx = tuple (1 for _ in x .shape )
39
41
msg = f"{ out .shape = } , should be reduced dimension { idx } [{ f_func } ]"
40
- assert out .shape == idx
42
+ assert out .shape == idx , msg
41
43
else :
42
44
ph .assert_shape ("min" , out .shape , (), ** kw )
43
45
@@ -84,7 +86,7 @@ def test_max(x, data):
84
86
if keepdims :
85
87
idx = tuple (1 for _ in x .shape )
86
88
msg = f"{ out .shape = } , should be reduced dimension { idx } [{ f_func } ]"
87
- assert out .shape == idx
89
+ assert out .shape == idx , msg
88
90
else :
89
91
ph .assert_shape ("max" , out .shape , (), ** kw )
90
92
@@ -105,11 +107,47 @@ def test_max(x, data):
105
107
assert max_ == expected , msg
106
108
107
109
108
- # TODO: generate kwargs
109
- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )))
110
- def test_mean (x ):
111
- xp .mean (x )
112
- # TODO
110
+ @given (
111
+ x = xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )),
112
+ data = st .data (),
113
+ )
114
+ def test_mean (x , data ):
115
+ axis_strats = [st .none ()]
116
+ if x .shape != ():
117
+ axis_strats .append (
118
+ st .integers (- x .ndim , x .ndim - 1 ) | xps .valid_tuple_axes (x .ndim )
119
+ )
120
+ kw = data .draw (
121
+ hh .kwargs (axis = st .one_of (axis_strats ), keepdims = st .booleans ()), label = "kw"
122
+ )
123
+
124
+ out = xp .mean (x , ** kw )
125
+
126
+ ph .assert_dtype ("mean" , x .dtype , out .dtype )
127
+
128
+ f_func = f"mean({ ph .fmt_kw (kw )} )"
129
+
130
+ # TODO: support axis
131
+ if kw .get ("axis" ) is None :
132
+ keepdims = kw .get ("keepdims" , False )
133
+ if keepdims :
134
+ idx = tuple (1 for _ in x .shape )
135
+ msg = f"{ out .shape = } , should be reduced dimension { idx } [{ f_func } ]"
136
+ assert out .shape == idx , msg
137
+ else :
138
+ ph .assert_shape ("max" , out .shape , (), ** kw )
139
+
140
+ # TODO: figure out NaN behaviour
141
+ if not xp .any (xp .isnan (x )):
142
+ _out = xp .reshape (out , ()) if keepdims else out
143
+ elements = []
144
+ for idx in ah .ndindex (x .shape ):
145
+ s = float (x [idx ])
146
+ elements .append (s )
147
+ mean = float (_out )
148
+ expected = sum (elements ) / len (elements )
149
+ msg = f"out={ mean } , should be roughly { expected } [{ f_func } ]"
150
+ assert math .isclose (mean , expected , rel_tol = RTOL ), msg
113
151
114
152
115
153
# TODO: generate kwargs
0 commit comments