From 5a1a19fb0bf700e29e90cc0b2b86e405f4cf22a5 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 9 Nov 2021 09:57:59 +0000 Subject: [PATCH 01/60] Move manipulation tests to their own file --- .../test_manipulation_functions.py | 38 +++++++++++++++++++ array_api_tests/test_type_promotion.py | 30 +-------------- 2 files changed, 39 insertions(+), 29 deletions(-) create mode 100644 array_api_tests/test_manipulation_functions.py diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py new file mode 100644 index 00000000..79395feb --- /dev/null +++ b/array_api_tests/test_manipulation_functions.py @@ -0,0 +1,38 @@ +from hypothesis import given +from hypothesis import strategies as st + +from . import _array_module as xp +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import xps + + +@given( + shape=hh.shapes(min_dims=1), + dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), + data=st.data(), +) +def test_concat(shape, dtypes, data): + arrays = [] + for i, dtype in enumerate(dtypes, 1): + x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") + arrays.append(x) + out = xp.concat(arrays) + ph.assert_dtype("concat", dtypes, out.dtype) + # TODO + + +@given( + shape=hh.shapes(), + dtypes=hh.mutually_promotable_dtypes(None), + data=st.data(), +) +def test_stack(shape, dtypes, data): + arrays = [] + for i, dtype in enumerate(dtypes, 1): + x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") + arrays.append(x) + out = xp.stack(arrays) + ph.assert_dtype("stack", dtypes, out.dtype) + # TODO diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index d304071a..22af9526 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -19,7 +19,7 @@ # TODO: move tests not covering elementwise funcs/ops into standalone tests -# result_type, meshgrid, concat, stack, where, tensordor, vecdot +# result_type, meshgrid, where, tensordor, vecdot @given(hh.mutually_promotable_dtypes(None)) @@ -51,34 +51,6 @@ def test_meshgrid(dtypes, data): ph.assert_dtype("meshgrid", dtypes, x.dtype, repr_name=f"out[{i}].dtype") -@given( - shape=hh.shapes(min_dims=1), - dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), - data=st.data(), -) -def test_concat(shape, dtypes, data): - arrays = [] - for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") - arrays.append(x) - out = xp.concat(arrays) - ph.assert_dtype("concat", dtypes, out.dtype) - - -@given( - shape=hh.shapes(), - dtypes=hh.mutually_promotable_dtypes(None), - data=st.data(), -) -def test_stack(shape, dtypes, data): - arrays = [] - for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") - arrays.append(x) - out = xp.stack(arrays) - ph.assert_dtype("stack", dtypes, out.dtype) - - bitwise_shift_funcs = [ "bitwise_left_shift", "bitwise_right_shift", From 6dd7f3d58a73026a2efc681d257216d454972a21 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 9 Nov 2021 12:10:59 +0000 Subject: [PATCH 02/60] Minor `test_concat` improvements --- array_api_tests/test_manipulation_functions.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 79395feb..9a7d1d7f 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,3 +1,5 @@ +import math + from hypothesis import given from hypothesis import strategies as st @@ -11,16 +13,23 @@ @given( shape=hh.shapes(min_dims=1), dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), + kw=hh.kwargs(axis=st.just(0) | st.none()), # TODO: test with axis >= 1 data=st.data(), ) -def test_concat(shape, dtypes, data): +def test_concat(shape, dtypes, kw, data): arrays = [] for i, dtype in enumerate(dtypes, 1): x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") arrays.append(x) - out = xp.concat(arrays) + out = xp.concat(arrays, **kw) ph.assert_dtype("concat", dtypes, out.dtype) - # TODO + shapes = tuple(x.shape for x in arrays) + if kw.get("axis", 0) == 0: + pass # TODO: assert expected shape + elif kw["axis"] is None: + size = sum(math.prod(s) for s in shapes) + ph.assert_result_shape("concat", shapes, out.shape, (size,), **kw) + # TODO: assert out elements match input arrays @given( From e5bb97442b759871d8ec013bf7dbb4a9e6f64552 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 9 Nov 2021 19:07:08 +0000 Subject: [PATCH 03/60] Smoke all manipulation methods --- .../test_manipulation_functions.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 9a7d1d7f..eea4bb00 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -9,6 +9,8 @@ from . import pytest_helpers as ph from . import xps +shared_shapes = st.shared(hh.shapes(), key="shape") + @given( shape=hh.shapes(min_dims=1), @@ -32,6 +34,81 @@ def test_concat(shape, dtypes, kw, data): # TODO: assert out elements match input arrays +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), + axis=shared_shapes.flatmap(lambda s: st.integers(-len(s), len(s))), +) +def test_expand_dims(x, axis): + xp.expand_dims(x, axis=axis) + # TODO + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), + kw=hh.kwargs( + axis=st.one_of( + st.none(), + shared_shapes.flatmap( + lambda s: st.none() + if len(s) == 0 + else st.integers(-len(s) + 1, len(s) - 1), + ), + ) + ), +) +def test_flip(x, kw): + xp.flip(x, **kw) + # TODO + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), + axes=shared_shapes.flatmap( + lambda s: st.lists( + st.integers(0, max(len(s) - 1, 0)), + min_size=len(s), + max_size=len(s), + unique=True, + ).map(tuple) + ), +) +def test_permute_dims(x, axes): + xp.permute_dims(x, axes) + # TODO + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), + shape=shared_shapes, # TODO: test more compatible shapes +) +def test_reshape(x, shape): + xp.reshape(x, shape) + # TODO + + +@given( + # TODO: axis arguments, update shift respectively + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), + shift=shared_shapes.flatmap(lambda s: st.integers(0, max(math.prod(s) - 1, 0))), +) +def test_roll(x, shift): + xp.roll(x, shift) + # TODO + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), + axis=shared_shapes.flatmap( + lambda s: st.just(0) + if len(s) == 0 + else st.integers(-len(s) + 1, len(s) - 1).filter(lambda i: s[i] == 1) + ), # TODO: tuple of axis i.e. axes +) +def test_squeeze(x, axis): + xp.squeeze(x, axis) + # TODO + + @given( shape=hh.shapes(), dtypes=hh.mutually_promotable_dtypes(None), From 41a0292c1d396dc876fa33d8b3db40c81ac4f3aa Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 9 Nov 2021 19:45:17 +0000 Subject: [PATCH 04/60] Smoke statistical functions --- array_api_tests/test_statistical_functions.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 array_api_tests/test_statistical_functions.py diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py new file mode 100644 index 00000000..7455123c --- /dev/null +++ b/array_api_tests/test_statistical_functions.py @@ -0,0 +1,54 @@ +from hypothesis import given + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) +def test_min(x): + xp.min(x) + # TODO + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) +def test_max(x): + xp.max(x) + # TODO + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) +def test_mean(x): + xp.mean(x) + # TODO + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) +def test_prod(x): + xp.prod(x) + # TODO + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) +def test_std(x): + xp.std(x) + # TODO + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) +def test_sum(x): + xp.sum(x) + # TODO + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) +def test_var(x): + xp.var(x) + # TODO From 97997eecdb5415b19d1115fb79f3e8ae12b526c5 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 9 Nov 2021 20:40:18 +0000 Subject: [PATCH 05/60] Smoke searching functions --- array_api_tests/test_searching_functions.py | 41 +++++++++++++++++++ array_api_tests/test_statistical_functions.py | 14 +++---- array_api_tests/test_type_promotion.py | 2 +- 3 files changed, 49 insertions(+), 8 deletions(-) create mode 100644 array_api_tests/test_searching_functions.py diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py new file mode 100644 index 00000000..c3686bb7 --- /dev/null +++ b/array_api_tests/test_searching_functions.py @@ -0,0 +1,41 @@ +from hypothesis import given +from hypothesis import strategies as st + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +def test_argmin(x): + xp.argmin(x) + # TODO + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +def test_argmax(x): + xp.argmax(x) + # TODO + + +# TODO: generate kwargs, skip if opted out +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +def test_nonzero(x): + xp.nonzero(x) + # TODO + + +# TODO: skip if opted out +@given( + shapes=hh.mutually_broadcastable_shapes(3), + dtypes=hh.mutually_promotable_dtypes(), + data=st.data(), +) +def test_where(shapes, dtypes, data): + cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[0]), label="condition") + x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1") + x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2") + xp.where(cond, x1, x2) + # TODO diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 7455123c..98cade7d 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -5,49 +5,49 @@ from . import xps -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) def test_min(x): xp.min(x) # TODO -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) def test_max(x): xp.max(x) # TODO -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) def test_mean(x): xp.mean(x) # TODO -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) def test_prod(x): xp.prod(x) # TODO -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) def test_std(x): xp.std(x) # TODO -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) def test_sum(x): xp.sum(x) # TODO -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) def test_var(x): xp.var(x) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 22af9526..2fb669e6 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -19,7 +19,7 @@ # TODO: move tests not covering elementwise funcs/ops into standalone tests -# result_type, meshgrid, where, tensordor, vecdot +# result_type, meshgrid, tensordor, vecdot @given(hh.mutually_promotable_dtypes(None)) From 3efff6c0598d1c6e69ba955c5890b41f6708c64b Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 10:09:49 +0000 Subject: [PATCH 06/60] Smoke remaining functions --- array_api_tests/test_set_functions.py | 29 +++++++++++++++++++++++ array_api_tests/test_sorting.py | 19 +++++++++++++++ array_api_tests/test_utility_functions.py | 19 +++++++++++++++ 3 files changed, 67 insertions(+) create mode 100644 array_api_tests/test_set_functions.py create mode 100644 array_api_tests/test_sorting.py create mode 100644 array_api_tests/test_utility_functions.py diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py new file mode 100644 index 00000000..856a7282 --- /dev/null +++ b/array_api_tests/test_set_functions.py @@ -0,0 +1,29 @@ +from hypothesis import given + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_unique_all(x): + xp.unique_all(x) + # TODO + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_unique_counts(x): + xp.unique_counts(x) + # TODO + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_unique_inverse(x): + xp.unique_inverse(x) + # TODO + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_unique_values(x): + xp.unique_values(x) + # TODO diff --git a/array_api_tests/test_sorting.py b/array_api_tests/test_sorting.py new file mode 100644 index 00000000..58179b3c --- /dev/null +++ b/array_api_tests/test_sorting.py @@ -0,0 +1,19 @@ +from hypothesis import given + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_argsort(x): + xp.argsort(x) + # TODO + + +# TODO: generate 0d arrays, generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1))) +def test_sort(x): + xp.sort(x) + # TODO diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py new file mode 100644 index 00000000..140aa85f --- /dev/null +++ b/array_api_tests/test_utility_functions.py @@ -0,0 +1,19 @@ +from hypothesis import given + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_any(x): + xp.any(x) + # TODO + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_all(x): + xp.all(x) + # TODO From 08b23917d97d4cf6d16656ae06e8f8bc3242e938 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 10:56:50 +0000 Subject: [PATCH 07/60] Test `expand_dims()` --- array_api_tests/pytest_helpers.py | 8 ++++---- array_api_tests/test_manipulation_functions.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 9424ba35..b138af3e 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -67,12 +67,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str: def assert_dtype( func_name: str, - in_dtypes: Tuple[DataType, ...], + in_dtypes: Union[DataType, Tuple[DataType, ...]], out_dtype: DataType, expected: Optional[DataType] = None, *, repr_name: str = "out.dtype", ): + if not isinstance(in_dtypes, tuple): + in_dtypes = (in_dtypes,) f_in_dtypes = dh.fmt_types(in_dtypes) f_out_dtype = dh.dtype_to_name[out_dtype] if expected is None: @@ -149,9 +151,7 @@ def assert_result_shape( f_sig = f" {f_in_shapes} " if kw: f_sig += f", {fmt_kw(kw)}" - msg = ( - f"{repr_name}={out_shape}, but should be {expected} [{func_name}({f_sig})]" - ) + msg = f"{repr_name}={out_shape}, but should be {expected} [{func_name}({f_sig})]" assert out_shape == expected, msg diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index eea4bb00..620e8968 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -39,8 +39,15 @@ def test_concat(shape, dtypes, kw, data): axis=shared_shapes.flatmap(lambda s: st.integers(-len(s), len(s))), ) def test_expand_dims(x, axis): - xp.expand_dims(x, axis=axis) - # TODO + out = xp.expand_dims(x, axis=axis) + + ph.assert_dtype("expand_dims", x.dtype, out.dtype) + + shape = [side for side in x.shape] + index = axis if axis >= 0 else x.ndim + axis + 1 + shape.insert(index, 1) + shape = tuple(shape) + ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape) @given( From 40a74bff0e718eeef3a14e897071dcadffa9a17b Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 12:03:20 +0000 Subject: [PATCH 08/60] Test `flip()` when `axis=None` --- .../test_manipulation_functions.py | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 620e8968..d5ec92df 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -4,6 +4,7 @@ from hypothesis import strategies as st from . import _array_module as xp +from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph @@ -50,22 +51,34 @@ def test_expand_dims(x, axis): ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape) +@st.composite +def flip_axis(draw, shape): + if len(shape) == 0 or draw(st.booleans()): + return None + else: + ndim = len(shape) + return draw(st.integers(-ndim, ndim - 1) | xps.valid_tuple_axes(ndim)) + + @given( x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - kw=hh.kwargs( - axis=st.one_of( - st.none(), - shared_shapes.flatmap( - lambda s: st.none() - if len(s) == 0 - else st.integers(-len(s) + 1, len(s) - 1), - ), - ) - ), + kw=hh.kwargs(axis=shared_shapes.flatmap(flip_axis)), ) def test_flip(x, kw): - xp.flip(x, **kw) - # TODO + out = xp.flip(x, **kw) + + ph.assert_dtype("expand_dims", x.dtype, out.dtype) + + # TODO: test all axis scenarios + if kw.get("axis", None) is None: + indices = list(ah.ndindex(x.shape)) + reverse_indices = indices[::-1] + for x_idx, out_idx in zip(indices, reverse_indices): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg @given( From d05a612575bb1669b6355e0faad4a41164bcd61f Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 15:41:13 +0000 Subject: [PATCH 09/60] Test `permute_dims()` --- .../test_manipulation_functions.py | 49 +++++++++++++------ 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index d5ec92df..08288787 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -9,8 +9,16 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps +from .typing import Shape -shared_shapes = st.shared(hh.shapes(), key="shape") + +def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: + key = "shape" + if args: + key += " " + " ".join(args) + if kwargs: + key += " " + ph.fmt_kw(kwargs) + return st.shared(hh.shapes(*args, **kwargs), key="shape") @given( @@ -36,8 +44,8 @@ def test_concat(shape, dtypes, kw, data): @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - axis=shared_shapes.flatmap(lambda s: st.integers(-len(s), len(s))), + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), + axis=shared_shapes().flatmap(lambda s: st.integers(-len(s), len(s))), ) def test_expand_dims(x, axis): out = xp.expand_dims(x, axis=axis) @@ -61,8 +69,8 @@ def flip_axis(draw, shape): @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - kw=hh.kwargs(axis=shared_shapes.flatmap(flip_axis)), + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), + kw=hh.kwargs(axis=shared_shapes().flatmap(flip_axis)), ) def test_flip(x, kw): out = xp.flip(x, **kw) @@ -82,10 +90,10 @@ def test_flip(x, kw): @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - axes=shared_shapes.flatmap( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes(min_dims=1)), + axes=shared_shapes(min_dims=1).flatmap( lambda s: st.lists( - st.integers(0, max(len(s) - 1, 0)), + st.integers(0, len(s) - 1), min_size=len(s), max_size=len(s), unique=True, @@ -93,13 +101,22 @@ def test_flip(x, kw): ), ) def test_permute_dims(x, axes): - xp.permute_dims(x, axes) - # TODO + out = xp.permute_dims(x, axes) + + ph.assert_dtype("permute_dims", x.dtype, out.dtype) + + shape = [None for _ in range(len(axes))] + for i, dim in enumerate(axes): + side = x.shape[dim] + shape[i] = side + assert all(isinstance(side, int) for side in shape) # sanity check + shape = tuple(shape) + ph.assert_result_shape("permute_dims", (x.shape,), out.shape, shape, axes=axes) @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - shape=shared_shapes, # TODO: test more compatible shapes + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), + shape=shared_shapes(), # TODO: test more compatible shapes ) def test_reshape(x, shape): xp.reshape(x, shape) @@ -108,8 +125,8 @@ def test_reshape(x, shape): @given( # TODO: axis arguments, update shift respectively - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - shift=shared_shapes.flatmap(lambda s: st.integers(0, max(math.prod(s) - 1, 0))), + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), + shift=shared_shapes().flatmap(lambda s: st.integers(0, max(math.prod(s) - 1, 0))), ) def test_roll(x, shift): xp.roll(x, shift) @@ -117,8 +134,8 @@ def test_roll(x, shift): @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - axis=shared_shapes.flatmap( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), + axis=shared_shapes().flatmap( lambda s: st.just(0) if len(s) == 0 else st.integers(-len(s) + 1, len(s) - 1).filter(lambda i: s[i] == 1) From a9f9237bfbc10263add9cca3fe957e5a4558b0e8 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 18:16:22 +0000 Subject: [PATCH 10/60] Test `reshape()` --- .../test_manipulation_functions.py | 46 +++++++++++++++++-- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 08288787..4f7f3426 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,6 +1,6 @@ import math -from hypothesis import given +from hypothesis import assume, given from hypothesis import strategies as st from . import _array_module as xp @@ -113,15 +113,51 @@ def test_permute_dims(x, axes): shape = tuple(shape) ph.assert_result_shape("permute_dims", (x.shape,), out.shape, shape, axes=axes) + # TODO: test elements + + +MAX_RESHAPE_SIDE = hh.MAX_ARRAY_SIZE // 64 +reshape_x_shapes = st.shared( + hh.shapes().filter(lambda s: math.prod(s) <= MAX_RESHAPE_SIDE), + key="reshape x shape", +) + + +@st.composite +def reshape_shapes(draw, shape): + size = 1 if len(shape) == 0 else math.prod(shape) + rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size)) + assume(all(side <= MAX_RESHAPE_SIDE for side in rshape)) + if len(rshape) != 0 and size > 0 and draw(st.booleans()): + index = draw(st.integers(0, len(rshape) - 1)) + rshape[index] = -1 + return tuple(rshape) + @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), - shape=shared_shapes(), # TODO: test more compatible shapes + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=reshape_x_shapes), + shape=reshape_x_shapes.flatmap(reshape_shapes), ) def test_reshape(x, shape): - xp.reshape(x, shape) - # TODO + assume(math.prod(shape) == math.prod(x.shape)) + + out = xp.reshape(x, shape) + + ph.assert_dtype("reshape", x.dtype, out.dtype) + + _shape = shape + if any(side == -1 for side in shape): + size = math.prod(x.shape) + rsize = math.prod(shape) * -1 + _shape[shape.index(-1)] = size / rsize + ph.assert_result_shape("reshape", (x.shape,), out.shape, _shape, shape=shape) + for x_idx, out_idx in zip(ah.ndindex(x.shape), ah.ndindex(out.shape)): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg @given( # TODO: axis arguments, update shift respectively From 3ac57efff17d322a871a667eef49dab35f44b469 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 19:44:35 +0000 Subject: [PATCH 11/60] Test `roll()` --- .../test_manipulation_functions.py | 40 +++++++++++++++---- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 4f7f3426..b8adec65 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,4 +1,5 @@ import math +from collections import deque from hypothesis import assume, given from hypothesis import strategies as st @@ -159,14 +160,37 @@ def test_reshape(x, shape): else: assert out[out_idx] == x[x_idx], msg -@given( - # TODO: axis arguments, update shift respectively - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), - shift=shared_shapes().flatmap(lambda s: st.integers(0, max(math.prod(s) - 1, 0))), -) -def test_roll(x, shift): - xp.roll(x, shift) - # TODO + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data()) +def test_roll(x, data): + shift = data.draw( + st.integers() | st.lists(st.integers(), max_size=x.ndim).map(tuple), + label="shift", + ) + axis_strats = [st.none()] + if x.shape != (): + axis_strats.append(st.integers(-x.ndim, x.ndim - 1)) + if isinstance(shift, int): + axis_strats.append(xps.valid_tuple_axes(x.ndim)) + kw = data.draw(hh.kwargs(axis=st.one_of(axis_strats)), label="kw") + + out = xp.roll(x, shift, **kw) + + ph.assert_dtype("roll", x.dtype, out.dtype) + + ph.assert_result_shape("roll", (x.shape,), out.shape) + + # TODO: test all shift/axis scenarios + if isinstance(shift, int) and kw.get("axis", None) is None: + indices = list(ah.ndindex(x.shape)) + shifted_indices = deque(indices) + shifted_indices.rotate(shift) + for x_idx, out_idx in zip(indices, shifted_indices): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg @given( From 258b75e5aa6f25c2aa5eb1de2100649813c176d1 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 20:30:09 +0000 Subject: [PATCH 12/60] Test `squeeze()` --- .../test_manipulation_functions.py | 62 ++++++++++++++----- 1 file changed, 48 insertions(+), 14 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index b8adec65..e17682c0 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -46,7 +46,7 @@ def test_concat(shape, dtypes, kw, data): @given( x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), - axis=shared_shapes().flatmap(lambda s: st.integers(-len(s), len(s))), + axis=shared_shapes().flatmap(lambda s: st.integers(-len(s) - 1, len(s))), ) def test_expand_dims(x, axis): out = xp.expand_dims(x, axis=axis) @@ -59,6 +59,53 @@ def test_expand_dims(x, axis): shape = tuple(shape) ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape) + for x_idx, out_idx in zip(ah.ndindex(x.shape), ah.ndindex(out.shape)): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg + + +@given( + x=xps.arrays( + dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1).filter(lambda s: 1 in s) + ), + data=st.data(), +) +def test_squeeze(x, data): + # axis=shared_shapes(min_side=1).flatmap(lambda s: nd_axes(len(s))), + squeezable_axes = st.sampled_from( + [i for i, side in enumerate(x.shape) if side == 1] + ) + axis = data.draw( + # TODO: generate valid negative axis + squeezable_axes | st.lists(squeezable_axes, unique=True).map(tuple), + label="axis", + ) + + out = xp.squeeze(x, axis) + + ph.assert_dtype("squeeze", x.dtype, out.dtype) + + if isinstance(axis, int): + axes = (axis,) + else: + axes = axis + shape = [] + for i, side in enumerate(x.shape): + if i not in axes: + shape.append(side) + shape = tuple(shape) + ph.assert_result_shape("squeeze", (x.shape,), out.shape, shape, axis=axis) + + for x_idx, out_idx in zip(ah.ndindex(x.shape), ah.ndindex(out.shape)): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg + @st.composite def flip_axis(draw, shape): @@ -193,19 +240,6 @@ def test_roll(x, data): assert out[out_idx] == x[x_idx], msg -@given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), - axis=shared_shapes().flatmap( - lambda s: st.just(0) - if len(s) == 0 - else st.integers(-len(s) + 1, len(s) - 1).filter(lambda i: s[i] == 1) - ), # TODO: tuple of axis i.e. axes -) -def test_squeeze(x, axis): - xp.squeeze(x, axis) - # TODO - - @given( shape=hh.shapes(), dtypes=hh.mutually_promotable_dtypes(None), From 1bdf6e476ee54f9e402d74959dc5ce0363a706b8 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 20:42:01 +0000 Subject: [PATCH 13/60] Refactor assertions using ndindex --- .../test_manipulation_functions.py | 59 +++++++++---------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index e17682c0..18b0cf2e 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,5 +1,6 @@ import math from collections import deque +from typing import Iterable, Union from hypothesis import assume, given from hypothesis import strategies as st @@ -10,7 +11,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps -from .typing import Shape +from .typing import Array, Shape def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: @@ -22,6 +23,21 @@ def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: return st.shared(hh.shapes(*args, **kwargs), key="shape") +def assert_array_ndindex( + func_name: str, + x: Array, + x_indices: Iterable[Union[int, Shape]], + out: Array, + out_indices: Iterable[Union[int, Shape]], +): + for x_idx, out_idx in zip(x_indices, out_indices): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]} [{func_name}()]" + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg + + @given( shape=hh.shapes(min_dims=1), dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), @@ -59,12 +75,9 @@ def test_expand_dims(x, axis): shape = tuple(shape) ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape) - for x_idx, out_idx in zip(ah.ndindex(x.shape), ah.ndindex(out.shape)): - msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg + assert_array_ndindex( + "expand_dims", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape) + ) @given( @@ -99,12 +112,7 @@ def test_squeeze(x, data): shape = tuple(shape) ph.assert_result_shape("squeeze", (x.shape,), out.shape, shape, axis=axis) - for x_idx, out_idx in zip(ah.ndindex(x.shape), ah.ndindex(out.shape)): - msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg + assert_array_ndindex("squeeze", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) @st.composite @@ -123,18 +131,13 @@ def flip_axis(draw, shape): def test_flip(x, kw): out = xp.flip(x, **kw) - ph.assert_dtype("expand_dims", x.dtype, out.dtype) + ph.assert_dtype("flip", x.dtype, out.dtype) # TODO: test all axis scenarios if kw.get("axis", None) is None: indices = list(ah.ndindex(x.shape)) reverse_indices = indices[::-1] - for x_idx, out_idx in zip(indices, reverse_indices): - msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg + assert_array_ndindex("flip", x, indices, out, reverse_indices) @given( @@ -200,12 +203,7 @@ def test_reshape(x, shape): _shape[shape.index(-1)] = size / rsize ph.assert_result_shape("reshape", (x.shape,), out.shape, _shape, shape=shape) - for x_idx, out_idx in zip(ah.ndindex(x.shape), ah.ndindex(out.shape)): - msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg + assert_array_ndindex("reshape", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data()) @@ -232,12 +230,9 @@ def test_roll(x, data): indices = list(ah.ndindex(x.shape)) shifted_indices = deque(indices) shifted_indices.rotate(shift) - for x_idx, out_idx in zip(indices, shifted_indices): - msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg + print(f"{indices=}") + print(f"{shifted_indices=}") + assert_array_ndindex("roll", x, indices, out, shifted_indices) @given( From f1cc3ea02aeec032525b943ee1629c8af7cfe277 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 11 Nov 2021 09:34:20 +0000 Subject: [PATCH 14/60] Fix `test_roll` --- array_api_tests/test_manipulation_functions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 18b0cf2e..1e91fafe 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -30,8 +30,10 @@ def assert_array_ndindex( out: Array, out_indices: Iterable[Union[int, Shape]], ): + msg_suffix = f" [{func_name}()]\n {x=}\n{out=}" for x_idx, out_idx in zip(x_indices, out_indices): - msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]} [{func_name}()]" + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + msg += msg_suffix if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): assert xp.isnan(out[out_idx]), msg else: @@ -229,9 +231,7 @@ def test_roll(x, data): if isinstance(shift, int) and kw.get("axis", None) is None: indices = list(ah.ndindex(x.shape)) shifted_indices = deque(indices) - shifted_indices.rotate(shift) - print(f"{indices=}") - print(f"{shifted_indices=}") + shifted_indices.rotate(-shift) assert_array_ndindex("roll", x, indices, out, shifted_indices) From 36679c43979a10b68175f1fbfc5b19da9b35bba1 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 11 Nov 2021 11:27:04 +0000 Subject: [PATCH 15/60] Test `concat()` --- .../test_manipulation_functions.py | 63 +++++++++++++++---- 1 file changed, 51 insertions(+), 12 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 1e91fafe..ed0267b7 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -13,6 +13,9 @@ from . import xps from .typing import Array, Shape +MAX_SIDE = hh.MAX_ARRAY_SIZE // 64 +MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims + def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: key = "shape" @@ -40,26 +43,63 @@ def assert_array_ndindex( assert out[out_idx] == x[x_idx], msg +@st.composite +def concat_shapes(draw, shape, axis): + shape = list(shape) + shape[axis] = draw(st.integers(1, MAX_SIDE)) + return tuple(shape) + + @given( - shape=hh.shapes(min_dims=1), dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), - kw=hh.kwargs(axis=st.just(0) | st.none()), # TODO: test with axis >= 1 + kw=hh.kwargs(axis=st.none() | st.integers(-MAX_DIMS, MAX_DIMS - 1)), data=st.data(), ) -def test_concat(shape, dtypes, kw, data): +def test_concat(dtypes, kw, data): + axis = kw.get("axis", 0) + if axis is None: + shape_strat = hh.shapes() + else: + _axis = axis if axis >= 0 else abs(axis) - 1 + shape_strat = shared_shapes(min_dims=_axis + 1).flatmap( + lambda s: concat_shapes(s, axis) + ) arrays = [] for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") + x = data.draw(xps.arrays(dtype=dtype, shape=shape_strat), label=f"x{i}") arrays.append(x) + out = xp.concat(arrays, **kw) + ph.assert_dtype("concat", dtypes, out.dtype) + shapes = tuple(x.shape for x in arrays) - if kw.get("axis", 0) == 0: - pass # TODO: assert expected shape - elif kw["axis"] is None: + axis = kw.get("axis", 0) + if axis is None: size = sum(math.prod(s) for s in shapes) - ph.assert_result_shape("concat", shapes, out.shape, (size,), **kw) - # TODO: assert out elements match input arrays + shape = (size,) + else: + shape = list(shapes[0]) + for other_shape in shapes[1:]: + shape[axis] += other_shape[axis] + shape = tuple(shape) + ph.assert_result_shape("concat", shapes, out.shape, shape, **kw) + + # TODO: adjust indices with nonzero axis + if axis is None or axis == 0: + out_indices = ah.ndindex(out.shape) + for i, x in enumerate(arrays, 1): + msg_suffix = f" [concat({ph.fmt_kw(kw)})]\nx{i}={x!r}\n{out=}" + for x_idx in ah.ndindex(x.shape): + out_idx = next(out_indices) + msg = ( + f"out[{out_idx}]={out[out_idx]}, should be x{i}[{x_idx}]={x[x_idx]}" + ) + msg += msg_suffix + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg @given( @@ -169,9 +209,8 @@ def test_permute_dims(x, axes): # TODO: test elements -MAX_RESHAPE_SIDE = hh.MAX_ARRAY_SIZE // 64 reshape_x_shapes = st.shared( - hh.shapes().filter(lambda s: math.prod(s) <= MAX_RESHAPE_SIDE), + hh.shapes().filter(lambda s: math.prod(s) <= MAX_SIDE), key="reshape x shape", ) @@ -180,7 +219,7 @@ def test_permute_dims(x, axes): def reshape_shapes(draw, shape): size = 1 if len(shape) == 0 else math.prod(shape) rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size)) - assume(all(side <= MAX_RESHAPE_SIDE for side in rshape)) + assume(all(side <= MAX_SIDE for side in rshape)) if len(rshape) != 0 and size > 0 and draw(st.booleans()): index = draw(st.integers(0, len(rshape) - 1)) rshape[index] = -1 From a8fb70d51c22f530d30a134e2bf0b5273d07db35 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 11 Nov 2021 11:54:05 +0000 Subject: [PATCH 16/60] Test `stack()` --- .../test_manipulation_functions.py | 39 +++++++++++++++++-- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index ed0267b7..6ccadf43 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -275,15 +275,46 @@ def test_roll(x, data): @given( - shape=hh.shapes(), + shape=shared_shapes(min_dims=1), dtypes=hh.mutually_promotable_dtypes(None), + kw=hh.kwargs( + axis=shared_shapes(min_dims=1).flatmap( + lambda s: st.integers(-len(s), len(s) - 1) + ) + ), data=st.data(), ) -def test_stack(shape, dtypes, data): +def test_stack(shape, dtypes, kw, data): arrays = [] for i, dtype in enumerate(dtypes, 1): x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") arrays.append(x) - out = xp.stack(arrays) + + out = xp.stack(arrays, **kw) + ph.assert_dtype("stack", dtypes, out.dtype) - # TODO + + axis = kw.get("axis", 0) + _axis = axis if axis >= 0 else len(shape) + axis + 1 + _shape = list(shape) + _shape.insert(_axis, len(arrays)) + _shape = tuple(_shape) + ph.assert_result_shape( + "stack", tuple(x.shape for x in arrays), out.shape, _shape, **kw + ) + + # TODO: adjust indices with nonzero axis + if axis == 0: + out_indices = ah.ndindex(out.shape) + for i, x in enumerate(arrays, 1): + msg_suffix = f" [stack({ph.fmt_kw(kw)})]\nx{i}={x!r}\n{out=}" + for x_idx in ah.ndindex(x.shape): + out_idx = next(out_indices) + msg = ( + f"out[{out_idx}]={out[out_idx]}, should be x{i}[{x_idx}]={x[x_idx]}" + ) + msg += msg_suffix + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg From cf7688844cc2267b35bff5eb9f0172f963a40177 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 11 Nov 2021 12:09:13 +0000 Subject: [PATCH 17/60] Manipulation tests clean up --- .../test_manipulation_functions.py | 43 ++++++++----------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 6ccadf43..408a587a 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -129,12 +129,11 @@ def test_expand_dims(x, axis): data=st.data(), ) def test_squeeze(x, data): - # axis=shared_shapes(min_side=1).flatmap(lambda s: nd_axes(len(s))), + # TODO: generate valid negative axis (which keep uniqueness) squeezable_axes = st.sampled_from( [i for i, side in enumerate(x.shape) if side == 1] ) axis = data.draw( - # TODO: generate valid negative axis squeezable_axes | st.lists(squeezable_axes, unique=True).map(tuple), label="axis", ) @@ -157,20 +156,19 @@ def test_squeeze(x, data): assert_array_ndindex("squeeze", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) -@st.composite -def flip_axis(draw, shape): - if len(shape) == 0 or draw(st.booleans()): - return None - else: - ndim = len(shape) - return draw(st.integers(-ndim, ndim - 1) | xps.valid_tuple_axes(ndim)) - - @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), - kw=hh.kwargs(axis=shared_shapes().flatmap(flip_axis)), + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), + data=st.data(), ) -def test_flip(x, kw): +def test_flip(x, data): + if x.ndim == 0: + axis_strat = st.none() + else: + axis_strat = ( + st.none() | st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) + ) + kw = data.draw(hh.kwargs(axis=axis_strat), label="kw") + out = xp.flip(x, **kw) ph.assert_dtype("flip", x.dtype, out.dtype) @@ -209,12 +207,6 @@ def test_permute_dims(x, axes): # TODO: test elements -reshape_x_shapes = st.shared( - hh.shapes().filter(lambda s: math.prod(s) <= MAX_SIDE), - key="reshape x shape", -) - - @st.composite def reshape_shapes(draw, shape): size = 1 if len(shape) == 0 else math.prod(shape) @@ -227,21 +219,22 @@ def reshape_shapes(draw, shape): @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=reshape_x_shapes), - shape=reshape_x_shapes.flatmap(reshape_shapes), + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(max_side=MAX_SIDE)), + data=st.data(), ) -def test_reshape(x, shape): - assume(math.prod(shape) == math.prod(x.shape)) +def test_reshape(x, data): + shape = data.draw(reshape_shapes(x.shape)) out = xp.reshape(x, shape) ph.assert_dtype("reshape", x.dtype, out.dtype) - _shape = shape + _shape = list(shape) if any(side == -1 for side in shape): size = math.prod(x.shape) rsize = math.prod(shape) * -1 _shape[shape.index(-1)] = size / rsize + _shape = tuple(_shape) ph.assert_result_shape("reshape", (x.shape,), out.shape, _shape, shape=shape) assert_array_ndindex("reshape", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) From 30a072163576a6664c94b93dd4437432f4fbeec7 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 11 Nov 2021 18:57:49 +0000 Subject: [PATCH 18/60] Improve `min()`/`max()` tests --- array_api_tests/test_statistical_functions.py | 106 ++++++++++++++++-- 1 file changed, 96 insertions(+), 10 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 98cade7d..1b2078b7 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,22 +1,108 @@ +import math + from hypothesis import given +from hypothesis import strategies as st from . import _array_module as xp +from . import array_helpers as ah +from . import dtype_helpers as dh from . import hypothesis_helpers as hh +from . import pytest_helpers as ph from . import xps -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) -def test_min(x): - xp.min(x) - # TODO +@given( + x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_min(x, data): + axis_strats = [st.none()] + if x.shape != (): + axis_strats.append( + st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) + ) + kw = data.draw( + hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw" + ) + out = xp.min(x, **kw) -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) -def test_max(x): - xp.max(x) - # TODO + ph.assert_dtype("min", x.dtype, out.dtype) + + f_func = f"min({ph.fmt_kw(kw)})" + + # TODO: support axis + if kw.get("axis") is None: + keepdims = kw.get("keepdims", False) + if keepdims: + idx = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" + assert out.shape == idx + else: + ph.assert_shape("min", out.shape, (), **kw) + + # TODO: figure out NaN behaviour + if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)): + _out = xp.reshape(out, ()) if keepdims else out + scalar_type = dh.get_scalar_type(out.dtype) + elements = [] + for idx in ah.ndindex(x.shape): + s = scalar_type(x[idx]) + elements.append(s) + min_ = scalar_type(_out) + expected = min(elements) + msg = f"out={min_}, should be {expected} [{f_func}]" + if math.isnan(min_): + assert math.isnan(expected), msg + else: + assert min_ == expected, msg + + +@given( + x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_max(x, data): + axis_strats = [st.none()] + if x.shape != (): + axis_strats.append( + st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) + ) + kw = data.draw( + hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw" + ) + + out = xp.max(x, **kw) + + ph.assert_dtype("max", x.dtype, out.dtype) + + f_func = f"max({ph.fmt_kw(kw)})" + + # TODO: support axis + if kw.get("axis") is None: + keepdims = kw.get("keepdims", False) + if keepdims: + idx = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" + assert out.shape == idx + else: + ph.assert_shape("max", out.shape, (), **kw) + + # TODO: figure out NaN behaviour + if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)): + _out = xp.reshape(out, ()) if keepdims else out + scalar_type = dh.get_scalar_type(out.dtype) + elements = [] + for idx in ah.ndindex(x.shape): + s = scalar_type(x[idx]) + elements.append(s) + max_ = scalar_type(_out) + expected = max(elements) + msg = f"out={max_}, should be {expected} [{f_func}]" + if math.isnan(max_): + assert math.isnan(expected), msg + else: + assert max_ == expected, msg # TODO: generate kwargs From 50a63b24fbd908a6a50abad1178c15b05152de3e Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 12 Nov 2021 09:35:43 +0000 Subject: [PATCH 19/60] Test `mean()` --- array_api_tests/test_statistical_functions.py | 52 ++++++++++++++++--- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 1b2078b7..54ec1d5b 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -10,6 +10,8 @@ from . import pytest_helpers as ph from . import xps +RTOL = 0.05 + @given( x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), @@ -37,7 +39,7 @@ def test_min(x, data): if keepdims: idx = tuple(1 for _ in x.shape) msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" - assert out.shape == idx + assert out.shape == idx, msg else: ph.assert_shape("min", out.shape, (), **kw) @@ -84,7 +86,7 @@ def test_max(x, data): if keepdims: idx = tuple(1 for _ in x.shape) msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" - assert out.shape == idx + assert out.shape == idx, msg else: ph.assert_shape("max", out.shape, (), **kw) @@ -105,11 +107,47 @@ def test_max(x, data): assert max_ == expected, msg -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) -def test_mean(x): - xp.mean(x) - # TODO +@given( + x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_mean(x, data): + axis_strats = [st.none()] + if x.shape != (): + axis_strats.append( + st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) + ) + kw = data.draw( + hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw" + ) + + out = xp.mean(x, **kw) + + ph.assert_dtype("mean", x.dtype, out.dtype) + + f_func = f"mean({ph.fmt_kw(kw)})" + + # TODO: support axis + if kw.get("axis") is None: + keepdims = kw.get("keepdims", False) + if keepdims: + idx = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" + assert out.shape == idx, msg + else: + ph.assert_shape("max", out.shape, (), **kw) + + # TODO: figure out NaN behaviour + if not xp.any(xp.isnan(x)): + _out = xp.reshape(out, ()) if keepdims else out + elements = [] + for idx in ah.ndindex(x.shape): + s = float(x[idx]) + elements.append(s) + mean = float(_out) + expected = sum(elements) / len(elements) + msg = f"out={mean}, should be roughly {expected} [{f_func}]" + assert math.isclose(mean, expected, rel_tol=RTOL), msg # TODO: generate kwargs From 1ff7271d661518b305a5aef5d4d111a1d513d007 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 12 Nov 2021 11:02:31 +0000 Subject: [PATCH 20/60] Test `prod()` --- array_api_tests/test_statistical_functions.py | 109 ++++++++++++++---- 1 file changed, 87 insertions(+), 22 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 54ec1d5b..8219dfd5 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,6 +1,6 @@ import math -from hypothesis import given +from hypothesis import assume, given from hypothesis import strategies as st from . import _array_module as xp @@ -9,8 +9,22 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps +from .typing import Scalar, ScalarType -RTOL = 0.05 + +def assert_equals( + func_name: str, type_: ScalarType, out: Scalar, expected: Scalar, /, **kw +): + f_func = f"{func_name}({ph.fmt_kw(kw)})" + if type_ is bool or type_ is int: + msg = f"{out=}, should be {expected} [{f_func}]" + assert out == expected, msg + elif math.isnan(expected): + msg = f"{out=}, should be {expected} [{f_func}]" + assert math.isnan(out), msg + else: + msg = f"{out=}, should be roughly {expected} [{f_func}]" + assert math.isclose(out, expected, rel_tol=0.05), msg @given( @@ -34,7 +48,7 @@ def test_min(x, data): f_func = f"min({ph.fmt_kw(kw)})" # TODO: support axis - if kw.get("axis") is None: + if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: idx = tuple(1 for _ in x.shape) @@ -53,11 +67,7 @@ def test_min(x, data): elements.append(s) min_ = scalar_type(_out) expected = min(elements) - msg = f"out={min_}, should be {expected} [{f_func}]" - if math.isnan(min_): - assert math.isnan(expected), msg - else: - assert min_ == expected, msg + assert_equals("min", dh.get_scalar_type(out.dtype), min_, expected) @given( @@ -81,7 +91,7 @@ def test_max(x, data): f_func = f"max({ph.fmt_kw(kw)})" # TODO: support axis - if kw.get("axis") is None: + if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: idx = tuple(1 for _ in x.shape) @@ -100,11 +110,7 @@ def test_max(x, data): elements.append(s) max_ = scalar_type(_out) expected = max(elements) - msg = f"out={max_}, should be {expected} [{f_func}]" - if math.isnan(max_): - assert math.isnan(expected), msg - else: - assert max_ == expected, msg + assert_equals("mean", dh.get_scalar_type(out.dtype), max_, expected) @given( @@ -128,7 +134,7 @@ def test_mean(x, data): f_func = f"mean({ph.fmt_kw(kw)})" # TODO: support axis - if kw.get("axis") is None: + if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: idx = tuple(1 for _ in x.shape) @@ -146,15 +152,74 @@ def test_mean(x, data): elements.append(s) mean = float(_out) expected = sum(elements) / len(elements) - msg = f"out={mean}, should be roughly {expected} [{f_func}]" - assert math.isclose(mean, expected, rel_tol=RTOL), msg + assert_equals("mean", float, mean, expected) -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) -def test_prod(x): - xp.prod(x) - # TODO +@given( + x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_prod(x, data): + axis_strats = [st.none()] + if x.shape != (): + axis_strats.append( + st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) + ) + kw = data.draw( + hh.kwargs( + axis=st.one_of(axis_strats), + dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes + keepdims=st.booleans(), + ), + label="kw", + ) + + out = xp.prod(x, **kw) + + dtype = kw.get("dtype", None) + if dtype is None: + if dh.is_int_dtype(x.dtype): + m, M = dh.dtype_ranges[x.dtype] + d_m, d_M = dh.dtype_ranges[dh.default_int] + if m < d_m or M > d_M: + _dtype = x.dtype + else: + _dtype = dh.default_int + else: + if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: + _dtype = x.dtype + else: + _dtype = dh.default_float + else: + _dtype = dtype + ph.assert_dtype("prod", x.dtype, out.dtype, _dtype) + + f_func = f"prod({ph.fmt_kw(kw)})" + + # TODO: support axis + if kw.get("axis", None) is None: + keepdims = kw.get("keepdims", False) + if keepdims: + idx = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" + assert out.shape == idx, msg + else: + ph.assert_shape("prod", out.shape, (), **kw) + + # TODO: figure out NaN behaviour + if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)): + _out = xp.reshape(out, ()) if keepdims else out + scalar_type = dh.get_scalar_type(out.dtype) + elements = [] + for idx in ah.ndindex(x.shape): + s = scalar_type(x[idx]) + elements.append(s) + prod = scalar_type(_out) + expected = math.prod(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + assert_equals("prod", dh.get_scalar_type(out.dtype), prod, expected) # TODO: generate kwargs From f426f88b0d8bcd5e7e1896fb7c47a4222c11bce0 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 12 Nov 2021 11:15:46 +0000 Subject: [PATCH 21/60] Refactor axes strategies for stat functions --- array_api_tests/test_statistical_functions.py | 57 +++++++------------ 1 file changed, 22 insertions(+), 35 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 8219dfd5..640ee7ec 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,4 +1,5 @@ import math +from typing import Optional, Union from hypothesis import assume, given from hypothesis import strategies as st @@ -9,7 +10,15 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps -from .typing import Scalar, ScalarType +from .typing import Scalar, ScalarType, Shape + + +def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]: + axes_strats = [st.none()] + if ndim != 0: + axes_strats.append(st.integers(-ndim, ndim - 1)) + axes_strats.append(xps.valid_tuple_axes(ndim)) + return st.one_of(axes_strats) def assert_equals( @@ -32,14 +41,7 @@ def assert_equals( data=st.data(), ) def test_min(x, data): - axis_strats = [st.none()] - if x.shape != (): - axis_strats.append( - st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) - ) - kw = data.draw( - hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw" - ) + kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") out = xp.min(x, **kw) @@ -75,14 +77,7 @@ def test_min(x, data): data=st.data(), ) def test_max(x, data): - axis_strats = [st.none()] - if x.shape != (): - axis_strats.append( - st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) - ) - kw = data.draw( - hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw" - ) + kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") out = xp.max(x, **kw) @@ -118,14 +113,7 @@ def test_max(x, data): data=st.data(), ) def test_mean(x, data): - axis_strats = [st.none()] - if x.shape != (): - axis_strats.append( - st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) - ) - kw = data.draw( - hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw" - ) + kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") out = xp.mean(x, **kw) @@ -160,14 +148,9 @@ def test_mean(x, data): data=st.data(), ) def test_prod(x, data): - axis_strats = [st.none()] - if x.shape != (): - axis_strats.append( - st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) - ) kw = data.draw( hh.kwargs( - axis=st.one_of(axis_strats), + axis=axes(x.ndim), dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes keepdims=st.booleans(), ), @@ -222,10 +205,14 @@ def test_prod(x, data): assert_equals("prod", dh.get_scalar_type(out.dtype), prod, expected) -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) -def test_std(x): - xp.std(x) +@given( + x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_std(x, data): + kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + + xp.std(x, **kw) # TODO From 1238e75e5f7a946987b10d1000e11ebbe0c7780e Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 12 Nov 2021 18:23:36 +0000 Subject: [PATCH 22/60] Test `std()` --- array_api_tests/test_statistical_functions.py | 65 ++++++++++++++----- 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 640ee7ec..6b160c7e 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -53,9 +53,9 @@ def test_min(x, data): if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: - idx = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" - assert out.shape == idx, msg + shape = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" + assert out.shape == shape, msg else: ph.assert_shape("min", out.shape, (), **kw) @@ -89,9 +89,9 @@ def test_max(x, data): if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: - idx = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" - assert out.shape == idx, msg + shape = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" + assert out.shape == shape, msg else: ph.assert_shape("max", out.shape, (), **kw) @@ -125,9 +125,9 @@ def test_mean(x, data): if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: - idx = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" - assert out.shape == idx, msg + shape = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" + assert out.shape == shape, msg else: ph.assert_shape("max", out.shape, (), **kw) @@ -183,9 +183,9 @@ def test_prod(x, data): if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: - idx = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" - assert out.shape == idx, msg + shape = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" + assert out.shape == shape, msg else: ph.assert_shape("prod", out.shape, (), **kw) @@ -206,14 +206,47 @@ def test_prod(x, data): @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), + x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)).filter( + lambda x: x.size >= 2 + ), data=st.data(), ) def test_std(x, data): - kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + axis = data.draw(axes(x.ndim), label="axis") + if axis is None: + N = x.size + _axes = tuple(range(x.ndim)) + else: + _axes = axis if isinstance(axis, tuple) else (axis,) + _axes = tuple( + axis if axis >= 0 else x.ndim + axis for axis in _axes + ) # normalise + N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) + correction = data.draw( + st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), + label="correction", + ) + keepdims = data.draw(st.booleans(), label="keepdims") + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("correction", correction, 0.0), + ("keepdims", keepdims, False), + ), + label="kw", + ) - xp.std(x, **kw) - # TODO + out = xp.std(x, **kw) + + ph.assert_dtype("std", x.dtype, out.dtype) + + if keepdims: + shape = tuple(1 if axis in _axes else side for axis, side in enumerate(x.shape)) + else: + shape = tuple(side for axis, side in enumerate(x.shape) if axis not in _axes) + ph.assert_shape("std", out.shape, shape, **kw) + + # We can't easily test the result(s) as standard deviation methods vary a lot # TODO: generate kwargs From d1eece1057708f8df73159f1536a687bfb3860d6 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 15 Nov 2021 12:56:31 +0000 Subject: [PATCH 23/60] Test axes results --- array_api_tests/meta/test_utils.py | 21 +- array_api_tests/test_statistical_functions.py | 272 ++++++++++-------- 2 files changed, 166 insertions(+), 127 deletions(-) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 7dfafc5b..35c884a3 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -1,7 +1,8 @@ import pytest -from ..test_signatures import extension_module from ..test_creation_functions import frange +from ..test_signatures import extension_module +from ..test_statistical_functions import axes_ndindex def test_extension_module_is_extension(): @@ -24,3 +25,21 @@ def test_extension_func_is_not_extension(): def test_frange(r, size, elements): assert len(r) == size assert list(r) == elements + + +@pytest.mark.parametrize( + "shape, axes, expected", + [ + ((), (), [((),)]), + ( + (2, 2), + (0,), + [ + ((0, 0), (1, 0)), + ((0, 1), (1, 1)), + ], + ), + ], +) +def test_axes_ndindex(shape, axes, expected): + assert list(axes_ndindex(shape, axes)) == expected diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 6b160c7e..3907d64e 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,5 +1,6 @@ import math -from typing import Optional, Union +from itertools import product +from typing import Iterator, Optional, Tuple, Union from hypothesis import assume, given from hypothesis import strategies as st @@ -21,23 +22,82 @@ def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]: return st.one_of(axes_strats) +def normalise_axis( + axis: Optional[Union[int, Tuple[int, ...]]], ndim: int +) -> Tuple[int, ...]: + if axis is None: + return tuple(range(ndim)) + axes = axis if isinstance(axis, tuple) else (axis,) + axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes) + return axes + + +def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, ...]]: + base_iterables = [] + axes_iterables = [] + for axis, side in enumerate(shape): + if axis in axes: + base_iterables.append((None,)) + axes_iterables.append(range(side)) + else: + base_iterables.append(range(side)) + axes_iterables.append((None,)) + for base_idx in product(*base_iterables): + indices = [] + for idx in product(*axes_iterables): + idx = list(idx) + for axis, side in enumerate(idx): + if axis not in axes: + idx[axis] = base_idx[axis] + idx = tuple(idx) + indices.append(idx) + yield tuple(indices) + + +def assert_keepdimable_shape( + func_name: str, + in_shape: Shape, + axes: Tuple[int, ...], + keepdims: bool, + out_shape: Shape, + /, + **kw, +): + if keepdims: + shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape)) + else: + shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes) + ph.assert_shape(func_name, out_shape, shape, **kw) + + def assert_equals( - func_name: str, type_: ScalarType, out: Scalar, expected: Scalar, /, **kw + func_name: str, + type_: ScalarType, + idx: Shape, + out: Scalar, + expected: Scalar, + /, + **kw, ): + out_repr = "out" if idx == () else f"out[{idx}]" f_func = f"{func_name}({ph.fmt_kw(kw)})" if type_ is bool or type_ is int: - msg = f"{out=}, should be {expected} [{f_func}]" + msg = f"{out_repr}={out}, should be {expected} [{f_func}]" assert out == expected, msg elif math.isnan(expected): - msg = f"{out=}, should be {expected} [{f_func}]" + msg = f"{out_repr}={out}, should be {expected} [{f_func}]" assert math.isnan(out), msg else: - msg = f"{out=}, should be roughly {expected} [{f_func}]" + msg = f"{out_repr}={out}, should be roughly {expected} [{f_func}]" assert math.isclose(out, expected, rel_tol=0.05), msg @given( - x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), data=st.data(), ) def test_min(x, data): @@ -46,34 +106,27 @@ def test_min(x, data): out = xp.min(x, **kw) ph.assert_dtype("min", x.dtype, out.dtype) - - f_func = f"min({ph.fmt_kw(kw)})" - - # TODO: support axis - if kw.get("axis", None) is None: - keepdims = kw.get("keepdims", False) - if keepdims: - shape = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" - assert out.shape == shape, msg - else: - ph.assert_shape("min", out.shape, (), **kw) - - # TODO: figure out NaN behaviour - if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)): - _out = xp.reshape(out, ()) if keepdims else out - scalar_type = dh.get_scalar_type(out.dtype) - elements = [] - for idx in ah.ndindex(x.shape): - s = scalar_type(x[idx]) - elements.append(s) - min_ = scalar_type(_out) - expected = min(elements) - assert_equals("min", dh.get_scalar_type(out.dtype), min_, expected) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "min", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + min_ = scalar_type(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = min(elements) + assert_equals("min", dh.get_scalar_type(out.dtype), out_idx, min_, expected) @given( - x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), data=st.data(), ) def test_max(x, data): @@ -82,34 +135,27 @@ def test_max(x, data): out = xp.max(x, **kw) ph.assert_dtype("max", x.dtype, out.dtype) - - f_func = f"max({ph.fmt_kw(kw)})" - - # TODO: support axis - if kw.get("axis", None) is None: - keepdims = kw.get("keepdims", False) - if keepdims: - shape = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" - assert out.shape == shape, msg - else: - ph.assert_shape("max", out.shape, (), **kw) - - # TODO: figure out NaN behaviour - if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)): - _out = xp.reshape(out, ()) if keepdims else out - scalar_type = dh.get_scalar_type(out.dtype) - elements = [] - for idx in ah.ndindex(x.shape): - s = scalar_type(x[idx]) - elements.append(s) - max_ = scalar_type(_out) - expected = max(elements) - assert_equals("mean", dh.get_scalar_type(out.dtype), max_, expected) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "max", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + max_ = scalar_type(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = max(elements) + assert_equals("max", dh.get_scalar_type(out.dtype), out_idx, max_, expected) @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), data=st.data(), ) def test_mean(x, data): @@ -118,33 +164,26 @@ def test_mean(x, data): out = xp.mean(x, **kw) ph.assert_dtype("mean", x.dtype, out.dtype) - - f_func = f"mean({ph.fmt_kw(kw)})" - - # TODO: support axis - if kw.get("axis", None) is None: - keepdims = kw.get("keepdims", False) - if keepdims: - shape = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" - assert out.shape == shape, msg - else: - ph.assert_shape("max", out.shape, (), **kw) - - # TODO: figure out NaN behaviour - if not xp.any(xp.isnan(x)): - _out = xp.reshape(out, ()) if keepdims else out - elements = [] - for idx in ah.ndindex(x.shape): - s = float(x[idx]) - elements.append(s) - mean = float(_out) - expected = sum(elements) / len(elements) - assert_equals("mean", float, mean, expected) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "mean", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + ) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + mean = float(out[out_idx]) + elements = [] + for idx in indices: + s = float(x[idx]) + elements.append(s) + expected = sum(elements) / len(elements) + assert_equals("mean", dh.get_scalar_type(out.dtype), out_idx, mean, expected) @given( - x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), data=st.data(), ) def test_prod(x, data): @@ -176,52 +215,37 @@ def test_prod(x, data): else: _dtype = dtype ph.assert_dtype("prod", x.dtype, out.dtype, _dtype) - - f_func = f"prod({ph.fmt_kw(kw)})" - - # TODO: support axis - if kw.get("axis", None) is None: - keepdims = kw.get("keepdims", False) - if keepdims: - shape = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" - assert out.shape == shape, msg - else: - ph.assert_shape("prod", out.shape, (), **kw) - - # TODO: figure out NaN behaviour - if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)): - _out = xp.reshape(out, ()) if keepdims else out - scalar_type = dh.get_scalar_type(out.dtype) - elements = [] - for idx in ah.ndindex(x.shape): - s = scalar_type(x[idx]) - elements.append(s) - prod = scalar_type(_out) - expected = math.prod(elements) - if dh.is_int_dtype(out.dtype): - m, M = dh.dtype_ranges[out.dtype] - assume(m <= expected <= M) - assert_equals("prod", dh.get_scalar_type(out.dtype), prod, expected) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "prod", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + prod = scalar_type(out[out_idx]) + assume(not math.isinf(prod)) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = math.prod(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + assert_equals("prod", dh.get_scalar_type(out.dtype), out_idx, prod, expected) @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)).filter( - lambda x: x.size >= 2 - ), + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ).filter(lambda x: x.size >= 2), data=st.data(), ) def test_std(x, data): axis = data.draw(axes(x.ndim), label="axis") - if axis is None: - N = x.size - _axes = tuple(range(x.ndim)) - else: - _axes = axis if isinstance(axis, tuple) else (axis,) - _axes = tuple( - axis if axis >= 0 else x.ndim + axis for axis in _axes - ) # normalise - N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) + _axes = normalise_axis(axis, x.ndim) + N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) correction = data.draw( st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), label="correction", @@ -239,13 +263,9 @@ def test_std(x, data): out = xp.std(x, **kw) ph.assert_dtype("std", x.dtype, out.dtype) - - if keepdims: - shape = tuple(1 if axis in _axes else side for axis, side in enumerate(x.shape)) - else: - shape = tuple(side for axis, side in enumerate(x.shape) if axis not in _axes) - ph.assert_shape("std", out.shape, shape, **kw) - + assert_keepdimable_shape( + "std", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + ) # We can't easily test the result(s) as standard deviation methods vary a lot From 687e40a0e40eb30a4ac74f1e2c1ccf7cbb22c3c9 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 15 Nov 2021 13:11:53 +0000 Subject: [PATCH 24/60] Test `var()` and `sum()` --- array_api_tests/test_statistical_functions.py | 109 +++++++++++++++--- 1 file changed, 93 insertions(+), 16 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 3907d64e..700f5bc3 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -56,10 +56,10 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, . def assert_keepdimable_shape( func_name: str, + out_shape: Shape, in_shape: Shape, axes: Tuple[int, ...], keepdims: bool, - out_shape: Shape, /, **kw, ): @@ -108,7 +108,7 @@ def test_min(x, data): ph.assert_dtype("min", x.dtype, out.dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "min", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + "min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): @@ -137,7 +137,7 @@ def test_max(x, data): ph.assert_dtype("max", x.dtype, out.dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "max", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + "max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): @@ -166,7 +166,7 @@ def test_mean(x, data): ph.assert_dtype("mean", x.dtype, out.dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "mean", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + "mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): mean = float(out[out_idx]) @@ -217,7 +217,7 @@ def test_prod(x, data): ph.assert_dtype("prod", x.dtype, out.dtype, _dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "prod", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + "prod", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): @@ -264,20 +264,97 @@ def test_std(x, data): ph.assert_dtype("std", x.dtype, out.dtype) assert_keepdimable_shape( - "std", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + "std", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) # We can't easily test the result(s) as standard deviation methods vary a lot -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) -def test_sum(x): - xp.sum(x) - # TODO +@given( + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ).filter(lambda x: x.size >= 2), + data=st.data(), +) +def test_var(x, data): + axis = data.draw(axes(x.ndim), label="axis") + _axes = normalise_axis(axis, x.ndim) + N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) + correction = data.draw( + st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), + label="correction", + ) + keepdims = data.draw(st.booleans(), label="keepdims") + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("correction", correction, 0.0), + ("keepdims", keepdims, False), + ), + label="kw", + ) + + out = xp.var(x, **kw) + + ph.assert_dtype("var", x.dtype, out.dtype) + assert_keepdimable_shape( + "var", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + # We can't easily test the result(s) as variance methods vary a lot + + +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_sum(x, data): + kw = data.draw( + hh.kwargs( + axis=axes(x.ndim), + dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes + keepdims=st.booleans(), + ), + label="kw", + ) + out = xp.sum(x, **kw) -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) -def test_var(x): - xp.var(x) - # TODO + dtype = kw.get("dtype", None) + if dtype is None: + if dh.is_int_dtype(x.dtype): + m, M = dh.dtype_ranges[x.dtype] + d_m, d_M = dh.dtype_ranges[dh.default_int] + if m < d_m or M > d_M: + _dtype = x.dtype + else: + _dtype = dh.default_int + else: + if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: + _dtype = x.dtype + else: + _dtype = dh.default_float + else: + _dtype = dtype + ph.assert_dtype("sum", x.dtype, out.dtype, _dtype) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "sum", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + sum_ = scalar_type(out[out_idx]) + assume(not math.isinf(sum_)) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = sum(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + assert_equals("sum", dh.get_scalar_type(out.dtype), out_idx, sum_, expected) From d38eee20dcead12e94d34e54863c12b00fa3282b Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 15 Nov 2021 13:27:57 +0000 Subject: [PATCH 25/60] Generate all valid dtypes in `test_prod` and `test_sum` --- array_api_tests/test_statistical_functions.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 700f5bc3..81498062 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -11,7 +11,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps -from .typing import Scalar, ScalarType, Shape +from .typing import DataType, Scalar, ScalarType, Shape def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]: @@ -22,6 +22,11 @@ def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]: return st.one_of(axes_strats) +def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: + dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype] + return st.none() | st.sampled_from(dtypes) + + def normalise_axis( axis: Optional[Union[int, Tuple[int, ...]]], ndim: int ) -> Tuple[int, ...]: @@ -190,7 +195,7 @@ def test_prod(x, data): kw = data.draw( hh.kwargs( axis=axes(x.ndim), - dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes + dtype=kwarg_dtypes(x.dtype), keepdims=st.booleans(), ), label="kw", @@ -316,7 +321,7 @@ def test_sum(x, data): kw = data.draw( hh.kwargs( axis=axes(x.ndim), - dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes + dtype=kwarg_dtypes(x.dtype), keepdims=st.booleans(), ), label="kw", From 704e47bf1efa2a220aa513b642ee018182a5eece Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 16 Nov 2021 12:12:03 +0000 Subject: [PATCH 26/60] Cover axis in `test_concat` --- array_api_tests/meta/test_utils.py | 26 ++++++++ .../test_manipulation_functions.py | 65 +++++++++++++++---- array_api_tests/test_statistical_functions.py | 4 +- 3 files changed, 79 insertions(+), 16 deletions(-) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 35c884a3..34fa1836 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -1,6 +1,8 @@ import pytest +from .. import array_helpers as ah from ..test_creation_functions import frange +from ..test_manipulation_functions import axis_ndindex from ..test_signatures import extension_module from ..test_statistical_functions import axes_ndindex @@ -27,6 +29,30 @@ def test_frange(r, size, elements): assert list(r) == elements +@pytest.mark.parametrize( + "shape, expected", + [((), [()])], +) +def test_ndindex(shape, expected): + assert list(ah.ndindex(shape)) == expected + + +@pytest.mark.parametrize( + "shape, axis, expected", + [ + ((1,), 0, [(slice(None, None),)]), + ((1, 2), 0, [(slice(None, None), slice(None, None))]), + ( + (2, 4), + 1, + [(0, slice(None, None)), (1, slice(None, None))], + ), + ], +) +def test_axis_ndindex(shape, axis, expected): + assert list(axis_ndindex(shape, axis)) == expected + + @pytest.mark.parametrize( "shape, axes, expected", [ diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 408a587a..a2b17c3e 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,6 +1,7 @@ import math from collections import deque -from typing import Iterable, Union +from itertools import product +from typing import Iterable, Iterator, Tuple, Union from hypothesis import assume, given from hypothesis import strategies as st @@ -43,6 +44,28 @@ def assert_array_ndindex( assert out[out_idx] == x[x_idx], msg +def axis_ndindex( + shape: Shape, axis: int +) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: + iterables = [range(side) for side in shape[:axis]] + for _ in range(len(shape[axis:])): + iterables.append([slice(None, None)]) + yield from product(*iterables) + + +def assert_equals( + func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw +): + msg = ( + f"{out_repr}={out_val}, should be {x_repr}={x_val} " + f"[{func_name}({ph.fmt_kw(kw)})]" + ) + if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val): + assert xp.isnan(x_val), msg + else: + assert out_val == out_val, msg + + @st.composite def concat_shapes(draw, shape, axis): shape = list(shape) @@ -85,21 +108,35 @@ def test_concat(dtypes, kw, data): shape = tuple(shape) ph.assert_result_shape("concat", shapes, out.shape, shape, **kw) - # TODO: adjust indices with nonzero axis - if axis is None or axis == 0: - out_indices = ah.ndindex(out.shape) - for i, x in enumerate(arrays, 1): - msg_suffix = f" [concat({ph.fmt_kw(kw)})]\nx{i}={x!r}\n{out=}" + if axis is None: + out_indices = (i for i in range(out.size)) + for x_num, x in enumerate(arrays, 1): for x_idx in ah.ndindex(x.shape): - out_idx = next(out_indices) - msg = ( - f"out[{out_idx}]={out[out_idx]}, should be x{i}[{x_idx}]={x[x_idx]}" + out_i = next(out_indices) + assert_equals( + "concat", + f"x{x_num}[{x_idx}]", + x[x_idx], + f"out[{out_i}]", + out[out_i], + **kw, ) - msg += msg_suffix - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg + else: + out_indices = ah.ndindex(out.shape) + for idx in axis_ndindex(shape, axis): + f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) + for x_num, x in enumerate(arrays, 1): + indexed_x = x[idx] + for x_idx in ah.ndindex(indexed_x.shape): + out_idx = next(out_indices) + assert_equals( + "concat", + f"x{x_num}[{f_idx}][{x_idx}]", + indexed_x[x_idx], + f"out[{out_idx}]", + out[out_idx], + **kw, + ) @given( diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 81498062..0373ac47 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -42,11 +42,11 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, . axes_iterables = [] for axis, side in enumerate(shape): if axis in axes: - base_iterables.append((None,)) + base_iterables.append([None]) axes_iterables.append(range(side)) else: base_iterables.append(range(side)) - axes_iterables.append((None,)) + axes_iterables.append([None]) for base_idx in product(*base_iterables): indices = [] for idx in product(*axes_iterables): From 40167c34a8f0785566e38b922a45651db252abdb Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 16 Nov 2021 12:38:23 +0000 Subject: [PATCH 27/60] Cover axis in `test_flip` --- array_api_tests/test_manipulation_functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index a2b17c3e..061b7f0a 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -12,6 +12,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps +from .test_statistical_functions import axes_ndindex, normalise_axis # TODO: Move from .typing import Array, Shape MAX_SIDE = hh.MAX_ARRAY_SIZE // 64 @@ -210,9 +211,8 @@ def test_flip(x, data): ph.assert_dtype("flip", x.dtype, out.dtype) - # TODO: test all axis scenarios - if kw.get("axis", None) is None: - indices = list(ah.ndindex(x.shape)) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + for indices in axes_ndindex(x.shape, _axes): reverse_indices = indices[::-1] assert_array_ndindex("flip", x, indices, out, reverse_indices) From 2decdf0349f5b0e1c86b209d6a26c9591a4c4cd4 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 16 Nov 2021 15:39:30 +0000 Subject: [PATCH 28/60] Cover elements in `test_permute_dims` --- array_api_tests/test_manipulation_functions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 061b7f0a..44098d20 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -237,11 +237,12 @@ def test_permute_dims(x, axes): for i, dim in enumerate(axes): side = x.shape[dim] shape[i] = side - assert all(isinstance(side, int) for side in shape) # sanity check shape = tuple(shape) ph.assert_result_shape("permute_dims", (x.shape,), out.shape, shape, axes=axes) - # TODO: test elements + indices = list(ah.ndindex(x.shape)) + permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices] + assert_array_ndindex("permute_dims", x, indices, out, permuted_indices) @st.composite From 80c4e31f403b63c0a0706c830e6ffc134f1265f2 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 6 Dec 2021 11:28:26 +0000 Subject: [PATCH 29/60] Fix `test_concat` axes iteration --- array_api_tests/test_manipulation_functions.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 44098d20..7397d47f 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,7 +1,7 @@ import math from collections import deque from itertools import product -from typing import Iterable, Iterator, Tuple, Union +from typing import Iterable, Union from hypothesis import assume, given from hypothesis import strategies as st @@ -45,15 +45,6 @@ def assert_array_ndindex( assert out[out_idx] == x[x_idx], msg -def axis_ndindex( - shape: Shape, axis: int -) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: - iterables = [range(side) for side in shape[:axis]] - for _ in range(len(shape[axis:])): - iterables.append([slice(None, None)]) - yield from product(*iterables) - - def assert_equals( func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw ): @@ -124,7 +115,10 @@ def test_concat(dtypes, kw, data): ) else: out_indices = ah.ndindex(out.shape) - for idx in axis_ndindex(shape, axis): + axis_indices = [range(side) for side in shapes[0][:_axis]] + for _ in range(_axis, len(shape)): + axis_indices.append([slice(None, None)]) + for idx in product(*axis_indices): f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) for x_num, x in enumerate(arrays, 1): indexed_x = x[idx] From 767bd3fbb872dd70da86e77096103e1e28102f8f Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 6 Dec 2021 16:07:56 +0000 Subject: [PATCH 30/60] Cover all shift/axes scenarios in `test_roll` --- .../test_manipulation_functions.py | 40 +++++++++++++------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 7397d47f..2db88177 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -274,16 +274,21 @@ def test_reshape(x, data): @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data()) def test_roll(x, data): - shift = data.draw( - st.integers() | st.lists(st.integers(), max_size=x.ndim).map(tuple), - label="shift", - ) - axis_strats = [st.none()] - if x.shape != (): - axis_strats.append(st.integers(-x.ndim, x.ndim - 1)) - if isinstance(shift, int): - axis_strats.append(xps.valid_tuple_axes(x.ndim)) - kw = data.draw(hh.kwargs(axis=st.one_of(axis_strats)), label="kw") + shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE) + if x.ndim > 0: + shift_strat = shift_strat | st.lists( + shift_strat, min_size=1, max_size=x.ndim + ).map(tuple) + shift = data.draw(shift_strat, label="shift") + if isinstance(shift, tuple): + axis_strat = xps.valid_tuple_axes(x.ndim).filter(lambda t: len(t) == len(shift)) + kw_strat = axis_strat.map(lambda t: {"axis": t}) + else: + axis_strat = st.none() + if x.ndim != 0: + axis_strat = axis_strat | st.integers(-x.ndim, x.ndim - 1) + kw_strat = hh.kwargs(axis=axis_strat) + kw = data.draw(kw_strat, label="kw") out = xp.roll(x, shift, **kw) @@ -291,12 +296,23 @@ def test_roll(x, data): ph.assert_result_shape("roll", (x.shape,), out.shape) - # TODO: test all shift/axis scenarios - if isinstance(shift, int) and kw.get("axis", None) is None: + if kw.get("axis", None) is None: + assert isinstance(shift, int) # sanity check indices = list(ah.ndindex(x.shape)) shifted_indices = deque(indices) shifted_indices.rotate(-shift) assert_array_ndindex("roll", x, indices, out, shifted_indices) + else: + _shift = (shift,) if isinstance(shift, int) else shift + axes = normalise_axis(kw["axis"], x.ndim) + all_indices = list(ah.ndindex(x.shape)) + for s, a in zip(_shift, axes): + side = x.shape[a] + for i in range(side): + indices = [idx for idx in all_indices if idx[a] == i] + shifted_indices = deque(indices) + shifted_indices.rotate(-s) + assert_array_ndindex("roll", x, indices, out, shifted_indices) @given( From aa7aaa06058e5968a3e5236ecdc3e3393dbb3f0f Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 7 Dec 2021 11:52:26 +0000 Subject: [PATCH 31/60] Cover all axis scenarios in `test_stack` --- .../test_manipulation_functions.py | 44 +++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 2db88177..49f86694 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,7 +1,7 @@ import math from collections import deque from itertools import product -from typing import Iterable, Union +from typing import Iterable, Iterator, Tuple, Union from hypothesis import assume, given from hypothesis import strategies as st @@ -28,6 +28,16 @@ def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: return st.shared(hh.shapes(*args, **kwargs), key="shape") +def axis_ndindex( + shape: Shape, axis: int +) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: + assert axis >= 0 # sanity check + axis_indices = [range(side) for side in shape[:axis]] + for _ in range(axis, len(shape)): + axis_indices.append([slice(None, None)]) + yield from product(*axis_indices) + + def assert_array_ndindex( func_name: str, x: Array, @@ -115,10 +125,7 @@ def test_concat(dtypes, kw, data): ) else: out_indices = ah.ndindex(out.shape) - axis_indices = [range(side) for side in shapes[0][:_axis]] - for _ in range(_axis, len(shape)): - axis_indices.append([slice(None, None)]) - for idx in product(*axis_indices): + for idx in axis_ndindex(shapes[0], _axis): f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) for x_num, x in enumerate(arrays, 1): indexed_x = x[idx] @@ -344,18 +351,19 @@ def test_stack(shape, dtypes, kw, data): "stack", tuple(x.shape for x in arrays), out.shape, _shape, **kw ) - # TODO: adjust indices with nonzero axis - if axis == 0: - out_indices = ah.ndindex(out.shape) - for i, x in enumerate(arrays, 1): - msg_suffix = f" [stack({ph.fmt_kw(kw)})]\nx{i}={x!r}\n{out=}" - for x_idx in ah.ndindex(x.shape): + out_indices = ah.ndindex(out.shape) + for idx in axis_ndindex(arrays[0].shape, axis=_axis): + f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) + print(f"{f_idx=}") + for x_num, x in enumerate(arrays, 1): + indexed_x = x[idx] + for x_idx in ah.ndindex(indexed_x.shape): out_idx = next(out_indices) - msg = ( - f"out[{out_idx}]={out[out_idx]}, should be x{i}[{x_idx}]={x[x_idx]}" + assert_equals( + "stack", + f"x{x_num}[{f_idx}][{x_idx}]", + indexed_x[x_idx], + f"out[{out_idx}]", + out[out_idx], + **kw, ) - msg += msg_suffix - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg From b32fff07f624e8cc36569393bb0f039e8d60dc04 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 7 Dec 2021 12:03:43 +0000 Subject: [PATCH 32/60] Make float assertions more lenient in statistical tests --- array_api_tests/test_statistical_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 0373ac47..91aac683 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -94,7 +94,7 @@ def assert_equals( assert math.isnan(out), msg else: msg = f"{out_repr}={out}, should be roughly {expected} [{f_func}]" - assert math.isclose(out, expected, rel_tol=0.05), msg + assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg @given( @@ -175,6 +175,7 @@ def test_mean(x, data): ) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): mean = float(out[out_idx]) + assume(not math.isinf(mean)) # mean may become inf due to internal overflows elements = [] for idx in indices: s = float(x[idx]) From 5978fdae0b9c0e4e4cf40d880d312ebc5d4d12d2 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 7 Dec 2021 12:26:31 +0000 Subject: [PATCH 33/60] Update sum and prod tests to use new default uint --- array_api_tests/dtype_helpers.py | 421 +++++++++--------- array_api_tests/test_statistical_functions.py | 16 +- 2 files changed, 224 insertions(+), 213 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 65b4090a..ce749d9e 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,45 +1,45 @@ -from warnings import warn from functools import lru_cache from typing import NamedTuple, Tuple, Union +from warnings import warn from . import _array_module as xp from ._array_module import _UndefinedStub from .typing import DataType, ScalarType - __all__ = [ - 'int_dtypes', - 'uint_dtypes', - 'all_int_dtypes', - 'float_dtypes', - 'numeric_dtypes', - 'all_dtypes', - 'dtype_to_name', - 'bool_and_all_int_dtypes', - 'dtype_to_scalars', - 'is_int_dtype', - 'is_float_dtype', - 'get_scalar_type', - 'dtype_ranges', - 'default_int', - 'default_float', - 'promotion_table', - 'dtype_nbits', - 'dtype_signed', - 'func_in_dtypes', - 'func_returns_bool', - 'binary_op_to_symbol', - 'unary_op_to_symbol', - 'inplace_op_to_symbol', - 'op_to_func', - 'fmt_types', + "int_dtypes", + "uint_dtypes", + "all_int_dtypes", + "float_dtypes", + "numeric_dtypes", + "all_dtypes", + "dtype_to_name", + "bool_and_all_int_dtypes", + "dtype_to_scalars", + "is_int_dtype", + "is_float_dtype", + "get_scalar_type", + "dtype_ranges", + "default_int", + "default_uint", + "default_float", + "promotion_table", + "dtype_nbits", + "dtype_signed", + "func_in_dtypes", + "func_returns_bool", + "binary_op_to_symbol", + "unary_op_to_symbol", + "inplace_op_to_symbol", + "op_to_func", + "fmt_types", ] -_uint_names = ('uint8', 'uint16', 'uint32', 'uint64') -_int_names = ('int8', 'int16', 'int32', 'int64') -_float_names = ('float32', 'float64') -_dtype_names = ('bool',) + _uint_names + _int_names + _float_names +_uint_names = ("uint8", "uint16", "uint32", "uint64") +_int_names = ("int8", "int16", "int32", "int64") +_float_names = ("float32", "float64") +_dtype_names = ("bool",) + _uint_names + _int_names + _float_names uint_dtypes = tuple(getattr(xp, name) for name in _uint_names) @@ -101,17 +101,34 @@ class MinMax(NamedTuple): xp.uint64: MinMax(0, +18_446_744_073_709_551_615), } +dtype_nbits = { + **{d: 8 for d in [xp.int8, xp.uint8]}, + **{d: 16 for d in [xp.int16, xp.uint16]}, + **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, + **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, +} + + +dtype_signed = { + **{d: True for d in int_dtypes}, + **{d: False for d in uint_dtypes}, +} + if isinstance(xp.asarray, _UndefinedStub): default_int = xp.int32 default_float = xp.float32 warn( - 'array module does not have attribute asarray. ' - 'default int is assumed int32, default float is assumed float32' + "array module does not have attribute asarray. " + "default int is assumed int32, default float is assumed float32" ) else: default_int = xp.asarray(int()).dtype default_float = xp.asarray(float()).dtype +if dtype_nbits[default_int] == 32: + default_uint = xp.uint32 +else: + default_uint = xp.uint64 _numeric_promotions = { @@ -173,200 +190,186 @@ def result_type(*dtypes: DataType): return result -dtype_nbits = { - **{d: 8 for d in [xp.int8, xp.uint8]}, - **{d: 16 for d in [xp.int16, xp.uint16]}, - **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, - **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, -} - - -dtype_signed = { - **{d: True for d in int_dtypes}, - **{d: False for d in uint_dtypes}, -} - - func_in_dtypes = { # elementwise - 'abs': numeric_dtypes, - 'acos': float_dtypes, - 'acosh': float_dtypes, - 'add': numeric_dtypes, - 'asin': float_dtypes, - 'asinh': float_dtypes, - 'atan': float_dtypes, - 'atan2': float_dtypes, - 'atanh': float_dtypes, - 'bitwise_and': bool_and_all_int_dtypes, - 'bitwise_invert': bool_and_all_int_dtypes, - 'bitwise_left_shift': all_int_dtypes, - 'bitwise_or': bool_and_all_int_dtypes, - 'bitwise_right_shift': all_int_dtypes, - 'bitwise_xor': bool_and_all_int_dtypes, - 'ceil': numeric_dtypes, - 'cos': float_dtypes, - 'cosh': float_dtypes, - 'divide': float_dtypes, - 'equal': all_dtypes, - 'exp': float_dtypes, - 'expm1': float_dtypes, - 'floor': numeric_dtypes, - 'floor_divide': numeric_dtypes, - 'greater': numeric_dtypes, - 'greater_equal': numeric_dtypes, - 'isfinite': numeric_dtypes, - 'isinf': numeric_dtypes, - 'isnan': numeric_dtypes, - 'less': numeric_dtypes, - 'less_equal': numeric_dtypes, - 'log': float_dtypes, - 'logaddexp': float_dtypes, - 'log10': float_dtypes, - 'log1p': float_dtypes, - 'log2': float_dtypes, - 'logical_and': (xp.bool,), - 'logical_not': (xp.bool,), - 'logical_or': (xp.bool,), - 'logical_xor': (xp.bool,), - 'multiply': numeric_dtypes, - 'negative': numeric_dtypes, - 'not_equal': all_dtypes, - 'positive': numeric_dtypes, - 'pow': float_dtypes, - 'remainder': numeric_dtypes, - 'round': numeric_dtypes, - 'sign': numeric_dtypes, - 'sin': float_dtypes, - 'sinh': float_dtypes, - 'sqrt': float_dtypes, - 'square': numeric_dtypes, - 'subtract': numeric_dtypes, - 'tan': float_dtypes, - 'tanh': float_dtypes, - 'trunc': numeric_dtypes, + "abs": numeric_dtypes, + "acos": float_dtypes, + "acosh": float_dtypes, + "add": numeric_dtypes, + "asin": float_dtypes, + "asinh": float_dtypes, + "atan": float_dtypes, + "atan2": float_dtypes, + "atanh": float_dtypes, + "bitwise_and": bool_and_all_int_dtypes, + "bitwise_invert": bool_and_all_int_dtypes, + "bitwise_left_shift": all_int_dtypes, + "bitwise_or": bool_and_all_int_dtypes, + "bitwise_right_shift": all_int_dtypes, + "bitwise_xor": bool_and_all_int_dtypes, + "ceil": numeric_dtypes, + "cos": float_dtypes, + "cosh": float_dtypes, + "divide": float_dtypes, + "equal": all_dtypes, + "exp": float_dtypes, + "expm1": float_dtypes, + "floor": numeric_dtypes, + "floor_divide": numeric_dtypes, + "greater": numeric_dtypes, + "greater_equal": numeric_dtypes, + "isfinite": numeric_dtypes, + "isinf": numeric_dtypes, + "isnan": numeric_dtypes, + "less": numeric_dtypes, + "less_equal": numeric_dtypes, + "log": float_dtypes, + "logaddexp": float_dtypes, + "log10": float_dtypes, + "log1p": float_dtypes, + "log2": float_dtypes, + "logical_and": (xp.bool,), + "logical_not": (xp.bool,), + "logical_or": (xp.bool,), + "logical_xor": (xp.bool,), + "multiply": numeric_dtypes, + "negative": numeric_dtypes, + "not_equal": all_dtypes, + "positive": numeric_dtypes, + "pow": float_dtypes, + "remainder": numeric_dtypes, + "round": numeric_dtypes, + "sign": numeric_dtypes, + "sin": float_dtypes, + "sinh": float_dtypes, + "sqrt": float_dtypes, + "square": numeric_dtypes, + "subtract": numeric_dtypes, + "tan": float_dtypes, + "tanh": float_dtypes, + "trunc": numeric_dtypes, # searching - 'where': all_dtypes, + "where": all_dtypes, } func_returns_bool = { # elementwise - 'abs': False, - 'acos': False, - 'acosh': False, - 'add': False, - 'asin': False, - 'asinh': False, - 'atan': False, - 'atan2': False, - 'atanh': False, - 'bitwise_and': False, - 'bitwise_invert': False, - 'bitwise_left_shift': False, - 'bitwise_or': False, - 'bitwise_right_shift': False, - 'bitwise_xor': False, - 'ceil': False, - 'cos': False, - 'cosh': False, - 'divide': False, - 'equal': True, - 'exp': False, - 'expm1': False, - 'floor': False, - 'floor_divide': False, - 'greater': True, - 'greater_equal': True, - 'isfinite': True, - 'isinf': True, - 'isnan': True, - 'less': True, - 'less_equal': True, - 'log': False, - 'logaddexp': False, - 'log10': False, - 'log1p': False, - 'log2': False, - 'logical_and': True, - 'logical_not': True, - 'logical_or': True, - 'logical_xor': True, - 'multiply': False, - 'negative': False, - 'not_equal': True, - 'positive': False, - 'pow': False, - 'remainder': False, - 'round': False, - 'sign': False, - 'sin': False, - 'sinh': False, - 'sqrt': False, - 'square': False, - 'subtract': False, - 'tan': False, - 'tanh': False, - 'trunc': False, + "abs": False, + "acos": False, + "acosh": False, + "add": False, + "asin": False, + "asinh": False, + "atan": False, + "atan2": False, + "atanh": False, + "bitwise_and": False, + "bitwise_invert": False, + "bitwise_left_shift": False, + "bitwise_or": False, + "bitwise_right_shift": False, + "bitwise_xor": False, + "ceil": False, + "cos": False, + "cosh": False, + "divide": False, + "equal": True, + "exp": False, + "expm1": False, + "floor": False, + "floor_divide": False, + "greater": True, + "greater_equal": True, + "isfinite": True, + "isinf": True, + "isnan": True, + "less": True, + "less_equal": True, + "log": False, + "logaddexp": False, + "log10": False, + "log1p": False, + "log2": False, + "logical_and": True, + "logical_not": True, + "logical_or": True, + "logical_xor": True, + "multiply": False, + "negative": False, + "not_equal": True, + "positive": False, + "pow": False, + "remainder": False, + "round": False, + "sign": False, + "sin": False, + "sinh": False, + "sqrt": False, + "square": False, + "subtract": False, + "tan": False, + "tanh": False, + "trunc": False, # searching - 'where': False, + "where": False, } unary_op_to_symbol = { - '__invert__': '~', - '__neg__': '-', - '__pos__': '+', + "__invert__": "~", + "__neg__": "-", + "__pos__": "+", } binary_op_to_symbol = { - '__add__': '+', - '__and__': '&', - '__eq__': '==', - '__floordiv__': '//', - '__ge__': '>=', - '__gt__': '>', - '__le__': '<=', - '__lshift__': '<<', - '__lt__': '<', - '__matmul__': '@', - '__mod__': '%', - '__mul__': '*', - '__ne__': '!=', - '__or__': '|', - '__pow__': '**', - '__rshift__': '>>', - '__sub__': '-', - '__truediv__': '/', - '__xor__': '^', + "__add__": "+", + "__and__": "&", + "__eq__": "==", + "__floordiv__": "//", + "__ge__": ">=", + "__gt__": ">", + "__le__": "<=", + "__lshift__": "<<", + "__lt__": "<", + "__matmul__": "@", + "__mod__": "%", + "__mul__": "*", + "__ne__": "!=", + "__or__": "|", + "__pow__": "**", + "__rshift__": ">>", + "__sub__": "-", + "__truediv__": "/", + "__xor__": "^", } op_to_func = { - '__abs__': 'abs', - '__add__': 'add', - '__and__': 'bitwise_and', - '__eq__': 'equal', - '__floordiv__': 'floor_divide', - '__ge__': 'greater_equal', - '__gt__': 'greater', - '__le__': 'less_equal', - '__lt__': 'less', + "__abs__": "abs", + "__add__": "add", + "__and__": "bitwise_and", + "__eq__": "equal", + "__floordiv__": "floor_divide", + "__ge__": "greater_equal", + "__gt__": "greater", + "__le__": "less_equal", + "__lt__": "less", # '__matmul__': 'matmul', # TODO: support matmul - '__mod__': 'remainder', - '__mul__': 'multiply', - '__ne__': 'not_equal', - '__or__': 'bitwise_or', - '__pow__': 'pow', - '__lshift__': 'bitwise_left_shift', - '__rshift__': 'bitwise_right_shift', - '__sub__': 'subtract', - '__truediv__': 'divide', - '__xor__': 'bitwise_xor', - '__invert__': 'bitwise_invert', - '__neg__': 'negative', - '__pos__': 'positive', + "__mod__": "remainder", + "__mul__": "multiply", + "__ne__": "not_equal", + "__or__": "bitwise_or", + "__pow__": "pow", + "__lshift__": "bitwise_left_shift", + "__rshift__": "bitwise_right_shift", + "__sub__": "subtract", + "__truediv__": "divide", + "__xor__": "bitwise_xor", + "__invert__": "bitwise_invert", + "__neg__": "negative", + "__pos__": "positive", } @@ -377,10 +380,10 @@ def result_type(*dtypes: DataType): inplace_op_to_symbol = {} for op, symbol in binary_op_to_symbol.items(): - if op == '__matmul__' or func_returns_bool[op]: + if op == "__matmul__" or func_returns_bool[op]: continue - iop = f'__i{op[2:]}' - inplace_op_to_symbol[iop] = f'{symbol}=' + iop = f"__i{op[2:]}" + inplace_op_to_symbol[iop] = f"{symbol}=" func_in_dtypes[iop] = func_in_dtypes[op] func_returns_bool[iop] = func_returns_bool[op] @@ -394,4 +397,4 @@ def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str: except KeyError: # i.e. dtype is bool, int, or float f_types.append(type_.__name__) - return ', '.join(f_types) + return ", ".join(f_types) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 91aac683..d7ab381a 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -207,12 +207,16 @@ def test_prod(x, data): dtype = kw.get("dtype", None) if dtype is None: if dh.is_int_dtype(x.dtype): + if x.dtype in dh.uint_dtypes: + default_dtype = dh.default_uint + else: + default_dtype = dh.default_int m, M = dh.dtype_ranges[x.dtype] - d_m, d_M = dh.dtype_ranges[dh.default_int] + d_m, d_M = dh.dtype_ranges[default_dtype] if m < d_m or M > d_M: _dtype = x.dtype else: - _dtype = dh.default_int + _dtype = default_dtype else: if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: _dtype = x.dtype @@ -333,12 +337,16 @@ def test_sum(x, data): dtype = kw.get("dtype", None) if dtype is None: if dh.is_int_dtype(x.dtype): + if x.dtype in dh.uint_dtypes: + default_dtype = dh.default_uint + else: + default_dtype = dh.default_int m, M = dh.dtype_ranges[x.dtype] - d_m, d_M = dh.dtype_ranges[dh.default_int] + d_m, d_M = dh.dtype_ranges[default_dtype] if m < d_m or M > d_M: _dtype = x.dtype else: - _dtype = dh.default_int + _dtype = default_dtype else: if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: _dtype = x.dtype From 9f3a83e73ea9fc6120ed991798a108d439157f7c Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 8 Dec 2021 09:56:44 +0000 Subject: [PATCH 34/60] Sort statistical tests by spec order --- array_api_tests/test_statistical_functions.py | 124 +++++++++--------- 1 file changed, 62 insertions(+), 62 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index d7ab381a..5ec98fca 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -105,83 +105,83 @@ def assert_equals( ), data=st.data(), ) -def test_min(x, data): +def test_max(x, data): kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") - out = xp.min(x, **kw) + out = xp.max(x, **kw) - ph.assert_dtype("min", x.dtype, out.dtype) + ph.assert_dtype("max", x.dtype, out.dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + "max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): - min_ = scalar_type(out[out_idx]) + max_ = scalar_type(out[out_idx]) elements = [] for idx in indices: s = scalar_type(x[idx]) elements.append(s) - expected = min(elements) - assert_equals("min", dh.get_scalar_type(out.dtype), out_idx, min_, expected) + expected = max(elements) + assert_equals("max", dh.get_scalar_type(out.dtype), out_idx, max_, expected) @given( x=xps.arrays( - dtype=xps.numeric_dtypes(), + dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1), elements={"allow_nan": False}, ), data=st.data(), ) -def test_max(x, data): +def test_mean(x, data): kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") - out = xp.max(x, **kw) + out = xp.mean(x, **kw) - ph.assert_dtype("max", x.dtype, out.dtype) + ph.assert_dtype("mean", x.dtype, out.dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + "mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) - scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): - max_ = scalar_type(out[out_idx]) + mean = float(out[out_idx]) + assume(not math.isinf(mean)) # mean may become inf due to internal overflows elements = [] for idx in indices: - s = scalar_type(x[idx]) + s = float(x[idx]) elements.append(s) - expected = max(elements) - assert_equals("max", dh.get_scalar_type(out.dtype), out_idx, max_, expected) + expected = sum(elements) / len(elements) + assert_equals("mean", dh.get_scalar_type(out.dtype), out_idx, mean, expected) @given( x=xps.arrays( - dtype=xps.floating_dtypes(), + dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1), elements={"allow_nan": False}, ), data=st.data(), ) -def test_mean(x, data): +def test_min(x, data): kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") - out = xp.mean(x, **kw) + out = xp.min(x, **kw) - ph.assert_dtype("mean", x.dtype, out.dtype) + ph.assert_dtype("min", x.dtype, out.dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + "min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) + scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): - mean = float(out[out_idx]) - assume(not math.isinf(mean)) # mean may become inf due to internal overflows + min_ = scalar_type(out[out_idx]) elements = [] for idx in indices: - s = float(x[idx]) + s = scalar_type(x[idx]) elements.append(s) - expected = sum(elements) / len(elements) - assert_equals("mean", dh.get_scalar_type(out.dtype), out_idx, mean, expected) + expected = min(elements) + assert_equals("min", dh.get_scalar_type(out.dtype), out_idx, min_, expected) @given( @@ -279,41 +279,6 @@ def test_std(x, data): # We can't easily test the result(s) as standard deviation methods vary a lot -@given( - x=xps.arrays( - dtype=xps.floating_dtypes(), - shape=hh.shapes(min_side=1), - elements={"allow_nan": False}, - ).filter(lambda x: x.size >= 2), - data=st.data(), -) -def test_var(x, data): - axis = data.draw(axes(x.ndim), label="axis") - _axes = normalise_axis(axis, x.ndim) - N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) - correction = data.draw( - st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), - label="correction", - ) - keepdims = data.draw(st.booleans(), label="keepdims") - kw = data.draw( - hh.specified_kwargs( - ("axis", axis, None), - ("correction", correction, 0.0), - ("keepdims", keepdims, False), - ), - label="kw", - ) - - out = xp.var(x, **kw) - - ph.assert_dtype("var", x.dtype, out.dtype) - assert_keepdimable_shape( - "var", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw - ) - # We can't easily test the result(s) as variance methods vary a lot - - @given( x=xps.arrays( dtype=xps.numeric_dtypes(), @@ -372,3 +337,38 @@ def test_sum(x, data): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) assert_equals("sum", dh.get_scalar_type(out.dtype), out_idx, sum_, expected) + + +@given( + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ).filter(lambda x: x.size >= 2), + data=st.data(), +) +def test_var(x, data): + axis = data.draw(axes(x.ndim), label="axis") + _axes = normalise_axis(axis, x.ndim) + N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) + correction = data.draw( + st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), + label="correction", + ) + keepdims = data.draw(st.booleans(), label="keepdims") + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("correction", correction, 0.0), + ("keepdims", keepdims, False), + ), + label="kw", + ) + + out = xp.var(x, **kw) + + ph.assert_dtype("var", x.dtype, out.dtype) + assert_keepdimable_shape( + "var", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + # We can't easily test the result(s) as variance methods vary a lot From 09aa26b831e66abf02322604949d35af5b2c4256 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 8 Dec 2021 10:28:10 +0000 Subject: [PATCH 35/60] Check error raising in `test_squeeze`, use negative axes --- .../test_manipulation_functions.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 49f86694..47ed8576 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -3,6 +3,7 @@ from itertools import product from typing import Iterable, Iterator, Tuple, Union +import pytest from hypothesis import assume, given from hypothesis import strategies as st @@ -168,23 +169,26 @@ def test_expand_dims(x, axis): data=st.data(), ) def test_squeeze(x, data): - # TODO: generate valid negative axis (which keep uniqueness) - squeezable_axes = st.sampled_from( - [i for i, side in enumerate(x.shape) if side == 1] - ) + axes = st.integers(-x.ndim, x.ndim - 1) axis = data.draw( - squeezable_axes | st.lists(squeezable_axes, unique=True).map(tuple), + axes + | st.lists(axes, unique_by=lambda i: i if i >= 0 else i + x.ndim).map(tuple), label="axis", ) + axes = (axis,) if isinstance(axis, int) else axis + axes = normalise_axis(axes, x.ndim) + + squeezable_axes = [i for i, side in enumerate(x.shape) if side == 1] + if any(i not in squeezable_axes for i in axes): + with pytest.raises(ValueError): + xp.squeeze(x, axis) + return + out = xp.squeeze(x, axis) ph.assert_dtype("squeeze", x.dtype, out.dtype) - if isinstance(axis, int): - axes = (axis,) - else: - axes = axis shape = [] for i, side in enumerate(x.shape): if i not in axes: From 18709f6310cfd9d7afcdf9c93fcdd68845db3b96 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 8 Dec 2021 10:46:32 +0000 Subject: [PATCH 36/60] Cover invalid axis in `test_expand_dims` --- array_api_tests/test_manipulation_functions.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 47ed8576..9e33567a 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -144,9 +144,17 @@ def test_concat(dtypes, kw, data): @given( x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), - axis=shared_shapes().flatmap(lambda s: st.integers(-len(s) - 1, len(s))), + axis=shared_shapes().flatmap( + # Generate both valid and invalid axis + lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s)) + ), ) def test_expand_dims(x, axis): + if axis < -x.ndim - 1 or axis > x.ndim: + with pytest.raises(IndexError): + xp.expand_dims(x, axis=axis) + return + out = xp.expand_dims(x, axis=axis) ph.assert_dtype("expand_dims", x.dtype, out.dtype) From 2ad6ddcb4dca8a69d4cfe7db8e57fa7c2131baa6 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 8 Dec 2021 11:05:32 +0000 Subject: [PATCH 37/60] Try to ignore overflow scenarios in `prod` and `sum` tests --- array_api_tests/test_statistical_functions.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 5ec98fca..24067108 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -4,6 +4,7 @@ from hypothesis import assume, given from hypothesis import strategies as st +from hypothesis.control import reject from . import _array_module as xp from . import array_helpers as ah @@ -202,7 +203,10 @@ def test_prod(x, data): label="kw", ) - out = xp.prod(x, **kw) + try: + out = xp.prod(x, **kw) + except OverflowError: + reject() dtype = kw.get("dtype", None) if dtype is None: @@ -232,7 +236,7 @@ def test_prod(x, data): scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): prod = scalar_type(out[out_idx]) - assume(not math.isinf(prod)) + assume(math.isfinite(prod)) elements = [] for idx in indices: s = scalar_type(x[idx]) @@ -297,7 +301,10 @@ def test_sum(x, data): label="kw", ) - out = xp.sum(x, **kw) + try: + out = xp.sum(x, **kw) + except OverflowError: + reject() dtype = kw.get("dtype", None) if dtype is None: @@ -327,7 +334,7 @@ def test_sum(x, data): scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): sum_ = scalar_type(out[out_idx]) - assume(not math.isinf(sum_)) + assume(math.isfinite(sum_)) elements = [] for idx in indices: s = scalar_type(x[idx]) From 6e4564b618e38400c3f50cc1bb47fb20c52ea606 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 8 Dec 2021 11:15:54 +0000 Subject: [PATCH 38/60] Docstrings for axes helpers --- array_api_tests/test_manipulation_functions.py | 3 ++- array_api_tests/test_statistical_functions.py | 17 +++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 9e33567a..a2e2e2a3 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -13,7 +13,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps -from .test_statistical_functions import axes_ndindex, normalise_axis # TODO: Move +from .test_statistical_functions import axes_ndindex, normalise_axis from .typing import Array, Shape MAX_SIDE = hh.MAX_ARRAY_SIZE // 64 @@ -32,6 +32,7 @@ def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: def axis_ndindex( shape: Shape, axis: int ) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: + """Generate indices that index all elements in dimensions beyond `axis`""" assert axis >= 0 # sanity check axis_indices = [range(side) for side in shape[:axis]] for _ in range(axis, len(shape)): diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 24067108..c62be30f 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -39,18 +39,19 @@ def normalise_axis( def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, ...]]: - base_iterables = [] - axes_iterables = [] + """Generate indices that index all elements except in `axes` dimensions""" + base_indices = [] + axes_indices = [] for axis, side in enumerate(shape): if axis in axes: - base_iterables.append([None]) - axes_iterables.append(range(side)) + base_indices.append([None]) + axes_indices.append(range(side)) else: - base_iterables.append(range(side)) - axes_iterables.append([None]) - for base_idx in product(*base_iterables): + base_indices.append(range(side)) + axes_indices.append([None]) + for base_idx in product(*base_indices): indices = [] - for idx in product(*axes_iterables): + for idx in product(*axes_indices): idx = list(idx) for axis, side in enumerate(idx): if axis not in axes: From 0326aa395c2ec75d4133962fb877f5956e538504 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 8 Dec 2021 18:56:28 +0000 Subject: [PATCH 39/60] Remove redundant calls to `dh.get_scalar_type()` --- array_api_tests/test_statistical_functions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index c62be30f..7813d2ea 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -125,7 +125,7 @@ def test_max(x, data): s = scalar_type(x[idx]) elements.append(s) expected = max(elements) - assert_equals("max", dh.get_scalar_type(out.dtype), out_idx, max_, expected) + assert_equals("max", scalar_type, out_idx, max_, expected) @given( @@ -154,7 +154,7 @@ def test_mean(x, data): s = float(x[idx]) elements.append(s) expected = sum(elements) / len(elements) - assert_equals("mean", dh.get_scalar_type(out.dtype), out_idx, mean, expected) + assert_equals("mean", float, out_idx, mean, expected) @given( @@ -183,7 +183,7 @@ def test_min(x, data): s = scalar_type(x[idx]) elements.append(s) expected = min(elements) - assert_equals("min", dh.get_scalar_type(out.dtype), out_idx, min_, expected) + assert_equals("min", scalar_type, out_idx, min_, expected) @given( @@ -246,7 +246,7 @@ def test_prod(x, data): if dh.is_int_dtype(out.dtype): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) - assert_equals("prod", dh.get_scalar_type(out.dtype), out_idx, prod, expected) + assert_equals("prod", scalar_type, out_idx, prod, expected) @given( @@ -344,7 +344,7 @@ def test_sum(x, data): if dh.is_int_dtype(out.dtype): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) - assert_equals("sum", dh.get_scalar_type(out.dtype), out_idx, sum_, expected) + assert_equals("sum", scalar_type, out_idx, sum_, expected) @given( From cd8f1172b2de845b546312d750f05cba39cfacf9 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 10 Dec 2021 19:30:58 +0000 Subject: [PATCH 40/60] Fix `test_manipulation_functions.assert_equals` --- array_api_tests/test_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index a2e2e2a3..1e95ca10 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -67,7 +67,7 @@ def assert_equals( if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val): assert xp.isnan(x_val), msg else: - assert out_val == out_val, msg + assert x_val == out_val, msg @st.composite From af285de6675fe4cdfabce6a610e16c6c7542c37a Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 10 Dec 2021 19:44:00 +0000 Subject: [PATCH 41/60] Skip `test_roll` as its wrong --- array_api_tests/test_manipulation_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 1e95ca10..9453fc25 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -292,6 +292,7 @@ def test_reshape(x, data): assert_array_ndindex("reshape", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) +@pytest.mark.skip(reason="faulty test logic") # TODO @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data()) def test_roll(x, data): shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE) From 83a6f5e3f28656ba1f16c3338536b73eff176752 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 10 Dec 2021 10:21:24 +0000 Subject: [PATCH 42/60] Cover everything for `argmin` and `argmax` tests --- array_api_tests/test_searching_functions.py | 95 ++++++++++++++++++--- 1 file changed, 85 insertions(+), 10 deletions(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index c3686bb7..4672c173 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -1,23 +1,98 @@ from hypothesis import given from hypothesis import strategies as st +from array_api_tests.test_statistical_functions import ( + assert_equals, + assert_keepdimable_shape, + axes_ndindex, + normalise_axis, +) +from array_api_tests.typing import DataType + from . import _array_module as xp +from . import array_helpers as ah +from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import xps -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) -def test_argmin(x): - xp.argmin(x) - # TODO +def assert_default_index(func_name: str, dtype: DataType): + f_dtype = dh.dtype_to_name[dtype] + msg = ( + f"out.dtype={f_dtype}, should be the default index dtype, " + f"which is either int32 or int64 [{func_name}()]" + ) + assert dtype in (xp.int32, xp.int64), msg -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) -def test_argmax(x): - xp.argmax(x) - # TODO +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_argmax(x, data): + kw = data.draw( + hh.kwargs( + axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)), + keepdims=st.booleans(), + ), + label="kw", + ) + + out = xp.argmax(x, **kw) + + assert_default_index("argmax", out.dtype) + axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "argmax", out.shape, x.shape, axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(x.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, axes), ah.ndindex(out.shape)): + max_i = int(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = max(range(len(elements)), key=elements.__getitem__) + assert_equals("argmax", int, out_idx, max_i, expected) + + +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_argmin(x, data): + kw = data.draw( + hh.kwargs( + axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)), + keepdims=st.booleans(), + ), + label="kw", + ) + + out = xp.argmin(x, **kw) + + assert_default_index("argmin", out.dtype) + axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "argmin", out.shape, x.shape, axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(x.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, axes), ah.ndindex(out.shape)): + min_i = int(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = min(range(len(elements)), key=elements.__getitem__) + assert_equals("argmin", int, out_idx, min_i, expected) # TODO: generate kwargs, skip if opted out From c8f40072c908fea3d2b69ca091d98263cbbee8af Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 10 Dec 2021 12:04:06 +0000 Subject: [PATCH 43/60] Cover everything for `test_nonzero` --- array_api_tests/test_searching_functions.py | 41 ++++++++++++++++++--- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 4672c173..f9cb986c 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -16,10 +16,10 @@ from . import xps -def assert_default_index(func_name: str, dtype: DataType): +def assert_default_index(func_name: str, dtype: DataType, repr_name="out.dtype"): f_dtype = dh.dtype_to_name[dtype] msg = ( - f"out.dtype={f_dtype}, should be the default index dtype, " + f"{repr_name}={f_dtype}, should be the default index dtype, " f"which is either int32 or int64 [{func_name}()]" ) assert dtype in (xp.int32, xp.int64), msg @@ -95,11 +95,42 @@ def test_argmin(x, data): assert_equals("argmin", int, out_idx, min_i, expected) -# TODO: generate kwargs, skip if opted out @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) def test_nonzero(x): - xp.nonzero(x) - # TODO + out = xp.nonzero(x) + if x.ndim == 0: + assert len(out) == 1, f"{len(out)=}, but should be 1 for 0-dimensional arrays" + else: + assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}" + size = out[0].size + for i in range(len(out)): + assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1" + assert ( + out[i].size == size + ), f"out[{i}].size={x.size}, but should be out[0].size={size}" + assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype") + indices = [] + if x.dtype == xp.bool: + for idx in ah.ndindex(x.shape): + if x[idx]: + indices.append(idx) + else: + for idx in ah.ndindex(x.shape): + if x[idx] != 0: + indices.append(idx) + if x.ndim == 0: + assert out[0].size == len( + indices + ), f"{out[0].size=}, but should be {len(indices)}" + else: + for i in range(size): + idx = tuple(int(x[i]) for x in out) + f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}" + f_element = f"x[{idx}]={x[idx]}" + assert idx in indices, f"{f_idx} results in {f_element}, a zero element" + assert ( + idx == indices[i] + ), f"{f_idx} is in the wrong position, should be {indices.index(idx)}" # TODO: skip if opted out From d6c4fc661171f41489ab49f4f8d17980bd747c0f Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 10 Dec 2021 12:46:00 +0000 Subject: [PATCH 44/60] Cover everything for `test_where` --- array_api_tests/test_searching_functions.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index f9cb986c..6003ccd4 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -1,6 +1,8 @@ from hypothesis import given from hypothesis import strategies as st +from array_api_tests.algos import broadcast_shapes +from array_api_tests.test_manipulation_functions import assert_equals as assert_equals_ from array_api_tests.test_statistical_functions import ( assert_equals, assert_keepdimable_shape, @@ -13,6 +15,7 @@ from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh +from . import pytest_helpers as ph from . import xps @@ -95,6 +98,7 @@ def test_argmin(x, data): assert_equals("argmin", int, out_idx, min_i, expected) +# TODO: skip if opted out @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) def test_nonzero(x): out = xp.nonzero(x) @@ -133,7 +137,6 @@ def test_nonzero(x): ), f"{f_idx} is in the wrong position, should be {indices.index(idx)}" -# TODO: skip if opted out @given( shapes=hh.mutually_broadcastable_shapes(3), dtypes=hh.mutually_promotable_dtypes(), @@ -143,5 +146,17 @@ def test_where(shapes, dtypes, data): cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[0]), label="condition") x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1") x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2") - xp.where(cond, x1, x2) - # TODO + + out = xp.where(cond, x1, x2) + + shape = broadcast_shapes(*shapes) + ph.assert_shape("where", out.shape, shape) + # TODO: generate indices without broadcasting arrays + _cond = xp.broadcast_to(cond, shape) + _x1 = xp.broadcast_to(x1, shape) + _x2 = xp.broadcast_to(x2, shape) + for idx in ah.ndindex(shape): + if _cond[idx]: + assert_equals_("where", f"_x1[{idx}]", _x1[idx], f"out[{idx}]", out[idx]) + else: + assert_equals_("where", f"_x2[{idx}]", _x2[idx], f"out[{idx}]", out[idx]) From f108941b76ab3e019a9935191572efa6784645a1 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 10 Dec 2021 19:38:04 +0000 Subject: [PATCH 45/60] Cover most things in `test_sort` --- array_api_tests/test_sorting.py | 60 ++++++++++++++++++++++++++++++--- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_sorting.py b/array_api_tests/test_sorting.py index 58179b3c..473d70ad 100644 --- a/array_api_tests/test_sorting.py +++ b/array_api_tests/test_sorting.py @@ -1,8 +1,14 @@ from hypothesis import given +from hypothesis import strategies as st +from hypothesis.control import assume from . import _array_module as xp +from . import array_helpers as ah +from . import dtype_helpers as dh from . import hypothesis_helpers as hh +from . import pytest_helpers as ph from . import xps +from .test_manipulation_functions import assert_equals, axis_ndindex # TODO: generate kwargs @@ -12,8 +18,52 @@ def test_argsort(x): # TODO -# TODO: generate 0d arrays, generate kwargs -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1))) -def test_sort(x): - xp.sort(x) - # TODO +# TODO: Test with signed zeros and NaNs (and ignore them somehow) +@given( + x=xps.arrays( + dtype=xps.scalar_dtypes(), + shape=hh.shapes(min_dims=1, min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_sort(x, data): + if dh.is_float_dtype(x.dtype): + assume(not xp.any(x == -0.0) and not xp.any(x == +0.0)) + + kw = data.draw( + hh.kwargs( + axis=st.integers(-x.ndim, x.ndim - 1), + descending=st.booleans(), + stable=st.booleans(), + ), + label="kw", + ) + + out = xp.sort(x, **kw) + + ph.assert_dtype("sort", out.dtype, x.dtype) + ph.assert_shape("sort", out.shape, x.shape, **kw) + axis = kw.get("axis", -1) + _axis = axis if axis >= 0 else x.ndim + axis + descending = kw.get("descending", False) + scalar_type = dh.get_scalar_type(x.dtype) + for idx in axis_ndindex(x.shape, _axis): + f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) + indexed_x = x[idx] + indexed_out = out[idx] + out_indices = list(ah.ndindex(indexed_x.shape)) + elements = [scalar_type(indexed_x[idx2]) for idx2 in out_indices] + indices_order = sorted( + range(len(out_indices)), key=elements.__getitem__, reverse=descending + ) + x_indices = [out_indices[o] for o in indices_order] + for out_idx, x_idx in zip(out_indices, x_indices): + assert_equals( + "sort", + f"x[{f_idx}][{x_idx}]", + indexed_x[x_idx], + f"out[{f_idx}][{out_idx}]", + indexed_out[out_idx], + **kw, + ) From 5b44997b64ac832e85a95f93a889bd48444c5679 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 13 Dec 2021 11:03:03 +0000 Subject: [PATCH 46/60] Fix `test_sort` using wrong axis iteration --- array_api_tests/meta/test_utils.py | 7 +++-- array_api_tests/test_sorting.py | 28 ++++++++----------- array_api_tests/test_statistical_functions.py | 6 ++-- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 34fa1836..83a4b75d 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -56,13 +56,14 @@ def test_axis_ndindex(shape, axis, expected): @pytest.mark.parametrize( "shape, axes, expected", [ - ((), (), [((),)]), + ((), (), [[()]]), + ((1,), (0,), [[(0,)]]), ( (2, 2), (0,), [ - ((0, 0), (1, 0)), - ((0, 1), (1, 1)), + [(0, 0), (1, 0)], + [(0, 1), (1, 1)], ], ), ], diff --git a/array_api_tests/test_sorting.py b/array_api_tests/test_sorting.py index 473d70ad..4345c64a 100644 --- a/array_api_tests/test_sorting.py +++ b/array_api_tests/test_sorting.py @@ -3,12 +3,12 @@ from hypothesis.control import assume from . import _array_module as xp -from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps -from .test_manipulation_functions import assert_equals, axis_ndindex +from .test_manipulation_functions import assert_equals +from .test_statistical_functions import axes_ndindex, normalise_axis # TODO: generate kwargs @@ -45,25 +45,21 @@ def test_sort(x, data): ph.assert_dtype("sort", out.dtype, x.dtype) ph.assert_shape("sort", out.shape, x.shape, **kw) axis = kw.get("axis", -1) - _axis = axis if axis >= 0 else x.ndim + axis + axes = normalise_axis(axis, x.ndim) descending = kw.get("descending", False) scalar_type = dh.get_scalar_type(x.dtype) - for idx in axis_ndindex(x.shape, _axis): - f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) - indexed_x = x[idx] - indexed_out = out[idx] - out_indices = list(ah.ndindex(indexed_x.shape)) - elements = [scalar_type(indexed_x[idx2]) for idx2 in out_indices] + for indices in axes_ndindex(x.shape, axes): + elements = [scalar_type(x[idx]) for idx in indices] indices_order = sorted( - range(len(out_indices)), key=elements.__getitem__, reverse=descending + range(len(indices)), key=elements.__getitem__, reverse=descending ) - x_indices = [out_indices[o] for o in indices_order] - for out_idx, x_idx in zip(out_indices, x_indices): + x_indices = [indices[o] for o in indices_order] + for out_idx, x_idx in zip(indices, x_indices): assert_equals( "sort", - f"x[{f_idx}][{x_idx}]", - indexed_x[x_idx], - f"out[{f_idx}][{out_idx}]", - indexed_out[out_idx], + f"x[{x_idx}]", + x[x_idx], + f"out[{out_idx}]", + out[out_idx], **kw, ) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 7813d2ea..81e11558 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import Iterator, Optional, Tuple, Union +from typing import Iterator, List, Optional, Tuple, Union from hypothesis import assume, given from hypothesis import strategies as st @@ -38,7 +38,7 @@ def normalise_axis( return axes -def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, ...]]: +def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]: """Generate indices that index all elements except in `axes` dimensions""" base_indices = [] axes_indices = [] @@ -58,7 +58,7 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, . idx[axis] = base_idx[axis] idx = tuple(idx) indices.append(idx) - yield tuple(indices) + yield list(indices) def assert_keepdimable_shape( From c0a47fd87460294198997a2533e3b21b53baeb75 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 13 Dec 2021 11:26:18 +0000 Subject: [PATCH 47/60] Cover most things for `test_argsort` --- array_api_tests/test_sorting.py | 50 +++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_sorting.py b/array_api_tests/test_sorting.py index 4345c64a..14f76ff9 100644 --- a/array_api_tests/test_sorting.py +++ b/array_api_tests/test_sorting.py @@ -7,15 +7,49 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps -from .test_manipulation_functions import assert_equals -from .test_statistical_functions import axes_ndindex, normalise_axis +from .test_manipulation_functions import assert_equals as assert_equals_ +from .test_searching_functions import assert_default_index +from .test_statistical_functions import assert_equals, axes_ndindex, normalise_axis -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) -def test_argsort(x): - xp.argsort(x) - # TODO +# TODO: Test with signed zeros and NaNs (and ignore them somehow) +@given( + x=xps.arrays( + dtype=xps.scalar_dtypes(), + shape=hh.shapes(min_dims=1, min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_argsort(x, data): + if dh.is_float_dtype(x.dtype): + assume(not xp.any(x == -0.0) and not xp.any(x == +0.0)) + + kw = data.draw( + hh.kwargs( + axis=st.integers(-x.ndim, x.ndim - 1), + descending=st.booleans(), + stable=st.booleans(), + ), + label="kw", + ) + + out = xp.argsort(x, **kw) + + assert_default_index("sort", out.dtype) + ph.assert_shape("sort", out.shape, x.shape, **kw) + axis = kw.get("axis", -1) + axes = normalise_axis(axis, x.ndim) + descending = kw.get("descending", False) + scalar_type = dh.get_scalar_type(x.dtype) + for indices in axes_ndindex(x.shape, axes): + elements = [scalar_type(x[idx]) for idx in indices] + indices_order = sorted(range(len(indices)), key=elements.__getitem__) + if descending: + # sorted(..., reverse=descending) doesn't always work + indices_order = reversed(indices_order) + for idx, o in zip(indices, indices_order): + assert_equals("argsort", int, idx, int(out[idx]), o) # TODO: Test with signed zeros and NaNs (and ignore them somehow) @@ -55,7 +89,7 @@ def test_sort(x, data): ) x_indices = [indices[o] for o in indices_order] for out_idx, x_idx in zip(indices, x_indices): - assert_equals( + assert_equals_( "sort", f"x[{x_idx}]", x[x_idx], From d594ff5430a0ef086a9d595bc6cfd2bcccc6e00a Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 13 Dec 2021 13:15:38 +0000 Subject: [PATCH 48/60] Cover everything in `test_unique_values()` --- array_api_tests/test_set_functions.py | 31 +++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 856a7282..d79fb169 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -1,7 +1,12 @@ -from hypothesis import given +import math + +from hypothesis import assume, given from . import _array_module as xp +from . import array_helpers as ah +from . import dtype_helpers as dh from . import hypothesis_helpers as hh +from . import pytest_helpers as ph from . import xps @@ -23,7 +28,25 @@ def test_unique_inverse(x): # TODO -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) def test_unique_values(x): - xp.unique_values(x) - # TODO + out = xp.unique_values(x) + ph.assert_dtype("unique_values", x.dtype, out.dtype) + scalar_type = dh.get_scalar_type(x.dtype) + distinct = set(scalar_type(x[idx]) for idx in ah.ndindex(x.shape)) + vals_idx = {} + nans = 0 + for idx in ah.ndindex(out.shape): + val = scalar_type(out[idx]) + if math.isnan(val): + nans += 1 + else: + assert val in distinct, f"out[{idx}]={val}, but {val} not in input array" + assert ( + val not in vals_idx.keys() + ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" + vals_idx[val] = idx + if dh.is_float_dtype(out.dtype): + assume(x.size <= 128) # may not be representable + expected_nans = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) + assert nans == expected_nans, f"{nans} NaNs in out, expected {expected_nans}" From f8c99d5e2f8609e3c96e6d9ca3a5406b6c6448af Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 13 Dec 2021 16:13:13 +0000 Subject: [PATCH 49/60] Cover everything in `test_unique_counts` --- array_api_tests/test_set_functions.py | 54 ++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index d79fb169..7a932f6d 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -1,4 +1,5 @@ import math +from collections import Counter from hypothesis import assume, given @@ -8,6 +9,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps +from .test_searching_functions import assert_default_index @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) @@ -16,10 +18,52 @@ def test_unique_all(x): # TODO -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) def test_unique_counts(x): - xp.unique_counts(x) - # TODO + out = xp.unique_counts(x) + assert hasattr(out, "values") + assert hasattr(out, "counts") + ph.assert_dtype( + "unique_counts", x.dtype, out.values.dtype, repr_name="out.values.dtype" + ) + assert_default_index( + "unique_counts", out.counts.dtype, repr_name="out.counts.dtype" + ) + assert ( + out.counts.shape == out.values.shape + ), f"{out.counts.shape=}, but should be {out.values.shape=}" + scalar_type = dh.get_scalar_type(out.values.dtype) + counts = Counter(scalar_type(x[idx]) for idx in ah.ndindex(x.shape)) + vals_idx = {} + nans = 0 + for idx in ah.ndindex(out.values.shape): + val = scalar_type(out.values[idx]) + count = int(out.counts[idx]) + if math.isnan(val): + nans += 1 + assert count == 1, ( + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " + "but count should be 1 as NaNs are distinct" + ) + else: + expected = counts[val] + assert ( + expected > 0 + ), f"out.values[{idx}]={val}, but {val} not in input array" + count = int(out.counts[idx]) + assert count == expected, ( + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " + f"but should be {expected}" + ) + assert ( + val not in vals_idx.keys() + ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" + vals_idx[val] = idx + if dh.is_float_dtype(out.values.dtype): + assume(x.size <= 128) # may not be representable + expected = sum(v for k, v in counts.items() if math.isnan(k)) + print(f"{counts=}") + assert nans == expected, f"{nans} NaNs in out, but should be {expected}" @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) @@ -48,5 +92,5 @@ def test_unique_values(x): vals_idx[val] = idx if dh.is_float_dtype(out.dtype): assume(x.size <= 128) # may not be representable - expected_nans = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) - assert nans == expected_nans, f"{nans} NaNs in out, expected {expected_nans}" + expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) + assert nans == expected, f"{nans} NaNs in out, but should be {expected}" From ab8674fcd637ba4f710501075fae530a9993ffd6 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 14 Dec 2021 11:21:13 +0000 Subject: [PATCH 50/60] Cover everything in `test_unique_inverse` --- array_api_tests/test_set_functions.py | 55 +++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 7a932f6d..7f1ad10a 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -1,3 +1,4 @@ +# TODO: disable if opted out import math from collections import Counter @@ -62,14 +63,60 @@ def test_unique_counts(x): if dh.is_float_dtype(out.values.dtype): assume(x.size <= 128) # may not be representable expected = sum(v for k, v in counts.items() if math.isnan(k)) - print(f"{counts=}") assert nans == expected, f"{nans} NaNs in out, but should be {expected}" -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) def test_unique_inverse(x): - xp.unique_inverse(x) - # TODO + out = xp.unique_inverse(x) + assert hasattr(out, "values") + assert hasattr(out, "inverse_indices") + ph.assert_dtype( + "unique_inverse", x.dtype, out.values.dtype, repr_name="out.values.dtype" + ) + assert_default_index( + "unique_inverse", + out.inverse_indices.dtype, + repr_name="out.inverse_indices.dtype", + ) + ph.assert_shape( + "unique_inverse", + out.inverse_indices.shape, + x.shape, + repr_name="out.inverse_indices.shape", + ) + scalar_type = dh.get_scalar_type(out.values.dtype) + distinct = set(scalar_type(x[idx]) for idx in ah.ndindex(x.shape)) + vals_idx = {} + nans = 0 + for idx in ah.ndindex(out.values.shape): + val = scalar_type(out.values[idx]) + if math.isnan(val): + nans += 1 + else: + assert ( + val in distinct + ), f"out.values[{idx}]={val}, but {val} not in input array" + assert ( + val not in vals_idx.keys() + ), f"out.values[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" + vals_idx[val] = idx + for idx in ah.ndindex(out.inverse_indices.shape): + ridx = int(out.inverse_indices[idx]) + val = out.values[ridx] + expected = x[idx] + msg = ( + f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, " + f"but should result in x[{idx}]={expected}" + ) + if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected): + assert xp.isnan(val), msg + else: + assert val == expected, msg + if dh.is_float_dtype(out.values.dtype): + assume(x.size <= 128) # may not be representable + expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) + assert nans == expected, f"{nans} NaNs in out.values, but should be {expected}" @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) From 131dd3151cd95612f5022465bd544e55520a4de1 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 14 Dec 2021 19:46:22 +0000 Subject: [PATCH 51/60] Cover everything in `test_unique_all` (if messily) --- array_api_tests/test_set_functions.py | 102 +++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 7f1ad10a..54e81c93 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -1,6 +1,6 @@ # TODO: disable if opted out import math -from collections import Counter +from collections import Counter, defaultdict from hypothesis import assume, given @@ -13,10 +13,104 @@ from .test_searching_functions import assert_default_index -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +@given( + xps.arrays( + dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1, min_dims=1, max_dims=1) + ) +) # TODO def test_unique_all(x): - xp.unique_all(x) - # TODO + out = xp.unique_all(x) + + assert hasattr(out, "values") + assert hasattr(out, "indices") + assert hasattr(out, "inverse_indices") + assert hasattr(out, "counts") + + ph.assert_dtype( + "unique_all", x.dtype, out.values.dtype, repr_name="out.values.dtype" + ) + assert_default_index("unique_all", out.indices.dtype, repr_name="out.indices.dtype") + assert_default_index( + "unique_all", out.inverse_indices.dtype, repr_name="out.inverse_indices.dtype" + ) + assert_default_index("unique_all", out.counts.dtype, repr_name="out.counts.dtype") + + assert ( + out.indices.shape == out.values.shape + ), f"{out.indices.shape=}, but should be {out.values.shape=}" + ph.assert_shape( + "unique_all", + out.inverse_indices.shape, + x.shape, + repr_name="out.inverse_indices.shape", + ) + assert ( + out.counts.shape == out.values.shape + ), f"{out.counts.shape=}, but should be {out.values.shape=}" + + scalar_type = dh.get_scalar_type(out.values.dtype) + counts = defaultdict(int) + firsts = {} + for i, idx in enumerate(ah.ndindex(x.shape)): + val = scalar_type(x[idx]) + if counts[val] == 0: + firsts[val] = i + counts[val] += 1 + + for idx in ah.ndindex(out.indices.shape): + val = scalar_type(out.values[idx]) + if math.isnan(val): + break + i = int(out.indices[idx]) + expected = firsts[val] + assert i == expected, ( + f"out.values[{idx}]={val} and out.indices[{idx}]={i}, " + f"but first occurence of {val} is at {expected}" + ) + + for idx in ah.ndindex(out.inverse_indices.shape): + ridx = int(out.inverse_indices[idx]) + val = out.values[ridx] + expected = x[idx] + msg = ( + f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, " + f"but should result in x[{idx}]={expected}" + ) + if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected): + assert xp.isnan(val), msg + else: + assert val == expected, msg + + vals_idx = {} + nans = 0 + for idx in ah.ndindex(out.values.shape): + val = scalar_type(out.values[idx]) + count = int(out.counts[idx]) + if math.isnan(val): + nans += 1 + assert count == 1, ( + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " + "but count should be 1 as NaNs are distinct" + ) + else: + expected = counts[val] + assert ( + expected > 0 + ), f"out.values[{idx}]={val}, but {val} not in input array" + count = int(out.counts[idx]) + assert count == expected, ( + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " + f"but should be {expected}" + ) + assert ( + val not in vals_idx.keys() + ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" + vals_idx[val] = idx + + if dh.is_float_dtype(out.values.dtype): + assume(x.size <= 128) # may not be representable + expected = sum(v for k, v in counts.items() if math.isnan(k)) + assert nans == expected, f"{nans} NaNs in out, but should be {expected}" @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) From 8267a7c9bb0b6c918143cf1e0bc0f6f56420d85f Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 16 Dec 2021 11:38:06 +0000 Subject: [PATCH 52/60] Cover everything in `test_all` --- array_api_tests/test_set_functions.py | 2 +- array_api_tests/test_utility_functions.py | 43 +++++++++++++++++++---- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 54e81c93..544820f0 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -1,4 +1,4 @@ -# TODO: disable if opted out +# TODO: disable if opted out, refactor things import math from collections import Counter, defaultdict diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index 140aa85f..39bb4555 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -1,19 +1,48 @@ from hypothesis import given +from hypothesis import strategies as st from . import _array_module as xp +from . import array_helpers as ah +from . import dtype_helpers as dh from . import hypothesis_helpers as hh +from . import pytest_helpers as ph from . import xps +from .test_statistical_functions import ( + assert_equals, + assert_keepdimable_shape, + axes, + axes_ndindex, + normalise_axis, +) -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) -def test_any(x): - xp.any(x) - # TODO +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_all(x, data): + kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + + out = xp.all(x, **kw) + + ph.assert_dtype("all", x.dtype, out.dtype, xp.bool) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "all", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(x.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + result = bool(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = all(elements) + assert_equals("all", scalar_type, out_idx, result, expected) # TODO: generate kwargs @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) -def test_all(x): - xp.all(x) +def test_any(x): + xp.any(x) # TODO From 3fc57beb0660d60058e96e3c3d91454ef776efa5 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 16 Dec 2021 11:39:40 +0000 Subject: [PATCH 53/60] Cover everything in `test_any` --- array_api_tests/test_utility_functions.py | 28 +++++++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index 39bb4555..78a3649b 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -41,8 +41,26 @@ def test_all(x, data): assert_equals("all", scalar_type, out_idx, result, expected) -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) -def test_any(x): - xp.any(x) - # TODO +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), + data=st.data(), +) +def test_any(x, data): + kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + + out = xp.any(x, **kw) + + ph.assert_dtype("any", x.dtype, out.dtype, xp.bool) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "any", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(x.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + result = bool(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = any(elements) + assert_equals("any", scalar_type, out_idx, result, expected) From a435e632b777a7afddab2f086fc9a94aa2a20f20 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 16 Dec 2021 12:45:20 +0000 Subject: [PATCH 54/60] Test 0d arrays conversion to scalars --- array_api_tests/test_array2scalar.py | 39 ++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 array_api_tests/test_array2scalar.py diff --git a/array_api_tests/test_array2scalar.py b/array_api_tests/test_array2scalar.py new file mode 100644 index 00000000..55fb2fe3 --- /dev/null +++ b/array_api_tests/test_array2scalar.py @@ -0,0 +1,39 @@ +import pytest +from hypothesis import given +from hypothesis import strategies as st + +from . import _array_module as xp +from . import dtype_helpers as dh +from . import xps +from .typing import DataType, Param + +method_stype = { + "__bool__": bool, + "__int__": int, + "__index__": int, + "__float__": float, +} + + +def make_param(method_name: str, dtype: DataType) -> Param: + stype = method_stype[method_name] + return pytest.param( + method_name, dtype, stype, id=f"{method_name}({dh.dtype_to_name[dtype]})" + ) + + +@pytest.mark.parametrize( + "method_name, dtype, stype", + [make_param("__bool__", xp.bool)] + + [make_param("__int__", d) for d in dh.all_int_dtypes] + + [make_param("__index__", d) for d in dh.all_int_dtypes] + + [make_param("__float__", d) for d in dh.float_dtypes], +) +@given(data=st.data()) +def test_0d_array_can_convert_to_scalar(method_name, dtype, stype, data): + x = data.draw(xps.arrays(dtype, shape=()), label="x") + method = getattr(x, method_name) + out = method() + assert isinstance( + out, stype + ), f"{method_name}({x})={out}, which is not a {stype.__name__} scalar" From dcc4adf931789ea68a1a9006f600f290194f870c Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 16 Dec 2021 14:39:30 +0000 Subject: [PATCH 55/60] Move `axes()` strategy to `hypothesis_helpers.py` --- array_api_tests/hypothesis_helpers.py | 11 +++++++++- array_api_tests/pytest_helpers.py | 2 +- array_api_tests/test_searching_functions.py | 19 ++++++++-------- array_api_tests/test_statistical_functions.py | 22 ++++++------------- array_api_tests/test_utility_functions.py | 5 ++--- 5 files changed, 29 insertions(+), 30 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index a14f3d51..8871651a 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -2,7 +2,7 @@ from functools import reduce from math import sqrt from operator import mul -from typing import Any, List, NamedTuple, Optional, Tuple, Sequence +from typing import Any, List, NamedTuple, Optional, Tuple, Sequence, Union from hypothesis import assume from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, @@ -399,3 +399,12 @@ def specified_kwargs(draw, *keys_values_defaults: KVD): if value is not default or draw(booleans()): kw[keyword] = value return kw + + +def axes(ndim: int) -> SearchStrategy[Optional[Union[int, Shape]]]: + """Generate valid arguments for some axis keywords""" + axes_strats = [none()] + if ndim != 0: + axes_strats.append(integers(-ndim, ndim - 1)) + axes_strats.append(xps.valid_tuple_axes(ndim)) + return one_of(axes_strats) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index b138af3e..fa9e8b87 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -1,4 +1,4 @@ -from array_api_tests.algos import broadcast_shapes +from .algos import broadcast_shapes import math from inspect import getfullargspec from typing import Any, Dict, Optional, Tuple, Union diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 6003ccd4..88a9da26 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -1,22 +1,21 @@ from hypothesis import given from hypothesis import strategies as st -from array_api_tests.algos import broadcast_shapes -from array_api_tests.test_manipulation_functions import assert_equals as assert_equals_ -from array_api_tests.test_statistical_functions import ( - assert_equals, - assert_keepdimable_shape, - axes_ndindex, - normalise_axis, -) -from array_api_tests.typing import DataType - from . import _array_module as xp from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps +from .algos import broadcast_shapes +from .test_manipulation_functions import assert_equals as assert_equals_ +from .test_statistical_functions import ( + assert_equals, + assert_keepdimable_shape, + axes_ndindex, + normalise_axis, +) +from .typing import DataType def assert_default_index(func_name: str, dtype: DataType, repr_name="out.dtype"): diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 81e11558..c0a03a3d 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -15,14 +15,6 @@ from .typing import DataType, Scalar, ScalarType, Shape -def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]: - axes_strats = [st.none()] - if ndim != 0: - axes_strats.append(st.integers(-ndim, ndim - 1)) - axes_strats.append(xps.valid_tuple_axes(ndim)) - return st.one_of(axes_strats) - - def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype] return st.none() | st.sampled_from(dtypes) @@ -108,7 +100,7 @@ def assert_equals( data=st.data(), ) def test_max(x, data): - kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") out = xp.max(x, **kw) @@ -137,7 +129,7 @@ def test_max(x, data): data=st.data(), ) def test_mean(x, data): - kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") out = xp.mean(x, **kw) @@ -166,7 +158,7 @@ def test_mean(x, data): data=st.data(), ) def test_min(x, data): - kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") out = xp.min(x, **kw) @@ -197,7 +189,7 @@ def test_min(x, data): def test_prod(x, data): kw = data.draw( hh.kwargs( - axis=axes(x.ndim), + axis=hh.axes(x.ndim), dtype=kwarg_dtypes(x.dtype), keepdims=st.booleans(), ), @@ -258,7 +250,7 @@ def test_prod(x, data): data=st.data(), ) def test_std(x, data): - axis = data.draw(axes(x.ndim), label="axis") + axis = data.draw(hh.axes(x.ndim), label="axis") _axes = normalise_axis(axis, x.ndim) N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) correction = data.draw( @@ -295,7 +287,7 @@ def test_std(x, data): def test_sum(x, data): kw = data.draw( hh.kwargs( - axis=axes(x.ndim), + axis=hh.axes(x.ndim), dtype=kwarg_dtypes(x.dtype), keepdims=st.booleans(), ), @@ -356,7 +348,7 @@ def test_sum(x, data): data=st.data(), ) def test_var(x, data): - axis = data.draw(axes(x.ndim), label="axis") + axis = data.draw(hh.axes(x.ndim), label="axis") _axes = normalise_axis(axis, x.ndim) N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) correction = data.draw( diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index 78a3649b..fe95ca01 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -10,7 +10,6 @@ from .test_statistical_functions import ( assert_equals, assert_keepdimable_shape, - axes, axes_ndindex, normalise_axis, ) @@ -21,7 +20,7 @@ data=st.data(), ) def test_all(x, data): - kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") out = xp.all(x, **kw) @@ -46,7 +45,7 @@ def test_all(x, data): data=st.data(), ) def test_any(x, data): - kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") out = xp.any(x, **kw) From 6194b087afc14801c780e0da5c7dbababea72529 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 16 Dec 2021 15:58:48 +0000 Subject: [PATCH 56/60] Remove faulty assertion in `test_two_mutual_arrays` --- array_api_tests/meta/test_hypothesis_helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_tests/meta/test_hypothesis_helpers.py b/array_api_tests/meta/test_hypothesis_helpers.py index 652644c1..b4cb6e96 100644 --- a/array_api_tests/meta/test_hypothesis_helpers.py +++ b/array_api_tests/meta/test_hypothesis_helpers.py @@ -68,7 +68,6 @@ def test_two_broadcastable_shapes(pair): @given(*hh.two_mutual_arrays()) def test_two_mutual_arrays(x1, x2): assert (x1.dtype, x2.dtype) in dh.promotion_table.keys() - assert broadcast_shapes(x1.shape, x2.shape) in (x1.shape, x2.shape) def test_two_mutual_arrays_raises_on_bad_dtypes(): From 1273270e4fa66010b5e5ae69521d1f1f72213659 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 16 Dec 2021 16:32:23 +0000 Subject: [PATCH 57/60] Move shape-related helpers into `shape_helpers.py` --- array_api_tests/array_helpers.py | 14 +---- array_api_tests/hypothesis_helpers.py | 8 +-- array_api_tests/meta/test_utils.py | 10 ++- array_api_tests/shape_helpers.py | 59 ++++++++++++++++++ array_api_tests/test_elementwise_functions.py | 47 +++++++------- array_api_tests/test_linalg.py | 7 ++- .../test_manipulation_functions.py | 51 ++++++--------- array_api_tests/test_searching_functions.py | 23 +++---- array_api_tests/test_set_functions.py | 24 +++---- array_api_tests/test_sorting.py | 11 ++-- array_api_tests/test_statistical_functions.py | 62 +++++-------------- array_api_tests/test_utility_functions.py | 17 ++--- 12 files changed, 162 insertions(+), 171 deletions(-) create mode 100644 array_api_tests/shape_helpers.py diff --git a/array_api_tests/array_helpers.py b/array_api_tests/array_helpers.py index 398f1994..b3ae583c 100644 --- a/array_api_tests/array_helpers.py +++ b/array_api_tests/array_helpers.py @@ -1,5 +1,3 @@ -import itertools - from ._array_module import (isnan, all, any, equal, not_equal, logical_and, logical_or, isfinite, greater, less, less_equal, zeros, ones, full, bool, int8, int16, int32, @@ -23,7 +21,7 @@ 'assert_isinf', 'positive_mathematical_sign', 'assert_positive_mathematical_sign', 'negative_mathematical_sign', 'assert_negative_mathematical_sign', 'same_sign', - 'assert_same_sign', 'ndindex', 'float64', + 'assert_same_sign', 'float64', 'asarray', 'full', 'true', 'false', 'isnan'] def zero(shape, dtype): @@ -319,13 +317,3 @@ def int_to_dtype(x, n, signed): if x & highest_bit: x = -((~x & mask) + 1) return x - -def ndindex(shape): - """ - Iterator of n-D indices to an array - - Yields tuples of integers to index every element of an array of shape - `shape`. Same as np.ndindex(). - - """ - return itertools.product(*[range(i) for i in shape]) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 8871651a..d8c5f976 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -2,7 +2,7 @@ from functools import reduce from math import sqrt from operator import mul -from typing import Any, List, NamedTuple, Optional, Tuple, Sequence, Union +from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union from hypothesis import assume from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, @@ -11,15 +11,15 @@ from . import _array_module as xp from . import dtype_helpers as dh +from . import shape_helpers as sh from . import xps from ._array_module import _UndefinedStub from ._array_module import bool as bool_dtype from ._array_module import broadcast_to, eye, float32, float64, full -from .array_helpers import ndindex +from .algos import broadcast_shapes from .function_stubs import elementwise_functions from .pytest_helpers import nargs from .typing import Array, DataType, Shape -from .algos import broadcast_shapes # Set this to True to not fail tests just because a dtype isn't implemented. # If no compatible dtype is implemented for a given test, the test will fail @@ -208,7 +208,7 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes( assume(xp.all(xp.abs(d) > 0.5)) a = xp.zeros(shape) - for j, (idx, i) in enumerate(itertools.product(ndindex(stack_shape), range(n))): + for j, (idx, i) in enumerate(itertools.product(sh.ndindex(stack_shape), range(n))): a[idx + (i, i)] = d[j] return a diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 83a4b75d..814c62cb 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -1,10 +1,8 @@ import pytest -from .. import array_helpers as ah +from .. import shape_helpers as sh from ..test_creation_functions import frange -from ..test_manipulation_functions import axis_ndindex from ..test_signatures import extension_module -from ..test_statistical_functions import axes_ndindex def test_extension_module_is_extension(): @@ -34,7 +32,7 @@ def test_frange(r, size, elements): [((), [()])], ) def test_ndindex(shape, expected): - assert list(ah.ndindex(shape)) == expected + assert list(sh.ndindex(shape)) == expected @pytest.mark.parametrize( @@ -50,7 +48,7 @@ def test_ndindex(shape, expected): ], ) def test_axis_ndindex(shape, axis, expected): - assert list(axis_ndindex(shape, axis)) == expected + assert list(sh.axis_ndindex(shape, axis)) == expected @pytest.mark.parametrize( @@ -69,4 +67,4 @@ def test_axis_ndindex(shape, axis, expected): ], ) def test_axes_ndindex(shape, axes, expected): - assert list(axes_ndindex(shape, axes)) == expected + assert list(sh.axes_ndindex(shape, axes)) == expected diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py new file mode 100644 index 00000000..751b8d49 --- /dev/null +++ b/array_api_tests/shape_helpers.py @@ -0,0 +1,59 @@ +from itertools import product +from typing import Iterator, List, Optional, Tuple, Union + +from .typing import Shape + +__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex"] + + +def normalise_axis( + axis: Optional[Union[int, Tuple[int, ...]]], ndim: int +) -> Tuple[int, ...]: + if axis is None: + return tuple(range(ndim)) + axes = axis if isinstance(axis, tuple) else (axis,) + axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes) + return axes + + +def ndindex(shape): + """Iterator of n-D indices to an array + + Yields tuples of integers to index every element of an array of shape + `shape`. Same as np.ndindex(). + """ + return product(*[range(i) for i in shape]) + + +def axis_ndindex( + shape: Shape, axis: int +) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: + """Generate indices that index all elements in dimensions beyond `axis`""" + assert axis >= 0 # sanity check + axis_indices = [range(side) for side in shape[:axis]] + for _ in range(axis, len(shape)): + axis_indices.append([slice(None, None)]) + yield from product(*axis_indices) + + +def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]: + """Generate indices that index all elements except in `axes` dimensions""" + base_indices = [] + axes_indices = [] + for axis, side in enumerate(shape): + if axis in axes: + base_indices.append([None]) + axes_indices.append(range(side)) + else: + base_indices.append(range(side)) + axes_indices.append([None]) + for base_idx in product(*base_indices): + indices = [] + for idx in product(*axes_indices): + idx = list(idx) + for axis, side in enumerate(idx): + if axis not in axes: + idx[axis] = base_idx[axis] + idx = tuple(idx) + indices.append(idx) + yield list(indices) diff --git a/array_api_tests/test_elementwise_functions.py b/array_api_tests/test_elementwise_functions.py index 4d288ee4..345eec44 100644 --- a/array_api_tests/test_elementwise_functions.py +++ b/array_api_tests/test_elementwise_functions.py @@ -23,6 +23,7 @@ from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps from .algos import broadcast_shapes from .typing import Array, DataType, Param, Scalar @@ -377,13 +378,13 @@ def test_bitwise_and( # Compare against the Python & operator. if res.dtype == xp.bool: - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = bool(_left[idx]) s_right = bool(_right[idx]) s_res = bool(res[idx]) assert (s_left and s_right) == s_res else: - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = int(_left[idx]) s_right = int(_right[idx]) s_res = int(res[idx]) @@ -427,7 +428,7 @@ def test_bitwise_left_shift( _right = xp.broadcast_to(right, shape) # Compare against the Python << operator. - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = int(_left[idx]) s_right = int(_right[idx]) s_res = int(res[idx]) @@ -452,12 +453,12 @@ def test_bitwise_invert(func_name, func, strat, data): ph.assert_shape(func_name, out.shape, x.shape) # Compare against the Python ~ operator. if out.dtype == xp.bool: - for idx in ah.ndindex(out.shape): + for idx in sh.ndindex(out.shape): s_x = bool(x[idx]) s_out = bool(out[idx]) assert (not s_x) == s_out else: - for idx in ah.ndindex(out.shape): + for idx in sh.ndindex(out.shape): s_x = int(x[idx]) s_out = int(out[idx]) s_invert = ah.int_to_dtype( @@ -495,13 +496,13 @@ def test_bitwise_or( # Compare against the Python | operator. if res.dtype == xp.bool: - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = bool(_left[idx]) s_right = bool(_right[idx]) s_res = bool(res[idx]) assert (s_left or s_right) == s_res else: - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = int(_left[idx]) s_right = int(_right[idx]) s_res = int(res[idx]) @@ -547,7 +548,7 @@ def test_bitwise_right_shift( _right = xp.broadcast_to(right, shape) # Compare against the Python >> operator. - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = int(_left[idx]) s_right = int(_right[idx]) s_res = int(res[idx]) @@ -586,13 +587,13 @@ def test_bitwise_xor( # Compare against the Python ^ operator. if res.dtype == xp.bool: - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = bool(_left[idx]) s_right = bool(_right[idx]) s_res = bool(res[idx]) assert (s_left ^ s_right) == s_res else: - for idx in ah.ndindex(res.shape): + for idx in sh.ndindex(res.shape): s_left = int(_left[idx]) s_right = int(_right[idx]) s_res = int(res[idx]) @@ -721,7 +722,7 @@ def test_equal( _right = ah.asarray(_right, dtype=promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): x1_idx = _left[idx] x2_idx = _right[idx] out_idx = out[idx] @@ -846,7 +847,7 @@ def test_greater( _right = ah.asarray(_right, dtype=promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): out_idx = out[idx] x1_idx = _left[idx] x2_idx = _right[idx] @@ -887,7 +888,7 @@ def test_greater_equal( _right = ah.asarray(_right, dtype=promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): out_idx = out[idx] x1_idx = _left[idx] x2_idx = _right[idx] @@ -907,7 +908,7 @@ def test_isfinite(x): # Test the exact value by comparing to the math version if dh.is_float_dtype(x.dtype): - for idx in ah.ndindex(x.shape): + for idx in sh.ndindex(x.shape): s = float(x[idx]) assert bool(res[idx]) == math.isfinite(s) @@ -925,7 +926,7 @@ def test_isinf(x): # Test the exact value by comparing to the math version if dh.is_float_dtype(x.dtype): - for idx in ah.ndindex(x.shape): + for idx in sh.ndindex(x.shape): s = float(x[idx]) assert bool(res[idx]) == math.isinf(s) @@ -943,7 +944,7 @@ def test_isnan(x): # Test the exact value by comparing to the math version if dh.is_float_dtype(x.dtype): - for idx in ah.ndindex(x.shape): + for idx in sh.ndindex(x.shape): s = float(x[idx]) assert bool(res[idx]) == math.isnan(s) @@ -979,7 +980,7 @@ def test_less( _right = ah.asarray(_right, dtype=promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): x1_idx = _left[idx] x2_idx = _right[idx] out_idx = out[idx] @@ -1020,7 +1021,7 @@ def test_less_equal( _right = ah.asarray(_right, dtype=promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): x1_idx = _left[idx] x2_idx = _right[idx] out_idx = out[idx] @@ -1100,7 +1101,7 @@ def test_logical_and(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): assert out[idx] == (bool(_x1[idx]) and bool(_x2[idx])) @@ -1108,7 +1109,7 @@ def test_logical_and(x1, x2): def test_logical_not(x): out = ah.logical_not(x) ph.assert_shape("logical_not", out.shape, x.shape) - for idx in ah.ndindex(x.shape): + for idx in sh.ndindex(x.shape): assert out[idx] == (not bool(x[idx])) @@ -1122,7 +1123,7 @@ def test_logical_or(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): assert out[idx] == (bool(_x1[idx]) or bool(_x2[idx])) @@ -1136,7 +1137,7 @@ def test_logical_xor(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): assert out[idx] == (bool(_x1[idx]) ^ bool(_x2[idx])) @@ -1225,7 +1226,7 @@ def test_not_equal( _right = ah.asarray(_right, dtype=promoted_dtype) scalar_type = dh.get_scalar_type(promoted_dtype) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): out_idx = out[idx] x1_idx = _left[idx] x2_idx = _right[idx] diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index ac9f3359..89707d3f 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -18,7 +18,7 @@ from hypothesis.strategies import (booleans, composite, none, tuples, integers, shared, sampled_from) -from .array_helpers import assert_exactly_equal, ndindex, asarray, equal, zero, infinity +from .array_helpers import assert_exactly_equal, asarray, equal, zero, infinity from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes, square_matrix_shapes, symmetric_matrices, positive_definite_matrices, MAX_ARRAY_SIZE, @@ -28,6 +28,7 @@ SQRT_MAX_ARRAY_SIZE, finite_matrices) from . import dtype_helpers as dh from . import pytest_helpers as ph +from . import shape_helpers as sh from .algos import broadcast_shapes @@ -53,7 +54,7 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, **kw): shape = args[0].shape if len(args) == 1 else broadcast_shapes(*[x.shape for x in args]) - for _idx in ndindex(shape[:-2]): + for _idx in sh.ndindex(shape[:-2]): idx = _idx + (slice(None),)*dims res_stack = res[idx] x_stacks = [x[_idx + (...,)] for x in args] @@ -147,7 +148,7 @@ def test_cross(x1_x2_kw): # is the only function that works the way it does, so it's not really # worth generalizing _test_stacks to handle it. a = axis if axis >= 0 else axis + len(shape) - for _idx in ndindex(shape[:a] + shape[a+1:]): + for _idx in sh.ndindex(shape[:a] + shape[a+1:]): idx = _idx[:a] + (slice(None),) + _idx[a:] assert len(idx) == len(shape), "Invalid index. This indicates a bug in the test suite." res_stack = res[idx] diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 9453fc25..433b096a 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,19 +1,17 @@ import math from collections import deque -from itertools import product -from typing import Iterable, Iterator, Tuple, Union +from typing import Iterable, Union import pytest from hypothesis import assume, given from hypothesis import strategies as st from . import _array_module as xp -from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps -from .test_statistical_functions import axes_ndindex, normalise_axis from .typing import Array, Shape MAX_SIDE = hh.MAX_ARRAY_SIZE // 64 @@ -29,17 +27,6 @@ def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: return st.shared(hh.shapes(*args, **kwargs), key="shape") -def axis_ndindex( - shape: Shape, axis: int -) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: - """Generate indices that index all elements in dimensions beyond `axis`""" - assert axis >= 0 # sanity check - axis_indices = [range(side) for side in shape[:axis]] - for _ in range(axis, len(shape)): - axis_indices.append([slice(None, None)]) - yield from product(*axis_indices) - - def assert_array_ndindex( func_name: str, x: Array, @@ -115,7 +102,7 @@ def test_concat(dtypes, kw, data): if axis is None: out_indices = (i for i in range(out.size)) for x_num, x in enumerate(arrays, 1): - for x_idx in ah.ndindex(x.shape): + for x_idx in sh.ndindex(x.shape): out_i = next(out_indices) assert_equals( "concat", @@ -126,12 +113,12 @@ def test_concat(dtypes, kw, data): **kw, ) else: - out_indices = ah.ndindex(out.shape) - for idx in axis_ndindex(shapes[0], _axis): + out_indices = sh.ndindex(out.shape) + for idx in sh.axis_ndindex(shapes[0], _axis): f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) for x_num, x in enumerate(arrays, 1): indexed_x = x[idx] - for x_idx in ah.ndindex(indexed_x.shape): + for x_idx in sh.ndindex(indexed_x.shape): out_idx = next(out_indices) assert_equals( "concat", @@ -167,7 +154,7 @@ def test_expand_dims(x, axis): ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape) assert_array_ndindex( - "expand_dims", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape) + "expand_dims", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape) ) @@ -186,7 +173,7 @@ def test_squeeze(x, data): ) axes = (axis,) if isinstance(axis, int) else axis - axes = normalise_axis(axes, x.ndim) + axes = sh.normalise_axis(axes, x.ndim) squeezable_axes = [i for i, side in enumerate(x.shape) if side == 1] if any(i not in squeezable_axes for i in axes): @@ -205,7 +192,7 @@ def test_squeeze(x, data): shape = tuple(shape) ph.assert_result_shape("squeeze", (x.shape,), out.shape, shape, axis=axis) - assert_array_ndindex("squeeze", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) + assert_array_ndindex("squeeze", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)) @given( @@ -225,8 +212,8 @@ def test_flip(x, data): ph.assert_dtype("flip", x.dtype, out.dtype) - _axes = normalise_axis(kw.get("axis", None), x.ndim) - for indices in axes_ndindex(x.shape, _axes): + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + for indices in sh.axes_ndindex(x.shape, _axes): reverse_indices = indices[::-1] assert_array_ndindex("flip", x, indices, out, reverse_indices) @@ -254,7 +241,7 @@ def test_permute_dims(x, axes): shape = tuple(shape) ph.assert_result_shape("permute_dims", (x.shape,), out.shape, shape, axes=axes) - indices = list(ah.ndindex(x.shape)) + indices = list(sh.ndindex(x.shape)) permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices] assert_array_ndindex("permute_dims", x, indices, out, permuted_indices) @@ -289,7 +276,7 @@ def test_reshape(x, data): _shape = tuple(_shape) ph.assert_result_shape("reshape", (x.shape,), out.shape, _shape, shape=shape) - assert_array_ndindex("reshape", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) + assert_array_ndindex("reshape", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)) @pytest.mark.skip(reason="faulty test logic") # TODO @@ -319,14 +306,14 @@ def test_roll(x, data): if kw.get("axis", None) is None: assert isinstance(shift, int) # sanity check - indices = list(ah.ndindex(x.shape)) + indices = list(sh.ndindex(x.shape)) shifted_indices = deque(indices) shifted_indices.rotate(-shift) assert_array_ndindex("roll", x, indices, out, shifted_indices) else: _shift = (shift,) if isinstance(shift, int) else shift - axes = normalise_axis(kw["axis"], x.ndim) - all_indices = list(ah.ndindex(x.shape)) + axes = sh.normalise_axis(kw["axis"], x.ndim) + all_indices = list(sh.ndindex(x.shape)) for s, a in zip(_shift, axes): side = x.shape[a] for i in range(side): @@ -365,13 +352,13 @@ def test_stack(shape, dtypes, kw, data): "stack", tuple(x.shape for x in arrays), out.shape, _shape, **kw ) - out_indices = ah.ndindex(out.shape) - for idx in axis_ndindex(arrays[0].shape, axis=_axis): + out_indices = sh.ndindex(out.shape) + for idx in sh.axis_ndindex(arrays[0].shape, axis=_axis): f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) print(f"{f_idx=}") for x_num, x in enumerate(arrays, 1): indexed_x = x[idx] - for x_idx in ah.ndindex(indexed_x.shape): + for x_idx in sh.ndindex(indexed_x.shape): out_idx = next(out_indices) assert_equals( "stack", diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 88a9da26..dff4590f 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -2,19 +2,14 @@ from hypothesis import strategies as st from . import _array_module as xp -from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps from .algos import broadcast_shapes from .test_manipulation_functions import assert_equals as assert_equals_ -from .test_statistical_functions import ( - assert_equals, - assert_keepdimable_shape, - axes_ndindex, - normalise_axis, -) +from .test_statistical_functions import assert_equals, assert_keepdimable_shape from .typing import DataType @@ -47,12 +42,12 @@ def test_argmax(x, data): out = xp.argmax(x, **kw) assert_default_index("argmax", out.dtype) - axes = normalise_axis(kw.get("axis", None), x.ndim) + axes = sh.normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( "argmax", out.shape, x.shape, axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(x.dtype) - for indices, out_idx in zip(axes_ndindex(x.shape, axes), ah.ndindex(out.shape)): + for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): max_i = int(out[out_idx]) elements = [] for idx in indices: @@ -82,12 +77,12 @@ def test_argmin(x, data): out = xp.argmin(x, **kw) assert_default_index("argmin", out.dtype) - axes = normalise_axis(kw.get("axis", None), x.ndim) + axes = sh.normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( "argmin", out.shape, x.shape, axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(x.dtype) - for indices, out_idx in zip(axes_ndindex(x.shape, axes), ah.ndindex(out.shape)): + for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): min_i = int(out[out_idx]) elements = [] for idx in indices: @@ -114,11 +109,11 @@ def test_nonzero(x): assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype") indices = [] if x.dtype == xp.bool: - for idx in ah.ndindex(x.shape): + for idx in sh.ndindex(x.shape): if x[idx]: indices.append(idx) else: - for idx in ah.ndindex(x.shape): + for idx in sh.ndindex(x.shape): if x[idx] != 0: indices.append(idx) if x.ndim == 0: @@ -154,7 +149,7 @@ def test_where(shapes, dtypes, data): _cond = xp.broadcast_to(cond, shape) _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - for idx in ah.ndindex(shape): + for idx in sh.ndindex(shape): if _cond[idx]: assert_equals_("where", f"_x1[{idx}]", _x1[idx], f"out[{idx}]", out[idx]) else: diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 544820f0..214df6d0 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -5,10 +5,10 @@ from hypothesis import assume, given from . import _array_module as xp -from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps from .test_searching_functions import assert_default_index @@ -51,13 +51,13 @@ def test_unique_all(x): scalar_type = dh.get_scalar_type(out.values.dtype) counts = defaultdict(int) firsts = {} - for i, idx in enumerate(ah.ndindex(x.shape)): + for i, idx in enumerate(sh.ndindex(x.shape)): val = scalar_type(x[idx]) if counts[val] == 0: firsts[val] = i counts[val] += 1 - for idx in ah.ndindex(out.indices.shape): + for idx in sh.ndindex(out.indices.shape): val = scalar_type(out.values[idx]) if math.isnan(val): break @@ -68,7 +68,7 @@ def test_unique_all(x): f"but first occurence of {val} is at {expected}" ) - for idx in ah.ndindex(out.inverse_indices.shape): + for idx in sh.ndindex(out.inverse_indices.shape): ridx = int(out.inverse_indices[idx]) val = out.values[ridx] expected = x[idx] @@ -83,7 +83,7 @@ def test_unique_all(x): vals_idx = {} nans = 0 - for idx in ah.ndindex(out.values.shape): + for idx in sh.ndindex(out.values.shape): val = scalar_type(out.values[idx]) count = int(out.counts[idx]) if math.isnan(val): @@ -128,10 +128,10 @@ def test_unique_counts(x): out.counts.shape == out.values.shape ), f"{out.counts.shape=}, but should be {out.values.shape=}" scalar_type = dh.get_scalar_type(out.values.dtype) - counts = Counter(scalar_type(x[idx]) for idx in ah.ndindex(x.shape)) + counts = Counter(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) vals_idx = {} nans = 0 - for idx in ah.ndindex(out.values.shape): + for idx in sh.ndindex(out.values.shape): val = scalar_type(out.values[idx]) count = int(out.counts[idx]) if math.isnan(val): @@ -180,10 +180,10 @@ def test_unique_inverse(x): repr_name="out.inverse_indices.shape", ) scalar_type = dh.get_scalar_type(out.values.dtype) - distinct = set(scalar_type(x[idx]) for idx in ah.ndindex(x.shape)) + distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) vals_idx = {} nans = 0 - for idx in ah.ndindex(out.values.shape): + for idx in sh.ndindex(out.values.shape): val = scalar_type(out.values[idx]) if math.isnan(val): nans += 1 @@ -195,7 +195,7 @@ def test_unique_inverse(x): val not in vals_idx.keys() ), f"out.values[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" vals_idx[val] = idx - for idx in ah.ndindex(out.inverse_indices.shape): + for idx in sh.ndindex(out.inverse_indices.shape): ridx = int(out.inverse_indices[idx]) val = out.values[ridx] expected = x[idx] @@ -218,10 +218,10 @@ def test_unique_values(x): out = xp.unique_values(x) ph.assert_dtype("unique_values", x.dtype, out.dtype) scalar_type = dh.get_scalar_type(x.dtype) - distinct = set(scalar_type(x[idx]) for idx in ah.ndindex(x.shape)) + distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) vals_idx = {} nans = 0 - for idx in ah.ndindex(out.shape): + for idx in sh.ndindex(out.shape): val = scalar_type(out[idx]) if math.isnan(val): nans += 1 diff --git a/array_api_tests/test_sorting.py b/array_api_tests/test_sorting.py index 14f76ff9..2578ea99 100644 --- a/array_api_tests/test_sorting.py +++ b/array_api_tests/test_sorting.py @@ -6,10 +6,11 @@ from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps from .test_manipulation_functions import assert_equals as assert_equals_ from .test_searching_functions import assert_default_index -from .test_statistical_functions import assert_equals, axes_ndindex, normalise_axis +from .test_statistical_functions import assert_equals # TODO: Test with signed zeros and NaNs (and ignore them somehow) @@ -39,10 +40,10 @@ def test_argsort(x, data): assert_default_index("sort", out.dtype) ph.assert_shape("sort", out.shape, x.shape, **kw) axis = kw.get("axis", -1) - axes = normalise_axis(axis, x.ndim) + axes = sh.normalise_axis(axis, x.ndim) descending = kw.get("descending", False) scalar_type = dh.get_scalar_type(x.dtype) - for indices in axes_ndindex(x.shape, axes): + for indices in sh.axes_ndindex(x.shape, axes): elements = [scalar_type(x[idx]) for idx in indices] indices_order = sorted(range(len(indices)), key=elements.__getitem__) if descending: @@ -79,10 +80,10 @@ def test_sort(x, data): ph.assert_dtype("sort", out.dtype, x.dtype) ph.assert_shape("sort", out.shape, x.shape, **kw) axis = kw.get("axis", -1) - axes = normalise_axis(axis, x.ndim) + axes = sh.normalise_axis(axis, x.ndim) descending = kw.get("descending", False) scalar_type = dh.get_scalar_type(x.dtype) - for indices in axes_ndindex(x.shape, axes): + for indices in sh.axes_ndindex(x.shape, axes): elements = [scalar_type(x[idx]) for idx in indices] indices_order = sorted( range(len(indices)), key=elements.__getitem__, reverse=descending diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index c0a03a3d..c858b6c6 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,16 +1,15 @@ import math -from itertools import product -from typing import Iterator, List, Optional, Tuple, Union +from typing import Optional, Tuple from hypothesis import assume, given from hypothesis import strategies as st from hypothesis.control import reject from . import _array_module as xp -from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps from .typing import DataType, Scalar, ScalarType, Shape @@ -20,39 +19,6 @@ def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: return st.none() | st.sampled_from(dtypes) -def normalise_axis( - axis: Optional[Union[int, Tuple[int, ...]]], ndim: int -) -> Tuple[int, ...]: - if axis is None: - return tuple(range(ndim)) - axes = axis if isinstance(axis, tuple) else (axis,) - axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes) - return axes - - -def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]: - """Generate indices that index all elements except in `axes` dimensions""" - base_indices = [] - axes_indices = [] - for axis, side in enumerate(shape): - if axis in axes: - base_indices.append([None]) - axes_indices.append(range(side)) - else: - base_indices.append(range(side)) - axes_indices.append([None]) - for base_idx in product(*base_indices): - indices = [] - for idx in product(*axes_indices): - idx = list(idx) - for axis, side in enumerate(idx): - if axis not in axes: - idx[axis] = base_idx[axis] - idx = tuple(idx) - indices.append(idx) - yield list(indices) - - def assert_keepdimable_shape( func_name: str, out_shape: Shape, @@ -105,12 +71,12 @@ def test_max(x, data): out = xp.max(x, **kw) ph.assert_dtype("max", x.dtype, out.dtype) - _axes = normalise_axis(kw.get("axis", None), x.ndim) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( "max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) - for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): max_ = scalar_type(out[out_idx]) elements = [] for idx in indices: @@ -134,11 +100,11 @@ def test_mean(x, data): out = xp.mean(x, **kw) ph.assert_dtype("mean", x.dtype, out.dtype) - _axes = normalise_axis(kw.get("axis", None), x.ndim) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( "mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) - for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): mean = float(out[out_idx]) assume(not math.isinf(mean)) # mean may become inf due to internal overflows elements = [] @@ -163,12 +129,12 @@ def test_min(x, data): out = xp.min(x, **kw) ph.assert_dtype("min", x.dtype, out.dtype) - _axes = normalise_axis(kw.get("axis", None), x.ndim) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( "min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) - for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): min_ = scalar_type(out[out_idx]) elements = [] for idx in indices: @@ -222,12 +188,12 @@ def test_prod(x, data): else: _dtype = dtype ph.assert_dtype("prod", x.dtype, out.dtype, _dtype) - _axes = normalise_axis(kw.get("axis", None), x.ndim) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( "prod", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) - for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): prod = scalar_type(out[out_idx]) assume(math.isfinite(prod)) elements = [] @@ -251,7 +217,7 @@ def test_prod(x, data): ) def test_std(x, data): axis = data.draw(hh.axes(x.ndim), label="axis") - _axes = normalise_axis(axis, x.ndim) + _axes = sh.normalise_axis(axis, x.ndim) N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) correction = data.draw( st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), @@ -320,12 +286,12 @@ def test_sum(x, data): else: _dtype = dtype ph.assert_dtype("sum", x.dtype, out.dtype, _dtype) - _axes = normalise_axis(kw.get("axis", None), x.ndim) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( "sum", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) - for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): sum_ = scalar_type(out[out_idx]) assume(math.isfinite(sum_)) elements = [] @@ -349,7 +315,7 @@ def test_sum(x, data): ) def test_var(x, data): axis = data.draw(hh.axes(x.ndim), label="axis") - _axes = normalise_axis(axis, x.ndim) + _axes = sh.normalise_axis(axis, x.ndim) N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) correction = data.draw( st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index fe95ca01..5c107ecb 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -2,17 +2,12 @@ from hypothesis import strategies as st from . import _array_module as xp -from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps -from .test_statistical_functions import ( - assert_equals, - assert_keepdimable_shape, - axes_ndindex, - normalise_axis, -) +from .test_statistical_functions import assert_equals, assert_keepdimable_shape @given( @@ -25,12 +20,12 @@ def test_all(x, data): out = xp.all(x, **kw) ph.assert_dtype("all", x.dtype, out.dtype, xp.bool) - _axes = normalise_axis(kw.get("axis", None), x.ndim) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( "all", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(x.dtype) - for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): result = bool(out[out_idx]) elements = [] for idx in indices: @@ -50,12 +45,12 @@ def test_any(x, data): out = xp.any(x, **kw) ph.assert_dtype("any", x.dtype, out.dtype, xp.bool) - _axes = normalise_axis(kw.get("axis", None), x.ndim) + _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( "any", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(x.dtype) - for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): result = bool(out[out_idx]) elements = [] for idx in indices: From a5fd48f25f576a45b72e4ba0755ae94369271f80 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 16 Dec 2021 16:52:17 +0000 Subject: [PATCH 58/60] Move assertion helpers to `pytest_helpers.py` --- array_api_tests/pytest_helpers.py | 67 ++++++++++++++++++- .../test_manipulation_functions.py | 19 +----- array_api_tests/test_searching_functions.py | 34 ++++------ array_api_tests/test_set_functions.py | 15 +++-- array_api_tests/test_sorting.py | 9 +-- array_api_tests/test_statistical_functions.py | 66 ++++-------------- array_api_tests/test_utility_functions.py | 9 ++- 7 files changed, 111 insertions(+), 108 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index fa9e8b87..c8fd1fdb 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -1,12 +1,13 @@ -from .algos import broadcast_shapes import math from inspect import getfullargspec from typing import Any, Dict, Optional, Tuple, Union +from . import _array_module as xp from . import array_helpers as ah from . import dtype_helpers as dh from . import function_stubs -from .typing import Array, DataType, Scalar, Shape +from .algos import broadcast_shapes +from .typing import Array, DataType, Scalar, ScalarType, Shape __all__ = [ "raises", @@ -17,8 +18,10 @@ "assert_kw_dtype", "assert_default_float", "assert_default_int", + "assert_default_index", "assert_shape", "assert_result_shape", + "assert_keepdimable_shape", "assert_fill", ] @@ -117,6 +120,15 @@ def assert_default_int(func_name: str, dtype: DataType): assert dtype == dh.default_int, msg +def assert_default_index(func_name: str, dtype: DataType, repr_name="out.dtype"): + f_dtype = dh.dtype_to_name[dtype] + msg = ( + f"{repr_name}={f_dtype}, should be the default index dtype, " + f"which is either int32 or int64 [{func_name}()]" + ) + assert dtype in (xp.int32, xp.int64), msg + + def assert_shape( func_name: str, out_shape: Union[int, Shape], @@ -155,6 +167,57 @@ def assert_result_shape( assert out_shape == expected, msg +def assert_keepdimable_shape( + func_name: str, + out_shape: Shape, + in_shape: Shape, + axes: Tuple[int, ...], + keepdims: bool, + /, + **kw, +): + if keepdims: + shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape)) + else: + shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes) + assert_shape(func_name, out_shape, shape, **kw) + + +def assert_0d_equals( + func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw +): + msg = ( + f"{out_repr}={out_val}, should be {x_repr}={x_val} " + f"[{func_name}({fmt_kw(kw)})]" + ) + if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val): + assert xp.isnan(x_val), msg + else: + assert x_val == out_val, msg + + +def assert_scalar_equals( + func_name: str, + type_: ScalarType, + idx: Shape, + out: Scalar, + expected: Scalar, + /, + **kw, +): + out_repr = "out" if idx == () else f"out[{idx}]" + f_func = f"{func_name}({fmt_kw(kw)})" + if type_ is bool or type_ is int: + msg = f"{out_repr}={out}, should be {expected} [{f_func}]" + assert out == expected, msg + elif math.isnan(expected): + msg = f"{out_repr}={out}, should be {expected} [{f_func}]" + assert math.isnan(out), msg + else: + msg = f"{out_repr}={out}, should be roughly {expected} [{f_func}]" + assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg + + def assert_fill( func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw ): diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 433b096a..da2ba385 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -44,19 +44,6 @@ def assert_array_ndindex( assert out[out_idx] == x[x_idx], msg -def assert_equals( - func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw -): - msg = ( - f"{out_repr}={out_val}, should be {x_repr}={x_val} " - f"[{func_name}({ph.fmt_kw(kw)})]" - ) - if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val): - assert xp.isnan(x_val), msg - else: - assert x_val == out_val, msg - - @st.composite def concat_shapes(draw, shape, axis): shape = list(shape) @@ -104,7 +91,7 @@ def test_concat(dtypes, kw, data): for x_num, x in enumerate(arrays, 1): for x_idx in sh.ndindex(x.shape): out_i = next(out_indices) - assert_equals( + ph.assert_0d_equals( "concat", f"x{x_num}[{x_idx}]", x[x_idx], @@ -120,7 +107,7 @@ def test_concat(dtypes, kw, data): indexed_x = x[idx] for x_idx in sh.ndindex(indexed_x.shape): out_idx = next(out_indices) - assert_equals( + ph.assert_0d_equals( "concat", f"x{x_num}[{f_idx}][{x_idx}]", indexed_x[x_idx], @@ -360,7 +347,7 @@ def test_stack(shape, dtypes, kw, data): indexed_x = x[idx] for x_idx in sh.ndindex(indexed_x.shape): out_idx = next(out_indices) - assert_equals( + ph.assert_0d_equals( "stack", f"x{x_num}[{f_idx}][{x_idx}]", indexed_x[x_idx], diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index dff4590f..244e7c24 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -8,18 +8,6 @@ from . import shape_helpers as sh from . import xps from .algos import broadcast_shapes -from .test_manipulation_functions import assert_equals as assert_equals_ -from .test_statistical_functions import assert_equals, assert_keepdimable_shape -from .typing import DataType - - -def assert_default_index(func_name: str, dtype: DataType, repr_name="out.dtype"): - f_dtype = dh.dtype_to_name[dtype] - msg = ( - f"{repr_name}={f_dtype}, should be the default index dtype, " - f"which is either int32 or int64 [{func_name}()]" - ) - assert dtype in (xp.int32, xp.int64), msg @given( @@ -41,9 +29,9 @@ def test_argmax(x, data): out = xp.argmax(x, **kw) - assert_default_index("argmax", out.dtype) + ph.assert_default_index("argmax", out.dtype) axes = sh.normalise_axis(kw.get("axis", None), x.ndim) - assert_keepdimable_shape( + ph.assert_keepdimable_shape( "argmax", out.shape, x.shape, axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(x.dtype) @@ -54,7 +42,7 @@ def test_argmax(x, data): s = scalar_type(x[idx]) elements.append(s) expected = max(range(len(elements)), key=elements.__getitem__) - assert_equals("argmax", int, out_idx, max_i, expected) + ph.assert_scalar_equals("argmax", int, out_idx, max_i, expected) @given( @@ -76,9 +64,9 @@ def test_argmin(x, data): out = xp.argmin(x, **kw) - assert_default_index("argmin", out.dtype) + ph.assert_default_index("argmin", out.dtype) axes = sh.normalise_axis(kw.get("axis", None), x.ndim) - assert_keepdimable_shape( + ph.assert_keepdimable_shape( "argmin", out.shape, x.shape, axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(x.dtype) @@ -89,7 +77,7 @@ def test_argmin(x, data): s = scalar_type(x[idx]) elements.append(s) expected = min(range(len(elements)), key=elements.__getitem__) - assert_equals("argmin", int, out_idx, min_i, expected) + ph.assert_scalar_equals("argmin", int, out_idx, min_i, expected) # TODO: skip if opted out @@ -106,7 +94,7 @@ def test_nonzero(x): assert ( out[i].size == size ), f"out[{i}].size={x.size}, but should be out[0].size={size}" - assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype") + ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype") indices = [] if x.dtype == xp.bool: for idx in sh.ndindex(x.shape): @@ -151,6 +139,10 @@ def test_where(shapes, dtypes, data): _x2 = xp.broadcast_to(x2, shape) for idx in sh.ndindex(shape): if _cond[idx]: - assert_equals_("where", f"_x1[{idx}]", _x1[idx], f"out[{idx}]", out[idx]) + ph.assert_0d_equals( + "where", f"_x1[{idx}]", _x1[idx], f"out[{idx}]", out[idx] + ) else: - assert_equals_("where", f"_x2[{idx}]", _x2[idx], f"out[{idx}]", out[idx]) + ph.assert_0d_equals( + "where", f"_x2[{idx}]", _x2[idx], f"out[{idx}]", out[idx] + ) diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 214df6d0..02660c6a 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -10,7 +10,6 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .test_searching_functions import assert_default_index @given( @@ -29,11 +28,15 @@ def test_unique_all(x): ph.assert_dtype( "unique_all", x.dtype, out.values.dtype, repr_name="out.values.dtype" ) - assert_default_index("unique_all", out.indices.dtype, repr_name="out.indices.dtype") - assert_default_index( + ph.assert_default_index( + "unique_all", out.indices.dtype, repr_name="out.indices.dtype" + ) + ph.assert_default_index( "unique_all", out.inverse_indices.dtype, repr_name="out.inverse_indices.dtype" ) - assert_default_index("unique_all", out.counts.dtype, repr_name="out.counts.dtype") + ph.assert_default_index( + "unique_all", out.counts.dtype, repr_name="out.counts.dtype" + ) assert ( out.indices.shape == out.values.shape @@ -121,7 +124,7 @@ def test_unique_counts(x): ph.assert_dtype( "unique_counts", x.dtype, out.values.dtype, repr_name="out.values.dtype" ) - assert_default_index( + ph.assert_default_index( "unique_counts", out.counts.dtype, repr_name="out.counts.dtype" ) assert ( @@ -168,7 +171,7 @@ def test_unique_inverse(x): ph.assert_dtype( "unique_inverse", x.dtype, out.values.dtype, repr_name="out.values.dtype" ) - assert_default_index( + ph.assert_default_index( "unique_inverse", out.inverse_indices.dtype, repr_name="out.inverse_indices.dtype", diff --git a/array_api_tests/test_sorting.py b/array_api_tests/test_sorting.py index 2578ea99..0c7334cc 100644 --- a/array_api_tests/test_sorting.py +++ b/array_api_tests/test_sorting.py @@ -8,9 +8,6 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .test_manipulation_functions import assert_equals as assert_equals_ -from .test_searching_functions import assert_default_index -from .test_statistical_functions import assert_equals # TODO: Test with signed zeros and NaNs (and ignore them somehow) @@ -37,7 +34,7 @@ def test_argsort(x, data): out = xp.argsort(x, **kw) - assert_default_index("sort", out.dtype) + ph.assert_default_index("sort", out.dtype) ph.assert_shape("sort", out.shape, x.shape, **kw) axis = kw.get("axis", -1) axes = sh.normalise_axis(axis, x.ndim) @@ -50,7 +47,7 @@ def test_argsort(x, data): # sorted(..., reverse=descending) doesn't always work indices_order = reversed(indices_order) for idx, o in zip(indices, indices_order): - assert_equals("argsort", int, idx, int(out[idx]), o) + ph.assert_scalar_equals("argsort", int, idx, int(out[idx]), o) # TODO: Test with signed zeros and NaNs (and ignore them somehow) @@ -90,7 +87,7 @@ def test_sort(x, data): ) x_indices = [indices[o] for o in indices_order] for out_idx, x_idx in zip(indices, x_indices): - assert_equals_( + ph.assert_0d_equals( "sort", f"x[{x_idx}]", x[x_idx], diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index c858b6c6..c2fb33db 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Tuple +from typing import Optional from hypothesis import assume, given from hypothesis import strategies as st @@ -11,7 +11,7 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .typing import DataType, Scalar, ScalarType, Shape +from .typing import DataType def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: @@ -19,44 +19,6 @@ def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: return st.none() | st.sampled_from(dtypes) -def assert_keepdimable_shape( - func_name: str, - out_shape: Shape, - in_shape: Shape, - axes: Tuple[int, ...], - keepdims: bool, - /, - **kw, -): - if keepdims: - shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape)) - else: - shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes) - ph.assert_shape(func_name, out_shape, shape, **kw) - - -def assert_equals( - func_name: str, - type_: ScalarType, - idx: Shape, - out: Scalar, - expected: Scalar, - /, - **kw, -): - out_repr = "out" if idx == () else f"out[{idx}]" - f_func = f"{func_name}({ph.fmt_kw(kw)})" - if type_ is bool or type_ is int: - msg = f"{out_repr}={out}, should be {expected} [{f_func}]" - assert out == expected, msg - elif math.isnan(expected): - msg = f"{out_repr}={out}, should be {expected} [{f_func}]" - assert math.isnan(out), msg - else: - msg = f"{out_repr}={out}, should be roughly {expected} [{f_func}]" - assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg - - @given( x=xps.arrays( dtype=xps.numeric_dtypes(), @@ -72,7 +34,7 @@ def test_max(x, data): ph.assert_dtype("max", x.dtype, out.dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) - assert_keepdimable_shape( + ph.assert_keepdimable_shape( "max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) @@ -83,7 +45,7 @@ def test_max(x, data): s = scalar_type(x[idx]) elements.append(s) expected = max(elements) - assert_equals("max", scalar_type, out_idx, max_, expected) + ph.assert_scalar_equals("max", scalar_type, out_idx, max_, expected) @given( @@ -101,7 +63,7 @@ def test_mean(x, data): ph.assert_dtype("mean", x.dtype, out.dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) - assert_keepdimable_shape( + ph.assert_keepdimable_shape( "mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): @@ -112,7 +74,7 @@ def test_mean(x, data): s = float(x[idx]) elements.append(s) expected = sum(elements) / len(elements) - assert_equals("mean", float, out_idx, mean, expected) + ph.assert_scalar_equals("mean", float, out_idx, mean, expected) @given( @@ -130,7 +92,7 @@ def test_min(x, data): ph.assert_dtype("min", x.dtype, out.dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) - assert_keepdimable_shape( + ph.assert_keepdimable_shape( "min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) @@ -141,7 +103,7 @@ def test_min(x, data): s = scalar_type(x[idx]) elements.append(s) expected = min(elements) - assert_equals("min", scalar_type, out_idx, min_, expected) + ph.assert_scalar_equals("min", scalar_type, out_idx, min_, expected) @given( @@ -189,7 +151,7 @@ def test_prod(x, data): _dtype = dtype ph.assert_dtype("prod", x.dtype, out.dtype, _dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) - assert_keepdimable_shape( + ph.assert_keepdimable_shape( "prod", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) @@ -204,7 +166,7 @@ def test_prod(x, data): if dh.is_int_dtype(out.dtype): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) - assert_equals("prod", scalar_type, out_idx, prod, expected) + ph.assert_scalar_equals("prod", scalar_type, out_idx, prod, expected) @given( @@ -236,7 +198,7 @@ def test_std(x, data): out = xp.std(x, **kw) ph.assert_dtype("std", x.dtype, out.dtype) - assert_keepdimable_shape( + ph.assert_keepdimable_shape( "std", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) # We can't easily test the result(s) as standard deviation methods vary a lot @@ -287,7 +249,7 @@ def test_sum(x, data): _dtype = dtype ph.assert_dtype("sum", x.dtype, out.dtype, _dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) - assert_keepdimable_shape( + ph.assert_keepdimable_shape( "sum", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) @@ -302,7 +264,7 @@ def test_sum(x, data): if dh.is_int_dtype(out.dtype): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) - assert_equals("sum", scalar_type, out_idx, sum_, expected) + ph.assert_scalar_equals("sum", scalar_type, out_idx, sum_, expected) @given( @@ -334,7 +296,7 @@ def test_var(x, data): out = xp.var(x, **kw) ph.assert_dtype("var", x.dtype, out.dtype) - assert_keepdimable_shape( + ph.assert_keepdimable_shape( "var", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) # We can't easily test the result(s) as variance methods vary a lot diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index 5c107ecb..c10d0dbd 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -7,7 +7,6 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .test_statistical_functions import assert_equals, assert_keepdimable_shape @given( @@ -21,7 +20,7 @@ def test_all(x, data): ph.assert_dtype("all", x.dtype, out.dtype, xp.bool) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) - assert_keepdimable_shape( + ph.assert_keepdimable_shape( "all", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(x.dtype) @@ -32,7 +31,7 @@ def test_all(x, data): s = scalar_type(x[idx]) elements.append(s) expected = all(elements) - assert_equals("all", scalar_type, out_idx, result, expected) + ph.assert_scalar_equals("all", scalar_type, out_idx, result, expected) @given( @@ -46,7 +45,7 @@ def test_any(x, data): ph.assert_dtype("any", x.dtype, out.dtype, xp.bool) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) - assert_keepdimable_shape( + ph.assert_keepdimable_shape( "any", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(x.dtype) @@ -57,4 +56,4 @@ def test_any(x, data): s = scalar_type(x[idx]) elements.append(s) expected = any(elements) - assert_equals("any", scalar_type, out_idx, result, expected) + ph.assert_scalar_equals("any", scalar_type, out_idx, result, expected) From 4ad9e94cfbee0b2d9a4c2d17d2cf101e5b0ef9ce Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 17 Dec 2021 12:43:54 +0000 Subject: [PATCH 59/60] Fix `test_roll` with bespoke axis iterator --- array_api_tests/meta/test_utils.py | 14 ++++++++ .../test_manipulation_functions.py | 32 +++++++++++-------- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 814c62cb..3b28b9a9 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -2,6 +2,7 @@ from .. import shape_helpers as sh from ..test_creation_functions import frange +from ..test_manipulation_functions import roll_ndindex from ..test_signatures import extension_module @@ -68,3 +69,16 @@ def test_axis_ndindex(shape, axis, expected): ) def test_axes_ndindex(shape, axes, expected): assert list(sh.axes_ndindex(shape, axes)) == expected + + +@pytest.mark.parametrize( + "shape, shifts, axes, expected", + [ + ((1, 1), (0,), (0,), [(0, 0)]), + ((2, 1), (1, 1), (0, 1), [(1, 0), (0, 0)]), + ((2, 2), (1, 1), (0, 1), [(1, 1), (1, 0), (0, 1), (0, 0)]), + ((2, 2), (-1, 1), (0, 1), [(1, 1), (1, 0), (0, 1), (0, 0)]), + ], +) +def test_roll_ndindex(shape, shifts, axes, expected): + assert list(roll_ndindex(shape, shifts, axes)) == expected diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index da2ba385..5cc68f80 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,6 +1,6 @@ import math from collections import deque -from typing import Iterable, Union +from typing import Iterable, Iterator, Tuple, Union import pytest from hypothesis import assume, given @@ -33,8 +33,10 @@ def assert_array_ndindex( x_indices: Iterable[Union[int, Shape]], out: Array, out_indices: Iterable[Union[int, Shape]], + /, + **kw, ): - msg_suffix = f" [{func_name}()]\n {x=}\n{out=}" + msg_suffix = f" [{func_name}({ph.fmt_kw(kw)})]\n {x=}\n{out=}" for x_idx, out_idx in zip(x_indices, out_indices): msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" msg += msg_suffix @@ -266,7 +268,15 @@ def test_reshape(x, data): assert_array_ndindex("reshape", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)) -@pytest.mark.skip(reason="faulty test logic") # TODO +def roll_ndindex(shape: Shape, shifts: Tuple[int], axes: Tuple[int]) -> Iterator[Shape]: + assert len(shifts) == len(axes) # sanity check + all_shifts = [0 for _ in shape] + for s, a in zip(shifts, axes): + all_shifts[a] = s + for idx in sh.ndindex(shape): + yield tuple((i + sh) % si for i, sh, si in zip(idx, all_shifts, shape)) + + @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data()) def test_roll(x, data): shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE) @@ -287,6 +297,8 @@ def test_roll(x, data): out = xp.roll(x, shift, **kw) + kw = {"shift": shift, **kw} # for error messages + ph.assert_dtype("roll", x.dtype, out.dtype) ph.assert_result_shape("roll", (x.shape,), out.shape) @@ -296,18 +308,12 @@ def test_roll(x, data): indices = list(sh.ndindex(x.shape)) shifted_indices = deque(indices) shifted_indices.rotate(-shift) - assert_array_ndindex("roll", x, indices, out, shifted_indices) + assert_array_ndindex("roll", x, indices, out, shifted_indices, **kw) else: - _shift = (shift,) if isinstance(shift, int) else shift + shifts = (shift,) if isinstance(shift, int) else shift axes = sh.normalise_axis(kw["axis"], x.ndim) - all_indices = list(sh.ndindex(x.shape)) - for s, a in zip(_shift, axes): - side = x.shape[a] - for i in range(side): - indices = [idx for idx in all_indices if idx[a] == i] - shifted_indices = deque(indices) - shifted_indices.rotate(-s) - assert_array_ndindex("roll", x, indices, out, shifted_indices) + shifted_indices = roll_ndindex(x.shape, shifts, axes) + assert_array_ndindex("roll", x, sh.ndindex(x.shape), out, shifted_indices, **kw) @given( From 318b8ddd42d87640ae1b617c542dc797257b21b0 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 21 Dec 2021 10:34:28 +0000 Subject: [PATCH 60/60] Rename file --- array_api_tests/{test_sorting.py => test_sorting_functions.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename array_api_tests/{test_sorting.py => test_sorting_functions.py} (100%) diff --git a/array_api_tests/test_sorting.py b/array_api_tests/test_sorting_functions.py similarity index 100% rename from array_api_tests/test_sorting.py rename to array_api_tests/test_sorting_functions.py