Skip to content

Commit d63dd70

Browse files
Merge pull request #1883 from IntelPython/change-descending-from-template-parameter-to-an-argument
Change descending from template parameter to an argument
2 parents 3c05c1b + ec6a930 commit d63dd70

File tree

8 files changed

+353
-174
lines changed

8 files changed

+353
-174
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,11 @@ set(_reduction_sources
114114
set(_sorting_sources
115115
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/sort.cpp
116116
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/argsort.cpp
117+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
118+
)
119+
set(_sorting_radix_sources
117120
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp
118121
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_argsort.cpp
119-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
120122
)
121123
set(_static_lib_sources
122124
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
@@ -153,6 +155,10 @@ set(_tensor_sorting_impl_sources
153155
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp
154156
${_sorting_sources}
155157
)
158+
set(_tensor_sorting_radix_impl_sources
159+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting_radix.cpp
160+
${_sorting_radix_sources}
161+
)
156162
set(_linalg_sources
157163
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
158164
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp
@@ -162,10 +168,10 @@ set(_tensor_linalg_impl_sources
162168
${_linalg_sources}
163169
)
164170
set(_accumulator_sources
165-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp
166-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp
167-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
168-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
171+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp
172+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp
173+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
174+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
169175
)
170176
set(_tensor_accumulation_impl_sources
171177
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp
@@ -207,6 +213,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_s
207213
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
208214
list(APPEND _py_trgts ${python_module_name})
209215

216+
set(python_module_name _tensor_sorting_radix_impl)
217+
pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_radix_impl_sources})
218+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_radix_impl_sources})
219+
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
220+
list(APPEND _py_trgts ${python_module_name})
221+
210222
set(python_module_name _tensor_linalg_impl)
211223
pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources})
212224
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources})

dpctl/tensor/_sorting.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
from ._tensor_sorting_impl import (
2323
_argsort_ascending,
2424
_argsort_descending,
25+
_sort_ascending,
26+
_sort_descending,
27+
)
28+
from ._tensor_sorting_radix_impl import (
2529
_radix_argsort_ascending,
2630
_radix_argsort_descending,
2731
_radix_sort_ascending,
2832
_radix_sort_descending,
2933
_radix_sort_dtype_supported,
30-
_sort_ascending,
31-
_sort_descending,
3234
)
3335

3436
__all__ = ["sort", "argsort"]

dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -807,8 +807,7 @@ sycl::event stable_argsort_axis1_contig_impl(
807807
const IndexComp<IndexTy, argTy, ValueComp> index_comp{arg_tp, ValueComp{}};
808808

809809
static constexpr size_t determine_automatically = 0;
810-
size_t sorted_block_size =
811-
(sort_nelems >= 512) ? 512 : determine_automatically;
810+
size_t sorted_block_size = determine_automatically;
812811

813812
const size_t total_nelems = iter_nelems * sort_nelems;
814813

0 commit comments

Comments
 (0)