File tree Expand file tree Collapse file tree 1 file changed +16
-0
lines changed
dpctl/tensor/libtensor/source/sorting Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Original file line number Diff line number Diff line change 23
23
// ===--------------------------------------------------------------------===//
24
24
25
25
#include < cstdint>
26
+ #include < exception>
26
27
#include < utility>
27
28
#include < vector>
28
29
@@ -105,6 +106,19 @@ void init_radix_sort_dispatch_vectors(void)
105
106
dtv2.populate_dispatch_vector (descending_radix_sort_contig_dispatch_vector);
106
107
}
107
108
109
+ bool py_radix_sort_defined (int typenum)
110
+ {
111
+ const auto &array_types = td_ns::usm_ndarray_types ();
112
+
113
+ try {
114
+ int type_id = array_types.typenum_to_lookup_id (typenum);
115
+ return (nullptr !=
116
+ ascending_radix_sort_contig_dispatch_vector[type_id]);
117
+ } catch (const std::exception &e) {
118
+ return false ;
119
+ }
120
+ }
121
+
108
122
void init_radix_sort_functions (py::module_ m)
109
123
{
110
124
dpctl::tensor::py_internal::init_radix_sort_dispatch_vectors ();
@@ -139,6 +153,8 @@ void init_radix_sort_functions(py::module_ m)
139
153
py::arg (" trailing_dims_to_sort" ), py::arg (" dst" ),
140
154
py::arg (" sycl_queue" ), py::arg (" depends" ) = py::list ());
141
155
156
+ m.def (" _radix_sort_dtype_supported" , py_radix_sort_defined);
157
+
142
158
return ;
143
159
}
144
160
You can’t perform that action at this time.
0 commit comments