Skip to content

Commit c2f8486

Browse files
Add Python API to check if radix sort is supported for given dtype
1 parent e274c2d commit c2f8486

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

dpctl/tensor/libtensor/source/sorting/radix_sort.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
//===--------------------------------------------------------------------===//
2424

2525
#include <cstdint>
26+
#include <exception>
2627
#include <utility>
2728
#include <vector>
2829

@@ -105,6 +106,19 @@ void init_radix_sort_dispatch_vectors(void)
105106
dtv2.populate_dispatch_vector(descending_radix_sort_contig_dispatch_vector);
106107
}
107108

109+
bool py_radix_sort_defined(int typenum)
110+
{
111+
const auto &array_types = td_ns::usm_ndarray_types();
112+
113+
try {
114+
int type_id = array_types.typenum_to_lookup_id(typenum);
115+
return (nullptr !=
116+
ascending_radix_sort_contig_dispatch_vector[type_id]);
117+
} catch (const std::exception &e) {
118+
return false;
119+
}
120+
}
121+
108122
void init_radix_sort_functions(py::module_ m)
109123
{
110124
dpctl::tensor::py_internal::init_radix_sort_dispatch_vectors();
@@ -139,6 +153,8 @@ void init_radix_sort_functions(py::module_ m)
139153
py::arg("trailing_dims_to_sort"), py::arg("dst"),
140154
py::arg("sycl_queue"), py::arg("depends") = py::list());
141155

156+
m.def("_radix_sort_dtype_supported", py_radix_sort_defined);
157+
142158
return;
143159
}
144160

0 commit comments

Comments
 (0)