Skip to content

Commit cef4359

Browse files
Make local work-groups size dependent on number of elements to process
1 parent cd783e0 commit cef4359

File tree

1 file changed

+53
-35
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+53
-35
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,23 @@ struct UnaryStridedFunctor
269269
}
270270
};
271271

272+
template <typename SizeT>
273+
SizeT select_lws(const sycl::device &, SizeT n_work_items_needed)
274+
{
275+
// TODO: make the decision based on device descriptors
276+
277+
constexpr SizeT few_threshold = (SizeT(1) << 17);
278+
constexpr SizeT med_threshold = (SizeT(1) << 21);
279+
280+
const SizeT lws =
281+
((n_work_items_needed <= few_threshold)
282+
? SizeT(64)
283+
: (n_work_items_needed <= med_threshold ? SizeT(128)
284+
: SizeT(256)));
285+
286+
return lws;
287+
}
288+
272289
template <typename argTy,
273290
template <typename T>
274291
class UnaryOutputType,
@@ -288,26 +305,28 @@ sycl::event unary_contig_impl(sycl::queue &exec_q,
288305
char *res_p,
289306
const std::vector<sycl::event> &depends = {})
290307
{
291-
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
292-
cgh.depends_on(depends);
308+
const size_t n_work_items_needed = nelems / (n_vecs * vec_sz);
309+
const size_t lws = select_lws(exec_q.get_device(), n_work_items_needed);
293310

294-
// Choose work-group size to occupy all threads of since vector core
295-
// busy (8 threads, simd32)
296-
const size_t lws = 256;
297-
const size_t n_groups =
298-
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
299-
const auto gws_range = sycl::range<1>(n_groups * lws);
300-
const auto lws_range = sycl::range<1>(lws);
311+
const size_t n_groups =
312+
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
313+
const auto gws_range = sycl::range<1>(n_groups * lws);
314+
const auto lws_range = sycl::range<1>(lws);
301315

302-
using resTy = typename UnaryOutputType<argTy>::value_type;
303-
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
304-
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
316+
using resTy = typename UnaryOutputType<argTy>::value_type;
317+
using BaseKernelName = kernel_name<argTy, resTy, vec_sz, n_vecs>;
318+
319+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
320+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
321+
322+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
323+
cgh.depends_on(depends);
305324

306325
if (is_aligned<required_alignment>(arg_p) &&
307326
is_aligned<required_alignment>(res_p))
308327
{
309328
constexpr bool enable_sg_loadstore = true;
310-
using KernelName = kernel_name<argTy, resTy, vec_sz, n_vecs>;
329+
using KernelName = BaseKernelName;
311330

312331
cgh.parallel_for<KernelName>(
313332
sycl::nd_range<1>(gws_range, lws_range),
@@ -316,16 +335,16 @@ sycl::event unary_contig_impl(sycl::queue &exec_q,
316335
}
317336
else {
318337
constexpr bool disable_sg_loadstore = false;
319-
using InnerKernelName = kernel_name<argTy, resTy, vec_sz, n_vecs>;
320338
using KernelName =
321-
disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
339+
disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
322340

323341
cgh.parallel_for<KernelName>(
324342
sycl::nd_range<1>(gws_range, lws_range),
325343
ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
326344
disable_sg_loadstore>(arg_tp, res_tp, nelems));
327345
}
328346
});
347+
329348
return comp_ev;
330349
}
331350

@@ -773,32 +792,33 @@ sycl::event binary_contig_impl(sycl::queue &exec_q,
773792
ssize_t res_offset,
774793
const std::vector<sycl::event> &depends = {})
775794
{
776-
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
777-
cgh.depends_on(depends);
795+
const size_t n_work_items_needed = nelems / (n_vecs * vec_sz);
796+
const size_t lws = select_lws(exec_q.get_device(), n_work_items_needed);
778797

779-
// Choose work-group size to occupy all threads of since vector core
780-
// busy (8 threads, simd32)
781-
const size_t lws = 256;
782-
const size_t n_groups =
783-
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
784-
const auto gws_range = sycl::range<1>(n_groups * lws);
785-
const auto lws_range = sycl::range<1>(lws);
798+
const size_t n_groups =
799+
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
800+
const auto gws_range = sycl::range<1>(n_groups * lws);
801+
const auto lws_range = sycl::range<1>(lws);
786802

787-
using resTy = typename BinaryOutputType<argTy1, argTy2>::value_type;
803+
using resTy = typename BinaryOutputType<argTy1, argTy2>::value_type;
804+
using BaseKernelName = kernel_name<argTy1, argTy2, resTy, vec_sz, n_vecs>;
805+
806+
const argTy1 *arg1_tp =
807+
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
808+
const argTy2 *arg2_tp =
809+
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
810+
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;
788811

789-
const argTy1 *arg1_tp =
790-
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
791-
const argTy2 *arg2_tp =
792-
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
793-
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;
812+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
813+
cgh.depends_on(depends);
794814

795815
if (is_aligned<required_alignment>(arg1_tp) &&
796816
is_aligned<required_alignment>(arg2_tp) &&
797817
is_aligned<required_alignment>(res_tp))
798818
{
799819
constexpr bool enable_sg_loadstore = true;
800-
using KernelName =
801-
kernel_name<argTy1, argTy2, resTy, vec_sz, n_vecs>;
820+
using KernelName = BaseKernelName;
821+
802822
cgh.parallel_for<KernelName>(
803823
sycl::nd_range<1>(gws_range, lws_range),
804824
BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz, n_vecs,
@@ -807,10 +827,8 @@ sycl::event binary_contig_impl(sycl::queue &exec_q,
807827
}
808828
else {
809829
constexpr bool disable_sg_loadstore = false;
810-
using InnerKernelName =
811-
kernel_name<argTy1, argTy2, resTy, vec_sz, n_vecs>;
812830
using KernelName =
813-
disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
831+
disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
814832
cgh.parallel_for<KernelName>(
815833
sycl::nd_range<1>(gws_range, lws_range),
816834
BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz, n_vecs,

0 commit comments

Comments
 (0)