Skip to content

Commit 1db958a

Browse files
Save repeated expression in accumulation functor operator to a variable
1 parent dd2812f commit 1db958a

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

dpctl/tensor/libtensor/include/kernels/accumulators.hpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -227,23 +227,22 @@ inclusive_scan_base_step(sycl::queue &exec_q,
227227

228228
#pragma unroll
229229
for (nwiT m_wi = 0; m_wi < n_wi; ++m_wi) {
230+
const size_t i_m_wi = i + m_wi;
230231
if constexpr (!include_initial) {
231232
local_iscan[m_wi] =
232-
(i + m_wi < acc_nelems)
233-
? transformer(
234-
input[inp_iter_offset +
235-
inp_indexer(s0 + s1 * (i + m_wi))])
233+
(i_m_wi < acc_nelems)
234+
? transformer(input[inp_iter_offset +
235+
inp_indexer(s0 + s1 * i_m_wi)])
236236
: identity;
237237
}
238238
else {
239239
// shift input to the left by a single element relative to
240240
// output
241241
local_iscan[m_wi] =
242-
(i + m_wi < acc_nelems && i + m_wi > 0)
242+
(i_m_wi < acc_nelems && i_m_wi > 0)
243243
? transformer(
244244
input[inp_iter_offset +
245-
inp_indexer((s0 + s1 * (i + m_wi)) -
246-
1)])
245+
inp_indexer((s0 + s1 * i_m_wi) - 1)])
247246
: identity;
248247
}
249248
}
@@ -280,9 +279,9 @@ inclusive_scan_base_step(sycl::queue &exec_q,
280279
local_iscan[m_wi] = scan_op(local_iscan[m_wi], addand);
281280
}
282281

283-
for (nwiT m_wi = 0; (m_wi < n_wi) && (i + m_wi < acc_nelems);
284-
++m_wi)
285-
{
282+
const nwiT m_max =
283+
std::min<nwiT>(n_wi, std::max(i, acc_nelems) - i);
284+
for (nwiT m_wi = 0; m_wi < m_max; ++m_wi) {
286285
output[out_iter_offset + out_indexer(i + m_wi)] =
287286
local_iscan[m_wi];
288287
}

0 commit comments

Comments
 (0)