Skip to content

Commit 038ae16

Browse files
committed
ENH: test astype with complex inputs
1 parent a3f3f37 commit 038ae16

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

array_api_tests/test_data_type_functions.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,49 @@ def non_complex_dtypes():
1919
return xps.boolean_dtypes() | hh.real_dtypes
2020

2121

22+
def numeric_dtypes():
23+
return xps.boolean_dtypes() | hh.real_dtypes | hh.complex_dtypes
24+
25+
2226
def float32(n: Union[int, float]) -> float:
2327
return struct.unpack("!f", struct.pack("!f", float(n)))[0]
2428

2529

30+
def _float_match_complex(complex_dtype):
31+
return xp.float32 if complex_dtype == xp.complex64 else xp.float64
32+
33+
2634
@given(
27-
x_dtype=non_complex_dtypes(),
28-
dtype=non_complex_dtypes(),
35+
x_dtype=numeric_dtypes(),
36+
dtype=numeric_dtypes(),
2937
kw=hh.kwargs(copy=st.booleans()),
3038
data=st.data(),
3139
)
3240
def test_astype(x_dtype, dtype, kw, data):
3341
if xp.bool in (x_dtype, dtype):
3442
elements_strat = hh.from_dtype(x_dtype)
3543
else:
36-
m1, M1 = dh.dtype_ranges[x_dtype]
37-
m2, M2 = dh.dtype_ranges[dtype]
44+
3845
if dh.is_int_dtype(x_dtype):
3946
cast = int
40-
elif x_dtype == xp.float32:
47+
elif x_dtype in (xp.float32, xp.complex64):
4148
cast = float32
4249
else:
4350
cast = float
51+
52+
real_dtype = x_dtype
53+
if x_dtype in (xp.complex64, xp.complex128):
54+
real_dtype = _float_match_complex(x_dtype)
55+
m1, M1 = dh.dtype_ranges[real_dtype]
56+
57+
real_dtype = dtype
58+
if dtype in (xp.complex64, xp.complex128):
59+
real_dtype = _float_match_complex(x_dtype)
60+
m2, M2 = dh.dtype_ranges[real_dtype]
61+
4462
min_value = cast(max(m1, m2))
4563
max_value = cast(min(M1, M2))
64+
4665
elements_strat = hh.from_dtype(
4766
x_dtype,
4867
min_value=min_value,

0 commit comments

Comments
 (0)