diff --git a/dpctl/tensor/libtensor/include/kernels/copy_as_contiguous.hpp b/dpctl/tensor/libtensor/include/kernels/copy_as_contiguous.hpp index 8c0b7b65f5..c71e487572 100644 --- a/dpctl/tensor/libtensor/include/kernels/copy_as_contiguous.hpp +++ b/dpctl/tensor/libtensor/include/kernels/copy_as_contiguous.hpp @@ -44,8 +44,8 @@ namespace copy_as_contig template class CopyAsCContigFunctor { @@ -66,53 +66,63 @@ class CopyAsCContigFunctor void operator()(sycl::nd_item<1> ndit) const { + static_assert(vec_sz > 0); + static_assert(n_vecs > 0); + static_assert(vec_sz * n_vecs < (std::uint32_t(1) << 8)); + + constexpr std::uint8_t elems_per_wi = + static_cast(vec_sz * n_vecs); + using dpctl::tensor::type_utils::is_complex; if constexpr (!enable_sg_loadstore || is_complex::value) { - const std::uint32_t sgSize = + const std::uint16_t sgSize = ndit.get_sub_group().get_local_range()[0]; const std::size_t gid = ndit.get_global_linear_id(); - const std::size_t base = - (gid / sgSize) * sgSize * n_vecs * vec_sz + (gid % sgSize); - for (size_t offset = base; - offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz)); - offset += sgSize) - { + // base = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize) + // gid % sgSize == gid - (gid / sgSize) * sgSize + const std::size_t elems_per_sg = sgSize * (elems_per_wi - 1); + const std::size_t base = (gid / sgSize) * elems_per_sg + gid; + const std::size_t offset_max = + std::min(nelems, base + sgSize * elems_per_wi); + + for (size_t offset = base; offset < offset_max; offset += sgSize) { auto src_offset = src_indexer(offset); dst_p[offset] = src_p[src_offset]; } } else { auto sg = ndit.get_sub_group(); - const std::uint32_t sgSize = sg.get_local_range()[0]; - const size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * sgSize); + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + const size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); - if (base + n_vecs * vec_sz * sgSize < nelems) { + if (base + elems_per_wi * sgSize < nelems) { sycl::vec dst_vec; #pragma unroll - for (std::uint32_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + const size_t block_start_id = base + it * sgSize; auto dst_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, - sycl::access::decorated::yes>( - &dst_p[base + it * sgSize]); + sycl::access::decorated::yes>(&dst_p[block_start_id]); + const size_t elem_id0 = block_start_id + sg.get_local_id(); #pragma unroll - for (std::uint32_t k = 0; k < vec_sz; k++) { - ssize_t src_offset = src_indexer( - base + (it + k) * sgSize + sg.get_local_id()); + for (std::uint8_t k = 0; k < vec_sz; k++) { + const size_t elem_id = elem_id0 + k * sgSize; + const ssize_t src_offset = src_indexer(elem_id); dst_vec[k] = src_p[src_offset]; } sg.store(dst_multi_ptr, dst_vec); } } else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems; - k += sgSize) - { - ssize_t src_offset = src_indexer(k); + const size_t lane_id = sg.get_local_id()[0]; + const size_t k0 = base + lane_id; + for (size_t k = k0; k < nelems; k += sgSize) { + const ssize_t src_offset = src_indexer(k); dst_p[k] = src_p[src_offset]; } } @@ -122,8 +132,8 @@ class CopyAsCContigFunctor template sycl::event submit_c_contiguous_copy(sycl::queue &exec_q, @@ -133,6 +143,10 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q, const IndexerT &src_indexer, const std::vector &depends) { + static_assert(vec_sz > 0); + static_assert(n_vecs > 0); + static_assert(vec_sz * n_vecs < (std::uint32_t(1) << 8)); + constexpr std::size_t preferred_lws = 256; const auto &kernel_id = sycl::get_kernel_id(); @@ -150,9 +164,11 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q, const std::size_t lws = ((preferred_lws + max_sg_size - 1) / max_sg_size) * max_sg_size; - constexpr std::uint32_t nelems_per_wi = n_vecs * vec_sz; - size_t n_groups = - (nelems + nelems_per_wi * lws - 1) / (nelems_per_wi * lws); + constexpr std::uint8_t nelems_per_wi = n_vecs * vec_sz; + + const size_t nelems_per_group = nelems_per_wi * lws; + const size_t n_groups = + (nelems + nelems_per_group - 1) / (nelems_per_group); sycl::event copy_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -171,8 +187,8 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q, template class as_contig_krn; @@ -194,8 +210,8 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q, using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; const IndexerT src_indexer(nd, ssize_t(0), shape_and_strides); - constexpr std::uint32_t n_vecs = 2; - constexpr std::uint32_t vec_sz = 4; + constexpr std::uint32_t vec_sz = 4u; + constexpr std::uint32_t n_vecs = 2u; using dpctl::tensor::kernels::alignment_utils:: disabled_sg_loadstore_wrapper_krn; @@ -207,7 +223,7 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q, constexpr bool enable_sg_load = true; using KernelName = as_contig_krn; - copy_ev = submit_c_contiguous_copy( exec_q, nelems, src_tp, dst_tp, src_indexer, depends); } @@ -216,7 +232,7 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q, using InnerKernelName = as_contig_krn; using KernelName = disabled_sg_loadstore_wrapper_krn; - copy_ev = submit_c_contiguous_copy( exec_q, nelems, src_tp, dst_tp, src_indexer, depends); }