Skip to content

Commit 4b3e736

Browse files
committed
Fixes correctness regression in search functions
``py_search_over_axis`` no longer calls the ``axis1`` contiguous variant ``py_search_over_axis`` now only calls ``axis0`` variant wh
1 parent 9131925 commit 4b3e736

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -874,14 +874,11 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
874874
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
875875

876876
// handle special case when both reduction and iteration are 1D contiguous
877-
// and can be done with atomics
878877
bool is_src_c_contig = src.is_c_contiguous();
879878
bool is_dst_c_contig = dst.is_c_contiguous();
880879
bool is_src_f_contig = src.is_f_contiguous();
881880

882-
if ((is_src_c_contig && is_dst_c_contig) ||
883-
(is_src_f_contig && dst_nelems == 1))
884-
{
881+
if (is_src_c_contig && is_dst_c_contig) {
885882
auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid];
886883
if (fn != nullptr) {
887884
size_t iter_nelems = dst_nelems;
@@ -903,9 +900,7 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
903900
reduction_over_axis_contig_ev);
904901
}
905902
}
906-
else if (is_src_f_contig &&
907-
((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous()))
908-
{
903+
else if (is_src_f_contig && dst_nd == 1) {
909904
auto fn = axis0_contig_dispatch_table[src_typeid][dst_typeid];
910905
if (fn != nullptr) {
911906
size_t iter_nelems = dst_nelems;
@@ -983,11 +978,9 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
983978
if ((reduction_nd == 1) && (iteration_nd == 1)) {
984979
bool mat_reduce_over_axis1 = false;
985980
bool mat_reduce_over_axis0 = false;
986-
bool array_reduce_all_elems = false;
987981
size_t iter_nelems = dst_nelems;
988982

989983
if (compact_reduction_src_strides[0] == 1) {
990-
array_reduce_all_elems = (simplified_iteration_shape[0] == 1);
991984
mat_reduce_over_axis1 =
992985
(simplified_iteration_dst_strides[0] == 1) &&
993986
(static_cast<size_t>(simplified_iteration_src_strides[0]) ==
@@ -1000,7 +993,7 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
1000993
(simplified_iteration_src_strides[0] == 1);
1001994
}
1002995

1003-
if (mat_reduce_over_axis1 || array_reduce_all_elems) {
996+
if (mat_reduce_over_axis1) {
1004997
auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid];
1005998
if (fn != nullptr) {
1006999
sycl::event reduction_over_axis1_contig_ev =

0 commit comments

Comments
 (0)