From 424898c41a00b31e97f4d91ed55450ebacfd75ac Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 6 Jan 2022 11:56:54 +0000 Subject: [PATCH] Correctly test unstable argsorts --- xptests/pytest_helpers.py | 6 +-- xptests/test_sorting_functions.py | 68 ++++++++++++++++++++++++------- 2 files changed, 57 insertions(+), 17 deletions(-) diff --git a/xptests/pytest_helpers.py b/xptests/pytest_helpers.py index c8fd1fdb..b10647dc 100644 --- a/xptests/pytest_helpers.py +++ b/xptests/pytest_helpers.py @@ -208,13 +208,13 @@ def assert_scalar_equals( out_repr = "out" if idx == () else f"out[{idx}]" f_func = f"{func_name}({fmt_kw(kw)})" if type_ is bool or type_ is int: - msg = f"{out_repr}={out}, should be {expected} [{f_func}]" + msg = f"{out_repr}={out}, but should be {expected} [{f_func}]" assert out == expected, msg elif math.isnan(expected): - msg = f"{out_repr}={out}, should be {expected} [{f_func}]" + msg = f"{out_repr}={out}, but should be {expected} [{f_func}]" assert math.isnan(out), msg else: - msg = f"{out_repr}={out}, should be roughly {expected} [{f_func}]" + msg = f"{out_repr}={out}, but should be roughly {expected} [{f_func}]" assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg diff --git a/xptests/test_sorting_functions.py b/xptests/test_sorting_functions.py index 0c7334cc..3a527ff7 100644 --- a/xptests/test_sorting_functions.py +++ b/xptests/test_sorting_functions.py @@ -1,7 +1,12 @@ +import math +from typing import Set + from hypothesis import given from hypothesis import strategies as st from hypothesis.control import assume +from xptests.typing import Scalar, ScalarType, Shape + from . import _array_module as xp from . import dtype_helpers as dh from . import hypothesis_helpers as hh @@ -10,6 +15,22 @@ from . import xps +def assert_scalar_in_set( + func_name: str, + type_: ScalarType, + idx: Shape, + out: Scalar, + set_: Set[Scalar], + /, + **kw, +): + out_repr = "out" if idx == () else f"out[{idx}]" + if math.isnan(out): + raise NotImplementedError() + msg = f"{out_repr}={out}, but should be in {set_} [{func_name}({ph.fmt_kw(kw)})]" + assert out in set_, msg + + # TODO: Test with signed zeros and NaNs (and ignore them somehow) @given( x=xps.arrays( @@ -34,20 +55,39 @@ def test_argsort(x, data): out = xp.argsort(x, **kw) - ph.assert_default_index("sort", out.dtype) - ph.assert_shape("sort", out.shape, x.shape, **kw) + ph.assert_default_index("argsort", out.dtype) + ph.assert_shape("argsort", out.shape, x.shape, **kw) axis = kw.get("axis", -1) axes = sh.normalise_axis(axis, x.ndim) - descending = kw.get("descending", False) scalar_type = dh.get_scalar_type(x.dtype) for indices in sh.axes_ndindex(x.shape, axes): elements = [scalar_type(x[idx]) for idx in indices] - indices_order = sorted(range(len(indices)), key=elements.__getitem__) - if descending: - # sorted(..., reverse=descending) doesn't always work - indices_order = reversed(indices_order) - for idx, o in zip(indices, indices_order): - ph.assert_scalar_equals("argsort", int, idx, int(out[idx]), o) + orders = sorted(range(len(elements)), key=elements.__getitem__) + if kw.get("descending", False): + orders = reversed(orders) + if kw.get("stable", True): + for idx, o in zip(indices, orders): + ph.assert_scalar_equals("argsort", int, idx, int(out[idx]), o) + else: + idx_elements = dict(zip(indices, elements)) + idx_orders = dict(zip(indices, orders)) + element_orders = {} + for e in set(elements): + element_orders[e] = [ + idx_orders[idx] for idx in indices if idx_elements[idx] == e + ] + for idx, e in zip(indices, elements): + o = int(out[idx]) + expected_orders = element_orders[e] + if len(expected_orders) == 1: + expected_order = expected_orders[0] + ph.assert_scalar_equals( + "argsort", int, idx, o, expected_order, **kw + ) + else: + assert_scalar_in_set( + "argsort", int, idx, o, set(expected_orders), **kw + ) # TODO: Test with signed zeros and NaNs (and ignore them somehow) @@ -78,15 +118,15 @@ def test_sort(x, data): ph.assert_shape("sort", out.shape, x.shape, **kw) axis = kw.get("axis", -1) axes = sh.normalise_axis(axis, x.ndim) - descending = kw.get("descending", False) scalar_type = dh.get_scalar_type(x.dtype) for indices in sh.axes_ndindex(x.shape, axes): elements = [scalar_type(x[idx]) for idx in indices] - indices_order = sorted( - range(len(indices)), key=elements.__getitem__, reverse=descending + size = len(elements) + orders = sorted( + range(size), key=elements.__getitem__, reverse=kw.get("descending", False) ) - x_indices = [indices[o] for o in indices_order] - for out_idx, x_idx in zip(indices, x_indices): + for out_idx, o in zip(indices, orders): + x_idx = indices[o] ph.assert_0d_equals( "sort", f"x[{x_idx}]",