Skip to content

Commit c5cbb08

Browse files
Merge pull request #1976 from IntelPython/reduce-elementwise-extension-size
Reduce elementwise extension size
2 parents d9e9bf8 + 1a95394 commit c5cbb08

File tree

7 files changed

+214
-94
lines changed

7 files changed

+214
-94
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@
2626
#include <cstddef>
2727
#include <cstdint>
2828
#include <limits>
29-
#include <sycl/sycl.hpp>
3029
#include <utility>
3130
#include <vector>
3231

32+
#include <sycl/sycl.hpp>
33+
3334
#include "dpctl_tensor_types.hpp"
3435
#include "utils/offset_utils.hpp"
3536
#include "utils/type_dispatch_building.hpp"
@@ -599,6 +600,10 @@ sycl::event masked_place_all_slices_strided_impl(
599600
sycl::nd_range<2> ndRange{gRange, lRange};
600601

601602
using LocalAccessorT = sycl::local_accessor<indT, 1>;
603+
using Impl =
604+
MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
605+
Strided1DCyclicIndexer, dataT, indT,
606+
LocalAccessorT>;
602607

603608
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
604609
const dataT *rhs_tp = reinterpret_cast<const dataT *>(rhs_p);
@@ -611,13 +616,9 @@ sycl::event masked_place_all_slices_strided_impl(
611616
LocalAccessorT lacc(lacc_size, cgh);
612617

613618
cgh.parallel_for<KernelName>(
614-
ndRange,
615-
MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
616-
Strided1DCyclicIndexer, dataT, indT,
617-
LocalAccessorT>(
618-
dst_tp, cumsum_tp, rhs_tp, iteration_size,
619-
orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
620-
lacc));
619+
ndRange, Impl(dst_tp, cumsum_tp, rhs_tp, iteration_size,
620+
orthog_dst_rhs_indexer, masked_dst_indexer,
621+
masked_rhs_indexer, lacc));
621622
});
622623

623624
return comp_ev;
@@ -696,6 +697,10 @@ sycl::event masked_place_some_slices_strided_impl(
696697
sycl::nd_range<2> ndRange{gRange, lRange};
697698

698699
using LocalAccessorT = sycl::local_accessor<indT, 1>;
700+
using Impl =
701+
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
702+
Strided1DCyclicIndexer, dataT, indT,
703+
LocalAccessorT>;
699704

700705
dataT *dst_tp = reinterpret_cast<dataT *>(dst_p);
701706
const dataT *rhs_tp = reinterpret_cast<const dataT *>(rhs_p);
@@ -708,13 +713,9 @@ sycl::event masked_place_some_slices_strided_impl(
708713
LocalAccessorT lacc(lacc_size, cgh);
709714

710715
cgh.parallel_for<KernelName>(
711-
ndRange,
712-
MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
713-
Strided1DCyclicIndexer, dataT, indT,
714-
LocalAccessorT>(
715-
dst_tp, cumsum_tp, rhs_tp, masked_nelems,
716-
orthog_dst_rhs_indexer, masked_dst_indexer, masked_rhs_indexer,
717-
lacc));
716+
ndRange, Impl(dst_tp, cumsum_tp, rhs_tp, masked_nelems,
717+
orthog_dst_rhs_indexer, masked_dst_indexer,
718+
masked_rhs_indexer, lacc));
718719
});
719720

720721
return comp_ev;

dpctl/tensor/libtensor/include/kernels/clip.hpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,22 +216,24 @@ sycl::event clip_contig_impl(sycl::queue &q,
216216
{
217217
constexpr bool enable_sg_loadstore = true;
218218
using KernelName = clip_contig_kernel<T, vec_sz, n_vecs>;
219+
using Impl =
220+
ClipContigFunctor<T, vec_sz, n_vecs, enable_sg_loadstore>;
219221

220222
cgh.parallel_for<KernelName>(
221223
sycl::nd_range<1>(gws_range, lws_range),
222-
ClipContigFunctor<T, vec_sz, n_vecs, enable_sg_loadstore>(
223-
nelems, x_tp, min_tp, max_tp, dst_tp));
224+
Impl(nelems, x_tp, min_tp, max_tp, dst_tp));
224225
}
225226
else {
226227
constexpr bool disable_sg_loadstore = false;
227228
using InnerKernelName = clip_contig_kernel<T, vec_sz, n_vecs>;
228229
using KernelName =
229230
disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
231+
using Impl =
232+
ClipContigFunctor<T, vec_sz, n_vecs, disable_sg_loadstore>;
230233

231234
cgh.parallel_for<KernelName>(
232235
sycl::nd_range<1>(gws_range, lws_range),
233-
ClipContigFunctor<T, vec_sz, n_vecs, disable_sg_loadstore>(
234-
nelems, x_tp, min_tp, max_tp, dst_tp));
236+
Impl(nelems, x_tp, min_tp, max_tp, dst_tp));
235237
}
236238
});
237239

