@@ -874,14 +874,11 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
874
874
int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
875
875
876
876
// handle special case when both reduction and iteration are 1D contiguous
877
- // and can be done with atomics
878
877
bool is_src_c_contig = src.is_c_contiguous ();
879
878
bool is_dst_c_contig = dst.is_c_contiguous ();
880
879
bool is_src_f_contig = src.is_f_contiguous ();
881
880
882
- if ((is_src_c_contig && is_dst_c_contig) ||
883
- (is_src_f_contig && dst_nelems == 1 ))
884
- {
881
+ if (is_src_c_contig && is_dst_c_contig) {
885
882
auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid];
886
883
if (fn != nullptr ) {
887
884
size_t iter_nelems = dst_nelems;
@@ -903,9 +900,7 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
903
900
reduction_over_axis_contig_ev);
904
901
}
905
902
}
906
- else if (is_src_f_contig &&
907
- ((is_dst_c_contig && dst_nd == 1 ) || dst.is_f_contiguous ()))
908
- {
903
+ else if (is_src_f_contig && dst_nd == 1 ) {
909
904
auto fn = axis0_contig_dispatch_table[src_typeid][dst_typeid];
910
905
if (fn != nullptr ) {
911
906
size_t iter_nelems = dst_nelems;
@@ -983,11 +978,9 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
983
978
if ((reduction_nd == 1 ) && (iteration_nd == 1 )) {
984
979
bool mat_reduce_over_axis1 = false ;
985
980
bool mat_reduce_over_axis0 = false ;
986
- bool array_reduce_all_elems = false ;
987
981
size_t iter_nelems = dst_nelems;
988
982
989
983
if (compact_reduction_src_strides[0 ] == 1 ) {
990
- array_reduce_all_elems = (simplified_iteration_shape[0 ] == 1 );
991
984
mat_reduce_over_axis1 =
992
985
(simplified_iteration_dst_strides[0 ] == 1 ) &&
993
986
(static_cast <size_t >(simplified_iteration_src_strides[0 ]) ==
@@ -1000,7 +993,7 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
1000
993
(simplified_iteration_src_strides[0 ] == 1 );
1001
994
}
1002
995
1003
- if (mat_reduce_over_axis1 || array_reduce_all_elems ) {
996
+ if (mat_reduce_over_axis1) {
1004
997
auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid];
1005
998
if (fn != nullptr ) {
1006
999
sycl::event reduction_over_axis1_contig_ev =
0 commit comments