diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 34c40024..49898b6e 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -2,7 +2,7 @@ from typing import Union import pytest -from hypothesis import given +from hypothesis import given, assume from hypothesis import strategies as st from . import _array_module as xp @@ -23,26 +23,43 @@ def float32(n: Union[int, float]) -> float: return struct.unpack("!f", struct.pack("!f", float(n)))[0] +def _float_match_complex(complex_dtype): + return xp.float32 if complex_dtype == xp.complex64 else xp.float64 + + @given( - x_dtype=non_complex_dtypes(), - dtype=non_complex_dtypes(), + x_dtype=hh.all_dtypes, + dtype=hh.all_dtypes, kw=hh.kwargs(copy=st.booleans()), data=st.data(), ) def test_astype(x_dtype, dtype, kw, data): + _complex_dtypes = (xp.complex64, xp.complex128) + if xp.bool in (x_dtype, dtype): elements_strat = hh.from_dtype(x_dtype) else: - m1, M1 = dh.dtype_ranges[x_dtype] - m2, M2 = dh.dtype_ranges[dtype] + if dh.is_int_dtype(x_dtype): cast = int - elif x_dtype == xp.float32: + elif x_dtype in (xp.float32, xp.complex64): cast = float32 else: cast = float + + real_dtype = x_dtype + if x_dtype in _complex_dtypes: + real_dtype = _float_match_complex(x_dtype) + m1, M1 = dh.dtype_ranges[real_dtype] + + real_dtype = dtype + if dtype in _complex_dtypes: + real_dtype = _float_match_complex(x_dtype) + m2, M2 = dh.dtype_ranges[real_dtype] + min_value = cast(max(m1, m2)) max_value = cast(min(M1, M2)) + elements_strat = hh.from_dtype( x_dtype, min_value=min_value, @@ -54,6 +71,11 @@ def test_astype(x_dtype, dtype, kw, data): hh.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x" ) + # according to the spec, "Casting a complex floating-point array to a real-valued + # data type should not be permitted." + # https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype + assume(not ((x_dtype in _complex_dtypes) and (dtype not in _complex_dtypes))) + out = xp.astype(x, dtype, **kw) ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype)