From 9061c372ca64e4383a7df77d4f7170eb2ab8cae5 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 9 Oct 2024 08:02:29 -0500 Subject: [PATCH 1/4] Remove unused imported qualifier in py_argsort function --- dpctl/tensor/libtensor/source/sorting/argsort.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/dpctl/tensor/libtensor/source/sorting/argsort.cpp b/dpctl/tensor/libtensor/source/sorting/argsort.cpp index 0bd998e90b..b5a052ef94 100644 --- a/dpctl/tensor/libtensor/source/sorting/argsort.cpp +++ b/dpctl/tensor/libtensor/source/sorting/argsort.cpp @@ -129,8 +129,6 @@ py_argsort(const dpctl::tensor::usm_ndarray &src, bool is_dst_c_contig = dst.is_c_contiguous(); if (is_src_c_contig && is_dst_c_contig) { - using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl; - static constexpr py::ssize_t zero_offset = py::ssize_t(0); auto fn = stable_sort_contig_fns[src_typeid][dst_typeid]; From 0ebb16fafad825c3f81a358cb578066262c4c428 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 9 Oct 2024 08:11:46 -0500 Subject: [PATCH 2/4] Change implementation of argsort Instead of implementing argsort as sort over structures (index, value), with subsequent projection to index, it is now implemented as sort over linear indices themselves, with dereferencing comparator, and subsequent mapping from linear index to row-wise index. On Iris Xe, argsort call took 215 ms to argsort 5670000 elements of type int32, and it now takes 117 ms. The new implementation also has smaller temporary allocation footprint. Previously, it would allocate 2*(sizeof(ValueT) + sizeof(IndexT)), now it only allocates sizeof(IndexT) for storing linear indices. --- .../include/kernels/sorting/sort.hpp | 54 ++++++++++--------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp index b26638ff75..e941daf6e1 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp @@ -768,18 +768,25 @@ sycl::event stable_sort_axis1_contig_impl( } } -template -class populate_indexed_data_krn; +template class populate_index_data_krn; -template class index_write_out_krn; +template class index_map_to_rows_krn; -template struct TupleComp +template struct IndexComp { - bool operator()(const pairT &p1, const pairT &p2) const + IndexComp(const ValueT *data, const ValueComp &comp_op) + : ptr(data), value_comp(comp_op) { - const ValueComp value_comp{}; - return value_comp(std::get<0>(p1), std::get<0>(p2)); } + + bool operator()(const IndexT &i1, const IndexT &i2) const + { + return value_comp(ptr[i1], ptr[i2]); + } + +private: + const ValueT *ptr; + ValueComp value_comp; }; template (res_cp) + iter_res_offset + sort_res_offset; - using ValueIndexT = std::pair; - const TupleComp tuple_comp{}; + const IndexComp index_comp{arg_tp, ValueComp{}}; static constexpr size_t determine_automatically = 0; size_t sorted_block_size = (sort_nelems >= 512) ? 512 : determine_automatically; - sycl::buffer indexed_data( - sycl::range<1>(iter_nelems * sort_nelems)); - sycl::buffer temp_buf( + sycl::buffer index_data( sycl::range<1>(iter_nelems * sort_nelems)); sycl::event populate_indexed_data_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); - sycl::accessor acc(indexed_data, cgh, sycl::write_only, + sycl::accessor acc(index_data, cgh, sycl::write_only, sycl::no_init); - auto const &range = indexed_data.get_range(); + auto const &range = index_data.get_range(); using KernelName = - populate_indexed_data_krn; + populate_index_data_krn; cgh.parallel_for(range, [=](sycl::id<1> id) { size_t i = id[0]; - size_t sort_id = i % sort_nelems; - acc[i] = - std::make_pair(arg_tp[i], static_cast(sort_id)); + acc[i] = static_cast(i); }); }); // Sort segments of the array sycl::event base_sort_ev = sort_detail::sort_over_work_group_contig_impl( - exec_q, iter_nelems, sort_nelems, indexed_data, temp_buf, tuple_comp, + exec_q, iter_nelems, sort_nelems, index_data, res_tp, index_comp, sorted_block_size, // modified in place with size of sorted block size {populate_indexed_data_ev}); // Merge segments in parallel until all elements are sorted sycl::event merges_ev = sort_detail::merge_sorted_block_contig_impl( - exec_q, iter_nelems, sort_nelems, temp_buf, tuple_comp, - sorted_block_size, {base_sort_ev}); + exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size, + {base_sort_ev}); sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(merges_ev); auto temp_acc = - sort_detail::GetReadOnlyAccess{}(temp_buf, cgh); + sort_detail::GetReadOnlyAccess{}(res_tp, cgh); - using KernelName = index_write_out_krn; + using KernelName = index_map_to_rows_krn; - cgh.parallel_for(temp_buf.get_range(), [=](sycl::id<1> id) { - res_tp[id] = std::get<1>(temp_acc[id]); - }); + cgh.parallel_for( + index_data.get_range(), + [=](sycl::id<1> id) { res_tp[id] = (temp_acc[id] % sort_nelems); }); }); return write_out_ev; From dc45158b38002a67c14f2fa3046bfc5787db652c Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 9 Oct 2024 08:36:35 -0500 Subject: [PATCH 3/4] Further improvement to argsort Eliminate use of temporary allocation altogether, cutting argsort execution time from 116 ms to 110 ms for 5670000 element array of type int32_t. --- .../include/kernels/sorting/sort.hpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp index e941daf6e1..28db00facd 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp @@ -817,29 +817,26 @@ sycl::event stable_argsort_axis1_contig_impl( size_t sorted_block_size = (sort_nelems >= 512) ? 512 : determine_automatically; - sycl::buffer index_data( - sycl::range<1>(iter_nelems * sort_nelems)); + const size_t total_nelems = iter_nelems * sort_nelems; sycl::event populate_indexed_data_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); - sycl::accessor acc(index_data, cgh, sycl::write_only, - sycl::no_init); - auto const &range = index_data.get_range(); + const sycl::range<1> range{total_nelems}; using KernelName = populate_index_data_krn; cgh.parallel_for(range, [=](sycl::id<1> id) { size_t i = id[0]; - acc[i] = static_cast(i); + res_tp[i] = static_cast(i); }); }); // Sort segments of the array sycl::event base_sort_ev = sort_detail::sort_over_work_group_contig_impl( - exec_q, iter_nelems, sort_nelems, index_data, res_tp, index_comp, + exec_q, iter_nelems, sort_nelems, res_tp, res_tp, index_comp, sorted_block_size, // modified in place with size of sorted block size {populate_indexed_data_ev}); @@ -856,9 +853,11 @@ sycl::event stable_argsort_axis1_contig_impl( using KernelName = index_map_to_rows_krn; - cgh.parallel_for( - index_data.get_range(), - [=](sycl::id<1> id) { res_tp[id] = (temp_acc[id] % sort_nelems); }); + const sycl::range<1> range{total_nelems}; + + cgh.parallel_for(range, [=](sycl::id<1> id) { + res_tp[id] = (temp_acc[id] % sort_nelems); + }); }); return write_out_ev; From aeb1b1ffaee2f4cecc7829e1f997305c5dceb0bc Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 9 Oct 2024 08:59:45 -0500 Subject: [PATCH 4/4] Add changelog entry for improvement to argsort performance --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bafc653489..2e4dcea993 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Improved performance of copy-and-cast operations from `numpy.ndarray` to `tensor.usm_ndarray` for contiguous inputs [gh-1829](https://github.com/IntelPython/dpctl/pull/1829) * Improved performance of copying operation to C-/F-contig array, with optimization for batch of square matrices [gh-1850](https://github.com/IntelPython/dpctl/pull/1850) +* Improved performance of `tensor.argsort` function for all types [gh-1859](https://github.com/IntelPython/dpctl/pull/1859) ### Fixed