@@ -768,18 +768,25 @@ sycl::event stable_sort_axis1_contig_impl(
768
768
}
769
769
}
770
770
771
- template <typename T1, typename T2, typename T3>
772
- class populate_indexed_data_krn ;
771
+ template <typename T1, typename T2, typename T3> class populate_index_data_krn ;
773
772
774
- template <typename T1, typename T2, typename T3> class index_write_out_krn ;
773
+ template <typename T1, typename T2, typename T3> class index_map_to_rows_krn ;
775
774
776
- template <typename pairT , typename ValueComp> struct TupleComp
775
+ template <typename IndexT , typename ValueT, typename ValueComp> struct IndexComp
777
776
{
778
- bool operator ()(const pairT &p1, const pairT &p2) const
777
+ IndexComp (const ValueT *data, const ValueComp &comp_op)
778
+ : ptr(data), value_comp(comp_op)
779
779
{
780
- const ValueComp value_comp{};
781
- return value_comp (std::get<0 >(p1), std::get<0 >(p2));
782
780
}
781
+
782
+ bool operator ()(const IndexT &i1, const IndexT &i2) const
783
+ {
784
+ return value_comp (ptr[i1], ptr[i2]);
785
+ }
786
+
787
+ private:
788
+ const ValueT *ptr;
789
+ ValueComp value_comp;
783
790
};
784
791
785
792
template <typename argTy,
@@ -804,59 +811,54 @@ sycl::event stable_argsort_axis1_contig_impl(
804
811
IndexTy *res_tp =
805
812
reinterpret_cast <IndexTy *>(res_cp) + iter_res_offset + sort_res_offset;
806
813
807
- using ValueIndexT = std::pair<argTy, IndexTy>;
808
- const TupleComp<ValueIndexT, ValueComp> tuple_comp{};
814
+ const IndexComp<IndexTy, argTy, ValueComp> index_comp{arg_tp, ValueComp{}};
809
815
810
816
static constexpr size_t determine_automatically = 0 ;
811
817
size_t sorted_block_size =
812
818
(sort_nelems >= 512 ) ? 512 : determine_automatically;
813
819
814
- sycl::buffer<ValueIndexT, 1 > indexed_data (
815
- sycl::range<1 >(iter_nelems * sort_nelems));
816
- sycl::buffer<ValueIndexT, 1 > temp_buf (
820
+ sycl::buffer<IndexTy, 1 > index_data (
817
821
sycl::range<1 >(iter_nelems * sort_nelems));
818
822
819
823
sycl::event populate_indexed_data_ev =
820
824
exec_q.submit ([&](sycl::handler &cgh) {
821
825
cgh.depends_on (depends);
822
- sycl::accessor acc (indexed_data , cgh, sycl::write_only,
826
+ sycl::accessor acc (index_data , cgh, sycl::write_only,
823
827
sycl::no_init);
824
828
825
- auto const &range = indexed_data .get_range ();
829
+ auto const &range = index_data .get_range ();
826
830
827
831
using KernelName =
828
- populate_indexed_data_krn <argTy, IndexTy, ValueComp>;
832
+ populate_index_data_krn <argTy, IndexTy, ValueComp>;
829
833
830
834
cgh.parallel_for <KernelName>(range, [=](sycl::id<1 > id) {
831
835
size_t i = id[0 ];
832
- size_t sort_id = i % sort_nelems;
833
- acc[i] =
834
- std::make_pair (arg_tp[i], static_cast <IndexTy>(sort_id));
836
+ acc[i] = static_cast <IndexTy>(i);
835
837
});
836
838
});
837
839
838
840
// Sort segments of the array
839
841
sycl::event base_sort_ev = sort_detail::sort_over_work_group_contig_impl (
840
- exec_q, iter_nelems, sort_nelems, indexed_data, temp_buf, tuple_comp ,
842
+ exec_q, iter_nelems, sort_nelems, index_data, res_tp, index_comp ,
841
843
sorted_block_size, // modified in place with size of sorted block size
842
844
{populate_indexed_data_ev});
843
845
844
846
// Merge segments in parallel until all elements are sorted
845
847
sycl::event merges_ev = sort_detail::merge_sorted_block_contig_impl (
846
- exec_q, iter_nelems, sort_nelems, temp_buf, tuple_comp ,
847
- sorted_block_size, {base_sort_ev});
848
+ exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size ,
849
+ {base_sort_ev});
848
850
849
851
sycl::event write_out_ev = exec_q.submit ([&](sycl::handler &cgh) {
850
852
cgh.depends_on (merges_ev);
851
853
852
854
auto temp_acc =
853
- sort_detail::GetReadOnlyAccess<decltype (temp_buf )>{}(temp_buf , cgh);
855
+ sort_detail::GetReadOnlyAccess<decltype (res_tp )>{}(res_tp , cgh);
854
856
855
- using KernelName = index_write_out_krn <argTy, IndexTy, ValueComp>;
857
+ using KernelName = index_map_to_rows_krn <argTy, IndexTy, ValueComp>;
856
858
857
- cgh.parallel_for <KernelName>(temp_buf. get_range (), [=](sycl::id< 1 > id) {
858
- res_tp[id] = std::get< 1 >(temp_acc[id]);
859
- });
859
+ cgh.parallel_for <KernelName>(
860
+ index_data. get_range (),
861
+ [=](sycl::id< 1 > id) { res_tp[id] = (temp_acc[id] % sort_nelems); });
860
862
});
861
863
862
864
return write_out_ev;
0 commit comments