Skip to content

Commit 505b64c

Browse files
oleksandr-pavlykndgrigorian
authored andcommitted
Remove skipping of tests for i1/i2 dtypes since work-around
was applied in C++. Add tests for 2d input arrays, for axis=0 and axis=1 Add a test for non-contiguous input, 0d input, validation 100% coverage of top_k function implementation achieved
1 parent 5125e11 commit 505b64c

File tree

1 file changed

+156
-6
lines changed

1 file changed

+156
-6
lines changed

dpctl/tests/test_usm_ndarray_top_k.py

Lines changed: 156 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _expected_largest_inds(inp, n, shift, k):
5555
@pytest.mark.parametrize(
5656
"dtype",
5757
[
58-
pytest.param("i1", marks=pytest.mark.skip(reason="CPU bug")),
58+
"i1",
5959
"u1",
6060
"i2",
6161
"u2",
@@ -74,8 +74,6 @@ def _expected_largest_inds(inp, n, shift, k):
7474
def test_top_k_1d_largest(dtype, n):
7575
q = get_queue_or_skip()
7676
skip_if_dtype_not_supported(dtype, q)
77-
if dtype == "i1":
78-
pytest.skip()
7977

8078
shift, k = 734, 5
8179
o = dpt.ones(n, dtype=dtype)
@@ -89,9 +87,9 @@ def test_top_k_1d_largest(dtype, n):
8987
assert s.values.shape == (k,)
9088
assert s.values.dtype == inp.dtype
9189
assert s.indices.shape == (k,)
92-
assert dpt.all(s.indices == expected_inds)
9390
assert dpt.all(s.values == dpt.ones(k, dtype=dtype)), s.values
9491
assert dpt.all(s.values == inp[s.indices]), s.indices
92+
assert dpt.all(s.indices == expected_inds), (s.indices, expected_inds)
9593

9694

9795
def _expected_smallest_inds(inp, n, shift, k):
@@ -128,7 +126,7 @@ def _expected_smallest_inds(inp, n, shift, k):
128126
@pytest.mark.parametrize(
129127
"dtype",
130128
[
131-
pytest.param("i1", marks=pytest.mark.skip(reason="CPU bug")),
129+
"i1",
132130
"u1",
133131
"i2",
134132
"u2",
@@ -160,6 +158,158 @@ def test_top_k_1d_smallest(dtype, n):
160158
assert s.values.shape == (k,)
161159
assert s.values.dtype == inp.dtype
162160
assert s.indices.shape == (k,)
163-
assert dpt.all(s.indices == expected_inds)
164161
assert dpt.all(s.values == dpt.zeros(k, dtype=dtype)), s.values
165162
assert dpt.all(s.values == inp[s.indices]), s.indices
163+
assert dpt.all(s.indices == expected_inds), (s.indices, expected_inds)
164+
165+
166+
@pytest.mark.parametrize(
167+
"dtype",
168+
[
169+
# skip short types to ensure that m*n can be represented
170+
# in the type
171+
"i4",
172+
"u4",
173+
"i8",
174+
"u8",
175+
"f2",
176+
"f4",
177+
"f8",
178+
"c8",
179+
"c16",
180+
],
181+
)
182+
@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193])
183+
def test_top_k_2d_largest(dtype, n):
184+
q = get_queue_or_skip()
185+
skip_if_dtype_not_supported(dtype, q)
186+
187+
m, k = 8, 3
188+
if dtype == "f2" and m * n > 2000:
189+
pytest.skip(
190+
"f2 can not distinguish between large integers used in this test"
191+
)
192+
193+
x = dpt.reshape(dpt.arange(m * n, dtype=dtype), (m, n))
194+
195+
r = dpt.top_k(x, k, axis=1)
196+
197+
assert r.values.shape == (m, k)
198+
assert r.indices.shape == (m, k)
199+
expected_inds = dpt.reshape(dpt.arange(n, dtype=r.indices.dtype), (1, n))[
200+
:, -k:
201+
]
202+
assert expected_inds.shape == (1, k)
203+
assert dpt.all(
204+
dpt.sort(r.indices, axis=1) == dpt.sort(expected_inds, axis=1)
205+
), (r.indices, expected_inds)
206+
expected_vals = x[:, -k:]
207+
assert dpt.all(
208+
dpt.sort(r.values, axis=1) == dpt.sort(expected_vals, axis=1)
209+
)
210+
211+
212+
@pytest.mark.parametrize(
213+
"dtype",
214+
[
215+
# skip short types to ensure that m*n can be represented
216+
# in the type
217+
"i4",
218+
"u4",
219+
"i8",
220+
"u8",
221+
"f2",
222+
"f4",
223+
"f8",
224+
"c8",
225+
"c16",
226+
],
227+
)
228+
@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193])
229+
def test_top_k_2d_smallest(dtype, n):
230+
q = get_queue_or_skip()
231+
skip_if_dtype_not_supported(dtype, q)
232+
233+
m, k = 8, 3
234+
if dtype == "f2" and m * n > 2000:
235+
pytest.skip(
236+
"f2 can not distinguish between large integers used in this test"
237+
)
238+
239+
x = dpt.reshape(dpt.arange(m * n, dtype=dtype), (m, n))
240+
241+
r = dpt.top_k(x, k, axis=1, mode="smallest")
242+
243+
assert r.values.shape == (m, k)
244+
assert r.indices.shape == (m, k)
245+
expected_inds = dpt.reshape(dpt.arange(n, dtype=r.indices.dtype), (1, n))[
246+
:, :k
247+
]
248+
assert dpt.all(
249+
dpt.sort(r.indices, axis=1) == dpt.sort(expected_inds, axis=1)
250+
)
251+
assert dpt.all(dpt.sort(r.values, axis=1) == dpt.sort(x[:, :k], axis=1))
252+
253+
254+
def test_top_k_0d():
255+
get_queue_or_skip()
256+
257+
a = dpt.ones(tuple(), dtype="i4")
258+
assert a.ndim == 0
259+
assert a.size == 1
260+
261+
r = dpt.top_k(a, 1)
262+
assert r.values == a
263+
assert r.indices == dpt.zeros_like(a, dtype=r.indices.dtype)
264+
265+
266+
def test_top_k_noncontig():
267+
get_queue_or_skip()
268+
269+
a = dpt.arange(256, dtype=dpt.int32)[::2]
270+
r = dpt.top_k(a, 3)
271+
272+
assert dpt.all(dpt.sort(r.values) == dpt.asarray([250, 252, 254])), r.values
273+
assert dpt.all(
274+
dpt.sort(r.indices) == dpt.asarray([125, 126, 127])
275+
), r.indices
276+
277+
278+
def test_top_k_axis0():
279+
get_queue_or_skip()
280+
281+
m, n, k = 128, 8, 3
282+
x = dpt.reshape(dpt.arange(m * n, dtype=dpt.int32), (m, n))
283+
284+
r = dpt.top_k(x, k, axis=0, mode="smallest")
285+
assert r.values.shape == (k, n)
286+
assert r.indices.shape == (k, n)
287+
expected_inds = dpt.reshape(dpt.arange(m, dtype=r.indices.dtype), (m, 1))[
288+
:k, :
289+
]
290+
assert dpt.all(
291+
dpt.sort(r.indices, axis=0) == dpt.sort(expected_inds, axis=0)
292+
)
293+
assert dpt.all(dpt.sort(r.values, axis=0) == dpt.sort(x[:k, :], axis=0))
294+
295+
296+
def test_top_k_validation():
297+
get_queue_or_skip()
298+
x = dpt.ones(10, dtype=dpt.int64)
299+
with pytest.raises(ValueError):
300+
# k must be positive
301+
dpt.top_k(x, -1)
302+
with pytest.raises(TypeError):
303+
# argument should be usm_ndarray
304+
dpt.top_k(list(), 2)
305+
x2 = dpt.reshape(x, (2, 5))
306+
with pytest.raises(ValueError):
307+
# k must not exceed array dimension
308+
# along specified axis
309+
dpt.top_k(x2, 100, axis=1)
310+
with pytest.raises(ValueError):
311+
# for 0d arrays, k must be 1
312+
dpt.top_k(x[0], 2)
313+
with pytest.raises(ValueError):
314+
# mode must be "largest", or "smallest"
315+
dpt.top_k(x, 2, mode="invalid")

0 commit comments

Comments
 (0)