32
32
#include " dpctl_tensor_types.hpp"
33
33
#include " kernels/alignment.hpp"
34
34
#include " utils/offset_utils.hpp"
35
+ #include " utils/sycl_utils.hpp"
35
36
#include " utils/type_utils.hpp"
36
37
37
38
namespace dpctl
@@ -50,15 +51,18 @@ using dpctl::tensor::kernels::alignment_utils::
50
51
using dpctl::tensor::kernels::alignment_utils::is_aligned;
51
52
using dpctl::tensor::kernels::alignment_utils::required_alignment;
52
53
54
+ using dpctl::tensor::sycl_utils::sub_group_load;
55
+ using dpctl::tensor::sycl_utils::sub_group_store;
56
+
53
57
template <typename T, typename condT, typename IndexerT>
54
58
class where_strided_kernel ;
55
- template <typename T, typename condT, int vec_sz, int n_vecs>
59
+ template <typename T, typename condT, std:: uint8_t vec_sz, std:: uint8_t n_vecs>
56
60
class where_contig_kernel ;
57
61
58
62
template <typename T,
59
63
typename condT,
60
- int vec_sz = 4 ,
61
- int n_vecs = 2 ,
64
+ std:: uint8_t vec_sz = 4u ,
65
+ std:: uint8_t n_vecs = 2u ,
62
66
bool enable_sg_loadstore = true >
63
67
class WhereContigFunctor
64
68
{
@@ -82,42 +86,40 @@ class WhereContigFunctor
82
86
83
87
void operator ()(sycl::nd_item<1 > ndit) const
84
88
{
89
+ constexpr std::uint8_t nelems_per_wi = n_vecs * vec_sz;
90
+
85
91
using dpctl::tensor::type_utils::is_complex;
86
92
if constexpr (!enable_sg_loadstore || is_complex<condT>::value ||
87
93
is_complex<T>::value)
88
94
{
89
- std::uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
90
- size_t base = ndit.get_global_linear_id ();
91
-
92
- base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
93
- for (size_t offset = base;
94
- offset < std::min (nelems, base + sgSize * (n_vecs * vec_sz));
95
- offset += sgSize)
96
- {
95
+ const std::uint16_t sgSize =
96
+ ndit.get_sub_group ().get_local_range ()[0 ];
97
+ const size_t gid = ndit.get_global_linear_id ();
98
+
99
+ const std::uint16_t nelems_per_sg = sgSize * nelems_per_wi;
100
+ const size_t start =
101
+ (gid / sgSize) * (nelems_per_sg - sgSize) + gid;
102
+ const size_t end = std::min (nelems, start + nelems_per_sg);
103
+ for (size_t offset = start; offset < end; offset += sgSize) {
97
104
using dpctl::tensor::type_utils::convert_impl;
98
- bool check = convert_impl<bool , condT>(cond_p[offset]);
105
+ const bool check = convert_impl<bool , condT>(cond_p[offset]);
99
106
dst_p[offset] = check ? x1_p[offset] : x2_p[offset];
100
107
}
101
108
}
102
109
else {
103
110
auto sg = ndit.get_sub_group ();
104
- std::uint8_t sgSize = sg.get_local_range ()[0 ];
105
- std::uint8_t max_sgSize = sg.get_max_local_range ()[0 ];
106
- size_t base = n_vecs * vec_sz *
107
- (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
108
- sg.get_group_id ()[0 ] * max_sgSize);
109
-
110
- if (base + n_vecs * vec_sz * sgSize < nelems &&
111
- sgSize == max_sgSize)
112
- {
111
+ const std::uint16_t sgSize = sg.get_max_local_range ()[0 ];
112
+
113
+ const size_t base =
114
+ nelems_per_wi * (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
115
+ sg.get_group_id ()[0 ] * sgSize);
116
+
117
+ if (base + nelems_per_wi * sgSize < nelems) {
113
118
sycl::vec<T, vec_sz> dst_vec;
114
- sycl::vec<T, vec_sz> x1_vec;
115
- sycl::vec<T, vec_sz> x2_vec;
116
- sycl::vec<condT, vec_sz> cond_vec;
117
119
118
120
#pragma unroll
119
121
for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
120
- auto idx = base + it * sgSize;
122
+ const size_t idx = base + it * sgSize;
121
123
auto x1_multi_ptr = sycl::address_space_cast<
122
124
sycl::access::address_space::global_space,
123
125
sycl::access::decorated::yes>(&x1_p[idx]);
@@ -131,20 +133,22 @@ class WhereContigFunctor
131
133
sycl::access::address_space::global_space,
132
134
sycl::access::decorated::yes>(&dst_p[idx]);
133
135
134
- x1_vec = sg.load <vec_sz>(x1_multi_ptr);
135
- x2_vec = sg.load <vec_sz>(x2_multi_ptr);
136
- cond_vec = sg.load <vec_sz>(cond_multi_ptr);
136
+ const sycl::vec<T, vec_sz> x1_vec =
137
+ sub_group_load<vec_sz>(sg, x1_multi_ptr);
138
+ const sycl::vec<T, vec_sz> x2_vec =
139
+ sub_group_load<vec_sz>(sg, x2_multi_ptr);
140
+ const sycl::vec<condT, vec_sz> cond_vec =
141
+ sub_group_load<vec_sz>(sg, cond_multi_ptr);
137
142
#pragma unroll
138
143
for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
139
144
dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k];
140
145
}
141
- sg. store <vec_sz>(dst_multi_ptr , dst_vec);
146
+ sub_group_store <vec_sz>(sg , dst_vec, dst_multi_ptr );
142
147
}
143
148
}
144
149
else {
145
- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems;
146
- k += sgSize)
147
- {
150
+ const size_t lane_id = sg.get_local_id ()[0 ];
151
+ for (size_t k = base + lane_id; k < nelems; k += sgSize) {
148
152
dst_p[k] = cond_p[k] ? x1_p[k] : x2_p[k];
149
153
}
150
154
}
@@ -179,8 +183,8 @@ sycl::event where_contig_impl(sycl::queue &q,
179
183
cgh.depends_on (depends);
180
184
181
185
size_t lws = 64 ;
182
- constexpr unsigned int vec_sz = 4 ;
183
- constexpr unsigned int n_vecs = 2 ;
186
+ constexpr std:: uint8_t vec_sz = 4u ;
187
+ constexpr std:: uint8_t n_vecs = 2u ;
184
188
const size_t n_groups =
185
189
((nelems + lws * n_vecs * vec_sz - 1 ) / (lws * n_vecs * vec_sz));
186
190
const auto gws_range = sycl::range<1 >(n_groups * lws);
0 commit comments