Skip to content

Commit 79b97d9

Browse files
committed
Implements top_k in dpctl.tensor
1 parent a4bc0c4 commit 79b97d9

File tree

7 files changed

+2424
-0
lines changed

7 files changed

+2424
-0
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ set(_sorting_sources
115115
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp
116116
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
117117
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
118+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/topk.cpp
118119
)
119120
set(_sorting_radix_sources
120121
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@
201201
)
202202
from ._sorting import argsort, sort
203203
from ._testing import allclose
204+
from ._topk import top_k
204205
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
205206

206207
__all__ = [
@@ -387,4 +388,5 @@
387388
"DLDeviceType",
388389
"take_along_axis",
389390
"put_along_axis",
391+
"top_k",
390392
]

dpctl/tensor/_topk.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2024 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import operator
18+
from typing import NamedTuple
19+
20+
import dpctl.tensor as dpt
21+
import dpctl.tensor._tensor_impl as ti
22+
import dpctl.utils as du
23+
from dpctl.tensor._numpy_helper import normalize_axis_index
24+
25+
from ._tensor_sorting_impl import _topk
26+
27+
28+
def _get_top_k_largest(mode):
29+
modes = {"largest": True, "smallest": False}
30+
try:
31+
return modes[mode]
32+
except KeyError:
33+
raise ValueError(
34+
f"`mode` must be `largest` or `smallest`. Got `{mode}`."
35+
)
36+
37+
38+
class TopKResult(NamedTuple):
39+
values: dpt.usm_ndarray
40+
indices: dpt.usm_ndarray
41+
42+
43+
def top_k(x, k, /, *, axis=None, mode="largest"):
44+
"""top_k(x, k, axis=None, mode="largest")
45+
46+
Returns the `k` largest or smallest values and their indices in the input
47+
array `x` along the specified axis `axis`.
48+
49+
Args:
50+
x (usm_ndarray):
51+
input array.
52+
k (int):
53+
number of elements to find. Must be a positive integer value.
54+
axis (Optional[int]):
55+
axis along which to search. If `None`, the search will be performed
56+
over the flattened array. Default: ``None``.
57+
mode (Literal["largest", "smallest"]):
58+
search mode. Must be one of the following modes:
59+
- `"largest"`: return the `k` largest elements.
60+
- `"smallest"`: return the `k` smallest elements.
61+
Default: `"largest"`.
62+
63+
Returns:
64+
tuple[usm_ndarray, usm_ndarray]:
65+
a namedtuple `(values, indices)` whose
66+
67+
- first element `values` will be an array containing the `k` largest or
68+
smallest elements of `x`. The array has the same data type as `x`.
69+
If `axis` was `None`, `values` will be a one-dimensional array
70+
with shape `(k,)` and otherwise, `values` will have shape
71+
`x.shape[:axis] + (k,) + x.shape[axis+1:]`
72+
73+
- second element `indices` will be an array containing indices of `x`
74+
that result in `values`. The array will have the same shape as
75+
`values` and will have the default array index data type.
76+
"""
77+
largest = _get_top_k_largest(mode)
78+
if not isinstance(x, dpt.usm_ndarray):
79+
raise TypeError(
80+
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
81+
)
82+
83+
k = operator.index(k)
84+
if k < 0:
85+
raise ValueError("`k` must be a positive integer value")
86+
87+
nd = x.ndim
88+
if axis is None:
89+
sz = x.size
90+
if nd == 0:
91+
return TopKResult(
92+
dpt.copy(x, order="C"),
93+
dpt.zeros_like(
94+
x, dtype=ti.default_device_index_type(x.sycl_queue)
95+
),
96+
)
97+
arr = x
98+
n_search_dims = None
99+
res_sh = k
100+
else:
101+
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
102+
sz = x.shape[axis]
103+
a1 = axis + 1
104+
if a1 == nd:
105+
perm = list(range(nd))
106+
arr = x
107+
else:
108+
perm = [i for i in range(nd) if i != axis] + [
109+
axis,
110+
]
111+
arr = dpt.permute_dims(x, perm)
112+
n_search_dims = 1
113+
res_sh = arr.shape[: nd - 1] + (k,)
114+
115+
if k > sz:
116+
raise ValueError(f"`k`={k} is out of bounds {sz}")
117+
118+
exec_q = x.sycl_queue
119+
_manager = du.SequentialOrderManager[exec_q]
120+
dep_evs = _manager.submitted_events
121+
122+
res_usm_type = arr.usm_type
123+
if arr.flags.c_contiguous:
124+
vals = dpt.empty(
125+
res_sh,
126+
dtype=arr.dtype,
127+
usm_type=res_usm_type,
128+
order="C",
129+
sycl_queue=exec_q,
130+
)
131+
inds = dpt.empty(
132+
res_sh,
133+
dtype=ti.default_device_index_type(exec_q),
134+
usm_type=res_usm_type,
135+
order="C",
136+
sycl_queue=exec_q,
137+
)
138+
ht_ev, impl_ev = _topk(
139+
src=arr,
140+
trailing_dims_to_search=n_search_dims,
141+
k=k,
142+
largest=largest,
143+
vals=vals,
144+
inds=inds,
145+
sycl_queue=exec_q,
146+
depends=dep_evs,
147+
)
148+
_manager.add_event_pair(ht_ev, impl_ev)
149+
else:
150+
tmp = dpt.empty_like(arr, order="C")
151+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
152+
src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
153+
)
154+
_manager.add_event_pair(ht_ev, copy_ev)
155+
vals = dpt.empty(
156+
res_sh,
157+
dtype=arr.dtype,
158+
usm_type=res_usm_type,
159+
order="C",
160+
sycl_queue=exec_q,
161+
)
162+
inds = dpt.empty(
163+
res_sh,
164+
dtype=ti.default_device_index_type(exec_q),
165+
usm_type=res_usm_type,
166+
order="C",
167+
sycl_queue=exec_q,
168+
)
169+
ht_ev, impl_ev = _topk(
170+
src=tmp,
171+
trailing_dims_to_search=n_search_dims,
172+
k=k,
173+
largest=largest,
174+
vals=vals,
175+
inds=inds,
176+
sycl_queue=exec_q,
177+
depends=[copy_ev],
178+
)
179+
_manager.add_event_pair(ht_ev, impl_ev)
180+
if axis is not None and a1 != nd:
181+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
182+
vals = dpt.permute_dims(vals, inv_perm)
183+
inds = dpt.permute_dims(inds, inv_perm)
184+
return TopKResult(vals, inds)

0 commit comments

Comments
 (0)