@@ -311,10 +313,12 @@ sycl::event clip_strided_impl(sycl::queue &q,
311313
const FourOffsets_StridedIndexer indexer{
312314
nd, x_offset, min_offset, max_offset, dst_offset, shape_strides};
313315

314-
cgh.parallel_for<clip_strided_kernel<T, FourOffsets_StridedIndexer>>(
316+
using KernelName = clip_strided_kernel<T, FourOffsets_StridedIndexer>;
317+
using Impl = ClipStridedFunctor<T, FourOffsets_StridedIndexer>;
318+
319+
cgh.parallel_for<KernelName>(
315320
sycl::range<1>(nelems),
316-
ClipStridedFunctor<T, FourOffsets_StridedIndexer>(
317-
x_tp, min_tp, max_tp, dst_tp, indexer));
321+
Impl(x_tp, min_tp, max_tp, dst_tp, indexer));
318322
});
319323

320324
return clip_ev;

dpctl/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@
2424
//===----------------------------------------------------------------------===//
2525

2626
#pragma once
27+
#include <complex>
28+
#include <cstddef>
29+
30+
#include <sycl/sycl.hpp>
31+
2732
#include "dpctl_tensor_types.hpp"
2833
#include "utils/offset_utils.hpp"
2934
#include "utils/strided_iters.hpp"
3035
#include "utils/type_utils.hpp"
31-
#include <complex>
32-
#include <cstddef>
33-
#include <sycl/sycl.hpp>
3436

