diff --git a/README.md b/README.md index 9eebc397..dcfe6c9f 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,13 @@ library to fail. ### Configuration +#### API version + +You can specify the API version to use when testing via the +`ARRAY_API_TESTS_VERSION` environment variable. Currently this defaults to the +array module's `__array_api_version__` value, and if that attribute doesn't +exist then we fallback to `"2021.12"`. + #### CI flag Use the `--ci` flag to run only the primary and special cases tests. You can diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index c472b862..e083d522 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -1,11 +1,13 @@ from functools import wraps +from os import getenv from hypothesis import strategies as st from hypothesis.extra import array_api +from . import _version from ._array_module import mod as _xp -__all__ = ["xps"] +__all__ = ["api_version", "xps"] # We monkey patch floats() to always disable subnormals as they are out-of-scope @@ -41,9 +43,9 @@ def _from_dtype(*a, **kw): pass -xps = array_api.make_strategies_namespace(_xp, api_version="2021.12") - - -from . import _version +api_version = getenv( + "ARRAY_API_TESTS_VERSION", getattr(_xp, "__array_api_version__", "2021.12") +) +xps = array_api.make_strategies_namespace(_xp, api_version=api_version) __version__ = _version.get_versions()["version"] diff --git a/array_api_tests/_array_module.py b/array_api_tests/_array_module.py index e83cd6ca..b4aaf76c 100644 --- a/array_api_tests/_array_module.py +++ b/array_api_tests/_array_module.py @@ -58,6 +58,7 @@ def __repr__(self): "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "float32", "float64", + "complex64", "complex128", ] _constants = ["e", "inf", "nan", "pi"] _funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs] diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 1527611c..fb167168 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -5,6 +5,7 @@ from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union from warnings import warn +from . import api_version from . import _array_module as xp from ._array_module import _UndefinedStub from .stubs import name_to_func @@ -15,10 +16,12 @@ "uint_dtypes", "all_int_dtypes", "float_dtypes", + "real_dtypes", "numeric_dtypes", "all_dtypes", - "dtype_to_name", + "all_float_dtypes", "bool_and_all_int_dtypes", + "dtype_to_name", "dtype_to_scalars", "is_int_dtype", "is_float_dtype", @@ -27,9 +30,11 @@ "default_int", "default_uint", "default_float", + "default_complex", "promotion_table", "dtype_nbits", "dtype_signed", + "dtype_components", "func_in_dtypes", "func_returns_bool", "binary_op_to_symbol", @@ -86,15 +91,25 @@ def __repr__(self): _uint_names = ("uint8", "uint16", "uint32", "uint64") _int_names = ("int8", "int16", "int32", "int64") _float_names = ("float32", "float64") -_dtype_names = ("bool",) + _uint_names + _int_names + _float_names +_real_names = _uint_names + _int_names + _float_names +_complex_names = ("complex64", "complex128") +_numeric_names = _real_names + _complex_names +_dtype_names = ("bool",) + _numeric_names uint_dtypes = tuple(getattr(xp, name) for name in _uint_names) int_dtypes = tuple(getattr(xp, name) for name in _int_names) float_dtypes = tuple(getattr(xp, name) for name in _float_names) all_int_dtypes = uint_dtypes + int_dtypes -numeric_dtypes = all_int_dtypes + float_dtypes +real_dtypes = all_int_dtypes + float_dtypes +complex_dtypes = tuple(getattr(xp, name) for name in _complex_names) +numeric_dtypes = real_dtypes +if api_version > "2021.12": + numeric_dtypes += complex_dtypes all_dtypes = (xp.bool,) + numeric_dtypes +all_float_dtypes = float_dtypes +if api_version > "2021.12": + all_float_dtypes += complex_dtypes bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes @@ -121,7 +136,10 @@ def is_float_dtype(dtype): # See https://github.com/numpy/numpy/issues/18434 if dtype is None: return False - return dtype in float_dtypes + valid_dtypes = float_dtypes + if api_version > "2021.12": + valid_dtypes += complex_dtypes + return dtype in valid_dtypes def get_scalar_type(dtype: DataType) -> ScalarType: @@ -129,6 +147,8 @@ def get_scalar_type(dtype: DataType) -> ScalarType: return int elif is_float_dtype(dtype): return float + elif dtype in complex_dtypes: + return complex else: return bool @@ -157,7 +177,8 @@ class MinMax(NamedTuple): [(d, 8) for d in [xp.int8, xp.uint8]] + [(d, 16) for d in [xp.int16, xp.uint16]] + [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]] - + [(d, 64) for d in [xp.int64, xp.uint64, xp.float64]] + + [(d, 64) for d in [xp.int64, xp.uint64, xp.float64, xp.complex64]] + + [(xp.complex128, 128)] ) @@ -166,6 +187,11 @@ class MinMax(NamedTuple): ) +dtype_components = EqualityMapping( + [(xp.complex64, xp.float32), (xp.complex128, xp.float64)] +) + + if isinstance(xp.asarray, _UndefinedStub): default_int = xp.int32 default_float = xp.float32 @@ -180,6 +206,15 @@ class MinMax(NamedTuple): default_float = xp.asarray(float()).dtype if default_float not in float_dtypes: warn(f"inferred default float is {default_float!r}, which is not a float") + if api_version > "2021.12": + default_complex = xp.asarray(complex()).dtype + if default_complex not in complex_dtypes: + warn( + f"inferred default complex is {default_complex!r}, " + "which is not a complex" + ) + else: + default_complex = None if dtype_nbits[default_int] == 32: default_uint = xp.uint32 else: @@ -226,6 +261,11 @@ class MinMax(NamedTuple): ((xp.float32, xp.float32), xp.float32), ((xp.float32, xp.float64), xp.float64), ((xp.float64, xp.float64), xp.float64), + # complex + ((xp.complex64, xp.complex64), xp.complex64), + ((xp.complex64, xp.complex128), xp.complex128), + ((xp.complex128, xp.complex128), xp.complex128), + ] _numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions] _promotion_table = list(set(_numeric_promotions)) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 20cc0e03..04369214 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -4,7 +4,7 @@ from operator import mul from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union -from hypothesis import assume +from hypothesis import assume, reject from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, integers, just, lists, none, one_of, sampled_from, shared) @@ -26,27 +26,20 @@ # work for floating point dtypes as those are assumed to be defined in other # places in the tests. FILTER_UNDEFINED_DTYPES = True +# TODO: currently we assume this to be true - we probably can remove this completely +assert FILTER_UNDEFINED_DTYPES -integer_dtypes = sampled_from(dh.all_int_dtypes) -floating_dtypes = sampled_from(dh.float_dtypes) -numeric_dtypes = sampled_from(dh.numeric_dtypes) -integer_or_boolean_dtypes = sampled_from(dh.bool_and_all_int_dtypes) -boolean_dtypes = just(xp.bool) -dtypes = sampled_from(dh.all_dtypes) - -if FILTER_UNDEFINED_DTYPES: - integer_dtypes = integer_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) - floating_dtypes = floating_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) - numeric_dtypes = numeric_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) - integer_or_boolean_dtypes = integer_or_boolean_dtypes.filter(lambda x: not - isinstance(x, _UndefinedStub)) - boolean_dtypes = boolean_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) - dtypes = dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) +integer_dtypes = xps.integer_dtypes() | xps.unsigned_integer_dtypes() +floating_dtypes = xps.floating_dtypes() +numeric_dtypes = xps.numeric_dtypes() +integer_or_boolean_dtypes = xps.boolean_dtypes() | integer_dtypes +boolean_dtypes = xps.boolean_dtypes() +dtypes = xps.scalar_dtypes() shared_dtypes = shared(dtypes, key="dtype") shared_floating_dtypes = shared(floating_dtypes, key="dtype") -_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes] +_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes, dh.complex_dtypes] _sorted_dtypes = [d for category in _dtype_categories for d in category] def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]): @@ -106,6 +99,46 @@ def mutually_promotable_dtypes( return one_of(strats).map(tuple) +class OnewayPromotableDtypes(NamedTuple): + input_dtype: DataType + result_dtype: DataType + + +@composite +def oneway_promotable_dtypes( + draw, dtypes: Sequence[DataType] +) -> SearchStrategy[OnewayPromotableDtypes]: + """Return a strategy for input dtypes that promote to result dtypes.""" + d1, d2 = draw(mutually_promotable_dtypes(dtypes=dtypes)) + result_dtype = dh.result_type(d1, d2) + if d1 == result_dtype: + return OnewayPromotableDtypes(d2, d1) + elif d2 == result_dtype: + return OnewayPromotableDtypes(d1, d2) + else: + reject() + + +class OnewayBroadcastableShapes(NamedTuple): + input_shape: Shape + result_shape: Shape + + +@composite +def oneway_broadcastable_shapes(draw) -> SearchStrategy[OnewayBroadcastableShapes]: + """Return a strategy for input shapes that broadcast to result shapes.""" + result_shape = draw(shapes(min_side=1)) + input_shape = draw( + xps.broadcastable_shapes( + result_shape, + # Override defaults so bad shapes are less likely to be generated. + max_side=None if result_shape == () else max(result_shape), + max_dims=len(result_shape), + ).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape) + ) + return OnewayBroadcastableShapes(input_shape, result_shape) + + # shared() allows us to draw either the function or the function name and they # will both correspond to the same function. diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 268a81aa..deeab264 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -4,15 +4,12 @@ from .. import _array_module as xp from .. import dtype_helpers as dh +from .. import hypothesis_helpers as hh from .. import shape_helpers as sh from .. import xps from ..test_creation_functions import frange from ..test_manipulation_functions import roll_ndindex -from ..test_operators_and_elementwise_functions import ( - mock_int_dtype, - oneway_broadcastable_shapes, - oneway_promotable_dtypes, -) +from ..test_operators_and_elementwise_functions import mock_int_dtype @pytest.mark.parametrize( @@ -115,11 +112,11 @@ def test_int_to_dtype(x, dtype): assert mock_int_dtype(x, dtype) == d -@given(oneway_promotable_dtypes(dh.all_dtypes)) +@given(hh.oneway_promotable_dtypes(dh.all_dtypes)) def test_oneway_promotable_dtypes(D): assert D.result_dtype == dh.result_type(*D) -@given(oneway_broadcastable_shapes()) +@given(hh.oneway_broadcastable_shapes()) def test_oneway_broadcastable_shapes(S): assert S.result_shape == sh.broadcast_shapes(*S) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 051a063f..0eb34180 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -1,3 +1,4 @@ +import cmath import math from inspect import getfullargspec from typing import Any, Dict, Optional, Sequence, Tuple, Union @@ -169,6 +170,23 @@ def assert_default_float(func_name: str, out_dtype: DataType): assert out_dtype == dh.default_float, msg +def assert_default_complex(func_name: str, out_dtype: DataType): + """ + Assert the output dtype is the default complex, e.g. + + >>> out = xp.asarray(4+2j) + >>> assert_default_complex('asarray', out.dtype) + + """ + f_dtype = dh.dtype_to_name[out_dtype] + f_default = dh.dtype_to_name[dh.default_complex] + msg = ( + f"out.dtype={f_dtype}, should be default " + f"complex dtype {f_default} [{func_name}()]" + ) + assert out_dtype == dh.default_complex, msg + + def assert_default_int(func_name: str, out_dtype: DataType): """ Assert the output dtype is the default int, e.g. @@ -345,12 +363,12 @@ def assert_scalar_equals( if type_ in [bool, int]: msg = f"{repr_name}={out}, but should be {expected} [{f_func}]" assert out == expected, msg - elif math.isnan(expected): + elif cmath.isnan(expected): msg = f"{repr_name}={out}, but should be {expected} [{f_func}]" - assert math.isnan(out), msg + assert cmath.isnan(out), msg else: msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]" - assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg + assert cmath.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg def assert_fill( @@ -368,12 +386,27 @@ def assert_fill( """ msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}" - if math.isnan(fill_value): + if cmath.isnan(fill_value): assert xp.all(xp.isnan(out)), msg else: assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg +def _assert_float_element(at_out: Array, at_expected: Array, msg: str): + if xp.isnan(at_expected): + assert xp.isnan(at_out), msg + elif at_expected == 0.0 or at_expected == -0.0: + scalar_at_expected = float(at_expected) + scalar_at_out = float(at_out) + if is_pos_zero(scalar_at_expected): + assert is_pos_zero(scalar_at_out), msg + else: + assert is_neg_zero(scalar_at_expected) # sanity check + assert is_neg_zero(scalar_at_out), msg + else: + assert at_out == at_expected, msg + + def assert_array_elements( func_name: str, out: Array, expected: Array, /, *, out_repr: str = "out", **kw ): @@ -392,7 +425,17 @@ def assert_array_elements( dh.result_type(out.dtype, expected.dtype) # sanity check assert_shape(func_name, out.shape, expected.shape, **kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" - if dh.is_float_dtype(out.dtype): + if out.dtype in dh.float_dtypes: + for idx in sh.ndindex(out.shape): + at_out = out[idx] + at_expected = expected[idx] + msg = ( + f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " + f"{f_func}" + ) + _assert_float_element(at_out, at_expected, msg) + elif out.dtype in dh.complex_dtypes: + assert (out.dtype in dh.complex_dtypes) == (expected.dtype in dh.complex_dtypes) for idx in sh.ndindex(out.shape): at_out = out[idx] at_expected = expected[idx] @@ -400,18 +443,8 @@ def assert_array_elements( f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " f"{f_func}" ) - if xp.isnan(at_expected): - assert xp.isnan(at_out), msg - elif at_expected == 0.0 or at_expected == -0.0: - scalar_at_expected = float(at_expected) - scalar_at_out = float(at_out) - if is_pos_zero(scalar_at_expected): - assert is_pos_zero(scalar_at_out), msg - else: - assert is_neg_zero(scalar_at_expected) # sanity check - assert is_neg_zero(scalar_at_out), msg - else: - assert at_out == at_expected, msg + _assert_float_element(at_out.real, at_expected.real, msg) + _assert_float_element(at_out.imag, at_expected.imag, msg) else: assert xp.all( out == expected diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index df3edb88..4a539fc7 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -1,3 +1,4 @@ +import cmath import math from itertools import product from typing import List, Sequence, Tuple, Union, get_args @@ -12,7 +13,6 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .test_operators_and_elementwise_functions import oneway_promotable_dtypes from .typing import DataType, Index, Param, Scalar, ScalarType, Shape pytestmark = pytest.mark.ci @@ -108,7 +108,7 @@ def test_getitem(shape, dtype, data): @given( shape=hh.shapes(), - dtypes=oneway_promotable_dtypes(dh.all_dtypes), + dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes), data=st.data(), ) def test_setitem(shape, dtypes, data): @@ -136,7 +136,7 @@ def test_setitem(shape, dtypes, data): f_res = sh.fmt_idx("x", key) if isinstance(value, get_args(Scalar)): msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]" - if math.isnan(value): + if cmath.isnan(value): assert xp.isnan(res[key]), msg else: assert res[key] == value, msg diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 76a6a072..2ebd3b07 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -1,3 +1,4 @@ +import cmath import math from itertools import count from typing import Iterator, NamedTuple, Union @@ -12,7 +13,6 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .test_operators_and_elementwise_functions import oneway_promotable_dtypes from .typing import DataType, Scalar pytestmark = pytest.mark.ci @@ -79,7 +79,8 @@ def reals(min_value=None, max_value=None) -> st.SearchStrategy[Union[int, float] ) -@given(dtype=st.none() | hh.numeric_dtypes, data=st.data()) +# TODO: support testing complex dtypes +@given(dtype=st.none() | xps.real_dtypes(), data=st.data()) def test_arange(dtype, data): if dtype is None or dh.is_float_dtype(dtype): start = data.draw(reals(), label="start") @@ -128,6 +129,12 @@ def test_arange(dtype, data): assert m <= _start <= M assert m <= _stop <= M assert m <= step <= M + # Ignore ridiculous distances so we don't fail like + # + # >>> torch.arange(9132051521638391890, 0, -91320515216383920) + # RuntimeError: invalid size, possible overflow? + # + assume(abs(_start - _stop) < M // 2) r = frange(_start, _stop, step) size = len(r) @@ -248,15 +255,15 @@ def test_asarray_scalars(shape, data): def scalar_eq(s1: Scalar, s2: Scalar) -> bool: - if math.isnan(s1): - return math.isnan(s2) + if cmath.isnan(s1): + return cmath.isnan(s2) else: return s1 == s2 @given( shape=hh.shapes(), - dtypes=oneway_promotable_dtypes(dh.all_dtypes), + dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes), data=st.data(), ) def test_asarray_arrays(shape, dtypes, data): @@ -308,7 +315,7 @@ def test_asarray_arrays(shape, dtypes, data): ), f"{f_out}, but should be {value} after x was mutated" -@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes)) +@given(hh.shapes(), hh.kwargs(dtype=st.none() | xps.scalar_dtypes())) def test_empty(shape, kw): out = xp.empty(shape, **kw) if kw.get("dtype", None) is None: @@ -362,13 +369,15 @@ def test_eye(n_rows, n_cols, kw): default_unsafe_dtypes.extend([xp.uint32, xp.int64]) if dh.default_float == xp.float32: default_unsafe_dtypes.append(xp.float64) +if dh.default_complex == xp.complex64: + default_unsafe_dtypes.append(xp.complex64) default_safe_dtypes: st.SearchStrategy = xps.scalar_dtypes().filter( lambda d: d not in default_unsafe_dtypes ) @st.composite -def full_fill_values(draw) -> st.SearchStrategy[float]: +def full_fill_values(draw) -> st.SearchStrategy[Union[bool, int, float, complex]]: kw = draw( st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw") ) @@ -389,15 +398,28 @@ def test_full(shape, fill_value, kw): dtype = xp.bool elif isinstance(fill_value, int): dtype = dh.default_int - else: + elif isinstance(fill_value, float): dtype = dh.default_float + else: + assert isinstance(fill_value, complex) # sanity check + dtype = dh.default_complex + # Ignore large components so we don't fail like + # + # >>> torch.fill(complex(0.0, 3.402823466385289e+38)) + # RuntimeError: value cannot be converted to complex without overflow + # + M = dh.dtype_ranges[dh.dtype_components[dtype]].max + assume(all(abs(c) < math.sqrt(M) for c in [fill_value.real, fill_value.imag])) if kw.get("dtype", None) is None: if isinstance(fill_value, bool): - pass # TODO + assert out.dtype == xp.bool, f"{out.dtype=}, but should be bool [full()]" elif isinstance(fill_value, int): ph.assert_default_int("full", out.dtype) - else: + elif isinstance(fill_value, float): ph.assert_default_float("full", out.dtype) + else: + assert isinstance(fill_value, complex) # sanity check + ph.assert_default_complex("full", out.dtype) else: ph.assert_kw_dtype("full", kw["dtype"], out.dtype) ph.assert_shape("full", out.shape, shape, shape=shape) @@ -448,7 +470,7 @@ def test_linspace(num, dtype, endpoint, data): assume(not xp.isnan(xp.asarray(start - stop, dtype=_dtype))) # avoid generating very large distances # https://github.com/data-apis/array-api-tests/issues/125 - assume(abs(stop - start) < dh.dtype_ranges[_dtype].max) + assume(abs(stop - start) < math.sqrt(dh.dtype_ranges[_dtype].max)) kw = data.draw( hh.specified_kwargs( diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 115ec9b9..5cd409ce 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -16,13 +16,18 @@ pytestmark = pytest.mark.ci +# TODO: test with complex dtypes +def non_complex_dtypes(): + return xps.boolean_dtypes() | xps.real_dtypes() + + def float32(n: Union[int, float]) -> float: return struct.unpack("!f", struct.pack("!f", float(n)))[0] @given( - x_dtype=xps.scalar_dtypes(), - dtype=xps.scalar_dtypes(), + x_dtype=non_complex_dtypes(), + dtype=non_complex_dtypes(), kw=hh.kwargs(copy=st.booleans()), data=st.data(), ) @@ -101,7 +106,7 @@ def test_broadcast_to(x, data): # TODO: test values -@given(_from=xps.scalar_dtypes(), to=xps.scalar_dtypes(), data=st.data()) +@given(_from=non_complex_dtypes(), to=non_complex_dtypes(), data=st.data()) def test_can_cast(_from, to, data): from_ = data.draw( st.just(_from) | xps.arrays(dtype=_from, shape=hh.shapes()), label="from_" @@ -114,10 +119,12 @@ def test_can_cast(_from, to, data): if _from == xp.bool: expected = to == xp.bool else: - for dtypes in [dh.all_int_dtypes, dh.float_dtypes]: + same_family = None + for dtypes in [dh.all_int_dtypes, dh.float_dtypes, dh.complex_dtypes]: if _from in dtypes: same_family = to in dtypes break + assert same_family is not None # sanity check if same_family: from_min, from_max = dh.dtype_ranges[_from] to_min, to_max = dh.dtype_ranges[to] diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 321263d3..cc07e6b4 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -12,6 +12,7 @@ required, but we don't yet have a clean way to disable only those tests (see https://github.com/data-apis/array-api-tests/issues/25). """ +# TODO: test with complex dtypes where appropiate import pytest from hypothesis import assume, given @@ -20,7 +21,7 @@ from ndindex import iter_indices from .array_helpers import assert_exactly_equal, asarray -from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes, +from .hypothesis_helpers import (xps, shapes, kwargs, matrix_shapes, square_matrix_shapes, symmetric_matrices, positive_definite_matrices, MAX_ARRAY_SIZE, invertible_matrices, two_mutual_arrays, @@ -117,7 +118,7 @@ def test_cholesky(x, kw): @composite -def cross_args(draw, dtype_objects=dh.numeric_dtypes): +def cross_args(draw, dtype_objects=dh.real_dtypes): """ cross() requires two arrays with a size 3 in the 'axis' dimension @@ -192,7 +193,7 @@ def test_det(x): @pytest.mark.xp_extension('linalg') @given( - x=xps.arrays(dtype=dtypes, shape=matrix_shapes()), + x=xps.arrays(dtype=xps.real_dtypes(), shape=matrix_shapes()), # offset may produce an overflow if it is too large. Supporting offsets # that are way larger than the array shape isn't very important. kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE)) @@ -277,7 +278,7 @@ def test_inv(x): # TODO: Test that the result is actually the inverse @given( - *two_mutual_arrays(dh.numeric_dtypes) + *two_mutual_arrays(dh.real_dtypes) ) def test_matmul(x1, x2): # TODO: Make this also test the @ operator @@ -366,7 +367,7 @@ def test_matrix_rank(x, kw): linalg.matrix_rank(x, **kw) @given( - x=xps.arrays(dtype=dtypes, shape=matrix_shapes()), + x=xps.arrays(dtype=xps.real_dtypes(), shape=matrix_shapes()), ) def test_matrix_transpose(x): res = _array_module.matrix_transpose(x) @@ -384,7 +385,7 @@ def test_matrix_transpose(x): @pytest.mark.xp_extension('linalg') @given( - *two_mutual_arrays(dtypes=dh.numeric_dtypes, + *two_mutual_arrays(dtypes=dh.real_dtypes, two_shapes=tuples(one_d_shapes, one_d_shapes)) ) def test_outer(x1, x2): @@ -573,7 +574,7 @@ def test_svdvals(x): @given( - dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes), + dtypes=mutually_promotable_dtypes(dtypes=dh.real_dtypes), shape=shapes(), data=data(), ) @@ -590,7 +591,7 @@ def test_tensordot(dtypes, shape, data): @pytest.mark.xp_extension('linalg') @given( - x=xps.arrays(dtype=xps.numeric_dtypes(), shape=matrix_shapes()), + x=xps.arrays(dtype=xps.real_dtypes(), shape=matrix_shapes()), # offset may produce an overflow if it is too large. Supporting offsets # that are way larger than the array shape isn't very important. kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE)) @@ -629,7 +630,7 @@ def true_trace(x_stack): @given( - dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes), + dtypes=mutually_promotable_dtypes(dtypes=dh.real_dtypes), shape=shapes(min_dims=1), data=data(), ) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index a30a0030..c5f19633 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -350,7 +350,6 @@ def test_stack(shape, dtypes, kw, data): out_indices = sh.ndindex(out.shape) for idx in sh.axis_ndindex(arrays[0].shape, axis=_axis): f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) - print(f"{f_idx=}") for x_num, x in enumerate(arrays, 1): indexed_x = x[idx] for x_idx in sh.ndindex(indexed_x.shape): diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 967a43a6..691494cb 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -3,15 +3,15 @@ """ import math import operator +from copy import copy from enum import Enum, auto from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union import pytest -from hypothesis import assume, given +from hypothesis import assume, given, reject from hypothesis import strategies as st -from hypothesis.control import reject -from . import _array_module as xp +from . import _array_module as xp, api_version from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh @@ -33,44 +33,11 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: return xps.boolean_dtypes() | all_integer_dtypes() -class OnewayPromotableDtypes(NamedTuple): - input_dtype: DataType - result_dtype: DataType - - -@st.composite -def oneway_promotable_dtypes( - draw, dtypes: Sequence[DataType] -) -> st.SearchStrategy[OnewayPromotableDtypes]: - """Return a strategy for input dtypes that promote to result dtypes.""" - d1, d2 = draw(hh.mutually_promotable_dtypes(dtypes=dtypes)) - result_dtype = dh.result_type(d1, d2) - if d1 == result_dtype: - return OnewayPromotableDtypes(d2, d1) - elif d2 == result_dtype: - return OnewayPromotableDtypes(d1, d2) - else: - reject() - - -class OnewayBroadcastableShapes(NamedTuple): - input_shape: Shape - result_shape: Shape - - -@st.composite -def oneway_broadcastable_shapes(draw) -> st.SearchStrategy[OnewayBroadcastableShapes]: - """Return a strategy for input shapes that broadcast to result shapes.""" - result_shape = draw(hh.shapes(min_side=1)) - input_shape = draw( - xps.broadcastable_shapes( - result_shape, - # Override defaults so bad shapes are less likely to be generated. - max_side=None if result_shape == () else max(result_shape), - max_dims=len(result_shape), - ).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape) - ) - return OnewayBroadcastableShapes(input_shape, result_shape) +def all_floating_dtypes() -> st.SearchStrategy[DataType]: + strat = xps.floating_dtypes() + if api_version >= "2022.12": + strat |= xps.complex_dtypes() + return strat def mock_int_dtype(n: int, dtype: DataType) -> int: @@ -85,14 +52,25 @@ def mock_int_dtype(n: int, dtype: DataType) -> int: return n -def isclose(a: float, b: float, *, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: +def isclose( + a: float, + b: float, + M: float, + *, + rel_tol: float = 0.25, + abs_tol: float = 1, +) -> bool: """Wraps math.isclose with very generous defaults. This is useful for many floating-point operations where the spec does not make accuracy requirements. """ - if not (math.isfinite(a) and math.isfinite(b)): - raise ValueError(f"{a=} and {b=}, but input must be finite") + if math.isnan(a) or math.isnan(b): + raise ValueError(f"{a=} and {b=}, but input must be non-NaN") + if math.isinf(a): + return math.isinf(b) or abs(b) > math.log(M) + elif math.isinf(b): + return math.isinf(a) or abs(a) > math.log(M) return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) @@ -246,8 +224,16 @@ def unary_assert_against_refimpl( expr_template = func_name + "({})={}" in_stype = dh.get_scalar_type(in_.dtype) if res_stype is None: - res_stype = in_stype - m, M = dh.dtype_ranges.get(res.dtype, (None, None)) + res_stype = dh.get_scalar_type(res.dtype) + if res.dtype == xp.bool: + m, M = (None, None) + elif res.dtype in dh.complex_dtypes: + m, M = dh.dtype_ranges[dh.dtype_components[res.dtype]] + else: + m, M = dh.dtype_ranges[res.dtype] + if in_.dtype in dh.complex_dtypes: + component_filter = copy(filter_) + filter_ = lambda s: component_filter(s.real) and component_filter(s.imag) for idx in sh.ndindex(in_.shape): scalar_i = in_stype(in_[idx]) if not filter_(scalar_i): @@ -257,18 +243,28 @@ def unary_assert_against_refimpl( except Exception: continue if res.dtype != xp.bool: - assert m is not None and M is not None # for mypy - if expected <= m or expected >= M: - continue + if res.dtype in dh.complex_dtypes: + if expected.real <= m or expected.real >= M: + continue + if expected.imag <= m or expected.imag >= M: + continue + else: + if expected <= m or expected >= M: + continue scalar_o = res_stype(res[idx]) f_i = sh.fmt_idx("x", idx) f_o = sh.fmt_idx("out", idx) expr = expr_template.format(f_i, expected) - if strict_check == False or dh.is_float_dtype(res.dtype): - assert isclose(scalar_o, expected), ( + if strict_check == False or res.dtype in dh.all_float_dtypes: + msg = ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_i}={scalar_i}" ) + if res.dtype in dh.complex_dtypes: + assert isclose(scalar_o.real, expected.real, M), msg + assert isclose(scalar_o.imag, expected.imag, M), msg + else: + assert isclose(scalar_o, expected, M), msg else: assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" @@ -299,9 +295,19 @@ def binary_assert_against_refimpl( if expr_template is None: expr_template = func_name + "({}, {})={}" in_stype = dh.get_scalar_type(left.dtype) + if res_stype is None: + res_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype - m, M = dh.dtype_ranges.get(res.dtype, (None, None)) + if res.dtype == xp.bool: + m, M = (None, None) + elif res.dtype in dh.complex_dtypes: + m, M = dh.dtype_ranges[dh.dtype_components[res.dtype]] + else: + m, M = dh.dtype_ranges[res.dtype] + if left.dtype in dh.complex_dtypes: + component_filter = copy(filter_) + filter_ = lambda s: component_filter(s.real) and component_filter(s.imag) for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): scalar_l = in_stype(left[l_idx]) scalar_r = in_stype(right[r_idx]) @@ -312,19 +318,29 @@ def binary_assert_against_refimpl( except Exception: continue if res.dtype != xp.bool: - assert m is not None and M is not None # for mypy - if expected <= m or expected >= M: - continue + if res.dtype in dh.complex_dtypes: + if expected.real <= m or expected.real >= M: + continue + if expected.imag <= m or expected.imag >= M: + continue + else: + if expected <= m or expected >= M: + continue scalar_o = res_stype(res[o_idx]) f_l = sh.fmt_idx(left_sym, l_idx) f_r = sh.fmt_idx(right_sym, r_idx) f_o = sh.fmt_idx(res_name, o_idx) expr = expr_template.format(f_l, f_r, expected) - if strict_check == False or dh.is_float_dtype(res.dtype): - assert isclose(scalar_o, expected), ( + if strict_check == False or res.dtype in dh.all_float_dtypes: + msg = ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_l}={scalar_l}, {f_r}={scalar_r}" ) + if res.dtype in dh.complex_dtypes: + assert isclose(scalar_o.real, expected.real, M), msg + assert isclose(scalar_o.imag, expected.imag, M), msg + else: + assert isclose(scalar_o, expected, M), msg else: assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" @@ -351,33 +367,53 @@ def right_scalar_assert_against_refimpl( See unary_assert_against_refimpl for more information. """ + if left.dtype in dh.complex_dtypes: + component_filter = copy(filter_) + filter_ = lambda s: component_filter(s.real) and component_filter(s.imag) if filter_(right): return # short-circuit here as there will be nothing to test in_stype = dh.get_scalar_type(left.dtype) + if res_stype is None: + res_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype - m, M = dh.dtype_ranges.get(left.dtype, (None, None)) + if res.dtype == xp.bool: + m, M = (None, None) + elif left.dtype in dh.complex_dtypes: + m, M = dh.dtype_ranges[dh.dtype_components[left.dtype]] + else: + m, M = dh.dtype_ranges[left.dtype] for idx in sh.ndindex(res.shape): scalar_l = in_stype(left[idx]) - if not filter_(scalar_l): + if not (filter_(scalar_l) and filter_(right)): continue try: expected = refimpl(scalar_l, right) except Exception: continue if left.dtype != xp.bool: - assert m is not None and M is not None # for mypy - if expected <= m or expected >= M: - continue + if res.dtype in dh.complex_dtypes: + if expected.real <= m or expected.real >= M: + continue + if expected.imag <= m or expected.imag >= M: + continue + else: + if expected <= m or expected >= M: + continue scalar_o = res_stype(res[idx]) f_l = sh.fmt_idx(left_sym, idx) f_o = sh.fmt_idx(res_name, idx) expr = expr_template.format(f_l, right, expected) - if strict_check == False or dh.is_float_dtype(res.dtype): - assert isclose(scalar_o, expected), ( + if strict_check == False or res.dtype in dh.all_float_dtypes: + msg = ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_l}={scalar_l}" ) + if res.dtype in dh.complex_dtypes: + assert isclose(scalar_o.real, expected.real, M), msg + assert isclose(scalar_o.imag, expected.imag, M), msg + else: + assert isclose(scalar_o, expected, M), msg else: assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" @@ -418,8 +454,16 @@ def __repr__(self): def make_unary_params( - elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType] + elwise_func_name: str, + dtypes: Sequence[DataType], + *, + min_version: str = "2021.12", ) -> List[Param[UnaryParamContext]]: + if hh.FILTER_UNDEFINED_DTYPES: + dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] + if api_version < "2022.12": + dtypes = [d for d in dtypes if d not in dh.complex_dtypes] + dtypes_strat = st.sampled_from(dtypes) strat = xps.arrays(dtype=dtypes_strat, shape=hh.shapes()) func_ctx = UnaryParamContext( func_name=elwise_func_name, func=getattr(xp, elwise_func_name), strat=strat @@ -428,7 +472,16 @@ def make_unary_params( op_ctx = UnaryParamContext( func_name=op_name, func=lambda x: getattr(x, op_name)(), strat=strat ) - return [pytest.param(func_ctx, id=func_ctx.id), pytest.param(op_ctx, id=op_ctx.id)] + if api_version < min_version: + marks = pytest.mark.skip( + reason=f"requires ARRAY_API_TESTS_VERSION=>{min_version}" + ) + else: + marks = () + return [ + pytest.param(func_ctx, id=func_ctx.id, marks=marks), + pytest.param(op_ctx, id=op_ctx.id, marks=marks), + ] class FuncType(Enum): @@ -463,7 +516,7 @@ def make_binary_params( ) -> List[Param[BinaryParamContext]]: if hh.FILTER_UNDEFINED_DTYPES: dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] - shared_oneway_dtypes = st.shared(oneway_promotable_dtypes(dtypes)) + shared_oneway_dtypes = st.shared(hh.oneway_promotable_dtypes(dtypes)) left_dtypes = shared_oneway_dtypes.map(lambda D: D.result_dtype) right_dtypes = shared_oneway_dtypes.map(lambda D: D.input_dtype) @@ -482,7 +535,7 @@ def make_param( right_strat = right_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw)) else: if func_type is FuncType.IOP: - shared_oneway_shapes = st.shared(oneway_broadcastable_shapes()) + shared_oneway_shapes = st.shared(hh.oneway_broadcastable_shapes()) left_strat = xps.arrays( dtype=left_dtypes, shape=shared_oneway_shapes.map(lambda S: S.result_shape), @@ -633,7 +686,7 @@ def binary_param_assert_against_refimpl( ) -@pytest.mark.parametrize("ctx", make_unary_params("abs", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_unary_params("abs", dh.numeric_dtypes)) @given(data=st.data()) def test_abs(ctx, data): x = data.draw(ctx.strat, label="x") @@ -643,13 +696,17 @@ def test_abs(ctx, data): out = ctx.func(x) - ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) + if x.dtype in dh.complex_dtypes: + assert out.dtype == dh.dtype_components[x.dtype] + else: + ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) unary_assert_against_refimpl( ctx.func_name, x, out, abs, # type: ignore + res_stype=float if x.dtype in dh.complex_dtypes else None, expr_template="abs({})={}", filter_=lambda s: ( s == float("infinity") or (math.isfinite(s) and not ph.is_neg_zero(s)) @@ -657,7 +714,7 @@ def test_abs(ctx, data): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_acos(x): out = xp.acos(x) ph.assert_dtype("acos", x.dtype, out.dtype) @@ -667,7 +724,7 @@ def test_acos(x): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_acosh(x): out = xp.acosh(x) ph.assert_dtype("acosh", x.dtype, out.dtype) @@ -693,7 +750,7 @@ def test_add(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_asin(x): out = xp.asin(x) ph.assert_dtype("asin", x.dtype, out.dtype) @@ -703,7 +760,7 @@ def test_asin(x): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_asinh(x): out = xp.asinh(x) ph.assert_dtype("asinh", x.dtype, out.dtype) @@ -711,7 +768,7 @@ def test_asinh(x): unary_assert_against_refimpl("asinh", x, out, math.asinh) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_atan(x): out = xp.atan(x) ph.assert_dtype("atan", x.dtype, out.dtype) @@ -727,7 +784,7 @@ def test_atan2(x1, x2): binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_atanh(x): out = xp.atanh(x) ph.assert_dtype("atanh", x.dtype, out.dtype) @@ -783,7 +840,7 @@ def test_bitwise_left_shift(ctx, data): @pytest.mark.parametrize( - "ctx", make_unary_params("bitwise_invert", boolean_and_all_integer_dtypes()) + "ctx", make_unary_params("bitwise_invert", dh.bool_and_all_int_dtypes) ) @given(data=st.data()) def test_bitwise_invert(ctx, data): @@ -859,7 +916,7 @@ def test_bitwise_xor(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "^", refimpl) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=xps.real_dtypes(), shape=hh.shapes())) def test_ceil(x): out = xp.ceil(x) ph.assert_dtype("ceil", x.dtype, out.dtype) @@ -867,7 +924,17 @@ def test_ceil(x): unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +if api_version >= "2022.12": + + @given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) + def test_conj(x): + out = xp.conj(x) + ph.assert_dtype("conj", x.dtype, out.dtype) + ph.assert_shape("conj", out.shape, x.shape) + unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) + + +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_cos(x): out = xp.cos(x) ph.assert_dtype("cos", x.dtype, out.dtype) @@ -875,7 +942,7 @@ def test_cos(x): unary_assert_against_refimpl("cos", x, out, math.cos) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_cosh(x): out = xp.cosh(x) ph.assert_dtype("cosh", x.dtype, out.dtype) @@ -883,18 +950,20 @@ def test_cosh(x): unary_assert_against_refimpl("cosh", x, out, math.cosh) -@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.float_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.all_float_dtypes)) @given(data=st.data()) def test_divide(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) if ctx.right_is_scalar: - assume + assume # TODO: assume what? res = ctx.func(left, right) binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) + if res.dtype in dh.complex_dtypes: + return # TOOD: handle complex division binary_param_assert_against_refimpl( ctx, left, @@ -934,7 +1003,7 @@ def test_equal(ctx, data): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_exp(x): out = xp.exp(x) ph.assert_dtype("exp", x.dtype, out.dtype) @@ -942,7 +1011,7 @@ def test_exp(x): unary_assert_against_refimpl("exp", x, out, math.exp) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_expm1(x): out = xp.expm1(x) ph.assert_dtype("expm1", x.dtype, out.dtype) @@ -950,7 +1019,7 @@ def test_expm1(x): unary_assert_against_refimpl("expm1", x, out, math.expm1) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=xps.real_dtypes(), shape=hh.shapes())) def test_floor(x): out = xp.floor(x) ph.assert_dtype("floor", x.dtype, out.dtype) @@ -958,7 +1027,7 @@ def test_floor(x): unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True) -@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.real_dtypes)) @given(data=st.data()) def test_floor_divide(ctx, data): left = data.draw( @@ -977,7 +1046,7 @@ def test_floor_divide(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv) -@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.real_dtypes)) @given(data=st.data()) def test_greater(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -997,7 +1066,7 @@ def test_greater(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.real_dtypes)) @given(data=st.data()) def test_greater_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1017,6 +1086,16 @@ def test_greater_equal(ctx, data): ) +if api_version >= "2022.12": + + @given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) + def test_imag(x): + out = xp.imag(x) + ph.assert_dtype("imag", x.dtype, out.dtype, dh.dtype_components[x.dtype]) + ph.assert_shape("imag", out.shape, x.shape) + unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag")) + + @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_isfinite(x): out = xp.isfinite(x) @@ -1041,7 +1120,7 @@ def test_isnan(x): unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool) -@pytest.mark.parametrize("ctx", make_binary_params("less", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("less", dh.real_dtypes)) @given(data=st.data()) def test_less(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1061,7 +1140,7 @@ def test_less(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.real_dtypes)) @given(data=st.data()) def test_less_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1081,7 +1160,7 @@ def test_less_equal(ctx, data): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_log(x): out = xp.log(x) ph.assert_dtype("log", x.dtype, out.dtype) @@ -1091,7 +1170,7 @@ def test_log(x): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_log1p(x): out = xp.log1p(x) ph.assert_dtype("log1p", x.dtype, out.dtype) @@ -1101,7 +1180,7 @@ def test_log1p(x): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_log2(x): out = xp.log2(x) ph.assert_dtype("log2", x.dtype, out.dtype) @@ -1111,7 +1190,7 @@ def test_log2(x): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_log10(x): out = xp.log10(x) ph.assert_dtype("log10", x.dtype, out.dtype) @@ -1187,9 +1266,7 @@ def test_multiply(ctx, data): # TODO: clarify if uints are acceptable, adjust accordingly -@pytest.mark.parametrize( - "ctx", make_unary_params("negative", xps.integer_dtypes() | xps.floating_dtypes()) -) +@pytest.mark.parametrize("ctx", make_unary_params("negative", dh.numeric_dtypes)) @given(data=st.data()) def test_negative(ctx, data): x = data.draw(ctx.strat, label="x") @@ -1226,7 +1303,7 @@ def test_not_equal(ctx, data): ) -@pytest.mark.parametrize("ctx", make_unary_params("positive", xps.numeric_dtypes())) +@pytest.mark.parametrize("ctx", make_unary_params("positive", dh.numeric_dtypes)) @given(data=st.data()) def test_positive(ctx, data): x = data.draw(ctx.strat, label="x") @@ -1260,7 +1337,17 @@ def test_pow(ctx, data): # Values testing pow is too finicky -@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.numeric_dtypes)) +if api_version >= "2022.12": + + @given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) + def test_real(x): + out = xp.real(x) + ph.assert_dtype("real", x.dtype, out.dtype, dh.dtype_components[x.dtype]) + ph.assert_shape("real", out.shape, x.shape) + unary_assert_against_refimpl("real", x, out, operator.attrgetter("real")) + + +@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes)) @given(data=st.data()) def test_remainder(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1295,7 +1382,7 @@ def test_sign(x): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_sin(x): out = xp.sin(x) ph.assert_dtype("sin", x.dtype, out.dtype) @@ -1303,7 +1390,7 @@ def test_sin(x): unary_assert_against_refimpl("sin", x, out, math.sin) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_sinh(x): out = xp.sinh(x) ph.assert_dtype("sinh", x.dtype, out.dtype) @@ -1317,11 +1404,11 @@ def test_square(x): ph.assert_dtype("square", x.dtype, out.dtype) ph.assert_shape("square", out.shape, x.shape) unary_assert_against_refimpl( - "square", x, out, lambda s: s ** 2, expr_template="{}²={}" + "square", x, out, lambda s: s**2, expr_template="{}²={}" ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_sqrt(x): out = xp.sqrt(x) ph.assert_dtype("sqrt", x.dtype, out.dtype) @@ -1347,7 +1434,7 @@ def test_subtract(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_tan(x): out = xp.tan(x) ph.assert_dtype("tan", x.dtype, out.dtype) @@ -1355,7 +1442,7 @@ def test_tan(x): unary_assert_against_refimpl("tan", x, out, math.tan) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=all_floating_dtypes(), shape=hh.shapes())) def test_tanh(x): out = xp.tanh(x) ph.assert_dtype("tanh", x.dtype, out.dtype) @@ -1363,7 +1450,7 @@ def test_tanh(x): unary_assert_against_refimpl("tanh", x, out, math.tanh) -@given(xps.arrays(dtype=hh.numeric_dtypes, shape=xps.array_shapes())) +@given(xps.arrays(dtype=xps.real_dtypes(), shape=xps.array_shapes())) def test_trunc(x): out = xp.trunc(x) ph.assert_dtype("trunc", x.dtype, out.dtype) diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 5e415858..193087d9 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -1,4 +1,5 @@ # TODO: disable if opted out, refactor things +import cmath import math from collections import Counter, defaultdict @@ -61,7 +62,7 @@ def test_unique_all(x): for idx in sh.ndindex(out.indices.shape): val = scalar_type(out.values[idx]) - if math.isnan(val): + if cmath.isnan(val): break i = int(out.indices[idx]) expected = firsts[val] @@ -88,7 +89,7 @@ def test_unique_all(x): for idx in sh.ndindex(out.values.shape): val = scalar_type(out.values[idx]) count = int(out.counts[idx]) - if math.isnan(val): + if cmath.isnan(val): nans += 1 assert count == 1, ( f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " @@ -225,7 +226,7 @@ def test_unique_values(x): nans = 0 for idx in sh.ndindex(out.shape): val = scalar_type(out[idx]) - if math.isnan(val): + if cmath.isnan(val): nans += 1 else: assert val in distinct, f"out[{idx}]={val}, but {val} not in input array" diff --git a/array_api_tests/test_sorting_functions.py b/array_api_tests/test_sorting_functions.py index 7c5a1411..69149c1b 100644 --- a/array_api_tests/test_sorting_functions.py +++ b/array_api_tests/test_sorting_functions.py @@ -1,4 +1,4 @@ -import math +import cmath from typing import Set import pytest @@ -26,7 +26,7 @@ def assert_scalar_in_set( **kw, ): out_repr = "out" if idx == () else f"out[{idx}]" - if math.isnan(out): + if cmath.isnan(out): raise NotImplementedError() msg = f"{out_repr}={out}, but should be in {set_} [{func_name}({ph.fmt_kw(kw)})]" assert out in set_, msg diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 9999d9b0..2e4167ce 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -35,10 +35,6 @@ from . import xps from ._array_module import mod as xp from .stubs import category_to_funcs -from .test_operators_and_elementwise_functions import ( - oneway_broadcastable_shapes, - oneway_promotable_dtypes, -) pytestmark = pytest.mark.ci @@ -1281,8 +1277,8 @@ def test_binary(func_name, func, case, x1, x2, data): @pytest.mark.parametrize("iop_name, iop, case", iop_params) @given( - oneway_dtypes=oneway_promotable_dtypes(dh.float_dtypes), - oneway_shapes=oneway_broadcastable_shapes(), + oneway_dtypes=hh.oneway_promotable_dtypes(dh.float_dtypes), + oneway_shapes=hh.oneway_broadcastable_shapes(), data=st.data(), ) def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data): diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 4371fb07..2d433dc6 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,3 +1,4 @@ +import cmath import math from typing import Optional @@ -162,7 +163,7 @@ def test_prod(x, data): scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): prod = scalar_type(out[out_idx]) - assume(math.isfinite(prod)) + assume(cmath.isfinite(prod)) elements = [] for idx in indices: s = scalar_type(x[idx]) @@ -267,7 +268,7 @@ def test_sum(x, data): scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): sum_ = scalar_type(out[out_idx]) - assume(math.isfinite(sum_)) + assume(cmath.isfinite(sum_)) elements = [] for idx in indices: s = scalar_type(x[idx]) diff --git a/array_api_tests/typing.py b/array_api_tests/typing.py index f0ed8c50..84311ff3 100644 --- a/array_api_tests/typing.py +++ b/array_api_tests/typing.py @@ -12,8 +12,8 @@ ] DataType = Type[Any] -Scalar = Union[bool, int, float] -ScalarType = Union[Type[bool], Type[int], Type[float]] +Scalar = Union[bool, int, float, complex] +ScalarType = Union[Type[bool], Type[int], Type[float], Type[complex]] Array = Any Shape = Tuple[int, ...] AtomicIndex = Union[int, "ellipsis", slice, None] # noqa diff --git a/conftest.py b/conftest.py index e0453e40..9b7e7956 100644 --- a/conftest.py +++ b/conftest.py @@ -5,6 +5,7 @@ from pytest import mark from array_api_tests import _array_module as xp +from array_api_tests import api_version from array_api_tests._array_module import _UndefinedStub from reporting import pytest_metadata, pytest_json_modifyreport, add_extra_json_metadata # noqa @@ -59,6 +60,10 @@ def pytest_configure(config): "markers", "data_dependent_shapes: output shapes are dependent on inputs" ) config.addinivalue_line("markers", "ci: primary test") + config.addinivalue_line( + "markers", + "min_version(api_version): run when greater or equal to api_version", + ) # Hypothesis hypothesis_max_examples = config.getoption("--hypothesis-max-examples") disable_deadline = config.getoption("--hypothesis-disable-deadline") @@ -126,3 +131,13 @@ def pytest_collection_modifyitems(config, items): ci_mark = next((m for m in markers if m.name == "ci"), None) if ci_mark is None: item.add_marker(mark.skip(reason="disabled via --ci")) + # skip if test is for greater api_version + ver_mark = next((m for m in markers if m.name == "min_version"), None) + if ver_mark is not None: + min_version = ver_mark.args[0] + if api_version < min_version: + item.add_marker( + mark.skip( + reason=f"requires ARRAY_API_TESTS_VERSION=>{min_version}" + ) + ) diff --git a/requirements.txt b/requirements.txt index 07b8b189..bb33bc90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ pytest pytest-json-report -hypothesis>=6.62.1 +hypothesis>=6.68.0 ndindex>=1.6