Skip to content

Commit dc45158

Browse files
Further improvement to argsort
Eliminate use of temporary allocation altogether, cutting argsort execution time from 116 ms to 110 ms for 5670000 element array of type int32_t.
1 parent 0ebb16f commit dc45158

File tree

1 file changed

+9
-10
lines changed
  • dpctl/tensor/libtensor/include/kernels/sorting

1 file changed

+9
-10
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -817,29 +817,26 @@ sycl::event stable_argsort_axis1_contig_impl(
817817
size_t sorted_block_size =
818818
(sort_nelems >= 512) ? 512 : determine_automatically;
819819

820-
sycl::buffer<IndexTy, 1> index_data(
821-
sycl::range<1>(iter_nelems * sort_nelems));
820+
const size_t total_nelems = iter_nelems * sort_nelems;
822821

823822
sycl::event populate_indexed_data_ev =
824823
exec_q.submit([&](sycl::handler &cgh) {
825824
cgh.depends_on(depends);
826-
sycl::accessor acc(index_data, cgh, sycl::write_only,
827-
sycl::no_init);
828825

829-
auto const &range = index_data.get_range();
826+
const sycl::range<1> range{total_nelems};
830827

831828
using KernelName =
832829
populate_index_data_krn<argTy, IndexTy, ValueComp>;
833830

834831
cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
835832
size_t i = id[0];
836-
acc[i] = static_cast<IndexTy>(i);
833+
res_tp[i] = static_cast<IndexTy>(i);
837834
});
838835
});
839836

840837
// Sort segments of the array
841838
sycl::event base_sort_ev = sort_detail::sort_over_work_group_contig_impl(
842-
exec_q, iter_nelems, sort_nelems, index_data, res_tp, index_comp,
839+
exec_q, iter_nelems, sort_nelems, res_tp, res_tp, index_comp,
843840
sorted_block_size, // modified in place with size of sorted block size
844841
{populate_indexed_data_ev});
845842

@@ -856,9 +853,11 @@ sycl::event stable_argsort_axis1_contig_impl(
856853

857854
using KernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;
858855

859-
cgh.parallel_for<KernelName>(
860-
index_data.get_range(),
861-
[=](sycl::id<1> id) { res_tp[id] = (temp_acc[id] % sort_nelems); });
856+
const sycl::range<1> range{total_nelems};
857+
858+
cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
859+
res_tp[id] = (temp_acc[id] % sort_nelems);
860+
});
862861
});
863862

864863
return write_out_ev;

0 commit comments

Comments
 (0)