Skip to content

Commit b8cf169

Browse files
Tweaks to copy_as_contiguous CContig functor
1. Save common subexpressions to variables 2. Sub-group size type changed to uint16 (from uint32) 3. sg.get_local_range() replaced with sg.get_max_local_range() This is safe to do since work-group size is chosen to be a multiple of sub-group size for all possile choices of sub-group size (1, 8, 16, 32, 64) 4. Simplified computation of base value in generic branch for complex types, or when sg_load is disabled, to avoid a division (and left a comment)
1 parent 6df2811 commit b8cf169

File tree

1 file changed

+36
-28
lines changed

1 file changed

+36
-28
lines changed

dpctl/tensor/libtensor/include/kernels/copy_as_contiguous.hpp

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ namespace copy_as_contig
4444

4545
template <typename T,
4646
typename IndexerT,
47-
int vec_sz = 4,
48-
int n_vecs = 2,
47+
unsigned int vec_sz = 4,
48+
unsigned int n_vecs = 2,
4949
bool enable_sg_loadstore = true>
5050
class CopyAsCContigFunctor
5151
{
@@ -66,52 +66,59 @@ class CopyAsCContigFunctor
6666

6767
void operator()(sycl::nd_item<1> ndit) const
6868
{
69+
constexpr std::uint8_t elems_per_wi =
70+
static_cast<std::uint8_t>(vec_sz * n_vecs);
6971
using dpctl::tensor::type_utils::is_complex;
72+
7073
if constexpr (!enable_sg_loadstore || is_complex<T>::value) {
71-
const std::uint32_t sgSize =
72-
ndit.get_sub_group().get_local_range()[0];
74+
const std::uint16_t sgSize =
75+
ndit.get_sub_group().get_max_local_range()[0];
7376
const std::size_t gid = ndit.get_global_linear_id();
7477

75-
const std::size_t base =
76-
(gid / sgSize) * sgSize * n_vecs * vec_sz + (gid % sgSize);
77-
for (size_t offset = base;
78-
offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz));
79-
offset += sgSize)
80-
{
78+
// base = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize)
79+
// gid % sgSize == gid - (gid / sgSize) * sgSize
80+
const std::size_t elems_per_sg = sgSize * (elems_per_wi - 1);
81+
const std::size_t base = (gid / sgSize) * elems_per_sg + gid;
82+
const std::size_t offset_max =
83+
std::min(nelems, base + sgSize * elems_per_wi);
84+
85+
for (size_t offset = base; offset < offset_max; offset += sgSize) {
8186
auto src_offset = src_indexer(offset);
8287
dst_p[offset] = src_p[src_offset];
8388
}
8489
}
8590
else {
8691
auto sg = ndit.get_sub_group();
87-
const std::uint32_t sgSize = sg.get_local_range()[0];
88-
const size_t base = n_vecs * vec_sz *
89-
(ndit.get_group(0) * ndit.get_local_range(0) +
90-
sg.get_group_id()[0] * sgSize);
92+
const std::uint16_t sgSize = sg.get_max_local_range()[0];
93+
const size_t base =
94+
elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
95+
sg.get_group_id()[0] * sgSize);
9196

92-
if (base + n_vecs * vec_sz * sgSize < nelems) {
97+
if (base + elems_per_wi * sgSize < nelems) {
9398
sycl::vec<T, vec_sz> dst_vec;
9499

95100
#pragma unroll
96-
for (std::uint32_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
101+
for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
102+
const size_t block_start_id = base + it * sgSize;
103+
97104
auto dst_multi_ptr = sycl::address_space_cast<
98105
sycl::access::address_space::global_space,
99-
sycl::access::decorated::yes>(
100-
&dst_p[base + it * sgSize]);
106+
sycl::access::decorated::yes>(&dst_p[block_start_id]);
101107

108+
const size_t elem_id0 = block_start_id + sg.get_local_id();
102109
#pragma unroll
103-
for (std::uint32_t k = 0; k < vec_sz; k++) {
104-
ssize_t src_offset = src_indexer(
105-
base + (it + k) * sgSize + sg.get_local_id());
110+
for (std::uint8_t k = 0; k < vec_sz; k++) {
111+
const size_t elem_id = elem_id0 + k * sgSize;
112+
const ssize_t src_offset = src_indexer(elem_id);
106113
dst_vec[k] = src_p[src_offset];
107114
}
108115
sg.store<vec_sz>(dst_multi_ptr, dst_vec);
109116
}
110117
}
111118
else {
112-
for (size_t k = base + sg.get_local_id()[0]; k < nelems;
113-
k += sgSize)
114-
{
119+
const size_t lane_id = sg.get_local_id()[0];
120+
const size_t k0 = base + lane_id;
121+
for (size_t k = k0; k < nelems; k += sgSize) {
115122
ssize_t src_offset = src_indexer(k);
116123
dst_p[k] = src_p[src_offset];
117124
}
@@ -122,8 +129,8 @@ class CopyAsCContigFunctor
122129

123130
template <typename T,
124131
typename IndexT,
125-
int vec_sz,
126-
int n_vecs,
132+
std::uint32_t vec_sz,
133+
std::uint32_t n_vecs,
127134
bool enable_sgload>
128135
class as_contig_krn;
129136

@@ -154,6 +161,9 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
154161
using dpctl::tensor::kernels::alignment_utils::is_aligned;
155162
using dpctl::tensor::kernels::alignment_utils::required_alignment;
156163

164+
constexpr std::uint8_t nelems_per_wi =
165+
static_cast<std::uint8_t>(n_vecs * vec_sz);
166+
157167
sycl::event copy_ev;
158168
if (is_aligned<required_alignment>(src_p) &&
159169
is_aligned<required_alignment>(dst_p))
@@ -177,7 +187,6 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
177187
const std::size_t lws =
178188
((preferred_lws + max_sg_size - 1) / max_sg_size) * max_sg_size;
179189

180-
constexpr std::uint32_t nelems_per_wi = n_vecs * vec_sz;
181190
size_t n_groups =
182191
(nelems + nelems_per_wi * lws - 1) / (nelems_per_wi * lws);
183192

@@ -216,7 +225,6 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
216225
const std::size_t lws =
217226
((preferred_lws + max_sg_size - 1) / max_sg_size) * max_sg_size;
218227

219-
constexpr std::uint32_t nelems_per_wi = n_vecs * vec_sz;
220228
size_t n_groups =
221229
(nelems + nelems_per_wi * lws - 1) / (nelems_per_wi * lws);
222230

0 commit comments

Comments
 (0)