diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 2ac923ad..aed4f1a2 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -1,25 +1,24 @@ +import itertools from functools import reduce -from operator import mul from math import sqrt -import itertools -from typing import Tuple, Optional, List +from operator import mul +from typing import Any, List, NamedTuple, Optional, Tuple from hypothesis import assume -from hypothesis.strategies import (lists, integers, sampled_from, - shared, floats, just, composite, one_of, - none, booleans, SearchStrategy) +from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, + integers, just, lists, none, one_of, + sampled_from, shared) -from .pytest_helpers import nargs -from .array_helpers import ndindex -from .typing import DataType, Shape -from . import dtype_helpers as dh -from ._array_module import (full, float32, float64, bool as bool_dtype, - _UndefinedStub, eye, broadcast_to) from . import _array_module as xp +from . import dtype_helpers as dh from . import xps - +from ._array_module import _UndefinedStub +from ._array_module import bool as bool_dtype +from ._array_module import broadcast_to, eye, float32, float64, full +from .array_helpers import ndindex from .function_stubs import elementwise_functions - +from .pytest_helpers import nargs +from .typing import DataType, Shape # Set this to True to not fail tests just because a dtype isn't implemented. # If no compatible dtype is implemented for a given test, the test will fail @@ -382,3 +381,24 @@ def test_f(x, kw): if draw(booleans()): result[k] = draw(strat) return result + + +class KVD(NamedTuple): + keyword: str + value: Any + default: Any + + +@composite +def specified_kwargs(draw, *keys_values_defaults: KVD): + """Generates valid kwargs given expected defaults. + + When we can't realistically use hh.kwargs() and thus test whether xp infact + defaults correctly, this strategy lets us remove generated arguments if they + are of the default value anyway. + """ + kw = {} + for keyword, value, default in keys_values_defaults: + if value is not default or draw(booleans()): + kw[keyword] = value + return kw diff --git a/array_api_tests/meta/test_hypothesis_helpers.py b/array_api_tests/meta/test_hypothesis_helpers.py index f583e711..b57db4ed 100644 --- a/array_api_tests/meta/test_hypothesis_helpers.py +++ b/array_api_tests/meta/test_hypothesis_helpers.py @@ -4,6 +4,7 @@ from hypothesis import given, strategies as st, settings from .. import _array_module as xp +from .. import xps from .._array_module import _UndefinedStub from .. import array_helpers as ah from .. import dtype_helpers as dh @@ -76,6 +77,37 @@ def run(kw): assert len(c_results) > 0 assert all(isinstance(kw["c"], str) for kw in c_results) + +def test_specified_kwargs(): + results = [] + + @given(n=st.integers(0, 10), d=st.none() | xps.scalar_dtypes(), data=st.data()) + @settings(max_examples=100) + def run(n, d, data): + kw = data.draw( + hh.specified_kwargs( + hh.KVD("n", n, 0), + hh.KVD("d", d, None), + ), + label="kw", + ) + results.append(kw) + run() + + assert all(isinstance(kw, dict) for kw in results) + + assert any(len(kw) == 0 for kw in results) + + assert any("n" not in kw.keys() for kw in results) + assert any("n" in kw.keys() and kw["n"] == 0 for kw in results) + assert any("n" in kw.keys() and kw["n"] != 0 for kw in results) + + assert any("d" not in kw.keys() for kw in results) + assert any("d" in kw.keys() and kw["d"] is None for kw in results) + assert any("d" in kw.keys() and kw["d"] is xp.float64 for kw in results) + + + @given(m=hh.symmetric_matrices(hh.shared_floating_dtypes, finite=st.shared(st.booleans(), key='finite')), dtype=hh.shared_floating_dtypes, diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index f4d7cf1f..7dfafc5b 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -1,9 +1,26 @@ +import pytest + from ..test_signatures import extension_module +from ..test_creation_functions import frange def test_extension_module_is_extension(): - assert extension_module('linalg') + assert extension_module("linalg") def test_extension_func_is_not_extension(): - assert not extension_module('linalg.cross') + assert not extension_module("linalg.cross") + + +@pytest.mark.parametrize( + "r, size, elements", + [ + (frange(0, 1, 1), 1, [0]), + (frange(1, 0, -1), 1, [1]), + (frange(0, 1, -1), 0, []), + (frange(0, 1, 2), 1, [0]), + ], +) +def test_frange(r, size, elements): + assert len(r) == size + assert list(r) == elements diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 129b6fc3..a9573515 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -1,12 +1,27 @@ +import math from inspect import getfullargspec -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union +from . import array_helpers as ah from . import dtype_helpers as dh from . import function_stubs -from .typing import DataType +from .typing import Array, DataType, Scalar, Shape +__all__ = [ + "raises", + "doesnt_raise", + "nargs", + "fmt_kw", + "assert_dtype", + "assert_kw_dtype", + "assert_default_float", + "assert_default_int", + "assert_shape", + "assert_fill", +] -def raises(exceptions, function, message=''): + +def raises(exceptions, function, message=""): """ Like pytest.raises() except it allows custom error messages """ @@ -16,11 +31,14 @@ def raises(exceptions, function, message=''): return except Exception as e: if message: - raise AssertionError(f"Unexpected exception {e!r} (expected {exceptions}): {message}") + raise AssertionError( + f"Unexpected exception {e!r} (expected {exceptions}): {message}" + ) raise AssertionError(f"Unexpected exception {e!r} (expected {exceptions})") raise AssertionError(message) -def doesnt_raise(function, message=''): + +def doesnt_raise(function, message=""): """ The inverse of raises(). @@ -36,10 +54,15 @@ def doesnt_raise(function, message=''): raise AssertionError(f"Unexpected exception {e!r}: {message}") raise AssertionError(f"Unexpected exception {e!r}") + def nargs(func_name): return len(getfullargspec(getattr(function_stubs, func_name)).args) +def fmt_kw(kw: Dict[str, Any]) -> str: + return ", ".join(f"{k}={v}" for k, v in kw.items()) + + def assert_dtype( func_name: str, in_dtypes: Tuple[DataType, ...], @@ -60,3 +83,54 @@ def assert_dtype( assert out_dtype == expected, msg +def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType): + f_kw_dtype = dh.dtype_to_name[kw_dtype] + f_out_dtype = dh.dtype_to_name[out_dtype] + msg = ( + f"out.dtype={f_out_dtype}, but should be {f_kw_dtype} " + f"[{func_name}(dtype={f_kw_dtype})]" + ) + assert out_dtype == kw_dtype, msg + + +def assert_default_float(func_name: str, dtype: DataType): + f_dtype = dh.dtype_to_name[dtype] + f_default = dh.dtype_to_name[dh.default_float] + msg = ( + f"out.dtype={f_dtype}, should be default " + f"floating-point dtype {f_default} [{func_name}()]" + ) + assert dtype == dh.default_float, msg + + +def assert_default_int(func_name: str, dtype: DataType): + f_dtype = dh.dtype_to_name[dtype] + f_default = dh.dtype_to_name[dh.default_int] + msg = ( + f"out.dtype={f_dtype}, should be default " + f"integer dtype {f_default} [{func_name}()]" + ) + assert dtype == dh.default_int, msg + + +def assert_shape( + func_name: str, out_shape: Union[int, Shape], expected: Union[int, Shape], /, **kw +): + if isinstance(out_shape, int): + out_shape = (out_shape,) + if isinstance(expected, int): + expected = (expected,) + msg = ( + f"out.shape={out_shape}, but should be {expected} [{func_name}({fmt_kw(kw)})]" + ) + assert out_shape == expected, msg + + +def assert_fill( + func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw +): + msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}" + if math.isnan(fill_value): + assert ah.all(ah.isnan(out)), msg + else: + assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 46587d68..1098f1e1 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -1,131 +1,238 @@ import math +from itertools import count +from typing import Iterator, NamedTuple, Union + +from hypothesis import assume, given +from hypothesis import strategies as st -from ._array_module import (asarray, arange, ceil, empty, empty_like, eye, full, - full_like, equal, all, linspace, ones, ones_like, - zeros, zeros_like, isnan) from . import _array_module as xp -from .array_helpers import assert_exactly_equal, isintegral -from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE, - shapes, sizes, sqrt_sizes, shared_dtypes, - scalars, kwargs) +from . import array_helpers as ah from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph from . import xps +from .typing import DataType, Scalar + + +class frange(NamedTuple): + start: float + stop: float + step: float + + def __iter__(self) -> Iterator[float]: + pos_range = self.stop > self.start + pos_step = self.step > 0 + if pos_step != pos_range: + return + if pos_range: + for n in count(self.start, self.step): + if n >= self.stop: + break + yield n + else: + for n in count(self.start, self.step): + if n <= self.stop: + break + yield n + + def __len__(self) -> int: + return max(math.ceil((self.stop - self.start) / self.step), 0) + + +# Testing xp.arange() requires bounding the start/stop/step arguments to only +# test argument combinations compliant with the Array API, as well as to not +# produce arrays with sizes not supproted by an array module. +# +# We first make sure generated integers can be represented by an array module's +# default integer type, as even if a float array should be produced a module +# might represent integer arguments as 0d arrays. +# +# This means that float arguments also need to be bound, so that they do not +# require any integer arguments to be outside the representable bounds. +int_min, int_max = dh.dtype_ranges[dh.default_int] +float_min = float(int_min * (hh.MAX_ARRAY_SIZE - 1)) +float_max = float(int_max * (hh.MAX_ARRAY_SIZE - 1)) + + +def reals(min_value=None, max_value=None) -> st.SearchStrategy[Union[int, float]]: + round_ = int + if min_value is not None and min_value > 0: + round_ = math.ceil + elif max_value is not None and max_value < 0: + round_ = math.floor + int_min_value = int_min if min_value is None else max(round_(min_value), int_min) + int_max_value = int_max if max_value is None else min(round_(max_value), int_max) + return st.one_of( + st.integers(int_min_value, int_max_value), + # We do not assign float bounds to the floats() strategy, instead opting + # to filter out-of-bound values. Passing such min/max values will modify + # test case reduction behaviour so that simple bugs will become harder + # for users to identify. Hypothesis plans to improve floats() behaviour + # in https://github.com/HypothesisWorks/hypothesis/issues/2907 + st.floats(min_value, max_value, allow_nan=False, allow_infinity=False).filter( + lambda n: float_min <= n <= float_max + ), + ) + + +@given(dtype=st.none() | hh.numeric_dtypes, data=st.data()) +def test_arange(dtype, data): + if dtype is None or dh.is_float_dtype(dtype): + start = data.draw(reals(), label="start") + stop = data.draw(reals() | st.none(), label="stop") + else: + start = data.draw(xps.from_dtype(dtype), label="start") + stop = data.draw(xps.from_dtype(dtype), label="stop") + if stop is None: + _start = 0 + _stop = start + else: + _start = start + _stop = stop + + # tol is the minimum tolerance for step values, used to avoid scenarios + # where xp.arange() produces arrays that would be over MAX_ARRAY_SIZE. + tol = max(abs(_stop - _start) / (math.sqrt(hh.MAX_ARRAY_SIZE)), 0.01) + assert tol != 0, "tol must not equal 0" # sanity check + assume(-tol > int_min) + assume(tol < int_max) + if dtype is None or dh.is_float_dtype(dtype): + step = data.draw(reals(min_value=tol) | reals(max_value=-tol), label="step") + else: + step_strats = [] + if dtype in dh.int_dtypes: + step_min = min(math.floor(-tol), -1) + step_strats.append(xps.from_dtype(dtype, max_value=step_min)) + step_max = max(math.ceil(tol), 1) + step_strats.append(xps.from_dtype(dtype, min_value=step_max)) + step = data.draw(st.one_of(step_strats), label="step") + assert step != 0, "step must not equal 0" # sanity check -from hypothesis import assume, given -from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared, composite, SearchStrategy - - -int_range = integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE) -float_range = floats(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE, - allow_nan=False) -@given(one_of(int_range, float_range), - one_of(none(), int_range, float_range), - one_of(none(), int_range, float_range).filter(lambda x: x != 0 - and (abs(x) > 0.01 if isinstance(x, float) else True)), - one_of(none(), numeric_dtypes)) -def test_arange(start, stop, step, dtype): - if dtype in dh.dtype_ranges: - m, M = dh.dtype_ranges[dtype] - if (not (m <= start <= M) - or isinstance(stop, int) and not (m <= stop <= M) - or isinstance(step, int) and not (m <= step <= M)): - assume(False) - - kwargs = {} if dtype is None else {'dtype': dtype} - - all_int = (dh.is_int_dtype(dtype) - and isinstance(start, int) - and (stop is None or isinstance(stop, int)) - and (step is None or isinstance(step, int))) + all_int = all(arg is None or isinstance(arg, int) for arg in [start, stop, step]) - if stop is None: - # NB: "start" is really the stop - # step is ignored in this case - a = arange(start, **kwargs) - if all_int: - r = range(start) - elif step is None: - a = arange(start, stop, **kwargs) + if dtype is None: if all_int: - r = range(start, stop) + _dtype = dh.default_int + else: + _dtype = dh.default_float else: - a = arange(start, stop, step, **kwargs) - if all_int: - r = range(start, stop, step) + _dtype = dtype + + # sanity checks + if dh.is_int_dtype(_dtype): + m, M = dh.dtype_ranges[_dtype] + assert m <= _start <= M + assert m <= _stop <= M + assert m <= step <= M + + r = frange(_start, _stop, step) + size = len(r) + assert ( + size <= hh.MAX_ARRAY_SIZE + ), f"{size=} should be no more than {hh.MAX_ARRAY_SIZE}" # sanity check + + args_samples = [(start, stop), (start, stop, step)] + if stop is None: + args_samples.insert(0, (start,)) + args = data.draw(st.sampled_from(args_samples), label="args") + kvds = [hh.KVD("dtype", dtype, None)] + if len(args) != 3: + kvds.insert(0, hh.KVD("step", step, 1)) + kwargs = data.draw(hh.specified_kwargs(*kvds), label="kwargs") + + out = xp.arange(*args, **kwargs) + if dtype is None: - # TODO: What is the correct dtype of a? - pass + if all_int: + ph.assert_default_int("arange", out.dtype) + else: + ph.assert_default_float("arange", out.dtype) else: - assert a.dtype == dtype, "arange() produced an incorrect dtype" - assert a.ndim == 1, "arange() should return a 1-dimensional array" - if all_int: - assert a.shape == (len(r),), "arange() produced incorrect shape" - if len(r) <= MAX_ARRAY_SIZE: - assert list(a) == list(r), "arange() produced incorrect values" + ph.assert_dtype("arange", (out.dtype,), dtype) + f_sig = ", ".join(str(n) for n in args) + if len(kwargs) > 0: + f_sig += f", {ph.fmt_kw(kwargs)}" + f_func = f"[arange({f_sig})]" + assert out.ndim == 1, f"{out.ndim=}, but should be 1 [{f_func}]" + # We check size is roughly as expected to avoid edge cases e.g. + # + # >>> xp.arange(2, step=0.333333333333333) + # [0.0, 0.33, 0.66, 1.0, 1.33, 1.66, 2.0] + # >>> xp.arange(2, step=0.3333333333333333) + # [0.0, 0.33, 0.66, 1.0, 1.33, 1.66] + # + # >>> start, stop, step = 0, 108086391056891901, 1080863910568919 + # >>> x = xp.arange(start, stop, step, dtype=xp.uint64) + # >>> x.size + # 100 + # >>> r = range(start, stop, step) + # >>> len(r) + # 101 + # + min_size = math.floor(size * 0.9) + max_size = max(math.ceil(size * 1.1), 1) + assert ( + min_size <= out.size <= max_size + ), f"{out.size=}, but should be roughly {size} {f_func}" + if dh.is_int_dtype(_dtype): + elements = list(r) + assume(out.size == len(elements)) + ah.assert_exactly_equal(out, ah.asarray(elements, dtype=_dtype)) else: - # This is already implied by the len(r) test above - if (stop is not None - and step is not None - and (step > 0 and stop >= start - or step < 0 and stop <= start)): - assert a.size == ceil(asarray((stop-start)/step)), "arange() produced an array of the incorrect size" - -@given(shapes(), kwargs(dtype=none() | shared_dtypes)) + assume(out.size == size) + if out.size > 0: + assert ah.equal( + out[0], ah.asarray(_start, dtype=out.dtype) + ), f"out[0]={out[0]}, but should be {_start} {f_func}" + + +@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes)) def test_empty(shape, kw): - out = empty(shape, **kw) - dtype = kw.get("dtype", None) or xp.float64 + out = xp.empty(shape, **kw) if kw.get("dtype", None) is None: - assert dh.is_float_dtype(out.dtype), f"empty() returned an array with dtype {out.dtype}, but should be the default float dtype" + ph.assert_default_float("empty", out.dtype) else: - assert out.dtype == dtype, f"{dtype=!s}, but empty() returned an array with dtype {out.dtype}" - if isinstance(shape, int): - shape = (shape,) - assert out.shape == shape, f"{shape=}, but empty() returned an array with shape {out.shape}" + ph.assert_kw_dtype("empty", kw["dtype"], out.dtype) + ph.assert_shape("empty", out.shape, shape, shape=shape) @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shapes()), - kw=kwargs(dtype=none() | xps.scalar_dtypes()) + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), + kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), ) def test_empty_like(x, kw): - out = empty_like(x, **kw) - dtype = kw.get("dtype", None) or x.dtype + out = xp.empty_like(x, **kw) if kw.get("dtype", None) is None: - assert out.dtype == x.dtype, f"{x.dtype=!s}, but empty_like() returned an array with dtype {out.dtype}" + ph.assert_dtype("empty_like", (x.dtype,), out.dtype) else: - assert out.dtype == dtype, f"{dtype=!s}, but empty_like() returned an array with dtype {out.dtype}" - assert out.shape == x.shape, f"{x.shape=}, but empty_like() returned an array with shape {out.shape}" - + ph.assert_kw_dtype("empty_like", kw["dtype"], out.dtype) + ph.assert_shape("empty_like", out.shape, x.shape) -# TODO: Use this method for all optional arguments -optional_marker = object() -@given(sqrt_sizes, one_of(just(optional_marker), none(), sqrt_sizes), one_of(none(), integers()), numeric_dtypes) -def test_eye(n_rows, n_cols, k, dtype): - kwargs = {k: v for k, v in {'k': k, 'dtype': dtype}.items() if v - is not None} - if n_cols is optional_marker: - a = eye(n_rows, **kwargs) - n_cols = None - else: - a = eye(n_rows, n_cols, **kwargs) - if dtype is None: - assert dh.is_float_dtype(a.dtype), "eye() should return an array with the default floating point dtype" +@given( + n_rows=hh.sqrt_sizes, + n_cols=st.none() | hh.sqrt_sizes, + kw=hh.kwargs( + k=st.integers(), + dtype=xps.numeric_dtypes(), + ), +) +def test_eye(n_rows, n_cols, kw): + out = xp.eye(n_rows, n_cols, **kw) + if kw.get("dtype", None) is None: + ph.assert_default_float("eye", out.dtype) else: - assert a.dtype == dtype, "eye() did not produce the correct dtype" - - if n_cols is None: - n_cols = n_rows - assert a.shape == (n_rows, n_cols), "eye() produced an array with incorrect shape" - - if k is None: - k = 0 + ph.assert_kw_dtype("eye", kw["dtype"], out.dtype) + _n_cols = n_rows if n_cols is None else n_cols + ph.assert_shape("eye", out.shape, (n_rows, _n_cols), n_rows=n_rows, n_cols=n_cols) + f_func = f"[eye({n_rows=}, {n_cols=})]" for i in range(n_rows): - for j in range(n_cols): - if j - i == k: - assert a[i, j] == 1, "eye() did not produce a 1 on the diagonal" + for j in range(_n_cols): + f_indexed_out = f"out[{i}, {j}]={out[i, j]}" + if j - i == kw.get("k", 0): + assert out[i, j] == 1, f"{f_indexed_out}, should be 1 {f_func}" else: - assert a[i, j] == 0, "eye() did not produce a 0 off the diagonal" + assert out[i, j] == 0, f"{f_indexed_out}, should be 0 {f_func}" default_unsafe_dtypes = [xp.uint64] @@ -133,121 +240,148 @@ def test_eye(n_rows, n_cols, k, dtype): default_unsafe_dtypes.extend([xp.uint32, xp.int64]) if dh.default_float == xp.float32: default_unsafe_dtypes.append(xp.float64) -default_safe_scalar_dtypes: SearchStrategy = xps.scalar_dtypes().filter( +default_safe_dtypes: st.SearchStrategy = xps.scalar_dtypes().filter( lambda d: d not in default_unsafe_dtypes ) -@composite -def full_fill_values(draw): - kw = draw(shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_kw")) - dtype = kw.get("dtype", None) or draw(default_safe_scalar_dtypes) +@st.composite +def full_fill_values(draw) -> st.SearchStrategy[float]: + kw = draw( + st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw") + ) + dtype = kw.get("dtype", None) or draw(default_safe_dtypes) return draw(xps.from_dtype(dtype)) @given( - shape=shapes(), + shape=hh.shapes(), fill_value=full_fill_values(), - kw=shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_kw"), + kw=st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw"), ) def test_full(shape, fill_value, kw): - out = full(shape, fill_value, **kw) + out = xp.full(shape, fill_value, **kw) if kw.get("dtype", None): dtype = kw["dtype"] elif isinstance(fill_value, bool): dtype = xp.bool elif isinstance(fill_value, int): - dtype = xp.int64 + dtype = dh.default_int else: - dtype = xp.float64 + dtype = dh.default_float if kw.get("dtype", None) is None: - if dtype == xp.float64: - assert dh.is_float_dtype(out.dtype), f"full() returned an array with dtype {out.dtype}, but should be the default float dtype" - elif dtype == xp.int64: - assert out.dtype == xp.int32 or out.dtype == xp.int64, f"full() returned an array with dtype {out.dtype}, but should be the default integer dtype" + if isinstance(fill_value, bool): + pass # TODO + elif isinstance(fill_value, int): + ph.assert_default_int("full", out.dtype) else: - assert out.dtype == xp.bool, f"full() returned an array with dtype {out.dtype}, but should be the bool dtype" - else: - assert out.dtype == dtype - assert out.shape == shape, f"{shape=}, but full() returned an array with shape {out.shape}" - if dh.is_float_dtype(out.dtype) and math.isnan(fill_value): - assert all(isnan(out)), "full() array did not equal the fill value" + ph.assert_default_float("full", out.dtype) else: - assert all(equal(out, asarray(fill_value, dtype=dtype))), "full() array did not equal the fill value" + ph.assert_kw_dtype("full", kw["dtype"], out.dtype) + ph.assert_shape("full", out.shape, shape, shape=shape) + ph.assert_fill("full", fill_value, dtype, out, fill_value=fill_value) -@composite +@st.composite def full_like_fill_values(draw): - kw = draw(shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_like_kw")) - dtype = kw.get("dtype", None) or draw(shared_dtypes) + kw = draw( + st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_like_kw") + ) + dtype = kw.get("dtype", None) or draw(hh.shared_dtypes) return draw(xps.from_dtype(dtype)) @given( - x=xps.arrays(dtype=shared_dtypes, shape=shapes()), + x=xps.arrays(dtype=hh.shared_dtypes, shape=hh.shapes()), fill_value=full_like_fill_values(), - kw=shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_like_kw"), + kw=st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_like_kw"), ) def test_full_like(x, fill_value, kw): - out = full_like(x, fill_value, **kw) + out = xp.full_like(x, fill_value, **kw) dtype = kw.get("dtype", None) or x.dtype if kw.get("dtype", None) is None: - assert out.dtype == x.dtype, f"{x.dtype=!s}, but full_like() returned an array with dtype {out.dtype}" + ph.assert_dtype("full_like", (x.dtype,), out.dtype) else: - assert out.dtype == dtype, f"{dtype=!s}, but full_like() returned an array with dtype {out.dtype}" - assert out.shape == x.shape, "{x.shape=}, but full_like() returned an array with shape {out.shape}" - if dh.is_float_dtype(dtype) and math.isnan(fill_value): - assert all(isnan(out)), "full_like() array did not equal the fill value" - else: - assert all(equal(out, asarray(fill_value, dtype=dtype))), "full_like() array did not equal the fill value" - - -@given(scalars(shared_dtypes, finite=True), - scalars(shared_dtypes, finite=True), - sizes, - one_of(none(), shared_dtypes), - one_of(none(), booleans()),) -def test_linspace(start, stop, num, dtype, endpoint): - # Skip on int start or stop that cannot be exactly represented as a float, - # since we do not have good approx_equal helpers yet. - if ((dtype is None or dh.is_float_dtype(dtype)) - and ((isinstance(start, int) and not isintegral(asarray(start, dtype=dtype))) - or (isinstance(stop, int) and not isintegral(asarray(stop, dtype=dtype))))): - assume(False) - - kwargs = {k: v for k, v in {'dtype': dtype, 'endpoint': endpoint}.items() - if v is not None} - a = linspace(start, stop, num, **kwargs) + ph.assert_kw_dtype("full_like", kw["dtype"], out.dtype) + ph.assert_shape("full_like", out.shape, x.shape) + ph.assert_fill("full_like", fill_value, dtype, out, fill_value=fill_value) - if dtype is None: - assert dh.is_float_dtype(a.dtype), "linspace() should return an array with the default floating point dtype" - else: - assert a.dtype == dtype, "linspace() did not produce the correct dtype" - assert a.shape == (num,), "linspace() did not return an array with the correct shape" +finite_kw = {"allow_nan": False, "allow_infinity": False} - if endpoint in [None, True]: - if num > 1: - assert all(equal(a[-1], full((), stop, dtype=a.dtype))), "linspace() produced an array that does not include the endpoint" - else: - # linspace(..., num, endpoint=False) is the same as the first num - # elements of linspace(..., num+1, endpoint=True) - b = linspace(start, stop, num + 1, **{**kwargs, 'endpoint': True}) - assert_exactly_equal(b[:-1], a) - if num > 0: - # We need to cast start to dtype - assert all(equal(a[0], full((), start, dtype=a.dtype))), "linspace() produced an array that does not start with the start" +def int_stops( + start: int, num, dtype: DataType, endpoint: bool +) -> st.SearchStrategy[int]: + min_gap = num + if endpoint: + min_gap += 1 + m, M = dh.dtype_ranges[dtype] + max_pos_gap = M - start + max_neg_gap = start - m + max_pos_mul = max_pos_gap // min_gap + max_neg_mul = max_neg_gap // min_gap + return st.one_of( + st.integers(0, max_pos_mul).map(lambda n: start + min_gap * n), + st.integers(0, max_neg_mul).map(lambda n: start - min_gap * n), + ) + + +@given( + num=hh.sizes, + dtype=st.none() | xps.numeric_dtypes(), + endpoint=st.booleans(), + data=st.data(), +) +def test_linspace(num, dtype, endpoint, data): + _dtype = dh.default_float if dtype is None else dtype + + start = data.draw(xps.from_dtype(_dtype, **finite_kw), label="start") + if dh.is_float_dtype(_dtype): + stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop") + # avoid overflow errors + assume(not ah.isnan(ah.asarray(stop - start, dtype=_dtype))) + assume(not ah.isnan(ah.asarray(start - stop, dtype=_dtype))) + else: + if num == 0: + stop = start + else: + stop = data.draw(int_stops(start, num, _dtype, endpoint), label="stop") - # TODO: This requires an assert_approx_equal function + kw = data.draw( + hh.specified_kwargs( + hh.KVD("dtype", dtype, None), + hh.KVD("endpoint", endpoint, True), + ), + label="kw", + ) + out = xp.linspace(start, stop, num, **kw) - # n = num - 1 if endpoint in [None, True] else num - # for i in range(1, num): - # assert all(equal(a[i], full((), i*(stop - start)/n + start, dtype=dtype))), f"linspace() produced an array with an incorrect value at index {i}" + if dtype is None: + ph.assert_default_float("linspace", out.dtype) + else: + ph.assert_dtype("linspace", (out.dtype,), dtype) + ph.assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num) + f_func = f"[linspace({start}, {stop}, {num})]" + if num > 0: + assert ah.equal( + out[0], ah.asarray(start, dtype=out.dtype) + ), f"out[0]={out[0]}, but should be {start} {f_func}" + if endpoint: + if num > 1: + assert ah.equal( + out[-1], ah.asarray(stop, dtype=out.dtype) + ), f"out[-1]={out[-1]}, but should be {stop} {f_func}" + else: + # linspace(..., num, endpoint=True) should return an array equivalent to + # the first num elements when endpoint=False + expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True) + expected = expected[:-1] + ah.assert_exactly_equal(out, expected) -def make_one(dtype): - if kwargs is None or dh.is_float_dtype(dtype): +def make_one(dtype: DataType) -> Scalar: + if dtype is None or dh.is_float_dtype(dtype): return 1.0 elif dh.is_int_dtype(dtype): return 1 @@ -255,35 +389,35 @@ def make_one(dtype): return True -@given(shapes(), kwargs(dtype=none() | xps.scalar_dtypes())) +@given(hh.shapes(), hh.kwargs(dtype=st.none() | xps.scalar_dtypes())) def test_ones(shape, kw): - out = ones(shape, **kw) - dtype = kw.get("dtype", None) or xp.float64 + out = xp.ones(shape, **kw) if kw.get("dtype", None) is None: - assert dh.is_float_dtype(out.dtype), f"ones() returned an array with dtype {out.dtype}, but should be the default float dtype" + ph.assert_default_float("ones", out.dtype) else: - assert out.dtype == dtype, f"{dtype=!s}, but ones() returned an array with dtype {out.dtype}" - assert out.shape == shape, f"{shape=}, but empty() returned an array with shape {out.shape}" - assert all(equal(out, full((), make_one(dtype), dtype=dtype))), "ones() array did not equal 1" + ph.assert_kw_dtype("ones", kw["dtype"], out.dtype) + ph.assert_shape("ones", out.shape, shape, shape=shape) + dtype = kw.get("dtype", None) or dh.default_float + ph.assert_fill("ones", make_one(dtype), dtype, out) @given( - x=xps.arrays(dtype=dtypes, shape=shapes()), - kw=kwargs(dtype=none() | xps.scalar_dtypes()), + x=xps.arrays(dtype=hh.dtypes, shape=hh.shapes()), + kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), ) def test_ones_like(x, kw): - out = ones_like(x, **kw) - dtype = kw.get("dtype", None) or x.dtype + out = xp.ones_like(x, **kw) if kw.get("dtype", None) is None: - assert out.dtype == x.dtype, f"{x.dtype=!s}, but ones_like() returned an array with dtype {out.dtype}" + ph.assert_dtype("ones_like", (x.dtype,), out.dtype) else: - assert out.dtype == dtype, f"{dtype=!s}, but ones_like() returned an array with dtype {out.dtype}" - assert out.shape == x.shape, "{x.shape=}, but ones_like() returned an array with shape {out.shape}" - assert all(equal(out, full((), make_one(dtype), dtype=dtype))), "ones_like() array elements did not equal 1" + ph.assert_kw_dtype("ones_like", kw["dtype"], out.dtype) + ph.assert_shape("ones_like", out.shape, x.shape) + dtype = kw.get("dtype", None) or x.dtype + ph.assert_fill("ones_like", make_one(dtype), dtype, out) -def make_zero(dtype): - if dh.is_float_dtype(dtype): +def make_zero(dtype: DataType) -> Scalar: + if dtype is None or dh.is_float_dtype(dtype): return 0.0 elif dh.is_int_dtype(dtype): return 0 @@ -291,29 +425,28 @@ def make_zero(dtype): return False -@given(shapes(), kwargs(dtype=none() | xps.scalar_dtypes())) +@given(hh.shapes(), hh.kwargs(dtype=st.none() | xps.scalar_dtypes())) def test_zeros(shape, kw): - out = zeros(shape, **kw) - dtype = kw.get("dtype", None) or xp.float64 + out = xp.zeros(shape, **kw) if kw.get("dtype", None) is None: - assert dh.is_float_dtype(out.dtype), "zeros() returned an array with dtype {out.dtype}, but should be the default float dtype" + ph.assert_default_float("zeros", out.dtype) else: - assert out.dtype == dtype, f"{dtype=!s}, but zeros() returned an array with dtype {out.dtype}" - assert out.shape == shape, "zeros() produced an array with incorrect shape" - assert all(equal(out, full((), make_zero(dtype), dtype=dtype))), "zeros() array did not equal 0" + ph.assert_kw_dtype("zeros", kw["dtype"], out.dtype) + ph.assert_shape("zeros", out.shape, shape, shape=shape) + dtype = kw.get("dtype", None) or dh.default_float + ph.assert_fill("zeros", make_zero(dtype), dtype, out) @given( - x=xps.arrays(dtype=dtypes, shape=shapes()), - kw=kwargs(dtype=none() | xps.scalar_dtypes()), + x=xps.arrays(dtype=hh.dtypes, shape=hh.shapes()), + kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), ) def test_zeros_like(x, kw): - out = zeros_like(x, **kw) - dtype = kw.get("dtype", None) or x.dtype + out = xp.zeros_like(x, **kw) if kw.get("dtype", None) is None: - assert out.dtype == x.dtype, f"{x.dtype=!s}, but zeros_like() returned an array with dtype {out.dtype}" + ph.assert_dtype("zeros_like", (x.dtype,), out.dtype) else: - assert out.dtype == dtype, f"{dtype=!s}, but zeros_like() returned an array with dtype {out.dtype}" - assert out.shape == x.shape, "{x.shape=}, but zeros_like() returned an array with shape {out.shape}" - assert all(equal(out, full((), make_zero(dtype), dtype=out.dtype))), "zeros_like() array elements did not all equal 0" - + ph.assert_kw_dtype("zeros_like", kw["dtype"], out.dtype) + ph.assert_shape("zeros_like", out.shape, x.shape) + dtype = kw.get("dtype", None) or x.dtype + ph.assert_fill("zeros_like", make_zero(dtype), dtype, out) diff --git a/array_api_tests/test_elementwise_functions.py b/array_api_tests/test_elementwise_functions.py index 5aee5c2d..acfc8cfb 100644 --- a/array_api_tests/test_elementwise_functions.py +++ b/array_api_tests/test_elementwise_functions.py @@ -12,35 +12,17 @@ import math from hypothesis import assume, given -from hypothesis import strategies as st from . import _array_module as xp from . import array_helpers as ah -from . import hypothesis_helpers as hh from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph from . import xps # We might as well use this implementation rather than requiring # mod.broadcast_shapes(). See test_equal() and others. from .test_broadcasting import broadcast_shapes -# integer_scalars = hh.array_scalars(integer_dtypes) -floating_scalars = hh.array_scalars(hh.floating_dtypes) -numeric_scalars = hh.array_scalars(hh.numeric_dtypes) -integer_or_boolean_scalars = hh.array_scalars(hh.integer_or_boolean_dtypes) -boolean_scalars = hh.array_scalars(hh.boolean_dtypes) - -two_integer_dtypes = hh.mutually_promotable_dtypes(dtypes=dh.all_int_dtypes) -two_floating_dtypes = hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes) -two_numeric_dtypes = hh.mutually_promotable_dtypes(dtypes=dh.numeric_dtypes) -two_integer_or_boolean_dtypes = hh.mutually_promotable_dtypes(dtypes=dh.bool_and_all_int_dtypes) -two_boolean_dtypes = hh.mutually_promotable_dtypes(dtypes=(xp.bool,)) -two_any_dtypes = hh.mutually_promotable_dtypes() - -@st.composite -def two_array_scalars(draw, dtype1, dtype2): - # two_dtypes should be a strategy that returns two dtypes (like - # hh.mutually_promotable_dtypes()) - return draw(hh.array_scalars(st.just(dtype1))), draw(hh.array_scalars(st.just(dtype2))) @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_abs(x): @@ -51,6 +33,7 @@ def test_abs(x): mask = xp.not_equal(x, ah.full(x.shape, minval, dtype=x.dtype)) x = x[mask] a = xp.abs(x) + ph.assert_shape("abs", a.shape, x.shape) assert ah.all(ah.logical_not(ah.negative_mathematical_sign(a))), "abs(x) did not have positive sign" less_zero = ah.negative_mathematical_sign(x) negx = ah.negative(x) @@ -62,6 +45,7 @@ def test_abs(x): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_acos(x): a = xp.acos(x) + ph.assert_shape("acos", a.shape, x.shape) ONE = ah.one(x.shape, x.dtype) # Here (and elsewhere), should technically be a.dtype, but this is the # same as x.dtype, as tested by the type_promotion tests. @@ -76,6 +60,7 @@ def test_acos(x): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_acosh(x): a = xp.acosh(x) + ph.assert_shape("acosh", a.shape, x.shape) ONE = ah.one(x.shape, x.dtype) INFINITY = ah.infinity(x.shape, x.dtype) ZERO = ah.zero(x.shape, x.dtype) @@ -97,6 +82,7 @@ def test_add(x1, x2): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_asin(x): a = xp.asin(x) + ph.assert_shape("asin", a.shape, x.shape) ONE = ah.one(x.shape, x.dtype) PI = ah.π(x.shape, x.dtype) domain = ah.inrange(x, -ONE, ONE) @@ -108,6 +94,7 @@ def test_asin(x): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_asinh(x): a = xp.asinh(x) + ph.assert_shape("asinh", a.shape, x.shape) INFINITY = ah.infinity(x.shape, x.dtype) domain = ah.inrange(x, -INFINITY, INFINITY) codomain = ah.inrange(a, -INFINITY, INFINITY) @@ -118,6 +105,7 @@ def test_asinh(x): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_atan(x): a = xp.atan(x) + ph.assert_shape("atan", a.shape, x.shape) INFINITY = ah.infinity(x.shape, x.dtype) PI = ah.π(x.shape, x.dtype) domain = ah.inrange(x, -INFINITY, INFINITY) @@ -164,6 +152,7 @@ def test_atan2(x1, x2): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_atanh(x): a = xp.atanh(x) + ph.assert_shape("atanh", a.shape, x.shape) ONE = ah.one(x.shape, x.dtype) INFINITY = ah.infinity(x.shape, x.dtype) domain = ah.inrange(x, -ONE, ONE) @@ -178,7 +167,7 @@ def test_bitwise_and(x1, x2): # TODO: generate indices without broadcasting arrays (see test_equal comment) shape = broadcast_shapes(x1.shape, x2.shape) - assert out.shape == shape + ph.assert_shape("bitwise_and", out.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -205,7 +194,7 @@ def test_bitwise_left_shift(x1, x2): # TODO: generate indices without broadcasting arrays (see test_equal comment) shape = broadcast_shapes(x1.shape, x2.shape) - assert out.shape == shape + ph.assert_shape("bitwise_left_shift", out.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -222,6 +211,7 @@ def test_bitwise_left_shift(x1, x2): @given(xps.arrays(dtype=hh.integer_or_boolean_dtypes, shape=hh.shapes())) def test_bitwise_invert(x): out = xp.bitwise_invert(x) + ph.assert_shape("bitwise_invert", out.shape, x.shape) # Compare against the Python ~ operator. if out.dtype == xp.bool: for idx in ah.ndindex(out.shape): @@ -242,7 +232,7 @@ def test_bitwise_or(x1, x2): # TODO: generate indices without broadcasting arrays (see test_equal comment) shape = broadcast_shapes(x1.shape, x2.shape) - assert out.shape == shape + ph.assert_shape("bitwise_or", out.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -269,7 +259,7 @@ def test_bitwise_right_shift(x1, x2): # TODO: generate indices without broadcasting arrays (see test_equal comment) shape = broadcast_shapes(x1.shape, x2.shape) - assert out.shape == shape + ph.assert_shape("bitwise_right_shift", out.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -288,7 +278,7 @@ def test_bitwise_xor(x1, x2): # TODO: generate indices without broadcasting arrays (see test_equal comment) shape = broadcast_shapes(x1.shape, x2.shape) - assert out.shape == shape + ph.assert_shape("bitwise_xor", out.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -312,6 +302,7 @@ def test_bitwise_xor(x1, x2): def test_ceil(x): # This test is almost identical to test_floor() a = xp.ceil(x) + ph.assert_shape("ceil", a.shape, x.shape) finite = ah.isfinite(x) ah.assert_integral(a[finite]) assert ah.all(ah.less_equal(x[finite], a[finite])) @@ -322,6 +313,7 @@ def test_ceil(x): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_cos(x): a = xp.cos(x) + ph.assert_shape("cos", a.shape, x.shape) ONE = ah.one(x.shape, x.dtype) INFINITY = ah.infinity(x.shape, x.dtype) domain = ah.inrange(x, -INFINITY, INFINITY, open=True) @@ -333,6 +325,7 @@ def test_cos(x): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_cosh(x): a = xp.cosh(x) + ph.assert_shape("cosh", a.shape, x.shape) INFINITY = ah.infinity(x.shape, x.dtype) domain = ah.inrange(x, -INFINITY, INFINITY) codomain = ah.inrange(a, -INFINITY, INFINITY) @@ -366,6 +359,7 @@ def test_equal(x1, x2): # indices to x1 and x2 that correspond to the broadcasted shapes. This # would avoid the dependence in this test on broadcast_to(). shape = broadcast_shapes(x1.shape, x2.shape) + ph.assert_shape("equal", a.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -401,6 +395,7 @@ def test_equal(x1, x2): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_exp(x): a = xp.exp(x) + ph.assert_shape("exp", a.shape, x.shape) INFINITY = ah.infinity(x.shape, x.dtype) ZERO = ah.zero(x.shape, x.dtype) domain = ah.inrange(x, -INFINITY, INFINITY) @@ -412,6 +407,7 @@ def test_exp(x): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_expm1(x): a = xp.expm1(x) + ph.assert_shape("expm1", a.shape, x.shape) INFINITY = ah.infinity(x.shape, x.dtype) NEGONE = -ah.one(x.shape, x.dtype) domain = ah.inrange(x, -INFINITY, INFINITY) @@ -424,6 +420,7 @@ def test_expm1(x): def test_floor(x): # This test is almost identical to test_ceil a = xp.floor(x) + ph.assert_shape("floor", a.shape, x.shape) finite = ah.isfinite(x) ah.assert_integral(a[finite]) assert ah.all(ah.less_equal(a[finite], x[finite])) @@ -461,6 +458,7 @@ def test_greater(x1, x2): # See the comments in test_equal() for a description of how this test # works. shape = broadcast_shapes(x1.shape, x2.shape) + ph.assert_shape("greater", a.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -489,6 +487,7 @@ def test_greater_equal(x1, x2): # See the comments in test_equal() for a description of how this test # works. shape = broadcast_shapes(x1.shape, x2.shape) + ph.assert_shape("greater_equal", a.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -513,9 +512,9 @@ def test_greater_equal(x1, x2): @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_isfinite(x): a = ah.isfinite(x) - TRUE = ah.true(x.shape) + ph.assert_shape("isfinite", a.shape, x.shape) if dh.is_int_dtype(x.dtype): - ah.assert_exactly_equal(a, TRUE) + ah.assert_exactly_equal(a, ah.true(x.shape)) # Test that isfinite, isinf, and isnan are self-consistent. inf = ah.logical_or(xp.isinf(x), ah.isnan(x)) ah.assert_exactly_equal(a, ah.logical_not(inf)) @@ -529,9 +528,11 @@ def test_isfinite(x): @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_isinf(x): a = xp.isinf(x) - FALSE = ah.false(x.shape) + + ph.assert_shape("isinf", a.shape, x.shape) + if dh.is_int_dtype(x.dtype): - ah.assert_exactly_equal(a, FALSE) + ah.assert_exactly_equal(a, ah.false(x.shape)) finite_or_nan = ah.logical_or(ah.isfinite(x), ah.isnan(x)) ah.assert_exactly_equal(a, ah.logical_not(finite_or_nan)) @@ -544,9 +545,11 @@ def test_isinf(x): @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_isnan(x): a = ah.isnan(x) - FALSE = ah.false(x.shape) + + ph.assert_shape("isnan", a.shape, x.shape) + if dh.is_int_dtype(x.dtype): - ah.assert_exactly_equal(a, FALSE) + ah.assert_exactly_equal(a, ah.false(x.shape)) finite_or_inf = ah.logical_or(ah.isfinite(x), xp.isinf(x)) ah.assert_exactly_equal(a, ah.logical_not(finite_or_inf)) @@ -563,6 +566,7 @@ def test_less(x1, x2): # See the comments in test_equal() for a description of how this test # works. shape = broadcast_shapes(x1.shape, x2.shape) + ph.assert_shape("less", a.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -591,6 +595,7 @@ def test_less_equal(x1, x2): # See the comments in test_equal() for a description of how this test # works. shape = broadcast_shapes(x1.shape, x2.shape) + ph.assert_shape("less_equal", a.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -615,6 +620,9 @@ def test_less_equal(x1, x2): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_log(x): a = xp.log(x) + + ph.assert_shape("log", a.shape, x.shape) + INFINITY = ah.infinity(x.shape, x.dtype) ZERO = ah.zero(x.shape, x.dtype) domain = ah.inrange(x, ZERO, INFINITY) @@ -626,6 +634,7 @@ def test_log(x): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_log1p(x): a = xp.log1p(x) + ph.assert_shape("log1p", a.shape, x.shape) INFINITY = ah.infinity(x.shape, x.dtype) NEGONE = -ah.one(x.shape, x.dtype) codomain = ah.inrange(x, NEGONE, INFINITY) @@ -637,6 +646,7 @@ def test_log1p(x): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_log2(x): a = xp.log2(x) + ph.assert_shape("log2", a.shape, x.shape) INFINITY = ah.infinity(x.shape, x.dtype) ZERO = ah.zero(x.shape, x.dtype) domain = ah.inrange(x, ZERO, INFINITY) @@ -648,6 +658,7 @@ def test_log2(x): @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_log10(x): a = xp.log10(x) + ph.assert_shape("log10", a.shape, x.shape) INFINITY = ah.infinity(x.shape, x.dtype) ZERO = ah.zero(x.shape, x.dtype) domain = ah.inrange(x, ZERO, INFINITY) @@ -669,6 +680,7 @@ def test_logical_and(x1, x2): # See the comments in test_equal shape = broadcast_shapes(x1.shape, x2.shape) + ph.assert_shape("logical_and", a.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -678,7 +690,7 @@ def test_logical_and(x1, x2): @given(xps.arrays(dtype=xp.bool, shape=hh.shapes())) def test_logical_not(x): a = ah.logical_not(x) - + ph.assert_shape("logical_not", a.shape, x.shape) for idx in ah.ndindex(x.shape): assert a[idx] == (not bool(x[idx])) @@ -688,6 +700,7 @@ def test_logical_or(x1, x2): # See the comments in test_equal shape = broadcast_shapes(x1.shape, x2.shape) + ph.assert_shape("logical_or", a.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -700,6 +713,7 @@ def test_logical_xor(x1, x2): # See the comments in test_equal shape = broadcast_shapes(x1.shape, x2.shape) + ph.assert_shape("logical_xor", a.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -718,6 +732,8 @@ def test_multiply(x1, x2): def test_negative(x): out = ah.negative(x) + ph.assert_shape("negative", out.shape, x.shape) + # Negation is an involution ah.assert_exactly_equal(x, ah.negative(out)) @@ -741,6 +757,7 @@ def test_not_equal(x1, x2): # See the comments in test_equal() for a description of how this test # works. shape = broadcast_shapes(x1.shape, x2.shape) + ph.assert_shape("not_equal", a.shape, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) @@ -766,6 +783,7 @@ def test_not_equal(x1, x2): @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_positive(x): out = xp.positive(x) + ph.assert_shape("positive", out.shape, x.shape) # Positive does nothing ah.assert_exactly_equal(out, x) @@ -792,6 +810,8 @@ def test_remainder(x1, x2): def test_round(x): a = xp.round(x) + ph.assert_shape("round", a.shape, x.shape) + # Test that the res is integral finite = ah.isfinite(x) ah.assert_integral(a[finite]) @@ -811,53 +831,57 @@ def test_round(x): ah.assert_exactly_equal(a[round_down], floor[round_down]) ah.assert_exactly_equal(a[round_up], ceil[round_up]) -@given(numeric_scalars) +@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_sign(x): - # a = xp.sign(x) - pass + out = xp.sign(x) + ph.assert_shape("sign", out.shape, x.shape) + # TODO -@given(floating_scalars) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_sin(x): - # a = xp.sin(x) - pass + out = xp.sin(x) + ph.assert_shape("sin", out.shape, x.shape) + # TODO -@given(floating_scalars) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_sinh(x): - # a = xp.sinh(x) - pass + out = xp.sinh(x) + ph.assert_shape("sinh", out.shape, x.shape) + # TODO -@given(numeric_scalars) +@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) def test_square(x): - # a = xp.square(x) - pass + out = xp.square(x) + ph.assert_shape("square", out.shape, x.shape) -@given(floating_scalars) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_sqrt(x): - # a = xp.sqrt(x) - pass + out = xp.sqrt(x) + ph.assert_shape("sqrt", out.shape, x.shape) -@given(two_numeric_dtypes.flatmap(lambda i: two_array_scalars(*i))) -def test_subtract(args): - x1, x2 = args - # a = xp.subtract(x1, x2) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) +def test_subtract(x1, x2): + # out = xp.subtract(x1, x2) + pass -@given(floating_scalars) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_tan(x): - # a = xp.tan(x) - pass + out = xp.tan(x) + ph.assert_shape("tan", out.shape, x.shape) + # TODO -@given(floating_scalars) +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) def test_tanh(x): - # a = xp.tanh(x) - pass + out = xp.tanh(x) + ph.assert_shape("tanh", out.shape, x.shape) + # TODO @given(xps.arrays(dtype=hh.numeric_dtypes, shape=xps.array_shapes())) def test_trunc(x): out = xp.trunc(x) - assert out.dtype == x.dtype, f"{x.dtype=!s} but {out.dtype=!s}" - assert out.shape == x.shape, f"{x.shape=} but {out.shape=}" - if x.dtype in dh.all_int_dtypes: - assert ah.all(ah.equal(x, out)), f"{x=!s} but {out=!s}" + ph.assert_shape("bitwise_trunc", out.shape, x.shape) + if dh.is_int_dtype(x.dtype): + ah.assert_exactly_equal(out, x) else: finite = ah.isfinite(x) ah.assert_integral(out[finite]) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index e3c37c11..5e692632 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -1,6 +1,7 @@ """ https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html """ +import math from collections import defaultdict from typing import Tuple, Union, List @@ -24,22 +25,30 @@ @given(hh.mutually_promotable_dtypes(None)) def test_result_type(dtypes): out = xp.result_type(*dtypes) - ph.assert_dtype('result_type', dtypes, out, out_name='out') + ph.assert_dtype("result_type", dtypes, out, out_name="out") +# The number and size of generated arrays is arbitrarily limited to prevent +# meshgrid() running out of memory. @given( - dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), + dtypes=hh.mutually_promotable_dtypes(5, dtypes=dh.numeric_dtypes), data=st.data(), ) def test_meshgrid(dtypes, data): arrays = [] - shapes = data.draw(hh.mutually_broadcastable_shapes(len(dtypes)), label='shapes') + shapes = data.draw( + hh.mutually_broadcastable_shapes( + len(dtypes), min_dims=1, max_dims=1, max_side=5 + ), + label="shapes", + ) for i, (dtype, shape) in enumerate(zip(dtypes, shapes), 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}') + x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") arrays.append(x) + assert math.prod(x.size for x in arrays) <= hh.MAX_ARRAY_SIZE # sanity check out = xp.meshgrid(*arrays) for i, x in enumerate(out): - ph.assert_dtype('meshgrid', dtypes, x.dtype, out_name=f'out[{i}].dtype') + ph.assert_dtype("meshgrid", dtypes, x.dtype, out_name=f"out[{i}].dtype") @given( @@ -50,10 +59,10 @@ def test_meshgrid(dtypes, data): def test_concat(shape, dtypes, data): arrays = [] for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}') + x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") arrays.append(x) out = xp.concat(arrays) - ph.assert_dtype('concat', dtypes, out.dtype) + ph.assert_dtype("concat", dtypes, out.dtype) @given( @@ -64,26 +73,26 @@ def test_concat(shape, dtypes, data): def test_stack(shape, dtypes, data): arrays = [] for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}') + x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") arrays.append(x) out = xp.stack(arrays) - ph.assert_dtype('stack', dtypes, out.dtype) + ph.assert_dtype("stack", dtypes, out.dtype) bitwise_shift_funcs = [ - 'bitwise_left_shift', - 'bitwise_right_shift', - '__lshift__', - '__rshift__', - '__ilshift__', - '__irshift__', + "bitwise_left_shift", + "bitwise_right_shift", + "__lshift__", + "__rshift__", + "__ilshift__", + "__irshift__", ] # We pass kwargs to the elements strategy used by xps.arrays() so that we don't # generate array elements that are erroneous or undefined for a function. func_elements = defaultdict( - lambda: None, {func: {'min_value': 1} for func in bitwise_shift_funcs} + lambda: None, {func: {"min_value": 1} for func in bitwise_shift_funcs} ) @@ -94,7 +103,7 @@ def make_id( ) -> str: f_args = dh.fmt_types(in_dtypes) f_out_dtype = dh.dtype_to_name[out_dtype] - return f'{func_name}({f_args}) -> {f_out_dtype}' + return f"{func_name}({f_args}) -> {f_out_dtype}" func_params: List[Param[str, Tuple[DataType, ...], DataType]] = [] @@ -128,7 +137,7 @@ def make_id( raise NotImplementedError() -@pytest.mark.parametrize('func_name, in_dtypes, out_dtype', func_params) +@pytest.mark.parametrize("func_name, in_dtypes, out_dtype", func_params) @given(data=st.data()) def test_func_promotion(func_name, in_dtypes, out_dtype, data): func = getattr(xp, func_name) @@ -136,17 +145,17 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data): if len(in_dtypes) == 1: x = data.draw( xps.arrays(dtype=in_dtypes[0], shape=hh.shapes(), elements=elements), - label='x', + label="x", ) out = func(x) else: arrays = [] shapes = data.draw( - hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes' + hh.mutually_broadcastable_shapes(len(in_dtypes)), label="shapes" ) for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): x = data.draw( - xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f'x{i}' + xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f"x{i}" ) arrays.append(x) try: @@ -161,46 +170,46 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data): p = pytest.param( (dtype1, dtype2), promoted_dtype, - id=make_id('', (dtype1, dtype2), promoted_dtype), + id=make_id("", (dtype1, dtype2), promoted_dtype), ) promotion_params.append(p) -@pytest.mark.parametrize('in_dtypes, out_dtype', promotion_params) +@pytest.mark.parametrize("in_dtypes, out_dtype", promotion_params) @given(shapes=hh.mutually_broadcastable_shapes(3), data=st.data()) def test_where(in_dtypes, out_dtype, shapes, data): - x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1') - x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2') - cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[2]), label='condition') + x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label="x1") + x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label="x2") + cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[2]), label="condition") out = xp.where(cond, x1, x2) - ph.assert_dtype('where', in_dtypes, out.dtype, out_dtype) + ph.assert_dtype("where", in_dtypes, out.dtype, out_dtype) numeric_promotion_params = promotion_params[1:] -@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_params) +@pytest.mark.parametrize("in_dtypes, out_dtype", numeric_promotion_params) @given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=2), data=st.data()) def test_tensordot(in_dtypes, out_dtype, shapes, data): - x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1') - x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2') + x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label="x1") + x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label="x2") out = xp.tensordot(x1, x2) - ph.assert_dtype('tensordot', in_dtypes, out.dtype, out_dtype) + ph.assert_dtype("tensordot", in_dtypes, out.dtype, out_dtype) -@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_params) +@pytest.mark.parametrize("in_dtypes, out_dtype", numeric_promotion_params) @given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=1), data=st.data()) def test_vecdot(in_dtypes, out_dtype, shapes, data): - x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1') - x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2') + x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label="x1") + x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label="x2") out = xp.vecdot(x1, x2) - ph.assert_dtype('vecdot', in_dtypes, out.dtype, out_dtype) + ph.assert_dtype("vecdot", in_dtypes, out.dtype, out_dtype) op_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = [] op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol} for op, symbol in op_to_symbol.items(): - if op == '__matmul__': + if op == "__matmul__": continue valid_in_dtypes = dh.func_in_dtypes[op] ndtypes = ph.nargs(op) @@ -209,7 +218,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data): out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype p = pytest.param( op, - f'{symbol}x', + f"{symbol}x", (in_dtype,), out_dtype, id=make_id(op, (in_dtype,), out_dtype), @@ -221,42 +230,42 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data): out_dtype = xp.bool if dh.func_returns_bool[op] else promoted_dtype p = pytest.param( op, - f'x1 {symbol} x2', + f"x1 {symbol} x2", (in_dtype1, in_dtype2), out_dtype, id=make_id(op, (in_dtype1, in_dtype2), out_dtype), ) op_params.append(p) # We generate params for abs seperately as it does not have an associated symbol -for in_dtype in dh.func_in_dtypes['__abs__']: +for in_dtype in dh.func_in_dtypes["__abs__"]: p = pytest.param( - '__abs__', - 'abs(x)', + "__abs__", + "abs(x)", (in_dtype,), in_dtype, - id=make_id('__abs__', (in_dtype,), in_dtype), + id=make_id("__abs__", (in_dtype,), in_dtype), ) op_params.append(p) -@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', op_params) +@pytest.mark.parametrize("op, expr, in_dtypes, out_dtype", op_params) @given(data=st.data()) def test_op_promotion(op, expr, in_dtypes, out_dtype, data): elements = func_elements[func_name] if len(in_dtypes) == 1: x = data.draw( xps.arrays(dtype=in_dtypes[0], shape=hh.shapes(), elements=elements), - label='x', + label="x", ) - out = eval(expr, {'x': x}) + out = eval(expr, {"x": x}) else: locals_ = {} shapes = data.draw( - hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes' + hh.mutually_broadcastable_shapes(len(in_dtypes)), label="shapes" ) for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): - locals_[f'x{i}'] = data.draw( - xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f'x{i}' + locals_[f"x{i}"] = data.draw( + xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f"x{i}" ) try: out = eval(expr, locals_) @@ -267,7 +276,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data): inplace_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = [] for op, symbol in dh.inplace_op_to_symbol.items(): - if op == '__imatmul__': + if op == "__imatmul__": continue valid_in_dtypes = dh.func_in_dtypes[op] for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): @@ -278,7 +287,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data): ): p = pytest.param( op, - f'x1 {symbol} x2', + f"x1 {symbol} x2", (in_dtype1, in_dtype2), promoted_dtype, id=make_id(op, (in_dtype1, in_dtype2), promoted_dtype), @@ -286,36 +295,36 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data): inplace_params.append(p) -@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', inplace_params) +@pytest.mark.parametrize("op, expr, in_dtypes, out_dtype", inplace_params) @given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data()) def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data): assume(len(shapes[0]) >= len(shapes[1])) elements = func_elements[func_name] x1 = data.draw( - xps.arrays(dtype=in_dtypes[0], shape=shapes[0], elements=elements), label='x1' + xps.arrays(dtype=in_dtypes[0], shape=shapes[0], elements=elements), label="x1" ) x2 = data.draw( - xps.arrays(dtype=in_dtypes[1], shape=shapes[1], elements=elements), label='x2' + xps.arrays(dtype=in_dtypes[1], shape=shapes[1], elements=elements), label="x2" ) - locals_ = {'x1': x1, 'x2': x2} + locals_ = {"x1": x1, "x2": x2} try: exec(expr, locals_) except OverflowError: reject() - x1 = locals_['x1'] - ph.assert_dtype(op, in_dtypes, x1.dtype, out_dtype, out_name='x1.dtype') + x1 = locals_["x1"] + ph.assert_dtype(op, in_dtypes, x1.dtype, out_dtype, out_name="x1.dtype") op_scalar_params: List[Param[str, str, DataType, ScalarType, DataType]] = [] for op, symbol in dh.binary_op_to_symbol.items(): - if op == '__matmul__': + if op == "__matmul__": continue for in_dtype in dh.func_in_dtypes[op]: out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype for in_stype in dh.dtype_to_scalars[in_dtype]: p = pytest.param( op, - f'x {symbol} s', + f"x {symbol} s", in_dtype, in_stype, out_dtype, @@ -324,17 +333,17 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data): op_scalar_params.append(p) -@pytest.mark.parametrize('op, expr, in_dtype, in_stype, out_dtype', op_scalar_params) +@pytest.mark.parametrize("op, expr, in_dtype, in_stype, out_dtype", op_scalar_params) @given(data=st.data()) def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data): elements = func_elements[func_name] - kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')} - s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label='scalar') + kw = {k: in_stype is float for k in ("allow_nan", "allow_infinity")} + s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label="scalar") x = data.draw( - xps.arrays(dtype=in_dtype, shape=hh.shapes(), elements=elements), label='x' + xps.arrays(dtype=in_dtype, shape=hh.shapes(), elements=elements), label="x" ) try: - out = eval(expr, {'x': x, 's': s}) + out = eval(expr, {"x": x, "s": s}) except OverflowError: reject() ph.assert_dtype(op, (in_dtype, in_stype), out.dtype, out_dtype) @@ -342,13 +351,13 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data): inplace_scalar_params: List[Param[str, str, DataType, ScalarType]] = [] for op, symbol in dh.inplace_op_to_symbol.items(): - if op == '__imatmul__': + if op == "__imatmul__": continue for dtype in dh.func_in_dtypes[op]: for in_stype in dh.dtype_to_scalars[dtype]: p = pytest.param( op, - f'x {symbol} s', + f"x {symbol} s", dtype, in_stype, id=make_id(op, (dtype, in_stype), dtype), @@ -356,25 +365,25 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data): inplace_scalar_params.append(p) -@pytest.mark.parametrize('op, expr, dtype, in_stype', inplace_scalar_params) +@pytest.mark.parametrize("op, expr, dtype, in_stype", inplace_scalar_params) @given(data=st.data()) def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data): elements = func_elements[func_name] - kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')} - s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label='scalar') + kw = {k: in_stype is float for k in ("allow_nan", "allow_infinity")} + s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label="scalar") x = data.draw( - xps.arrays(dtype=dtype, shape=hh.shapes(), elements=elements), label='x' + xps.arrays(dtype=dtype, shape=hh.shapes(), elements=elements), label="x" ) - locals_ = {'x': x, 's': s} + locals_ = {"x": x, "s": s} try: exec(expr, locals_) except OverflowError: reject() - x = locals_['x'] - assert x.dtype == dtype, f'{x.dtype=!s}, but should be {dtype}' - ph.assert_dtype(op, (dtype, in_stype), x.dtype, dtype, out_name='x.dtype') + x = locals_["x"] + assert x.dtype == dtype, f"{x.dtype=!s}, but should be {dtype}" + ph.assert_dtype(op, (dtype, in_stype), x.dtype, dtype, out_name="x.dtype") -if __name__ == '__main__': +if __name__ == "__main__": for (i, j), p in dh.promotion_table.items(): - print(f'({i}, {j}) -> {p}') + print(f"({i}, {j}) -> {p}") diff --git a/array_api_tests/typing.py b/array_api_tests/typing.py index 93165c72..286ce21b 100644 --- a/array_api_tests/typing.py +++ b/array_api_tests/typing.py @@ -2,12 +2,16 @@ __all__ = [ "DataType", + "Scalar", "ScalarType", + "Array", "Shape", "Param", ] DataType = Type[Any] +Scalar = Union[bool, int, float] ScalarType = Union[Type[bool], Type[int], Type[float]] +Array = Any Shape = Tuple[int, ...] Param = Tuple