diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 94576ee3..f691d771 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -121,7 +121,6 @@ def is_float_dtype(dtype): # See https://github.com/numpy/numpy/issues/18434 if dtype is None: return False - # TODO: Return True for float dtypes that aren't part of the spec e.g. np.float16 return dtype in float_dtypes diff --git a/array_api_tests/meta/test_pytest_helpers.py b/array_api_tests/meta/test_pytest_helpers.py index 117e2b11..a6851a15 100644 --- a/array_api_tests/meta/test_pytest_helpers.py +++ b/array_api_tests/meta/test_pytest_helpers.py @@ -13,10 +13,10 @@ def test_assert_dtype(): ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool) -def test_assert_array(): - ph.assert_array("int zeros", xp.asarray(0), xp.asarray(0)) - ph.assert_array("pos zeros", xp.asarray(0.0), xp.asarray(0.0)) +def test_assert_array_elements(): + ph.assert_array_elements("int zeros", xp.asarray(0), xp.asarray(0)) + ph.assert_array_elements("pos zeros", xp.asarray(0.0), xp.asarray(0.0)) with raises(AssertionError): - ph.assert_array("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0)) + ph.assert_array_elements("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0)) with raises(AssertionError): - ph.assert_array("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0)) + ph.assert_array_elements("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0)) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 1fe3ca66..268a81aa 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -91,6 +91,7 @@ def test_roll_ndindex(shape, shifts, axes, expected): ((), "x"), (42, "x[42]"), ((42,), "x[42]"), + ((42, 7), "x[42, 7]"), (slice(None, 2), "x[:2]"), (slice(2, None), "x[2:]"), (slice(0, 2), "x[0:2]"), @@ -98,6 +99,7 @@ def test_roll_ndindex(shape, shifts, axes, expected): (slice(None, None, -1), "x[::-1]"), (slice(None, None), "x[:]"), (..., "x[...]"), + ((None, 42), "x[None, 42]"), ], ) def test_fmt_idx(idx, expected): diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 39513670..78797c30 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -25,7 +25,7 @@ "assert_keepdimable_shape", "assert_0d_equals", "assert_fill", - "assert_array", + "assert_array_elements", ] @@ -301,7 +301,7 @@ def assert_0d_equals( >>> x = xp.asarray([0, 1, 2]) >>> res = xp.asarray(x, copy=True) >>> res[0] = 42 - >>> assert_0d_equals('__setitem__', 'x[0]', x[0], 'x[0]', res[0]) + >>> assert_0d_equals('asarray', 'x[0]', x[0], 'x[0]', res[0]) is equivalent to @@ -374,28 +374,30 @@ def assert_fill( assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg -def assert_array(func_name: str, out: Array, expected: Array, /, **kw): +def assert_array_elements( + func_name: str, out: Array, expected: Array, /, *, out_repr: str = "out", **kw +): """ - Assert array is (strictly) as expected, e.g. + Assert array elements are (strictly) as expected, e.g. >>> x = xp.arange(5) >>> out = xp.asarray(x) - >>> assert_array('asarray', out, x) + >>> assert_array_elements('asarray', out, x) is equivalent to >>> assert xp.all(out == x) """ - assert_dtype(func_name, out.dtype, expected.dtype) - assert_shape(func_name, out.shape, expected.shape, **kw) + dh.result_type(out.dtype, expected.dtype) # sanity check + assert_shape(func_name, out.shape, expected.shape, **kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" if dh.is_float_dtype(out.dtype): for idx in sh.ndindex(out.shape): at_out = out[idx] at_expected = expected[idx] msg = ( - f"{sh.fmt_idx('out', idx)}={at_out}, should be {at_expected} " + f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " f"{f_func}" ) if xp.isnan(at_expected): @@ -411,6 +413,6 @@ def assert_array(func_name: str, out: Array, expected: Array, /, **kw): else: assert at_out == at_expected, msg else: - assert xp.all(out == expected), ( - f"out not as expected {f_func}\n" f"{out=}\n{expected=}" - ) + assert xp.all( + out == expected + ), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py index 9b3d001b..ba7d994e 100644 --- a/array_api_tests/shape_helpers.py +++ b/array_api_tests/shape_helpers.py @@ -156,6 +156,8 @@ def fmt_i(i: AtomicIndex) -> str: if i.step is not None: res += f":{i.step}" return res + elif i is None: + return "None" else: return "..." diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 50db7e51..df3edb88 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import List, get_args +from typing import List, Sequence, Tuple, Union, get_args import pytest from hypothesis import assume, given, note @@ -12,12 +12,15 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .typing import DataType, Param, Scalar, ScalarType, Shape +from .test_operators_and_elementwise_functions import oneway_promotable_dtypes +from .typing import DataType, Index, Param, Scalar, ScalarType, Shape pytestmark = pytest.mark.ci -def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scalar]]: +def scalar_objects( + dtype: DataType, shape: Shape +) -> st.SearchStrategy[Union[Scalar, List[Scalar]]]: """Generates scalars or nested sequences which are valid for xp.asarray()""" size = math.prod(shape) return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map( @@ -25,17 +28,13 @@ def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scal ) -@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays -def test_getitem(shape, data): - dtype = data.draw(xps.scalar_dtypes(), label="dtype") - obj = data.draw(scalar_objects(dtype, shape), label="obj") - x = xp.asarray(obj, dtype=dtype) - note(f"{x=}") - key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key") - - out = x[key] +def normalise_key(key: Index, shape: Shape) -> Tuple[Union[int, slice], ...]: + """ + Normalise an indexing key. - ph.assert_dtype("__getitem__", x.dtype, out.dtype) + * If a non-tuple index, wrap as a tuple. + * Represent ellipsis as equivalent slices. + """ _key = tuple(key) if isinstance(key, tuple) else (key,) if Ellipsis in _key: nonexpanding_key = tuple(i for i in _key if i is not None) @@ -44,71 +43,109 @@ def test_getitem(shape, data): slices = tuple(slice(None) for _ in range(start_a, stop_a)) start_pos = _key.index(Ellipsis) _key = _key[:start_pos] + slices + _key[start_pos + 1 :] + return _key + + +def get_indexed_axes_and_out_shape( + key: Tuple[Union[int, slice, None], ...], shape: Shape +) -> Tuple[Tuple[Sequence[int], ...], Shape]: + """ + From the (normalised) key and input shape, calculates: + + * indexed_axes: For each dimension, the axes which the key indexes. + * out_shape: The resulting shape of indexing an array (of the input shape) + with the key. + """ axes_indices = [] out_shape = [] a = 0 - for i in _key: + for i in key: if i is None: out_shape.append(1) else: + side = shape[a] if isinstance(i, int): - axes_indices.append([i]) + if i < 0: + i += side + axes_indices.append((i,)) else: - assert isinstance(i, slice) # sanity check - side = shape[a] indices = range(side)[i] axes_indices.append(indices) out_shape.append(len(indices)) a += 1 - out_shape = tuple(out_shape) + return tuple(axes_indices), tuple(out_shape) + + +@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data()) +def test_getitem(shape, dtype, data): + zero_sided = any(side == 0 for side in shape) + if zero_sided: + x = xp.zeros(shape, dtype=dtype) + else: + obj = data.draw(scalar_objects(dtype, shape), label="obj") + x = xp.asarray(obj, dtype=dtype) + note(f"{x=}") + key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key") + + out = x[key] + + ph.assert_dtype("__getitem__", x.dtype, out.dtype) + _key = normalise_key(key, shape) + axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape) ph.assert_shape("__getitem__", out.shape, out_shape) - assume(all(len(indices) > 0 for indices in axes_indices)) - out_obj = [] - for idx in product(*axes_indices): - val = obj - for i in idx: - val = val[i] - out_obj.append(val) - out_obj = sh.reshape(out_obj, out_shape) - expected = xp.asarray(out_obj, dtype=dtype) - ph.assert_array("__getitem__", out, expected) - - -@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays -def test_setitem(shape, data): - dtype = data.draw(xps.scalar_dtypes(), label="dtype") - obj = data.draw(scalar_objects(dtype, shape), label="obj") - x = xp.asarray(obj, dtype=dtype) + out_zero_sided = any(side == 0 for side in out_shape) + if not zero_sided and not out_zero_sided: + out_obj = [] + for idx in product(*axes_indices): + val = obj + for i in idx: + val = val[i] + out_obj.append(val) + out_obj = sh.reshape(out_obj, out_shape) + expected = xp.asarray(out_obj, dtype=dtype) + ph.assert_array_elements("__getitem__", out, expected) + + +@given( + shape=hh.shapes(), + dtypes=oneway_promotable_dtypes(dh.all_dtypes), + data=st.data(), +) +def test_setitem(shape, dtypes, data): + zero_sided = any(side == 0 for side in shape) + if zero_sided: + x = xp.zeros(shape, dtype=dtypes.result_dtype) + else: + obj = data.draw(scalar_objects(dtypes.result_dtype, shape), label="obj") + x = xp.asarray(obj, dtype=dtypes.result_dtype) note(f"{x=}") - # TODO: test setting non-0d arrays - key = data.draw(xps.indices(shape=shape, max_dims=0), label="key") - value = data.draw( - xps.from_dtype(dtype) | xps.arrays(dtype=dtype, shape=()), label="value" - ) + key = data.draw(xps.indices(shape=shape), label="key") + _key = normalise_key(key, shape) + axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape) + value_strat = xps.arrays(dtype=dtypes.result_dtype, shape=out_shape) + if out_shape == (): + # We can pass scalars if we're only indexing one element + value_strat |= xps.from_dtype(dtypes.result_dtype) + value = data.draw(value_strat, label="value") res = xp.asarray(x, copy=True) res[key] = value ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype") ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.shape") + f_res = sh.fmt_idx("x", key) if isinstance(value, get_args(Scalar)): - msg = f"x[{key}]={res[key]!r}, but should be {value=} [__setitem__()]" + msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]" if math.isnan(value): assert xp.isnan(res[key]), msg else: assert res[key] == value, msg else: - ph.assert_0d_equals( - "__setitem__", "value", value, f"modified x[{key}]", res[key] - ) - _key = key if isinstance(key, tuple) else (key,) - assume(all(isinstance(i, int) for i in _key)) # TODO: normalise slices and ellipsis - _key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape)) - unaffected_indices = list(sh.ndindex(res.shape)) - unaffected_indices.remove(_key) + ph.assert_array_elements("__setitem__", res[key], value, out_repr=f_res) + unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices)) for idx in unaffected_indices: ph.assert_0d_equals( - "__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx] + "__setitem__", f"old {f_res}", x[idx], f"modified {f_res}", res[idx] ) diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index f5cb6342..4ae92f1f 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -12,6 +12,7 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps +from .test_operators_and_elementwise_functions import oneway_promotable_dtypes from .typing import DataType, Scalar pytestmark = pytest.mark.ci @@ -180,7 +181,7 @@ def test_arange(dtype, data): if dh.is_int_dtype(_dtype): elements = list(r) assume(out.size == len(elements)) - ph.assert_array("arange", out, xp.asarray(elements, dtype=_dtype)) + ph.assert_array_elements("arange", out, xp.asarray(elements, dtype=_dtype)) else: assume(out.size == size) if out.size > 0: @@ -245,11 +246,25 @@ def test_asarray_scalars(shape, data): ph.assert_scalar_equals("asarray", scalar_type, idx, v, v_expect, **kw) -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), st.data()) -def test_asarray_arrays(x, data): - # TODO: test other valid dtypes +def scalar_eq(s1: Scalar, s2: Scalar) -> bool: + if math.isnan(s1): + return math.isnan(s2) + else: + return s1 == s2 + + +@given( + shape=hh.shapes(), + dtypes=oneway_promotable_dtypes(dh.all_dtypes), + data=st.data(), +) +def test_asarray_arrays(shape, dtypes, data): + x = data.draw(xps.arrays(dtype=dtypes.input_dtype, shape=shape), label="x") + dtypes_strat = st.just(dtypes.input_dtype) + if dtypes.input_dtype == dtypes.result_dtype: + dtypes_strat |= st.none() kw = data.draw( - hh.kwargs(dtype=st.none() | st.just(x.dtype), copy=st.none() | st.booleans()), + hh.kwargs(dtype=dtypes_strat, copy=st.none() | st.booleans()), label="kw", ) @@ -261,27 +276,35 @@ def test_asarray_arrays(x, data): else: ph.assert_kw_dtype("asarray", dtype, out.dtype) ph.assert_shape("asarray", out.shape, x.shape) - if dtype is None or dtype == x.dtype: - ph.assert_array("asarray", out, x, **kw) - else: - pass # TODO + ph.assert_array_elements("asarray", out, x, **kw) copy = kw.get("copy", None) if copy is not None: + stype = dh.get_scalar_type(x.dtype) idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx") - _dtype = x.dtype if dtype is None else dtype - old_value = x[idx] + old_value = stype(x[idx]) + scalar_strat = xps.from_dtype(dtypes.input_dtype).filter( + lambda n: not scalar_eq(n, old_value) + ) value = data.draw( - xps.arrays(dtype=_dtype, shape=()).filter(lambda y: y != old_value), + scalar_strat | scalar_strat.map(lambda n: xp.asarray(n, dtype=x.dtype)), label="mutating value", ) x[idx] = value note(f"mutated {x=}") + # sanity check + ph.assert_scalar_equals( + "__setitem__", stype, idx, stype(x[idx]), value, repr_name="x" + ) + new_out_value = stype(out[idx]) + f_out = f"{sh.fmt_idx('out', idx)}={new_out_value}" if copy: - assert not xp.all( - out == x - ), f"xp.all(out == x)=True, but should be False after x was mutated\n{out=}" - elif copy is False: - pass # TODO + assert scalar_eq( + new_out_value, old_value + ), f"{f_out}, but should be {old_value} even after x was mutated" + else: + assert scalar_eq( + new_out_value, value + ), f"{f_out}, but should be {value} after x was mutated" @given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes)) @@ -452,7 +475,7 @@ def test_linspace(num, dtype, endpoint, data): # the first num elements when endpoint=False expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True) expected = expected[:-1] - ph.assert_array("linspace", out, expected) + ph.assert_array_elements("linspace", out, expected) @given(dtype=xps.numeric_dtypes(), data=st.data()) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index beff36de..d4349372 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1124,7 +1124,7 @@ def test_positive(ctx, data): ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) - ph.assert_array(ctx.func_name, out, x) + ph.assert_array_elements(ctx.func_name, out, x) @pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes)) diff --git a/array_api_tests/typing.py b/array_api_tests/typing.py index da8652ae..f0ed8c50 100644 --- a/array_api_tests/typing.py +++ b/array_api_tests/typing.py @@ -16,6 +16,6 @@ ScalarType = Union[Type[bool], Type[int], Type[float]] Array = Any Shape = Tuple[int, ...] -AtomicIndex = Union[int, "ellipsis", slice] # noqa +AtomicIndex = Union[int, "ellipsis", slice, None] # noqa Index = Union[AtomicIndex, Tuple[AtomicIndex, ...]] Param = Tuple