From 82e6312fff3cc21fa3ae1f1179bb8d287235d195 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 7 Dec 2022 17:37:55 +0000 Subject: [PATCH 01/23] Rudimentary complex testing for unary elwise functions --- array_api_tests/dtype_helpers.py | 26 ++++++++++-- array_api_tests/hypothesis_helpers.py | 2 +- array_api_tests/pytest_helpers.py | 41 +++++++++++++------ ...est_operators_and_elementwise_functions.py | 41 +++++++++++++------ 4 files changed, 80 insertions(+), 30 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 1527611c..4afe73d7 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -15,6 +15,7 @@ "uint_dtypes", "all_int_dtypes", "float_dtypes", + "real_dtypes", "numeric_dtypes", "all_dtypes", "dtype_to_name", @@ -30,6 +31,7 @@ "promotion_table", "dtype_nbits", "dtype_signed", + "dtype_components", "func_in_dtypes", "func_returns_bool", "binary_op_to_symbol", @@ -86,14 +88,19 @@ def __repr__(self): _uint_names = ("uint8", "uint16", "uint32", "uint64") _int_names = ("int8", "int16", "int32", "int64") _float_names = ("float32", "float64") -_dtype_names = ("bool",) + _uint_names + _int_names + _float_names +_real_names = _uint_names + _int_names + _float_names +_complex_names = ("complex64", "complex128") +_numeric_names = _real_names + _complex_names +_dtype_names = ("bool",) + _numeric_names uint_dtypes = tuple(getattr(xp, name) for name in _uint_names) int_dtypes = tuple(getattr(xp, name) for name in _int_names) float_dtypes = tuple(getattr(xp, name) for name in _float_names) all_int_dtypes = uint_dtypes + int_dtypes -numeric_dtypes = all_int_dtypes + float_dtypes +real_dtypes = all_int_dtypes + float_dtypes +complex_dtypes = tuple(getattr(xp, name) for name in _complex_names) +numeric_dtypes = real_dtypes + complex_dtypes all_dtypes = (xp.bool,) + numeric_dtypes bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes @@ -129,6 +136,8 @@ def get_scalar_type(dtype: DataType) -> ScalarType: return int elif is_float_dtype(dtype): return float + elif dtype in complex_dtypes: + return complex else: return bool @@ -157,7 +166,8 @@ class MinMax(NamedTuple): [(d, 8) for d in [xp.int8, xp.uint8]] + [(d, 16) for d in [xp.int16, xp.uint16]] + [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]] - + [(d, 64) for d in [xp.int64, xp.uint64, xp.float64]] + + [(d, 64) for d in [xp.int64, xp.uint64, xp.float64, xp.complex64]] + + [(xp.complex128, 128)] ) @@ -166,6 +176,11 @@ class MinMax(NamedTuple): ) +dtype_components = EqualityMapping( + [(xp.complex64, xp.float32), (xp.complex128, xp.float64)] +) + + if isinstance(xp.asarray, _UndefinedStub): default_int = xp.int32 default_float = xp.float32 @@ -226,6 +241,11 @@ class MinMax(NamedTuple): ((xp.float32, xp.float32), xp.float32), ((xp.float32, xp.float64), xp.float64), ((xp.float64, xp.float64), xp.float64), + # complex + ((xp.complex64, xp.complex64), xp.complex64), + ((xp.complex64, xp.complex128), xp.complex128), + ((xp.complex128, xp.complex128), xp.complex128), + ] _numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions] _promotion_table = list(set(_numeric_promotions)) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 20cc0e03..6634b1fc 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -46,7 +46,7 @@ shared_dtypes = shared(dtypes, key="dtype") shared_floating_dtypes = shared(floating_dtypes, key="dtype") -_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes] +_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes, dh.complex_dtypes] _sorted_dtypes = [d for category in _dtype_categories for d in category] def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]): diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 051a063f..8ba39731 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -374,6 +374,21 @@ def assert_fill( assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg +def _assert_float_element(at_out: Array, at_expected: Array, msg: str): + if xp.isnan(at_expected): + assert xp.isnan(at_out), msg + elif at_expected == 0.0 or at_expected == -0.0: + scalar_at_expected = float(at_expected) + scalar_at_out = float(at_out) + if is_pos_zero(scalar_at_expected): + assert is_pos_zero(scalar_at_out), msg + else: + assert is_neg_zero(scalar_at_expected) # sanity check + assert is_neg_zero(scalar_at_out), msg + else: + assert at_out == at_expected, msg + + def assert_array_elements( func_name: str, out: Array, expected: Array, /, *, out_repr: str = "out", **kw ): @@ -392,7 +407,17 @@ def assert_array_elements( dh.result_type(out.dtype, expected.dtype) # sanity check assert_shape(func_name, out.shape, expected.shape, **kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" - if dh.is_float_dtype(out.dtype): + if out.dtype in dh.float_dtypes: + for idx in sh.ndindex(out.shape): + at_out = out[idx] + at_expected = expected[idx] + msg = ( + f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " + f"{f_func}" + ) + _assert_float_element(at_out, at_expected, msg) + elif out.dtype in dh.complex_dtypes: + assert (out.dtype in dh.complex_dtypes) == (expected.dtype in dh.complex_dtypes) for idx in sh.ndindex(out.shape): at_out = out[idx] at_expected = expected[idx] @@ -400,18 +425,8 @@ def assert_array_elements( f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " f"{f_func}" ) - if xp.isnan(at_expected): - assert xp.isnan(at_out), msg - elif at_expected == 0.0 or at_expected == -0.0: - scalar_at_expected = float(at_expected) - scalar_at_out = float(at_out) - if is_pos_zero(scalar_at_expected): - assert is_pos_zero(scalar_at_out), msg - else: - assert is_neg_zero(scalar_at_expected) # sanity check - assert is_neg_zero(scalar_at_out), msg - else: - assert at_out == at_expected, msg + _assert_float_element(at_out.real, at_expected.real, msg) + _assert_float_element(at_out.imag, at_expected.imag, msg) else: assert xp.all( out == expected diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 967a43a6..774afd89 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -103,6 +103,8 @@ def default_filter(s: Scalar) -> bool: """ if isinstance(s, int): # note bools are ints return True + elif isinstance(s, complex): + return default_filter(s.real) and default_filter(s.imag) else: return math.isfinite(s) and s != 0 @@ -247,7 +249,12 @@ def unary_assert_against_refimpl( in_stype = dh.get_scalar_type(in_.dtype) if res_stype is None: res_stype = in_stype - m, M = dh.dtype_ranges.get(res.dtype, (None, None)) + if res.dtype == xp.bool: + m, M = (None, None) + if res.dtype in dh.complex_dtypes: + m, M = dh.dtype_ranges[dh.dtype_components[res.dtype]] + else: + m, M = dh.dtype_ranges[res.dtype] for idx in sh.ndindex(in_.shape): scalar_i = in_stype(in_[idx]) if not filter_(scalar_i): @@ -257,9 +264,13 @@ def unary_assert_against_refimpl( except Exception: continue if res.dtype != xp.bool: - assert m is not None and M is not None # for mypy - if expected <= m or expected >= M: - continue + if res.dtype in dh.complex_dtypes: + for component in [expected.real, expected.imag]: + if component <= m or expected >= M: + continue + else: + if expected <= m or expected >= M: + continue scalar_o = res_stype(res[idx]) f_i = sh.fmt_idx("x", idx) f_o = sh.fmt_idx("out", idx) @@ -418,8 +429,11 @@ def __repr__(self): def make_unary_params( - elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType] + elwise_func_name: str, dtypes: Sequence[DataType] ) -> List[Param[UnaryParamContext]]: + if hh.FILTER_UNDEFINED_DTYPES: + dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] + dtypes_strat = st.sampled_from(dtypes) strat = xps.arrays(dtype=dtypes_strat, shape=hh.shapes()) func_ctx = UnaryParamContext( func_name=elwise_func_name, func=getattr(xp, elwise_func_name), strat=strat @@ -633,7 +647,7 @@ def binary_param_assert_against_refimpl( ) -@pytest.mark.parametrize("ctx", make_unary_params("abs", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_unary_params("abs", dh.numeric_dtypes)) @given(data=st.data()) def test_abs(ctx, data): x = data.draw(ctx.strat, label="x") @@ -643,7 +657,10 @@ def test_abs(ctx, data): out = ctx.func(x) - ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) + if x.dtype in dh.complex_dtypes: + assert out.dtype == dh.complex_components[x.dtype] + else: + ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) unary_assert_against_refimpl( ctx.func_name, @@ -783,7 +800,7 @@ def test_bitwise_left_shift(ctx, data): @pytest.mark.parametrize( - "ctx", make_unary_params("bitwise_invert", boolean_and_all_integer_dtypes()) + "ctx", make_unary_params("bitwise_invert", dh.bool_and_all_int_dtypes) ) @given(data=st.data()) def test_bitwise_invert(ctx, data): @@ -1187,9 +1204,7 @@ def test_multiply(ctx, data): # TODO: clarify if uints are acceptable, adjust accordingly -@pytest.mark.parametrize( - "ctx", make_unary_params("negative", xps.integer_dtypes() | xps.floating_dtypes()) -) +@pytest.mark.parametrize("ctx", make_unary_params("negative", dh.numeric_dtypes)) @given(data=st.data()) def test_negative(ctx, data): x = data.draw(ctx.strat, label="x") @@ -1226,7 +1241,7 @@ def test_not_equal(ctx, data): ) -@pytest.mark.parametrize("ctx", make_unary_params("positive", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_unary_params("positive", dh.numeric_dtypes)) @given(data=st.data()) def test_positive(ctx, data): x = data.draw(ctx.strat, label="x") @@ -1317,7 +1332,7 @@ def test_square(x): ph.assert_dtype("square", x.dtype, out.dtype) ph.assert_shape("square", out.shape, x.shape) unary_assert_against_refimpl( - "square", x, out, lambda s: s ** 2, expr_template="{}²={}" + "square", x, out, lambda s: s**2, expr_template="{}²={}" ) From 7a1e48e2bc93fbbc46df51bf0e017649eaf0b4f9 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 7 Dec 2022 18:31:02 +0000 Subject: [PATCH 02/23] Define elwise filters only for component dtypes --- .../test_operators_and_elementwise_functions.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 774afd89..37b630ad 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -3,6 +3,7 @@ """ import math import operator +from copy import copy from enum import Enum, auto from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union @@ -103,8 +104,6 @@ def default_filter(s: Scalar) -> bool: """ if isinstance(s, int): # note bools are ints return True - elif isinstance(s, complex): - return default_filter(s.real) and default_filter(s.imag) else: return math.isfinite(s) and s != 0 @@ -255,6 +254,9 @@ def unary_assert_against_refimpl( m, M = dh.dtype_ranges[dh.dtype_components[res.dtype]] else: m, M = dh.dtype_ranges[res.dtype] + if in_.dtype in dh.complex_dtypes: + component_filter = copy(filter_) + filter_ = lambda s: component_filter(s.real) and component_filter(s.imag) for idx in sh.ndindex(in_.shape): scalar_i = in_stype(in_[idx]) if not filter_(scalar_i): @@ -313,6 +315,9 @@ def binary_assert_against_refimpl( if res_stype is None: res_stype = in_stype m, M = dh.dtype_ranges.get(res.dtype, (None, None)) + if left.dtype in dh.complex_dtypes: + component_filter = copy(filter_) + filter_ = lambda s: component_filter(s.real) and component_filter(s.imag) for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): scalar_l = in_stype(left[l_idx]) scalar_r = in_stype(right[r_idx]) From 5889a7ee1dfb1ad6d23206af954567bb8a126418 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 8 Dec 2022 12:39:46 +0000 Subject: [PATCH 03/23] Complex testing for all elwise funcs --- array_api_tests/dtype_helpers.py | 4 +- ...est_operators_and_elementwise_functions.py | 149 ++++++++++++------ 2 files changed, 102 insertions(+), 51 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 4afe73d7..87e7823f 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -18,8 +18,9 @@ "real_dtypes", "numeric_dtypes", "all_dtypes", - "dtype_to_name", + "all_float_dtypes", "bool_and_all_int_dtypes", + "dtype_to_name", "dtype_to_scalars", "is_int_dtype", "is_float_dtype", @@ -102,6 +103,7 @@ def __repr__(self): complex_dtypes = tuple(getattr(xp, name) for name in _complex_names) numeric_dtypes = real_dtypes + complex_dtypes all_dtypes = (xp.bool,) + numeric_dtypes +all_float_dtypes = float_dtypes + complex_dtypes bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 37b630ad..e66612df 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -34,6 +34,10 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: return xps.boolean_dtypes() | all_integer_dtypes() +def all_floating_dtypes() -> st.SearchStrategy[DataType]: + return xps.floating_dtypes() | xps.complex_dtypes() + + class OnewayPromotableDtypes(NamedTuple): input_dtype: DataType result_dtype: DataType @@ -250,7 +254,7 @@ def unary_assert_against_refimpl( res_stype = in_stype if res.dtype == xp.bool: m, M = (None, None) - if res.dtype in dh.complex_dtypes: + elif res.dtype in dh.complex_dtypes: m, M = dh.dtype_ranges[dh.dtype_components[res.dtype]] else: m, M = dh.dtype_ranges[res.dtype] @@ -267,9 +271,10 @@ def unary_assert_against_refimpl( continue if res.dtype != xp.bool: if res.dtype in dh.complex_dtypes: - for component in [expected.real, expected.imag]: - if component <= m or expected >= M: - continue + if expected.real <= m or expected.real >= M: + continue + if expected.imag <= m or expected.imag >= M: + continue else: if expected <= m or expected >= M: continue @@ -277,11 +282,16 @@ def unary_assert_against_refimpl( f_i = sh.fmt_idx("x", idx) f_o = sh.fmt_idx("out", idx) expr = expr_template.format(f_i, expected) - if strict_check == False or dh.is_float_dtype(res.dtype): - assert isclose(scalar_o, expected), ( + if strict_check == False or res.dtype in dh.all_float_dtypes: + msg = ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_i}={scalar_i}" ) + if res.dtype in dh.complex_dtypes: + assert isclose(scalar_o.real, expected.real), msg + assert isclose(scalar_o.imag, expected.imag), msg + else: + assert isclose(scalar_o, expected), msg else: assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" @@ -314,7 +324,14 @@ def binary_assert_against_refimpl( in_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype - m, M = dh.dtype_ranges.get(res.dtype, (None, None)) + if res_stype is None: + res_stype = in_stype + if res.dtype == xp.bool: + m, M = (None, None) + elif res.dtype in dh.complex_dtypes: + m, M = dh.dtype_ranges[dh.dtype_components[res.dtype]] + else: + m, M = dh.dtype_ranges[res.dtype] if left.dtype in dh.complex_dtypes: component_filter = copy(filter_) filter_ = lambda s: component_filter(s.real) and component_filter(s.imag) @@ -328,19 +345,29 @@ def binary_assert_against_refimpl( except Exception: continue if res.dtype != xp.bool: - assert m is not None and M is not None # for mypy - if expected <= m or expected >= M: - continue + if res.dtype in dh.complex_dtypes: + if expected.real <= m or expected.real >= M: + continue + if expected.imag <= m or expected.imag >= M: + continue + else: + if expected <= m or expected >= M: + continue scalar_o = res_stype(res[o_idx]) f_l = sh.fmt_idx(left_sym, l_idx) f_r = sh.fmt_idx(right_sym, r_idx) f_o = sh.fmt_idx(res_name, o_idx) expr = expr_template.format(f_l, f_r, expected) - if strict_check == False or dh.is_float_dtype(res.dtype): - assert isclose(scalar_o, expected), ( + if strict_check == False or res.dtype in dh.all_float_dtypes: + msg = ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_l}={scalar_l}, {f_r}={scalar_r}" ) + if res.dtype in dh.complex_dtypes: + assert isclose(scalar_o.real, expected.real), msg + assert isclose(scalar_o.imag, expected.imag), msg + else: + assert isclose(scalar_o, expected), msg else: assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" @@ -367,33 +394,53 @@ def right_scalar_assert_against_refimpl( See unary_assert_against_refimpl for more information. """ + if left.dtype in dh.complex_dtypes: + component_filter = copy(filter_) + filter_ = lambda s: component_filter(s.real) and component_filter(s.imag) if filter_(right): return # short-circuit here as there will be nothing to test in_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype - m, M = dh.dtype_ranges.get(left.dtype, (None, None)) + if res_stype is None: + res_stype = in_stype + if res.dtype == xp.bool: + m, M = (None, None) + elif left.dtype in dh.complex_dtypes: + m, M = dh.dtype_ranges[dh.dtype_components[left.dtype]] + else: + m, M = dh.dtype_ranges[left.dtype] for idx in sh.ndindex(res.shape): scalar_l = in_stype(left[idx]) - if not filter_(scalar_l): + if not (filter_(scalar_l) and filter_(right)): continue try: expected = refimpl(scalar_l, right) except Exception: continue if left.dtype != xp.bool: - assert m is not None and M is not None # for mypy - if expected <= m or expected >= M: - continue + if res.dtype in dh.complex_dtypes: + if expected.real <= m or expected.real >= M: + continue + if expected.imag <= m or expected.imag >= M: + continue + else: + if expected <= m or expected >= M: + continue scalar_o = res_stype(res[idx]) f_l = sh.fmt_idx(left_sym, idx) f_o = sh.fmt_idx(res_name, idx) expr = expr_template.format(f_l, right, expected) - if strict_check == False or dh.is_float_dtype(res.dtype): - assert isclose(scalar_o, expected), ( + if strict_check == False or res.dtype in dh.all_float_dtypes: + msg = ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_l}={scalar_l}" ) + if res.dtype in dh.complex_dtypes: + assert isclose(scalar_o.real, expected.real), msg + assert isclose(scalar_o.imag, expected.imag), msg + else: + assert isclose(scalar_o, expected), msg else: assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" @@ -663,7 +710,7 @@ def test_abs(ctx, data): out = ctx.func(x) if x.dtype in dh.complex_dtypes: - assert out.dtype == dh.complex_components[x.dtype] + assert out.dtype == dh.dtype_components[x.dtype] else: ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) @@ -672,6 +719,7 @@ def test_abs(ctx, data): x, out, abs, # type: ignore + res_stype=float if x.dtype in dh.complex_dtypes else None, expr_template="abs({})={}", filter_=lambda s: ( s == float("infinity") or (math.isfinite(s) and not ph.is_neg_zero(s)) @@ -679,7 +727,7 @@ def test_abs(ctx, data): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_acos(x): out = xp.acos(x) ph.assert_dtype("acos", x.dtype, out.dtype) @@ -689,7 +737,7 @@ def test_acos(x): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_acosh(x): out = xp.acosh(x) ph.assert_dtype("acosh", x.dtype, out.dtype) @@ -715,7 +763,7 @@ def test_add(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_asin(x): out = xp.asin(x) ph.assert_dtype("asin", x.dtype, out.dtype) @@ -725,7 +773,7 @@ def test_asin(x): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_asinh(x): out = xp.asinh(x) ph.assert_dtype("asinh", x.dtype, out.dtype) @@ -733,7 +781,7 @@ def test_asinh(x): unary_assert_against_refimpl("asinh", x, out, math.asinh) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_atan(x): out = xp.atan(x) ph.assert_dtype("atan", x.dtype, out.dtype) @@ -749,7 +797,7 @@ def test_atan2(x1, x2): binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_atanh(x): out = xp.atanh(x) ph.assert_dtype("atanh", x.dtype, out.dtype) @@ -881,7 +929,7 @@ def test_bitwise_xor(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "^", refimpl) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=xps.real_dtypes(), shape=hh.shapes())) def test_ceil(x): out = xp.ceil(x) ph.assert_dtype("ceil", x.dtype, out.dtype) @@ -889,7 +937,7 @@ def test_ceil(x): unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_cos(x): out = xp.cos(x) ph.assert_dtype("cos", x.dtype, out.dtype) @@ -897,7 +945,7 @@ def test_cos(x): unary_assert_against_refimpl("cos", x, out, math.cos) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_cosh(x): out = xp.cosh(x) ph.assert_dtype("cosh", x.dtype, out.dtype) @@ -905,7 +953,7 @@ def test_cosh(x): unary_assert_against_refimpl("cosh", x, out, math.cosh) -@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.float_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.all_float_dtypes)) @given(data=st.data()) def test_divide(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -956,7 +1004,7 @@ def test_equal(ctx, data): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_exp(x): out = xp.exp(x) ph.assert_dtype("exp", x.dtype, out.dtype) @@ -964,7 +1012,7 @@ def test_exp(x): unary_assert_against_refimpl("exp", x, out, math.exp) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_expm1(x): out = xp.expm1(x) ph.assert_dtype("expm1", x.dtype, out.dtype) @@ -972,7 +1020,7 @@ def test_expm1(x): unary_assert_against_refimpl("expm1", x, out, math.expm1) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=xps.real_dtypes(), shape=hh.shapes())) def test_floor(x): out = xp.floor(x) ph.assert_dtype("floor", x.dtype, out.dtype) @@ -980,7 +1028,7 @@ def test_floor(x): unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True) -@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.real_dtypes)) @given(data=st.data()) def test_floor_divide(ctx, data): left = data.draw( @@ -999,7 +1047,7 @@ def test_floor_divide(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv) -@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.real_dtypes)) @given(data=st.data()) def test_greater(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1019,7 +1067,7 @@ def test_greater(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.real_dtypes)) @given(data=st.data()) def test_greater_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1063,7 +1111,7 @@ def test_isnan(x): unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool) -@pytest.mark.parametrize("ctx", make_binary_params("less", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("less", dh.real_dtypes)) @given(data=st.data()) def test_less(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1083,7 +1131,7 @@ def test_less(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.real_dtypes)) @given(data=st.data()) def test_less_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1103,7 +1151,7 @@ def test_less_equal(ctx, data): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_log(x): out = xp.log(x) ph.assert_dtype("log", x.dtype, out.dtype) @@ -1113,7 +1161,7 @@ def test_log(x): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_log1p(x): out = xp.log1p(x) ph.assert_dtype("log1p", x.dtype, out.dtype) @@ -1123,7 +1171,7 @@ def test_log1p(x): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_log2(x): out = xp.log2(x) ph.assert_dtype("log2", x.dtype, out.dtype) @@ -1133,7 +1181,7 @@ def test_log2(x): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_log10(x): out = xp.log10(x) ph.assert_dtype("log10", x.dtype, out.dtype) @@ -1280,7 +1328,7 @@ def test_pow(ctx, data): # Values testing pow is too finicky -@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes)) @given(data=st.data()) def test_remainder(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1305,7 +1353,8 @@ def test_round(x): unary_assert_against_refimpl("round", x, out, round, strict_check=True) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(), elements=finite_kw)) +# TODO: https://github.com/data-apis/array-api/issues/545 +@given(xps.arrays(dtype=xps.real_dtypes(), shape=hh.shapes(), elements=finite_kw)) def test_sign(x): out = xp.sign(x) ph.assert_dtype("sign", x.dtype, out.dtype) @@ -1315,7 +1364,7 @@ def test_sign(x): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_sin(x): out = xp.sin(x) ph.assert_dtype("sin", x.dtype, out.dtype) @@ -1323,7 +1372,7 @@ def test_sin(x): unary_assert_against_refimpl("sin", x, out, math.sin) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_sinh(x): out = xp.sinh(x) ph.assert_dtype("sinh", x.dtype, out.dtype) @@ -1341,7 +1390,7 @@ def test_square(x): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_sqrt(x): out = xp.sqrt(x) ph.assert_dtype("sqrt", x.dtype, out.dtype) @@ -1367,7 +1416,7 @@ def test_subtract(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_tan(x): out = xp.tan(x) ph.assert_dtype("tan", x.dtype, out.dtype) @@ -1375,7 +1424,7 @@ def test_tan(x): unary_assert_against_refimpl("tan", x, out, math.tan) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_tanh(x): out = xp.tanh(x) ph.assert_dtype("tanh", x.dtype, out.dtype) @@ -1383,7 +1432,7 @@ def test_tanh(x): unary_assert_against_refimpl("tanh", x, out, math.tanh) -@given(xps.arrays(dtype=hh.numeric_dtypes, shape=xps.array_shapes())) +@given(xps.arrays(dtype=xps.real_dtypes(), shape=xps.array_shapes())) def test_trunc(x): out = xp.trunc(x) ph.assert_dtype("trunc", x.dtype, out.dtype) From ff865bcdbbacf4729a581a0f1bfeff8612533484 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Sat, 10 Dec 2022 14:10:49 +0000 Subject: [PATCH 04/23] TODOs for `test_divide` --- array_api_tests/test_operators_and_elementwise_functions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index e66612df..0d156938 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -959,12 +959,14 @@ def test_divide(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) if ctx.right_is_scalar: - assume + assume # TODO: assume what? res = ctx.func(left, right) binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) + if res.dtype in dh.complex_dtypes: + return # TOOD: handle complex division binary_param_assert_against_refimpl( ctx, left, From 9f03c9219ecef2ae6a320825a5324a8ae021a9db Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Sun, 11 Dec 2022 14:07:36 +0000 Subject: [PATCH 05/23] Loose assertion of infinities to very large floats --- ...est_operators_and_elementwise_functions.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 0d156938..978a29cc 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -90,14 +90,25 @@ def mock_int_dtype(n: int, dtype: DataType) -> int: return n -def isclose(a: float, b: float, *, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: +def isclose( + a: float, + b: float, + M: float, + *, + rel_tol: float = 0.25, + abs_tol: float = 1, +) -> bool: """Wraps math.isclose with very generous defaults. This is useful for many floating-point operations where the spec does not make accuracy requirements. """ - if not (math.isfinite(a) and math.isfinite(b)): - raise ValueError(f"{a=} and {b=}, but input must be finite") + if math.isnan(a) or math.isnan(b): + raise ValueError(f"{a=} and {b=}, but input must be non-NaN") + if math.isinf(a): + return math.isinf(b) or abs(b) > math.log(M) + elif math.isinf(b): + return math.isinf(a) or abs(a) > math.log(M) return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) @@ -288,10 +299,10 @@ def unary_assert_against_refimpl( f"{f_i}={scalar_i}" ) if res.dtype in dh.complex_dtypes: - assert isclose(scalar_o.real, expected.real), msg - assert isclose(scalar_o.imag, expected.imag), msg + assert isclose(scalar_o.real, expected.real, M), msg + assert isclose(scalar_o.imag, expected.imag, M), msg else: - assert isclose(scalar_o, expected), msg + assert isclose(scalar_o, expected, M), msg else: assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" @@ -364,10 +375,10 @@ def binary_assert_against_refimpl( f"{f_l}={scalar_l}, {f_r}={scalar_r}" ) if res.dtype in dh.complex_dtypes: - assert isclose(scalar_o.real, expected.real), msg - assert isclose(scalar_o.imag, expected.imag), msg + assert isclose(scalar_o.real, expected.real, M), msg + assert isclose(scalar_o.imag, expected.imag, M), msg else: - assert isclose(scalar_o, expected), msg + assert isclose(scalar_o, expected, M), msg else: assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" @@ -437,10 +448,10 @@ def right_scalar_assert_against_refimpl( f"{f_l}={scalar_l}" ) if res.dtype in dh.complex_dtypes: - assert isclose(scalar_o.real, expected.real), msg - assert isclose(scalar_o.imag, expected.imag), msg + assert isclose(scalar_o.real, expected.real, M), msg + assert isclose(scalar_o.imag, expected.imag, M), msg else: - assert isclose(scalar_o, expected), msg + assert isclose(scalar_o, expected, M), msg else: assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" From 1ba2efdabde15a917e6cf54366405e454d1e4a5a Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 12 Dec 2022 13:03:53 +0000 Subject: [PATCH 06/23] Complex-related updates * Update test_sign for complex inputs * Use result dtype for res_type inference in assert-against-refimpl utils * Test `xp.real()` and `xp.imag()` --- ...est_operators_and_elementwise_functions.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 978a29cc..a5cdbf9e 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -262,7 +262,7 @@ def unary_assert_against_refimpl( expr_template = func_name + "({})={}" in_stype = dh.get_scalar_type(in_.dtype) if res_stype is None: - res_stype = in_stype + res_stype = dh.get_scalar_type(res.dtype) if res.dtype == xp.bool: m, M = (None, None) elif res.dtype in dh.complex_dtypes: @@ -334,7 +334,7 @@ def binary_assert_against_refimpl( expr_template = func_name + "({}, {})={}" in_stype = dh.get_scalar_type(left.dtype) if res_stype is None: - res_stype = in_stype + res_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype if res.dtype == xp.bool: @@ -412,7 +412,7 @@ def right_scalar_assert_against_refimpl( return # short-circuit here as there will be nothing to test in_stype = dh.get_scalar_type(left.dtype) if res_stype is None: - res_stype = in_stype + res_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype if res.dtype == xp.bool: @@ -1100,6 +1100,14 @@ def test_greater_equal(ctx, data): ) +@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) +def test_imag(x): + out = xp.imag(x) + ph.assert_dtype("imag", x.dtype, out.dtype, dh.dtype_components[x.dtype]) + ph.assert_shape("imag", out.shape, x.shape) + unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag")) + + @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_isfinite(x): out = xp.isfinite(x) @@ -1341,6 +1349,14 @@ def test_pow(ctx, data): # Values testing pow is too finicky +@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) +def test_real(x): + out = xp.real(x) + ph.assert_dtype("real", x.dtype, out.dtype, dh.dtype_components[x.dtype]) + ph.assert_shape("real", out.shape, x.shape) + unary_assert_against_refimpl("real", x, out, operator.attrgetter("real")) + + @pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes)) @given(data=st.data()) def test_remainder(ctx, data): @@ -1366,8 +1382,7 @@ def test_round(x): unary_assert_against_refimpl("round", x, out, round, strict_check=True) -# TODO: https://github.com/data-apis/array-api/issues/545 -@given(xps.arrays(dtype=xps.real_dtypes(), shape=hh.shapes(), elements=finite_kw)) +@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(), elements=finite_kw)) def test_sign(x): out = xp.sign(x) ph.assert_dtype("sign", x.dtype, out.dtype) From 2d6b2d8a98170609bd3a239304eb31540232f775 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 12 Dec 2022 13:15:29 +0000 Subject: [PATCH 07/23] `test_conj` --- .../test_operators_and_elementwise_functions.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index a5cdbf9e..adbc7530 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -948,6 +948,14 @@ def test_ceil(x): unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) +@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) +def test_conj(x): + out = xp.conj(x) + ph.assert_dtype("conj", x.dtype, out.dtype) + ph.assert_shape("conj", out.shape, x.shape) + unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) + + @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_cos(x): out = xp.cos(x) From f5723ade8712c6023b08386fc58ee7a1523461a8 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 27 Jan 2023 11:49:01 +0000 Subject: [PATCH 08/23] `min_version()` marker and other versioning nicities --- README.md | 6 ++ array_api_tests/__init__.py | 13 ++-- array_api_tests/_array_module.py | 1 + ...est_operators_and_elementwise_functions.py | 67 +++++++++++++------ conftest.py | 15 +++++ 5 files changed, 75 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 9eebc397..5f5814d2 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,12 @@ library to fail. ### Configuration +#### API version + +You can specify the API version to use when testing via the +ARRAY_API_TESTS_VERSION environment variable. Currently this defaults to +`"2021.12"`. + #### CI flag Use the `--ci` flag to run only the primary and special cases tests. You can diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index c472b862..931b83a9 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -1,11 +1,16 @@ from functools import wraps +from os import getenv from hypothesis import strategies as st from hypothesis.extra import array_api +from . import _version from ._array_module import mod as _xp -__all__ = ["xps"] +__all__ = ["COMPLEX_VER", "api_version", "xps"] + + +COMPLEX_VER: str = "2022.12" # We monkey patch floats() to always disable subnormals as they are out-of-scope @@ -41,9 +46,7 @@ def _from_dtype(*a, **kw): pass -xps = array_api.make_strategies_namespace(_xp, api_version="2021.12") - - -from . import _version +api_version = getenv("ARRAY_API_TESTS_VERSION", "2021.12") +xps = array_api.make_strategies_namespace(_xp, api_version=api_version) __version__ = _version.get_versions()["version"] diff --git a/array_api_tests/_array_module.py b/array_api_tests/_array_module.py index e83cd6ca..b4aaf76c 100644 --- a/array_api_tests/_array_module.py +++ b/array_api_tests/_array_module.py @@ -58,6 +58,7 @@ def __repr__(self): "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "float32", "float64", + "complex64", "complex128", ] _constants = ["e", "inf", "nan", "pi"] _funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs] diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index adbc7530..d921dedf 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -12,7 +12,7 @@ from hypothesis import strategies as st from hypothesis.control import reject -from . import _array_module as xp +from . import COMPLEX_VER, _array_module as xp, api_version from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh @@ -35,7 +35,10 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: def all_floating_dtypes() -> st.SearchStrategy[DataType]: - return xps.floating_dtypes() | xps.complex_dtypes() + strat = xps.floating_dtypes() + if api_version >= COMPLEX_VER: + strat |= xps.complex_dtypes() + return strat class OnewayPromotableDtypes(NamedTuple): @@ -492,10 +495,15 @@ def __repr__(self): def make_unary_params( - elwise_func_name: str, dtypes: Sequence[DataType] + elwise_func_name: str, + dtypes: Sequence[DataType], + *, + min_version: str = "2021.12", ) -> List[Param[UnaryParamContext]]: if hh.FILTER_UNDEFINED_DTYPES: dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] + if api_version < COMPLEX_VER: + dtypes = [d for d in dtypes if d not in dh.complex_dtypes] dtypes_strat = st.sampled_from(dtypes) strat = xps.arrays(dtype=dtypes_strat, shape=hh.shapes()) func_ctx = UnaryParamContext( @@ -505,7 +513,16 @@ def make_unary_params( op_ctx = UnaryParamContext( func_name=op_name, func=lambda x: getattr(x, op_name)(), strat=strat ) - return [pytest.param(func_ctx, id=func_ctx.id), pytest.param(op_ctx, id=op_ctx.id)] + if api_version < min_version: + marks = pytest.mark.skip( + reason=f"requires ARRAY_API_TESTS_VERSION=>{min_version}" + ) + else: + marks = () + return [ + pytest.param(func_ctx, id=func_ctx.id, marks=marks), + pytest.param(op_ctx, id=op_ctx.id, marks=marks), + ] class FuncType(Enum): @@ -948,12 +965,14 @@ def test_ceil(x): unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) -@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) -def test_conj(x): - out = xp.conj(x) - ph.assert_dtype("conj", x.dtype, out.dtype) - ph.assert_shape("conj", out.shape, x.shape) - unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) +if api_version >= COMPLEX_VER: + + @given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) + def test_conj(x): + out = xp.conj(x) + ph.assert_dtype("conj", x.dtype, out.dtype) + ph.assert_shape("conj", out.shape, x.shape) + unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) @given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) @@ -1108,12 +1127,14 @@ def test_greater_equal(ctx, data): ) -@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) -def test_imag(x): - out = xp.imag(x) - ph.assert_dtype("imag", x.dtype, out.dtype, dh.dtype_components[x.dtype]) - ph.assert_shape("imag", out.shape, x.shape) - unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag")) +if api_version >= COMPLEX_VER: + + @given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) + def test_imag(x): + out = xp.imag(x) + ph.assert_dtype("imag", x.dtype, out.dtype, dh.dtype_components[x.dtype]) + ph.assert_shape("imag", out.shape, x.shape) + unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag")) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) @@ -1357,12 +1378,14 @@ def test_pow(ctx, data): # Values testing pow is too finicky -@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) -def test_real(x): - out = xp.real(x) - ph.assert_dtype("real", x.dtype, out.dtype, dh.dtype_components[x.dtype]) - ph.assert_shape("real", out.shape, x.shape) - unary_assert_against_refimpl("real", x, out, operator.attrgetter("real")) +if api_version >= COMPLEX_VER: + + @given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) + def test_real(x): + out = xp.real(x) + ph.assert_dtype("real", x.dtype, out.dtype, dh.dtype_components[x.dtype]) + ph.assert_shape("real", out.shape, x.shape) + unary_assert_against_refimpl("real", x, out, operator.attrgetter("real")) @pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes)) diff --git a/conftest.py b/conftest.py index e0453e40..9b7e7956 100644 --- a/conftest.py +++ b/conftest.py @@ -5,6 +5,7 @@ from pytest import mark from array_api_tests import _array_module as xp +from array_api_tests import api_version from array_api_tests._array_module import _UndefinedStub from reporting import pytest_metadata, pytest_json_modifyreport, add_extra_json_metadata # noqa @@ -59,6 +60,10 @@ def pytest_configure(config): "markers", "data_dependent_shapes: output shapes are dependent on inputs" ) config.addinivalue_line("markers", "ci: primary test") + config.addinivalue_line( + "markers", + "min_version(api_version): run when greater or equal to api_version", + ) # Hypothesis hypothesis_max_examples = config.getoption("--hypothesis-max-examples") disable_deadline = config.getoption("--hypothesis-disable-deadline") @@ -126,3 +131,13 @@ def pytest_collection_modifyitems(config, items): ci_mark = next((m for m in markers if m.name == "ci"), None) if ci_mark is None: item.add_marker(mark.skip(reason="disabled via --ci")) + # skip if test is for greater api_version + ver_mark = next((m for m in markers if m.name == "min_version"), None) + if ver_mark is not None: + min_version = ver_mark.args[0] + if api_version < min_version: + item.add_marker( + mark.skip( + reason=f"requires ARRAY_API_TESTS_VERSION=>{min_version}" + ) + ) From f34bbf5fe5c0c39a2d2b270a5547f9de4c38ca83 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 27 Jan 2023 12:02:05 +0000 Subject: [PATCH 09/23] Try inferring `api_version` from `xp.__array_api_version__` --- README.md | 5 +++-- array_api_tests/__init__.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5f5814d2..dcfe6c9f 100644 --- a/README.md +++ b/README.md @@ -163,8 +163,9 @@ library to fail. #### API version You can specify the API version to use when testing via the -ARRAY_API_TESTS_VERSION environment variable. Currently this defaults to -`"2021.12"`. +`ARRAY_API_TESTS_VERSION` environment variable. Currently this defaults to the +array module's `__array_api_version__` value, and if that attribute doesn't +exist then we fallback to `"2021.12"`. #### CI flag diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index 931b83a9..2d3ea019 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -46,7 +46,9 @@ def _from_dtype(*a, **kw): pass -api_version = getenv("ARRAY_API_TESTS_VERSION", "2021.12") +api_version = getenv( + "ARRAY_API_TESTS_VERSION", getattr(_xp, "__array_api_version__", "2021.12") +) xps = array_api.make_strategies_namespace(_xp, api_version=api_version) __version__ = _version.get_versions()["version"] From e1d56a01d9d7ff0c4ac2deddb6bcf84974800dcf Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 14 Feb 2023 11:04:44 +0000 Subject: [PATCH 10/23] Bump Hypothesis to `>=6.68.0` --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 07b8b189..bb33bc90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ pytest pytest-json-report -hypothesis>=6.62.1 +hypothesis>=6.68.0 ndindex>=1.6 From 03b735f8ed5406c369b3fcc6080897b95320eb4f Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 14 Feb 2023 11:17:09 +0000 Subject: [PATCH 11/23] Remove `COMPLEX_VER` --- array_api_tests/__init__.py | 5 +---- .../test_operators_and_elementwise_functions.py | 12 ++++++------ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index 2d3ea019..e083d522 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -7,10 +7,7 @@ from . import _version from ._array_module import mod as _xp -__all__ = ["COMPLEX_VER", "api_version", "xps"] - - -COMPLEX_VER: str = "2022.12" +__all__ = ["api_version", "xps"] # We monkey patch floats() to always disable subnormals as they are out-of-scope diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index d921dedf..7521f56d 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -12,7 +12,7 @@ from hypothesis import strategies as st from hypothesis.control import reject -from . import COMPLEX_VER, _array_module as xp, api_version +from . import _array_module as xp, api_version from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh @@ -36,7 +36,7 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: def all_floating_dtypes() -> st.SearchStrategy[DataType]: strat = xps.floating_dtypes() - if api_version >= COMPLEX_VER: + if api_version >= "2022.12": strat |= xps.complex_dtypes() return strat @@ -502,7 +502,7 @@ def make_unary_params( ) -> List[Param[UnaryParamContext]]: if hh.FILTER_UNDEFINED_DTYPES: dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] - if api_version < COMPLEX_VER: + if api_version < "2022.12": dtypes = [d for d in dtypes if d not in dh.complex_dtypes] dtypes_strat = st.sampled_from(dtypes) strat = xps.arrays(dtype=dtypes_strat, shape=hh.shapes()) @@ -965,7 +965,7 @@ def test_ceil(x): unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) -if api_version >= COMPLEX_VER: +if api_version >= "2022.12": @given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) def test_conj(x): @@ -1127,7 +1127,7 @@ def test_greater_equal(ctx, data): ) -if api_version >= COMPLEX_VER: +if api_version >= "2022.12": @given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) def test_imag(x): @@ -1378,7 +1378,7 @@ def test_pow(ctx, data): # Values testing pow is too finicky -if api_version >= COMPLEX_VER: +if api_version >= "2022.12": @given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) def test_real(x): From 78c57d0eba3730db9ad22b1ed5a4d5cf1889581f Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 24 Feb 2023 10:12:16 +0000 Subject: [PATCH 12/23] `hypothesis_helper` dtype strats just alias `xps` strats --- array_api_tests/hypothesis_helpers.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 6634b1fc..7c8890a3 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -26,22 +26,15 @@ # work for floating point dtypes as those are assumed to be defined in other # places in the tests. FILTER_UNDEFINED_DTYPES = True - -integer_dtypes = sampled_from(dh.all_int_dtypes) -floating_dtypes = sampled_from(dh.float_dtypes) -numeric_dtypes = sampled_from(dh.numeric_dtypes) -integer_or_boolean_dtypes = sampled_from(dh.bool_and_all_int_dtypes) -boolean_dtypes = just(xp.bool) -dtypes = sampled_from(dh.all_dtypes) - -if FILTER_UNDEFINED_DTYPES: - integer_dtypes = integer_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) - floating_dtypes = floating_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) - numeric_dtypes = numeric_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) - integer_or_boolean_dtypes = integer_or_boolean_dtypes.filter(lambda x: not - isinstance(x, _UndefinedStub)) - boolean_dtypes = boolean_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) - dtypes = dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) +# TODO: currently we assume this to be true - we probably can remove this completely +assert FILTER_UNDEFINED_DTYPES + +integer_dtypes = xps.integer_dtypes() | xps.unsigned_integer_dtypes() +floating_dtypes = xps.floating_dtypes() +numeric_dtypes = xps.numeric_dtypes() +integer_or_boolean_dtypes = xps.boolean_dtypes() | integer_dtypes +boolean_dtypes = xps.boolean_dtypes() +dtypes = xps.scalar_dtypes() shared_dtypes = shared(dtypes, key="dtype") shared_floating_dtypes = shared(floating_dtypes, key="dtype") From 84bd3ef238f9311ecfa36ddb9a4dc1234a2f701a Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 27 Feb 2023 10:12:16 +0000 Subject: [PATCH 13/23] Move oneway strategies to `hypothesis_helpers.py` --- array_api_tests/hypothesis_helpers.py | 42 ++++++++++++++++- array_api_tests/meta/test_utils.py | 11 ++--- array_api_tests/test_array_object.py | 3 +- array_api_tests/test_creation_functions.py | 3 +- ...est_operators_and_elementwise_functions.py | 47 ++----------------- array_api_tests/test_special_cases.py | 8 +--- 6 files changed, 52 insertions(+), 62 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 7c8890a3..04369214 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -4,7 +4,7 @@ from operator import mul from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union -from hypothesis import assume +from hypothesis import assume, reject from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, integers, just, lists, none, one_of, sampled_from, shared) @@ -99,6 +99,46 @@ def mutually_promotable_dtypes( return one_of(strats).map(tuple) +class OnewayPromotableDtypes(NamedTuple): + input_dtype: DataType + result_dtype: DataType + + +@composite +def oneway_promotable_dtypes( + draw, dtypes: Sequence[DataType] +) -> SearchStrategy[OnewayPromotableDtypes]: + """Return a strategy for input dtypes that promote to result dtypes.""" + d1, d2 = draw(mutually_promotable_dtypes(dtypes=dtypes)) + result_dtype = dh.result_type(d1, d2) + if d1 == result_dtype: + return OnewayPromotableDtypes(d2, d1) + elif d2 == result_dtype: + return OnewayPromotableDtypes(d1, d2) + else: + reject() + + +class OnewayBroadcastableShapes(NamedTuple): + input_shape: Shape + result_shape: Shape + + +@composite +def oneway_broadcastable_shapes(draw) -> SearchStrategy[OnewayBroadcastableShapes]: + """Return a strategy for input shapes that broadcast to result shapes.""" + result_shape = draw(shapes(min_side=1)) + input_shape = draw( + xps.broadcastable_shapes( + result_shape, + # Override defaults so bad shapes are less likely to be generated. + max_side=None if result_shape == () else max(result_shape), + max_dims=len(result_shape), + ).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape) + ) + return OnewayBroadcastableShapes(input_shape, result_shape) + + # shared() allows us to draw either the function or the function name and they # will both correspond to the same function. diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 268a81aa..deeab264 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -4,15 +4,12 @@ from .. import _array_module as xp from .. import dtype_helpers as dh +from .. import hypothesis_helpers as hh from .. import shape_helpers as sh from .. import xps from ..test_creation_functions import frange from ..test_manipulation_functions import roll_ndindex -from ..test_operators_and_elementwise_functions import ( - mock_int_dtype, - oneway_broadcastable_shapes, - oneway_promotable_dtypes, -) +from ..test_operators_and_elementwise_functions import mock_int_dtype @pytest.mark.parametrize( @@ -115,11 +112,11 @@ def test_int_to_dtype(x, dtype): assert mock_int_dtype(x, dtype) == d -@given(oneway_promotable_dtypes(dh.all_dtypes)) +@given(hh.oneway_promotable_dtypes(dh.all_dtypes)) def test_oneway_promotable_dtypes(D): assert D.result_dtype == dh.result_type(*D) -@given(oneway_broadcastable_shapes()) +@given(hh.oneway_broadcastable_shapes()) def test_oneway_broadcastable_shapes(S): assert S.result_shape == sh.broadcast_shapes(*S) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index df3edb88..b35976d5 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -12,7 +12,6 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .test_operators_and_elementwise_functions import oneway_promotable_dtypes from .typing import DataType, Index, Param, Scalar, ScalarType, Shape pytestmark = pytest.mark.ci @@ -108,7 +107,7 @@ def test_getitem(shape, dtype, data): @given( shape=hh.shapes(), - dtypes=oneway_promotable_dtypes(dh.all_dtypes), + dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes), data=st.data(), ) def test_setitem(shape, dtypes, data): diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 76a6a072..47025f03 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -12,7 +12,6 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .test_operators_and_elementwise_functions import oneway_promotable_dtypes from .typing import DataType, Scalar pytestmark = pytest.mark.ci @@ -256,7 +255,7 @@ def scalar_eq(s1: Scalar, s2: Scalar) -> bool: @given( shape=hh.shapes(), - dtypes=oneway_promotable_dtypes(dh.all_dtypes), + dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes), data=st.data(), ) def test_asarray_arrays(shape, dtypes, data): diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 7521f56d..691494cb 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -8,9 +8,8 @@ from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union import pytest -from hypothesis import assume, given +from hypothesis import assume, given, reject from hypothesis import strategies as st -from hypothesis.control import reject from . import _array_module as xp, api_version from . import array_helpers as ah @@ -41,46 +40,6 @@ def all_floating_dtypes() -> st.SearchStrategy[DataType]: return strat -class OnewayPromotableDtypes(NamedTuple): - input_dtype: DataType - result_dtype: DataType - - -@st.composite -def oneway_promotable_dtypes( - draw, dtypes: Sequence[DataType] -) -> st.SearchStrategy[OnewayPromotableDtypes]: - """Return a strategy for input dtypes that promote to result dtypes.""" - d1, d2 = draw(hh.mutually_promotable_dtypes(dtypes=dtypes)) - result_dtype = dh.result_type(d1, d2) - if d1 == result_dtype: - return OnewayPromotableDtypes(d2, d1) - elif d2 == result_dtype: - return OnewayPromotableDtypes(d1, d2) - else: - reject() - - -class OnewayBroadcastableShapes(NamedTuple): - input_shape: Shape - result_shape: Shape - - -@st.composite -def oneway_broadcastable_shapes(draw) -> st.SearchStrategy[OnewayBroadcastableShapes]: - """Return a strategy for input shapes that broadcast to result shapes.""" - result_shape = draw(hh.shapes(min_side=1)) - input_shape = draw( - xps.broadcastable_shapes( - result_shape, - # Override defaults so bad shapes are less likely to be generated. - max_side=None if result_shape == () else max(result_shape), - max_dims=len(result_shape), - ).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape) - ) - return OnewayBroadcastableShapes(input_shape, result_shape) - - def mock_int_dtype(n: int, dtype: DataType) -> int: """Returns equivalent of `n` that mocks `dtype` behaviour.""" nbits = dh.dtype_nbits[dtype] @@ -557,7 +516,7 @@ def make_binary_params( ) -> List[Param[BinaryParamContext]]: if hh.FILTER_UNDEFINED_DTYPES: dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] - shared_oneway_dtypes = st.shared(oneway_promotable_dtypes(dtypes)) + shared_oneway_dtypes = st.shared(hh.oneway_promotable_dtypes(dtypes)) left_dtypes = shared_oneway_dtypes.map(lambda D: D.result_dtype) right_dtypes = shared_oneway_dtypes.map(lambda D: D.input_dtype) @@ -576,7 +535,7 @@ def make_param( right_strat = right_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw)) else: if func_type is FuncType.IOP: - shared_oneway_shapes = st.shared(oneway_broadcastable_shapes()) + shared_oneway_shapes = st.shared(hh.oneway_broadcastable_shapes()) left_strat = xps.arrays( dtype=left_dtypes, shape=shared_oneway_shapes.map(lambda S: S.result_shape), diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 9999d9b0..2e4167ce 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -35,10 +35,6 @@ from . import xps from ._array_module import mod as xp from .stubs import category_to_funcs -from .test_operators_and_elementwise_functions import ( - oneway_broadcastable_shapes, - oneway_promotable_dtypes, -) pytestmark = pytest.mark.ci @@ -1281,8 +1277,8 @@ def test_binary(func_name, func, case, x1, x2, data): @pytest.mark.parametrize("iop_name, iop, case", iop_params) @given( - oneway_dtypes=oneway_promotable_dtypes(dh.float_dtypes), - oneway_shapes=oneway_broadcastable_shapes(), + oneway_dtypes=hh.oneway_promotable_dtypes(dh.float_dtypes), + oneway_shapes=hh.oneway_broadcastable_shapes(), data=st.data(), ) def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data): From a9506a8c0cf34c7ec3d1b0c426406c48097d1661 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 27 Feb 2023 11:36:13 +0000 Subject: [PATCH 14/23] Use `cmath` where obvious --- array_api_tests/pytest_helpers.py | 9 +++++---- array_api_tests/test_array_object.py | 3 ++- array_api_tests/test_creation_functions.py | 5 +++-- array_api_tests/test_set_functions.py | 7 ++++--- array_api_tests/test_sorting_functions.py | 4 ++-- array_api_tests/test_statistical_functions.py | 5 +++-- 6 files changed, 19 insertions(+), 14 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 8ba39731..3f478fd9 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -1,3 +1,4 @@ +import cmath import math from inspect import getfullargspec from typing import Any, Dict, Optional, Sequence, Tuple, Union @@ -345,12 +346,12 @@ def assert_scalar_equals( if type_ in [bool, int]: msg = f"{repr_name}={out}, but should be {expected} [{f_func}]" assert out == expected, msg - elif math.isnan(expected): + elif cmath.isnan(expected): msg = f"{repr_name}={out}, but should be {expected} [{f_func}]" - assert math.isnan(out), msg + assert cmath.isnan(out), msg else: msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]" - assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg + assert cmath.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg def assert_fill( @@ -368,7 +369,7 @@ def assert_fill( """ msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}" - if math.isnan(fill_value): + if cmath.isnan(fill_value): assert xp.all(xp.isnan(out)), msg else: assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index b35976d5..4a539fc7 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -1,3 +1,4 @@ +import cmath import math from itertools import product from typing import List, Sequence, Tuple, Union, get_args @@ -135,7 +136,7 @@ def test_setitem(shape, dtypes, data): f_res = sh.fmt_idx("x", key) if isinstance(value, get_args(Scalar)): msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]" - if math.isnan(value): + if cmath.isnan(value): assert xp.isnan(res[key]), msg else: assert res[key] == value, msg diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 47025f03..b733a413 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -1,3 +1,4 @@ +import cmath import math from itertools import count from typing import Iterator, NamedTuple, Union @@ -247,8 +248,8 @@ def test_asarray_scalars(shape, data): def scalar_eq(s1: Scalar, s2: Scalar) -> bool: - if math.isnan(s1): - return math.isnan(s2) + if cmath.isnan(s1): + return cmath.isnan(s2) else: return s1 == s2 diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 5e415858..193087d9 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -1,4 +1,5 @@ # TODO: disable if opted out, refactor things +import cmath import math from collections import Counter, defaultdict @@ -61,7 +62,7 @@ def test_unique_all(x): for idx in sh.ndindex(out.indices.shape): val = scalar_type(out.values[idx]) - if math.isnan(val): + if cmath.isnan(val): break i = int(out.indices[idx]) expected = firsts[val] @@ -88,7 +89,7 @@ def test_unique_all(x): for idx in sh.ndindex(out.values.shape): val = scalar_type(out.values[idx]) count = int(out.counts[idx]) - if math.isnan(val): + if cmath.isnan(val): nans += 1 assert count == 1, ( f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " @@ -225,7 +226,7 @@ def test_unique_values(x): nans = 0 for idx in sh.ndindex(out.shape): val = scalar_type(out[idx]) - if math.isnan(val): + if cmath.isnan(val): nans += 1 else: assert val in distinct, f"out[{idx}]={val}, but {val} not in input array" diff --git a/array_api_tests/test_sorting_functions.py b/array_api_tests/test_sorting_functions.py index 7c5a1411..69149c1b 100644 --- a/array_api_tests/test_sorting_functions.py +++ b/array_api_tests/test_sorting_functions.py @@ -1,4 +1,4 @@ -import math +import cmath from typing import Set import pytest @@ -26,7 +26,7 @@ def assert_scalar_in_set( **kw, ): out_repr = "out" if idx == () else f"out[{idx}]" - if math.isnan(out): + if cmath.isnan(out): raise NotImplementedError() msg = f"{out_repr}={out}, but should be in {set_} [{func_name}({ph.fmt_kw(kw)})]" assert out in set_, msg diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 4371fb07..2d433dc6 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,3 +1,4 @@ +import cmath import math from typing import Optional @@ -162,7 +163,7 @@ def test_prod(x, data): scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): prod = scalar_type(out[out_idx]) - assume(math.isfinite(prod)) + assume(cmath.isfinite(prod)) elements = [] for idx in indices: s = scalar_type(x[idx]) @@ -267,7 +268,7 @@ def test_sum(x, data): scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): sum_ = scalar_type(out[out_idx]) - assume(math.isfinite(sum_)) + assume(cmath.isfinite(sum_)) elements = [] for idx in indices: s = scalar_type(x[idx]) From 12c3aa2a739718e5ac71ca77ae6de63a83266d57 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 27 Feb 2023 11:36:24 +0000 Subject: [PATCH 15/23] Stop testing complex in `test_arange` Also ignore very large distances --- array_api_tests/test_creation_functions.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index b733a413..df034383 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -79,7 +79,8 @@ def reals(min_value=None, max_value=None) -> st.SearchStrategy[Union[int, float] ) -@given(dtype=st.none() | hh.numeric_dtypes, data=st.data()) +# TODO: support testing complex dtypes +@given(dtype=st.none() | xps.real_dtypes(), data=st.data()) def test_arange(dtype, data): if dtype is None or dh.is_float_dtype(dtype): start = data.draw(reals(), label="start") @@ -128,6 +129,12 @@ def test_arange(dtype, data): assert m <= _start <= M assert m <= _stop <= M assert m <= step <= M + # Ignore ridiculous distances so we don't fail like + # + # >>> torch.arange(9132051521638391890, 0, -91320515216383920) + # RuntimeError: invalid size, possible overflow? + # + assume(abs(_start - _stop) < M // 2) r = frange(_start, _stop, step) size = len(r) From 74101de037879d732c537b314ab213b941cdf5d7 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 27 Feb 2023 11:49:46 +0000 Subject: [PATCH 16/23] Remove unnecessary use of `hh.shared_dtypes()` in `test_empty` --- array_api_tests/test_creation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index df034383..fc22406e 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -315,7 +315,7 @@ def test_asarray_arrays(shape, dtypes, data): ), f"{f_out}, but should be {value} after x was mutated" -@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes)) +@given(hh.shapes(), hh.kwargs(dtype=st.none() | xps.scalar_dtypes())) def test_empty(shape, kw): out = xp.empty(shape, **kw) if kw.get("dtype", None) is None: From b41d447a6f1e9e2c9799f86f9a3715a18eac0ac1 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 27 Feb 2023 12:12:50 +0000 Subject: [PATCH 17/23] Support complex in `test_full`, complex dtype utilities --- array_api_tests/dtype_helpers.py | 11 +++++++++++ array_api_tests/pytest_helpers.py | 17 ++++++++++++++++ array_api_tests/test_creation_functions.py | 23 ++++++++++++++++++---- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 87e7823f..430fd529 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -5,6 +5,7 @@ from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union from warnings import warn +from . import api_version from . import _array_module as xp from ._array_module import _UndefinedStub from .stubs import name_to_func @@ -29,6 +30,7 @@ "default_int", "default_uint", "default_float", + "default_complex", "promotion_table", "dtype_nbits", "dtype_signed", @@ -197,6 +199,15 @@ class MinMax(NamedTuple): default_float = xp.asarray(float()).dtype if default_float not in float_dtypes: warn(f"inferred default float is {default_float!r}, which is not a float") + if api_version > "2021.12": + default_complex = xp.asarray(complex()).dtype + if default_complex not in complex_dtypes: + warn( + f"inferred default complex is {default_complex!r}, " + "which is not a complex" + ) + else: + default_complex = None if dtype_nbits[default_int] == 32: default_uint = xp.uint32 else: diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 3f478fd9..0eb34180 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -170,6 +170,23 @@ def assert_default_float(func_name: str, out_dtype: DataType): assert out_dtype == dh.default_float, msg +def assert_default_complex(func_name: str, out_dtype: DataType): + """ + Assert the output dtype is the default complex, e.g. + + >>> out = xp.asarray(4+2j) + >>> assert_default_complex('asarray', out.dtype) + + """ + f_dtype = dh.dtype_to_name[out_dtype] + f_default = dh.dtype_to_name[dh.default_complex] + msg = ( + f"out.dtype={f_dtype}, should be default " + f"complex dtype {f_default} [{func_name}()]" + ) + assert out_dtype == dh.default_complex, msg + + def assert_default_int(func_name: str, out_dtype: DataType): """ Assert the output dtype is the default int, e.g. diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index fc22406e..c37d2ca7 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -369,13 +369,15 @@ def test_eye(n_rows, n_cols, kw): default_unsafe_dtypes.extend([xp.uint32, xp.int64]) if dh.default_float == xp.float32: default_unsafe_dtypes.append(xp.float64) +if dh.default_complex == xp.complex64: + default_unsafe_dtypes.append(xp.complex64) default_safe_dtypes: st.SearchStrategy = xps.scalar_dtypes().filter( lambda d: d not in default_unsafe_dtypes ) @st.composite -def full_fill_values(draw) -> st.SearchStrategy[float]: +def full_fill_values(draw) -> st.SearchStrategy[Union[bool, int, float, complex]]: kw = draw( st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw") ) @@ -396,15 +398,28 @@ def test_full(shape, fill_value, kw): dtype = xp.bool elif isinstance(fill_value, int): dtype = dh.default_int - else: + elif isinstance(fill_value, float): dtype = dh.default_float + else: + assert isinstance(fill_value, complex) # sanity check + dtype = dh.default_complex + # Ignore large components so we don't fail like + # + # >>> torch.fill(complex(0.0, 3.402823466385289e+38)) + # RuntimeError: value cannot be converted to complex without overflow + # + M = dh.dtype_ranges[dh.dtype_components[dtype]].max + assume(all(abs(c) < math.sqrt(M) for c in [fill_value.real, fill_value.imag])) if kw.get("dtype", None) is None: if isinstance(fill_value, bool): - pass # TODO + assert out.dtype == xp.bool, f"{out.dtype=}, but should be bool [full()]" elif isinstance(fill_value, int): ph.assert_default_int("full", out.dtype) - else: + elif isinstance(fill_value, float): ph.assert_default_float("full", out.dtype) + else: + assert isinstance(fill_value, complex) # sanity check + ph.assert_default_complex("full", out.dtype) else: ph.assert_kw_dtype("full", kw["dtype"], out.dtype) ph.assert_shape("full", out.shape, shape, shape=shape) From 90e7837aca8cf55aa924060a82f2b48acb740daa Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 27 Feb 2023 12:27:09 +0000 Subject: [PATCH 18/23] Skip even not-so-very-large distances in `test_linspace` --- array_api_tests/test_creation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index c37d2ca7..2ebd3b07 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -470,7 +470,7 @@ def test_linspace(num, dtype, endpoint, data): assume(not xp.isnan(xp.asarray(start - stop, dtype=_dtype))) # avoid generating very large distances # https://github.com/data-apis/array-api-tests/issues/125 - assume(abs(stop - start) < dh.dtype_ranges[_dtype].max) + assume(abs(stop - start) < math.sqrt(dh.dtype_ranges[_dtype].max)) kw = data.draw( hh.specified_kwargs( From 05f2cf9930415d0498da7c0bcafb103fa5f7f5f8 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 27 Feb 2023 12:35:13 +0000 Subject: [PATCH 19/23] Skip testing complex dtypes in `test_data_type_functions.py` for now --- array_api_tests/test_data_type_functions.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 115ec9b9..5cd409ce 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -16,13 +16,18 @@ pytestmark = pytest.mark.ci +# TODO: test with complex dtypes +def non_complex_dtypes(): + return xps.boolean_dtypes() | xps.real_dtypes() + + def float32(n: Union[int, float]) -> float: return struct.unpack("!f", struct.pack("!f", float(n)))[0] @given( - x_dtype=xps.scalar_dtypes(), - dtype=xps.scalar_dtypes(), + x_dtype=non_complex_dtypes(), + dtype=non_complex_dtypes(), kw=hh.kwargs(copy=st.booleans()), data=st.data(), ) @@ -101,7 +106,7 @@ def test_broadcast_to(x, data): # TODO: test values -@given(_from=xps.scalar_dtypes(), to=xps.scalar_dtypes(), data=st.data()) +@given(_from=non_complex_dtypes(), to=non_complex_dtypes(), data=st.data()) def test_can_cast(_from, to, data): from_ = data.draw( st.just(_from) | xps.arrays(dtype=_from, shape=hh.shapes()), label="from_" @@ -114,10 +119,12 @@ def test_can_cast(_from, to, data): if _from == xp.bool: expected = to == xp.bool else: - for dtypes in [dh.all_int_dtypes, dh.float_dtypes]: + same_family = None + for dtypes in [dh.all_int_dtypes, dh.float_dtypes, dh.complex_dtypes]: if _from in dtypes: same_family = to in dtypes break + assert same_family is not None # sanity check if same_family: from_min, from_max = dh.dtype_ranges[_from] to_min, to_max = dh.dtype_ranges[to] From 8922b83eff752b1da0e4d91a54485bd769b5dc86 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 27 Feb 2023 12:41:57 +0000 Subject: [PATCH 20/23] Skip testing complex numbers in `test_linalg.py` for now --- array_api_tests/test_linalg.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 321263d3..cc07e6b4 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -12,6 +12,7 @@ required, but we don't yet have a clean way to disable only those tests (see https://github.com/data-apis/array-api-tests/issues/25). """ +# TODO: test with complex dtypes where appropiate import pytest from hypothesis import assume, given @@ -20,7 +21,7 @@ from ndindex import iter_indices from .array_helpers import assert_exactly_equal, asarray -from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes, +from .hypothesis_helpers import (xps, shapes, kwargs, matrix_shapes, square_matrix_shapes, symmetric_matrices, positive_definite_matrices, MAX_ARRAY_SIZE, invertible_matrices, two_mutual_arrays, @@ -117,7 +118,7 @@ def test_cholesky(x, kw): @composite -def cross_args(draw, dtype_objects=dh.numeric_dtypes): +def cross_args(draw, dtype_objects=dh.real_dtypes): """ cross() requires two arrays with a size 3 in the 'axis' dimension @@ -192,7 +193,7 @@ def test_det(x): @pytest.mark.xp_extension('linalg') @given( - x=xps.arrays(dtype=dtypes, shape=matrix_shapes()), + x=xps.arrays(dtype=xps.real_dtypes(), shape=matrix_shapes()), # offset may produce an overflow if it is too large. Supporting offsets # that are way larger than the array shape isn't very important. kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE)) @@ -277,7 +278,7 @@ def test_inv(x): # TODO: Test that the result is actually the inverse @given( - *two_mutual_arrays(dh.numeric_dtypes) + *two_mutual_arrays(dh.real_dtypes) ) def test_matmul(x1, x2): # TODO: Make this also test the @ operator @@ -366,7 +367,7 @@ def test_matrix_rank(x, kw): linalg.matrix_rank(x, **kw) @given( - x=xps.arrays(dtype=dtypes, shape=matrix_shapes()), + x=xps.arrays(dtype=xps.real_dtypes(), shape=matrix_shapes()), ) def test_matrix_transpose(x): res = _array_module.matrix_transpose(x) @@ -384,7 +385,7 @@ def test_matrix_transpose(x): @pytest.mark.xp_extension('linalg') @given( - *two_mutual_arrays(dtypes=dh.numeric_dtypes, + *two_mutual_arrays(dtypes=dh.real_dtypes, two_shapes=tuples(one_d_shapes, one_d_shapes)) ) def test_outer(x1, x2): @@ -573,7 +574,7 @@ def test_svdvals(x): @given( - dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes), + dtypes=mutually_promotable_dtypes(dtypes=dh.real_dtypes), shape=shapes(), data=data(), ) @@ -590,7 +591,7 @@ def test_tensordot(dtypes, shape, data): @pytest.mark.xp_extension('linalg') @given( - x=xps.arrays(dtype=xps.numeric_dtypes(), shape=matrix_shapes()), + x=xps.arrays(dtype=xps.real_dtypes(), shape=matrix_shapes()), # offset may produce an overflow if it is too large. Supporting offsets # that are way larger than the array shape isn't very important. kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE)) @@ -629,7 +630,7 @@ def true_trace(x_stack): @given( - dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes), + dtypes=mutually_promotable_dtypes(dtypes=dh.real_dtypes), shape=shapes(min_dims=1), data=data(), ) From 33a0f6cc1893f498cee4f95d58218fd6455c6ac3 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 27 Feb 2023 12:44:06 +0000 Subject: [PATCH 21/23] Remove debug `print` statement in `test_manipulation_functions.py` --- array_api_tests/test_manipulation_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index a30a0030..c5f19633 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -350,7 +350,6 @@ def test_stack(shape, dtypes, kw, data): out_indices = sh.ndindex(out.shape) for idx in sh.axis_ndindex(arrays[0].shape, axis=_axis): f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) - print(f"{f_idx=}") for x_num, x in enumerate(arrays, 1): indexed_x = x[idx] for x_idx in sh.ndindex(indexed_x.shape): From d2267e4a9fdf8be1dc1c3325e1cfcf3d7db734df Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 27 Feb 2023 18:36:29 +0000 Subject: [PATCH 22/23] Change dtype helpers behaviour depending on `api_version` --- array_api_tests/dtype_helpers.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 430fd529..fb167168 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -103,9 +103,13 @@ def __repr__(self): all_int_dtypes = uint_dtypes + int_dtypes real_dtypes = all_int_dtypes + float_dtypes complex_dtypes = tuple(getattr(xp, name) for name in _complex_names) -numeric_dtypes = real_dtypes + complex_dtypes +numeric_dtypes = real_dtypes +if api_version > "2021.12": + numeric_dtypes += complex_dtypes all_dtypes = (xp.bool,) + numeric_dtypes -all_float_dtypes = float_dtypes + complex_dtypes +all_float_dtypes = float_dtypes +if api_version > "2021.12": + all_float_dtypes += complex_dtypes bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes @@ -132,7 +136,10 @@ def is_float_dtype(dtype): # See https://github.com/numpy/numpy/issues/18434 if dtype is None: return False - return dtype in float_dtypes + valid_dtypes = float_dtypes + if api_version > "2021.12": + valid_dtypes += complex_dtypes + return dtype in valid_dtypes def get_scalar_type(dtype: DataType) -> ScalarType: From ef0e3b1672297f16b57f34eeda238a310c12ecd9 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 27 Feb 2023 19:20:58 +0000 Subject: [PATCH 23/23] Update type hints relating to `complex` --- array_api_tests/typing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_tests/typing.py b/array_api_tests/typing.py index f0ed8c50..84311ff3 100644 --- a/array_api_tests/typing.py +++ b/array_api_tests/typing.py @@ -12,8 +12,8 @@ ] DataType = Type[Any] -Scalar = Union[bool, int, float] -ScalarType = Union[Type[bool], Type[int], Type[float]] +Scalar = Union[bool, int, float, complex] +ScalarType = Union[Type[bool], Type[int], Type[float], Type[complex]] Array = Any Shape = Tuple[int, ...] AtomicIndex = Union[int, "ellipsis", slice, None] # noqa