@@ -19,30 +19,49 @@ def non_complex_dtypes():
19
19
return xps .boolean_dtypes () | hh .real_dtypes
20
20
21
21
22
+ def numeric_dtypes ():
23
+ return xps .boolean_dtypes () | hh .real_dtypes | hh .complex_dtypes
24
+
25
+
22
26
def float32 (n : Union [int , float ]) -> float :
23
27
return struct .unpack ("!f" , struct .pack ("!f" , float (n )))[0 ]
24
28
25
29
30
+ def _float_match_complex (complex_dtype ):
31
+ return xp .float32 if complex_dtype == xp .complex64 else xp .float64
32
+
33
+
26
34
@given (
27
- x_dtype = non_complex_dtypes (),
28
- dtype = non_complex_dtypes (),
35
+ x_dtype = numeric_dtypes (),
36
+ dtype = numeric_dtypes (),
29
37
kw = hh .kwargs (copy = st .booleans ()),
30
38
data = st .data (),
31
39
)
32
40
def test_astype (x_dtype , dtype , kw , data ):
33
41
if xp .bool in (x_dtype , dtype ):
34
42
elements_strat = hh .from_dtype (x_dtype )
35
43
else :
36
- m1 , M1 = dh .dtype_ranges [x_dtype ]
37
- m2 , M2 = dh .dtype_ranges [dtype ]
44
+
38
45
if dh .is_int_dtype (x_dtype ):
39
46
cast = int
40
- elif x_dtype == xp .float32 :
47
+ elif x_dtype in ( xp .float32 , xp . complex64 ) :
41
48
cast = float32
42
49
else :
43
50
cast = float
51
+
52
+ real_dtype = x_dtype
53
+ if x_dtype in (xp .complex64 , xp .complex128 ):
54
+ real_dtype = _float_match_complex (x_dtype )
55
+ m1 , M1 = dh .dtype_ranges [real_dtype ]
56
+
57
+ real_dtype = dtype
58
+ if dtype in (xp .complex64 , xp .complex128 ):
59
+ real_dtype = _float_match_complex (x_dtype )
60
+ m2 , M2 = dh .dtype_ranges [real_dtype ]
61
+
44
62
min_value = cast (max (m1 , m2 ))
45
63
max_value = cast (min (M1 , M2 ))
64
+
46
65
elements_strat = hh .from_dtype (
47
66
x_dtype ,
48
67
min_value = min_value ,
0 commit comments