Skip to content

Commit b97fcb4

Browse files
Replace sg.load/sg.store with sub_group_load/sub_group_store utilities
1 parent 968f937 commit b97fcb4

File tree

2 files changed

+75
-66
lines changed

2 files changed

+75
-66
lines changed

dpctl/tensor/libtensor/include/kernels/clip.hpp

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "kernels/alignment.hpp"
3434
#include "utils/math_utils.hpp"
3535
#include "utils/offset_utils.hpp"
36+
#include "utils/sycl_utils.hpp"
3637
#include "utils/type_utils.hpp"
3738

3839
namespace dpctl
@@ -51,6 +52,9 @@ using dpctl::tensor::kernels::alignment_utils::
5152
using dpctl::tensor::kernels::alignment_utils::is_aligned;
5253
using dpctl::tensor::kernels::alignment_utils::required_alignment;
5354

55+
using dpctl::tensor::sycl_utils::sub_group_load;
56+
using dpctl::tensor::sycl_utils::sub_group_store;
57+
5458
template <typename T> T clip(const T &x, const T &min, const T &max)
5559
{
5660
using dpctl::tensor::type_utils::is_complex;
@@ -75,8 +79,8 @@ template <typename T> T clip(const T &x, const T &min, const T &max)
7579
}
7680

