26
26
#include < cstddef>
27
27
#include < cstdint>
28
28
#include < stdexcept>
29
- #include < sycl/sycl.hpp>
30
29
#include < utility>
31
30
31
+ #include < sycl/sycl.hpp>
32
+
32
33
#include " kernels/alignment.hpp"
33
34
#include " kernels/dpctl_tensor_types.hpp"
35
+ #include " kernels/elementwise_functions/common_detail.hpp"
34
36
#include " utils/offset_utils.hpp"
35
37
#include " utils/sycl_alloc_utils.hpp"
36
38
#include " utils/sycl_utils.hpp"
@@ -324,21 +326,23 @@ sycl::event unary_contig_impl(sycl::queue &exec_q,
324
326
{
325
327
constexpr bool enable_sg_loadstore = true ;
326
328
using KernelName = BaseKernelName;
329
+ using Impl = ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
330
+ enable_sg_loadstore>;
327
331
328
332
cgh.parallel_for <KernelName>(
329
333
sycl::nd_range<1 >(gws_range, lws_range),
330
- ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
331
- enable_sg_loadstore>(arg_tp, res_tp, nelems));
334
+ Impl (arg_tp, res_tp, nelems));
332
335
}
333
336
else {
334
337
constexpr bool disable_sg_loadstore = false ;
335
338
using KernelName =
336
339
disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
340
+ using Impl = ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
341
+ disable_sg_loadstore>;
337
342
338
343
cgh.parallel_for <KernelName>(
339
344
sycl::nd_range<1 >(gws_range, lws_range),
340
- ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
341
- disable_sg_loadstore>(arg_tp, res_tp, nelems));
345
+ Impl (arg_tp, res_tp, nelems));
342
346
}
343
347
});
344
348
@@ -377,9 +381,10 @@ unary_strided_impl(sycl::queue &exec_q,
377
381
const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_p);
378
382
resTy *res_tp = reinterpret_cast <resTy *>(res_p);
379
383
384
+ using Impl = StridedFunctorT<argTy, resTy, IndexerT>;
385
+
380
386
cgh.parallel_for <kernel_name<argTy, resTy, IndexerT>>(
381
- {nelems},
382
- StridedFunctorT<argTy, resTy, IndexerT>(arg_tp, res_tp, indexer));
387
+ {nelems}, Impl (arg_tp, res_tp, indexer));
383
388
});
384
389
return comp_ev;
385
390
}
@@ -814,22 +819,23 @@ sycl::event binary_contig_impl(sycl::queue &exec_q,
814
819
{
815
820
constexpr bool enable_sg_loadstore = true ;
816
821
using KernelName = BaseKernelName;
822
+ using Impl = BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz,
823
+ n_vecs, enable_sg_loadstore>;
817
824
818
825
cgh.parallel_for <KernelName>(
819
826
sycl::nd_range<1 >(gws_range, lws_range),
820
- BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz, n_vecs,
821
- enable_sg_loadstore>(arg1_tp, arg2_tp,
822
- res_tp, nelems));
827
+ Impl (arg1_tp, arg2_tp, res_tp, nelems));
823
828
}
824
829
else {
825
830
constexpr bool disable_sg_loadstore = false ;
826
831
using KernelName =
827
832
disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
833
+ using Impl = BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz,
834
+ n_vecs, disable_sg_loadstore>;
835
+
828
836
cgh.parallel_for <KernelName>(
829
837
sycl::nd_range<1 >(gws_range, lws_range),
830
- BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz, n_vecs,
831
- disable_sg_loadstore>(arg1_tp, arg2_tp,
832
- res_tp, nelems));
838
+ Impl (arg1_tp, arg2_tp, res_tp, nelems));
833
839
}
834
840
});
835
841
return comp_ev;
@@ -873,9 +879,10 @@ binary_strided_impl(sycl::queue &exec_q,
873
879
const argTy2 *arg2_tp = reinterpret_cast <const argTy2 *>(arg2_p);
874
880
resTy *res_tp = reinterpret_cast <resTy *>(res_p);
875
881
882
+ using Impl = BinaryStridedFunctorT<argTy1, argTy2, resTy, IndexerT>;
883
+
876
884
cgh.parallel_for <kernel_name<argTy1, argTy2, resTy, IndexerT>>(
877
- {nelems}, BinaryStridedFunctorT<argTy1, argTy2, resTy, IndexerT>(
878
- arg1_tp, arg2_tp, res_tp, indexer));
885
+ {nelems}, Impl (arg1_tp, arg2_tp, res_tp, indexer));
879
886
});
880
887
return comp_ev;
881
888
}
@@ -917,13 +924,9 @@ sycl::event binary_contig_matrix_contig_row_broadcast_impl(
917
924
exec_q);
918
925
argT2 *padded_vec = padded_vec_owner.get ();
919
926
920
- sycl::event make_padded_vec_ev = exec_q.submit ([&](sycl::handler &cgh) {
921
- cgh.depends_on (depends); // ensure vec contains actual data
922
- cgh.parallel_for ({n1_padded}, [=](sycl::id<1 > id) {
923
- auto i = id[0 ];
924
- padded_vec[i] = vec[i % n1];
925
- });
926
- });
927
+ sycl::event make_padded_vec_ev =
928
+ dpctl::tensor::kernels::elementwise_detail::populate_padded_vector<
929
+ argT2>(exec_q, vec, n1, padded_vec, n1_padded, depends);
927
930
928
931
// sub-group spans work-items [I, I + sgSize)
929
932
// base = ndit.get_global_linear_id() - sg.get_local_id()[0]
@@ -942,10 +945,12 @@ sycl::event binary_contig_matrix_contig_row_broadcast_impl(
942
945
std::size_t n_groups = (n_elems + lws - 1 ) / lws;
943
946
auto gwsRange = sycl::range<1 >(n_groups * lws);
944
947
948
+ using Impl =
949
+ BinaryContigMatrixContigRowBroadcastFunctorT<argT1, argT2, resT>;
950
+
945
951
cgh.parallel_for <class kernel_name <argT1, argT2, resT>>(
946
952
sycl::nd_range<1 >(gwsRange, lwsRange),
947
- BinaryContigMatrixContigRowBroadcastFunctorT<argT1, argT2, resT>(
948
- mat, padded_vec, res, n_elems, n1));
953
+ Impl (mat, padded_vec, res, n_elems, n1));
949
954
});
950
955
951
956
sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free (
@@ -993,13 +998,9 @@ sycl::event binary_contig_row_contig_matrix_broadcast_impl(
993
998
exec_q);
994
999
argT2 *padded_vec = padded_vec_owner.get ();
995
1000
996
- sycl::event make_padded_vec_ev = exec_q.submit ([&](sycl::handler &cgh) {
997
- cgh.depends_on (depends); // ensure vec contains actual data
998
- cgh.parallel_for ({n1_padded}, [=](sycl::id<1 > id) {
999
- auto i = id[0 ];
1000
- padded_vec[i] = vec[i % n1];
1001
- });
1002
- });
1001
+ sycl::event make_padded_vec_ev =
1002
+ dpctl::tensor::kernels::elementwise_detail::populate_padded_vector<
1003
+ argT2>(exec_q, vec, n1, padded_vec, n1_padded, depends);
1003
1004
1004
1005
// sub-group spans work-items [I, I + sgSize)
1005
1006
// base = ndit.get_global_linear_id() - sg.get_local_id()[0]
@@ -1018,10 +1019,12 @@ sycl::event binary_contig_row_contig_matrix_broadcast_impl(
1018
1019
std::size_t n_groups = (n_elems + lws - 1 ) / lws;
1019
1020
auto gwsRange = sycl::range<1 >(n_groups * lws);
1020
1021
1022
+ using Impl =
1023
+ BinaryContigRowContigMatrixBroadcastFunctorT<argT1, argT2, resT>;
1024
+
1021
1025
cgh.parallel_for <class kernel_name <argT1, argT2, resT>>(
1022
1026
sycl::nd_range<1 >(gwsRange, lwsRange),
1023
- BinaryContigRowContigMatrixBroadcastFunctorT<argT1, argT2, resT>(
1024
- padded_vec, mat, res, n_elems, n1));
1027
+ Impl (padded_vec, mat, res, n_elems, n1));
1025
1028
});
1026
1029
1027
1030
sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free (
0 commit comments