3537
namespace dpctl
3638
{
@@ -200,22 +202,25 @@ sycl::event lin_space_affine_impl(sycl::queue &exec_q,
200202
{
201203
dpctl::tensor::type_utils::validate_type_for_device<Ty>(exec_q);
202204

203-
bool device_supports_doubles = exec_q.get_device().has(sycl::aspect::fp64);
205+
const bool device_supports_doubles =
206+
exec_q.get_device().has(sycl::aspect::fp64);
207+
const std::size_t den = (include_endpoint) ? nelems - 1 : nelems;
208+
204209
sycl::event lin_space_affine_event = exec_q.submit([&](sycl::handler &cgh) {
205210
cgh.depends_on(depends);
206211
if (device_supports_doubles) {
207-
cgh.parallel_for<linear_sequence_affine_kernel<Ty, double>>(
208-
sycl::range<1>{nelems},
209-
LinearSequenceAffineFunctor<Ty, double>(
210-
array_data, start_v, end_v,
211-
(include_endpoint) ? nelems - 1 : nelems));
212+
using KernelName = linear_sequence_affine_kernel<Ty, double>;
213+
using Impl = LinearSequenceAffineFunctor<Ty, double>;
214+
215+
cgh.parallel_for<KernelName>(sycl::range<1>{nelems},
216+
Impl(array_data, start_v, end_v, den));
212217
}
213218
else {
214-
cgh.parallel_for<linear_sequence_affine_kernel<Ty, float>>(
215-
sycl::range<1>{nelems},
216-
LinearSequenceAffineFunctor<Ty, float>(
217-
array_data, start_v, end_v,
218-
(include_endpoint) ? nelems - 1 : nelems));
219+
using KernelName = linear_sequence_affine_kernel<Ty, float>;
220+
using Impl = LinearSequenceAffineFunctor<Ty, float>;
221+
222+
cgh.parallel_for<KernelName>(sycl::range<1>{nelems},
223+
Impl(array_data, start_v, end_v, den));
219224
}
220225
});
221226

@@ -312,10 +317,12 @@ sycl::event full_strided_impl(sycl::queue &q,
312317

313318
sycl::event fill_ev = q.submit([&](sycl::handler &cgh) {
314319
cgh.depends_on(depends);
315-
cgh.parallel_for<full_strided_kernel<dstTy>>(
316-
sycl::range<1>{nelems},
317-
FullStridedFunctor<dstTy, decltype(strided_indexer)>(
318-
dst_tp, fill_v, strided_indexer));
320+
321+
using KernelName = full_strided_kernel<dstTy>;
322+
using Impl = FullStridedFunctor<dstTy, StridedIndexer>;
323+
324+
cgh.parallel_for<KernelName>(sycl::range<1>{nelems},
325+
Impl(dst_tp, fill_v, strided_indexer));
319326
});
320327

321328
return fill_ev;
@@ -388,9 +395,12 @@ sycl::event eye_impl(sycl::queue &exec_q,
388395
dpctl::tensor::type_utils::validate_type_for_device<Ty>(exec_q);
389396
sycl::event eye_event = exec_q.submit([&](sycl::handler &cgh) {
390397
cgh.depends_on(depends);
391-
cgh.parallel_for<eye_kernel<Ty>>(
392-
sycl::range<1>{nelems},
393-
EyeFunctor<Ty>(array_data, start, end, step));
398+
399+
using KernelName = eye_kernel<Ty>;
400+
using Impl = EyeFunctor<Ty>;
401+
402+
cgh.parallel_for<KernelName>(sycl::range<1>{nelems},
403+
Impl(array_data, start, end, step));
394404
});
395405

396406
return eye_event;
@@ -478,7 +488,7 @@ sycl::event tri_impl(sycl::queue &exec_q,
478488
ssize_t inner_gid = idx[0] - inner_range * outer_gid;
479489

480490
ssize_t src_inner_offset = 0, dst_inner_offset = 0;
481-
bool to_copy(true);
491+
bool to_copy{false};
482492

483493
{
484494
using dpctl::tensor::strides::CIndexer_array;

dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
#include <cstddef>
2727
#include <cstdint>
2828
#include <stdexcept>
29-
#include <sycl/sycl.hpp>
3029
#include <utility>
3130

31+
#include <sycl/sycl.hpp>
32+
3233
#include "kernels/alignment.hpp"
3334
#include "kernels/dpctl_tensor_types.hpp"
35+
#include "kernels/elementwise_functions/common_detail.hpp"
3436
#include "utils/offset_utils.hpp"
3537
#include "utils/sycl_alloc_utils.hpp"
3638
#include "utils/sycl_utils.hpp"
@@ -324,21 +326,23 @@ sycl::event unary_contig_impl(sycl::queue &exec_q,
324326
{
325327
constexpr bool enable_sg_loadstore = true;
326328
using KernelName = BaseKernelName;
329+
using Impl = ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
330+
enable_sg_loadstore>;
327331

328332
cgh.parallel_for<KernelName>(
329333
sycl::nd_range<1>(gws_range, lws_range),
330-
ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
331-
enable_sg_loadstore>(arg_tp, res_tp, nelems));
334+
Impl(arg_tp, res_tp, nelems));
332335
}
333336
else {
334337
constexpr bool disable_sg_loadstore = false;
335338
using KernelName =
336339
disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
340+
using Impl = ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
341+
disable_sg_loadstore>;
337342

338343
cgh.parallel_for<KernelName>(
339344
sycl::nd_range<1>(gws_range, lws_range),
340-
ContigFunctorT<argTy, resTy, vec_sz, n_vecs,
341-
disable_sg_loadstore>(arg_tp, res_tp, nelems));
345+
Impl(arg_tp, res_tp, nelems));
342346
}
343347
});
344348

@@ -377,9 +381,10 @@ unary_strided_impl(sycl::queue &exec_q,
377381
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
378382
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
379383

384+
using Impl = StridedFunctorT<argTy, resTy, IndexerT>;
385+
380386
cgh.parallel_for<kernel_name<argTy, resTy, IndexerT>>(
381-
{nelems},
382-
StridedFunctorT<argTy, resTy, IndexerT>(arg_tp, res_tp, indexer));
387+
{nelems}, Impl(arg_tp, res_tp, indexer));
383388
});
384389
return comp_ev;
385390
}
@@ -814,22 +819,23 @@ sycl::event binary_contig_impl(sycl::queue &exec_q,
814819
{
815820
constexpr bool enable_sg_loadstore = true;
816821
using KernelName = BaseKernelName;
822+
using Impl = BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz,
823+
n_vecs, enable_sg_loadstore>;
817824

818825
cgh.parallel_for<KernelName>(
819826
sycl::nd_range<1>(gws_range, lws_range),
820-
BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz, n_vecs,
821-
enable_sg_loadstore>(arg1_tp, arg2_tp,
822-
res_tp, nelems));
827+
Impl(arg1_tp, arg2_tp, res_tp, nelems));
823828
}
824829
else {
825830
constexpr bool disable_sg_loadstore = false;
826831
using KernelName =
827832
disabled_sg_loadstore_wrapper_krn<BaseKernelName>;
833+
using Impl = BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz,
834+
n_vecs, disable_sg_loadstore>;
835+
828836
cgh.parallel_for<KernelName>(
829837
sycl::nd_range<1>(gws_range, lws_range),
830-
BinaryContigFunctorT<argTy1, argTy2, resTy, vec_sz, n_vecs,
831-
disable_sg_loadstore>(arg1_tp, arg2_tp,
832-
res_tp, nelems));
838+
Impl(arg1_tp, arg2_tp, res_tp, nelems));
833839
}
834840
});
835841
return comp_ev;
@@ -873,9 +879,10 @@ binary_strided_impl(sycl::queue &exec_q,
873879
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
874880
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
875881

882+
using Impl = BinaryStridedFunctorT<argTy1, argTy2, resTy, IndexerT>;
883+
876884
cgh.parallel_for<kernel_name<argTy1, argTy2, resTy, IndexerT>>(
877-
{nelems}, BinaryStridedFunctorT<argTy1, argTy2, resTy, IndexerT>(
878-
arg1_tp, arg2_tp, res_tp, indexer));
885+
{nelems}, Impl(arg1_tp, arg2_tp, res_tp, indexer));
879886
});
880887
return comp_ev;
881888
}
@@ -917,13 +924,9 @@ sycl::event binary_contig_matrix_contig_row_broadcast_impl(
917924
exec_q);
918925
argT2 *padded_vec = padded_vec_owner.get();
919926

