diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 65b4090a..ce749d9e 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,45 +1,45 @@ -from warnings import warn from functools import lru_cache from typing import NamedTuple, Tuple, Union +from warnings import warn from . import _array_module as xp from ._array_module import _UndefinedStub from .typing import DataType, ScalarType - __all__ = [ - 'int_dtypes', - 'uint_dtypes', - 'all_int_dtypes', - 'float_dtypes', - 'numeric_dtypes', - 'all_dtypes', - 'dtype_to_name', - 'bool_and_all_int_dtypes', - 'dtype_to_scalars', - 'is_int_dtype', - 'is_float_dtype', - 'get_scalar_type', - 'dtype_ranges', - 'default_int', - 'default_float', - 'promotion_table', - 'dtype_nbits', - 'dtype_signed', - 'func_in_dtypes', - 'func_returns_bool', - 'binary_op_to_symbol', - 'unary_op_to_symbol', - 'inplace_op_to_symbol', - 'op_to_func', - 'fmt_types', + "int_dtypes", + "uint_dtypes", + "all_int_dtypes", + "float_dtypes", + "numeric_dtypes", + "all_dtypes", + "dtype_to_name", + "bool_and_all_int_dtypes", + "dtype_to_scalars", + "is_int_dtype", + "is_float_dtype", + "get_scalar_type", + "dtype_ranges", + "default_int", + "default_uint", + "default_float", + "promotion_table", + "dtype_nbits", + "dtype_signed", + "func_in_dtypes", + "func_returns_bool", + "binary_op_to_symbol", + "unary_op_to_symbol", + "inplace_op_to_symbol", + "op_to_func", + "fmt_types", ] -_uint_names = ('uint8', 'uint16', 'uint32', 'uint64') -_int_names = ('int8', 'int16', 'int32', 'int64') -_float_names = ('float32', 'float64') -_dtype_names = ('bool',) + _uint_names + _int_names + _float_names +_uint_names = ("uint8", "uint16", "uint32", "uint64") +_int_names = ("int8", "int16", "int32", "int64") +_float_names = ("float32", "float64") +_dtype_names = ("bool",) + _uint_names + _int_names + _float_names uint_dtypes = tuple(getattr(xp, name) for name in _uint_names) @@ -101,17 +101,34 @@ class MinMax(NamedTuple): xp.uint64: MinMax(0, +18_446_744_073_709_551_615), } +dtype_nbits = { + **{d: 8 for d in [xp.int8, xp.uint8]}, + **{d: 16 for d in [xp.int16, xp.uint16]}, + **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, + **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, +} + + +dtype_signed = { + **{d: True for d in int_dtypes}, + **{d: False for d in uint_dtypes}, +} + if isinstance(xp.asarray, _UndefinedStub): default_int = xp.int32 default_float = xp.float32 warn( - 'array module does not have attribute asarray. ' - 'default int is assumed int32, default float is assumed float32' + "array module does not have attribute asarray. " + "default int is assumed int32, default float is assumed float32" ) else: default_int = xp.asarray(int()).dtype default_float = xp.asarray(float()).dtype +if dtype_nbits[default_int] == 32: + default_uint = xp.uint32 +else: + default_uint = xp.uint64 _numeric_promotions = { @@ -173,200 +190,186 @@ def result_type(*dtypes: DataType): return result -dtype_nbits = { - **{d: 8 for d in [xp.int8, xp.uint8]}, - **{d: 16 for d in [xp.int16, xp.uint16]}, - **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, - **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, -} - - -dtype_signed = { - **{d: True for d in int_dtypes}, - **{d: False for d in uint_dtypes}, -} - - func_in_dtypes = { # elementwise - 'abs': numeric_dtypes, - 'acos': float_dtypes, - 'acosh': float_dtypes, - 'add': numeric_dtypes, - 'asin': float_dtypes, - 'asinh': float_dtypes, - 'atan': float_dtypes, - 'atan2': float_dtypes, - 'atanh': float_dtypes, - 'bitwise_and': bool_and_all_int_dtypes, - 'bitwise_invert': bool_and_all_int_dtypes, - 'bitwise_left_shift': all_int_dtypes, - 'bitwise_or': bool_and_all_int_dtypes, - 'bitwise_right_shift': all_int_dtypes, - 'bitwise_xor': bool_and_all_int_dtypes, - 'ceil': numeric_dtypes, - 'cos': float_dtypes, - 'cosh': float_dtypes, - 'divide': float_dtypes, - 'equal': all_dtypes, - 'exp': float_dtypes, - 'expm1': float_dtypes, - 'floor': numeric_dtypes, - 'floor_divide': numeric_dtypes, - 'greater': numeric_dtypes, - 'greater_equal': numeric_dtypes, - 'isfinite': numeric_dtypes, - 'isinf': numeric_dtypes, - 'isnan': numeric_dtypes, - 'less': numeric_dtypes, - 'less_equal': numeric_dtypes, - 'log': float_dtypes, - 'logaddexp': float_dtypes, - 'log10': float_dtypes, - 'log1p': float_dtypes, - 'log2': float_dtypes, - 'logical_and': (xp.bool,), - 'logical_not': (xp.bool,), - 'logical_or': (xp.bool,), - 'logical_xor': (xp.bool,), - 'multiply': numeric_dtypes, - 'negative': numeric_dtypes, - 'not_equal': all_dtypes, - 'positive': numeric_dtypes, - 'pow': float_dtypes, - 'remainder': numeric_dtypes, - 'round': numeric_dtypes, - 'sign': numeric_dtypes, - 'sin': float_dtypes, - 'sinh': float_dtypes, - 'sqrt': float_dtypes, - 'square': numeric_dtypes, - 'subtract': numeric_dtypes, - 'tan': float_dtypes, - 'tanh': float_dtypes, - 'trunc': numeric_dtypes, + "abs": numeric_dtypes, + "acos": float_dtypes, + "acosh": float_dtypes, + "add": numeric_dtypes, + "asin": float_dtypes, + "asinh": float_dtypes, + "atan": float_dtypes, + "atan2": float_dtypes, + "atanh": float_dtypes, + "bitwise_and": bool_and_all_int_dtypes, + "bitwise_invert": bool_and_all_int_dtypes, + "bitwise_left_shift": all_int_dtypes, + "bitwise_or": bool_and_all_int_dtypes, + "bitwise_right_shift": all_int_dtypes, + "bitwise_xor": bool_and_all_int_dtypes, + "ceil": numeric_dtypes, + "cos": float_dtypes, + "cosh": float_dtypes, + "divide": float_dtypes, + "equal": all_dtypes, + "exp": float_dtypes, + "expm1": float_dtypes, + "floor": numeric_dtypes, + "floor_divide": numeric_dtypes, + "greater": numeric_dtypes, + "greater_equal": numeric_dtypes, + "isfinite": numeric_dtypes, + "isinf": numeric_dtypes, + "isnan": numeric_dtypes, + "less": numeric_dtypes, + "less_equal": numeric_dtypes, + "log": float_dtypes, + "logaddexp": float_dtypes, + "log10": float_dtypes, + "log1p": float_dtypes, + "log2": float_dtypes, + "logical_and": (xp.bool,), + "logical_not": (xp.bool,), + "logical_or": (xp.bool,), + "logical_xor": (xp.bool,), + "multiply": numeric_dtypes, + "negative": numeric_dtypes, + "not_equal": all_dtypes, + "positive": numeric_dtypes, + "pow": float_dtypes, + "remainder": numeric_dtypes, + "round": numeric_dtypes, + "sign": numeric_dtypes, + "sin": float_dtypes, + "sinh": float_dtypes, + "sqrt": float_dtypes, + "square": numeric_dtypes, + "subtract": numeric_dtypes, + "tan": float_dtypes, + "tanh": float_dtypes, + "trunc": numeric_dtypes, # searching - 'where': all_dtypes, + "where": all_dtypes, } func_returns_bool = { # elementwise - 'abs': False, - 'acos': False, - 'acosh': False, - 'add': False, - 'asin': False, - 'asinh': False, - 'atan': False, - 'atan2': False, - 'atanh': False, - 'bitwise_and': False, - 'bitwise_invert': False, - 'bitwise_left_shift': False, - 'bitwise_or': False, - 'bitwise_right_shift': False, - 'bitwise_xor': False, - 'ceil': False, - 'cos': False, - 'cosh': False, - 'divide': False, - 'equal': True, - 'exp': False, - 'expm1': False, - 'floor': False, - 'floor_divide': False, - 'greater': True, - 'greater_equal': True, - 'isfinite': True, - 'isinf': True, - 'isnan': True, - 'less': True, - 'less_equal': True, - 'log': False, - 'logaddexp': False, - 'log10': False, - 'log1p': False, - 'log2': False, - 'logical_and': True, - 'logical_not': True, - 'logical_or': True, - 'logical_xor': True, - 'multiply': False, - 'negative': False, - 'not_equal': True, - 'positive': False, - 'pow': False, - 'remainder': False, - 'round': False, - 'sign': False, - 'sin': False, - 'sinh': False, - 'sqrt': False, - 'square': False, - 'subtract': False, - 'tan': False, - 'tanh': False, - 'trunc': False, + "abs": False, + "acos": False, + "acosh": False, + "add": False, + "asin": False, + "asinh": False, + "atan": False, + "atan2": False, + "atanh": False, + "bitwise_and": False, + "bitwise_invert": False, + "bitwise_left_shift": False, + "bitwise_or": False, + "bitwise_right_shift": False, + "bitwise_xor": False, + "ceil": False, + "cos": False, + "cosh": False, + "divide": False, + "equal": True, + "exp": False, + "expm1": False, + "floor": False, + "floor_divide": False, + "greater": True, + "greater_equal": True, + "isfinite": True, + "isinf": True, + "isnan": True, + "less": True, + "less_equal": True, + "log": False, + "logaddexp": False, + "log10": False, + "log1p": False, + "log2": False, + "logical_and": True, + "logical_not": True, + "logical_or": True, + "logical_xor": True, + "multiply": False, + "negative": False, + "not_equal": True, + "positive": False, + "pow": False, + "remainder": False, + "round": False, + "sign": False, + "sin": False, + "sinh": False, + "sqrt": False, + "square": False, + "subtract": False, + "tan": False, + "tanh": False, + "trunc": False, # searching - 'where': False, + "where": False, } unary_op_to_symbol = { - '__invert__': '~', - '__neg__': '-', - '__pos__': '+', + "__invert__": "~", + "__neg__": "-", + "__pos__": "+", } binary_op_to_symbol = { - '__add__': '+', - '__and__': '&', - '__eq__': '==', - '__floordiv__': '//', - '__ge__': '>=', - '__gt__': '>', - '__le__': '<=', - '__lshift__': '<<', - '__lt__': '<', - '__matmul__': '@', - '__mod__': '%', - '__mul__': '*', - '__ne__': '!=', - '__or__': '|', - '__pow__': '**', - '__rshift__': '>>', - '__sub__': '-', - '__truediv__': '/', - '__xor__': '^', + "__add__": "+", + "__and__": "&", + "__eq__": "==", + "__floordiv__": "//", + "__ge__": ">=", + "__gt__": ">", + "__le__": "<=", + "__lshift__": "<<", + "__lt__": "<", + "__matmul__": "@", + "__mod__": "%", + "__mul__": "*", + "__ne__": "!=", + "__or__": "|", + "__pow__": "**", + "__rshift__": ">>", + "__sub__": "-", + "__truediv__": "/", + "__xor__": "^", } op_to_func = { - '__abs__': 'abs', - '__add__': 'add', - '__and__': 'bitwise_and', - '__eq__': 'equal', - '__floordiv__': 'floor_divide', - '__ge__': 'greater_equal', - '__gt__': 'greater', - '__le__': 'less_equal', - '__lt__': 'less', + "__abs__": "abs", + "__add__": "add", + "__and__": "bitwise_and", + "__eq__": "equal", + "__floordiv__": "floor_divide", + "__ge__": "greater_equal", + "__gt__": "greater", + "__le__": "less_equal", + "__lt__": "less", # '__matmul__': 'matmul', # TODO: support matmul - '__mod__': 'remainder', - '__mul__': 'multiply', - '__ne__': 'not_equal', - '__or__': 'bitwise_or', - '__pow__': 'pow', - '__lshift__': 'bitwise_left_shift', - '__rshift__': 'bitwise_right_shift', - '__sub__': 'subtract', - '__truediv__': 'divide', - '__xor__': 'bitwise_xor', - '__invert__': 'bitwise_invert', - '__neg__': 'negative', - '__pos__': 'positive', + "__mod__": "remainder", + "__mul__": "multiply", + "__ne__": "not_equal", + "__or__": "bitwise_or", + "__pow__": "pow", + "__lshift__": "bitwise_left_shift", + "__rshift__": "bitwise_right_shift", + "__sub__": "subtract", + "__truediv__": "divide", + "__xor__": "bitwise_xor", + "__invert__": "bitwise_invert", + "__neg__": "negative", + "__pos__": "positive", } @@ -377,10 +380,10 @@ def result_type(*dtypes: DataType): inplace_op_to_symbol = {} for op, symbol in binary_op_to_symbol.items(): - if op == '__matmul__' or func_returns_bool[op]: + if op == "__matmul__" or func_returns_bool[op]: continue - iop = f'__i{op[2:]}' - inplace_op_to_symbol[iop] = f'{symbol}=' + iop = f"__i{op[2:]}" + inplace_op_to_symbol[iop] = f"{symbol}=" func_in_dtypes[iop] = func_in_dtypes[op] func_returns_bool[iop] = func_returns_bool[op] @@ -394,4 +397,4 @@ def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str: except KeyError: # i.e. dtype is bool, int, or float f_types.append(type_.__name__) - return ', '.join(f_types) + return ", ".join(f_types) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 7dfafc5b..34fa1836 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -1,7 +1,10 @@ import pytest -from ..test_signatures import extension_module +from .. import array_helpers as ah from ..test_creation_functions import frange +from ..test_manipulation_functions import axis_ndindex +from ..test_signatures import extension_module +from ..test_statistical_functions import axes_ndindex def test_extension_module_is_extension(): @@ -24,3 +27,45 @@ def test_extension_func_is_not_extension(): def test_frange(r, size, elements): assert len(r) == size assert list(r) == elements + + +@pytest.mark.parametrize( + "shape, expected", + [((), [()])], +) +def test_ndindex(shape, expected): + assert list(ah.ndindex(shape)) == expected + + +@pytest.mark.parametrize( + "shape, axis, expected", + [ + ((1,), 0, [(slice(None, None),)]), + ((1, 2), 0, [(slice(None, None), slice(None, None))]), + ( + (2, 4), + 1, + [(0, slice(None, None)), (1, slice(None, None))], + ), + ], +) +def test_axis_ndindex(shape, axis, expected): + assert list(axis_ndindex(shape, axis)) == expected + + +@pytest.mark.parametrize( + "shape, axes, expected", + [ + ((), (), [((),)]), + ( + (2, 2), + (0,), + [ + ((0, 0), (1, 0)), + ((0, 1), (1, 1)), + ], + ), + ], +) +def test_axes_ndindex(shape, axes, expected): + assert list(axes_ndindex(shape, axes)) == expected diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 9424ba35..b138af3e 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -67,12 +67,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str: def assert_dtype( func_name: str, - in_dtypes: Tuple[DataType, ...], + in_dtypes: Union[DataType, Tuple[DataType, ...]], out_dtype: DataType, expected: Optional[DataType] = None, *, repr_name: str = "out.dtype", ): + if not isinstance(in_dtypes, tuple): + in_dtypes = (in_dtypes,) f_in_dtypes = dh.fmt_types(in_dtypes) f_out_dtype = dh.dtype_to_name[out_dtype] if expected is None: @@ -149,9 +151,7 @@ def assert_result_shape( f_sig = f" {f_in_shapes} " if kw: f_sig += f", {fmt_kw(kw)}" - msg = ( - f"{repr_name}={out_shape}, but should be {expected} [{func_name}({f_sig})]" - ) + msg = f"{repr_name}={out_shape}, but should be {expected} [{func_name}({f_sig})]" assert out_shape == expected, msg diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py new file mode 100644 index 00000000..9453fc25 --- /dev/null +++ b/array_api_tests/test_manipulation_functions.py @@ -0,0 +1,383 @@ +import math +from collections import deque +from itertools import product +from typing import Iterable, Iterator, Tuple, Union + +import pytest +from hypothesis import assume, given +from hypothesis import strategies as st + +from . import _array_module as xp +from . import array_helpers as ah +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import xps +from .test_statistical_functions import axes_ndindex, normalise_axis +from .typing import Array, Shape + +MAX_SIDE = hh.MAX_ARRAY_SIZE // 64 +MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims + + +def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: + key = "shape" + if args: + key += " " + " ".join(args) + if kwargs: + key += " " + ph.fmt_kw(kwargs) + return st.shared(hh.shapes(*args, **kwargs), key="shape") + + +def axis_ndindex( + shape: Shape, axis: int +) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: + """Generate indices that index all elements in dimensions beyond `axis`""" + assert axis >= 0 # sanity check + axis_indices = [range(side) for side in shape[:axis]] + for _ in range(axis, len(shape)): + axis_indices.append([slice(None, None)]) + yield from product(*axis_indices) + + +def assert_array_ndindex( + func_name: str, + x: Array, + x_indices: Iterable[Union[int, Shape]], + out: Array, + out_indices: Iterable[Union[int, Shape]], +): + msg_suffix = f" [{func_name}()]\n {x=}\n{out=}" + for x_idx, out_idx in zip(x_indices, out_indices): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + msg += msg_suffix + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg + + +def assert_equals( + func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw +): + msg = ( + f"{out_repr}={out_val}, should be {x_repr}={x_val} " + f"[{func_name}({ph.fmt_kw(kw)})]" + ) + if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val): + assert xp.isnan(x_val), msg + else: + assert x_val == out_val, msg + + +@st.composite +def concat_shapes(draw, shape, axis): + shape = list(shape) + shape[axis] = draw(st.integers(1, MAX_SIDE)) + return tuple(shape) + + +@given( + dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), + kw=hh.kwargs(axis=st.none() | st.integers(-MAX_DIMS, MAX_DIMS - 1)), + data=st.data(), +) +def test_concat(dtypes, kw, data): + axis = kw.get("axis", 0) + if axis is None: + shape_strat = hh.shapes() + else: + _axis = axis if axis >= 0 else abs(axis) - 1 + shape_strat = shared_shapes(min_dims=_axis + 1).flatmap( + lambda s: concat_shapes(s, axis) + ) + arrays = [] + for i, dtype in enumerate(dtypes, 1): + x = data.draw(xps.arrays(dtype=dtype, shape=shape_strat), label=f"x{i}") + arrays.append(x) + + out = xp.concat(arrays, **kw) + + ph.assert_dtype("concat", dtypes, out.dtype) + + shapes = tuple(x.shape for x in arrays) + axis = kw.get("axis", 0) + if axis is None: + size = sum(math.prod(s) for s in shapes) + shape = (size,) + else: + shape = list(shapes[0]) + for other_shape in shapes[1:]: + shape[axis] += other_shape[axis] + shape = tuple(shape) + ph.assert_result_shape("concat", shapes, out.shape, shape, **kw) + + if axis is None: + out_indices = (i for i in range(out.size)) + for x_num, x in enumerate(arrays, 1): + for x_idx in ah.ndindex(x.shape): + out_i = next(out_indices) + assert_equals( + "concat", + f"x{x_num}[{x_idx}]", + x[x_idx], + f"out[{out_i}]", + out[out_i], + **kw, + ) + else: + out_indices = ah.ndindex(out.shape) + for idx in axis_ndindex(shapes[0], _axis): + f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) + for x_num, x in enumerate(arrays, 1): + indexed_x = x[idx] + for x_idx in ah.ndindex(indexed_x.shape): + out_idx = next(out_indices) + assert_equals( + "concat", + f"x{x_num}[{f_idx}][{x_idx}]", + indexed_x[x_idx], + f"out[{out_idx}]", + out[out_idx], + **kw, + ) + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), + axis=shared_shapes().flatmap( + # Generate both valid and invalid axis + lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s)) + ), +) +def test_expand_dims(x, axis): + if axis < -x.ndim - 1 or axis > x.ndim: + with pytest.raises(IndexError): + xp.expand_dims(x, axis=axis) + return + + out = xp.expand_dims(x, axis=axis) + + ph.assert_dtype("expand_dims", x.dtype, out.dtype) + + shape = [side for side in x.shape] + index = axis if axis >= 0 else x.ndim + axis + 1 + shape.insert(index, 1) + shape = tuple(shape) + ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape) + + assert_array_ndindex( + "expand_dims", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape) + ) + + +@given( + x=xps.arrays( + dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1).filter(lambda s: 1 in s) + ), + data=st.data(), +) +def test_squeeze(x, data): + axes = st.integers(-x.ndim, x.ndim - 1) + axis = data.draw( + axes + | st.lists(axes, unique_by=lambda i: i if i >= 0 else i + x.ndim).map(tuple), + label="axis", + ) + + axes = (axis,) if isinstance(axis, int) else axis + axes = normalise_axis(axes, x.ndim) + + squeezable_axes = [i for i, side in enumerate(x.shape) if side == 1] + if any(i not in squeezable_axes for i in axes): + with pytest.raises(ValueError): + xp.squeeze(x, axis) + return + + out = xp.squeeze(x, axis) + + ph.assert_dtype("squeeze", x.dtype, out.dtype) + + shape = [] + for i, side in enumerate(x.shape): + if i not in axes: + shape.append(side) + shape = tuple(shape) + ph.assert_result_shape("squeeze", (x.shape,), out.shape, shape, axis=axis) + + assert_array_ndindex("squeeze", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), + data=st.data(), +) +def test_flip(x, data): + if x.ndim == 0: + axis_strat = st.none() + else: + axis_strat = ( + st.none() | st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) + ) + kw = data.draw(hh.kwargs(axis=axis_strat), label="kw") + + out = xp.flip(x, **kw) + + ph.assert_dtype("flip", x.dtype, out.dtype) + + _axes = normalise_axis(kw.get("axis", None), x.ndim) + for indices in axes_ndindex(x.shape, _axes): + reverse_indices = indices[::-1] + assert_array_ndindex("flip", x, indices, out, reverse_indices) + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes(min_dims=1)), + axes=shared_shapes(min_dims=1).flatmap( + lambda s: st.lists( + st.integers(0, len(s) - 1), + min_size=len(s), + max_size=len(s), + unique=True, + ).map(tuple) + ), +) +def test_permute_dims(x, axes): + out = xp.permute_dims(x, axes) + + ph.assert_dtype("permute_dims", x.dtype, out.dtype) + + shape = [None for _ in range(len(axes))] + for i, dim in enumerate(axes): + side = x.shape[dim] + shape[i] = side + shape = tuple(shape) + ph.assert_result_shape("permute_dims", (x.shape,), out.shape, shape, axes=axes) + + indices = list(ah.ndindex(x.shape)) + permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices] + assert_array_ndindex("permute_dims", x, indices, out, permuted_indices) + + +@st.composite +def reshape_shapes(draw, shape): + size = 1 if len(shape) == 0 else math.prod(shape) + rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size)) + assume(all(side <= MAX_SIDE for side in rshape)) + if len(rshape) != 0 and size > 0 and draw(st.booleans()): + index = draw(st.integers(0, len(rshape) - 1)) + rshape[index] = -1 + return tuple(rshape) + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(max_side=MAX_SIDE)), + data=st.data(), +) +def test_reshape(x, data): + shape = data.draw(reshape_shapes(x.shape)) + + out = xp.reshape(x, shape) + + ph.assert_dtype("reshape", x.dtype, out.dtype) + + _shape = list(shape) + if any(side == -1 for side in shape): + size = math.prod(x.shape) + rsize = math.prod(shape) * -1 + _shape[shape.index(-1)] = size / rsize + _shape = tuple(_shape) + ph.assert_result_shape("reshape", (x.shape,), out.shape, _shape, shape=shape) + + assert_array_ndindex("reshape", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) + + +@pytest.mark.skip(reason="faulty test logic") # TODO +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data()) +def test_roll(x, data): + shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE) + if x.ndim > 0: + shift_strat = shift_strat | st.lists( + shift_strat, min_size=1, max_size=x.ndim + ).map(tuple) + shift = data.draw(shift_strat, label="shift") + if isinstance(shift, tuple): + axis_strat = xps.valid_tuple_axes(x.ndim).filter(lambda t: len(t) == len(shift)) + kw_strat = axis_strat.map(lambda t: {"axis": t}) + else: + axis_strat = st.none() + if x.ndim != 0: + axis_strat = axis_strat | st.integers(-x.ndim, x.ndim - 1) + kw_strat = hh.kwargs(axis=axis_strat) + kw = data.draw(kw_strat, label="kw") + + out = xp.roll(x, shift, **kw) + + ph.assert_dtype("roll", x.dtype, out.dtype) + + ph.assert_result_shape("roll", (x.shape,), out.shape) + + if kw.get("axis", None) is None: + assert isinstance(shift, int) # sanity check + indices = list(ah.ndindex(x.shape)) + shifted_indices = deque(indices) + shifted_indices.rotate(-shift) + assert_array_ndindex("roll", x, indices, out, shifted_indices) + else: + _shift = (shift,) if isinstance(shift, int) else shift + axes = normalise_axis(kw["axis"], x.ndim) + all_indices = list(ah.ndindex(x.shape)) + for s, a in zip(_shift, axes): + side = x.shape[a] + for i in range(side): + indices = [idx for idx in all_indices if idx[a] == i] + shifted_indices = deque(indices) + shifted_indices.rotate(-s) + assert_array_ndindex("roll", x, indices, out, shifted_indices) + + +@given( + shape=shared_shapes(min_dims=1), + dtypes=hh.mutually_promotable_dtypes(None), + kw=hh.kwargs( + axis=shared_shapes(min_dims=1).flatmap( + lambda s: st.integers(-len(s), len(s) - 1) + ) + ), + data=st.data(), +) +def test_stack(shape, dtypes, kw, data): + arrays = [] + for i, dtype in enumerate(dtypes, 1): + x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") + arrays.append(x) + + out = xp.stack(arrays, **kw) + + ph.assert_dtype("stack", dtypes, out.dtype) + + axis = kw.get("axis", 0) + _axis = axis if axis >= 0 else len(shape) + axis + 1 + _shape = list(shape) + _shape.insert(_axis, len(arrays)) + _shape = tuple(_shape) + ph.assert_result_shape( + "stack", tuple(x.shape for x in arrays), out.shape, _shape, **kw + ) + + out_indices = ah.ndindex(out.shape) + for idx in axis_ndindex(arrays[0].shape, axis=_axis): + f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) + print(f"{f_idx=}") + for x_num, x in enumerate(arrays, 1): + indexed_x = x[idx] + for x_idx in ah.ndindex(indexed_x.shape): + out_idx = next(out_indices) + assert_equals( + "stack", + f"x{x_num}[{f_idx}][{x_idx}]", + indexed_x[x_idx], + f"out[{out_idx}]", + out[out_idx], + **kw, + ) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py new file mode 100644 index 00000000..c3686bb7 --- /dev/null +++ b/array_api_tests/test_searching_functions.py @@ -0,0 +1,41 @@ +from hypothesis import given +from hypothesis import strategies as st + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +def test_argmin(x): + xp.argmin(x) + # TODO + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +def test_argmax(x): + xp.argmax(x) + # TODO + + +# TODO: generate kwargs, skip if opted out +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +def test_nonzero(x): + xp.nonzero(x) + # TODO + + +# TODO: skip if opted out +@given( + shapes=hh.mutually_broadcastable_shapes(3), + dtypes=hh.mutually_promotable_dtypes(), + data=st.data(), +) +def test_where(shapes, dtypes, data): + cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[0]), label="condition") + x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1") + x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2") + xp.where(cond, x1, x2) + # TODO diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py new file mode 100644 index 00000000..856a7282 --- /dev/null +++ b/array_api_tests/test_set_functions.py @@ -0,0 +1,29 @@ +from hypothesis import given + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_unique_all(x): + xp.unique_all(x) + # TODO + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_unique_counts(x): + xp.unique_counts(x) + # TODO + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_unique_inverse(x): + xp.unique_inverse(x) + # TODO + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_unique_values(x): + xp.unique_values(x) + # TODO diff --git a/array_api_tests/test_sorting.py b/array_api_tests/test_sorting.py new file mode 100644 index 00000000..58179b3c --- /dev/null +++ b/array_api_tests/test_sorting.py @@ -0,0 +1,19 @@ +from hypothesis import given + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_argsort(x): + xp.argsort(x) + # TODO + + +# TODO: generate 0d arrays, generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1))) +def test_sort(x): + xp.sort(x) + # TODO diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py new file mode 100644 index 00000000..7813d2ea --- /dev/null +++ b/array_api_tests/test_statistical_functions.py @@ -0,0 +1,382 @@ +import math +from itertools import product +from typing import Iterator, Optional, Tuple, Union + +from hypothesis import assume, given +from hypothesis import strategies as st +from hypothesis.control import reject + +from . import _array_module as xp +from . import array_helpers as ah +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import xps +from .typing import DataType, Scalar, ScalarType, Shape + + +def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]: + axes_strats = [st.none()] + if ndim != 0: + axes_strats.append(st.integers(-ndim, ndim - 1)) + axes_strats.append(xps.valid_tuple_axes(ndim)) + return st.one_of(axes_strats) + + +def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: + dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype] + return st.none() | st.sampled_from(dtypes) + + +def normalise_axis( + axis: Optional[Union[int, Tuple[int, ...]]], ndim: int +) -> Tuple[int, ...]: + if axis is None: + return tuple(range(ndim)) + axes = axis if isinstance(axis, tuple) else (axis,) + axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes) + return axes + + +def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, ...]]: + """Generate indices that index all elements except in `axes` dimensions""" + base_indices = [] + axes_indices = [] + for axis, side in enumerate(shape): + if axis in axes: + base_indices.append([None]) + axes_indices.append(range(side)) + else: + base_indices.append(range(side)) + axes_indices.append([None]) + for base_idx in product(*base_indices): + indices = [] + for idx in product(*axes_indices): + idx = list(idx) + for axis, side in enumerate(idx): + if axis not in axes: + idx[axis] = base_idx[axis] + idx = tuple(idx) + indices.append(idx) + yield tuple(indices) + + +def assert_keepdimable_shape( + func_name: str, + out_shape: Shape, + in_shape: Shape, + axes: Tuple[int, ...], + keepdims: bool, + /, + **kw, +): + if keepdims: + shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape)) + else: + shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes) + ph.assert_shape(func_name, out_shape, shape, **kw) + + +def assert_equals( + func_name: str, + type_: ScalarType, + idx: Shape, + out: Scalar, + expected: Scalar, + /, + **kw, +): + out_repr = "out" if idx == () else f"out[{idx}]" + f_func = f"{func_name}({ph.fmt_kw(kw)})" + if type_ is bool or type_ is int: + msg = f"{out_repr}={out}, should be {expected} [{f_func}]" + assert out == expected, msg + elif math.isnan(expected): + msg = f"{out_repr}={out}, should be {expected} [{f_func}]" + assert math.isnan(out), msg + else: + msg = f"{out_repr}={out}, should be roughly {expected} [{f_func}]" + assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg + + +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_max(x, data): + kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + + out = xp.max(x, **kw) + + ph.assert_dtype("max", x.dtype, out.dtype) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + max_ = scalar_type(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = max(elements) + assert_equals("max", scalar_type, out_idx, max_, expected) + + +@given( + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_mean(x, data): + kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + + out = xp.mean(x, **kw) + + ph.assert_dtype("mean", x.dtype, out.dtype) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + mean = float(out[out_idx]) + assume(not math.isinf(mean)) # mean may become inf due to internal overflows + elements = [] + for idx in indices: + s = float(x[idx]) + elements.append(s) + expected = sum(elements) / len(elements) + assert_equals("mean", float, out_idx, mean, expected) + + +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_min(x, data): + kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + + out = xp.min(x, **kw) + + ph.assert_dtype("min", x.dtype, out.dtype) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + min_ = scalar_type(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = min(elements) + assert_equals("min", scalar_type, out_idx, min_, expected) + + +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_prod(x, data): + kw = data.draw( + hh.kwargs( + axis=axes(x.ndim), + dtype=kwarg_dtypes(x.dtype), + keepdims=st.booleans(), + ), + label="kw", + ) + + try: + out = xp.prod(x, **kw) + except OverflowError: + reject() + + dtype = kw.get("dtype", None) + if dtype is None: + if dh.is_int_dtype(x.dtype): + if x.dtype in dh.uint_dtypes: + default_dtype = dh.default_uint + else: + default_dtype = dh.default_int + m, M = dh.dtype_ranges[x.dtype] + d_m, d_M = dh.dtype_ranges[default_dtype] + if m < d_m or M > d_M: + _dtype = x.dtype + else: + _dtype = default_dtype + else: + if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: + _dtype = x.dtype + else: + _dtype = dh.default_float + else: + _dtype = dtype + ph.assert_dtype("prod", x.dtype, out.dtype, _dtype) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "prod", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + prod = scalar_type(out[out_idx]) + assume(math.isfinite(prod)) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = math.prod(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + assert_equals("prod", scalar_type, out_idx, prod, expected) + + +@given( + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ).filter(lambda x: x.size >= 2), + data=st.data(), +) +def test_std(x, data): + axis = data.draw(axes(x.ndim), label="axis") + _axes = normalise_axis(axis, x.ndim) + N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) + correction = data.draw( + st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), + label="correction", + ) + keepdims = data.draw(st.booleans(), label="keepdims") + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("correction", correction, 0.0), + ("keepdims", keepdims, False), + ), + label="kw", + ) + + out = xp.std(x, **kw) + + ph.assert_dtype("std", x.dtype, out.dtype) + assert_keepdimable_shape( + "std", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + # We can't easily test the result(s) as standard deviation methods vary a lot + + +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_sum(x, data): + kw = data.draw( + hh.kwargs( + axis=axes(x.ndim), + dtype=kwarg_dtypes(x.dtype), + keepdims=st.booleans(), + ), + label="kw", + ) + + try: + out = xp.sum(x, **kw) + except OverflowError: + reject() + + dtype = kw.get("dtype", None) + if dtype is None: + if dh.is_int_dtype(x.dtype): + if x.dtype in dh.uint_dtypes: + default_dtype = dh.default_uint + else: + default_dtype = dh.default_int + m, M = dh.dtype_ranges[x.dtype] + d_m, d_M = dh.dtype_ranges[default_dtype] + if m < d_m or M > d_M: + _dtype = x.dtype + else: + _dtype = default_dtype + else: + if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: + _dtype = x.dtype + else: + _dtype = dh.default_float + else: + _dtype = dtype + ph.assert_dtype("sum", x.dtype, out.dtype, _dtype) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "sum", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + sum_ = scalar_type(out[out_idx]) + assume(math.isfinite(sum_)) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = sum(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + assert_equals("sum", scalar_type, out_idx, sum_, expected) + + +@given( + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ).filter(lambda x: x.size >= 2), + data=st.data(), +) +def test_var(x, data): + axis = data.draw(axes(x.ndim), label="axis") + _axes = normalise_axis(axis, x.ndim) + N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) + correction = data.draw( + st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), + label="correction", + ) + keepdims = data.draw(st.booleans(), label="keepdims") + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("correction", correction, 0.0), + ("keepdims", keepdims, False), + ), + label="kw", + ) + + out = xp.var(x, **kw) + + ph.assert_dtype("var", x.dtype, out.dtype) + assert_keepdimable_shape( + "var", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + # We can't easily test the result(s) as variance methods vary a lot diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index d304071a..2fb669e6 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -19,7 +19,7 @@ # TODO: move tests not covering elementwise funcs/ops into standalone tests -# result_type, meshgrid, concat, stack, where, tensordor, vecdot +# result_type, meshgrid, tensordor, vecdot @given(hh.mutually_promotable_dtypes(None)) @@ -51,34 +51,6 @@ def test_meshgrid(dtypes, data): ph.assert_dtype("meshgrid", dtypes, x.dtype, repr_name=f"out[{i}].dtype") -@given( - shape=hh.shapes(min_dims=1), - dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), - data=st.data(), -) -def test_concat(shape, dtypes, data): - arrays = [] - for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") - arrays.append(x) - out = xp.concat(arrays) - ph.assert_dtype("concat", dtypes, out.dtype) - - -@given( - shape=hh.shapes(), - dtypes=hh.mutually_promotable_dtypes(None), - data=st.data(), -) -def test_stack(shape, dtypes, data): - arrays = [] - for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") - arrays.append(x) - out = xp.stack(arrays) - ph.assert_dtype("stack", dtypes, out.dtype) - - bitwise_shift_funcs = [ "bitwise_left_shift", "bitwise_right_shift", diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py new file mode 100644 index 00000000..140aa85f --- /dev/null +++ b/array_api_tests/test_utility_functions.py @@ -0,0 +1,19 @@ +from hypothesis import given + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_any(x): + xp.any(x) + # TODO + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_all(x): + xp.all(x) + # TODO