Skip to content

Commit b41d447

Browse files
committed
Support complex in test_full, complex dtype utilities
1 parent 74101de commit b41d447

File tree

3 files changed

+47
-4
lines changed

3 files changed

+47
-4
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union
66
from warnings import warn
77

8+
from . import api_version
89
from . import _array_module as xp
910
from ._array_module import _UndefinedStub
1011
from .stubs import name_to_func
@@ -29,6 +30,7 @@
2930
"default_int",
3031
"default_uint",
3132
"default_float",
33+
"default_complex",
3234
"promotion_table",
3335
"dtype_nbits",
3436
"dtype_signed",
@@ -197,6 +199,15 @@ class MinMax(NamedTuple):
197199
default_float = xp.asarray(float()).dtype
198200
if default_float not in float_dtypes:
199201
warn(f"inferred default float is {default_float!r}, which is not a float")
202+
if api_version > "2021.12":
203+
default_complex = xp.asarray(complex()).dtype
204+
if default_complex not in complex_dtypes:
205+
warn(
206+
f"inferred default complex is {default_complex!r}, "
207+
"which is not a complex"
208+
)
209+
else:
210+
default_complex = None
200211
if dtype_nbits[default_int] == 32:
201212
default_uint = xp.uint32
202213
else:

array_api_tests/pytest_helpers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,23 @@ def assert_default_float(func_name: str, out_dtype: DataType):
170170
assert out_dtype == dh.default_float, msg
171171

172172

173+
def assert_default_complex(func_name: str, out_dtype: DataType):
174+
"""
175+
Assert the output dtype is the default complex, e.g.
176+
177+
>>> out = xp.asarray(4+2j)
178+
>>> assert_default_complex('asarray', out.dtype)
179+
180+
"""
181+
f_dtype = dh.dtype_to_name[out_dtype]
182+
f_default = dh.dtype_to_name[dh.default_complex]
183+
msg = (
184+
f"out.dtype={f_dtype}, should be default "
185+
f"complex dtype {f_default} [{func_name}()]"
186+
)
187+
assert out_dtype == dh.default_complex, msg
188+
189+
173190
def assert_default_int(func_name: str, out_dtype: DataType):
174191
"""
175192
Assert the output dtype is the default int, e.g.

array_api_tests/test_creation_functions.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,13 +369,15 @@ def test_eye(n_rows, n_cols, kw):
369369
default_unsafe_dtypes.extend([xp.uint32, xp.int64])
370370
if dh.default_float == xp.float32:
371371
default_unsafe_dtypes.append(xp.float64)
372+
if dh.default_complex == xp.complex64:
373+
default_unsafe_dtypes.append(xp.complex64)
372374
default_safe_dtypes: st.SearchStrategy = xps.scalar_dtypes().filter(
373375
lambda d: d not in default_unsafe_dtypes
374376
)
375377

376378

377379
@st.composite
378-
def full_fill_values(draw) -> st.SearchStrategy[float]:
380+
def full_fill_values(draw) -> st.SearchStrategy[Union[bool, int, float, complex]]:
379381
kw = draw(
380382
st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw")
381383
)
@@ -396,15 +398,28 @@ def test_full(shape, fill_value, kw):
396398
dtype = xp.bool
397399
elif isinstance(fill_value, int):
398400
dtype = dh.default_int
399-
else:
401+
elif isinstance(fill_value, float):
400402
dtype = dh.default_float
403+
else:
404+
assert isinstance(fill_value, complex) # sanity check
405+
dtype = dh.default_complex
406+
# Ignore large components so we don't fail like
407+
#
408+
# >>> torch.fill(complex(0.0, 3.402823466385289e+38))
409+
# RuntimeError: value cannot be converted to complex<float> without overflow
410+
#
411+
M = dh.dtype_ranges[dh.dtype_components[dtype]].max
412+
assume(all(abs(c) < math.sqrt(M) for c in [fill_value.real, fill_value.imag]))
401413
if kw.get("dtype", None) is None:
402414
if isinstance(fill_value, bool):
403-
pass # TODO
415+
assert out.dtype == xp.bool, f"{out.dtype=}, but should be bool [full()]"
404416
elif isinstance(fill_value, int):
405417
ph.assert_default_int("full", out.dtype)
406-
else:
418+
elif isinstance(fill_value, float):
407419
ph.assert_default_float("full", out.dtype)
420+
else:
421+
assert isinstance(fill_value, complex) # sanity check
422+
ph.assert_default_complex("full", out.dtype)
408423
else:
409424
ph.assert_kw_dtype("full", kw["dtype"], out.dtype)
410425
ph.assert_shape("full", out.shape, shape, shape=shape)

0 commit comments

Comments
 (0)