920-
sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) {
921-
cgh.depends_on(depends); // ensure vec contains actual data
922-
cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) {
923-
auto i = id[0];
924-
padded_vec[i] = vec[i % n1];
925-
});
926-
});
927+
sycl::event make_padded_vec_ev =
928+
dpctl::tensor::kernels::elementwise_detail::populate_padded_vector<
929+
argT2>(exec_q, vec, n1, padded_vec, n1_padded, depends);
927930

928931
// sub-group spans work-items [I, I + sgSize)
929932
// base = ndit.get_global_linear_id() - sg.get_local_id()[0]
@@ -942,10 +945,12 @@ sycl::event binary_contig_matrix_contig_row_broadcast_impl(
942945
std::size_t n_groups = (n_elems + lws - 1) / lws;
943946
auto gwsRange = sycl::range<1>(n_groups * lws);
944947

948+
using Impl =
949+
BinaryContigMatrixContigRowBroadcastFunctorT<argT1, argT2, resT>;
950+
945951
cgh.parallel_for<class kernel_name<argT1, argT2, resT>>(
946952
sycl::nd_range<1>(gwsRange, lwsRange),
947-
BinaryContigMatrixContigRowBroadcastFunctorT<argT1, argT2, resT>(
948-
mat, padded_vec, res, n_elems, n1));
953+
Impl(mat, padded_vec, res, n_elems, n1));
949954
});
950955

951956
sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
@@ -993,13 +998,9 @@ sycl::event binary_contig_row_contig_matrix_broadcast_impl(
993998
exec_q);
994999
argT2 *padded_vec = padded_vec_owner.get();
9951000

996-
sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) {
997-
cgh.depends_on(depends); // ensure vec contains actual data
998-
cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) {
999-
auto i = id[0];
1000-
padded_vec[i] = vec[i % n1];
1001-
});
1002-
});
1001+
sycl::event make_padded_vec_ev =
1002+
dpctl::tensor::kernels::elementwise_detail::populate_padded_vector<
1003+
argT2>(exec_q, vec, n1, padded_vec, n1_padded, depends);
10031004

10041005
// sub-group spans work-items [I, I + sgSize)
10051006
// base = ndit.get_global_linear_id() - sg.get_local_id()[0]
@@ -1018,10 +1019,12 @@ sycl::event binary_contig_row_contig_matrix_broadcast_impl(
10181019
std::size_t n_groups = (n_elems + lws - 1) / lws;
10191020
auto gwsRange = sycl::range<1>(n_groups * lws);
10201021

1022+
using Impl =
1023+
BinaryContigRowContigMatrixBroadcastFunctorT<argT1, argT2, resT>;
1024+
10211025
cgh.parallel_for<class kernel_name<argT1, argT2, resT>>(
10221026
sycl::nd_range<1>(gwsRange, lwsRange),
1023-
BinaryContigRowContigMatrixBroadcastFunctorT<argT1, argT2, resT>(
1024-
padded_vec, mat, res, n_elems, n1));
1027+
Impl(padded_vec, mat, res, n_elems, n1));
10251028
});
10261029

10271030
sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(

0 commit comments

Comments
 (0)