From 5cd2e72c8e08f3e26f56d098fdb69f70db8c09dc Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 18 May 2022 16:20:40 +0100 Subject: [PATCH 01/10] Remove `dh.is_float_dtype()` TODO > Return True for float dtypes that aren't part of the spec e.g. np.float16 Such utility hasn't ended up being desired anywhere --- array_api_tests/dtype_helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 94576ee3..f691d771 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -121,7 +121,6 @@ def is_float_dtype(dtype): # See https://github.com/numpy/numpy/issues/18434 if dtype is None: return False - # TODO: Return True for float dtypes that aren't part of the spec e.g. np.float16 return dtype in float_dtypes From 34fda926a1d445ccadeb7828d8873e610f83cc68 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 18 May 2022 17:11:20 +0100 Subject: [PATCH 02/10] Test 0-sided array in `test_getitem` --- array_api_tests/test_array_object.py | 31 ++++++++++++++++------------ 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 50db7e51..3cc590ef 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -25,11 +25,15 @@ def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scal ) -@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays +@given(hh.shapes(), st.data()) def test_getitem(shape, data): dtype = data.draw(xps.scalar_dtypes(), label="dtype") - obj = data.draw(scalar_objects(dtype, shape), label="obj") - x = xp.asarray(obj, dtype=dtype) + zero_sided = any(side == 0 for side in shape) + if zero_sided: + x = xp.ones(shape, dtype=dtype) + else: + obj = data.draw(scalar_objects(dtype, shape), label="obj") + x = xp.asarray(obj, dtype=dtype) note(f"{x=}") key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key") @@ -62,16 +66,17 @@ def test_getitem(shape, data): a += 1 out_shape = tuple(out_shape) ph.assert_shape("__getitem__", out.shape, out_shape) - assume(all(len(indices) > 0 for indices in axes_indices)) - out_obj = [] - for idx in product(*axes_indices): - val = obj - for i in idx: - val = val[i] - out_obj.append(val) - out_obj = sh.reshape(out_obj, out_shape) - expected = xp.asarray(out_obj, dtype=dtype) - ph.assert_array("__getitem__", out, expected) + out_zero_sided = any(side == 0 for side in out_shape) + if not zero_sided and not out_zero_sided: + out_obj = [] + for idx in product(*axes_indices): + val = obj + for i in idx: + val = val[i] + out_obj.append(val) + out_obj = sh.reshape(out_obj, out_shape) + expected = xp.asarray(out_obj, dtype=dtype) + ph.assert_array("__getitem__", out, expected) @given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays From ff3fed431457da9a27d307145c606b4b2fb087e2 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 18 May 2022 19:24:12 +0100 Subject: [PATCH 03/10] Test non-0d-resulting keys in `test_setitem` Also `ph.assert_array()` -> `ph.assert_array_elements()` --- array_api_tests/meta/test_pytest_helpers.py | 10 +-- array_api_tests/pytest_helpers.py | 22 +++--- array_api_tests/test_array_object.py | 72 +++++++++++-------- array_api_tests/test_creation_functions.py | 6 +- ...est_operators_and_elementwise_functions.py | 2 +- 5 files changed, 65 insertions(+), 47 deletions(-) diff --git a/array_api_tests/meta/test_pytest_helpers.py b/array_api_tests/meta/test_pytest_helpers.py index 117e2b11..a6851a15 100644 --- a/array_api_tests/meta/test_pytest_helpers.py +++ b/array_api_tests/meta/test_pytest_helpers.py @@ -13,10 +13,10 @@ def test_assert_dtype(): ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool) -def test_assert_array(): - ph.assert_array("int zeros", xp.asarray(0), xp.asarray(0)) - ph.assert_array("pos zeros", xp.asarray(0.0), xp.asarray(0.0)) +def test_assert_array_elements(): + ph.assert_array_elements("int zeros", xp.asarray(0), xp.asarray(0)) + ph.assert_array_elements("pos zeros", xp.asarray(0.0), xp.asarray(0.0)) with raises(AssertionError): - ph.assert_array("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0)) + ph.assert_array_elements("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0)) with raises(AssertionError): - ph.assert_array("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0)) + ph.assert_array_elements("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0)) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 39513670..af6766f2 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -25,7 +25,7 @@ "assert_keepdimable_shape", "assert_0d_equals", "assert_fill", - "assert_array", + "assert_array_elements", ] @@ -374,28 +374,30 @@ def assert_fill( assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg -def assert_array(func_name: str, out: Array, expected: Array, /, **kw): +def assert_array_elements( + func_name: str, out: Array, expected: Array, /, *, out_repr: str = "out", **kw +): """ - Assert array is (strictly) as expected, e.g. + Assert array elements are (strictly) as expected, e.g. >>> x = xp.arange(5) >>> out = xp.asarray(x) - >>> assert_array('asarray', out, x) + >>> assert_array_elements('asarray', out, x) is equivalent to >>> assert xp.all(out == x) """ - assert_dtype(func_name, out.dtype, expected.dtype) - assert_shape(func_name, out.shape, expected.shape, **kw) + dh.result_type(out.dtype, expected.dtype) # sanity check + assert_shape(func_name, out.shape, expected.shape, **kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" if dh.is_float_dtype(out.dtype): for idx in sh.ndindex(out.shape): at_out = out[idx] at_expected = expected[idx] msg = ( - f"{sh.fmt_idx('out', idx)}={at_out}, should be {at_expected} " + f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " f"{f_func}" ) if xp.isnan(at_expected): @@ -411,6 +413,6 @@ def assert_array(func_name: str, out: Array, expected: Array, /, **kw): else: assert at_out == at_expected, msg else: - assert xp.all(out == expected), ( - f"out not as expected {f_func}\n" f"{out=}\n{expected=}" - ) + assert xp.all( + out == expected + ), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 3cc590ef..35b48052 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -25,12 +25,11 @@ def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scal ) -@given(hh.shapes(), st.data()) -def test_getitem(shape, data): - dtype = data.draw(xps.scalar_dtypes(), label="dtype") +@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data()) +def test_getitem(shape, dtype, data): zero_sided = any(side == 0 for side in shape) if zero_sided: - x = xp.ones(shape, dtype=dtype) + x = xp.zeros(shape, dtype=dtype) else: obj = data.draw(scalar_objects(dtype, shape), label="obj") x = xp.asarray(obj, dtype=dtype) @@ -76,45 +75,62 @@ def test_getitem(shape, data): out_obj.append(val) out_obj = sh.reshape(out_obj, out_shape) expected = xp.asarray(out_obj, dtype=dtype) - ph.assert_array("__getitem__", out, expected) + ph.assert_array_elements("__getitem__", out, expected) -@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays -def test_setitem(shape, data): - dtype = data.draw(xps.scalar_dtypes(), label="dtype") - obj = data.draw(scalar_objects(dtype, shape), label="obj") - x = xp.asarray(obj, dtype=dtype) +@given(shape=hh.shapes(min_side=1), dtype=xps.scalar_dtypes(), data=st.data()) +def test_setitem(shape, dtype, data): + zero_sided = any(side == 0 for side in shape) + if zero_sided: + x = xp.zeros(shape, dtype=dtype) + else: + obj = data.draw(scalar_objects(dtype, shape), label="obj") + x = xp.asarray(obj, dtype=dtype) note(f"{x=}") - # TODO: test setting non-0d arrays - key = data.draw(xps.indices(shape=shape, max_dims=0), label="key") - value = data.draw( - xps.from_dtype(dtype) | xps.arrays(dtype=dtype, shape=()), label="value" - ) + key = data.draw(xps.indices(shape=shape), label="key") + _key = tuple(key) if isinstance(key, tuple) else (key,) + if Ellipsis in _key: + nonexpanding_key = tuple(i for i in _key if i is not None) + start_a = nonexpanding_key.index(Ellipsis) + stop_a = start_a + (len(shape) - (len(nonexpanding_key) - 1)) + slices = tuple(slice(None) for _ in range(start_a, stop_a)) + start_pos = _key.index(Ellipsis) + _key = _key[:start_pos] + slices + _key[start_pos + 1 :] + out_shape = [] + for a, i in enumerate(_key): + if isinstance(i, slice): + side = shape[a] + indices = range(side)[i] + out_shape.append(len(indices)) + out_shape = tuple(out_shape) + value_strat = xps.arrays(dtype=dtype, shape=out_shape) + if out_shape == (): + # We can pass scalars if we're only indexing one element + value_strat |= xps.from_dtype(dtype) + value = data.draw(value_strat, label="value") res = xp.asarray(x, copy=True) res[key] = value ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype") ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.shape") + f_res = f"res[{sh.fmt_idx('x', key)}]" if isinstance(value, get_args(Scalar)): - msg = f"x[{key}]={res[key]!r}, but should be {value=} [__setitem__()]" + msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]" if math.isnan(value): assert xp.isnan(res[key]), msg else: assert res[key] == value, msg else: - ph.assert_0d_equals( - "__setitem__", "value", value, f"modified x[{key}]", res[key] - ) - _key = key if isinstance(key, tuple) else (key,) - assume(all(isinstance(i, int) for i in _key)) # TODO: normalise slices and ellipsis - _key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape)) - unaffected_indices = list(sh.ndindex(res.shape)) - unaffected_indices.remove(_key) - for idx in unaffected_indices: - ph.assert_0d_equals( - "__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx] - ) + ph.assert_array_elements("__setitem__", res[key], value, out_repr=f_res) + if all(isinstance(i, int) for i in _key): # TODO: normalise slices and ellipsis + _key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape)) + unaffected_indices = list(sh.ndindex(res.shape)) + unaffected_indices.remove(_key) + for idx in unaffected_indices: + ph.assert_0d_equals( + "__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx] + ) @pytest.mark.data_dependent_shapes diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index f5cb6342..46a09cd9 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -180,7 +180,7 @@ def test_arange(dtype, data): if dh.is_int_dtype(_dtype): elements = list(r) assume(out.size == len(elements)) - ph.assert_array("arange", out, xp.asarray(elements, dtype=_dtype)) + ph.assert_array_elements("arange", out, xp.asarray(elements, dtype=_dtype)) else: assume(out.size == size) if out.size > 0: @@ -262,7 +262,7 @@ def test_asarray_arrays(x, data): ph.assert_kw_dtype("asarray", dtype, out.dtype) ph.assert_shape("asarray", out.shape, x.shape) if dtype is None or dtype == x.dtype: - ph.assert_array("asarray", out, x, **kw) + ph.assert_array_elements("asarray", out, x, **kw) else: pass # TODO copy = kw.get("copy", None) @@ -452,7 +452,7 @@ def test_linspace(num, dtype, endpoint, data): # the first num elements when endpoint=False expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True) expected = expected[:-1] - ph.assert_array("linspace", out, expected) + ph.assert_array_elements("linspace", out, expected) @given(dtype=xps.numeric_dtypes(), data=st.data()) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index beff36de..d4349372 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1124,7 +1124,7 @@ def test_positive(ctx, data): ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) ph.assert_shape(ctx.func_name, out.shape, x.shape) - ph.assert_array(ctx.func_name, out, x) + ph.assert_array_elements(ctx.func_name, out, x) @pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes)) From 7feaa28c2a98586dff7bd7ba09208c8c5cc30bae Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 19 May 2022 12:13:44 +0100 Subject: [PATCH 04/10] Test different dtypes in `test_setitem` --- array_api_tests/test_array_object.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 35b48052..c56871aa 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -12,6 +12,7 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps +from .test_operators_and_elementwise_functions import oneway_promotable_dtypes from .typing import DataType, Param, Scalar, ScalarType, Shape pytestmark = pytest.mark.ci @@ -78,14 +79,18 @@ def test_getitem(shape, dtype, data): ph.assert_array_elements("__getitem__", out, expected) -@given(shape=hh.shapes(min_side=1), dtype=xps.scalar_dtypes(), data=st.data()) -def test_setitem(shape, dtype, data): +@given( + shape=hh.shapes(), + dtypes=oneway_promotable_dtypes(dh.all_dtypes), + data=st.data(), +) +def test_setitem(shape, dtypes, data): zero_sided = any(side == 0 for side in shape) if zero_sided: - x = xp.zeros(shape, dtype=dtype) + x = xp.zeros(shape, dtype=dtypes.result_dtype) else: - obj = data.draw(scalar_objects(dtype, shape), label="obj") - x = xp.asarray(obj, dtype=dtype) + obj = data.draw(scalar_objects(dtypes.result_dtype, shape), label="obj") + x = xp.asarray(obj, dtype=dtypes.result_dtype) note(f"{x=}") key = data.draw(xps.indices(shape=shape), label="key") _key = tuple(key) if isinstance(key, tuple) else (key,) @@ -103,10 +108,10 @@ def test_setitem(shape, dtype, data): indices = range(side)[i] out_shape.append(len(indices)) out_shape = tuple(out_shape) - value_strat = xps.arrays(dtype=dtype, shape=out_shape) + value_strat = xps.arrays(dtype=dtypes.result_dtype, shape=out_shape) if out_shape == (): # We can pass scalars if we're only indexing one element - value_strat |= xps.from_dtype(dtype) + value_strat |= xps.from_dtype(dtypes.result_dtype) value = data.draw(value_strat, label="value") res = xp.asarray(x, copy=True) From b1dcf77aae986d4f8ee35ef3e2fbeb317801edce Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 19 May 2022 12:59:43 +0100 Subject: [PATCH 05/10] Test unaffected indices more wholly in `test_setitem` --- array_api_tests/pytest_helpers.py | 2 +- array_api_tests/test_array_object.py | 34 ++++++++++++++++++---------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index af6766f2..78797c30 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -301,7 +301,7 @@ def assert_0d_equals( >>> x = xp.asarray([0, 1, 2]) >>> res = xp.asarray(x, copy=True) >>> res[0] = 42 - >>> assert_0d_equals('__setitem__', 'x[0]', x[0], 'x[0]', res[0]) + >>> assert_0d_equals('asarray', 'x[0]', x[0], 'x[0]', res[0]) is equivalent to diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index c56871aa..6e9ed8ba 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -55,11 +55,13 @@ def test_getitem(shape, dtype, data): if i is None: out_shape.append(1) else: + side = shape[a] if isinstance(i, int): + if i < 0: + i += side axes_indices.append([i]) else: assert isinstance(i, slice) # sanity check - side = shape[a] indices = range(side)[i] axes_indices.append(indices) out_shape.append(len(indices)) @@ -102,9 +104,9 @@ def test_setitem(shape, dtypes, data): start_pos = _key.index(Ellipsis) _key = _key[:start_pos] + slices + _key[start_pos + 1 :] out_shape = [] - for a, i in enumerate(_key): + + for i, side in zip(_key, shape): if isinstance(i, slice): - side = shape[a] indices = range(side)[i] out_shape.append(len(indices)) out_shape = tuple(out_shape) @@ -119,7 +121,8 @@ def test_setitem(shape, dtypes, data): ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype") ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.shape") - f_res = f"res[{sh.fmt_idx('x', key)}]" + + f_res = sh.fmt_idx("x", key) if isinstance(value, get_args(Scalar)): msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]" if math.isnan(value): @@ -128,14 +131,21 @@ def test_setitem(shape, dtypes, data): assert res[key] == value, msg else: ph.assert_array_elements("__setitem__", res[key], value, out_repr=f_res) - if all(isinstance(i, int) for i in _key): # TODO: normalise slices and ellipsis - _key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape)) - unaffected_indices = list(sh.ndindex(res.shape)) - unaffected_indices.remove(_key) - for idx in unaffected_indices: - ph.assert_0d_equals( - "__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx] - ) + + axes_indices = [] + for i, side in zip(_key, shape): + if isinstance(i, int): + if i < 0: + i += side + axes_indices.append([i]) + else: + indices = range(side)[i] + axes_indices.append(indices) + unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices)) + for idx in unaffected_indices: + ph.assert_0d_equals( + "__setitem__", f"old {f_res}", x[idx], f"modified {f_res}", res[idx] + ) @pytest.mark.data_dependent_shapes From a4dd0759c78aa55db0e9febe16dff6445e7cf1eb Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 23 May 2022 11:09:51 +0100 Subject: [PATCH 06/10] Fix `scalar_objects()` typing --- array_api_tests/test_array_object.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 6e9ed8ba..a8aaaf01 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import List, get_args +from typing import List, Union, get_args import pytest from hypothesis import assume, given, note @@ -18,7 +18,9 @@ pytestmark = pytest.mark.ci -def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scalar]]: +def scalar_objects( + dtype: DataType, shape: Shape +) -> st.SearchStrategy[Union[Scalar, List[Scalar]]]: """Generates scalars or nested sequences which are valid for xp.asarray()""" size = math.prod(shape) return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map( From 131037d30e6b068ba0394d7b866423753f15dfd0 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 23 May 2022 11:37:06 +0100 Subject: [PATCH 07/10] `normalise_key()` util for indexing tests --- array_api_tests/test_array_object.py | 38 +++++++++++++++------------- array_api_tests/typing.py | 2 +- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index a8aaaf01..5ea5e1a7 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -13,7 +13,7 @@ from . import shape_helpers as sh from . import xps from .test_operators_and_elementwise_functions import oneway_promotable_dtypes -from .typing import DataType, Param, Scalar, ScalarType, Shape +from .typing import DataType, Index, Param, Scalar, ScalarType, Shape pytestmark = pytest.mark.ci @@ -28,6 +28,24 @@ def scalar_objects( ) +def normalise_key(key: Index, shape: Shape): + """ + Normalise an indexing key. + + * If a non-tuple index, wrap as a tuple. + * Represent ellipsis as equivalent slices. + """ + _key = tuple(key) if isinstance(key, tuple) else (key,) + if Ellipsis in _key: + nonexpanding_key = tuple(i for i in _key if i is not None) + start_a = nonexpanding_key.index(Ellipsis) + stop_a = start_a + (len(shape) - (len(nonexpanding_key) - 1)) + slices = tuple(slice(None) for _ in range(start_a, stop_a)) + start_pos = _key.index(Ellipsis) + _key = _key[:start_pos] + slices + _key[start_pos + 1 :] + return _key + + @given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data()) def test_getitem(shape, dtype, data): zero_sided = any(side == 0 for side in shape) @@ -42,14 +60,7 @@ def test_getitem(shape, dtype, data): out = x[key] ph.assert_dtype("__getitem__", x.dtype, out.dtype) - _key = tuple(key) if isinstance(key, tuple) else (key,) - if Ellipsis in _key: - nonexpanding_key = tuple(i for i in _key if i is not None) - start_a = nonexpanding_key.index(Ellipsis) - stop_a = start_a + (len(shape) - (len(nonexpanding_key) - 1)) - slices = tuple(slice(None) for _ in range(start_a, stop_a)) - start_pos = _key.index(Ellipsis) - _key = _key[:start_pos] + slices + _key[start_pos + 1 :] + _key = normalise_key(key, shape) axes_indices = [] out_shape = [] a = 0 @@ -97,14 +108,7 @@ def test_setitem(shape, dtypes, data): x = xp.asarray(obj, dtype=dtypes.result_dtype) note(f"{x=}") key = data.draw(xps.indices(shape=shape), label="key") - _key = tuple(key) if isinstance(key, tuple) else (key,) - if Ellipsis in _key: - nonexpanding_key = tuple(i for i in _key if i is not None) - start_a = nonexpanding_key.index(Ellipsis) - stop_a = start_a + (len(shape) - (len(nonexpanding_key) - 1)) - slices = tuple(slice(None) for _ in range(start_a, stop_a)) - start_pos = _key.index(Ellipsis) - _key = _key[:start_pos] + slices + _key[start_pos + 1 :] + _key = normalise_key(key, shape) out_shape = [] for i, side in zip(_key, shape): diff --git a/array_api_tests/typing.py b/array_api_tests/typing.py index da8652ae..f0ed8c50 100644 --- a/array_api_tests/typing.py +++ b/array_api_tests/typing.py @@ -16,6 +16,6 @@ ScalarType = Union[Type[bool], Type[int], Type[float]] Array = Any Shape = Tuple[int, ...] -AtomicIndex = Union[int, "ellipsis", slice] # noqa +AtomicIndex = Union[int, "ellipsis", slice, None] # noqa Index = Union[AtomicIndex, Tuple[AtomicIndex, ...]] Param = Tuple From 373dd4825582f648dacb449bfc913b5e3b0d3b92 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 23 May 2022 12:15:42 +0100 Subject: [PATCH 08/10] `get_indexed_axes_and_out_shape()` util for indexing tests --- array_api_tests/test_array_object.py | 71 +++++++++++++--------------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 5ea5e1a7..df3edb88 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import List, Union, get_args +from typing import List, Sequence, Tuple, Union, get_args import pytest from hypothesis import assume, given, note @@ -28,7 +28,7 @@ def scalar_objects( ) -def normalise_key(key: Index, shape: Shape): +def normalise_key(key: Index, shape: Shape) -> Tuple[Union[int, slice], ...]: """ Normalise an indexing key. @@ -46,25 +46,20 @@ def normalise_key(key: Index, shape: Shape): return _key -@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data()) -def test_getitem(shape, dtype, data): - zero_sided = any(side == 0 for side in shape) - if zero_sided: - x = xp.zeros(shape, dtype=dtype) - else: - obj = data.draw(scalar_objects(dtype, shape), label="obj") - x = xp.asarray(obj, dtype=dtype) - note(f"{x=}") - key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key") - - out = x[key] +def get_indexed_axes_and_out_shape( + key: Tuple[Union[int, slice, None], ...], shape: Shape +) -> Tuple[Tuple[Sequence[int], ...], Shape]: + """ + From the (normalised) key and input shape, calculates: - ph.assert_dtype("__getitem__", x.dtype, out.dtype) - _key = normalise_key(key, shape) + * indexed_axes: For each dimension, the axes which the key indexes. + * out_shape: The resulting shape of indexing an array (of the input shape) + with the key. + """ axes_indices = [] out_shape = [] a = 0 - for i in _key: + for i in key: if i is None: out_shape.append(1) else: @@ -72,14 +67,31 @@ def test_getitem(shape, dtype, data): if isinstance(i, int): if i < 0: i += side - axes_indices.append([i]) + axes_indices.append((i,)) else: - assert isinstance(i, slice) # sanity check indices = range(side)[i] axes_indices.append(indices) out_shape.append(len(indices)) a += 1 - out_shape = tuple(out_shape) + return tuple(axes_indices), tuple(out_shape) + + +@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data()) +def test_getitem(shape, dtype, data): + zero_sided = any(side == 0 for side in shape) + if zero_sided: + x = xp.zeros(shape, dtype=dtype) + else: + obj = data.draw(scalar_objects(dtype, shape), label="obj") + x = xp.asarray(obj, dtype=dtype) + note(f"{x=}") + key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key") + + out = x[key] + + ph.assert_dtype("__getitem__", x.dtype, out.dtype) + _key = normalise_key(key, shape) + axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape) ph.assert_shape("__getitem__", out.shape, out_shape) out_zero_sided = any(side == 0 for side in out_shape) if not zero_sided and not out_zero_sided: @@ -109,13 +121,7 @@ def test_setitem(shape, dtypes, data): note(f"{x=}") key = data.draw(xps.indices(shape=shape), label="key") _key = normalise_key(key, shape) - out_shape = [] - - for i, side in zip(_key, shape): - if isinstance(i, slice): - indices = range(side)[i] - out_shape.append(len(indices)) - out_shape = tuple(out_shape) + axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape) value_strat = xps.arrays(dtype=dtypes.result_dtype, shape=out_shape) if out_shape == (): # We can pass scalars if we're only indexing one element @@ -127,7 +133,6 @@ def test_setitem(shape, dtypes, data): ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype") ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.shape") - f_res = sh.fmt_idx("x", key) if isinstance(value, get_args(Scalar)): msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]" @@ -137,16 +142,6 @@ def test_setitem(shape, dtypes, data): assert res[key] == value, msg else: ph.assert_array_elements("__setitem__", res[key], value, out_repr=f_res) - - axes_indices = [] - for i, side in zip(_key, shape): - if isinstance(i, int): - if i < 0: - i += side - axes_indices.append([i]) - else: - indices = range(side)[i] - axes_indices.append(indices) unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices)) for idx in unaffected_indices: ph.assert_0d_equals( From 52e835e58983412753ddffa160b8935463c68ee2 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 23 May 2022 12:21:24 +0100 Subject: [PATCH 09/10] Support `None` in `sh.fmt_i()` --- array_api_tests/meta/test_utils.py | 2 ++ array_api_tests/shape_helpers.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 1fe3ca66..268a81aa 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -91,6 +91,7 @@ def test_roll_ndindex(shape, shifts, axes, expected): ((), "x"), (42, "x[42]"), ((42,), "x[42]"), + ((42, 7), "x[42, 7]"), (slice(None, 2), "x[:2]"), (slice(2, None), "x[2:]"), (slice(0, 2), "x[0:2]"), @@ -98,6 +99,7 @@ def test_roll_ndindex(shape, shifts, axes, expected): (slice(None, None, -1), "x[::-1]"), (slice(None, None), "x[:]"), (..., "x[...]"), + ((None, 42), "x[None, 42]"), ], ) def test_fmt_idx(idx, expected): diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py index 9b3d001b..ba7d994e 100644 --- a/array_api_tests/shape_helpers.py +++ b/array_api_tests/shape_helpers.py @@ -156,6 +156,8 @@ def fmt_i(i: AtomicIndex) -> str: if i.step is not None: res += f":{i.step}" return res + elif i is None: + return "None" else: return "..." From 8a5103bc91978102c8563d8ae6392c329305aaa6 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 23 May 2022 15:35:32 +0100 Subject: [PATCH 10/10] `test_asarray_arrays` improvements * Test all possible dtype kwargs * Fix erroneous nan equals * Clean up copy testing --- array_api_tests/test_creation_functions.py | 55 +++++++++++++++------- 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 46a09cd9..4ae92f1f 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -12,6 +12,7 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps +from .test_operators_and_elementwise_functions import oneway_promotable_dtypes from .typing import DataType, Scalar pytestmark = pytest.mark.ci @@ -245,11 +246,25 @@ def test_asarray_scalars(shape, data): ph.assert_scalar_equals("asarray", scalar_type, idx, v, v_expect, **kw) -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), st.data()) -def test_asarray_arrays(x, data): - # TODO: test other valid dtypes +def scalar_eq(s1: Scalar, s2: Scalar) -> bool: + if math.isnan(s1): + return math.isnan(s2) + else: + return s1 == s2 + + +@given( + shape=hh.shapes(), + dtypes=oneway_promotable_dtypes(dh.all_dtypes), + data=st.data(), +) +def test_asarray_arrays(shape, dtypes, data): + x = data.draw(xps.arrays(dtype=dtypes.input_dtype, shape=shape), label="x") + dtypes_strat = st.just(dtypes.input_dtype) + if dtypes.input_dtype == dtypes.result_dtype: + dtypes_strat |= st.none() kw = data.draw( - hh.kwargs(dtype=st.none() | st.just(x.dtype), copy=st.none() | st.booleans()), + hh.kwargs(dtype=dtypes_strat, copy=st.none() | st.booleans()), label="kw", ) @@ -261,27 +276,35 @@ def test_asarray_arrays(x, data): else: ph.assert_kw_dtype("asarray", dtype, out.dtype) ph.assert_shape("asarray", out.shape, x.shape) - if dtype is None or dtype == x.dtype: - ph.assert_array_elements("asarray", out, x, **kw) - else: - pass # TODO + ph.assert_array_elements("asarray", out, x, **kw) copy = kw.get("copy", None) if copy is not None: + stype = dh.get_scalar_type(x.dtype) idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx") - _dtype = x.dtype if dtype is None else dtype - old_value = x[idx] + old_value = stype(x[idx]) + scalar_strat = xps.from_dtype(dtypes.input_dtype).filter( + lambda n: not scalar_eq(n, old_value) + ) value = data.draw( - xps.arrays(dtype=_dtype, shape=()).filter(lambda y: y != old_value), + scalar_strat | scalar_strat.map(lambda n: xp.asarray(n, dtype=x.dtype)), label="mutating value", ) x[idx] = value note(f"mutated {x=}") + # sanity check + ph.assert_scalar_equals( + "__setitem__", stype, idx, stype(x[idx]), value, repr_name="x" + ) + new_out_value = stype(out[idx]) + f_out = f"{sh.fmt_idx('out', idx)}={new_out_value}" if copy: - assert not xp.all( - out == x - ), f"xp.all(out == x)=True, but should be False after x was mutated\n{out=}" - elif copy is False: - pass # TODO + assert scalar_eq( + new_out_value, old_value + ), f"{f_out}, but should be {old_value} even after x was mutated" + else: + assert scalar_eq( + new_out_value, value + ), f"{f_out}, but should be {value} after x was mutated" @given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes))