10
10
from . import dtype_helpers as dh
11
11
from . import pytest_helpers as ph
12
12
from . import xps
13
- from .typing import Shape , DataType
13
+ from .typing import Shape , DataType , Array
14
14
15
15
16
16
def assert_default_float (func_name : str , dtype : DataType ):
@@ -33,11 +33,7 @@ def assert_default_int(func_name: str, dtype: DataType):
33
33
assert dtype == dh .default_int , msg
34
34
35
35
36
- def assert_kw_dtype (
37
- func_name : str ,
38
- kw_dtype : DataType ,
39
- out_dtype : DataType ,
40
- ):
36
+ def assert_kw_dtype (func_name : str , kw_dtype : DataType , out_dtype : DataType ):
41
37
f_kw_dtype = dh .dtype_to_name [kw_dtype ]
42
38
f_out_dtype = dh .dtype_to_name [out_dtype ]
43
39
msg = (
@@ -47,12 +43,7 @@ def assert_kw_dtype(
47
43
assert out_dtype == kw_dtype , msg
48
44
49
45
50
- def assert_shape (
51
- func_name : str ,
52
- out_shape : Shape ,
53
- expected : Union [int , Shape ],
54
- ** kw ,
55
- ):
46
+ def assert_shape (func_name : str , out_shape : Shape , expected : Union [int , Shape ], ** kw ):
56
47
f_kw = ", " .join (f"{ k } ={ v } " for k , v in kw .items ())
57
48
msg = f"out.shape={ out_shape } , but should be { expected } [{ func_name } ({ f_kw } )]"
58
49
if isinstance (expected , int ):
@@ -61,6 +52,18 @@ def assert_shape(
61
52
62
53
63
54
55
+ def assert_fill (func_name : str , fill : float , dtype : DataType , out : Array , ** kw ):
56
+ f_kw = ", " .join (f"{ k } ={ v } " for k , v in kw .items ())
57
+ msg = (
58
+ f"out not filled with { fill } [{ func_name } ({ f_kw } )]\n "
59
+ f"{ out = } "
60
+ )
61
+ if math .isnan (fill ):
62
+ assert ah .all (ah .isnan (out )), msg
63
+ else :
64
+ assert ah .all (ah .equal (out , ah .asarray (fill , dtype = dtype ))), msg
65
+
66
+
64
67
# Testing xp.arange() requires bounding the start/stop/step arguments to only
65
68
# test argument combinations compliant with the Array API, as well as to not
66
69
# produce arrays with sizes not supproted by an array module.
@@ -234,8 +237,9 @@ def test_eye(n_rows, n_cols, kw):
234
237
)
235
238
236
239
240
+
237
241
@st .composite
238
- def full_fill_values (draw ):
242
+ def full_fill_values (draw ) -> st . SearchStrategy [ float ] :
239
243
kw = draw (st .shared (hh .kwargs (dtype = st .none () | xps .scalar_dtypes ()), key = "full_kw" ))
240
244
dtype = kw .get ("dtype" , None ) or draw (default_safe_dtypes )
241
245
return draw (xps .from_dtype (dtype ))
@@ -266,10 +270,7 @@ def test_full(shape, fill_value, kw):
266
270
else :
267
271
assert_kw_dtype ("full" , kw ["dtype" ], out .dtype )
268
272
assert_shape ("full" , out .shape , shape , shape = shape )
269
- if dh .is_float_dtype (out .dtype ) and math .isnan (fill_value ):
270
- assert ah .all (ah .isnan (out )), "full() array did not equal the fill value"
271
- else :
272
- assert ah .all (ah .equal (out , ah .asarray (fill_value , dtype = dtype ))), "full() array did not equal the fill value"
273
+ assert_fill ("full" , fill_value , dtype , out , fill_value = fill_value )
273
274
274
275
275
276
@st .composite
@@ -291,13 +292,8 @@ def test_full_like(x, fill_value, kw):
291
292
ph .assert_dtype ("full_like" , (x .dtype ,), out .dtype )
292
293
else :
293
294
assert_kw_dtype ("full_like" , kw ["dtype" ], out .dtype )
294
-
295
295
assert_shape ("full_like" , out .shape , x .shape )
296
- if dh .is_float_dtype (dtype ) and math .isnan (fill_value ):
297
- assert ah .all (ah .isnan (out )), "full_like() array did not equal the fill value"
298
- else :
299
- assert ah .all (ah .equal (out , ah .asarray (fill_value , dtype = dtype ))), "full_like() array did not equal the fill value"
300
-
296
+ assert_fill ("full_like" , fill_value , dtype , out , fill_value = fill_value )
301
297
302
298
finite_kw = {"allow_nan" : False , "allow_infinity" : False }
303
299
@@ -364,7 +360,7 @@ def test_linspace(num, dtype, endpoint, data):
364
360
# TODO: array assertions ala test_arange
365
361
366
362
367
- def make_one (dtype ) :
363
+ def make_one (dtype : DataType ) -> Union [ bool , float ] :
368
364
if dtype is None or dh .is_float_dtype (dtype ):
369
365
return 1.0
370
366
elif dh .is_int_dtype (dtype ):
@@ -382,7 +378,7 @@ def test_ones(shape, kw):
382
378
assert_kw_dtype ("ones" , kw ["dtype" ], out .dtype )
383
379
assert_shape ("ones" , out .shape , shape , shape = shape )
384
380
dtype = kw .get ("dtype" , None ) or dh .default_float
385
- assert ah . all ( ah . equal ( out , ah . asarray ( make_one (dtype ), dtype = dtype ))), "ones() array did not equal 1"
381
+ assert_fill ( "ones" , make_one (dtype ), dtype , out )
386
382
387
383
388
384
@given (
@@ -397,10 +393,10 @@ def test_ones_like(x, kw):
397
393
assert_kw_dtype ("ones_like" , kw ["dtype" ], out .dtype )
398
394
assert_shape ("ones_like" , out .shape , x .shape )
399
395
dtype = kw .get ("dtype" , None ) or x .dtype
400
- assert ah . all ( ah . equal ( out , ah . asarray ( make_one (dtype ), dtype = dtype ))), "ones_like() array elements did not equal 1"
396
+ assert_fill ( "ones_like" , make_one (dtype ), dtype , out )
401
397
402
398
403
- def make_zero (dtype ) :
399
+ def make_zero (dtype : DataType ) -> Union [ bool , float ] :
404
400
if dtype is None or dh .is_float_dtype (dtype ):
405
401
return 0.0
406
402
elif dh .is_int_dtype (dtype ):
@@ -418,7 +414,7 @@ def test_zeros(shape, kw):
418
414
assert_kw_dtype ("zeros" , kw ["dtype" ], out .dtype )
419
415
assert_shape ("zeros" , out .shape , shape , shape = shape )
420
416
dtype = kw .get ("dtype" , None ) or dh .default_float
421
- assert ah . all ( ah . equal ( out , ah . asarray ( make_zero (dtype ), dtype = dtype ))), "zeros() array did not equal 0"
417
+ assert_fill ( "zeros" , make_zero (dtype ), dtype , out )
422
418
423
419
424
420
@given (
@@ -433,4 +429,4 @@ def test_zeros_like(x, kw):
433
429
assert_kw_dtype ("zeros_like" , kw ["dtype" ], out .dtype )
434
430
assert_shape ("zeros_like" , out .shape , x .shape )
435
431
dtype = kw .get ("dtype" , None ) or x .dtype
436
- assert ah . all ( ah . equal ( out , ah . asarray ( make_zero (dtype ), dtype = out . dtype ))), "xp.zeros_like() array elements did not ah.all xp.equal 0"
432
+ assert_fill ( "zeros_like" , make_zero (dtype ), dtype , out )
0 commit comments