|
| 1 | +# Data Parallel Control (dpctl) |
| 2 | +# |
| 3 | +# Copyright 2020-2024 Intel Corporation |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | +import operator |
| 18 | +from typing import NamedTuple |
| 19 | + |
| 20 | +import dpctl.tensor as dpt |
| 21 | +import dpctl.tensor._tensor_impl as ti |
| 22 | +import dpctl.utils as du |
| 23 | +from dpctl.tensor._numpy_helper import normalize_axis_index |
| 24 | + |
| 25 | +from ._tensor_sorting_impl import _topk |
| 26 | + |
| 27 | + |
| 28 | +def _get_top_k_largest(mode): |
| 29 | + modes = {"largest": True, "smallest": False} |
| 30 | + try: |
| 31 | + return modes[mode] |
| 32 | + except KeyError: |
| 33 | + raise ValueError( |
| 34 | + f"`mode` must be `largest` or `smallest`. Got `{mode}`." |
| 35 | + ) |
| 36 | + |
| 37 | + |
| 38 | +class TopKResult(NamedTuple): |
| 39 | + values: dpt.usm_ndarray |
| 40 | + indices: dpt.usm_ndarray |
| 41 | + |
| 42 | + |
| 43 | +def top_k(x, k, /, *, axis=None, mode="largest"): |
| 44 | + """top_k(x, k, axis=None, mode="largest") |
| 45 | +
|
| 46 | + Returns the `k` largest or smallest values and their indices in the input |
| 47 | + array `x` along the specified axis `axis`. |
| 48 | +
|
| 49 | + Args: |
| 50 | + x (usm_ndarray): |
| 51 | + input array. |
| 52 | + k (int): |
| 53 | + number of elements to find. Must be a positive integer value. |
| 54 | + axis (Optional[int]): |
| 55 | + axis along which to search. If `None`, the search will be performed |
| 56 | + over the flattened array. Default: ``None``. |
| 57 | + mode (Literal["largest", "smallest"]): |
| 58 | + search mode. Must be one of the following modes: |
| 59 | + - `"largest"`: return the `k` largest elements. |
| 60 | + - `"smallest"`: return the `k` smallest elements. |
| 61 | + Default: `"largest"`. |
| 62 | +
|
| 63 | + Returns: |
| 64 | + tuple[usm_ndarray, usm_ndarray]: |
| 65 | + a namedtuple `(values, indices)` whose |
| 66 | +
|
| 67 | + - first element `values` will be an array containing the `k` largest or |
| 68 | + smallest elements of `x`. The array has the same data type as `x`. |
| 69 | + If `axis` was `None`, `values` will be a one-dimensional array |
| 70 | + with shape `(k,)` and otherwise, `values` will have shape |
| 71 | + `x.shape[:axis] + (k,) + x.shape[axis+1:]` |
| 72 | +
|
| 73 | + - second element `indices` will be an array containing indices of `x` |
| 74 | + that result in `values`. The array will have the same shape as |
| 75 | + `values` and will have the default array index data type. |
| 76 | + """ |
| 77 | + largest = _get_top_k_largest(mode) |
| 78 | + if not isinstance(x, dpt.usm_ndarray): |
| 79 | + raise TypeError( |
| 80 | + f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}" |
| 81 | + ) |
| 82 | + |
| 83 | + k = operator.index(k) |
| 84 | + if k < 0: |
| 85 | + raise ValueError("`k` must be a positive integer value") |
| 86 | + |
| 87 | + nd = x.ndim |
| 88 | + if axis is None: |
| 89 | + sz = x.size |
| 90 | + if nd == 0: |
| 91 | + return TopKResult( |
| 92 | + dpt.copy(x, order="C"), |
| 93 | + dpt.zeros_like( |
| 94 | + x, dtype=ti.default_device_index_type(x.sycl_queue) |
| 95 | + ), |
| 96 | + ) |
| 97 | + arr = x |
| 98 | + n_search_dims = None |
| 99 | + res_sh = k |
| 100 | + else: |
| 101 | + axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis") |
| 102 | + sz = x.shape[axis] |
| 103 | + a1 = axis + 1 |
| 104 | + if a1 == nd: |
| 105 | + perm = list(range(nd)) |
| 106 | + arr = x |
| 107 | + else: |
| 108 | + perm = [i for i in range(nd) if i != axis] + [ |
| 109 | + axis, |
| 110 | + ] |
| 111 | + arr = dpt.permute_dims(x, perm) |
| 112 | + n_search_dims = 1 |
| 113 | + res_sh = arr.shape[: nd - 1] + (k,) |
| 114 | + |
| 115 | + if k > sz: |
| 116 | + raise ValueError(f"`k`={k} is out of bounds {sz}") |
| 117 | + |
| 118 | + exec_q = x.sycl_queue |
| 119 | + _manager = du.SequentialOrderManager[exec_q] |
| 120 | + dep_evs = _manager.submitted_events |
| 121 | + |
| 122 | + res_usm_type = arr.usm_type |
| 123 | + if arr.flags.c_contiguous: |
| 124 | + vals = dpt.empty( |
| 125 | + res_sh, |
| 126 | + dtype=arr.dtype, |
| 127 | + usm_type=res_usm_type, |
| 128 | + order="C", |
| 129 | + sycl_queue=exec_q, |
| 130 | + ) |
| 131 | + inds = dpt.empty( |
| 132 | + res_sh, |
| 133 | + dtype=ti.default_device_index_type(exec_q), |
| 134 | + usm_type=res_usm_type, |
| 135 | + order="C", |
| 136 | + sycl_queue=exec_q, |
| 137 | + ) |
| 138 | + ht_ev, impl_ev = _topk( |
| 139 | + src=arr, |
| 140 | + trailing_dims_to_search=n_search_dims, |
| 141 | + k=k, |
| 142 | + largest=largest, |
| 143 | + vals=vals, |
| 144 | + inds=inds, |
| 145 | + sycl_queue=exec_q, |
| 146 | + depends=dep_evs, |
| 147 | + ) |
| 148 | + _manager.add_event_pair(ht_ev, impl_ev) |
| 149 | + else: |
| 150 | + tmp = dpt.empty_like(arr, order="C") |
| 151 | + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( |
| 152 | + src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs |
| 153 | + ) |
| 154 | + _manager.add_event_pair(ht_ev, copy_ev) |
| 155 | + vals = dpt.empty( |
| 156 | + res_sh, |
| 157 | + dtype=arr.dtype, |
| 158 | + usm_type=res_usm_type, |
| 159 | + order="C", |
| 160 | + sycl_queue=exec_q, |
| 161 | + ) |
| 162 | + inds = dpt.empty( |
| 163 | + res_sh, |
| 164 | + dtype=ti.default_device_index_type(exec_q), |
| 165 | + usm_type=res_usm_type, |
| 166 | + order="C", |
| 167 | + sycl_queue=exec_q, |
| 168 | + ) |
| 169 | + ht_ev, impl_ev = _topk( |
| 170 | + src=tmp, |
| 171 | + trailing_dims_to_search=n_search_dims, |
| 172 | + k=k, |
| 173 | + largest=largest, |
| 174 | + vals=vals, |
| 175 | + inds=inds, |
| 176 | + sycl_queue=exec_q, |
| 177 | + depends=[copy_ev], |
| 178 | + ) |
| 179 | + _manager.add_event_pair(ht_ev, impl_ev) |
| 180 | + if axis is not None and a1 != nd: |
| 181 | + inv_perm = sorted(range(nd), key=lambda d: perm[d]) |
| 182 | + vals = dpt.permute_dims(vals, inv_perm) |
| 183 | + inds = dpt.permute_dims(inds, inv_perm) |
| 184 | + return TopKResult(vals, inds) |
0 commit comments