@@ -44,8 +44,8 @@ namespace copy_as_contig
44
44
45
45
template <typename T,
46
46
typename IndexerT,
47
- int vec_sz = 4 ,
48
- int n_vecs = 2 ,
47
+ unsigned int vec_sz = 4 ,
48
+ unsigned int n_vecs = 2 ,
49
49
bool enable_sg_loadstore = true >
50
50
class CopyAsCContigFunctor
51
51
{
@@ -66,52 +66,59 @@ class CopyAsCContigFunctor
66
66
67
67
void operator ()(sycl::nd_item<1 > ndit) const
68
68
{
69
+ constexpr std::uint8_t elems_per_wi =
70
+ static_cast <std::uint8_t >(vec_sz * n_vecs);
69
71
using dpctl::tensor::type_utils::is_complex;
72
+
70
73
if constexpr (!enable_sg_loadstore || is_complex<T>::value) {
71
- const std::uint32_t sgSize =
72
- ndit.get_sub_group ().get_local_range ()[0 ];
74
+ const std::uint16_t sgSize =
75
+ ndit.get_sub_group ().get_max_local_range ()[0 ];
73
76
const std::size_t gid = ndit.get_global_linear_id ();
74
77
75
- const std::size_t base =
76
- (gid / sgSize) * sgSize * n_vecs * vec_sz + (gid % sgSize);
77
- for (size_t offset = base;
78
- offset < std::min (nelems, base + sgSize * (n_vecs * vec_sz));
79
- offset += sgSize)
80
- {
78
+ // base = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize)
79
+ // gid % sgSize == gid - (gid / sgSize) * sgSize
80
+ const std::size_t elems_per_sg = sgSize * (elems_per_wi - 1 );
81
+ const std::size_t base = (gid / sgSize) * elems_per_sg + gid;
82
+ const std::size_t offset_max =
83
+ std::min (nelems, base + sgSize * elems_per_wi);
84
+
85
+ for (size_t offset = base; offset < offset_max; offset += sgSize) {
81
86
auto src_offset = src_indexer (offset);
82
87
dst_p[offset] = src_p[src_offset];
83
88
}
84
89
}
85
90
else {
86
91
auto sg = ndit.get_sub_group ();
87
- const std::uint32_t sgSize = sg.get_local_range ()[0 ];
88
- const size_t base = n_vecs * vec_sz *
89
- (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
90
- sg.get_group_id ()[0 ] * sgSize);
92
+ const std::uint16_t sgSize = sg.get_max_local_range ()[0 ];
93
+ const size_t base =
94
+ elems_per_wi * (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
95
+ sg.get_group_id ()[0 ] * sgSize);
91
96
92
- if (base + n_vecs * vec_sz * sgSize < nelems) {
97
+ if (base + elems_per_wi * sgSize < nelems) {
93
98
sycl::vec<T, vec_sz> dst_vec;
94
99
95
100
#pragma unroll
96
- for (std::uint32_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
101
+ for (std::uint8_t it = 0 ; it < elems_per_wi; it += vec_sz) {
102
+ const size_t block_start_id = base + it * sgSize;
103
+
97
104
auto dst_multi_ptr = sycl::address_space_cast<
98
105
sycl::access::address_space::global_space,
99
- sycl::access::decorated::yes>(
100
- &dst_p[base + it * sgSize]);
106
+ sycl::access::decorated::yes>(&dst_p[block_start_id]);
101
107
108
+ const size_t elem_id0 = block_start_id + sg.get_local_id ();
102
109
#pragma unroll
103
- for (std::uint32_t k = 0 ; k < vec_sz; k++) {
104
- ssize_t src_offset = src_indexer (
105
- base + (it + k) * sgSize + sg. get_local_id () );
110
+ for (std::uint8_t k = 0 ; k < vec_sz; k++) {
111
+ const size_t elem_id = elem_id0 + k * sgSize;
112
+ const ssize_t src_offset = src_indexer (elem_id );
106
113
dst_vec[k] = src_p[src_offset];
107
114
}
108
115
sg.store <vec_sz>(dst_multi_ptr, dst_vec);
109
116
}
110
117
}
111
118
else {
112
- for ( size_t k = base + sg.get_local_id ()[0 ]; k < nelems ;
113
- k += sgSize)
114
- {
119
+ const size_t lane_id = sg.get_local_id ()[0 ];
120
+ const size_t k0 = base + lane_id;
121
+ for ( size_t k = k0; k < nelems; k += sgSize) {
115
122
ssize_t src_offset = src_indexer (k);
116
123
dst_p[k] = src_p[src_offset];
117
124
}
@@ -122,8 +129,8 @@ class CopyAsCContigFunctor
122
129
123
130
template <typename T,
124
131
typename IndexT,
125
- int vec_sz,
126
- int n_vecs,
132
+ std:: uint32_t vec_sz,
133
+ std:: uint32_t n_vecs,
127
134
bool enable_sgload>
128
135
class as_contig_krn ;
129
136
@@ -154,6 +161,9 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
154
161
using dpctl::tensor::kernels::alignment_utils::is_aligned;
155
162
using dpctl::tensor::kernels::alignment_utils::required_alignment;
156
163
164
+ constexpr std::uint8_t nelems_per_wi =
165
+ static_cast <std::uint8_t >(n_vecs * vec_sz);
166
+
157
167
sycl::event copy_ev;
158
168
if (is_aligned<required_alignment>(src_p) &&
159
169
is_aligned<required_alignment>(dst_p))
@@ -177,7 +187,6 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
177
187
const std::size_t lws =
178
188
((preferred_lws + max_sg_size - 1 ) / max_sg_size) * max_sg_size;
179
189
180
- constexpr std::uint32_t nelems_per_wi = n_vecs * vec_sz;
181
190
size_t n_groups =
182
191
(nelems + nelems_per_wi * lws - 1 ) / (nelems_per_wi * lws);
183
192
@@ -216,7 +225,6 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
216
225
const std::size_t lws =
217
226
((preferred_lws + max_sg_size - 1 ) / max_sg_size) * max_sg_size;
218
227
219
- constexpr std::uint32_t nelems_per_wi = n_vecs * vec_sz;
220
228
size_t n_groups =
221
229
(nelems + nelems_per_wi * lws - 1 ) / (nelems_per_wi * lws);
222
230
0 commit comments