diff --git a/CHANGELOG.md b/CHANGELOG.md index bafc653489..2e4dcea993 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Improved performance of copy-and-cast operations from `numpy.ndarray` to `tensor.usm_ndarray` for contiguous inputs [gh-1829](https://github.com/IntelPython/dpctl/pull/1829) * Improved performance of copying operation to C-/F-contig array, with optimization for batch of square matrices [gh-1850](https://github.com/IntelPython/dpctl/pull/1850) +* Improved performance of `tensor.argsort` function for all types [gh-1859](https://github.com/IntelPython/dpctl/pull/1859) ### Fixed diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp index b26638ff75..28db00facd 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp @@ -768,18 +768,25 @@ sycl::event stable_sort_axis1_contig_impl( } } -template -class populate_indexed_data_krn; +template class populate_index_data_krn; -template class index_write_out_krn; +template class index_map_to_rows_krn; -template struct TupleComp +template struct IndexComp { - bool operator()(const pairT &p1, const pairT &p2) const + IndexComp(const ValueT *data, const ValueComp &comp_op) + : ptr(data), value_comp(comp_op) { - const ValueComp value_comp{}; - return value_comp(std::get<0>(p1), std::get<0>(p2)); } + + bool operator()(const IndexT &i1, const IndexT &i2) const + { + return value_comp(ptr[i1], ptr[i2]); + } + +private: + const ValueT *ptr; + ValueComp value_comp; }; template (res_cp) + iter_res_offset + sort_res_offset; - using ValueIndexT = std::pair; - const TupleComp tuple_comp{}; + const IndexComp index_comp{arg_tp, ValueComp{}}; static constexpr size_t determine_automatically = 0; size_t sorted_block_size = (sort_nelems >= 512) ? 512 : determine_automatically; - sycl::buffer indexed_data( - sycl::range<1>(iter_nelems * sort_nelems)); - sycl::buffer temp_buf( - sycl::range<1>(iter_nelems * sort_nelems)); + const size_t total_nelems = iter_nelems * sort_nelems; sycl::event populate_indexed_data_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); - sycl::accessor acc(indexed_data, cgh, sycl::write_only, - sycl::no_init); - auto const &range = indexed_data.get_range(); + const sycl::range<1> range{total_nelems}; using KernelName = - populate_indexed_data_krn; + populate_index_data_krn; cgh.parallel_for(range, [=](sycl::id<1> id) { size_t i = id[0]; - size_t sort_id = i % sort_nelems; - acc[i] = - std::make_pair(arg_tp[i], static_cast(sort_id)); + res_tp[i] = static_cast(i); }); }); // Sort segments of the array sycl::event base_sort_ev = sort_detail::sort_over_work_group_contig_impl( - exec_q, iter_nelems, sort_nelems, indexed_data, temp_buf, tuple_comp, + exec_q, iter_nelems, sort_nelems, res_tp, res_tp, index_comp, sorted_block_size, // modified in place with size of sorted block size {populate_indexed_data_ev}); // Merge segments in parallel until all elements are sorted sycl::event merges_ev = sort_detail::merge_sorted_block_contig_impl( - exec_q, iter_nelems, sort_nelems, temp_buf, tuple_comp, - sorted_block_size, {base_sort_ev}); + exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size, + {base_sort_ev}); sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(merges_ev); auto temp_acc = - sort_detail::GetReadOnlyAccess{}(temp_buf, cgh); + sort_detail::GetReadOnlyAccess{}(res_tp, cgh); + + using KernelName = index_map_to_rows_krn; - using KernelName = index_write_out_krn; + const sycl::range<1> range{total_nelems}; - cgh.parallel_for(temp_buf.get_range(), [=](sycl::id<1> id) { - res_tp[id] = std::get<1>(temp_acc[id]); + cgh.parallel_for(range, [=](sycl::id<1> id) { + res_tp[id] = (temp_acc[id] % sort_nelems); }); }); diff --git a/dpctl/tensor/libtensor/source/sorting/argsort.cpp b/dpctl/tensor/libtensor/source/sorting/argsort.cpp index 0bd998e90b..b5a052ef94 100644 --- a/dpctl/tensor/libtensor/source/sorting/argsort.cpp +++ b/dpctl/tensor/libtensor/source/sorting/argsort.cpp @@ -129,8 +129,6 @@ py_argsort(const dpctl::tensor::usm_ndarray &src, bool is_dst_c_contig = dst.is_c_contiguous(); if (is_src_c_contig && is_dst_c_contig) { - using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl; - static constexpr py::ssize_t zero_offset = py::ssize_t(0); auto fn = stable_sort_contig_fns[src_typeid][dst_typeid];