1
1
import math
2
2
3
- from hypothesis import given
3
+ from hypothesis import assume , given
4
4
from hypothesis import strategies as st
5
5
6
6
from . import _array_module as xp
9
9
from . import hypothesis_helpers as hh
10
10
from . import pytest_helpers as ph
11
11
from . import xps
12
+ from .typing import Scalar , ScalarType
12
13
13
- RTOL = 0.05
14
+
15
+ def assert_equals (
16
+ func_name : str , type_ : ScalarType , out : Scalar , expected : Scalar , / , ** kw
17
+ ):
18
+ f_func = f"{ func_name } ({ ph .fmt_kw (kw )} )"
19
+ if type_ is bool or type_ is int :
20
+ msg = f"{ out = } , should be { expected } [{ f_func } ]"
21
+ assert out == expected , msg
22
+ elif math .isnan (expected ):
23
+ msg = f"{ out = } , should be { expected } [{ f_func } ]"
24
+ assert math .isnan (out ), msg
25
+ else :
26
+ msg = f"{ out = } , should be roughly { expected } [{ f_func } ]"
27
+ assert math .isclose (out , expected , rel_tol = 0.05 ), msg
14
28
15
29
16
30
@given (
@@ -34,7 +48,7 @@ def test_min(x, data):
34
48
f_func = f"min({ ph .fmt_kw (kw )} )"
35
49
36
50
# TODO: support axis
37
- if kw .get ("axis" ) is None :
51
+ if kw .get ("axis" , None ) is None :
38
52
keepdims = kw .get ("keepdims" , False )
39
53
if keepdims :
40
54
idx = tuple (1 for _ in x .shape )
@@ -53,11 +67,7 @@ def test_min(x, data):
53
67
elements .append (s )
54
68
min_ = scalar_type (_out )
55
69
expected = min (elements )
56
- msg = f"out={ min_ } , should be { expected } [{ f_func } ]"
57
- if math .isnan (min_ ):
58
- assert math .isnan (expected ), msg
59
- else :
60
- assert min_ == expected , msg
70
+ assert_equals ("min" , dh .get_scalar_type (out .dtype ), min_ , expected )
61
71
62
72
63
73
@given (
@@ -81,7 +91,7 @@ def test_max(x, data):
81
91
f_func = f"max({ ph .fmt_kw (kw )} )"
82
92
83
93
# TODO: support axis
84
- if kw .get ("axis" ) is None :
94
+ if kw .get ("axis" , None ) is None :
85
95
keepdims = kw .get ("keepdims" , False )
86
96
if keepdims :
87
97
idx = tuple (1 for _ in x .shape )
@@ -100,11 +110,7 @@ def test_max(x, data):
100
110
elements .append (s )
101
111
max_ = scalar_type (_out )
102
112
expected = max (elements )
103
- msg = f"out={ max_ } , should be { expected } [{ f_func } ]"
104
- if math .isnan (max_ ):
105
- assert math .isnan (expected ), msg
106
- else :
107
- assert max_ == expected , msg
113
+ assert_equals ("mean" , dh .get_scalar_type (out .dtype ), max_ , expected )
108
114
109
115
110
116
@given (
@@ -128,7 +134,7 @@ def test_mean(x, data):
128
134
f_func = f"mean({ ph .fmt_kw (kw )} )"
129
135
130
136
# TODO: support axis
131
- if kw .get ("axis" ) is None :
137
+ if kw .get ("axis" , None ) is None :
132
138
keepdims = kw .get ("keepdims" , False )
133
139
if keepdims :
134
140
idx = tuple (1 for _ in x .shape )
@@ -146,15 +152,74 @@ def test_mean(x, data):
146
152
elements .append (s )
147
153
mean = float (_out )
148
154
expected = sum (elements ) / len (elements )
149
- msg = f"out={ mean } , should be roughly { expected } [{ f_func } ]"
150
- assert math .isclose (mean , expected , rel_tol = RTOL ), msg
155
+ assert_equals ("mean" , float , mean , expected )
151
156
152
157
153
- # TODO: generate kwargs
154
- @given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )))
155
- def test_prod (x ):
156
- xp .prod (x )
157
- # TODO
158
+ @given (
159
+ x = xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )),
160
+ data = st .data (),
161
+ )
162
+ def test_prod (x , data ):
163
+ axis_strats = [st .none ()]
164
+ if x .shape != ():
165
+ axis_strats .append (
166
+ st .integers (- x .ndim , x .ndim - 1 ) | xps .valid_tuple_axes (x .ndim )
167
+ )
168
+ kw = data .draw (
169
+ hh .kwargs (
170
+ axis = st .one_of (axis_strats ),
171
+ dtype = st .none () | st .just (x .dtype ), # TODO: all valid dtypes
172
+ keepdims = st .booleans (),
173
+ ),
174
+ label = "kw" ,
175
+ )
176
+
177
+ out = xp .prod (x , ** kw )
178
+
179
+ dtype = kw .get ("dtype" , None )
180
+ if dtype is None :
181
+ if dh .is_int_dtype (x .dtype ):
182
+ m , M = dh .dtype_ranges [x .dtype ]
183
+ d_m , d_M = dh .dtype_ranges [dh .default_int ]
184
+ if m < d_m or M > d_M :
185
+ _dtype = x .dtype
186
+ else :
187
+ _dtype = dh .default_int
188
+ else :
189
+ if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_float ]:
190
+ _dtype = x .dtype
191
+ else :
192
+ _dtype = dh .default_float
193
+ else :
194
+ _dtype = dtype
195
+ ph .assert_dtype ("prod" , x .dtype , out .dtype , _dtype )
196
+
197
+ f_func = f"prod({ ph .fmt_kw (kw )} )"
198
+
199
+ # TODO: support axis
200
+ if kw .get ("axis" , None ) is None :
201
+ keepdims = kw .get ("keepdims" , False )
202
+ if keepdims :
203
+ idx = tuple (1 for _ in x .shape )
204
+ msg = f"{ out .shape = } , should be reduced dimension { idx } [{ f_func } ]"
205
+ assert out .shape == idx , msg
206
+ else :
207
+ ph .assert_shape ("prod" , out .shape , (), ** kw )
208
+
209
+ # TODO: figure out NaN behaviour
210
+ if dh .is_int_dtype (x .dtype ) or not xp .any (xp .isnan (x )):
211
+ _out = xp .reshape (out , ()) if keepdims else out
212
+ scalar_type = dh .get_scalar_type (out .dtype )
213
+ elements = []
214
+ for idx in ah .ndindex (x .shape ):
215
+ s = scalar_type (x [idx ])
216
+ elements .append (s )
217
+ prod = scalar_type (_out )
218
+ expected = math .prod (elements )
219
+ if dh .is_int_dtype (out .dtype ):
220
+ m , M = dh .dtype_ranges [out .dtype ]
221
+ assume (m <= expected <= M )
222
+ assert_equals ("prod" , dh .get_scalar_type (out .dtype ), prod , expected )
158
223
159
224
160
225
# TODO: generate kwargs
0 commit comments