Skip to content

Commit 6e50fde

Browse files
oleksandr-pavlykndgrigorian
authored andcommitted
Add information displayed on failure, renamed variables
Add check of computed against expected indices
1 parent 16aca70 commit 6e50fde

File tree

1 file changed

+83
-14
lines changed

1 file changed

+83
-14
lines changed

dpctl/tests/test_usm_ndarray_top_k.py

Lines changed: 83 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,38 @@
2020
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2121

2222

23+
def _expected_largest_inds(inp, n, shift, k):
24+
"Computed expected top_k indices for mode='largest'"
25+
assert k < n
26+
ones_start_id = shift % (2 * n)
27+
28+
alloc_dev = inp.device
29+
30+
if ones_start_id < n:
31+
expected_inds = dpt.arange(
32+
ones_start_id, ones_start_id + k, dtype="i8", device=alloc_dev
33+
)
34+
else:
35+
# wrap-around
36+
ones_end_id = (ones_start_id + n) % (2 * n)
37+
if ones_end_id >= k:
38+
expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev)
39+
else:
40+
expected_inds = dpt.concat(
41+
(
42+
dpt.arange(ones_end_id, dtype="i8", device=alloc_dev),
43+
dpt.arange(
44+
ones_start_id,
45+
ones_start_id + k - ones_end_id,
46+
dtype="i8",
47+
device=alloc_dev,
48+
),
49+
)
50+
)
51+
52+
return expected_inds
53+
54+
2355
@pytest.mark.parametrize(
2456
"dtype",
2557
[
@@ -38,23 +70,57 @@
3870
"c16",
3971
],
4072
)
41-
@pytest.mark.parametrize("n", [33, 255, 511, 1021, 8193])
42-
def test_topk_1d_largest(dtype, n):
73+
@pytest.mark.parametrize("n", [33, 43, 255, 511, 1021, 8193])
74+
def test_top_k_1d_largest(dtype, n):
4375
q = get_queue_or_skip()
4476
skip_if_dtype_not_supported(dtype, q)
4577

78+
shift, k = 734, 5
4679
o = dpt.ones(n, dtype=dtype)
4780
z = dpt.zeros(n, dtype=dtype)
48-
zo = dpt.concat((o, z))
49-
inp = dpt.roll(zo, 734)
50-
k = 5
81+
oz = dpt.concat((o, z))
82+
inp = dpt.roll(oz, shift)
83+
84+
expected_inds = _expected_largest_inds(oz, n, shift, k)
5185

5286
s = dpt.top_k(inp, k, mode="largest")
5387
assert s.values.shape == (k,)
5488
assert s.values.dtype == inp.dtype
5589
assert s.indices.shape == (k,)
56-
assert dpt.all(s.values == dpt.ones(k, dtype=dtype))
57-
assert dpt.all(s.values == inp[s.indices])
90+
assert dpt.all(s.indices == expected_inds)
91+
assert dpt.all(s.values == dpt.ones(k, dtype=dtype)), s.values
92+
assert dpt.all(s.values == inp[s.indices]), s.indices
93+
94+
95+
def _expected_smallest_inds(inp, n, shift, k):
96+
"Computed expected top_k indices for mode='smallest'"
97+
assert k < n
98+
zeros_start_id = (n + shift) % (2 * n)
99+
zeros_end_id = (shift) % (2 * n)
100+
101+
alloc_dev = inp.device
102+
103+
if zeros_start_id < zeros_end_id:
104+
expected_inds = dpt.arange(
105+
zeros_start_id, zeros_start_id + k, dtype="i8", device=alloc_dev
106+
)
107+
else:
108+
if zeros_end_id >= k:
109+
expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev)
110+
else:
111+
expected_inds = dpt.concat(
112+
(
113+
dpt.arange(zeros_end_id, dtype="i8", device=alloc_dev),
114+
dpt.arange(
115+
zeros_start_id,
116+
zeros_start_id + k - zeros_end_id,
117+
dtype="i8",
118+
device=alloc_dev,
119+
),
120+
)
121+
)
122+
123+
return expected_inds
58124

59125

60126
@pytest.mark.parametrize(
@@ -75,20 +141,23 @@ def test_topk_1d_largest(dtype, n):
75141
"c16",
76142
],
77143
)
78-
@pytest.mark.parametrize("n", [33, 255, 257, 513, 1021, 8193])
79-
def test_topk_1d_smallest(dtype, n):
144+
@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193])
145+
def test_top_k_1d_smallest(dtype, n):
80146
q = get_queue_or_skip()
81147
skip_if_dtype_not_supported(dtype, q)
82148

149+
shift, k = 734, 5
83150
o = dpt.ones(n, dtype=dtype)
84151
z = dpt.zeros(n, dtype=dtype)
85-
zo = dpt.concat((o, z))
86-
inp = dpt.roll(zo, 734)
87-
k = 5
152+
oz = dpt.concat((o, z))
153+
inp = dpt.roll(oz, shift)
154+
155+
expected_inds = _expected_smallest_inds(oz, n, shift, k)
88156

89157
s = dpt.top_k(inp, k, mode="smallest")
90158
assert s.values.shape == (k,)
91159
assert s.values.dtype == inp.dtype
92160
assert s.indices.shape == (k,)
93-
assert dpt.all(s.values == dpt.zeros(k, dtype=dtype))
94-
assert dpt.all(s.values == inp[s.indices])
161+
assert dpt.all(s.indices == expected_inds)
162+
assert dpt.all(s.values == dpt.zeros(k, dtype=dtype)), s.values
163+
assert dpt.all(s.values == inp[s.indices]), s.indices

0 commit comments

Comments
 (0)