From 0ab903adebf91abc42cc68b3740fe6ff00949db9 Mon Sep 17 00:00:00 2001 From: Christian Bourjau Date: Tue, 30 Jul 2024 15:05:28 +0200 Subject: [PATCH 1/3] Fix way to determine default_complex --- array_api_tests/dtype_helpers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 6ff8d49c..3af990e0 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -322,15 +322,16 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg): default_float = xp.asarray(float()).dtype if default_float not in real_float_dtypes: warn(f"inferred default float is {default_float!r}, which is not a float") - if api_version > "2021.12": + + if not (hasattr(xp, "complex32") or hasattr(xp, "complex64")): + default_complex = None + else: default_complex = xp.asarray(complex()).dtype if default_complex not in complex_dtypes: warn( f"inferred default complex is {default_complex!r}, " "which is not a complex" ) - else: - default_complex = None if dtype_nbits[default_int] == 32: default_uint = _name_to_dtype.get("uint32") From c0a7bf2d412d72840c90c819565d26af0a7246c1 Mon Sep 17 00:00:00 2001 From: Christian Bourjau Date: Wed, 31 Jul 2024 14:12:44 +0200 Subject: [PATCH 2/3] Keep version check and use skip_dtypes instead --- array_api_tests/dtype_helpers.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 3af990e0..4729d8f9 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -322,16 +322,15 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg): default_float = xp.asarray(float()).dtype if default_float not in real_float_dtypes: warn(f"inferred default float is {default_float!r}, which is not a float") - - if not (hasattr(xp, "complex32") or hasattr(xp, "complex64")): - default_complex = None - else: + if api_version > "2021.12" and ({'complex64', 'complex128'} - set(skip_dtypes)): default_complex = xp.asarray(complex()).dtype if default_complex not in complex_dtypes: warn( f"inferred default complex is {default_complex!r}, " "which is not a complex" ) + else: + default_complex = None if dtype_nbits[default_int] == 32: default_uint = _name_to_dtype.get("uint32") From bef0ea392d0b3b19b13c9894cf28aec5bc505fb8 Mon Sep 17 00:00:00 2001 From: Christian Bourjau Date: Mon, 12 Aug 2024 16:25:32 +0200 Subject: [PATCH 3/3] Avoid invalid strategy when sampling from empty sequence --- array_api_tests/hypothesis_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index cc70c13d..54255cb4 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -186,11 +186,11 @@ def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes: real_floating_dtypes = sampled_from(dh.real_float_dtypes) numeric_dtypes = sampled_from(dh.numeric_dtypes) # Note: this always returns complex dtypes, even if api_version < 2022.12 -complex_dtypes = sampled_from(dh.complex_dtypes) +complex_dtypes: SearchStrategy[Any] | None = sampled_from(dh.complex_dtypes) if dh.complex_dtypes else None def all_floating_dtypes() -> SearchStrategy[DataType]: strat = floating_dtypes - if api_version >= "2022.12": + if api_version >= "2022.12" and complex_dtypes is not None: strat |= complex_dtypes return strat