diff --git a/array_api_tests/array_helpers.py b/array_api_tests/array_helpers.py index 398f1994..b3ae583c 100644 --- a/array_api_tests/array_helpers.py +++ b/array_api_tests/array_helpers.py @@ -1,5 +1,3 @@ -import itertools - from ._array_module import (isnan, all, any, equal, not_equal, logical_and, logical_or, isfinite, greater, less, less_equal, zeros, ones, full, bool, int8, int16, int32, @@ -23,7 +21,7 @@ 'assert_isinf', 'positive_mathematical_sign', 'assert_positive_mathematical_sign', 'negative_mathematical_sign', 'assert_negative_mathematical_sign', 'same_sign', - 'assert_same_sign', 'ndindex', 'float64', + 'assert_same_sign', 'float64', 'asarray', 'full', 'true', 'false', 'isnan'] def zero(shape, dtype): @@ -319,13 +317,3 @@ def int_to_dtype(x, n, signed): if x & highest_bit: x = -((~x & mask) + 1) return x - -def ndindex(shape): - """ - Iterator of n-D indices to an array - - Yields tuples of integers to index every element of an array of shape - `shape`. Same as np.ndindex(). - - """ - return itertools.product(*[range(i) for i in shape]) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 65b4090a..ce749d9e 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,45 +1,45 @@ -from warnings import warn from functools import lru_cache from typing import NamedTuple, Tuple, Union +from warnings import warn from . import _array_module as xp from ._array_module import _UndefinedStub from .typing import DataType, ScalarType - __all__ = [ - 'int_dtypes', - 'uint_dtypes', - 'all_int_dtypes', - 'float_dtypes', - 'numeric_dtypes', - 'all_dtypes', - 'dtype_to_name', - 'bool_and_all_int_dtypes', - 'dtype_to_scalars', - 'is_int_dtype', - 'is_float_dtype', - 'get_scalar_type', - 'dtype_ranges', - 'default_int', - 'default_float', - 'promotion_table', - 'dtype_nbits', - 'dtype_signed', - 'func_in_dtypes', - 'func_returns_bool', - 'binary_op_to_symbol', - 'unary_op_to_symbol', - 'inplace_op_to_symbol', - 'op_to_func', - 'fmt_types', + "int_dtypes", + "uint_dtypes", + "all_int_dtypes", + "float_dtypes", + "numeric_dtypes", + "all_dtypes", + "dtype_to_name", + "bool_and_all_int_dtypes", + "dtype_to_scalars", + "is_int_dtype", + "is_float_dtype", + "get_scalar_type", + "dtype_ranges", + "default_int", + "default_uint", + "default_float", + "promotion_table", + "dtype_nbits", + "dtype_signed", + "func_in_dtypes", + "func_returns_bool", + "binary_op_to_symbol", + "unary_op_to_symbol", + "inplace_op_to_symbol", + "op_to_func", + "fmt_types", ] -_uint_names = ('uint8', 'uint16', 'uint32', 'uint64') -_int_names = ('int8', 'int16', 'int32', 'int64') -_float_names = ('float32', 'float64') -_dtype_names = ('bool',) + _uint_names + _int_names + _float_names +_uint_names = ("uint8", "uint16", "uint32", "uint64") +_int_names = ("int8", "int16", "int32", "int64") +_float_names = ("float32", "float64") +_dtype_names = ("bool",) + _uint_names + _int_names + _float_names uint_dtypes = tuple(getattr(xp, name) for name in _uint_names) @@ -101,17 +101,34 @@ class MinMax(NamedTuple): xp.uint64: MinMax(0, +18_446_744_073_709_551_615), } +dtype_nbits = { + **{d: 8 for d in [xp.int8, xp.uint8]}, + **{d: 16 for d in [xp.int16, xp.uint16]}, + **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, + **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, +} + + +dtype_signed = { + **{d: True for d in int_dtypes}, + **{d: False for d in uint_dtypes}, +} + if isinstance(xp.asarray, _UndefinedStub): default_int = xp.int32 default_float = xp.float32 warn( - 'array module does not have attribute asarray. ' - 'default int is assumed int32, default float is assumed float32' + "array module does not have attribute asarray. " + "default int is assumed int32, default float is assumed float32" ) else: default_int = xp.asarray(int()).dtype default_float = xp.asarray(float()).dtype +if dtype_nbits[default_int] == 32: + default_uint = xp.uint32 +else: + default_uint = xp.uint64 _numeric_promotions = { @@ -173,200 +190,186 @@ def result_type(*dtypes: DataType): return result -dtype_nbits = { - **{d: 8 for d in [xp.int8, xp.uint8]}, - **{d: 16 for d in [xp.int16, xp.uint16]}, - **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, - **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, -} - - -dtype_signed = { - **{d: True for d in int_dtypes}, - **{d: False for d in uint_dtypes}, -} - - func_in_dtypes = { # elementwise - 'abs': numeric_dtypes, - 'acos': float_dtypes, - 'acosh': float_dtypes, - 'add': numeric_dtypes, - 'asin': float_dtypes, - 'asinh': float_dtypes, - 'atan': float_dtypes, - 'atan2': float_dtypes, - 'atanh': float_dtypes, - 'bitwise_and': bool_and_all_int_dtypes, - 'bitwise_invert': bool_and_all_int_dtypes, - 'bitwise_left_shift': all_int_dtypes, - 'bitwise_or': bool_and_all_int_dtypes, - 'bitwise_right_shift': all_int_dtypes, - 'bitwise_xor': bool_and_all_int_dtypes, - 'ceil': numeric_dtypes, - 'cos': float_dtypes, - 'cosh': float_dtypes, - 'divide': float_dtypes, - 'equal': all_dtypes, - 'exp': float_dtypes, - 'expm1': float_dtypes, - 'floor': numeric_dtypes, - 'floor_divide': numeric_dtypes, - 'greater': numeric_dtypes, - 'greater_equal': numeric_dtypes, - 'isfinite': numeric_dtypes, - 'isinf': numeric_dtypes, - 'isnan': numeric_dtypes, - 'less': numeric_dtypes, - 'less_equal': numeric_dtypes, - 'log': float_dtypes, - 'logaddexp': float_dtypes, - 'log10': float_dtypes, - 'log1p': float_dtypes, - 'log2': float_dtypes, - 'logical_and': (xp.bool,), - 'logical_not': (xp.bool,), - 'logical_or': (xp.bool,), - 'logical_xor': (xp.bool,), - 'multiply': numeric_dtypes, - 'negative': numeric_dtypes, - 'not_equal': all_dtypes, - 'positive': numeric_dtypes, - 'pow': float_dtypes, - 'remainder': numeric_dtypes, - 'round': numeric_dtypes, - 'sign': numeric_dtypes, - 'sin': float_dtypes, - 'sinh': float_dtypes, - 'sqrt': float_dtypes, - 'square': numeric_dtypes, - 'subtract': numeric_dtypes, - 'tan': float_dtypes, - 'tanh': float_dtypes, - 'trunc': numeric_dtypes, + "abs": numeric_dtypes, + "acos": float_dtypes, + "acosh": float_dtypes, + "add": numeric_dtypes, + "asin": float_dtypes, + "asinh": float_dtypes, + "atan": float_dtypes, + "atan2": float_dtypes, + "atanh": float_dtypes, + "bitwise_and": bool_and_all_int_dtypes, + "bitwise_invert": bool_and_all_int_dtypes, + "bitwise_left_shift": all_int_dtypes, + "bitwise_or": bool_and_all_int_dtypes, + "bitwise_right_shift": all_int_dtypes, + "bitwise_xor": bool_and_all_int_dtypes, + "ceil": numeric_dtypes, + "cos": float_dtypes, + "cosh": float_dtypes, + "divide": float_dtypes, + "equal": all_dtypes, + "exp": float_dtypes, + "expm1": float_dtypes, + "floor": numeric_dtypes, + "floor_divide": numeric_dtypes, + "greater": numeric_dtypes, + "greater_equal": numeric_dtypes, + "isfinite": numeric_dtypes, + "isinf": numeric_dtypes, + "isnan": numeric_dtypes, + "less": numeric_dtypes, + "less_equal": numeric_dtypes, + "log": float_dtypes, + "logaddexp": float_dtypes, + "log10": float_dtypes, + "log1p": float_dtypes, + "log2": float_dtypes, + "logical_and": (xp.bool,), + "logical_not": (xp.bool,), + "logical_or": (xp.bool,), + "logical_xor": (xp.bool,), + "multiply": numeric_dtypes, + "negative": numeric_dtypes, + "not_equal": all_dtypes, + "positive": numeric_dtypes, + "pow": float_dtypes, + "remainder": numeric_dtypes, + "round": numeric_dtypes, + "sign": numeric_dtypes, + "sin": float_dtypes, + "sinh": float_dtypes, + "sqrt": float_dtypes, + "square": numeric_dtypes, + "subtract": numeric_dtypes, + "tan": float_dtypes, + "tanh": float_dtypes, + "trunc": numeric_dtypes, # searching - 'where': all_dtypes, + "where": all_dtypes, } func_returns_bool = { # elementwise - 'abs': False, - 'acos': False, - 'acosh': False, - 'add': False, - 'asin': False, - 'asinh': False, - 'atan': False, - 'atan2': False, - 'atanh': False, - 'bitwise_and': False, - 'bitwise_invert': False, - 'bitwise_left_shift': False, - 'bitwise_or': False, - 'bitwise_right_shift': False, - 'bitwise_xor': False, - 'ceil': False, - 'cos': False, - 'cosh': False, - 'divide': False, - 'equal': True, - 'exp': False, - 'expm1': False, - 'floor': False, - 'floor_divide': False, - 'greater': True, - 'greater_equal': True, - 'isfinite': True, - 'isinf': True, - 'isnan': True, - 'less': True, - 'less_equal': True, - 'log': False, - 'logaddexp': False, - 'log10': False, - 'log1p': False, - 'log2': False, - 'logical_and': True, - 'logical_not': True, - 'logical_or': True, - 'logical_xor': True, - 'multiply': False, - 'negative': False, - 'not_equal': True, - 'positive': False, - 'pow': False, - 'remainder': False, - 'round': False, - 'sign': False, - 'sin': False, - 'sinh': False, - 'sqrt': False, - 'square': False, - 'subtract': False, - 'tan': False, - 'tanh': False, - 'trunc': False, + "abs": False, + "acos": False, + "acosh": False, + "add": False, + "asin": False, + "asinh": False, + "atan": False, + "atan2": False, + "atanh": False, + "bitwise_and": False, + "bitwise_invert": False, + "bitwise_left_shift": False, + "bitwise_or": False, + "bitwise_right_shift": False, + "bitwise_xor": False, + "ceil": False, + "cos": False, + "cosh": False, + "divide": False, + "equal": True, + "exp": False, + "expm1": False, + "floor": False, + "floor_divide": False, + "greater": True, + "greater_equal": True, + "isfinite": True, + "isinf": True, + "isnan": True, + "less": True, + "less_equal": True, + "log": False, + "logaddexp": False, + "log10": False, + "log1p": False, + "log2": False, + "logical_and": True, + "logical_not": True, + "logical_or": True, + "logical_xor": True, + "multiply": False, + "negative": False, + "not_equal": True, + "positive": False, + "pow": False, + "remainder": False, + "round": False, + "sign": False, + "sin": False, + "sinh": False, + "sqrt": False, + "square": False, + "subtract": False, + "tan": False, + "tanh": False, + "trunc": False, # searching - 'where': False, + "where": False, } unary_op_to_symbol = { - '__invert__': '~', - '__neg__': '-', - '__pos__': '+', + "__invert__": "~", + "__neg__": "-", + "__pos__": "+", } binary_op_to_symbol = { - '__add__': '+', - '__and__': '&', - '__eq__': '==', - '__floordiv__': '//', - '__ge__': '>=', - '__gt__': '>', - '__le__': '<=', - '__lshift__': '<<', - '__lt__': '<', - '__matmul__': '@', - '__mod__': '%', - '__mul__': '*', - '__ne__': '!=', - '__or__': '|', - '__pow__': '**', - '__rshift__': '>>', - '__sub__': '-', - '__truediv__': '/', - '__xor__': '^', + "__add__": "+", + "__and__": "&", + "__eq__": "==", + "__floordiv__": "//", + "__ge__": ">=", + "__gt__": ">", + "__le__": "<=", + "__lshift__": "<<", + "__lt__": "<", + "__matmul__": "@", + "__mod__": "%", + "__mul__": "*", + "__ne__": "!=", + "__or__": "|", + "__pow__": "**", + "__rshift__": ">>", + "__sub__": "-", + "__truediv__": "/", + "__xor__": "^", } op_to_func = { - '__abs__': 'abs', - '__add__': 'add', - '__and__': 'bitwise_and', - '__eq__': 'equal', - '__floordiv__': 'floor_divide', - '__ge__': 'greater_equal', - '__gt__': 'greater', - '__le__': 'less_equal', - '__lt__': 'less', + "__abs__": "abs", + "__add__": "add", + "__and__": "bitwise_and", + "__eq__": "equal", + "__floordiv__": "floor_divide", + "__ge__": "greater_equal", + "__gt__": "greater", + "__le__": "less_equal", + "__lt__": "less", # '__matmul__': 'matmul', # TODO: support matmul - '__mod__': 'remainder', - '__mul__': 'multiply', - '__ne__': 'not_equal', - '__or__': 'bitwise_or', - '__pow__': 'pow', - '__lshift__': 'bitwise_left_shift', - '__rshift__': 'bitwise_right_shift', - '__sub__': 'subtract', - '__truediv__': 'divide', - '__xor__': 'bitwise_xor', - '__invert__': 'bitwise_invert', - '__neg__': 'negative', - '__pos__': 'positive', + "__mod__": "remainder", + "__mul__": "multiply", + "__ne__": "not_equal", + "__or__": "bitwise_or", + "__pow__": "pow", + "__lshift__": "bitwise_left_shift", + "__rshift__": "bitwise_right_shift", + "__sub__": "subtract", + "__truediv__": "divide", + "__xor__": "bitwise_xor", + "__invert__": "bitwise_invert", + "__neg__": "negative", + "__pos__": "positive", } @@ -377,10 +380,10 @@ def result_type(*dtypes: DataType): inplace_op_to_symbol = {} for op, symbol in binary_op_to_symbol.items(): - if op == '__matmul__' or func_returns_bool[op]: + if op == "__matmul__" or func_returns_bool[op]: continue - iop = f'__i{op[2:]}' - inplace_op_to_symbol[iop] = f'{symbol}=' + iop = f"__i{op[2:]}" + inplace_op_to_symbol[iop] = f"{symbol}=" func_in_dtypes[iop] = func_in_dtypes[op] func_returns_bool[iop] = func_returns_bool[op] @@ -394,4 +397,4 @@ def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str: except KeyError: # i.e. dtype is bool, int, or float f_types.append(type_.__name__) - return ', '.join(f_types) + return ", ".join(f_types) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index a14f3d51..d8c5f976 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -2,7 +2,7 @@ from functools import reduce from math import sqrt from operator import mul -from typing import Any, List, NamedTuple, Optional, Tuple, Sequence +from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union from hypothesis import assume from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, @@ -11,15 +11,15 @@ from . import _array_module as xp from . import dtype_helpers as dh +from . import shape_helpers as sh from . import xps from ._array_module import _UndefinedStub from ._array_module import bool as bool_dtype from ._array_module import broadcast_to, eye, float32, float64, full -from .array_helpers import ndindex +from .algos import broadcast_shapes from .function_stubs import elementwise_functions from .pytest_helpers import nargs from .typing import Array, DataType, Shape -from .algos import broadcast_shapes # Set this to True to not fail tests just because a dtype isn't implemented. # If no compatible dtype is implemented for a given test, the test will fail @@ -208,7 +208,7 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes( assume(xp.all(xp.abs(d) > 0.5)) a = xp.zeros(shape) - for j, (idx, i) in enumerate(itertools.product(ndindex(stack_shape), range(n))): + for j, (idx, i) in enumerate(itertools.product(sh.ndindex(stack_shape), range(n))): a[idx + (i, i)] = d[j] return a @@ -399,3 +399,12 @@ def specified_kwargs(draw, *keys_values_defaults: KVD): if value is not default or draw(booleans()): kw[keyword] = value return kw + + +def axes(ndim: int) -> SearchStrategy[Optional[Union[int, Shape]]]: + """Generate valid arguments for some axis keywords""" + axes_strats = [none()] + if ndim != 0: + axes_strats.append(integers(-ndim, ndim - 1)) + axes_strats.append(xps.valid_tuple_axes(ndim)) + return one_of(axes_strats) diff --git a/array_api_tests/meta/test_hypothesis_helpers.py b/array_api_tests/meta/test_hypothesis_helpers.py index 652644c1..b4cb6e96 100644 --- a/array_api_tests/meta/test_hypothesis_helpers.py +++ b/array_api_tests/meta/test_hypothesis_helpers.py @@ -68,7 +68,6 @@ def test_two_broadcastable_shapes(pair): @given(*hh.two_mutual_arrays()) def test_two_mutual_arrays(x1, x2): assert (x1.dtype, x2.dtype) in dh.promotion_table.keys() - assert broadcast_shapes(x1.shape, x2.shape) in (x1.shape, x2.shape) def test_two_mutual_arrays_raises_on_bad_dtypes(): diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 7dfafc5b..3b28b9a9 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -1,7 +1,9 @@ import pytest -from ..test_signatures import extension_module +from .. import shape_helpers as sh from ..test_creation_functions import frange +from ..test_manipulation_functions import roll_ndindex +from ..test_signatures import extension_module def test_extension_module_is_extension(): @@ -24,3 +26,59 @@ def test_extension_func_is_not_extension(): def test_frange(r, size, elements): assert len(r) == size assert list(r) == elements + + +@pytest.mark.parametrize( + "shape, expected", + [((), [()])], +) +def test_ndindex(shape, expected): + assert list(sh.ndindex(shape)) == expected + + +@pytest.mark.parametrize( + "shape, axis, expected", + [ + ((1,), 0, [(slice(None, None),)]), + ((1, 2), 0, [(slice(None, None), slice(None, None))]), + ( + (2, 4), + 1, + [(0, slice(None, None)), (1, slice(None, None))], + ), + ], +) +def test_axis_ndindex(shape, axis, expected): + assert list(sh.axis_ndindex(shape, axis)) == expected + + +@pytest.mark.parametrize( + "shape, axes, expected", + [ + ((), (), [[()]]), + ((1,), (0,), [[(0,)]]), + ( + (2, 2), + (0,), + [ + [(0, 0), (1, 0)], + [(0, 1), (1, 1)], + ], + ), + ], +) +def test_axes_ndindex(shape, axes, expected): + assert list(sh.axes_ndindex(shape, axes)) == expected + + +@pytest.mark.parametrize( + "shape, shifts, axes, expected", + [ + ((1, 1), (0,), (0,), [(0, 0)]), + ((2, 1), (1, 1), (0, 1), [(1, 0), (0, 0)]), + ((2, 2), (1, 1), (0, 1), [(1, 1), (1, 0), (0, 1), (0, 0)]), + ((2, 2), (-1, 1), (0, 1), [(1, 1), (1, 0), (0, 1), (0, 0)]), + ], +) +def test_roll_ndindex(shape, shifts, axes, expected): + assert list(roll_ndindex(shape, shifts, axes)) == expected diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 9424ba35..c8fd1fdb 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -1,12 +1,13 @@ -from array_api_tests.algos import broadcast_shapes import math from inspect import getfullargspec from typing import Any, Dict, Optional, Tuple, Union +from . import _array_module as xp from . import array_helpers as ah from . import dtype_helpers as dh from . import function_stubs -from .typing import Array, DataType, Scalar, Shape +from .algos import broadcast_shapes +from .typing import Array, DataType, Scalar, ScalarType, Shape __all__ = [ "raises", @@ -17,8 +18,10 @@ "assert_kw_dtype", "assert_default_float", "assert_default_int", + "assert_default_index", "assert_shape", "assert_result_shape", + "assert_keepdimable_shape", "assert_fill", ] @@ -67,12 +70,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str: def assert_dtype( func_name: str, - in_dtypes: Tuple[DataType, ...], + in_dtypes: Union[DataType, Tuple[DataType, ...]], out_dtype: DataType, expected: Optional[DataType] = None, *, repr_name: str = "out.dtype", ): + if not isinstance(in_dtypes, tuple): + in_dtypes = (in_dtypes,) f_in_dtypes = dh.fmt_types(in_dtypes) f_out_dtype = dh.dtype_to_name[out_dtype] if expected is None: @@ -115,6 +120,15 @@ def assert_default_int(func_name: str, dtype: DataType): assert dtype == dh.default_int, msg +def assert_default_index(func_name: str, dtype: DataType, repr_name="out.dtype"): + f_dtype = dh.dtype_to_name[dtype] + msg = ( + f"{repr_name}={f_dtype}, should be the default index dtype, " + f"which is either int32 or int64 [{func_name}()]" + ) + assert dtype in (xp.int32, xp.int64), msg + + def assert_shape( func_name: str, out_shape: Union[int, Shape], @@ -149,10 +163,59 @@ def assert_result_shape( f_sig = f" {f_in_shapes} " if kw: f_sig += f", {fmt_kw(kw)}" + msg = f"{repr_name}={out_shape}, but should be {expected} [{func_name}({f_sig})]" + assert out_shape == expected, msg + + +def assert_keepdimable_shape( + func_name: str, + out_shape: Shape, + in_shape: Shape, + axes: Tuple[int, ...], + keepdims: bool, + /, + **kw, +): + if keepdims: + shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape)) + else: + shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes) + assert_shape(func_name, out_shape, shape, **kw) + + +def assert_0d_equals( + func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw +): msg = ( - f"{repr_name}={out_shape}, but should be {expected} [{func_name}({f_sig})]" + f"{out_repr}={out_val}, should be {x_repr}={x_val} " + f"[{func_name}({fmt_kw(kw)})]" ) - assert out_shape == expected, msg + if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val): + assert xp.isnan(x_val), msg + else: + assert x_val == out_val, msg + + +def assert_scalar_equals( + func_name: str, + type_: ScalarType, + idx: Shape, + out: Scalar, + expected: Scalar, + /, + **kw, +): + out_repr = "out" if idx == () else f"out[{idx}]" + f_func = f"{func_name}({fmt_kw(kw)})" + if type_ is bool or type_ is int: + msg = f"{out_repr}={out}, should be {expected} [{f_func}]" + assert out == expected, msg + elif math.isnan(expected): + msg = f"{out_repr}={out}, should be {expected} [{f_func}]" + assert math.isnan(out), msg + else: + msg = f"{out_repr}={out}, should be roughly {expected} [{f_func}]" + assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg def assert_fill( diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py new file mode 100644 index 00000000..751b8d49 --- /dev/null +++ b/array_api_tests/shape_helpers.py @@ -0,0 +1,59 @@ +from itertools import product +from typing import Iterator, List, Optional, Tuple, Union + +from .typing import Shape + +__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex"] + + +def normalise_axis( + axis: Optional[Union[int, Tuple[int, ...]]], ndim: int +) -> Tuple[int, ...]: + if axis is None: + return tuple(range(ndim)) + axes = axis if isinstance(axis, tuple) else (axis,) + axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes) + return axes + + +def ndindex(shape): + """Iterator of n-D indices to an array + + Yields tuples of integers to index every element of an array of shape + `shape`. Same as np.ndindex(). + """ + return product(*[range(i) for i in shape]) + + +def axis_ndindex( + shape: Shape, axis: int +) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: + """Generate indices that index all elements in dimensions beyond `axis`""" + assert axis >= 0 # sanity check + axis_indices = [range(side) for side in shape[:axis]] + for _ in range(axis, len(shape)): + axis_indices.append([slice(None, None)]) + yield from product(*axis_indices) + + +def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]: + """Generate indices that index all elements except in `axes` dimensions""" + base_indices = [] + axes_indices = [] + for axis, side in enumerate(shape): + if axis in axes: + base_indices.append([None]) + axes_indices.append(range(side)) + else: + base_indices.append(range(side)) + axes_indices.append([None]) + for base_idx in product(*base_indices): + indices = [] + for idx in product(*axes_indices): + idx = list(idx) + for axis, side in enumerate(idx): + if axis not in axes: + idx[axis] = base_idx[axis] + idx = tuple(idx) + indices.append(idx) + yield list(indices) diff --git a/array_api_tests/test_array2scalar.py b/array_api_tests/test_array2scalar.py new file mode 100644 index 00000000..55fb2fe3 --- /dev/null +++ b/array_api_tests/test_array2scalar.py @@ -0,0 +1,39 @@ +import pytest +from hypothesis import given +from hypothesis import strategies as st + +from . import _array_module as xp +from . import dtype_helpers as dh +from . import xps +from .typing import DataType, Param + +method_stype = { + "__bool__": bool, + "__int__": int, + "__index__": int, + "__float__": float, +} + + +def make_param(method_name: str, dtype: DataType) -> Param: + stype = method_stype[method_name] + return pytest.param( + method_name, dtype, stype, id=f"{method_name}({dh.dtype_to_name[dtype]})" + ) + + +@pytest.mark.parametrize( + "method_name, dtype, stype", + [make_param("__bool__", xp.bool)] + + [make_param("__int__", d) for d in dh.all_int_dtypes] + + [make_param("__index__", d) for d in dh.all_int_dtypes] + + [make_param("__float__", d) for d in dh.float_dtypes], +) +@given(data=st.data()) +def test_0d_array_can_convert_to_scalar(method_name, dtype, stype, data): + x = data.draw(xps.arrays(dtype, shape=()), label="x") + method = getattr(x, method_name) + out = method() + assert isinstance( + out, stype + ), f"{method_name}({x})={out}, which is not a {stype.__name__} scalar" diff --git a/array_api_tests/test_elementwise_functions.py b/array_api_tests/test_elementwise_functions.py index 4d288ee4..345eec44 100644 --- a/array_api_tests/test_elementwise_functions.py +++ b/array_api_tests/test_elementwise_functions.py @@ -23,6 +23,7 @@ from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps from .algos import broadcast_shapes from .typing import Array, DataType, Param, Scalar @@ -377,13 +378,13 @@ def test_bitwise_and( # Compare against the Python & operator. if res.dtype == xp.bool: - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = bool(_left[idx]) s_right = bool(_right[idx]) s_res = bool(res[idx]) assert (s_left and s_right) == s_res else: - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = int(_left[idx]) s_right = int(_right[idx]) s_res = int(res[idx]) @@ -427,7 +428,7 @@ def test_bitwise_left_shift( _right = xp.broadcast_to(right, shape) # Compare against the Python << operator. - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = int(_left[idx]) s_right = int(_right[idx]) s_res = int(res[idx]) @@ -452,12 +453,12 @@ def test_bitwise_invert(func_name, func, strat, data): ph.assert_shape(func_name, out.shape, x.shape) # Compare against the Python ~ operator. if out.dtype == xp.bool: - for idx in ah.ndindex(out.shape): + for idx in sh.ndindex(out.shape): s_x = bool(x[idx]) s_out = bool(out[idx]) assert (not s_x) == s_out else: - for idx in ah.ndindex(out.shape): + for idx in sh.ndindex(out.shape): s_x = int(x[idx]) s_out = int(out[idx]) s_invert = ah.int_to_dtype( @@ -495,13 +496,13 @@ def test_bitwise_or( # Compare against the Python | operator. if res.dtype == xp.bool: - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = bool(_left[idx]) s_right = bool(_right[idx]) s_res = bool(res[idx]) assert (s_left or s_right) == s_res else: - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = int(_left[idx]) s_right = int(_right[idx]) s_res = int(res[idx]) @@ -547,7 +548,7 @@ def test_bitwise_right_shift( _right = xp.broadcast_to(right, shape) # Compare against the Python >> operator. - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = int(_left[idx]) s_right = int(_right[idx]) s_res = int(res[idx]) @@ -586,13 +587,13 @@ def test_bitwise_xor( # Compare against the Python ^ operator. if res.dtype == xp.bool: - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = bool(_left[idx]) s_right = bool(_right[idx]) s_res = bool(res[idx]) assert (s_left ^ s_right) == s_res else: - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = int(_left[idx]) s_right = int(_right[idx]) s_res = int(res[idx]) @@ -721,7 +722,7 @@ def test_equal( _right = ah.asarray(_right, dtype=promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): x1_idx = _left[idx] x2_idx = _right[idx] out_idx = out[idx] @@ -846,7 +847,7 @@ def test_greater( _right = ah.asarray(_right, dtype=promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): out_idx = out[idx] x1_idx = _left[idx] x2_idx = _right[idx] @@ -887,7 +888,7 @@ def test_greater_equal( _right = ah.asarray(_right, dtype=promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): out_idx = out[idx] x1_idx = _left[idx] x2_idx = _right[idx] @@ -907,7 +908,7 @@ def test_isfinite(x): # Test the exact value by comparing to the math version if dh.is_float_dtype(x.dtype): - for idx in ah.ndindex(x.shape): + for idx in sh.ndindex(x.shape): s = float(x[idx]) assert bool(res[idx]) == math.isfinite(s) @@ -925,7 +926,7 @@ def test_isinf(x): # Test the exact value by comparing to the math version if dh.is_float_dtype(x.dtype): - for idx in ah.ndindex(x.shape): + for idx in sh.ndindex(x.shape): s = float(x[idx]) assert bool(res[idx]) == math.isinf(s) @@ -943,7 +944,7 @@ def test_isnan(x): # Test the exact value by comparing to the math version if dh.is_float_dtype(x.dtype): - for idx in ah.ndindex(x.shape): + for idx in sh.ndindex(x.shape): s = float(x[idx]) assert bool(res[idx]) == math.isnan(s) @@ -979,7 +980,7 @@ def test_less( _right = ah.asarray(_right, dtype=promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): x1_idx = _left[idx] x2_idx = _right[idx] out_idx = out[idx] @@ -1020,7 +1021,7 @@ def test_less_equal( _right = ah.asarray(_right, dtype=promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): x1_idx = _left[idx] x2_idx = _right[idx] out_idx = out[idx] @@ -1100,7 +1101,7 @@ def test_logical_and(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): assert out[idx] == (bool(_x1[idx]) and bool(_x2[idx])) @@ -1108,7 +1109,7 @@ def test_logical_and(x1, x2): def test_logical_not(x): out = ah.logical_not(x) ph.assert_shape("logical_not", out.shape, x.shape) - for idx in ah.ndindex(x.shape): + for idx in sh.ndindex(x.shape): assert out[idx] == (not bool(x[idx])) @@ -1122,7 +1123,7 @@ def test_logical_or(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): assert out[idx] == (bool(_x1[idx]) or bool(_x2[idx])) @@ -1136,7 +1137,7 @@ def test_logical_xor(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): assert out[idx] == (bool(_x1[idx]) ^ bool(_x2[idx])) @@ -1225,7 +1226,7 @@ def test_not_equal( _right = ah.asarray(_right, dtype=promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): out_idx = out[idx] x1_idx = _left[idx] x2_idx = _right[idx] diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index ac9f3359..89707d3f 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -18,7 +18,7 @@ from hypothesis.strategies import (booleans, composite, none, tuples, integers, shared, sampled_from) -from .array_helpers import assert_exactly_equal, ndindex, asarray, equal, zero, infinity +from .array_helpers import assert_exactly_equal, asarray, equal, zero, infinity from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes, square_matrix_shapes, symmetric_matrices, positive_definite_matrices, MAX_ARRAY_SIZE, @@ -28,6 +28,7 @@ SQRT_MAX_ARRAY_SIZE, finite_matrices) from . import dtype_helpers as dh from . import pytest_helpers as ph +from . import shape_helpers as sh from .algos import broadcast_shapes @@ -53,7 +54,7 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, **kw): shape = args[0].shape if len(args) == 1 else broadcast_shapes(*[x.shape for x in args]) - for _idx in ndindex(shape[:-2]): + for _idx in sh.ndindex(shape[:-2]): idx = _idx + (slice(None),)*dims res_stack = res[idx] x_stacks = [x[_idx + (...,)] for x in args] @@ -147,7 +148,7 @@ def test_cross(x1_x2_kw): # is the only function that works the way it does, so it's not really # worth generalizing _test_stacks to handle it. a = axis if axis >= 0 else axis + len(shape) - for _idx in ndindex(shape[:a] + shape[a+1:]): + for _idx in sh.ndindex(shape[:a] + shape[a+1:]): idx = _idx[:a] + (slice(None),) + _idx[a:] assert len(idx) == len(shape), "Invalid index. This indicates a bug in the test suite." res_stack = res[idx] diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py new file mode 100644 index 00000000..5cc68f80 --- /dev/null +++ b/array_api_tests/test_manipulation_functions.py @@ -0,0 +1,363 @@ +import math +from collections import deque +from typing import Iterable, Iterator, Tuple, Union + +import pytest +from hypothesis import assume, given +from hypothesis import strategies as st + +from . import _array_module as xp +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import shape_helpers as sh +from . import xps +from .typing import Array, Shape + +MAX_SIDE = hh.MAX_ARRAY_SIZE // 64 +MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims + + +def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: + key = "shape" + if args: + key += " " + " ".join(args) + if kwargs: + key += " " + ph.fmt_kw(kwargs) + return st.shared(hh.shapes(*args, **kwargs), key="shape") + + +def assert_array_ndindex( + func_name: str, + x: Array, + x_indices: Iterable[Union[int, Shape]], + out: Array, + out_indices: Iterable[Union[int, Shape]], + /, + **kw, +): + msg_suffix = f" [{func_name}({ph.fmt_kw(kw)})]\n {x=}\n{out=}" + for x_idx, out_idx in zip(x_indices, out_indices): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + msg += msg_suffix + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg + + +@st.composite +def concat_shapes(draw, shape, axis): + shape = list(shape) + shape[axis] = draw(st.integers(1, MAX_SIDE)) + return tuple(shape) + + +@given( + dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), + kw=hh.kwargs(axis=st.none() | st.integers(-MAX_DIMS, MAX_DIMS - 1)), + data=st.data(), +) +def test_concat(dtypes, kw, data): + axis = kw.get("axis", 0) + if axis is None: + shape_strat = hh.shapes() + else: + _axis = axis if axis >= 0 else abs(axis) - 1 + shape_strat = shared_shapes(min_dims=_axis + 1).flatmap( + lambda s: concat_shapes(s, axis) + ) + arrays = [] + for i, dtype in enumerate(dtypes, 1): + x = data.draw(xps.arrays(dtype=dtype, shape=shape_strat), label=f"x{i}") + arrays.append(x) + + out = xp.concat(arrays, **kw) + + ph.assert_dtype("concat", dtypes, out.dtype) + + shapes = tuple(x.shape for x in arrays) + axis = kw.get("axis", 0) + if axis is None: + size = sum(math.prod(s) for s in shapes) + shape = (size,) + else: + shape = list(shapes[0]) + for other_shape in shapes[1:]: + shape[axis] += other_shape[axis] + shape = tuple(shape) + ph.assert_result_shape("concat", shapes, out.shape, shape, **kw) + + if axis is None: + out_indices = (i for i in range(out.size)) + for x_num, x in enumerate(arrays, 1): + for x_idx in sh.ndindex(x.shape): + out_i = next(out_indices) + ph.assert_0d_equals( + "concat", + f"x{x_num}[{x_idx}]", + x[x_idx], + f"out[{out_i}]", + out[out_i], + **kw, + ) + else: + out_indices = sh.ndindex(out.shape) + for idx in sh.axis_ndindex(shapes[0], _axis): + f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) + for x_num, x in enumerate(arrays, 1): + indexed_x = x[idx] + for x_idx in sh.ndindex(indexed_x.shape): + out_idx = next(out_indices) + ph.assert_0d_equals( + "concat", + f"x{x_num}[{f_idx}][{x_idx}]", + indexed_x[x_idx], + f"out[{out_idx}]", + out[out_idx], + **kw, + ) + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), + axis=shared_shapes().flatmap( + # Generate both valid and invalid axis + lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s)) + ), +) +def test_expand_dims(x, axis): + if axis < -x.ndim - 1 or axis > x.ndim: + with pytest.raises(IndexError): + xp.expand_dims(x, axis=axis) + return + + out = xp.expand_dims(x, axis=axis) + + ph.assert_dtype("expand_dims", x.dtype, out.dtype) + + shape = [side for side in x.shape] + index = axis if axis >= 0 else x.ndim + axis + 1 + shape.insert(index, 1) + shape = tuple(shape) + ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape) + + assert_array_ndindex( + "expand_dims", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape) + ) + + +@given( + x=xps.arrays( + dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1).filter(lambda s: 1 in s) + ), + data=st.data(), +) +def test_squeeze(x, data): + axes = st.integers(-x.ndim, x.ndim - 1) + axis = data.draw( + axes + | st.lists(axes, unique_by=lambda i: i if i >= 0 else i + x.ndim).map(tuple), + label="axis", + ) + + axes = (axis,) if isinstance(axis, int) else axis + axes = sh.normalise_axis(axes, x.ndim) + + squeezable_axes = [i for i, side in enumerate(x.shape) if side == 1] + if any(i not in squeezable_axes for i in axes): + with pytest.raises(ValueError): + xp.squeeze(x, axis) + return + + out = xp.squeeze(x, axis) + + ph.assert_dtype("squeeze", x.dtype, out.dtype) + + shape = [] + for i, side in enumerate(x.shape): + if i not in axes: + shape.append(side) + shape = tuple(shape) + ph.assert_result_shape("squeeze", (x.shape,), out.shape, shape, axis=axis) + + assert_array_ndindex("squeeze", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)) + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), + data=st.data(), +) +def test_flip(x, data): + if x.ndim == 0: + axis_strat = st.none() + else: + axis_strat = ( + st.none() | st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) + ) + kw = data.draw(hh.kwargs(axis=axis_strat), label="kw") + + out = xp.flip(x, **kw) + + ph.assert_dtype("flip", x.dtype, out.dtype) + + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + for indices in sh.axes_ndindex(x.shape, _axes): + reverse_indices = indices[::-1] + assert_array_ndindex("flip", x, indices, out, reverse_indices) + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes(min_dims=1)), + axes=shared_shapes(min_dims=1).flatmap( + lambda s: st.lists( + st.integers(0, len(s) - 1), + min_size=len(s), + max_size=len(s), + unique=True, + ).map(tuple) + ), +) +def test_permute_dims(x, axes): + out = xp.permute_dims(x, axes) + + ph.assert_dtype("permute_dims", x.dtype, out.dtype) + + shape = [None for _ in range(len(axes))] + for i, dim in enumerate(axes): + side = x.shape[dim] + shape[i] = side + shape = tuple(shape) + ph.assert_result_shape("permute_dims", (x.shape,), out.shape, shape, axes=axes) + + indices = list(sh.ndindex(x.shape)) + permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices] + assert_array_ndindex("permute_dims", x, indices, out, permuted_indices) + + +@st.composite +def reshape_shapes(draw, shape): + size = 1 if len(shape) == 0 else math.prod(shape) + rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size)) + assume(all(side <= MAX_SIDE for side in rshape)) + if len(rshape) != 0 and size > 0 and draw(st.booleans()): + index = draw(st.integers(0, len(rshape) - 1)) + rshape[index] = -1 + return tuple(rshape) + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(max_side=MAX_SIDE)), + data=st.data(), +) +def test_reshape(x, data): + shape = data.draw(reshape_shapes(x.shape)) + + out = xp.reshape(x, shape) + + ph.assert_dtype("reshape", x.dtype, out.dtype) + + _shape = list(shape) + if any(side == -1 for side in shape): + size = math.prod(x.shape) + rsize = math.prod(shape) * -1 + _shape[shape.index(-1)] = size / rsize + _shape = tuple(_shape) + ph.assert_result_shape("reshape", (x.shape,), out.shape, _shape, shape=shape) + + assert_array_ndindex("reshape", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)) + + +def roll_ndindex(shape: Shape, shifts: Tuple[int], axes: Tuple[int]) -> Iterator[Shape]: + assert len(shifts) == len(axes) # sanity check + all_shifts = [0 for _ in shape] + for s, a in zip(shifts, axes): + all_shifts[a] = s + for idx in sh.ndindex(shape): + yield tuple((i + sh) % si for i, sh, si in zip(idx, all_shifts, shape)) + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data()) +def test_roll(x, data): + shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE) + if x.ndim > 0: + shift_strat = shift_strat | st.lists( + shift_strat, min_size=1, max_size=x.ndim + ).map(tuple) + shift = data.draw(shift_strat, label="shift") + if isinstance(shift, tuple): + axis_strat = xps.valid_tuple_axes(x.ndim).filter(lambda t: len(t) == len(shift)) + kw_strat = axis_strat.map(lambda t: {"axis": t}) + else: + axis_strat = st.none() + if x.ndim != 0: + axis_strat = axis_strat | st.integers(-x.ndim, x.ndim - 1) + kw_strat = hh.kwargs(axis=axis_strat) + kw = data.draw(kw_strat, label="kw") + + out = xp.roll(x, shift, **kw) + + kw = {"shift": shift, **kw} # for error messages + + ph.assert_dtype("roll", x.dtype, out.dtype) + + ph.assert_result_shape("roll", (x.shape,), out.shape) + + if kw.get("axis", None) is None: + assert isinstance(shift, int) # sanity check + indices = list(sh.ndindex(x.shape)) + shifted_indices = deque(indices) + shifted_indices.rotate(-shift) + assert_array_ndindex("roll", x, indices, out, shifted_indices, **kw) + else: + shifts = (shift,) if isinstance(shift, int) else shift + axes = sh.normalise_axis(kw["axis"], x.ndim) + shifted_indices = roll_ndindex(x.shape, shifts, axes) + assert_array_ndindex("roll", x, sh.ndindex(x.shape), out, shifted_indices, **kw) + + +@given( + shape=shared_shapes(min_dims=1), + dtypes=hh.mutually_promotable_dtypes(None), + kw=hh.kwargs( + axis=shared_shapes(min_dims=1).flatmap( + lambda s: st.integers(-len(s), len(s) - 1) + ) + ), + data=st.data(), +) +def test_stack(shape, dtypes, kw, data): + arrays = [] + for i, dtype in enumerate(dtypes, 1): + x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") + arrays.append(x) + + out = xp.stack(arrays, **kw) + + ph.assert_dtype("stack", dtypes, out.dtype) + + axis = kw.get("axis", 0) + _axis = axis if axis >= 0 else len(shape) + axis + 1 + _shape = list(shape) + _shape.insert(_axis, len(arrays)) + _shape = tuple(_shape) + ph.assert_result_shape( + "stack", tuple(x.shape for x in arrays), out.shape, _shape, **kw + ) + + out_indices = sh.ndindex(out.shape) + for idx in sh.axis_ndindex(arrays[0].shape, axis=_axis): + f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) + print(f"{f_idx=}") + for x_num, x in enumerate(arrays, 1): + indexed_x = x[idx] + for x_idx in sh.ndindex(indexed_x.shape): + out_idx = next(out_indices) + ph.assert_0d_equals( + "stack", + f"x{x_num}[{f_idx}][{x_idx}]", + indexed_x[x_idx], + f"out[{out_idx}]", + out[out_idx], + **kw, + ) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py new file mode 100644 index 00000000..244e7c24 --- /dev/null +++ b/array_api_tests/test_searching_functions.py @@ -0,0 +1,148 @@ +from hypothesis import given +from hypothesis import strategies as st + +from . import _array_module as xp +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import shape_helpers as sh +from . import xps +from .algos import broadcast_shapes + + +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_argmax(x, data): + kw = data.draw( + hh.kwargs( + axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)), + keepdims=st.booleans(), + ), + label="kw", + ) + + out = xp.argmax(x, **kw) + + ph.assert_default_index("argmax", out.dtype) + axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "argmax", out.shape, x.shape, axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(x.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): + max_i = int(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = max(range(len(elements)), key=elements.__getitem__) + ph.assert_scalar_equals("argmax", int, out_idx, max_i, expected) + + +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_argmin(x, data): + kw = data.draw( + hh.kwargs( + axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)), + keepdims=st.booleans(), + ), + label="kw", + ) + + out = xp.argmin(x, **kw) + + ph.assert_default_index("argmin", out.dtype) + axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "argmin", out.shape, x.shape, axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(x.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): + min_i = int(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = min(range(len(elements)), key=elements.__getitem__) + ph.assert_scalar_equals("argmin", int, out_idx, min_i, expected) + + +# TODO: skip if opted out +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +def test_nonzero(x): + out = xp.nonzero(x) + if x.ndim == 0: + assert len(out) == 1, f"{len(out)=}, but should be 1 for 0-dimensional arrays" + else: + assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}" + size = out[0].size + for i in range(len(out)): + assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1" + assert ( + out[i].size == size + ), f"out[{i}].size={x.size}, but should be out[0].size={size}" + ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype") + indices = [] + if x.dtype == xp.bool: + for idx in sh.ndindex(x.shape): + if x[idx]: + indices.append(idx) + else: + for idx in sh.ndindex(x.shape): + if x[idx] != 0: + indices.append(idx) + if x.ndim == 0: + assert out[0].size == len( + indices + ), f"{out[0].size=}, but should be {len(indices)}" + else: + for i in range(size): + idx = tuple(int(x[i]) for x in out) + f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}" + f_element = f"x[{idx}]={x[idx]}" + assert idx in indices, f"{f_idx} results in {f_element}, a zero element" + assert ( + idx == indices[i] + ), f"{f_idx} is in the wrong position, should be {indices.index(idx)}" + + +@given( + shapes=hh.mutually_broadcastable_shapes(3), + dtypes=hh.mutually_promotable_dtypes(), + data=st.data(), +) +def test_where(shapes, dtypes, data): + cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[0]), label="condition") + x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1") + x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2") + + out = xp.where(cond, x1, x2) + + shape = broadcast_shapes(*shapes) + ph.assert_shape("where", out.shape, shape) + # TODO: generate indices without broadcasting arrays + _cond = xp.broadcast_to(cond, shape) + _x1 = xp.broadcast_to(x1, shape) + _x2 = xp.broadcast_to(x2, shape) + for idx in sh.ndindex(shape): + if _cond[idx]: + ph.assert_0d_equals( + "where", f"_x1[{idx}]", _x1[idx], f"out[{idx}]", out[idx] + ) + else: + ph.assert_0d_equals( + "where", f"_x2[{idx}]", _x2[idx], f"out[{idx}]", out[idx] + ) diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py new file mode 100644 index 00000000..02660c6a --- /dev/null +++ b/array_api_tests/test_set_functions.py @@ -0,0 +1,240 @@ +# TODO: disable if opted out, refactor things +import math +from collections import Counter, defaultdict + +from hypothesis import assume, given + +from . import _array_module as xp +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import shape_helpers as sh +from . import xps + + +@given( + xps.arrays( + dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1, min_dims=1, max_dims=1) + ) +) # TODO +def test_unique_all(x): + out = xp.unique_all(x) + + assert hasattr(out, "values") + assert hasattr(out, "indices") + assert hasattr(out, "inverse_indices") + assert hasattr(out, "counts") + + ph.assert_dtype( + "unique_all", x.dtype, out.values.dtype, repr_name="out.values.dtype" + ) + ph.assert_default_index( + "unique_all", out.indices.dtype, repr_name="out.indices.dtype" + ) + ph.assert_default_index( + "unique_all", out.inverse_indices.dtype, repr_name="out.inverse_indices.dtype" + ) + ph.assert_default_index( + "unique_all", out.counts.dtype, repr_name="out.counts.dtype" + ) + + assert ( + out.indices.shape == out.values.shape + ), f"{out.indices.shape=}, but should be {out.values.shape=}" + ph.assert_shape( + "unique_all", + out.inverse_indices.shape, + x.shape, + repr_name="out.inverse_indices.shape", + ) + assert ( + out.counts.shape == out.values.shape + ), f"{out.counts.shape=}, but should be {out.values.shape=}" + + scalar_type = dh.get_scalar_type(out.values.dtype) + counts = defaultdict(int) + firsts = {} + for i, idx in enumerate(sh.ndindex(x.shape)): + val = scalar_type(x[idx]) + if counts[val] == 0: + firsts[val] = i + counts[val] += 1 + + for idx in sh.ndindex(out.indices.shape): + val = scalar_type(out.values[idx]) + if math.isnan(val): + break + i = int(out.indices[idx]) + expected = firsts[val] + assert i == expected, ( + f"out.values[{idx}]={val} and out.indices[{idx}]={i}, " + f"but first occurence of {val} is at {expected}" + ) + + for idx in sh.ndindex(out.inverse_indices.shape): + ridx = int(out.inverse_indices[idx]) + val = out.values[ridx] + expected = x[idx] + msg = ( + f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, " + f"but should result in x[{idx}]={expected}" + ) + if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected): + assert xp.isnan(val), msg + else: + assert val == expected, msg + + vals_idx = {} + nans = 0 + for idx in sh.ndindex(out.values.shape): + val = scalar_type(out.values[idx]) + count = int(out.counts[idx]) + if math.isnan(val): + nans += 1 + assert count == 1, ( + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " + "but count should be 1 as NaNs are distinct" + ) + else: + expected = counts[val] + assert ( + expected > 0 + ), f"out.values[{idx}]={val}, but {val} not in input array" + count = int(out.counts[idx]) + assert count == expected, ( + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " + f"but should be {expected}" + ) + assert ( + val not in vals_idx.keys() + ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" + vals_idx[val] = idx + + if dh.is_float_dtype(out.values.dtype): + assume(x.size <= 128) # may not be representable + expected = sum(v for k, v in counts.items() if math.isnan(k)) + assert nans == expected, f"{nans} NaNs in out, but should be {expected}" + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +def test_unique_counts(x): + out = xp.unique_counts(x) + assert hasattr(out, "values") + assert hasattr(out, "counts") + ph.assert_dtype( + "unique_counts", x.dtype, out.values.dtype, repr_name="out.values.dtype" + ) + ph.assert_default_index( + "unique_counts", out.counts.dtype, repr_name="out.counts.dtype" + ) + assert ( + out.counts.shape == out.values.shape + ), f"{out.counts.shape=}, but should be {out.values.shape=}" + scalar_type = dh.get_scalar_type(out.values.dtype) + counts = Counter(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) + vals_idx = {} + nans = 0 + for idx in sh.ndindex(out.values.shape): + val = scalar_type(out.values[idx]) + count = int(out.counts[idx]) + if math.isnan(val): + nans += 1 + assert count == 1, ( + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " + "but count should be 1 as NaNs are distinct" + ) + else: + expected = counts[val] + assert ( + expected > 0 + ), f"out.values[{idx}]={val}, but {val} not in input array" + count = int(out.counts[idx]) + assert count == expected, ( + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " + f"but should be {expected}" + ) + assert ( + val not in vals_idx.keys() + ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" + vals_idx[val] = idx + if dh.is_float_dtype(out.values.dtype): + assume(x.size <= 128) # may not be representable + expected = sum(v for k, v in counts.items() if math.isnan(k)) + assert nans == expected, f"{nans} NaNs in out, but should be {expected}" + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +def test_unique_inverse(x): + out = xp.unique_inverse(x) + assert hasattr(out, "values") + assert hasattr(out, "inverse_indices") + ph.assert_dtype( + "unique_inverse", x.dtype, out.values.dtype, repr_name="out.values.dtype" + ) + ph.assert_default_index( + "unique_inverse", + out.inverse_indices.dtype, + repr_name="out.inverse_indices.dtype", + ) + ph.assert_shape( + "unique_inverse", + out.inverse_indices.shape, + x.shape, + repr_name="out.inverse_indices.shape", + ) + scalar_type = dh.get_scalar_type(out.values.dtype) + distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) + vals_idx = {} + nans = 0 + for idx in sh.ndindex(out.values.shape): + val = scalar_type(out.values[idx]) + if math.isnan(val): + nans += 1 + else: + assert ( + val in distinct + ), f"out.values[{idx}]={val}, but {val} not in input array" + assert ( + val not in vals_idx.keys() + ), f"out.values[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" + vals_idx[val] = idx + for idx in sh.ndindex(out.inverse_indices.shape): + ridx = int(out.inverse_indices[idx]) + val = out.values[ridx] + expected = x[idx] + msg = ( + f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, " + f"but should result in x[{idx}]={expected}" + ) + if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected): + assert xp.isnan(val), msg + else: + assert val == expected, msg + if dh.is_float_dtype(out.values.dtype): + assume(x.size <= 128) # may not be representable + expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) + assert nans == expected, f"{nans} NaNs in out.values, but should be {expected}" + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +def test_unique_values(x): + out = xp.unique_values(x) + ph.assert_dtype("unique_values", x.dtype, out.dtype) + scalar_type = dh.get_scalar_type(x.dtype) + distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) + vals_idx = {} + nans = 0 + for idx in sh.ndindex(out.shape): + val = scalar_type(out[idx]) + if math.isnan(val): + nans += 1 + else: + assert val in distinct, f"out[{idx}]={val}, but {val} not in input array" + assert ( + val not in vals_idx.keys() + ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" + vals_idx[val] = idx + if dh.is_float_dtype(out.dtype): + assume(x.size <= 128) # may not be representable + expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) + assert nans == expected, f"{nans} NaNs in out, but should be {expected}" diff --git a/array_api_tests/test_sorting_functions.py b/array_api_tests/test_sorting_functions.py new file mode 100644 index 00000000..0c7334cc --- /dev/null +++ b/array_api_tests/test_sorting_functions.py @@ -0,0 +1,97 @@ +from hypothesis import given +from hypothesis import strategies as st +from hypothesis.control import assume + +from . import _array_module as xp +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import shape_helpers as sh +from . import xps + + +# TODO: Test with signed zeros and NaNs (and ignore them somehow) +@given( + x=xps.arrays( + dtype=xps.scalar_dtypes(), + shape=hh.shapes(min_dims=1, min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_argsort(x, data): + if dh.is_float_dtype(x.dtype): + assume(not xp.any(x == -0.0) and not xp.any(x == +0.0)) + + kw = data.draw( + hh.kwargs( + axis=st.integers(-x.ndim, x.ndim - 1), + descending=st.booleans(), + stable=st.booleans(), + ), + label="kw", + ) + + out = xp.argsort(x, **kw) + + ph.assert_default_index("sort", out.dtype) + ph.assert_shape("sort", out.shape, x.shape, **kw) + axis = kw.get("axis", -1) + axes = sh.normalise_axis(axis, x.ndim) + descending = kw.get("descending", False) + scalar_type = dh.get_scalar_type(x.dtype) + for indices in sh.axes_ndindex(x.shape, axes): + elements = [scalar_type(x[idx]) for idx in indices] + indices_order = sorted(range(len(indices)), key=elements.__getitem__) + if descending: + # sorted(..., reverse=descending) doesn't always work + indices_order = reversed(indices_order) + for idx, o in zip(indices, indices_order): + ph.assert_scalar_equals("argsort", int, idx, int(out[idx]), o) + + +# TODO: Test with signed zeros and NaNs (and ignore them somehow) +@given( + x=xps.arrays( + dtype=xps.scalar_dtypes(), + shape=hh.shapes(min_dims=1, min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_sort(x, data): + if dh.is_float_dtype(x.dtype): + assume(not xp.any(x == -0.0) and not xp.any(x == +0.0)) + + kw = data.draw( + hh.kwargs( + axis=st.integers(-x.ndim, x.ndim - 1), + descending=st.booleans(), + stable=st.booleans(), + ), + label="kw", + ) + + out = xp.sort(x, **kw) + + ph.assert_dtype("sort", out.dtype, x.dtype) + ph.assert_shape("sort", out.shape, x.shape, **kw) + axis = kw.get("axis", -1) + axes = sh.normalise_axis(axis, x.ndim) + descending = kw.get("descending", False) + scalar_type = dh.get_scalar_type(x.dtype) + for indices in sh.axes_ndindex(x.shape, axes): + elements = [scalar_type(x[idx]) for idx in indices] + indices_order = sorted( + range(len(indices)), key=elements.__getitem__, reverse=descending + ) + x_indices = [indices[o] for o in indices_order] + for out_idx, x_idx in zip(indices, x_indices): + ph.assert_0d_equals( + "sort", + f"x[{x_idx}]", + x[x_idx], + f"out[{out_idx}]", + out[out_idx], + **kw, + ) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py new file mode 100644 index 00000000..c2fb33db --- /dev/null +++ b/array_api_tests/test_statistical_functions.py @@ -0,0 +1,302 @@ +import math +from typing import Optional + +from hypothesis import assume, given +from hypothesis import strategies as st +from hypothesis.control import reject + +from . import _array_module as xp +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import shape_helpers as sh +from . import xps +from .typing import DataType + + +def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: + dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype] + return st.none() | st.sampled_from(dtypes) + + +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_max(x, data): + kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + + out = xp.max(x, **kw) + + ph.assert_dtype("max", x.dtype, out.dtype) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): + max_ = scalar_type(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = max(elements) + ph.assert_scalar_equals("max", scalar_type, out_idx, max_, expected) + + +@given( + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_mean(x, data): + kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + + out = xp.mean(x, **kw) + + ph.assert_dtype("mean", x.dtype, out.dtype) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): + mean = float(out[out_idx]) + assume(not math.isinf(mean)) # mean may become inf due to internal overflows + elements = [] + for idx in indices: + s = float(x[idx]) + elements.append(s) + expected = sum(elements) / len(elements) + ph.assert_scalar_equals("mean", float, out_idx, mean, expected) + + +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_min(x, data): + kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + + out = xp.min(x, **kw) + + ph.assert_dtype("min", x.dtype, out.dtype) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): + min_ = scalar_type(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = min(elements) + ph.assert_scalar_equals("min", scalar_type, out_idx, min_, expected) + + +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_prod(x, data): + kw = data.draw( + hh.kwargs( + axis=hh.axes(x.ndim), + dtype=kwarg_dtypes(x.dtype), + keepdims=st.booleans(), + ), + label="kw", + ) + + try: + out = xp.prod(x, **kw) + except OverflowError: + reject() + + dtype = kw.get("dtype", None) + if dtype is None: + if dh.is_int_dtype(x.dtype): + if x.dtype in dh.uint_dtypes: + default_dtype = dh.default_uint + else: + default_dtype = dh.default_int + m, M = dh.dtype_ranges[x.dtype] + d_m, d_M = dh.dtype_ranges[default_dtype] + if m < d_m or M > d_M: + _dtype = x.dtype + else: + _dtype = default_dtype + else: + if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: + _dtype = x.dtype + else: + _dtype = dh.default_float + else: + _dtype = dtype + ph.assert_dtype("prod", x.dtype, out.dtype, _dtype) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "prod", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): + prod = scalar_type(out[out_idx]) + assume(math.isfinite(prod)) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = math.prod(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + ph.assert_scalar_equals("prod", scalar_type, out_idx, prod, expected) + + +@given( + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ).filter(lambda x: x.size >= 2), + data=st.data(), +) +def test_std(x, data): + axis = data.draw(hh.axes(x.ndim), label="axis") + _axes = sh.normalise_axis(axis, x.ndim) + N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) + correction = data.draw( + st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), + label="correction", + ) + keepdims = data.draw(st.booleans(), label="keepdims") + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("correction", correction, 0.0), + ("keepdims", keepdims, False), + ), + label="kw", + ) + + out = xp.std(x, **kw) + + ph.assert_dtype("std", x.dtype, out.dtype) + ph.assert_keepdimable_shape( + "std", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + # We can't easily test the result(s) as standard deviation methods vary a lot + + +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_sum(x, data): + kw = data.draw( + hh.kwargs( + axis=hh.axes(x.ndim), + dtype=kwarg_dtypes(x.dtype), + keepdims=st.booleans(), + ), + label="kw", + ) + + try: + out = xp.sum(x, **kw) + except OverflowError: + reject() + + dtype = kw.get("dtype", None) + if dtype is None: + if dh.is_int_dtype(x.dtype): + if x.dtype in dh.uint_dtypes: + default_dtype = dh.default_uint + else: + default_dtype = dh.default_int + m, M = dh.dtype_ranges[x.dtype] + d_m, d_M = dh.dtype_ranges[default_dtype] + if m < d_m or M > d_M: + _dtype = x.dtype + else: + _dtype = default_dtype + else: + if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: + _dtype = x.dtype + else: + _dtype = dh.default_float + else: + _dtype = dtype + ph.assert_dtype("sum", x.dtype, out.dtype, _dtype) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "sum", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): + sum_ = scalar_type(out[out_idx]) + assume(math.isfinite(sum_)) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = sum(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + ph.assert_scalar_equals("sum", scalar_type, out_idx, sum_, expected) + + +@given( + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ).filter(lambda x: x.size >= 2), + data=st.data(), +) +def test_var(x, data): + axis = data.draw(hh.axes(x.ndim), label="axis") + _axes = sh.normalise_axis(axis, x.ndim) + N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) + correction = data.draw( + st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), + label="correction", + ) + keepdims = data.draw(st.booleans(), label="keepdims") + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("correction", correction, 0.0), + ("keepdims", keepdims, False), + ), + label="kw", + ) + + out = xp.var(x, **kw) + + ph.assert_dtype("var", x.dtype, out.dtype) + ph.assert_keepdimable_shape( + "var", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + # We can't easily test the result(s) as variance methods vary a lot diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index d304071a..2fb669e6 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -19,7 +19,7 @@ # TODO: move tests not covering elementwise funcs/ops into standalone tests -# result_type, meshgrid, concat, stack, where, tensordor, vecdot +# result_type, meshgrid, tensordor, vecdot @given(hh.mutually_promotable_dtypes(None)) @@ -51,34 +51,6 @@ def test_meshgrid(dtypes, data): ph.assert_dtype("meshgrid", dtypes, x.dtype, repr_name=f"out[{i}].dtype") -@given( - shape=hh.shapes(min_dims=1), - dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), - data=st.data(), -) -def test_concat(shape, dtypes, data): - arrays = [] - for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") - arrays.append(x) - out = xp.concat(arrays) - ph.assert_dtype("concat", dtypes, out.dtype) - - -@given( - shape=hh.shapes(), - dtypes=hh.mutually_promotable_dtypes(None), - data=st.data(), -) -def test_stack(shape, dtypes, data): - arrays = [] - for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") - arrays.append(x) - out = xp.stack(arrays) - ph.assert_dtype("stack", dtypes, out.dtype) - - bitwise_shift_funcs = [ "bitwise_left_shift", "bitwise_right_shift", diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py new file mode 100644 index 00000000..c10d0dbd --- /dev/null +++ b/array_api_tests/test_utility_functions.py @@ -0,0 +1,59 @@ +from hypothesis import given +from hypothesis import strategies as st + +from . import _array_module as xp +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import shape_helpers as sh +from . import xps + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_all(x, data): + kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + + out = xp.all(x, **kw) + + ph.assert_dtype("all", x.dtype, out.dtype, xp.bool) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "all", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(x.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): + result = bool(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = all(elements) + ph.assert_scalar_equals("all", scalar_type, out_idx, result, expected) + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), + data=st.data(), +) +def test_any(x, data): + kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + + out = xp.any(x, **kw) + + ph.assert_dtype("any", x.dtype, out.dtype, xp.bool) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "any", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(x.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): + result = bool(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = any(elements) + ph.assert_scalar_equals("any", scalar_type, out_idx, result, expected)