@@ -55,7 +55,7 @@ def _expected_largest_inds(inp, n, shift, k):
55
55
@pytest .mark .parametrize (
56
56
"dtype" ,
57
57
[
58
- pytest . param ( "i1" , marks = pytest . mark . skip ( reason = "CPU bug" )) ,
58
+ "i1" ,
59
59
"u1" ,
60
60
"i2" ,
61
61
"u2" ,
@@ -74,8 +74,6 @@ def _expected_largest_inds(inp, n, shift, k):
74
74
def test_top_k_1d_largest (dtype , n ):
75
75
q = get_queue_or_skip ()
76
76
skip_if_dtype_not_supported (dtype , q )
77
- if dtype == "i1" :
78
- pytest .skip ()
79
77
80
78
shift , k = 734 , 5
81
79
o = dpt .ones (n , dtype = dtype )
@@ -89,9 +87,9 @@ def test_top_k_1d_largest(dtype, n):
89
87
assert s .values .shape == (k ,)
90
88
assert s .values .dtype == inp .dtype
91
89
assert s .indices .shape == (k ,)
92
- assert dpt .all (s .indices == expected_inds )
93
90
assert dpt .all (s .values == dpt .ones (k , dtype = dtype )), s .values
94
91
assert dpt .all (s .values == inp [s .indices ]), s .indices
92
+ assert dpt .all (s .indices == expected_inds ), (s .indices , expected_inds )
95
93
96
94
97
95
def _expected_smallest_inds (inp , n , shift , k ):
@@ -128,7 +126,7 @@ def _expected_smallest_inds(inp, n, shift, k):
128
126
@pytest .mark .parametrize (
129
127
"dtype" ,
130
128
[
131
- pytest . param ( "i1" , marks = pytest . mark . skip ( reason = "CPU bug" )) ,
129
+ "i1" ,
132
130
"u1" ,
133
131
"i2" ,
134
132
"u2" ,
@@ -160,6 +158,158 @@ def test_top_k_1d_smallest(dtype, n):
160
158
assert s .values .shape == (k ,)
161
159
assert s .values .dtype == inp .dtype
162
160
assert s .indices .shape == (k ,)
163
- assert dpt .all (s .indices == expected_inds )
164
161
assert dpt .all (s .values == dpt .zeros (k , dtype = dtype )), s .values
165
162
assert dpt .all (s .values == inp [s .indices ]), s .indices
163
+ assert dpt .all (s .indices == expected_inds ), (s .indices , expected_inds )
164
+
165
+
166
+ @pytest .mark .parametrize (
167
+ "dtype" ,
168
+ [
169
+ # skip short types to ensure that m*n can be represented
170
+ # in the type
171
+ "i4" ,
172
+ "u4" ,
173
+ "i8" ,
174
+ "u8" ,
175
+ "f2" ,
176
+ "f4" ,
177
+ "f8" ,
178
+ "c8" ,
179
+ "c16" ,
180
+ ],
181
+ )
182
+ @pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
183
+ def test_top_k_2d_largest (dtype , n ):
184
+ q = get_queue_or_skip ()
185
+ skip_if_dtype_not_supported (dtype , q )
186
+
187
+ m , k = 8 , 3
188
+ if dtype == "f2" and m * n > 2000 :
189
+ pytest .skip (
190
+ "f2 can not distinguish between large integers used in this test"
191
+ )
192
+
193
+ x = dpt .reshape (dpt .arange (m * n , dtype = dtype ), (m , n ))
194
+
195
+ r = dpt .top_k (x , k , axis = 1 )
196
+
197
+ assert r .values .shape == (m , k )
198
+ assert r .indices .shape == (m , k )
199
+ expected_inds = dpt .reshape (dpt .arange (n , dtype = r .indices .dtype ), (1 , n ))[
200
+ :, - k :
201
+ ]
202
+ assert expected_inds .shape == (1 , k )
203
+ assert dpt .all (
204
+ dpt .sort (r .indices , axis = 1 ) == dpt .sort (expected_inds , axis = 1 )
205
+ ), (r .indices , expected_inds )
206
+ expected_vals = x [:, - k :]
207
+ assert dpt .all (
208
+ dpt .sort (r .values , axis = 1 ) == dpt .sort (expected_vals , axis = 1 )
209
+ )
210
+
211
+
212
+ @pytest .mark .parametrize (
213
+ "dtype" ,
214
+ [
215
+ # skip short types to ensure that m*n can be represented
216
+ # in the type
217
+ "i4" ,
218
+ "u4" ,
219
+ "i8" ,
220
+ "u8" ,
221
+ "f2" ,
222
+ "f4" ,
223
+ "f8" ,
224
+ "c8" ,
225
+ "c16" ,
226
+ ],
227
+ )
228
+ @pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
229
+ def test_top_k_2d_smallest (dtype , n ):
230
+ q = get_queue_or_skip ()
231
+ skip_if_dtype_not_supported (dtype , q )
232
+
233
+ m , k = 8 , 3
234
+ if dtype == "f2" and m * n > 2000 :
235
+ pytest .skip (
236
+ "f2 can not distinguish between large integers used in this test"
237
+ )
238
+
239
+ x = dpt .reshape (dpt .arange (m * n , dtype = dtype ), (m , n ))
240
+
241
+ r = dpt .top_k (x , k , axis = 1 , mode = "smallest" )
242
+
243
+ assert r .values .shape == (m , k )
244
+ assert r .indices .shape == (m , k )
245
+ expected_inds = dpt .reshape (dpt .arange (n , dtype = r .indices .dtype ), (1 , n ))[
246
+ :, :k
247
+ ]
248
+ assert dpt .all (
249
+ dpt .sort (r .indices , axis = 1 ) == dpt .sort (expected_inds , axis = 1 )
250
+ )
251
+ assert dpt .all (dpt .sort (r .values , axis = 1 ) == dpt .sort (x [:, :k ], axis = 1 ))
252
+
253
+
254
+ def test_top_k_0d ():
255
+ get_queue_or_skip ()
256
+
257
+ a = dpt .ones (tuple (), dtype = "i4" )
258
+ assert a .ndim == 0
259
+ assert a .size == 1
260
+
261
+ r = dpt .top_k (a , 1 )
262
+ assert r .values == a
263
+ assert r .indices == dpt .zeros_like (a , dtype = r .indices .dtype )
264
+
265
+
266
+ def test_top_k_noncontig ():
267
+ get_queue_or_skip ()
268
+
269
+ a = dpt .arange (256 , dtype = dpt .int32 )[::2 ]
270
+ r = dpt .top_k (a , 3 )
271
+
272
+ assert dpt .all (dpt .sort (r .values ) == dpt .asarray ([250 , 252 , 254 ])), r .values
273
+ assert dpt .all (
274
+ dpt .sort (r .indices ) == dpt .asarray ([125 , 126 , 127 ])
275
+ ), r .indices
276
+
277
+
278
+ def test_top_k_axis0 ():
279
+ get_queue_or_skip ()
280
+
281
+ m , n , k = 128 , 8 , 3
282
+ x = dpt .reshape (dpt .arange (m * n , dtype = dpt .int32 ), (m , n ))
283
+
284
+ r = dpt .top_k (x , k , axis = 0 , mode = "smallest" )
285
+ assert r .values .shape == (k , n )
286
+ assert r .indices .shape == (k , n )
287
+ expected_inds = dpt .reshape (dpt .arange (m , dtype = r .indices .dtype ), (m , 1 ))[
288
+ :k , :
289
+ ]
290
+ assert dpt .all (
291
+ dpt .sort (r .indices , axis = 0 ) == dpt .sort (expected_inds , axis = 0 )
292
+ )
293
+ assert dpt .all (dpt .sort (r .values , axis = 0 ) == dpt .sort (x [:k , :], axis = 0 ))
294
+
295
+
296
+ def test_top_k_validation ():
297
+ get_queue_or_skip ()
298
+ x = dpt .ones (10 , dtype = dpt .int64 )
299
+ with pytest .raises (ValueError ):
300
+ # k must be positive
301
+ dpt .top_k (x , - 1 )
302
+ with pytest .raises (TypeError ):
303
+ # argument should be usm_ndarray
304
+ dpt .top_k (list (), 2 )
305
+ x2 = dpt .reshape (x , (2 , 5 ))
306
+ with pytest .raises (ValueError ):
307
+ # k must not exceed array dimension
308
+ # along specified axis
309
+ dpt .top_k (x2 , 100 , axis = 1 )
310
+ with pytest .raises (ValueError ):
311
+ # for 0d arrays, k must be 1
312
+ dpt .top_k (x [0 ], 2 )
313
+ with pytest .raises (ValueError ):
314
+ # mode must be "largest", or "smallest"
315
+ dpt .top_k (x , 2 , mode = "invalid" )
0 commit comments