Skip to content

Commit 18f7ea0

Browse files
Added static assert to verify that n_vecs * vec_sz fits in uint8_t
Also reordered template parameters vec_sz, n_vecs for consistency with the wide code-base.
1 parent 04fd35c commit 18f7ea0

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

dpctl/tensor/libtensor/include/kernels/copy_as_contiguous.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ class CopyAsCContigFunctor
132132

133133
template <typename T,
134134
typename IndexerT,
135-
std::uint32_t n_vecs,
136135
std::uint32_t vec_sz,
136+
std::uint32_t n_vecs,
137137
bool enable_sg_load,
138138
typename KernelName>
139139
sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
@@ -143,6 +143,10 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q,
143143
const IndexerT &src_indexer,
144144
const std::vector<sycl::event> &depends)
145145
{
146+
static_assert(vec_sz > 0);
147+
static_assert(n_vecs > 0);
148+
static_assert(vec_sz * n_vecs < (std::uint32_t(1) << 8));
149+
146150
constexpr std::size_t preferred_lws = 256;
147151

148152
const auto &kernel_id = sycl::get_kernel_id<KernelName>();
@@ -206,8 +210,8 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
206210
using IndexerT = dpctl::tensor::offset_utils::StridedIndexer;
207211
const IndexerT src_indexer(nd, ssize_t(0), shape_and_strides);
208212

209-
constexpr std::uint32_t n_vecs = 2;
210-
constexpr std::uint32_t vec_sz = 4;
213+
constexpr std::uint32_t vec_sz = 4u;
214+
constexpr std::uint32_t n_vecs = 2u;
211215

212216
using dpctl::tensor::kernels::alignment_utils::
213217
disabled_sg_loadstore_wrapper_krn;
@@ -219,7 +223,7 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
219223
constexpr bool enable_sg_load = true;
220224
using KernelName =
221225
as_contig_krn<T, IndexerT, vec_sz, n_vecs, enable_sg_load>;
222-
copy_ev = submit_c_contiguous_copy<T, IndexerT, n_vecs, vec_sz,
226+
copy_ev = submit_c_contiguous_copy<T, IndexerT, vec_sz, n_vecs,
223227
enable_sg_load, KernelName>(
224228
exec_q, nelems, src_tp, dst_tp, src_indexer, depends);
225229
}
@@ -228,7 +232,7 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q,
228232
using InnerKernelName =
229233
as_contig_krn<T, IndexerT, vec_sz, n_vecs, disable_sg_load>;
230234
using KernelName = disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
231-
copy_ev = submit_c_contiguous_copy<T, IndexerT, n_vecs, vec_sz,
235+
copy_ev = submit_c_contiguous_copy<T, IndexerT, vec_sz, n_vecs,
232236
disable_sg_load, KernelName>(
233237
exec_q, nelems, src_tp, dst_tp, src_indexer, depends);
234238
}

0 commit comments

Comments
 (0)