1
1
import math
2
+ from typing import Optional , Union
2
3
3
4
from hypothesis import assume , given
4
5
from hypothesis import strategies as st
9
10
from . import hypothesis_helpers as hh
10
11
from . import pytest_helpers as ph
11
12
from . import xps
12
- from .typing import Scalar , ScalarType
13
+ from .typing import Scalar , ScalarType , Shape
14
+
15
+
16
+ def axes (ndim : int ) -> st .SearchStrategy [Optional [Union [int , Shape ]]]:
17
+ axes_strats = [st .none ()]
18
+ if ndim != 0 :
19
+ axes_strats .append (st .integers (- ndim , ndim - 1 ))
20
+ axes_strats .append (xps .valid_tuple_axes (ndim ))
21
+ return st .one_of (axes_strats )
13
22
14
23
15
24
def assert_equals (
@@ -32,14 +41,7 @@ def assert_equals(
32
41
data = st .data (),
33
42
)
34
43
def test_min (x , data ):
35
- axis_strats = [st .none ()]
36
- if x .shape != ():
37
- axis_strats .append (
38
- st .integers (- x .ndim , x .ndim - 1 ) | xps .valid_tuple_axes (x .ndim )
39
- )
40
- kw = data .draw (
41
- hh .kwargs (axis = st .one_of (axis_strats ), keepdims = st .booleans ()), label = "kw"
42
- )
44
+ kw = data .draw (hh .kwargs (axis = axes (x .ndim ), keepdims = st .booleans ()), label = "kw" )
43
45
44
46
out = xp .min (x , ** kw )
45
47
@@ -75,14 +77,7 @@ def test_min(x, data):
75
77
data = st .data (),
76
78
)
77
79
def test_max (x , data ):
78
- axis_strats = [st .none ()]
79
- if x .shape != ():
80
- axis_strats .append (
81
- st .integers (- x .ndim , x .ndim - 1 ) | xps .valid_tuple_axes (x .ndim )
82
- )
83
- kw = data .draw (
84
- hh .kwargs (axis = st .one_of (axis_strats ), keepdims = st .booleans ()), label = "kw"
85
- )
80
+ kw = data .draw (hh .kwargs (axis = axes (x .ndim ), keepdims = st .booleans ()), label = "kw" )
86
81
87
82
out = xp .max (x , ** kw )
88
83
@@ -118,14 +113,7 @@ def test_max(x, data):
118
113
data = st .data (),
119
114
)
120
115
def test_mean (x , data ):
121
- axis_strats = [st .none ()]
122
- if x .shape != ():
123
- axis_strats .append (
124
- st .integers (- x .ndim , x .ndim - 1 ) | xps .valid_tuple_axes (x .ndim )
125
- )
126
- kw = data .draw (
127
- hh .kwargs (axis = st .one_of (axis_strats ), keepdims = st .booleans ()), label = "kw"
128
- )
116
+ kw = data .draw (hh .kwargs (axis = axes (x .ndim ), keepdims = st .booleans ()), label = "kw" )
129
117
130
118
out = xp .mean (x , ** kw )
131
119
@@ -160,14 +148,9 @@ def test_mean(x, data):
160
148
data = st .data (),
161
149
)
162
150
def test_prod (x , data ):
163
- axis_strats = [st .none ()]
164
- if x .shape != ():
165
- axis_strats .append (
166
- st .integers (- x .ndim , x .ndim - 1 ) | xps .valid_tuple_axes (x .ndim )
167
- )
168
151
kw = data .draw (
169
152
hh .kwargs (
170
- axis = st . one_of ( axis_strats ),
153
+ axis = axes ( x . ndim ),
171
154
dtype = st .none () | st .just (x .dtype ), # TODO: all valid dtypes
172
155
keepdims = st .booleans (),
173
156
),
@@ -222,10 +205,14 @@ def test_prod(x, data):
222
205
assert_equals ("prod" , dh .get_scalar_type (out .dtype ), prod , expected )
223
206
224
207
225
- # TODO: generate kwargs
226
- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )))
227
- def test_std (x ):
228
- xp .std (x )
208
+ @given (
209
+ x = xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )),
210
+ data = st .data (),
211
+ )
212
+ def test_std (x , data ):
213
+ kw = data .draw (hh .kwargs (axis = axes (x .ndim ), keepdims = st .booleans ()), label = "kw" )
214
+
215
+ xp .std (x , ** kw )
229
216
# TODO
230
217
231
218
0 commit comments