Skip to content

Commit 9a7868d

Browse files
Do not invoke map_back kernel if iter_nelems is 1
The map_back operation is a no-op then.
1 parent 082fca1 commit 9a7868d

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,10 @@ sycl::event stable_argsort_axis1_contig_impl(
814814
exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size,
815815
{base_sort_ev});
816816

817+
// no need to map back if iter_nelems == 1
818+
if (iter_nelems == 1u)
819+
return merges_ev;
820+
817821
using MapBackKernelName = index_map_to_rows_krn<argTy, IndexTy>;
818822
using dpctl::tensor::kernels::sort_utils_detail::map_back_impl;
819823

dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,11 +1867,17 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
18671867
using MapBackKernelName = radix_argsort_index_write_out_krn<argTy, IndexTy>;
18681868
using dpctl::tensor::kernels::sort_utils_detail::map_back_impl;
18691869

1870-
sycl::event map_back_ev = map_back_impl<MapBackKernelName, IndexTy>(
1871-
exec_q, total_nelems, res_tp, res_tp, sort_nelems, {radix_sort_ev});
1870+
sycl::event dep = radix_sort_ev;
1871+
1872+
// no need to perform map_back ( id % sort_nelems)
1873+
// if total_nelems == sort_nelems
1874+
if (iter_nelems > 1u) {
1875+
dep = map_back_impl<MapBackKernelName, IndexTy>(
1876+
exec_q, total_nelems, res_tp, res_tp, sort_nelems, {dep});
1877+
}
18721878

18731879
sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
1874-
exec_q, {map_back_ev}, workspace_owner);
1880+
exec_q, {dep}, workspace_owner);
18751881

18761882
return cleanup_ev;
18771883
}

0 commit comments

Comments
 (0)