7781
template <typename T,
78-
int vec_sz = 4,
79-
int n_vecs = 2,
82+
std::uint8_t vec_sz = 4,
83+
std::uint8_t n_vecs = 2,
8084
bool enable_sg_loadstore = true>
8185
class ClipContigFunctor
8286
{
@@ -100,37 +104,36 @@ class ClipContigFunctor
100104

101105
void operator()(sycl::nd_item<1> ndit) const
102106
{
107+
constexpr std::uint8_t nelems_per_wi = n_vecs * vec_sz;
108+
103109
using dpctl::tensor::type_utils::is_complex;
104110
if constexpr (is_complex<T>::value || !enable_sg_loadstore) {
105-
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
106-
size_t base = ndit.get_global_linear_id();
107-
108-
base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
109-
for (size_t offset = base;
110-
offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz));
111-
offset += sgSize)
112-
{
111+
const std::uint16_t sgSize =
112+
ndit.get_sub_group().get_local_range()[0];
113+
const size_t gid = ndit.get_global_linear_id();
114+
const uint16_t nelems_per_sg = sgSize * nelems_per_wi;
115+
116+
const size_t start =
117+
(gid / sgSize) * (nelems_per_sg - sgSize) + gid;
118+
const size_t end = std::min(nelems, start + nelems_per_sg);
119+
120+
for (size_t offset = start; offset < end; offset += sgSize) {
113121
dst_p[offset] = clip(x_p[offset], min_p[offset], max_p[offset]);
114122
}
115123
}
116124
else {
117125
auto sg = ndit.get_sub_group();
118-
std::uint8_t sgSize = sg.get_local_range()[0];
119-
std::uint8_t max_sgSize = sg.get_max_local_range()[0];
120-
size_t base = n_vecs * vec_sz *
121-
(ndit.get_group(0) * ndit.get_local_range(0) +
122-
sg.get_group_id()[0] * max_sgSize);
123-
124-
if (base + n_vecs * vec_sz * sgSize < nelems &&
125-
sgSize == max_sgSize)
126-
{
127-
sycl::vec<T, vec_sz> x_vec;
128-
sycl::vec<T, vec_sz> min_vec;
129-
sycl::vec<T, vec_sz> max_vec;
126+
const std::uint16_t sgSize = sg.get_max_local_range()[0];
127+
128+
const size_t base =
129+
nelems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
130+
sg.get_group_id()[0] * sgSize);
131+
132+
if (base + nelems_per_wi * sgSize < nelems) {
130133
sycl::vec<T, vec_sz> dst_vec;
131134
#pragma unroll
132135
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
133-
auto idx = base + it * sgSize;
136+
const size_t idx = base + it * sgSize;
134137
auto x_multi_ptr = sycl::address_space_cast<
135138
sycl::access::address_space::global_space,
136139
sycl::access::decorated::yes>(&x_p[idx]);
@@ -144,21 +147,23 @@ class ClipContigFunctor
144147
sycl::access::address_space::global_space,
145148
sycl::access::decorated::yes>(&dst_p[idx]);
146149

147-
x_vec = sg.load<vec_sz>(x_multi_ptr);
148-
min_vec = sg.load<vec_sz>(min_multi_ptr);
149-
max_vec = sg.load<vec_sz>(max_multi_ptr);
150+
const sycl::vec<T, vec_sz> x_vec =
151+
sub_group_load<vec_sz>(sg, x_multi_ptr);
152+
const sycl::vec<T, vec_sz> min_vec =
153+
sub_group_load<vec_sz>(sg, min_multi_ptr);
154+
const sycl::vec<T, vec_sz> max_vec =
155+
sub_group_load<vec_sz>(sg, max_multi_ptr);
150156
#pragma unroll
151157
for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) {
152158
dst_vec[vec_id] = clip(x_vec[vec_id], min_vec[vec_id],
153159
max_vec[vec_id]);
154160
}
155-
sg.store<vec_sz>(dst_multi_ptr, dst_vec);
161+
sub_group_store<vec_sz>(sg, dst_vec, dst_multi_ptr);
156162
}
157163
}
158164
else {
159-
for (size_t k = base + sg.get_local_id()[0]; k < nelems;
160-
k += sgSize)
161-
{
165+
const size_t lane_id = sg.get_local_id()[0];
166+
for (size_t k = base + lane_id; k < nelems; k += sgSize) {
162167
dst_p[k] = clip(x_p[k], min_p[k], max_p[k]);
163168
}
164169
}
@@ -195,8 +200,8 @@ sycl::event clip_contig_impl(sycl::queue &q,
195200
cgh.depends_on(depends);
196201

197202
size_t lws = 64;
198-
constexpr unsigned int vec_sz = 4;
199-
constexpr unsigned int n_vecs = 2;
203+
constexpr std::uint8_t vec_sz = 4;
204+
constexpr std::uint8_t n_vecs = 2;
200205
const size_t n_groups =
201206
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
202207
const auto gws_range = sycl::range<1>(n_groups * lws);

dpctl/tensor/libtensor/include/kernels/where.hpp

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "dpctl_tensor_types.hpp"
3333
#include "kernels/alignment.hpp"
3434
#include "utils/offset_utils.hpp"
35+
#include "utils/sycl_utils.hpp"
3536
#include "utils/type_utils.hpp"
3637

3738
namespace dpctl
@@ -50,15 +51,18 @@ using dpctl::tensor::kernels::alignment_utils::
5051
using dpctl::tensor::kernels::alignment_utils::is_aligned;
5152
using dpctl::tensor::kernels::alignment_utils::required_alignment;
5253

54+
using dpctl::tensor::sycl_utils::sub_group_load;
55+
using dpctl::tensor::sycl_utils::sub_group_store;
56+
5357
template <typename T, typename condT, typename IndexerT>
5458
class where_strided_kernel;
55-
template <typename T, typename condT, int vec_sz, int n_vecs>
59+
template <typename T, typename condT, std::uint8_t vec_sz, std::uint8_t n_vecs>
5660
class where_contig_kernel;
5761

5862
template <typename T,
5963
typename condT,
60-
int vec_sz = 4,
61-
int n_vecs = 2,
64+
std::uint8_t vec_sz = 4u,
65+
std::uint8_t n_vecs = 2u,
6266
bool enable_sg_loadstore = true>
6367
class WhereContigFunctor
6468
{
@@ -82,42 +86,40 @@ class WhereContigFunctor
8286

8387
void operator()(sycl::nd_item<1> ndit) const
8488
{
89+
constexpr std::uint8_t nelems_per_wi = n_vecs * vec_sz;
90+
8591
using dpctl::tensor::type_utils::is_complex;
8692
if constexpr (!enable_sg_loadstore || is_complex<condT>::value ||
8793
is_complex<T>::value)
8894
{
89-
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
90-
size_t base = ndit.get_global_linear_id();
91-
92-
base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
93-
for (size_t offset = base;
94-
offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz));
95-
offset += sgSize)
96-
{
95+
const std::uint16_t sgSize =
96+
ndit.get_sub_group().get_local_range()[0];
97+
const size_t gid = ndit.get_global_linear_id();
98+
99+
const std::uint16_t nelems_per_sg = sgSize * nelems_per_wi;
100+
const size_t start =
101+
(gid / sgSize) * (nelems_per_sg - sgSize) + gid;
102+
const size_t end = std::min(nelems, start + nelems_per_sg);
103+
for (size_t offset = start; offset < end; offset += sgSize) {
97104
using dpctl::tensor::type_utils::convert_impl;
98-
bool check = convert_impl<bool, condT>(cond_p[offset]);
105+
const bool check = convert_impl<bool, condT>(cond_p[offset]);
99106
dst_p[offset] = check ? x1_p[offset] : x2_p[offset];
100107
}
101108
}
102109
else {
103110
auto sg = ndit.get_sub_group();
104-
std::uint8_t sgSize = sg.get_local_range()[0];
105-
std::uint8_t max_sgSize = sg.get_max_local_range()[0];
106-
size_t base = n_vecs * vec_sz *
107-
(ndit.get_group(0) * ndit.get_local_range(0) +
108-
sg.get_group_id()[0] * max_sgSize);
109-
110-
if (base + n_vecs * vec_sz * sgSize < nelems &&
111-
sgSize == max_sgSize)
112-
{
111+
const std::uint16_t sgSize = sg.get_max_local_range()[0];
112+
113+
const size_t base =
114+
nelems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
115+
sg.get_group_id()[0] * sgSize);
116+
117+
if (base + nelems_per_wi * sgSize < nelems) {
113118
sycl::vec<T, vec_sz> dst_vec;
114-
sycl::vec<T, vec_sz> x1_vec;
115-
sycl::vec<T, vec_sz> x2_vec;
116-
sycl::vec<condT, vec_sz> cond_vec;
117119

118120
#pragma unroll
119121
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
120-
auto idx = base + it * sgSize;
122+
const size_t idx = base + it * sgSize;
121123
auto x1_multi_ptr = sycl::address_space_cast<
122124
sycl::access::address_space::global_space,
123125
sycl::access::decorated::yes>(&x1_p[idx]);
@@ -131,20 +133,22 @@ class WhereContigFunctor
131133
sycl::access::address_space::global_space,
132134
sycl::access::decorated::yes>(&dst_p[idx]);
133135

134-
x1_vec = sg.load<vec_sz>(x1_multi_ptr);
135-
x2_vec = sg.load<vec_sz>(x2_multi_ptr);
136-
cond_vec = sg.load<vec_sz>(cond_multi_ptr);
136+
const sycl::vec<T, vec_sz> x1_vec =
137+
sub_group_load<vec_sz>(sg, x1_multi_ptr);
138+
const sycl::vec<T, vec_sz> x2_vec =
139+
sub_group_load<vec_sz>(sg, x2_multi_ptr);
140+
const sycl::vec<condT, vec_sz> cond_vec =
141+
sub_group_load<vec_sz>(sg, cond_multi_ptr);
137142
#pragma unroll
138143
for (std::uint8_t k = 0; k < vec_sz; ++k) {
139144
dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k];
140145
}
141-
sg.store<vec_sz>(dst_multi_ptr, dst_vec);
146+
sub_group_store<vec_sz>(sg, dst_vec, dst_multi_ptr);
142147
}
143148
}
144149
else {
145-
for (size_t k = base + sg.get_local_id()[0]; k < nelems;
146-
k += sgSize)
147-
{
150+
const size_t lane_id = sg.get_local_id()[0];
151+
for (size_t k = base + lane_id; k < nelems; k += sgSize) {
148152
dst_p[k] = cond_p[k] ? x1_p[k] : x2_p[k];
149153
}
150154
}
@@ -179,8 +183,8 @@ sycl::event where_contig_impl(sycl::queue &q,
179183
cgh.depends_on(depends);
180184

181185
size_t lws = 64;
182-
constexpr unsigned int vec_sz = 4;
183-
constexpr unsigned int n_vecs = 2;
186+
constexpr std::uint8_t vec_sz = 4u;
187+
constexpr std::uint8_t n_vecs = 2u;
184188
const size_t n_groups =
185189
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
186190
const auto gws_range = sycl::range<1>(n_groups * lws);

0 commit comments

Comments
 (0)