Skip to content

Commit 2ea5dfd

Browse files
Merge pull request #1891 from IntelPython/contribution-to-fix-gh-1887
Contribution to fix gh 1887
2 parents a91fc80 + 18f7ea0 commit 2ea5dfd

File tree

1 file changed

+50
-34
lines changed

1 file changed

+50
-34
lines changed

dpctl/tensor/libtensor/include/kernels/copy_as_contiguous.hpp

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ namespace copy_as_contig
4444

4545
template <typename T,
4646
typename IndexerT,
47-
int vec_sz = 4,
48-
int n_vecs = 2,
47+
std::uint32_t vec_sz = 4u,
48+
std::uint32_t n_vecs = 2u,
4949
bool enable_sg_loadstore = true>
5050
class CopyAsCContigFunctor
5151
{
@@ -66,53 +66,63 @@ class CopyAsCContigFunctor
6666

6767
void operator()(sycl::nd_item<1> ndit) const
6868
{
69+
static_assert(vec_sz > 0);
70+
static_assert(n_vecs > 0);
71+
static_assert(vec_sz * n_vecs < (std::uint32_t(1) << 8));
72+
73+
constexpr std::uint8_t elems_per_wi =
74+
static_cast<std::uint8_t>(vec_sz * n_vecs);
75+
6976
using dpctl::tensor::type_utils::is_complex;
7077
if constexpr (!enable_sg_loadstore || is_complex<T>::value) {
71-
const std::uint32_t sgSize =
78+
const std::uint16_t sgSize =
7279
ndit.get_sub_group().get_local_range()[0];
7380
const std::size_t gid = ndit.get_global_linear_id();
7481

75-
const std::size_t base =
76-
(gid / sgSize) * sgSize * n_vecs * vec_sz + (gid % sgSize);
77-
for (size_t offset = base;
78-
offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz));
79-
offset += sgSize)
80-
{
82+
// base = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize)
83+
// gid % sgSize == gid - (gid / sgSize) * sgSize
84+
const std::size_t elems_per_sg = sgSize * (elems_per_wi - 1);
85+
const std::size_t base = (gid / sgSize) * elems_per_sg + gid;
86+
const std::size_t offset_max =
87+
std::min(nelems, base + sgSize * elems_per_wi);
88+
89+
for (size_t offset = base; offset < offset_max; offset += sgSize) {
8190
auto src_offset = src_indexer(offset);
8291
dst_p[offset] = src_p[src_offset];
8392
}
8493
}
8594
else {
8695
auto sg = ndit.get_sub_group();
87-
const std::uint32_t sgSize = sg.get_local_range()[0];
88-
const size_t base = n_vecs * vec_sz *
89-
(ndit.get_group(0) * ndit.get_local_range(0) +
90-
sg.get_group_id()[0] * sgSize);
96+
const std::uint16_t sgSize = sg.get_max_local_range()[0];
97+
const size_t base =
98+
elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
99+
sg.get_group_id()[0] * sgSize);
91100

92-
if (base + n_vecs * vec_sz * sgSize < nelems) {
101+
if (base + elems_per_wi * sgSize < nelems) {
93102
sycl::vec<T, vec_sz> dst_vec;
94103

95104
#pragma unroll
96-
for (std::uint32_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
105+
for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
106+
const size_t block_start_id = base + it * sgSize;
97107
auto dst_multi_ptr = sycl::address_space_cast<
98108
sycl::access::address_space::global_space,
99-
sycl::access::decorated::yes>(
100-
&dst_p[base + it * sgSize]);
109+
sycl::access::decorated::yes>(&dst_p[block_start_id]);
101110

111+
const size_t elem_id0 = block_start_id + sg.get_local_id();
102112
#pragma unroll
103-
for (std::uint32_t k = 0; k < vec_sz; k++) {
104-
ssize_t src_offset = src_indexer(
105-
base + (it + k) * sgSize + sg.get_local_id());
113+
for (std::uint8_t k = 0; k < vec_sz; k++) {
114+
const size_t elem_id = elem_id0 + k * sgSize;
115+
const ssize_t src_offset = src_indexer(elem_id);
106116
dst_vec[k] = src_p[src_offset];
107117
}
108118
sg.store<vec_sz>(dst_multi_ptr, dst_vec);
109119
}
110120
}
111121
else {
112-
for (size_t k = base + sg.get_local_id()[0]; k < nelems;
113-
k += sgSize)
114-
{
115-
ssize_t src_offset = src_indexer(k);
122+
const size_t lane_id = sg.get_local_id()[0];
123+
const size_t k0 = base + lane_id;
124+
for (size_t k = k0; k < nelems; k += sgSize) {
125+
const ssize_t src_offset = src_indexer(k);
116126
dst_p[k] = src_p[src_offset];
117127
}
118128
}
@@ -122,8 +132,8 @@ class CopyAsCContigFunctor
122132

123133
template <typename T,
124134
typename IndexerT,
125-
std::uint32_t n_vecs,
126135
std::uint32_t vec_sz,
136+
std::uint32_t n_vecs,
127137
bool enable_sg_load,
128138
typename KernelName>
129139
sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
@@ -133,6 +143,10 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
133143
const IndexerT &src_indexer,
134144
const std::vector<sycl::event> &depends)
135145
{
146+
static_assert(vec_sz > 0);
147+
static_assert(n_vecs > 0);
148+
static_assert(vec_sz * n_vecs < (std::uint32_t(1) << 8));
149+
136150
constexpr std::size_t preferred_lws = 256;
137151

138152
const auto &kernel_id = sycl::get_kernel_id<KernelName>();
@@ -150,9 +164,11 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
150164
const std::size_t lws =
151165
((preferred_lws + max_sg_size - 1) / max_sg_size) * max_sg_size;
152166

153-
constexpr std::uint32_t nelems_per_wi = n_vecs * vec_sz;
154-
size_t n_groups =
155-
(nelems + nelems_per_wi * lws - 1) / (nelems_per_wi * lws);
167+
constexpr std::uint8_t nelems_per_wi = n_vecs * vec_sz;
168+
169+
const size_t nelems_per_group = nelems_per_wi * lws;
170+
const size_t n_groups =
171+
(nelems + nelems_per_group - 1) / (nelems_per_group);
156172

157173
sycl::event copy_ev = exec_q.submit([&](sycl::handler &cgh) {
158174
cgh.depends_on(depends);
@@ -171,8 +187,8 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
171187

172188
template <typename T,
173189
typename IndexT,
174-
int vec_sz,
175-
int n_vecs,
190+
std::uint32_t vec_sz,
191+
std::uint32_t n_vecs,
176192
bool enable_sgload>
177193
class as_contig_krn;
178194

@@ -194,8 +210,8 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
194210
using IndexerT = dpctl::tensor::offset_utils::StridedIndexer;
195211
const IndexerT src_indexer(nd, ssize_t(0), shape_and_strides);
196212

197-
constexpr std::uint32_t n_vecs = 2;
198-
constexpr std::uint32_t vec_sz = 4;
213+
constexpr std::uint32_t vec_sz = 4u;
214+
constexpr std::uint32_t n_vecs = 2u;
199215

200216
using dpctl::tensor::kernels::alignment_utils::
201217
disabled_sg_loadstore_wrapper_krn;
@@ -207,7 +223,7 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
207223
constexpr bool enable_sg_load = true;
208224
using KernelName =
209225
as_contig_krn<T, IndexerT, vec_sz, n_vecs, enable_sg_load>;
210-
copy_ev = submit_c_contiguous_copy<T, IndexerT, n_vecs, vec_sz,
226+
copy_ev = submit_c_contiguous_copy<T, IndexerT, vec_sz, n_vecs,
211227
enable_sg_load, KernelName>(
212228
exec_q, nelems, src_tp, dst_tp, src_indexer, depends);
213229
}
@@ -216,7 +232,7 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
216232
using InnerKernelName =
217233
as_contig_krn<T, IndexerT, vec_sz, n_vecs, disable_sg_load>;
218234
using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
219-
copy_ev = submit_c_contiguous_copy<T, IndexerT, n_vecs, vec_sz,
235+
copy_ev = submit_c_contiguous_copy<T, IndexerT, vec_sz, n_vecs,
220236
disable_sg_load, KernelName>(
221237
exec_q, nelems, src_tp, dst_tp, src_indexer, depends);
222238
}

0 commit comments

Comments
 (0)