Skip to content

Commit 0ebb16f

Browse files
Change implementation of argsort
Instead of implementing argsort as sort over structures (index, value), with subsequent projection to index, it is now implemented as sort over linear indices themselves, with dereferencing comparator, and subsequent mapping from linear index to row-wise index. On Iris Xe, argsort call took 215 ms to argsort 5670000 elements of type int32, and it now takes 117 ms. The new implementation also has smaller temporary allocation footprint. Previously, it would allocate 2*(sizeof(ValueT) + sizeof(IndexT)), now it only allocates sizeof(IndexT) for storing linear indices.
1 parent 9061c37 commit 0ebb16f

File tree

1 file changed

+28
-26
lines changed
  • dpctl/tensor/libtensor/include/kernels/sorting

1 file changed

+28
-26
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -768,18 +768,25 @@ sycl::event stable_sort_axis1_contig_impl(
768768
}
769769
}
770770

771-
template <typename T1, typename T2, typename T3>
772-
class populate_indexed_data_krn;
771+
template <typename T1, typename T2, typename T3> class populate_index_data_krn;
773772

774-
template <typename T1, typename T2, typename T3> class index_write_out_krn;
773+
template <typename T1, typename T2, typename T3> class index_map_to_rows_krn;
775774

776-
template <typename pairT, typename ValueComp> struct TupleComp
775+
template <typename IndexT, typename ValueT, typename ValueComp> struct IndexComp
777776
{
778-
bool operator()(const pairT &p1, const pairT &p2) const
777+
IndexComp(const ValueT *data, const ValueComp &comp_op)
778+
: ptr(data), value_comp(comp_op)
779779
{
780-
const ValueComp value_comp{};
781-
return value_comp(std::get<0>(p1), std::get<0>(p2));
782780
}
781+
782+
bool operator()(const IndexT &i1, const IndexT &i2) const
783+
{
784+
return value_comp(ptr[i1], ptr[i2]);
785+
}
786+
787+
private:
788+
const ValueT *ptr;
789+
ValueComp value_comp;
783790
};
784791

785792
template <typename argTy,
@@ -804,59 +811,54 @@ sycl::event stable_argsort_axis1_contig_impl(
804811
IndexTy *res_tp =
805812
reinterpret_cast<IndexTy *>(res_cp) + iter_res_offset + sort_res_offset;
806813

807-
using ValueIndexT = std::pair<argTy, IndexTy>;
808-
const TupleComp<ValueIndexT, ValueComp> tuple_comp{};
814+
const IndexComp<IndexTy, argTy, ValueComp> index_comp{arg_tp, ValueComp{}};
809815

810816
static constexpr size_t determine_automatically = 0;
811817
size_t sorted_block_size =
812818
(sort_nelems >= 512) ? 512 : determine_automatically;
813819

814-
sycl::buffer<ValueIndexT, 1> indexed_data(
815-
sycl::range<1>(iter_nelems * sort_nelems));
816-
sycl::buffer<ValueIndexT, 1> temp_buf(
820+
sycl::buffer<IndexTy, 1> index_data(
817821
sycl::range<1>(iter_nelems * sort_nelems));
818822

819823
sycl::event populate_indexed_data_ev =
820824
exec_q.submit([&](sycl::handler &cgh) {
821825
cgh.depends_on(depends);
822-
sycl::accessor acc(indexed_data, cgh, sycl::write_only,
826+
sycl::accessor acc(index_data, cgh, sycl::write_only,
823827
sycl::no_init);
824828

825-
auto const &range = indexed_data.get_range();
829+
auto const &range = index_data.get_range();
826830

827831
using KernelName =
828-
populate_indexed_data_krn<argTy, IndexTy, ValueComp>;
832+
populate_index_data_krn<argTy, IndexTy, ValueComp>;
829833

830834
cgh.parallel_for<KernelName>(range, [=](sycl::id<1> id) {
831835
size_t i = id[0];
832-
size_t sort_id = i % sort_nelems;
833-
acc[i] =
834-
std::make_pair(arg_tp[i], static_cast<IndexTy>(sort_id));
836+
acc[i] = static_cast<IndexTy>(i);
835837
});
836838
});
837839

838840
// Sort segments of the array
839841
sycl::event base_sort_ev = sort_detail::sort_over_work_group_contig_impl(
840-
exec_q, iter_nelems, sort_nelems, indexed_data, temp_buf, tuple_comp,
842+
exec_q, iter_nelems, sort_nelems, index_data, res_tp, index_comp,
841843
sorted_block_size, // modified in place with size of sorted block size
842844
{populate_indexed_data_ev});
843845

844846
// Merge segments in parallel until all elements are sorted
845847
sycl::event merges_ev = sort_detail::merge_sorted_block_contig_impl(
846-
exec_q, iter_nelems, sort_nelems, temp_buf, tuple_comp,
847-
sorted_block_size, {base_sort_ev});
848+
exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size,
849+
{base_sort_ev});
848850

849851
sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
850852
cgh.depends_on(merges_ev);
851853

852854
auto temp_acc =
853-
sort_detail::GetReadOnlyAccess<decltype(temp_buf)>{}(temp_buf, cgh);
855+
sort_detail::GetReadOnlyAccess<decltype(res_tp)>{}(res_tp, cgh);
854856

855-
using KernelName = index_write_out_krn<argTy, IndexTy, ValueComp>;
857+
using KernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;
856858

857-
cgh.parallel_for<KernelName>(temp_buf.get_range(), [=](sycl::id<1> id) {
858-
res_tp[id] = std::get<1>(temp_acc[id]);
859-
});
859+
cgh.parallel_for<KernelName>(
860+
index_data.get_range(),
861+
[=](sycl::id<1> id) { res_tp[id] = (temp_acc[id] % sort_nelems); });
860862
});
861863

862864
return write_out_ev;

0 commit comments

Comments
 (0)