Skip to content

Commit 0985d6e

Browse files
Remove unused radix_argsort_impl
Renamed radix_argsort_alt_impl to radux_argsort_impl. Also renamed associated kernels
1 parent 09236c9 commit 0985d6e

File tree

2 files changed

+10
-95
lines changed

2 files changed

+10
-95
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp

Lines changed: 8 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,20 +1748,19 @@ radix_sort_axis1_contig_impl(sycl::queue &exec_q,
17481748
}
17491749

17501750
template <typename ValueT, typename IndexT>
1751-
class populate_indexed_data_for_radix_sort_krn;
1751+
class radix_argsort_index_write_out_krn;
17521752

1753-
template <typename ValueT, typename IndexT>
1754-
class index_write_out_for_radix_sort_krn;
1753+
template <typename ValueT, typename IndexT> class radix_argsort_iota_krn;
17551754

17561755
template <typename argTy, typename IndexTy>
17571756
sycl::event
17581757
radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
17591758
const bool sort_ascending,
1760-
// number of sub-arrays to sort (num. of rows in
1761-
// a matrix when sorting over rows)
1759+
// number of sub-arrays to sort (num. of
1760+
// rows in a matrix when sorting over rows)
17621761
size_t iter_nelems,
1763-
// size of each array to sort (length of rows,
1764-
// i.e. number of columns)
1762+
// size of each array to sort (length of
1763+
// rows, i.e. number of columns)
17651764
size_t sort_nelems,
17661765
const char *arg_cp,
17671766
char *res_cp,
@@ -1776,90 +1775,6 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
17761775
IndexTy *res_tp =
17771776
reinterpret_cast<IndexTy *>(res_cp) + iter_res_offset + sort_res_offset;
17781777

1779-
using ValueIndexT = std::pair<argTy, IndexTy>;
1780-
1781-
const std::size_t total_nelems = iter_nelems * sort_nelems;
1782-
const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64;
1783-
ValueIndexT *workspace = sycl::malloc_device<ValueIndexT>(
1784-
padded_total_nelems + total_nelems, exec_q);
1785-
1786-
if (nullptr == workspace) {
1787-
throw std::runtime_error("Could not allocate workspace on device");
1788-
}
1789-
1790-
ValueIndexT *indexed_data_tp = workspace;
1791-
ValueIndexT *temp_tp = workspace + padded_total_nelems;
1792-
1793-
using Proj = radix_sort_details::ValueProj<argTy, IndexTy>;
1794-
constexpr Proj proj_op{};
1795-
1796-
sycl::event populate_indexed_data_ev =
1797-
exec_q.submit([&](sycl::handler &cgh) {
1798-
cgh.depends_on(depends);
1799-
1800-
using KernelName =
1801-
populate_indexed_data_for_radix_sort_krn<argTy, IndexTy>;
1802-
1803-
cgh.parallel_for<KernelName>(
1804-
sycl::range<1>(total_nelems), [=](sycl::id<1> id) {
1805-
size_t i = id[0];
1806-
IndexTy sort_id = static_cast<IndexTy>(i % sort_nelems);
1807-
indexed_data_tp[i] = std::make_pair(arg_tp[i], sort_id);
1808-
});
1809-
});
1810-
1811-
sycl::event radix_sort_ev =
1812-
radix_sort_details::parallel_radix_sort_impl<ValueIndexT, Proj>(
1813-
exec_q, iter_nelems, sort_nelems, indexed_data_tp, temp_tp, proj_op,
1814-
sort_ascending, {populate_indexed_data_ev});
1815-
1816-
sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
1817-
cgh.depends_on(radix_sort_ev);
1818-
1819-
using KernelName = index_write_out_for_radix_sort_krn<argTy, IndexTy>;
1820-
1821-
cgh.parallel_for<KernelName>(
1822-
sycl::range<1>(total_nelems),
1823-
[=](sycl::id<1> id) { res_tp[id] = std::get<1>(temp_tp[id]); });
1824-
});
1825-
1826-
sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
1827-
cgh.depends_on(write_out_ev);
1828-
1829-
const sycl::context &ctx = exec_q.get_context();
1830-
1831-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1832-
cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); });
1833-
});
1834-
1835-
return cleanup_ev;
1836-
}
1837-
1838-
template <typename ValueT, typename IndexT> class iota_for_radix_sort_krn;
1839-
1840-
template <typename argTy, typename IndexTy>
1841-
sycl::event
1842-
radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q,
1843-
const bool sort_ascending,
1844-
// number of sub-arrays to sort (num. of
1845-
// rows in a matrix when sorting over rows)
1846-
size_t iter_nelems,
1847-
// size of each array to sort (length of
1848-
// rows, i.e. number of columns)
1849-
size_t sort_nelems,
1850-
const char *arg_cp,
1851-
char *res_cp,
1852-
ssize_t iter_arg_offset,
1853-
ssize_t iter_res_offset,
1854-
ssize_t sort_arg_offset,
1855-
ssize_t sort_res_offset,
1856-
const std::vector<sycl::event> &depends)
1857-
{
1858-
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
1859-
iter_arg_offset + sort_arg_offset;
1860-
IndexTy *res_tp =
1861-
reinterpret_cast<IndexTy *>(res_cp) + iter_res_offset + sort_res_offset;
1862-
18631778
const std::size_t total_nelems = iter_nelems * sort_nelems;
18641779
const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64;
18651780
IndexTy *workspace = sycl::malloc_device<IndexTy>(
@@ -1877,7 +1792,7 @@ radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q,
18771792
sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) {
18781793
cgh.depends_on(depends);
18791794

1880-
using KernelName = iota_for_radix_sort_krn<argTy, IndexTy>;
1795+
using KernelName = radix_argsort_iota_krn<argTy, IndexTy>;
18811796

18821797
cgh.parallel_for<KernelName>(
18831798
sycl::range<1>(total_nelems), [=](sycl::id<1> id) {
@@ -1895,7 +1810,7 @@ radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q,
18951810
sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) {
18961811
cgh.depends_on(radix_sort_ev);
18971812

1898-
using KernelName = index_write_out_for_radix_sort_krn<argTy, IndexTy>;
1813+
using KernelName = radix_argsort_index_write_out_krn<argTy, IndexTy>;
18991814

19001815
cgh.parallel_for<KernelName>(
19011816
sycl::range<1>(total_nelems), [=](sycl::id<1> id) {

dpctl/tensor/libtensor/source/sorting/radix_argsort.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ sycl::event argsort_axis1_contig_caller(sycl::queue &q,
8080
ssize_t sort_res_offset,
8181
const std::vector<sycl::event> &depends)
8282
{
83-
using dpctl::tensor::kernels::radix_argsort_axis1_contig_alt_impl;
83+
using dpctl::tensor::kernels::radix_argsort_axis1_contig_impl;
8484

85-
return radix_argsort_axis1_contig_alt_impl<T, I>(
85+
return radix_argsort_axis1_contig_impl<T, I>(
8686
q, is_ascending, iter_nelems, sort_nelems, arg_cp, res_cp,
8787
iter_arg_offset, iter_res_offset, sort_arg_offset, sort_res_offset,
8888
depends);

0 commit comments

Comments
 (0)