diff --git a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp index aec9863bf0..f4a97c20b9 100644 --- a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp +++ b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp @@ -227,23 +227,22 @@ inclusive_scan_base_step(sycl::queue &exec_q, #pragma unroll for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) { + const size_t i_m_wi = i + m_wi; if constexpr (!include_initial) { local_iscan[m_wi] = - (i + m_wi < acc_nelems) - ? transformer( - input[inp_iter_offset + - inp_indexer(s0 + s1 * (i + m_wi))]) + (i_m_wi < acc_nelems) + ? transformer(input[inp_iter_offset + + inp_indexer(s0 + s1 * i_m_wi)]) : identity; } else { // shift input to the left by a single element relative to // output local_iscan[m_wi] = - (i + m_wi < acc_nelems && i + m_wi > 0) + (i_m_wi < acc_nelems && i_m_wi > 0) ? transformer( input[inp_iter_offset + - inp_indexer((s0 + s1 * (i + m_wi)) - - 1)]) + inp_indexer((s0 + s1 * i_m_wi) - 1)]) : identity; } } @@ -280,9 +279,9 @@ inclusive_scan_base_step(sycl::queue &exec_q, local_iscan[m_wi] = scan_op(local_iscan[m_wi], addand); } - for (nwiT m_wi = 0; (m_wi < n_wi) && (i + m_wi < acc_nelems); - ++m_wi) - { + const nwiT m_max = + std::min(n_wi, std::max(i, acc_nelems) - i); + for (nwiT m_wi = 0; m_wi < m_max; ++m_wi) { output[out_iter_offset + out_indexer(i + m_wi)] = local_iscan[m_wi]; }