|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
| 17 | +import operator |
| 18 | +from typing import NamedTuple |
| 19 | + |
17 | 20 | import dpctl.tensor as dpt
|
18 | 21 | import dpctl.tensor._tensor_impl as ti
|
19 | 22 | import dpctl.utils as du
|
|
24 | 27 | _argsort_descending,
|
25 | 28 | _sort_ascending,
|
26 | 29 | _sort_descending,
|
| 30 | + _topk, |
27 | 31 | )
|
28 | 32 | from ._tensor_sorting_radix_impl import (
|
29 | 33 | _radix_argsort_ascending,
|
@@ -267,3 +271,166 @@ def argsort(x, axis=-1, descending=False, stable=True, kind=None):
|
267 | 271 | inv_perm = sorted(range(nd), key=lambda d: perm[d])
|
268 | 272 | res = dpt.permute_dims(res, inv_perm)
|
269 | 273 | return res
|
| 274 | + |
| 275 | + |
| 276 | +def _get_top_k_largest(mode): |
| 277 | + modes = {"largest": True, "smallest": False} |
| 278 | + try: |
| 279 | + return modes[mode] |
| 280 | + except KeyError: |
| 281 | + raise ValueError( |
| 282 | + f"`mode` must be `largest` or `smallest`. Got `{mode}`." |
| 283 | + ) |
| 284 | + |
| 285 | + |
| 286 | +class TopKResult(NamedTuple): |
| 287 | + values: dpt.usm_ndarray |
| 288 | + indices: dpt.usm_ndarray |
| 289 | + |
| 290 | + |
| 291 | +def top_k(x, k, /, *, axis=None, mode="largest"): |
| 292 | + """top_k(x, k, axis=None, mode="largest") |
| 293 | +
|
| 294 | + Returns the `k` largest or smallest values and their indices in the input |
| 295 | + array `x` along the specified axis `axis`. |
| 296 | +
|
| 297 | + Args: |
| 298 | + x (usm_ndarray): |
| 299 | + input array. |
| 300 | + k (int): |
| 301 | + number of elements to find. Must be a positive integer value. |
| 302 | + axis (Optional[int]): |
| 303 | + axis along which to search. If `None`, the search will be performed |
| 304 | + over the flattened array. Default: ``None``. |
| 305 | + mode (Literal["largest", "smallest"]): |
| 306 | + search mode. Must be one of the following modes: |
| 307 | +
|
| 308 | + - `"largest"`: return the `k` largest elements. |
| 309 | + - `"smallest"`: return the `k` smallest elements. |
| 310 | +
|
| 311 | + Default: `"largest"`. |
| 312 | +
|
| 313 | + Returns: |
| 314 | + tuple[usm_ndarray, usm_ndarray] |
| 315 | + a namedtuple `(values, indices)` whose |
| 316 | +
|
| 317 | + * first element `values` will be an array containing the `k` |
| 318 | + largest or smallest elements of `x`. The array has the same data |
| 319 | + type as `x`. If `axis` was `None`, `values` will be a |
| 320 | + one-dimensional array with shape `(k,)` and otherwise, `values` |
| 321 | + will have shape `x.shape[:axis] + (k,) + x.shape[axis+1:]` |
| 322 | + * second element `indices` will be an array containing indices of |
| 323 | + `x` that result in `values`. The array will have the same shape |
| 324 | + as `values` and will have the default array index data type. |
| 325 | + """ |
| 326 | + largest = _get_top_k_largest(mode) |
| 327 | + if not isinstance(x, dpt.usm_ndarray): |
| 328 | + raise TypeError( |
| 329 | + f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}" |
| 330 | + ) |
| 331 | + |
| 332 | + k = operator.index(k) |
| 333 | + if k < 0: |
| 334 | + raise ValueError("`k` must be a positive integer value") |
| 335 | + |
| 336 | + nd = x.ndim |
| 337 | + if axis is None: |
| 338 | + sz = x.size |
| 339 | + if nd == 0: |
| 340 | + if k > 1: |
| 341 | + raise ValueError(f"`k`={k} is out of bounds 1") |
| 342 | + return TopKResult( |
| 343 | + dpt.copy(x, order="C"), |
| 344 | + dpt.zeros_like( |
| 345 | + x, dtype=ti.default_device_index_type(x.sycl_queue) |
| 346 | + ), |
| 347 | + ) |
| 348 | + arr = x |
| 349 | + n_search_dims = None |
| 350 | + res_sh = k |
| 351 | + else: |
| 352 | + axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis") |
| 353 | + sz = x.shape[axis] |
| 354 | + a1 = axis + 1 |
| 355 | + if a1 == nd: |
| 356 | + perm = list(range(nd)) |
| 357 | + arr = x |
| 358 | + else: |
| 359 | + perm = [i for i in range(nd) if i != axis] + [ |
| 360 | + axis, |
| 361 | + ] |
| 362 | + arr = dpt.permute_dims(x, perm) |
| 363 | + n_search_dims = 1 |
| 364 | + res_sh = arr.shape[: nd - 1] + (k,) |
| 365 | + |
| 366 | + if k > sz: |
| 367 | + raise ValueError(f"`k`={k} is out of bounds {sz}") |
| 368 | + |
| 369 | + exec_q = x.sycl_queue |
| 370 | + _manager = du.SequentialOrderManager[exec_q] |
| 371 | + dep_evs = _manager.submitted_events |
| 372 | + |
| 373 | + res_usm_type = arr.usm_type |
| 374 | + if arr.flags.c_contiguous: |
| 375 | + vals = dpt.empty( |
| 376 | + res_sh, |
| 377 | + dtype=arr.dtype, |
| 378 | + usm_type=res_usm_type, |
| 379 | + order="C", |
| 380 | + sycl_queue=exec_q, |
| 381 | + ) |
| 382 | + inds = dpt.empty( |
| 383 | + res_sh, |
| 384 | + dtype=ti.default_device_index_type(exec_q), |
| 385 | + usm_type=res_usm_type, |
| 386 | + order="C", |
| 387 | + sycl_queue=exec_q, |
| 388 | + ) |
| 389 | + ht_ev, impl_ev = _topk( |
| 390 | + src=arr, |
| 391 | + trailing_dims_to_search=n_search_dims, |
| 392 | + k=k, |
| 393 | + largest=largest, |
| 394 | + vals=vals, |
| 395 | + inds=inds, |
| 396 | + sycl_queue=exec_q, |
| 397 | + depends=dep_evs, |
| 398 | + ) |
| 399 | + _manager.add_event_pair(ht_ev, impl_ev) |
| 400 | + else: |
| 401 | + tmp = dpt.empty_like(arr, order="C") |
| 402 | + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( |
| 403 | + src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs |
| 404 | + ) |
| 405 | + _manager.add_event_pair(ht_ev, copy_ev) |
| 406 | + vals = dpt.empty( |
| 407 | + res_sh, |
| 408 | + dtype=arr.dtype, |
| 409 | + usm_type=res_usm_type, |
| 410 | + order="C", |
| 411 | + sycl_queue=exec_q, |
| 412 | + ) |
| 413 | + inds = dpt.empty( |
| 414 | + res_sh, |
| 415 | + dtype=ti.default_device_index_type(exec_q), |
| 416 | + usm_type=res_usm_type, |
| 417 | + order="C", |
| 418 | + sycl_queue=exec_q, |
| 419 | + ) |
| 420 | + ht_ev, impl_ev = _topk( |
| 421 | + src=tmp, |
| 422 | + trailing_dims_to_search=n_search_dims, |
| 423 | + k=k, |
| 424 | + largest=largest, |
| 425 | + vals=vals, |
| 426 | + inds=inds, |
| 427 | + sycl_queue=exec_q, |
| 428 | + depends=[copy_ev], |
| 429 | + ) |
| 430 | + _manager.add_event_pair(ht_ev, impl_ev) |
| 431 | + if axis is not None and a1 != nd: |
| 432 | + inv_perm = sorted(range(nd), key=lambda d: perm[d]) |
| 433 | + vals = dpt.permute_dims(vals, inv_perm) |
| 434 | + inds = dpt.permute_dims(inds, inv_perm) |
| 435 | + |
| 436 | + return TopKResult(vals, inds) |
0 commit comments