Skip to content

Commit a15e4aa

Browse files
Add support for kind keyword in sort/argsort
Supported values for kind are "radixsort", "mergesort", "stable". The default is None (same as "stable"). For stable, radix sort is used for bool, (u)int8, (u)int16. Radix sort is supported for integral, boolean and real floating point types.
1 parent c2f8486 commit a15e4aa

File tree

1 file changed

+79
-5
lines changed

1 file changed

+79
-5
lines changed

dpctl/tensor/_sorting.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,27 @@
2222
from ._tensor_sorting_impl import (
2323
_argsort_ascending,
2424
_argsort_descending,
25+
_radix_argsort_ascending,
26+
_radix_argsort_descending,
27+
_radix_sort_ascending,
28+
_radix_sort_descending,
29+
_radix_sort_dtype_supported,
2530
_sort_ascending,
2631
_sort_descending,
2732
)
2833

2934
__all__ = ["sort", "argsort"]
3035

3136

32-
def sort(x, /, *, axis=-1, descending=False, stable=True):
37+
def _get_mergesort_impl_fn(descending):
38+
return _sort_descending if descending else _sort_ascending
39+
40+
41+
def _get_radixsort_impl_fn(descending):
42+
return _radix_sort_descending if descending else _radix_sort_ascending
43+
44+
45+
def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None):
3346
"""sort(x, axis=-1, descending=False, stable=True)
3447
3548
Returns a sorted copy of an input array `x`.
@@ -49,7 +62,10 @@ def sort(x, /, *, axis=-1, descending=False, stable=True):
4962
relative order of `x` values which compare as equal. If `False`,
5063
the returned array may or may not maintain the relative order of
5164
`x` values which compare as equal. Default: `True`.
52-
65+
kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
66+
Sorting algorithm. The default is `"stable"`, which uses parallel
67+
merge-sort or parallel radix-sort algorithms depending on the
68+
array data type.
5369
Returns:
5470
usm_ndarray:
5571
a sorted array. The returned array has the same data type and
@@ -74,10 +90,33 @@ def sort(x, /, *, axis=-1, descending=False, stable=True):
7490
axis,
7591
]
7692
arr = dpt.permute_dims(x, perm)
93+
if kind is None:
94+
kind = "stable"
95+
if not isinstance(kind, str) or kind not in [
96+
"stable",
97+
"radixsort",
98+
"mergesort",
99+
]:
100+
raise ValueError(
101+
"Unsupported kind value. Expected 'stable', 'mergesort', "
102+
f"or 'radixsort', but got '{kind}'"
103+
)
104+
if kind == "mergesort":
105+
impl_fn = _get_mergesort_impl_fn(descending)
106+
elif kind == "radixsort":
107+
if _radix_sort_dtype_supported(x.dtype.num):
108+
impl_fn = _get_radixsort_impl_fn(descending)
109+
else:
110+
raise ValueError(f"Radix sort is not supported for {x.dtype}")
111+
else:
112+
dt = x.dtype
113+
if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]:
114+
impl_fn = _get_radixsort_impl_fn(descending)
115+
else:
116+
impl_fn = _get_mergesort_impl_fn(descending)
77117
exec_q = x.sycl_queue
78118
_manager = du.SequentialOrderManager[exec_q]
79119
dep_evs = _manager.submitted_events
80-
impl_fn = _sort_descending if descending else _sort_ascending
81120
if arr.flags.c_contiguous:
82121
res = dpt.empty_like(arr, order="C")
83122
ht_ev, impl_ev = impl_fn(
@@ -109,7 +148,15 @@ def sort(x, /, *, axis=-1, descending=False, stable=True):
109148
return res
110149

111150

112-
def argsort(x, axis=-1, descending=False, stable=True):
151+
def _get_mergeargsort_impl_fn(descending):
152+
return _argsort_descending if descending else _argsort_ascending
153+
154+
155+
def _get_radixargsort_impl_fn(descending):
156+
return _radix_argsort_descending if descending else _radix_argsort_ascending
157+
158+
159+
def argsort(x, axis=-1, descending=False, stable=True, kind=None):
113160
"""argsort(x, axis=-1, descending=False, stable=True)
114161
115162
Returns the indices that sort an array `x` along a specified axis.
@@ -129,6 +176,10 @@ def argsort(x, axis=-1, descending=False, stable=True):
129176
relative order of `x` values which compare as equal. If `False`,
130177
the returned array may or may not maintain the relative order of
131178
`x` values which compare as equal. Default: `True`.
179+
kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
180+
Sorting algorithm. The default is `"stable"`, which uses parallel
181+
merge-sort or parallel radix-sort algorithms depending on the
182+
array data type.
132183
133184
Returns:
134185
usm_ndarray:
@@ -157,10 +208,33 @@ def argsort(x, axis=-1, descending=False, stable=True):
157208
axis,
158209
]
159210
arr = dpt.permute_dims(x, perm)
211+
if kind is None:
212+
kind = "stable"
213+
if not isinstance(kind, str) or kind not in [
214+
"stable",
215+
"radixsort",
216+
"mergesort",
217+
]:
218+
raise ValueError(
219+
"Unsupported kind value. Expected 'stable', 'mergesort', "
220+
f"or 'radixsort', but got '{kind}'"
221+
)
222+
if kind == "mergesort":
223+
impl_fn = _get_mergeargsort_impl_fn(descending)
224+
elif kind == "radixsort":
225+
if _radix_sort_dtype_supported(x.dtype.num):
226+
impl_fn = _get_radixargsort_impl_fn(descending)
227+
else:
228+
raise ValueError(f"Radix sort is not supported for {x.dtype}")
229+
else:
230+
dt = x.dtype
231+
if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]:
232+
impl_fn = _get_radixargsort_impl_fn(descending)
233+
else:
234+
impl_fn = _get_mergeargsort_impl_fn(descending)
160235
exec_q = x.sycl_queue
161236
_manager = du.SequentialOrderManager[exec_q]
162237
dep_evs = _manager.submitted_events
163-
impl_fn = _argsort_descending if descending else _argsort_ascending
164238
index_dt = ti.default_device_index_type(exec_q)
165239
if arr.flags.c_contiguous:
166240
res = dpt.empty_like(arr, dtype=index_dt, order="C")

0 commit comments

Comments
 (0)