20
20
from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
21
21
22
22
23
+ def _expected_largest_inds (inp , n , shift , k ):
24
+ "Computed expected top_k indices for mode='largest'"
25
+ assert k < n
26
+ ones_start_id = shift % (2 * n )
27
+
28
+ alloc_dev = inp .device
29
+
30
+ if ones_start_id < n :
31
+ expected_inds = dpt .arange (
32
+ ones_start_id , ones_start_id + k , dtype = "i8" , device = alloc_dev
33
+ )
34
+ else :
35
+ # wrap-around
36
+ ones_end_id = (ones_start_id + n ) % (2 * n )
37
+ if ones_end_id >= k :
38
+ expected_inds = dpt .arange (k , dtype = "i8" , device = alloc_dev )
39
+ else :
40
+ expected_inds = dpt .concat (
41
+ (
42
+ dpt .arange (ones_end_id , dtype = "i8" , device = alloc_dev ),
43
+ dpt .arange (
44
+ ones_start_id ,
45
+ ones_start_id + k - ones_end_id ,
46
+ dtype = "i8" ,
47
+ device = alloc_dev ,
48
+ ),
49
+ )
50
+ )
51
+
52
+ return expected_inds
53
+
54
+
23
55
@pytest .mark .parametrize (
24
56
"dtype" ,
25
57
[
38
70
"c16" ,
39
71
],
40
72
)
41
- @pytest .mark .parametrize ("n" , [33 , 255 , 511 , 1021 , 8193 ])
42
- def test_topk_1d_largest (dtype , n ):
73
+ @pytest .mark .parametrize ("n" , [33 , 43 , 255 , 511 , 1021 , 8193 ])
74
+ def test_top_k_1d_largest (dtype , n ):
43
75
q = get_queue_or_skip ()
44
76
skip_if_dtype_not_supported (dtype , q )
45
77
78
+ shift , k = 734 , 5
46
79
o = dpt .ones (n , dtype = dtype )
47
80
z = dpt .zeros (n , dtype = dtype )
48
- zo = dpt .concat ((o , z ))
49
- inp = dpt .roll (zo , 734 )
50
- k = 5
81
+ oz = dpt .concat ((o , z ))
82
+ inp = dpt .roll (oz , shift )
83
+
84
+ expected_inds = _expected_largest_inds (oz , n , shift , k )
51
85
52
86
s = dpt .top_k (inp , k , mode = "largest" )
53
87
assert s .values .shape == (k ,)
54
88
assert s .values .dtype == inp .dtype
55
89
assert s .indices .shape == (k ,)
56
- assert dpt .all (s .values == dpt .ones (k , dtype = dtype ))
57
- assert dpt .all (s .values == inp [s .indices ])
90
+ assert dpt .all (s .indices == expected_inds )
91
+ assert dpt .all (s .values == dpt .ones (k , dtype = dtype )), s .values
92
+ assert dpt .all (s .values == inp [s .indices ]), s .indices
93
+
94
+
95
+ def _expected_smallest_inds (inp , n , shift , k ):
96
+ "Computed expected top_k indices for mode='smallest'"
97
+ assert k < n
98
+ zeros_start_id = (n + shift ) % (2 * n )
99
+ zeros_end_id = (shift ) % (2 * n )
100
+
101
+ alloc_dev = inp .device
102
+
103
+ if zeros_start_id < zeros_end_id :
104
+ expected_inds = dpt .arange (
105
+ zeros_start_id , zeros_start_id + k , dtype = "i8" , device = alloc_dev
106
+ )
107
+ else :
108
+ if zeros_end_id >= k :
109
+ expected_inds = dpt .arange (k , dtype = "i8" , device = alloc_dev )
110
+ else :
111
+ expected_inds = dpt .concat (
112
+ (
113
+ dpt .arange (zeros_end_id , dtype = "i8" , device = alloc_dev ),
114
+ dpt .arange (
115
+ zeros_start_id ,
116
+ zeros_start_id + k - zeros_end_id ,
117
+ dtype = "i8" ,
118
+ device = alloc_dev ,
119
+ ),
120
+ )
121
+ )
122
+
123
+ return expected_inds
58
124
59
125
60
126
@pytest .mark .parametrize (
@@ -75,20 +141,23 @@ def test_topk_1d_largest(dtype, n):
75
141
"c16" ,
76
142
],
77
143
)
78
- @pytest .mark .parametrize ("n" , [33 , 255 , 257 , 513 , 1021 , 8193 ])
79
- def test_topk_1d_smallest (dtype , n ):
144
+ @pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
145
+ def test_top_k_1d_smallest (dtype , n ):
80
146
q = get_queue_or_skip ()
81
147
skip_if_dtype_not_supported (dtype , q )
82
148
149
+ shift , k = 734 , 5
83
150
o = dpt .ones (n , dtype = dtype )
84
151
z = dpt .zeros (n , dtype = dtype )
85
- zo = dpt .concat ((o , z ))
86
- inp = dpt .roll (zo , 734 )
87
- k = 5
152
+ oz = dpt .concat ((o , z ))
153
+ inp = dpt .roll (oz , shift )
154
+
155
+ expected_inds = _expected_smallest_inds (oz , n , shift , k )
88
156
89
157
s = dpt .top_k (inp , k , mode = "smallest" )
90
158
assert s .values .shape == (k ,)
91
159
assert s .values .dtype == inp .dtype
92
160
assert s .indices .shape == (k ,)
93
- assert dpt .all (s .values == dpt .zeros (k , dtype = dtype ))
94
- assert dpt .all (s .values == inp [s .indices ])
161
+ assert dpt .all (s .indices == expected_inds )
162
+ assert dpt .all (s .values == dpt .zeros (k , dtype = dtype )), s .values
163
+ assert dpt .all (s .values == inp [s .indices ]), s .indices
0 commit comments