diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 6ff8d49c..4729d8f9 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -322,7 +322,7 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg): default_float = xp.asarray(float()).dtype if default_float not in real_float_dtypes: warn(f"inferred default float is {default_float!r}, which is not a float") - if api_version > "2021.12": + if api_version > "2021.12" and ({'complex64', 'complex128'} - set(skip_dtypes)): default_complex = xp.asarray(complex()).dtype if default_complex not in complex_dtypes: warn( diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index cc70c13d..54255cb4 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -186,11 +186,11 @@ def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes: real_floating_dtypes = sampled_from(dh.real_float_dtypes) numeric_dtypes = sampled_from(dh.numeric_dtypes) # Note: this always returns complex dtypes, even if api_version < 2022.12 -complex_dtypes = sampled_from(dh.complex_dtypes) +complex_dtypes: SearchStrategy[Any] | None = sampled_from(dh.complex_dtypes) if dh.complex_dtypes else None def all_floating_dtypes() -> SearchStrategy[DataType]: strat = floating_dtypes - if api_version >= "2022.12": + if api_version >= "2022.12" and complex_dtypes is not None: strat |= complex_dtypes return strat