Skip to content

Commit ec6a930

Browse files
Moved radix sort Python API to dedicated module, _tensor_sorting_radix_impl
With this change, _tensor_sorting_impl goes back to 17MB, and _tensor_sorting_radix_impl is 30MB. The memory footprint of linking should be greatly reduced, speeding up the building process, reducing the required memory footprint, and providing better parallelisation opportunities for the build job. The build time on my Core i7 reduced from 45 minutes to 33 minutes.
1 parent bbe1019 commit ec6a930

File tree

4 files changed

+58
-12
lines changed

4 files changed

+58
-12
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,11 @@ set(_reduction_sources
114114
set(_sorting_sources
115115
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/sort.cpp
116116
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/argsort.cpp
117+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
118+
)
119+
set(_sorting_radix_sources
117120
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp
118121
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_argsort.cpp
119-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
120122
)
121123
set(_static_lib_sources
122124
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
@@ -153,6 +155,10 @@ set(_tensor_sorting_impl_sources
153155
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp
154156
${_sorting_sources}
155157
)
158+
set(_tensor_sorting_radix_impl_sources
159+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting_radix.cpp
160+
${_sorting_radix_sources}
161+
)
156162
set(_linalg_sources
157163
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
158164
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp
@@ -162,10 +168,10 @@ set(_tensor_linalg_impl_sources
162168
${_linalg_sources}
163169
)
164170
set(_accumulator_sources
165-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp
166-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp
167-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
168-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
171+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp
172+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp
173+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
174+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
169175
)
170176
set(_tensor_accumulation_impl_sources
171177
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp
@@ -207,6 +213,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_s
207213
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
208214
list(APPEND _py_trgts ${python_module_name})
209215

216+
set(python_module_name _tensor_sorting_radix_impl)
217+
pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_radix_impl_sources})
218+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_radix_impl_sources})
219+
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
220+
list(APPEND _py_trgts ${python_module_name})
221+
210222
set(python_module_name _tensor_linalg_impl)
211223
pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources})
212224
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources})

dpctl/tensor/_sorting.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
from ._tensor_sorting_impl import (
2323
_argsort_ascending,
2424
_argsort_descending,
25+
_sort_ascending,
26+
_sort_descending,
27+
)
28+
from ._tensor_sorting_radix_impl import (
2529
_radix_argsort_ascending,
2630
_radix_argsort_descending,
2731
_radix_sort_ascending,
2832
_radix_sort_descending,
2933
_radix_sort_dtype_supported,
30-
_sort_ascending,
31-
_sort_descending,
3234
)
3335

3436
__all__ = ["sort", "argsort"]

dpctl/tensor/libtensor/source/tensor_sorting.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,11 @@
2929
#include "sorting/searchsorted.hpp"
3030
#include "sorting/sort.hpp"
3131

32-
#include "sorting/radix_argsort.hpp"
33-
#include "sorting/radix_sort.hpp"
34-
3532
namespace py = pybind11;
3633

3734
PYBIND11_MODULE(_tensor_sorting_impl, m)
3835
{
3936
dpctl::tensor::py_internal::init_sort_functions(m);
40-
dpctl::tensor::py_internal::init_radix_sort_functions(m);
4137
dpctl::tensor::py_internal::init_argsort_functions(m);
42-
dpctl::tensor::py_internal::init_radix_argsort_functions(m);
4338
dpctl::tensor::py_internal::init_searchsorted_functions(m);
4439
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//===-- tensor_sorting.cpp - -----*-C++-*-/===//
2+
// Implementation of _tensor_reductions_impl module
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2024 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain a copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===----------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
24+
//===----------------------------------------------------------------------===//
25+
26+
#include <pybind11/pybind11.h>
27+
28+
#include "sorting/radix_argsort.hpp"
29+
#include "sorting/radix_sort.hpp"
30+
31+
namespace py = pybind11;
32+
33+
PYBIND11_MODULE(_tensor_sorting_radix_impl, m)
34+
{
35+
dpctl::tensor::py_internal::init_radix_sort_functions(m);
36+
dpctl::tensor::py_internal::init_radix_argsort_functions(m);
37+
}

0 commit comments

Comments
 (0)