Skip to content

Commit 3a1a7c5

Browse files
authored
Merge pull request #1921 from IntelPython/feature/topk
Implements `top_k` in dpctl.tensor
2 parents f7cb1b1 + 8c6abf5 commit 3a1a7c5

File tree

14 files changed

+1854
-187
lines changed

14 files changed

+1854
-187
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
### Added
1010

11+
* Added `dpctl.tensor.top_k` per Python Array API specification: [#1921](https://github.com/IntelPython/dpctl/pull/1921)
12+
1113
### Changed
1214

1315
* Improved performance of copy-and-cast operations from `numpy.ndarray` to `tensor.usm_ndarray` for contiguous inputs [gh-1829](https://github.com/IntelPython/dpctl/pull/1829)

docs/doc_sources/api_reference/dpctl/tensor.sorting_functions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ Sorting functions
1010

1111
argsort
1212
sort
13+
top_k

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ set(_sorting_sources
115115
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp
116116
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
117117
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
118+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/topk.cpp
118119
)
119120
set(_sorting_radix_sources
120121
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp

dpctl/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@
199199
unique_inverse,
200200
unique_values,
201201
)
202-
from ._sorting import argsort, sort
202+
from ._sorting import argsort, sort, top_k
203203
from ._testing import allclose
204204
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
205205

@@ -387,4 +387,5 @@
387387
"DLDeviceType",
388388
"take_along_axis",
389389
"put_along_axis",
390+
"top_k",
390391
]

dpctl/tensor/_sorting.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import operator
18+
from typing import NamedTuple
19+
1720
import dpctl.tensor as dpt
1821
import dpctl.tensor._tensor_impl as ti
1922
import dpctl.utils as du
@@ -24,6 +27,7 @@
2427
_argsort_descending,
2528
_sort_ascending,
2629
_sort_descending,
30+
_topk,
2731
)
2832
from ._tensor_sorting_radix_impl import (
2933
_radix_argsort_ascending,
@@ -267,3 +271,166 @@ def argsort(x, axis=-1, descending=False, stable=True, kind=None):
267271
inv_perm = sorted(range(nd), key=lambda d: perm[d])
268272
res = dpt.permute_dims(res, inv_perm)
269273
return res
274+
275+
276+
def _get_top_k_largest(mode):
277+
modes = {"largest": True, "smallest": False}
278+
try:
279+
return modes[mode]
280+
except KeyError:
281+
raise ValueError(
282+
f"`mode` must be `largest` or `smallest`. Got `{mode}`."
283+
)
284+
285+
286+
class TopKResult(NamedTuple):
287+
values: dpt.usm_ndarray
288+
indices: dpt.usm_ndarray
289+
290+
291+
def top_k(x, k, /, *, axis=None, mode="largest"):
292+
"""top_k(x, k, axis=None, mode="largest")
293+
294+
Returns the `k` largest or smallest values and their indices in the input
295+
array `x` along the specified axis `axis`.
296+
297+
Args:
298+
x (usm_ndarray):
299+
input array.
300+
k (int):
301+
number of elements to find. Must be a positive integer value.
302+
axis (Optional[int]):
303+
axis along which to search. If `None`, the search will be performed
304+
over the flattened array. Default: ``None``.
305+
mode (Literal["largest", "smallest"]):
306+
search mode. Must be one of the following modes:
307+
308+
- `"largest"`: return the `k` largest elements.
309+
- `"smallest"`: return the `k` smallest elements.
310+
311+
Default: `"largest"`.
312+
313+
Returns:
314+
tuple[usm_ndarray, usm_ndarray]
315+
a namedtuple `(values, indices)` whose
316+
317+
* first element `values` will be an array containing the `k`
318+
largest or smallest elements of `x`. The array has the same data
319+
type as `x`. If `axis` was `None`, `values` will be a
320+
one-dimensional array with shape `(k,)` and otherwise, `values`
321+
will have shape `x.shape[:axis] + (k,) + x.shape[axis+1:]`
322+
* second element `indices` will be an array containing indices of
323+
`x` that result in `values`. The array will have the same shape
324+
as `values` and will have the default array index data type.
325+
"""
326+
largest = _get_top_k_largest(mode)
327+
if not isinstance(x, dpt.usm_ndarray):
328+
raise TypeError(
329+
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
330+
)
331+
332+
k = operator.index(k)
333+
if k < 0:
334+
raise ValueError("`k` must be a positive integer value")
335+
336+
nd = x.ndim
337+
if axis is None:
338+
sz = x.size
339+
if nd == 0:
340+
if k > 1:
341+
raise ValueError(f"`k`={k} is out of bounds 1")
342+
return TopKResult(
343+
dpt.copy(x, order="C"),
344+
dpt.zeros_like(
345+
x, dtype=ti.default_device_index_type(x.sycl_queue)
346+
),
347+
)
348+
arr = x
349+
n_search_dims = None
350+
res_sh = k
351+
else:
352+
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
353+
sz = x.shape[axis]
354+
a1 = axis + 1
355+
if a1 == nd:
356+
perm = list(range(nd))
357+
arr = x
358+
else:
359+
perm = [i for i in range(nd) if i != axis] + [
360+
axis,
361+
]
362+
arr = dpt.permute_dims(x, perm)
363+
n_search_dims = 1
364+
res_sh = arr.shape[: nd - 1] + (k,)
365+
366+
if k > sz:
367+
raise ValueError(f"`k`={k} is out of bounds {sz}")
368+
369+
exec_q = x.sycl_queue
370+
_manager = du.SequentialOrderManager[exec_q]
371+
dep_evs = _manager.submitted_events
372+
373+
res_usm_type = arr.usm_type
374+
if arr.flags.c_contiguous:
375+
vals = dpt.empty(
376+
res_sh,
377+
dtype=arr.dtype,
378+
usm_type=res_usm_type,
379+
order="C",
380+
sycl_queue=exec_q,
381+
)
382+
inds = dpt.empty(
383+
res_sh,
384+
dtype=ti.default_device_index_type(exec_q),
385+
usm_type=res_usm_type,
386+
order="C",
387+
sycl_queue=exec_q,
388+
)
389+
ht_ev, impl_ev = _topk(
390+
src=arr,
391+
trailing_dims_to_search=n_search_dims,
392+
k=k,
393+
largest=largest,
394+
vals=vals,
395+
inds=inds,
396+
sycl_queue=exec_q,
397+
depends=dep_evs,
398+
)
399+
_manager.add_event_pair(ht_ev, impl_ev)
400+
else:
401+
tmp = dpt.empty_like(arr, order="C")
402+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
403+
src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
404+
)
405+
_manager.add_event_pair(ht_ev, copy_ev)
406+
vals = dpt.empty(
407+
res_sh,
408+
dtype=arr.dtype,
409+
usm_type=res_usm_type,
410+
order="C",
411+
sycl_queue=exec_q,
412+
)
413+
inds = dpt.empty(
414+
res_sh,
415+
dtype=ti.default_device_index_type(exec_q),
416+
usm_type=res_usm_type,
417+
order="C",
418+
sycl_queue=exec_q,
419+
)
420+
ht_ev, impl_ev = _topk(
421+
src=tmp,
422+
trailing_dims_to_search=n_search_dims,
423+
k=k,
424+
largest=largest,
425+
vals=vals,
426+
inds=inds,
427+
sycl_queue=exec_q,
428+
depends=[copy_ev],
429+
)
430+
_manager.add_event_pair(ht_ev, impl_ev)
431+
if axis is not None and a1 != nd:
432+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
433+
vals = dpt.permute_dims(vals, inv_perm)
434+
inds = dpt.permute_dims(inds, inv_perm)
435+
436+
return TopKResult(vals, inds)

dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
#include "kernels/dpctl_tensor_types.hpp"
3535
#include "kernels/sorting/search_sorted_detail.hpp"
36+
#include "kernels/sorting/sort_utils.hpp"
3637

3738
namespace dpctl
3839
{
@@ -811,20 +812,12 @@ sycl::event stable_argsort_axis1_contig_impl(
811812

812813
const size_t total_nelems = iter_nelems * sort_nelems;
813814

814-
sycl::event populate_indexed_data_ev =
815-
exec_q.submit([&](sycl::handler &cgh) {
816-
cgh.depends_on(depends);
815+
using dpctl::tensor::kernels::sort_utils_detail::iota_impl;
817816

818-
const sycl::range<1> range{total_nelems};
817+
using IotaKernelName = populate_index_data_krn<argTy, IndexTy, ValueComp>;
819818

820-
using KernelName =
821-
populate_index_data_krn<argTy, IndexTy, ValueComp>;
822-
823-
cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
824-
size_t i = id[0];
825-
res_tp[i] = static_cast<IndexTy>(i);
826-
});
827-
});
819+
sycl::event populate_indexed_data_ev = iota_impl<IotaKernelName, IndexTy>(
820+
exec_q, res_tp, total_nelems, depends);
828821

829822
// Sort segments of the array
830823
sycl::event base_sort_ev =
@@ -839,21 +832,11 @@ sycl::event stable_argsort_axis1_contig_impl(
839832
exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size,
840833
{base_sort_ev});
841834

842-
sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
843-
cgh.depends_on(merges_ev);
844-
845-
auto temp_acc =
846-
merge_sort_detail::GetReadOnlyAccess<decltype(res_tp)>{}(res_tp,
847-
cgh);
848-
849-
using KernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;
835+
using MapBackKernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;
836+
using dpctl::tensor::kernels::sort_utils_detail::map_back_impl;
850837

851-
const sycl::range<1> range{total_nelems};
852-
853-
cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
854-
res_tp[id] = (temp_acc[id] % sort_nelems);
855-
});
856-
});
838+
sycl::event write_out_ev = map_back_impl<MapBackKernelName, IndexTy>(
839+
exec_q, total_nelems, res_tp, res_tp, sort_nelems, {merges_ev});
857840

858841
return write_out_ev;
859842
}

0 commit comments

Comments
 (0)