@@ -269,6 +269,23 @@ struct UnaryStridedFunctor
269
269
}
270
270
};
271
271
272
+ template <typename SizeT>
273
+ SizeT select_lws (const sycl::device &, SizeT n_work_items_needed)
274
+ {
275
+ // TODO: make the decision based on device descriptors
276
+
277
+ constexpr SizeT few_threshold = (SizeT (1 ) << 17 );
278
+ constexpr SizeT med_threshold = (SizeT (1 ) << 21 );
279
+
280
+ const SizeT lws =
281
+ ((n_work_items_needed <= few_threshold)
282
+ ? SizeT (64 )
283
+ : (n_work_items_needed <= med_threshold ? SizeT (128 )
284
+ : SizeT (256 )));
285
+
286
+ return lws;
287
+ }
288
+
272
289
template <typename argTy,
273
290
template <typename T>
274
291
class UnaryOutputType ,
@@ -288,26 +305,28 @@ sycl::event unary_contig_impl(sycl::queue &exec_q,
288
305
char *res_p,
289
306
const std::vector<sycl::event> &depends = {})
290
307
{
291
- sycl::event comp_ev = exec_q. submit ([&](sycl::handler &cgh) {
292
- cgh. depends_on (depends );
308
+ const size_t n_work_items_needed = nelems / (n_vecs * vec_sz);
309
+ const size_t lws = select_lws (exec_q. get_device (), n_work_items_needed );
293
310
294
- // Choose work-group size to occupy all threads of since vector core
295
- // busy (8 threads, simd32)
296
- const size_t lws = 256 ;
297
- const size_t n_groups =
298
- ((nelems + lws * n_vecs * vec_sz - 1 ) / (lws * n_vecs * vec_sz));
299
- const auto gws_range = sycl::range<1 >(n_groups * lws);
300
- const auto lws_range = sycl::range<1 >(lws);
311
+ const size_t n_groups =
312
+ ((nelems + lws * n_vecs * vec_sz - 1 ) / (lws * n_vecs * vec_sz));
313
+ const auto gws_range = sycl::range<1 >(n_groups * lws);
314
+ const auto lws_range = sycl::range<1 >(lws);
301
315
302
- using resTy = typename UnaryOutputType<argTy>::value_type;
303
- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_p);
304
- resTy *res_tp = reinterpret_cast <resTy *>(res_p);
316
+ using resTy = typename UnaryOutputType<argTy>::value_type;
317
+ using BaseKernelName = kernel_name<argTy, resTy, vec_sz, n_vecs>;
318
+
319
+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_p);
320
+ resTy *res_tp = reinterpret_cast <resTy *>(res_p);
321
+
322
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
323
+ cgh.depends_on (depends);
305
324
306
325
if (is_aligned<required_alignment>(arg_p) &&
307
326
is_aligned<required_alignment>(res_p))
308
327
{
309
328
constexpr bool enable_sg_loadstore = true ;
310
- using KernelName = kernel_name<argTy, resTy, vec_sz, n_vecs> ;
329
+ using KernelName = BaseKernelName ;
311
330
312
331
cgh.parallel_for <KernelName>(
313
332
sycl::nd_range<1 >(gws_range, lws_range),
@@ -316,16 +335,16 @@ sycl::event unary_contig_impl(sycl::queue &exec_q,
316
335
}
317
336
else {
318
337
constexpr bool disable_sg_loadstore = false ;
319
- using InnerKernelName = kernel_name<argTy, resTy, vec_sz, n_vecs>;
320
338
using KernelName =
321
- disabled_sg_loadstore_wrapper_krn<InnerKernelName >;
339
+ disabled_sg_loadstore_wrapper_krn<BaseKernelName >;
322
340
323
341
cgh.parallel_for <KernelName>(
324
342
sycl::nd_range<1 >(gws_range, lws_range),
325
343
ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
326
344
disable_sg_loadstore>(arg_tp, res_tp, nelems));
327
345
}
328
346
});
347
+
329
348
return comp_ev;
330
349
}
331
350
@@ -773,32 +792,33 @@ sycl::event binary_contig_impl(sycl::queue &exec_q,
773
792
ssize_t res_offset,
774
793
const std::vector<sycl::event> &depends = {})
775
794
{
776
- sycl::event comp_ev = exec_q. submit ([&](sycl::handler &cgh) {
777
- cgh. depends_on (depends );
795
+ const size_t n_work_items_needed = nelems / (n_vecs * vec_sz);
796
+ const size_t lws = select_lws (exec_q. get_device (), n_work_items_needed );
778
797
779
- // Choose work-group size to occupy all threads of since vector core
780
- // busy (8 threads, simd32)
781
- const size_t lws = 256 ;
782
- const size_t n_groups =
783
- ((nelems + lws * n_vecs * vec_sz - 1 ) / (lws * n_vecs * vec_sz));
784
- const auto gws_range = sycl::range<1 >(n_groups * lws);
785
- const auto lws_range = sycl::range<1 >(lws);
798
+ const size_t n_groups =
799
+ ((nelems + lws * n_vecs * vec_sz - 1 ) / (lws * n_vecs * vec_sz));
800
+ const auto gws_range = sycl::range<1 >(n_groups * lws);
801
+ const auto lws_range = sycl::range<1 >(lws);
786
802
787
- using resTy = typename BinaryOutputType<argTy1, argTy2>::value_type;
803
+ using resTy = typename BinaryOutputType<argTy1, argTy2>::value_type;
804
+ using BaseKernelName = kernel_name<argTy1, argTy2, resTy, vec_sz, n_vecs>;
805
+
806
+ const argTy1 *arg1_tp =
807
+ reinterpret_cast <const argTy1 *>(arg1_p) + arg1_offset;
808
+ const argTy2 *arg2_tp =
809
+ reinterpret_cast <const argTy2 *>(arg2_p) + arg2_offset;
810
+ resTy *res_tp = reinterpret_cast <resTy *>(res_p) + res_offset;
788
811
789
- const argTy1 *arg1_tp =
790
- reinterpret_cast <const argTy1 *>(arg1_p) + arg1_offset;
791
- const argTy2 *arg2_tp =
792
- reinterpret_cast <const argTy2 *>(arg2_p) + arg2_offset;
793
- resTy *res_tp = reinterpret_cast <resTy *>(res_p) + res_offset;
812
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
813
+ cgh.depends_on (depends);
794
814
795
815
if (is_aligned<required_alignment>(arg1_tp) &&
796
816
is_aligned<required_alignment>(arg2_tp) &&
797
817
is_aligned<required_alignment>(res_tp))
798
818
{
799
819
constexpr bool enable_sg_loadstore = true ;
800
- using KernelName =
801
- kernel_name<argTy1, argTy2, resTy, vec_sz, n_vecs>;
820
+ using KernelName = BaseKernelName;
821
+
802
822
cgh.parallel_for <KernelName>(
803
823
sycl::nd_range<1 >(gws_range, lws_range),
804
824
BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz, n_vecs,
@@ -807,10 +827,8 @@ sycl::event binary_contig_impl(sycl::queue &exec_q,
807
827
}
808
828
else {
809
829
constexpr bool disable_sg_loadstore = false ;
810
- using InnerKernelName =
811
- kernel_name<argTy1, argTy2, resTy, vec_sz, n_vecs>;
812
830
using KernelName =
813
- disabled_sg_loadstore_wrapper_krn<InnerKernelName >;
831
+ disabled_sg_loadstore_wrapper_krn<BaseKernelName >;
814
832
cgh.parallel_for <KernelName>(
815
833
sycl::nd_range<1 >(gws_range, lws_range),
816
834
BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz, n_vecs,
0 commit comments