Skip to content

Commit 69c2dab

Browse files
authored
Merge pull request #61 from honno/unstable-sort
Correctly test unstable argsorts
2 parents b08b33f + 424898c commit 69c2dab

File tree

2 files changed

+57
-17
lines changed

2 files changed

+57
-17
lines changed

xptests/pytest_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,13 @@ def assert_scalar_equals(
208208
out_repr = "out" if idx == () else f"out[{idx}]"
209209
f_func = f"{func_name}({fmt_kw(kw)})"
210210
if type_ is bool or type_ is int:
211-
msg = f"{out_repr}={out}, should be {expected} [{f_func}]"
211+
msg = f"{out_repr}={out}, but should be {expected} [{f_func}]"
212212
assert out == expected, msg
213213
elif math.isnan(expected):
214-
msg = f"{out_repr}={out}, should be {expected} [{f_func}]"
214+
msg = f"{out_repr}={out}, but should be {expected} [{f_func}]"
215215
assert math.isnan(out), msg
216216
else:
217-
msg = f"{out_repr}={out}, should be roughly {expected} [{f_func}]"
217+
msg = f"{out_repr}={out}, but should be roughly {expected} [{f_func}]"
218218
assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg
219219

220220

xptests/test_sorting_functions.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
import math
2+
from typing import Set
3+
14
from hypothesis import given
25
from hypothesis import strategies as st
36
from hypothesis.control import assume
47

8+
from xptests.typing import Scalar, ScalarType, Shape
9+
510
from . import _array_module as xp
611
from . import dtype_helpers as dh
712
from . import hypothesis_helpers as hh
@@ -10,6 +15,22 @@
1015
from . import xps
1116

1217

18+
def assert_scalar_in_set(
19+
func_name: str,
20+
type_: ScalarType,
21+
idx: Shape,
22+
out: Scalar,
23+
set_: Set[Scalar],
24+
/,
25+
**kw,
26+
):
27+
out_repr = "out" if idx == () else f"out[{idx}]"
28+
if math.isnan(out):
29+
raise NotImplementedError()
30+
msg = f"{out_repr}={out}, but should be in {set_} [{func_name}({ph.fmt_kw(kw)})]"
31+
assert out in set_, msg
32+
33+
1334
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
1435
@given(
1536
x=xps.arrays(
@@ -34,20 +55,39 @@ def test_argsort(x, data):
3455

3556
out = xp.argsort(x, **kw)
3657

37-
ph.assert_default_index("sort", out.dtype)
38-
ph.assert_shape("sort", out.shape, x.shape, **kw)
58+
ph.assert_default_index("argsort", out.dtype)
59+
ph.assert_shape("argsort", out.shape, x.shape, **kw)
3960
axis = kw.get("axis", -1)
4061
axes = sh.normalise_axis(axis, x.ndim)
41-
descending = kw.get("descending", False)
4262
scalar_type = dh.get_scalar_type(x.dtype)
4363
for indices in sh.axes_ndindex(x.shape, axes):
4464
elements = [scalar_type(x[idx]) for idx in indices]
45-
indices_order = sorted(range(len(indices)), key=elements.__getitem__)
46-
if descending:
47-
# sorted(..., reverse=descending) doesn't always work
48-
indices_order = reversed(indices_order)
49-
for idx, o in zip(indices, indices_order):
50-
ph.assert_scalar_equals("argsort", int, idx, int(out[idx]), o)
65+
orders = sorted(range(len(elements)), key=elements.__getitem__)
66+
if kw.get("descending", False):
67+
orders = reversed(orders)
68+
if kw.get("stable", True):
69+
for idx, o in zip(indices, orders):
70+
ph.assert_scalar_equals("argsort", int, idx, int(out[idx]), o)
71+
else:
72+
idx_elements = dict(zip(indices, elements))
73+
idx_orders = dict(zip(indices, orders))
74+
element_orders = {}
75+
for e in set(elements):
76+
element_orders[e] = [
77+
idx_orders[idx] for idx in indices if idx_elements[idx] == e
78+
]
79+
for idx, e in zip(indices, elements):
80+
o = int(out[idx])
81+
expected_orders = element_orders[e]
82+
if len(expected_orders) == 1:
83+
expected_order = expected_orders[0]
84+
ph.assert_scalar_equals(
85+
"argsort", int, idx, o, expected_order, **kw
86+
)
87+
else:
88+
assert_scalar_in_set(
89+
"argsort", int, idx, o, set(expected_orders), **kw
90+
)
5191

5292

5393
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
@@ -78,15 +118,15 @@ def test_sort(x, data):
78118
ph.assert_shape("sort", out.shape, x.shape, **kw)
79119
axis = kw.get("axis", -1)
80120
axes = sh.normalise_axis(axis, x.ndim)
81-
descending = kw.get("descending", False)
82121
scalar_type = dh.get_scalar_type(x.dtype)
83122
for indices in sh.axes_ndindex(x.shape, axes):
84123
elements = [scalar_type(x[idx]) for idx in indices]
85-
indices_order = sorted(
86-
range(len(indices)), key=elements.__getitem__, reverse=descending
124+
size = len(elements)
125+
orders = sorted(
126+
range(size), key=elements.__getitem__, reverse=kw.get("descending", False)
87127
)
88-
x_indices = [indices[o] for o in indices_order]
89-
for out_idx, x_idx in zip(indices, x_indices):
128+
for out_idx, o in zip(indices, orders):
129+
x_idx = indices[o]
90130
ph.assert_0d_equals(
91131
"sort",
92132
f"x[{x_idx}]",

0 commit comments

Comments
 (0)