Skip to content

Commit 11ecba8

Browse files
authored
Fix search reductions giving incorrect results for F-contiguous inputs (#1462)
* Fixes correctness regression in search functions ``py_search_over_axis`` no longer calls the ``axis1`` contiguous variant ``py_search_over_axis`` now only calls ``axis0`` variant wh * Adds tests for fixed search reduction behavior
1 parent 9131925 commit 11ecba8

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -874,14 +874,11 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
874874
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
875875

876876
// handle special case when both reduction and iteration are 1D contiguous
877-
// and can be done with atomics
878877
bool is_src_c_contig = src.is_c_contiguous();
879878
bool is_dst_c_contig = dst.is_c_contiguous();
880879
bool is_src_f_contig = src.is_f_contiguous();
881880

882-
if ((is_src_c_contig && is_dst_c_contig) ||
883-
(is_src_f_contig && dst_nelems == 1))
884-
{
881+
if (is_src_c_contig && is_dst_c_contig) {
885882
auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid];
886883
if (fn != nullptr) {
887884
size_t iter_nelems = dst_nelems;
@@ -903,9 +900,7 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
903900
reduction_over_axis_contig_ev);
904901
}
905902
}
906-
else if (is_src_f_contig &&
907-
((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous()))
908-
{
903+
else if (is_src_f_contig && dst_nd == 1) {
909904
auto fn = axis0_contig_dispatch_table[src_typeid][dst_typeid];
910905
if (fn != nullptr) {
911906
size_t iter_nelems = dst_nelems;
@@ -983,11 +978,9 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
983978
if ((reduction_nd == 1) && (iteration_nd == 1)) {
984979
bool mat_reduce_over_axis1 = false;
985980
bool mat_reduce_over_axis0 = false;
986-
bool array_reduce_all_elems = false;
987981
size_t iter_nelems = dst_nelems;
988982

989983
if (compact_reduction_src_strides[0] == 1) {
990-
array_reduce_all_elems = (simplified_iteration_shape[0] == 1);
991984
mat_reduce_over_axis1 =
992985
(simplified_iteration_dst_strides[0] == 1) &&
993986
(static_cast<size_t>(simplified_iteration_src_strides[0]) ==
@@ -1000,7 +993,7 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
1000993
(simplified_iteration_src_strides[0] == 1);
1001994
}
1002995

1003-
if (mat_reduce_over_axis1 || array_reduce_all_elems) {
996+
if (mat_reduce_over_axis1) {
1004997
auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid];
1005998
if (fn != nullptr) {
1006999
sycl::event reduction_over_axis1_contig_ev =

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,22 @@ def test_argmax_argmin_identities():
265265
assert dpt.argmin(x) == 0
266266

267267

268+
@pytest.mark.parametrize("order", ["C", "F"])
269+
def test_argmax_axis0_axis1(order):
270+
get_queue_or_skip()
271+
272+
x = dpt.asarray([[1, 2, 3], [6, 5, 4]], dtype="i4", order=order)
273+
assert dpt.argmax(x) == 3
274+
275+
res = dpt.argmax(x, axis=0)
276+
expected = dpt.asarray([1, 1, 1], dtype=res.dtype)
277+
assert dpt.all(res == expected)
278+
279+
res = dpt.argmax(x, axis=1)
280+
expected = dpt.asarray([2, 0], dtype=res.dtype)
281+
assert dpt.all(res == expected)
282+
283+
268284
def test_reduction_arg_validation():
269285
get_queue_or_skip()
270286

0 commit comments

Comments
 (0)