@@ -1748,20 +1748,19 @@ radix_sort_axis1_contig_impl(sycl::queue &exec_q,
1748
1748
}
1749
1749
1750
1750
template <typename ValueT, typename IndexT>
1751
- class populate_indexed_data_for_radix_sort_krn ;
1751
+ class radix_argsort_index_write_out_krn ;
1752
1752
1753
- template <typename ValueT, typename IndexT>
1754
- class index_write_out_for_radix_sort_krn ;
1753
+ template <typename ValueT, typename IndexT> class radix_argsort_iota_krn ;
1755
1754
1756
1755
template <typename argTy, typename IndexTy>
1757
1756
sycl::event
1758
1757
radix_argsort_axis1_contig_impl (sycl::queue &exec_q,
1759
1758
const bool sort_ascending,
1760
- // number of sub-arrays to sort (num. of rows in
1761
- // a matrix when sorting over rows)
1759
+ // number of sub-arrays to sort (num. of
1760
+ // rows in a matrix when sorting over rows)
1762
1761
size_t iter_nelems,
1763
- // size of each array to sort (length of rows,
1764
- // i.e. number of columns)
1762
+ // size of each array to sort (length of
1763
+ // rows, i.e. number of columns)
1765
1764
size_t sort_nelems,
1766
1765
const char *arg_cp,
1767
1766
char *res_cp,
@@ -1776,90 +1775,6 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
1776
1775
IndexTy *res_tp =
1777
1776
reinterpret_cast <IndexTy *>(res_cp) + iter_res_offset + sort_res_offset;
1778
1777
1779
- using ValueIndexT = std::pair<argTy, IndexTy>;
1780
-
1781
- const std::size_t total_nelems = iter_nelems * sort_nelems;
1782
- const std::size_t padded_total_nelems = ((total_nelems + 63 ) / 64 ) * 64 ;
1783
- ValueIndexT *workspace = sycl::malloc_device<ValueIndexT>(
1784
- padded_total_nelems + total_nelems, exec_q);
1785
-
1786
- if (nullptr == workspace) {
1787
- throw std::runtime_error (" Could not allocate workspace on device" );
1788
- }
1789
-
1790
- ValueIndexT *indexed_data_tp = workspace;
1791
- ValueIndexT *temp_tp = workspace + padded_total_nelems;
1792
-
1793
- using Proj = radix_sort_details::ValueProj<argTy, IndexTy>;
1794
- constexpr Proj proj_op{};
1795
-
1796
- sycl::event populate_indexed_data_ev =
1797
- exec_q.submit ([&](sycl::handler &cgh) {
1798
- cgh.depends_on (depends);
1799
-
1800
- using KernelName =
1801
- populate_indexed_data_for_radix_sort_krn<argTy, IndexTy>;
1802
-
1803
- cgh.parallel_for <KernelName>(
1804
- sycl::range<1 >(total_nelems), [=](sycl::id<1 > id) {
1805
- size_t i = id[0 ];
1806
- IndexTy sort_id = static_cast <IndexTy>(i % sort_nelems);
1807
- indexed_data_tp[i] = std::make_pair (arg_tp[i], sort_id);
1808
- });
1809
- });
1810
-
1811
- sycl::event radix_sort_ev =
1812
- radix_sort_details::parallel_radix_sort_impl<ValueIndexT, Proj>(
1813
- exec_q, iter_nelems, sort_nelems, indexed_data_tp, temp_tp, proj_op,
1814
- sort_ascending, {populate_indexed_data_ev});
1815
-
1816
- sycl::event write_out_ev = exec_q.submit ([&](sycl::handler &cgh) {
1817
- cgh.depends_on (radix_sort_ev);
1818
-
1819
- using KernelName = index_write_out_for_radix_sort_krn<argTy, IndexTy>;
1820
-
1821
- cgh.parallel_for <KernelName>(
1822
- sycl::range<1 >(total_nelems),
1823
- [=](sycl::id<1 > id) { res_tp[id] = std::get<1 >(temp_tp[id]); });
1824
- });
1825
-
1826
- sycl::event cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
1827
- cgh.depends_on (write_out_ev);
1828
-
1829
- const sycl::context &ctx = exec_q.get_context ();
1830
-
1831
- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1832
- cgh.host_task ([ctx, workspace] { sycl_free_noexcept (workspace, ctx); });
1833
- });
1834
-
1835
- return cleanup_ev;
1836
- }
1837
-
1838
- template <typename ValueT, typename IndexT> class iota_for_radix_sort_krn ;
1839
-
1840
- template <typename argTy, typename IndexTy>
1841
- sycl::event
1842
- radix_argsort_axis1_contig_alt_impl (sycl::queue &exec_q,
1843
- const bool sort_ascending,
1844
- // number of sub-arrays to sort (num. of
1845
- // rows in a matrix when sorting over rows)
1846
- size_t iter_nelems,
1847
- // size of each array to sort (length of
1848
- // rows, i.e. number of columns)
1849
- size_t sort_nelems,
1850
- const char *arg_cp,
1851
- char *res_cp,
1852
- ssize_t iter_arg_offset,
1853
- ssize_t iter_res_offset,
1854
- ssize_t sort_arg_offset,
1855
- ssize_t sort_res_offset,
1856
- const std::vector<sycl::event> &depends)
1857
- {
1858
- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
1859
- iter_arg_offset + sort_arg_offset;
1860
- IndexTy *res_tp =
1861
- reinterpret_cast <IndexTy *>(res_cp) + iter_res_offset + sort_res_offset;
1862
-
1863
1778
const std::size_t total_nelems = iter_nelems * sort_nelems;
1864
1779
const std::size_t padded_total_nelems = ((total_nelems + 63 ) / 64 ) * 64 ;
1865
1780
IndexTy *workspace = sycl::malloc_device<IndexTy>(
@@ -1877,7 +1792,7 @@ radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q,
1877
1792
sycl::event iota_ev = exec_q.submit ([&](sycl::handler &cgh) {
1878
1793
cgh.depends_on (depends);
1879
1794
1880
- using KernelName = iota_for_radix_sort_krn <argTy, IndexTy>;
1795
+ using KernelName = radix_argsort_iota_krn <argTy, IndexTy>;
1881
1796
1882
1797
cgh.parallel_for <KernelName>(
1883
1798
sycl::range<1 >(total_nelems), [=](sycl::id<1 > id) {
@@ -1895,7 +1810,7 @@ radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q,
1895
1810
sycl::event map_back_ev = exec_q.submit ([&](sycl::handler &cgh) {
1896
1811
cgh.depends_on (radix_sort_ev);
1897
1812
1898
- using KernelName = index_write_out_for_radix_sort_krn <argTy, IndexTy>;
1813
+ using KernelName = radix_argsort_index_write_out_krn <argTy, IndexTy>;
1899
1814
1900
1815
cgh.parallel_for <KernelName>(
1901
1816
sycl::range<1 >(total_nelems), [=](sycl::id<1 > id) {
0 commit comments