1
1
import math
2
- from typing import Union , Any , Tuple , NamedTuple , Iterator
3
2
from itertools import count
3
+ from typing import Any , Iterator , NamedTuple , Tuple , Union
4
4
5
- from hypothesis import assume , given , strategies as st
5
+ from hypothesis import assume , given
6
+ from hypothesis import strategies as st
6
7
7
8
from . import _array_module as xp
8
9
from . import array_helpers as ah
9
- from . import hypothesis_helpers as hh
10
10
from . import dtype_helpers as dh
11
+ from . import hypothesis_helpers as hh
11
12
from . import pytest_helpers as ph
12
13
from . import xps
13
- from .typing import Shape , DataType , Array , Scalar
14
+ from .typing import DataType , Scalar
14
15
15
16
16
17
@st .composite
@@ -28,59 +29,6 @@ def specified_kwargs(draw, *keys_values_defaults: Tuple[str, Any, Any]):
28
29
return kw
29
30
30
31
31
- def assert_default_float (func_name : str , dtype : DataType ):
32
- f_dtype = dh .dtype_to_name [dtype ]
33
- f_default = dh .dtype_to_name [dh .default_float ]
34
- msg = (
35
- f"out.dtype={ f_dtype } , should be default "
36
- f"floating-point dtype { f_default } [{ func_name } ()]"
37
- )
38
- assert dtype == dh .default_float , msg
39
-
40
-
41
- def assert_default_int (func_name : str , dtype : DataType ):
42
- f_dtype = dh .dtype_to_name [dtype ]
43
- f_default = dh .dtype_to_name [dh .default_int ]
44
- msg = (
45
- f"out.dtype={ f_dtype } , should be default "
46
- f"integer dtype { f_default } [{ func_name } ()]"
47
- )
48
- assert dtype == dh .default_int , msg
49
-
50
-
51
- def assert_kw_dtype (func_name : str , kw_dtype : DataType , out_dtype : DataType ):
52
- f_kw_dtype = dh .dtype_to_name [kw_dtype ]
53
- f_out_dtype = dh .dtype_to_name [out_dtype ]
54
- msg = (
55
- f"out.dtype={ f_out_dtype } , but should be { f_kw_dtype } "
56
- f"[{ func_name } (dtype={ f_kw_dtype } )]"
57
- )
58
- assert out_dtype == kw_dtype , msg
59
-
60
-
61
- def assert_shape (
62
- func_name : str , out_shape : Union [int , Shape ], expected : Union [int , Shape ], / , ** kw
63
- ):
64
- if isinstance (out_shape , int ):
65
- out_shape = (out_shape ,)
66
- if isinstance (expected , int ):
67
- expected = (expected ,)
68
- f_kw = ", " .join (f"{ k } ={ v } " for k , v in kw .items ())
69
- msg = f"out.shape={ out_shape } , but should be { expected } [{ func_name } ({ f_kw } )]"
70
- assert out_shape == expected , msg
71
-
72
-
73
- def assert_fill (
74
- func_name : str , fill_value : Scalar , dtype : DataType , out : Array , / , ** kw
75
- ):
76
- f_kw = ", " .join (f"{ k } ={ v } " for k , v in kw .items ())
77
- msg = f"out not filled with { fill_value } [{ func_name } ({ f_kw } )]\n " f"{ out = } "
78
- if math .isnan (fill_value ):
79
- assert ah .all (ah .isnan (out )), msg
80
- else :
81
- assert ah .all (ah .equal (out , ah .asarray (fill_value , dtype = dtype ))), msg
82
-
83
-
84
32
class frange (NamedTuple ):
85
33
start : float
86
34
stop : float
@@ -210,9 +158,9 @@ def test_arange(dtype, data):
210
158
211
159
if dtype is None :
212
160
if all_int :
213
- assert_default_int ("arange" , out .dtype )
161
+ ph . assert_default_int ("arange" , out .dtype )
214
162
else :
215
- assert_default_float ("arange" , out .dtype )
163
+ ph . assert_default_float ("arange" , out .dtype )
216
164
else :
217
165
assert out .dtype == dtype
218
166
assert out .ndim == 1 , f"{ out .ndim = } , but should be 1 [linspace()]"
@@ -253,10 +201,10 @@ def test_arange(dtype, data):
253
201
def test_empty (shape , kw ):
254
202
out = xp .empty (shape , ** kw )
255
203
if kw .get ("dtype" , None ) is None :
256
- assert_default_float ("empty" , out .dtype )
204
+ ph . assert_default_float ("empty" , out .dtype )
257
205
else :
258
- assert_kw_dtype ("empty" , kw ["dtype" ], out .dtype )
259
- assert_shape ("empty" , out .shape , shape , shape = shape )
206
+ ph . assert_kw_dtype ("empty" , kw ["dtype" ], out .dtype )
207
+ ph . assert_shape ("empty" , out .shape , shape , shape = shape )
260
208
261
209
262
210
@given (
@@ -268,8 +216,8 @@ def test_empty_like(x, kw):
268
216
if kw .get ("dtype" , None ) is None :
269
217
ph .assert_dtype ("empty_like" , (x .dtype ,), out .dtype )
270
218
else :
271
- assert_kw_dtype ("empty_like" , kw ["dtype" ], out .dtype )
272
- assert_shape ("empty_like" , out .shape , x .shape )
219
+ ph . assert_kw_dtype ("empty_like" , kw ["dtype" ], out .dtype )
220
+ ph . assert_shape ("empty_like" , out .shape , x .shape )
273
221
274
222
275
223
@given (
@@ -283,11 +231,11 @@ def test_empty_like(x, kw):
283
231
def test_eye (n_rows , n_cols , kw ):
284
232
out = xp .eye (n_rows , n_cols , ** kw )
285
233
if kw .get ("dtype" , None ) is None :
286
- assert_default_float ("eye" , out .dtype )
234
+ ph . assert_default_float ("eye" , out .dtype )
287
235
else :
288
- assert_kw_dtype ("eye" , kw ["dtype" ], out .dtype )
236
+ ph . assert_kw_dtype ("eye" , kw ["dtype" ], out .dtype )
289
237
_n_cols = n_rows if n_cols is None else n_cols
290
- assert_shape ("eye" , out .shape , (n_rows , _n_cols ), n_rows = n_rows , n_cols = n_cols )
238
+ ph . assert_shape ("eye" , out .shape , (n_rows , _n_cols ), n_rows = n_rows , n_cols = n_cols )
291
239
f_func = f"[eye({ n_rows = } , { n_cols = } )]"
292
240
for i in range (n_rows ):
293
241
for j in range (_n_cols ):
@@ -336,13 +284,13 @@ def test_full(shape, fill_value, kw):
336
284
if isinstance (fill_value , bool ):
337
285
pass # TODO
338
286
elif isinstance (fill_value , int ):
339
- assert_default_int ("full" , out .dtype )
287
+ ph . assert_default_int ("full" , out .dtype )
340
288
else :
341
- assert_default_float ("full" , out .dtype )
289
+ ph . assert_default_float ("full" , out .dtype )
342
290
else :
343
- assert_kw_dtype ("full" , kw ["dtype" ], out .dtype )
344
- assert_shape ("full" , out .shape , shape , shape = shape )
345
- assert_fill ("full" , fill_value , dtype , out , fill_value = fill_value )
291
+ ph . assert_kw_dtype ("full" , kw ["dtype" ], out .dtype )
292
+ ph . assert_shape ("full" , out .shape , shape , shape = shape )
293
+ ph . assert_fill ("full" , fill_value , dtype , out , fill_value = fill_value )
346
294
347
295
348
296
@st .composite
@@ -365,9 +313,9 @@ def test_full_like(x, fill_value, kw):
365
313
if kw .get ("dtype" , None ) is None :
366
314
ph .assert_dtype ("full_like" , (x .dtype ,), out .dtype )
367
315
else :
368
- assert_kw_dtype ("full_like" , kw ["dtype" ], out .dtype )
369
- assert_shape ("full_like" , out .shape , x .shape )
370
- assert_fill ("full_like" , fill_value , dtype , out , fill_value = fill_value )
316
+ ph . assert_kw_dtype ("full_like" , kw ["dtype" ], out .dtype )
317
+ ph . assert_shape ("full_like" , out .shape , x .shape )
318
+ ph . assert_fill ("full_like" , fill_value , dtype , out , fill_value = fill_value )
371
319
372
320
373
321
finite_kw = {"allow_nan" : False , "allow_infinity" : False }
@@ -420,7 +368,7 @@ def test_linspace(num, dtype, endpoint, data):
420
368
)
421
369
out = xp .linspace (start , stop , num , ** kw )
422
370
423
- assert_shape ("linspace" , out .shape , num , start = stop , stop = stop , num = num )
371
+ ph . assert_shape ("linspace" , out .shape , num , start = stop , stop = stop , num = num )
424
372
f_func = f"[linspace({ start = } , { stop = } , { num = } )]"
425
373
if num > 0 :
426
374
assert ah .equal (
@@ -452,12 +400,12 @@ def make_one(dtype: DataType) -> Scalar:
452
400
def test_ones (shape , kw ):
453
401
out = xp .ones (shape , ** kw )
454
402
if kw .get ("dtype" , None ) is None :
455
- assert_default_float ("ones" , out .dtype )
403
+ ph . assert_default_float ("ones" , out .dtype )
456
404
else :
457
- assert_kw_dtype ("ones" , kw ["dtype" ], out .dtype )
458
- assert_shape ("ones" , out .shape , shape , shape = shape )
405
+ ph . assert_kw_dtype ("ones" , kw ["dtype" ], out .dtype )
406
+ ph . assert_shape ("ones" , out .shape , shape , shape = shape )
459
407
dtype = kw .get ("dtype" , None ) or dh .default_float
460
- assert_fill ("ones" , make_one (dtype ), dtype , out )
408
+ ph . assert_fill ("ones" , make_one (dtype ), dtype , out )
461
409
462
410
463
411
@given (
@@ -469,10 +417,10 @@ def test_ones_like(x, kw):
469
417
if kw .get ("dtype" , None ) is None :
470
418
ph .assert_dtype ("ones_like" , (x .dtype ,), out .dtype )
471
419
else :
472
- assert_kw_dtype ("ones_like" , kw ["dtype" ], out .dtype )
473
- assert_shape ("ones_like" , out .shape , x .shape )
420
+ ph . assert_kw_dtype ("ones_like" , kw ["dtype" ], out .dtype )
421
+ ph . assert_shape ("ones_like" , out .shape , x .shape )
474
422
dtype = kw .get ("dtype" , None ) or x .dtype
475
- assert_fill ("ones_like" , make_one (dtype ), dtype , out )
423
+ ph . assert_fill ("ones_like" , make_one (dtype ), dtype , out )
476
424
477
425
478
426
def make_zero (dtype : DataType ) -> Scalar :
@@ -488,12 +436,12 @@ def make_zero(dtype: DataType) -> Scalar:
488
436
def test_zeros (shape , kw ):
489
437
out = xp .zeros (shape , ** kw )
490
438
if kw .get ("dtype" , None ) is None :
491
- assert_default_float ("zeros" , out .dtype )
439
+ ph . assert_default_float ("zeros" , out .dtype )
492
440
else :
493
- assert_kw_dtype ("zeros" , kw ["dtype" ], out .dtype )
494
- assert_shape ("zeros" , out .shape , shape , shape = shape )
441
+ ph . assert_kw_dtype ("zeros" , kw ["dtype" ], out .dtype )
442
+ ph . assert_shape ("zeros" , out .shape , shape , shape = shape )
495
443
dtype = kw .get ("dtype" , None ) or dh .default_float
496
- assert_fill ("zeros" , make_zero (dtype ), dtype , out )
444
+ ph . assert_fill ("zeros" , make_zero (dtype ), dtype , out )
497
445
498
446
499
447
@given (
@@ -505,7 +453,7 @@ def test_zeros_like(x, kw):
505
453
if kw .get ("dtype" , None ) is None :
506
454
ph .assert_dtype ("zeros_like" , (x .dtype ,), out .dtype )
507
455
else :
508
- assert_kw_dtype ("zeros_like" , kw ["dtype" ], out .dtype )
509
- assert_shape ("zeros_like" , out .shape , x .shape )
456
+ ph . assert_kw_dtype ("zeros_like" , kw ["dtype" ], out .dtype )
457
+ ph . assert_shape ("zeros_like" , out .shape , x .shape )
510
458
dtype = kw .get ("dtype" , None ) or x .dtype
511
- assert_fill ("zeros_like" , make_zero (dtype ), dtype , out )
459
+ ph . assert_fill ("zeros_like" , make_zero (dtype ), dtype , out )
0 commit comments