@@ -56,10 +56,10 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, .
56
56
57
57
def assert_keepdimable_shape (
58
58
func_name : str ,
59
+ out_shape : Shape ,
59
60
in_shape : Shape ,
60
61
axes : Tuple [int , ...],
61
62
keepdims : bool ,
62
- out_shape : Shape ,
63
63
/ ,
64
64
** kw ,
65
65
):
@@ -108,7 +108,7 @@ def test_min(x, data):
108
108
ph .assert_dtype ("min" , x .dtype , out .dtype )
109
109
_axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
110
110
assert_keepdimable_shape (
111
- "min" , x .shape , _axes , kw .get ("keepdims" , False ), out . shape , ** kw
111
+ "min" , out . shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
112
112
)
113
113
scalar_type = dh .get_scalar_type (out .dtype )
114
114
for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
@@ -137,7 +137,7 @@ def test_max(x, data):
137
137
ph .assert_dtype ("max" , x .dtype , out .dtype )
138
138
_axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
139
139
assert_keepdimable_shape (
140
- "max" , x .shape , _axes , kw .get ("keepdims" , False ), out . shape , ** kw
140
+ "max" , out . shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
141
141
)
142
142
scalar_type = dh .get_scalar_type (out .dtype )
143
143
for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
@@ -166,7 +166,7 @@ def test_mean(x, data):
166
166
ph .assert_dtype ("mean" , x .dtype , out .dtype )
167
167
_axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
168
168
assert_keepdimable_shape (
169
- "mean" , x .shape , _axes , kw .get ("keepdims" , False ), out . shape , ** kw
169
+ "mean" , out . shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
170
170
)
171
171
for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
172
172
mean = float (out [out_idx ])
@@ -217,7 +217,7 @@ def test_prod(x, data):
217
217
ph .assert_dtype ("prod" , x .dtype , out .dtype , _dtype )
218
218
_axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
219
219
assert_keepdimable_shape (
220
- "prod" , x .shape , _axes , kw .get ("keepdims" , False ), out . shape , ** kw
220
+ "prod" , out . shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
221
221
)
222
222
scalar_type = dh .get_scalar_type (out .dtype )
223
223
for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
@@ -264,20 +264,97 @@ def test_std(x, data):
264
264
265
265
ph .assert_dtype ("std" , x .dtype , out .dtype )
266
266
assert_keepdimable_shape (
267
- "std" , x .shape , _axes , kw .get ("keepdims" , False ), out . shape , ** kw
267
+ "std" , out . shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
268
268
)
269
269
# We can't easily test the result(s) as standard deviation methods vary a lot
270
270
271
271
272
- # TODO: generate kwargs
273
- @given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (min_side = 1 )))
274
- def test_sum (x ):
275
- xp .sum (x )
276
- # TODO
272
+ @given (
273
+ x = xps .arrays (
274
+ dtype = xps .floating_dtypes (),
275
+ shape = hh .shapes (min_side = 1 ),
276
+ elements = {"allow_nan" : False },
277
+ ).filter (lambda x : x .size >= 2 ),
278
+ data = st .data (),
279
+ )
280
+ def test_var (x , data ):
281
+ axis = data .draw (axes (x .ndim ), label = "axis" )
282
+ _axes = normalise_axis (axis , x .ndim )
283
+ N = sum (side for axis , side in enumerate (x .shape ) if axis not in _axes )
284
+ correction = data .draw (
285
+ st .floats (0.0 , N , allow_infinity = False , allow_nan = False ) | st .integers (0 , N ),
286
+ label = "correction" ,
287
+ )
288
+ keepdims = data .draw (st .booleans (), label = "keepdims" )
289
+ kw = data .draw (
290
+ hh .specified_kwargs (
291
+ ("axis" , axis , None ),
292
+ ("correction" , correction , 0.0 ),
293
+ ("keepdims" , keepdims , False ),
294
+ ),
295
+ label = "kw" ,
296
+ )
297
+
298
+ out = xp .var (x , ** kw )
299
+
300
+ ph .assert_dtype ("var" , x .dtype , out .dtype )
301
+ assert_keepdimable_shape (
302
+ "var" , out .shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
303
+ )
304
+ # We can't easily test the result(s) as variance methods vary a lot
305
+
306
+
307
+ @given (
308
+ x = xps .arrays (
309
+ dtype = xps .numeric_dtypes (),
310
+ shape = hh .shapes (min_side = 1 ),
311
+ elements = {"allow_nan" : False },
312
+ ),
313
+ data = st .data (),
314
+ )
315
+ def test_sum (x , data ):
316
+ kw = data .draw (
317
+ hh .kwargs (
318
+ axis = axes (x .ndim ),
319
+ dtype = st .none () | st .just (x .dtype ), # TODO: all valid dtypes
320
+ keepdims = st .booleans (),
321
+ ),
322
+ label = "kw" ,
323
+ )
277
324
325
+ out = xp .sum (x , ** kw )
278
326
279
- # TODO: generate kwargs
280
- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )))
281
- def test_var (x ):
282
- xp .var (x )
283
- # TODO
327
+ dtype = kw .get ("dtype" , None )
328
+ if dtype is None :
329
+ if dh .is_int_dtype (x .dtype ):
330
+ m , M = dh .dtype_ranges [x .dtype ]
331
+ d_m , d_M = dh .dtype_ranges [dh .default_int ]
332
+ if m < d_m or M > d_M :
333
+ _dtype = x .dtype
334
+ else :
335
+ _dtype = dh .default_int
336
+ else :
337
+ if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_float ]:
338
+ _dtype = x .dtype
339
+ else :
340
+ _dtype = dh .default_float
341
+ else :
342
+ _dtype = dtype
343
+ ph .assert_dtype ("sum" , x .dtype , out .dtype , _dtype )
344
+ _axes = normalise_axis (kw .get ("axis" , None ), x .ndim )
345
+ assert_keepdimable_shape (
346
+ "sum" , out .shape , x .shape , _axes , kw .get ("keepdims" , False ), ** kw
347
+ )
348
+ scalar_type = dh .get_scalar_type (out .dtype )
349
+ for indices , out_idx in zip (axes_ndindex (x .shape , _axes ), ah .ndindex (out .shape )):
350
+ sum_ = scalar_type (out [out_idx ])
351
+ assume (not math .isinf (sum_ ))
352
+ elements = []
353
+ for idx in indices :
354
+ s = scalar_type (x [idx ])
355
+ elements .append (s )
356
+ expected = sum (elements )
357
+ if dh .is_int_dtype (out .dtype ):
358
+ m , M = dh .dtype_ranges [out .dtype ]
359
+ assume (m <= expected <= M )
360
+ assert_equals ("sum" , dh .get_scalar_type (out .dtype ), out_idx , sum_ , expected )
0 commit comments