From bbe10191bddf0f2c72eb65c83eb5b8b1b3dd02d6 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 29 Oct 2024 09:04:51 -0500 Subject: [PATCH 1/2] Change to pass sorting direction as call argument, not template parameter The intent is to reduce the build time, build memory footprint, and binary size of the sorting_impl module. With this change it stands at 46MB, before it was 72MB. --- .../include/kernels/sorting/merge_sort.hpp | 3 +- .../include/kernels/sorting/radix_sort.hpp | 386 +++++++++++------- .../source/sorting/radix_argsort.cpp | 36 +- .../libtensor/source/sorting/radix_sort.cpp | 32 +- 4 files changed, 295 insertions(+), 162 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp index 79f351bf51..f3b5030c48 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp @@ -807,8 +807,7 @@ sycl::event stable_argsort_axis1_contig_impl( const IndexComp index_comp{arg_tp, ValueComp{}}; static constexpr size_t determine_automatically = 0; - size_t sorted_block_size = - (sort_nelems >= 512) ? 512 : determine_automatically; + size_t sorted_block_size = determine_automatically; const size_t total_nelems = iter_nelems * sort_nelems; diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp index 3f834b4317..348ad72d13 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp @@ -49,16 +49,16 @@ namespace kernels namespace radix_sort_details { -template +template class radix_sort_count_kernel; template class radix_sort_scan_kernel; -template +template class radix_sort_reorder_peer_kernel; -template +template class radix_sort_reorder_kernel; //---------------------------------------------------------- @@ -223,7 +223,6 @@ std::uint32_t get_bucket_id(T val, std::uint32_t radix_offset) template @@ -238,6 +237,7 @@ radix_sort_count_submit(sycl::queue &exec_q, std::size_t n_counts, CountT *counts_ptr, const Proj &proj_op, + const bool is_ascending, const std::vector &dependency_events) { // bin_count = radix_states used for an array storing bucket state counters @@ -280,18 +280,37 @@ radix_sort_count_submit(sycl::queue &exec_q, // count array const std::size_t seg_end = sycl::min(seg_start + elems_per_segment, n); - for (std::size_t val_id = seg_start + lid; val_id < seg_end; - val_id += wg_size) - { - // get the bucket for the bit-ordered input value, - // applying the offset and mask for radix bits - const auto val = order_preserving_cast( - proj_op(vals_ptr[val_iter_offset + val_id])); - const std::uint32_t bucket_id = - get_bucket_id(val, radix_offset); - - // increment counter for this bit bucket - ++counts_arr[bucket_id]; + if (is_ascending) { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += wg_size) + { + // get the bucket for the bit-ordered input value, + // applying the offset and mask for radix bits + const auto val = + order_preserving_cast( + proj_op(vals_ptr[val_iter_offset + val_id])); + const std::uint32_t bucket_id = + get_bucket_id(val, radix_offset); + + // increment counter for this bit bucket + ++counts_arr[bucket_id]; + } + } + else { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += wg_size) + { + // get the bucket for the bit-ordered input value, + // applying the offset and mask for radix bits + const auto val = + order_preserving_cast( + proj_op(vals_ptr[val_iter_offset + val_id])); + const std::uint32_t bucket_id = + get_bucket_id(val, radix_offset); + + // increment counter for this bit bucket + ++counts_arr[bucket_id]; + } } // count per work-item: write private count array to local count @@ -622,7 +641,6 @@ void copy_func_for_radix_sort(const std::size_t n_segments, //----------------------------------------------------------------------- template dependency_events) { using ValueT = InputT; @@ -735,32 +754,65 @@ radix_sort_reorder_submit(sycl::queue &exec_q, // find offsets for the same values within a segment and fill the // resulting buffer - for (std::size_t val_id = seg_start + lid; val_id < seg_end; - val_id += sg_size) - { - ValueT in_val = std::move(b_input_ptr[val_id]); + if (is_ascending) { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += sg_size) + { + ValueT in_val = std::move(b_input_ptr[val_id]); - // get the bucket for the bit-ordered input value, applying the - // offset and mask for radix bits - const auto mapped_val = - order_preserving_cast(proj_op(in_val)); - std::uint32_t bucket_id = - get_bucket_id(mapped_val, radix_offset); + // get the bucket for the bit-ordered input value, applying + // the offset and mask for radix bits + const auto mapped_val = + order_preserving_cast( + proj_op(in_val)); + std::uint32_t bucket_id = + get_bucket_id(mapped_val, radix_offset); - OffsetT new_offset_id = 0; - for (std::uint32_t radix_state_id = 0; - radix_state_id < radix_states; ++radix_state_id) + OffsetT new_offset_id = 0; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + bool is_current_bucket = (bucket_id == radix_state_id); + std::uint32_t sg_total_offset = + peer_prefix_hlp.peer_contribution( + /* modified by reference */ new_offset_id, + offset_arr[radix_state_id], + /* bit contribution from this work-item */ + is_current_bucket); + offset_arr[radix_state_id] += sg_total_offset; + } + b_output_ptr[new_offset_id] = std::move(in_val); + } + } + else { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += sg_size) { - bool is_current_bucket = (bucket_id == radix_state_id); - std::uint32_t sg_total_offset = - peer_prefix_hlp.peer_contribution( - /* modified by reference */ new_offset_id, - offset_arr[radix_state_id], - /* bit contribution from this work-item */ - is_current_bucket); - offset_arr[radix_state_id] += sg_total_offset; + ValueT in_val = std::move(b_input_ptr[val_id]); + + // get the bucket for the bit-ordered input value, applying + // the offset and mask for radix bits + const auto mapped_val = + order_preserving_cast( + proj_op(in_val)); + std::uint32_t bucket_id = + get_bucket_id(mapped_val, radix_offset); + + OffsetT new_offset_id = 0; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + bool is_current_bucket = (bucket_id == radix_state_id); + std::uint32_t sg_total_offset = + peer_prefix_hlp.peer_contribution( + /* modified by reference */ new_offset_id, + offset_arr[radix_state_id], + /* bit contribution from this work-item */ + is_current_bucket); + offset_arr[radix_state_id] += sg_total_offset; + } + b_output_ptr[new_offset_id] = std::move(in_val); } - b_output_ptr[new_offset_id] = std::move(in_val); } if (tail_size > 0) { ValueT in_val; @@ -770,8 +822,13 @@ radix_sort_reorder_submit(sycl::queue &exec_q, if (lid < tail_size) { in_val = std::move(b_input_ptr[seg_end + lid]); + const auto proj_val = proj_op(in_val); const auto mapped_val = - order_preserving_cast(proj_op(in_val)); + (is_ascending) + ? order_preserving_cast( + proj_val) + : order_preserving_cast( + proj_val); bucket_id = get_bucket_id(mapped_val, radix_offset); } @@ -820,20 +877,18 @@ sizeT _slm_adjusted_work_group_size(sycl::queue &exec_q, // radix sort: one iteration //----------------------------------------------------------------------- -template +template struct parallel_radix_sort_iteration_step { template - using count_phase = - radix_sort_count_kernel; + using count_phase = radix_sort_count_kernel; template using local_scan_phase = radix_sort_scan_kernel; template using reorder_peer_phase = - radix_sort_reorder_peer_kernel; + radix_sort_reorder_peer_kernel; template - using reorder_phase = - radix_sort_reorder_kernel; + using reorder_phase = radix_sort_reorder_kernel; template &dependency_events) { using _RadixCountKernel = count_phase; @@ -898,10 +954,9 @@ struct parallel_radix_sort_iteration_step // 1. Count Phase sycl::event count_ev = - radix_sort_count_submit<_RadixCountKernel, radix_bits, - is_ascending>( + radix_sort_count_submit<_RadixCountKernel, radix_bits>( exec_q, n_iters, n_segments, count_wg_size, radix_offset, - n_values, in_ptr, n_counts, counts_ptr, proj_op, + n_values, in_ptr, n_counts, counts_ptr, proj_op, is_ascending, dependency_events); // 2. Scan Phase @@ -917,21 +972,21 @@ struct parallel_radix_sort_iteration_step { constexpr auto peer_algorithm = peer_prefix_algo::subgroup_ballot; - reorder_ev = - radix_sort_reorder_submit<_RadixReorderPeerKernel, radix_bits, - is_ascending, peer_algorithm>( - exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, - out_ptr, n_counts, counts_ptr, proj_op, {scan_ev}); + reorder_ev = radix_sort_reorder_submit<_RadixReorderPeerKernel, + radix_bits, peer_algorithm>( + exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, + out_ptr, n_counts, counts_ptr, proj_op, is_ascending, + {scan_ev}); } else { constexpr auto peer_algorithm = peer_prefix_algo::scan_then_broadcast; - reorder_ev = - radix_sort_reorder_submit<_RadixReorderKernel, radix_bits, - is_ascending, peer_algorithm>( - exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, - out_ptr, n_counts, counts_ptr, proj_op, {scan_ev}); + reorder_ev = radix_sort_reorder_submit<_RadixReorderKernel, + radix_bits, peer_algorithm>( + exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, + out_ptr, n_counts, counts_ptr, proj_op, is_ascending, + {scan_ev}); } return reorder_ev; @@ -945,7 +1000,6 @@ template struct subgroup_radix_sort { @@ -965,6 +1019,7 @@ struct subgroup_radix_sort ValueT *input_ptr, OutputT *output_ptr, ProjT proj_op, + const bool is_ascending, const std::vector &depends) { static_assert(std::is_same_v, OutputT>); @@ -995,7 +1050,8 @@ struct subgroup_radix_sort return one_group_submitter<_SortKernelLoc>()( exec_q, n_iters, n_iters, n_values, input_ptr, output_ptr, - proj_op, storage_for_values, storage_for_counters, depends); + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); } case temp_allocations::counters_in_slm: { @@ -1004,7 +1060,8 @@ struct subgroup_radix_sort return one_group_submitter<_SortKernelPartGlob>()( exec_q, n_iters, n_batch_size, n_values, input_ptr, output_ptr, - proj_op, storage_for_values, storage_for_counters, depends); + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); } default: { @@ -1013,7 +1070,8 @@ struct subgroup_radix_sort return one_group_submitter<_SortKernelGlob>()( exec_q, n_iters, n_batch_size, n_values, input_ptr, output_ptr, - proj_op, storage_for_values, storage_for_counters, depends); + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); } } } @@ -1111,6 +1169,7 @@ struct subgroup_radix_sort InputT *input_arr, OutputT *output_arr, const ProjT &proj_op, + const bool is_ascending, SLM_value_tag, SLM_counter_tag, const std::vector &depends) @@ -1216,28 +1275,63 @@ struct subgroup_radix_sort sycl::group_barrier(ndit.get_group()); + if (is_ascending) { #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t id = wi * block_size + i; - constexpr uint16_t bin_mask = bin_count - 1; - - // points to the padded element, i.e. id is - // in-range - constexpr std::uint16_t - default_out_of_range_bin_id = bin_mask; - - const uint16_t bin = - (id < n) ? get_bucket_id( - order_preserving_cast< - is_ascending>( - proj_op(values[i])), - begin_bit) - : default_out_of_range_bin_id; - - // counting and local offset calculation - counters[i] = &pcounter[bin * wg_size]; - indices[i] = *counters[i]; - *counters[i] = indices[i] + 1; + for (uint16_t i = 0; i < block_size; ++i) { + const uint16_t id = wi * block_size + i; + constexpr uint16_t bin_mask = + bin_count - 1; + + // points to the padded element, i.e. id + // is in-range + constexpr std::uint16_t + default_out_of_range_bin_id = + bin_mask; + + const uint16_t bin = + (id < n) + ? get_bucket_id( + order_preserving_cast< + /* is_ascending */ + true>( + proj_op(values[i])), + begin_bit) + : default_out_of_range_bin_id; + + // counting and local offset calculation + counters[i] = &pcounter[bin * wg_size]; + indices[i] = *counters[i]; + *counters[i] = indices[i] + 1; + } + } + else { +#pragma unroll + for (uint16_t i = 0; i < block_size; ++i) { + const uint16_t id = wi * block_size + i; + constexpr uint16_t bin_mask = + bin_count - 1; + + // points to the padded element, i.e. id + // is in-range + constexpr std::uint16_t + default_out_of_range_bin_id = + bin_mask; + + const uint16_t bin = + (id < n) + ? get_bucket_id( + order_preserving_cast< + /* is_ascending */ + false>( + proj_op(values[i])), + begin_bit) + : default_out_of_range_bin_id; + + // counting and local offset calculation + counters[i] = &pcounter[bin * wg_size]; + indices[i] = *counters[i]; + *counters[i] = indices[i] + 1; + } } sycl::group_barrier(ndit.get_group()); @@ -1351,19 +1445,19 @@ struct subgroup_radix_sort }; }; -template -struct OneWorkGroupRadixSortKernel; +template struct OneWorkGroupRadixSortKernel; //----------------------------------------------------------------------- // radix sort: main function //----------------------------------------------------------------------- -template +template sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, std::size_t n_iters, std::size_t n_to_sort, const ValueT *input_arr, ValueT *output_arr, const ProjT &proj_op, + const bool is_ascending, const std::vector &depends) { assert(n_to_sort > 1); @@ -1377,14 +1471,13 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, sycl::event sort_ev{}; + const auto &dev = exec_q.get_device(); const auto max_wg_size = - exec_q.get_device() - .template get_info(); + dev.template get_info(); constexpr std::uint16_t ref_wg_size = 64; if (n_to_sort <= 16384 && ref_wg_size * 8 <= max_wg_size) { - using _RadixSortKernel = - OneWorkGroupRadixSortKernel; + using _RadixSortKernel = OneWorkGroupRadixSortKernel; if (n_to_sort <= 64 && ref_wg_size <= max_wg_size) { // wg_size * block_size == 64 * 1 * 1 == 64 @@ -1392,9 +1485,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 1; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 128 && ref_wg_size * 2 <= max_wg_size) { // wg_size * block_size == 64 * 2 * 1 == 128 @@ -1402,9 +1495,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 1; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 256 && ref_wg_size * 2 <= max_wg_size) { // wg_size * block_size == 64 * 2 * 2 == 256 @@ -1412,9 +1505,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 2; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 512 && ref_wg_size * 2 <= max_wg_size) { // wg_size * block_size == 64 * 2 * 4 == 512 @@ -1422,9 +1515,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 4; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 1024 && ref_wg_size * 2 <= max_wg_size) { // wg_size * block_size == 64 * 2 * 8 == 1024 @@ -1432,9 +1525,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 8; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 2048 && ref_wg_size * 4 <= max_wg_size) { // wg_size * block_size == 64 * 4 * 8 == 2048 @@ -1442,9 +1535,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 8; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 4096 && ref_wg_size * 4 <= max_wg_size) { // wg_size * block_size == 64 * 4 * 16 == 4096 @@ -1452,9 +1545,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 16; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 8192 && ref_wg_size * 8 <= max_wg_size) { // wg_size * block_size == 64 * 8 * 16 == 8192 @@ -1462,9 +1555,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 16; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else { // wg_size * block_size == 64 * 8 * 32 == 16384 @@ -1472,9 +1565,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 32; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } } else { @@ -1512,11 +1605,11 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, if constexpr (std::is_same_v) { sort_ev = parallel_radix_sort_iteration_step< - radix_bits, is_ascending, - /*even=*/true>::submit(exec_q, n_iters, n_segments, - zero_radix_iter, n_to_sort, input_arr, - output_arr, n_counts, count_ptr, proj_op, - depends); + radix_bits, /*even=*/true>::submit(exec_q, n_iters, n_segments, + zero_radix_iter, n_to_sort, + input_arr, output_arr, + n_counts, count_ptr, proj_op, + is_ascending, depends); sort_ev = exec_q.submit([=](sycl::handler &cgh) { cgh.depends_on(sort_ev); @@ -1542,33 +1635,30 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, assert(radix_iters > 0); sort_ev = parallel_radix_sort_iteration_step< - radix_bits, is_ascending, /*even=*/true>::submit(exec_q, n_iters, - n_segments, - zero_radix_iter, - n_to_sort, - input_arr, tmp_arr, - n_counts, - count_ptr, proj_op, - depends); + radix_bits, /*even=*/true>::submit(exec_q, n_iters, n_segments, + zero_radix_iter, n_to_sort, + input_arr, tmp_arr, n_counts, + count_ptr, proj_op, is_ascending, + depends); for (std::uint32_t radix_iter = 1; radix_iter < radix_iters; ++radix_iter) { if (radix_iter % 2 == 0) { sort_ev = parallel_radix_sort_iteration_step< - radix_bits, is_ascending, + radix_bits, /*even=*/true>::submit(exec_q, n_iters, n_segments, radix_iter, n_to_sort, output_arr, tmp_arr, n_counts, count_ptr, - proj_op, {sort_ev}); + proj_op, is_ascending, {sort_ev}); } else { sort_ev = parallel_radix_sort_iteration_step< - radix_bits, is_ascending, + radix_bits, /*even=*/false>::submit(exec_q, n_iters, n_segments, radix_iter, n_to_sort, tmp_arr, output_arr, n_counts, count_ptr, - proj_op, {sort_ev}); + proj_op, is_ascending, {sort_ev}); } } @@ -1621,9 +1711,10 @@ template struct IndexedProj } // end of namespace radix_sort_details -template +template sycl::event radix_sort_axis1_contig_impl(sycl::queue &exec_q, + const bool sort_ascending, // number of sub-arrays to sort (num. of rows in a // matrix when sorting over rows) size_t iter_nelems, @@ -1647,22 +1738,23 @@ radix_sort_axis1_contig_impl(sycl::queue &exec_q, constexpr Proj proj_op{}; sycl::event radix_sort_ev = - radix_sort_details::parallel_radix_sort_impl( - exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, proj_op, depends); + radix_sort_details::parallel_radix_sort_impl( + exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, proj_op, + sort_ascending, depends); return radix_sort_ev; } -template +template class populate_indexed_data_for_radix_sort_krn; -template +template class index_write_out_for_radix_sort_krn; -template +template sycl::event radix_argsort_axis1_contig_impl(sycl::queue &exec_q, + const bool sort_ascending, // number of sub-arrays to sort (num. of rows in // a matrix when sorting over rows) size_t iter_nelems, @@ -1704,8 +1796,7 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q, cgh.depends_on(depends); using KernelName = - populate_indexed_data_for_radix_sort_krn; + populate_indexed_data_for_radix_sort_krn; cgh.parallel_for( sycl::range<1>(total_nelems), [=](sycl::id<1> id) { @@ -1716,16 +1807,14 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q, }); sycl::event radix_sort_ev = - radix_sort_details::parallel_radix_sort_impl( + radix_sort_details::parallel_radix_sort_impl( exec_q, iter_nelems, sort_nelems, indexed_data_tp, temp_tp, proj_op, - {populate_indexed_data_ev}); + sort_ascending, {populate_indexed_data_ev}); sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(radix_sort_ev); - using KernelName = - index_write_out_for_radix_sort_krn; + using KernelName = index_write_out_for_radix_sort_krn; cgh.parallel_for( sycl::range<1>(total_nelems), @@ -1743,12 +1832,12 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q, return cleanup_ev; } -template -class iota_for_radix_sort_krn; +template class iota_for_radix_sort_krn; -template +template sycl::event radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q, + const bool sort_ascending, // number of sub-arrays to sort (num. of // rows in a matrix when sorting over rows) size_t iter_nelems, @@ -1785,8 +1874,7 @@ radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q, sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); - using KernelName = - iota_for_radix_sort_krn; + using KernelName = iota_for_radix_sort_krn; cgh.parallel_for( sycl::range<1>(total_nelems), [=](sycl::id<1> id) { @@ -1797,16 +1885,14 @@ radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q, }); sycl::event radix_sort_ev = - radix_sort_details::parallel_radix_sort_impl( + radix_sort_details::parallel_radix_sort_impl( exec_q, iter_nelems, sort_nelems, workspace, res_tp, proj_op, - {iota_ev}); + sort_ascending, {iota_ev}); sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(radix_sort_ev); - using KernelName = - index_write_out_for_radix_sort_krn; + using KernelName = index_write_out_for_radix_sort_krn; cgh.parallel_for( sycl::range<1>(total_nelems), [=](sycl::id<1> id) { diff --git a/dpctl/tensor/libtensor/source/sorting/radix_argsort.cpp b/dpctl/tensor/libtensor/source/sorting/radix_argsort.cpp index 74ab28c684..a98e5677b2 100644 --- a/dpctl/tensor/libtensor/source/sorting/radix_argsort.cpp +++ b/dpctl/tensor/libtensor/source/sorting/radix_argsort.cpp @@ -38,6 +38,7 @@ #include "utils/sycl_alloc_utils.hpp" #include "utils/type_dispatch.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/sorting/radix_sort.hpp" #include "kernels/sorting/sort_impl_fn_ptr_t.hpp" @@ -64,6 +65,31 @@ static sort_contig_fn_ptr_t descending_radix_argsort_contig_dispatch_table[td_ns::num_types] [td_ns::num_types]; +namespace +{ + +template +sycl::event argsort_axis1_contig_caller(sycl::queue &q, + size_t iter_nelems, + size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + using dpctl::tensor::kernels::radix_argsort_axis1_contig_alt_impl; + + return radix_argsort_axis1_contig_alt_impl( + q, is_ascending, iter_nelems, sort_nelems, arg_cp, res_cp, + iter_arg_offset, iter_res_offset, sort_arg_offset, sort_res_offset, + depends); +} + +} // end of anonymous namespace + template struct AscendingRadixArgSortContigFactory { @@ -73,9 +99,8 @@ struct AscendingRadixArgSortContigFactory (std::is_same_v || std::is_same_v)) { - using dpctl::tensor::kernels::radix_argsort_axis1_contig_alt_impl; - return radix_argsort_axis1_contig_alt_impl; + return argsort_axis1_contig_caller< + /*ascending*/ true, argTy, IndexTy>; } else { return nullptr; @@ -92,9 +117,8 @@ struct DescendingRadixArgSortContigFactory (std::is_same_v || std::is_same_v)) { - using dpctl::tensor::kernels::radix_argsort_axis1_contig_alt_impl; - return radix_argsort_axis1_contig_alt_impl; + return argsort_axis1_contig_caller< + /*ascending*/ false, argTy, IndexTy>; } else { return nullptr; diff --git a/dpctl/tensor/libtensor/source/sorting/radix_sort.cpp b/dpctl/tensor/libtensor/source/sorting/radix_sort.cpp index 83afe7c6ff..09eb75d1f1 100644 --- a/dpctl/tensor/libtensor/source/sorting/radix_sort.cpp +++ b/dpctl/tensor/libtensor/source/sorting/radix_sort.cpp @@ -38,6 +38,7 @@ #include "utils/output_validation.hpp" #include "utils/type_dispatch.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/sorting/radix_sort.hpp" #include "kernels/sorting/sort_impl_fn_ptr_t.hpp" @@ -61,13 +62,37 @@ static sort_contig_fn_ptr_t static sort_contig_fn_ptr_t descending_radix_sort_contig_dispatch_vector[td_ns::num_types]; +namespace +{ + +template +sycl::event sort_axis1_contig_caller(sycl::queue &q, + size_t iter_nelems, + size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + using dpctl::tensor::kernels::radix_sort_axis1_contig_impl; + + return radix_sort_axis1_contig_impl( + q, is_ascending, iter_nelems, sort_nelems, arg_cp, res_cp, + iter_arg_offset, iter_res_offset, sort_arg_offset, sort_res_offset, + depends); +} + +} // end of anonymous namespace + template struct AscendingRadixSortContigFactory { fnT get() { if constexpr (RadixSortSupportVector::is_defined) { - using dpctl::tensor::kernels::radix_sort_axis1_contig_impl; - return radix_sort_axis1_contig_impl; + return sort_axis1_contig_caller; } else { return nullptr; @@ -80,8 +105,7 @@ template struct DescendingRadixSortContigFactory fnT get() { if constexpr (RadixSortSupportVector::is_defined) { - using dpctl::tensor::kernels::radix_sort_axis1_contig_impl; - return radix_sort_axis1_contig_impl; + return sort_axis1_contig_caller; } else { return nullptr; From ec6a930876ec8c386d53615f6b8f05512db764b9 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 29 Oct 2024 09:44:23 -0500 Subject: [PATCH 2/2] Moved radix sort Python API to dedicated module, _tensor_sorting_radix_impl With this change, _tensor_sorting_impl goes back to 17MB, and _tensor_sorting_radix_impl is 30MB. The memory footprint of linking should be greatly reduced, speeding up the building process, reducing the required memory footprint, and providing better parallelisation opportunities for the build job. The build time on my Core i7 reduced from 45 minutes to 33 minutes. --- dpctl/tensor/CMakeLists.txt | 22 ++++++++--- dpctl/tensor/_sorting.py | 6 ++- .../libtensor/source/tensor_sorting.cpp | 5 --- .../libtensor/source/tensor_sorting_radix.cpp | 37 +++++++++++++++++++ 4 files changed, 58 insertions(+), 12 deletions(-) create mode 100644 dpctl/tensor/libtensor/source/tensor_sorting_radix.cpp diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 2a278c51ec..59728f64d8 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -114,9 +114,11 @@ set(_reduction_sources set(_sorting_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/argsort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp +) +set(_sorting_radix_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_argsort.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp ) set(_static_lib_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp @@ -153,6 +155,10 @@ set(_tensor_sorting_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp ${_sorting_sources} ) +set(_tensor_sorting_radix_impl_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting_radix.cpp + ${_sorting_radix_sources} +) set(_linalg_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp @@ -162,10 +168,10 @@ set(_tensor_linalg_impl_sources ${_linalg_sources} ) set(_accumulator_sources -${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp -${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp -${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp -${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp ) set(_tensor_accumulation_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp @@ -207,6 +213,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_s target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) list(APPEND _py_trgts ${python_module_name}) +set(python_module_name _tensor_sorting_radix_impl) +pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_radix_impl_sources}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_radix_impl_sources}) +target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) +list(APPEND _py_trgts ${python_module_name}) + set(python_module_name _tensor_linalg_impl) pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources}) add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources}) diff --git a/dpctl/tensor/_sorting.py b/dpctl/tensor/_sorting.py index bbf6489412..d5026a6ee8 100644 --- a/dpctl/tensor/_sorting.py +++ b/dpctl/tensor/_sorting.py @@ -22,13 +22,15 @@ from ._tensor_sorting_impl import ( _argsort_ascending, _argsort_descending, + _sort_ascending, + _sort_descending, +) +from ._tensor_sorting_radix_impl import ( _radix_argsort_ascending, _radix_argsort_descending, _radix_sort_ascending, _radix_sort_descending, _radix_sort_dtype_supported, - _sort_ascending, - _sort_descending, ) __all__ = ["sort", "argsort"] diff --git a/dpctl/tensor/libtensor/source/tensor_sorting.cpp b/dpctl/tensor/libtensor/source/tensor_sorting.cpp index 80351ed102..6f2f965285 100644 --- a/dpctl/tensor/libtensor/source/tensor_sorting.cpp +++ b/dpctl/tensor/libtensor/source/tensor_sorting.cpp @@ -29,16 +29,11 @@ #include "sorting/searchsorted.hpp" #include "sorting/sort.hpp" -#include "sorting/radix_argsort.hpp" -#include "sorting/radix_sort.hpp" - namespace py = pybind11; PYBIND11_MODULE(_tensor_sorting_impl, m) { dpctl::tensor::py_internal::init_sort_functions(m); - dpctl::tensor::py_internal::init_radix_sort_functions(m); dpctl::tensor::py_internal::init_argsort_functions(m); - dpctl::tensor::py_internal::init_radix_argsort_functions(m); dpctl::tensor::py_internal::init_searchsorted_functions(m); } diff --git a/dpctl/tensor/libtensor/source/tensor_sorting_radix.cpp b/dpctl/tensor/libtensor/source/tensor_sorting_radix.cpp new file mode 100644 index 0000000000..b5ef49e0ac --- /dev/null +++ b/dpctl/tensor/libtensor/source/tensor_sorting_radix.cpp @@ -0,0 +1,37 @@ +//===-- tensor_sorting.cpp - -----*-C++-*-/===// +// Implementation of _tensor_reductions_impl module +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===----------------------------------------------------------------------===// + +#include + +#include "sorting/radix_argsort.hpp" +#include "sorting/radix_sort.hpp" + +namespace py = pybind11; + +PYBIND11_MODULE(_tensor_sorting_radix_impl, m) +{ + dpctl::tensor::py_internal::init_radix_sort_functions(m); + dpctl::tensor::py_internal::init_radix_argsort_functions(m); +}