@@ -153,81 +153,117 @@ def test_sort_validation():
153
153
dpt .sort (dict ())
154
154
155
155
156
+ def test_sort_validation_kind ():
157
+ get_queue_or_skip ()
158
+
159
+ x = dpt .ones (128 , dtype = "u1" )
160
+
161
+ with pytest .raises (ValueError ):
162
+ dpt .sort (x , kind = Ellipsis )
163
+
164
+ with pytest .raises (ValueError ):
165
+ dpt .sort (x , kind = "invalid" )
166
+
167
+
156
168
def test_argsort_validation ():
157
169
with pytest .raises (TypeError ):
158
170
dpt .argsort (dict ())
159
171
160
172
161
- def test_sort_axis0 ():
173
+ def test_argsort_validation_kind ():
174
+ get_queue_or_skip ()
175
+
176
+ x = dpt .arange (127 , stop = 0 , step = - 1 , dtype = "i1" )
177
+
178
+ with pytest .raises (ValueError ):
179
+ dpt .argsort (x , kind = Ellipsis )
180
+
181
+ with pytest .raises (ValueError ):
182
+ dpt .argsort (x , kind = "invalid" )
183
+
184
+
185
+ _all_kinds = ["stable" , "mergesort" , "radixsort" ]
186
+
187
+
188
+ @pytest .mark .parametrize ("kind" , _all_kinds )
189
+ def test_sort_axis0 (kind ):
162
190
get_queue_or_skip ()
163
191
164
192
n , m = 200 , 30
165
193
xf = dpt .arange (n * m , 0 , step = - 1 , dtype = "i4" )
166
194
x = dpt .reshape (xf , (n , m ))
167
- s = dpt .sort (x , axis = 0 )
195
+ s = dpt .sort (x , axis = 0 , kind = kind )
168
196
169
197
assert dpt .all (s [:- 1 , :] <= s [1 :, :])
170
198
171
199
172
- def test_argsort_axis0 ():
200
+ @pytest .mark .parametrize ("kind" , _all_kinds )
201
+ def test_argsort_axis0 (kind ):
173
202
get_queue_or_skip ()
174
203
175
204
n , m = 200 , 30
176
205
xf = dpt .arange (n * m , 0 , step = - 1 , dtype = "i4" )
177
206
x = dpt .reshape (xf , (n , m ))
178
- idx = dpt .argsort (x , axis = 0 )
207
+ idx = dpt .argsort (x , axis = 0 , kind = kind )
179
208
180
209
s = dpt .take_along_axis (x , idx , axis = 0 )
181
210
182
211
assert dpt .all (s [:- 1 , :] <= s [1 :, :])
183
212
184
213
185
- def test_argsort_axis1 ():
214
+ @pytest .mark .parametrize ("kind" , _all_kinds )
215
+ def test_argsort_axis1 (kind ):
186
216
get_queue_or_skip ()
187
217
188
218
n , m = 200 , 30
189
219
xf = dpt .arange (n * m , 0 , step = - 1 , dtype = "i4" )
190
220
x = dpt .reshape (xf , (n , m ))
191
- idx = dpt .argsort (x , axis = 1 )
221
+ idx = dpt .argsort (x , axis = 1 , kind = kind )
192
222
193
223
s = dpt .take_along_axis (x , idx , axis = 1 )
194
224
195
225
assert dpt .all (s [:, :- 1 ] <= s [:, 1 :])
196
226
197
227
198
- def test_sort_strided ():
228
+ @pytest .mark .parametrize ("kind" , _all_kinds )
229
+ def test_sort_strided (kind ):
199
230
get_queue_or_skip ()
200
231
201
232
x_orig = dpt .arange (100 , dtype = "i4" )
202
233
x_flipped = dpt .flip (x_orig , axis = 0 )
203
- s = dpt .sort (x_flipped )
234
+ s = dpt .sort (x_flipped , kind = kind )
204
235
205
236
assert dpt .all (s == x_orig )
206
237
207
238
208
- def test_argsort_strided ():
239
+ @pytest .mark .parametrize ("kind" , _all_kinds )
240
+ def test_argsort_strided (kind ):
209
241
get_queue_or_skip ()
210
242
211
243
x_orig = dpt .arange (100 , dtype = "i4" )
212
244
x_flipped = dpt .flip (x_orig , axis = 0 )
213
- idx = dpt .argsort (x_flipped )
245
+ idx = dpt .argsort (x_flipped , kind = kind )
214
246
s = dpt .take_along_axis (x_flipped , idx , axis = 0 )
215
247
216
248
assert dpt .all (s == x_orig )
217
249
218
250
219
- def test_sort_0d_array ():
251
+ @pytest .mark .parametrize ("kind" , _all_kinds )
252
+ def test_sort_0d_array (kind ):
220
253
get_queue_or_skip ()
221
254
222
255
x = dpt .asarray (1 , dtype = "i4" )
223
- assert dpt .sort (x ) == 1
256
+ expected = dpt .asarray (1 , dtype = "i4" )
257
+ assert dpt .sort (x , kind = kind ) == expected
224
258
225
259
226
- def test_argsort_0d_array ():
260
+ @pytest .mark .parametrize ("kind" , _all_kinds )
261
+ def test_argsort_0d_array (kind ):
227
262
get_queue_or_skip ()
228
263
229
264
x = dpt .asarray (1 , dtype = "i4" )
230
- assert dpt .argsort (x ) == 0
265
+ expected = dpt .asarray (0 , dtype = "i4" )
266
+ assert dpt .argsort (x , kind = kind ) == expected
231
267
232
268
233
269
@pytest .mark .parametrize (
@@ -238,22 +274,23 @@ def test_argsort_0d_array():
238
274
"f8" ,
239
275
],
240
276
)
241
- def test_sort_real_fp_nan (dtype ):
277
+ @pytest .mark .parametrize ("kind" , _all_kinds )
278
+ def test_sort_real_fp_nan (dtype , kind ):
242
279
q = get_queue_or_skip ()
243
280
skip_if_dtype_not_supported (dtype , q )
244
281
245
282
x = dpt .asarray (
246
283
[- 0.0 , 0.1 , dpt .nan , 0.0 , - 0.1 , dpt .nan , 0.2 , - 0.3 ], dtype = dtype
247
284
)
248
- s = dpt .sort (x )
285
+ s = dpt .sort (x , kind = kind )
249
286
250
287
expected = dpt .asarray (
251
288
[- 0.3 , - 0.1 , - 0.0 , 0.0 , 0.1 , 0.2 , dpt .nan , dpt .nan ], dtype = dtype
252
289
)
253
290
254
291
assert dpt .allclose (s , expected , equal_nan = True )
255
292
256
- s = dpt .sort (x , descending = True )
293
+ s = dpt .sort (x , descending = True , kind = kind )
257
294
258
295
expected = dpt .asarray (
259
296
[dpt .nan , dpt .nan , 0.2 , 0.1 , - 0.0 , 0.0 , - 0.1 , - 0.3 ], dtype = dtype
0 commit comments