1
1
import math
2
- from typing import Optional , Union
2
+ from itertools import product
3
+ from typing import Iterator , Optional , Tuple , Union
3
4
4
5
from hypothesis import assume , given
5
6
from hypothesis import strategies as st
@@ -21,23 +22,82 @@ def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]:
21
22
return st .one_of (axes_strats )
22
23
23
24
25
+ def normalise_axis (
26
+ axis : Optional [Union [int , Tuple [int , ...]]], ndim : int
27
+ ) -> Tuple [int , ...]:
28
+ if axis is None :
29
+ return tuple (range (ndim ))
30
+ axes = axis if isinstance (axis , tuple ) else (axis ,)
31
+ axes = tuple (axis if axis >= 0 else ndim + axis for axis in axes )
32
+ return axes
33
+
34
+
35
+ def axes_ndindex (shape : Shape , axes : Tuple [int , ...]) -> Iterator [Tuple [Shape , ...]]:
36
+ base_iterables = []
37
+ axes_iterables = []
38
+ for axis , side in enumerate (shape ):
39
+ if axis in axes :
40
+ base_iterables .append ((None ,))
41
+ axes_iterables .append (range (side ))
42
+ else :
43
+ base_iterables .append (range (side ))
44
+ axes_iterables .append ((None ,))
45
+ for base_idx in product (* base_iterables ):
46
+ indices = []
47
+ for idx in product (* axes_iterables ):
48
+ idx = list (idx )
49
+ for axis , side in enumerate (idx ):
50
+ if axis not in axes :
51
+ idx [axis ] = base_idx [axis ]
52
+ idx = tuple (idx )
53
+ indices .append (idx )
54
+ yield tuple (indices )
55
+
56
+
57
+ def assert_keepdimable_shape (
58
+ func_name : str ,
59
+ in_shape : Shape ,
60
+ axes : Tuple [int , ...],
61
+ keepdims : bool ,
62
+ out_shape : Shape ,
63
+ / ,
64
+ ** kw ,
65
+ ):
66
+ if keepdims :
67
+ shape = tuple (1 if axis in axes else side for axis , side in enumerate (in_shape ))
68
+ else :
69
+ shape = tuple (side for axis , side in enumerate (in_shape ) if axis not in axes )
70
+ ph .assert_shape (func_name , out_shape , shape , ** kw )
71
+
72
+
24
73
def assert_equals (
25
- func_name : str , type_ : ScalarType , out : Scalar , expected : Scalar , / , ** kw
74
+ func_name : str ,
75
+ type_ : ScalarType ,
76
+ idx : Shape ,
77
+ out : Scalar ,
78
+ expected : Scalar ,
79
+ / ,
80
+ ** kw ,
26
81
):
82
+ out_repr = "out" if idx == () else f"out[{ idx } ]"
27
83
f_func = f"{ func_name } ({ ph .fmt_kw (kw )} )"
28
84
if type_ is bool or type_ is int :
29
- msg = f"{ out = } , should be { expected } [{ f_func } ]"
85
+ msg = f"{ out_repr } = { out } , should be { expected } [{ f_func } ]"
30
86
assert out == expected , msg
31
87
elif math .isnan (expected ):
32
- msg = f"{ out = } , should be { expected } [{ f_func } ]"
88
+ msg = f"{ out_repr } = { out } , should be { expected } [{ f_func } ]"
33
89
assert math .isnan (out ), msg
34
90
else :
35
- msg = f"{ out = } , should be roughly { expected } [{ f_func } ]"
91
+ msg = f"{ out_repr } = { out } , should be roughly { expected } [{ f_func } ]"
36
92
assert math .isclose (out , expected , rel_tol = 0.05 ), msg
37
93
38
94
39
95
@given (
40
- x = xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )),
96
+ x = xps .arrays (
97
+ dtype = xps .numeric_dtypes (),
98
+ shape = hh .shapes (min_side = 1 ),
99
+ elements = {"allow_nan" : False },
100
+ ),
41
101
data = st .data (),
42
102
)
43
103
def test_min (x , data ):
@@ -46,34 +106,27 @@ def test_min(x, data):
46
106
out = xp .min (x , ** kw )
47
107
48
108
ph .assert_dtype ("min" , x .dtype , out .dtype )
49
-
50
- f_func = f"min({ ph .fmt_kw (kw )} )"
51
-
52
- # TODO: support axis
53
- if kw .get ("axis" , None ) is None :
54
- keepdims = kw .get ("keepdims" , False )
55
- if keepdims :
56
- shape = tuple (1 for _ in x .shape )
57
- msg = f"{ out .shape = } , should be reduced dimension { shape } [{ f_func } ]"
58
- assert out .shape == shape , msg
59
- else :
60
- ph .assert_shape ("min" , out .shape , (), ** kw )
61
-
62
- # TODO: figure out NaN behaviour
63
- if dh .is_int_dtype (x .dtype ) or not xp .any (xp .isnan (x )):
64
- _out = xp .reshape (out , ()) if keepdims else out
65
- scalar_type = dh .get_scalar_type (out .dtype )
66
- elements = []
67
- for idx in ah .ndindex (x .shape ):
68
- s = scalar_type (x [idx ])
69
- elements .append (s )
70
- min_ = scalar_type (_out )
71
- expected = min (elements )
72
- assert_equals ("min" , dh .get_scalar_type (out .dtype ), min_ , expected )
109
+ _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
110
+ assert_keepdimable_shape (
111
+ "min" , x .shape , _axes , kw .get ("keepdims" , False ), out .shape , ** kw
112
+ )
113
+ scalar_type = dh .get_scalar_type (out .dtype )
114
+ for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
115
+ min_ = scalar_type (out [out_idx ])
116
+ elements = []
117
+ for idx in indices :
118
+ s = scalar_type (x [idx ])
119
+ elements .append (s )
120
+ expected = min (elements )
121
+ assert_equals ("min" , dh .get_scalar_type (out .dtype ), out_idx , min_ , expected )
73
122
74
123
75
124
@given (
76
- x = xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )),
125
+ x = xps .arrays (
126
+ dtype = xps .numeric_dtypes (),
127
+ shape = hh .shapes (min_side = 1 ),
128
+ elements = {"allow_nan" : False },
129
+ ),
77
130
data = st .data (),
78
131
)
79
132
def test_max (x , data ):
@@ -82,34 +135,27 @@ def test_max(x, data):
82
135
out = xp .max (x , ** kw )
83
136
84
137
ph .assert_dtype ("max" , x .dtype , out .dtype )
85
-
86
- f_func = f"max({ ph .fmt_kw (kw )} )"
87
-
88
- # TODO: support axis
89
- if kw .get ("axis" , None ) is None :
90
- keepdims = kw .get ("keepdims" , False )
91
- if keepdims :
92
- shape = tuple (1 for _ in x .shape )
93
- msg = f"{ out .shape = } , should be reduced dimension { shape } [{ f_func } ]"
94
- assert out .shape == shape , msg
95
- else :
96
- ph .assert_shape ("max" , out .shape , (), ** kw )
97
-
98
- # TODO: figure out NaN behaviour
99
- if dh .is_int_dtype (x .dtype ) or not xp .any (xp .isnan (x )):
100
- _out = xp .reshape (out , ()) if keepdims else out
101
- scalar_type = dh .get_scalar_type (out .dtype )
102
- elements = []
103
- for idx in ah .ndindex (x .shape ):
104
- s = scalar_type (x [idx ])
105
- elements .append (s )
106
- max_ = scalar_type (_out )
107
- expected = max (elements )
108
- assert_equals ("mean" , dh .get_scalar_type (out .dtype ), max_ , expected )
138
+ _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
139
+ assert_keepdimable_shape (
140
+ "max" , x .shape , _axes , kw .get ("keepdims" , False ), out .shape , ** kw
141
+ )
142
+ scalar_type = dh .get_scalar_type (out .dtype )
143
+ for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
144
+ max_ = scalar_type (out [out_idx ])
145
+ elements = []
146
+ for idx in indices :
147
+ s = scalar_type (x [idx ])
148
+ elements .append (s )
149
+ expected = max (elements )
150
+ assert_equals ("max" , dh .get_scalar_type (out .dtype ), out_idx , max_ , expected )
109
151
110
152
111
153
@given (
112
- x = xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )),
154
+ x = xps .arrays (
155
+ dtype = xps .floating_dtypes (),
156
+ shape = hh .shapes (min_side = 1 ),
157
+ elements = {"allow_nan" : False },
158
+ ),
113
159
data = st .data (),
114
160
)
115
161
def test_mean (x , data ):
@@ -118,33 +164,26 @@ def test_mean(x, data):
118
164
out = xp .mean (x , ** kw )
119
165
120
166
ph .assert_dtype ("mean" , x .dtype , out .dtype )
121
-
122
- f_func = f"mean({ ph .fmt_kw (kw )} )"
123
-
124
- # TODO: support axis
125
- if kw .get ("axis" , None ) is None :
126
- keepdims = kw .get ("keepdims" , False )
127
- if keepdims :
128
- shape = tuple (1 for _ in x .shape )
129
- msg = f"{ out .shape = } , should be reduced dimension { shape } [{ f_func } ]"
130
- assert out .shape == shape , msg
131
- else :
132
- ph .assert_shape ("max" , out .shape , (), ** kw )
133
-
134
- # TODO: figure out NaN behaviour
135
- if not xp .any (xp .isnan (x )):
136
- _out = xp .reshape (out , ()) if keepdims else out
137
- elements = []
138
- for idx in ah .ndindex (x .shape ):
139
- s = float (x [idx ])
140
- elements .append (s )
141
- mean = float (_out )
142
- expected = sum (elements ) / len (elements )
143
- assert_equals ("mean" , float , mean , expected )
167
+ _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
168
+ assert_keepdimable_shape (
169
+ "mean" , x .shape , _axes , kw .get ("keepdims" , False ), out .shape , ** kw
170
+ )
171
+ for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
172
+ mean = float (out [out_idx ])
173
+ elements = []
174
+ for idx in indices :
175
+ s = float (x [idx ])
176
+ elements .append (s )
177
+ expected = sum (elements ) / len (elements )
178
+ assert_equals ("mean" , dh .get_scalar_type (out .dtype ), out_idx , mean , expected )
144
179
145
180
146
181
@given (
147
- x = xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )),
182
+ x = xps .arrays (
183
+ dtype = xps .numeric_dtypes (),
184
+ shape = hh .shapes (min_side = 1 ),
185
+ elements = {"allow_nan" : False },
186
+ ),
148
187
data = st .data (),
149
188
)
150
189
def test_prod (x , data ):
@@ -176,52 +215,37 @@ def test_prod(x, data):
176
215
else :
177
216
_dtype = dtype
178
217
ph .assert_dtype ("prod" , x .dtype , out .dtype , _dtype )
179
-
180
- f_func = f"prod({ ph .fmt_kw (kw )} )"
181
-
182
- # TODO: support axis
183
- if kw .get ("axis" , None ) is None :
184
- keepdims = kw .get ("keepdims" , False )
185
- if keepdims :
186
- shape = tuple (1 for _ in x .shape )
187
- msg = f"{ out .shape = } , should be reduced dimension { shape } [{ f_func } ]"
188
- assert out .shape == shape , msg
189
- else :
190
- ph .assert_shape ("prod" , out .shape , (), ** kw )
191
-
192
- # TODO: figure out NaN behaviour
193
- if dh .is_int_dtype (x .dtype ) or not xp .any (xp .isnan (x )):
194
- _out = xp .reshape (out , ()) if keepdims else out
195
- scalar_type = dh .get_scalar_type (out .dtype )
196
- elements = []
197
- for idx in ah .ndindex (x .shape ):
198
- s = scalar_type (x [idx ])
199
- elements .append (s )
200
- prod = scalar_type (_out )
201
- expected = math .prod (elements )
202
- if dh .is_int_dtype (out .dtype ):
203
- m , M = dh .dtype_ranges [out .dtype ]
204
- assume (m <= expected <= M )
205
- assert_equals ("prod" , dh .get_scalar_type (out .dtype ), prod , expected )
218
+ _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
219
+ assert_keepdimable_shape (
220
+ "prod" , x .shape , _axes , kw .get ("keepdims" , False ), out .shape , ** kw
221
+ )
222
+ scalar_type = dh .get_scalar_type (out .dtype )
223
+ for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
224
+ prod = scalar_type (out [out_idx ])
225
+ assume (not math .isinf (prod ))
226
+ elements = []
227
+ for idx in indices :
228
+ s = scalar_type (x [idx ])
229
+ elements .append (s )
230
+ expected = math .prod (elements )
231
+ if dh .is_int_dtype (out .dtype ):
232
+ m , M = dh .dtype_ranges [out .dtype ]
233
+ assume (m <= expected <= M )
234
+ assert_equals ("prod" , dh .get_scalar_type (out .dtype ), out_idx , prod , expected )
206
235
207
236
208
237
@given (
209
- x = xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )).filter (
210
- lambda x : x .size >= 2
211
- ),
238
+ x = xps .arrays (
239
+ dtype = xps .floating_dtypes (),
240
+ shape = hh .shapes (min_side = 1 ),
241
+ elements = {"allow_nan" : False },
242
+ ).filter (lambda x : x .size >= 2 ),
212
243
data = st .data (),
213
244
)
214
245
def test_std (x , data ):
215
246
axis = data .draw (axes (x .ndim ), label = "axis" )
216
- if axis is None :
217
- N = x .size
218
- _axes = tuple (range (x .ndim ))
219
- else :
220
- _axes = axis if isinstance (axis , tuple ) else (axis ,)
221
- _axes = tuple (
222
- axis if axis >= 0 else x .ndim + axis for axis in _axes
223
- ) # normalise
224
- N = sum (side for axis , side in enumerate (x .shape ) if axis not in _axes )
247
+ _axes = normalise_axis (axis , x .ndim )
248
+ N = sum (side for axis , side in enumerate (x .shape ) if axis not in _axes )
225
249
correction = data .draw (
226
250
st .floats (0.0 , N , allow_infinity = False , allow_nan = False ) | st .integers (0 , N ),
227
251
label = "correction" ,
@@ -239,13 +263,9 @@ def test_std(x, data):
239
263
out = xp .std (x , ** kw )
240
264
241
265
ph .assert_dtype ("std" , x .dtype , out .dtype )
242
-
243
- if keepdims :
244
- shape = tuple (1 if axis in _axes else side for axis , side in enumerate (x .shape ))
245
- else :
246
- shape = tuple (side for axis , side in enumerate (x .shape ) if axis not in _axes )
247
- ph .assert_shape ("std" , out .shape , shape , ** kw )
248
-
266
+ assert_keepdimable_shape (
267
+ "std" , x .shape , _axes , kw .get ("keepdims" , False ), out .shape , ** kw
268
+ )
249
269
# We can't easily test the result(s) as standard deviation methods vary a lot
250
270
251
271
0 commit comments