Skip to content

Commit 101514f

Browse files
Implement custom scan group function for generic binary operator
Using ``` import dpctl.tensor as dpt import dpctl x = dpt.ones(2048000, dtype="f4") q_prof = dpctl.SyclQueue(x.sycl_context, x.sycl_device, property="enable_profiling") xx = x.to_device(q_prof) mm = dpt.cumulative_logsumexp(xx) timer = dpctl.SyclTimer(device_timer="order_manager", time_scale=1e9) with timer(q_prof): for _ in range(250): dpt.cumulative_logsumexp(xx, out=mm) print(f"dpctl.__version__ = {dpctl.__version__}") print(f"Device: {x.sycl_device}") print(f"host_dt={timer.dt.host_dt/250}, device_dt={timer.dt.device_dt/250}") ``` Testing on Iris Xe from WSL. This branch: ``` $ python ~/cumlogsumexp.py dpctl.__version__ = 0.19.0dev0+351.gffd26092a0.dirty Device: <dpctl.SyclDevice [backend_type.level_zero, device_type.gpu, Intel(R) Graphics [0x9a49]] at 0x7f37a8f995f0> host_dt=1059589.7079911083, device_dt=1154782.72 ``` vs. main branch: ``` $ python cumlogsumexp.py dpctl.__version__ = 0.19.0dev0+307.g04a8228748 Device: <dpctl.SyclDevice [backend_type.level_zero, device_type.gpu, Intel(R) Graphics [0x9a49]] at 0x7ff6147d3cf0> host_dt=2721938.803792, device_dt=10048323.168 ``` So this is about 8x speed-up.
1 parent ffd2609 commit 101514f

File tree

2 files changed

+73
-15
lines changed

2 files changed

+73
-15
lines changed

dpctl/tensor/libtensor/include/kernels/accumulators.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,8 @@ inclusive_scan_base_step_blocked(sycl::queue &exec_q,
280280
}
281281
else {
282282
wg_iscan_val = su_ns::custom_inclusive_scan_over_group(
283-
it.get_group(), slm_iscan_tmp, local_iscan.back(), scan_op);
283+
it.get_group(), it.get_sub_group(), slm_iscan_tmp,
284+
local_iscan.back(), identity, scan_op);
284285
// ensure all finished reading from SLM, to avoid race condition
285286
// with subsequent writes into SLM
286287
it.barrier(sycl::access::fence_space::local_space);
@@ -454,7 +455,8 @@ inclusive_scan_base_step_striped(sycl::queue &exec_q,
454455
}
455456
else {
456457
wg_iscan_val = su_ns::custom_inclusive_scan_over_group(
457-
it.get_group(), slm_iscan_tmp, local_iscan.back(), scan_op);
458+
it.get_group(), sg, slm_iscan_tmp, local_iscan.back(),
459+
identity, scan_op);
458460
// ensure all finished reading from SLM, to avoid race condition
459461
// with subsequent writes into SLM
460462
it.barrier(sycl::access::fence_space::local_space);

dpctl/tensor/libtensor/include/utils/sycl_utils.hpp

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -212,29 +212,85 @@ T custom_reduce_over_group(const GroupT &wg,
212212
return sycl::group_broadcast(wg, red_val_over_wg, 0);
213213
}
214214

215-
template <typename T, typename GroupT, typename LocAccT, typename OpT>
216-
T custom_inclusive_scan_over_group(const GroupT &wg,
217-
LocAccT local_mem_acc,
218-
const T local_val,
219-
const OpT &op)
215+
template <typename GroupT,
216+
typename SubGroupT,
217+
typename LocAccT,
218+
typename T,
219+
typename OpT>
220+
T custom_inclusive_scan_over_group(GroupT &&wg,
221+
SubGroupT &&sg,
222+
LocAccT &&local_mem_acc,
223+
const T &local_val,
224+
const T &identity,
225+
OpT &&op)
220226
{
221227
const std::uint32_t local_id = wg.get_local_id(0);
222228
const std::uint32_t wgs = wg.get_local_range(0);
223-
local_mem_acc[local_id] = local_val;
224229

230+
const std::uint32_t lane_id = sg.get_local_id()[0];
231+
const std::uint32_t sgSize = sg.get_local_range()[0];
232+
233+
T scan_val = local_val;
234+
for (std::uint32_t step = 1; step < sgSize; step *= 2) {
235+
const bool advanced_lane = (lane_id >= step);
236+
const std::uint32_t src_lane_id =
237+
(advanced_lane ? lane_id - step : lane_id);
238+
const T modifier = sycl::select_from_group(sg, scan_val, src_lane_id);
239+
if (advanced_lane) {
240+
scan_val = op(scan_val, modifier);
241+
}
242+
}
243+
244+
local_mem_acc[local_id] = scan_val;
225245
sycl::group_barrier(wg, sycl::memory_scope::work_group);
226246

227-
if (wg.leader()) {
228-
T scan_val = local_mem_acc[0];
229-
for (std::uint32_t i = 1; i < wgs; ++i) {
230-
scan_val = op(local_mem_acc[i], scan_val);
231-
local_mem_acc[i] = scan_val;
247+
const std::uint32_t max_sgSize = sg.get_max_local_range()[0];
248+
const std::uint32_t sgr_id = sg.get_group_id()[0];
249+
250+
// now scan
251+
const std::uint32_t n_aggregates = 1 + ((wgs - 1) / max_sgSize);
252+
const bool large_wg = (n_aggregates > max_sgSize);
253+
if (large_wg) {
254+
if (wg.leader()) {
255+
T _scan_val = identity;
256+
for (std::uint32_t i = 1; i <= n_aggregates - max_sgSize; ++i) {
257+
_scan_val = op(local_mem_acc[i * max_sgSize - 1], _scan_val);
258+
local_mem_acc[i * max_sgSize - 1] = _scan_val;
259+
}
260+
}
261+
sycl::group_barrier(wg, sycl::memory_scope::work_group);
262+
}
263+
264+
if (sgr_id == 0 && lane_id < n_aggregates) {
265+
const std::uint32_t offset =
266+
(large_wg) ? n_aggregates - max_sgSize : 0u;
267+
T __scan_val = (offset + lane_id > 0)
268+
? local_mem_acc[(offset + lane_id) * max_sgSize - 1]
269+
: identity;
270+
for (std::uint32_t step = 1; step < sgSize; step *= 2) {
271+
const bool advanced_lane = (lane_id >= step);
272+
const std::uint32_t src_lane_id =
273+
(advanced_lane ? lane_id - step : lane_id);
274+
const T modifier =
275+
sycl::select_from_group(sg, __scan_val, src_lane_id);
276+
if (advanced_lane) {
277+
__scan_val = op(__scan_val, modifier);
278+
}
232279
}
280+
sycl::group_barrier(sg);
281+
local_mem_acc[(offset + lane_id) * max_sgSize - 1] = __scan_val;
233282
}
283+
sycl::group_barrier(wg, sycl::memory_scope::work_group);
234284

235-
// ensure all work-items see the same SLM that leader updated
285+
if (sgr_id > 0) {
286+
const T modifier = local_mem_acc[sgr_id * max_sgSize - 1];
287+
scan_val = op(scan_val, modifier);
288+
}
289+
290+
// ensure all work-items finished reading from SLM
236291
sycl::group_barrier(wg, sycl::memory_scope::work_group);
237-
return local_mem_acc[local_id];
292+
293+
return scan_val;
238294
}
239295

240296
// Reduction functors

0 commit comments

Comments
 (0)