22
22
from ._tensor_sorting_impl import (
23
23
_argsort_ascending ,
24
24
_argsort_descending ,
25
+ _radix_argsort_ascending ,
26
+ _radix_argsort_descending ,
27
+ _radix_sort_ascending ,
28
+ _radix_sort_descending ,
29
+ _radix_sort_dtype_supported ,
25
30
_sort_ascending ,
26
31
_sort_descending ,
27
32
)
28
33
29
34
__all__ = ["sort" , "argsort" ]
30
35
31
36
32
- def sort (x , / , * , axis = - 1 , descending = False , stable = True ):
37
+ def _get_mergesort_impl_fn (descending ):
38
+ return _sort_descending if descending else _sort_ascending
39
+
40
+
41
+ def _get_radixsort_impl_fn (descending ):
42
+ return _radix_sort_descending if descending else _radix_sort_ascending
43
+
44
+
45
+ def sort (x , / , * , axis = - 1 , descending = False , stable = True , kind = None ):
33
46
"""sort(x, axis=-1, descending=False, stable=True)
34
47
35
48
Returns a sorted copy of an input array `x`.
@@ -49,7 +62,10 @@ def sort(x, /, *, axis=-1, descending=False, stable=True):
49
62
relative order of `x` values which compare as equal. If `False`,
50
63
the returned array may or may not maintain the relative order of
51
64
`x` values which compare as equal. Default: `True`.
52
-
65
+ kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
66
+ Sorting algorithm. The default is `"stable"`, which uses parallel
67
+ merge-sort or parallel radix-sort algorithms depending on the
68
+ array data type.
53
69
Returns:
54
70
usm_ndarray:
55
71
a sorted array. The returned array has the same data type and
@@ -74,10 +90,33 @@ def sort(x, /, *, axis=-1, descending=False, stable=True):
74
90
axis ,
75
91
]
76
92
arr = dpt .permute_dims (x , perm )
93
+ if kind is None :
94
+ kind = "stable"
95
+ if not isinstance (kind , str ) or kind not in [
96
+ "stable" ,
97
+ "radixsort" ,
98
+ "mergesort" ,
99
+ ]:
100
+ raise ValueError (
101
+ "Unsupported kind value. Expected 'stable', 'mergesort', "
102
+ f"or 'radixsort', but got '{ kind } '"
103
+ )
104
+ if kind == "mergesort" :
105
+ impl_fn = _get_mergesort_impl_fn (descending )
106
+ elif kind == "radixsort" :
107
+ if _radix_sort_dtype_supported (x .dtype .num ):
108
+ impl_fn = _get_radixsort_impl_fn (descending )
109
+ else :
110
+ raise ValueError (f"Radix sort is not supported for { x .dtype } " )
111
+ else :
112
+ dt = x .dtype
113
+ if dt in [dpt .bool , dpt .uint8 , dpt .int8 , dpt .int16 , dpt .uint16 ]:
114
+ impl_fn = _get_radixsort_impl_fn (descending )
115
+ else :
116
+ impl_fn = _get_mergesort_impl_fn (descending )
77
117
exec_q = x .sycl_queue
78
118
_manager = du .SequentialOrderManager [exec_q ]
79
119
dep_evs = _manager .submitted_events
80
- impl_fn = _sort_descending if descending else _sort_ascending
81
120
if arr .flags .c_contiguous :
82
121
res = dpt .empty_like (arr , order = "C" )
83
122
ht_ev , impl_ev = impl_fn (
@@ -109,7 +148,15 @@ def sort(x, /, *, axis=-1, descending=False, stable=True):
109
148
return res
110
149
111
150
112
- def argsort (x , axis = - 1 , descending = False , stable = True ):
151
+ def _get_mergeargsort_impl_fn (descending ):
152
+ return _argsort_descending if descending else _argsort_ascending
153
+
154
+
155
+ def _get_radixargsort_impl_fn (descending ):
156
+ return _radix_argsort_descending if descending else _radix_argsort_ascending
157
+
158
+
159
+ def argsort (x , axis = - 1 , descending = False , stable = True , kind = None ):
113
160
"""argsort(x, axis=-1, descending=False, stable=True)
114
161
115
162
Returns the indices that sort an array `x` along a specified axis.
@@ -129,6 +176,10 @@ def argsort(x, axis=-1, descending=False, stable=True):
129
176
relative order of `x` values which compare as equal. If `False`,
130
177
the returned array may or may not maintain the relative order of
131
178
`x` values which compare as equal. Default: `True`.
179
+ kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
180
+ Sorting algorithm. The default is `"stable"`, which uses parallel
181
+ merge-sort or parallel radix-sort algorithms depending on the
182
+ array data type.
132
183
133
184
Returns:
134
185
usm_ndarray:
@@ -157,10 +208,33 @@ def argsort(x, axis=-1, descending=False, stable=True):
157
208
axis ,
158
209
]
159
210
arr = dpt .permute_dims (x , perm )
211
+ if kind is None :
212
+ kind = "stable"
213
+ if not isinstance (kind , str ) or kind not in [
214
+ "stable" ,
215
+ "radixsort" ,
216
+ "mergesort" ,
217
+ ]:
218
+ raise ValueError (
219
+ "Unsupported kind value. Expected 'stable', 'mergesort', "
220
+ f"or 'radixsort', but got '{ kind } '"
221
+ )
222
+ if kind == "mergesort" :
223
+ impl_fn = _get_mergeargsort_impl_fn (descending )
224
+ elif kind == "radixsort" :
225
+ if _radix_sort_dtype_supported (x .dtype .num ):
226
+ impl_fn = _get_radixargsort_impl_fn (descending )
227
+ else :
228
+ raise ValueError (f"Radix sort is not supported for { x .dtype } " )
229
+ else :
230
+ dt = x .dtype
231
+ if dt in [dpt .bool , dpt .uint8 , dpt .int8 , dpt .int16 , dpt .uint16 ]:
232
+ impl_fn = _get_radixargsort_impl_fn (descending )
233
+ else :
234
+ impl_fn = _get_mergeargsort_impl_fn (descending )
160
235
exec_q = x .sycl_queue
161
236
_manager = du .SequentialOrderManager [exec_q ]
162
237
dep_evs = _manager .submitted_events
163
- impl_fn = _argsort_descending if descending else _argsort_ascending
164
238
index_dt = ti .default_device_index_type (exec_q )
165
239
if arr .flags .c_contiguous :
166
240
res = dpt .empty_like (arr , dtype = index_dt , order = "C" )
0 commit comments