diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 2a278c51ec..59728f64d8 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -114,9 +114,11 @@ set(_reduction_sources set(_sorting_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/argsort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp +) +set(_sorting_radix_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_argsort.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp ) set(_static_lib_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp @@ -153,6 +155,10 @@ set(_tensor_sorting_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp ${_sorting_sources} ) +set(_tensor_sorting_radix_impl_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting_radix.cpp + ${_sorting_radix_sources} +) set(_linalg_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp @@ -162,10 +168,10 @@ set(_tensor_linalg_impl_sources ${_linalg_sources} ) set(_accumulator_sources -${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp -${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp -${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp -${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp ) set(_tensor_accumulation_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp @@ -207,6 +213,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_s target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) list(APPEND _py_trgts ${python_module_name}) +set(python_module_name _tensor_sorting_radix_impl) +pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_radix_impl_sources}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_radix_impl_sources}) +target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) +list(APPEND _py_trgts ${python_module_name}) + set(python_module_name _tensor_linalg_impl) pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources}) add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources}) diff --git a/dpctl/tensor/_sorting.py b/dpctl/tensor/_sorting.py index bbf6489412..d5026a6ee8 100644 --- a/dpctl/tensor/_sorting.py +++ b/dpctl/tensor/_sorting.py @@ -22,13 +22,15 @@ from ._tensor_sorting_impl import ( _argsort_ascending, _argsort_descending, + _sort_ascending, + _sort_descending, +) +from ._tensor_sorting_radix_impl import ( _radix_argsort_ascending, _radix_argsort_descending, _radix_sort_ascending, _radix_sort_descending, _radix_sort_dtype_supported, - _sort_ascending, - _sort_descending, ) __all__ = ["sort", "argsort"] diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp index 79f351bf51..f3b5030c48 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp @@ -807,8 +807,7 @@ sycl::event stable_argsort_axis1_contig_impl( const IndexComp index_comp{arg_tp, ValueComp{}}; static constexpr size_t determine_automatically = 0; - size_t sorted_block_size = - (sort_nelems >= 512) ? 512 : determine_automatically; + size_t sorted_block_size = determine_automatically; const size_t total_nelems = iter_nelems * sort_nelems; diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp index 3f834b4317..348ad72d13 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp @@ -49,16 +49,16 @@ namespace kernels namespace radix_sort_details { -template +template class radix_sort_count_kernel; template class radix_sort_scan_kernel; -template +template class radix_sort_reorder_peer_kernel; -template +template class radix_sort_reorder_kernel; //---------------------------------------------------------- @@ -223,7 +223,6 @@ std::uint32_t get_bucket_id(T val, std::uint32_t radix_offset) template @@ -238,6 +237,7 @@ radix_sort_count_submit(sycl::queue &exec_q, std::size_t n_counts, CountT *counts_ptr, const Proj &proj_op, + const bool is_ascending, const std::vector &dependency_events) { // bin_count = radix_states used for an array storing bucket state counters @@ -280,18 +280,37 @@ radix_sort_count_submit(sycl::queue &exec_q, // count array const std::size_t seg_end = sycl::min(seg_start + elems_per_segment, n); - for (std::size_t val_id = seg_start + lid; val_id < seg_end; - val_id += wg_size) - { - // get the bucket for the bit-ordered input value, - // applying the offset and mask for radix bits - const auto val = order_preserving_cast( - proj_op(vals_ptr[val_iter_offset + val_id])); - const std::uint32_t bucket_id = - get_bucket_id(val, radix_offset); - - // increment counter for this bit bucket - ++counts_arr[bucket_id]; + if (is_ascending) { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += wg_size) + { + // get the bucket for the bit-ordered input value, + // applying the offset and mask for radix bits + const auto val = + order_preserving_cast( + proj_op(vals_ptr[val_iter_offset + val_id])); + const std::uint32_t bucket_id = + get_bucket_id(val, radix_offset); + + // increment counter for this bit bucket + ++counts_arr[bucket_id]; + } + } + else { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += wg_size) + { + // get the bucket for the bit-ordered input value, + // applying the offset and mask for radix bits + const auto val = + order_preserving_cast( + proj_op(vals_ptr[val_iter_offset + val_id])); + const std::uint32_t bucket_id = + get_bucket_id(val, radix_offset); + + // increment counter for this bit bucket + ++counts_arr[bucket_id]; + } } // count per work-item: write private count array to local count @@ -622,7 +641,6 @@ void copy_func_for_radix_sort(const std::size_t n_segments, //----------------------------------------------------------------------- template dependency_events) { using ValueT = InputT; @@ -735,32 +754,65 @@ radix_sort_reorder_submit(sycl::queue &exec_q, // find offsets for the same values within a segment and fill the // resulting buffer - for (std::size_t val_id = seg_start + lid; val_id < seg_end; - val_id += sg_size) - { - ValueT in_val = std::move(b_input_ptr[val_id]); + if (is_ascending) { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += sg_size) + { + ValueT in_val = std::move(b_input_ptr[val_id]); - // get the bucket for the bit-ordered input value, applying the - // offset and mask for radix bits - const auto mapped_val = - order_preserving_cast(proj_op(in_val)); - std::uint32_t bucket_id = - get_bucket_id(mapped_val, radix_offset); + // get the bucket for the bit-ordered input value, applying + // the offset and mask for radix bits + const auto mapped_val = + order_preserving_cast( + proj_op(in_val)); + std::uint32_t bucket_id = + get_bucket_id(mapped_val, radix_offset); - OffsetT new_offset_id = 0; - for (std::uint32_t radix_state_id = 0; - radix_state_id < radix_states; ++radix_state_id) + OffsetT new_offset_id = 0; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + bool is_current_bucket = (bucket_id == radix_state_id); + std::uint32_t sg_total_offset = + peer_prefix_hlp.peer_contribution( + /* modified by reference */ new_offset_id, + offset_arr[radix_state_id], + /* bit contribution from this work-item */ + is_current_bucket); + offset_arr[radix_state_id] += sg_total_offset; + } + b_output_ptr[new_offset_id] = std::move(in_val); + } + } + else { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += sg_size) { - bool is_current_bucket = (bucket_id == radix_state_id); - std::uint32_t sg_total_offset = - peer_prefix_hlp.peer_contribution( - /* modified by reference */ new_offset_id, - offset_arr[radix_state_id], - /* bit contribution from this work-item */ - is_current_bucket); - offset_arr[radix_state_id] += sg_total_offset; + ValueT in_val = std::move(b_input_ptr[val_id]); + + // get the bucket for the bit-ordered input value, applying + // the offset and mask for radix bits + const auto mapped_val = + order_preserving_cast( + proj_op(in_val)); + std::uint32_t bucket_id = + get_bucket_id(mapped_val, radix_offset); + + OffsetT new_offset_id = 0; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + bool is_current_bucket = (bucket_id == radix_state_id); + std::uint32_t sg_total_offset = + peer_prefix_hlp.peer_contribution( + /* modified by reference */ new_offset_id, + offset_arr[radix_state_id], + /* bit contribution from this work-item */ + is_current_bucket); + offset_arr[radix_state_id] += sg_total_offset; + } + b_output_ptr[new_offset_id] = std::move(in_val); } - b_output_ptr[new_offset_id] = std::move(in_val); } if (tail_size > 0) { ValueT in_val; @@ -770,8 +822,13 @@ radix_sort_reorder_submit(sycl::queue &exec_q, if (lid < tail_size) { in_val = std::move(b_input_ptr[seg_end + lid]); + const auto proj_val = proj_op(in_val); const auto mapped_val = - order_preserving_cast(proj_op(in_val)); + (is_ascending) + ? order_preserving_cast( + proj_val) + : order_preserving_cast( + proj_val); bucket_id = get_bucket_id(mapped_val, radix_offset); } @@ -820,20 +877,18 @@ sizeT _slm_adjusted_work_group_size(sycl::queue &exec_q, // radix sort: one iteration //----------------------------------------------------------------------- -template +template struct parallel_radix_sort_iteration_step { template - using count_phase = - radix_sort_count_kernel; + using count_phase = radix_sort_count_kernel; template using local_scan_phase = radix_sort_scan_kernel; template using reorder_peer_phase = - radix_sort_reorder_peer_kernel; + radix_sort_reorder_peer_kernel; template - using reorder_phase = - radix_sort_reorder_kernel; + using reorder_phase = radix_sort_reorder_kernel; template &dependency_events) { using _RadixCountKernel = count_phase; @@ -898,10 +954,9 @@ struct parallel_radix_sort_iteration_step // 1. Count Phase sycl::event count_ev = - radix_sort_count_submit<_RadixCountKernel, radix_bits, - is_ascending>( + radix_sort_count_submit<_RadixCountKernel, radix_bits>( exec_q, n_iters, n_segments, count_wg_size, radix_offset, - n_values, in_ptr, n_counts, counts_ptr, proj_op, + n_values, in_ptr, n_counts, counts_ptr, proj_op, is_ascending, dependency_events); // 2. Scan Phase @@ -917,21 +972,21 @@ struct parallel_radix_sort_iteration_step { constexpr auto peer_algorithm = peer_prefix_algo::subgroup_ballot; - reorder_ev = - radix_sort_reorder_submit<_RadixReorderPeerKernel, radix_bits, - is_ascending, peer_algorithm>( - exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, - out_ptr, n_counts, counts_ptr, proj_op, {scan_ev}); + reorder_ev = radix_sort_reorder_submit<_RadixReorderPeerKernel, + radix_bits, peer_algorithm>( + exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, + out_ptr, n_counts, counts_ptr, proj_op, is_ascending, + {scan_ev}); } else { constexpr auto peer_algorithm = peer_prefix_algo::scan_then_broadcast; - reorder_ev = - radix_sort_reorder_submit<_RadixReorderKernel, radix_bits, - is_ascending, peer_algorithm>( - exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, - out_ptr, n_counts, counts_ptr, proj_op, {scan_ev}); + reorder_ev = radix_sort_reorder_submit<_RadixReorderKernel, + radix_bits, peer_algorithm>( + exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, + out_ptr, n_counts, counts_ptr, proj_op, is_ascending, + {scan_ev}); } return reorder_ev; @@ -945,7 +1000,6 @@ template struct subgroup_radix_sort { @@ -965,6 +1019,7 @@ struct subgroup_radix_sort ValueT *input_ptr, OutputT *output_ptr, ProjT proj_op, + const bool is_ascending, const std::vector &depends) { static_assert(std::is_same_v, OutputT>); @@ -995,7 +1050,8 @@ struct subgroup_radix_sort return one_group_submitter<_SortKernelLoc>()( exec_q, n_iters, n_iters, n_values, input_ptr, output_ptr, - proj_op, storage_for_values, storage_for_counters, depends); + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); } case temp_allocations::counters_in_slm: { @@ -1004,7 +1060,8 @@ struct subgroup_radix_sort return one_group_submitter<_SortKernelPartGlob>()( exec_q, n_iters, n_batch_size, n_values, input_ptr, output_ptr, - proj_op, storage_for_values, storage_for_counters, depends); + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); } default: { @@ -1013,7 +1070,8 @@ struct subgroup_radix_sort return one_group_submitter<_SortKernelGlob>()( exec_q, n_iters, n_batch_size, n_values, input_ptr, output_ptr, - proj_op, storage_for_values, storage_for_counters, depends); + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); } } } @@ -1111,6 +1169,7 @@ struct subgroup_radix_sort InputT *input_arr, OutputT *output_arr, const ProjT &proj_op, + const bool is_ascending, SLM_value_tag, SLM_counter_tag, const std::vector &depends) @@ -1216,28 +1275,63 @@ struct subgroup_radix_sort sycl::group_barrier(ndit.get_group()); + if (is_ascending) { #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t id = wi * block_size + i; - constexpr uint16_t bin_mask = bin_count - 1; - - // points to the padded element, i.e. id is - // in-range - constexpr std::uint16_t - default_out_of_range_bin_id = bin_mask; - - const uint16_t bin = - (id < n) ? get_bucket_id( - order_preserving_cast< - is_ascending>( - proj_op(values[i])), - begin_bit) - : default_out_of_range_bin_id; - - // counting and local offset calculation - counters[i] = &pcounter[bin * wg_size]; - indices[i] = *counters[i]; - *counters[i] = indices[i] + 1; + for (uint16_t i = 0; i < block_size; ++i) { + const uint16_t id = wi * block_size + i; + constexpr uint16_t bin_mask = + bin_count - 1; + + // points to the padded element, i.e. id + // is in-range + constexpr std::uint16_t + default_out_of_range_bin_id = + bin_mask; + + const uint16_t bin = + (id < n) + ? get_bucket_id( + order_preserving_cast< + /* is_ascending */ + true>( + proj_op(values[i])), + begin_bit) + : default_out_of_range_bin_id; + + // counting and local offset calculation + counters[i] = &pcounter[bin * wg_size]; + indices[i] = *counters[i]; + *counters[i] = indices[i] + 1; + } + } + else { +#pragma unroll + for (uint16_t i = 0; i < block_size; ++i) { + const uint16_t id = wi * block_size + i; + constexpr uint16_t bin_mask = + bin_count - 1; + + // points to the padded element, i.e. id + // is in-range + constexpr std::uint16_t + default_out_of_range_bin_id = + bin_mask; + + const uint16_t bin = + (id < n) + ? get_bucket_id( + order_preserving_cast< + /* is_ascending */ + false>( + proj_op(values[i])), + begin_bit) + : default_out_of_range_bin_id; + + // counting and local offset calculation + counters[i] = &pcounter[bin * wg_size]; + indices[i] = *counters[i]; + *counters[i] = indices[i] + 1; + } } sycl::group_barrier(ndit.get_group()); @@ -1351,19 +1445,19 @@ struct subgroup_radix_sort }; }; -template -struct OneWorkGroupRadixSortKernel; +template struct OneWorkGroupRadixSortKernel; //----------------------------------------------------------------------- // radix sort: main function //----------------------------------------------------------------------- -template +template sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, std::size_t n_iters, std::size_t n_to_sort, const ValueT *input_arr, ValueT *output_arr, const ProjT &proj_op, + const bool is_ascending, const std::vector &depends) { assert(n_to_sort > 1); @@ -1377,14 +1471,13 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, sycl::event sort_ev{}; + const auto &dev = exec_q.get_device(); const auto max_wg_size = - exec_q.get_device() - .template get_info(); + dev.template get_info(); constexpr std::uint16_t ref_wg_size = 64; if (n_to_sort <= 16384 && ref_wg_size * 8 <= max_wg_size) { - using _RadixSortKernel = - OneWorkGroupRadixSortKernel; + using _RadixSortKernel = OneWorkGroupRadixSortKernel; if (n_to_sort <= 64 && ref_wg_size <= max_wg_size) { // wg_size * block_size == 64 * 1 * 1 == 64 @@ -1392,9 +1485,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 1; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 128 && ref_wg_size * 2 <= max_wg_size) { // wg_size * block_size == 64 * 2 * 1 == 128 @@ -1402,9 +1495,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 1; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 256 && ref_wg_size * 2 <= max_wg_size) { // wg_size * block_size == 64 * 2 * 2 == 256 @@ -1412,9 +1505,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 2; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 512 && ref_wg_size * 2 <= max_wg_size) { // wg_size * block_size == 64 * 2 * 4 == 512 @@ -1422,9 +1515,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 4; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 1024 && ref_wg_size * 2 <= max_wg_size) { // wg_size * block_size == 64 * 2 * 8 == 1024 @@ -1432,9 +1525,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 8; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 2048 && ref_wg_size * 4 <= max_wg_size) { // wg_size * block_size == 64 * 4 * 8 == 2048 @@ -1442,9 +1535,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 8; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 4096 && ref_wg_size * 4 <= max_wg_size) { // wg_size * block_size == 64 * 4 * 16 == 4096 @@ -1452,9 +1545,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 16; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else if (n_to_sort <= 8192 && ref_wg_size * 8 <= max_wg_size) { // wg_size * block_size == 64 * 8 * 16 == 8192 @@ -1462,9 +1555,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 16; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } else { // wg_size * block_size == 64 * 8 * 32 == 16384 @@ -1472,9 +1565,9 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, constexpr std::uint16_t block_size = 32; sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, - radix_bits, is_ascending>{}( + radix_bits>{}( exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, - depends); + is_ascending, depends); } } else { @@ -1512,11 +1605,11 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, if constexpr (std::is_same_v) { sort_ev = parallel_radix_sort_iteration_step< - radix_bits, is_ascending, - /*even=*/true>::submit(exec_q, n_iters, n_segments, - zero_radix_iter, n_to_sort, input_arr, - output_arr, n_counts, count_ptr, proj_op, - depends); + radix_bits, /*even=*/true>::submit(exec_q, n_iters, n_segments, + zero_radix_iter, n_to_sort, + input_arr, output_arr, + n_counts, count_ptr, proj_op, + is_ascending, depends); sort_ev = exec_q.submit([=](sycl::handler &cgh) { cgh.depends_on(sort_ev); @@ -1542,33 +1635,30 @@ sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, assert(radix_iters > 0); sort_ev = parallel_radix_sort_iteration_step< - radix_bits, is_ascending, /*even=*/true>::submit(exec_q, n_iters, - n_segments, - zero_radix_iter, - n_to_sort, - input_arr, tmp_arr, - n_counts, - count_ptr, proj_op, - depends); + radix_bits, /*even=*/true>::submit(exec_q, n_iters, n_segments, + zero_radix_iter, n_to_sort, + input_arr, tmp_arr, n_counts, + count_ptr, proj_op, is_ascending, + depends); for (std::uint32_t radix_iter = 1; radix_iter < radix_iters; ++radix_iter) { if (radix_iter % 2 == 0) { sort_ev = parallel_radix_sort_iteration_step< - radix_bits, is_ascending, + radix_bits, /*even=*/true>::submit(exec_q, n_iters, n_segments, radix_iter, n_to_sort, output_arr, tmp_arr, n_counts, count_ptr, - proj_op, {sort_ev}); + proj_op, is_ascending, {sort_ev}); } else { sort_ev = parallel_radix_sort_iteration_step< - radix_bits, is_ascending, + radix_bits, /*even=*/false>::submit(exec_q, n_iters, n_segments, radix_iter, n_to_sort, tmp_arr, output_arr, n_counts, count_ptr, - proj_op, {sort_ev}); + proj_op, is_ascending, {sort_ev}); } } @@ -1621,9 +1711,10 @@ template struct IndexedProj } // end of namespace radix_sort_details -template +template sycl::event radix_sort_axis1_contig_impl(sycl::queue &exec_q, + const bool sort_ascending, // number of sub-arrays to sort (num. of rows in a // matrix when sorting over rows) size_t iter_nelems, @@ -1647,22 +1738,23 @@ radix_sort_axis1_contig_impl(sycl::queue &exec_q, constexpr Proj proj_op{}; sycl::event radix_sort_ev = - radix_sort_details::parallel_radix_sort_impl( - exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, proj_op, depends); + radix_sort_details::parallel_radix_sort_impl( + exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, proj_op, + sort_ascending, depends); return radix_sort_ev; } -template +template class populate_indexed_data_for_radix_sort_krn; -template +template class index_write_out_for_radix_sort_krn; -template +template sycl::event radix_argsort_axis1_contig_impl(sycl::queue &exec_q, + const bool sort_ascending, // number of sub-arrays to sort (num. of rows in // a matrix when sorting over rows) size_t iter_nelems, @@ -1704,8 +1796,7 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q, cgh.depends_on(depends); using KernelName = - populate_indexed_data_for_radix_sort_krn; + populate_indexed_data_for_radix_sort_krn; cgh.parallel_for( sycl::range<1>(total_nelems), [=](sycl::id<1> id) { @@ -1716,16 +1807,14 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q, }); sycl::event radix_sort_ev = - radix_sort_details::parallel_radix_sort_impl( + radix_sort_details::parallel_radix_sort_impl( exec_q, iter_nelems, sort_nelems, indexed_data_tp, temp_tp, proj_op, - {populate_indexed_data_ev}); + sort_ascending, {populate_indexed_data_ev}); sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(radix_sort_ev); - using KernelName = - index_write_out_for_radix_sort_krn; + using KernelName = index_write_out_for_radix_sort_krn; cgh.parallel_for( sycl::range<1>(total_nelems), @@ -1743,12 +1832,12 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q, return cleanup_ev; } -template -class iota_for_radix_sort_krn; +template class iota_for_radix_sort_krn; -template +template sycl::event radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q, + const bool sort_ascending, // number of sub-arrays to sort (num. of // rows in a matrix when sorting over rows) size_t iter_nelems, @@ -1785,8 +1874,7 @@ radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q, sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); - using KernelName = - iota_for_radix_sort_krn; + using KernelName = iota_for_radix_sort_krn; cgh.parallel_for( sycl::range<1>(total_nelems), [=](sycl::id<1> id) { @@ -1797,16 +1885,14 @@ radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q, }); sycl::event radix_sort_ev = - radix_sort_details::parallel_radix_sort_impl( + radix_sort_details::parallel_radix_sort_impl( exec_q, iter_nelems, sort_nelems, workspace, res_tp, proj_op, - {iota_ev}); + sort_ascending, {iota_ev}); sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(radix_sort_ev); - using KernelName = - index_write_out_for_radix_sort_krn; + using KernelName = index_write_out_for_radix_sort_krn; cgh.parallel_for( sycl::range<1>(total_nelems), [=](sycl::id<1> id) { diff --git a/dpctl/tensor/libtensor/source/sorting/radix_argsort.cpp b/dpctl/tensor/libtensor/source/sorting/radix_argsort.cpp index 74ab28c684..a98e5677b2 100644 --- a/dpctl/tensor/libtensor/source/sorting/radix_argsort.cpp +++ b/dpctl/tensor/libtensor/source/sorting/radix_argsort.cpp @@ -38,6 +38,7 @@ #include "utils/sycl_alloc_utils.hpp" #include "utils/type_dispatch.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/sorting/radix_sort.hpp" #include "kernels/sorting/sort_impl_fn_ptr_t.hpp" @@ -64,6 +65,31 @@ static sort_contig_fn_ptr_t descending_radix_argsort_contig_dispatch_table[td_ns::num_types] [td_ns::num_types]; +namespace +{ + +template +sycl::event argsort_axis1_contig_caller(sycl::queue &q, + size_t iter_nelems, + size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + using dpctl::tensor::kernels::radix_argsort_axis1_contig_alt_impl; + + return radix_argsort_axis1_contig_alt_impl( + q, is_ascending, iter_nelems, sort_nelems, arg_cp, res_cp, + iter_arg_offset, iter_res_offset, sort_arg_offset, sort_res_offset, + depends); +} + +} // end of anonymous namespace + template struct AscendingRadixArgSortContigFactory { @@ -73,9 +99,8 @@ struct AscendingRadixArgSortContigFactory (std::is_same_v || std::is_same_v)) { - using dpctl::tensor::kernels::radix_argsort_axis1_contig_alt_impl; - return radix_argsort_axis1_contig_alt_impl; + return argsort_axis1_contig_caller< + /*ascending*/ true, argTy, IndexTy>; } else { return nullptr; @@ -92,9 +117,8 @@ struct DescendingRadixArgSortContigFactory (std::is_same_v || std::is_same_v)) { - using dpctl::tensor::kernels::radix_argsort_axis1_contig_alt_impl; - return radix_argsort_axis1_contig_alt_impl; + return argsort_axis1_contig_caller< + /*ascending*/ false, argTy, IndexTy>; } else { return nullptr; diff --git a/dpctl/tensor/libtensor/source/sorting/radix_sort.cpp b/dpctl/tensor/libtensor/source/sorting/radix_sort.cpp index 83afe7c6ff..09eb75d1f1 100644 --- a/dpctl/tensor/libtensor/source/sorting/radix_sort.cpp +++ b/dpctl/tensor/libtensor/source/sorting/radix_sort.cpp @@ -38,6 +38,7 @@ #include "utils/output_validation.hpp" #include "utils/type_dispatch.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/sorting/radix_sort.hpp" #include "kernels/sorting/sort_impl_fn_ptr_t.hpp" @@ -61,13 +62,37 @@ static sort_contig_fn_ptr_t static sort_contig_fn_ptr_t descending_radix_sort_contig_dispatch_vector[td_ns::num_types]; +namespace +{ + +template +sycl::event sort_axis1_contig_caller(sycl::queue &q, + size_t iter_nelems, + size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + using dpctl::tensor::kernels::radix_sort_axis1_contig_impl; + + return radix_sort_axis1_contig_impl( + q, is_ascending, iter_nelems, sort_nelems, arg_cp, res_cp, + iter_arg_offset, iter_res_offset, sort_arg_offset, sort_res_offset, + depends); +} + +} // end of anonymous namespace + template struct AscendingRadixSortContigFactory { fnT get() { if constexpr (RadixSortSupportVector::is_defined) { - using dpctl::tensor::kernels::radix_sort_axis1_contig_impl; - return radix_sort_axis1_contig_impl; + return sort_axis1_contig_caller; } else { return nullptr; @@ -80,8 +105,7 @@ template struct DescendingRadixSortContigFactory fnT get() { if constexpr (RadixSortSupportVector::is_defined) { - using dpctl::tensor::kernels::radix_sort_axis1_contig_impl; - return radix_sort_axis1_contig_impl; + return sort_axis1_contig_caller; } else { return nullptr; diff --git a/dpctl/tensor/libtensor/source/tensor_sorting.cpp b/dpctl/tensor/libtensor/source/tensor_sorting.cpp index 80351ed102..6f2f965285 100644 --- a/dpctl/tensor/libtensor/source/tensor_sorting.cpp +++ b/dpctl/tensor/libtensor/source/tensor_sorting.cpp @@ -29,16 +29,11 @@ #include "sorting/searchsorted.hpp" #include "sorting/sort.hpp" -#include "sorting/radix_argsort.hpp" -#include "sorting/radix_sort.hpp" - namespace py = pybind11; PYBIND11_MODULE(_tensor_sorting_impl, m) { dpctl::tensor::py_internal::init_sort_functions(m); - dpctl::tensor::py_internal::init_radix_sort_functions(m); dpctl::tensor::py_internal::init_argsort_functions(m); - dpctl::tensor::py_internal::init_radix_argsort_functions(m); dpctl::tensor::py_internal::init_searchsorted_functions(m); } diff --git a/dpctl/tensor/libtensor/source/tensor_sorting_radix.cpp b/dpctl/tensor/libtensor/source/tensor_sorting_radix.cpp new file mode 100644 index 0000000000..b5ef49e0ac --- /dev/null +++ b/dpctl/tensor/libtensor/source/tensor_sorting_radix.cpp @@ -0,0 +1,37 @@ +//===-- tensor_sorting.cpp - -----*-C++-*-/===// +// Implementation of _tensor_reductions_impl module +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===----------------------------------------------------------------------===// + +#include + +#include "sorting/radix_argsort.hpp" +#include "sorting/radix_sort.hpp" + +namespace py = pybind11; + +PYBIND11_MODULE(_tensor_sorting_radix_impl, m) +{ + dpctl::tensor::py_internal::init_radix_sort_functions(m); + dpctl::tensor::py_internal::init_radix_argsort_functions(m); +}