Skip to content

Commit e15e3aa

Browse files
Tweak bounds of cooperative reduction steps
Factor out bounds as constexpr values, reused between power-of-2 branch and not-power-of-two branch. Lowered lower bounds from 32 to 8 based on pefrormance testing on PVC and Iris Xe.
1 parent 03910f3 commit e15e3aa

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

dpctl/tensor/libtensor/include/utils/sycl_utils.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ T custom_reduce_over_group(const GroupT &wg,
166166
const T &local_val,
167167
const OpT &op)
168168
{
169+
constexpr std::uint32_t low_sz = 8u;
170+
constexpr std::uint32_t high_sz = 1024u;
169171
const std::uint32_t wgs = wg.get_local_linear_range();
170172
const std::uint32_t lid = wg.get_local_linear_id();
171173

@@ -176,7 +178,7 @@ T custom_reduce_over_group(const GroupT &wg,
176178
if (wgs & (wgs - 1)) {
177179
// wgs is not a power of 2
178180
#pragma unroll
179-
for (std::uint32_t sz = 1024; sz >= 32; sz >>= 1) {
181+
for (std::uint32_t sz = high_sz; sz >= low_sz; sz >>= 1) {
180182
if (n_witems >= sz) {
181183
const std::uint32_t n_witems_ = (n_witems + 1) >> 1;
182184
_fold(local_mem_acc, lid, n_witems - n_witems_, n_witems_, op);
@@ -188,7 +190,7 @@ T custom_reduce_over_group(const GroupT &wg,
188190
else {
189191
// wgs is a power of 2
190192
#pragma unroll
191-
for (std::uint32_t sz = 1024; sz >= 32; sz >>= 1) {
193+
for (std::uint32_t sz = high_sz; sz >= low_sz; sz >>= 1) {
192194
if (n_witems >= sz) {
193195
n_witems = (n_witems + 1) >> 1;
194196
_fold(local_mem_acc, lid, n_witems, op);
@@ -204,8 +206,6 @@ T custom_reduce_over_group(const GroupT &wg,
204206
}
205207
}
206208

207-
// sycl::group_barrier(wg, sycl::memory_scope::work_group);
208-
209209
return sycl::group_broadcast(wg, red_val_over_wg, 0);
210210
}
211211

0 commit comments

Comments
 (0)