diff --git a/CHANGELOG.md b/CHANGELOG.md index 5523c684ca..ec4b346c74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Fix additional warnings when generating docs [gh-1861](https://github.com/IntelPython/dpctl/pull/1861) * Add missing include of SYCL header to "math_utils.hpp" [gh-1899](https://github.com/IntelPython/dpctl/pull/1899) * Add support of CV-qualifiers in `is_complex` helper [gh-1900](https://github.com/IntelPython/dpctl/pull/1900) +* Tuning work for elementwise functions with modest performance gains (under 10%) [gh-1889](https://github.com/IntelPython/dpctl/pull/1889) ## [0.18.1] - Oct. 11, 2024 diff --git a/dpctl/tensor/libtensor/include/kernels/alignment.hpp b/dpctl/tensor/libtensor/include/kernels/alignment.hpp index ff4541af4d..9ec14dd027 100644 --- a/dpctl/tensor/libtensor/include/kernels/alignment.hpp +++ b/dpctl/tensor/libtensor/include/kernels/alignment.hpp @@ -30,7 +30,7 @@ namespace kernels namespace alignment_utils { -static constexpr size_t required_alignment = 64; +static constexpr size_t required_alignment = 64UL; template bool is_aligned(Ptr p) { diff --git a/dpctl/tensor/libtensor/include/kernels/clip.hpp b/dpctl/tensor/libtensor/include/kernels/clip.hpp index 7b422c1281..66bedfd1cd 100644 --- a/dpctl/tensor/libtensor/include/kernels/clip.hpp +++ b/dpctl/tensor/libtensor/include/kernels/clip.hpp @@ -33,6 +33,7 @@ #include "kernels/alignment.hpp" #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" #include "utils/type_utils.hpp" namespace dpctl @@ -51,6 +52,9 @@ using dpctl::tensor::kernels::alignment_utils:: using dpctl::tensor::kernels::alignment_utils::is_aligned; using dpctl::tensor::kernels::alignment_utils::required_alignment; +using dpctl::tensor::sycl_utils::sub_group_load; +using dpctl::tensor::sycl_utils::sub_group_store; + template T clip(const T &x, const T &min, const T &max) { using dpctl::tensor::type_utils::is_complex; @@ -75,8 +79,8 @@ template T clip(const T &x, const T &min, const T &max) } template class ClipContigFunctor { @@ -100,37 +104,36 @@ class ClipContigFunctor void operator()(sycl::nd_item<1> ndit) const { + constexpr std::uint8_t nelems_per_wi = n_vecs * vec_sz; + using dpctl::tensor::type_utils::is_complex; if constexpr (is_complex::value || !enable_sg_loadstore) { - std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; - size_t base = ndit.get_global_linear_id(); - - base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); - for (size_t offset = base; - offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz)); - offset += sgSize) - { + const std::uint16_t sgSize = + ndit.get_sub_group().get_local_range()[0]; + const size_t gid = ndit.get_global_linear_id(); + const uint16_t nelems_per_sg = sgSize * nelems_per_wi; + + const size_t start = + (gid / sgSize) * (nelems_per_sg - sgSize) + gid; + const size_t end = std::min(nelems, start + nelems_per_sg); + + for (size_t offset = start; offset < end; offset += sgSize) { dst_p[offset] = clip(x_p[offset], min_p[offset], max_p[offset]); } } else { auto sg = ndit.get_sub_group(); - std::uint8_t sgSize = sg.get_local_range()[0]; - std::uint8_t max_sgSize = sg.get_max_local_range()[0]; - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * max_sgSize); - - if (base + n_vecs * vec_sz * sgSize < nelems && - sgSize == max_sgSize) - { - sycl::vec x_vec; - sycl::vec min_vec; - sycl::vec max_vec; + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + + const size_t base = + nelems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if (base + nelems_per_wi * sgSize < nelems) { sycl::vec dst_vec; #pragma unroll for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { - auto idx = base + it * sgSize; + const size_t idx = base + it * sgSize; auto x_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, sycl::access::decorated::yes>(&x_p[idx]); @@ -144,21 +147,23 @@ class ClipContigFunctor sycl::access::address_space::global_space, sycl::access::decorated::yes>(&dst_p[idx]); - x_vec = sg.load(x_multi_ptr); - min_vec = sg.load(min_multi_ptr); - max_vec = sg.load(max_multi_ptr); + const sycl::vec x_vec = + sub_group_load(sg, x_multi_ptr); + const sycl::vec min_vec = + sub_group_load(sg, min_multi_ptr); + const sycl::vec max_vec = + sub_group_load(sg, max_multi_ptr); #pragma unroll for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) { dst_vec[vec_id] = clip(x_vec[vec_id], min_vec[vec_id], max_vec[vec_id]); } - sg.store(dst_multi_ptr, dst_vec); + sub_group_store(sg, dst_vec, dst_multi_ptr); } } else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems; - k += sgSize) - { + const size_t lane_id = sg.get_local_id()[0]; + for (size_t k = base + lane_id; k < nelems; k += sgSize) { dst_p[k] = clip(x_p[k], min_p[k], max_p[k]); } } @@ -195,8 +200,8 @@ sycl::event clip_contig_impl(sycl::queue &q, cgh.depends_on(depends); size_t lws = 64; - constexpr unsigned int vec_sz = 4; - constexpr unsigned int n_vecs = 2; + constexpr std::uint8_t vec_sz = 4; + constexpr std::uint8_t n_vecs = 2; const size_t n_groups = ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); const auto gws_range = sycl::range<1>(n_groups * lws); diff --git a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp index f48a5a287e..a4e7fceca1 100644 --- a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp +++ b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp @@ -31,6 +31,7 @@ #include "dpctl_tensor_types.hpp" #include "kernels/alignment.hpp" #include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" #include "utils/type_utils.hpp" namespace dpctl @@ -49,13 +50,16 @@ using dpctl::tensor::kernels::alignment_utils:: using dpctl::tensor::kernels::alignment_utils::is_aligned; using dpctl::tensor::kernels::alignment_utils::required_alignment; +using dpctl::tensor::sycl_utils::sub_group_load; +using dpctl::tensor::sycl_utils::sub_group_store; + template class copy_cast_generic_kernel; template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class copy_cast_contig_kernel; template @@ -207,8 +211,8 @@ template struct CopyAndCastGenericFactory template class ContigCopyFunctor { @@ -227,58 +231,55 @@ class ContigCopyFunctor { CastFnT fn{}; + constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz; + using dpctl::tensor::type_utils::is_complex; if constexpr (!enable_sg_loadstore || is_complex::value || is_complex::value) { - std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; - size_t base = ndit.get_global_linear_id(); - - base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); - for (size_t offset = base; - offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz)); - offset += sgSize) - { + std::uint16_t sgSize = ndit.get_sub_group().get_local_range()[0]; + const size_t gid = ndit.get_global_linear_id(); + + // start = (gid / sgSize) * elems_per_sg + (gid % sgSize) + const std::uint16_t elems_per_sg = sgSize * elems_per_wi; + const size_t start = (gid / sgSize) * (elems_per_sg - sgSize) + gid; + const size_t end = std::min(nelems, start + elems_per_sg); + for (size_t offset = start; offset < end; offset += sgSize) { dst_p[offset] = fn(src_p[offset]); } } else { auto sg = ndit.get_sub_group(); - std::uint8_t sgSize = sg.get_local_range()[0]; - std::uint8_t max_sgSize = sg.get_max_local_range()[0]; - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * max_sgSize); - - if (base + n_vecs * vec_sz * sgSize < nelems && - sgSize == max_sgSize) - { - sycl::vec src_vec; + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + const size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if (base + elems_per_wi * sgSize < nelems) { sycl::vec dst_vec; #pragma unroll for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + const size_t offset = base + it * sgSize; auto src_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, - sycl::access::decorated::yes>( - &src_p[base + it * sgSize]); + sycl::access::decorated::yes>(&src_p[offset]); auto dst_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, - sycl::access::decorated::yes>( - &dst_p[base + it * sgSize]); + sycl::access::decorated::yes>(&dst_p[offset]); - src_vec = sg.load(src_multi_ptr); + const sycl::vec src_vec = + sub_group_load(sg, src_multi_ptr); #pragma unroll for (std::uint8_t k = 0; k < vec_sz; k++) { dst_vec[k] = fn(src_vec[k]); } - sg.store(dst_multi_ptr, dst_vec); + sub_group_store(sg, dst_vec, dst_multi_ptr); } } else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems; - k += sgSize) - { + const size_t start = base + sg.get_local_id()[0]; + for (size_t k = start; k < nelems; k += sgSize) { dst_p[k] = fn(src_p[k]); } } @@ -332,8 +333,8 @@ sycl::event copy_and_cast_contig_impl(sycl::queue &q, dstTy *dst_tp = reinterpret_cast(dst_cp); size_t lws = 64; - constexpr unsigned int vec_sz = 4; - constexpr unsigned int n_vecs = 2; + constexpr std::uint32_t vec_sz = 4; + constexpr std::uint32_t n_vecs = 2; const size_t n_groups = ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); const auto gws_range = sycl::range<1>(n_groups * lws); diff --git a/dpctl/tensor/libtensor/include/kernels/copy_as_contiguous.hpp b/dpctl/tensor/libtensor/include/kernels/copy_as_contiguous.hpp index c71e487572..1a44946cc4 100644 --- a/dpctl/tensor/libtensor/include/kernels/copy_as_contiguous.hpp +++ b/dpctl/tensor/libtensor/include/kernels/copy_as_contiguous.hpp @@ -31,6 +31,7 @@ #include "dpctl_tensor_types.hpp" #include "kernels/alignment.hpp" #include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" #include "utils/type_utils.hpp" namespace dpctl @@ -42,10 +43,12 @@ namespace kernels namespace copy_as_contig { +using dpctl::tensor::sycl_utils::sub_group_store; + template class CopyAsCContigFunctor { @@ -68,25 +71,23 @@ class CopyAsCContigFunctor { static_assert(vec_sz > 0); static_assert(n_vecs > 0); - static_assert(vec_sz * n_vecs < (std::uint32_t(1) << 8)); - constexpr std::uint8_t elems_per_wi = - static_cast(vec_sz * n_vecs); + constexpr std::uint8_t elems_per_wi = vec_sz * n_vecs; using dpctl::tensor::type_utils::is_complex; if constexpr (!enable_sg_loadstore || is_complex::value) { const std::uint16_t sgSize = - ndit.get_sub_group().get_local_range()[0]; + ndit.get_sub_group().get_max_local_range()[0]; const std::size_t gid = ndit.get_global_linear_id(); - // base = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize) + // start = (gid / sgSize) * sgSize * elems_per_wi + (gid % sgSize) // gid % sgSize == gid - (gid / sgSize) * sgSize - const std::size_t elems_per_sg = sgSize * (elems_per_wi - 1); - const std::size_t base = (gid / sgSize) * elems_per_sg + gid; - const std::size_t offset_max = - std::min(nelems, base + sgSize * elems_per_wi); + const std::uint16_t elems_per_sg = sgSize * elems_per_wi; + const std::size_t start = + (gid / sgSize) * (elems_per_sg - sgSize) + gid; + const std::size_t end = std::min(nelems, start + elems_per_sg); - for (size_t offset = base; offset < offset_max; offset += sgSize) { + for (size_t offset = start; offset < end; offset += sgSize) { auto src_offset = src_indexer(offset); dst_p[offset] = src_p[src_offset]; } @@ -97,25 +98,26 @@ class CopyAsCContigFunctor const size_t base = elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + sg.get_group_id()[0] * sgSize); + const std::uint16_t elems_per_sg = elems_per_wi * sgSize; - if (base + elems_per_wi * sgSize < nelems) { - sycl::vec dst_vec; - + if (base + elems_per_sg < nelems) { #pragma unroll for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + // it == vec_id * vec_sz, for 0 <= vec_id < n_vecs const size_t block_start_id = base + it * sgSize; auto dst_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, sycl::access::decorated::yes>(&dst_p[block_start_id]); const size_t elem_id0 = block_start_id + sg.get_local_id(); + sycl::vec dst_vec; #pragma unroll - for (std::uint8_t k = 0; k < vec_sz; k++) { + for (std::uint8_t k = 0; k < vec_sz; ++k) { const size_t elem_id = elem_id0 + k * sgSize; const ssize_t src_offset = src_indexer(elem_id); dst_vec[k] = src_p[src_offset]; } - sg.store(dst_multi_ptr, dst_vec); + sub_group_store(sg, dst_vec, dst_multi_ptr); } } else { @@ -132,8 +134,8 @@ class CopyAsCContigFunctor template sycl::event submit_c_contiguous_copy(sycl::queue &exec_q, @@ -145,7 +147,6 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q, { static_assert(vec_sz > 0); static_assert(n_vecs > 0); - static_assert(vec_sz * n_vecs < (std::uint32_t(1) << 8)); constexpr std::size_t preferred_lws = 256; @@ -187,8 +188,8 @@ sycl::event submit_c_contiguous_copy(sycl::queue &exec_q, template class as_contig_krn; @@ -210,8 +211,8 @@ as_c_contiguous_array_generic_impl(sycl::queue &exec_q, using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; const IndexerT src_indexer(nd, ssize_t(0), shape_and_strides); - constexpr std::uint32_t vec_sz = 4u; - constexpr std::uint32_t n_vecs = 2u; + constexpr std::uint8_t vec_sz = 4u; + constexpr std::uint8_t n_vecs = 2u; using dpctl::tensor::kernels::alignment_utils:: disabled_sg_loadstore_wrapper_krn; @@ -256,8 +257,8 @@ template struct AsCContigFactory template + std::uint16_t tile_size, + std::uint16_t n_lines> class as_contig_batch_of_square_matrices_krn; namespace detail @@ -283,14 +284,14 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl( const T *src_tp = reinterpret_cast(src_p); T *dst_tp = reinterpret_cast(dst_p); - constexpr std::uint32_t private_tile_size = 4; - constexpr std::uint32_t n_lines = 2; - constexpr std::uint32_t block_size = + constexpr std::uint16_t private_tile_size = 4; + constexpr std::uint16_t n_lines = 2; + constexpr std::uint16_t block_size = n_lines * private_tile_size * private_tile_size; - constexpr std::uint32_t lws0 = block_size; - constexpr std::uint32_t lws1 = n_lines; - constexpr std::uint32_t nelems_per_wi = (block_size / lws1); + constexpr std::uint16_t lws0 = block_size; + constexpr std::uint16_t lws1 = n_lines; + constexpr std::uint16_t nelems_per_wi = (block_size / lws1); static_assert(nelems_per_wi * lws1 == block_size); static_assert(nelems_per_wi == private_tile_size * private_tile_size); @@ -377,40 +378,41 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl( std::array private_block_01 = {T(0)}; std::array private_block_10 = {T(0)}; - // 0 <= lid_lin < lws0 * lws1 == (block_size * block_size / - // nelems_per_wi) == (block_size/private_tile_size)**2 - constexpr std::uint32_t n_private_tiles_per_axis = + // 0 <= lid_lin < lws0 * lws1 == + // (block_size * block_size / nelems_per_wi) == + // (block_size/private_tile_size)**2 + constexpr std::uint16_t n_private_tiles_per_axis = block_size / private_tile_size; - const std::uint32_t local_tile_id0 = + const std::uint16_t local_tile_id0 = lid_lin / n_private_tiles_per_axis; - const std::uint32_t local_tile_id1 = + const std::uint16_t local_tile_id1 = lid_lin - local_tile_id0 * n_private_tiles_per_axis; if (local_tile_id0 <= local_tile_id1) { - for (std::uint32_t pr_i0 = 0; pr_i0 < private_tile_size; + for (std::uint16_t pr_i0 = 0; pr_i0 < private_tile_size; ++pr_i0) { - for (std::uint32_t pr_i1 = 0; pr_i1 < private_tile_size; + for (std::uint16_t pr_i1 = 0; pr_i1 < private_tile_size; ++pr_i1) { - const std::uint32_t t0_offset = + const std::uint16_t t0_offset = local_tile_id0 * private_tile_size; - const std::uint32_t t1_offset = + const std::uint16_t t1_offset = local_tile_id1 * private_tile_size; - const std::uint32_t pr_offset = + const std::uint16_t pr_offset = pr_i1 * private_tile_size + pr_i0; - const std::uint32_t rel_offset = + const std::uint16_t rel_offset = pr_i0 + pr_i1 * block_size; // read (local_tile_id0, local_tile_id1) - const std::uint32_t local_01_offset = + const std::uint16_t local_01_offset = (t0_offset + t1_offset * block_size) + rel_offset; private_block_01[pr_offset] = local_block[local_01_offset]; // read (local_tile_id1, local_tile_id0) - const std::uint32_t local_10_offset = + const std::uint16_t local_10_offset = (t1_offset + t0_offset * block_size) + rel_offset; private_block_10[pr_offset] = local_block[local_10_offset]; @@ -422,20 +424,20 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl( sycl::memory_scope::work_group); if (local_tile_id0 <= local_tile_id1) { - for (std::uint32_t pr_i0 = 0; pr_i0 < private_tile_size; + for (std::uint16_t pr_i0 = 0; pr_i0 < private_tile_size; ++pr_i0) { - for (std::uint32_t pr_i1 = 0; pr_i1 < private_tile_size; + for (std::uint16_t pr_i1 = 0; pr_i1 < private_tile_size; ++pr_i1) { - const std::uint32_t t0_offset = + const std::uint16_t t0_offset = local_tile_id0 * private_tile_size; - const std::uint32_t t1_offset = + const std::uint16_t t1_offset = local_tile_id1 * private_tile_size; - const std::uint32_t pr_offset = + const std::uint16_t pr_offset = pr_i0 * private_tile_size + pr_i1; - const std::uint32_t rel_offset = + const std::uint16_t rel_offset = pr_i0 + pr_i1 * block_size; // write back permuted private blocks @@ -444,7 +446,7 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl( local_block[local_01_offset] = private_block_10[pr_offset]; - const std::uint32_t local_10_offset = + const std::uint16_t local_10_offset = (t1_offset + t0_offset * block_size) + rel_offset; local_block[local_10_offset] = private_block_01[pr_offset]; @@ -461,8 +463,8 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl( const std::size_t dst_tile_start1 = src_tile_start1; if (local_dim0 == block_size && local_dim1 == block_size) { - const std::uint32_t dst_i0 = src_i1; - const std::uint32_t dst_i1 = src_i0; + const std::uint16_t dst_i0 = src_i1; + const std::uint16_t dst_i1 = src_i0; const std::size_t dst_gid0 = (dst_tile_start0 + dst_i0); const std::size_t dst_gid1 = (dst_tile_start1 + dst_i1); @@ -471,11 +473,11 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl( dst_batch_offset + dst_gid0 * dst_stride + dst_gid1 * 1; const std::size_t pr_step_dst = lws1 * dst_stride; - const std::uint32_t _local_offset0 = + const std::uint16_t _local_offset0 = dst_i0 * block_size + dst_i1; - const std::uint32_t _pr_step_local = lws1 * block_size; + const std::uint16_t _pr_step_local = lws1 * block_size; - for (std::uint32_t pr_id = 0; pr_id < nelems_per_wi; ++pr_id) { + for (std::uint16_t pr_id = 0; pr_id < nelems_per_wi; ++pr_id) { if ((dst_gid1 < n) && ((dst_gid0 + pr_id * lws1) < n)) { dst_tp[dst_offset0 + pr_step_dst * pr_id] = local_block[_local_offset0 + @@ -485,24 +487,24 @@ sycl::event as_c_contiguous_batch_of_square_matrices_impl( } else { // map local_linear_id into (local_dim0, local_dim1) - for (std::uint32_t el_id = lid_lin; + for (std::uint16_t el_id = lid_lin; el_id < local_dim0 * local_dim1; el_id += lws0 * lws1) { // 0 <= local_i0 < local_dim0 - const std::uint32_t loc_i0 = el_id / local_dim1; + const std::uint16_t loc_i0 = el_id / local_dim1; // 0 <= local_i1 < local_dim1 - const std::uint32_t loc_i1 = el_id - loc_i0 * local_dim1; + const std::uint16_t loc_i1 = el_id - loc_i0 * local_dim1; - const std::uint32_t dst_i0 = loc_i0; - const std::uint32_t dst_i1 = loc_i1; + const std::uint16_t dst_i0 = loc_i0; + const std::uint16_t dst_i1 = loc_i1; const std::size_t dst_gid0 = (dst_tile_start0 + dst_i0); const std::size_t dst_gid1 = (dst_tile_start1 + dst_i1); const std::size_t dst_offset = dst_batch_offset + dst_gid0 * dst_stride + dst_gid1 * 1; - const std::uint32_t local_offset = + const std::uint16_t local_offset = loc_i0 * block_size + loc_i1; if ((dst_gid1 < n) && (dst_gid0 < n)) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index 411040bada..0dd315fc9d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -32,9 +32,11 @@ #include #include "cabs_impl.hpp" -#include "kernels/elementwise_functions/common.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -89,8 +91,8 @@ template struct AbsFunctor template using AbsContigFunctor = elementwise_common::UnaryContigFunctor struct AbsOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; + +template struct AbsContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // namespace + +template class abs_contig_kernel; template @@ -132,9 +152,12 @@ sycl::event abs_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = AbsContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vec = AbsContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, AbsOutputType, AbsContigFunctor, abs_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, AbsOutputType, AbsContigFunctor, abs_contig_kernel, vec_sz, + n_vec>(exec_q, nelems, arg_p, res_p, depends); } template struct AbsContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp index a90f4e699f..47c69d5190 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp @@ -29,10 +29,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -128,8 +130,8 @@ template struct AcosFunctor template using AcosContigFunctor = elementwise_common::UnaryContigFunctor struct AcosOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; + +template struct AcosContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // namespace + +template class acos_contig_kernel; template @@ -166,9 +186,12 @@ sycl::event acos_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = AcosContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vec = AcosContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, AcosOutputType, AcosContigFunctor, acos_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, AcosOutputType, AcosContigFunctor, acos_contig_kernel, vec_sz, + n_vec>(exec_q, nelems, arg_p, res_p, depends); } template struct AcosContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp index 8af3708427..f199be5a7e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp @@ -29,10 +29,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -155,8 +157,8 @@ template struct AcoshFunctor template using AcoshContigFunctor = elementwise_common::UnaryContigFunctor struct AcoshOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct AcoshContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // namespace + +template class acosh_contig_kernel; template @@ -193,9 +214,12 @@ sycl::event acosh_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = AcoshContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vec = AcoshContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, AcoshOutputType, AcoshContigFunctor, acosh_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, AcoshOutputType, AcoshContigFunctor, acosh_contig_kernel, vec_sz, + n_vec>(exec_q, nelems, arg_p, res_p, depends); } template struct AcoshContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index c06e98f3e5..69f63b53c0 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -30,6 +30,8 @@ #include #include "sycl_complex.hpp" +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -110,8 +112,8 @@ template struct AddFunctor template using AddContigFunctor = elementwise_common::BinaryContigFunctor struct AddOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template struct AddContigHyperparameterSet +{ + using value_type = typename std::disjunction< + BinaryContigHyperparameterSetEntry, + BinaryContigHyperparameterSetEntry, + BinaryContigHyperparameterSetEntry, + BinaryContigHyperparameterSetEntry, + BinaryContigHyperparameterSetEntry, + BinaryContigHyperparameterSetEntry, + ContigHyperparameterSetDefault<4u, 2u>>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class add_contig_kernel; template @@ -214,10 +271,13 @@ sycl::event add_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr auto vec_sz = AddContigHyperparameterSet::vec_sz; + constexpr auto n_vecs = AddContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< - argTy1, argTy2, AddOutputType, AddContigFunctor, add_contig_kernel>( - exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, - res_offset, depends); + argTy1, argTy2, AddOutputType, AddContigFunctor, add_contig_kernel, + vec_sz, n_vecs>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, + arg2_offset, res_p, res_offset, depends); } template struct AddContigFactory @@ -410,8 +470,8 @@ template struct AddInplaceFunctor template using AddInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor< argT, @@ -431,8 +491,8 @@ using AddInplaceStridedFunctor = template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class add_inplace_contig_kernel; /* @brief Types supported by in-place add */ @@ -489,9 +549,13 @@ add_inplace_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr auto vec_sz = AddContigHyperparameterSet::vec_sz; + constexpr auto n_vecs = AddContigHyperparameterSet::n_vecs; + return elementwise_common::binary_inplace_contig_impl< - argTy, resTy, AddInplaceContigFunctor, add_inplace_contig_kernel>( - exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); + argTy, resTy, AddInplaceContigFunctor, add_inplace_contig_kernel, + vec_sz, n_vecs>(exec_q, nelems, arg_p, arg_offset, res_p, res_offset, + depends); } template struct AddInplaceContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp index 034b71438f..670b9c10f8 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp @@ -30,10 +30,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -74,8 +76,8 @@ template struct AngleFunctor template using AngleContigFunctor = elementwise_common::UnaryContigFunctor struct AngleOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct AngleContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class angle_contig_kernel; template @@ -109,9 +130,12 @@ sycl::event angle_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = AngleContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vec = AngleContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, AngleOutputType, AngleContigFunctor, angle_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, AngleOutputType, AngleContigFunctor, angle_contig_kernel, vec_sz, + n_vec>(exec_q, nelems, arg_p, res_p, depends); } template struct AngleContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp index 35c381aa84..db7ec5723e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp @@ -29,10 +29,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -148,8 +150,8 @@ template struct AsinFunctor template using AsinContigFunctor = elementwise_common::UnaryContigFunctor struct AsinOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct AsinContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class asin_contig_kernel; template @@ -186,9 +207,12 @@ sycl::event asin_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = AsinContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vec = AsinContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, AsinOutputType, AsinContigFunctor, asin_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, AsinOutputType, AsinContigFunctor, asin_contig_kernel, vec_sz, + n_vec>(exec_q, nelems, arg_p, res_p, depends); } template struct AsinContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp index 7373dc39d5..9b58d7ad19 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp @@ -29,10 +29,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -131,8 +133,8 @@ template struct AsinhFunctor template using AsinhContigFunctor = elementwise_common::UnaryContigFunctor struct AsinhOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct AsinhContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class asinh_contig_kernel; template @@ -169,9 +190,12 @@ sycl::event asinh_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = AsinhContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vec = AsinhContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, AsinhOutputType, AsinhContigFunctor, asinh_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, AsinhOutputType, AsinhContigFunctor, asinh_contig_kernel, vec_sz, + n_vec>(exec_q, nelems, arg_p, res_p, depends); } template struct AsinhContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp index fbba3fc436..3f96f95526 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp @@ -30,10 +30,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -49,6 +51,9 @@ namespace atan namespace td_ns = dpctl::tensor::type_dispatch; +using dpctl::tensor::kernels::vec_size_utils::ContigHyperparameterSetDefault; +using dpctl::tensor::kernels::vec_size_utils::UnaryContigHyperparameterSetEntry; + using dpctl::tensor::type_utils::is_complex; template struct AtanFunctor @@ -138,8 +143,8 @@ template struct AtanFunctor template using AtanContigFunctor = elementwise_common::UnaryContigFunctor struct AtanOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct AtanContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class atan_contig_kernel; template @@ -176,9 +200,12 @@ sycl::event atan_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = AtanContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vec = AtanContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, AtanOutputType, AtanContigFunctor, atan_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, AtanOutputType, AtanContigFunctor, atan_contig_kernel, vec_sz, + n_vec>(exec_q, nelems, arg_p, res_p, depends); } template struct AtanContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp index 1a694527dd..37bd66fb54 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp @@ -29,6 +29,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -68,8 +70,8 @@ template struct Atan2Functor template using Atan2ContigFunctor = elementwise_common::BinaryContigFunctor struct Atan2OutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template struct Atan2ContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class atan2_contig_kernel; template @@ -121,10 +142,16 @@ sycl::event atan2_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + Atan2ContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + Atan2ContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, Atan2OutputType, Atan2ContigFunctor, - atan2_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + atan2_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg1_p, + arg1_offset, arg2_p, arg2_offset, + res_p, res_offset, depends); } template struct Atan2ContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp index 340e72b11c..25c15ef614 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp @@ -30,10 +30,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -132,8 +134,8 @@ template struct AtanhFunctor template using AtanhContigFunctor = elementwise_common::UnaryContigFunctor struct AtanhOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct AtanhContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class atanh_contig_kernel; template @@ -170,9 +191,12 @@ sycl::event atanh_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = AtanhContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vec = AtanhContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, AtanhOutputType, AtanhContigFunctor, atanh_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, AtanhOutputType, AtanhContigFunctor, atanh_contig_kernel, vec_sz, + n_vec>(exec_q, nelems, arg_p, res_p, depends); } template struct AtanhContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp index da32b17183..45a03c913d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp @@ -28,6 +28,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -91,8 +93,8 @@ struct BitwiseAndFunctor template using BitwiseAndContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -160,11 +162,30 @@ template struct BitwiseAndOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct BitwiseAndContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; +} // namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class bitwise_and_contig_kernel; template @@ -179,10 +200,16 @@ bitwise_and_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + BitwiseAndContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vec = + BitwiseAndContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, BitwiseAndOutputType, BitwiseAndContigFunctor, - bitwise_and_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + bitwise_and_contig_kernel, vec_sz, n_vec>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct BitwiseAndContigFactory @@ -290,8 +317,8 @@ template struct BitwiseAndInplaceFunctor template using BitwiseAndInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor< @@ -312,8 +339,8 @@ using BitwiseAndInplaceStridedFunctor = template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class bitwise_and_inplace_contig_kernel; /* @brief Types supported by in-place bitwise AND */ @@ -361,10 +388,15 @@ bitwise_and_inplace_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + BitwiseAndContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + BitwiseAndContigHyperparameterSet::n_vecs; + return elementwise_common::binary_inplace_contig_impl< argTy, resTy, BitwiseAndInplaceContigFunctor, - bitwise_and_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset, - res_p, res_offset, depends); + bitwise_and_inplace_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp index d6c1bc72db..582da57c29 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp @@ -30,6 +30,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -80,8 +82,8 @@ template struct BitwiseInvertFunctor template using BitwiseInvertContigFunctor = elementwise_common::UnaryContigFunctor struct BitwiseInvertOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct BitwiseInvertContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class bitwise_invert_contig_kernel; template @@ -126,10 +147,15 @@ bitwise_invert_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { - return elementwise_common::unary_contig_impl( - exec_q, nelems, arg_p, res_p, depends); + constexpr std::uint8_t vec_sz = + BitwiseInvertContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vec = + BitwiseInvertContigHyperparameterSet::n_vecs; + + return elementwise_common::unary_contig_impl< + argTy, BitwiseInvertOutputType, BitwiseInvertContigFunctor, + bitwise_invert_contig_kernel, vec_sz, n_vec>(exec_q, nelems, arg_p, + res_p, depends); } template struct BitwiseInvertContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp index a987c8d604..8cb0dcc9d0 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp @@ -29,6 +29,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -100,8 +102,8 @@ struct BitwiseLeftShiftFunctor template using BitwiseLeftShiftContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -169,11 +171,31 @@ template struct BitwiseLeftShiftOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct BitwiseLeftShiftContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class bitwise_left_shift_contig_kernel; template @@ -188,11 +210,16 @@ bitwise_left_shift_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + BitwiseLeftShiftContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + BitwiseLeftShiftContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, BitwiseLeftShiftOutputType, - BitwiseLeftShiftContigFunctor, bitwise_left_shift_contig_kernel>( - exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, - res_offset, depends); + BitwiseLeftShiftContigFunctor, bitwise_left_shift_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template @@ -304,8 +331,8 @@ template struct BitwiseLeftShiftInplaceFunctor template using BitwiseLeftShiftInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor< @@ -326,8 +353,8 @@ using BitwiseLeftShiftInplaceStridedFunctor = template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class bitwise_left_shift_inplace_contig_kernel; /* @brief Types supported by in-place bitwise left shift */ @@ -375,9 +402,14 @@ sycl::event bitwise_left_shift_inplace_contig_impl( ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + BitwiseLeftShiftContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + BitwiseLeftShiftContigHyperparameterSet::n_vecs; + return elementwise_common::binary_inplace_contig_impl< argTy, resTy, BitwiseLeftShiftInplaceContigFunctor, - bitwise_left_shift_inplace_contig_kernel>( + bitwise_left_shift_inplace_contig_kernel, vec_sz, n_vecs>( exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp index 71f3e809d9..e1de5be474 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp @@ -28,6 +28,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -90,8 +92,8 @@ template struct BitwiseOrFunctor template using BitwiseOrContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -159,11 +161,31 @@ template struct BitwiseOrOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct BitwiseOrContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class bitwise_or_contig_kernel; template @@ -177,10 +199,16 @@ sycl::event bitwise_or_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + BitwiseOrContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + BitwiseOrContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, BitwiseOrOutputType, BitwiseOrContigFunctor, - bitwise_or_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + bitwise_or_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct BitwiseOrContigFactory @@ -286,8 +314,8 @@ template struct BitwiseOrInplaceFunctor template using BitwiseOrInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor< @@ -308,8 +336,8 @@ using BitwiseOrInplaceStridedFunctor = template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class bitwise_or_inplace_contig_kernel; /* @brief Types supported by in-place bitwise OR */ @@ -355,10 +383,15 @@ bitwise_or_inplace_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + BitwiseOrContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + BitwiseOrContigHyperparameterSet::n_vecs; + return elementwise_common::binary_inplace_contig_impl< argTy, resTy, BitwiseOrInplaceContigFunctor, - bitwise_or_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset, - res_p, res_offset, depends); + bitwise_or_inplace_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp index e4dfee2ed6..35d3352c41 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp @@ -29,6 +29,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -101,8 +103,8 @@ struct BitwiseRightShiftFunctor template using BitwiseRightShiftContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -171,11 +173,31 @@ template struct BitwiseRightShiftOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct BitwiseRightShiftContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class bitwise_right_shift_contig_kernel; template @@ -190,11 +212,16 @@ bitwise_right_shift_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + BitwiseRightShiftContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + BitwiseRightShiftContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, BitwiseRightShiftOutputType, - BitwiseRightShiftContigFunctor, bitwise_right_shift_contig_kernel>( - exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, - res_offset, depends); + BitwiseRightShiftContigFunctor, bitwise_right_shift_contig_kernel, + vec_sz, n_vecs>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, + arg2_offset, res_p, res_offset, depends); } template @@ -308,8 +335,8 @@ template struct BitwiseRightShiftInplaceFunctor template using BitwiseRightShiftInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor< @@ -330,8 +357,8 @@ using BitwiseRightShiftInplaceStridedFunctor = template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class bitwise_right_shift_inplace_contig_kernel; /* @brief Types supported by in-place bitwise right shift */ @@ -379,9 +406,15 @@ sycl::event bitwise_right_shift_inplace_contig_impl( ssize_t res_offset, const std::vector &depends = {}) { + // res = OP(res, arg) + constexpr std::uint8_t vec_sz = + BitwiseRightShiftContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + BitwiseRightShiftContigHyperparameterSet::n_vecs; + return elementwise_common::binary_inplace_contig_impl< argTy, resTy, BitwiseRightShiftInplaceContigFunctor, - bitwise_right_shift_inplace_contig_kernel>( + bitwise_right_shift_inplace_contig_kernel, vec_sz, n_vecs>( exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp index d035b31170..fb18128cc1 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp @@ -28,6 +28,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -91,8 +93,8 @@ struct BitwiseXorFunctor template using BitwiseXorContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -160,11 +162,31 @@ template struct BitwiseXorOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct BitwiseXorContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class bitwise_xor_contig_kernel; template @@ -179,10 +201,16 @@ bitwise_xor_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + BitwiseXorContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + BitwiseXorContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, BitwiseXorOutputType, BitwiseXorContigFunctor, - bitwise_xor_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + bitwise_xor_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct BitwiseXorContigFactory @@ -290,8 +318,8 @@ template struct BitwiseXorInplaceFunctor template using BitwiseXorInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor< @@ -312,8 +340,8 @@ using BitwiseXorInplaceStridedFunctor = template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class bitwise_xor_inplace_contig_kernel; /* @brief Types supported by in-place bitwise XOR */ @@ -361,10 +389,15 @@ bitwise_xor_inplace_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + BitwiseXorContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + BitwiseXorContigHyperparameterSet::n_vecs; + return elementwise_common::binary_inplace_contig_impl< argTy, resTy, BitwiseXorInplaceContigFunctor, - bitwise_xor_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset, - res_p, res_offset, depends); + bitwise_xor_inplace_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp index e61304bed8..fc42d2d4ba 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp @@ -57,21 +57,16 @@ template realT cabs(std::complex const &z) constexpr realT q_nan = std::numeric_limits::quiet_NaN(); constexpr realT p_inf = std::numeric_limits::infinity(); - if (std::isinf(x)) { - return p_inf; - } - else if (std::isinf(y)) { - return p_inf; - } - else if (std::isnan(x)) { - return q_nan; - } - else if (std::isnan(y)) { - return q_nan; - } - else { - return exprm_ns::abs(exprm_ns::complex(z)); - } + const realT res = + std::isinf(x) + ? p_inf + : ((std::isinf(y) + ? p_inf + : ((std::isnan(x) + ? q_nan + : exprm_ns::abs(exprm_ns::complex(z)))))); + + return res; } } // namespace detail diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp index 4f2634f17a..a071558a5f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp @@ -30,6 +30,8 @@ #include #include +#include "vec_size_util.hpp" + #include "kernels/elementwise_functions/common.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -65,8 +67,8 @@ template struct CbrtFunctor template using CbrtContigFunctor = elementwise_common::UnaryContigFunctor struct CbrtOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct CbrtContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class cbrt_contig_kernel; template @@ -101,9 +122,12 @@ sycl::event cbrt_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = CbrtContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = CbrtContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, CbrtOutputType, CbrtContigFunctor, cbrt_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, CbrtOutputType, CbrtContigFunctor, cbrt_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct CbrtContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp index 59bc630720..ab7610088f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp @@ -29,6 +29,8 @@ #include #include +#include "vec_size_util.hpp" + #include "kernels/elementwise_functions/common.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -78,8 +80,8 @@ template struct CeilFunctor template using CeilContigFunctor = elementwise_common::UnaryContigFunctor struct CeilOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct CeilContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class ceil_contig_kernel; template @@ -122,9 +143,12 @@ sycl::event ceil_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = CeilContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = CeilContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, CeilOutputType, CeilContigFunctor, ceil_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, CeilOutputType, CeilContigFunctor, ceil_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct CeilContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp index ee955dcde5..7efd4b02ee 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp @@ -33,6 +33,7 @@ #include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" #include "utils/sycl_alloc_utils.hpp" +#include "utils/sycl_utils.hpp" namespace dpctl { @@ -48,12 +49,15 @@ using dpctl::tensor::kernels::alignment_utils:: using dpctl::tensor::kernels::alignment_utils::is_aligned; using dpctl::tensor::kernels::alignment_utils::required_alignment; +using dpctl::tensor::sycl_utils::sub_group_load; +using dpctl::tensor::sycl_utils::sub_group_store; + /*! @brief Functor for unary function evaluation on contiguous array */ template struct UnaryContigFunctor { @@ -70,9 +74,10 @@ struct UnaryContigFunctor void operator()(sycl::nd_item<1> ndit) const { + constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz; UnaryOperatorT op{}; /* Each work-item processes vec_sz elements, contiguous in memory */ - /* NOTE: vec_sz must divide sg.max_local_range()[0] */ + /* NOTE: work-group size must be divisible by sub-group size */ if constexpr (enable_sg_loadstore && UnaryOperatorT::is_constant::value) { @@ -80,53 +85,44 @@ struct UnaryContigFunctor constexpr resT const_val = UnaryOperatorT::constant_value; auto sg = ndit.get_sub_group(); - std::uint8_t sgSize = sg.get_local_range()[0]; - std::uint8_t max_sgSize = sg.get_max_local_range()[0]; - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * sgSize); - if (base + n_vecs * vec_sz * sgSize < nelems_ && - max_sgSize == sgSize) - { - sycl::vec res_vec(const_val); + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + + const size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + if (base + elems_per_wi * sgSize < nelems_) { + constexpr sycl::vec res_vec(const_val); #pragma unroll - for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { - size_t offset = base + static_cast(it) * - static_cast(sgSize); + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + const size_t offset = base + it * sgSize; auto out_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, sycl::access::decorated::yes>(&out[offset]); - sg.store(out_multi_ptr, res_vec); + sub_group_store(sg, res_vec, out_multi_ptr); } } else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems_; - k += sgSize) - { + const size_t lane_id = sg.get_local_id()[0]; + for (size_t k = base + lane_id; k < nelems_; k += sgSize) { out[k] = const_val; } } } else if constexpr (enable_sg_loadstore && UnaryOperatorT::supports_sg_loadstore::value && - UnaryOperatorT::supports_vec::value) + UnaryOperatorT::supports_vec::value && (vec_sz > 1)) { auto sg = ndit.get_sub_group(); - std::uint16_t sgSize = sg.get_local_range()[0]; - std::uint16_t max_sgSize = sg.get_max_local_range()[0]; - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * max_sgSize); - if (base + n_vecs * vec_sz * sgSize < nelems_ && - sgSize == max_sgSize) - { - sycl::vec x; + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + const size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + if (base + elems_per_wi * sgSize < nelems_) { #pragma unroll - for (std::uint16_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { - size_t offset = base + static_cast(it) * - static_cast(sgSize); + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + const size_t offset = base + it * sgSize; auto in_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, sycl::access::decorated::yes>(&in[offset]); @@ -134,15 +130,15 @@ struct UnaryContigFunctor sycl::access::address_space::global_space, sycl::access::decorated::yes>(&out[offset]); - x = sg.load(in_multi_ptr); - sycl::vec res_vec = op(x); - sg.store(out_multi_ptr, res_vec); + const sycl::vec x = + sub_group_load(sg, in_multi_ptr); + const sycl::vec res_vec = op(x); + sub_group_store(sg, res_vec, out_multi_ptr); } } else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems_; - k += sgSize) - { + const size_t lane_id = sg.get_local_id()[0]; + for (size_t k = base + lane_id; k < nelems_; k += sgSize) { // scalar call out[k] = op(in[k]); } @@ -155,21 +151,15 @@ struct UnaryContigFunctor // default: use scalar-value function auto sg = ndit.get_sub_group(); - std::uint8_t sgSize = sg.get_local_range()[0]; - std::uint8_t maxsgSize = sg.get_max_local_range()[0]; - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * maxsgSize); - - if ((base + n_vecs * vec_sz * sgSize < nelems_) && - (maxsgSize == sgSize)) - { - sycl::vec arg_vec; + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + const size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + if (base + elems_per_wi * sgSize < nelems_) { #pragma unroll - for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { - size_t offset = base + static_cast(it) * - static_cast(sgSize); + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + const size_t offset = base + it * sgSize; auto in_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, sycl::access::decorated::yes>(&in[offset]); @@ -177,18 +167,18 @@ struct UnaryContigFunctor sycl::access::address_space::global_space, sycl::access::decorated::yes>(&out[offset]); - arg_vec = sg.load(in_multi_ptr); + sycl::vec arg_vec = + sub_group_load(sg, in_multi_ptr); #pragma unroll - for (std::uint8_t k = 0; k < vec_sz; ++k) { + for (std::uint32_t k = 0; k < vec_sz; ++k) { arg_vec[k] = op(arg_vec[k]); } - sg.store(out_multi_ptr, arg_vec); + sub_group_store(sg, arg_vec, out_multi_ptr); } } else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems_; - k += sgSize) - { + const size_t lane_id = sg.get_local_id()[0]; + for (size_t k = base + lane_id; k < nelems_; k += sgSize) { out[k] = op(in[k]); } } @@ -199,22 +189,15 @@ struct UnaryContigFunctor // default: use scalar-value function auto sg = ndit.get_sub_group(); - std::uint8_t sgSize = sg.get_local_range()[0]; - std::uint8_t maxsgSize = sg.get_max_local_range()[0]; - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * maxsgSize); - - if ((base + n_vecs * vec_sz * sgSize < nelems_) && - (maxsgSize == sgSize)) - { - sycl::vec arg_vec; - sycl::vec res_vec; + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + const size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + if (base + elems_per_wi * sgSize < nelems_) { #pragma unroll - for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { - size_t offset = base + static_cast(it) * - static_cast(sgSize); + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + const size_t offset = base + it * sgSize; auto in_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, sycl::access::decorated::yes>(&in[offset]); @@ -222,31 +205,32 @@ struct UnaryContigFunctor sycl::access::address_space::global_space, sycl::access::decorated::yes>(&out[offset]); - arg_vec = sg.load(in_multi_ptr); + const sycl::vec arg_vec = + sub_group_load(sg, in_multi_ptr); + sycl::vec res_vec; #pragma unroll for (std::uint8_t k = 0; k < vec_sz; ++k) { res_vec[k] = op(arg_vec[k]); } - sg.store(out_multi_ptr, res_vec); + sub_group_store(sg, res_vec, out_multi_ptr); } } else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems_; - k += sgSize) - { + const size_t lane_id = sg.get_local_id()[0]; + for (size_t k = base + lane_id; k < nelems_; k += sgSize) { out[k] = op(in[k]); } } } else { - std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; - size_t base = ndit.get_global_linear_id(); - - base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); - for (size_t offset = base; - offset < std::min(nelems_, base + sgSize * (n_vecs * vec_sz)); - offset += sgSize) - { + const std::uint16_t sgSize = + ndit.get_sub_group().get_local_range()[0]; + const size_t gid = ndit.get_global_linear_id(); + const std::uint16_t elems_per_sg = sgSize * elems_per_wi; + + const size_t start = (gid / sgSize) * (elems_per_sg - sgSize) + gid; + const size_t end = std::min(nelems_, start + elems_per_sg); + for (size_t offset = start; offset < end; offset += sgSize) { out[offset] = op(in[offset]); } } @@ -281,43 +265,62 @@ struct UnaryStridedFunctor } }; +template +SizeT select_lws(const sycl::device &, SizeT n_work_items_needed) +{ + // TODO: make the decision based on device descriptors + + // constexpr SizeT few_threshold = (SizeT(1) << 17); + constexpr SizeT med_threshold = (SizeT(1) << 21); + + const SizeT lws = + (n_work_items_needed <= med_threshold ? SizeT(128) : SizeT(256)); + + return lws; +} + template class UnaryOutputType, template class ContigFunctorT, - template + template class kernel_name, - unsigned int vec_sz = 4, - unsigned int n_vecs = 2> + std::uint8_t vec_sz = 4u, + std::uint8_t n_vecs = 2u> sycl::event unary_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg_p, char *res_p, const std::vector &depends = {}) { - sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz; + const size_t n_work_items_needed = nelems / elems_per_wi; + const size_t lws = select_lws(exec_q.get_device(), n_work_items_needed); - const size_t lws = 128; - const size_t n_groups = - ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); - const auto gws_range = sycl::range<1>(n_groups * lws); - const auto lws_range = sycl::range<1>(lws); + const size_t n_groups = + ((nelems + lws * elems_per_wi - 1) / (lws * elems_per_wi)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); - using resTy = typename UnaryOutputType::value_type; - const argTy *arg_tp = reinterpret_cast(arg_p); - resTy *res_tp = reinterpret_cast(res_p); + using resTy = typename UnaryOutputType::value_type; + using BaseKernelName = kernel_name; + + const argTy *arg_tp = reinterpret_cast(arg_p); + resTy *res_tp = reinterpret_cast(res_p); + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); if (is_aligned(arg_p) && is_aligned(res_p)) { constexpr bool enable_sg_loadstore = true; - using KernelName = kernel_name; + using KernelName = BaseKernelName; cgh.parallel_for( sycl::nd_range<1>(gws_range, lws_range), @@ -326,9 +329,8 @@ sycl::event unary_contig_impl(sycl::queue &exec_q, } else { constexpr bool disable_sg_loadstore = false; - using InnerKernelName = kernel_name; using KernelName = - disabled_sg_loadstore_wrapper_krn; + disabled_sg_loadstore_wrapper_krn; cgh.parallel_for( sycl::nd_range<1>(gws_range, lws_range), @@ -336,6 +338,7 @@ sycl::event unary_contig_impl(sycl::queue &exec_q, disable_sg_loadstore>(arg_tp, res_tp, nelems)); } }); + return comp_ev; } @@ -382,8 +385,8 @@ template struct BinaryContigFunctor { @@ -404,32 +407,28 @@ struct BinaryContigFunctor void operator()(sycl::nd_item<1> ndit) const { + constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz; BinaryOperatorT op{}; /* Each work-item processes vec_sz elements, contiguous in memory */ + /* NOTE: work-group size must be divisible by sub-group size */ if constexpr (enable_sg_loadstore && BinaryOperatorT::supports_sg_loadstore::value && - BinaryOperatorT::supports_vec::value) + BinaryOperatorT::supports_vec::value && (vec_sz > 1)) { auto sg = ndit.get_sub_group(); - std::uint8_t sgSize = sg.get_local_range()[0]; - std::uint8_t maxsgSize = sg.get_max_local_range()[0]; - - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * sgSize); - - if ((base + n_vecs * vec_sz * sgSize < nelems_) && - (sgSize == maxsgSize)) - { - sycl::vec arg1_vec; - sycl::vec arg2_vec; + std::uint16_t sgSize = sg.get_max_local_range()[0]; + + const size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if (base + elems_per_wi * sgSize < nelems_) { sycl::vec res_vec; #pragma unroll - for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { - size_t offset = base + static_cast(it) * - static_cast(sgSize); + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + size_t offset = base + it * sgSize; auto in1_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, sycl::access::decorated::yes>(&in1[offset]); @@ -440,16 +439,17 @@ struct BinaryContigFunctor sycl::access::address_space::global_space, sycl::access::decorated::yes>(&out[offset]); - arg1_vec = sg.load(in1_multi_ptr); - arg2_vec = sg.load(in2_multi_ptr); + const sycl::vec arg1_vec = + sub_group_load(sg, in1_multi_ptr); + const sycl::vec arg2_vec = + sub_group_load(sg, in2_multi_ptr); res_vec = op(arg1_vec, arg2_vec); - sg.store(out_multi_ptr, res_vec); + sub_group_store(sg, res_vec, out_multi_ptr); } } else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems_; - k += sgSize) - { + const std::size_t lane_id = sg.get_local_id()[0]; + for (size_t k = base + lane_id; k < nelems_; k += sgSize) { out[k] = op(in1[k], in2[k]); } } @@ -458,24 +458,16 @@ struct BinaryContigFunctor BinaryOperatorT::supports_sg_loadstore::value) { auto sg = ndit.get_sub_group(); - std::uint8_t sgSize = sg.get_local_range()[0]; - std::uint8_t maxsgSize = sg.get_max_local_range()[0]; - - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * sgSize); - - if ((base + n_vecs * vec_sz * sgSize < nelems_) && - (sgSize == maxsgSize)) - { - sycl::vec arg1_vec; - sycl::vec arg2_vec; - sycl::vec res_vec; + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + const size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if (base + elems_per_wi * sgSize < nelems_) { #pragma unroll - for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { - size_t offset = base + static_cast(it) * - static_cast(sgSize); + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + const size_t offset = base + it * sgSize; auto in1_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, sycl::access::decorated::yes>(&in1[offset]); @@ -486,33 +478,35 @@ struct BinaryContigFunctor sycl::access::address_space::global_space, sycl::access::decorated::yes>(&out[offset]); - arg1_vec = sg.load(in1_multi_ptr); - arg2_vec = sg.load(in2_multi_ptr); + const sycl::vec arg1_vec = + sub_group_load(sg, in1_multi_ptr); + const sycl::vec arg2_vec = + sub_group_load(sg, in2_multi_ptr); + + sycl::vec res_vec; #pragma unroll for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) { res_vec[vec_id] = op(arg1_vec[vec_id], arg2_vec[vec_id]); } - sg.store(out_multi_ptr, res_vec); + sub_group_store(sg, res_vec, out_multi_ptr); } } else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems_; - k += sgSize) - { + const std::size_t lane_id = sg.get_local_id()[0]; + for (size_t k = base + lane_id; k < nelems_; k += sgSize) { out[k] = op(in1[k], in2[k]); } } } else { - std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; - size_t base = ndit.get_global_linear_id(); - - base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); - for (size_t offset = base; - offset < std::min(nelems_, base + sgSize * (n_vecs * vec_sz)); - offset += sgSize) - { + const size_t sgSize = ndit.get_sub_group().get_local_range()[0]; + const size_t gid = ndit.get_global_linear_id(); + const size_t elems_per_sg = sgSize * elems_per_wi; + + const size_t start = (gid / sgSize) * (elems_per_sg - sgSize) + gid; + const size_t end = std::min(nelems_, start + elems_per_sg); + for (size_t offset = start; offset < end; offset += sgSize) { out[offset] = op(in1[offset], in2[offset]); } } @@ -582,14 +576,16 @@ struct BinaryContigMatrixContigRowBroadcastingFunctor void operator()(sycl::nd_item<1> ndit) const { + /* NOTE: work-group size must be divisible by sub-group size */ + BinaryOperatorT op{}; static_assert(BinaryOperatorT::supports_sg_loadstore::value); - auto sg = ndit.get_sub_group(); - size_t gid = ndit.get_global_linear_id(); + const auto &sg = ndit.get_sub_group(); + const size_t gid = ndit.get_global_linear_id(); - std::uint8_t sgSize = sg.get_local_range()[0]; - size_t base = gid - sg.get_local_id()[0]; + const size_t sgSize = sg.get_max_local_range()[0]; + const size_t base = gid - sg.get_local_id()[0]; if (base + sgSize < n_elems) { auto in1_multi_ptr = sycl::address_space_cast< @@ -604,17 +600,16 @@ struct BinaryContigMatrixContigRowBroadcastingFunctor sycl::access::address_space::global_space, sycl::access::decorated::yes>(&res[base]); - const argT1 mat_el = sg.load(in1_multi_ptr); - const argT2 vec_el = sg.load(in2_multi_ptr); + const argT1 mat_el = sub_group_load(sg, in1_multi_ptr); + const argT2 vec_el = sub_group_load(sg, in2_multi_ptr); resT res_el = op(mat_el, vec_el); - sg.store(out_multi_ptr, res_el); + sub_group_store(sg, res_el, out_multi_ptr); } else { - for (size_t k = base + sg.get_local_id()[0]; k < n_elems; - k += sgSize) - { + const size_t lane_id = sg.get_local_id()[0]; + for (size_t k = base + lane_id; k < n_elems; k += sgSize) { res[k] = op(mat[k], padded_vec[k % n1]); } } @@ -647,14 +642,15 @@ struct BinaryContigRowContigMatrixBroadcastingFunctor void operator()(sycl::nd_item<1> ndit) const { + /* NOTE: work-group size must be divisible by sub-group size */ BinaryOperatorT op{}; static_assert(BinaryOperatorT::supports_sg_loadstore::value); - auto sg = ndit.get_sub_group(); + const auto &sg = ndit.get_sub_group(); size_t gid = ndit.get_global_linear_id(); - std::uint8_t sgSize = sg.get_local_range()[0]; - size_t base = gid - sg.get_local_id()[0]; + const size_t sgSize = sg.get_max_local_range()[0]; + const size_t base = gid - sg.get_local_id()[0]; if (base + sgSize < n_elems) { auto in1_multi_ptr = sycl::address_space_cast< @@ -669,17 +665,16 @@ struct BinaryContigRowContigMatrixBroadcastingFunctor sycl::access::address_space::global_space, sycl::access::decorated::yes>(&res[base]); - const argT2 mat_el = sg.load(in2_multi_ptr); - const argT1 vec_el = sg.load(in1_multi_ptr); + const argT2 mat_el = sub_group_load(sg, in2_multi_ptr); + const argT1 vec_el = sub_group_load(sg, in1_multi_ptr); resT res_el = op(vec_el, mat_el); - sg.store(out_multi_ptr, res_el); + sub_group_store(sg, res_el, out_multi_ptr); } else { - for (size_t k = base + sg.get_local_id()[0]; k < n_elems; - k += sgSize) - { + const size_t lane_id = sg.get_local_id()[0]; + for (size_t k = base + lane_id; k < n_elems; k += sgSize) { res[k] = op(padded_vec[k % n1], mat[k]); } } @@ -765,18 +760,18 @@ template class BinaryContigFunctorT, template + std::uint8_t vs, + std::uint8_t nv> class kernel_name, - unsigned int vec_sz = 4, - unsigned int n_vecs = 2> + std::uint8_t vec_sz = 4u, + std::uint8_t n_vecs = 2u> sycl::event binary_contig_impl(sycl::queue &exec_q, size_t nelems, const char *arg1_p, @@ -787,30 +782,33 @@ sycl::event binary_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { - sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + const size_t n_work_items_needed = nelems / (n_vecs * vec_sz); + const size_t lws = select_lws(exec_q.get_device(), n_work_items_needed); - const size_t lws = 128; - const size_t n_groups = - ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); - const auto gws_range = sycl::range<1>(n_groups * lws); - const auto lws_range = sycl::range<1>(lws); + const size_t n_groups = + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); - using resTy = typename BinaryOutputType::value_type; + using resTy = typename BinaryOutputType::value_type; + using BaseKernelName = kernel_name; - const argTy1 *arg1_tp = - reinterpret_cast(arg1_p) + arg1_offset; - const argTy2 *arg2_tp = - reinterpret_cast(arg2_p) + arg2_offset; - resTy *res_tp = reinterpret_cast(res_p) + res_offset; + const argTy1 *arg1_tp = + reinterpret_cast(arg1_p) + arg1_offset; + const argTy2 *arg2_tp = + reinterpret_cast(arg2_p) + arg2_offset; + resTy *res_tp = reinterpret_cast(res_p) + res_offset; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); if (is_aligned(arg1_tp) && is_aligned(arg2_tp) && is_aligned(res_tp)) { constexpr bool enable_sg_loadstore = true; - using KernelName = - kernel_name; + using KernelName = BaseKernelName; + cgh.parallel_for( sycl::nd_range<1>(gws_range, lws_range), BinaryContigFunctorT; using KernelName = - disabled_sg_loadstore_wrapper_krn; + disabled_sg_loadstore_wrapper_krn; cgh.parallel_for( sycl::nd_range<1>(gws_range, lws_range), BinaryContigFunctorT struct BinaryInplaceContigFunctor { @@ -72,47 +76,46 @@ struct BinaryInplaceContigFunctor void operator()(sycl::nd_item<1> ndit) const { BinaryInplaceOperatorT op{}; + constexpr std::uint8_t elems_per_wi = vec_sz * n_vecs; /* Each work-item processes vec_sz elements, contiguous in memory */ + /* NB: Workgroup size must be divisible by sub-group size */ if constexpr (enable_sg_loadstore && BinaryInplaceOperatorT::supports_sg_loadstore::value && - BinaryInplaceOperatorT::supports_vec::value) + BinaryInplaceOperatorT::supports_vec::value && + (vec_sz > 1)) { auto sg = ndit.get_sub_group(); - std::uint8_t sgSize = sg.get_local_range()[0]; - std::uint8_t maxsgSize = sg.get_max_local_range()[0]; - - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * sgSize); + std::uint16_t sgSize = sg.get_max_local_range()[0]; - if ((base + n_vecs * vec_sz * sgSize < nelems_) && - (sgSize == maxsgSize)) - { + size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); - sycl::vec arg_vec; - sycl::vec res_vec; + if (base + elems_per_wi * sgSize < nelems_) { #pragma unroll - for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + const size_t offset = base + it * sgSize; auto rhs_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, - sycl::access::decorated::yes>(&rhs[base + it * sgSize]); + sycl::access::decorated::yes>(&rhs[offset]); auto lhs_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, - sycl::access::decorated::yes>(&lhs[base + it * sgSize]); + sycl::access::decorated::yes>(&lhs[offset]); - arg_vec = sg.load(rhs_multi_ptr); - res_vec = sg.load(lhs_multi_ptr); + const sycl::vec &arg_vec = + sub_group_load(sg, rhs_multi_ptr); + sycl::vec res_vec = + sub_group_load(sg, lhs_multi_ptr); op(res_vec, arg_vec); - sg.store(lhs_multi_ptr, res_vec); + sub_group_store(sg, res_vec, lhs_multi_ptr); } } else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems_; - k += sgSize) - { + const size_t lane_id = sg.get_local_id()[0]; + for (size_t k = base + lane_id; k < nelems_; k += sgSize) { op(lhs[k], rhs[k]); } } @@ -121,54 +124,49 @@ struct BinaryInplaceContigFunctor BinaryInplaceOperatorT::supports_sg_loadstore::value) { auto sg = ndit.get_sub_group(); - std::uint8_t sgSize = sg.get_local_range()[0]; - std::uint8_t maxsgSize = sg.get_max_local_range()[0]; + std::uint16_t sgSize = sg.get_max_local_range()[0]; - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * sgSize); - - if ((base + n_vecs * vec_sz * sgSize < nelems_) && - (sgSize == maxsgSize)) - { - sycl::vec arg_vec; - sycl::vec res_vec; + size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + if (base + elems_per_wi * sgSize < nelems_) { #pragma unroll - for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + const size_t offset = base + it * sgSize; auto rhs_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, - sycl::access::decorated::yes>(&rhs[base + it * sgSize]); + sycl::access::decorated::yes>(&rhs[offset]); auto lhs_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, - sycl::access::decorated::yes>(&lhs[base + it * sgSize]); + sycl::access::decorated::yes>(&lhs[offset]); - arg_vec = sg.load(rhs_multi_ptr); - res_vec = sg.load(lhs_multi_ptr); + const sycl::vec arg_vec = + sub_group_load(sg, rhs_multi_ptr); + sycl::vec res_vec = + sub_group_load(sg, lhs_multi_ptr); #pragma unroll for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) { op(res_vec[vec_id], arg_vec[vec_id]); } - sg.store(lhs_multi_ptr, res_vec); + sub_group_store(sg, res_vec, lhs_multi_ptr); } } else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems_; - k += sgSize) - { + const size_t lane_id = sg.get_local_id()[0]; + for (size_t k = base + lane_id; k < nelems_; k += sgSize) { op(lhs[k], rhs[k]); } } } else { - std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; - size_t base = ndit.get_global_linear_id(); - - base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); - for (size_t offset = base; - offset < std::min(nelems_, base + sgSize * (n_vecs * vec_sz)); - offset += sgSize) - { + const size_t sgSize = ndit.get_sub_group().get_local_range()[0]; + const size_t gid = ndit.get_global_linear_id(); + const size_t elems_per_sg = elems_per_wi * sgSize; + + const size_t start = (gid / sgSize) * (elems_per_sg - sgSize) + gid; + const size_t end = std::min(nelems_, start + elems_per_sg); + for (size_t offset = start; offset < end; offset += sgSize) { op(lhs[offset], rhs[offset]); } } @@ -228,13 +226,14 @@ struct BinaryInplaceRowMatrixBroadcastingFunctor void operator()(sycl::nd_item<1> ndit) const { + /* Workgroup size is expected to be a multiple of sub-group size */ BinaryOperatorT op{}; static_assert(BinaryOperatorT::supports_sg_loadstore::value); auto sg = ndit.get_sub_group(); - size_t gid = ndit.get_global_linear_id(); + const size_t gid = ndit.get_global_linear_id(); - std::uint8_t sgSize = sg.get_local_range()[0]; + std::uint8_t sgSize = sg.get_max_local_range()[0]; size_t base = gid - sg.get_local_id()[0]; if (base + sgSize < n_elems) { @@ -246,17 +245,16 @@ struct BinaryInplaceRowMatrixBroadcastingFunctor sycl::access::address_space::global_space, sycl::access::decorated::yes>(&mat[base]); - const argT vec_el = sg.load(in_multi_ptr); - resT mat_el = sg.load(out_multi_ptr); + const argT vec_el = sub_group_load(sg, in_multi_ptr); + resT mat_el = sub_group_load(sg, out_multi_ptr); op(mat_el, vec_el); - sg.store(out_multi_ptr, mat_el); + sub_group_store(sg, mat_el, out_multi_ptr); } else { - for (size_t k = base + sg.get_local_id()[0]; k < n_elems; - k += sgSize) - { + const size_t start = base + sg.get_local_id()[0]; + for (size_t k = start; k < n_elems; k += sgSize) { op(mat[k], padded_vec[k % n1]); } } @@ -301,14 +299,14 @@ template class BinaryInplaceContigFunctorT, - template + template class kernel_name, - unsigned int vec_sz = 4, - unsigned int n_vecs = 2> + std::uint8_t vec_sz = 4u, + std::uint8_t n_vecs = 2u> sycl::event binary_inplace_contig_impl(sycl::queue &exec_q, size_t nelems, @@ -437,10 +435,10 @@ sycl::event binary_inplace_row_matrix_broadcast_impl( // sub-group spans work-items [I, I + sgSize) // base = ndit.get_global_linear_id() - sg.get_local_id()[0] - // Generically, sg.load( &mat[base]) may load arrays from + // Generically, sub_group_load( &mat[base]) may load arrays from // different rows of mat. The start corresponds to row (base / n0) - // We read sg.load(&padded_vec[(base / n0)]). The vector is padded to - // ensure that reads are accessible + // We read sub_group_load(&padded_vec[(base / n0)]). The vector is + // padded to ensure that reads are accessible const size_t lws = 128; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp index 4953feedb2..486174435c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp @@ -31,10 +31,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -82,8 +84,8 @@ template struct ConjFunctor template using ConjContigFunctor = elementwise_common::UnaryContigFunctor struct ConjOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct ConjContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class conj_contig_kernel; template @@ -129,9 +150,12 @@ sycl::event conj_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = ConjContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = ConjContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, ConjOutputType, ConjContigFunctor, conj_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, ConjOutputType, ConjContigFunctor, conj_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct ConjContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp index 92997b572b..9ad6a6ad65 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/copysign.hpp @@ -29,6 +29,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -82,8 +84,8 @@ template struct CopysignFunctor template using CopysignContigFunctor = elementwise_common::BinaryContigFunctor struct CopysignOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct CopysignContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class copysign_contig_kernel; template @@ -135,10 +157,16 @@ sycl::event copysign_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + CopysignContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + CopysignContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, CopysignOutputType, CopysignContigFunctor, - copysign_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + copysign_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct CopysignContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index 8b6b0c5fbe..52fbebe545 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -29,10 +29,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -163,8 +165,8 @@ template struct CosFunctor template using CosContigFunctor = elementwise_common::UnaryContigFunctor struct CosOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct CosContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class cos_contig_kernel; template @@ -202,9 +223,12 @@ sycl::event cos_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = CosContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = CosContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, CosOutputType, CosContigFunctor, cos_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, CosOutputType, CosContigFunctor, cos_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct CosContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp index cff1038ed9..b1752e5929 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp @@ -29,10 +29,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -153,8 +155,8 @@ template struct CoshFunctor template using CoshContigFunctor = elementwise_common::UnaryContigFunctor struct CoshOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct CoshContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class cosh_contig_kernel; template @@ -191,9 +212,12 @@ sycl::event cosh_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = CoshContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = CoshContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, CoshOutputType, CoshContigFunctor, cosh_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, CoshOutputType, CoshContigFunctor, cosh_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct CoshContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp index d368658afc..6a455509c5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp @@ -30,6 +30,8 @@ #include #include "sycl_complex.hpp" +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -119,8 +121,8 @@ template struct EqualFunctor template using EqualContigFunctor = elementwise_common::BinaryContigFunctor struct EqualOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template struct EqualContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class equal_contig_kernel; template @@ -207,10 +228,16 @@ sycl::event equal_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + EqualContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + EqualContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, EqualOutputType, EqualContigFunctor, - equal_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + equal_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg1_p, + arg1_offset, arg2_p, arg2_offset, + res_p, res_offset, depends); } template struct EqualContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp index 7e613c9731..21edeaeb31 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp @@ -29,10 +29,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -122,8 +124,8 @@ template struct ExpFunctor template using ExpContigFunctor = elementwise_common::UnaryContigFunctor struct ExpOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct ExpContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class exp_contig_kernel; template @@ -160,9 +181,12 @@ sycl::event exp_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = ExpContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = ExpContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, ExpOutputType, ExpContigFunctor, exp_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, ExpOutputType, ExpContigFunctor, exp_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct ExpContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp index b436bb3855..df9a472329 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp @@ -30,10 +30,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -124,8 +126,8 @@ template struct Exp2Functor template using Exp2ContigFunctor = elementwise_common::UnaryContigFunctor struct Exp2OutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct Exp2ContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class exp2_contig_kernel; template @@ -162,9 +183,12 @@ sycl::event exp2_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = Exp2ContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = Exp2ContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, Exp2OutputType, Exp2ContigFunctor, exp2_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, Exp2OutputType, Exp2ContigFunctor, exp2_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct Exp2ContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp index 9a9d0a1562..a8bebd7a15 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp @@ -31,9 +31,11 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -136,8 +138,8 @@ template struct Expm1Functor template using Expm1ContigFunctor = elementwise_common::UnaryContigFunctor struct Expm1OutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct Expm1ContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class expm1_contig_kernel; template @@ -175,9 +196,12 @@ sycl::event expm1_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = Expm1ContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = Expm1ContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, Expm1OutputType, Expm1ContigFunctor, expm1_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, Expm1OutputType, Expm1ContigFunctor, expm1_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct Expm1ContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp index 530dd3d9aa..2381327766 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp @@ -29,9 +29,11 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -78,8 +80,8 @@ template struct FloorFunctor template using FloorContigFunctor = elementwise_common::UnaryContigFunctor struct FloorOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct FloorContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class floor_contig_kernel; template @@ -122,9 +143,12 @@ sycl::event floor_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = FloorContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = FloorContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, FloorOutputType, FloorContigFunctor, floor_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, FloorOutputType, FloorContigFunctor, floor_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct FloorContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp index 72ee3a789a..98bc9820ba 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp @@ -29,6 +29,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -126,8 +128,8 @@ struct FloorDivideFunctor template using FloorDivideContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -201,11 +203,31 @@ template struct FloorDivideOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct FloorDivideContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class floor_divide_contig_kernel; template @@ -220,10 +242,16 @@ floor_divide_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + FloorDivideContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + FloorDivideContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, FloorDivideOutputType, FloorDivideContigFunctor, - floor_divide_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + floor_divide_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template @@ -367,8 +395,8 @@ template struct FloorDivideInplaceFunctor template using FloorDivideInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor< @@ -389,8 +417,8 @@ using FloorDivideInplaceStridedFunctor = template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class floor_divide_inplace_contig_kernel; /* @brief Types supported by in-place floor division */ @@ -440,10 +468,15 @@ floor_divide_inplace_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + FloorDivideContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + FloorDivideContigHyperparameterSet::n_vecs; + return elementwise_common::binary_inplace_contig_impl< argTy, resTy, FloorDivideInplaceContigFunctor, - floor_divide_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset, - res_p, res_offset, depends); + floor_divide_inplace_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp index 05c2a36b0c..588ebc780d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp @@ -30,6 +30,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" @@ -120,8 +122,8 @@ template struct GreaterFunctor template using GreaterContigFunctor = elementwise_common::BinaryContigFunctor struct GreaterOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct GreaterContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class greater_contig_kernel; template @@ -208,10 +230,16 @@ sycl::event greater_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + GreaterContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + GreaterContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, GreaterOutputType, GreaterContigFunctor, - greater_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + greater_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg1_p, + arg1_offset, arg2_p, arg2_offset, + res_p, res_offset, depends); } template struct GreaterContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp index 43e4e98db1..614fb202e1 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp @@ -30,6 +30,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" @@ -121,8 +123,8 @@ struct GreaterEqualFunctor template using GreaterEqualContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -191,11 +193,31 @@ template struct GreaterEqualOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct GreaterEqualContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class greater_equal_contig_kernel; template @@ -210,11 +232,16 @@ greater_equal_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + GreaterEqualContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + GreaterEqualContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, GreaterEqualOutputType, GreaterEqualContigFunctor, - greater_equal_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, - arg2_p, arg2_offset, res_p, res_offset, - depends); + greater_equal_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp index c5b68644a9..f65951f36b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/hypot.hpp @@ -29,6 +29,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -84,8 +86,8 @@ template struct HypotFunctor template using HypotContigFunctor = elementwise_common::BinaryContigFunctor struct HypotOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template struct HypotContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class hypot_contig_kernel; template @@ -137,10 +158,16 @@ sycl::event hypot_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + HypotContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + HypotContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, HypotOutputType, HypotContigFunctor, - hypot_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + hypot_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg1_p, + arg1_offset, arg2_p, arg2_offset, + res_p, res_offset, depends); } template struct HypotContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp index e918bc0ac7..1d33f83d27 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp @@ -31,9 +31,11 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -78,8 +80,8 @@ template struct ImagFunctor template using ImagContigFunctor = elementwise_common::UnaryContigFunctor struct ImagOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct ImagContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class imag_contig_kernel; template @@ -125,9 +146,12 @@ sycl::event imag_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = ImagContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = ImagContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, ImagOutputType, ImagContigFunctor, imag_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, ImagOutputType, ImagContigFunctor, imag_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct ImagContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp index df979eec76..067e3e36ee 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp @@ -30,6 +30,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -97,8 +99,8 @@ template struct IsFiniteFunctor template using IsFiniteContigFunctor = elementwise_common::UnaryContigFunctor struct IsFiniteOutputType using value_type = bool; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct IsFiniteContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class isfinite_contig_kernel; template @@ -127,10 +148,15 @@ sycl::event isfinite_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { - return elementwise_common::unary_contig_impl( - exec_q, nelems, arg_p, res_p, depends); + constexpr std::uint8_t vec_sz = + IsFiniteContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + IsFiniteContigHyperparameterSet::n_vecs; + + return elementwise_common::unary_contig_impl< + argTy, IsFiniteOutputType, IsFiniteContigFunctor, + isfinite_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg_p, res_p, + depends); } template struct IsFiniteContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp index 24be019a44..70069bdaa2 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp @@ -30,7 +30,10 @@ #include #include +#include "vec_size_util.hpp" + #include "kernels/dpctl_tensor_types.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -96,8 +99,8 @@ template struct IsInfFunctor template using IsInfContigFunctor = elementwise_common::UnaryContigFunctor struct IsInfOutputType using value_type = bool; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct IsInfContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class isinf_contig_kernel; template @@ -126,9 +148,12 @@ sycl::event isinf_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = IsInfContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = IsInfContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, IsInfOutputType, IsInfContigFunctor, isinf_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, IsInfOutputType, IsInfContigFunctor, isinf_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct IsInfContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp index cc452a25b1..0d8a15d0b8 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -29,7 +29,10 @@ #include #include +#include "vec_size_util.hpp" + #include "kernels/dpctl_tensor_types.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -94,8 +97,8 @@ template struct IsNanFunctor template using IsNanContigFunctor = elementwise_common::UnaryContigFunctor struct IsNanOutputType using value_type = bool; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct IsNanContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class isnan_contig_kernel; template @@ -124,9 +146,12 @@ sycl::event isnan_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = IsNanContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = IsNanContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, IsNanOutputType, IsNanContigFunctor, isnan_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, IsNanOutputType, IsNanContigFunctor, isnan_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct IsNanContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp index 0b26342563..43f11725b7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp @@ -29,12 +29,14 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" +#include "vec_size_util.hpp" + #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" namespace dpctl @@ -118,8 +120,8 @@ template struct LessFunctor template using LessContigFunctor = elementwise_common::BinaryContigFunctor struct LessOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template struct LessContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class less_contig_kernel; template @@ -206,10 +227,15 @@ sycl::event less_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + LessContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + LessContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< - argTy1, argTy2, LessOutputType, LessContigFunctor, less_contig_kernel>( - exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, - res_offset, depends); + argTy1, argTy2, LessOutputType, LessContigFunctor, less_contig_kernel, + vec_sz, n_vecs>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, + arg2_offset, res_p, res_offset, depends); } template struct LessContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp index 01289ae98f..81cc375c16 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/less_equal.hpp @@ -30,6 +30,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" @@ -119,8 +121,8 @@ template struct LessEqualFunctor template using LessEqualContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -189,11 +191,31 @@ template struct LessEqualOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct LessEqualContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class less_equal_contig_kernel; template @@ -207,10 +229,16 @@ sycl::event less_equal_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + LessEqualContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + LessEqualContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, LessEqualOutputType, LessEqualContigFunctor, - less_equal_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + less_equal_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct LessEqualContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp index a3e28ef5d7..13eb64afca 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp @@ -30,10 +30,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -78,8 +80,8 @@ template struct LogFunctor template using LogContigFunctor = elementwise_common::UnaryContigFunctor struct LogOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct LogContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class log_contig_kernel; template @@ -117,9 +138,12 @@ sycl::event log_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = LogContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = LogContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, LogOutputType, LogContigFunctor, log_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, LogOutputType, LogContigFunctor, log_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct LogContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp index 793b910f69..ea486239e7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp @@ -31,10 +31,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -97,8 +99,8 @@ template struct Log10Functor template using Log10ContigFunctor = elementwise_common::UnaryContigFunctor struct Log10OutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct Log10ContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class log10_contig_kernel; template @@ -136,9 +157,12 @@ sycl::event log10_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = Log10ContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = Log10ContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, Log10OutputType, Log10ContigFunctor, log10_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, Log10OutputType, Log10ContigFunctor, log10_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct Log10ContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp index 19238e7e37..3df38d05f0 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp @@ -30,9 +30,11 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -102,8 +104,8 @@ template struct Log1pFunctor template using Log1pContigFunctor = elementwise_common::UnaryContigFunctor struct Log1pOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct Log1pContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class log1p_contig_kernel; template @@ -141,9 +162,12 @@ sycl::event log1p_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = Log1pContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = Log1pContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, Log1pOutputType, Log1pContigFunctor, log1p_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, Log1pOutputType, Log1pContigFunctor, log1p_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct Log1pContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp index 69d0022c72..2da4c55de0 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp @@ -31,10 +31,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -98,8 +100,8 @@ template struct Log2Functor template using Log2ContigFunctor = elementwise_common::UnaryContigFunctor struct Log2OutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct Log2ContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class log2_contig_kernel; template @@ -137,9 +158,12 @@ sycl::event log2_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = Log2ContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = Log2ContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, Log2OutputType, Log2ContigFunctor, log2_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, Log2OutputType, Log2ContigFunctor, log2_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct Log2ContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp index b0be45ea54..6d2375c20d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp @@ -31,12 +31,14 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" +#include "vec_size_util.hpp" + #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" namespace dpctl @@ -99,8 +101,8 @@ template struct LogAddExpFunctor template using LogAddExpContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -134,11 +136,31 @@ template struct LogAddExpOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct LogAddExpContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class logaddexp_contig_kernel; template @@ -152,10 +174,16 @@ sycl::event logaddexp_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + LogAddExpContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + LogAddExpContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, LogAddExpOutputType, LogAddExpContigFunctor, - logaddexp_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + logaddexp_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct LogAddExpContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp index f15caa02e6..768ace7754 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_and.hpp @@ -30,11 +30,13 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" namespace dpctl @@ -93,8 +95,8 @@ struct LogicalAndFunctor template using LogicalAndContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -159,11 +161,31 @@ template struct LogicalAndOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct LogicalAndContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class logical_and_contig_kernel; template @@ -178,10 +200,16 @@ logical_and_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + LogicalAndContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + LogicalAndContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, LogicalAndOutputType, LogicalAndContigFunctor, - logical_and_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + logical_and_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct LogicalAndContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp index 7c83e07072..53c5404caa 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_not.hpp @@ -30,7 +30,10 @@ #include #include +#include "vec_size_util.hpp" + #include "kernels/dpctl_tensor_types.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -66,8 +69,8 @@ template struct LogicalNotFunctor template using LogicalNotContigFunctor = elementwise_common::UnaryContigFunctor struct LogicalNotOutputType using value_type = bool; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct LogicalNotContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class logical_not_contig_kernel; template @@ -100,10 +122,15 @@ logical_not_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { - return elementwise_common::unary_contig_impl( - exec_q, nelems, arg_p, res_p, depends); + constexpr std::uint8_t vec_sz = + LogicalNotContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + LogicalNotContigHyperparameterSet::n_vecs; + + return elementwise_common::unary_contig_impl< + argTy, LogicalNotOutputType, LogicalNotContigFunctor, + logical_not_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg_p, res_p, + depends); } template struct LogicalNotContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp index 43e02f2102..93c5f3b9a6 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_or.hpp @@ -30,11 +30,13 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" namespace dpctl @@ -92,8 +94,8 @@ template struct LogicalOrFunctor template using LogicalOrContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -158,11 +160,31 @@ template struct LogicalOrOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct LogicalOrContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class logical_or_contig_kernel; template @@ -176,10 +198,16 @@ sycl::event logical_or_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + LogicalOrContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + LogicalOrContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, LogicalOrOutputType, LogicalOrContigFunctor, - logical_or_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + logical_or_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct LogicalOrContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp index dc41760985..9ff54b6f16 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logical_xor.hpp @@ -30,11 +30,13 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" namespace dpctl @@ -94,8 +96,8 @@ struct LogicalXorFunctor template using LogicalXorContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -160,11 +162,31 @@ template struct LogicalXorOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct LogicalXorContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class logical_xor_contig_kernel; template @@ -179,10 +201,16 @@ logical_xor_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + LogicalXorContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + LogicalXorContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, LogicalXorOutputType, LogicalXorContigFunctor, - logical_xor_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + logical_xor_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct LogicalXorContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp index e73704bad8..ed44b8ade7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp @@ -29,12 +29,14 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" +#include "vec_size_util.hpp" + #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" namespace dpctl @@ -70,9 +72,13 @@ template struct MaximumFunctor } else if constexpr (std::is_floating_point_v || std::is_same_v) - return (std::isnan(in1) || in1 > in2) ? in1 : in2; - else + { + const bool choose_first = (std::isnan(in1) || (in1 > in2)); + return (choose_first) ? in1 : in2; + } + else { return (in1 > in2) ? in1 : in2; + } } template @@ -83,11 +89,17 @@ template struct MaximumFunctor sycl::vec res; #pragma unroll for (int i = 0; i < vec_sz; ++i) { - if constexpr (std::is_floating_point_v) - res[i] = - (sycl::isnan(in1[i]) || in1[i] > in2[i]) ? in1[i] : in2[i]; - else - res[i] = (in1[i] > in2[i]) ? in1[i] : in2[i]; + const auto &v1 = in1[i]; + const auto &v2 = in2[i]; + if constexpr (std::is_floating_point_v || + std::is_same_v) + { + const bool choose_first = (std::isnan(v1) || (v1 > v2)); + res[i] = (choose_first) ? v1 : v2; + } + else { + res[i] = (v1 > v2) ? v1 : v2; + } } return res; } @@ -96,8 +108,8 @@ template struct MaximumFunctor template using MaximumContigFunctor = elementwise_common::BinaryContigFunctor struct MaximumOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct MaximumContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class maximum_contig_kernel; template @@ -200,10 +232,16 @@ sycl::event maximum_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + MaximumContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + MaximumContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, MaximumOutputType, MaximumContigFunctor, - maximum_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + maximum_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg1_p, + arg1_offset, arg2_p, arg2_offset, + res_p, res_offset, depends); } template struct MaximumContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp index 590c0b6486..551daf0498 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp @@ -29,12 +29,14 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" +#include "vec_size_util.hpp" + #include "utils/math_utils.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" namespace dpctl @@ -70,9 +72,13 @@ template struct MinimumFunctor } else if constexpr (std::is_floating_point_v || std::is_same_v) - return (std::isnan(in1) || in1 < in2) ? in1 : in2; - else + { + const bool choose_first = sycl::isnan(in1) || (in1 < in2); + return (choose_first) ? in1 : in2; + } + else { return (in1 < in2) ? in1 : in2; + } } template @@ -83,11 +89,17 @@ template struct MinimumFunctor sycl::vec res; #pragma unroll for (int i = 0; i < vec_sz; ++i) { - if constexpr (std::is_floating_point_v) - res[i] = - (sycl::isnan(in1[i]) || in1[i] < in2[i]) ? in1[i] : in2[i]; - else - res[i] = (in1[i] < in2[i]) ? in1[i] : in2[i]; + const auto &v1 = in1[i]; + const auto &v2 = in2[i]; + if constexpr (std::is_floating_point_v || + std::is_same_v) + { + const bool choose_first = sycl::isnan(v1) || (v1 < v2); + res[i] = (choose_first) ? v1 : v2; + } + else { + res[i] = (v1 < v2) ? v1 : v2; + } } return res; } @@ -96,8 +108,8 @@ template struct MinimumFunctor template using MinimumContigFunctor = elementwise_common::BinaryContigFunctor struct MinimumOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct MinimumContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class minimum_contig_kernel; template @@ -200,10 +232,16 @@ sycl::event minimum_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + MinimumContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + MinimumContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, MinimumOutputType, MinimumContigFunctor, - minimum_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + minimum_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg1_p, + arg1_offset, arg2_p, arg2_offset, + res_p, res_offset, depends); } template struct MinimumContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp index 1af284f55b..37b3803c27 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp @@ -30,12 +30,14 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" @@ -98,8 +100,8 @@ template struct MultiplyFunctor template using MultiplyContigFunctor = elementwise_common::BinaryContigFunctor struct MultiplyOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct MultiplyContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class multiply_contig_kernel; template @@ -202,10 +224,16 @@ sycl::event multiply_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + MultiplyContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + MultiplyContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, MultiplyOutputType, MultiplyContigFunctor, - multiply_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + multiply_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct MultiplyContigFactory @@ -402,8 +430,8 @@ template struct MultiplyInplaceFunctor template using MultiplyInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor< @@ -424,8 +452,8 @@ using MultiplyInplaceStridedFunctor = template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class multiply_inplace_contig_kernel; /* @brief Types supported by in-place multiplication */ @@ -482,10 +510,15 @@ multiply_inplace_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + MultiplyContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + MultiplyContigHyperparameterSet::n_vecs; + return elementwise_common::binary_inplace_contig_impl< argTy, resTy, MultiplyInplaceContigFunctor, - multiply_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset, - res_p, res_offset, depends); + multiply_inplace_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp index 83f17dd47b..a036158ccd 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/negative.hpp @@ -30,9 +30,11 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -65,8 +67,8 @@ template struct NegativeFunctor template using NegativeContigFunctor = elementwise_common::UnaryContigFunctor struct NegativeOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct NegativeContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class negative_contig_kernel; template @@ -107,10 +128,15 @@ sycl::event negative_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { - return elementwise_common::unary_contig_impl( - exec_q, nelems, arg_p, res_p, depends); + constexpr std::uint8_t vec_sz = + NegativeContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + NegativeContigHyperparameterSet::n_vecs; + + return elementwise_common::unary_contig_impl< + argTy, NegativeOutputType, NegativeContigFunctor, + negative_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg_p, res_p, + depends); } template struct NegativeContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/nextafter.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/nextafter.hpp index 5dc9ea40b3..b58b1b98ef 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/nextafter.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/nextafter.hpp @@ -29,6 +29,8 @@ #include #include +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -82,8 +84,8 @@ template struct NextafterFunctor template using NextafterContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -117,11 +119,31 @@ template struct NextafterOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct NextafterContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class nextafter_contig_kernel; template @@ -135,10 +157,16 @@ sycl::event nextafter_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + NextafterContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + NextafterContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, NextafterOutputType, NextafterContigFunctor, - nextafter_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + nextafter_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct NextafterContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp index c1b920193b..be1231648c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/not_equal.hpp @@ -29,11 +29,13 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" namespace dpctl @@ -103,8 +105,8 @@ template struct NotEqualFunctor template using NotEqualContigFunctor = elementwise_common::BinaryContigFunctor struct NotEqualOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct NotEqualContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class not_equal_contig_kernel; template @@ -191,10 +213,16 @@ sycl::event not_equal_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + NotEqualContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + NotEqualContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, NotEqualOutputType, NotEqualContigFunctor, - not_equal_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + not_equal_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct NotEqualContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp index ae2711ed0e..3ccca611d8 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/positive.hpp @@ -30,9 +30,11 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -80,8 +82,8 @@ template struct PositiveFunctor template using PositiveContigFunctor = elementwise_common::UnaryContigFunctor struct PositiveOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct PositiveContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class positive_contig_kernel; template @@ -122,10 +143,15 @@ sycl::event positive_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { - return elementwise_common::unary_contig_impl( - exec_q, nelems, arg_p, res_p, depends); + constexpr std::uint8_t vec_sz = + PositiveContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + PositiveContigHyperparameterSet::n_vecs; + + return elementwise_common::unary_contig_impl< + argTy, PositiveOutputType, PositiveContigFunctor, + positive_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg_p, res_p, + depends); } template struct PositiveContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp index bb462dceae..353e516d28 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp @@ -30,12 +30,14 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" @@ -151,8 +153,8 @@ template struct PowFunctor template using PowContigFunctor = elementwise_common::BinaryContigFunctor struct PowOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template struct PowContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class pow_contig_kernel; template @@ -254,10 +275,15 @@ sycl::event pow_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + PowContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + PowContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< - argTy1, argTy2, PowOutputType, PowContigFunctor, pow_contig_kernel>( - exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, - res_offset, depends); + argTy1, argTy2, PowOutputType, PowContigFunctor, pow_contig_kernel, + vec_sz, n_vecs>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, + arg2_offset, res_p, res_offset, depends); } template struct PowContigFactory @@ -417,8 +443,8 @@ template struct PowInplaceFunctor template using PowInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor< argT, @@ -438,8 +464,8 @@ using PowInplaceStridedFunctor = template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class pow_inplace_contig_kernel; /* @brief Types supported by in-place pow */ @@ -495,9 +521,15 @@ pow_inplace_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + PowContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + PowContigHyperparameterSet::n_vecs; + return elementwise_common::binary_inplace_contig_impl< - argTy, resTy, PowInplaceContigFunctor, pow_inplace_contig_kernel>( - exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); + argTy, resTy, PowInplaceContigFunctor, pow_inplace_contig_kernel, + vec_sz, n_vecs>(exec_q, nelems, arg_p, arg_offset, res_p, res_offset, + depends); } template struct PowInplaceContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp index 2c3dce0c9c..1297dab283 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp @@ -32,9 +32,11 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -91,8 +93,8 @@ template struct ProjFunctor template using ProjContigFunctor = elementwise_common::UnaryContigFunctor struct ProjOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct ProjContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class proj_contig_kernel; template @@ -126,9 +147,12 @@ sycl::event proj_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = ProjContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = ProjContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, ProjOutputType, ProjContigFunctor, proj_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, ProjOutputType, ProjContigFunctor, proj_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct ProjContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp index c66e4003cb..270b613346 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp @@ -31,9 +31,11 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -78,8 +80,8 @@ template struct RealFunctor template using RealContigFunctor = elementwise_common::UnaryContigFunctor struct RealOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct RealContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class real_contig_kernel; template @@ -125,9 +146,12 @@ sycl::event real_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = RealContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = RealContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, RealOutputType, RealContigFunctor, real_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, RealOutputType, RealContigFunctor, real_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct RealContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp index 4d4b70fd4f..90909ea772 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp @@ -32,12 +32,14 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" namespace dpctl @@ -81,8 +83,8 @@ template struct ReciprocalFunctor template using ReciprocalContigFunctor = elementwise_common::UnaryContigFunctor struct ReciprocalOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct ReciprocalContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class reciprocal_contig_kernel; template @@ -122,10 +143,15 @@ sycl::event reciprocal_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { - return elementwise_common::unary_contig_impl( - exec_q, nelems, arg_p, res_p, depends); + constexpr std::uint8_t vec_sz = + ReciprocalContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + ReciprocalContigHyperparameterSet::n_vecs; + + return elementwise_common::unary_contig_impl< + argTy, ReciprocalOutputType, ReciprocalContigFunctor, + reciprocal_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg_p, res_p, + depends); } template struct ReciprocalContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp index 7bb070cc00..57467d56b3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp @@ -30,11 +30,13 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" @@ -144,8 +146,8 @@ template struct RemainderFunctor template using RemainderContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -219,11 +221,31 @@ template struct RemainderOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct RemainderContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class remainder_contig_kernel; template @@ -237,10 +259,16 @@ sycl::event remainder_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + RemainderContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + RemainderContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, RemainderOutputType, RemainderContigFunctor, - remainder_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + remainder_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct RemainderContigFactory @@ -393,8 +421,8 @@ template struct RemainderInplaceFunctor template using RemainderInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor< @@ -415,8 +443,8 @@ using RemainderInplaceStridedFunctor = template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class remainder_inplace_contig_kernel; /* @brief Types supported by in-place remainder */ @@ -464,10 +492,15 @@ remainder_inplace_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + RemainderContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + RemainderContigHyperparameterSet::n_vecs; + return elementwise_common::binary_inplace_contig_impl< argTy, resTy, RemainderInplaceContigFunctor, - remainder_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset, - res_p, res_offset, depends); + remainder_inplace_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp index 241f75c1bb..60ea58f7c3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp @@ -29,9 +29,11 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -87,8 +89,8 @@ template struct RoundFunctor template using RoundContigFunctor = elementwise_common::UnaryContigFunctor struct RoundOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct RoundContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class round_contig_kernel; template @@ -133,9 +154,12 @@ sycl::event round_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = RoundContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = RoundContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, RoundOutputType, RoundContigFunctor, round_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, RoundOutputType, RoundContigFunctor, round_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct RoundContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp index 61aafb13d9..f92dac50b1 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/rsqrt.hpp @@ -33,9 +33,11 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -68,8 +70,8 @@ template struct RsqrtFunctor template using RsqrtContigFunctor = elementwise_common::UnaryContigFunctor struct RsqrtOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct RsqrtContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class rsqrt_contig_kernel; template @@ -104,9 +125,12 @@ sycl::event rsqrt_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = RsqrtContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = RsqrtContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, RsqrtOutputType, RsqrtContigFunctor, rsqrt_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, RsqrtOutputType, RsqrtContigFunctor, rsqrt_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct RsqrtContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp index 651f7d5d9a..ffb4183474 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp @@ -31,9 +31,11 @@ #include #include "cabs_impl.hpp" -#include "kernels/elementwise_functions/common.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -103,8 +105,8 @@ template struct SignFunctor template using SignContigFunctor = elementwise_common::UnaryContigFunctor struct SignOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct SignContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class sign_contig_kernel; template @@ -145,9 +166,12 @@ sycl::event sign_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = SignContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = SignContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, SignOutputType, SignContigFunctor, sign_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, SignOutputType, SignContigFunctor, sign_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct SignContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp index e8ac7709ad..7ba04fcd17 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/signbit.hpp @@ -30,7 +30,10 @@ #include #include +#include "vec_size_util.hpp" + #include "kernels/dpctl_tensor_types.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -74,8 +77,8 @@ template struct SignbitFunctor template using SignbitContigFunctor = elementwise_common::UnaryContigFunctor struct SignbitOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct SignbitContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class signbit_contig_kernel; template @@ -110,9 +132,14 @@ sycl::event signbit_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + SignbitContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + SignbitContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, SignbitOutputType, SignbitContigFunctor, signbit_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, SignbitOutputType, SignbitContigFunctor, signbit_contig_kernel, + vec_sz, n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct SignbitContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp index 8bc12097a8..596c1de9e4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp @@ -29,10 +29,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -186,8 +188,8 @@ template struct SinFunctor template using SinContigFunctor = elementwise_common::UnaryContigFunctor struct SinOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct SinContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class sin_contig_kernel; template @@ -224,9 +245,12 @@ sycl::event sin_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = SinContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = SinContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, SinOutputType, SinContigFunctor, sin_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, SinOutputType, SinContigFunctor, sin_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct SinContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp index e83626e56d..6d418872b8 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp @@ -29,10 +29,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -155,8 +157,8 @@ template struct SinhFunctor template using SinhContigFunctor = elementwise_common::UnaryContigFunctor struct SinhOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct SinhContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class sinh_contig_kernel; template @@ -193,9 +214,12 @@ sycl::event sinh_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = SinhContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = SinhContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, SinhOutputType, SinhContigFunctor, sinh_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, SinhOutputType, SinhContigFunctor, sinh_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct SinhContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp index 5adb41b20d..6dcb2ca742 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -32,10 +32,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -80,8 +82,8 @@ template struct SqrtFunctor template using SqrtContigFunctor = elementwise_common::UnaryContigFunctor struct SqrtOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct SqrtContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class sqrt_contig_kernel; template @@ -119,9 +140,12 @@ sycl::event sqrt_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = SqrtContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = SqrtContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, SqrtOutputType, SqrtContigFunctor, sqrt_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, SqrtOutputType, SqrtContigFunctor, sqrt_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct SqrtContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp index 4b096cc291..dbf665b79c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp @@ -30,10 +30,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -97,8 +99,8 @@ template struct SquareFunctor template using SquareContigFunctor = elementwise_common::UnaryContigFunctor struct SquareOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct SquareContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class square_contig_kernel; template @@ -144,9 +165,14 @@ sycl::event square_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + SquareContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + SquareContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, SquareOutputType, SquareContigFunctor, square_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, SquareOutputType, SquareContigFunctor, square_contig_kernel, + vec_sz, n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct SquareContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp index 4ee3ae089b..47ca000c3f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp @@ -29,11 +29,13 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" @@ -85,8 +87,8 @@ template struct SubtractFunctor template using SubtractContigFunctor = elementwise_common::BinaryContigFunctor struct SubtractOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct SubtractContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class subtract_contig_kernel; template @@ -188,10 +210,16 @@ sycl::event subtract_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + SubtractContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + SubtractContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, SubtractOutputType, SubtractContigFunctor, - subtract_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + subtract_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct SubtractContigFactory @@ -401,8 +429,8 @@ template struct SubtractInplaceFunctor template using SubtractInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor< @@ -423,8 +451,8 @@ using SubtractInplaceStridedFunctor = template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class subtract_inplace_contig_kernel; /* @brief Types supported by in-place subtraction */ @@ -480,10 +508,15 @@ subtract_inplace_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + SubtractContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + SubtractContigHyperparameterSet::n_vecs; + return elementwise_common::binary_inplace_contig_impl< argTy, resTy, SubtractInplaceContigFunctor, - subtract_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset, - res_p, res_offset, depends); + subtract_inplace_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp index 4364d81fb7..a7da718a4b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp @@ -30,10 +30,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -130,8 +132,8 @@ template struct TanFunctor template using TanContigFunctor = elementwise_common::UnaryContigFunctor struct TanOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct TanContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class tan_contig_kernel; template @@ -168,9 +189,12 @@ sycl::event tan_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = TanContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = TanContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, TanOutputType, TanContigFunctor, tan_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, TanOutputType, TanContigFunctor, tan_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct TanContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp index 0af4e4e628..626420d48b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp @@ -31,10 +31,12 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -124,8 +126,8 @@ template struct TanhFunctor template using TanhContigFunctor = elementwise_common::UnaryContigFunctor struct TanhOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct TanhContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class tanh_contig_kernel; template @@ -162,9 +183,12 @@ sycl::event tanh_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = TanhContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = TanhContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, TanhOutputType, TanhContigFunctor, tanh_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, TanhOutputType, TanhContigFunctor, tanh_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct TanhContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp index 53db1e163c..27de2069ff 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -29,12 +29,14 @@ #include #include -#include "kernels/dpctl_tensor_types.hpp" #include "sycl_complex.hpp" +#include "vec_size_util.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" +#include "kernels/dpctl_tensor_types.hpp" #include "kernels/elementwise_functions/common.hpp" #include "kernels/elementwise_functions/common_inplace.hpp" @@ -112,8 +114,8 @@ struct TrueDivideFunctor template using TrueDivideContigFunctor = elementwise_common::BinaryContigFunctor< argT1, @@ -177,11 +179,31 @@ template struct TrueDivideOutputType static constexpr bool is_defined = !std::is_same_v; }; +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct TrueDivideContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class true_divide_contig_kernel; template @@ -196,10 +218,16 @@ true_divide_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + TrueDivideContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + TrueDivideContigHyperparameterSet::n_vecs; + return elementwise_common::binary_contig_impl< argTy1, argTy2, TrueDivideOutputType, TrueDivideContigFunctor, - true_divide_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, - arg2_offset, res_p, res_offset, depends); + true_divide_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); } template struct TrueDivideContigFactory @@ -473,8 +501,8 @@ struct TrueDivideInplaceTypeMapFactory template using TrueDivideInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor< @@ -495,8 +523,8 @@ using TrueDivideInplaceStridedFunctor = template + std::uint8_t vec_sz, + std::uint8_t n_vecs> class true_divide_inplace_contig_kernel; template @@ -509,10 +537,15 @@ true_divide_inplace_contig_impl(sycl::queue &exec_q, ssize_t res_offset, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = + TrueDivideContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = + TrueDivideContigHyperparameterSet::vec_sz; + return elementwise_common::binary_inplace_contig_impl< argTy, resTy, TrueDivideInplaceContigFunctor, - true_divide_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset, - res_p, res_offset, depends); + true_divide_inplace_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); } template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp index 55c8493880..cf9d6fa14f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp @@ -29,9 +29,11 @@ #include #include -#include "kernels/elementwise_functions/common.hpp" +#include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -75,8 +77,8 @@ template struct TruncFunctor template using TruncContigFunctor = elementwise_common::UnaryContigFunctor struct TruncOutputType static constexpr bool is_defined = !std::is_same_v; }; -template +namespace +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::ContigHyperparameterSetDefault; +using vsu_ns::UnaryContigHyperparameterSetEntry; + +template struct TruncContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of anonymous namespace + +template class trunc_contig_kernel; template @@ -119,9 +140,12 @@ sycl::event trunc_contig_impl(sycl::queue &exec_q, char *res_p, const std::vector &depends = {}) { + constexpr std::uint8_t vec_sz = TruncContigHyperparameterSet::vec_sz; + constexpr std::uint8_t n_vecs = TruncContigHyperparameterSet::n_vecs; + return elementwise_common::unary_contig_impl< - argTy, TruncOutputType, TruncContigFunctor, trunc_contig_kernel>( - exec_q, nelems, arg_p, res_p, depends); + argTy, TruncOutputType, TruncContigFunctor, trunc_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg_p, res_p, depends); } template struct TruncContigFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/vec_size_util.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/vec_size_util.hpp new file mode 100644 index 0000000000..0be2c68c3b --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/vec_size_util.hpp @@ -0,0 +1,73 @@ +//=== vec_size_utils.hpp - -------/ /*-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines utilities for selection of hyperparameters for kernels +/// implementing unary and binary elementwise functions for contiguous inputs +//===---------------------------------------------------------------------===// + +#pragma once + +#include +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace vec_size_utils +{ + +template +struct BinaryContigHyperparameterSetEntry + : std::conjunction, std::is_same> +{ + static constexpr std::uint8_t vec_sz = vec_sz_v; + static constexpr std::uint8_t n_vecs = n_vecs_v; +}; + +template +struct UnaryContigHyperparameterSetEntry : std::is_same +{ + static constexpr std::uint8_t vec_sz = vec_sz_v; + static constexpr std::uint8_t n_vecs = n_vecs_v; +}; + +template +struct ContigHyperparameterSetDefault : std::true_type +{ + static constexpr std::uint8_t vec_sz = vec_sz_v; + static constexpr std::uint8_t n_vecs = n_vecs_v; +}; + +} // end of namespace vec_size_utils +} // end of namespace kernels +} // end of namespace tensor +} // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/where.hpp b/dpctl/tensor/libtensor/include/kernels/where.hpp index b356c256c3..dbf3fdfedf 100644 --- a/dpctl/tensor/libtensor/include/kernels/where.hpp +++ b/dpctl/tensor/libtensor/include/kernels/where.hpp @@ -32,6 +32,7 @@ #include "dpctl_tensor_types.hpp" #include "kernels/alignment.hpp" #include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" #include "utils/type_utils.hpp" namespace dpctl @@ -50,15 +51,18 @@ using dpctl::tensor::kernels::alignment_utils:: using dpctl::tensor::kernels::alignment_utils::is_aligned; using dpctl::tensor::kernels::alignment_utils::required_alignment; +using dpctl::tensor::sycl_utils::sub_group_load; +using dpctl::tensor::sycl_utils::sub_group_store; + template class where_strided_kernel; -template +template class where_contig_kernel; template class WhereContigFunctor { @@ -82,42 +86,40 @@ class WhereContigFunctor void operator()(sycl::nd_item<1> ndit) const { + constexpr std::uint8_t nelems_per_wi = n_vecs * vec_sz; + using dpctl::tensor::type_utils::is_complex; if constexpr (!enable_sg_loadstore || is_complex::value || is_complex::value) { - std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; - size_t base = ndit.get_global_linear_id(); - - base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); - for (size_t offset = base; - offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz)); - offset += sgSize) - { + const std::uint16_t sgSize = + ndit.get_sub_group().get_local_range()[0]; + const size_t gid = ndit.get_global_linear_id(); + + const std::uint16_t nelems_per_sg = sgSize * nelems_per_wi; + const size_t start = + (gid / sgSize) * (nelems_per_sg - sgSize) + gid; + const size_t end = std::min(nelems, start + nelems_per_sg); + for (size_t offset = start; offset < end; offset += sgSize) { using dpctl::tensor::type_utils::convert_impl; - bool check = convert_impl(cond_p[offset]); + const bool check = convert_impl(cond_p[offset]); dst_p[offset] = check ? x1_p[offset] : x2_p[offset]; } } else { auto sg = ndit.get_sub_group(); - std::uint8_t sgSize = sg.get_local_range()[0]; - std::uint8_t max_sgSize = sg.get_max_local_range()[0]; - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * max_sgSize); - - if (base + n_vecs * vec_sz * sgSize < nelems && - sgSize == max_sgSize) - { + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + + const size_t base = + nelems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if (base + nelems_per_wi * sgSize < nelems) { sycl::vec dst_vec; - sycl::vec x1_vec; - sycl::vec x2_vec; - sycl::vec cond_vec; #pragma unroll for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { - auto idx = base + it * sgSize; + const size_t idx = base + it * sgSize; auto x1_multi_ptr = sycl::address_space_cast< sycl::access::address_space::global_space, sycl::access::decorated::yes>(&x1_p[idx]); @@ -131,20 +133,22 @@ class WhereContigFunctor sycl::access::address_space::global_space, sycl::access::decorated::yes>(&dst_p[idx]); - x1_vec = sg.load(x1_multi_ptr); - x2_vec = sg.load(x2_multi_ptr); - cond_vec = sg.load(cond_multi_ptr); + const sycl::vec x1_vec = + sub_group_load(sg, x1_multi_ptr); + const sycl::vec x2_vec = + sub_group_load(sg, x2_multi_ptr); + const sycl::vec cond_vec = + sub_group_load(sg, cond_multi_ptr); #pragma unroll for (std::uint8_t k = 0; k < vec_sz; ++k) { dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k]; } - sg.store(dst_multi_ptr, dst_vec); + sub_group_store(sg, dst_vec, dst_multi_ptr); } } else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems; - k += sgSize) - { + const size_t lane_id = sg.get_local_id()[0]; + for (size_t k = base + lane_id; k < nelems; k += sgSize) { dst_p[k] = cond_p[k] ? x1_p[k] : x2_p[k]; } } @@ -179,8 +183,8 @@ sycl::event where_contig_impl(sycl::queue &q, cgh.depends_on(depends); size_t lws = 64; - constexpr unsigned int vec_sz = 4; - constexpr unsigned int n_vecs = 2; + constexpr std::uint8_t vec_sz = 4u; + constexpr std::uint8_t n_vecs = 2u; const size_t n_groups = ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); const auto gws_range = sycl::range<1>(n_groups * lws); diff --git a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp index 3ad465db6a..19be8645c9 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp @@ -421,6 +421,100 @@ struct Identity::value>> static constexpr T value = sycl::known_identity::value; }; +// Sub-group load/store + +#ifndef USE_GROUP_LOAD_STORE +#if defined(SYCL_EXT_ONEAPI_GROUP_LOAD_STORE) && \ + SYCL_EXT_ONEAPI_GROUP_LOAD_STORE +#define USE_GROUP_LOAD_STORE 1 +#else +#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER > 20250100u) +#define USE_GROUP_LOAD_STORE 1 +#else +#define USE_GROUP_LOAD_STORE 0 +#endif +#endif +#endif + +#if (USE_GROUP_LOAD_STORE) +namespace ls_ns = sycl::ext::oneapi::experimental; +#endif + +template +auto sub_group_load(const sycl::sub_group &sg, + sycl::multi_ptr m_ptr) +{ +#if (USE_GROUP_LOAD_STORE) + using ValueT = typename std::remove_cv_t; + sycl::vec x{}; + ls_ns::group_load(sg, m_ptr, x, ls_ns::data_placement_blocked); + return x; +#else + return sg.load(m_ptr); +#endif +} + +template +auto sub_group_load(const sycl::sub_group &sg, + sycl::multi_ptr m_ptr) +{ +#if (USE_GROUP_LOAD_STORE) + using ValueT = typename std::remove_cv_t; + ValueT x{}; + ls_ns::group_load(sg, m_ptr, x, ls_ns::data_placement_blocked); + return x; +#else + return sg.load(m_ptr); +#endif +} + +template +std::enable_if_t< + std::is_same_v, std::remove_cv_t>, + void> +sub_group_store(const sycl::sub_group &sg, + const sycl::vec &val, + sycl::multi_ptr m_ptr) +{ +#if (USE_GROUP_LOAD_STORE) + static_assert(std::is_same_v); + ls_ns::group_store(sg, val, m_ptr, ls_ns::data_placement_blocked); + return; +#else + sg.store(m_ptr, val); + return; +#endif +} + +template +std::enable_if_t< + std::is_same_v, std::remove_cv_t>, + void> +sub_group_store(const sycl::sub_group &sg, + const VecT &val, + sycl::multi_ptr m_ptr) +{ +#if (USE_GROUP_LOAD_STORE) + ls_ns::group_store(sg, val, m_ptr, ls_ns::data_placement_blocked); + return; +#else + sg.store(m_ptr, val); + return; +#endif +} + } // namespace sycl_utils } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/utils/type_dispatch_building.hpp b/dpctl/tensor/libtensor/include/utils/type_dispatch_building.hpp index a9f3a0c876..1cd378f83e 100644 --- a/dpctl/tensor/libtensor/include/utils/type_dispatch_building.hpp +++ b/dpctl/tensor/libtensor/include/utils/type_dispatch_building.hpp @@ -26,6 +26,8 @@ #pragma once #include +#include + #include namespace dpctl @@ -161,7 +163,7 @@ class DispatchVectorBuilder /*! @brief struct to define result_type typename for Ty == ArgTy */ template -struct TypeMapResultEntry : std::bool_constant> +struct TypeMapResultEntry : std::is_same { using result_type = ResTy; }; @@ -174,8 +176,7 @@ template struct BinaryTypeMapResultEntry - : std::bool_constant, - std::is_same>> + : std::conjunction, std::is_same> { using result_type = ResTy; }; @@ -272,8 +273,8 @@ template struct NullPtrTable }; template -struct TypePairDefinedEntry : std::bool_constant && - std::is_same_v> +struct TypePairDefinedEntry + : std::conjunction, std::is_same> { static constexpr bool is_defined = true; };