From 038ae16b6f89403063a659dc0b61d5bb5e4aee9b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 18 Nov 2024 20:35:56 +0100 Subject: [PATCH 1/3] ENH: test astype with complex inputs --- array_api_tests/test_data_type_functions.py | 29 +++++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 34c40024..73087ecb 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -19,13 +19,21 @@ def non_complex_dtypes(): return xps.boolean_dtypes() | hh.real_dtypes +def numeric_dtypes(): + return xps.boolean_dtypes() | hh.real_dtypes | hh.complex_dtypes + + def float32(n: Union[int, float]) -> float: return struct.unpack("!f", struct.pack("!f", float(n)))[0] +def _float_match_complex(complex_dtype): + return xp.float32 if complex_dtype == xp.complex64 else xp.float64 + + @given( - x_dtype=non_complex_dtypes(), - dtype=non_complex_dtypes(), + x_dtype=numeric_dtypes(), + dtype=numeric_dtypes(), kw=hh.kwargs(copy=st.booleans()), data=st.data(), ) @@ -33,16 +41,27 @@ def test_astype(x_dtype, dtype, kw, data): if xp.bool in (x_dtype, dtype): elements_strat = hh.from_dtype(x_dtype) else: - m1, M1 = dh.dtype_ranges[x_dtype] - m2, M2 = dh.dtype_ranges[dtype] + if dh.is_int_dtype(x_dtype): cast = int - elif x_dtype == xp.float32: + elif x_dtype in (xp.float32, xp.complex64): cast = float32 else: cast = float + + real_dtype = x_dtype + if x_dtype in (xp.complex64, xp.complex128): + real_dtype = _float_match_complex(x_dtype) + m1, M1 = dh.dtype_ranges[real_dtype] + + real_dtype = dtype + if dtype in (xp.complex64, xp.complex128): + real_dtype = _float_match_complex(x_dtype) + m2, M2 = dh.dtype_ranges[real_dtype] + min_value = cast(max(m1, m2)) max_value = cast(min(M1, M2)) + elements_strat = hh.from_dtype( x_dtype, min_value=min_value, From 653921d67adf2ede680df37ae1aeb945d61176cd Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 23 Nov 2024 14:10:40 +0200 Subject: [PATCH 2/3] test astype for all_dtypes --- array_api_tests/test_data_type_functions.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 73087ecb..1f556169 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -19,10 +19,6 @@ def non_complex_dtypes(): return xps.boolean_dtypes() | hh.real_dtypes -def numeric_dtypes(): - return xps.boolean_dtypes() | hh.real_dtypes | hh.complex_dtypes - - def float32(n: Union[int, float]) -> float: return struct.unpack("!f", struct.pack("!f", float(n)))[0] @@ -32,8 +28,8 @@ def _float_match_complex(complex_dtype): @given( - x_dtype=numeric_dtypes(), - dtype=numeric_dtypes(), + x_dtype=hh.all_dtypes, + dtype=hh.all_dtypes, kw=hh.kwargs(copy=st.booleans()), data=st.data(), ) From 3f739136c1b6c87b5aad3ddffd3d4ff2bca82c95 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 26 Nov 2024 16:46:12 +0200 Subject: [PATCH 3/3] avoid testing asttype(complex, not complex) --- array_api_tests/test_data_type_functions.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 1f556169..49898b6e 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -2,7 +2,7 @@ from typing import Union import pytest -from hypothesis import given +from hypothesis import given, assume from hypothesis import strategies as st from . import _array_module as xp @@ -34,6 +34,8 @@ def _float_match_complex(complex_dtype): data=st.data(), ) def test_astype(x_dtype, dtype, kw, data): + _complex_dtypes = (xp.complex64, xp.complex128) + if xp.bool in (x_dtype, dtype): elements_strat = hh.from_dtype(x_dtype) else: @@ -46,12 +48,12 @@ def test_astype(x_dtype, dtype, kw, data): cast = float real_dtype = x_dtype - if x_dtype in (xp.complex64, xp.complex128): + if x_dtype in _complex_dtypes: real_dtype = _float_match_complex(x_dtype) m1, M1 = dh.dtype_ranges[real_dtype] real_dtype = dtype - if dtype in (xp.complex64, xp.complex128): + if dtype in _complex_dtypes: real_dtype = _float_match_complex(x_dtype) m2, M2 = dh.dtype_ranges[real_dtype] @@ -69,6 +71,11 @@ def test_astype(x_dtype, dtype, kw, data): hh.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x" ) + # according to the spec, "Casting a complex floating-point array to a real-valued + # data type should not be permitted." + # https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype + assume(not ((x_dtype in _complex_dtypes) and (dtype not in _complex_dtypes))) + out = xp.astype(x, dtype, **kw) ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype)