From 5a1a19fb0bf700e29e90cc0b2b86e405f4cf22a5 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 9 Nov 2021 09:57:59 +0000 Subject: [PATCH 01/41] Move manipulation tests to their own file --- .../test_manipulation_functions.py | 38 +++++++++++++++++++ array_api_tests/test_type_promotion.py | 30 +-------------- 2 files changed, 39 insertions(+), 29 deletions(-) create mode 100644 array_api_tests/test_manipulation_functions.py diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py new file mode 100644 index 00000000..79395feb --- /dev/null +++ b/array_api_tests/test_manipulation_functions.py @@ -0,0 +1,38 @@ +from hypothesis import given +from hypothesis import strategies as st + +from . import _array_module as xp +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import xps + + +@given( + shape=hh.shapes(min_dims=1), + dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), + data=st.data(), +) +def test_concat(shape, dtypes, data): + arrays = [] + for i, dtype in enumerate(dtypes, 1): + x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") + arrays.append(x) + out = xp.concat(arrays) + ph.assert_dtype("concat", dtypes, out.dtype) + # TODO + + +@given( + shape=hh.shapes(), + dtypes=hh.mutually_promotable_dtypes(None), + data=st.data(), +) +def test_stack(shape, dtypes, data): + arrays = [] + for i, dtype in enumerate(dtypes, 1): + x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") + arrays.append(x) + out = xp.stack(arrays) + ph.assert_dtype("stack", dtypes, out.dtype) + # TODO diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index d304071a..22af9526 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -19,7 +19,7 @@ # TODO: move tests not covering elementwise funcs/ops into standalone tests -# result_type, meshgrid, concat, stack, where, tensordor, vecdot +# result_type, meshgrid, where, tensordor, vecdot @given(hh.mutually_promotable_dtypes(None)) @@ -51,34 +51,6 @@ def test_meshgrid(dtypes, data): ph.assert_dtype("meshgrid", dtypes, x.dtype, repr_name=f"out[{i}].dtype") -@given( - shape=hh.shapes(min_dims=1), - dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), - data=st.data(), -) -def test_concat(shape, dtypes, data): - arrays = [] - for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") - arrays.append(x) - out = xp.concat(arrays) - ph.assert_dtype("concat", dtypes, out.dtype) - - -@given( - shape=hh.shapes(), - dtypes=hh.mutually_promotable_dtypes(None), - data=st.data(), -) -def test_stack(shape, dtypes, data): - arrays = [] - for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") - arrays.append(x) - out = xp.stack(arrays) - ph.assert_dtype("stack", dtypes, out.dtype) - - bitwise_shift_funcs = [ "bitwise_left_shift", "bitwise_right_shift", From 6dd7f3d58a73026a2efc681d257216d454972a21 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 9 Nov 2021 12:10:59 +0000 Subject: [PATCH 02/41] Minor `test_concat` improvements --- array_api_tests/test_manipulation_functions.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 79395feb..9a7d1d7f 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,3 +1,5 @@ +import math + from hypothesis import given from hypothesis import strategies as st @@ -11,16 +13,23 @@ @given( shape=hh.shapes(min_dims=1), dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), + kw=hh.kwargs(axis=st.just(0) | st.none()), # TODO: test with axis >= 1 data=st.data(), ) -def test_concat(shape, dtypes, data): +def test_concat(shape, dtypes, kw, data): arrays = [] for i, dtype in enumerate(dtypes, 1): x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") arrays.append(x) - out = xp.concat(arrays) + out = xp.concat(arrays, **kw) ph.assert_dtype("concat", dtypes, out.dtype) - # TODO + shapes = tuple(x.shape for x in arrays) + if kw.get("axis", 0) == 0: + pass # TODO: assert expected shape + elif kw["axis"] is None: + size = sum(math.prod(s) for s in shapes) + ph.assert_result_shape("concat", shapes, out.shape, (size,), **kw) + # TODO: assert out elements match input arrays @given( From e5bb97442b759871d8ec013bf7dbb4a9e6f64552 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 9 Nov 2021 19:07:08 +0000 Subject: [PATCH 03/41] Smoke all manipulation methods --- .../test_manipulation_functions.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 9a7d1d7f..eea4bb00 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -9,6 +9,8 @@ from . import pytest_helpers as ph from . import xps +shared_shapes = st.shared(hh.shapes(), key="shape") + @given( shape=hh.shapes(min_dims=1), @@ -32,6 +34,81 @@ def test_concat(shape, dtypes, kw, data): # TODO: assert out elements match input arrays +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), + axis=shared_shapes.flatmap(lambda s: st.integers(-len(s), len(s))), +) +def test_expand_dims(x, axis): + xp.expand_dims(x, axis=axis) + # TODO + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), + kw=hh.kwargs( + axis=st.one_of( + st.none(), + shared_shapes.flatmap( + lambda s: st.none() + if len(s) == 0 + else st.integers(-len(s) + 1, len(s) - 1), + ), + ) + ), +) +def test_flip(x, kw): + xp.flip(x, **kw) + # TODO + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), + axes=shared_shapes.flatmap( + lambda s: st.lists( + st.integers(0, max(len(s) - 1, 0)), + min_size=len(s), + max_size=len(s), + unique=True, + ).map(tuple) + ), +) +def test_permute_dims(x, axes): + xp.permute_dims(x, axes) + # TODO + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), + shape=shared_shapes, # TODO: test more compatible shapes +) +def test_reshape(x, shape): + xp.reshape(x, shape) + # TODO + + +@given( + # TODO: axis arguments, update shift respectively + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), + shift=shared_shapes.flatmap(lambda s: st.integers(0, max(math.prod(s) - 1, 0))), +) +def test_roll(x, shift): + xp.roll(x, shift) + # TODO + + +@given( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), + axis=shared_shapes.flatmap( + lambda s: st.just(0) + if len(s) == 0 + else st.integers(-len(s) + 1, len(s) - 1).filter(lambda i: s[i] == 1) + ), # TODO: tuple of axis i.e. axes +) +def test_squeeze(x, axis): + xp.squeeze(x, axis) + # TODO + + @given( shape=hh.shapes(), dtypes=hh.mutually_promotable_dtypes(None), From 41a0292c1d396dc876fa33d8b3db40c81ac4f3aa Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 9 Nov 2021 19:45:17 +0000 Subject: [PATCH 04/41] Smoke statistical functions --- array_api_tests/test_statistical_functions.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 array_api_tests/test_statistical_functions.py diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py new file mode 100644 index 00000000..7455123c --- /dev/null +++ b/array_api_tests/test_statistical_functions.py @@ -0,0 +1,54 @@ +from hypothesis import given + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) +def test_min(x): + xp.min(x) + # TODO + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) +def test_max(x): + xp.max(x) + # TODO + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) +def test_mean(x): + xp.mean(x) + # TODO + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) +def test_prod(x): + xp.prod(x) + # TODO + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) +def test_std(x): + xp.std(x) + # TODO + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) +def test_sum(x): + xp.sum(x) + # TODO + + +# TODO generate kwargs +@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) +def test_var(x): + xp.var(x) + # TODO From 97997eecdb5415b19d1115fb79f3e8ae12b526c5 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 9 Nov 2021 20:40:18 +0000 Subject: [PATCH 05/41] Smoke searching functions --- array_api_tests/test_searching_functions.py | 41 +++++++++++++++++++ array_api_tests/test_statistical_functions.py | 14 +++---- array_api_tests/test_type_promotion.py | 2 +- 3 files changed, 49 insertions(+), 8 deletions(-) create mode 100644 array_api_tests/test_searching_functions.py diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py new file mode 100644 index 00000000..c3686bb7 --- /dev/null +++ b/array_api_tests/test_searching_functions.py @@ -0,0 +1,41 @@ +from hypothesis import given +from hypothesis import strategies as st + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +def test_argmin(x): + xp.argmin(x) + # TODO + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +def test_argmax(x): + xp.argmax(x) + # TODO + + +# TODO: generate kwargs, skip if opted out +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +def test_nonzero(x): + xp.nonzero(x) + # TODO + + +# TODO: skip if opted out +@given( + shapes=hh.mutually_broadcastable_shapes(3), + dtypes=hh.mutually_promotable_dtypes(), + data=st.data(), +) +def test_where(shapes, dtypes, data): + cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[0]), label="condition") + x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1") + x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2") + xp.where(cond, x1, x2) + # TODO diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 7455123c..98cade7d 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -5,49 +5,49 @@ from . import xps -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) def test_min(x): xp.min(x) # TODO -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) def test_max(x): xp.max(x) # TODO -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) def test_mean(x): xp.mean(x) # TODO -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) def test_prod(x): xp.prod(x) # TODO -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) def test_std(x): xp.std(x) # TODO -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) def test_sum(x): xp.sum(x) # TODO -# TODO generate kwargs +# TODO: generate kwargs @given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) def test_var(x): xp.var(x) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 22af9526..2fb669e6 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -19,7 +19,7 @@ # TODO: move tests not covering elementwise funcs/ops into standalone tests -# result_type, meshgrid, where, tensordor, vecdot +# result_type, meshgrid, tensordor, vecdot @given(hh.mutually_promotable_dtypes(None)) From 3efff6c0598d1c6e69ba955c5890b41f6708c64b Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 10:09:49 +0000 Subject: [PATCH 06/41] Smoke remaining functions --- array_api_tests/test_set_functions.py | 29 +++++++++++++++++++++++ array_api_tests/test_sorting.py | 19 +++++++++++++++ array_api_tests/test_utility_functions.py | 19 +++++++++++++++ 3 files changed, 67 insertions(+) create mode 100644 array_api_tests/test_set_functions.py create mode 100644 array_api_tests/test_sorting.py create mode 100644 array_api_tests/test_utility_functions.py diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py new file mode 100644 index 00000000..856a7282 --- /dev/null +++ b/array_api_tests/test_set_functions.py @@ -0,0 +1,29 @@ +from hypothesis import given + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_unique_all(x): + xp.unique_all(x) + # TODO + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_unique_counts(x): + xp.unique_counts(x) + # TODO + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_unique_inverse(x): + xp.unique_inverse(x) + # TODO + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_unique_values(x): + xp.unique_values(x) + # TODO diff --git a/array_api_tests/test_sorting.py b/array_api_tests/test_sorting.py new file mode 100644 index 00000000..58179b3c --- /dev/null +++ b/array_api_tests/test_sorting.py @@ -0,0 +1,19 @@ +from hypothesis import given + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_argsort(x): + xp.argsort(x) + # TODO + + +# TODO: generate 0d arrays, generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1))) +def test_sort(x): + xp.sort(x) + # TODO diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py new file mode 100644 index 00000000..140aa85f --- /dev/null +++ b/array_api_tests/test_utility_functions.py @@ -0,0 +1,19 @@ +from hypothesis import given + +from . import _array_module as xp +from . import hypothesis_helpers as hh +from . import xps + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_any(x): + xp.any(x) + # TODO + + +# TODO: generate kwargs +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) +def test_all(x): + xp.all(x) + # TODO From 08b23917d97d4cf6d16656ae06e8f8bc3242e938 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 10:56:50 +0000 Subject: [PATCH 07/41] Test `expand_dims()` --- array_api_tests/pytest_helpers.py | 8 ++++---- array_api_tests/test_manipulation_functions.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 9424ba35..b138af3e 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -67,12 +67,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str: def assert_dtype( func_name: str, - in_dtypes: Tuple[DataType, ...], + in_dtypes: Union[DataType, Tuple[DataType, ...]], out_dtype: DataType, expected: Optional[DataType] = None, *, repr_name: str = "out.dtype", ): + if not isinstance(in_dtypes, tuple): + in_dtypes = (in_dtypes,) f_in_dtypes = dh.fmt_types(in_dtypes) f_out_dtype = dh.dtype_to_name[out_dtype] if expected is None: @@ -149,9 +151,7 @@ def assert_result_shape( f_sig = f" {f_in_shapes} " if kw: f_sig += f", {fmt_kw(kw)}" - msg = ( - f"{repr_name}={out_shape}, but should be {expected} [{func_name}({f_sig})]" - ) + msg = f"{repr_name}={out_shape}, but should be {expected} [{func_name}({f_sig})]" assert out_shape == expected, msg diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index eea4bb00..620e8968 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -39,8 +39,15 @@ def test_concat(shape, dtypes, kw, data): axis=shared_shapes.flatmap(lambda s: st.integers(-len(s), len(s))), ) def test_expand_dims(x, axis): - xp.expand_dims(x, axis=axis) - # TODO + out = xp.expand_dims(x, axis=axis) + + ph.assert_dtype("expand_dims", x.dtype, out.dtype) + + shape = [side for side in x.shape] + index = axis if axis >= 0 else x.ndim + axis + 1 + shape.insert(index, 1) + shape = tuple(shape) + ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape) @given( From 40a74bff0e718eeef3a14e897071dcadffa9a17b Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 12:03:20 +0000 Subject: [PATCH 08/41] Test `flip()` when `axis=None` --- .../test_manipulation_functions.py | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 620e8968..d5ec92df 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -4,6 +4,7 @@ from hypothesis import strategies as st from . import _array_module as xp +from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph @@ -50,22 +51,34 @@ def test_expand_dims(x, axis): ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape) +@st.composite +def flip_axis(draw, shape): + if len(shape) == 0 or draw(st.booleans()): + return None + else: + ndim = len(shape) + return draw(st.integers(-ndim, ndim - 1) | xps.valid_tuple_axes(ndim)) + + @given( x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - kw=hh.kwargs( - axis=st.one_of( - st.none(), - shared_shapes.flatmap( - lambda s: st.none() - if len(s) == 0 - else st.integers(-len(s) + 1, len(s) - 1), - ), - ) - ), + kw=hh.kwargs(axis=shared_shapes.flatmap(flip_axis)), ) def test_flip(x, kw): - xp.flip(x, **kw) - # TODO + out = xp.flip(x, **kw) + + ph.assert_dtype("expand_dims", x.dtype, out.dtype) + + # TODO: test all axis scenarios + if kw.get("axis", None) is None: + indices = list(ah.ndindex(x.shape)) + reverse_indices = indices[::-1] + for x_idx, out_idx in zip(indices, reverse_indices): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg @given( From d05a612575bb1669b6355e0faad4a41164bcd61f Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 15:41:13 +0000 Subject: [PATCH 09/41] Test `permute_dims()` --- .../test_manipulation_functions.py | 49 +++++++++++++------ 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index d5ec92df..08288787 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -9,8 +9,16 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps +from .typing import Shape -shared_shapes = st.shared(hh.shapes(), key="shape") + +def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: + key = "shape" + if args: + key += " " + " ".join(args) + if kwargs: + key += " " + ph.fmt_kw(kwargs) + return st.shared(hh.shapes(*args, **kwargs), key="shape") @given( @@ -36,8 +44,8 @@ def test_concat(shape, dtypes, kw, data): @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - axis=shared_shapes.flatmap(lambda s: st.integers(-len(s), len(s))), + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), + axis=shared_shapes().flatmap(lambda s: st.integers(-len(s), len(s))), ) def test_expand_dims(x, axis): out = xp.expand_dims(x, axis=axis) @@ -61,8 +69,8 @@ def flip_axis(draw, shape): @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - kw=hh.kwargs(axis=shared_shapes.flatmap(flip_axis)), + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), + kw=hh.kwargs(axis=shared_shapes().flatmap(flip_axis)), ) def test_flip(x, kw): out = xp.flip(x, **kw) @@ -82,10 +90,10 @@ def test_flip(x, kw): @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - axes=shared_shapes.flatmap( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes(min_dims=1)), + axes=shared_shapes(min_dims=1).flatmap( lambda s: st.lists( - st.integers(0, max(len(s) - 1, 0)), + st.integers(0, len(s) - 1), min_size=len(s), max_size=len(s), unique=True, @@ -93,13 +101,22 @@ def test_flip(x, kw): ), ) def test_permute_dims(x, axes): - xp.permute_dims(x, axes) - # TODO + out = xp.permute_dims(x, axes) + + ph.assert_dtype("permute_dims", x.dtype, out.dtype) + + shape = [None for _ in range(len(axes))] + for i, dim in enumerate(axes): + side = x.shape[dim] + shape[i] = side + assert all(isinstance(side, int) for side in shape) # sanity check + shape = tuple(shape) + ph.assert_result_shape("permute_dims", (x.shape,), out.shape, shape, axes=axes) @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - shape=shared_shapes, # TODO: test more compatible shapes + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), + shape=shared_shapes(), # TODO: test more compatible shapes ) def test_reshape(x, shape): xp.reshape(x, shape) @@ -108,8 +125,8 @@ def test_reshape(x, shape): @given( # TODO: axis arguments, update shift respectively - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - shift=shared_shapes.flatmap(lambda s: st.integers(0, max(math.prod(s) - 1, 0))), + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), + shift=shared_shapes().flatmap(lambda s: st.integers(0, max(math.prod(s) - 1, 0))), ) def test_roll(x, shift): xp.roll(x, shift) @@ -117,8 +134,8 @@ def test_roll(x, shift): @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes), - axis=shared_shapes.flatmap( + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), + axis=shared_shapes().flatmap( lambda s: st.just(0) if len(s) == 0 else st.integers(-len(s) + 1, len(s) - 1).filter(lambda i: s[i] == 1) From a9f9237bfbc10263add9cca3fe957e5a4558b0e8 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 18:16:22 +0000 Subject: [PATCH 10/41] Test `reshape()` --- .../test_manipulation_functions.py | 46 +++++++++++++++++-- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 08288787..4f7f3426 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,6 +1,6 @@ import math -from hypothesis import given +from hypothesis import assume, given from hypothesis import strategies as st from . import _array_module as xp @@ -113,15 +113,51 @@ def test_permute_dims(x, axes): shape = tuple(shape) ph.assert_result_shape("permute_dims", (x.shape,), out.shape, shape, axes=axes) + # TODO: test elements + + +MAX_RESHAPE_SIDE = hh.MAX_ARRAY_SIZE // 64 +reshape_x_shapes = st.shared( + hh.shapes().filter(lambda s: math.prod(s) <= MAX_RESHAPE_SIDE), + key="reshape x shape", +) + + +@st.composite +def reshape_shapes(draw, shape): + size = 1 if len(shape) == 0 else math.prod(shape) + rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size)) + assume(all(side <= MAX_RESHAPE_SIDE for side in rshape)) + if len(rshape) != 0 and size > 0 and draw(st.booleans()): + index = draw(st.integers(0, len(rshape) - 1)) + rshape[index] = -1 + return tuple(rshape) + @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), - shape=shared_shapes(), # TODO: test more compatible shapes + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=reshape_x_shapes), + shape=reshape_x_shapes.flatmap(reshape_shapes), ) def test_reshape(x, shape): - xp.reshape(x, shape) - # TODO + assume(math.prod(shape) == math.prod(x.shape)) + + out = xp.reshape(x, shape) + + ph.assert_dtype("reshape", x.dtype, out.dtype) + + _shape = shape + if any(side == -1 for side in shape): + size = math.prod(x.shape) + rsize = math.prod(shape) * -1 + _shape[shape.index(-1)] = size / rsize + ph.assert_result_shape("reshape", (x.shape,), out.shape, _shape, shape=shape) + for x_idx, out_idx in zip(ah.ndindex(x.shape), ah.ndindex(out.shape)): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg @given( # TODO: axis arguments, update shift respectively From 3ac57efff17d322a871a667eef49dab35f44b469 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 19:44:35 +0000 Subject: [PATCH 11/41] Test `roll()` --- .../test_manipulation_functions.py | 40 +++++++++++++++---- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 4f7f3426..b8adec65 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,4 +1,5 @@ import math +from collections import deque from hypothesis import assume, given from hypothesis import strategies as st @@ -159,14 +160,37 @@ def test_reshape(x, shape): else: assert out[out_idx] == x[x_idx], msg -@given( - # TODO: axis arguments, update shift respectively - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), - shift=shared_shapes().flatmap(lambda s: st.integers(0, max(math.prod(s) - 1, 0))), -) -def test_roll(x, shift): - xp.roll(x, shift) - # TODO + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data()) +def test_roll(x, data): + shift = data.draw( + st.integers() | st.lists(st.integers(), max_size=x.ndim).map(tuple), + label="shift", + ) + axis_strats = [st.none()] + if x.shape != (): + axis_strats.append(st.integers(-x.ndim, x.ndim - 1)) + if isinstance(shift, int): + axis_strats.append(xps.valid_tuple_axes(x.ndim)) + kw = data.draw(hh.kwargs(axis=st.one_of(axis_strats)), label="kw") + + out = xp.roll(x, shift, **kw) + + ph.assert_dtype("roll", x.dtype, out.dtype) + + ph.assert_result_shape("roll", (x.shape,), out.shape) + + # TODO: test all shift/axis scenarios + if isinstance(shift, int) and kw.get("axis", None) is None: + indices = list(ah.ndindex(x.shape)) + shifted_indices = deque(indices) + shifted_indices.rotate(shift) + for x_idx, out_idx in zip(indices, shifted_indices): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg @given( From 258b75e5aa6f25c2aa5eb1de2100649813c176d1 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 20:30:09 +0000 Subject: [PATCH 12/41] Test `squeeze()` --- .../test_manipulation_functions.py | 62 ++++++++++++++----- 1 file changed, 48 insertions(+), 14 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index b8adec65..e17682c0 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -46,7 +46,7 @@ def test_concat(shape, dtypes, kw, data): @given( x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), - axis=shared_shapes().flatmap(lambda s: st.integers(-len(s), len(s))), + axis=shared_shapes().flatmap(lambda s: st.integers(-len(s) - 1, len(s))), ) def test_expand_dims(x, axis): out = xp.expand_dims(x, axis=axis) @@ -59,6 +59,53 @@ def test_expand_dims(x, axis): shape = tuple(shape) ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape) + for x_idx, out_idx in zip(ah.ndindex(x.shape), ah.ndindex(out.shape)): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg + + +@given( + x=xps.arrays( + dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1).filter(lambda s: 1 in s) + ), + data=st.data(), +) +def test_squeeze(x, data): + # axis=shared_shapes(min_side=1).flatmap(lambda s: nd_axes(len(s))), + squeezable_axes = st.sampled_from( + [i for i, side in enumerate(x.shape) if side == 1] + ) + axis = data.draw( + # TODO: generate valid negative axis + squeezable_axes | st.lists(squeezable_axes, unique=True).map(tuple), + label="axis", + ) + + out = xp.squeeze(x, axis) + + ph.assert_dtype("squeeze", x.dtype, out.dtype) + + if isinstance(axis, int): + axes = (axis,) + else: + axes = axis + shape = [] + for i, side in enumerate(x.shape): + if i not in axes: + shape.append(side) + shape = tuple(shape) + ph.assert_result_shape("squeeze", (x.shape,), out.shape, shape, axis=axis) + + for x_idx, out_idx in zip(ah.ndindex(x.shape), ah.ndindex(out.shape)): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg + @st.composite def flip_axis(draw, shape): @@ -193,19 +240,6 @@ def test_roll(x, data): assert out[out_idx] == x[x_idx], msg -@given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), - axis=shared_shapes().flatmap( - lambda s: st.just(0) - if len(s) == 0 - else st.integers(-len(s) + 1, len(s) - 1).filter(lambda i: s[i] == 1) - ), # TODO: tuple of axis i.e. axes -) -def test_squeeze(x, axis): - xp.squeeze(x, axis) - # TODO - - @given( shape=hh.shapes(), dtypes=hh.mutually_promotable_dtypes(None), From 1bdf6e476ee54f9e402d74959dc5ce0363a706b8 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 10 Nov 2021 20:42:01 +0000 Subject: [PATCH 13/41] Refactor assertions using ndindex --- .../test_manipulation_functions.py | 59 +++++++++---------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index e17682c0..18b0cf2e 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,5 +1,6 @@ import math from collections import deque +from typing import Iterable, Union from hypothesis import assume, given from hypothesis import strategies as st @@ -10,7 +11,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps -from .typing import Shape +from .typing import Array, Shape def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: @@ -22,6 +23,21 @@ def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: return st.shared(hh.shapes(*args, **kwargs), key="shape") +def assert_array_ndindex( + func_name: str, + x: Array, + x_indices: Iterable[Union[int, Shape]], + out: Array, + out_indices: Iterable[Union[int, Shape]], +): + for x_idx, out_idx in zip(x_indices, out_indices): + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]} [{func_name}()]" + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg + + @given( shape=hh.shapes(min_dims=1), dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), @@ -59,12 +75,9 @@ def test_expand_dims(x, axis): shape = tuple(shape) ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape) - for x_idx, out_idx in zip(ah.ndindex(x.shape), ah.ndindex(out.shape)): - msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg + assert_array_ndindex( + "expand_dims", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape) + ) @given( @@ -99,12 +112,7 @@ def test_squeeze(x, data): shape = tuple(shape) ph.assert_result_shape("squeeze", (x.shape,), out.shape, shape, axis=axis) - for x_idx, out_idx in zip(ah.ndindex(x.shape), ah.ndindex(out.shape)): - msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg + assert_array_ndindex("squeeze", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) @st.composite @@ -123,18 +131,13 @@ def flip_axis(draw, shape): def test_flip(x, kw): out = xp.flip(x, **kw) - ph.assert_dtype("expand_dims", x.dtype, out.dtype) + ph.assert_dtype("flip", x.dtype, out.dtype) # TODO: test all axis scenarios if kw.get("axis", None) is None: indices = list(ah.ndindex(x.shape)) reverse_indices = indices[::-1] - for x_idx, out_idx in zip(indices, reverse_indices): - msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg + assert_array_ndindex("flip", x, indices, out, reverse_indices) @given( @@ -200,12 +203,7 @@ def test_reshape(x, shape): _shape[shape.index(-1)] = size / rsize ph.assert_result_shape("reshape", (x.shape,), out.shape, _shape, shape=shape) - for x_idx, out_idx in zip(ah.ndindex(x.shape), ah.ndindex(out.shape)): - msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg + assert_array_ndindex("reshape", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data()) @@ -232,12 +230,9 @@ def test_roll(x, data): indices = list(ah.ndindex(x.shape)) shifted_indices = deque(indices) shifted_indices.rotate(shift) - for x_idx, out_idx in zip(indices, shifted_indices): - msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg + print(f"{indices=}") + print(f"{shifted_indices=}") + assert_array_ndindex("roll", x, indices, out, shifted_indices) @given( From f1cc3ea02aeec032525b943ee1629c8af7cfe277 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 11 Nov 2021 09:34:20 +0000 Subject: [PATCH 14/41] Fix `test_roll` --- array_api_tests/test_manipulation_functions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 18b0cf2e..1e91fafe 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -30,8 +30,10 @@ def assert_array_ndindex( out: Array, out_indices: Iterable[Union[int, Shape]], ): + msg_suffix = f" [{func_name}()]\n {x=}\n{out=}" for x_idx, out_idx in zip(x_indices, out_indices): - msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]} [{func_name}()]" + msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}" + msg += msg_suffix if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): assert xp.isnan(out[out_idx]), msg else: @@ -229,9 +231,7 @@ def test_roll(x, data): if isinstance(shift, int) and kw.get("axis", None) is None: indices = list(ah.ndindex(x.shape)) shifted_indices = deque(indices) - shifted_indices.rotate(shift) - print(f"{indices=}") - print(f"{shifted_indices=}") + shifted_indices.rotate(-shift) assert_array_ndindex("roll", x, indices, out, shifted_indices) From 36679c43979a10b68175f1fbfc5b19da9b35bba1 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 11 Nov 2021 11:27:04 +0000 Subject: [PATCH 15/41] Test `concat()` --- .../test_manipulation_functions.py | 63 +++++++++++++++---- 1 file changed, 51 insertions(+), 12 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 1e91fafe..ed0267b7 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -13,6 +13,9 @@ from . import xps from .typing import Array, Shape +MAX_SIDE = hh.MAX_ARRAY_SIZE // 64 +MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims + def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: key = "shape" @@ -40,26 +43,63 @@ def assert_array_ndindex( assert out[out_idx] == x[x_idx], msg +@st.composite +def concat_shapes(draw, shape, axis): + shape = list(shape) + shape[axis] = draw(st.integers(1, MAX_SIDE)) + return tuple(shape) + + @given( - shape=hh.shapes(min_dims=1), dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), - kw=hh.kwargs(axis=st.just(0) | st.none()), # TODO: test with axis >= 1 + kw=hh.kwargs(axis=st.none() | st.integers(-MAX_DIMS, MAX_DIMS - 1)), data=st.data(), ) -def test_concat(shape, dtypes, kw, data): +def test_concat(dtypes, kw, data): + axis = kw.get("axis", 0) + if axis is None: + shape_strat = hh.shapes() + else: + _axis = axis if axis >= 0 else abs(axis) - 1 + shape_strat = shared_shapes(min_dims=_axis + 1).flatmap( + lambda s: concat_shapes(s, axis) + ) arrays = [] for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") + x = data.draw(xps.arrays(dtype=dtype, shape=shape_strat), label=f"x{i}") arrays.append(x) + out = xp.concat(arrays, **kw) + ph.assert_dtype("concat", dtypes, out.dtype) + shapes = tuple(x.shape for x in arrays) - if kw.get("axis", 0) == 0: - pass # TODO: assert expected shape - elif kw["axis"] is None: + axis = kw.get("axis", 0) + if axis is None: size = sum(math.prod(s) for s in shapes) - ph.assert_result_shape("concat", shapes, out.shape, (size,), **kw) - # TODO: assert out elements match input arrays + shape = (size,) + else: + shape = list(shapes[0]) + for other_shape in shapes[1:]: + shape[axis] += other_shape[axis] + shape = tuple(shape) + ph.assert_result_shape("concat", shapes, out.shape, shape, **kw) + + # TODO: adjust indices with nonzero axis + if axis is None or axis == 0: + out_indices = ah.ndindex(out.shape) + for i, x in enumerate(arrays, 1): + msg_suffix = f" [concat({ph.fmt_kw(kw)})]\nx{i}={x!r}\n{out=}" + for x_idx in ah.ndindex(x.shape): + out_idx = next(out_indices) + msg = ( + f"out[{out_idx}]={out[out_idx]}, should be x{i}[{x_idx}]={x[x_idx]}" + ) + msg += msg_suffix + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg @given( @@ -169,9 +209,8 @@ def test_permute_dims(x, axes): # TODO: test elements -MAX_RESHAPE_SIDE = hh.MAX_ARRAY_SIZE // 64 reshape_x_shapes = st.shared( - hh.shapes().filter(lambda s: math.prod(s) <= MAX_RESHAPE_SIDE), + hh.shapes().filter(lambda s: math.prod(s) <= MAX_SIDE), key="reshape x shape", ) @@ -180,7 +219,7 @@ def test_permute_dims(x, axes): def reshape_shapes(draw, shape): size = 1 if len(shape) == 0 else math.prod(shape) rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size)) - assume(all(side <= MAX_RESHAPE_SIDE for side in rshape)) + assume(all(side <= MAX_SIDE for side in rshape)) if len(rshape) != 0 and size > 0 and draw(st.booleans()): index = draw(st.integers(0, len(rshape) - 1)) rshape[index] = -1 From a8fb70d51c22f530d30a134e2bf0b5273d07db35 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 11 Nov 2021 11:54:05 +0000 Subject: [PATCH 16/41] Test `stack()` --- .../test_manipulation_functions.py | 39 +++++++++++++++++-- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index ed0267b7..6ccadf43 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -275,15 +275,46 @@ def test_roll(x, data): @given( - shape=hh.shapes(), + shape=shared_shapes(min_dims=1), dtypes=hh.mutually_promotable_dtypes(None), + kw=hh.kwargs( + axis=shared_shapes(min_dims=1).flatmap( + lambda s: st.integers(-len(s), len(s) - 1) + ) + ), data=st.data(), ) -def test_stack(shape, dtypes, data): +def test_stack(shape, dtypes, kw, data): arrays = [] for i, dtype in enumerate(dtypes, 1): x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") arrays.append(x) - out = xp.stack(arrays) + + out = xp.stack(arrays, **kw) + ph.assert_dtype("stack", dtypes, out.dtype) - # TODO + + axis = kw.get("axis", 0) + _axis = axis if axis >= 0 else len(shape) + axis + 1 + _shape = list(shape) + _shape.insert(_axis, len(arrays)) + _shape = tuple(_shape) + ph.assert_result_shape( + "stack", tuple(x.shape for x in arrays), out.shape, _shape, **kw + ) + + # TODO: adjust indices with nonzero axis + if axis == 0: + out_indices = ah.ndindex(out.shape) + for i, x in enumerate(arrays, 1): + msg_suffix = f" [stack({ph.fmt_kw(kw)})]\nx{i}={x!r}\n{out=}" + for x_idx in ah.ndindex(x.shape): + out_idx = next(out_indices) + msg = ( + f"out[{out_idx}]={out[out_idx]}, should be x{i}[{x_idx}]={x[x_idx]}" + ) + msg += msg_suffix + if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): + assert xp.isnan(out[out_idx]), msg + else: + assert out[out_idx] == x[x_idx], msg From cf7688844cc2267b35bff5eb9f0172f963a40177 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 11 Nov 2021 12:09:13 +0000 Subject: [PATCH 17/41] Manipulation tests clean up --- .../test_manipulation_functions.py | 43 ++++++++----------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 6ccadf43..408a587a 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -129,12 +129,11 @@ def test_expand_dims(x, axis): data=st.data(), ) def test_squeeze(x, data): - # axis=shared_shapes(min_side=1).flatmap(lambda s: nd_axes(len(s))), + # TODO: generate valid negative axis (which keep uniqueness) squeezable_axes = st.sampled_from( [i for i, side in enumerate(x.shape) if side == 1] ) axis = data.draw( - # TODO: generate valid negative axis squeezable_axes | st.lists(squeezable_axes, unique=True).map(tuple), label="axis", ) @@ -157,20 +156,19 @@ def test_squeeze(x, data): assert_array_ndindex("squeeze", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) -@st.composite -def flip_axis(draw, shape): - if len(shape) == 0 or draw(st.booleans()): - return None - else: - ndim = len(shape) - return draw(st.integers(-ndim, ndim - 1) | xps.valid_tuple_axes(ndim)) - - @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), - kw=hh.kwargs(axis=shared_shapes().flatmap(flip_axis)), + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), + data=st.data(), ) -def test_flip(x, kw): +def test_flip(x, data): + if x.ndim == 0: + axis_strat = st.none() + else: + axis_strat = ( + st.none() | st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) + ) + kw = data.draw(hh.kwargs(axis=axis_strat), label="kw") + out = xp.flip(x, **kw) ph.assert_dtype("flip", x.dtype, out.dtype) @@ -209,12 +207,6 @@ def test_permute_dims(x, axes): # TODO: test elements -reshape_x_shapes = st.shared( - hh.shapes().filter(lambda s: math.prod(s) <= MAX_SIDE), - key="reshape x shape", -) - - @st.composite def reshape_shapes(draw, shape): size = 1 if len(shape) == 0 else math.prod(shape) @@ -227,21 +219,22 @@ def reshape_shapes(draw, shape): @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=reshape_x_shapes), - shape=reshape_x_shapes.flatmap(reshape_shapes), + x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(max_side=MAX_SIDE)), + data=st.data(), ) -def test_reshape(x, shape): - assume(math.prod(shape) == math.prod(x.shape)) +def test_reshape(x, data): + shape = data.draw(reshape_shapes(x.shape)) out = xp.reshape(x, shape) ph.assert_dtype("reshape", x.dtype, out.dtype) - _shape = shape + _shape = list(shape) if any(side == -1 for side in shape): size = math.prod(x.shape) rsize = math.prod(shape) * -1 _shape[shape.index(-1)] = size / rsize + _shape = tuple(_shape) ph.assert_result_shape("reshape", (x.shape,), out.shape, _shape, shape=shape) assert_array_ndindex("reshape", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) From 30a072163576a6664c94b93dd4437432f4fbeec7 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 11 Nov 2021 18:57:49 +0000 Subject: [PATCH 18/41] Improve `min()`/`max()` tests --- array_api_tests/test_statistical_functions.py | 106 ++++++++++++++++-- 1 file changed, 96 insertions(+), 10 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 98cade7d..1b2078b7 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,22 +1,108 @@ +import math + from hypothesis import given +from hypothesis import strategies as st from . import _array_module as xp +from . import array_helpers as ah +from . import dtype_helpers as dh from . import hypothesis_helpers as hh +from . import pytest_helpers as ph from . import xps -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) -def test_min(x): - xp.min(x) - # TODO +@given( + x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_min(x, data): + axis_strats = [st.none()] + if x.shape != (): + axis_strats.append( + st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) + ) + kw = data.draw( + hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw" + ) + out = xp.min(x, **kw) -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) -def test_max(x): - xp.max(x) - # TODO + ph.assert_dtype("min", x.dtype, out.dtype) + + f_func = f"min({ph.fmt_kw(kw)})" + + # TODO: support axis + if kw.get("axis") is None: + keepdims = kw.get("keepdims", False) + if keepdims: + idx = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" + assert out.shape == idx + else: + ph.assert_shape("min", out.shape, (), **kw) + + # TODO: figure out NaN behaviour + if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)): + _out = xp.reshape(out, ()) if keepdims else out + scalar_type = dh.get_scalar_type(out.dtype) + elements = [] + for idx in ah.ndindex(x.shape): + s = scalar_type(x[idx]) + elements.append(s) + min_ = scalar_type(_out) + expected = min(elements) + msg = f"out={min_}, should be {expected} [{f_func}]" + if math.isnan(min_): + assert math.isnan(expected), msg + else: + assert min_ == expected, msg + + +@given( + x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_max(x, data): + axis_strats = [st.none()] + if x.shape != (): + axis_strats.append( + st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) + ) + kw = data.draw( + hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw" + ) + + out = xp.max(x, **kw) + + ph.assert_dtype("max", x.dtype, out.dtype) + + f_func = f"max({ph.fmt_kw(kw)})" + + # TODO: support axis + if kw.get("axis") is None: + keepdims = kw.get("keepdims", False) + if keepdims: + idx = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" + assert out.shape == idx + else: + ph.assert_shape("max", out.shape, (), **kw) + + # TODO: figure out NaN behaviour + if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)): + _out = xp.reshape(out, ()) if keepdims else out + scalar_type = dh.get_scalar_type(out.dtype) + elements = [] + for idx in ah.ndindex(x.shape): + s = scalar_type(x[idx]) + elements.append(s) + max_ = scalar_type(_out) + expected = max(elements) + msg = f"out={max_}, should be {expected} [{f_func}]" + if math.isnan(max_): + assert math.isnan(expected), msg + else: + assert max_ == expected, msg # TODO: generate kwargs From 50a63b24fbd908a6a50abad1178c15b05152de3e Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 12 Nov 2021 09:35:43 +0000 Subject: [PATCH 19/41] Test `mean()` --- array_api_tests/test_statistical_functions.py | 52 ++++++++++++++++--- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 1b2078b7..54ec1d5b 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -10,6 +10,8 @@ from . import pytest_helpers as ph from . import xps +RTOL = 0.05 + @given( x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), @@ -37,7 +39,7 @@ def test_min(x, data): if keepdims: idx = tuple(1 for _ in x.shape) msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" - assert out.shape == idx + assert out.shape == idx, msg else: ph.assert_shape("min", out.shape, (), **kw) @@ -84,7 +86,7 @@ def test_max(x, data): if keepdims: idx = tuple(1 for _ in x.shape) msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" - assert out.shape == idx + assert out.shape == idx, msg else: ph.assert_shape("max", out.shape, (), **kw) @@ -105,11 +107,47 @@ def test_max(x, data): assert max_ == expected, msg -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) -def test_mean(x): - xp.mean(x) - # TODO +@given( + x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_mean(x, data): + axis_strats = [st.none()] + if x.shape != (): + axis_strats.append( + st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) + ) + kw = data.draw( + hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw" + ) + + out = xp.mean(x, **kw) + + ph.assert_dtype("mean", x.dtype, out.dtype) + + f_func = f"mean({ph.fmt_kw(kw)})" + + # TODO: support axis + if kw.get("axis") is None: + keepdims = kw.get("keepdims", False) + if keepdims: + idx = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" + assert out.shape == idx, msg + else: + ph.assert_shape("max", out.shape, (), **kw) + + # TODO: figure out NaN behaviour + if not xp.any(xp.isnan(x)): + _out = xp.reshape(out, ()) if keepdims else out + elements = [] + for idx in ah.ndindex(x.shape): + s = float(x[idx]) + elements.append(s) + mean = float(_out) + expected = sum(elements) / len(elements) + msg = f"out={mean}, should be roughly {expected} [{f_func}]" + assert math.isclose(mean, expected, rel_tol=RTOL), msg # TODO: generate kwargs From 1ff7271d661518b305a5aef5d4d111a1d513d007 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 12 Nov 2021 11:02:31 +0000 Subject: [PATCH 20/41] Test `prod()` --- array_api_tests/test_statistical_functions.py | 109 ++++++++++++++---- 1 file changed, 87 insertions(+), 22 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 54ec1d5b..8219dfd5 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,6 +1,6 @@ import math -from hypothesis import given +from hypothesis import assume, given from hypothesis import strategies as st from . import _array_module as xp @@ -9,8 +9,22 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps +from .typing import Scalar, ScalarType -RTOL = 0.05 + +def assert_equals( + func_name: str, type_: ScalarType, out: Scalar, expected: Scalar, /, **kw +): + f_func = f"{func_name}({ph.fmt_kw(kw)})" + if type_ is bool or type_ is int: + msg = f"{out=}, should be {expected} [{f_func}]" + assert out == expected, msg + elif math.isnan(expected): + msg = f"{out=}, should be {expected} [{f_func}]" + assert math.isnan(out), msg + else: + msg = f"{out=}, should be roughly {expected} [{f_func}]" + assert math.isclose(out, expected, rel_tol=0.05), msg @given( @@ -34,7 +48,7 @@ def test_min(x, data): f_func = f"min({ph.fmt_kw(kw)})" # TODO: support axis - if kw.get("axis") is None: + if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: idx = tuple(1 for _ in x.shape) @@ -53,11 +67,7 @@ def test_min(x, data): elements.append(s) min_ = scalar_type(_out) expected = min(elements) - msg = f"out={min_}, should be {expected} [{f_func}]" - if math.isnan(min_): - assert math.isnan(expected), msg - else: - assert min_ == expected, msg + assert_equals("min", dh.get_scalar_type(out.dtype), min_, expected) @given( @@ -81,7 +91,7 @@ def test_max(x, data): f_func = f"max({ph.fmt_kw(kw)})" # TODO: support axis - if kw.get("axis") is None: + if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: idx = tuple(1 for _ in x.shape) @@ -100,11 +110,7 @@ def test_max(x, data): elements.append(s) max_ = scalar_type(_out) expected = max(elements) - msg = f"out={max_}, should be {expected} [{f_func}]" - if math.isnan(max_): - assert math.isnan(expected), msg - else: - assert max_ == expected, msg + assert_equals("mean", dh.get_scalar_type(out.dtype), max_, expected) @given( @@ -128,7 +134,7 @@ def test_mean(x, data): f_func = f"mean({ph.fmt_kw(kw)})" # TODO: support axis - if kw.get("axis") is None: + if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: idx = tuple(1 for _ in x.shape) @@ -146,15 +152,74 @@ def test_mean(x, data): elements.append(s) mean = float(_out) expected = sum(elements) / len(elements) - msg = f"out={mean}, should be roughly {expected} [{f_func}]" - assert math.isclose(mean, expected, rel_tol=RTOL), msg + assert_equals("mean", float, mean, expected) -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) -def test_prod(x): - xp.prod(x) - # TODO +@given( + x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_prod(x, data): + axis_strats = [st.none()] + if x.shape != (): + axis_strats.append( + st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) + ) + kw = data.draw( + hh.kwargs( + axis=st.one_of(axis_strats), + dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes + keepdims=st.booleans(), + ), + label="kw", + ) + + out = xp.prod(x, **kw) + + dtype = kw.get("dtype", None) + if dtype is None: + if dh.is_int_dtype(x.dtype): + m, M = dh.dtype_ranges[x.dtype] + d_m, d_M = dh.dtype_ranges[dh.default_int] + if m < d_m or M > d_M: + _dtype = x.dtype + else: + _dtype = dh.default_int + else: + if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: + _dtype = x.dtype + else: + _dtype = dh.default_float + else: + _dtype = dtype + ph.assert_dtype("prod", x.dtype, out.dtype, _dtype) + + f_func = f"prod({ph.fmt_kw(kw)})" + + # TODO: support axis + if kw.get("axis", None) is None: + keepdims = kw.get("keepdims", False) + if keepdims: + idx = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" + assert out.shape == idx, msg + else: + ph.assert_shape("prod", out.shape, (), **kw) + + # TODO: figure out NaN behaviour + if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)): + _out = xp.reshape(out, ()) if keepdims else out + scalar_type = dh.get_scalar_type(out.dtype) + elements = [] + for idx in ah.ndindex(x.shape): + s = scalar_type(x[idx]) + elements.append(s) + prod = scalar_type(_out) + expected = math.prod(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + assert_equals("prod", dh.get_scalar_type(out.dtype), prod, expected) # TODO: generate kwargs From f426f88b0d8bcd5e7e1896fb7c47a4222c11bce0 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 12 Nov 2021 11:15:46 +0000 Subject: [PATCH 21/41] Refactor axes strategies for stat functions --- array_api_tests/test_statistical_functions.py | 57 +++++++------------ 1 file changed, 22 insertions(+), 35 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 8219dfd5..640ee7ec 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,4 +1,5 @@ import math +from typing import Optional, Union from hypothesis import assume, given from hypothesis import strategies as st @@ -9,7 +10,15 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps -from .typing import Scalar, ScalarType +from .typing import Scalar, ScalarType, Shape + + +def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]: + axes_strats = [st.none()] + if ndim != 0: + axes_strats.append(st.integers(-ndim, ndim - 1)) + axes_strats.append(xps.valid_tuple_axes(ndim)) + return st.one_of(axes_strats) def assert_equals( @@ -32,14 +41,7 @@ def assert_equals( data=st.data(), ) def test_min(x, data): - axis_strats = [st.none()] - if x.shape != (): - axis_strats.append( - st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) - ) - kw = data.draw( - hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw" - ) + kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") out = xp.min(x, **kw) @@ -75,14 +77,7 @@ def test_min(x, data): data=st.data(), ) def test_max(x, data): - axis_strats = [st.none()] - if x.shape != (): - axis_strats.append( - st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) - ) - kw = data.draw( - hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw" - ) + kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") out = xp.max(x, **kw) @@ -118,14 +113,7 @@ def test_max(x, data): data=st.data(), ) def test_mean(x, data): - axis_strats = [st.none()] - if x.shape != (): - axis_strats.append( - st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) - ) - kw = data.draw( - hh.kwargs(axis=st.one_of(axis_strats), keepdims=st.booleans()), label="kw" - ) + kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") out = xp.mean(x, **kw) @@ -160,14 +148,9 @@ def test_mean(x, data): data=st.data(), ) def test_prod(x, data): - axis_strats = [st.none()] - if x.shape != (): - axis_strats.append( - st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim) - ) kw = data.draw( hh.kwargs( - axis=st.one_of(axis_strats), + axis=axes(x.ndim), dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes keepdims=st.booleans(), ), @@ -222,10 +205,14 @@ def test_prod(x, data): assert_equals("prod", dh.get_scalar_type(out.dtype), prod, expected) -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) -def test_std(x): - xp.std(x) +@given( + x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_std(x, data): + kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + + xp.std(x, **kw) # TODO From 1238e75e5f7a946987b10d1000e11ebbe0c7780e Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 12 Nov 2021 18:23:36 +0000 Subject: [PATCH 22/41] Test `std()` --- array_api_tests/test_statistical_functions.py | 65 ++++++++++++++----- 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 640ee7ec..6b160c7e 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -53,9 +53,9 @@ def test_min(x, data): if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: - idx = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" - assert out.shape == idx, msg + shape = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" + assert out.shape == shape, msg else: ph.assert_shape("min", out.shape, (), **kw) @@ -89,9 +89,9 @@ def test_max(x, data): if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: - idx = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" - assert out.shape == idx, msg + shape = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" + assert out.shape == shape, msg else: ph.assert_shape("max", out.shape, (), **kw) @@ -125,9 +125,9 @@ def test_mean(x, data): if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: - idx = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" - assert out.shape == idx, msg + shape = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" + assert out.shape == shape, msg else: ph.assert_shape("max", out.shape, (), **kw) @@ -183,9 +183,9 @@ def test_prod(x, data): if kw.get("axis", None) is None: keepdims = kw.get("keepdims", False) if keepdims: - idx = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {idx} [{f_func}]" - assert out.shape == idx, msg + shape = tuple(1 for _ in x.shape) + msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" + assert out.shape == shape, msg else: ph.assert_shape("prod", out.shape, (), **kw) @@ -206,14 +206,47 @@ def test_prod(x, data): @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), + x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)).filter( + lambda x: x.size >= 2 + ), data=st.data(), ) def test_std(x, data): - kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") + axis = data.draw(axes(x.ndim), label="axis") + if axis is None: + N = x.size + _axes = tuple(range(x.ndim)) + else: + _axes = axis if isinstance(axis, tuple) else (axis,) + _axes = tuple( + axis if axis >= 0 else x.ndim + axis for axis in _axes + ) # normalise + N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) + correction = data.draw( + st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), + label="correction", + ) + keepdims = data.draw(st.booleans(), label="keepdims") + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("correction", correction, 0.0), + ("keepdims", keepdims, False), + ), + label="kw", + ) - xp.std(x, **kw) - # TODO + out = xp.std(x, **kw) + + ph.assert_dtype("std", x.dtype, out.dtype) + + if keepdims: + shape = tuple(1 if axis in _axes else side for axis, side in enumerate(x.shape)) + else: + shape = tuple(side for axis, side in enumerate(x.shape) if axis not in _axes) + ph.assert_shape("std", out.shape, shape, **kw) + + # We can't easily test the result(s) as standard deviation methods vary a lot # TODO: generate kwargs From d1eece1057708f8df73159f1536a687bfb3860d6 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 15 Nov 2021 12:56:31 +0000 Subject: [PATCH 23/41] Test axes results --- array_api_tests/meta/test_utils.py | 21 +- array_api_tests/test_statistical_functions.py | 272 ++++++++++-------- 2 files changed, 166 insertions(+), 127 deletions(-) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 7dfafc5b..35c884a3 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -1,7 +1,8 @@ import pytest -from ..test_signatures import extension_module from ..test_creation_functions import frange +from ..test_signatures import extension_module +from ..test_statistical_functions import axes_ndindex def test_extension_module_is_extension(): @@ -24,3 +25,21 @@ def test_extension_func_is_not_extension(): def test_frange(r, size, elements): assert len(r) == size assert list(r) == elements + + +@pytest.mark.parametrize( + "shape, axes, expected", + [ + ((), (), [((),)]), + ( + (2, 2), + (0,), + [ + ((0, 0), (1, 0)), + ((0, 1), (1, 1)), + ], + ), + ], +) +def test_axes_ndindex(shape, axes, expected): + assert list(axes_ndindex(shape, axes)) == expected diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 6b160c7e..3907d64e 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,5 +1,6 @@ import math -from typing import Optional, Union +from itertools import product +from typing import Iterator, Optional, Tuple, Union from hypothesis import assume, given from hypothesis import strategies as st @@ -21,23 +22,82 @@ def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]: return st.one_of(axes_strats) +def normalise_axis( + axis: Optional[Union[int, Tuple[int, ...]]], ndim: int +) -> Tuple[int, ...]: + if axis is None: + return tuple(range(ndim)) + axes = axis if isinstance(axis, tuple) else (axis,) + axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes) + return axes + + +def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, ...]]: + base_iterables = [] + axes_iterables = [] + for axis, side in enumerate(shape): + if axis in axes: + base_iterables.append((None,)) + axes_iterables.append(range(side)) + else: + base_iterables.append(range(side)) + axes_iterables.append((None,)) + for base_idx in product(*base_iterables): + indices = [] + for idx in product(*axes_iterables): + idx = list(idx) + for axis, side in enumerate(idx): + if axis not in axes: + idx[axis] = base_idx[axis] + idx = tuple(idx) + indices.append(idx) + yield tuple(indices) + + +def assert_keepdimable_shape( + func_name: str, + in_shape: Shape, + axes: Tuple[int, ...], + keepdims: bool, + out_shape: Shape, + /, + **kw, +): + if keepdims: + shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape)) + else: + shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes) + ph.assert_shape(func_name, out_shape, shape, **kw) + + def assert_equals( - func_name: str, type_: ScalarType, out: Scalar, expected: Scalar, /, **kw + func_name: str, + type_: ScalarType, + idx: Shape, + out: Scalar, + expected: Scalar, + /, + **kw, ): + out_repr = "out" if idx == () else f"out[{idx}]" f_func = f"{func_name}({ph.fmt_kw(kw)})" if type_ is bool or type_ is int: - msg = f"{out=}, should be {expected} [{f_func}]" + msg = f"{out_repr}={out}, should be {expected} [{f_func}]" assert out == expected, msg elif math.isnan(expected): - msg = f"{out=}, should be {expected} [{f_func}]" + msg = f"{out_repr}={out}, should be {expected} [{f_func}]" assert math.isnan(out), msg else: - msg = f"{out=}, should be roughly {expected} [{f_func}]" + msg = f"{out_repr}={out}, should be roughly {expected} [{f_func}]" assert math.isclose(out, expected, rel_tol=0.05), msg @given( - x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), data=st.data(), ) def test_min(x, data): @@ -46,34 +106,27 @@ def test_min(x, data): out = xp.min(x, **kw) ph.assert_dtype("min", x.dtype, out.dtype) - - f_func = f"min({ph.fmt_kw(kw)})" - - # TODO: support axis - if kw.get("axis", None) is None: - keepdims = kw.get("keepdims", False) - if keepdims: - shape = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" - assert out.shape == shape, msg - else: - ph.assert_shape("min", out.shape, (), **kw) - - # TODO: figure out NaN behaviour - if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)): - _out = xp.reshape(out, ()) if keepdims else out - scalar_type = dh.get_scalar_type(out.dtype) - elements = [] - for idx in ah.ndindex(x.shape): - s = scalar_type(x[idx]) - elements.append(s) - min_ = scalar_type(_out) - expected = min(elements) - assert_equals("min", dh.get_scalar_type(out.dtype), min_, expected) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "min", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + min_ = scalar_type(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = min(elements) + assert_equals("min", dh.get_scalar_type(out.dtype), out_idx, min_, expected) @given( - x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), data=st.data(), ) def test_max(x, data): @@ -82,34 +135,27 @@ def test_max(x, data): out = xp.max(x, **kw) ph.assert_dtype("max", x.dtype, out.dtype) - - f_func = f"max({ph.fmt_kw(kw)})" - - # TODO: support axis - if kw.get("axis", None) is None: - keepdims = kw.get("keepdims", False) - if keepdims: - shape = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" - assert out.shape == shape, msg - else: - ph.assert_shape("max", out.shape, (), **kw) - - # TODO: figure out NaN behaviour - if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)): - _out = xp.reshape(out, ()) if keepdims else out - scalar_type = dh.get_scalar_type(out.dtype) - elements = [] - for idx in ah.ndindex(x.shape): - s = scalar_type(x[idx]) - elements.append(s) - max_ = scalar_type(_out) - expected = max(elements) - assert_equals("mean", dh.get_scalar_type(out.dtype), max_, expected) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "max", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + max_ = scalar_type(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = max(elements) + assert_equals("max", dh.get_scalar_type(out.dtype), out_idx, max_, expected) @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), data=st.data(), ) def test_mean(x, data): @@ -118,33 +164,26 @@ def test_mean(x, data): out = xp.mean(x, **kw) ph.assert_dtype("mean", x.dtype, out.dtype) - - f_func = f"mean({ph.fmt_kw(kw)})" - - # TODO: support axis - if kw.get("axis", None) is None: - keepdims = kw.get("keepdims", False) - if keepdims: - shape = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" - assert out.shape == shape, msg - else: - ph.assert_shape("max", out.shape, (), **kw) - - # TODO: figure out NaN behaviour - if not xp.any(xp.isnan(x)): - _out = xp.reshape(out, ()) if keepdims else out - elements = [] - for idx in ah.ndindex(x.shape): - s = float(x[idx]) - elements.append(s) - mean = float(_out) - expected = sum(elements) / len(elements) - assert_equals("mean", float, mean, expected) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "mean", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + ) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + mean = float(out[out_idx]) + elements = [] + for idx in indices: + s = float(x[idx]) + elements.append(s) + expected = sum(elements) / len(elements) + assert_equals("mean", dh.get_scalar_type(out.dtype), out_idx, mean, expected) @given( - x=xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1)), + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), data=st.data(), ) def test_prod(x, data): @@ -176,52 +215,37 @@ def test_prod(x, data): else: _dtype = dtype ph.assert_dtype("prod", x.dtype, out.dtype, _dtype) - - f_func = f"prod({ph.fmt_kw(kw)})" - - # TODO: support axis - if kw.get("axis", None) is None: - keepdims = kw.get("keepdims", False) - if keepdims: - shape = tuple(1 for _ in x.shape) - msg = f"{out.shape=}, should be reduced dimension {shape} [{f_func}]" - assert out.shape == shape, msg - else: - ph.assert_shape("prod", out.shape, (), **kw) - - # TODO: figure out NaN behaviour - if dh.is_int_dtype(x.dtype) or not xp.any(xp.isnan(x)): - _out = xp.reshape(out, ()) if keepdims else out - scalar_type = dh.get_scalar_type(out.dtype) - elements = [] - for idx in ah.ndindex(x.shape): - s = scalar_type(x[idx]) - elements.append(s) - prod = scalar_type(_out) - expected = math.prod(elements) - if dh.is_int_dtype(out.dtype): - m, M = dh.dtype_ranges[out.dtype] - assume(m <= expected <= M) - assert_equals("prod", dh.get_scalar_type(out.dtype), prod, expected) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "prod", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + prod = scalar_type(out[out_idx]) + assume(not math.isinf(prod)) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = math.prod(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + assert_equals("prod", dh.get_scalar_type(out.dtype), out_idx, prod, expected) @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)).filter( - lambda x: x.size >= 2 - ), + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ).filter(lambda x: x.size >= 2), data=st.data(), ) def test_std(x, data): axis = data.draw(axes(x.ndim), label="axis") - if axis is None: - N = x.size - _axes = tuple(range(x.ndim)) - else: - _axes = axis if isinstance(axis, tuple) else (axis,) - _axes = tuple( - axis if axis >= 0 else x.ndim + axis for axis in _axes - ) # normalise - N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) + _axes = normalise_axis(axis, x.ndim) + N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) correction = data.draw( st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), label="correction", @@ -239,13 +263,9 @@ def test_std(x, data): out = xp.std(x, **kw) ph.assert_dtype("std", x.dtype, out.dtype) - - if keepdims: - shape = tuple(1 if axis in _axes else side for axis, side in enumerate(x.shape)) - else: - shape = tuple(side for axis, side in enumerate(x.shape) if axis not in _axes) - ph.assert_shape("std", out.shape, shape, **kw) - + assert_keepdimable_shape( + "std", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + ) # We can't easily test the result(s) as standard deviation methods vary a lot From 687e40a0e40eb30a4ac74f1e2c1ccf7cbb22c3c9 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 15 Nov 2021 13:11:53 +0000 Subject: [PATCH 24/41] Test `var()` and `sum()` --- array_api_tests/test_statistical_functions.py | 109 +++++++++++++++--- 1 file changed, 93 insertions(+), 16 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 3907d64e..700f5bc3 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -56,10 +56,10 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, . def assert_keepdimable_shape( func_name: str, + out_shape: Shape, in_shape: Shape, axes: Tuple[int, ...], keepdims: bool, - out_shape: Shape, /, **kw, ): @@ -108,7 +108,7 @@ def test_min(x, data): ph.assert_dtype("min", x.dtype, out.dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "min", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + "min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): @@ -137,7 +137,7 @@ def test_max(x, data): ph.assert_dtype("max", x.dtype, out.dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "max", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + "max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): @@ -166,7 +166,7 @@ def test_mean(x, data): ph.assert_dtype("mean", x.dtype, out.dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "mean", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + "mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): mean = float(out[out_idx]) @@ -217,7 +217,7 @@ def test_prod(x, data): ph.assert_dtype("prod", x.dtype, out.dtype, _dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "prod", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + "prod", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): @@ -264,20 +264,97 @@ def test_std(x, data): ph.assert_dtype("std", x.dtype, out.dtype) assert_keepdimable_shape( - "std", x.shape, _axes, kw.get("keepdims", False), out.shape, **kw + "std", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) # We can't easily test the result(s) as standard deviation methods vary a lot -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1))) -def test_sum(x): - xp.sum(x) - # TODO +@given( + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ).filter(lambda x: x.size >= 2), + data=st.data(), +) +def test_var(x, data): + axis = data.draw(axes(x.ndim), label="axis") + _axes = normalise_axis(axis, x.ndim) + N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) + correction = data.draw( + st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), + label="correction", + ) + keepdims = data.draw(st.booleans(), label="keepdims") + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("correction", correction, 0.0), + ("keepdims", keepdims, False), + ), + label="kw", + ) + + out = xp.var(x, **kw) + + ph.assert_dtype("var", x.dtype, out.dtype) + assert_keepdimable_shape( + "var", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + # We can't easily test the result(s) as variance methods vary a lot + + +@given( + x=xps.arrays( + dtype=xps.numeric_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_sum(x, data): + kw = data.draw( + hh.kwargs( + axis=axes(x.ndim), + dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes + keepdims=st.booleans(), + ), + label="kw", + ) + out = xp.sum(x, **kw) -# TODO: generate kwargs -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) -def test_var(x): - xp.var(x) - # TODO + dtype = kw.get("dtype", None) + if dtype is None: + if dh.is_int_dtype(x.dtype): + m, M = dh.dtype_ranges[x.dtype] + d_m, d_M = dh.dtype_ranges[dh.default_int] + if m < d_m or M > d_M: + _dtype = x.dtype + else: + _dtype = dh.default_int + else: + if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: + _dtype = x.dtype + else: + _dtype = dh.default_float + else: + _dtype = dtype + ph.assert_dtype("sum", x.dtype, out.dtype, _dtype) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + assert_keepdimable_shape( + "sum", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + scalar_type = dh.get_scalar_type(out.dtype) + for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): + sum_ = scalar_type(out[out_idx]) + assume(not math.isinf(sum_)) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = sum(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + assert_equals("sum", dh.get_scalar_type(out.dtype), out_idx, sum_, expected) From d38eee20dcead12e94d34e54863c12b00fa3282b Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 15 Nov 2021 13:27:57 +0000 Subject: [PATCH 25/41] Generate all valid dtypes in `test_prod` and `test_sum` --- array_api_tests/test_statistical_functions.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 700f5bc3..81498062 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -11,7 +11,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps -from .typing import Scalar, ScalarType, Shape +from .typing import DataType, Scalar, ScalarType, Shape def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]: @@ -22,6 +22,11 @@ def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]: return st.one_of(axes_strats) +def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: + dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype] + return st.none() | st.sampled_from(dtypes) + + def normalise_axis( axis: Optional[Union[int, Tuple[int, ...]]], ndim: int ) -> Tuple[int, ...]: @@ -190,7 +195,7 @@ def test_prod(x, data): kw = data.draw( hh.kwargs( axis=axes(x.ndim), - dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes + dtype=kwarg_dtypes(x.dtype), keepdims=st.booleans(), ), label="kw", @@ -316,7 +321,7 @@ def test_sum(x, data): kw = data.draw( hh.kwargs( axis=axes(x.ndim), - dtype=st.none() | st.just(x.dtype), # TODO: all valid dtypes + dtype=kwarg_dtypes(x.dtype), keepdims=st.booleans(), ), label="kw", From 704e47bf1efa2a220aa513b642ee018182a5eece Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 16 Nov 2021 12:12:03 +0000 Subject: [PATCH 26/41] Cover axis in `test_concat` --- array_api_tests/meta/test_utils.py | 26 ++++++++ .../test_manipulation_functions.py | 65 +++++++++++++++---- array_api_tests/test_statistical_functions.py | 4 +- 3 files changed, 79 insertions(+), 16 deletions(-) diff --git a/array_api_tests/meta/test_utils.py b/array_api_tests/meta/test_utils.py index 35c884a3..34fa1836 100644 --- a/array_api_tests/meta/test_utils.py +++ b/array_api_tests/meta/test_utils.py @@ -1,6 +1,8 @@ import pytest +from .. import array_helpers as ah from ..test_creation_functions import frange +from ..test_manipulation_functions import axis_ndindex from ..test_signatures import extension_module from ..test_statistical_functions import axes_ndindex @@ -27,6 +29,30 @@ def test_frange(r, size, elements): assert list(r) == elements +@pytest.mark.parametrize( + "shape, expected", + [((), [()])], +) +def test_ndindex(shape, expected): + assert list(ah.ndindex(shape)) == expected + + +@pytest.mark.parametrize( + "shape, axis, expected", + [ + ((1,), 0, [(slice(None, None),)]), + ((1, 2), 0, [(slice(None, None), slice(None, None))]), + ( + (2, 4), + 1, + [(0, slice(None, None)), (1, slice(None, None))], + ), + ], +) +def test_axis_ndindex(shape, axis, expected): + assert list(axis_ndindex(shape, axis)) == expected + + @pytest.mark.parametrize( "shape, axes, expected", [ diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 408a587a..a2b17c3e 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,6 +1,7 @@ import math from collections import deque -from typing import Iterable, Union +from itertools import product +from typing import Iterable, Iterator, Tuple, Union from hypothesis import assume, given from hypothesis import strategies as st @@ -43,6 +44,28 @@ def assert_array_ndindex( assert out[out_idx] == x[x_idx], msg +def axis_ndindex( + shape: Shape, axis: int +) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: + iterables = [range(side) for side in shape[:axis]] + for _ in range(len(shape[axis:])): + iterables.append([slice(None, None)]) + yield from product(*iterables) + + +def assert_equals( + func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw +): + msg = ( + f"{out_repr}={out_val}, should be {x_repr}={x_val} " + f"[{func_name}({ph.fmt_kw(kw)})]" + ) + if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val): + assert xp.isnan(x_val), msg + else: + assert out_val == out_val, msg + + @st.composite def concat_shapes(draw, shape, axis): shape = list(shape) @@ -85,21 +108,35 @@ def test_concat(dtypes, kw, data): shape = tuple(shape) ph.assert_result_shape("concat", shapes, out.shape, shape, **kw) - # TODO: adjust indices with nonzero axis - if axis is None or axis == 0: - out_indices = ah.ndindex(out.shape) - for i, x in enumerate(arrays, 1): - msg_suffix = f" [concat({ph.fmt_kw(kw)})]\nx{i}={x!r}\n{out=}" + if axis is None: + out_indices = (i for i in range(out.size)) + for x_num, x in enumerate(arrays, 1): for x_idx in ah.ndindex(x.shape): - out_idx = next(out_indices) - msg = ( - f"out[{out_idx}]={out[out_idx]}, should be x{i}[{x_idx}]={x[x_idx]}" + out_i = next(out_indices) + assert_equals( + "concat", + f"x{x_num}[{x_idx}]", + x[x_idx], + f"out[{out_i}]", + out[out_i], + **kw, ) - msg += msg_suffix - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg + else: + out_indices = ah.ndindex(out.shape) + for idx in axis_ndindex(shape, axis): + f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) + for x_num, x in enumerate(arrays, 1): + indexed_x = x[idx] + for x_idx in ah.ndindex(indexed_x.shape): + out_idx = next(out_indices) + assert_equals( + "concat", + f"x{x_num}[{f_idx}][{x_idx}]", + indexed_x[x_idx], + f"out[{out_idx}]", + out[out_idx], + **kw, + ) @given( diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 81498062..0373ac47 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -42,11 +42,11 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, . axes_iterables = [] for axis, side in enumerate(shape): if axis in axes: - base_iterables.append((None,)) + base_iterables.append([None]) axes_iterables.append(range(side)) else: base_iterables.append(range(side)) - axes_iterables.append((None,)) + axes_iterables.append([None]) for base_idx in product(*base_iterables): indices = [] for idx in product(*axes_iterables): From 40167c34a8f0785566e38b922a45651db252abdb Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 16 Nov 2021 12:38:23 +0000 Subject: [PATCH 27/41] Cover axis in `test_flip` --- array_api_tests/test_manipulation_functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index a2b17c3e..061b7f0a 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -12,6 +12,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps +from .test_statistical_functions import axes_ndindex, normalise_axis # TODO: Move from .typing import Array, Shape MAX_SIDE = hh.MAX_ARRAY_SIZE // 64 @@ -210,9 +211,8 @@ def test_flip(x, data): ph.assert_dtype("flip", x.dtype, out.dtype) - # TODO: test all axis scenarios - if kw.get("axis", None) is None: - indices = list(ah.ndindex(x.shape)) + _axes = normalise_axis(kw.get("axis", None), x.ndim) + for indices in axes_ndindex(x.shape, _axes): reverse_indices = indices[::-1] assert_array_ndindex("flip", x, indices, out, reverse_indices) From 2decdf0349f5b0e1c86b209d6a26c9591a4c4cd4 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 16 Nov 2021 15:39:30 +0000 Subject: [PATCH 28/41] Cover elements in `test_permute_dims` --- array_api_tests/test_manipulation_functions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 061b7f0a..44098d20 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -237,11 +237,12 @@ def test_permute_dims(x, axes): for i, dim in enumerate(axes): side = x.shape[dim] shape[i] = side - assert all(isinstance(side, int) for side in shape) # sanity check shape = tuple(shape) ph.assert_result_shape("permute_dims", (x.shape,), out.shape, shape, axes=axes) - # TODO: test elements + indices = list(ah.ndindex(x.shape)) + permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices] + assert_array_ndindex("permute_dims", x, indices, out, permuted_indices) @st.composite From 80c4e31f403b63c0a0706c830e6ffc134f1265f2 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 6 Dec 2021 11:28:26 +0000 Subject: [PATCH 29/41] Fix `test_concat` axes iteration --- array_api_tests/test_manipulation_functions.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 44098d20..7397d47f 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,7 +1,7 @@ import math from collections import deque from itertools import product -from typing import Iterable, Iterator, Tuple, Union +from typing import Iterable, Union from hypothesis import assume, given from hypothesis import strategies as st @@ -45,15 +45,6 @@ def assert_array_ndindex( assert out[out_idx] == x[x_idx], msg -def axis_ndindex( - shape: Shape, axis: int -) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: - iterables = [range(side) for side in shape[:axis]] - for _ in range(len(shape[axis:])): - iterables.append([slice(None, None)]) - yield from product(*iterables) - - def assert_equals( func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw ): @@ -124,7 +115,10 @@ def test_concat(dtypes, kw, data): ) else: out_indices = ah.ndindex(out.shape) - for idx in axis_ndindex(shape, axis): + axis_indices = [range(side) for side in shapes[0][:_axis]] + for _ in range(_axis, len(shape)): + axis_indices.append([slice(None, None)]) + for idx in product(*axis_indices): f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) for x_num, x in enumerate(arrays, 1): indexed_x = x[idx] From 767bd3fbb872dd70da86e77096103e1e28102f8f Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 6 Dec 2021 16:07:56 +0000 Subject: [PATCH 30/41] Cover all shift/axes scenarios in `test_roll` --- .../test_manipulation_functions.py | 40 +++++++++++++------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 7397d47f..2db88177 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -274,16 +274,21 @@ def test_reshape(x, data): @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data()) def test_roll(x, data): - shift = data.draw( - st.integers() | st.lists(st.integers(), max_size=x.ndim).map(tuple), - label="shift", - ) - axis_strats = [st.none()] - if x.shape != (): - axis_strats.append(st.integers(-x.ndim, x.ndim - 1)) - if isinstance(shift, int): - axis_strats.append(xps.valid_tuple_axes(x.ndim)) - kw = data.draw(hh.kwargs(axis=st.one_of(axis_strats)), label="kw") + shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE) + if x.ndim > 0: + shift_strat = shift_strat | st.lists( + shift_strat, min_size=1, max_size=x.ndim + ).map(tuple) + shift = data.draw(shift_strat, label="shift") + if isinstance(shift, tuple): + axis_strat = xps.valid_tuple_axes(x.ndim).filter(lambda t: len(t) == len(shift)) + kw_strat = axis_strat.map(lambda t: {"axis": t}) + else: + axis_strat = st.none() + if x.ndim != 0: + axis_strat = axis_strat | st.integers(-x.ndim, x.ndim - 1) + kw_strat = hh.kwargs(axis=axis_strat) + kw = data.draw(kw_strat, label="kw") out = xp.roll(x, shift, **kw) @@ -291,12 +296,23 @@ def test_roll(x, data): ph.assert_result_shape("roll", (x.shape,), out.shape) - # TODO: test all shift/axis scenarios - if isinstance(shift, int) and kw.get("axis", None) is None: + if kw.get("axis", None) is None: + assert isinstance(shift, int) # sanity check indices = list(ah.ndindex(x.shape)) shifted_indices = deque(indices) shifted_indices.rotate(-shift) assert_array_ndindex("roll", x, indices, out, shifted_indices) + else: + _shift = (shift,) if isinstance(shift, int) else shift + axes = normalise_axis(kw["axis"], x.ndim) + all_indices = list(ah.ndindex(x.shape)) + for s, a in zip(_shift, axes): + side = x.shape[a] + for i in range(side): + indices = [idx for idx in all_indices if idx[a] == i] + shifted_indices = deque(indices) + shifted_indices.rotate(-s) + assert_array_ndindex("roll", x, indices, out, shifted_indices) @given( From aa7aaa06058e5968a3e5236ecdc3e3393dbb3f0f Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 7 Dec 2021 11:52:26 +0000 Subject: [PATCH 31/41] Cover all axis scenarios in `test_stack` --- .../test_manipulation_functions.py | 44 +++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 2db88177..49f86694 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -1,7 +1,7 @@ import math from collections import deque from itertools import product -from typing import Iterable, Union +from typing import Iterable, Iterator, Tuple, Union from hypothesis import assume, given from hypothesis import strategies as st @@ -28,6 +28,16 @@ def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: return st.shared(hh.shapes(*args, **kwargs), key="shape") +def axis_ndindex( + shape: Shape, axis: int +) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: + assert axis >= 0 # sanity check + axis_indices = [range(side) for side in shape[:axis]] + for _ in range(axis, len(shape)): + axis_indices.append([slice(None, None)]) + yield from product(*axis_indices) + + def assert_array_ndindex( func_name: str, x: Array, @@ -115,10 +125,7 @@ def test_concat(dtypes, kw, data): ) else: out_indices = ah.ndindex(out.shape) - axis_indices = [range(side) for side in shapes[0][:_axis]] - for _ in range(_axis, len(shape)): - axis_indices.append([slice(None, None)]) - for idx in product(*axis_indices): + for idx in axis_ndindex(shapes[0], _axis): f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) for x_num, x in enumerate(arrays, 1): indexed_x = x[idx] @@ -344,18 +351,19 @@ def test_stack(shape, dtypes, kw, data): "stack", tuple(x.shape for x in arrays), out.shape, _shape, **kw ) - # TODO: adjust indices with nonzero axis - if axis == 0: - out_indices = ah.ndindex(out.shape) - for i, x in enumerate(arrays, 1): - msg_suffix = f" [stack({ph.fmt_kw(kw)})]\nx{i}={x!r}\n{out=}" - for x_idx in ah.ndindex(x.shape): + out_indices = ah.ndindex(out.shape) + for idx in axis_ndindex(arrays[0].shape, axis=_axis): + f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) + print(f"{f_idx=}") + for x_num, x in enumerate(arrays, 1): + indexed_x = x[idx] + for x_idx in ah.ndindex(indexed_x.shape): out_idx = next(out_indices) - msg = ( - f"out[{out_idx}]={out[out_idx]}, should be x{i}[{x_idx}]={x[x_idx]}" + assert_equals( + "stack", + f"x{x_num}[{f_idx}][{x_idx}]", + indexed_x[x_idx], + f"out[{out_idx}]", + out[out_idx], + **kw, ) - msg += msg_suffix - if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]): - assert xp.isnan(out[out_idx]), msg - else: - assert out[out_idx] == x[x_idx], msg From b32fff07f624e8cc36569393bb0f039e8d60dc04 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 7 Dec 2021 12:03:43 +0000 Subject: [PATCH 32/41] Make float assertions more lenient in statistical tests --- array_api_tests/test_statistical_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 0373ac47..91aac683 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -94,7 +94,7 @@ def assert_equals( assert math.isnan(out), msg else: msg = f"{out_repr}={out}, should be roughly {expected} [{f_func}]" - assert math.isclose(out, expected, rel_tol=0.05), msg + assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg @given( @@ -175,6 +175,7 @@ def test_mean(x, data): ) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): mean = float(out[out_idx]) + assume(not math.isinf(mean)) # mean may become inf due to internal overflows elements = [] for idx in indices: s = float(x[idx]) From 5978fdae0b9c0e4e4cf40d880d312ebc5d4d12d2 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 7 Dec 2021 12:26:31 +0000 Subject: [PATCH 33/41] Update sum and prod tests to use new default uint --- array_api_tests/dtype_helpers.py | 421 +++++++++--------- array_api_tests/test_statistical_functions.py | 16 +- 2 files changed, 224 insertions(+), 213 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 65b4090a..ce749d9e 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,45 +1,45 @@ -from warnings import warn from functools import lru_cache from typing import NamedTuple, Tuple, Union +from warnings import warn from . import _array_module as xp from ._array_module import _UndefinedStub from .typing import DataType, ScalarType - __all__ = [ - 'int_dtypes', - 'uint_dtypes', - 'all_int_dtypes', - 'float_dtypes', - 'numeric_dtypes', - 'all_dtypes', - 'dtype_to_name', - 'bool_and_all_int_dtypes', - 'dtype_to_scalars', - 'is_int_dtype', - 'is_float_dtype', - 'get_scalar_type', - 'dtype_ranges', - 'default_int', - 'default_float', - 'promotion_table', - 'dtype_nbits', - 'dtype_signed', - 'func_in_dtypes', - 'func_returns_bool', - 'binary_op_to_symbol', - 'unary_op_to_symbol', - 'inplace_op_to_symbol', - 'op_to_func', - 'fmt_types', + "int_dtypes", + "uint_dtypes", + "all_int_dtypes", + "float_dtypes", + "numeric_dtypes", + "all_dtypes", + "dtype_to_name", + "bool_and_all_int_dtypes", + "dtype_to_scalars", + "is_int_dtype", + "is_float_dtype", + "get_scalar_type", + "dtype_ranges", + "default_int", + "default_uint", + "default_float", + "promotion_table", + "dtype_nbits", + "dtype_signed", + "func_in_dtypes", + "func_returns_bool", + "binary_op_to_symbol", + "unary_op_to_symbol", + "inplace_op_to_symbol", + "op_to_func", + "fmt_types", ] -_uint_names = ('uint8', 'uint16', 'uint32', 'uint64') -_int_names = ('int8', 'int16', 'int32', 'int64') -_float_names = ('float32', 'float64') -_dtype_names = ('bool',) + _uint_names + _int_names + _float_names +_uint_names = ("uint8", "uint16", "uint32", "uint64") +_int_names = ("int8", "int16", "int32", "int64") +_float_names = ("float32", "float64") +_dtype_names = ("bool",) + _uint_names + _int_names + _float_names uint_dtypes = tuple(getattr(xp, name) for name in _uint_names) @@ -101,17 +101,34 @@ class MinMax(NamedTuple): xp.uint64: MinMax(0, +18_446_744_073_709_551_615), } +dtype_nbits = { + **{d: 8 for d in [xp.int8, xp.uint8]}, + **{d: 16 for d in [xp.int16, xp.uint16]}, + **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, + **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, +} + + +dtype_signed = { + **{d: True for d in int_dtypes}, + **{d: False for d in uint_dtypes}, +} + if isinstance(xp.asarray, _UndefinedStub): default_int = xp.int32 default_float = xp.float32 warn( - 'array module does not have attribute asarray. ' - 'default int is assumed int32, default float is assumed float32' + "array module does not have attribute asarray. " + "default int is assumed int32, default float is assumed float32" ) else: default_int = xp.asarray(int()).dtype default_float = xp.asarray(float()).dtype +if dtype_nbits[default_int] == 32: + default_uint = xp.uint32 +else: + default_uint = xp.uint64 _numeric_promotions = { @@ -173,200 +190,186 @@ def result_type(*dtypes: DataType): return result -dtype_nbits = { - **{d: 8 for d in [xp.int8, xp.uint8]}, - **{d: 16 for d in [xp.int16, xp.uint16]}, - **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, - **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, -} - - -dtype_signed = { - **{d: True for d in int_dtypes}, - **{d: False for d in uint_dtypes}, -} - - func_in_dtypes = { # elementwise - 'abs': numeric_dtypes, - 'acos': float_dtypes, - 'acosh': float_dtypes, - 'add': numeric_dtypes, - 'asin': float_dtypes, - 'asinh': float_dtypes, - 'atan': float_dtypes, - 'atan2': float_dtypes, - 'atanh': float_dtypes, - 'bitwise_and': bool_and_all_int_dtypes, - 'bitwise_invert': bool_and_all_int_dtypes, - 'bitwise_left_shift': all_int_dtypes, - 'bitwise_or': bool_and_all_int_dtypes, - 'bitwise_right_shift': all_int_dtypes, - 'bitwise_xor': bool_and_all_int_dtypes, - 'ceil': numeric_dtypes, - 'cos': float_dtypes, - 'cosh': float_dtypes, - 'divide': float_dtypes, - 'equal': all_dtypes, - 'exp': float_dtypes, - 'expm1': float_dtypes, - 'floor': numeric_dtypes, - 'floor_divide': numeric_dtypes, - 'greater': numeric_dtypes, - 'greater_equal': numeric_dtypes, - 'isfinite': numeric_dtypes, - 'isinf': numeric_dtypes, - 'isnan': numeric_dtypes, - 'less': numeric_dtypes, - 'less_equal': numeric_dtypes, - 'log': float_dtypes, - 'logaddexp': float_dtypes, - 'log10': float_dtypes, - 'log1p': float_dtypes, - 'log2': float_dtypes, - 'logical_and': (xp.bool,), - 'logical_not': (xp.bool,), - 'logical_or': (xp.bool,), - 'logical_xor': (xp.bool,), - 'multiply': numeric_dtypes, - 'negative': numeric_dtypes, - 'not_equal': all_dtypes, - 'positive': numeric_dtypes, - 'pow': float_dtypes, - 'remainder': numeric_dtypes, - 'round': numeric_dtypes, - 'sign': numeric_dtypes, - 'sin': float_dtypes, - 'sinh': float_dtypes, - 'sqrt': float_dtypes, - 'square': numeric_dtypes, - 'subtract': numeric_dtypes, - 'tan': float_dtypes, - 'tanh': float_dtypes, - 'trunc': numeric_dtypes, + "abs": numeric_dtypes, + "acos": float_dtypes, + "acosh": float_dtypes, + "add": numeric_dtypes, + "asin": float_dtypes, + "asinh": float_dtypes, + "atan": float_dtypes, + "atan2": float_dtypes, + "atanh": float_dtypes, + "bitwise_and": bool_and_all_int_dtypes, + "bitwise_invert": bool_and_all_int_dtypes, + "bitwise_left_shift": all_int_dtypes, + "bitwise_or": bool_and_all_int_dtypes, + "bitwise_right_shift": all_int_dtypes, + "bitwise_xor": bool_and_all_int_dtypes, + "ceil": numeric_dtypes, + "cos": float_dtypes, + "cosh": float_dtypes, + "divide": float_dtypes, + "equal": all_dtypes, + "exp": float_dtypes, + "expm1": float_dtypes, + "floor": numeric_dtypes, + "floor_divide": numeric_dtypes, + "greater": numeric_dtypes, + "greater_equal": numeric_dtypes, + "isfinite": numeric_dtypes, + "isinf": numeric_dtypes, + "isnan": numeric_dtypes, + "less": numeric_dtypes, + "less_equal": numeric_dtypes, + "log": float_dtypes, + "logaddexp": float_dtypes, + "log10": float_dtypes, + "log1p": float_dtypes, + "log2": float_dtypes, + "logical_and": (xp.bool,), + "logical_not": (xp.bool,), + "logical_or": (xp.bool,), + "logical_xor": (xp.bool,), + "multiply": numeric_dtypes, + "negative": numeric_dtypes, + "not_equal": all_dtypes, + "positive": numeric_dtypes, + "pow": float_dtypes, + "remainder": numeric_dtypes, + "round": numeric_dtypes, + "sign": numeric_dtypes, + "sin": float_dtypes, + "sinh": float_dtypes, + "sqrt": float_dtypes, + "square": numeric_dtypes, + "subtract": numeric_dtypes, + "tan": float_dtypes, + "tanh": float_dtypes, + "trunc": numeric_dtypes, # searching - 'where': all_dtypes, + "where": all_dtypes, } func_returns_bool = { # elementwise - 'abs': False, - 'acos': False, - 'acosh': False, - 'add': False, - 'asin': False, - 'asinh': False, - 'atan': False, - 'atan2': False, - 'atanh': False, - 'bitwise_and': False, - 'bitwise_invert': False, - 'bitwise_left_shift': False, - 'bitwise_or': False, - 'bitwise_right_shift': False, - 'bitwise_xor': False, - 'ceil': False, - 'cos': False, - 'cosh': False, - 'divide': False, - 'equal': True, - 'exp': False, - 'expm1': False, - 'floor': False, - 'floor_divide': False, - 'greater': True, - 'greater_equal': True, - 'isfinite': True, - 'isinf': True, - 'isnan': True, - 'less': True, - 'less_equal': True, - 'log': False, - 'logaddexp': False, - 'log10': False, - 'log1p': False, - 'log2': False, - 'logical_and': True, - 'logical_not': True, - 'logical_or': True, - 'logical_xor': True, - 'multiply': False, - 'negative': False, - 'not_equal': True, - 'positive': False, - 'pow': False, - 'remainder': False, - 'round': False, - 'sign': False, - 'sin': False, - 'sinh': False, - 'sqrt': False, - 'square': False, - 'subtract': False, - 'tan': False, - 'tanh': False, - 'trunc': False, + "abs": False, + "acos": False, + "acosh": False, + "add": False, + "asin": False, + "asinh": False, + "atan": False, + "atan2": False, + "atanh": False, + "bitwise_and": False, + "bitwise_invert": False, + "bitwise_left_shift": False, + "bitwise_or": False, + "bitwise_right_shift": False, + "bitwise_xor": False, + "ceil": False, + "cos": False, + "cosh": False, + "divide": False, + "equal": True, + "exp": False, + "expm1": False, + "floor": False, + "floor_divide": False, + "greater": True, + "greater_equal": True, + "isfinite": True, + "isinf": True, + "isnan": True, + "less": True, + "less_equal": True, + "log": False, + "logaddexp": False, + "log10": False, + "log1p": False, + "log2": False, + "logical_and": True, + "logical_not": True, + "logical_or": True, + "logical_xor": True, + "multiply": False, + "negative": False, + "not_equal": True, + "positive": False, + "pow": False, + "remainder": False, + "round": False, + "sign": False, + "sin": False, + "sinh": False, + "sqrt": False, + "square": False, + "subtract": False, + "tan": False, + "tanh": False, + "trunc": False, # searching - 'where': False, + "where": False, } unary_op_to_symbol = { - '__invert__': '~', - '__neg__': '-', - '__pos__': '+', + "__invert__": "~", + "__neg__": "-", + "__pos__": "+", } binary_op_to_symbol = { - '__add__': '+', - '__and__': '&', - '__eq__': '==', - '__floordiv__': '//', - '__ge__': '>=', - '__gt__': '>', - '__le__': '<=', - '__lshift__': '<<', - '__lt__': '<', - '__matmul__': '@', - '__mod__': '%', - '__mul__': '*', - '__ne__': '!=', - '__or__': '|', - '__pow__': '**', - '__rshift__': '>>', - '__sub__': '-', - '__truediv__': '/', - '__xor__': '^', + "__add__": "+", + "__and__": "&", + "__eq__": "==", + "__floordiv__": "//", + "__ge__": ">=", + "__gt__": ">", + "__le__": "<=", + "__lshift__": "<<", + "__lt__": "<", + "__matmul__": "@", + "__mod__": "%", + "__mul__": "*", + "__ne__": "!=", + "__or__": "|", + "__pow__": "**", + "__rshift__": ">>", + "__sub__": "-", + "__truediv__": "/", + "__xor__": "^", } op_to_func = { - '__abs__': 'abs', - '__add__': 'add', - '__and__': 'bitwise_and', - '__eq__': 'equal', - '__floordiv__': 'floor_divide', - '__ge__': 'greater_equal', - '__gt__': 'greater', - '__le__': 'less_equal', - '__lt__': 'less', + "__abs__": "abs", + "__add__": "add", + "__and__": "bitwise_and", + "__eq__": "equal", + "__floordiv__": "floor_divide", + "__ge__": "greater_equal", + "__gt__": "greater", + "__le__": "less_equal", + "__lt__": "less", # '__matmul__': 'matmul', # TODO: support matmul - '__mod__': 'remainder', - '__mul__': 'multiply', - '__ne__': 'not_equal', - '__or__': 'bitwise_or', - '__pow__': 'pow', - '__lshift__': 'bitwise_left_shift', - '__rshift__': 'bitwise_right_shift', - '__sub__': 'subtract', - '__truediv__': 'divide', - '__xor__': 'bitwise_xor', - '__invert__': 'bitwise_invert', - '__neg__': 'negative', - '__pos__': 'positive', + "__mod__": "remainder", + "__mul__": "multiply", + "__ne__": "not_equal", + "__or__": "bitwise_or", + "__pow__": "pow", + "__lshift__": "bitwise_left_shift", + "__rshift__": "bitwise_right_shift", + "__sub__": "subtract", + "__truediv__": "divide", + "__xor__": "bitwise_xor", + "__invert__": "bitwise_invert", + "__neg__": "negative", + "__pos__": "positive", } @@ -377,10 +380,10 @@ def result_type(*dtypes: DataType): inplace_op_to_symbol = {} for op, symbol in binary_op_to_symbol.items(): - if op == '__matmul__' or func_returns_bool[op]: + if op == "__matmul__" or func_returns_bool[op]: continue - iop = f'__i{op[2:]}' - inplace_op_to_symbol[iop] = f'{symbol}=' + iop = f"__i{op[2:]}" + inplace_op_to_symbol[iop] = f"{symbol}=" func_in_dtypes[iop] = func_in_dtypes[op] func_returns_bool[iop] = func_returns_bool[op] @@ -394,4 +397,4 @@ def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str: except KeyError: # i.e. dtype is bool, int, or float f_types.append(type_.__name__) - return ', '.join(f_types) + return ", ".join(f_types) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 91aac683..d7ab381a 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -207,12 +207,16 @@ def test_prod(x, data): dtype = kw.get("dtype", None) if dtype is None: if dh.is_int_dtype(x.dtype): + if x.dtype in dh.uint_dtypes: + default_dtype = dh.default_uint + else: + default_dtype = dh.default_int m, M = dh.dtype_ranges[x.dtype] - d_m, d_M = dh.dtype_ranges[dh.default_int] + d_m, d_M = dh.dtype_ranges[default_dtype] if m < d_m or M > d_M: _dtype = x.dtype else: - _dtype = dh.default_int + _dtype = default_dtype else: if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: _dtype = x.dtype @@ -333,12 +337,16 @@ def test_sum(x, data): dtype = kw.get("dtype", None) if dtype is None: if dh.is_int_dtype(x.dtype): + if x.dtype in dh.uint_dtypes: + default_dtype = dh.default_uint + else: + default_dtype = dh.default_int m, M = dh.dtype_ranges[x.dtype] - d_m, d_M = dh.dtype_ranges[dh.default_int] + d_m, d_M = dh.dtype_ranges[default_dtype] if m < d_m or M > d_M: _dtype = x.dtype else: - _dtype = dh.default_int + _dtype = default_dtype else: if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: _dtype = x.dtype From 9f3a83e73ea9fc6120ed991798a108d439157f7c Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 8 Dec 2021 09:56:44 +0000 Subject: [PATCH 34/41] Sort statistical tests by spec order --- array_api_tests/test_statistical_functions.py | 124 +++++++++--------- 1 file changed, 62 insertions(+), 62 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index d7ab381a..5ec98fca 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -105,83 +105,83 @@ def assert_equals( ), data=st.data(), ) -def test_min(x, data): +def test_max(x, data): kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") - out = xp.min(x, **kw) + out = xp.max(x, **kw) - ph.assert_dtype("min", x.dtype, out.dtype) + ph.assert_dtype("max", x.dtype, out.dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + "max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): - min_ = scalar_type(out[out_idx]) + max_ = scalar_type(out[out_idx]) elements = [] for idx in indices: s = scalar_type(x[idx]) elements.append(s) - expected = min(elements) - assert_equals("min", dh.get_scalar_type(out.dtype), out_idx, min_, expected) + expected = max(elements) + assert_equals("max", dh.get_scalar_type(out.dtype), out_idx, max_, expected) @given( x=xps.arrays( - dtype=xps.numeric_dtypes(), + dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1), elements={"allow_nan": False}, ), data=st.data(), ) -def test_max(x, data): +def test_mean(x, data): kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") - out = xp.max(x, **kw) + out = xp.mean(x, **kw) - ph.assert_dtype("max", x.dtype, out.dtype) + ph.assert_dtype("mean", x.dtype, out.dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + "mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) - scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): - max_ = scalar_type(out[out_idx]) + mean = float(out[out_idx]) + assume(not math.isinf(mean)) # mean may become inf due to internal overflows elements = [] for idx in indices: - s = scalar_type(x[idx]) + s = float(x[idx]) elements.append(s) - expected = max(elements) - assert_equals("max", dh.get_scalar_type(out.dtype), out_idx, max_, expected) + expected = sum(elements) / len(elements) + assert_equals("mean", dh.get_scalar_type(out.dtype), out_idx, mean, expected) @given( x=xps.arrays( - dtype=xps.floating_dtypes(), + dtype=xps.numeric_dtypes(), shape=hh.shapes(min_side=1), elements={"allow_nan": False}, ), data=st.data(), ) -def test_mean(x, data): +def test_min(x, data): kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw") - out = xp.mean(x, **kw) + out = xp.min(x, **kw) - ph.assert_dtype("mean", x.dtype, out.dtype) + ph.assert_dtype("min", x.dtype, out.dtype) _axes = normalise_axis(kw.get("axis", None), x.ndim) assert_keepdimable_shape( - "mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + "min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw ) + scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): - mean = float(out[out_idx]) - assume(not math.isinf(mean)) # mean may become inf due to internal overflows + min_ = scalar_type(out[out_idx]) elements = [] for idx in indices: - s = float(x[idx]) + s = scalar_type(x[idx]) elements.append(s) - expected = sum(elements) / len(elements) - assert_equals("mean", dh.get_scalar_type(out.dtype), out_idx, mean, expected) + expected = min(elements) + assert_equals("min", dh.get_scalar_type(out.dtype), out_idx, min_, expected) @given( @@ -279,41 +279,6 @@ def test_std(x, data): # We can't easily test the result(s) as standard deviation methods vary a lot -@given( - x=xps.arrays( - dtype=xps.floating_dtypes(), - shape=hh.shapes(min_side=1), - elements={"allow_nan": False}, - ).filter(lambda x: x.size >= 2), - data=st.data(), -) -def test_var(x, data): - axis = data.draw(axes(x.ndim), label="axis") - _axes = normalise_axis(axis, x.ndim) - N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) - correction = data.draw( - st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), - label="correction", - ) - keepdims = data.draw(st.booleans(), label="keepdims") - kw = data.draw( - hh.specified_kwargs( - ("axis", axis, None), - ("correction", correction, 0.0), - ("keepdims", keepdims, False), - ), - label="kw", - ) - - out = xp.var(x, **kw) - - ph.assert_dtype("var", x.dtype, out.dtype) - assert_keepdimable_shape( - "var", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw - ) - # We can't easily test the result(s) as variance methods vary a lot - - @given( x=xps.arrays( dtype=xps.numeric_dtypes(), @@ -372,3 +337,38 @@ def test_sum(x, data): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) assert_equals("sum", dh.get_scalar_type(out.dtype), out_idx, sum_, expected) + + +@given( + x=xps.arrays( + dtype=xps.floating_dtypes(), + shape=hh.shapes(min_side=1), + elements={"allow_nan": False}, + ).filter(lambda x: x.size >= 2), + data=st.data(), +) +def test_var(x, data): + axis = data.draw(axes(x.ndim), label="axis") + _axes = normalise_axis(axis, x.ndim) + N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) + correction = data.draw( + st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), + label="correction", + ) + keepdims = data.draw(st.booleans(), label="keepdims") + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("correction", correction, 0.0), + ("keepdims", keepdims, False), + ), + label="kw", + ) + + out = xp.var(x, **kw) + + ph.assert_dtype("var", x.dtype, out.dtype) + assert_keepdimable_shape( + "var", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw + ) + # We can't easily test the result(s) as variance methods vary a lot From 09aa26b831e66abf02322604949d35af5b2c4256 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 8 Dec 2021 10:28:10 +0000 Subject: [PATCH 35/41] Check error raising in `test_squeeze`, use negative axes --- .../test_manipulation_functions.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 49f86694..47ed8576 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -3,6 +3,7 @@ from itertools import product from typing import Iterable, Iterator, Tuple, Union +import pytest from hypothesis import assume, given from hypothesis import strategies as st @@ -168,23 +169,26 @@ def test_expand_dims(x, axis): data=st.data(), ) def test_squeeze(x, data): - # TODO: generate valid negative axis (which keep uniqueness) - squeezable_axes = st.sampled_from( - [i for i, side in enumerate(x.shape) if side == 1] - ) + axes = st.integers(-x.ndim, x.ndim - 1) axis = data.draw( - squeezable_axes | st.lists(squeezable_axes, unique=True).map(tuple), + axes + | st.lists(axes, unique_by=lambda i: i if i >= 0 else i + x.ndim).map(tuple), label="axis", ) + axes = (axis,) if isinstance(axis, int) else axis + axes = normalise_axis(axes, x.ndim) + + squeezable_axes = [i for i, side in enumerate(x.shape) if side == 1] + if any(i not in squeezable_axes for i in axes): + with pytest.raises(ValueError): + xp.squeeze(x, axis) + return + out = xp.squeeze(x, axis) ph.assert_dtype("squeeze", x.dtype, out.dtype) - if isinstance(axis, int): - axes = (axis,) - else: - axes = axis shape = [] for i, side in enumerate(x.shape): if i not in axes: From 18709f6310cfd9d7afcdf9c93fcdd68845db3b96 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 8 Dec 2021 10:46:32 +0000 Subject: [PATCH 36/41] Cover invalid axis in `test_expand_dims` --- array_api_tests/test_manipulation_functions.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 47ed8576..9e33567a 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -144,9 +144,17 @@ def test_concat(dtypes, kw, data): @given( x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), - axis=shared_shapes().flatmap(lambda s: st.integers(-len(s) - 1, len(s))), + axis=shared_shapes().flatmap( + # Generate both valid and invalid axis + lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s)) + ), ) def test_expand_dims(x, axis): + if axis < -x.ndim - 1 or axis > x.ndim: + with pytest.raises(IndexError): + xp.expand_dims(x, axis=axis) + return + out = xp.expand_dims(x, axis=axis) ph.assert_dtype("expand_dims", x.dtype, out.dtype) From 2ad6ddcb4dca8a69d4cfe7db8e57fa7c2131baa6 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 8 Dec 2021 11:05:32 +0000 Subject: [PATCH 37/41] Try to ignore overflow scenarios in `prod` and `sum` tests --- array_api_tests/test_statistical_functions.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 5ec98fca..24067108 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -4,6 +4,7 @@ from hypothesis import assume, given from hypothesis import strategies as st +from hypothesis.control import reject from . import _array_module as xp from . import array_helpers as ah @@ -202,7 +203,10 @@ def test_prod(x, data): label="kw", ) - out = xp.prod(x, **kw) + try: + out = xp.prod(x, **kw) + except OverflowError: + reject() dtype = kw.get("dtype", None) if dtype is None: @@ -232,7 +236,7 @@ def test_prod(x, data): scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): prod = scalar_type(out[out_idx]) - assume(not math.isinf(prod)) + assume(math.isfinite(prod)) elements = [] for idx in indices: s = scalar_type(x[idx]) @@ -297,7 +301,10 @@ def test_sum(x, data): label="kw", ) - out = xp.sum(x, **kw) + try: + out = xp.sum(x, **kw) + except OverflowError: + reject() dtype = kw.get("dtype", None) if dtype is None: @@ -327,7 +334,7 @@ def test_sum(x, data): scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)): sum_ = scalar_type(out[out_idx]) - assume(not math.isinf(sum_)) + assume(math.isfinite(sum_)) elements = [] for idx in indices: s = scalar_type(x[idx]) From 6e4564b618e38400c3f50cc1bb47fb20c52ea606 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 8 Dec 2021 11:15:54 +0000 Subject: [PATCH 38/41] Docstrings for axes helpers --- array_api_tests/test_manipulation_functions.py | 3 ++- array_api_tests/test_statistical_functions.py | 17 +++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 9e33567a..a2e2e2a3 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -13,7 +13,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import xps -from .test_statistical_functions import axes_ndindex, normalise_axis # TODO: Move +from .test_statistical_functions import axes_ndindex, normalise_axis from .typing import Array, Shape MAX_SIDE = hh.MAX_ARRAY_SIZE // 64 @@ -32,6 +32,7 @@ def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: def axis_ndindex( shape: Shape, axis: int ) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]: + """Generate indices that index all elements in dimensions beyond `axis`""" assert axis >= 0 # sanity check axis_indices = [range(side) for side in shape[:axis]] for _ in range(axis, len(shape)): diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 24067108..c62be30f 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -39,18 +39,19 @@ def normalise_axis( def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[Tuple[Shape, ...]]: - base_iterables = [] - axes_iterables = [] + """Generate indices that index all elements except in `axes` dimensions""" + base_indices = [] + axes_indices = [] for axis, side in enumerate(shape): if axis in axes: - base_iterables.append([None]) - axes_iterables.append(range(side)) + base_indices.append([None]) + axes_indices.append(range(side)) else: - base_iterables.append(range(side)) - axes_iterables.append([None]) - for base_idx in product(*base_iterables): + base_indices.append(range(side)) + axes_indices.append([None]) + for base_idx in product(*base_indices): indices = [] - for idx in product(*axes_iterables): + for idx in product(*axes_indices): idx = list(idx) for axis, side in enumerate(idx): if axis not in axes: From 0326aa395c2ec75d4133962fb877f5956e538504 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 8 Dec 2021 18:56:28 +0000 Subject: [PATCH 39/41] Remove redundant calls to `dh.get_scalar_type()` --- array_api_tests/test_statistical_functions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index c62be30f..7813d2ea 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -125,7 +125,7 @@ def test_max(x, data): s = scalar_type(x[idx]) elements.append(s) expected = max(elements) - assert_equals("max", dh.get_scalar_type(out.dtype), out_idx, max_, expected) + assert_equals("max", scalar_type, out_idx, max_, expected) @given( @@ -154,7 +154,7 @@ def test_mean(x, data): s = float(x[idx]) elements.append(s) expected = sum(elements) / len(elements) - assert_equals("mean", dh.get_scalar_type(out.dtype), out_idx, mean, expected) + assert_equals("mean", float, out_idx, mean, expected) @given( @@ -183,7 +183,7 @@ def test_min(x, data): s = scalar_type(x[idx]) elements.append(s) expected = min(elements) - assert_equals("min", dh.get_scalar_type(out.dtype), out_idx, min_, expected) + assert_equals("min", scalar_type, out_idx, min_, expected) @given( @@ -246,7 +246,7 @@ def test_prod(x, data): if dh.is_int_dtype(out.dtype): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) - assert_equals("prod", dh.get_scalar_type(out.dtype), out_idx, prod, expected) + assert_equals("prod", scalar_type, out_idx, prod, expected) @given( @@ -344,7 +344,7 @@ def test_sum(x, data): if dh.is_int_dtype(out.dtype): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) - assert_equals("sum", dh.get_scalar_type(out.dtype), out_idx, sum_, expected) + assert_equals("sum", scalar_type, out_idx, sum_, expected) @given( From cd8f1172b2de845b546312d750f05cba39cfacf9 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 10 Dec 2021 19:30:58 +0000 Subject: [PATCH 40/41] Fix `test_manipulation_functions.assert_equals` --- array_api_tests/test_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index a2e2e2a3..1e95ca10 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -67,7 +67,7 @@ def assert_equals( if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val): assert xp.isnan(x_val), msg else: - assert out_val == out_val, msg + assert x_val == out_val, msg @st.composite From af285de6675fe4cdfabce6a610e16c6c7542c37a Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 10 Dec 2021 19:44:00 +0000 Subject: [PATCH 41/41] Skip `test_roll` as its wrong --- array_api_tests/test_manipulation_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 1e95ca10..9453fc25 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -292,6 +292,7 @@ def test_reshape(x, data): assert_array_ndindex("reshape", x, ah.ndindex(x.shape), out, ah.ndindex(out.shape)) +@pytest.mark.skip(reason="faulty test logic") # TODO @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data()) def test_roll(x, data): shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE)