From 863bd503a6b77ddeb707c4ba1fa1cb2111909b4f Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 22 Dec 2023 13:25:08 -0800 Subject: [PATCH 01/48] Adds ThreeOffsets_CombinedIndexer This enables strided data processing by gemm kernels --- .../libtensor/include/utils/offset_utils.hpp | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/dpctl/tensor/libtensor/include/utils/offset_utils.hpp b/dpctl/tensor/libtensor/include/utils/offset_utils.hpp index 523620737b..440d0d9d0b 100644 --- a/dpctl/tensor/libtensor/include/utils/offset_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/offset_utils.hpp @@ -450,6 +450,32 @@ struct ThreeZeroOffsets_Indexer } }; +template +struct ThreeOffsets_CombinedIndexer +{ +private: + FirstIndexerT first_indexer_; + SecondIndexerT second_indexer_; + ThirdIndexerT third_indexer_; + +public: + ThreeOffsets_CombinedIndexer(const FirstIndexerT &first_indexer, + const SecondIndexerT &second_indexer, + const ThirdIndexerT &third_indexer) + : first_indexer_(first_indexer), second_indexer_(second_indexer), + third_indexer_(third_indexer) + { + } + + ThreeOffsets operator()(py::ssize_t gid) const + { + return ThreeOffsets( + first_indexer_(gid), second_indexer_(gid), third_indexer_(gid)); + } +}; + template struct FourOffsets { FourOffsets() From 7cb2a532a37d607bfee01b94f4db52be85914202 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 2 Jan 2024 11:28:33 -0800 Subject: [PATCH 02/48] Remove unused `elementwise_functions.cpp` --- .../source/elementwise_functions.cpp | 5155 ----------------- 1 file changed, 5155 deletions(-) delete mode 100644 dpctl/tensor/libtensor/source/elementwise_functions.cpp diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp deleted file mode 100644 index 9ab7c0807c..0000000000 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ /dev/null @@ -1,5155 +0,0 @@ -//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===// -// -// Data Parallel Control (dpctl) -// -// Copyright 2020-2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines functions of dpctl.tensor._tensor_impl extensions, -/// specifically functions for elementwise operations. -//===----------------------------------------------------------------------===// - -#include "dpctl4pybind11.hpp" -#include -#include -#include -#include -#include - -#include "elementwise_functions.hpp" -#include "utils/type_dispatch.hpp" - -#include "kernels/elementwise_functions/abs.hpp" -#include "kernels/elementwise_functions/acos.hpp" -#include "kernels/elementwise_functions/acosh.hpp" -#include "kernels/elementwise_functions/add.hpp" -#include "kernels/elementwise_functions/asin.hpp" -#include "kernels/elementwise_functions/asinh.hpp" -#include "kernels/elementwise_functions/atan.hpp" -#include "kernels/elementwise_functions/atan2.hpp" -#include "kernels/elementwise_functions/atanh.hpp" -#include "kernels/elementwise_functions/bitwise_and.hpp" -#include "kernels/elementwise_functions/bitwise_invert.hpp" -#include "kernels/elementwise_functions/bitwise_left_shift.hpp" -#include "kernels/elementwise_functions/bitwise_or.hpp" -#include "kernels/elementwise_functions/bitwise_right_shift.hpp" -#include "kernels/elementwise_functions/bitwise_xor.hpp" -#include "kernels/elementwise_functions/cbrt.hpp" -#include "kernels/elementwise_functions/ceil.hpp" -#include "kernels/elementwise_functions/conj.hpp" -#include "kernels/elementwise_functions/copysign.hpp" -#include "kernels/elementwise_functions/cos.hpp" -#include "kernels/elementwise_functions/cosh.hpp" -#include "kernels/elementwise_functions/equal.hpp" -#include "kernels/elementwise_functions/exp.hpp" -#include "kernels/elementwise_functions/exp2.hpp" -#include "kernels/elementwise_functions/expm1.hpp" -#include "kernels/elementwise_functions/floor.hpp" -#include "kernels/elementwise_functions/floor_divide.hpp" -#include "kernels/elementwise_functions/greater.hpp" -#include "kernels/elementwise_functions/greater_equal.hpp" -#include "kernels/elementwise_functions/hypot.hpp" -#include "kernels/elementwise_functions/imag.hpp" -#include "kernels/elementwise_functions/isfinite.hpp" -#include "kernels/elementwise_functions/isinf.hpp" -#include "kernels/elementwise_functions/isnan.hpp" -#include "kernels/elementwise_functions/less.hpp" -#include "kernels/elementwise_functions/less_equal.hpp" -#include "kernels/elementwise_functions/log.hpp" -#include "kernels/elementwise_functions/log10.hpp" -#include "kernels/elementwise_functions/log1p.hpp" -#include "kernels/elementwise_functions/log2.hpp" -#include "kernels/elementwise_functions/logaddexp.hpp" -#include "kernels/elementwise_functions/logical_and.hpp" -#include "kernels/elementwise_functions/logical_not.hpp" -#include "kernels/elementwise_functions/logical_or.hpp" -#include "kernels/elementwise_functions/logical_xor.hpp" -#include "kernels/elementwise_functions/maximum.hpp" -#include "kernels/elementwise_functions/minimum.hpp" -#include "kernels/elementwise_functions/multiply.hpp" -#include "kernels/elementwise_functions/negative.hpp" -#include "kernels/elementwise_functions/not_equal.hpp" -#include "kernels/elementwise_functions/positive.hpp" -#include "kernels/elementwise_functions/pow.hpp" -#include "kernels/elementwise_functions/proj.hpp" -#include "kernels/elementwise_functions/real.hpp" -#include "kernels/elementwise_functions/remainder.hpp" -#include "kernels/elementwise_functions/round.hpp" -#include "kernels/elementwise_functions/rsqrt.hpp" -#include "kernels/elementwise_functions/sign.hpp" -#include "kernels/elementwise_functions/signbit.hpp" -#include "kernels/elementwise_functions/sin.hpp" -#include "kernels/elementwise_functions/sinh.hpp" -#include "kernels/elementwise_functions/sqrt.hpp" -#include "kernels/elementwise_functions/square.hpp" -#include "kernels/elementwise_functions/subtract.hpp" -#include "kernels/elementwise_functions/tan.hpp" -#include "kernels/elementwise_functions/tanh.hpp" -#include "kernels/elementwise_functions/true_divide.hpp" -#include "kernels/elementwise_functions/trunc.hpp" - -namespace dpctl -{ -namespace tensor -{ -namespace py_internal -{ - -namespace td_ns = dpctl::tensor::type_dispatch; - -py::dtype _dtype_from_typenum(td_ns::typenum_t dst_typenum_t) -{ - switch (dst_typenum_t) { - case td_ns::typenum_t::BOOL: - return py::dtype("?"); - case td_ns::typenum_t::INT8: - return py::dtype("i1"); - case td_ns::typenum_t::UINT8: - return py::dtype("u1"); - case td_ns::typenum_t::INT16: - return py::dtype("i2"); - case td_ns::typenum_t::UINT16: - return py::dtype("u2"); - case td_ns::typenum_t::INT32: - return py::dtype("i4"); - case td_ns::typenum_t::UINT32: - return py::dtype("u4"); - case td_ns::typenum_t::INT64: - return py::dtype("i8"); - case td_ns::typenum_t::UINT64: - return py::dtype("u8"); - case td_ns::typenum_t::HALF: - return py::dtype("f2"); - case td_ns::typenum_t::FLOAT: - return py::dtype("f4"); - case td_ns::typenum_t::DOUBLE: - return py::dtype("f8"); - case td_ns::typenum_t::CFLOAT: - return py::dtype("c8"); - case td_ns::typenum_t::CDOUBLE: - return py::dtype("c16"); - default: - throw py::value_error("Unrecognized dst_typeid"); - } -} - -int _result_typeid(int arg_typeid, const int *fn_output_id) -{ - if (arg_typeid < 0 || arg_typeid >= td_ns::num_types) { - throw py::value_error("Input typeid " + std::to_string(arg_typeid) + - " is outside of expected bounds."); - } - - return fn_output_id[arg_typeid]; -} - -namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; -using ew_cmn_ns::binary_contig_impl_fn_ptr_t; -using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; -using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; -using ew_cmn_ns::binary_strided_impl_fn_ptr_t; -using ew_cmn_ns::unary_contig_impl_fn_ptr_t; -using ew_cmn_ns::unary_strided_impl_fn_ptr_t; - -using ew_cmn_ns::binary_inplace_contig_impl_fn_ptr_t; -using ew_cmn_ns::binary_inplace_row_matrix_broadcast_impl_fn_ptr_t; -using ew_cmn_ns::binary_inplace_strided_impl_fn_ptr_t; - -// U01: ==== ABS (x) -namespace impl -{ - -namespace abs_fn_ns = dpctl::tensor::kernels::abs; - -static unary_contig_impl_fn_ptr_t abs_contig_dispatch_vector[td_ns::num_types]; -static int abs_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - abs_strided_dispatch_vector[td_ns::num_types]; - -void populate_abs_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = abs_fn_ns; - - using fn_ns::AbsContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(abs_contig_dispatch_vector); - - using fn_ns::AbsStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(abs_strided_dispatch_vector); - - using fn_ns::AbsTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(abs_output_typeid_vector); -}; - -} // namespace impl - -// U02: ==== ACOS (x) -namespace impl -{ - -namespace acos_fn_ns = dpctl::tensor::kernels::acos; - -static unary_contig_impl_fn_ptr_t acos_contig_dispatch_vector[td_ns::num_types]; -static int acos_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - acos_strided_dispatch_vector[td_ns::num_types]; - -void populate_acos_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = acos_fn_ns; - - using fn_ns::AcosContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(acos_contig_dispatch_vector); - - using fn_ns::AcosStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(acos_strided_dispatch_vector); - - using fn_ns::AcosTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(acos_output_typeid_vector); -} - -} // namespace impl - -// U03: ===== ACOSH (x) -namespace impl -{ - -namespace acosh_fn_ns = dpctl::tensor::kernels::acosh; - -static unary_contig_impl_fn_ptr_t - acosh_contig_dispatch_vector[td_ns::num_types]; -static int acosh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - acosh_strided_dispatch_vector[td_ns::num_types]; - -void populate_acosh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = acosh_fn_ns; - - using fn_ns::AcoshContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(acosh_contig_dispatch_vector); - - using fn_ns::AcoshStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(acosh_strided_dispatch_vector); - - using fn_ns::AcoshTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(acosh_output_typeid_vector); -} - -} // namespace impl - -// B01: ===== ADD (x1, x2) -namespace impl -{ -namespace add_fn_ns = dpctl::tensor::kernels::add; - -static binary_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int add_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - add_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -// add(matrix, row) -static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - add_contig_matrix_contig_row_broadcast_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -// add(row, matrix) -static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - add_contig_row_contig_matrix_broadcast_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - add_inplace_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - add_inplace_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t - add_inplace_row_matrix_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_add_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = add_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::AddTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(add_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::AddStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(add_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::AddContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(add_contig_dispatch_table); - - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::AddContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, - AddContigMatrixContigRowBroadcastFactory, num_types> - dtb4; - dtb4.populate_dispatch_table( - add_contig_matrix_contig_row_broadcast_dispatch_table); - - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::AddContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, - AddContigRowContigMatrixBroadcastFactory, num_types> - dtb5; - dtb5.populate_dispatch_table( - add_contig_row_contig_matrix_broadcast_dispatch_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::AddInplaceStridedFactory; - DispatchTableBuilder - dtb6; - dtb6.populate_dispatch_table(add_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::AddInplaceContigFactory; - DispatchTableBuilder - dtb7; - dtb7.populate_dispatch_table(add_inplace_contig_dispatch_table); - - // function pointers for inplace operation on contiguous matrix - // and contiguous row - using fn_ns::AddInplaceRowMatrixBroadcastFactory; - DispatchTableBuilder - dtb8; - dtb8.populate_dispatch_table(add_inplace_row_matrix_dispatch_table); -}; - -} // namespace impl - -// U04: ===== ASIN (x) -namespace impl -{ - -namespace asin_fn_ns = dpctl::tensor::kernels::asin; - -static unary_contig_impl_fn_ptr_t asin_contig_dispatch_vector[td_ns::num_types]; -static int asin_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - asin_strided_dispatch_vector[td_ns::num_types]; - -void populate_asin_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = asin_fn_ns; - - using fn_ns::AsinContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(asin_contig_dispatch_vector); - - using fn_ns::AsinStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(asin_strided_dispatch_vector); - - using fn_ns::AsinTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(asin_output_typeid_vector); -} - -} // namespace impl - -// U05: ===== ASINH (x) -namespace impl -{ - -namespace asinh_fn_ns = dpctl::tensor::kernels::asinh; - -static unary_contig_impl_fn_ptr_t - asinh_contig_dispatch_vector[td_ns::num_types]; -static int asinh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - asinh_strided_dispatch_vector[td_ns::num_types]; - -void populate_asinh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = asinh_fn_ns; - - using fn_ns::AsinhContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(asinh_contig_dispatch_vector); - - using fn_ns::AsinhStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(asinh_strided_dispatch_vector); - - using fn_ns::AsinhTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(asinh_output_typeid_vector); -} - -} // namespace impl - -// U06: ===== ATAN (x) -namespace impl -{ - -namespace atan_fn_ns = dpctl::tensor::kernels::atan; - -static unary_contig_impl_fn_ptr_t atan_contig_dispatch_vector[td_ns::num_types]; -static int atan_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - atan_strided_dispatch_vector[td_ns::num_types]; - -void populate_atan_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = atan_fn_ns; - - using fn_ns::AtanContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(atan_contig_dispatch_vector); - - using fn_ns::AtanStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(atan_strided_dispatch_vector); - - using fn_ns::AtanTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(atan_output_typeid_vector); -} - -} // namespace impl - -// B02: ===== ATAN2 (x1, x2) -namespace impl -{ -namespace atan2_fn_ns = dpctl::tensor::kernels::atan2; - -static binary_contig_impl_fn_ptr_t - atan2_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int atan2_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - atan2_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_atan2_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = atan2_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::Atan2TypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(atan2_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::Atan2StridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(atan2_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::Atan2ContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(atan2_contig_dispatch_table); -}; - -} // namespace impl - -// U07: ===== ATANH (x) -namespace impl -{ - -namespace atanh_fn_ns = dpctl::tensor::kernels::atanh; - -static unary_contig_impl_fn_ptr_t - atanh_contig_dispatch_vector[td_ns::num_types]; -static int atanh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - atanh_strided_dispatch_vector[td_ns::num_types]; - -void populate_atanh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = atanh_fn_ns; - - using fn_ns::AtanhContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(atanh_contig_dispatch_vector); - - using fn_ns::AtanhStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(atanh_strided_dispatch_vector); - - using fn_ns::AtanhTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(atanh_output_typeid_vector); -} - -} // namespace impl - -// B03: ===== BITWISE_AND (x1, x2) -namespace impl -{ -namespace bitwise_and_fn_ns = dpctl::tensor::kernels::bitwise_and; - -static binary_contig_impl_fn_ptr_t - bitwise_and_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int bitwise_and_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_and_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_bitwise_and_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_and_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseAndTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_and_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseAndStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_and_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseAndContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_and_contig_dispatch_table); -}; - -} // namespace impl - -// B04: ===== BITWISE_LEFT_SHIFT (x1, x2) -namespace impl -{ -namespace bitwise_left_shift_fn_ns = dpctl::tensor::kernels::bitwise_left_shift; - -static binary_contig_impl_fn_ptr_t - bitwise_left_shift_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int bitwise_left_shift_output_id_table[td_ns::num_types] - [td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_left_shift_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_bitwise_left_shift_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_left_shift_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseLeftShiftTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_left_shift_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseLeftShiftStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_left_shift_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseLeftShiftContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_left_shift_contig_dispatch_table); -}; - -} // namespace impl - -// U08: ===== BITWISE_INVERT (x) -namespace impl -{ - -namespace bitwise_invert_fn_ns = dpctl::tensor::kernels::bitwise_invert; - -static unary_contig_impl_fn_ptr_t - bitwise_invert_contig_dispatch_vector[td_ns::num_types]; -static int bitwise_invert_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - bitwise_invert_strided_dispatch_vector[td_ns::num_types]; - -void populate_bitwise_invert_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_invert_fn_ns; - - using fn_ns::BitwiseInvertContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(bitwise_invert_contig_dispatch_vector); - - using fn_ns::BitwiseInvertStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(bitwise_invert_strided_dispatch_vector); - - using fn_ns::BitwiseInvertTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(bitwise_invert_output_typeid_vector); -}; - -} // namespace impl - -// B05: ===== BITWISE_OR (x1, x2) -namespace impl -{ -namespace bitwise_or_fn_ns = dpctl::tensor::kernels::bitwise_or; - -static binary_contig_impl_fn_ptr_t - bitwise_or_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int bitwise_or_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_or_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_bitwise_or_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_or_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseOrTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_or_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseOrStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_or_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseOrContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_or_contig_dispatch_table); -}; -} // namespace impl - -// B06: ===== BITWISE_RIGHT_SHIFT (x1, x2) -namespace impl -{ -namespace bitwise_right_shift_fn_ns = - dpctl::tensor::kernels::bitwise_right_shift; - -static binary_contig_impl_fn_ptr_t - bitwise_right_shift_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int bitwise_right_shift_output_id_table[td_ns::num_types] - [td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_right_shift_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_bitwise_right_shift_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_right_shift_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseRightShiftTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_right_shift_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseRightShiftStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_right_shift_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseRightShiftContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_right_shift_contig_dispatch_table); -}; - -} // namespace impl - -// B07: ===== BITWISE_XOR (x1, x2) -namespace impl -{ -namespace bitwise_xor_fn_ns = dpctl::tensor::kernels::bitwise_xor; - -static binary_contig_impl_fn_ptr_t - bitwise_xor_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int bitwise_xor_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - bitwise_xor_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_bitwise_xor_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = bitwise_xor_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::BitwiseXorTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(bitwise_xor_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::BitwiseXorStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(bitwise_xor_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::BitwiseXorContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(bitwise_xor_contig_dispatch_table); -}; -} // namespace impl - -// U09: ==== CEIL (x) -namespace impl -{ - -namespace ceil_fn_ns = dpctl::tensor::kernels::ceil; - -static unary_contig_impl_fn_ptr_t ceil_contig_dispatch_vector[td_ns::num_types]; -static int ceil_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - ceil_strided_dispatch_vector[td_ns::num_types]; - -void populate_ceil_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = ceil_fn_ns; - - using fn_ns::CeilContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(ceil_contig_dispatch_vector); - - using fn_ns::CeilStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(ceil_strided_dispatch_vector); - - using fn_ns::CeilTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(ceil_output_typeid_vector); -} - -} // namespace impl - -// U10: ==== CONJ (x) -namespace impl -{ - -namespace conj_fn_ns = dpctl::tensor::kernels::conj; - -static unary_contig_impl_fn_ptr_t conj_contig_dispatch_vector[td_ns::num_types]; -static int conj_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - conj_strided_dispatch_vector[td_ns::num_types]; - -void populate_conj_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = conj_fn_ns; - - using fn_ns::ConjContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(conj_contig_dispatch_vector); - - using fn_ns::ConjStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(conj_strided_dispatch_vector); - - using fn_ns::ConjTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(conj_output_typeid_vector); -} -} // namespace impl - -// U11: ==== COS (x) -namespace impl -{ - -namespace cos_fn_ns = dpctl::tensor::kernels::cos; - -static unary_contig_impl_fn_ptr_t cos_contig_dispatch_vector[td_ns::num_types]; -static int cos_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - cos_strided_dispatch_vector[td_ns::num_types]; - -void populate_cos_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = cos_fn_ns; - - using fn_ns::CosContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(cos_contig_dispatch_vector); - - using fn_ns::CosStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(cos_strided_dispatch_vector); - - using fn_ns::CosTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(cos_output_typeid_vector); -} - -} // namespace impl - -// U12: ==== COSH (x) -namespace impl -{ - -namespace cosh_fn_ns = dpctl::tensor::kernels::cosh; - -static unary_contig_impl_fn_ptr_t cosh_contig_dispatch_vector[td_ns::num_types]; -static int cosh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - cosh_strided_dispatch_vector[td_ns::num_types]; - -void populate_cosh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = cosh_fn_ns; - - using fn_ns::CoshContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(cosh_contig_dispatch_vector); - - using fn_ns::CoshStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(cosh_strided_dispatch_vector); - - using fn_ns::CoshTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(cosh_output_typeid_vector); -} - -} // namespace impl - -// B08: ==== DIVIDE (x1, x2) -namespace impl -{ -namespace true_divide_fn_ns = dpctl::tensor::kernels::true_divide; - -static binary_contig_impl_fn_ptr_t - true_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int true_divide_output_id_table[td_ns::num_types][td_ns::num_types]; -static int true_divide_inplace_output_id_table[td_ns::num_types] - [td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - true_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -// divide(matrix, row) -static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - true_divide_contig_matrix_contig_row_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -// divide(row, matrix) -static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - true_divide_contig_row_contig_matrix_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - true_divide_inplace_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - true_divide_inplace_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t - true_divide_inplace_row_matrix_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_true_divide_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = true_divide_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::TrueDivideTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(true_divide_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::TrueDivideStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(true_divide_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::TrueDivideContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(true_divide_contig_dispatch_table); - - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::TrueDivideContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, - TrueDivideContigMatrixContigRowBroadcastFactory, num_types> - dtb4; - dtb4.populate_dispatch_table( - true_divide_contig_matrix_contig_row_broadcast_dispatch_table); - - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::TrueDivideContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, - TrueDivideContigRowContigMatrixBroadcastFactory, num_types> - dtb5; - dtb5.populate_dispatch_table( - true_divide_contig_row_contig_matrix_broadcast_dispatch_table); - - // which input types are supported, and what is the type of the result - using fn_ns::TrueDivideInplaceTypeMapFactory; - DispatchTableBuilder dtb6; - dtb6.populate_dispatch_table(true_divide_inplace_output_id_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::TrueDivideInplaceStridedFactory; - DispatchTableBuilder - dtb7; - dtb7.populate_dispatch_table(true_divide_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::TrueDivideInplaceContigFactory; - DispatchTableBuilder - dtb8; - dtb8.populate_dispatch_table(true_divide_inplace_contig_dispatch_table); - - // function pointers for inplace operation on contiguous matrix - // and contiguous row - using fn_ns::TrueDivideInplaceRowMatrixBroadcastFactory; - DispatchTableBuilder - dtb9; - dtb9.populate_dispatch_table(true_divide_inplace_row_matrix_dispatch_table); -}; - -} // namespace impl - -// B09: ==== EQUAL (x1, x2) -namespace impl -{ -namespace equal_fn_ns = dpctl::tensor::kernels::equal; - -static binary_contig_impl_fn_ptr_t - equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int equal_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_equal_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = equal_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::EqualTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(equal_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::EqualStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(equal_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::EqualContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(equal_contig_dispatch_table); -}; -} // namespace impl - -// U13: ==== EXP (x) -namespace impl -{ - -namespace exp_fn_ns = dpctl::tensor::kernels::exp; - -static unary_contig_impl_fn_ptr_t exp_contig_dispatch_vector[td_ns::num_types]; -static int exp_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - exp_strided_dispatch_vector[td_ns::num_types]; - -void populate_exp_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = exp_fn_ns; - - using fn_ns::ExpContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(exp_contig_dispatch_vector); - - using fn_ns::ExpStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(exp_strided_dispatch_vector); - - using fn_ns::ExpTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(exp_output_typeid_vector); -} - -} // namespace impl - -// U14: ==== EXPM1 (x) -namespace impl -{ - -namespace expm1_fn_ns = dpctl::tensor::kernels::expm1; - -static unary_contig_impl_fn_ptr_t - expm1_contig_dispatch_vector[td_ns::num_types]; -static int expm1_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - expm1_strided_dispatch_vector[td_ns::num_types]; - -void populate_expm1_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = expm1_fn_ns; - - using fn_ns::Expm1ContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(expm1_contig_dispatch_vector); - - using fn_ns::Expm1StridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(expm1_strided_dispatch_vector); - - using fn_ns::Expm1TypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(expm1_output_typeid_vector); -} - -} // namespace impl - -// U15: ==== FLOOR (x) -namespace impl -{ - -namespace floor_fn_ns = dpctl::tensor::kernels::floor; - -static unary_contig_impl_fn_ptr_t - floor_contig_dispatch_vector[td_ns::num_types]; -static int floor_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - floor_strided_dispatch_vector[td_ns::num_types]; - -void populate_floor_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = floor_fn_ns; - - using fn_ns::FloorContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(floor_contig_dispatch_vector); - - using fn_ns::FloorStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(floor_strided_dispatch_vector); - - using fn_ns::FloorTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(floor_output_typeid_vector); -} - -} // namespace impl - -// B10: ==== FLOOR_DIVIDE (x1, x2) -namespace impl -{ -namespace floor_divide_fn_ns = dpctl::tensor::kernels::floor_divide; - -static binary_contig_impl_fn_ptr_t - floor_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int floor_divide_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - floor_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - floor_divide_inplace_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - floor_divide_inplace_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_floor_divide_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = floor_divide_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::FloorDivideTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(floor_divide_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::FloorDivideStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(floor_divide_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::FloorDivideContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(floor_divide_contig_dispatch_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::FloorDivideInplaceStridedFactory; - DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table(floor_divide_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::FloorDivideInplaceContigFactory; - DispatchTableBuilder - dtb5; - dtb5.populate_dispatch_table(floor_divide_inplace_contig_dispatch_table); -}; - -} // namespace impl - -// B11: ==== GREATER (x1, x2) -namespace impl -{ -namespace greater_fn_ns = dpctl::tensor::kernels::greater; - -static binary_contig_impl_fn_ptr_t - greater_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int greater_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - greater_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_greater_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = greater_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::GreaterTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(greater_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::GreaterStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(greater_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::GreaterContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(greater_contig_dispatch_table); -}; -} // namespace impl - -// B12: ==== GREATER_EQUAL (x1, x2) -namespace impl -{ -namespace greater_equal_fn_ns = dpctl::tensor::kernels::greater_equal; - -static binary_contig_impl_fn_ptr_t - greater_equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int greater_equal_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - greater_equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_greater_equal_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = greater_equal_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::GreaterEqualTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(greater_equal_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::GreaterEqualStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(greater_equal_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::GreaterEqualContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(greater_equal_contig_dispatch_table); -}; -} // namespace impl - -// U16: ==== IMAG (x) -namespace impl -{ - -namespace imag_fn_ns = dpctl::tensor::kernels::imag; - -static unary_contig_impl_fn_ptr_t imag_contig_dispatch_vector[td_ns::num_types]; -static int imag_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - imag_strided_dispatch_vector[td_ns::num_types]; - -void populate_imag_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = imag_fn_ns; - - using fn_ns::ImagContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(imag_contig_dispatch_vector); - - using fn_ns::ImagStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(imag_strided_dispatch_vector); - - using fn_ns::ImagTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(imag_output_typeid_vector); -} -} // namespace impl - -// U17: ==== ISFINITE (x) -namespace impl -{ -namespace isfinite_fn_ns = dpctl::tensor::kernels::isfinite; - -static unary_contig_impl_fn_ptr_t - isfinite_contig_dispatch_vector[td_ns::num_types]; -static int isfinite_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - isfinite_strided_dispatch_vector[td_ns::num_types]; - -void populate_isfinite_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = isfinite_fn_ns; - - using fn_ns::IsFiniteContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(isfinite_contig_dispatch_vector); - - using fn_ns::IsFiniteStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(isfinite_strided_dispatch_vector); - - using fn_ns::IsFiniteTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(isfinite_output_typeid_vector); -} - -} // namespace impl - -// U18: ==== ISINF (x) -namespace impl -{ -namespace isinf_fn_ns = dpctl::tensor::kernels::isinf; - -static unary_contig_impl_fn_ptr_t - isinf_contig_dispatch_vector[td_ns::num_types]; -static int isinf_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - isinf_strided_dispatch_vector[td_ns::num_types]; - -void populate_isinf_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = isinf_fn_ns; - - using fn_ns::IsInfContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(isinf_contig_dispatch_vector); - - using fn_ns::IsInfStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(isinf_strided_dispatch_vector); - - using fn_ns::IsInfTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(isinf_output_typeid_vector); -} - -} // namespace impl - -// U19: ==== ISNAN (x) -namespace impl -{ -namespace isnan_fn_ns = dpctl::tensor::kernels::isnan; - -static unary_contig_impl_fn_ptr_t - isnan_contig_dispatch_vector[td_ns::num_types]; -static int isnan_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - isnan_strided_dispatch_vector[td_ns::num_types]; - -void populate_isnan_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = isnan_fn_ns; - - using fn_ns::IsNanContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(isnan_contig_dispatch_vector); - - using fn_ns::IsNanStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(isnan_strided_dispatch_vector); - - using fn_ns::IsNanTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(isnan_output_typeid_vector); -} - -} // namespace impl - -// B13: ==== LESS (x1, x2) -namespace impl -{ -namespace less_fn_ns = dpctl::tensor::kernels::less; - -static binary_contig_impl_fn_ptr_t less_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int less_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - less_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_less_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = less_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LessTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(less_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LessStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(less_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LessContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(less_contig_dispatch_table); -}; -} // namespace impl - -// B14: ==== LESS_EQUAL (x1, x2) -namespace impl -{ -namespace less_equal_fn_ns = dpctl::tensor::kernels::less_equal; - -static binary_contig_impl_fn_ptr_t - less_equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int less_equal_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - less_equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_less_equal_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = less_equal_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LessEqualTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(less_equal_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LessEqualStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(less_equal_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LessEqualContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(less_equal_contig_dispatch_table); -}; -} // namespace impl - -// U20: ==== LOG (x) -namespace impl -{ - -namespace log_fn_ns = dpctl::tensor::kernels::log; - -static unary_contig_impl_fn_ptr_t log_contig_dispatch_vector[td_ns::num_types]; -static int log_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - log_strided_dispatch_vector[td_ns::num_types]; - -void populate_log_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = log_fn_ns; - - using fn_ns::LogContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(log_contig_dispatch_vector); - - using fn_ns::LogStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(log_strided_dispatch_vector); - - using fn_ns::LogTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(log_output_typeid_vector); -} - -} // namespace impl - -// U21: ==== LOG1P (x) -namespace impl -{ - -namespace log1p_fn_ns = dpctl::tensor::kernels::log1p; - -static unary_contig_impl_fn_ptr_t - log1p_contig_dispatch_vector[td_ns::num_types]; -static int log1p_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - log1p_strided_dispatch_vector[td_ns::num_types]; - -void populate_log1p_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = log1p_fn_ns; - - using fn_ns::Log1pContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(log1p_contig_dispatch_vector); - - using fn_ns::Log1pStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(log1p_strided_dispatch_vector); - - using fn_ns::Log1pTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(log1p_output_typeid_vector); -} - -} // namespace impl - -// U22: ==== LOG2 (x) -namespace impl -{ - -namespace log2_fn_ns = dpctl::tensor::kernels::log2; - -static unary_contig_impl_fn_ptr_t log2_contig_dispatch_vector[td_ns::num_types]; -static int log2_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - log2_strided_dispatch_vector[td_ns::num_types]; - -void populate_log2_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = log2_fn_ns; - - using fn_ns::Log2ContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(log2_contig_dispatch_vector); - - using fn_ns::Log2StridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(log2_strided_dispatch_vector); - - using fn_ns::Log2TypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(log2_output_typeid_vector); -}; - -} // namespace impl - -// U23: ==== LOG10 (x) -namespace impl -{ - -namespace log10_fn_ns = dpctl::tensor::kernels::log10; - -static unary_contig_impl_fn_ptr_t - log10_contig_dispatch_vector[td_ns::num_types]; -static int log10_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - log10_strided_dispatch_vector[td_ns::num_types]; - -void populate_log10_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = log10_fn_ns; - - using fn_ns::Log10ContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(log10_contig_dispatch_vector); - - using fn_ns::Log10StridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(log10_strided_dispatch_vector); - - using fn_ns::Log10TypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(log10_output_typeid_vector); -}; - -} // namespace impl - -// B15: ==== LOGADDEXP (x1, x2) -namespace impl -{ -namespace logaddexp_fn_ns = dpctl::tensor::kernels::logaddexp; - -static binary_contig_impl_fn_ptr_t - logaddexp_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int logaddexp_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - logaddexp_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_logaddexp_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = logaddexp_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LogAddExpTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(logaddexp_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LogAddExpStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(logaddexp_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LogAddExpContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(logaddexp_contig_dispatch_table); -}; -} // namespace impl - -// B16: ==== LOGICAL_AND (x1, x2) -namespace impl -{ -namespace logical_and_fn_ns = dpctl::tensor::kernels::logical_and; - -static binary_contig_impl_fn_ptr_t - logical_and_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int logical_and_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - logical_and_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_logical_and_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = logical_and_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LogicalAndTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(logical_and_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LogicalAndStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(logical_and_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LogicalAndContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(logical_and_contig_dispatch_table); -}; -} // namespace impl - -// U24: ==== LOGICAL_NOT (x) -namespace impl -{ -namespace logical_not_fn_ns = dpctl::tensor::kernels::logical_not; - -static unary_contig_impl_fn_ptr_t - logical_not_contig_dispatch_vector[td_ns::num_types]; -static int logical_not_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - logical_not_strided_dispatch_vector[td_ns::num_types]; - -void populate_logical_not_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = logical_not_fn_ns; - - using fn_ns::LogicalNotContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(logical_not_contig_dispatch_vector); - - using fn_ns::LogicalNotStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(logical_not_strided_dispatch_vector); - - using fn_ns::LogicalNotTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(logical_not_output_typeid_vector); -}; -} // namespace impl - -// B17: ==== LOGICAL_OR (x1, x2) -namespace impl -{ -namespace logical_or_fn_ns = dpctl::tensor::kernels::logical_or; - -static binary_contig_impl_fn_ptr_t - logical_or_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int logical_or_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - logical_or_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_logical_or_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = logical_or_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LogicalOrTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(logical_or_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LogicalOrStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(logical_or_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LogicalOrContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(logical_or_contig_dispatch_table); -}; -} // namespace impl - -// B18: ==== LOGICAL_XOR (x1, x2) -namespace impl -{ -namespace logical_xor_fn_ns = dpctl::tensor::kernels::logical_xor; - -static binary_contig_impl_fn_ptr_t - logical_xor_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int logical_xor_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - logical_xor_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_logical_xor_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = logical_xor_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::LogicalXorTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(logical_xor_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::LogicalXorStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(logical_xor_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::LogicalXorContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(logical_xor_contig_dispatch_table); -}; -} // namespace impl - -// B??: ==== MAXIMUM (x1, x2) -namespace impl -{ - -namespace maximum_fn_ns = dpctl::tensor::kernels::maximum; - -static binary_contig_impl_fn_ptr_t - maximum_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int maximum_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - maximum_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_maximum_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = maximum_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::MaximumTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(maximum_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::MaximumStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(maximum_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::MaximumContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(maximum_contig_dispatch_table); -}; - -} // namespace impl - -// B??: ==== MINIMUM (x1, x2) -namespace impl -{ - -namespace minimum_fn_ns = dpctl::tensor::kernels::minimum; - -static binary_contig_impl_fn_ptr_t - minimum_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int minimum_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - minimum_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_minimum_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = minimum_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::MinimumTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(minimum_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::MinimumStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(minimum_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::MinimumContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(minimum_contig_dispatch_table); -}; - -} // namespace impl - -// B19: ==== MULTIPLY (x1, x2) -namespace impl -{ - -namespace multiply_fn_ns = dpctl::tensor::kernels::multiply; - -static binary_contig_impl_fn_ptr_t - multiply_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int multiply_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - multiply_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -// mul(matrix, row) -static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - multiply_contig_matrix_contig_row_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -// mul(row, matrix) -static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - multiply_contig_row_contig_matrix_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - multiply_inplace_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - multiply_inplace_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t - multiply_inplace_row_matrix_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_multiply_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = multiply_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::MultiplyTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(multiply_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::MultiplyStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(multiply_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::MultiplyContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(multiply_contig_dispatch_table); - - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::MultiplyContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, - MultiplyContigMatrixContigRowBroadcastFactory, num_types> - dtb4; - dtb4.populate_dispatch_table( - multiply_contig_matrix_contig_row_broadcast_dispatch_table); - - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::MultiplyContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, - MultiplyContigRowContigMatrixBroadcastFactory, num_types> - dtb5; - dtb5.populate_dispatch_table( - multiply_contig_row_contig_matrix_broadcast_dispatch_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::MultiplyInplaceStridedFactory; - DispatchTableBuilder - dtb6; - dtb6.populate_dispatch_table(multiply_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::MultiplyInplaceContigFactory; - DispatchTableBuilder - dtb7; - dtb7.populate_dispatch_table(multiply_inplace_contig_dispatch_table); - - // function pointers for inplace operation on contiguous matrix - // and contiguous row - using fn_ns::MultiplyInplaceRowMatrixBroadcastFactory; - DispatchTableBuilder - dtb8; - dtb8.populate_dispatch_table(multiply_inplace_row_matrix_dispatch_table); -}; - -} // namespace impl - -// U25: ==== NEGATIVE (x) -namespace impl -{ - -namespace negative_fn_ns = dpctl::tensor::kernels::negative; - -static unary_contig_impl_fn_ptr_t - negative_contig_dispatch_vector[td_ns::num_types]; -static int negative_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - negative_strided_dispatch_vector[td_ns::num_types]; - -void populate_negative_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = negative_fn_ns; - - using fn_ns::NegativeContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(negative_contig_dispatch_vector); - - using fn_ns::NegativeStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(negative_strided_dispatch_vector); - - using fn_ns::NegativeTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(negative_output_typeid_vector); -} - -} // namespace impl - -// B20: ==== NOT_EQUAL (x1, x2) -namespace impl -{ -namespace not_equal_fn_ns = dpctl::tensor::kernels::not_equal; - -static binary_contig_impl_fn_ptr_t - not_equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int not_equal_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - not_equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_not_equal_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = not_equal_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::NotEqualTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(not_equal_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::NotEqualStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(not_equal_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::NotEqualContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(not_equal_contig_dispatch_table); -}; -} // namespace impl - -// U26: ==== POSITIVE (x) -namespace impl -{ - -namespace positive_fn_ns = dpctl::tensor::kernels::positive; - -static unary_contig_impl_fn_ptr_t - positive_contig_dispatch_vector[td_ns::num_types]; -static int positive_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - positive_strided_dispatch_vector[td_ns::num_types]; - -void populate_positive_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = positive_fn_ns; - - using fn_ns::PositiveContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(positive_contig_dispatch_vector); - - using fn_ns::PositiveStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(positive_strided_dispatch_vector); - - using fn_ns::PositiveTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(positive_output_typeid_vector); -} - -} // namespace impl - -// B21: ==== POW (x1, x2) -namespace impl -{ - -namespace pow_fn_ns = dpctl::tensor::kernels::pow; - -static binary_contig_impl_fn_ptr_t pow_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int pow_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - pow_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_pow_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = pow_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::PowTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(pow_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::PowStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(pow_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::PowContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(pow_contig_dispatch_table); -}; - -} // namespace impl - -// U??: ==== PROJ (x) -namespace impl -{ - -namespace proj_fn_ns = dpctl::tensor::kernels::proj; - -static unary_contig_impl_fn_ptr_t proj_contig_dispatch_vector[td_ns::num_types]; -static int proj_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - proj_strided_dispatch_vector[td_ns::num_types]; - -void populate_proj_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = proj_fn_ns; - - using fn_ns::ProjContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(proj_contig_dispatch_vector); - - using fn_ns::ProjStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(proj_strided_dispatch_vector); - - using fn_ns::ProjTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(proj_output_typeid_vector); -} -} // namespace impl - -// U27: ==== REAL (x) -namespace impl -{ - -namespace real_fn_ns = dpctl::tensor::kernels::real; - -static unary_contig_impl_fn_ptr_t real_contig_dispatch_vector[td_ns::num_types]; -static int real_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - real_strided_dispatch_vector[td_ns::num_types]; - -void populate_real_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = real_fn_ns; - - using fn_ns::RealContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(real_contig_dispatch_vector); - - using fn_ns::RealStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(real_strided_dispatch_vector); - - using fn_ns::RealTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(real_output_typeid_vector); -} -} // namespace impl - -// B22: ==== REMAINDER (x1, x2) -namespace impl -{ - -namespace remainder_fn_ns = dpctl::tensor::kernels::remainder; - -static binary_contig_impl_fn_ptr_t - remainder_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int remainder_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - remainder_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_remainder_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = remainder_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::RemainderTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(remainder_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::RemainderStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(remainder_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::RemainderContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(remainder_contig_dispatch_table); -} - -} // namespace impl - -// U28: ==== ROUND (x) -namespace impl -{ - -namespace round_fn_ns = dpctl::tensor::kernels::round; - -static unary_contig_impl_fn_ptr_t - round_contig_dispatch_vector[td_ns::num_types]; -static int round_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - round_strided_dispatch_vector[td_ns::num_types]; - -void populate_round_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = round_fn_ns; - - using fn_ns::RoundContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(round_contig_dispatch_vector); - - using fn_ns::RoundStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(round_strided_dispatch_vector); - - using fn_ns::RoundTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(round_output_typeid_vector); -} - -} // namespace impl - -// U29: ==== SIGN (x) -namespace impl -{ - -namespace sign_fn_ns = dpctl::tensor::kernels::sign; - -static unary_contig_impl_fn_ptr_t sign_contig_dispatch_vector[td_ns::num_types]; -static int sign_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - sign_strided_dispatch_vector[td_ns::num_types]; - -void populate_sign_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = sign_fn_ns; - - using fn_ns::SignContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(sign_contig_dispatch_vector); - - using fn_ns::SignStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(sign_strided_dispatch_vector); - - using fn_ns::SignTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(sign_output_typeid_vector); -} - -} // namespace impl - -// ==== SIGNBIT (x) -namespace impl -{ - -namespace signbit_fn_ns = dpctl::tensor::kernels::signbit; - -static unary_contig_impl_fn_ptr_t - signbit_contig_dispatch_vector[td_ns::num_types]; -static int signbit_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - signbit_strided_dispatch_vector[td_ns::num_types]; - -void populate_signbit_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = signbit_fn_ns; - - using fn_ns::SignbitContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(signbit_contig_dispatch_vector); - - using fn_ns::SignbitStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(signbit_strided_dispatch_vector); - - using fn_ns::SignbitTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(signbit_output_typeid_vector); -} - -} // namespace impl - -// U30: ==== SIN (x) -namespace impl -{ - -namespace sin_fn_ns = dpctl::tensor::kernels::sin; - -static unary_contig_impl_fn_ptr_t sin_contig_dispatch_vector[td_ns::num_types]; -static int sin_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - sin_strided_dispatch_vector[td_ns::num_types]; - -void populate_sin_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = sin_fn_ns; - - using fn_ns::SinContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(sin_contig_dispatch_vector); - - using fn_ns::SinStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(sin_strided_dispatch_vector); - - using fn_ns::SinTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(sin_output_typeid_vector); -} - -} // namespace impl - -// U31: ==== SINH (x) -namespace impl -{ - -namespace sinh_fn_ns = dpctl::tensor::kernels::sinh; - -static unary_contig_impl_fn_ptr_t sinh_contig_dispatch_vector[td_ns::num_types]; -static int sinh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - sinh_strided_dispatch_vector[td_ns::num_types]; - -void populate_sinh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = sinh_fn_ns; - - using fn_ns::SinhContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(sinh_contig_dispatch_vector); - - using fn_ns::SinhStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(sinh_strided_dispatch_vector); - - using fn_ns::SinhTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(sinh_output_typeid_vector); -} - -} // namespace impl - -// U32: ==== SQUARE (x) -namespace impl -{ - -namespace square_fn_ns = dpctl::tensor::kernels::square; - -static unary_contig_impl_fn_ptr_t - square_contig_dispatch_vector[td_ns::num_types]; -static int square_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - square_strided_dispatch_vector[td_ns::num_types]; - -void populate_square_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = square_fn_ns; - - using fn_ns::SquareContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(square_contig_dispatch_vector); - - using fn_ns::SquareStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(square_strided_dispatch_vector); - - using fn_ns::SquareTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(square_output_typeid_vector); -} - -} // namespace impl - -// U33: ==== SQRT (x) -namespace impl -{ - -namespace sqrt_fn_ns = dpctl::tensor::kernels::sqrt; - -static unary_contig_impl_fn_ptr_t sqrt_contig_dispatch_vector[td_ns::num_types]; -static int sqrt_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - sqrt_strided_dispatch_vector[td_ns::num_types]; - -void populate_sqrt_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = sqrt_fn_ns; - - using fn_ns::SqrtContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(sqrt_contig_dispatch_vector); - - using fn_ns::SqrtStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(sqrt_strided_dispatch_vector); - - using fn_ns::SqrtTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(sqrt_output_typeid_vector); -} - -} // namespace impl - -// B23: ==== SUBTRACT (x1, x2) -namespace impl -{ -namespace subtract_fn_ns = dpctl::tensor::kernels::subtract; - -static binary_contig_impl_fn_ptr_t - subtract_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int subtract_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - subtract_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -// sub(matrix, row) -static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - subtract_contig_matrix_contig_row_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -// sub(row, matrix) -static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - subtract_contig_row_contig_matrix_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -static binary_inplace_contig_impl_fn_ptr_t - subtract_inplace_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_strided_impl_fn_ptr_t - subtract_inplace_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; -static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t - subtract_inplace_row_matrix_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_subtract_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = subtract_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::SubtractTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(subtract_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::SubtractStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(subtract_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::SubtractContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(subtract_contig_dispatch_table); - - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::SubtractContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, - SubtractContigMatrixContigRowBroadcastFactory, num_types> - dtb4; - dtb4.populate_dispatch_table( - subtract_contig_matrix_contig_row_broadcast_dispatch_table); - - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::SubtractContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, - SubtractContigRowContigMatrixBroadcastFactory, num_types> - dtb5; - dtb5.populate_dispatch_table( - subtract_contig_row_contig_matrix_broadcast_dispatch_table); - - // function pointers for inplace operation on general strided arrays - using fn_ns::SubtractInplaceStridedFactory; - DispatchTableBuilder - dtb6; - dtb6.populate_dispatch_table(subtract_inplace_strided_dispatch_table); - - // function pointers for inplace operation on contiguous inputs and output - using fn_ns::SubtractInplaceContigFactory; - DispatchTableBuilder - dtb7; - dtb7.populate_dispatch_table(subtract_inplace_contig_dispatch_table); - - // function pointers for inplace operation on contiguous matrix - // and contiguous row - using fn_ns::SubtractInplaceRowMatrixBroadcastFactory; - DispatchTableBuilder - dtb8; - dtb8.populate_dispatch_table(subtract_inplace_row_matrix_dispatch_table); -}; - -} // namespace impl - -// U34: ==== TAN (x) -namespace impl -{ - -namespace tan_fn_ns = dpctl::tensor::kernels::tan; - -static unary_contig_impl_fn_ptr_t tan_contig_dispatch_vector[td_ns::num_types]; -static int tan_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - tan_strided_dispatch_vector[td_ns::num_types]; - -void populate_tan_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = tan_fn_ns; - - using fn_ns::TanContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(tan_contig_dispatch_vector); - - using fn_ns::TanStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(tan_strided_dispatch_vector); - - using fn_ns::TanTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(tan_output_typeid_vector); -} - -} // namespace impl - -// U35: ==== TANH (x) -namespace impl -{ - -namespace tanh_fn_ns = dpctl::tensor::kernels::tanh; - -static unary_contig_impl_fn_ptr_t tanh_contig_dispatch_vector[td_ns::num_types]; -static int tanh_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - tanh_strided_dispatch_vector[td_ns::num_types]; - -void populate_tanh_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = tanh_fn_ns; - - using fn_ns::TanhContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(tanh_contig_dispatch_vector); - - using fn_ns::TanhStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(tanh_strided_dispatch_vector); - - using fn_ns::TanhTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(tanh_output_typeid_vector); -} - -} // namespace impl - -// U36: ==== TRUNC (x) -namespace impl -{ - -namespace trunc_fn_ns = dpctl::tensor::kernels::trunc; - -static unary_contig_impl_fn_ptr_t - trunc_contig_dispatch_vector[td_ns::num_types]; -static int trunc_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - trunc_strided_dispatch_vector[td_ns::num_types]; - -void populate_trunc_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = trunc_fn_ns; - - using fn_ns::TruncContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(trunc_contig_dispatch_vector); - - using fn_ns::TruncStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(trunc_strided_dispatch_vector); - - using fn_ns::TruncTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(trunc_output_typeid_vector); -} - -} // namespace impl - -// B24: ==== HYPOT (x1, x2) -namespace impl -{ -namespace hypot_fn_ns = dpctl::tensor::kernels::hypot; - -static binary_contig_impl_fn_ptr_t - hypot_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int hypot_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - hypot_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_hypot_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = hypot_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::HypotTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(hypot_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::HypotStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(hypot_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::HypotContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(hypot_contig_dispatch_table); -}; - -} // namespace impl - -// U37: ==== CBRT (x) -namespace impl -{ - -namespace cbrt_fn_ns = dpctl::tensor::kernels::cbrt; - -static unary_contig_impl_fn_ptr_t cbrt_contig_dispatch_vector[td_ns::num_types]; -static int cbrt_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - cbrt_strided_dispatch_vector[td_ns::num_types]; - -void populate_cbrt_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = cbrt_fn_ns; - - using fn_ns::CbrtContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(cbrt_contig_dispatch_vector); - - using fn_ns::CbrtStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(cbrt_strided_dispatch_vector); - - using fn_ns::CbrtTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(cbrt_output_typeid_vector); -} - -} // namespace impl - -// B24: ==== COPYSIGN (x1, x2) -namespace impl -{ -namespace copysign_fn_ns = dpctl::tensor::kernels::copysign; - -static binary_contig_impl_fn_ptr_t - copysign_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int copysign_output_id_table[td_ns::num_types][td_ns::num_types]; - -static binary_strided_impl_fn_ptr_t - copysign_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -void populate_copysign_dispatch_tables(void) -{ - using namespace td_ns; - namespace fn_ns = copysign_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::CopysignTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(copysign_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::CopysignStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(copysign_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::CopysignContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(copysign_contig_dispatch_table); -}; - -} // namespace impl - -// U38: ==== EXP2 (x) -namespace impl -{ - -namespace exp2_fn_ns = dpctl::tensor::kernels::exp2; - -static unary_contig_impl_fn_ptr_t exp2_contig_dispatch_vector[td_ns::num_types]; -static int exp2_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - exp2_strided_dispatch_vector[td_ns::num_types]; - -void populate_exp2_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = exp2_fn_ns; - - using fn_ns::Exp2ContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(exp2_contig_dispatch_vector); - - using fn_ns::Exp2StridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(exp2_strided_dispatch_vector); - - using fn_ns::Exp2TypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(exp2_output_typeid_vector); -} - -} // namespace impl - -// U39: ==== RSQRT (x) -namespace impl -{ - -namespace rsqrt_fn_ns = dpctl::tensor::kernels::rsqrt; - -static unary_contig_impl_fn_ptr_t - rsqrt_contig_dispatch_vector[td_ns::num_types]; -static int rsqrt_output_typeid_vector[td_ns::num_types]; -static unary_strided_impl_fn_ptr_t - rsqrt_strided_dispatch_vector[td_ns::num_types]; - -void populate_rsqrt_dispatch_vectors(void) -{ - using namespace td_ns; - namespace fn_ns = rsqrt_fn_ns; - - using fn_ns::RsqrtContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(rsqrt_contig_dispatch_vector); - - using fn_ns::RsqrtStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(rsqrt_strided_dispatch_vector); - - using fn_ns::RsqrtTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(rsqrt_output_typeid_vector); -} - -} // namespace impl - -// ========================================================================================== -// // - -namespace py = pybind11; - -void init_elementwise_functions(py::module_ m) -{ - using arrayT = dpctl::tensor::usm_ndarray; - using event_vecT = std::vector; - - // U01: ==== ABS (x) - { - impl::populate_abs_dispatch_vectors(); - using impl::abs_contig_dispatch_vector; - using impl::abs_output_typeid_vector; - using impl::abs_strided_dispatch_vector; - - auto abs_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, abs_output_typeid_vector, - abs_contig_dispatch_vector, abs_strided_dispatch_vector); - }; - m.def("_abs", abs_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto abs_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, abs_output_typeid_vector); - }; - m.def("_abs_result_type", abs_result_type_pyapi); - } - - // U02: ==== ACOS (x) - { - impl::populate_acos_dispatch_vectors(); - using impl::acos_contig_dispatch_vector; - using impl::acos_output_typeid_vector; - using impl::acos_strided_dispatch_vector; - - auto acos_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, acos_output_typeid_vector, - acos_contig_dispatch_vector, acos_strided_dispatch_vector); - }; - m.def("_acos", acos_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto acos_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, acos_output_typeid_vector); - }; - m.def("_acos_result_type", acos_result_type_pyapi); - } - - // U03: ===== ACOSH (x) - { - impl::populate_acosh_dispatch_vectors(); - using impl::acosh_contig_dispatch_vector; - using impl::acosh_output_typeid_vector; - using impl::acosh_strided_dispatch_vector; - - auto acosh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, acosh_output_typeid_vector, - acosh_contig_dispatch_vector, acosh_strided_dispatch_vector); - }; - m.def("_acosh", acosh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto acosh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - acosh_output_typeid_vector); - }; - m.def("_acosh_result_type", acosh_result_type_pyapi); - } - - // B01: ===== ADD (x1, x2) - { - impl::populate_add_dispatch_tables(); - using impl::add_contig_dispatch_table; - using impl::add_contig_matrix_contig_row_broadcast_dispatch_table; - using impl::add_contig_row_contig_matrix_broadcast_dispatch_table; - using impl::add_output_id_table; - using impl::add_strided_dispatch_table; - - auto add_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, add_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - add_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - add_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - add_contig_matrix_contig_row_broadcast_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - add_contig_row_contig_matrix_broadcast_dispatch_table); - }; - auto add_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - add_output_id_table); - }; - m.def("_add", add_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_add_result_type", add_result_type_pyapi, ""); - - using impl::add_inplace_contig_dispatch_table; - using impl::add_inplace_row_matrix_dispatch_table; - using impl::add_inplace_strided_dispatch_table; - - auto add_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, add_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - add_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - add_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - add_inplace_row_matrix_dispatch_table); - }; - m.def("_add_inplace", add_inplace_pyapi, "", py::arg("lhs"), - py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // U04: ===== ASIN (x) - { - impl::populate_asin_dispatch_vectors(); - using impl::asin_contig_dispatch_vector; - using impl::asin_output_typeid_vector; - using impl::asin_strided_dispatch_vector; - - auto asin_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, asin_output_typeid_vector, - asin_contig_dispatch_vector, asin_strided_dispatch_vector); - }; - m.def("_asin", asin_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto asin_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, asin_output_typeid_vector); - }; - m.def("_asin_result_type", asin_result_type_pyapi); - } - - // U05: ===== ASINH (x) - { - impl::populate_asinh_dispatch_vectors(); - using impl::asinh_contig_dispatch_vector; - using impl::asinh_output_typeid_vector; - using impl::asinh_strided_dispatch_vector; - - auto asinh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, asinh_output_typeid_vector, - asinh_contig_dispatch_vector, asinh_strided_dispatch_vector); - }; - m.def("_asinh", asinh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto asinh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - asinh_output_typeid_vector); - }; - m.def("_asinh_result_type", asinh_result_type_pyapi); - } - - // U06: ===== ATAN (x) - { - impl::populate_atan_dispatch_vectors(); - using impl::atan_contig_dispatch_vector; - using impl::atan_output_typeid_vector; - using impl::atan_strided_dispatch_vector; - - auto atan_pyapi = [&](arrayT src, arrayT dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, atan_output_typeid_vector, - atan_contig_dispatch_vector, atan_strided_dispatch_vector); - }; - m.def("_atan", atan_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto atan_result_type_pyapi = [&](py::dtype dtype) { - return py_unary_ufunc_result_type(dtype, atan_output_typeid_vector); - }; - m.def("_atan_result_type", atan_result_type_pyapi); - } - - // B02: ===== ATAN2 (x1, x2) - { - impl::populate_atan2_dispatch_tables(); - using impl::atan2_contig_dispatch_table; - using impl::atan2_output_id_table; - using impl::atan2_strided_dispatch_table; - - auto atan2_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, atan2_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - atan2_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - atan2_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto atan2_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - atan2_output_id_table); - }; - m.def("_atan2", atan2_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_atan2_result_type", atan2_result_type_pyapi, ""); - } - - // U07: ===== ATANH (x) - { - impl::populate_atanh_dispatch_vectors(); - using impl::atanh_contig_dispatch_vector; - using impl::atanh_output_typeid_vector; - using impl::atanh_strided_dispatch_vector; - - auto atanh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, atanh_output_typeid_vector, - atanh_contig_dispatch_vector, atanh_strided_dispatch_vector); - }; - m.def("_atanh", atanh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto atanh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - atanh_output_typeid_vector); - }; - m.def("_atanh_result_type", atanh_result_type_pyapi); - } - - // B03: ===== BITWISE_AND (x1, x2) - { - impl::populate_bitwise_and_dispatch_tables(); - using impl::bitwise_and_contig_dispatch_table; - using impl::bitwise_and_output_id_table; - using impl::bitwise_and_strided_dispatch_table; - - auto bitwise_and_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, bitwise_and_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_and_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_and_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_and_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - bitwise_and_output_id_table); - }; - m.def("_bitwise_and", bitwise_and_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_bitwise_and_result_type", bitwise_and_result_type_pyapi, ""); - } - - // B04: ===== BITWISE_LEFT_SHIFT (x1, x2) - { - impl::populate_bitwise_left_shift_dispatch_tables(); - using impl::bitwise_left_shift_contig_dispatch_table; - using impl::bitwise_left_shift_output_id_table; - using impl::bitwise_left_shift_strided_dispatch_table; - - auto bitwise_left_shift_pyapi = [&](const dpctl::tensor::usm_ndarray - &src1, - const dpctl::tensor::usm_ndarray - &src2, - const dpctl::tensor::usm_ndarray - &dst, - sycl::queue &exec_q, - const std::vector - &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, - bitwise_left_shift_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_left_shift_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_left_shift_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_left_shift_result_type_pyapi = - [&](const py::dtype &dtype1, const py::dtype &dtype2) { - return py_binary_ufunc_result_type( - dtype1, dtype2, bitwise_left_shift_output_id_table); - }; - m.def("_bitwise_left_shift", bitwise_left_shift_pyapi, "", - py::arg("src1"), py::arg("src2"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_bitwise_left_shift_result_type", - bitwise_left_shift_result_type_pyapi, ""); - } - - // U08: ===== BITWISE_INVERT (x) - { - impl::populate_bitwise_invert_dispatch_vectors(); - using impl::bitwise_invert_contig_dispatch_vector; - using impl::bitwise_invert_output_typeid_vector; - using impl::bitwise_invert_strided_dispatch_vector; - - auto bitwise_invert_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - bitwise_invert_output_typeid_vector, - bitwise_invert_contig_dispatch_vector, - bitwise_invert_strided_dispatch_vector); - }; - m.def("_bitwise_invert", bitwise_invert_pyapi, "", py::arg("src"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - - auto bitwise_invert_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type( - dtype, bitwise_invert_output_typeid_vector); - }; - m.def("_bitwise_invert_result_type", bitwise_invert_result_type_pyapi); - } - - // B05: ===== BITWISE_OR (x1, x2) - { - impl::populate_bitwise_or_dispatch_tables(); - using impl::bitwise_or_contig_dispatch_table; - using impl::bitwise_or_output_id_table; - using impl::bitwise_or_strided_dispatch_table; - - auto bitwise_or_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, bitwise_or_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_or_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_or_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_or_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - bitwise_or_output_id_table); - }; - m.def("_bitwise_or", bitwise_or_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_bitwise_or_result_type", bitwise_or_result_type_pyapi, ""); - } - - // B06: ===== BITWISE_RIGHT_SHIFT (x1, x2) - { - impl::populate_bitwise_right_shift_dispatch_tables(); - using impl::bitwise_right_shift_contig_dispatch_table; - using impl::bitwise_right_shift_output_id_table; - using impl::bitwise_right_shift_strided_dispatch_table; - - auto bitwise_right_shift_pyapi = [&](const dpctl::tensor::usm_ndarray - &src1, - const dpctl::tensor::usm_ndarray - &src2, - const dpctl::tensor::usm_ndarray - &dst, - sycl::queue &exec_q, - const std::vector - &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, - bitwise_right_shift_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_right_shift_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_right_shift_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_right_shift_result_type_pyapi = - [&](const py::dtype &dtype1, const py::dtype &dtype2) { - return py_binary_ufunc_result_type( - dtype1, dtype2, bitwise_right_shift_output_id_table); - }; - m.def("_bitwise_right_shift", bitwise_right_shift_pyapi, "", - py::arg("src1"), py::arg("src2"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_bitwise_right_shift_result_type", - bitwise_right_shift_result_type_pyapi, ""); - } - - // B07: ===== BITWISE_XOR (x1, x2) - { - impl::populate_bitwise_xor_dispatch_tables(); - using impl::bitwise_xor_contig_dispatch_table; - using impl::bitwise_xor_output_id_table; - using impl::bitwise_xor_strided_dispatch_table; - - auto bitwise_xor_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, bitwise_xor_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - bitwise_xor_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - bitwise_xor_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto bitwise_xor_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - bitwise_xor_output_id_table); - }; - m.def("_bitwise_xor", bitwise_xor_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_bitwise_xor_result_type", bitwise_xor_result_type_pyapi, ""); - } - - // U09: ==== CEIL (x) - { - impl::populate_ceil_dispatch_vectors(); - using impl::ceil_contig_dispatch_vector; - using impl::ceil_output_typeid_vector; - using impl::ceil_strided_dispatch_vector; - - auto ceil_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, ceil_output_typeid_vector, - ceil_contig_dispatch_vector, ceil_strided_dispatch_vector); - }; - m.def("_ceil", ceil_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto ceil_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, ceil_output_typeid_vector); - }; - m.def("_ceil_result_type", ceil_result_type_pyapi); - } - - // U10: ==== CONJ (x) - { - impl::populate_conj_dispatch_vectors(); - using impl::conj_contig_dispatch_vector; - using impl::conj_output_typeid_vector; - using impl::conj_strided_dispatch_vector; - - auto conj_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, conj_output_typeid_vector, - conj_contig_dispatch_vector, conj_strided_dispatch_vector); - }; - m.def("_conj", conj_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto conj_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, conj_output_typeid_vector); - }; - m.def("_conj_result_type", conj_result_type_pyapi); - } - - // U11: ==== COS (x) - { - impl::populate_cos_dispatch_vectors(); - using impl::cos_contig_dispatch_vector; - using impl::cos_output_typeid_vector; - using impl::cos_strided_dispatch_vector; - - auto cos_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, cos_output_typeid_vector, - cos_contig_dispatch_vector, cos_strided_dispatch_vector); - }; - m.def("_cos", cos_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto cos_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, cos_output_typeid_vector); - }; - m.def("_cos_result_type", cos_result_type_pyapi); - } - - // U12: ==== COSH (x) - { - impl::populate_cosh_dispatch_vectors(); - using impl::cosh_contig_dispatch_vector; - using impl::cosh_output_typeid_vector; - using impl::cosh_strided_dispatch_vector; - - auto cosh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, cosh_output_typeid_vector, - cosh_contig_dispatch_vector, cosh_strided_dispatch_vector); - }; - m.def("_cosh", cosh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto cosh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, cosh_output_typeid_vector); - }; - m.def("_cosh_result_type", cosh_result_type_pyapi); - } - - // B08: ==== DIVIDE (x1, x2) - { - impl::populate_true_divide_dispatch_tables(); - using impl::true_divide_contig_dispatch_table; - using impl:: - true_divide_contig_matrix_contig_row_broadcast_dispatch_table; - using impl:: - true_divide_contig_row_contig_matrix_broadcast_dispatch_table; - using impl::true_divide_output_id_table; - using impl::true_divide_strided_dispatch_table; - - auto divide_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, true_divide_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - true_divide_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - true_divide_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - true_divide_contig_matrix_contig_row_broadcast_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - true_divide_contig_row_contig_matrix_broadcast_dispatch_table); - }; - auto divide_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - true_divide_output_id_table); - }; - m.def("_divide", divide_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_divide_result_type", divide_result_type_pyapi, ""); - - using impl::true_divide_inplace_contig_dispatch_table; - using impl::true_divide_inplace_output_id_table; - using impl::true_divide_inplace_row_matrix_dispatch_table; - using impl::true_divide_inplace_strided_dispatch_table; - - auto divide_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, - true_divide_inplace_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - true_divide_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - true_divide_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - true_divide_inplace_row_matrix_dispatch_table); - }; - m.def("_divide_inplace", divide_inplace_pyapi, "", py::arg("lhs"), - py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // B09: ==== EQUAL (x1, x2) - { - impl::populate_equal_dispatch_tables(); - using impl::equal_contig_dispatch_table; - using impl::equal_output_id_table; - using impl::equal_strided_dispatch_table; - - auto equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, equal_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - equal_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - equal_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto equal_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - equal_output_id_table); - }; - m.def("_equal", equal_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_equal_result_type", equal_result_type_pyapi, ""); - } - - // U13: ==== EXP (x) - { - impl::populate_exp_dispatch_vectors(); - using impl::exp_contig_dispatch_vector; - using impl::exp_output_typeid_vector; - using impl::exp_strided_dispatch_vector; - - auto exp_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, exp_output_typeid_vector, - exp_contig_dispatch_vector, exp_strided_dispatch_vector); - }; - m.def("_exp", exp_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto exp_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, exp_output_typeid_vector); - }; - m.def("_exp_result_type", exp_result_type_pyapi); - } - - // U14: ==== EXPM1 (x) - { - impl::populate_expm1_dispatch_vectors(); - using impl::expm1_contig_dispatch_vector; - using impl::expm1_output_typeid_vector; - using impl::expm1_strided_dispatch_vector; - - auto expm1_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, expm1_output_typeid_vector, - expm1_contig_dispatch_vector, expm1_strided_dispatch_vector); - }; - m.def("_expm1", expm1_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto expm1_result_type_pyapi = [&](const py::dtype dtype) { - return py_unary_ufunc_result_type(dtype, - expm1_output_typeid_vector); - }; - m.def("_expm1_result_type", expm1_result_type_pyapi); - } - - // U15: ==== FLOOR (x) - { - impl::populate_floor_dispatch_vectors(); - using impl::floor_contig_dispatch_vector; - using impl::floor_output_typeid_vector; - using impl::floor_strided_dispatch_vector; - - auto floor_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, floor_output_typeid_vector, - floor_contig_dispatch_vector, floor_strided_dispatch_vector); - }; - m.def("_floor", floor_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto floor_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - floor_output_typeid_vector); - }; - m.def("_floor_result_type", floor_result_type_pyapi); - } - - // B10: ==== FLOOR_DIVIDE (x1, x2) - { - impl::populate_floor_divide_dispatch_tables(); - using impl::floor_divide_contig_dispatch_table; - using impl::floor_divide_output_id_table; - using impl::floor_divide_strided_dispatch_table; - - auto floor_divide_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, floor_divide_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - floor_divide_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - floor_divide_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto floor_divide_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - floor_divide_output_id_table); - }; - m.def("_floor_divide", floor_divide_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_floor_divide_result_type", floor_divide_result_type_pyapi, ""); - - using impl::floor_divide_inplace_contig_dispatch_table; - using impl::floor_divide_inplace_strided_dispatch_table; - - auto floor_divide_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, floor_divide_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - floor_divide_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - floor_divide_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - td_ns::NullPtrTable< - binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); - }; - m.def("_floor_divide_inplace", floor_divide_inplace_pyapi, "", - py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // B11: ==== GREATER (x1, x2) - { - impl::populate_greater_dispatch_tables(); - using impl::greater_contig_dispatch_table; - using impl::greater_output_id_table; - using impl::greater_strided_dispatch_table; - - auto greater_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, greater_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - greater_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - greater_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto greater_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - greater_output_id_table); - }; - m.def("_greater", greater_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_greater_result_type", greater_result_type_pyapi, ""); - } - - // B12: ==== GREATER_EQUAL (x1, x2) - { - impl::populate_greater_equal_dispatch_tables(); - using impl::greater_equal_contig_dispatch_table; - using impl::greater_equal_output_id_table; - using impl::greater_equal_strided_dispatch_table; - - auto greater_equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, greater_equal_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - greater_equal_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - greater_equal_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto greater_equal_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - greater_equal_output_id_table); - }; - m.def("_greater_equal", greater_equal_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_greater_equal_result_type", greater_equal_result_type_pyapi, - ""); - } - - // U16: ==== IMAG (x) - { - impl::populate_imag_dispatch_vectors(); - using impl::imag_contig_dispatch_vector; - using impl::imag_output_typeid_vector; - using impl::imag_strided_dispatch_vector; - - auto imag_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, imag_output_typeid_vector, - imag_contig_dispatch_vector, imag_strided_dispatch_vector); - }; - m.def("_imag", imag_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto imag_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, imag_output_typeid_vector); - }; - m.def("_imag_result_type", imag_result_type_pyapi); - } - - // U17: ==== ISFINITE (x) - { - impl::populate_isfinite_dispatch_vectors(); - - using impl::isfinite_contig_dispatch_vector; - using impl::isfinite_output_typeid_vector; - using impl::isfinite_strided_dispatch_vector; - auto isfinite_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - isfinite_output_typeid_vector, - isfinite_contig_dispatch_vector, - isfinite_strided_dispatch_vector); - }; - auto isfinite_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - isfinite_output_typeid_vector); - }; - m.def("_isfinite", isfinite_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_isfinite_result_type", isfinite_result_type_pyapi, ""); - } - - // U18: ==== ISINF (x) - { - impl::populate_isinf_dispatch_vectors(); - - using impl::isinf_contig_dispatch_vector; - using impl::isinf_output_typeid_vector; - using impl::isinf_strided_dispatch_vector; - auto isinf_pyapi = [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, isinf_output_typeid_vector, - isinf_contig_dispatch_vector, isinf_strided_dispatch_vector); - }; - auto isinf_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - isinf_output_typeid_vector); - }; - m.def("_isinf", isinf_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_isinf_result_type", isinf_result_type_pyapi, ""); - } - - // U19: ==== ISNAN (x) - { - impl::populate_isnan_dispatch_vectors(); - - using impl::isnan_contig_dispatch_vector; - using impl::isnan_output_typeid_vector; - using impl::isnan_strided_dispatch_vector; - auto isnan_pyapi = [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, isnan_output_typeid_vector, - isnan_contig_dispatch_vector, isnan_strided_dispatch_vector); - }; - auto isnan_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - isnan_output_typeid_vector); - }; - m.def("_isnan", isnan_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_isnan_result_type", isnan_result_type_pyapi, ""); - } - - // B13: ==== LESS (x1, x2) - { - impl::populate_less_dispatch_tables(); - using impl::less_contig_dispatch_table; - using impl::less_output_id_table; - using impl::less_strided_dispatch_table; - - auto less_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, less_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - less_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - less_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto less_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - less_output_id_table); - }; - m.def("_less", less_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_less_result_type", less_result_type_pyapi, ""); - } - - // B14: ==== LESS_EQUAL (x1, x2) - { - impl::populate_less_equal_dispatch_tables(); - using impl::less_equal_contig_dispatch_table; - using impl::less_equal_output_id_table; - using impl::less_equal_strided_dispatch_table; - - auto less_equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, less_equal_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - less_equal_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - less_equal_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto less_equal_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - less_equal_output_id_table); - }; - m.def("_less_equal", less_equal_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_less_equal_result_type", less_equal_result_type_pyapi, ""); - } - - // U20: ==== LOG (x) - { - impl::populate_log_dispatch_vectors(); - using impl::log_contig_dispatch_vector; - using impl::log_output_typeid_vector; - using impl::log_strided_dispatch_vector; - - auto log_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, log_output_typeid_vector, - log_contig_dispatch_vector, log_strided_dispatch_vector); - }; - m.def("_log", log_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto log_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, log_output_typeid_vector); - }; - m.def("_log_result_type", log_result_type_pyapi); - } - - // U21: ==== LOG1P (x) - { - impl::populate_log1p_dispatch_vectors(); - using impl::log1p_contig_dispatch_vector; - using impl::log1p_output_typeid_vector; - using impl::log1p_strided_dispatch_vector; - - auto log1p_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, log1p_output_typeid_vector, - log1p_contig_dispatch_vector, log1p_strided_dispatch_vector); - }; - m.def("_log1p", log1p_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto log1p_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - log1p_output_typeid_vector); - }; - m.def("_log1p_result_type", log1p_result_type_pyapi); - } - - // U22: ==== LOG2 (x) - { - impl::populate_log2_dispatch_vectors(); - - using impl::log2_contig_dispatch_vector; - using impl::log2_output_typeid_vector; - using impl::log2_strided_dispatch_vector; - auto log2_pyapi = [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, log2_output_typeid_vector, - log2_contig_dispatch_vector, log2_strided_dispatch_vector); - }; - auto log2_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, log2_output_typeid_vector); - }; - m.def("_log2", log2_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_log2_result_type", log2_result_type_pyapi, ""); - } - - // U23: ==== LOG10 (x) - { - impl::populate_log10_dispatch_vectors(); - - using impl::log10_contig_dispatch_vector; - using impl::log10_output_typeid_vector; - using impl::log10_strided_dispatch_vector; - auto log10_pyapi = [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, log10_output_typeid_vector, - log10_contig_dispatch_vector, log10_strided_dispatch_vector); - }; - auto log10_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - log10_output_typeid_vector); - }; - m.def("_log10", log10_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_log10_result_type", log10_result_type_pyapi, ""); - } - - // B15: ==== LOGADDEXP (x1, x2) - { - impl::populate_logaddexp_dispatch_tables(); - using impl::logaddexp_contig_dispatch_table; - using impl::logaddexp_output_id_table; - using impl::logaddexp_strided_dispatch_table; - - auto logaddexp_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, logaddexp_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - logaddexp_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - logaddexp_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto logaddexp_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - logaddexp_output_id_table); - }; - m.def("_logaddexp", logaddexp_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_logaddexp_result_type", logaddexp_result_type_pyapi, ""); - } - - // B16: ==== LOGICAL_AND (x1, x2) - { - impl::populate_logical_and_dispatch_tables(); - using impl::logical_and_contig_dispatch_table; - using impl::logical_and_output_id_table; - using impl::logical_and_strided_dispatch_table; - - auto logical_and_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, logical_and_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - logical_and_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - logical_and_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto logical_and_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - logical_and_output_id_table); - }; - m.def("_logical_and", logical_and_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_logical_and_result_type", logical_and_result_type_pyapi, ""); - } - - // U24: ==== LOGICAL_NOT (x) - { - impl::populate_logical_not_dispatch_vectors(); - using impl::logical_not_contig_dispatch_vector; - using impl::logical_not_output_typeid_vector; - using impl::logical_not_strided_dispatch_vector; - - auto logical_not_pyapi = [&](const arrayT &src, arrayT dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - logical_not_output_typeid_vector, - logical_not_contig_dispatch_vector, - logical_not_strided_dispatch_vector); - }; - m.def("_logical_not", logical_not_pyapi, "", py::arg("src"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - - auto logical_not_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - logical_not_output_typeid_vector); - }; - m.def("_logical_not_result_type", logical_not_result_type_pyapi); - } - - // B17: ==== LOGICAL_OR (x1, x2) - { - impl::populate_logical_or_dispatch_tables(); - using impl::logical_or_contig_dispatch_table; - using impl::logical_or_output_id_table; - using impl::logical_or_strided_dispatch_table; - - auto logical_or_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, logical_or_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - logical_or_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - logical_or_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto logical_or_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - logical_or_output_id_table); - }; - m.def("_logical_or", logical_or_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_logical_or_result_type", logical_or_result_type_pyapi, ""); - } - - // B18: ==== LOGICAL_XOR (x1, x2) - { - impl::populate_logical_xor_dispatch_tables(); - using impl::logical_xor_contig_dispatch_table; - using impl::logical_xor_output_id_table; - using impl::logical_xor_strided_dispatch_table; - - auto logical_xor_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, logical_xor_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - logical_xor_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - logical_xor_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto logical_xor_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - logical_xor_output_id_table); - }; - m.def("_logical_xor", logical_xor_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_logical_xor_result_type", logical_xor_result_type_pyapi, ""); - } - - // B??: ==== MAXIMUM (x1, x2) - { - impl::populate_maximum_dispatch_tables(); - using impl::maximum_contig_dispatch_table; - using impl::maximum_output_id_table; - using impl::maximum_strided_dispatch_table; - - auto maximum_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, maximum_output_id_table, - // function pointers to handle operation on contiguous - // arrays (pointers may be nullptr) - maximum_contig_dispatch_table, - // function pointers to handle operation on strided arrays - // (most general case) - maximum_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto maximum_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - maximum_output_id_table); - }; - m.def("_maximum", maximum_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_maximum_result_type", maximum_result_type_pyapi, ""); - } - - // B??: ==== MINIMUM (x1, x2) - { - impl::populate_minimum_dispatch_tables(); - using impl::minimum_contig_dispatch_table; - using impl::minimum_output_id_table; - using impl::minimum_strided_dispatch_table; - - auto minimum_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, minimum_output_id_table, - // function pointers to handle operation on contiguous - // arrays (pointers may be nullptr) - minimum_contig_dispatch_table, - // function pointers to handle operation on strided arrays - // (most general case) - minimum_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto minimum_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - minimum_output_id_table); - }; - m.def("_minimum", minimum_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_minimum_result_type", minimum_result_type_pyapi, ""); - } - - // B19: ==== MULTIPLY (x1, x2) - { - impl::populate_multiply_dispatch_tables(); - using impl::multiply_contig_dispatch_table; - using impl::multiply_contig_matrix_contig_row_broadcast_dispatch_table; - using impl::multiply_contig_row_contig_matrix_broadcast_dispatch_table; - using impl::multiply_output_id_table; - using impl::multiply_strided_dispatch_table; - - auto multiply_pyapi = - [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, multiply_output_id_table, - // function pointers to handle operation on contiguous - // arrays (pointers may be nullptr) - multiply_contig_dispatch_table, - // function pointers to handle operation on strided arrays - // (most general case) - multiply_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - multiply_contig_matrix_contig_row_broadcast_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - multiply_contig_row_contig_matrix_broadcast_dispatch_table); - }; - auto multiply_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - multiply_output_id_table); - }; - m.def("_multiply", multiply_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_multiply_result_type", multiply_result_type_pyapi, ""); - - using impl::multiply_inplace_contig_dispatch_table; - using impl::multiply_inplace_row_matrix_dispatch_table; - using impl::multiply_inplace_strided_dispatch_table; - - auto multiply_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, multiply_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - multiply_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - multiply_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - multiply_inplace_row_matrix_dispatch_table); - }; - m.def("_multiply_inplace", multiply_inplace_pyapi, "", py::arg("lhs"), - py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // U25: ==== NEGATIVE (x) - { - impl::populate_negative_dispatch_vectors(); - using impl::negative_contig_dispatch_vector; - using impl::negative_output_typeid_vector; - using impl::negative_strided_dispatch_vector; - - auto negative_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - negative_output_typeid_vector, - negative_contig_dispatch_vector, - negative_strided_dispatch_vector); - }; - m.def("_negative", negative_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto negative_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - negative_output_typeid_vector); - }; - m.def("_negative_result_type", negative_result_type_pyapi); - } - - // B20: ==== NOT_EQUAL (x1, x2) - { - impl::populate_not_equal_dispatch_tables(); - using impl::not_equal_contig_dispatch_table; - using impl::not_equal_output_id_table; - using impl::not_equal_strided_dispatch_table; - - auto not_equal_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, not_equal_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - not_equal_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - not_equal_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto not_equal_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - not_equal_output_id_table); - }; - m.def("_not_equal", not_equal_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_not_equal_result_type", not_equal_result_type_pyapi, ""); - } - - // U26: ==== POSITIVE (x) - { - impl::populate_positive_dispatch_vectors(); - using impl::positive_contig_dispatch_vector; - using impl::positive_output_typeid_vector; - using impl::positive_strided_dispatch_vector; - - auto positive_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - positive_output_typeid_vector, - positive_contig_dispatch_vector, - positive_strided_dispatch_vector); - }; - m.def("_positive", positive_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto positive_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - positive_output_typeid_vector); - }; - m.def("_positive_result_type", positive_result_type_pyapi); - } - - // B21: ==== POW (x1, x2) - { - impl::populate_pow_dispatch_tables(); - using impl::pow_contig_dispatch_table; - using impl::pow_output_id_table; - using impl::pow_strided_dispatch_table; - - auto pow_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, pow_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - pow_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - pow_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto pow_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - pow_output_id_table); - }; - m.def("_pow", pow_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_pow_result_type", pow_result_type_pyapi, ""); - } - - // U??: ==== PROJ (x) - { - impl::populate_proj_dispatch_vectors(); - using impl::proj_contig_dispatch_vector; - using impl::proj_output_typeid_vector; - using impl::proj_strided_dispatch_vector; - - auto proj_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, proj_output_typeid_vector, - proj_contig_dispatch_vector, proj_strided_dispatch_vector); - }; - m.def("_proj", proj_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto proj_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, proj_output_typeid_vector); - }; - m.def("_proj_result_type", proj_result_type_pyapi); - } - - // U27: ==== REAL (x) - { - impl::populate_real_dispatch_vectors(); - using impl::real_contig_dispatch_vector; - using impl::real_output_typeid_vector; - using impl::real_strided_dispatch_vector; - - auto real_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, real_output_typeid_vector, - real_contig_dispatch_vector, real_strided_dispatch_vector); - }; - m.def("_real", real_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto real_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, real_output_typeid_vector); - }; - m.def("_real_result_type", real_result_type_pyapi); - } - - // B22: ==== REMAINDER (x1, x2) - { - impl::populate_remainder_dispatch_tables(); - using impl::remainder_contig_dispatch_table; - using impl::remainder_output_id_table; - using impl::remainder_strided_dispatch_table; - - auto remainder_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, remainder_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - remainder_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - remainder_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto remainder_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - remainder_output_id_table); - }; - m.def("_remainder", remainder_pyapi, "", py::arg("src1"), - py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_remainder_result_type", remainder_result_type_pyapi, ""); - } - - // U28: ==== ROUND (x) - { - impl::populate_round_dispatch_vectors(); - using impl::round_contig_dispatch_vector; - using impl::round_output_typeid_vector; - using impl::round_strided_dispatch_vector; - - auto round_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, round_output_typeid_vector, - round_contig_dispatch_vector, round_strided_dispatch_vector); - }; - m.def("_round", round_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto round_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - round_output_typeid_vector); - }; - m.def("_round_result_type", round_result_type_pyapi); - } - - // U29: ==== SIGN (x) - { - impl::populate_sign_dispatch_vectors(); - using impl::sign_contig_dispatch_vector; - using impl::sign_output_typeid_vector; - using impl::sign_strided_dispatch_vector; - - auto sign_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, sign_output_typeid_vector, - sign_contig_dispatch_vector, sign_strided_dispatch_vector); - }; - m.def("_sign", sign_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sign_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, sign_output_typeid_vector); - }; - m.def("_sign_result_type", sign_result_type_pyapi); - } - - // ==== SIGNBIT (x) - { - impl::populate_signbit_dispatch_vectors(); - using impl::signbit_contig_dispatch_vector; - using impl::signbit_output_typeid_vector; - using impl::signbit_strided_dispatch_vector; - - auto signbit_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc(src, dst, exec_q, depends, - signbit_output_typeid_vector, - signbit_contig_dispatch_vector, - signbit_strided_dispatch_vector); - }; - m.def("_signbit", signbit_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto signbit_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - signbit_output_typeid_vector); - }; - m.def("_signbit_result_type", signbit_result_type_pyapi); - } - - // U30: ==== SIN (x) - { - impl::populate_sin_dispatch_vectors(); - using impl::sin_contig_dispatch_vector; - using impl::sin_output_typeid_vector; - using impl::sin_strided_dispatch_vector; - - auto sin_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, sin_output_typeid_vector, - sin_contig_dispatch_vector, sin_strided_dispatch_vector); - }; - m.def("_sin", sin_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sin_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, sin_output_typeid_vector); - }; - m.def("_sin_result_type", sin_result_type_pyapi); - } - // U31: ==== SINH (x) - { - impl::populate_sinh_dispatch_vectors(); - using impl::sinh_contig_dispatch_vector; - using impl::sinh_output_typeid_vector; - using impl::sinh_strided_dispatch_vector; - - auto sinh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, sinh_output_typeid_vector, - sinh_contig_dispatch_vector, sinh_strided_dispatch_vector); - }; - m.def("_sinh", sinh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sinh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, sinh_output_typeid_vector); - }; - m.def("_sinh_result_type", sinh_result_type_pyapi); - } - - // U32: ==== SQUARE (x) - { - impl::populate_square_dispatch_vectors(); - using impl::square_contig_dispatch_vector; - using impl::square_output_typeid_vector; - using impl::square_strided_dispatch_vector; - - auto square_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, square_output_typeid_vector, - square_contig_dispatch_vector, square_strided_dispatch_vector); - }; - m.def("_square", square_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto square_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - square_output_typeid_vector); - }; - m.def("_square_result_type", square_result_type_pyapi); - } - - // U33: ==== SQRT (x) - { - impl::populate_sqrt_dispatch_vectors(); - using impl::sqrt_contig_dispatch_vector; - using impl::sqrt_output_typeid_vector; - using impl::sqrt_strided_dispatch_vector; - - auto sqrt_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, sqrt_output_typeid_vector, - sqrt_contig_dispatch_vector, sqrt_strided_dispatch_vector); - }; - m.def("_sqrt", sqrt_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sqrt_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, sqrt_output_typeid_vector); - }; - m.def("_sqrt_result_type", sqrt_result_type_pyapi); - } - - // B23: ==== SUBTRACT (x1, x2) - { - impl::populate_subtract_dispatch_tables(); - using impl::subtract_contig_dispatch_table; - using impl::subtract_contig_matrix_contig_row_broadcast_dispatch_table; - using impl::subtract_contig_row_contig_matrix_broadcast_dispatch_table; - using impl::subtract_output_id_table; - using impl::subtract_strided_dispatch_table; - - auto subtract_pyapi = - [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, subtract_output_id_table, - // function pointers to handle operation on contiguous - // arrays (pointers may be nullptr) - subtract_contig_dispatch_table, - // function pointers to handle operation on strided arrays - // (most general case) - subtract_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - subtract_contig_matrix_contig_row_broadcast_dispatch_table, - // function pointers to handle operation of c-contig matrix - // and c-contig row with broadcasting (may be nullptr) - subtract_contig_row_contig_matrix_broadcast_dispatch_table); - }; - auto subtract_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - subtract_output_id_table); - }; - m.def("_subtract", subtract_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_subtract_result_type", subtract_result_type_pyapi, ""); - - using impl::subtract_inplace_contig_dispatch_table; - using impl::subtract_inplace_row_matrix_dispatch_table; - using impl::subtract_inplace_strided_dispatch_table; - - auto subtract_inplace_pyapi = - [&](const dpctl::tensor::usm_ndarray &src, - const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_inplace_ufunc( - src, dst, exec_q, depends, subtract_output_id_table, - // function pointers to handle inplace operation on - // contiguous arrays (pointers may be nullptr) - subtract_inplace_contig_dispatch_table, - // function pointers to handle inplace operation on strided - // arrays (most general case) - subtract_inplace_strided_dispatch_table, - // function pointers to handle inplace operation on - // c-contig matrix with c-contig row with broadcasting - // (may be nullptr) - subtract_inplace_row_matrix_dispatch_table); - }; - m.def("_subtract_inplace", subtract_inplace_pyapi, "", py::arg("lhs"), - py::arg("rhs"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - } - - // U34: ==== TAN (x) - { - impl::populate_tan_dispatch_vectors(); - using impl::tan_contig_dispatch_vector; - using impl::tan_output_typeid_vector; - using impl::tan_strided_dispatch_vector; - - auto tan_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, tan_output_typeid_vector, - tan_contig_dispatch_vector, tan_strided_dispatch_vector); - }; - m.def("_tan", tan_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto tan_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, tan_output_typeid_vector); - }; - m.def("_tan_result_type", tan_result_type_pyapi); - } - - // U35: ==== TANH (x) - { - impl::populate_tanh_dispatch_vectors(); - using impl::tanh_contig_dispatch_vector; - using impl::tanh_output_typeid_vector; - using impl::tanh_strided_dispatch_vector; - - auto tanh_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, tanh_output_typeid_vector, - tanh_contig_dispatch_vector, tanh_strided_dispatch_vector); - }; - m.def("_tanh", tanh_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto tanh_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, tanh_output_typeid_vector); - }; - m.def("_tanh_result_type", tanh_result_type_pyapi); - } - - // U36: ==== TRUNC (x) - { - impl::populate_trunc_dispatch_vectors(); - using impl::trunc_contig_dispatch_vector; - using impl::trunc_output_typeid_vector; - using impl::trunc_strided_dispatch_vector; - - auto trunc_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, trunc_output_typeid_vector, - trunc_contig_dispatch_vector, trunc_strided_dispatch_vector); - }; - m.def("_trunc", trunc_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto trunc_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - trunc_output_typeid_vector); - }; - m.def("_trunc_result_type", trunc_result_type_pyapi); - } - - // B24: ==== HYPOT (x1, x2) - { - impl::populate_hypot_dispatch_tables(); - using impl::hypot_contig_dispatch_table; - using impl::hypot_output_id_table; - using impl::hypot_strided_dispatch_table; - - auto hypot_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, hypot_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - hypot_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - hypot_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto hypot_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - hypot_output_id_table); - }; - m.def("_hypot", hypot_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_hypot_result_type", hypot_result_type_pyapi, ""); - } - - // U37: ==== CBRT (x) - { - impl::populate_cbrt_dispatch_vectors(); - using impl::cbrt_contig_dispatch_vector; - using impl::cbrt_output_typeid_vector; - using impl::cbrt_strided_dispatch_vector; - - auto cbrt_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, cbrt_output_typeid_vector, - cbrt_contig_dispatch_vector, cbrt_strided_dispatch_vector); - }; - m.def("_cbrt", cbrt_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto cbrt_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, cbrt_output_typeid_vector); - }; - m.def("_cbrt_result_type", cbrt_result_type_pyapi); - } - - // B25: ==== COPYSIGN (x1, x2) - { - impl::populate_copysign_dispatch_tables(); - using impl::copysign_contig_dispatch_table; - using impl::copysign_output_id_table; - using impl::copysign_strided_dispatch_table; - - auto copysign_pyapi = [&](const dpctl::tensor::usm_ndarray &src1, - const dpctl::tensor::usm_ndarray &src2, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends = - {}) { - return py_binary_ufunc( - src1, src2, dst, exec_q, depends, copysign_output_id_table, - // function pointers to handle operation on contiguous arrays - // (pointers may be nullptr) - copysign_contig_dispatch_table, - // function pointers to handle operation on strided arrays (most - // general case) - copysign_strided_dispatch_table, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, - // function pointers to handle operation of c-contig matrix and - // c-contig row with broadcasting (may be nullptr) - td_ns::NullPtrTable< - binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); - }; - auto copysign_result_type_pyapi = [&](const py::dtype &dtype1, - const py::dtype &dtype2) { - return py_binary_ufunc_result_type(dtype1, dtype2, - copysign_output_id_table); - }; - m.def("_copysign", copysign_pyapi, "", py::arg("src1"), py::arg("src2"), - py::arg("dst"), py::arg("sycl_queue"), - py::arg("depends") = py::list()); - m.def("_copysign_result_type", copysign_result_type_pyapi, ""); - } - - // U38: ==== EXP2 (x) - { - impl::populate_exp2_dispatch_vectors(); - using impl::exp2_contig_dispatch_vector; - using impl::exp2_output_typeid_vector; - using impl::exp2_strided_dispatch_vector; - - auto exp2_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, exp2_output_typeid_vector, - exp2_contig_dispatch_vector, exp2_strided_dispatch_vector); - }; - m.def("_exp2", exp2_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto exp2_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, exp2_output_typeid_vector); - }; - m.def("_exp2_result_type", exp2_result_type_pyapi); - } - - // U39: ==== RSQRT (x) - { - impl::populate_rsqrt_dispatch_vectors(); - using impl::rsqrt_contig_dispatch_vector; - using impl::rsqrt_output_typeid_vector; - using impl::rsqrt_strided_dispatch_vector; - - auto rsqrt_pyapi = [&](const arrayT &src, const arrayT &dst, - sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_unary_ufunc( - src, dst, exec_q, depends, rsqrt_output_typeid_vector, - rsqrt_contig_dispatch_vector, rsqrt_strided_dispatch_vector); - }; - m.def("_rsqrt", rsqrt_pyapi, "", py::arg("src"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto rsqrt_result_type_pyapi = [&](const py::dtype &dtype) { - return py_unary_ufunc_result_type(dtype, - rsqrt_output_typeid_vector); - }; - m.def("_rsqrt_result_type", rsqrt_result_type_pyapi); - } -} - -} // namespace py_internal -} // namespace tensor -} // namespace dpctl From af41424aadd7e37b9ac99e8e2311da0b3acda362 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 2 Jan 2024 15:07:51 -0800 Subject: [PATCH 03/48] Implements `matmul`, `vecdot`, and `tensordot` These three functions are implemented through a common `py_dot` binding, which is also part of a new tensor submodule `_tensor_linalg_impl` --- dpctl/tensor/CMakeLists.txt | 15 + dpctl/tensor/__init__.py | 10 +- dpctl/tensor/_linear_algebra_functions.py | 965 ++ .../kernels/linalg_functions/dot_product.hpp | 1137 ++ .../include/kernels/linalg_functions/gemm.hpp | 9840 +++++++++++++++++ .../libtensor/source/linalg_functions/dot.cpp | 857 ++ .../libtensor/source/linalg_functions/dot.hpp | 17 + .../linalg_functions/dot_atomic_support.hpp | 34 + .../source/linalg_functions/dot_dispatch.hpp | 336 + .../tensor/libtensor/source/tensor_linalg.cpp | 34 + dpctl/tests/test_usm_ndarray_linalg.py | 438 +- 11 files changed, 13681 insertions(+), 2 deletions(-) create mode 100644 dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp create mode 100644 dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp create mode 100644 dpctl/tensor/libtensor/source/linalg_functions/dot.cpp create mode 100644 dpctl/tensor/libtensor/source/linalg_functions/dot.hpp create mode 100644 dpctl/tensor/libtensor/source/linalg_functions/dot_atomic_support.hpp create mode 100644 dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp create mode 100644 dpctl/tensor/libtensor/source/tensor_linalg.cpp diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index de8be6fae0..a57e9c5104 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -148,6 +148,15 @@ set(_tensor_reductions_impl_sources ${_boolean_reduction_sources} ${_reduction_sources} ) +set(_linalg_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp +) +set(_tensor_linalg_impl_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_linalg.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp + ${_linalg_sources} +) set(_py_trgts) @@ -166,6 +175,11 @@ pybind11_add_module(${python_module_name} MODULE ${_tensor_reductions_impl_sourc add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_reductions_impl_sources}) list(APPEND _py_trgts ${python_module_name}) +set(python_module_name _tensor_linalg_impl) +pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources}) +list(APPEND _py_trgts ${python_module_name}) + set(_clang_prefix "") if (WIN32) set(_clang_prefix "/clang:") @@ -179,6 +193,7 @@ set(_no_fast_math_sources list(APPEND _no_fast_math_sources ${_elementwise_sources} ${_reduction_sources} + ${_linalg_sources} ) foreach(_src_fn ${_no_fast_math_sources}) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index cdb701e1cb..3cd2aa80d4 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -60,7 +60,12 @@ from dpctl.tensor._device import Device from dpctl.tensor._dlpack import from_dlpack from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take -from dpctl.tensor._linear_algebra_functions import matrix_transpose +from dpctl.tensor._linear_algebra_functions import ( + matmul, + matrix_transpose, + tensordot, + vecdot, +) from dpctl.tensor._manipulation_functions import ( broadcast_arrays, broadcast_to, @@ -343,4 +348,7 @@ "__array_namespace_info__", "reciprocal", "angle", + "matmul", + "tensordot", + "vecdot", ] diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index fd2c58b08a..2588fc0856 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -14,7 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import operator + +from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple + +import dpctl import dpctl.tensor as dpt +import dpctl.tensor._tensor_elementwise_impl as tei +import dpctl.tensor._tensor_impl as ti +import dpctl.tensor._tensor_linalg_impl as tli +from dpctl.tensor._copy_utils import _empty_like_orderK, _empty_like_pair_orderK +from dpctl.tensor._manipulation_functions import _broadcast_shape_impl +from dpctl.tensor._type_utils import ( + _acceptance_fn_default_binary, + _find_buf_dtype2, + _to_device_supported_dtype, +) +from dpctl.utils import ExecutionPlacementError def matrix_transpose(x): @@ -46,3 +62,952 @@ def matrix_transpose(x): ) return x.mT + + +def tensordot(x1, x2, axes=2): + """tensordot(x1, x2, axes=2) + + Returns a tensor contraction of `x1` and `x2` over specific axes. + + Args: + x1 (usm_ndarray): + first input array, expected to have numeric data type. + x2 (usm_ndarray): + second input array, expected to have numeric data type. + Corresponding contracted axes of `x1` and `x2` must be equal. + axes (Union[int, Tuple[Sequence[int], Sequence[int]]): + number of axes to contract or explicit sequences of axes for + `x1` and `x2`, respectively. If `axes` is an integer equal to `N`, + then the contraction is performed over last `N` axes of `x1` and + the first `N` axis of `x2` in order. The size of each corresponding + axis must match and must be non-negative. + * if `N` equals `0`, the result is the tensor outer product + * if `N` equals `1`, the result is the tensor dot product + * if `N` equals `2`, the result is the tensor double + contraction (default). + If `axes` is a tuple of two sequences `(x1_axes, x2_axes)`, the + first sequence applies to `x1` and the second sequence applies + to `x2`. Both sequences must have equal length, and each axis + `x1_axes[i]` for `x1` must have the same size as the respective + axis `x2_axes[i]` for `x2`. Each sequence must consist of unique + non-negative integers that specify valid axes for each respective + array. + Returns: + usm_ndarray: + an array containing the tensor contraction whose shape consists of + the non-contracted axes of the first array `x1`, followed by the + non-contracted axes of the second array `x2`. The returned array + must have a data type determined by Type Promotion Rules. + """ + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + q1, x1_usm_type = x1.sycl_queue, x1.usm_type + q2, x2_usm_type = x2.sycl_queue, x2.usm_type + if q1 is None and q2 is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments. " + "One of the arguments must represent USM allocation and " + "expose `__sycl_usm_array_interface__` property" + ) + if q1 is None: + exec_q = q2 + res_usm_type = x2_usm_type + elif q2 is None: + exec_q = q1 + res_usm_type = x1_usm_type + else: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + # handle axes and shapes validation + x1_nd = x1.ndim + x2_nd = x2.ndim + x1_shape = x1.shape + x2_shape = x2.shape + if isinstance(axes, int): + n_axes1 = axes + n_axes2 = axes + axes1 = tuple(range(-axes, 0)) + axes2 = tuple(range(0, axes)) + elif isinstance(axes, tuple): + if len(axes) != 2: + raise ValueError( + "`axes` tuple is expected to contain two sequences" + ) + axes1 = tuple(axes[0]) + axes2 = tuple(axes[1]) + n_axes1 = len(axes1) + n_axes2 = len(axes2) + else: + raise TypeError("`axes` must be an integer or a tuple of sequences") + if n_axes1 != n_axes2: + raise ValueError( + "number of axes contracted must be the same for each array" + ) + if n_axes1 == 0: + arr1 = x1[..., dpt.newaxis] + arr2 = x2[dpt.newaxis, ...] + n_axes1 = 1 + n_axes2 = 1 + else: + same_shapes = True + for i in range(n_axes1): + same_shapes = same_shapes and ( + x1_shape[axes1[i]] == x2_shape[axes2[i]] + ) + if not same_shapes: + raise ValueError("shape mismatch in contracted `tensordot` axes") + axes1 = normalize_axis_tuple(axes1, x1_nd) + axes2 = normalize_axis_tuple(axes2, x2_nd) + perm1 = [i for i in range(x1_nd) if i not in axes1] + list(axes1) + perm2 = list(axes2) + [i for i in range(x2_nd) if i not in axes2] + arr1 = dpt.permute_dims(x1, perm1) + arr2 = dpt.permute_dims(x2, perm2) + arr1_outer_nd = arr1.ndim - n_axes1 + arr2_outer_nd = arr2.ndim - n_axes2 + res_shape = arr1.shape[:arr1_outer_nd] + arr2.shape[n_axes2:] + # type validation + sycl_dev = exec_q.sycl_device + x1_dtype = x1.dtype + x2_dtype = x2.dtype + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + x1_dtype, + x2_dtype, + tli._dot_result_type, + sycl_dev, + acceptance_fn=_acceptance_fn_default_binary, + ) + if res_dt is None: + raise TypeError( + "function 'tensordot' does not support input types " + f"({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + + if buf1_dt is None and buf2_dt is None: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=arr1, + x2=arr2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + ) + ht_dot_ev.wait() + + return out + + elif buf1_dt is None: + buf2 = _empty_like_orderK(arr2, buf2_dt) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr2, dst=buf2, sycl_queue=exec_q + ) + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=arr1, + x2=buf2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + ht_copy_ev.wait() + ht_dot_ev.wait() + + return out + + elif buf2_dt is None: + buf1 = _empty_like_orderK(arr1, buf1_dt) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr1, dst=buf1, sycl_queue=exec_q + ) + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=buf1, + x2=arr2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + ht_copy_ev.wait() + ht_dot_ev.wait() + + return out + + buf1 = _empty_like_orderK(arr1, buf1_dt) + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr1, dst=buf1, sycl_queue=exec_q + ) + buf2 = _empty_like_orderK(arr2, buf2_dt) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr2, dst=buf2, sycl_queue=exec_q + ) + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_, _ = tli._dot( + x1=buf1, + x2=buf2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + depends=[copy1_ev, copy2_ev], + ) + dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_]) + + return out + + +def vecdot(x1, x2, axis=-1): + """vecdot(x1, x2, axis=-1) + + Computes the (vector) dot product of two arrays. + + Args: + x1 (usm_ndarray): + first input array. + x2 (usm_ndarray): + second input array. Input arrays must have compatible + shapes along non-contract axes according to broadcasting + rules, and must have the same size along the contracted + axis. Input arrays should be of numeric type. + axis (Optional[int]): + axis over which to compute the dot product. The axis must + be an integer on the interval `[-N, N)`, where `N` is the + array rank of input arrays after broadcasting rules are + applied. If specified as a negative integer, the axis along + which dot product is performed is counted backward from + the last axes (that is `-1` refers to the last axis). By + default, dot product is computed over the last axis. + Default: `-1`. + + Returns: + usm_ndarray: + if `x1` and `x2` are both one-dimensional arrays, a + zero-dimensional array containing the dot product value + is returned; otherwise, a non-zero-dimensional array containing + the dot products and having rank `N-1`, where `N` is the rank + of the shape of input arrays after broadcasting rules are applied + to non-contracted axes. + """ + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + q1, x1_usm_type = x1.sycl_queue, x1.usm_type + q2, x2_usm_type = x2.sycl_queue, x2.usm_type + if q1 is None and q2 is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments. " + "One of the arguments must represent USM allocation and " + "expose `__sycl_usm_array_interface__` property" + ) + if q1 is None: + exec_q = q2 + res_usm_type = x2_usm_type + elif q2 is None: + exec_q = q1 + res_usm_type = x1_usm_type + else: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + # axis and shape validation + x1_nd = x1.ndim + x2_nd = x2.ndim + x1_shape = x1.shape + x2_shape = x2.shape + if x1_nd > x2_nd: + x2_shape = (1,) * (x1_nd - x2_nd) + x2_shape + x2_nd = len(x2_shape) + elif x2_nd > x1_nd: + x1_shape = (1,) * (x2_nd - x1_nd) + x1_shape + x1_nd = len(x1_shape) + axis = normalize_axis_index(operator.index(axis), x1_nd) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError( + "given axis must have the same shape for `x1` and `x2`" + ) + try: + broadcast_sh = _broadcast_shape_impl( + [ + x1_shape, + x2_shape, + ] + ) + except ValueError: + raise ValueError("mismatch in `vecdot` dimensions") + res_sh = tuple( + [broadcast_sh[i] for i in range(len(broadcast_sh)) if i != axis] + ) + # type validation + sycl_dev = exec_q.sycl_device + x1_dtype = x1.dtype + x2_dtype = x2.dtype + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + x1_dtype, + x2_dtype, + tli._dot_result_type, + sycl_dev, + acceptance_fn=_acceptance_fn_default_binary, + ) + if res_dt is None: + raise TypeError( + "function 'vecdot' does not support input types " + f"({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + + ht_list = [] + deps = [] + if buf1_dt is None and buf2_dt is None: + if x1.dtype.kind == "c": + x1_tmp = _empty_like_orderK(x1, x1.dtype) + ht_conj_ev, conj_ev = tei._conj( + src=x1, + dst=x1_tmp, + sycl_queue=exec_q, + ) + ht_list.append(ht_conj_ev) + deps.append(conj_ev) + x1 = x1_tmp + if x1.shape != broadcast_sh: + x1 = dpt.broadcast_to(x1, broadcast_sh) + if x2.shape != broadcast_sh: + x2 = dpt.broadcast_to(x2, broadcast_sh) + x1 = dpt.moveaxis(x1, axis, -1) + x2 = dpt.moveaxis(x2, axis, -1) + + out = dpt.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=x1, + x2=x2, + batch_dims=len(x1.shape[:-1]), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps, + ) + ht_list.append(ht_dot_ev) + dpctl.SyclEvent.wait_for(ht_list) + + return dpt.reshape(out, res_sh) + + elif buf1_dt is None: + if x1.dtype.kind == "c": + x1_tmp = _empty_like_orderK(x1, x1.dtype) + ht_conj_ev, conj_e = tei._conj( + src=x1, dst=x1_tmp, sycl_queue=exec_q + ) + ht_list.append(ht_conj_ev) + deps.append(conj_e) + x1 = x1_tmp + buf2 = _empty_like_orderK(x2, buf2_dt) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q + ) + if x1.shape != broadcast_sh: + x1 = dpt.broadcast_to(x1, broadcast_sh) + if buf2.shape != broadcast_sh: + buf2 = dpt.broadcast_to(buf2, broadcast_sh) + x1 = dpt.moveaxis(x1, axis, -1) + buf2 = dpt.moveaxis(buf2, axis, -1) + out = dpt.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=x1, + x2=buf2, + batch_dims=len(x1.shape[:-1]), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps, + ) + ht_list.append(ht_dot_ev) + dpctl.SyclEvent.wait_for(ht_list) + + return dpt.reshape(out, res_sh) + + elif buf2_dt is None: + buf1 = _empty_like_orderK(x1, buf1_dt) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q + ) + ht_list.append(ht_copy_ev) + deps.append(copy_ev) + if buf1.dtype.kind == "c": + ht_conj_ev, conj_ev = tei._conj( + src=buf1, dst=buf1, sycl_queue=exec_q, depends=[copy_ev] + ) + ht_list.append(ht_conj_ev) + deps.append(conj_ev) + if buf1.shape != broadcast_sh: + buf1 = dpt.broadcast_to(buf1, broadcast_sh) + if x2.shape != broadcast_sh: + x2 = dpt.broadcast_to(x2, broadcast_sh) + buf1 = dpt.moveaxis(buf1, axis, -1) + x2 = dpt.moveaxis(x2, axis, -1) + out = dpt.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=buf1, + x2=x2, + batch_dims=len(x1.shape[:-1]), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps, + ) + ht_list.append(ht_dot_ev) + dpctl.SyclEvent.wait_for(ht_list) + + return dpt.reshape(out, res_sh) + + buf1 = _empty_like_orderK(x1, buf1_dt) + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q + ) + ht_list.append(ht_copy1_ev) + deps.append(copy1_ev) + if buf1.dtype.kind == "c": + ht_conj_ev, conj_ev = tei._conj( + src=buf1, dst=buf1, sycl_queue=exec_q, depends=[copy1_ev] + ) + ht_list.append(ht_conj_ev) + deps.append(conj_ev) + buf2 = _empty_like_orderK(x2, buf2_dt) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q + ) + ht_list.append(ht_copy2_ev) + deps.append(copy2_ev) + if buf1.shape != broadcast_sh: + buf1 = dpt.broadcast_to(buf1, broadcast_sh) + if buf2.shape != broadcast_sh: + buf2 = dpt.broadcast_to(buf2, broadcast_sh) + buf1 = dpt.moveaxis(buf1, axis, -1) + buf2 = dpt.moveaxis(buf2, axis, -1) + out = dpt.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, _ = tli._dot( + x1=buf1, + x2=buf2, + batch_dims=len(x1.shape[:-1]), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps, + ) + ht_list.append(ht_dot_ev) + dpctl.SyclEvent.wait_for(ht_list) + + return out + + +def matmul(x1, x2, out=None, dtype=None, order="K"): + """matmul(x1, x2, out=None, order="K") + + Computes the matrix product. Implements the same semantics + as the built-in operator `@`. + + Args: + x1 (usm_ndarray): + first input array. Expected to have numeric data type, and + at least one dimension. If `x1` is one-dimensional having + shape `(M,)`, and `x2` has more than one dimension, `x1` is + effectively treated as a two-dimensional array with shape `(1, M)`, + although the prepended dimension is removed from the output array. + If `x1` has shape `(..., M, K)`, the innermost two dimensions form + matrices on which to perform matrix multiplication. + x2 (usm_ndarray): + second input array. Expected to have numeric data type, and + at least one dimension. If `x2` is one-dimensional having + shape `(N,)`, and `x1` has more than one dimension, `x2` is + effectively treated as a two-dimensional array with shape `(N, 1)`, + although the appended dimension is removed from the output array. + If `x2` has shape `(..., K, N)`, the innermost two dimensions form + matrices on which to perform matrix multiplication. + out (Optional[usm_ndarray]): + the array into which the result of the matrix product is written. + If `None` then a new array is returned. + order (["K", "C", "F", "A"]): + memory layout of the output array, if `out` is `None`, otherwise + the `order` parameter value is not used. + + Returns: + usm_ndarray: + * if both `x1` and `x2` are one-dimensional arrays with shape + `(N,)`, returned array is a zero-dimensional array containing + inner product as its only element. + * if `x1` is two-dimensional array with shape `(M, K)` and `x2` is + a two-dimensional array with shape `(K, N)`, returned array is a + two-dimensional array with shape `(M, N)` and contains the + conventional matrix product. + * if `x1` is a one-dimensinal array with shape `(K,)` and `x2` is an + array with shape `(..., K, N)`, returned array contains the + conventional matrix product and has shape `(..., N)`. + * if `x1` is an array with shape `(..., M, K)` and `x2` is a + one-dimensional array with shape `(K,)`, returned array has shape + `(..., M)` and contains the conventional matrix product. + * if `x1` is a two-dimensional array with shape `(M, K)` and `x2` + is an array with shape `(..., K, N)`, returned array contains + conventional matrix product for each stacked matrix and has shape + `(..., M, N)`. + * if `x1` has shape `(..., M, K)` and `x2` is a two-dimensional + array with shape `(K, N)`, returned array contains conventional + matrix product for each stacked matrix and has shape + `(..., M, N)`. + * if both `x1` and `x2` have more than two dimensions, returned + array contains conventional matrix product for each stacked + matrix and has shape determined by broadcasting rules for + `x1.shape[:-2]` and `x2.shape[:-2]`. + + The data type of the returned array is determined by the Type + Promotion Rules. If either `x1` or `x2` has a complex floating + point type, neither argument is complex conjugated or transposed. + """ + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + if order not in ["K", "C", "F", "A"]: + order = "K" + q1, x1_usm_type = x1.sycl_queue, x1.usm_type + q2, x2_usm_type = x2.sycl_queue, x2.usm_type + if q1 is None and q2 is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments. " + "One of the arguments must represent USM allocation and " + "expose `__sycl_usm_array_interface__` property" + ) + if q1 is None: + exec_q = q2 + res_usm_type = x2_usm_type + elif q2 is None: + exec_q = q1 + res_usm_type = x1_usm_type + else: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + + x1_nd = x1.ndim + x2_nd = x2.ndim + if x1_nd == 0 or x2_nd == 0: + raise ValueError("one or more operands to `matmul` is 0 dimensional") + x1_shape = x1.shape + x2_shape = x2.shape + appended_axes = [] + if x1_nd == 1: + x1 = x1[dpt.newaxis, :] + x1_shape = x1.shape + appended_axes.append(-2) + if x2_nd == 1: + x2 = x2[:, dpt.newaxis] + x2_shape = x2.shape + appended_axes.append(-1) + if x1_shape[-1] != x2_shape[-2]: + raise ValueError("mismatch in `matmul` inner dimension") + x1_outer_sh = x1_shape[:-2] + x2_outer_sh = x2_shape[:-2] + try: + res_outer_sh = _broadcast_shape_impl( + [ + x1_outer_sh, + x2_outer_sh, + ] + ) + except ValueError: + raise ValueError("mismatch in `matmul` batching dimensions") + x1_broadcast_shape = res_outer_sh + x1_shape[-2:] + x2_broadcast_shape = res_outer_sh + x2_shape[-2:] + res_shape = res_outer_sh + x1_shape[-2:-1] + x2_shape[-1:] + + sycl_dev = exec_q.sycl_device + x1_dtype = x1.dtype + x2_dtype = x2.dtype + if dtype is None: + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + x1_dtype, + x2_dtype, + tli._dot_result_type, + sycl_dev, + acceptance_fn=_acceptance_fn_default_binary, + ) + if res_dt is None: + raise ValueError( + "function 'matmul' does not support input types " + f"({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + else: + res_dt = dpt.dtype(dtype) + res_dt = _to_device_supported_dtype(res_dt, sycl_dev) + if x1_dtype != res_dt: + if dpt.can_cast(x1_dtype, res_dt, casting="same_kind"): + buf1_dt = res_dt + else: + raise ValueError( + f"`matmul` input `x1` cannot be cast from {x1_dtype} to " + f"requested type {res_dt} according to the casting rule " + "''same_kind''." + ) + if x2_dtype != res_dt: + if dpt.can_cast(x2_dtype, res_dt, casting="same_kind"): + buf2_dt = res_dt + else: + raise ValueError( + f"`matmul` input `x2` cannot be cast from {x2_dtype} to " + f"requested type {res_dt} according to the casting rule " + "''same_kind''." + ) + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + + if out.shape != res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {res_shape}, got {out.shape}" + ) + + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed," f"got {out.dtype}" + ) + + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + + if isinstance(x1, dpt.usm_ndarray): + if ti._array_overlap(x1, out) and buf1_dt is None: + out = dpt.empty_like(out) + + if isinstance(x2, dpt.usm_ndarray): + if ti._array_overlap(x2, out) and buf2_dt is None: + # should not reach if out is reallocated + # after being checked against x1 + out = dpt.empty_like(out) + + if buf1_dt is None and buf2_dt is None: + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + x1, x2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + if order == "A": + order = ( + "F" + if all( + arr.flags.f_contiguous + for arr in ( + x1, + x2, + ) + ) + else "C" + ) + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + if x1.shape != res_shape: + x1 = dpt.broadcast_to(x1, x1_broadcast_shape) + if x2.shape != res_shape: + x2 = dpt.broadcast_to(x2, x2_broadcast_shape) + ht_dot_ev, binary_ev = tli._dot( + x1=x1, + x2=x2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + ) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[binary_ev], + ) + ht_copy_out_ev.wait() + out = orig_out + ht_dot_ev.wait() + if appended_axes: + out = dpt.squeeze(out, tuple(appended_axes)) + return out + elif buf1_dt is None: + if order == "K": + buf2 = _empty_like_orderK(x2, buf2_dt) + else: + if order == "A": + order = "F" if x1.flags.f_contiguous else "C" + buf2 = dpt.empty_like(x2, dtype=buf2_dt, order=order) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q + ) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + x1, buf2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + if x1.shape != res_shape: + x1 = dpt.broadcast_to(x1, x1_broadcast_shape) + buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape) + ht_dot_ev, binary_ev = tli._dot( + x1=x1, + x2=buf2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[binary_ev], + ) + ht_copy_out_ev.wait() + out = orig_out + ht_copy_ev.wait() + ht_dot_ev.wait() + if appended_axes: + out = dpt.squeeze(out, tuple(appended_axes)) + return out + + elif buf2_dt is None: + if order == "K": + buf1 = _empty_like_orderK(x1, buf1_dt) + else: + if order == "A": + order = "F" if x1.flags.f_contiguous else "C" + buf1 = dpt.empty_like(x1, dtype=buf1_dt, order=order) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q + ) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + buf1, x2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape) + if x2.shape != res_shape: + x2 = dpt.broadcast_to(x2, x2_broadcast_shape) + ht_dot_ev, binary_ev = tli._dot( + x1=buf1, + x2=x2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[binary_ev], + ) + ht_copy_out_ev.wait() + out = orig_out + ht_copy_ev.wait() + ht_dot_ev.wait() + if appended_axes: + out = dpt.squeeze(out, tuple(appended_axes)) + return out + + if order in ["K", "A"]: + if x1.flags.f_contiguous and x2.flags.f_contiguous: + order = "F" + elif x1.flags.c_contiguous and x2.flags.c_contiguous: + order = "C" + else: + order = "C" if order == "A" else "K" + if order == "K": + buf1 = _empty_like_orderK(x1, buf1_dt) + else: + buf1 = dpt.empty_like(x1, dtype=buf1_dt, order=order) + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q + ) + if order == "K": + buf2 = _empty_like_orderK(x2, buf2_dt) + else: + buf2 = dpt.empty_like(x2, dtype=buf2_dt, order=order) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q + ) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + buf1, buf2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape) + buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape) + ht_, _ = tli._dot( + x1=buf1, + x2=buf2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=[copy1_ev, copy2_ev], + ) + dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_]) + if appended_axes: + out = dpt.squeeze(out, tuple(appended_axes)) + return out diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp new file mode 100644 index 0000000000..15e5e35d67 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp @@ -0,0 +1,1137 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "pybind11/pybind11.h" +#include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ + +template +struct SequentialDotProduct +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + +public: + SequentialDotProduct(const lhsT *lhs, + const rhsT *rhs, + outT *out, + BatchIndexerT batch_indexer, + RedIndexerT reduced_dims_indexer, + size_t reduction_size) + : lhs_(lhs), rhs_(rhs), out_(out), batch_indexer_(batch_indexer), + reduced_dims_indexer_(reduced_dims_indexer), + reduction_max_gid_(reduction_size) + { + } + + void operator()(sycl::id<1> id) const + { + + auto const &batch_offsets = batch_indexer_(id[0]); + const py::ssize_t &lhs_batch_offset = batch_offsets.get_first_offset(); + const py::ssize_t &rhs_batch_offset = batch_offsets.get_second_offset(); + const py::ssize_t &out_batch_offset = batch_offsets.get_third_offset(); + + outT red_val(0); + for (size_t m = 0; m < reduction_max_gid_; ++m) { + auto reduction_offsets = reduced_dims_indexer_(m); + auto lhs_reduction_offset = reduction_offsets.get_first_offset(); + auto rhs_reduction_offset = reduction_offsets.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + red_val += convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + } + + out_[out_batch_offset] = red_val; + } +}; + +template +struct DotProductFunctor +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + size_t batches_ = 1; + size_t reductions_per_wi = 16; + +public: + DotProductFunctor(const lhsT *lhs, + const rhsT *rhs, + outT *res, + BatchIndexerT batch_indexer, + RedIndexerT arg_reduced_dims_indexer, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : lhs_(lhs), rhs_(rhs), out_(res), batch_indexer_(batch_indexer), + reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), batches_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t batch_id = it.get_group(0) % batches_; + const size_t reduction_batch_id = it.get_group(0) / batches_; + + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + // for each input + + auto batch_offsets_ = batch_indexer_(batch_id); + const auto &lhs_batch_offset = batch_offsets_.get_first_offset(); + const auto &rhs_batch_offset = batch_offsets_.get_second_offset(); + const auto &out_batch_offset = batch_offsets_.get_third_offset(); + + outT local_red_val(0); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto reduction_offsets_ = reduced_dims_indexer_(arg_reduce_gid); + const auto &lhs_reduction_offset = + reduction_offsets_.get_first_offset(); + const auto &rhs_reduction_offset = + reduction_offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + + local_red_val += val; + } + + auto work_group = it.get_group(); + outT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, outT(0), sycl::plus()); + + if (work_group.leader()) { + sycl::atomic_ref + res_ref(out_[out_batch_offset]); + res_ref += red_val_over_wg; + } + } +}; + +template +class dot_product_seq_krn; + +template class dot_product_init_krn; + +template +class dot_product_krn; + +typedef sycl::event (*dot_product_impl_fn_ptr_t)( + sycl::queue &, + size_t, + size_t, + const char *, + const char *, + char *, + int, + const py::ssize_t *, + py::ssize_t, + py::ssize_t, + py::ssize_t, + int, + const py::ssize_t *, + py::ssize_t, + py::ssize_t, + const std::vector &); + +template +sycl::event dot_product_impl(sycl::queue &exec_q, + size_t batches, + size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + int batch_nd, + const py::ssize_t *batch_shape_and_strides, + py::ssize_t batch_lhs_offset, + py::ssize_t batch_rhs_offset, + py::ssize_t batch_res_offset, + int red_nd, + const py::ssize_t *reduction_shape_stride, + py::ssize_t reduction_lhs_offset, + py::ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + InputOutputBatchIndexerT in_out_batch_indexer{ + batch_nd, batch_lhs_offset, batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + cgh.parallel_for>( + sycl::range<1>(batches), + SequentialDotProduct( + lhs_tp, rhs_tp, res_tp, in_out_batch_indexer, + reduction_indexer, reduction_nelems)); + }); + + return dot_ev; + } + else { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const py::ssize_t *const &res_shape = batch_shape_and_strides; + const py::ssize_t *const &res_strides = + batch_shape_and_strides + 3 * batch_nd; + IndexerT res_indexer(batch_nd, batch_res_offset, res_shape, + res_strides); + using InitKernelName = + class dot_product_init_krn; + cgh.depends_on(depends); + + cgh.parallel_for( + sycl::range<1>(batches), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = 0; + }); + }); + + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using BatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + BatchIndexerT batch_indexer{batch_nd, batch_lhs_offset, + batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + constexpr size_t preferred_reductions_per_wi = + 4; // determined experimentally + size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = + class dot_product_krn; + + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductFunctor( + lhs_tp, rhs_tp, res_tp, batch_indexer, reduction_indexer, + reduction_nelems, batches, reductions_per_wi)); + }); + return dot_ev; + } +} + +typedef sycl::event (*dot_product_contig_impl_fn_ptr_t)( + sycl::queue &, + size_t, + size_t, + const char *, + const char *, + char *, + py::ssize_t, + py::ssize_t, + py::ssize_t, + py::ssize_t, + py::ssize_t, + const std::vector &); + +template +sycl::event +dot_product_contig_impl(sycl::queue &exec_q, + size_t batches, + size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + py::ssize_t batch_lhs_offset, + py::ssize_t batch_rhs_offset, + py::ssize_t batch_res_offset, + py::ssize_t reduction_lhs_offset, + py::ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp) + + batch_lhs_offset + reduction_lhs_offset; + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp) + + batch_rhs_offset + reduction_rhs_offset; + resTy *res_tp = reinterpret_cast(res_cp) + batch_res_offset; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + cgh.parallel_for>( + sycl::range<1>(batches), + SequentialDotProduct( + lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, + reduction_indexer, reduction_nelems)); + }); + + return dot_ev; + } + else { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.fill(res_tp, resTy(0), batches); + }); + + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + constexpr size_t preferred_reductions_per_wi = + 4; // determined experimentally + size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class dot_product_krn; + + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductFunctor( + lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, + reduction_indexer, reduction_nelems, batches, + reductions_per_wi)); + }); + return dot_ev; + } +} + +template +struct DotProductNoAtomicFunctor +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + size_t reduction_max_gid_ = 0; + size_t batches_ = 1; + size_t reductions_per_wi = 16; + +public: + DotProductNoAtomicFunctor(const lhsT *lhs, + const rhsT *rhs, + outT *res, + BatchIndexerT batch_indexer, + RedIndexerT arg_reduced_dims_indexer, + size_t reduction_size, + size_t iteration_size, + size_t reduction_size_per_wi) + : lhs_(lhs), rhs_(rhs), out_(res), batch_indexer_(batch_indexer), + reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), batches_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg + + const size_t batch_id = it.get_group(0) % batches_; + const size_t reduction_batch_id = it.get_group(0) / batches_; + const size_t n_reduction_groups = it.get_group_range(0) / batches_; + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + // for each input + + auto batch_offsets_ = batch_indexer_(batch_id); + const auto &lhs_batch_offset = batch_offsets_.get_first_offset(); + const auto &rhs_batch_offset = batch_offsets_.get_second_offset(); + const auto &out_batch_offset = batch_offsets_.get_third_offset(); + + outT local_red_val(0); + size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto reduction_offsets_ = reduced_dims_indexer_(arg_reduce_gid); + const auto &lhs_reduction_offset = + reduction_offsets_.get_first_offset(); + const auto &rhs_reduction_offset = + reduction_offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + + local_red_val += val; + } + + auto work_group = it.get_group(); + outT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, outT(0), sycl::plus()); + + if (work_group.leader()) { + // each group writes to a different memory location + out_[out_batch_offset * n_reduction_groups + reduction_batch_id] = + red_val_over_wg; + } + } +}; + +template +class dot_product_tree_krn; + +template +class dot_product_reduction_over_group_temps_krn; + +template +sycl::event dot_product_tree_impl(sycl::queue &exec_q, + size_t batches, + size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + int batch_nd, + const py::ssize_t *batch_shape_and_strides, + py::ssize_t batch_lhs_offset, + py::ssize_t batch_rhs_offset, + py::ssize_t batch_res_offset, + int red_nd, + const py::ssize_t *reduction_shape_stride, + py::ssize_t reduction_lhs_offset, + py::ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + InputOutputBatchIndexerT in_out_batch_indexer{ + batch_nd, batch_lhs_offset, batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + cgh.parallel_for>( + sycl::range<1>(batches), + SequentialDotProduct( + lhs_tp, rhs_tp, res_tp, in_out_batch_indexer, + reduction_indexer, reduction_nelems)); + }); + + return dot_ev; + } + + constexpr size_t preferred_reductions_per_wi = 8; + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, d.get_info() / 2); + + size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using BatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + BatchIndexerT batch_indexer{batch_nd, batch_lhs_offset, + batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + if (batches == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + using KernelName = + dot_product_tree_krn; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductNoAtomicFunctor( + lhs_tp, rhs_tp, res_tp, batch_indexer, reduction_indexer, + reduction_nelems, batches, reductions_per_wi)); + }); + + return dot_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + // more than one work-groups is needed, requires a temporary + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + resTy *partially_reduced_tmp = sycl::malloc_device( + batches * (reduction_groups + second_iter_reduction_groups_), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * batches; + } + + const sycl::event &first_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using LhsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using RhsIndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + LhsIndexerT, RhsIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + LhsIndexerT lhs_indexer(batch_nd, batch_lhs_offset, + batch_shape_and_strides); + RhsIndexerT rhs_indexer(batch_nd, batch_rhs_offset, + batch_shape_and_strides, + batch_shape_and_strides + 2 * batch_nd); + ResIndexerT noop_tmp_indexer{}; + + InputOutputBatchIndexerT in_out_iter_indexer{ + lhs_indexer, rhs_indexer, noop_tmp_indexer}; + ReductionIndexerT reduction_indexer{ + red_nd, reduction_lhs_offset, reduction_rhs_offset, + reduction_shape_stride}; + + auto globalRange = + sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = + class dot_product_tree_krn; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductNoAtomicFunctor( + lhs_tp, rhs_tp, partially_reduced_tmp, + in_out_iter_indexer, reduction_indexer, + reduction_nelems, batches, + preferred_reductions_per_wi)); + }); + + size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + sycl::event partial_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(batches), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{ + inp_indexer, res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{batches * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + using KernelName = + class dot_product_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, batches, + preferred_reductions_per_wi)); + }); + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(batches), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{batch_nd, batch_res_offset, + /* shape */ batch_shape_and_strides, + /* strides */ batch_shape_and_strides + + 2 * batch_nd}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class dot_product_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, batches, reductions_per_wi)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(final_reduction_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } +} + +template +sycl::event +dot_product_contig_tree_impl(sycl::queue &exec_q, + size_t batches, + size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + py::ssize_t batch_lhs_offset, + py::ssize_t batch_rhs_offset, + py::ssize_t batch_res_offset, + py::ssize_t reduction_lhs_offset, + py::ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp) + + batch_lhs_offset + reduction_lhs_offset; + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp) + + batch_rhs_offset + reduction_rhs_offset; + resTy *res_tp = reinterpret_cast(res_cp) + batch_res_offset; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + cgh.parallel_for>( + sycl::range<1>(batches), + SequentialDotProduct( + lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, + reduction_indexer, reduction_nelems)); + }); + + return dot_ev; + } + + constexpr size_t preferred_reductions_per_wi = 8; + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, d.get_info() / 2); + + size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + if (batches == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + using KernelName = dot_product_tree_krn; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductNoAtomicFunctor( + lhs_tp, rhs_tp, res_tp, inp_out_batch_indexer, + reduction_indexer, reduction_nelems, batches, + reductions_per_wi)); + }); + + return dot_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + // more than one work-groups is needed, requires a temporary + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + resTy *partially_reduced_tmp = sycl::malloc_device( + batches * (reduction_groups + second_iter_reduction_groups_), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * batches; + } + + const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + InputBatchIndexerT inp_batch_indexer{ + 0, static_cast(reduction_nelems), + static_cast(batches)}; + InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{NoOpIndexerT{}, NoOpIndexerT{}}; + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = + class dot_product_tree_krn; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + DotProductNoAtomicFunctor( + lhs_tp, rhs_tp, partially_reduced_tmp, + inp_out_batch_indexer, reduction_indexer, reduction_nelems, + batches, preferred_reductions_per_wi)); + }); + + size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + sycl::event partial_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(batches), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{ + inp_indexer, res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{batches * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + using KernelName = + class dot_product_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, batches, + preferred_reductions_per_wi)); + }); + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(batches), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class dot_product_reduction_over_group_temps_krn< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, batches, reductions_per_wi)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(final_reduction_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } +} + +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp new file mode 100644 index 0000000000..68bf9be860 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -0,0 +1,9840 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "kernels/reductions.hpp" +#include "pybind11/pybind11.h" +#include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ + +namespace gemm_detail +{ + +template +void scale_gemm_k_parameters(const size_t &local_mem_size, + const size_t &reserved_slm_size, + const size_t delta_k, + size_t &n_wi, + size_t &delta_n) +{ + constexpr size_t slm_elem_size = sizeof(T) * m_groups; + + while (slm_elem_size * (n_wi + delta_n) * delta_k + reserved_slm_size >= + local_mem_size) + { + n_wi = n_wi / 2; + delta_n = delta_n / 2; + if (delta_n == 0) + throw std::runtime_error("Insufficient resources"); + } +} + +template +void scale_gemm_nm_parameters(const size_t &local_mem_size, + const size_t &reserved_slm_size, + const size_t &wi_delta_n, + size_t &wi_delta_k, + size_t &wg_delta_n, + size_t &wg_delta_m) +{ + constexpr size_t slm_A_elem_size = sizeof(T); + constexpr size_t slm_B_elem_size = sizeof(T) * wi_delta_m; + + while ((wi_delta_n * wg_delta_n * wi_delta_k * slm_A_elem_size) + + (wi_delta_k * wg_delta_m * slm_B_elem_size) + + reserved_slm_size >= + local_mem_size) + { + wg_delta_n /= 2; + wg_delta_m /= 2; + wi_delta_k /= 2; + if (wg_delta_n == 0) + throw std::runtime_error("Insufficient resources"); + } +} +} // namespace gemm_detail + +// template +// struct ThreeOffsets_CombinedIndexer +// { +// private: +// FirstIndexerT first_indexer_; +// SecondIndexerT second_indexer_; +// ThirdIndexerT third_indexer_; + +// public: +// ThreeOffsets_CombinedIndexer(const FirstIndexerT &first_indexer, +// const SecondIndexerT &second_indexer, +// const ThirdIndexerT &third_indexer) +// : first_indexer_(first_indexer), second_indexer_(second_indexer), +// third_indexer_(third_indexer) +// { +// } + +// ThreeOffsets operator()(py::ssize_t gid) const +// { +// return ThreeOffsets( +// first_indexer_(gid), second_indexer_(gid), third_indexer_(gid)); +// } +// }; + +using dpctl::tensor::sycl_utils::choose_workgroup_size; + +template +class gemm_reduction_over_group_temps_strided_krn; + +template +sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, + T *partially_reduced_tmp, + T *partially_reduced_tmp2, + T *res_tp, + T identity_val, + size_t iter_nelems, + size_t reduction_nelems, + size_t reduction_groups, + size_t wg, + size_t max_wg, + size_t preferred_reductions_per_wi, + size_t reductions_per_wi, + int res_nd, + py::ssize_t res_offset, + const py::ssize_t *res_shape_strides, + std::vector depends = {}) +{ + + const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + // Only 2*iter_nd entries describing shape and strides of + // iterated dimensions of input array from + // iter_shape_and_strides are going to be accessed by + // inp_indexer + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_reduction_over_group_temps_strided_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>( + partially_reduced_tmp, partially_reduced_tmp2, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, reductions_per_wi)); + }); + + size_t remaining_reduction_nelems = reduction_groups; + + T *temp_arg = partially_reduced_tmp2; + T *temp2_arg = partially_reduced_tmp; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > preferred_reductions_per_wi * max_wg) { + size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + sycl::event partial_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = + class gemm_reduction_over_group_temps_strided_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); + }); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{ + res_nd, static_cast(res_offset), res_shape_strides}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_reduction_over_group_temps_strided_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, reductions_per_wi)); + }); + + return final_reduction_ev; +} + +template +class gemm_reduction_over_group_temps_contig_krn; + +template +sycl::event +tree_reduction_for_gemm_contig(sycl::queue &exec_q, + T *partially_reduced_tmp, + T *partially_reduced_tmp2, + T *res_tp, + T identity_val, + size_t iter_nelems, + size_t reduction_nelems, + size_t reduction_groups, + size_t wg, + size_t max_wg, + size_t preferred_reductions_per_wi, + size_t reductions_per_wi, + std::vector depends = {}) +{ + + const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + // Only 2*iter_nd entries describing shape and strides of + // iterated dimensions of input array from + // iter_shape_and_strides are going to be accessed by + // inp_indexer + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, /* size */ static_cast(reduction_nelems), + /* step */ static_cast(iter_nelems)}; + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_reduction_over_group_temps_contig_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>( + partially_reduced_tmp, partially_reduced_tmp2, ReductionOpT(), + identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, reductions_per_wi)); + }); + + size_t remaining_reduction_nelems = reduction_groups; + + T *temp_arg = partially_reduced_tmp2; + T *temp2_arg = partially_reduced_tmp; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > preferred_reductions_per_wi * max_wg) { + size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + sycl::event partial_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + // n * m = iter_nelems because essentially, this process + // creates a stack of reduction_nelems 2D matrices and we reduce + // along the stack axis + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + + ReductionIndexerT reduction_indexer{}; + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_reduction_over_group_temps_contig_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); + }); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + sycl::event final_reduction_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_ev); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(remaining_reduction_nelems)}; + ResIndexerT res_iter_indexer{}; + + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_reduction_over_group_temps_contig_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>( + temp_arg, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, reductions_per_wi)); + }); + + return final_reduction_ev; +} + +template +class GemmFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wi_delta_m * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0 <= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0 <= v_s < wi_delta_k + + size_t g_j0 = j + v_j * wi_delta_m; + size_t g_s = s + v_s; + + sycl::vec vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t g_j = g_j0 + lane_id; + vec[lane_id] = + (g_j < m && g_s < k) + ? static_cast(rhs[g_s * b_st0 + g_j * b_st1]) + : resT(0); + } + + local_B_block[vid] = vec; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j * wi_delta_m; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + sycl::vec local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + sycl::atomic_ref + aout(res[res_indexer(gl_i * c_st0 + gl_j * c_st1)]); + + aout += local_sum[lane_id]; + } + } + } + } +}; + +// specialization for wi_delta_m == 1 +template +class GemmFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j; + size_t g_s = s + v_s; + + resT val = (g_j0 < m && g_s < k) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) + : resT(0); + + local_B_block[vid] = val; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + resT local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + + if (gl_i < n && j < m) { + sycl::atomic_ref + aout(res[res_indexer(gl_i * c_st0 + j * c_st1)]); + + aout += local_sum; + } + } + } +}; + +template +class GemmFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + size_t lid = it.get_local_linear_id(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = m_groups * block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = sycl::vec( + (sq < k && j < m) + ? static_cast(rhs[rhs_indexer(sqmj)]) + : identity_, + (sq < k && j + 1 < m) + ? static_cast(rhs[rhs_indexer(sqmj + 1)]) + : identity_); + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + sycl::vec private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + if (t + t_shift < k) { + private_sum += + (static_cast(lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]); + } + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + sycl::vec local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + sycl::atomic_ref + aout0(res[res_indexer(i * m + j)]); + + aout0 += local_sum[0]; + + if (j + 1 < m) { + sycl::atomic_ref + aout1(res[res_indexer(i * m + j + 1)]); + + aout1 += local_sum[1]; + } + } + } +}; + +// specialization for m_groups == 1 +template +class GemmFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + size_t lid = it.get_local_linear_id(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_indexer(sqmj)]) + : identity_; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + resT private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + if (t + t_shift < k) { + private_sum += + (static_cast(lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]); + } + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + resT local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + sycl::atomic_ref + aout(res[res_indexer(i * m + j)]); + + aout += local_sum; + } + } +}; + +template class gemm_init_krn; + +template +class gemm_k_krn; + +template +class gemm_nm_krn; + +typedef sycl::event (*gemm_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // lhs_outer_nelems (n) + size_t, // inner_nelems (k) + size_t, // rhs_outer_nelems (m) + int, // inner nd + int, // lhs outer nd + const py::ssize_t *, // lhs shape and strides + int, // rhs outer nd + const py::ssize_t *, // rhs shape and strides + int, // res outer nd + const py::ssize_t *, // res shape and strides + std::vector const &); + +template +sycl::event gemm_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_shape_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_shape_strides, + int res_outer_nd, + const py::ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + IndexerT res_indexer(res_outer_nd, 0, res_shape_strides); + using InitKernelName = class gemm_init_krn; + cgh.parallel_for( + sycl::range<1>(n * m), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using OuterInnerIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_shape_strides); + OuterInnerIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_shape_strides); + OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); + + if (m == 1) { + constexpr size_t m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); + } + else if (k > n && k > m) { + constexpr size_t m_groups = 2; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_nm_krn; + cgh.parallel_for( + ndRange, + GemmFunctorThreadNM( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + return gemm_ev; +} + +typedef sycl::event (*gemm_contig_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // n + size_t, // k + size_t, // m + std::vector const &); + +template +sycl::event gemm_contig_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m); + }); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using OuterInnerIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerIndexerT lhs_indexer{}; + OuterInnerIndexerT rhs_indexer{}; + OuterInnerIndexerT res_indexer{}; + + if (m == 1) { + constexpr size_t m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); + } + else if (k > n && k > m) { + constexpr size_t m_groups = 2; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_nm_krn; + cgh.parallel_for( + ndRange, + GemmFunctorThreadNM( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + return gemm_ev; +} + +template +class GemmNoAtomicFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wi_delta_m * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j * wi_delta_m; + size_t g_s = s + v_s; + + sycl::vec vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t g_j = g_j0 + lane_id; + vec[lane_id] = + (g_j < m && g_s < k) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : resT(0); + } + + local_B_block[vid] = vec; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j * wi_delta_m; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + sycl::vec local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + res[res_indexer(gl_i * c_st0 + gl_j * c_st1 + + block_s * n * m)] = local_sum[lane_id]; + } + } + } + } +}; + +template +class GemmNoAtomicFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j; + size_t g_s = s + v_s; + + resT val = (g_j0 < m && g_s < k) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) + : resT(0); + + local_B_block[vid] = val; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + resT local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + + if (gl_i < n && j < m) { + res[res_indexer(gl_i * c_st0 + j * c_st1 + block_s * n * m)] = + local_sum; + } + } + } +}; + +template +class GemmNoAtomicFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + size_t lid = it.get_local_linear_id(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = m_groups * block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = sycl::vec( + (sq < k && j < m) + ? static_cast(rhs[rhs_indexer(sqmj)]) + : identity_, + (sq < k && j + 1 < m) + ? static_cast(rhs[rhs_indexer(sqmj + 1)]) + : identity_); + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + sycl::vec private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + if (t + t_shift < k) { + private_sum += + (static_cast(lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]); + } + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + sycl::vec local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + res[res_indexer(i * m + j) + (block_s * n * m)] = local_sum[0]; + + if (j + 1 < m) { + res[res_indexer(i * m + j + 1) + (block_s * n * m)] = + local_sum[1]; + } + } + } +}; + +template +class GemmNoAtomicFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t gr_id = it.get_group_linear_id(); + size_t lid = it.get_local_linear_id(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_indexer(sqmj)]) + : identity_; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + resT private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + if (t + t_shift < k) { + private_sum += + (static_cast(lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]); + } + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + resT local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + res[res_indexer(i * m + j) + (block_s * n * m)] = local_sum; + } + } +}; + +template +class gemm_reduction_seq_strided_krn; + +template +class gemm_tree_nm_krn; + +template +class gemm_tree_k_krn; + +template +sycl::event gemm_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const py::ssize_t *res_shapes_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + if ((k > n && k > m) || m == 1) { + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m == 1) { + constexpr int m_groups = 1; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_nd, 0, + res_shapes_strides); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + m_groups>(lhs_tp, rhs_tp, res_tp, workspace, + local_B_block, n, n_blocks, delta_n, + k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, + res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-groups is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>( + n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = + sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + ResIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, ResIndexerT, + m_groups>( + lhs_tp, rhs_tp, tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, + k_blocks, delta_k, n_wi, m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{res_nd, 0, + res_shapes_strides}; + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for< + class gemm_reduction_seq_strided_krn< + resTy, resTy, ReductionOpT, + InputOutputIterIndexerT, + ReductionIndexerT>>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task( + [ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info< + sycl::info::device::max_work_group_size>() / + 2); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * + (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = partially_reduced_tmp + + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, ResIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, + workspace, local_B_block, n, n_blocks, delta_n, + k, k_blocks, delta_k, n_wi, m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + // tree_reduction_for_gemm returns sycl::event for reduction + sycl::event red_ev = + tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, + partially_reduced_tmp2, res_tp, identity_val, + iter_nelems, reduction_nelems, reduction_groups, wg, + max_wg, preferred_reductions_per_wi, + reductions_per_wi, res_nd, 0, res_shapes_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + else { + constexpr int m_groups = 2; + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_nd, 0, + res_shapes_strides); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + m_groups>(lhs_tp, rhs_tp, res_tp, workspace, + local_B_block, n, n_blocks, delta_n, + k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, + res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-groups is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>( + n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = + sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor< + sycl::vec, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + ResIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, ResIndexerT, + m_groups>( + lhs_tp, rhs_tp, tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, + k_blocks, delta_k, n_wi, m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{res_nd, 0, + res_shapes_strides}; + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for< + class gemm_reduction_seq_strided_krn< + resTy, resTy, ReductionOpT, + InputOutputIterIndexerT, + ReductionIndexerT>>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task( + [ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info< + sycl::info::device::max_work_group_size>() / + 2); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * + (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = partially_reduced_tmp + + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, ResIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, + workspace, local_B_block, n, n_blocks, delta_n, + k, k_blocks, delta_k, n_wi, m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, + partially_reduced_tmp2, res_tp, identity_val, + iter_nelems, reduction_nelems, reduction_groups, wg, + max_wg, preferred_reductions_per_wi, + reductions_per_wi, res_nd, 0, res_shapes_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + } + else { + constexpr int m_groups = 1; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_nd, 0, + res_shapes_strides); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-groups is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, ResIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, + n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + }); + + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{res_nd, 0, + res_shapes_strides}; + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / + 2); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + }); + // tree_reduction_for_gemm returns sycl::event for reduction + sycl::event red_ev = + tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, res_nd, + 0, res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + } + else { // m > 1, n > k or m > k + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + constexpr int wi_delta_m = 4; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_nd, 0, + res_shapes_strides); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to process + // use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * + wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, + 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>(lhs_tp, rhs_tp, tmp, local_A_block, + local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{res_nd, 0, + res_shapes_strides}; + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / + 2); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>(lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, + m_blocks, wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, res_nd, + 0, res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + else { + constexpr int wi_delta_m = 1; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_nd, 0, + res_shapes_strides); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to process + // use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * + wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>(lhs_tp, rhs_tp, tmp, local_A_block, + local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{res_nd, 0, + res_shapes_strides}; + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / + 2); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + using ResIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>(lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, + m_blocks, wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, res_nd, + 0, res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + } +} + +template +sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + if ((k > n && k > m) || m == 1) { + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m == 1) { + constexpr int m_groups = 1; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + m_groups>(lhs_tp, rhs_tp, res_tp, workspace, + local_B_block, n, n_blocks, delta_n, + k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, + res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-groups is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>( + n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = + sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, + k_blocks, delta_k, n_wi, m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for< + class gemm_reduction_seq_strided_krn< + resTy, resTy, ReductionOpT, + InputOutputIterIndexerT, + ReductionIndexerT>>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task( + [ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info< + sycl::info::device::max_work_group_size>() / + 2); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * + (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = partially_reduced_tmp + + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, + workspace, local_B_block, n, n_blocks, + delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + }); + // tree_reduction_for_gemm_contig returns sycl::event for + // reduction + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, + partially_reduced_tmp2, res_tp, identity_val, + iter_nelems, reduction_nelems, reduction_groups, wg, + max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + else { + constexpr int m_groups = 2; + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + m_groups>(lhs_tp, rhs_tp, res_tp, workspace, + local_B_block, n, n_blocks, delta_n, + k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, + res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-groups is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>( + n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = + sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor< + sycl::vec, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, + k_blocks, delta_k, n_wi, m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for< + class gemm_reduction_seq_strided_krn< + resTy, resTy, ReductionOpT, + InputOutputIterIndexerT, + ReductionIndexerT>>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task( + [ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info< + sycl::info::device::max_work_group_size>() / + 2); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * + (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = partially_reduced_tmp + + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + m_groups>(lhs_tp, rhs_tp, partially_reduced_tmp, + workspace, local_B_block, n, n_blocks, + delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, + res_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, + partially_reduced_tmp2, res_tp, identity_val, + iter_nelems, reduction_nelems, reduction_groups, wg, + max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + } + else { + constexpr int m_groups = 1; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-groups is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, + k_blocks, delta_k, n_wi, m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / + 2); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + }); + // tree_reduction_for_gemm returns sycl::event for reduction + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + } + else { // m > 1, n > k or m > k + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + constexpr int wi_delta_m = 4; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to process + // use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * + wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, + 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / + 2); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + else { + constexpr int wi_delta_m = 1; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to process + // use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * + wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / + 2); + + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + } +} + +template +class GemmBatchFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + size_t batch_nelems; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmBatchFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t m_id = + it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); + size_t gr_id = + it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wi_delta_m * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j * wi_delta_m; + size_t g_s = s + v_s; + + sycl::vec vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t g_j = g_j0 + lane_id; + vec[lane_id] = + (g_j < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : resT(0); + } + + local_B_block[vid] = vec; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j * wi_delta_m; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + sycl::vec local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + sycl::atomic_ref + aout(res[res_offset + + res_indexer(gl_i * c_st0 + gl_j * c_st1)]); + + aout += local_sum[lane_id]; + } + } + } + } +}; + +template +class GemmBatchFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + size_t batch_nelems; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmBatchFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t m_id = + it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); + size_t gr_id = + it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j; + size_t g_s = s + v_s; + + resT val = (g_j0 < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) + : resT(0); + + local_B_block[vid] = val; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + resT local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + + if (gl_i < n && j < m) { + sycl::atomic_ref + aout(res[res_offset + + res_indexer(gl_i * c_st0 + j * c_st1)]); + + aout += local_sum; + } + } + } +}; + +template +class GemmBatchFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmBatchFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + // for batching: + // (current matrix in batch) m_id = global_id / (global_range / + // batch_nelems) for lhs, offset = m_id * (n * k) for rhs, offset = m_id + // * (k * m) for res, offset = m_id * (n * m) + size_t m_id = + it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); + size_t gr_id = + it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = m_groups * block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = sycl::vec( + (sq < k && j < m) + ? static_cast(rhs[rhs_offset + rhs_indexer(sqmj)]) + : identity_, + (sq < k && j + 1 < m) + ? static_cast( + rhs[rhs_offset + rhs_indexer(sqmj + 1)]) + : identity_); + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + sycl::vec private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + if (t + t_shift < k) { + private_sum += + (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]); + } + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + sycl::vec local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + sycl::atomic_ref + aout0(res[res_offset + res_indexer(i * m + j)]); + + aout0 += local_sum[0]; + + if (j + 1 < m) { + sycl::atomic_ref + aout1(res[res_offset + res_indexer(i * m + j + 1)]); + + aout1 += local_sum[1]; + } + } + } +}; + +template +class GemmBatchFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmBatchFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + OuterInnerDimsIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + // for batching: + // (current matrix in batch) m_id = global_id / (global_range / + // batch_nelems) for lhs, offset = m_id * (n * k) for rhs, offset = m_id + // * (k * m) for res, offset = m_id * (n * m) + size_t m_id = + it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); + size_t gr_id = + it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_offset + rhs_indexer(sqmj)]) + : identity_; + ; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + resT private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + if (t + t_shift < k) { + private_sum += + (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]); + } + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + resT local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + sycl::atomic_ref + aout(res[res_offset + res_indexer(i * m + j)]); + + aout += local_sum; + } + } +}; + +template class gemm_batch_init_krn; + +template +class gemm_batch_k_krn; + +template +class gemm_batch_nm_krn; + +typedef sycl::event (*gemm_batch_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // batch nelems + size_t, // lhs outer nelems (n) + size_t, // inner nelems (k) + size_t, // rhs outer nelems (m) + int, // batching nd + const py::ssize_t *, // batch shape strides + py::ssize_t, // lhs batch offset + py::ssize_t, // rhs batch offset + py::ssize_t, // res batch offset + int, // inner dims + int, // lhs outer dims + const py::ssize_t *, // lhs outer and inner shape and strides + int, // rhs outer dims + const py::ssize_t *, // rhs outer and inner shape and strides + int, // res outer dims + const py::ssize_t *, // res outer and inner shape and strides + const py::ssize_t *, // res full shape and strides + std::vector const &); + +template +sycl::event gemm_batch_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const py::ssize_t *batch_shape_strides, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const py::ssize_t *res_outer_shapes_strides, + const py::ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + IndexerT res_indexer(batch_nd + res_outer_nd, res_batch_offset, + res_shape_strides); + using InitKernelName = class gemm_batch_init_krn; + cgh.parallel_for( + sycl::range<1>(n * m * batch_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + if (m == 1) { + constexpr int m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else if (k > n && k > m) { + constexpr size_t m_groups = 2; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_batch_nm_krn; + cgh.parallel_for( + ndRange, + GemmBatchFunctorThreadNM( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + + return gemm_ev; +} + +typedef sycl::event (*gemm_batch_contig_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + size_t, // batch nelems + size_t, // n + size_t, // k + size_t, // m + py::ssize_t, // lhs batch offset + py::ssize_t, // rhs batch offset + py::ssize_t, // res batch offset + std::vector const &); + +template +sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m * batch_nelems); + }); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); + + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + if (m == 1) { + constexpr int m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else if (k > n && k > m) { + constexpr size_t m_groups = 2; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_batch_nm_krn; + cgh.parallel_for( + ndRange, + GemmBatchFunctorThreadNM( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + + return gemm_ev; +} + +template +class GemmBatchNoAtomicFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + size_t batch_nelems; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t m_id = + it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); + size_t gr_id = + it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wi_delta_m * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j * wi_delta_m; + size_t g_s = s + v_s; + + sycl::vec vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t g_j = g_j0 + lane_id; + vec[lane_id] = + (g_j < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : resT(0); + } + + local_B_block[vid] = vec; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j * wi_delta_m; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + sycl::vec local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) { + size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + res[res_offset + res_indexer(gl_i * c_st0 + gl_j * c_st1) + + (block_s * n * m * batch_nelems)] = local_sum[lane_id]; + } + } + } + } +}; + +template +class GemmBatchNoAtomicFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + size_t n = 0; + size_t wg_delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t wi_delta_k = 0; + size_t m = 0; + size_t m_blocks = 0; + size_t wg_delta_m = 0; + size_t batch_nelems; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + size_t n_, + size_t wg_delta_n_, + size_t k_, + size_t k_blocks_, + size_t wi_delta_k_, + size_t m_, + size_t m_blocks_, + size_t wg_delta_m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t m_id = + it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); + size_t gr_id = + it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < + // k_blocks + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + size_t block_i = gr_id / (m_blocks * k_blocks); + size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + size_t block_j = block_r / k_blocks; + size_t block_s = block_r - block_j * k_blocks; + + size_t lid = it.get_local_linear_id(); + size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + size_t local_j = lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + size_t i = block_i * wi_delta_n * wg_delta_n; + size_t j = block_j * wg_delta_m; + size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + size_t lws = it.get_local_range(0); + + for (size_t vid = lid; vid < local_A_block.size(); vid += lws) { + size_t v_i = vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_i = i + v_i; + size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + for (size_t vid = lid; vid < local_B_block.size(); vid += lws) { + size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + size_t g_j0 = j + v_j; + size_t g_s = s + v_s; + + resT val = (g_j0 < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j0 * b_st1)]) + : resT(0); + + local_B_block[vid] = val; + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j; + + size_t a_offset = local_i * wi_delta_k * wi_delta_n; + size_t b_offset = local_j * wi_delta_k; + + constexpr resT identity_(0); + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + size_t a_pr_offset = private_i * wi_delta_k; + + resT local_sum(identity_); + for (size_t private_s = 0; private_s < wi_delta_k; ++private_s) { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + size_t gl_i = i + private_i; + + if (gl_i < n && j < m) { + res[res_offset + res_indexer(gl_i * c_st0 + j * c_st1) + + (block_s * n * m * batch_nelems)] = local_sum; + } + } + } +}; + +template +class GemmBatchNoAtomicFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t m_id = + it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); + size_t gr_id = + it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = m_groups * block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_offset + rhs_indexer(sqmj)]) + : identity_; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + sycl::vec private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + if (t + t_shift < k) { + private_sum += + (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]); + } + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + sycl::vec local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + res[res_offset + res_indexer(i * m + j) + + (block_s * n * m * batch_nelems)] = local_sum[0]; + + if (j + 1 < m) { + res[res_offset + res_indexer(i * m + j + 1) + + (block_s * n * m * batch_nelems)] = local_sum[1]; + } + } + } +}; + +template +class GemmBatchNoAtomicFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + size_t n = 0; + size_t n_blocks = 0; + size_t delta_n = 0; + size_t k = 0; + size_t k_blocks = 0; + size_t delta_k = 0; + size_t n_wi = 0; + size_t m = 0; + size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + size_t n_, + size_t n_blocks_, + size_t delta_n_, + size_t k_, + size_t k_blocks_, + size_t delta_k_, + size_t n_wi_, + size_t m_, + size_t batch_nelems_, + BatchDimsIndexerT batch_indexer_, + OuterInnerDimsIndexerT lhs_indexer_, + OuterInnerDimsIndexerT rhs_indexer_, + ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + size_t m_id = + it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); + size_t gr_id = + it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = + batch_indexer(static_cast(m_id)); + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + size_t block_j = + gr_id / (n_blocks * k_blocks); // 0 <= block_j < m_blocks + size_t block_r = + gr_id - block_j * (n_blocks * + k_blocks); // 0 <= block_r < n_blocks * k_blocks + size_t block_s = block_r / n_blocks; // 0 <= block_s < k_blocks + size_t block_i = + block_r - block_s * n_blocks; // 0 <= block_i < n_blocks + + size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + size_t local_s = lid - local_i * (delta_k); // 0 <= local_s < delta_k + + size_t i = block_i * delta_n + local_i; + size_t j = block_j; + size_t s = block_s * delta_k * n_wi + local_s; + + constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { + size_t sq = s + q; + size_t sqmj = sq * m + j; + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast(rhs[rhs_offset + rhs_indexer(sqmj)]) + : identity_; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + size_t t_shift = block_s * delta_k * n_wi; + size_t global_s_offset = i * k + t_shift; + + resT private_sum(identity_); + for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { + if (t + t_shift < k) { + private_sum += + (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]); + } + } + + size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + resT local_sum(workspace[workspace_i_shift]); + for (size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + res[res_offset + res_indexer(i * m + j) + + (block_s * n * m * batch_nelems)] = local_sum; + } + } +}; + +template +class gemm_batch_tree_k_krn; + +template +class gemm_batch_tree_nm_krn; + +template +sycl::event +gemm_batch_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const py::ssize_t *batch_shape_strides, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const py::ssize_t *res_outer_shapes_strides, + const py::ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + if ((k > n && k > m) || m == 1) { + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(4); + + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m == 1) { + constexpr int m_groups = 1; + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + if (k <= (delta_k * n_wi)) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer( + res_outer_nd, 0, res_outer_shapes_strides); + using BatchDimsIndexerT = dpctl::tensor::offset_utils:: + ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer( + batch_nd, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, batch_shape_strides); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, + m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, + local_B_block, n, n_blocks, delta_n, k, + k_blocks, delta_k, n_wi, m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils:: + UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, + Strided1DIndexer>; + StridedIndexer lhs_batch_indexer( + batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), + n * m); + BatchDimsIndexerT batch_indexer(lhs_batch_indexer, + rhs_batch_indexer, + tmp_batch_indexer); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, + k_blocks, delta_k, n_wi, m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{ + batch_nd + res_outer_nd, + static_cast(res_batch_offset), + res_shape_strides}; + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for< + class gemm_reduction_seq_strided_krn< + resTy, resTy, ReductionOpT, + InputOutputIterIndexerT, + ReductionIndexerT>>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task( + [ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info< + sycl::info::device::max_work_group_size>() / + 2); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * + (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = partially_reduced_tmp + + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, StridedIndexer, Strided1DIndexer>; + StridedIndexer lhs_batch_indexer( + batch_nd, lhs_batch_offset, batch_shape_strides); + StridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer(lhs_batch_indexer, + rhs_batch_indexer, + tmp_batch_indexer); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + constexpr int m_groups = 1; + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, + workspace, local_B_block, n, n_blocks, delta_n, + k, k_blocks, delta_k, n_wi, m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, + partially_reduced_tmp2, res_tp, identity_val, + iter_nelems, reduction_nelems, reduction_groups, wg, + max_wg, preferred_reductions_per_wi, + reductions_per_wi, batch_nd + res_outer_nd, + res_batch_offset, res_shape_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + else { + constexpr int m_groups = 2; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + if (k <= (delta_k * n_wi)) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer( + res_outer_nd, 0, res_outer_shapes_strides); + using BatchDimsIndexerT = dpctl::tensor::offset_utils:: + ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer( + batch_nd, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, batch_shape_strides); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, + m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, + local_B_block, n, n_blocks, delta_n, k, + k_blocks, delta_k, n_wi, m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils:: + UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, + Strided1DIndexer>; + StridedIndexer lhs_batch_indexer( + batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), + n * m); + BatchDimsIndexerT batch_indexer(lhs_batch_indexer, + rhs_batch_indexer, + tmp_batch_indexer); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = + sycl::local_accessor, + 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, + k_blocks, delta_k, n_wi, m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{ + batch_nd + res_outer_nd, + static_cast(res_batch_offset), + res_shape_strides}; + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for< + class gemm_reduction_seq_strided_krn< + resTy, resTy, ReductionOpT, + InputOutputIterIndexerT, + ReductionIndexerT>>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task( + [ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info< + sycl::info::device::max_work_group_size>() / + 2); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * + (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = partially_reduced_tmp + + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils:: + UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + StridedIndexer lhs_batch_indexer( + batch_nd, lhs_batch_offset, batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer(lhs_batch_indexer, + rhs_batch_indexer, + tmp_batch_indexer); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, + workspace, local_B_block, n, n_blocks, delta_n, + k, k_blocks, delta_k, n_wi, m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, + partially_reduced_tmp2, res_tp, identity_val, + iter_nelems, reduction_nelems, reduction_groups, wg, + max_wg, preferred_reductions_per_wi, + reductions_per_wi, batch_nd + res_outer_nd, + res_batch_offset, res_shape_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + } + else { + constexpr int m_groups = 1; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= (delta_k * n_wi)) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer( + res_outer_nd, 0, res_outer_shapes_strides); + using BatchDimsIndexerT = dpctl::tensor::offset_utils:: + ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer( + batch_nd, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, batch_shape_strides); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + constexpr int m_groups = 1; + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils:: + UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + StridedIndexer lhs_batch_indexer( + batch_nd, lhs_batch_offset, batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer(lhs_batch_indexer, + rhs_batch_indexer, + tmp_batch_indexer); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + constexpr int m_groups = 1; + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, + n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{ + batch_nd + res_outer_nd, + static_cast(res_batch_offset), + res_shape_strides}; + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / + 2); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer(lhs_batch_indexer, + rhs_batch_indexer, + tmp_batch_indexer); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + constexpr int m_groups = 1; + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, + res_shape_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + } + else { // m > 1, n > k or m > k + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + constexpr int wi_delta_m = 4; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer( + res_outer_nd, 0, res_outer_shapes_strides); + using BatchDimsIndexerT = dpctl::tensor::offset_utils:: + ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer( + batch_nd, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, batch_shape_strides); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils:: + UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + StridedIndexer lhs_batch_indexer( + batch_nd, lhs_batch_offset, batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer(lhs_batch_indexer, + rhs_batch_indexer, + tmp_batch_indexer); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * + wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, + 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{ + batch_nd + res_outer_nd, + static_cast(res_batch_offset), + res_shape_strides}; + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / + 2); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer(lhs_batch_indexer, + rhs_batch_indexer, + tmp_batch_indexer); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, + res_shape_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + else { // m > 1, n > k or m > k, resTy complex + constexpr int wi_delta_m = 1; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer( + res_outer_nd, 0, res_outer_shapes_strides); + using BatchDimsIndexerT = dpctl::tensor::offset_utils:: + ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer( + batch_nd, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, batch_shape_strides); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils:: + UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + StridedIndexer lhs_batch_indexer( + batch_nd, lhs_batch_offset, batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer(lhs_batch_indexer, + rhs_batch_indexer, + tmp_batch_indexer); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * + wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{ + batch_nd + res_outer_nd, + static_cast(res_batch_offset), + res_shape_strides}; + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / + 2); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer(lhs_batch_indexer, + rhs_batch_indexer, + tmp_batch_indexer); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, + res_shape_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + } +} + +template +sycl::event +gemm_batch_contig_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = + reinterpret_cast(lhs_cp) + lhs_batch_offset; + const rhsTy *rhs_tp = + reinterpret_cast(rhs_cp) + rhs_batch_offset; + resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + if ((k > n && k > m) || m == 1) { + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(4); + + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m == 1) { + constexpr int m_groups = 1; + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + if (k <= (delta_k * n_wi)) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{ + lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{ + rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{ + res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, + m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, + local_B_block, n, n_blocks, delta_n, k, + k_blocks, delta_k, n_wi, m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{ + lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{ + rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{ + res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, + m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, + m_groups>(lhs_tp, rhs_tp, tmp, workspace, + local_B_block, n, n_blocks, + delta_n, k, k_blocks, delta_k, + n_wi, m, batch_nelems, + batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for< + class gemm_reduction_seq_strided_krn< + resTy, resTy, ReductionOpT, + InputOutputIterIndexerT, + ReductionIndexerT>>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task( + [ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info< + sycl::info::device::max_work_group_size>() / + 2); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * + (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = partially_reduced_tmp + + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{ + lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{ + rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{ + res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + constexpr int m_groups = 1; + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, + m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, + workspace, local_B_block, n, n_blocks, delta_n, + k, k_blocks, delta_k, n_wi, m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + tmp_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, + partially_reduced_tmp2, res_tp, identity_val, + iter_nelems, reduction_nelems, reduction_groups, wg, + max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + else { + constexpr int m_groups = 2; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + if (k <= (delta_k * n_wi)) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{ + lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{ + rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{ + res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, + m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, + local_B_block, n, n_blocks, delta_n, k, + k_blocks, delta_k, n_wi, m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{ + lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{ + rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{ + res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = + sycl::local_accessor, + 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, + m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, + m_groups>(lhs_tp, rhs_tp, tmp, workspace, + local_B_block, n, n_blocks, + delta_n, k, k_blocks, delta_k, + n_wi, m, batch_nelems, + batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for< + class gemm_reduction_seq_strided_krn< + resTy, resTy, ReductionOpT, + InputOutputIterIndexerT, + ReductionIndexerT>>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task( + [ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info< + sycl::info::device::max_work_group_size>() / + 2); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * + (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = partially_reduced_tmp + + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{ + lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{ + rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{ + res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, + m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, + workspace, local_B_block, n, n_blocks, delta_n, + k, k_blocks, delta_k, n_wi, m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + tmp_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, + partially_reduced_tmp2, res_tp, identity_val, + iter_nelems, reduction_nelems, reduction_groups, wg, + max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + } + else { + constexpr int m_groups = 1; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, + delta_k, // modified by reference + n_wi, // modified by reference + delta_n // modified by reference + ); + + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= (delta_k * n_wi)) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + constexpr int m_groups = 1; + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{ + lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{ + rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{ + res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + constexpr int m_groups = 1; + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, + m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, + n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / + 2); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + constexpr int m_groups = 1; + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + } + else { // m > 1, n > k or m > k + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + constexpr int wi_delta_m = 4; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{ + lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{ + rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{ + res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * + wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, + 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, + wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / + 2); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + else { // m > 1, n > k or m > k, resTy not complex + constexpr int wi_delta_m = 1; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = + choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{ + lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{ + rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{ + res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * + wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, + wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils:: + TwoOffsets_CombinedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / + 2); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error( + "Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils:: + ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{lhs_batch_offset, + static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{rhs_batch_offset, + static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{res_batch_offset, + static_cast(batch_nelems), + static_cast(n * m)}); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } + } + } +} + +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp b/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp new file mode 100644 index 0000000000..926f5ffad6 --- /dev/null +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot.cpp @@ -0,0 +1,857 @@ +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include +#include +#include + +#include "dot.hpp" +#include "dot_atomic_support.hpp" +#include "dot_dispatch.hpp" +#include "elementwise_functions/elementwise_functions_type_utils.hpp" +#include "kernels/linalg_functions/dot_product.hpp" +#include "kernels/linalg_functions/gemm.hpp" +#include "reductions/reduction_atomic_support.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +static int dot_output_id_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::dot_product_impl_fn_ptr_t; +static dot_product_impl_fn_ptr_t dot_product_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static dot_product_impl_fn_ptr_t + dot_product_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::dot_product_contig_impl_fn_ptr_t; +static dot_product_contig_impl_fn_ptr_t + dot_product_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static dot_product_contig_impl_fn_ptr_t + dot_product_contig_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_impl_fn_ptr_t; +static gemm_impl_fn_ptr_t gemm_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static gemm_impl_fn_ptr_t gemm_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_contig_impl_fn_ptr_t; +static gemm_contig_impl_fn_ptr_t + gemm_contig_atomic_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static gemm_contig_impl_fn_ptr_t + gemm_contig_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_batch_impl_fn_ptr_t; +static gemm_batch_impl_fn_ptr_t + gemm_batch_atomic_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static gemm_batch_impl_fn_ptr_t + gemm_batch_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_batch_contig_impl_fn_ptr_t; +static gemm_batch_contig_impl_fn_ptr_t + gemm_batch_contig_atomic_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static gemm_batch_contig_impl_fn_ptr_t + gemm_batch_contig_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void init_dot_dispatch_tables(void) +{ + using dpctl::tensor::py_internal::DotTypeMapFactory; + td_ns::DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(dot_output_id_table); + + using dpctl::tensor::py_internal::GemmBatchAtomicFactory; + td_ns::DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(gemm_batch_atomic_dispatch_table); + + using dpctl::tensor::py_internal::GemmBatchContigAtomicFactory; + td_ns::DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(gemm_batch_contig_atomic_dispatch_table); + + using dpctl::tensor::py_internal::GemmAtomicFactory; + td_ns::DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(gemm_atomic_dispatch_table); + + using dpctl::tensor::py_internal::GemmContigAtomicFactory; + td_ns::DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(gemm_contig_atomic_dispatch_table); + + using dpctl::tensor::py_internal::GemmBatchTempsFactory; + td_ns::DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(gemm_batch_temps_dispatch_table); + + using dpctl::tensor::py_internal::GemmBatchContigTempsFactory; + td_ns::DispatchTableBuilder + dtb7; + dtb7.populate_dispatch_table(gemm_batch_contig_temps_dispatch_table); + + using dpctl::tensor::py_internal::GemmTempsFactory; + td_ns::DispatchTableBuilder + dtb8; + dtb8.populate_dispatch_table(gemm_temps_dispatch_table); + + using dpctl::tensor::py_internal::GemmContigTempsFactory; + td_ns::DispatchTableBuilder + dtb9; + dtb9.populate_dispatch_table(gemm_contig_temps_dispatch_table); + + using dpctl::tensor::py_internal::DotProductAtomicFactory; + td_ns::DispatchTableBuilder + dtb10; + dtb10.populate_dispatch_table(dot_product_dispatch_table); + + using dpctl::tensor::py_internal::DotProductNoAtomicFactory; + td_ns::DispatchTableBuilder + dtb11; + dtb11.populate_dispatch_table(dot_product_temps_dispatch_table); + + using dpctl::tensor::py_internal::DotProductContigAtomicFactory; + td_ns::DispatchTableBuilder + dtb12; + dtb12.populate_dispatch_table(dot_product_contig_dispatch_table); + + using dpctl::tensor::py_internal::DotProductContigNoAtomicFactory; + td_ns::DispatchTableBuilder + dtb13; + dtb13.populate_dispatch_table(dot_product_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t dot_atomic_support_vector[td_ns::num_types]; + +void init_dot_atomic_support_vector(void) +{ + + using atomic_support::DotAtomicSupportFactory; + td_ns::DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(dot_atomic_support_vector); +} + +std::pair +py_dot(const dpctl::tensor::usm_ndarray &x1, + const dpctl::tensor::usm_ndarray &x2, + int batch_dims, + int x1_outer_dims, + int x2_outer_dims, + int inner_dims, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) +{ + + if (!dst.is_writable()) { + throw py::value_error("Output array is read-only."); + } + + if (inner_dims == 0) { + throw py::value_error("No inner dimension for dot"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {x1, x2, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + int x1_nd = x1.get_ndim(); + int x2_nd = x2.get_ndim(); + if (x1_nd != (batch_dims + x1_outer_dims + inner_dims) || + x2_nd != (batch_dims + x2_outer_dims + inner_dims)) + { + throw py::value_error("Input arrays do not have dimensions consistent " + "with input dimensions"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != (batch_dims + x1_outer_dims + x2_outer_dims)) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of input dimensions"); + } + + const py::ssize_t *x1_shape_ptr = x1.get_shape_raw(); + const py::ssize_t *x2_shape_ptr = x2.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + size_t batches(1); + for (int i = 0; same_shapes && (i < batch_dims); ++i) { + same_shapes = same_shapes && (x1_shape_ptr[i] == dst_shape_ptr[i]) && + (x2_shape_ptr[i] == dst_shape_ptr[i]); + batches *= x1_shape_ptr[i]; + } + size_t x1_outer_nelems(1); + for (int i = batch_dims; same_shapes && (i < (batch_dims + x1_outer_dims)); + ++i) { + same_shapes = same_shapes && (x1_shape_ptr[i] == dst_shape_ptr[i]); + x1_outer_nelems *= x1_shape_ptr[i]; + } + size_t inner_nelems(1); + for (int i = batch_dims; i < (batch_dims + inner_dims); ++i) { + auto x1_shape_idx = x1_outer_dims + i; + same_shapes = + same_shapes && (x1_shape_ptr[x1_shape_idx] == x2_shape_ptr[i]); + inner_nelems *= x1_shape_ptr[x1_shape_idx]; + } + size_t x2_outer_nelems(1); + for (int i = 0; same_shapes && (i < x2_outer_dims); ++i) { + auto x2_shape_idx = batch_dims + inner_dims + i; + same_shapes = + same_shapes && (x2_shape_ptr[x2_shape_idx] == + dst_shape_ptr[batch_dims + x1_outer_dims + i]); + x2_outer_nelems *= x2_shape_ptr[x2_shape_idx]; + } + if (!same_shapes) { + throw py::value_error("Input arrays to tensor dot product do not have " + "appropriate shapes"); + } + + size_t dst_nelems = batches * x1_outer_nelems * x2_outer_nelems; + if (dst_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + if (static_cast(dst.get_size()) != dst_nelems) { + throw py::value_error("dst shape and size mismatch"); + } + + // ensure that dst is sufficiently ample + auto dst_offsets = dst.get_minmax_offsets(); + // destination must be ample enough to accommodate all elements + { + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < dst_nelems) { + throw py::value_error( + "Memory addressed by the destination array can not " + "accommodate all the " + "array elements."); + } + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + // check that dst does not intersect with x1 or x2 + if (overlap(dst, x1) || overlap(dst, x2)) { + throw py::value_error("Result array overlaps with inputs"); + } + + int x1_typenum = x1.get_typenum(); + int x2_typenum = x2.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + int x1_typeid = array_types.typenum_to_lookup_id(x1_typenum); + int x2_typeid = array_types.typenum_to_lookup_id(x2_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + int output_typeid = dot_output_id_table[x1_typeid][x2_typeid]; + + if (output_typeid != dst_typeid) { + throw py::value_error( + "Result array has unexpected elemental data type."); + } + + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + bool supports_atomics = + dot_atomic_support_vector[output_typeid](exec_q, usm_type); + + const char *x1_data = x1.get_data(); + const char *x2_data = x2.get_data(); + char *dst_data = dst.get_data(); + + auto x1_shape_vec = x1.get_shape_vector(); + auto x1_strides_vec = x1.get_strides_vector(); + + auto x2_shape_vec = x2.get_shape_vector(); + auto x2_strides_vec = x2.get_strides_vector(); + + auto dst_shape_vec = dst.get_shape_vector(); + auto dst_strides_vec = dst.get_strides_vector(); + + bool is_x1_c_contig = x1.is_c_contiguous(); + bool is_x1_f_contig = x1.is_f_contiguous(); + bool is_x2_c_contig = x2.is_c_contiguous(); + bool is_x2_f_contig = x2.is_f_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + + bool call_vecdot = ((x1_outer_dims == 0 && x1_outer_nelems == 1) && + (x2_outer_dims == 0 && x2_outer_nelems == 1)); + + bool call_batched = (batch_dims != 0 || batches > 1); + std::vector host_task_events{}; + sycl::event dot_ev; + if (call_vecdot) { + if ((is_x1_c_contig && is_x2_c_contig && is_dst_c_contig) || + ((is_x1_f_contig && is_x2_f_contig) && !call_batched)) + { + dot_product_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = dot_product_contig_dispatch_table[x1_typeid][x2_typeid]; + } + else { + fn = dot_product_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + constexpr py::ssize_t zero_offset = 0; + dot_ev = fn(exec_q, dst_nelems, inner_nelems, x1.get_data(), + x2.get_data(), dst.get_data(), + zero_offset, // lhs batch offset + zero_offset, // rhs batch offset + zero_offset, // res batch offset + zero_offset, // lhs reduction offset + zero_offset, // rhs reduction offset + depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + using dpctl::tensor::py_internal::simplify_iteration_space; + using dpctl::tensor::py_internal::simplify_iteration_space_3; + + int inner_nd = inner_dims; + const py::ssize_t *inner_shape_ptr = x1_shape_ptr + batch_dims; + using shT = std::vector; + shT inner_x1_strides(std::begin(x1_strides_vec) + batch_dims, + std::end(x1_strides_vec)); + shT inner_x2_strides(std::begin(x2_strides_vec) + batch_dims, + std::end(x2_strides_vec)); + + shT simplified_inner_shape; + shT simplified_inner_x1_strides; + shT simplified_inner_x2_strides; + py::ssize_t inner_x1_offset(0); + py::ssize_t inner_x2_offset(0); + + simplify_iteration_space( + inner_nd, inner_shape_ptr, inner_x1_strides, inner_x2_strides, + // output + simplified_inner_shape, simplified_inner_x1_strides, + simplified_inner_x2_strides, inner_x1_offset, inner_x2_offset); + + const py::ssize_t *batch_shape_ptr = x1_shape_ptr; + + shT batch_x1_strides(std::begin(x1_strides_vec), + std::begin(x1_strides_vec) + batch_dims); + shT batch_x2_strides(std::begin(x2_strides_vec), + std::begin(x2_strides_vec) + batch_dims); + shT const &batch_dst_strides = dst_strides_vec; + + shT simplified_batch_shape; + shT simplified_batch_x1_strides; + shT simplified_batch_x2_strides; + shT simplified_batch_dst_strides; + py::ssize_t batch_x1_offset(0); + py::ssize_t batch_x2_offset(0); + py::ssize_t batch_dst_offset(0); + + if (batch_dims == 0) { + if (dst_nelems != 1) { + throw std::runtime_error( + "batch_dims == 0, but dst_nelems != 1"); + } + batch_dims = 1; + simplified_batch_shape.push_back(1); + simplified_batch_x1_strides.push_back(0); + simplified_batch_x2_strides.push_back(0); + simplified_batch_dst_strides.push_back(0); + } + else { + simplify_iteration_space_3( + batch_dims, batch_shape_ptr, batch_x1_strides, batch_x2_strides, + batch_dst_strides, + // output + simplified_batch_shape, simplified_batch_x1_strides, + simplified_batch_x2_strides, simplified_batch_dst_strides, + batch_x1_offset, batch_x2_offset, batch_dst_offset); + } + + if (inner_nd == 1 && batch_dims == 1) { + bool dot_product_c_contig = false; + bool reduce_all_elems = false; + + if (simplified_inner_x1_strides[0] == 1 && + simplified_inner_x2_strides[0] == 1) { + reduce_all_elems = (simplified_batch_shape[0] == 1); + dot_product_c_contig = + (simplified_batch_dst_strides[0] == 1) && + (static_cast(simplified_batch_x1_strides[0]) == + inner_nelems) && + (static_cast(simplified_batch_x2_strides[0]) == + inner_nelems); + } + + if (dot_product_c_contig || reduce_all_elems) { + dot_product_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = + dot_product_contig_dispatch_table[x1_typeid][x2_typeid]; + } + else { + fn = dot_product_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + dot_ev = fn(exec_q, dst_nelems, inner_nelems, x1.get_data(), + x2.get_data(), dst.get_data(), + batch_x1_offset, // lhs batch offset + batch_x2_offset, // rhs batch offset + batch_dst_offset, // res batch offset + inner_x1_offset, // lhs reduction offset + inner_x2_offset, // rhs reduction offset + depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + } + + dot_product_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = dot_product_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn == nullptr) { + fn = dot_product_temps_dispatch_table[x1_typeid][x2_typeid]; + if (fn == nullptr) { + throw std::runtime_error( + "Implementation is missing for x1_typeid=" + + std::to_string(x1_typeid) + + " and x2_typeid=" + std::to_string(x2_typeid)); + } + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_batch_shape, simplified_batch_x1_strides, + simplified_batch_x2_strides, simplified_batch_dst_strides, + // reduction metadata + simplified_inner_shape, simplified_inner_x1_strides, + simplified_inner_x2_strides); + py::ssize_t *temp_allocation_ptr = + std::get<0>(arrays_metainfo_packing_triple_); + if (temp_allocation_ptr == nullptr) { + throw std::runtime_error("Unable to allocate memory on device"); + } + const auto ©_metadata_ev = + std::get<2>(arrays_metainfo_packing_triple_); + + py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + py::ssize_t *inner_shape_stride = + temp_allocation_ptr + 4 * simplified_batch_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + dot_ev = + fn(exec_q, dst_nelems, inner_nelems, x1.get_data(), x2.get_data(), + dst.get_data(), batch_dims, iter_shape_and_strides, + batch_x1_offset, batch_x2_offset, batch_dst_offset, + inner_nd, // number dimensions being reduced + inner_shape_stride, inner_x1_offset, inner_x2_offset, all_deps); + + sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dot_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, temp_allocation_ptr] { + sycl::free(temp_allocation_ptr, ctx); + }); + }); + host_task_events.push_back(temp_cleanup_ev); + } + else { // if (!call_vecdot) + if (!call_batched) { + if ((is_x1_c_contig && is_x2_c_contig && is_dst_c_contig)) { + gemm_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = + gemm_contig_atomic_dispatch_table[x1_typeid][x2_typeid]; + } + else { + fn = gemm_contig_temps_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn != nullptr) { + dot_ev = fn(exec_q, x1_data, x2_data, dst_data, + x1_outer_nelems, // n + inner_nelems, // k + x2_outer_nelems, // m + depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + gemm_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_atomic_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn == nullptr) { + fn = gemm_temps_dispatch_table[x1_typeid][x2_typeid]; + if (fn == nullptr) { + throw std::runtime_error( + "Implementation is missing for x1_typeid=" + + std::to_string(x1_typeid) + + " and x2_typeid=" + std::to_string(x2_typeid)); + } + } + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_size_event_tuple1 = + device_allocate_and_pack( + exec_q, host_task_events, x1_shape_vec, x1_strides_vec, + x2_shape_vec, x2_strides_vec, dst_shape_vec, + dst_strides_vec); + py::ssize_t *packed_shapes_strides = + std::get<0>(ptr_size_event_tuple1); + if (packed_shapes_strides == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event copy_shapes_strides_ev = + std::get<2>(ptr_size_event_tuple1); + py::ssize_t *x1_shape_strides = packed_shapes_strides; + py::ssize_t *x2_shape_strides = packed_shapes_strides + 2 * (x1_nd); + py::ssize_t *dst_shape_strides = + packed_shapes_strides + 2 * (x1_nd + x2_nd); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + // change gemm calls to pass inner dims and outer dims separately + dot_ev = + fn(exec_q, x1_data, x2_data, dst_data, x1_outer_nelems, + inner_nelems, x2_outer_nelems, inner_dims, x1_outer_dims, + x1_shape_strides, x2_outer_dims, x2_shape_strides, + x1_outer_dims + x2_outer_dims, dst_shape_strides, all_deps); + + sycl::event cleanup_tmp_allocations_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dot_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, packed_shapes_strides] { + sycl::free(packed_shapes_strides, ctx); + }); + }); + host_task_events.push_back(cleanup_tmp_allocations_ev); + host_task_events.push_back(dot_ev); + } + else { // if (call_batched) + using shT = std::vector; + // temporary asserts for matmul + assert(x1_outer_dims == 1); + assert(x2_outer_dims == 1); + assert(inner_dims == 1); + + if ((is_x1_c_contig && is_x2_c_contig && is_dst_c_contig)) { + gemm_batch_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_batch_contig_atomic_dispatch_table[x1_typeid] + [x2_typeid]; + } + else { + fn = gemm_batch_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + constexpr py::ssize_t zero_offset = 0; + dot_ev = fn(exec_q, x1_data, x2_data, dst_data, batches, + x1_outer_nelems, // n + inner_nelems, // k + x2_outer_nelems, // m + zero_offset, zero_offset, zero_offset, depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + + auto x1_outer_inner_dims = x1_nd - batch_dims; + auto x2_outer_inner_dims = x2_nd - batch_dims; + auto dst_outer_inner_dims = dst_nd - batch_dims; + + shT batch_x1_shape; + shT outer_inner_x1_shape; + shT batch_x1_strides; + shT outer_inner_x1_strides; + dpctl::tensor::py_internal::split_iteration_space( + x1_shape_vec, x1_strides_vec, batch_dims, + batch_dims + x1_outer_inner_dims, batch_x1_shape, + outer_inner_x1_shape, // 4 vectors modified + batch_x1_strides, outer_inner_x1_strides); + + shT batch_x2_shape; + shT outer_inner_x2_shape; + shT batch_x2_strides; + shT outer_inner_x2_strides; + dpctl::tensor::py_internal::split_iteration_space( + x2_shape_vec, x2_strides_vec, batch_dims, + batch_dims + x2_outer_inner_dims, batch_x2_shape, + outer_inner_x2_shape, // 4 vectors modified + batch_x2_strides, outer_inner_x2_strides); + + shT batch_dst_shape; + shT outer_inner_dst_shape; + shT batch_dst_strides; + shT outer_inner_dst_strides; + dpctl::tensor::py_internal::split_iteration_space( + dst_shape_vec, dst_strides_vec, batch_dims, + batch_dims + dst_outer_inner_dims, batch_dst_shape, + outer_inner_dst_shape, // 4 vectors modified + batch_dst_strides, outer_inner_dst_strides); + + using shT = std::vector; + shT simplified_batch_shape; + shT simplified_batch_x1_strides; + shT simplified_batch_x2_strides; + shT simplified_batch_dst_strides; + py::ssize_t x1_batch_offset(0); + py::ssize_t x2_batch_offset(0); + py::ssize_t dst_batch_offset(0); + + const py::ssize_t *shape = x1_shape_ptr; + + using dpctl::tensor::py_internal::simplify_iteration_space_3; + simplify_iteration_space_3( + batch_dims, shape, batch_x1_strides, batch_x2_strides, + batch_dst_strides, + // outputs + simplified_batch_shape, simplified_batch_x1_strides, + simplified_batch_x2_strides, simplified_batch_dst_strides, + x1_batch_offset, x2_batch_offset, dst_batch_offset); + + if (batch_dims == 1 && x1_outer_dims == 1 && x2_outer_dims == 1 && + inner_dims == 1) + { + bool gemm_batch_c_contig = false; + + if ((static_cast(outer_inner_x1_strides[0]) == + inner_nelems && + outer_inner_x1_strides[1] == 1) && + (static_cast(outer_inner_x2_strides[0]) == + inner_nelems && + outer_inner_x2_strides[1] == 1) && + (static_cast(outer_inner_dst_strides[0]) == + x2_outer_nelems && + outer_inner_dst_strides[1] == 1)) + { + gemm_batch_c_contig = + (static_cast(simplified_batch_x1_strides[0]) == + x1_outer_nelems * inner_nelems) && + (static_cast(simplified_batch_x2_strides[0]) == + x2_outer_nelems * inner_nelems) && + (static_cast(simplified_batch_dst_strides[0]) == + x1_outer_nelems * x2_outer_nelems); + } + + if (gemm_batch_c_contig) { + gemm_batch_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_batch_contig_atomic_dispatch_table[x1_typeid] + [x2_typeid]; + } + else { + fn = gemm_batch_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + dot_ev = fn(exec_q, x1_data, x2_data, dst_data, batches, + x1_outer_nelems, // n + inner_nelems, // k + x2_outer_nelems, // m + x1_batch_offset, x2_batch_offset, + dst_batch_offset, depends); + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {x1, x2, dst}, + {dot_ev}), + dot_ev); + } + } + } + + gemm_batch_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_batch_atomic_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn == nullptr) { + fn = gemm_batch_temps_dispatch_table[x1_typeid][x2_typeid]; + if (fn == nullptr) { + throw std::runtime_error( + "Implementation is missing for x1_typeid=" + + std::to_string(x1_typeid) + + " and x2_typeid=" + std::to_string(x2_typeid)); + } + } + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_size_event_tuple1 = + device_allocate_and_pack( + exec_q, host_task_events, simplified_batch_shape, + simplified_batch_x1_strides, simplified_batch_x2_strides, + simplified_batch_dst_strides, outer_inner_x1_shape, + outer_inner_x1_strides, outer_inner_x2_shape, + outer_inner_x2_strides, outer_inner_dst_shape, + outer_inner_dst_strides, + // full shape and strides of the result array + // necessary for reduction and initialization + simplified_batch_shape, outer_inner_dst_shape, + simplified_batch_dst_strides, outer_inner_dst_strides); + py::ssize_t *packed_shapes_strides = + std::get<0>(ptr_size_event_tuple1); + if (packed_shapes_strides == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event copy_shapes_strides_ev = + std::get<2>(ptr_size_event_tuple1); + + auto batch_shape_strides = packed_shapes_strides; + auto x1_outer_inner_shapes_strides = + packed_shapes_strides + 4 * batch_dims; + auto x2_outer_inner_shapes_strides = packed_shapes_strides + + 4 * batch_dims + + 2 * (x1_outer_inner_dims); + auto dst_outer_shapes_strides = + packed_shapes_strides + 4 * batch_dims + + 2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims); + auto dst_full_shape_strides = + packed_shapes_strides + 4 * batch_dims + + 2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims) + + 2 * (dst_outer_inner_dims); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + dot_ev = fn( + exec_q, x1_data, x2_data, dst_data, batches, x1_outer_nelems, + inner_nelems, x2_outer_nelems, batch_dims, batch_shape_strides, + x1_batch_offset, x2_batch_offset, dst_batch_offset, inner_dims, + x1_outer_dims, x1_outer_inner_shapes_strides, x2_outer_dims, + x2_outer_inner_shapes_strides, x1_outer_dims + x2_outer_dims, + dst_outer_shapes_strides, dst_full_shape_strides, all_deps); + + sycl::event cleanup_tmp_allocations_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dot_ev); + const auto &ctx = exec_q.get_context(); + cgh.host_task([ctx, packed_shapes_strides] { + sycl::free(packed_shapes_strides, ctx); + }); + }); + host_task_events.push_back(cleanup_tmp_allocations_ev); + host_task_events.push_back(dot_ev); + } + } + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {x1, x2, dst}, host_task_events), + dot_ev); +} + +template +py::object py_dot_result_type(const py::dtype &input1_dtype, + const py::dtype &input2_dtype, + const output_typesT &output_types_table) +{ + int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl + int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl + int src1_typeid = -1; + int src2_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + src1_typeid = array_types.typenum_to_lookup_id(tn1); + src2_typeid = array_types.typenum_to_lookup_id(tn2); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 || + src2_typeid >= td_ns::num_types) + { + throw std::runtime_error("binary output type lookup failed"); + } + int dst_typeid = output_types_table[src1_typeid][src2_typeid]; + + if (dst_typeid < 0) { + auto res = py::none(); + return py::cast(res); + } + else { + using dpctl::tensor::py_internal::type_utils::_dtype_from_typenum; + + auto dst_typenum_t = static_cast(dst_typeid); + auto dt = _dtype_from_typenum(dst_typenum_t); + + return py::cast(dt); + } +} + +void init_dot(py::module_ m) +{ + using dpctl::tensor::py_internal::init_dot_atomic_support_vector; + init_dot_atomic_support_vector(); + using dpctl::tensor::py_internal::init_dot_dispatch_tables; + init_dot_dispatch_tables(); + + using dpctl::tensor::py_internal::py_dot; + m.def("_dot", &py_dot, "", py::arg("x1"), py::arg("x2"), + py::arg("batch_dims"), py::arg("x1_outer_dims"), + py::arg("x2_outer_dims"), py::arg("inner_dims"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + using dpctl::tensor::py_internal::dot_output_id_table; + auto dot_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + using dpctl::tensor::py_internal::py_dot_result_type; + return py_dot_result_type(dtype1, dtype2, dot_output_id_table); + }; + m.def("_dot_result_type", dot_result_type_pyapi, ""); +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot.hpp b/dpctl/tensor/libtensor/source/linalg_functions/dot.hpp new file mode 100644 index 0000000000..5f8f6cf494 --- /dev/null +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot.hpp @@ -0,0 +1,17 @@ +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_dot(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot_atomic_support.hpp b/dpctl/tensor/libtensor/source/linalg_functions/dot_atomic_support.hpp new file mode 100644 index 0000000000..29022342a1 --- /dev/null +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot_atomic_support.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include + +#include "reductions/reduction_atomic_support.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ +namespace atomic_support +{ + +template struct DotAtomicSupportFactory +{ + fnT get() + { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + return atomic_support::fixed_decision; + } + else { + return atomic_support::check_atomic_support; + } + } +}; + +} // namespace atomic_support +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp b/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp new file mode 100644 index 0000000000..de59450174 --- /dev/null +++ b/dpctl/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp @@ -0,0 +1,336 @@ +#pragma once + +#include +#include +#include + +#include "kernels/linalg_functions/dot_product.hpp" +#include "kernels/linalg_functions/gemm.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +template struct DotAtomicOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; +}; + +// add separate type support lists for atomic vs. temps +// gemm, gevm, and dot product share output type struct +template struct DotNoAtomicOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::DefaultResultEntry>::result_type; +}; + +template struct DotTypeMapFactory +{ + /*! @brief get typeid for output type of kernels called by py_dot */ + std::enable_if_t::value, int> get() + { + using rT1 = typename DotNoAtomicOutputType::value_type; + using rT2 = typename DotAtomicOutputType::value_type; + static_assert(std::is_same_v || std::is_same_v); + return td_ns::GetTypeid{}.get(); + } +}; + +template struct GemmBatchAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_impl; + fnT fn = gemm_batch_impl; + return fn; + } + } +}; + +template +struct GemmBatchContigAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_contig_impl; + fnT fn = gemm_batch_contig_impl; + return fn; + } + } +}; + +template struct GemmAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_impl; + fnT fn = gemm_impl; + return fn; + } + } +}; + +template struct GemmContigAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_contig_impl; + fnT fn = gemm_contig_impl; + return fn; + } + } +}; + +template struct GemmTempsFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_tree_impl; + fnT fn = gemm_tree_impl; + return fn; + } + } +}; + +template struct GemmContigTempsFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_contig_tree_impl; + fnT fn = gemm_contig_tree_impl; + return fn; + } + } +}; + +template struct GemmBatchTempsFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_tree_impl; + fnT fn = gemm_batch_tree_impl; + return fn; + } + } +}; + +template +struct GemmBatchContigTempsFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_contig_tree_impl; + fnT fn = gemm_batch_contig_tree_impl; + return fn; + } + } +}; + +template struct DotProductAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_impl; + fnT fn = dot_product_impl; + return fn; + } + } +}; + +template +struct DotProductNoAtomicFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_tree_impl; + fnT fn = dot_product_tree_impl; + return fn; + } + } +}; + +template +struct DotProductContigAtomicFactory +{ + fnT get() + { + using T3 = typename DotAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_contig_impl; + fnT fn = dot_product_contig_impl; + return fn; + } + } +}; + +template +struct DotProductContigNoAtomicFactory +{ + fnT get() + { + using T3 = typename DotNoAtomicOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_contig_tree_impl; + fnT fn = dot_product_contig_tree_impl; + return fn; + } + } +}; + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_linalg.cpp b/dpctl/tensor/libtensor/source/tensor_linalg.cpp new file mode 100644 index 0000000000..82c9893c08 --- /dev/null +++ b/dpctl/tensor/libtensor/source/tensor_linalg.cpp @@ -0,0 +1,34 @@ +//===-- tensor_linalg.cpp ---*-C++-*-/===// +// Implementation of _tensor_linalg_impl module +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===----------------------------------------------------------------------===// + +#include "linalg_functions/dot.hpp" +#include + +namespace py = pybind11; + +PYBIND11_MODULE(_tensor_linalg_impl, m) +{ + dpctl::tensor::py_internal::init_dot(m); +} diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 4023eb8ad7..0b90e0b8fc 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -14,10 +14,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools + +import numpy as np import pytest import dpctl.tensor as dpt -from dpctl.tests.helper import get_queue_or_skip +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +_numeric_types = [ + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", +] def test_matrix_transpose(): @@ -46,3 +65,420 @@ def test_matrix_transpose_arg_validation(): X = dpt.empty((5, 5), dtype="i4") assert isinstance(dpt.matrix_transpose(X), dpt.usm_ndarray) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_matmul_simple(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n, m = 100, 17 + m1 = dpt.ones((m, n), dtype=dtype) + m2 = dpt.ones((n, m), dtype=dtype) + + for k in [1, 2, 3, 4, 7, 8, 9, 15, 16, 17]: + r = dpt.matmul(m1[:k, :], m2[:, :k]) + assert dpt.all(r == dpt.full((k, k), n, dtype=dtype)) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_matmul_nilpotent1(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n = 77 + N_mat = dpt.eye(n, k=1, dtype=dtype) + I_mat = dpt.eye(n, dtype=dtype) + R_mat = dpt.eye(n, dtype=dtype) + for _ in range(n + 1): + R_mat = I_mat + dpt.matmul(N_mat, R_mat) + + assert dpt.allclose(dpt.matmul(I_mat - N_mat, R_mat), I_mat) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_matmul_nilpotent2(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n = 128 + u = dpt.ones((n, 1), dtype=dtype) + v = dpt.ones((1, n), dtype=dtype) + + uv = dpt.matmul(u, v) + uv_ref = u * v + + assert dpt.allclose(uv, uv_ref) + + +def test_matmul_null_axis(): + n = 3 + + A_mat = dpt.ones((n, 0), dtype="f4") + B_mat = dpt.ones((0, 1), dtype="f4") + + R_mat = dpt.matmul(A_mat, B_mat) + assert R_mat.shape == (n, 1) + + R_mat = dpt.matmul(A_mat, B_mat[:, :0]) + assert R_mat.shape == (n, 0) + + +@pytest.mark.parametrize("dtype", ["i4", "f4"]) +def test_matmul_dims(dtype): + get_queue_or_skip() + + n, m, k, b = 4, 5, 7, 3 + v = dpt.ones(k, dtype=dtype) + m1 = dpt.ones((n, k), dtype=dtype) + m2 = dpt.ones((k, m), dtype=dtype) + st1 = dpt.ones((b, n, k), dtype=dtype) + st2 = dpt.ones((b, k, m), dtype=dtype) + + r = dpt.matmul(v, v) + assert r.shape == tuple() + assert dpt.round(r) == k + + r = dpt.matmul(m1, v) + assert r.shape == (n,) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(v, m2) + assert r.shape == (m,) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(m1, m2) + assert r.shape == ( + n, + m, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(v, st2) + assert r.shape == ( + b, + m, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(st1, v) + assert r.shape == ( + b, + n, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(st1, m2) + assert r.shape == ( + b, + n, + m, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(m1, st2) + assert r.shape == ( + b, + n, + m, + ) + assert dpt.all(dpt.round(r) == k) + + r = dpt.matmul(st1, st2) + assert r.shape == ( + b, + n, + m, + ) + assert dpt.all(dpt.round(r) == k) + + +def test_matmul_arg_validation(): + get_queue_or_skip() + + s1, s2 = dpt.ones(tuple()), dpt.zeros(tuple()) + v1, v2 = dpt.ones(16), dpt.zeros(16) + + with pytest.raises(ValueError): + dpt.matmul(s1, v2) + + with pytest.raises(ValueError): + dpt.matmul(v1, s2) + + with pytest.raises(TypeError): + dpt.matmul(dict(), v2) + + with pytest.raises(TypeError): + dpt.matmul(v2, None) + + +def test_matmul_dims_validation(): + get_queue_or_skip() + + m1 = dpt.ones((16, 16)) + m2 = dpt.ones((16, 16)) + + # contraction dimensions mismatch + with pytest.raises(ValueError): + dpt.matmul(m1[:, :7], m2[:3, :]) + + m1 = dpt.ones((3, 4, 5)) + m2 = dpt.ones((2, 5, 3)) + # broadcasting dimensions mismatch + with pytest.raises(ValueError): + dpt.matmul(m1, m2) + + +def test_matmul_broadcasting(): + get_queue_or_skip() + + m1 = dpt.ones((7, 11, 16)) + m2 = dpt.ones((16, 13)) + + r = dpt.matmul(m1, m2[dpt.newaxis, ...]) + + assert r.shape == (7, 11, 13) + + +@pytest.mark.parametrize("dtype", ["i4", "i8", "f4", "c8"]) +def test_matmul_strided(dtype): + get_queue_or_skip() + + m1_shape = (14, 22, 32) + m1_size = 1 + for el in m1_shape: + m1_size = m1_size * el + + m1 = dpt.remainder(dpt.arange(1, m1_size + 1, dtype="i8"), 13) + m1 = dpt.reshape(dpt.astype(m1, dtype), (14, 22, 32))[::2, ::-2, ::2] + m2 = dpt.ones((14, 16, 13), dtype=dtype)[::2, :, :] + + r = dpt.matmul(m1, m2) + + assert r.shape == (7, 11, 13) + ref = np.matmul(dpt.asnumpy(m1), dpt.asnumpy(m2)) + assert np.allclose(dpt.asnumpy(r), ref) + + +def test_matmul_out(): + get_queue_or_skip() + + m1 = ( + dpt.arange(14, dtype="f4")[:, dpt.newaxis, dpt.newaxis] + + dpt.arange(17, dtype="f4")[dpt.newaxis, :, dpt.newaxis] + + dpt.arange(128, dtype="f4")[dpt.newaxis, dpt.newaxis, :] + ) + assert m1.shape == (14, 17, 128) + m2 = dpt.tile( + dpt.reshape(dpt.asarray([1, 2], dtype="f4"), (2, 1, 1)), (7, 128, 13) + ) + assert m2.shape == (14, 128, 13) + + buf = dpt.zeros((2 * 14, 3 * 17, 13), dtype="f4") + res = dpt.matmul(m1, m2, out=buf[::-2, 1::3, :]) + + assert dpt.allclose(res, buf[::-2, 1::3, :]) + assert dpt.allclose(dpt.zeros_like(res), buf[::-2, 0::3, :]) + assert dpt.allclose(dpt.zeros_like(res), buf[::-2, 2::3, :]) + + ref = np.matmul(dpt.asnumpy(m1), dpt.asnumpy(m2)) + assert np.allclose(ref, dpt.asnumpy(res)) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_tensordot_outer(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + t1 = dpt.ones((3, 8), dtype=dtype) + t2 = dpt.ones((4, 12), dtype=dtype) + + r = dpt.tensordot(t1, t2, axes=0) + assert r.shape == t1.shape + t2.shape + assert dpt.allclose(r, dpt.ones_like(r)) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_tensordot_inner(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + t1 = dpt.ones((3, 8), dtype=dtype) + t2 = dpt.ones((4, 8), dtype=dtype) + + r = dpt.tensordot(t1, t2.mT, axes=1) + assert r.shape == t1.shape[:1] + t2.shape[:1] + assert dpt.allclose(r, dpt.full_like(r, fill_value=t1.shape[1])) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_tensordot_double(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + t1 = dpt.ones((2, 4, 8), dtype=dtype) + t2 = dpt.ones((3, 4, 8), dtype=dtype) + + r = dpt.tensordot(t1, dpt.permute_dims(t2, (1, 2, 0)), axes=2) + assert r.shape == t1.shape[:1] + t2.shape[:1] + expected = dpt.prod(dpt.asarray(t1.shape[1:])) + assert dpt.allclose(r, dpt.full_like(r, fill_value=expected)) + + +@pytest.mark.parametrize("dtype", ["i4", "f4"]) +def test_tensordot_axes_sequence(dtype): + get_queue_or_skip() + + r = 4 + t1 = dpt.ones((2, 2, 4, 3), dtype=dtype) + t2 = dpt.ones((3, 2, 4, 3), dtype=dtype) + + assert len(t1.shape) == r + assert len(t2.shape) == r + + expected = dpt.prod(dpt.asarray(t1.shape[1:])) + ps1 = itertools.permutations(range(r)) + ps2 = itertools.permutations(range(r)) + + for p1 in ps1: + assert len(p1) == r + inv_p1 = sorted(range(r), key=p1.__getitem__) + u1 = dpt.permute_dims(t1, p1) + x1_axes = inv_p1[1:] + for p2 in ps2: + inv_p2 = sorted(range(r), key=p2.__getitem__) + u2 = dpt.permute_dims(t2, p2) + x2_axes = inv_p2[1:] + + tdr = dpt.tensordot(u1, u2, axes=(x1_axes, x2_axes)) + assert tdr.shape == t1.shape[:1] + t2.shape[:1] + assert dpt.allclose(tdr, dpt.full_like(tdr, fill_value=expected)) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_vecdot_1d(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n = 511 + v1 = dpt.ones(n, dtype=dtype) + + v2 = dpt.ones(n, dtype=dtype) + + r = dpt.vecdot(v1, v2) + + assert r == n + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_vecdot_3d(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m1, m2, n = 7, 3, 511 + v1 = dpt.ones((m1, m2, n), dtype=dtype) + + v2 = dpt.ones((m1, m2, n), dtype=dtype) + + r = dpt.vecdot(v1, v2) + + assert r.shape == ( + m1, + m2, + ) + assert dpt.all(r == n) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_vecdot_axis(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m1, m2, n = 7, 3, 511 + v1 = dpt.ones((m1, n, m2), dtype=dtype) + + v2 = dpt.ones((m1, n, m2), dtype=dtype) + + r = dpt.vecdot(v1, v2, axis=1) + + assert r.shape == ( + m1, + m2, + ) + assert dpt.all(r == n) + + +@pytest.mark.parametrize("dtype", _numeric_types) +def test_vecdot_strided(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m1, m2, n = 7, 3, 511 + list1 = [1, 0, 2, 0] + pattern1 = dpt.asarray(list1, dtype=dtype) + n_padded1 = pattern1.size * (1 + ((n - 1) // pattern1.size)) + v1 = dpt.tile(dpt.reshape(pattern1, (1, -1, 1)), (m1, n_padded1, m2))[ + ::-1, :n, : + ] + + list2 = [1, 2, 1, 2] + pattern2 = dpt.asarray(list2, dtype=dtype) + n_padded2 = pattern2.size * (1 + ((n - 1) // pattern2.size)) + v2 = dpt.tile(dpt.reshape(pattern2, (1, -1, 1)), (m1, n_padded2, m2))[ + :, :n, ::-1 + ] + + r = dpt.vecdot(v1, v2, axis=1) + + ref = sum( + el1 * el2 + for el1, el2 in zip((list1 * n_padded1)[:n], (list2 * n_padded1)[:n]) + ) + + assert r.shape == ( + m1, + m2, + ) + assert dpt.all(r == ref) + + +def test_vector_arg_validation(): + get_queue_or_skip() + + s1, s2 = dpt.ones(tuple()), dpt.zeros(tuple()) + v1, v2 = dpt.ones(16), dpt.zeros(16) + + with pytest.raises(ValueError): + dpt.vecdot(s1, v2) + + with pytest.raises(ValueError): + dpt.vecdot(v1, s2) + + with pytest.raises(TypeError): + dpt.vecdot(dict(), v2) + + with pytest.raises(TypeError): + dpt.vecdot(v2, None) + + with pytest.raises(ValueError): + dpt.vecdot(v1[:5], v2[:4]) + + with pytest.raises(ValueError): + dpt.vecdot(v1, v2, axis=2) + + +@pytest.mark.parametrize("dt1", _numeric_types) +@pytest.mark.parametrize("dt2", _numeric_types) +def test_vecdot_type_promotion(dt1, dt2): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt1, q) + skip_if_dtype_not_supported(dt2, q) + + v1 = dpt.ones(128, dtype=dt1) + v2 = dpt.ones(128, dtype=dt2) + + r = dpt.vecdot(v1, v2) + mul = v1 * v2 + assert r.shape == tuple() + assert r.dtype == mul.dtype + assert dpt.allclose(r, dpt.sum(mul, dtype=mul.dtype)) From 39cf672709ce3ae9342111fe8a2b64804519785a Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 8 Jan 2024 10:42:11 -0800 Subject: [PATCH 04/48] Tweaks to `matmul` and gemm kernels Fixes a missing indexer in gemm functor with threading along `nm` dimensions Fixes `matmul` broadcasting, which was broadcasting in some unnecessary cases --- dpctl/tensor/_linear_algebra_functions.py | 20 +++++++++++-------- .../include/kernels/linalg_functions/gemm.hpp | 3 ++- dpctl/tests/test_usm_ndarray_linalg.py | 2 +- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index 2588fc0856..3801a3e3b1 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -823,9 +823,9 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): sycl_queue=exec_q, order=order, ) - if x1.shape != res_shape: + if x1.shape != x1_broadcast_shape: x1 = dpt.broadcast_to(x1, x1_broadcast_shape) - if x2.shape != res_shape: + if x2.shape != x2_broadcast_shape: x2 = dpt.broadcast_to(x2, x2_broadcast_shape) ht_dot_ev, binary_ev = tli._dot( x1=x1, @@ -875,9 +875,10 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): order=order, ) - if x1.shape != res_shape: + if x1.shape != x1_broadcast_shape: x1 = dpt.broadcast_to(x1, x1_broadcast_shape) - buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape) + if buf2.shape != x2_broadcast_shape: + buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape) ht_dot_ev, binary_ev = tli._dot( x1=x1, x2=buf2, @@ -929,8 +930,9 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): order=order, ) - buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape) - if x2.shape != res_shape: + if buf1.shape != x1_broadcast_shape: + buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape) + if x2.shape != x2_broadcast_shape: x2 = dpt.broadcast_to(x2, x2_broadcast_shape) ht_dot_ev, binary_ev = tli._dot( x1=buf1, @@ -994,8 +996,10 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): order=order, ) - buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape) - buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape) + if buf1.shape != x1_broadcast_shape: + buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape) + if buf2.shape != x2_broadcast_shape: + buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape) ht_, _ = tli._dot( x1=buf1, x2=buf2, diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 68bf9be860..afe2de92c0 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -533,7 +533,8 @@ class GemmFunctorThreadNM size_t g_j = g_j0 + lane_id; vec[lane_id] = (g_j < m && g_s < k) - ? static_cast(rhs[g_s * b_st0 + g_j * b_st1]) + ? static_cast( + rhs[rhs_indexer(g_s * b_st0 + g_j * b_st1)]) : resT(0); } diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 0b90e0b8fc..d9b0707200 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -72,7 +72,7 @@ def test_matmul_simple(dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) - n, m = 100, 17 + n, m = 235, 17 m1 = dpt.ones((m, n), dtype=dtype) m2 = dpt.ones((n, m), dtype=dtype) From 5b32a53259c2cc5dff263fd11b8ba538c485475c Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 8 Jan 2024 11:29:35 -0800 Subject: [PATCH 05/48] Remove double-counting of batch offset in gemm batch tree reduction --- .../include/kernels/linalg_functions/gemm.hpp | 166 ++++++++---------- 1 file changed, 69 insertions(+), 97 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index afe2de92c0..d2bde25dfb 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1068,7 +1068,7 @@ sycl::event gemm_impl(sycl::queue &exec_q, constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(64); - size_t delta_n(32); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, @@ -1105,7 +1105,7 @@ sycl::event gemm_impl(sycl::queue &exec_q, constexpr size_t m_groups = 2; size_t delta_k(4); size_t n_wi(64); - size_t delta_n(32); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, @@ -1236,7 +1236,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(64); - size_t delta_n(32); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, @@ -1273,7 +1273,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, constexpr size_t m_groups = 2; size_t delta_k(4); size_t n_wi(64); - size_t delta_n(32); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, @@ -1968,7 +1968,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, // temp memory if only one group is needed size_t delta_k(4); size_t n_wi(64); - size_t delta_n(32); + size_t delta_n(16); using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { @@ -3402,7 +3402,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, // temp memory if only one group is needed size_t delta_k(4); size_t n_wi(64); - size_t delta_n(32); + size_t delta_n(16); using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { @@ -5472,8 +5472,8 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, if (m == 1) { constexpr int m_groups = 1; size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); + size_t n_wi(32); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, @@ -5514,8 +5514,8 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); + size_t n_wi(32); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, @@ -5637,9 +5637,11 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, py::ssize_t res_batch_offset, std::vector const &depends = {}) { - const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); - const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); - resTy *res_tp = reinterpret_cast(res_cp); + const lhsTy *lhs_tp = + reinterpret_cast(lhs_cp) + lhs_batch_offset; + const rhsTy *rhs_tp = + reinterpret_cast(rhs_cp) + rhs_batch_offset; + resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = @@ -5665,20 +5667,17 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, Strided1DIndexer>; BatchDimsIndexerT batch_indexer( - Strided1DIndexer{lhs_batch_offset, - static_cast(batch_nelems), + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * k)}, - Strided1DIndexer{rhs_batch_offset, - static_cast(batch_nelems), + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(k * m)}, - Strided1DIndexer{res_batch_offset, - static_cast(batch_nelems), + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * m)}); if (m == 1) { constexpr int m_groups = 1; size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); + size_t n_wi(32); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, @@ -5719,8 +5718,8 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); + size_t n_wi(32); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, @@ -6499,7 +6498,7 @@ gemm_batch_tree_impl(sycl::queue &exec_q, if ((k > n && k > m) || m == 1) { size_t delta_k(4); - size_t n_wi(64); + size_t n_wi(32); size_t delta_n(4); using dpctl::tensor::type_utils::is_complex; @@ -8205,7 +8204,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, if ((k > n && k > m) || m == 1) { size_t delta_k(4); - size_t n_wi(64); + size_t n_wi(32); size_t delta_n(4); using dpctl::tensor::type_utils::is_complex; @@ -8240,16 +8239,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, using dpctl::tensor::offset_utils::Strided1DIndexer; BatchDimsIndexerT batch_indexer( Strided1DIndexer{ - lhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * k)}, Strided1DIndexer{ - rhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(k * m)}, Strided1DIndexer{ - res_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * m)}); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -8327,16 +8323,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, BatchDimsIndexerT batch_indexer( Strided1DIndexer{ - lhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * k)}, Strided1DIndexer{ - rhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(k * m)}, Strided1DIndexer{ - res_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * m)}); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -8471,16 +8464,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, BatchDimsIndexerT batch_indexer( Strided1DIndexer{ - lhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * k)}, Strided1DIndexer{ - rhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(k * m)}, Strided1DIndexer{ - res_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * m)}); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -8569,16 +8559,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, BatchDimsIndexerT batch_indexer( Strided1DIndexer{ - lhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * k)}, Strided1DIndexer{ - rhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(k * m)}, Strided1DIndexer{ - res_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * m)}); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -8657,16 +8644,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, BatchDimsIndexerT batch_indexer( Strided1DIndexer{ - lhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * k)}, Strided1DIndexer{ - rhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(k * m)}, Strided1DIndexer{ - res_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * m)}); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -8803,16 +8787,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, BatchDimsIndexerT batch_indexer( Strided1DIndexer{ - lhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * k)}, Strided1DIndexer{ - rhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(k * m)}, Strided1DIndexer{ - res_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * m)}); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -8901,13 +8882,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; BatchDimsIndexerT batch_indexer( - Strided1DIndexer{lhs_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * k)}, - Strided1DIndexer{rhs_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(k * m)}, - Strided1DIndexer{res_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * m)}); @@ -8985,16 +8966,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, BatchDimsIndexerT batch_indexer( Strided1DIndexer{ - lhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * k)}, Strided1DIndexer{ - rhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(k * m)}, Strided1DIndexer{ - res_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * m)}); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -9116,13 +9094,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; BatchDimsIndexerT batch_indexer( - Strided1DIndexer{lhs_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * k)}, - Strided1DIndexer{rhs_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(k * m)}, - Strided1DIndexer{res_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * m)}); @@ -9217,13 +9195,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; BatchDimsIndexerT batch_indexer( - Strided1DIndexer{lhs_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * k)}, - Strided1DIndexer{rhs_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(k * m)}, - Strided1DIndexer{res_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * m)}); @@ -9304,16 +9282,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, BatchDimsIndexerT batch_indexer( Strided1DIndexer{ - lhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * k)}, Strided1DIndexer{ - rhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(k * m)}, Strided1DIndexer{ - res_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * m)}); size_t lws = wg_delta_n * wg_delta_m; @@ -9444,13 +9419,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; BatchDimsIndexerT batch_indexer( - Strided1DIndexer{lhs_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * k)}, - Strided1DIndexer{rhs_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(k * m)}, - Strided1DIndexer{res_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * m)}); @@ -9540,13 +9515,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; BatchDimsIndexerT batch_indexer( - Strided1DIndexer{lhs_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * k)}, - Strided1DIndexer{rhs_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(k * m)}, - Strided1DIndexer{res_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * m)}); @@ -9626,16 +9601,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, BatchDimsIndexerT batch_indexer( Strided1DIndexer{ - lhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * k)}, Strided1DIndexer{ - rhs_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(k * m)}, Strided1DIndexer{ - res_batch_offset, - static_cast(batch_nelems), + 0, static_cast(batch_nelems), static_cast(n * m)}); size_t lws = wg_delta_n * wg_delta_m; @@ -9764,13 +9736,13 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; BatchDimsIndexerT batch_indexer( - Strided1DIndexer{lhs_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * k)}, - Strided1DIndexer{rhs_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(k * m)}, - Strided1DIndexer{res_batch_offset, + Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * m)}); From e25b9a7744f62f876db5cf17ebbf780453b13596 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 9 Jan 2024 01:24:34 -0800 Subject: [PATCH 06/48] Fixes missing dependency in vecdot When the first argument would not be cast and the second argument would be, the copy dependency was not appended to the list of dependencies, creating race conditions --- dpctl/tensor/_linear_algebra_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index 3801a3e3b1..4a8c19f667 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -475,6 +475,8 @@ def vecdot(x1, x2, axis=-1): ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=x2, dst=buf2, sycl_queue=exec_q ) + ht_list.append(ht_copy_ev) + deps.append(copy_ev) if x1.shape != broadcast_sh: x1 = dpt.broadcast_to(x1, broadcast_sh) if buf2.shape != broadcast_sh: From 58bb4ab97a665407e4e021239a255ea2ff6ff954 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 9 Jan 2024 01:24:52 -0800 Subject: [PATCH 07/48] Run test_matmul_simple2 in Windows before full test suite Part of triaging crashes on Windows --- .github/workflows/conda-package.yml | 3 +++ dpctl/tests/test_usm_ndarray_linalg.py | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index f0dff8a13c..fdc301c8cf 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -308,6 +308,9 @@ jobs: shell: cmd /C CALL {0} run: >- conda activate dpctl_test && python -m dpctl -f + - name: Run test_matmul_simple2 + run: | + conda activate dpctl_test && python -m pytest -q --pyargs dpctl.tests.test_usm_ndarray_linalg::test_matmul_simple2 -vv || true - name: Run tests shell: cmd /C CALL {0} env: diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index d9b0707200..8e3b2a2446 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -19,6 +19,7 @@ import numpy as np import pytest +import dpctl import dpctl.tensor as dpt from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported @@ -81,6 +82,26 @@ def test_matmul_simple(dtype): assert dpt.all(r == dpt.full((k, k), n, dtype=dtype)) +@pytest.mark.parametrize("dtype", _numeric_types) +def test_matmul_simple2(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + dev = q.sycl_device + if dev.is_cpu: + cpu_count = dev.max_compute_units + sub_devs = dev.create_sub_devices(partition=min(4, cpu_count // 2)) + ctx = dpctl.SyclContext(sub_devs[0]) + q = dpctl.SyclQueue(ctx, sub_devs[0]) + + n, m = 235, 17 + m1 = dpt.ones((m, n), dtype=dtype, sycl_queue=q) + m2 = dpt.ones((n, m), dtype=dtype, sycl_queue=q) + + for k in [1, 2, 3, 4, 7, 8, 9, 15, 16, 17]: + r = dpt.matmul(m1[:k, :], m2[:, :k]) + assert dpt.all(r == dpt.full((k, k), n, dtype=dtype, sycl_queue=q)) + + @pytest.mark.parametrize("dtype", _numeric_types) def test_matmul_nilpotent1(dtype): q = get_queue_or_skip() From 60f1d2108ae631582557d001bbd864b15b8f77f3 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 9 Jan 2024 01:25:04 -0800 Subject: [PATCH 08/48] Test removing test_matmul_simple leaving only test_matmul_simple2 --- dpctl/tests/test_usm_ndarray_linalg.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 8e3b2a2446..6365ffc6b7 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -68,18 +68,18 @@ def test_matrix_transpose_arg_validation(): assert isinstance(dpt.matrix_transpose(X), dpt.usm_ndarray) -@pytest.mark.parametrize("dtype", _numeric_types) -def test_matmul_simple(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) +# @pytest.mark.parametrize("dtype", _numeric_types) +# def test_matmul_simple(dtype): +# q = get_queue_or_skip() +# skip_if_dtype_not_supported(dtype, q) - n, m = 235, 17 - m1 = dpt.ones((m, n), dtype=dtype) - m2 = dpt.ones((n, m), dtype=dtype) +# n, m = 235, 17 +# m1 = dpt.ones((m, n), dtype=dtype) +# m2 = dpt.ones((n, m), dtype=dtype) - for k in [1, 2, 3, 4, 7, 8, 9, 15, 16, 17]: - r = dpt.matmul(m1[:k, :], m2[:, :k]) - assert dpt.all(r == dpt.full((k, k), n, dtype=dtype)) +# for k in [1, 2, 3, 4, 7, 8, 9, 15, 16, 17]: +# r = dpt.matmul(m1[:k, :], m2[:, :k]) +# assert dpt.all(r == dpt.full((k, k), n, dtype=dtype)) @pytest.mark.parametrize("dtype", _numeric_types) @@ -89,7 +89,7 @@ def test_matmul_simple2(dtype): dev = q.sycl_device if dev.is_cpu: cpu_count = dev.max_compute_units - sub_devs = dev.create_sub_devices(partition=min(4, cpu_count // 2)) + sub_devs = dev.create_sub_devices(partition=min(2, cpu_count // 2)) ctx = dpctl.SyclContext(sub_devs[0]) q = dpctl.SyclQueue(ctx, sub_devs[0]) From ad53472783b81de633e9e7aa7bd2dd67ee9d4954 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 9 Jan 2024 09:03:44 -0800 Subject: [PATCH 09/48] Fix incorrect comments throughtout gemm kernels Comments incorrectly stated that the third argument to `scale_gemm_k_parameters` is modified by reference --- .../include/kernels/linalg_functions/gemm.hpp | 140 ++++++++---------- 1 file changed, 60 insertions(+), 80 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index d2bde25dfb..f4d7adf2f5 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1071,10 +1071,9 @@ sycl::event gemm_impl(sycl::queue &exec_q, size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -1108,10 +1107,9 @@ sycl::event gemm_impl(sycl::queue &exec_q, size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -1239,10 +1237,9 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -1276,10 +1273,9 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -1976,10 +1972,9 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, constexpr int m_groups = 1; gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); sycl::event gemm_ev; @@ -2250,10 +2245,9 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, else { constexpr int m_groups = 2; gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); sycl::event gemm_ev; @@ -2529,10 +2523,9 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, constexpr int m_groups = 1; gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); sycl::event gemm_ev; @@ -3410,10 +3403,9 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, constexpr int m_groups = 1; gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); sycl::event gemm_ev; @@ -3663,10 +3655,9 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, else { constexpr int m_groups = 2; gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); sycl::event gemm_ev; @@ -3920,10 +3911,9 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, constexpr int m_groups = 1; gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); sycl::event gemm_ev; @@ -5476,10 +5466,9 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -5518,10 +5507,9 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -5680,10 +5668,9 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -5722,10 +5709,9 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); size_t n_blocks = (n + delta_n - 1) / delta_n; @@ -6506,10 +6492,9 @@ gemm_batch_tree_impl(sycl::queue &exec_q, if (m == 1) { constexpr int m_groups = 1; gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); if (k <= (delta_k * n_wi)) { @@ -6836,10 +6821,9 @@ gemm_batch_tree_impl(sycl::queue &exec_q, constexpr int m_groups = 2; gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); if (k <= (delta_k * n_wi)) { @@ -7174,10 +7158,9 @@ gemm_batch_tree_impl(sycl::queue &exec_q, constexpr int m_groups = 1; gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); // each group processes delta_k * n_wi @@ -8212,10 +8195,9 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, if (m == 1) { constexpr int m_groups = 1; gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); if (k <= (delta_k * n_wi)) { @@ -8533,10 +8515,9 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, constexpr int m_groups = 2; gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); if (k <= (delta_k * n_wi)) { @@ -8857,10 +8838,9 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, constexpr int m_groups = 1; gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, - delta_k, // modified by reference - n_wi, // modified by reference - delta_n // modified by reference + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference ); // each group processes delta_k * n_wi From 144ac0fb938f0c45e6c6a84bea85f4c19edc6073 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 9 Jan 2024 13:09:30 -0800 Subject: [PATCH 10/48] Drastically reduced parameters used for gemm kernels which thread over k Experimental change to see if this stabilizes CI --- .../include/kernels/linalg_functions/gemm.hpp | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index f4d7adf2f5..b1adf500fa 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1067,8 +1067,8 @@ sycl::event gemm_impl(sycl::queue &exec_q, if (m == 1) { constexpr size_t m_groups = 1; size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(16); + size_t n_wi(4); + size_t delta_n(4); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -1103,8 +1103,8 @@ sycl::event gemm_impl(sycl::queue &exec_q, else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(16); + size_t n_wi(4); + size_t delta_n(4); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -1233,8 +1233,8 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, if (m == 1) { constexpr size_t m_groups = 1; size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(16); + size_t n_wi(4); + size_t delta_n(4); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -1269,8 +1269,8 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(16); + size_t n_wi(4); + size_t delta_n(4); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -1963,8 +1963,8 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, // items in a column, so no need for allocating // temp memory if only one group is needed size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(16); + size_t n_wi(4); + size_t delta_n(4); using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { @@ -3394,8 +3394,8 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, // items in a column, so no need for allocating // temp memory if only one group is needed size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(16); + size_t n_wi(4); + size_t delta_n(4); using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { @@ -5462,8 +5462,8 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, if (m == 1) { constexpr int m_groups = 1; size_t delta_k(4); - size_t n_wi(32); - size_t delta_n(16); + size_t n_wi(4); + size_t delta_n(4); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -5503,8 +5503,8 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); - size_t n_wi(32); - size_t delta_n(16); + size_t n_wi(4); + size_t delta_n(4); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -5664,8 +5664,8 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, if (m == 1) { constexpr int m_groups = 1; size_t delta_k(4); - size_t n_wi(32); - size_t delta_n(16); + size_t n_wi(4); + size_t delta_n(4); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -5705,8 +5705,8 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); - size_t n_wi(32); - size_t delta_n(16); + size_t n_wi(4); + size_t delta_n(4); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -6484,7 +6484,7 @@ gemm_batch_tree_impl(sycl::queue &exec_q, if ((k > n && k > m) || m == 1) { size_t delta_k(4); - size_t n_wi(32); + size_t n_wi(4); size_t delta_n(4); using dpctl::tensor::type_utils::is_complex; @@ -8187,7 +8187,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, if ((k > n && k > m) || m == 1) { size_t delta_k(4); - size_t n_wi(32); + size_t n_wi(4); size_t delta_n(4); using dpctl::tensor::type_utils::is_complex; From 00cbc351cbb743ec3255dbc5ea45979e9cd7fa59 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 9 Jan 2024 16:12:54 -0800 Subject: [PATCH 11/48] Test removal of k-threading gemm kernel which writes to multiple outputs atomically --- .../include/kernels/linalg_functions/gemm.hpp | 154 +++++++++--------- 1 file changed, 80 insertions(+), 74 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index b1adf500fa..cedcf55ae8 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1064,7 +1064,7 @@ sycl::event gemm_impl(sycl::queue &exec_q, rhs_shape_strides); OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); - if (m == 1) { + if (k > n && k > m || m == 1) { constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(4); @@ -1099,42 +1099,46 @@ sycl::event gemm_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); - } - else if (k > n && k > m) { - constexpr size_t m_groups = 2; - size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_k_krn; - cgh.parallel_for( - ndRange, GemmFunctorThreadK( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, - n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, - m, lhs_indexer, rhs_indexer, res_indexer)); + // } + // else if (k > n && k > m) { + // constexpr size_t m_groups = 2; + // size_t delta_k(4); + // size_t n_wi(4); + // size_t delta_n(4); + + // gemm_detail::scale_gemm_k_parameters( + // local_mem_size, reserved_slm_size, delta_k, + // n_wi, // modified by reference + // delta_n // modified by reference + // ); + + // size_t n_blocks = (n + delta_n - 1) / delta_n; + // size_t m_blocks = (m + m_groups - 1) / m_groups; + // size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * + // delta_k); + + // size_t lws = delta_n * delta_k; + + // auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * + // lws); auto lRange = sycl::range<1>(lws); + + // auto ndRange = sycl::nd_range<1>(gRange, lRange); + + // using LocAccT = sycl::local_accessor, 1>; LocAccT local_B_block(n_wi * delta_k, cgh); + // LocAccT workspace(delta_n * delta_k, cgh); + + // using KernelName = class gemm_k_krn; + // cgh.parallel_for( + // ndRange, GemmFunctorThreadK( + // lhs_tp, rhs_tp, res_tp, workspace, + // local_B_block, n, n_blocks, delta_n, k, + // k_blocks, delta_k, n_wi, m, lhs_indexer, + // rhs_indexer, res_indexer)); } else { constexpr int wi_delta_n = 2; @@ -1230,7 +1234,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, OuterInnerIndexerT rhs_indexer{}; OuterInnerIndexerT res_indexer{}; - if (m == 1) { + if (k > n && k > m || m == 1) { constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(4); @@ -1266,42 +1270,44 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); } - else if (k > n && k > m) { - constexpr size_t m_groups = 2; - size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_k_krn; - cgh.parallel_for( - ndRange, GemmFunctorThreadK( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, - n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, - m, lhs_indexer, rhs_indexer, res_indexer)); - } + // else if (k > n && k > m) { + // constexpr size_t m_groups = 2; + // size_t delta_k(4); + // size_t n_wi(4); + // size_t delta_n(4); + + // gemm_detail::scale_gemm_k_parameters( + // local_mem_size, reserved_slm_size, delta_k, + // n_wi, // modified by reference + // delta_n // modified by reference + // ); + + // size_t n_blocks = (n + delta_n - 1) / delta_n; + // size_t m_blocks = (m + m_groups - 1) / m_groups; + // size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + // size_t lws = delta_n * delta_k; + + // auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * + // lws); auto lRange = sycl::range<1>(lws); + + // auto ndRange = sycl::nd_range<1>(gRange, lRange); + + // using LocAccT = sycl::local_accessor, + // 1>; LocAccT local_B_block(n_wi * delta_k, cgh); LocAccT + // workspace(delta_n * delta_k, cgh); + + // using KernelName = class gemm_k_krn; + // cgh.parallel_for( + // ndRange, GemmFunctorThreadK( + // lhs_tp, rhs_tp, res_tp, workspace, + // local_B_block, n, n_blocks, delta_n, k, + // k_blocks, delta_k, n_wi, m, lhs_indexer, + // rhs_indexer, res_indexer)); + // } else { constexpr int wi_delta_n = 2; constexpr int wi_delta_m = 4; From 2b847436e2277426d6993389d309f9e2237f1ea3 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 10 Jan 2024 15:44:46 -0800 Subject: [PATCH 12/48] Refactors `gemm_tree_impl` Now uses two smaller functions, `gemm_tree_k_impl` and `gemm_tree_nm_impl` for greater readability --- .../include/kernels/linalg_functions/gemm.hpp | 2099 ++++++----------- 1 file changed, 708 insertions(+), 1391 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index cedcf55ae8..7b26590580 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1064,7 +1064,7 @@ sycl::event gemm_impl(sycl::queue &exec_q, rhs_shape_strides); OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); - if (k > n && k > m || m == 1) { + if ((k > n && k > m) || m == 1) { constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(4); @@ -1234,7 +1234,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, OuterInnerIndexerT rhs_indexer{}; OuterInnerIndexerT res_indexer{}; - if (k > n && k > m || m == 1) { + if ((k > n && k > m) || m == 1) { constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(4); @@ -1938,1397 +1938,481 @@ template class gemm_tree_k_krn; -template -sycl::event gemm_tree_impl(sycl::queue &exec_q, - const char *lhs_cp, - const char *rhs_cp, - char *res_cp, - size_t n, - size_t k, - size_t m, - int inner_nd, - int lhs_outer_nd, - const py::ssize_t *lhs_outer_inner_shapes_strides, - int rhs_outer_nd, - const py::ssize_t *rhs_outer_inner_shapes_strides, - int res_nd, - const py::ssize_t *res_shapes_strides, - std::vector const &depends = {}) +template +sycl::event gemm_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const py::ssize_t *res_shapes_strides, + std::vector depends) { - const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); - const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); - resTy *res_tp = reinterpret_cast(res_cp); + size_t delta_k(4); + size_t n_wi(4); + size_t delta_n(4); const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = dev.get_info(); const size_t reserved_slm_size = 512; - if ((k > n && k > m) || m == 1) { - // each group processes delta_k * n_wi - // items in a column, so no need for allocating - // temp memory if only one group is needed - size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_nd, 0, res_shapes_strides); - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - if (m == 1) { - constexpr int m_groups = 1; + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + size_t lws = delta_n * delta_k; - sycl::event gemm_ev; - if (k <= (delta_k * n_wi)) { - gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer(res_nd, 0, - res_shapes_strides); + auto ndRange = sycl::nd_range<1>(gRange, lRange); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-groups is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; - size_t lws = delta_n * delta_k; + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; - auto gRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); + size_t lws = delta_n * delta_k; - auto ndRange = sycl::nd_range<1>(gRange, lRange); + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - m_groups>(lhs_tp, rhs_tp, res_tp, workspace, - local_B_block, n, n_blocks, delta_n, - k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, - res_indexer)); - }); - return gemm_ev; + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); } else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = n * m; - size_t reduction_nelems = - (k + delta_k * n_wi - 1) / (delta_k * n_wi); - - // more than one work-groups is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - using ResIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>( - n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = - sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - ResIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, ResIndexerT, - m_groups>( - lhs_tp, rhs_tp, tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, - k_blocks, delta_k, n_wi, m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(gemm_ev); + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; - ResIndexerT res_iter_indexer{res_nd, 0, - res_shapes_strides}; - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; + ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; - sycl::range<1> iter_range{iter_nelems}; + sycl::range<1> iter_range{iter_nelems}; - cgh.parallel_for< - class gemm_reduction_seq_strided_krn< - resTy, resTy, ReductionOpT, - InputOutputIterIndexerT, - ReductionIndexerT>>( - iter_range, - SequentialReduction>( + iter_range, SequentialReduction( tmp, res_tp, ReductionOpT(), identity_val, in_out_iter_indexer, reduction_indexer, reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task( - [ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info< - sycl::info::device::max_work_group_size>() / - 2); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * - (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = partially_reduced_tmp + - reduction_nelems * iter_nelems; - } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - using ResIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, ResIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, - workspace, local_B_block, n, n_blocks, delta_n, - k, k_blocks, delta_k, n_wi, m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - // tree_reduction_for_gemm returns sycl::event for reduction - sycl::event red_ev = - tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, - partially_reduced_tmp2, res_tp, identity_val, - iter_nelems, reduction_nelems, reduction_groups, wg, - max_wg, preferred_reductions_per_wi, - reductions_per_wi, res_nd, 0, res_shapes_strides, - {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } - } - else { - constexpr int m_groups = 2; - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - sycl::event gemm_ev; - if (k <= (delta_k * n_wi)) { - gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer(res_nd, 0, - res_shapes_strides); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - m_groups>(lhs_tp, rhs_tp, res_tp, workspace, - local_B_block, n, n_blocks, delta_n, - k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, - res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = n * m; - size_t reduction_nelems = - (k + delta_k * n_wi - 1) / (delta_k * n_wi); - - // more than one work-groups is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - using ResIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>( - n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = - sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor< - sycl::vec, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - ResIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, ResIndexerT, - m_groups>( - lhs_tp, rhs_tp, tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, - k_blocks, delta_k, n_wi, m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - ResIndexerT res_iter_indexer{res_nd, 0, - res_shapes_strides}; - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for< - class gemm_reduction_seq_strided_krn< - resTy, resTy, ReductionOpT, - InputOutputIterIndexerT, - ReductionIndexerT>>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task( - [ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info< - sycl::info::device::max_work_group_size>() / - 2); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * - (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = partially_reduced_tmp + - reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - using ResIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, ResIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, - workspace, local_B_block, n, n_blocks, delta_n, - k, k_blocks, delta_k, n_wi, m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - - sycl::event red_ev = - tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, - partially_reduced_tmp2, res_tp, identity_val, - iter_nelems, reduction_nelems, reduction_groups, wg, - max_wg, preferred_reductions_per_wi, - reductions_per_wi, res_nd, 0, res_shapes_strides, - {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } - } - } - else { - constexpr int m_groups = 1; - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - sycl::event gemm_ev; - if (k <= (delta_k * n_wi)) { - gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer(res_nd, 0, - res_shapes_strides); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = n * m; - size_t reduction_nelems = - (k + delta_k * n_wi - 1) / (delta_k * n_wi); - - // more than one work-groups is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - using ResIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, ResIndexerT, m_groups>( - lhs_tp, rhs_tp, tmp, workspace, local_B_block, - n, n_blocks, delta_n, k, k_blocks, delta_k, - n_wi, m, lhs_indexer, rhs_indexer, - res_indexer)); - }); - - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - ResIndexerT res_iter_indexer{res_nd, 0, - res_shapes_strides}; - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / - 2); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - using ResIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, lhs_indexer, rhs_indexer, - res_indexer)); - }); - // tree_reduction_for_gemm returns sycl::event for reduction - sycl::event red_ev = - tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, res_nd, - 0, res_shapes_strides, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } - } - } - else { // m > 1, n > k or m > k - constexpr int wi_delta_n = 2; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - constexpr int wi_delta_m = 4; - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); - - // each group processes delta_k items in a column, - // so no need to allocate temp memory if one group needed - if (k <= wi_delta_k) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer(res_nd, 0, - res_shapes_strides); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = n * m; - size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - - // more than one work-groups is needed, requires a temporary - // wi_delta_k elements processed along k, so if more to process - // use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - using ResIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * - wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, - 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, - wi_delta_m>(lhs_tp, rhs_tp, tmp, local_A_block, - local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - ResIndexerT res_iter_indexer{res_nd, 0, - res_shapes_strides}; - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / - 2); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - using ResIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + size_t lws = delta_n * delta_k; - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + auto ndRange = sycl::nd_range<1>(gRange, lRange); - using KernelName = - class gemm_tree_nm_krn; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, - wi_delta_m>(lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, - m_blocks, wg_delta_m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - - sycl::event red_ev = - tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, res_nd, - 0, res_shapes_strides, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } - } - else { - constexpr int wi_delta_m = 1; - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); - - // each group processes delta_k items in a column, - // so no need to allocate temp memory if one group needed - if (k <= wi_delta_k) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer(res_nd, 0, - res_shapes_strides); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - - using KernelName = class gemm_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - return gemm_ev; + ResIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); } else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = n * m; - size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - - // more than one work-groups is needed, requires a temporary - // wi_delta_k elements processed along k, so if more to process - // use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - using ResIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * - wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, - wi_delta_m>(lhs_tp, rhs_tp, tmp, local_A_block, - local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, - wg_delta_m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + }); + // tree_reduction_for_gemm returns sycl::event for reduction + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, wg, + max_wg, preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, + res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); - ResIndexerT res_iter_indexer{res_nd, 0, - res_shapes_strides}; - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; + return cleanup_host_task_event; + } +} - sycl::range<1> iter_range{iter_nelems}; +template +sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const py::ssize_t *res_shapes_strides, + std::vector depends) +{ + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI - cgh.parallel_for>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_nd, 0, res_shapes_strides); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / - 2); + size_t lws = wg_delta_n * wg_delta_m; - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - using ResIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; - size_t lws = wg_delta_n * wg_delta_m; + size_t lws = wg_delta_n * wg_delta_m; - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + if constexpr (wi_delta_m == 1) { using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), @@ -3345,34 +2429,265 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, GemmNoAtomicFunctorThreadNM< lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, - wi_delta_m>(lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, - wg_delta_n, k, k_blocks, wi_delta_k, m, - m_blocks, wg_delta_m, lhs_indexer, - rhs_indexer, res_indexer)); + wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); }); + return cleanup_host_task_event; + } - sycl::event red_ev = - tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, res_nd, - 0, res_shapes_strides, {gemm_ev}); + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } - return cleanup_host_task_event; + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, wg, + max_wg, preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, + res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } +} + +template +sycl::event gemm_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const py::ssize_t *res_shapes_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + if ((k > n && k > m) || m == 1) { + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m == 1) { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + else { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); } } + else { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + else { + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } } } @@ -3466,9 +2781,9 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); - // more than one work-groups is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple + // more than one work-groups is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple const auto &sg_sizes = dev.get_info(); size_t wg = @@ -3635,8 +2950,8 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); }); - // tree_reduction_for_gemm_contig returns sycl::event for - // reduction + // tree_reduction_for_gemm_contig returns sycl::event + // for reduction sycl::event red_ev = tree_reduction_for_gemm_contig( exec_q, partially_reduced_tmp, @@ -3719,9 +3034,9 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); - // more than one work-groups is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple + // more than one work-groups is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple const auto &sg_sizes = dev.get_info(); size_t wg = @@ -4236,8 +3551,8 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; // more than one work-groups is needed, requires a temporary - // wi_delta_k elements processed along k, so if more to process - // use multiple + // wi_delta_k elements processed along k, so if more to + // process use multiple const auto &sg_sizes = dev.get_info(); size_t wg = @@ -4502,8 +3817,8 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; // more than one work-groups is needed, requires a temporary - // wi_delta_k elements processed along k, so if more to process - // use multiple + // wi_delta_k elements processed along k, so if more to + // process use multiple const auto &sg_sizes = dev.get_info(); size_t wg = @@ -4769,8 +4084,8 @@ class GemmBatchFunctorThreadNM batch_indexer(static_cast(m_id)); // lift group_id to (block_i, block_j, block_s), - // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < - // k_blocks + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s + // < k_blocks const auto &lhs_offset = three_offsets_.get_first_offset(); const auto &rhs_offset = three_offsets_.get_second_offset(); @@ -4957,8 +4272,8 @@ class GemmBatchFunctorThreadNM(m_id)); // lift group_id to (block_i, block_j, block_s), - // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < - // k_blocks + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s + // < k_blocks const auto &lhs_offset = three_offsets_.get_first_offset(); const auto &rhs_offset = three_offsets_.get_second_offset(); @@ -5116,7 +4431,8 @@ class GemmBatchFunctorThreadK { // for batching: // (current matrix in batch) m_id = global_id / (global_range / - // batch_nelems) for lhs, offset = m_id * (n * k) for rhs, offset = m_id + // batch_nelems) for lhs, offset = m_id * (n * k) for rhs, offset = + // m_id // * (k * m) for res, offset = m_id * (n * m) size_t m_id = it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); @@ -5277,7 +4593,8 @@ class GemmBatchFunctorThreadK(m_id)); // lift group_id to (block_i, block_j, block_s), - // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < - // k_blocks + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s + // < k_blocks const auto &lhs_offset = three_offsets_.get_first_offset(); const auto &rhs_offset = three_offsets_.get_second_offset(); @@ -6062,8 +5379,8 @@ class GemmBatchNoAtomicFunctorThreadNM(m_id)); // lift group_id to (block_i, block_j, block_s), - // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s < - // k_blocks + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s + // < k_blocks const auto &lhs_offset = three_offsets_.get_first_offset(); const auto &rhs_offset = three_offsets_.get_second_offset(); @@ -6569,9 +5886,9 @@ gemm_batch_tree_impl(sycl::queue &exec_q, size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); - // more than one work-group is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple const auto &sg_sizes = dev.get_info(); size_t wg = @@ -6899,9 +6216,9 @@ gemm_batch_tree_impl(sycl::queue &exec_q, size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); - // more than one work-group is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple const auto &sg_sizes = dev.get_info(); size_t wg = @@ -8281,9 +7598,9 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); - // more than one work-group is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple const auto &sg_sizes = dev.get_info(); size_t wg = @@ -8601,9 +7918,9 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); - // more than one work-group is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple const auto &sg_sizes = dev.get_info(); size_t wg = From 01ff619a99968e633849d14023a695492f947ec1 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 10 Jan 2024 15:45:16 -0800 Subject: [PATCH 13/48] Reverse order of numeric types passed to test_matmul_simple2 May improve stability on CPU --- dpctl/tests/test_usm_ndarray_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 6365ffc6b7..29b9ee17d3 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -82,7 +82,7 @@ def test_matrix_transpose_arg_validation(): # assert dpt.all(r == dpt.full((k, k), n, dtype=dtype)) -@pytest.mark.parametrize("dtype", _numeric_types) +@pytest.mark.parametrize("dtype", _numeric_types[::-1]) def test_matmul_simple2(dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) From 1b861616246e8fa7fd4d3ad3f478ab4312bb3c77 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 10 Jan 2024 19:22:44 -0800 Subject: [PATCH 14/48] Refactors `gemm_contig_tree_impl` `gemm_contig_tree_impl` now calls new functions `gemm_contig_tree_k_impl` and `gemm_contig_tree_nm_impl` --- .../include/kernels/linalg_functions/gemm.hpp | 1839 ++++++----------- 1 file changed, 609 insertions(+), 1230 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 7b26590580..f712407b86 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1953,7 +1953,7 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q, const py::ssize_t *rhs_outer_inner_shapes_strides, int res_nd, const py::ssize_t *res_shapes_strides, - std::vector depends) + const std::vector &depends) { size_t delta_k(4); size_t n_wi(4); @@ -2272,7 +2272,7 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, const py::ssize_t *rhs_outer_inner_shapes_strides, int res_nd, const py::ssize_t *res_shapes_strides, - std::vector depends) + const std::vector &depends) { constexpr int wi_delta_n = 2; size_t wg_delta_n(16); // rows of A processed in WG @@ -2691,1282 +2691,457 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, } } -template -sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, - const char *lhs_cp, - const char *rhs_cp, - char *res_cp, - size_t n, - size_t k, - size_t m, - std::vector const &depends = {}) +template +sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + std::vector const &depends) { - const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); - const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); - resTy *res_tp = reinterpret_cast(res_cp); + size_t delta_k(4); + size_t n_wi(4); + size_t delta_n(4); const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = dev.get_info(); const size_t reserved_slm_size = 512; - if ((k > n && k > m) || m == 1) { - // each group processes delta_k * n_wi - // items in a column, so no need for allocating - // temp memory if only one group is needed - size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - if (m == 1) { - constexpr int m_groups = 1; + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; - sycl::event gemm_ev; - if (k <= (delta_k * n_wi)) { - gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; + size_t lws = delta_n * delta_k; - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - size_t lws = delta_n * delta_k; + auto ndRange = sycl::nd_range<1>(gRange, lRange); - auto gRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; - auto ndRange = sycl::nd_range<1>(gRange, lRange); + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - m_groups>(lhs_tp, rhs_tp, res_tp, workspace, - local_B_block, n, n_blocks, delta_n, - k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, - res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; + // more than one work-groups is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - size_t iter_nelems = n * m; - size_t reduction_nelems = - (k + delta_k * n_wi - 1) / (delta_k * n_wi); + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - // more than one work-groups is needed, requires a - // temporary delta_k * n_wi elements processed along k, - // so if more to process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>( - n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = - sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, - k_blocks, delta_k, n_wi, m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(gemm_ev); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; + size_t lws = delta_n * delta_k; - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - sycl::range<1> iter_range{iter_nelems}; + auto ndRange = sycl::nd_range<1>(gRange, lRange); + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); - cgh.parallel_for< - class gemm_reduction_seq_strided_krn< - resTy, resTy, ReductionOpT, - InputOutputIterIndexerT, - ReductionIndexerT>>( - iter_range, - SequentialReduction; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, SequentialReduction( tmp, res_tp, ReductionOpT(), identity_val, in_out_iter_indexer, reduction_indexer, reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - cgh.host_task( - [ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info< - sycl::info::device::max_work_group_size>() / - 2); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * - (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = partially_reduced_tmp + - reduction_nelems * iter_nelems; - } + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); - sycl::event gemm_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } - size_t lws = delta_n * delta_k; + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - auto gRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; - auto ndRange = sycl::nd_range<1>(gRange, lRange); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, - workspace, local_B_block, n, n_blocks, - delta_n, k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, res_indexer)); - }); - // tree_reduction_for_gemm_contig returns sycl::event - // for reduction - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, - partially_reduced_tmp2, res_tp, identity_val, - iter_nelems, reduction_nelems, reduction_groups, wg, - max_wg, preferred_reductions_per_wi, - reductions_per_wi, {gemm_ev}); + size_t lws = delta_n * delta_k; - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); + auto ndRange = sycl::nd_range<1>(gRange, lRange); - return cleanup_host_task_event; - } + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); } else { - constexpr int m_groups = 2; - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + }); + // tree_reduction_for_gemm_contig returns sycl::event + // for reduction + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + {gemm_ev}); - sycl::event gemm_ev; - if (k <= (delta_k * n_wi)) { - gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; + return cleanup_host_task_event; + } +} - size_t lws = delta_n * delta_k; +template +sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + std::vector const &depends) +{ + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI - auto gRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; - auto ndRange = sycl::nd_range<1>(gRange, lRange); + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - m_groups>(lhs_tp, rhs_tp, res_tp, workspace, - local_B_block, n, n_blocks, delta_n, - k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, - res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - size_t iter_nelems = n * m; - size_t reduction_nelems = - (k + delta_k * n_wi - 1) / (delta_k * n_wi); + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; - // more than one work-groups is needed, requires a - // temporary delta_k * n_wi elements processed along k, - // so if more to process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); + size_t lws = wg_delta_n * wg_delta_m; - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>( - n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = - sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor< - sycl::vec, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, - k_blocks, delta_k, n_wi, m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(gemm_ev); + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - sycl::range<1> iter_range{iter_nelems}; + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - cgh.parallel_for< - class gemm_reduction_seq_strided_krn< - resTy, resTy, ReductionOpT, - InputOutputIterIndexerT, - ReductionIndexerT>>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task( - [ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info< - sycl::info::device::max_work_group_size>() / - 2); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * - (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = partially_reduced_tmp + - reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - m_groups>(lhs_tp, rhs_tp, partially_reduced_tmp, - workspace, local_B_block, n, n_blocks, - delta_n, k, k_blocks, delta_k, n_wi, - m, lhs_indexer, rhs_indexer, - res_indexer)); - }); - - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, - partially_reduced_tmp2, res_tp, identity_val, - iter_nelems, reduction_nelems, reduction_groups, wg, - max_wg, preferred_reductions_per_wi, - reductions_per_wi, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } - } - } - else { - constexpr int m_groups = 1; - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - sycl::event gemm_ev; - if (k <= (delta_k * n_wi)) { - gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, - n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, - lhs_indexer, rhs_indexer, res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = n * m; - size_t reduction_nelems = - (k + delta_k * n_wi - 1) / (delta_k * n_wi); - - // more than one work-groups is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, - k_blocks, delta_k, n_wi, m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / - 2); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, lhs_indexer, rhs_indexer, - res_indexer)); - }); - // tree_reduction_for_gemm returns sycl::event for reduction - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, - {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } - } - } - else { // m > 1, n > k or m > k - constexpr int wi_delta_n = 2; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - constexpr int wi_delta_m = 4; - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); - - // each group processes delta_k items in a column, - // so no need to allocate temp memory if one group needed - if (k <= wi_delta_k) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = n * m; - size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - - // more than one work-groups is needed, requires a temporary - // wi_delta_k elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * - wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, - 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, - lhs_indexer, rhs_indexer, res_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / - 2); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - lhs_indexer, rhs_indexer, res_indexer)); - }); - - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, - {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } - } - else { - constexpr int wi_delta_m = 1; - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); - - // each group processes delta_k items in a column, - // so no need to allocate temp memory if one group needed - if (k <= wi_delta_k) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - - using KernelName = class gemm_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, lhs_indexer, - rhs_indexer, res_indexer)); - }); - return gemm_ev; + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); } else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = n * m; - size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - - // more than one work-groups is needed, requires a temporary - // wi_delta_k elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * - wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - - using KernelName = class gemm_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, - lhs_indexer, rhs_indexer, res_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / - 2); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; + size_t iter_nelems = n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; - size_t lws = wg_delta_n * wg_delta_m; + size_t lws = wg_delta_n * wg_delta_m; - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + if constexpr (wi_delta_m == 1) { using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), @@ -3983,33 +3158,237 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - lhs_indexer, rhs_indexer, res_indexer)); + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); + } + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); }); + return cleanup_host_task_event; + } - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, - {gemm_ev}); + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); - return cleanup_host_task_event; + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } +} + +template +sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t n, + size_t k, + size_t m, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + if ((k > n && k > m) || m == 1) { + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m == 1) { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + else { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } } + else { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + else { + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } } } From 459c0efc31800a537e8fbfac2aabd9ab65c5ee7a Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 11 Jan 2024 01:19:47 -0800 Subject: [PATCH 15/48] Refactoring `gemm_batch_tree` functions Adds new functions for calling `nm` threading and `k` threading kernels to improve readability --- .../include/kernels/linalg_functions/gemm.hpp | 4676 ++++++----------- 1 file changed, 1493 insertions(+), 3183 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index f712407b86..1cc7e9db65 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -24,7 +24,7 @@ namespace kernels namespace gemm_detail { -template +template void scale_gemm_k_parameters(const size_t &local_mem_size, const size_t &reserved_slm_size, const size_t delta_k, @@ -68,32 +68,6 @@ void scale_gemm_nm_parameters(const size_t &local_mem_size, } } // namespace gemm_detail -// template -// struct ThreeOffsets_CombinedIndexer -// { -// private: -// FirstIndexerT first_indexer_; -// SecondIndexerT second_indexer_; -// ThirdIndexerT third_indexer_; - -// public: -// ThreeOffsets_CombinedIndexer(const FirstIndexerT &first_indexer, -// const SecondIndexerT &second_indexer, -// const ThirdIndexerT &third_indexer) -// : first_indexer_(first_indexer), second_indexer_(second_indexer), -// third_indexer_(third_indexer) -// { -// } - -// ThreeOffsets operator()(py::ssize_t gid) const -// { -// return ThreeOffsets( -// first_indexer_(gid), second_indexer_(gid), third_indexer_(gid)); -// } -// }; - using dpctl::tensor::sycl_utils::choose_workgroup_size; template @@ -426,8 +400,8 @@ template + int wi_delta_n, + int wi_delta_m> class GemmFunctorThreadNM { private: @@ -586,7 +560,7 @@ template + int wi_delta_n> class GemmFunctorThreadNM + int wi_delta_n, + int wi_delta_m> class GemmNoAtomicFunctorThreadNM { private: @@ -1524,7 +1498,7 @@ template + int wi_delta_n> class GemmNoAtomicFunctorThreadNM +template sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, const lhsTy *lhs_tp, const rhsTy *rhs_tp, @@ -2998,7 +2972,7 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, } } -template +template sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, const lhsTy *lhs_tp, const rhsTy *rhs_tp, @@ -3399,8 +3373,8 @@ template + int wi_delta_n, + int wi_delta_m> class GemmBatchFunctorThreadNM { private: @@ -3580,7 +3554,7 @@ template + int wi_delta_n> class GemmBatchFunctorThreadNM(batch_nelems), static_cast(n * m)}); if (m == 1) { - constexpr int m_groups = 1; + constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(4); size_t delta_n(4); @@ -4509,8 +4483,8 @@ template + int wi_delta_n, + int wi_delta_m> class GemmBatchNoAtomicFunctorThreadNM { private: @@ -4686,7 +4660,7 @@ template + int wi_delta_n> class GemmBatchNoAtomicFunctorThreadNM class gemm_batch_tree_nm_krn; -template +template sycl::event -gemm_batch_tree_impl(sycl::queue &exec_q, - const char *lhs_cp, - const char *rhs_cp, - char *res_cp, - size_t batch_nelems, - size_t n, - size_t k, - size_t m, - int batch_nd, - const py::ssize_t *batch_shape_strides, - py::ssize_t lhs_batch_offset, - py::ssize_t rhs_batch_offset, - py::ssize_t res_batch_offset, - int inner_nd, - int lhs_outer_nd, - const py::ssize_t *lhs_outer_inner_shapes_strides, - int rhs_outer_nd, - const py::ssize_t *rhs_outer_inner_shapes_strides, - int res_outer_nd, - const py::ssize_t *res_outer_shapes_strides, - const py::ssize_t *res_shape_strides, - std::vector const &depends = {}) +gemm_batch_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const py::ssize_t *batch_shape_strides, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const py::ssize_t *res_outer_shapes_strides, + const py::ssize_t *res_shape_strides, + std::vector const &depends) { - const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); - const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); - resTy *res_tp = reinterpret_cast(res_cp); + size_t delta_k(4); + size_t n_wi(4); + size_t delta_n(4); const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = dev.get_info(); const size_t reserved_slm_size = 512; - if ((k > n && k > m) || m == 1) { - size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - if (m == 1) { - constexpr int m_groups = 1; - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - if (k <= (delta_k * n_wi)) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer( - res_outer_nd, 0, res_outer_shapes_strides); - using BatchDimsIndexerT = dpctl::tensor::offset_utils:: - ThreeOffsets_StridedIndexer; - BatchDimsIndexerT batch_indexer( - batch_nd, lhs_batch_offset, rhs_batch_offset, - res_batch_offset, batch_shape_strides); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, - m_groups>; - - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, - local_B_block, n, n_blocks, delta_n, k, - k_blocks, delta_k, n_wi, m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = batch_nelems * n * m; - size_t reduction_nelems = - (k + delta_k * n_wi - 1) / (delta_k * n_wi); - - // more than one work-group is needed, requires a - // temporary delta_k * n_wi elements processed along k, - // so if more to process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils:: - UnpackedStridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer< - StridedIndexer, UnpackedStridedIndexer, - Strided1DIndexer>; - StridedIndexer lhs_batch_indexer( - batch_nd, lhs_batch_offset, - batch_shape_strides); - UnpackedStridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), - n * m); - BatchDimsIndexerT batch_indexer(lhs_batch_indexer, - rhs_batch_indexer, - tmp_batch_indexer); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = - sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, - k_blocks, delta_k, n_wi, m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - ResIndexerT res_iter_indexer{ - batch_nd + res_outer_nd, - static_cast(res_batch_offset), - res_shape_strides}; - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for< - class gemm_reduction_seq_strided_krn< - resTy, resTy, ReductionOpT, - InputOutputIterIndexerT, - ReductionIndexerT>>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task( - [ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info< - sycl::info::device::max_work_group_size>() / - 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * - (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = partially_reduced_tmp + - reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - StridedIndexer, StridedIndexer, Strided1DIndexer>; - StridedIndexer lhs_batch_indexer( - batch_nd, lhs_batch_offset, batch_shape_strides); - StridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer(lhs_batch_indexer, - rhs_batch_indexer, - tmp_batch_indexer); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - constexpr int m_groups = 1; - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, - workspace, local_B_block, n, n_blocks, delta_n, - k, k_blocks, delta_k, n_wi, m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - }); - - sycl::event red_ev = - tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, - partially_reduced_tmp2, res_tp, identity_val, - iter_nelems, reduction_nelems, reduction_groups, wg, - max_wg, preferred_reductions_per_wi, - reductions_per_wi, batch_nd + res_outer_nd, - res_batch_offset, res_shape_strides, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } + if (k <= (delta_k * n_wi)) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); } else { - constexpr int m_groups = 2; - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - if (k <= (delta_k * n_wi)) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer( - res_outer_nd, 0, res_outer_shapes_strides); - using BatchDimsIndexerT = dpctl::tensor::offset_utils:: - ThreeOffsets_StridedIndexer; - BatchDimsIndexerT batch_indexer( - batch_nd, lhs_batch_offset, rhs_batch_offset, - res_batch_offset, batch_shape_strides); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, - m_groups>; - - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, - local_B_block, n, n_blocks, delta_n, k, - k_blocks, delta_k, n_wi, m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = batch_nelems * n * m; - size_t reduction_nelems = - (k + delta_k * n_wi - 1) / (delta_k * n_wi); - - // more than one work-group is needed, requires a - // temporary delta_k * n_wi elements processed along k, - // so if more to process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils:: - UnpackedStridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer< - StridedIndexer, UnpackedStridedIndexer, - Strided1DIndexer>; - StridedIndexer lhs_batch_indexer( - batch_nd, lhs_batch_offset, - batch_shape_strides); - UnpackedStridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), - n * m); - BatchDimsIndexerT batch_indexer(lhs_batch_indexer, - rhs_batch_indexer, - tmp_batch_indexer); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = - sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = - sycl::local_accessor, - 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, - k_blocks, delta_k, n_wi, m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - ResIndexerT res_iter_indexer{ - batch_nd + res_outer_nd, - static_cast(res_batch_offset), - res_shape_strides}; - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for< - class gemm_reduction_seq_strided_krn< - resTy, resTy, ReductionOpT, - InputOutputIterIndexerT, - ReductionIndexerT>>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task( - [ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info< - sycl::info::device::max_work_group_size>() / - 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * - (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = partially_reduced_tmp + - reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils:: - UnpackedStridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - StridedIndexer lhs_batch_indexer( - batch_nd, lhs_batch_offset, batch_shape_strides); - UnpackedStridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer(lhs_batch_indexer, - rhs_batch_indexer, - tmp_batch_indexer); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, - workspace, local_B_block, n, n_blocks, delta_n, - k, k_blocks, delta_k, n_wi, m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - }); - - sycl::event red_ev = - tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, - partially_reduced_tmp2, res_tp, identity_val, - iter_nelems, reduction_nelems, reduction_groups, wg, - max_wg, preferred_reductions_per_wi, - reductions_per_wi, batch_nd + res_outer_nd, - res_batch_offset, res_shape_strides, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); } - } - else { - constexpr int m_groups = 1; + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - // each group processes delta_k * n_wi - // items in a column, so no need for allocating - // temp memory if only one group is needed - if (k <= (delta_k * n_wi)) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer( - res_outer_nd, 0, res_outer_shapes_strides); - using BatchDimsIndexerT = dpctl::tensor::offset_utils:: - ThreeOffsets_StridedIndexer; - BatchDimsIndexerT batch_indexer( - batch_nd, lhs_batch_offset, rhs_batch_offset, - res_batch_offset, batch_shape_strides); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - constexpr int m_groups = 1; - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { using LocAccT = sycl::local_accessor; LocAccT local_B_block(n_wi * delta_k, cgh); LocAccT workspace(delta_n * delta_k, cgh); using KernelName = class gemm_batch_tree_k_krn< lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; - + TmpIndexerT, BatchDimsIndexerT, m_groups>; cgh.parallel_for( ndRange, GemmBatchNoAtomicFunctorThreadK< lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, TmpIndexerT, BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = batch_nelems * n * m; - size_t reduction_nelems = - (k + delta_k * n_wi - 1) / (delta_k * n_wi); - - // more than one work-group is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils:: - UnpackedStridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - StridedIndexer lhs_batch_indexer( - batch_nd, lhs_batch_offset, batch_shape_strides); - UnpackedStridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer(lhs_batch_indexer, - rhs_batch_indexer, - tmp_batch_indexer); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - constexpr int m_groups = 1; - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, tmp, workspace, local_B_block, - n, n_blocks, delta_n, k, k_blocks, delta_k, - n_wi, m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, res_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - ResIndexerT res_iter_indexer{ - batch_nd + res_outer_nd, - static_cast(res_batch_offset), - res_shape_strides}; - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / - 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); } else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils::UnpackedStridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, - batch_shape_strides); - UnpackedStridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer(lhs_batch_indexer, - rhs_batch_indexer, - tmp_batch_indexer); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - constexpr int m_groups = 1; - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; + using LocAccT = + sycl::local_accessor, 1>; LocAccT local_B_block(n_wi * delta_k, cgh); LocAccT workspace(delta_n * delta_k, cgh); @@ -6149,426 +5327,368 @@ gemm_batch_tree_impl(sycl::queue &exec_q, lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, TmpIndexerT, BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, res_indexer)); - }); - - sycl::event red_ev = - tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, - batch_nd + res_outer_nd, res_batch_offset, - res_shape_strides, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } - } - } - else { // m > 1, n > k or m > k - constexpr int wi_delta_n = 2; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - constexpr int wi_delta_m = 4; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); + ResIndexerT res_iter_indexer{ + batch_nd + res_outer_nd, + static_cast(res_batch_offset), + res_shape_strides}; + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; - // each group processes delta_k * n_wi - // items in a column, so no need for allocating - // temp memory if only one group is needed - if (k <= wi_delta_k) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer( - res_outer_nd, 0, res_outer_shapes_strides); - using BatchDimsIndexerT = dpctl::tensor::offset_utils:: - ThreeOffsets_StridedIndexer; - BatchDimsIndexerT batch_indexer( - batch_nd, lhs_batch_offset, rhs_batch_offset, - res_batch_offset, batch_shape_strides); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + sycl::range<1> iter_range{iter_nelems}; - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + cgh.parallel_for>( + iter_range, SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - size_t iter_nelems = batch_nelems * n * m; - size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - - // more than one work-group is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils:: - UnpackedStridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - StridedIndexer lhs_batch_indexer( - batch_nd, lhs_batch_offset, batch_shape_strides); - UnpackedStridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer(lhs_batch_indexer, - rhs_batch_indexer, - tmp_batch_indexer); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * - wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, - 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - ResIndexerT res_iter_indexer{ - batch_nd + res_outer_nd, - static_cast(res_batch_offset), - res_shape_strides}; - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } + return cleanup_host_task_event; + } - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / - 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils::UnpackedStridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, - batch_shape_strides); - UnpackedStridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer(lhs_batch_indexer, - rhs_batch_indexer, - tmp_batch_indexer); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + StridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, wg, + max_wg, preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); }); + }); + + return cleanup_host_task_event; + } +} + +template +sycl::event +gemm_batch_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const py::ssize_t *batch_shape_strides, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const py::ssize_t *res_outer_shapes_strides, + const py::ssize_t *res_shape_strides, + std::vector const &depends) +{ + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); - sycl::event red_ev = - tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, - batch_nd + res_outer_nd, res_batch_offset, - res_shape_strides, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); } - } - else { // m > 1, n > k or m > k, resTy complex - constexpr int wi_delta_m = 1; + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); - if (k <= wi_delta_k) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT res_indexer( - res_outer_nd, 0, res_outer_shapes_strides); - using BatchDimsIndexerT = dpctl::tensor::offset_utils:: - ThreeOffsets_StridedIndexer; - BatchDimsIndexerT batch_indexer( - batch_nd, lhs_batch_offset, rhs_batch_offset, - res_batch_offset, batch_shape_strides); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + if constexpr (wi_delta_m == 1) { using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), @@ -6579,244 +5699,25 @@ gemm_batch_tree_impl(sycl::queue &exec_q, using KernelName = class gemm_batch_tree_nm_krn< lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; cgh.parallel_for( ndRange, GemmBatchNoAtomicFunctorThreadNM< lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, TmpIndexerT, BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - size_t iter_nelems = batch_nelems * n * m; - size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - - // more than one work-group is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils:: - UnpackedStridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - StridedIndexer lhs_batch_indexer( - batch_nd, lhs_batch_offset, batch_shape_strides); - UnpackedStridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer(lhs_batch_indexer, - rhs_batch_indexer, - tmp_batch_indexer); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * - wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - ResIndexerT res_iter_indexer{ - batch_nd + res_outer_nd, - static_cast(res_batch_offset), - res_shape_strides}; - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / - 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); } else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer( - inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer( - inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils::UnpackedStridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, - batch_shape_strides); - UnpackedStridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer(lhs_batch_indexer, - rhs_batch_indexer, - tmp_batch_indexer); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); - using LocAccT2 = sycl::local_accessor; + using LocAccT2 = + sycl::local_accessor, 1>; LocAccT2 local_B_block( sycl::range<1>(wi_delta_k * wg_delta_m), cgh); @@ -6829,765 +5730,433 @@ gemm_batch_tree_impl(sycl::queue &exec_q, lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, OuterInnerDimsIndexerT, TmpIndexerT, BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - }); - - sycl::event red_ev = - tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, - batch_nd + res_outer_nd, res_batch_offset, - res_shape_strides, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } - } - } -} + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); -template -sycl::event -gemm_batch_contig_tree_impl(sycl::queue &exec_q, - const char *lhs_cp, - const char *rhs_cp, - char *res_cp, - size_t batch_nelems, - size_t n, - size_t k, - size_t m, - py::ssize_t lhs_batch_offset, - py::ssize_t rhs_batch_offset, - py::ssize_t res_batch_offset, - std::vector const &depends = {}) -{ - const lhsTy *lhs_tp = - reinterpret_cast(lhs_cp) + lhs_batch_offset; - const rhsTy *rhs_tp = - reinterpret_cast(rhs_cp) + rhs_batch_offset; - resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; - const sycl::device &dev = exec_q.get_device(); - const size_t local_mem_size = - dev.get_info(); - const size_t reserved_slm_size = 512; + ResIndexerT res_iter_indexer{ + batch_nd + res_outer_nd, + static_cast(res_batch_offset), + res_shape_strides}; + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; - if ((k > n && k > m) || m == 1) { - size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + sycl::range<1> iter_range{iter_nelems}; - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - if (m == 1) { - constexpr int m_groups = 1; - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - if (k <= (delta_k * n_wi)) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - - using dpctl::tensor::offset_utils::Strided1DIndexer; - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, - m_groups>; - - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, - local_B_block, n, n_blocks, delta_n, k, - k_blocks, delta_k, n_wi, m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = batch_nelems * n * m; - size_t reduction_nelems = - (k + delta_k * n_wi - 1) / (delta_k * n_wi); - - // more than one work-group is needed, requires a - // temporary delta_k * n_wi elements processed along k, - // so if more to process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = - sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, - m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, - m_groups>(lhs_tp, rhs_tp, tmp, workspace, - local_B_block, n, n_blocks, - delta_n, k, k_blocks, delta_k, - n_wi, m, batch_nelems, - batch_indexer, lhs_indexer, - rhs_indexer, tmp_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for< - class gemm_reduction_seq_strided_krn< - resTy, resTy, ReductionOpT, - InputOutputIterIndexerT, - ReductionIndexerT>>( - iter_range, - SequentialReduction>( + iter_range, SequentialReduction( tmp, res_tp, ReductionOpT(), identity_val, in_out_iter_indexer, reduction_indexer, reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task( - [ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info< - sycl::info::device::max_work_group_size>() / - 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * - (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = partially_reduced_tmp + - reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - constexpr int m_groups = 1; - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, - m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, - workspace, local_B_block, n, n_blocks, delta_n, - k, k_blocks, delta_k, n_wi, m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - tmp_indexer)); - }); - - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, - partially_reduced_tmp2, res_tp, identity_val, - iter_nelems, reduction_nelems, reduction_groups, wg, - max_wg, preferred_reductions_per_wi, - reductions_per_wi, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, BatchDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, wi_delta_k, + m, m_blocks, wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); } - else { - constexpr int m_groups = 2; - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - if (k <= (delta_k * n_wi)) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, - m_groups>; - - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, - local_B_block, n, n_blocks, delta_n, k, - k_blocks, delta_k, n_wi, m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - }); - return gemm_ev; - } else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = batch_nelems * n * m; - size_t reduction_nelems = - (k + delta_k * n_wi - 1) / (delta_k * n_wi); - - // more than one work-group is needed, requires a - // temporary delta_k * n_wi elements processed along k, - // so if more to process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = - sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = - sycl::local_accessor, - 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, - m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, - m_groups>(lhs_tp, rhs_tp, tmp, workspace, - local_B_block, n, n_blocks, - delta_n, k, k_blocks, delta_k, - n_wi, m, batch_nelems, - batch_indexer, lhs_indexer, - rhs_indexer, tmp_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for< - class gemm_reduction_seq_strided_krn< - resTy, resTy, ReductionOpT, - InputOutputIterIndexerT, - ReductionIndexerT>>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task( - [ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info< - sycl::info::device::max_work_group_size>() / - 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * - (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = partially_reduced_tmp + - reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, - m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, - workspace, local_B_block, n, n_blocks, delta_n, - k, k_blocks, delta_k, n_wi, m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - tmp_indexer)); - }); - - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, - partially_reduced_tmp2, res_tp, identity_val, - iter_nelems, reduction_nelems, reduction_groups, wg, - max_wg, preferred_reductions_per_wi, - reductions_per_wi, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, BatchDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, wi_delta_k, + m, m_blocks, wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, wg, + max_wg, preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; + } +} + +template +sycl::event +gemm_batch_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + int batch_nd, + const py::ssize_t *batch_shape_strides, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const py::ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const py::ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const py::ssize_t *res_outer_shapes_strides, + const py::ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + if ((k > n && k > m) || m == 1) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m == 1) { + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + else { + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); } } else { - constexpr int m_groups = 1; + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + else { // m > 1, n > k or m > k, resTy complex + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + } +} - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); +template +sycl::event +gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + std::vector const &depends) +{ + size_t delta_k(4); + size_t n_wi(4); + size_t delta_n(4); + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + if (k <= (delta_k * n_wi)) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - // each group processes delta_k * n_wi - // items in a column, so no need for allocating - // temp memory if only one group is needed - if (k <= (delta_k * n_wi)) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(n * m)}); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - constexpr int m_groups = 1; - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); + auto ndRange = sycl::nd_range<1>(gRange, lRange); + if constexpr (m_groups == 1) { using LocAccT = sycl::local_accessor; LocAccT local_B_block(n_wi * delta_k, cgh); LocAccT workspace(delta_n * delta_k, cgh); @@ -7595,212 +6164,20 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, using KernelName = class gemm_batch_tree_k_krn< lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( ndRange, GemmBatchNoAtomicFunctorThreadK< lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - - size_t iter_nelems = batch_nelems * n * m; - size_t reduction_nelems = - (k + delta_k * n_wi - 1) / (delta_k * n_wi); - - // more than one work-group is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - constexpr int m_groups = 1; - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, - m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, tmp, workspace, local_B_block, - n, n_blocks, delta_n, k, k_blocks, delta_k, - n_wi, m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, tmp_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / - 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); + rhs_indexer, tmp_indexer)); } else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(n * m)}); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = - (k + n_wi * delta_k - 1) / (n_wi * delta_k); - constexpr int m_groups = 1; - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - using LocAccT = sycl::local_accessor; + using LocAccT = + sycl::local_accessor, 1>; LocAccT local_B_block(n_wi * delta_k, cgh); LocAccT workspace(delta_n * delta_k, cgh); @@ -7813,100 +6190,349 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, tmp_indexer)); + lhs_tp, rhs_tp, tmp, workspace, local_B_block, n, + n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + } + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); + }); + return cleanup_host_task_event; + } + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + }); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); }); + }); + + return cleanup_host_task_event; + } +} + +template +sycl::event +gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + std::vector const &depends) +{ + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= wi_delta_k) { + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + size_t lws = wg_delta_n * wg_delta_m; + + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); + return gemm_ev; + } + else { + using ReductionOpT = sycl::plus; + constexpr resTy identity_val = + sycl::known_identity::value; + size_t iter_nelems = batch_nelems * n * m; + size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, - {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } - } - } - else { // m > 1, n > k or m > k - constexpr int wi_delta_n = 2; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI + size_t lws = wg_delta_n * wg_delta_m; - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - constexpr int wi_delta_m = 4; + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); - // each group processes delta_k * n_wi - // items in a column, so no need for allocating - // temp memory if only one group is needed - if (k <= wi_delta_k) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(n * m)}); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + if constexpr (wi_delta_m == 1) { using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); - using LocAccT2 = - sycl::local_accessor, 1>; + using LocAccT2 = sycl::local_accessor; LocAccT2 local_B_block( sycl::range<1>(wi_delta_k * wg_delta_m), cgh); @@ -7919,212 +6545,12 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - }); - return gemm_ev; - } - else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - size_t iter_nelems = batch_nelems * n * m; - size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - - // more than one work-group is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * - wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, - 1>; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, - wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, tmp_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } - - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / - 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); } else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } - - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(n * m)}); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), @@ -8143,349 +6569,233 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, tmp_indexer)); + lhs_tp, rhs_tp, tmp, local_A_block, local_B_block, + n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, + wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + }); + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gemm_ev); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, SequentialReduction( + tmp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); }); + return cleanup_host_task_event; + } - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, - {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; + + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; } - else { // m > 1, n > k or m > k, resTy not complex - constexpr int wi_delta_m = 1; - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - if (k <= wi_delta_k) { - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(n * m)}); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + size_t lws = wg_delta_n * wg_delta_m; - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, res_tp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, - batch_indexer, lhs_indexer, rhs_indexer, - res_indexer)); - }); - return gemm_ev; + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * + k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); + + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, wi_delta_k, + m, m_blocks, wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); } else { - using ReductionOpT = sycl::plus; - constexpr resTy identity_val = - sycl::known_identity::value; - size_t iter_nelems = batch_nelems * n * m; - size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; - - // more than one work-group is needed, requires a temporary - // delta_k * n_wi elements processed along k, so if more to - // process use multiple - const auto &sg_sizes = - dev.get_info(); - size_t wg = - choose_workgroup_size<4>(reduction_nelems, sg_sizes); - - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler - &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{ - 0, static_cast(batch_nelems), - static_cast(n * m)}); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = - sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * - wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, - wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, tmp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, - wi_delta_k, m, m_blocks, wg_delta_m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, tmp_indexer)); - }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils:: - TwoOffsets_CombinedIndexer; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{ - NoOpIndexerT{}, NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, - SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, tmp] { sycl::free(tmp, ctx); }); - }); - return cleanup_host_task_event; - } + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), + cgh); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / - 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error( - "Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, local_A_block, + local_B_block, n, wg_delta_n, k, k_blocks, wi_delta_k, + m, m_blocks, wg_delta_m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + }); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils:: - ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, - static_cast(batch_nelems), - static_cast(n * m)}); - - size_t lws = wg_delta_n * wg_delta_m; - - size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / - (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / - (wi_delta_m * wg_delta_m)); - - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * - m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + {gemm_ev}); - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block( - sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - batch_nelems, batch_indexer, lhs_indexer, - rhs_indexer, tmp_indexer)); + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); }); + }); + + return cleanup_host_task_event; + } +} + +template +sycl::event +gemm_batch_contig_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + size_t batch_nelems, + size_t n, + size_t k, + size_t m, + py::ssize_t lhs_batch_offset, + py::ssize_t rhs_batch_offset, + py::ssize_t res_batch_offset, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = + reinterpret_cast(lhs_cp) + lhs_batch_offset; + const rhsTy *rhs_tp = + reinterpret_cast(rhs_cp) + rhs_batch_offset; + resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, - res_tp, identity_val, iter_nelems, reduction_nelems, - reduction_groups, wg, max_wg, - preferred_reductions_per_wi, reductions_per_wi, - {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; + if ((k > n && k > m) || m == 1) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m == 1) { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); } + else { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); + } + } + else { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + else { // m > 1, n > k or m > k, resTy complex + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); } } } From eaa048a888230c0f19d978afdd92789c1b8747b3 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 11 Jan 2024 01:20:16 -0800 Subject: [PATCH 16/48] Test reversing data types for `test_matmul_strided` --- dpctl/tests/test_usm_ndarray_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 29b9ee17d3..70ab1092a9 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -261,7 +261,7 @@ def test_matmul_broadcasting(): assert r.shape == (7, 11, 13) -@pytest.mark.parametrize("dtype", ["i4", "i8", "f4", "c8"]) +@pytest.mark.parametrize("dtype", ["i4", "i8", "f4", "c8"][::-1]) def test_matmul_strided(dtype): get_queue_or_skip() From 6c57d2b818efc338a2ff00633bfa27d9fc2848d2 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 11 Jan 2024 07:39:56 -0800 Subject: [PATCH 17/48] pre-commit fixes in `gemm.hpp` --- .../libtensor/include/kernels/linalg_functions/gemm.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 1cc7e9db65..b8095491aa 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -5871,14 +5871,15 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, m, m_blocks, wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); } - else { + else { using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), cgh); + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); using LocAccT2 = sycl::local_accessor, 1>; LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); + cgh); using KernelName = class gemm_batch_tree_nm_krn< lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; From 7875f3818352475481f99796b52e4bc9ef19714e Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 11 Jan 2024 12:30:10 -0600 Subject: [PATCH 18/48] Check if malloc_device return nullptr (#1493) --- .../include/kernels/linalg_functions/gemm.hpp | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index b8095491aa..de06fcda55 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -2024,6 +2024,9 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q, if (reduction_nelems < wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -2144,7 +2147,7 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q, resTy *partially_reduced_tmp2 = nullptr; if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); + throw std::runtime_error("Unable to allocate device memory"); } else { partially_reduced_tmp2 = @@ -2360,6 +2363,9 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, if (reduction_nelems < wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -2768,6 +2774,9 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, if (reduction_nelems < wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -3092,6 +3101,9 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, if (reduction_nelems < wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -5254,6 +5266,9 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, if (reduction_nelems < wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -5647,6 +5662,9 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, if (reduction_nelems < wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -6124,6 +6142,9 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, if (reduction_nelems < wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -6493,6 +6514,9 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, if (reduction_nelems < wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); From f0079cfb93a56a10d4511d1d3ae02f74b717d0c0 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 11 Jan 2024 10:33:59 -0800 Subject: [PATCH 19/48] Add step to Linux conda package workflow to run `test_matmul_strided` under gdb Part of triaging CPU crashes --- .github/workflows/conda-package.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index fdc301c8cf..5ff5f4961b 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -191,6 +191,11 @@ jobs: . $CONDA/etc/profile.d/conda.sh conda activate test_dpctl gdb --batch -ex r -ex 'info sharedlibrary' -ex 'set print elements 1000' -ex bt --args ${CONDA_PREFIX}/bin/python -m pytest -q -ra --disable-warnings --pyargs dpctl.tests.elementwise.test_trigonometric::test_trig_order -vv || true + - name: Run test_matmul_strided under gdb + run: | + . $CONDA/etc/profile.d/conda.sh + conda activate test_dpctl + gdb --batch -ex r -ex 'info sharedlibrary' -ex 'set print elements 1000' -ex bt --args ${CONDA_PREFIX}/bin/python -m pytest -q -ra --disable-warnings --pyargs dpctl.tests.test_usm_ndarray_linalg::test_matmul_strided -vv || true - name: Run tests env: SYCL_CACHE_PERSISTENT: 1 From e05b8053bac82139897893238eb4c39b71d30bf2 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 11 Jan 2024 11:52:48 -0800 Subject: [PATCH 20/48] Remove unnecessary comments --- .../libtensor/include/kernels/linalg_functions/gemm.hpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index de06fcda55..b21a3568e0 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -2623,10 +2623,6 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, resTy *res_tp = reinterpret_cast(res_cp); if ((k > n && k > m) || m == 1) { - // each group processes delta_k * n_wi - // items in a column, so no need for allocating - // temp memory if only one group is needed - using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { if (m == 1) { @@ -3345,10 +3341,6 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, resTy *res_tp = reinterpret_cast(res_cp); if ((k > n && k > m) || m == 1) { - // each group processes delta_k * n_wi - // items in a column, so no need for allocating - // temp memory if only one group is needed - using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { if (m == 1) { From cb06dedcf0f457135c1393bbe676bfc433b42172 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 11 Jan 2024 13:46:28 -0800 Subject: [PATCH 21/48] Adds a fast-path for empty (k = 0) gemm kernels --- .../include/kernels/linalg_functions/gemm.hpp | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index b21a3568e0..a6465770d9 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1023,6 +1023,10 @@ sycl::event gemm_impl(sycl::queue &exec_q, }); }); + if (k == 0) { + return res_init_ev; + } + const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = dev.get_info(); @@ -1195,6 +1199,10 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, cgh.fill(res_tp, resTy(0), n * m); }); + if (k == 0) { + return res_init_ev; + } + const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = dev.get_info(); @@ -2601,6 +2609,8 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, } } +template class gemm_tree_empty_krn; + template sycl::event gemm_tree_impl(sycl::queue &exec_q, const char *lhs_cp, @@ -2622,6 +2632,24 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); resTy *res_tp = reinterpret_cast(res_cp); + if (k == 0) { + sycl::event gemm_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + IndexerT res_indexer(res_nd, 0, res_shapes_strides); + using InitKernelName = + class gemm_tree_empty_krn; + cgh.parallel_for( + sycl::range<1>(n * m), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + return gemm_no_reduction_ev; + } + if ((k > n && k > m) || m == 1) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { @@ -3340,6 +3368,15 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); resTy *res_tp = reinterpret_cast(res_cp); + if (k == 0) { + sycl::event gemm_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m); + }); + return gemm_no_reduction_ev; + } + if ((k > n && k > m) || m == 1) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { @@ -4123,6 +4160,10 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, }); }); + if (k == 0) { + return res_init_ev; + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(res_init_ev); @@ -4321,6 +4362,10 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, cgh.fill(res_tp, resTy(0), n * m * batch_nelems); }); + if (k == 0) { + return res_init_ev; + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(res_init_ev); @@ -5927,6 +5972,9 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, } } +template +class gemm_batch_tree_empty_krn; + template sycl::event gemm_batch_tree_impl(sycl::queue &exec_q, @@ -5956,6 +6004,25 @@ gemm_batch_tree_impl(sycl::queue &exec_q, const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); resTy *res_tp = reinterpret_cast(res_cp); + if (k == 0) { + sycl::event gemm_batch_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + IndexerT res_indexer(batch_nd + res_outer_nd, res_batch_offset, + res_shape_strides); + using InitKernelName = + class gemm_batch_tree_empty_krn; + cgh.parallel_for( + sycl::range<1>(n * m * batch_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + return gemm_batch_no_reduction_ev; + } + if ((k > n && k > m) || m == 1) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { @@ -6785,6 +6852,15 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, reinterpret_cast(rhs_cp) + rhs_batch_offset; resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; + if (k == 0) { + sycl::event gemm_batch_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m * batch_nelems); + }); + return gemm_batch_no_reduction_ev; + } + if ((k > n && k > m) || m == 1) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { From a06bb2d9f25d39959021404bfe6b5d44a4cd6b21 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 11 Jan 2024 17:54:32 -0800 Subject: [PATCH 22/48] Adds logic that avoids certain kernels on CPU that are known to be problematic Specifically uses logic to always avoid paths which would call k threaded functors on CPU with m_groups > 1 --- .../include/kernels/linalg_functions/gemm.hpp | 387 +++++++++++------- 1 file changed, 245 insertions(+), 142 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index a6465770d9..070fc000a1 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1042,7 +1042,7 @@ sycl::event gemm_impl(sycl::queue &exec_q, rhs_shape_strides); OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); - if ((k > n && k > m) || m == 1) { + if (m == 1) { constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(4); @@ -1077,46 +1077,42 @@ sycl::event gemm_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); - // } - // else if (k > n && k > m) { - // constexpr size_t m_groups = 2; - // size_t delta_k(4); - // size_t n_wi(4); - // size_t delta_n(4); - - // gemm_detail::scale_gemm_k_parameters( - // local_mem_size, reserved_slm_size, delta_k, - // n_wi, // modified by reference - // delta_n // modified by reference - // ); - - // size_t n_blocks = (n + delta_n - 1) / delta_n; - // size_t m_blocks = (m + m_groups - 1) / m_groups; - // size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * - // delta_k); - - // size_t lws = delta_n * delta_k; - - // auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * - // lws); auto lRange = sycl::range<1>(lws); - - // auto ndRange = sycl::nd_range<1>(gRange, lRange); - - // using LocAccT = sycl::local_accessor, 1>; LocAccT local_B_block(n_wi * delta_k, cgh); - // LocAccT workspace(delta_n * delta_k, cgh); - - // using KernelName = class gemm_k_krn; - // cgh.parallel_for( - // ndRange, GemmFunctorThreadK( - // lhs_tp, rhs_tp, res_tp, workspace, - // local_B_block, n, n_blocks, delta_n, k, - // k_blocks, delta_k, n_wi, m, lhs_indexer, - // rhs_indexer, res_indexer)); + } + else if (k > n && k > m && !exec_q.get_device().is_cpu()) { + constexpr size_t m_groups = 2; + size_t delta_k(4); + size_t n_wi(4); + size_t delta_n(4); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); } else { constexpr int wi_delta_n = 2; @@ -1216,7 +1212,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, OuterInnerIndexerT rhs_indexer{}; OuterInnerIndexerT res_indexer{}; - if ((k > n && k > m) || m == 1) { + if (m == 1) { constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(4); @@ -1252,44 +1248,42 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); } - // else if (k > n && k > m) { - // constexpr size_t m_groups = 2; - // size_t delta_k(4); - // size_t n_wi(4); - // size_t delta_n(4); - - // gemm_detail::scale_gemm_k_parameters( - // local_mem_size, reserved_slm_size, delta_k, - // n_wi, // modified by reference - // delta_n // modified by reference - // ); - - // size_t n_blocks = (n + delta_n - 1) / delta_n; - // size_t m_blocks = (m + m_groups - 1) / m_groups; - // size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - - // size_t lws = delta_n * delta_k; - - // auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * - // lws); auto lRange = sycl::range<1>(lws); - - // auto ndRange = sycl::nd_range<1>(gRange, lRange); - - // using LocAccT = sycl::local_accessor, - // 1>; LocAccT local_B_block(n_wi * delta_k, cgh); LocAccT - // workspace(delta_n * delta_k, cgh); - - // using KernelName = class gemm_k_krn; - // cgh.parallel_for( - // ndRange, GemmFunctorThreadK( - // lhs_tp, rhs_tp, res_tp, workspace, - // local_B_block, n, n_blocks, delta_n, k, - // k_blocks, delta_k, n_wi, m, lhs_indexer, - // rhs_indexer, res_indexer)); - // } + else if (k > n && k > m && !exec_q.get_device().is_cpu()) { + constexpr size_t m_groups = 2; + size_t delta_k(4); + size_t n_wi(4); + size_t delta_n(4); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = class gemm_k_krn; + cgh.parallel_for( + ndRange, GemmFunctorThreadK( + lhs_tp, rhs_tp, res_tp, workspace, local_B_block, + n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, + m, lhs_indexer, rhs_indexer, res_indexer)); + } else { constexpr int wi_delta_n = 2; constexpr int wi_delta_m = 4; @@ -2650,7 +2644,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, return gemm_no_reduction_ev; } - if ((k > n && k > m) || m == 1) { + if (exec_q.get_device().is_cpu()) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { if (m == 1) { @@ -2661,7 +2655,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, depends); } else { - return gemm_tree_k_impl( + return gemm_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, @@ -2669,28 +2663,56 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, } } else { - return gemm_tree_k_impl( + return gemm_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, depends); } } - else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - return gemm_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, - lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, - depends); + else { + if ((k > n && k > m) || m == 1) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m == 1) { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, + rhs_outer_nd, rhs_outer_inner_shapes_strides, res_nd, + res_shapes_strides, depends); + } + else { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, + rhs_outer_nd, rhs_outer_inner_shapes_strides, res_nd, + res_shapes_strides, depends); + } + } + else { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } } - else { - return gemm_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, - lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, - depends); + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + else { + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } } } } @@ -3377,7 +3399,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, return gemm_no_reduction_ev; } - if ((k > n && k > m) || m == 1) { + if (exec_q.get_device().is_cpu()) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { if (m == 1) { @@ -3385,24 +3407,43 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } else { - return gemm_contig_tree_k_impl( + return gemm_contig_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } } else { - return gemm_contig_tree_k_impl( + return gemm_contig_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } } - else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - return gemm_contig_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + else { + if ((k > n && k > m) || m == 1) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m == 1) { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + else { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + } + else { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } } - else { - return gemm_contig_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + else { + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } } } } @@ -4221,7 +4262,7 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); } - else if (k > n && k > m) { + else if (k > n && k > m && !exec_q.get_device().is_cpu()) { constexpr size_t m_groups = 2; size_t delta_k(4); size_t n_wi(4); @@ -4427,7 +4468,7 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); } - else if (k > n && k > m) { + else if (k > n && k > m && !exec_q.get_device().is_cpu()) { constexpr size_t m_groups = 2; size_t delta_k(4); size_t n_wi(4); @@ -6023,7 +6064,7 @@ gemm_batch_tree_impl(sycl::queue &exec_q, return gemm_batch_no_reduction_ev; } - if ((k > n && k > m) || m == 1) { + if (exec_q.get_device().is_cpu()) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { if (m == 1) { @@ -6036,7 +6077,7 @@ gemm_batch_tree_impl(sycl::queue &exec_q, res_outer_shapes_strides, res_shape_strides, depends); } else { - return gemm_batch_tree_k_impl( + return gemm_batch_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, batch_shape_strides, lhs_batch_offset, rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, @@ -6046,7 +6087,7 @@ gemm_batch_tree_impl(sycl::queue &exec_q, } } else { - return gemm_batch_tree_k_impl( + return gemm_batch_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, batch_shape_strides, lhs_batch_offset, rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, @@ -6055,25 +6096,61 @@ gemm_batch_tree_impl(sycl::queue &exec_q, res_outer_shapes_strides, res_shape_strides, depends); } } - else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - return gemm_batch_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, - batch_shape_strides, lhs_batch_offset, rhs_batch_offset, - res_batch_offset, inner_nd, lhs_outer_nd, - lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_outer_nd, - res_outer_shapes_strides, res_shape_strides, depends); + else { + if ((k > n && k > m) || m == 1) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m == 1) { + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, + rhs_outer_nd, rhs_outer_inner_shapes_strides, + res_outer_nd, res_outer_shapes_strides, + res_shape_strides, depends); + } + else { + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, + rhs_outer_nd, rhs_outer_inner_shapes_strides, + res_outer_nd, res_outer_shapes_strides, + res_shape_strides, depends); + } + } + else { + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } } - else { // m > 1, n > k or m > k, resTy complex - return gemm_batch_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, - batch_shape_strides, lhs_batch_offset, rhs_batch_offset, - res_batch_offset, inner_nd, lhs_outer_nd, - lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_outer_nd, - res_outer_shapes_strides, res_shape_strides, depends); + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + else { // m > 1, n > k or m > k, resTy complex + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } } } } @@ -6861,7 +6938,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, return gemm_batch_no_reduction_ev; } - if ((k > n && k > m) || m == 1) { + if (exec_q.get_device().is_cpu()) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { if (m == 1) { @@ -6870,25 +6947,51 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, depends); } else { - return gemm_batch_contig_tree_k_impl( + return gemm_batch_contig_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); } } else { - return gemm_batch_contig_tree_k_impl( + return gemm_batch_contig_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); } } - else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - return gemm_batch_contig_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + else { + if ((k > n && k > m) || m == 1) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m == 1) { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); + } + else { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); + } + } + else { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); + } } - else { // m > 1, n > k or m > k, resTy complex - return gemm_batch_contig_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); + } + else { // m > 1, n > k or m > k, resTy complex + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); + } } } } From 00ec8e6c495dcbae8e82dbfd61f3e27bda870350 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 15 Jan 2024 11:46:02 -0600 Subject: [PATCH 23/48] Also access memory if indices are in range This prevents out-of-bound access that was responsible for crashes observed in CI. --- .../include/kernels/linalg_functions/gemm.hpp | 92 ++++++++++--------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 070fc000a1..0c56fdb766 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -804,12 +804,13 @@ class GemmFunctorThreadK size_t global_s_offset = i * k + t_shift; sycl::vec private_sum(identity_); + constexpr sycl::vec vec_identity_(identity_); for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { - if (t + t_shift < k) { - private_sum += - (static_cast(lhs[lhs_indexer(global_s_offset + t)]) * - local_B_block[t]); - } + private_sum += ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; } size_t workspace_i_shift = local_i * delta_k; @@ -936,11 +937,11 @@ class GemmFunctorThreadK resT private_sum(identity_); for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { - if (t + t_shift < k) { - private_sum += - (static_cast(lhs[lhs_indexer(global_s_offset + t)]) * - local_B_block[t]); - } + private_sum += ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : identity_; } size_t workspace_i_shift = local_i * delta_k; @@ -1743,12 +1744,13 @@ class GemmNoAtomicFunctorThreadK size_t global_s_offset = i * k + t_shift; sycl::vec private_sum(identity_); + constexpr sycl::vec vec_identity_(identity_); for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { - if (t + t_shift < k) { - private_sum += - (static_cast(lhs[lhs_indexer(global_s_offset + t)]) * - local_B_block[t]); - } + private_sum += ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; } size_t workspace_i_shift = local_i * delta_k; @@ -1872,11 +1874,11 @@ class GemmNoAtomicFunctorThreadK(lhs[lhs_indexer(global_s_offset + t)]) * - local_B_block[t]); - } + private_sum += ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : identity_; } size_t workspace_i_shift = local_i * delta_k; @@ -3923,13 +3925,14 @@ class GemmBatchFunctorThreadK size_t global_s_offset = i * k + t_shift; sycl::vec private_sum(identity_); + constexpr sycl::vec vec_identity_(identity_); for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { - if (t + t_shift < k) { - private_sum += - (static_cast( - lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * - local_B_block[t]); - } + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; } size_t workspace_i_shift = local_i * delta_k; @@ -4083,12 +4086,12 @@ class GemmBatchFunctorThreadK( - lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * - local_B_block[t]); - } + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : identity_; } size_t workspace_i_shift = local_i * delta_k; @@ -5024,13 +5027,14 @@ class GemmBatchNoAtomicFunctorThreadK size_t global_s_offset = i * k + t_shift; sycl::vec private_sum(identity_); + constexpr sycl::vec vec_identity_(identity_); for (size_t t = local_s; t < local_B_block.size(); t += delta_k) { - if (t + t_shift < k) { - private_sum += - (static_cast( - lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * - local_B_block[t]); - } + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; } size_t workspace_i_shift = local_i * delta_k; @@ -5171,12 +5175,12 @@ class GemmBatchNoAtomicFunctorThreadK( - lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * - local_B_block[t]); - } + private_sum += + ((i < n) && ((t + t_shift < k))) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : identity_; } size_t workspace_i_shift = local_i * delta_k; From d97a9c2bd2c6c8220a133c40ba3b2fa73b2ee5a6 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 15 Jan 2024 12:31:31 -0600 Subject: [PATCH 24/48] Simplified computation of m_id/gr_id in kernels No need to use both it.get_global_linear_id() and it.get_group_linear_id() to compute batch id and group id. --- .../include/kernels/linalg_functions/gemm.hpp | 64 +++++++++---------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 0c56fdb766..b4401ffc4e 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -3512,10 +3512,10 @@ class GemmBatchFunctorThreadNM void operator()(sycl::nd_item<1> it) const { - size_t m_id = - it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); - size_t gr_id = - it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; const auto &three_offsets_ = batch_indexer(static_cast(m_id)); @@ -3700,10 +3700,10 @@ class GemmBatchFunctorThreadNM it) const { - size_t m_id = - it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); - size_t gr_id = - it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; const auto &three_offsets_ = batch_indexer(static_cast(m_id)); @@ -3871,10 +3871,10 @@ class GemmBatchFunctorThreadK // batch_nelems) for lhs, offset = m_id * (n * k) for rhs, offset = // m_id // * (k * m) for res, offset = m_id * (n * m) - size_t m_id = - it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); - size_t gr_id = - it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; size_t lid = it.get_local_linear_id(); const auto &three_offsets_ = @@ -4034,10 +4034,10 @@ class GemmBatchFunctorThreadK it) const { - size_t m_id = - it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); - size_t gr_id = - it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; const auto &three_offsets_ = batch_indexer(static_cast(m_id)); @@ -4816,10 +4816,10 @@ class GemmBatchNoAtomicFunctorThreadNM it) const { - size_t m_id = - it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); - size_t gr_id = - it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; const auto &three_offsets_ = batch_indexer(static_cast(m_id)); @@ -4978,10 +4978,10 @@ class GemmBatchNoAtomicFunctorThreadK void operator()(sycl::nd_item<1> it) const { - size_t m_id = - it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); - size_t gr_id = - it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; size_t lid = it.get_local_linear_id(); const auto &three_offsets_ = @@ -5125,10 +5125,10 @@ class GemmBatchNoAtomicFunctorThreadK it) const { - size_t m_id = - it.get_global_linear_id() / (it.get_global_range(0) / batch_nelems); - size_t gr_id = - it.get_group_linear_id() % (it.get_group_range(0) / batch_nelems); + const size_t n_groups_per_batch = it.get_group_range(0) / batch_nelems; + const size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; size_t lid = it.get_local_linear_id(); const auto &three_offsets_ = From 1dc2541ae4bc4030fdbed568050ef3847d9864bb Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 15 Jan 2024 13:50:16 -0600 Subject: [PATCH 25/48] Change generic kernels to work for any value of m_groups, not just m_groups=2 --- .../include/kernels/linalg_functions/gemm.hpp | 57 ++++++++++++------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index b4401ffc4e..25345c2349 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -831,13 +831,16 @@ class GemmFunctorThreadK aout0 += local_sum[0]; - if (j + 1 < m) { - sycl::atomic_ref - aout1(res[res_indexer(i * m + j + 1)]); +#pragma unroll + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + sycl::atomic_ref + aout1(res[res_indexer(i * m + j + vec_id)]); - aout1 += local_sum[1]; + aout1 += local_sum[vec_id]; + } } } } @@ -1764,11 +1767,15 @@ class GemmNoAtomicFunctorThreadK local_sum += workspace[workspace_i_shift + t]; } - res[res_indexer(i * m + j) + (block_s * n * m)] = local_sum[0]; + const size_t res_offset = (block_s * n * m); + res[res_indexer(i * m + j) + res_offset] = local_sum[0]; - if (j + 1 < m) { - res[res_indexer(i * m + j + 1) + (block_s * n * m)] = - local_sum[1]; +#pragma unroll + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + res[res_indexer(i * m + j + vec_id) + res_offset] = + local_sum[vec_id]; + } } } } @@ -3953,13 +3960,17 @@ class GemmBatchFunctorThreadK aout0 += local_sum[0]; - if (j + 1 < m) { - sycl::atomic_ref - aout1(res[res_offset + res_indexer(i * m + j + 1)]); +#pragma unroll + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + sycl::atomic_ref + aout1( + res[res_offset + res_indexer(i * m + j + vec_id)]); - aout1 += local_sum[1]; + aout1 += local_sum[vec_id]; + } } } } @@ -5048,12 +5059,16 @@ class GemmBatchNoAtomicFunctorThreadK local_sum += workspace[workspace_i_shift + t]; } - res[res_offset + res_indexer(i * m + j) + - (block_s * n * m * batch_nelems)] = local_sum[0]; + const size_t total_offset = + res_offset + (block_s * n * m * batch_nelems); + res[total_offset + res_indexer(i * m + j)] = local_sum[0]; - if (j + 1 < m) { - res[res_offset + res_indexer(i * m + j + 1) + - (block_s * n * m * batch_nelems)] = local_sum[1]; +#pragma unroll + for (size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + res[total_offset + res_indexer(i * m + j + vec_id)] = + local_sum[1]; + } } } } From d930b2ee22314a8b4544731d73fec6dee58f6112 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 15 Jan 2024 13:51:02 -0600 Subject: [PATCH 26/48] Remove work-arounds/special-casing for CPUs --- .../include/kernels/linalg_functions/gemm.hpp | 62 +++++-------------- 1 file changed, 17 insertions(+), 45 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 25345c2349..f7cca9b3f8 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -2653,7 +2653,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, return gemm_no_reduction_ev; } - if (exec_q.get_device().is_cpu()) { + if ((k > n && k > m) || m == 1) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { if (m == 1) { @@ -2664,7 +2664,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, depends); } else { - return gemm_tree_nm_impl( + return gemm_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, @@ -2672,56 +2672,28 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, } } else { - return gemm_tree_nm_impl( + return gemm_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, depends); } } - else { - if ((k > n && k > m) || m == 1) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - if (m == 1) { - return gemm_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, - lhs_outer_nd, lhs_outer_inner_shapes_strides, - rhs_outer_nd, rhs_outer_inner_shapes_strides, res_nd, - res_shapes_strides, depends); - } - else { - return gemm_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, - lhs_outer_nd, lhs_outer_inner_shapes_strides, - rhs_outer_nd, rhs_outer_inner_shapes_strides, res_nd, - res_shapes_strides, depends); - } - } - else { - return gemm_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, - lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, - depends); - } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); } - else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - return gemm_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, - lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, - depends); - } - else { - return gemm_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, - lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, - depends); - } + else { + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); } } } From 7a277cbcd8230213dcc9ed46a00cc575b8283002 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 15 Jan 2024 14:33:41 -0600 Subject: [PATCH 27/48] Extended test_matmul_strided, reverted work-arounds --- dpctl/tests/test_usm_ndarray_linalg.py | 57 ++++++++++++-------------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 70ab1092a9..c65cab2d93 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -19,7 +19,6 @@ import numpy as np import pytest -import dpctl import dpctl.tensor as dpt from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported @@ -68,38 +67,18 @@ def test_matrix_transpose_arg_validation(): assert isinstance(dpt.matrix_transpose(X), dpt.usm_ndarray) -# @pytest.mark.parametrize("dtype", _numeric_types) -# def test_matmul_simple(dtype): -# q = get_queue_or_skip() -# skip_if_dtype_not_supported(dtype, q) - -# n, m = 235, 17 -# m1 = dpt.ones((m, n), dtype=dtype) -# m2 = dpt.ones((n, m), dtype=dtype) - -# for k in [1, 2, 3, 4, 7, 8, 9, 15, 16, 17]: -# r = dpt.matmul(m1[:k, :], m2[:, :k]) -# assert dpt.all(r == dpt.full((k, k), n, dtype=dtype)) - - -@pytest.mark.parametrize("dtype", _numeric_types[::-1]) -def test_matmul_simple2(dtype): +@pytest.mark.parametrize("dtype", _numeric_types) +def test_matmul_simple(dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) - dev = q.sycl_device - if dev.is_cpu: - cpu_count = dev.max_compute_units - sub_devs = dev.create_sub_devices(partition=min(2, cpu_count // 2)) - ctx = dpctl.SyclContext(sub_devs[0]) - q = dpctl.SyclQueue(ctx, sub_devs[0]) n, m = 235, 17 - m1 = dpt.ones((m, n), dtype=dtype, sycl_queue=q) - m2 = dpt.ones((n, m), dtype=dtype, sycl_queue=q) + m1 = dpt.ones((m, n), dtype=dtype) + m2 = dpt.ones((n, m), dtype=dtype) for k in [1, 2, 3, 4, 7, 8, 9, 15, 16, 17]: r = dpt.matmul(m1[:k, :], m2[:, :k]) - assert dpt.all(r == dpt.full((k, k), n, dtype=dtype, sycl_queue=q)) + assert dpt.all(r == dpt.full((k, k), n, dtype=dtype)) @pytest.mark.parametrize("dtype", _numeric_types) @@ -261,7 +240,7 @@ def test_matmul_broadcasting(): assert r.shape == (7, 11, 13) -@pytest.mark.parametrize("dtype", ["i4", "i8", "f4", "c8"][::-1]) +@pytest.mark.parametrize("dtype", ["i4", "i8", "f4", "c8"]) def test_matmul_strided(dtype): get_queue_or_skip() @@ -271,12 +250,30 @@ def test_matmul_strided(dtype): m1_size = m1_size * el m1 = dpt.remainder(dpt.arange(1, m1_size + 1, dtype="i8"), 13) - m1 = dpt.reshape(dpt.astype(m1, dtype), (14, 22, 32))[::2, ::-2, ::2] - m2 = dpt.ones((14, 16, 13), dtype=dtype)[::2, :, :] + m1_orig = dpt.reshape(dpt.astype(m1, dtype), m1_shape) + m2_orig = dpt.ones((14, 16, 13), dtype=dtype) + m1 = m1_orig[::2, ::-2, ::2] + m2 = m2_orig[::2, :, :] r = dpt.matmul(m1, m2) - assert r.shape == (7, 11, 13) + assert r.shape == m1.shape[:2] + m2.shape[-1:] + ref = np.matmul(dpt.asnumpy(m1), dpt.asnumpy(m2)) + assert np.allclose(dpt.asnumpy(r), ref) + + m1 = m1_orig[::2, ::2, ::-2] + m2 = m2_orig[::2, :, :] + r = dpt.matmul(m1, m2) + + assert r.shape == m1.shape[:2] + m2.shape[-1:] + ref = np.matmul(dpt.asnumpy(m1), dpt.asnumpy(m2)) + assert np.allclose(dpt.asnumpy(r), ref) + + m1 = m1_orig[::-2, ::2, ::2] + m2 = m2_orig[::-2, :, :] + r = dpt.matmul(m1, m2) + + assert r.shape == m1.shape[:2] + m2.shape[-1:] ref = np.matmul(dpt.asnumpy(m1), dpt.asnumpy(m2)) assert np.allclose(dpt.asnumpy(r), ref) From 686276a160eda0e95f5d5fd558b1327c541f87d2 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 15 Jan 2024 16:03:46 -0800 Subject: [PATCH 28/48] Revert remaining gemm work-arounds This commit removes remaining checks for if a kernel is called on CPU as well as reverting hyperparameters for gemm kernels to their original values --- .../include/kernels/linalg_functions/gemm.hpp | 175 +++++------------- 1 file changed, 47 insertions(+), 128 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index f7cca9b3f8..6963608709 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1082,7 +1082,7 @@ sycl::event gemm_impl(sycl::queue &exec_q, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); } - else if (k > n && k > m && !exec_q.get_device().is_cpu()) { + else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); size_t n_wi(4); @@ -1252,7 +1252,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); } - else if (k > n && k > m && !exec_q.get_device().is_cpu()) { + else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); size_t n_wi(4); @@ -3380,7 +3380,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, return gemm_no_reduction_ev; } - if (exec_q.get_device().is_cpu()) { + if ((k > n && k > m) || m == 1) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { if (m == 1) { @@ -3388,43 +3388,24 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } else { - return gemm_contig_tree_nm_impl( + return gemm_contig_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } } else { - return gemm_contig_tree_nm_impl( + return gemm_contig_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } } - else { - if ((k > n && k > m) || m == 1) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - if (m == 1) { - return gemm_contig_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); - } - else { - return gemm_contig_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); - } - } - else { - return gemm_contig_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); - } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } - else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - return gemm_contig_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); - } - else { - return gemm_contig_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); - } + else { + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } } } @@ -4248,7 +4229,7 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); } - else if (k > n && k > m && !exec_q.get_device().is_cpu()) { + else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); size_t n_wi(4); @@ -4454,7 +4435,7 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); } - else if (k > n && k > m && !exec_q.get_device().is_cpu()) { + else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); size_t n_wi(4); @@ -6055,7 +6036,7 @@ gemm_batch_tree_impl(sycl::queue &exec_q, return gemm_batch_no_reduction_ev; } - if (exec_q.get_device().is_cpu()) { + if ((k > n && k > m) || m == 1) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { if (m == 1) { @@ -6068,7 +6049,7 @@ gemm_batch_tree_impl(sycl::queue &exec_q, res_outer_shapes_strides, res_shape_strides, depends); } else { - return gemm_batch_tree_nm_impl( + return gemm_batch_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, batch_shape_strides, lhs_batch_offset, rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, @@ -6078,7 +6059,7 @@ gemm_batch_tree_impl(sycl::queue &exec_q, } } else { - return gemm_batch_tree_nm_impl( + return gemm_batch_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, batch_shape_strides, lhs_batch_offset, rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, @@ -6087,61 +6068,25 @@ gemm_batch_tree_impl(sycl::queue &exec_q, res_outer_shapes_strides, res_shape_strides, depends); } } - else { - if ((k > n && k > m) || m == 1) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - if (m == 1) { - return gemm_batch_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - batch_nd, batch_shape_strides, lhs_batch_offset, - rhs_batch_offset, res_batch_offset, inner_nd, - lhs_outer_nd, lhs_outer_inner_shapes_strides, - rhs_outer_nd, rhs_outer_inner_shapes_strides, - res_outer_nd, res_outer_shapes_strides, - res_shape_strides, depends); - } - else { - return gemm_batch_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - batch_nd, batch_shape_strides, lhs_batch_offset, - rhs_batch_offset, res_batch_offset, inner_nd, - lhs_outer_nd, lhs_outer_inner_shapes_strides, - rhs_outer_nd, rhs_outer_inner_shapes_strides, - res_outer_nd, res_outer_shapes_strides, - res_shape_strides, depends); - } - } - else { - return gemm_batch_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - batch_nd, batch_shape_strides, lhs_batch_offset, - rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, - lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_outer_nd, - res_outer_shapes_strides, res_shape_strides, depends); - } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); } - else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - return gemm_batch_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - batch_nd, batch_shape_strides, lhs_batch_offset, - rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, - lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_outer_nd, - res_outer_shapes_strides, res_shape_strides, depends); - } - else { // m > 1, n > k or m > k, resTy complex - return gemm_batch_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - batch_nd, batch_shape_strides, lhs_batch_offset, - rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, - lhs_outer_inner_shapes_strides, rhs_outer_nd, - rhs_outer_inner_shapes_strides, res_outer_nd, - res_outer_shapes_strides, res_shape_strides, depends); - } + else { // m > 1, n > k or m > k, resTy complex + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); } } } @@ -6929,7 +6874,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, return gemm_batch_no_reduction_ev; } - if (exec_q.get_device().is_cpu()) { + if ((k > n && k > m) || m == 1) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { if (m == 1) { @@ -6938,51 +6883,25 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, depends); } else { - return gemm_batch_contig_tree_nm_impl( + return gemm_batch_contig_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); } } else { - return gemm_batch_contig_tree_nm_impl( + return gemm_batch_contig_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); } } - else { - if ((k > n && k > m) || m == 1) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - if (m == 1) { - return gemm_batch_contig_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - depends); - } - else { - return gemm_batch_contig_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - depends); - } - } - else { - return gemm_batch_contig_tree_k_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - depends); - } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); } - else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { - return gemm_batch_contig_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - depends); - } - else { // m > 1, n > k or m > k, resTy complex - return gemm_batch_contig_tree_nm_impl( - exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, - depends); - } + else { // m > 1, n > k or m > k, resTy complex + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); } } } From 303e7dbc97c7198495901c1eefeb435726baf689 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 15 Jan 2024 19:11:31 -0800 Subject: [PATCH 29/48] Revert tuning down of `gemm` kernel parameters --- .../include/kernels/linalg_functions/gemm.hpp | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 6963608709..a811f20d17 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1049,8 +1049,8 @@ sycl::event gemm_impl(sycl::queue &exec_q, if (m == 1) { constexpr size_t m_groups = 1; size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + size_t n_wi(64); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -1085,8 +1085,8 @@ sycl::event gemm_impl(sycl::queue &exec_q, else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + size_t n_wi(64); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -1219,8 +1219,8 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, if (m == 1) { constexpr size_t m_groups = 1; size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + size_t n_wi(64); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -1255,8 +1255,8 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + size_t n_wi(64); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -1941,8 +1941,8 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q, const std::vector &depends) { size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + size_t n_wi(64); + size_t delta_n(16); const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = @@ -2709,8 +2709,8 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, std::vector const &depends) { size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + size_t n_wi(64); + size_t delta_n(16); const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = @@ -4191,8 +4191,8 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, if (m == 1) { constexpr size_t m_groups = 1; size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + size_t n_wi(64); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -4232,8 +4232,8 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + size_t n_wi(64); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -4397,8 +4397,8 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, if (m == 1) { constexpr size_t m_groups = 1; size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + size_t n_wi(64); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -4438,8 +4438,8 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, else if (k > n && k > m) { constexpr size_t m_groups = 2; size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + size_t n_wi(64); + size_t delta_n(16); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -5212,8 +5212,8 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, std::vector const &depends) { size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + size_t n_wi(64); + size_t delta_n(16); const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = @@ -6104,8 +6104,8 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, std::vector const &depends) { size_t delta_k(4); - size_t n_wi(4); - size_t delta_n(4); + size_t n_wi(64); + size_t delta_n(16); const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = From 0e14ba83daa05a3bf096d4d66ff92b42389ea406 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 16 Jan 2024 08:03:08 -0600 Subject: [PATCH 30/48] Removed logically dead code from _linear_algebra_functions.py --- dpctl/tensor/_linear_algebra_functions.py | 60 ++++++----------------- 1 file changed, 16 insertions(+), 44 deletions(-) diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index 4a8c19f667..502c48b33f 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -105,32 +105,18 @@ def tensordot(x1, x2, axes=2): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") q1, x1_usm_type = x1.sycl_queue, x1.usm_type q2, x2_usm_type = x2.sycl_queue, x2.usm_type - if q1 is None and q2 is None: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: raise ExecutionPlacementError( "Execution placement can not be unambiguously inferred " - "from input arguments. " - "One of the arguments must represent USM allocation and " - "expose `__sycl_usm_array_interface__` property" + "from input arguments." ) - if q1 is None: - exec_q = q2 - res_usm_type = x2_usm_type - elif q2 is None: - exec_q = q1 - res_usm_type = x1_usm_type - else: - exec_q = dpctl.utils.get_execution_queue((q1, q2)) - if exec_q is None: - raise ExecutionPlacementError( - "Execution placement can not be unambiguously inferred " - "from input arguments." - ) - res_usm_type = dpctl.utils.get_coerced_usm_type( - ( - x1_usm_type, - x2_usm_type, - ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, ) + ) dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) # handle axes and shapes validation x1_nd = x1.ndim @@ -345,32 +331,18 @@ def vecdot(x1, x2, axis=-1): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") q1, x1_usm_type = x1.sycl_queue, x1.usm_type q2, x2_usm_type = x2.sycl_queue, x2.usm_type - if q1 is None and q2 is None: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: raise ExecutionPlacementError( "Execution placement can not be unambiguously inferred " - "from input arguments. " - "One of the arguments must represent USM allocation and " - "expose `__sycl_usm_array_interface__` property" + "from input arguments." ) - if q1 is None: - exec_q = q2 - res_usm_type = x2_usm_type - elif q2 is None: - exec_q = q1 - res_usm_type = x1_usm_type - else: - exec_q = dpctl.utils.get_execution_queue((q1, q2)) - if exec_q is None: - raise ExecutionPlacementError( - "Execution placement can not be unambiguously inferred " - "from input arguments." - ) - res_usm_type = dpctl.utils.get_coerced_usm_type( - ( - x1_usm_type, - x2_usm_type, - ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, ) + ) dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) # axis and shape validation x1_nd = x1.ndim From 05a71ee9657608cc754edc8636df02b81d65fc3d Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 16 Jan 2024 08:05:55 -0600 Subject: [PATCH 31/48] Added more tests to improve coverage of _linear_algebra_functions --- dpctl/tests/test_usm_ndarray_linalg.py | 110 +++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index c65cab2d93..b49975fc8d 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -19,6 +19,7 @@ import numpy as np import pytest +import dpctl import dpctl.tensor as dpt from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported @@ -303,6 +304,35 @@ def test_matmul_out(): assert np.allclose(ref, dpt.asnumpy(res)) +def test_matmul_dtype(): + get_queue_or_skip() + + m1 = dpt.ones((10, 10), dtype="i4") + m2 = dpt.ones((10, 10), dtype="i8") + + r = dpt.matmul(m1, m2, dtype="f4") + assert r.dtype == dpt.float32 + + +@pytest.mark.parametrize("dt1", _numeric_types) +@pytest.mark.parametrize("dt2", _numeric_types) +def test_matmul_type_promotion(dt1, dt2): + get_queue_or_skip() + + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt1, q) + skip_if_dtype_not_supported(dt2, q) + + m1 = dpt.ones((10, 10), dtype=dt1) + m2 = dpt.ones((10, 10), dtype=dt2) + + r = dpt.matmul(m1, m2) + assert r.shape == ( + 10, + 10, + ) + + @pytest.mark.parametrize("dtype", _numeric_types) def test_tensordot_outer(dtype): q = get_queue_or_skip() @@ -373,6 +403,60 @@ def test_tensordot_axes_sequence(dtype): assert dpt.allclose(tdr, dpt.full_like(tdr, fill_value=expected)) +def test_tensordot_validation(): + get_queue_or_skip() + + with pytest.raises(TypeError): + dpt.tensordot(dict(), dict()) + + t1 = dpt.empty((10, 10, 10)) + with pytest.raises(TypeError): + dpt.tensordot(t1, dict()) + + t2 = dpt.empty((10, 10, 10)) + q = dpctl.SyclQueue(t2.sycl_context, t2.sycl_device, property="in_order") + with pytest.raises(dpctl.utils.ExecutionPlacementError): + dpt.tensordot(t1, t2.to_device(q)) + + invalid_axes = ( + 1, + 2, + 3, + ) + with pytest.raises(ValueError): + dpt.tensordot(t1, t2, axes=invalid_axes) + + invalid_axes = 5.2 + with pytest.raises(TypeError): + dpt.tensordot(t1, t2, axes=invalid_axes) + + invalid_axes = ( + (1,), + ( + 0, + 2, + ), + ) + with pytest.raises(ValueError): + dpt.tensordot(t1, t2, axes=invalid_axes) + + with pytest.raises(ValueError): + dpt.tensordot(t1[..., :5], t2) + + +def test_tensordot_promotion(): + get_queue_or_skip() + + t1 = dpt.zeros((10, 10), dtype="i4") + t2 = dpt.zeros((10, 10), dtype="i8") + + r1 = dpt.tensordot(t1, t2) + assert r1.dtype == t2.dtype + + r2 = dpt.tensordot(t2, t1) + assert r2.dtype == t2.dtype + + @pytest.mark.parametrize("dtype", _numeric_types) def test_vecdot_1d(dtype): q = get_queue_or_skip() @@ -484,6 +568,32 @@ def test_vector_arg_validation(): with pytest.raises(ValueError): dpt.vecdot(v1, v2, axis=2) + q = dpctl.SyclQueue( + v2.sycl_context, v2.sycl_device, property="enable_profiling" + ) + with pytest.raises(dpctl.utils.ExecutionPlacementError): + dpt.vecdot(v1, v2.to_device(q)) + + m1 = dpt.empty((10, 5)) + m2 = dpt.empty((5, 5)) + with pytest.raises(ValueError): + dpt.vecdot(m1, m2, axis=-1) + + +def test_vecdot_broadcast(): + get_queue_or_skip() + + for dt1, dt2 in [ + (dpt.int32, dpt.int32), + (dpt.int32, dpt.int64), + (dpt.int64, dpt.int32), + ]: + m1 = dpt.zeros((1, 5), dtype=dt1) + m2 = dpt.zeros((5, 5), dtype=dt2) + r1 = dpt.vecdot(m1, m2, axis=-1) + r2 = dpt.vecdot(m2, m1, axis=-1) + assert r1.shape == r2.shape + @pytest.mark.parametrize("dt1", _numeric_types) @pytest.mark.parametrize("dt2", _numeric_types) From 7e428e0d133394ccc185fa8cb9275fd62679f684 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 16 Jan 2024 09:16:04 -0600 Subject: [PATCH 32/48] Fixed "UnboundLocalError: local variable 'buf1_dt' referenced before assignment" Initialized buf1_dt and buf2_dt to None --- dpctl/tensor/_linear_algebra_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index 502c48b33f..a1642fe9a0 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -719,6 +719,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): else: res_dt = dpt.dtype(dtype) res_dt = _to_device_supported_dtype(res_dt, sycl_dev) + buf1_dt, buf2_dt = None, None if x1_dtype != res_dt: if dpt.can_cast(x1_dtype, res_dt, casting="same_kind"): buf1_dt = res_dt From 1e689d7a9fcafdd67ef5e09cd8fb4bcc6acdf6ee Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 16 Jan 2024 09:57:27 -0600 Subject: [PATCH 33/48] More tests to improve coverage --- dpctl/tests/test_usm_ndarray_linalg.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index b49975fc8d..fa80521a7f 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -300,18 +300,32 @@ def test_matmul_out(): assert dpt.allclose(dpt.zeros_like(res), buf[::-2, 0::3, :]) assert dpt.allclose(dpt.zeros_like(res), buf[::-2, 2::3, :]) - ref = np.matmul(dpt.asnumpy(m1), dpt.asnumpy(m2)) + m1_np = dpt.asnumpy(m1) + ref = np.matmul(m1_np, dpt.asnumpy(m2)) + assert np.allclose(ref, dpt.asnumpy(res)) + + res = dpt.matmul(m1[:, :10, :10], m1[:, :10, :10].mT, out=m1[:, :10, :10]) + ref = np.matmul( + m1_np[:, :10, :10], np.transpose(m1_np[:, :10, :10], (0, 2, 1)) + ) assert np.allclose(ref, dpt.asnumpy(res)) def test_matmul_dtype(): get_queue_or_skip() - m1 = dpt.ones((10, 10), dtype="i4") - m2 = dpt.ones((10, 10), dtype="i8") + for dt1, dt2 in [ + (dpt.int32, dpt.int16), + (dpt.int16, dpt.int32), + (dpt.float32, dpt.int16), + (dpt.int32, dpt.float32), + ]: + m1 = dpt.ones((10, 10), dtype=dt1) + m2 = dpt.ones((10, 10), dtype=dt2) - r = dpt.matmul(m1, m2, dtype="f4") - assert r.dtype == dpt.float32 + for ord in ["C", "A", "F", "K"]: + r = dpt.matmul(m1, m2, dtype=dpt.float32, order=ord) + assert r.dtype == dpt.float32 @pytest.mark.parametrize("dt1", _numeric_types) From 35cc4584db932ab178ba78ea18fd0eeebf733412 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 16 Jan 2024 13:38:20 -0800 Subject: [PATCH 34/48] Removed more dead branches in _linear_algebra_functions.py --- dpctl/tensor/_linear_algebra_functions.py | 32 +++++++---------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index a1642fe9a0..c50d5daa3d 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -638,32 +638,18 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): order = "K" q1, x1_usm_type = x1.sycl_queue, x1.usm_type q2, x2_usm_type = x2.sycl_queue, x2.usm_type - if q1 is None and q2 is None: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: raise ExecutionPlacementError( "Execution placement can not be unambiguously inferred " - "from input arguments. " - "One of the arguments must represent USM allocation and " - "expose `__sycl_usm_array_interface__` property" - ) - if q1 is None: - exec_q = q2 - res_usm_type = x2_usm_type - elif q2 is None: - exec_q = q1 - res_usm_type = x1_usm_type - else: - exec_q = dpctl.utils.get_execution_queue((q1, q2)) - if exec_q is None: - raise ExecutionPlacementError( - "Execution placement can not be unambiguously inferred " - "from input arguments." - ) - res_usm_type = dpctl.utils.get_coerced_usm_type( - ( - x1_usm_type, - x2_usm_type, - ) + "from input arguments." ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, + ) + ) dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) x1_nd = x1.ndim From d8659d4c1bc3849afc5a6d6a9f60629510285b5a Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 16 Jan 2024 13:39:41 -0800 Subject: [PATCH 35/48] `tensordot` now properly handles negative `axes` As per array API, negative axes are not permitted --- dpctl/tensor/_linear_algebra_functions.py | 26 ++++++++++++++--------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index c50d5daa3d..1f844a89f9 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -124,9 +124,11 @@ def tensordot(x1, x2, axes=2): x1_shape = x1.shape x2_shape = x2.shape if isinstance(axes, int): + if axes < 0: + raise ValueError("`axes` integer is expected to be non-negative") n_axes1 = axes n_axes2 = axes - axes1 = tuple(range(-axes, 0)) + axes1 = normalize_axis_tuple(tuple(range(-axes, 0)), x1_nd) axes2 = tuple(range(0, axes)) elif isinstance(axes, tuple): if len(axes) != 2: @@ -151,9 +153,13 @@ def tensordot(x1, x2, axes=2): else: same_shapes = True for i in range(n_axes1): - same_shapes = same_shapes and ( - x1_shape[axes1[i]] == x2_shape[axes2[i]] - ) + axis1 = axes1[i] + if axis1 < 0: + raise ValueError("`axes` must be non-negative") + axis2 = axes2[i] + if axis2 < 0: + raise ValueError("`axes` must be non-negative") + same_shapes = same_shapes and (x1_shape[axis1] == x2_shape[axis2]) if not same_shapes: raise ValueError("shape mismatch in contracted `tensordot` axes") axes1 = normalize_axis_tuple(axes1, x1_nd) @@ -788,7 +794,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): x1 = dpt.broadcast_to(x1, x1_broadcast_shape) if x2.shape != x2_broadcast_shape: x2 = dpt.broadcast_to(x2, x2_broadcast_shape) - ht_dot_ev, binary_ev = tli._dot( + ht_dot_ev, dot_ev = tli._dot( x1=x1, x2=x2, batch_dims=len(res_shape[:-2]), @@ -804,7 +810,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): src=out, dst=orig_out, sycl_queue=exec_q, - depends=[binary_ev], + depends=[dot_ev], ) ht_copy_out_ev.wait() out = orig_out @@ -840,7 +846,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): x1 = dpt.broadcast_to(x1, x1_broadcast_shape) if buf2.shape != x2_broadcast_shape: buf2 = dpt.broadcast_to(buf2, x2_broadcast_shape) - ht_dot_ev, binary_ev = tli._dot( + ht_dot_ev, dot_ev = tli._dot( x1=x1, x2=buf2, batch_dims=len(res_shape[:-2]), @@ -857,7 +863,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): src=out, dst=orig_out, sycl_queue=exec_q, - depends=[binary_ev], + depends=[dot_ev], ) ht_copy_out_ev.wait() out = orig_out @@ -895,7 +901,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): buf1 = dpt.broadcast_to(buf1, x1_broadcast_shape) if x2.shape != x2_broadcast_shape: x2 = dpt.broadcast_to(x2, x2_broadcast_shape) - ht_dot_ev, binary_ev = tli._dot( + ht_dot_ev, dot_ev = tli._dot( x1=buf1, x2=x2, batch_dims=len(res_shape[:-2]), @@ -912,7 +918,7 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): src=out, dst=orig_out, sycl_queue=exec_q, - depends=[binary_ev], + depends=[dot_ev], ) ht_copy_out_ev.wait() out = orig_out From 11b710c6bfcbb848389bea9a45c7b65cbd8859d6 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 16 Jan 2024 13:42:24 -0800 Subject: [PATCH 36/48] Adds `test_tensordot_type_matrix` to `test_usm_ndarray_linalg.py` --- dpctl/tests/test_usm_ndarray_linalg.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index fa80521a7f..59a9daf6fc 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -471,6 +471,25 @@ def test_tensordot_promotion(): assert r2.dtype == t2.dtype +@pytest.mark.parametrize("dt1", _numeric_types) +@pytest.mark.parametrize("dt2", _numeric_types) +def test_tensordot_type_promotion2(dt1, dt2): + get_queue_or_skip() + + q = get_queue_or_skip() + skip_if_dtype_not_supported(dt1, q) + skip_if_dtype_not_supported(dt2, q) + + m1 = dpt.ones((10, 10), dtype=dt1) + m2 = dpt.ones((10, 10), dtype=dt2) + + r = dpt.tensordot(m1, m2, axes=1) + assert r.shape == ( + 10, + 10, + ) + + @pytest.mark.parametrize("dtype", _numeric_types) def test_vecdot_1d(dtype): q = get_queue_or_skip() From 2e448dc8a804a31dcd590891fcfee385e3800e59 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 16 Jan 2024 17:26:02 -0800 Subject: [PATCH 37/48] Addresses flaws in gemm tree kernel logic Previously, assertions for calling a full tree reduction with only a single work-group of elements could be tripped The kernel logic has been changed such that this is no longer possible --- .../include/kernels/linalg_functions/gemm.hpp | 2397 +++++++++-------- 1 file changed, 1220 insertions(+), 1177 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index a811f20d17..255e008777 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -71,7 +71,202 @@ void scale_gemm_nm_parameters(const size_t &local_mem_size, using dpctl::tensor::sycl_utils::choose_workgroup_size; template -class gemm_reduction_over_group_temps_strided_krn; +class gemm_seq_reduction_krn; + +template +class gemm_tree_reduction_krn; + +template +sycl::event single_reduction_for_gemm(sycl::queue &exec_q, + T *tmp_tp, + T *res_tp, + T identity_val, + size_t iter_nelems, + size_t reduction_nelems, + size_t reduction_groups, + size_t wg, + size_t max_wg, + size_t preferred_reductions_per_wi, + size_t reductions_per_wi, + int res_nd, + py::ssize_t res_offset, + const py::ssize_t *res_shapes_strides, + const std::vector &depends) +{ + sycl::event red_ev; + if (reduction_nelems < wg) { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems)); + }); + } + else { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_tree_reduction_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + iter_nelems, reductions_per_wi)); + }); + } + return red_ev; +} + +template +sycl::event +single_reduction_for_gemm_contig(sycl::queue &exec_q, + T *tmp_tp, + T *res_tp, + T identity_val, + size_t iter_nelems, + size_t reduction_nelems, + size_t reduction_groups, + size_t wg, + size_t max_wg, + size_t preferred_reductions_per_wi, + size_t reductions_per_wi, + const std::vector &depends) +{ + sycl::event red_ev; + if (reduction_nelems < wg) { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems)); + }); + } + else { + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + + InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + ReductionIndexerT reduction_indexer{ + 0, static_cast(reduction_nelems), + static_cast(iter_nelems)}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + + using KernelName = class gemm_tree_reduction_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems, + iter_nelems, reductions_per_wi)); + }); + } + return red_ev; +} template sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, @@ -89,7 +284,7 @@ sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, int res_nd, py::ssize_t res_offset, const py::ssize_t *res_shape_strides, - std::vector depends = {}) + const std::vector &depends) { const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler @@ -116,7 +311,7 @@ sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - using KernelName = class gemm_reduction_over_group_temps_strided_krn< + using KernelName = class gemm_tree_reduction_krn< T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; cgh.parallel_for( sycl::nd_range<1>(globalRange, localRange), @@ -140,47 +335,43 @@ sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, assert(reduction_groups_ > 1); // keep reducing - sycl::event partial_reduction_ev = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(dependent_ev); + sycl::event partial_reduction_ev = exec_q.submit([&](sycl::handler + &cgh) { + cgh.depends_on(dependent_ev); - using InputIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - InputIndexerT, ResIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - InputIndexerT inp_indexer{ - 0, static_cast(iter_nelems), - static_cast(reduction_groups_)}; - ResIndexerT res_iter_indexer{}; + InputIndexerT inp_indexer{ + 0, static_cast(iter_nelems), + static_cast(reduction_groups_)}; + ResIndexerT res_iter_indexer{}; - InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, - res_iter_indexer}; + InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; - ReductionIndexerT reduction_indexer{}; + ReductionIndexerT reduction_indexer{}; - auto globalRange = - sycl::range<1>{iter_nelems * reduction_groups_ * wg}; - auto localRange = sycl::range<1>{wg}; + auto globalRange = + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; - using KernelName = - class gemm_reduction_over_group_temps_strided_krn< - T, T, ReductionOpT, InputOutputIterIndexerT, - ReductionIndexerT>; - cgh.parallel_for( - sycl::nd_range<1>(globalRange, localRange), - ReductionOverGroupNoAtomicFunctor( - temp_arg, temp2_arg, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - remaining_reduction_nelems, iter_nelems, - reductions_per_wi)); - }); + using KernelName = class gemm_tree_reduction_krn< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; + cgh.parallel_for( + sycl::nd_range<1>(globalRange, localRange), + ReductionOverGroupNoAtomicFunctor( + temp_arg, temp2_arg, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); + }); remaining_reduction_nelems = reduction_groups_; std::swap(temp_arg, temp2_arg); @@ -220,7 +411,7 @@ sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; auto localRange = sycl::range<1>{wg}; - using KernelName = class gemm_reduction_over_group_temps_strided_krn< + using KernelName = class gemm_tree_reduction_krn< T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; cgh.parallel_for( sycl::nd_range<1>(globalRange, localRange), @@ -251,7 +442,7 @@ tree_reduction_for_gemm_contig(sycl::queue &exec_q, size_t max_wg, size_t preferred_reductions_per_wi, size_t reductions_per_wi, - std::vector depends = {}) + const std::vector &depends) { const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler @@ -1904,9 +2095,6 @@ class GemmNoAtomicFunctorThreadK -class gemm_reduction_seq_strided_krn; - template (); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - if (reduction_nelems < wg) { + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); if (!tmp) { @@ -2098,36 +2299,13 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer)); } }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, ResIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, + res_shapes_strides, {gemm_ev}); + sycl::event cleanup_host_task_event = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(red_ev); @@ -2137,111 +2315,104 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q, }); return cleanup_host_task_event; } - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); + else { + assert(reduction_groups > 1); - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t lws = delta_n * delta_k; - size_t lws = delta_n * delta_k; + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + auto ndRange = sycl::nd_range<1>(gRange, lRange); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + }); + // tree_reduction_for_gemm returns sycl::event for reduction + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + res_nd, 0, res_shapes_strides, {gemm_ev}); - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, lhs_indexer, rhs_indexer, - res_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, lhs_indexer, rhs_indexer, - res_indexer)); - } - }); - // tree_reduction_for_gemm returns sycl::event for reduction - sycl::event red_ev = tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, - identity_val, iter_nelems, reduction_nelems, reduction_groups, wg, - max_wg, preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, - res_shapes_strides, {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); }); - }); - return cleanup_host_task_event; + return cleanup_host_task_event; + } } } @@ -2371,7 +2542,20 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, dev.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - if (reduction_nelems < wg) { + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); if (!tmp) { @@ -2450,36 +2634,13 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); } }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, ResIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, + res_shapes_strides, {gemm_ev}); + sycl::event cleanup_host_task_event = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(red_ev); @@ -2489,126 +2650,117 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, }); return cleanup_host_task_event; } + else { + assert(reduction_groups > 1); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; - - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - ResIndexerT res_indexer{}; + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + ResIndexerT res_indexer{}; - size_t lws = wg_delta_n * wg_delta_m; + size_t lws = wg_delta_n * wg_delta_m; - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, ResIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>(lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, + m_blocks, wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, ResIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - lhs_indexer, rhs_indexer, res_indexer)); - } - }); + using KernelName = + class gemm_tree_nm_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, + wi_delta_m>(lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, + wg_delta_n, k, k_blocks, wi_delta_k, m, + m_blocks, wg_delta_m, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); - sycl::event red_ev = tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, - identity_val, iter_nelems, reduction_nelems, reduction_groups, wg, - max_wg, preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, - res_shapes_strides, {gemm_ev}); + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + res_nd, 0, res_shapes_strides, {gemm_ev}); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); }); - }); - return cleanup_host_task_event; + return cleanup_host_task_event; + } } } @@ -2798,7 +2950,20 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, dev.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - if (reduction_nelems < wg) { + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); if (!tmp) { @@ -2862,34 +3027,13 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer)); } }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, NoOpIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + sycl::event cleanup_host_task_event = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(red_ev); @@ -2899,145 +3043,138 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, }); return cleanup_host_task_event; } - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); + else { + assert(reduction_groups > 1); - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = + class gemm_tree_k_krn; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, lhs_indexer, rhs_indexer, + res_indexer)); + } + }); + // tree_reduction_for_gemm_contig returns sycl::event + // for reduction + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); + + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); + }); + + return cleanup_host_task_event; } + } +} - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); - - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; - - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); - - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, lhs_indexer, rhs_indexer, - res_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = - class gemm_tree_k_krn; - cgh.parallel_for( - ndRange, - GemmNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, lhs_indexer, rhs_indexer, - res_indexer)); - } - }); - // tree_reduction_for_gemm_contig returns sycl::event - // for reduction - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, - identity_val, iter_nelems, reduction_nelems, reduction_groups, - wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, - {gemm_ev}); - - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); - - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); - }); - }); - - return cleanup_host_task_event; - } -} - -template -sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, - const lhsTy *lhs_tp, - const rhsTy *rhs_tp, - resTy *res_tp, - size_t n, - size_t k, - size_t m, - std::vector const &depends) -{ - constexpr int wi_delta_n = 2; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - const sycl::device &dev = exec_q.get_device(); - const size_t local_mem_size = - dev.get_info(); - const size_t reserved_slm_size = 512; - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); - - // each group processes delta_k items in a column, - // so no need to allocate temp memory if one group needed - if (k <= wi_delta_k) { +template +sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + size_t n, + size_t k, + size_t m, + std::vector const &depends) +{ + constexpr int wi_delta_n = 2; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const size_t local_mem_size = + dev.get_info(); + const size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); @@ -3125,7 +3262,20 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, dev.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - if (reduction_nelems < wg) { + constexpr size_t preferred_reductions_per_wi = 8; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); if (!tmp) { @@ -3199,34 +3349,13 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); } }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, NoOpIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + sycl::event cleanup_host_task_event = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(red_ev); @@ -3236,124 +3365,113 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, }); return cleanup_host_task_event; } + else { + assert(reduction_groups > 1); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); - - constexpr size_t preferred_reductions_per_wi = 8; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; - size_t lws = wg_delta_n * wg_delta_m; + size_t lws = wg_delta_n * wg_delta_m; - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + auto gwsRange = + sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wi_delta_k * wg_delta_m, cgh); - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - using KernelName = - class gemm_tree_nm_krn; - cgh.parallel_for( - ndRange, GemmNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, - local_A_block, local_B_block, n, wg_delta_n, k, - k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, - lhs_indexer, rhs_indexer, res_indexer)); - } - }); + using KernelName = class gemm_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, - identity_val, iter_nelems, reduction_nelems, reduction_groups, - wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, - {gemm_ev}); + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); }); - }); - return cleanup_host_task_event; + return cleanup_host_task_event; + } } } @@ -5313,23 +5431,37 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, dev.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - if (reduction_nelems < wg) { - resTy *tmp = sycl::malloc_device( - iter_nelems * reduction_nelems, exec_q); - if (!tmp) { - throw std::runtime_error("Unable to allocate device memory"); - } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + resTy *tmp = sycl::malloc_device( + iter_nelems * reduction_nelems, exec_q); + if (!tmp) { + throw std::runtime_error("Unable to allocate device memory"); + } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; OuterInnerDimsIndexerT lhs_indexer( inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); OuterInnerDimsIndexerT rhs_indexer( inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; using dpctl::tensor::offset_utils::UnpackedStridedIndexer; using dpctl::tensor::offset_utils::Strided1DIndexer; @@ -5398,39 +5530,14 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, rhs_indexer, res_indexer)); } }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, ResIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - ResIndexerT res_iter_indexer{ - batch_nd + res_outer_nd, - static_cast(res_batch_offset), - res_shape_strides}; - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + sycl::event cleanup_host_task_event = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(red_ev); @@ -5440,131 +5547,123 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, }); return cleanup_host_task_event; } + else { + assert(reduction_groups > 1); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, - batch_shape_strides); - StridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer( - lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + StridedIndexer rhs_batch_indexer(batch_nd, rhs_batch_offset, + batch_shape_strides + + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t lws = delta_n * delta_k; + size_t lws = delta_n * delta_k; - auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + auto ndRange = sycl::nd_range<1>(gRange, lRange); - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, m_groups>; + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - TmpIndexerT, BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, res_indexer)); - } - }); + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer)); + } + }); - sycl::event red_ev = tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, - identity_val, iter_nelems, reduction_nelems, reduction_groups, wg, - max_wg, preferred_reductions_per_wi, reductions_per_wi, - batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, - {gemm_ev}); + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); }); - }); - return cleanup_host_task_event; + return cleanup_host_task_event; + } } } @@ -5709,7 +5808,20 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, dev.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - if (reduction_nelems < wg) { + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); if (!tmp) { @@ -5726,6 +5838,7 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, OuterInnerDimsIndexerT rhs_indexer( inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; using dpctl::tensor::offset_utils::UnpackedStridedIndexer; using dpctl::tensor::offset_utils::Strided1DIndexer; @@ -5804,39 +5917,14 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer)); } }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, ResIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - ResIndexerT res_iter_indexer{ - batch_nd + res_outer_nd, - static_cast(res_batch_offset), - res_shape_strides}; - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - res_iter_indexer}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + sycl::event cleanup_host_task_event = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(red_ev); @@ -5846,142 +5934,133 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, }); return cleanup_host_task_event; } + else { + assert(reduction_groups > 1); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + Strided1DIndexer tmp_batch_indexer( + 0, static_cast(batch_nelems), n * m); + BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::StridedIndexer; - using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, - lhs_outer_inner_shapes_strides); - OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, - rhs_outer_inner_shapes_strides); - TmpIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::StridedIndexer; - using dpctl::tensor::offset_utils::UnpackedStridedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< - StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; - StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, - batch_shape_strides); - UnpackedStridedIndexer rhs_batch_indexer( - batch_nd, rhs_batch_offset, batch_shape_strides, - batch_shape_strides + 2 * batch_nd); - Strided1DIndexer tmp_batch_indexer( - 0, static_cast(batch_nelems), n * m); - BatchDimsIndexerT batch_indexer( - lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + size_t lws = wg_delta_n * wg_delta_m; - size_t lws = wg_delta_n * wg_delta_m; + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, TmpIndexerT, BatchDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, wi_delta_k, - m, m_blocks, wg_delta_m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, res_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, TmpIndexerT, - BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, TmpIndexerT, BatchDimsIndexerT, - wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, wi_delta_k, - m, m_blocks, wg_delta_m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, res_indexer)); - } - }); + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + TmpIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, TmpIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + } + }); - sycl::event red_ev = tree_reduction_for_gemm( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, - identity_val, iter_nelems, reduction_nelems, reduction_groups, wg, - max_wg, preferred_reductions_per_wi, reductions_per_wi, - batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, - {gemm_ev}); + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); }); - }); - return cleanup_host_task_event; + return cleanup_host_task_event; + } } } @@ -6211,7 +6290,20 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, dev.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - if (reduction_nelems < wg) { + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); if (!tmp) { @@ -6290,34 +6382,13 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, rhs_indexer, tmp_indexer)); } }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, NoOpIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + sycl::event cleanup_host_task_event = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(red_ev); @@ -6327,126 +6398,116 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, }); return cleanup_host_task_event; } + else { + assert(reduction_groups > 1); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * m)}); + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t lws = delta_n * delta_k; + size_t lws = delta_n * delta_k; - auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); + auto gRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + auto ndRange = sycl::nd_range<1>(gRange, lRange); - if constexpr (m_groups == 1) { - using LocAccT = sycl::local_accessor; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); + if constexpr (m_groups == 1) { + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, tmp_indexer)); - } - else { - using LocAccT = - sycl::local_accessor, 1>; - LocAccT local_B_block(n_wi * delta_k, cgh); - LocAccT workspace(delta_n * delta_k, cgh); + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + else { + using LocAccT = + sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); - using KernelName = class gemm_batch_tree_k_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadK< - lhsTy, rhsTy, resTy, LocAccT, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>( - lhs_tp, rhs_tp, partially_reduced_tmp, workspace, - local_B_block, n, n_blocks, delta_n, k, k_blocks, - delta_k, n_wi, m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, tmp_indexer)); - } - }); + using KernelName = class gemm_batch_tree_k_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, m_groups>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK< + lhsTy, rhsTy, resTy, LocAccT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, m_groups>( + lhs_tp, rhs_tp, partially_reduced_tmp, workspace, + local_B_block, n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer)); + } + }); - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, - identity_val, iter_nelems, reduction_nelems, reduction_groups, - wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, - {gemm_ev}); + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); }); - }); - return cleanup_host_task_event; + return cleanup_host_task_event; + } } } @@ -6583,7 +6644,20 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, dev.get_info(); size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); - if (reduction_nelems < wg) { + constexpr size_t preferred_reductions_per_wi = 4; + size_t reductions_per_wi(preferred_reductions_per_wi); + + size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + constexpr size_t max_max_wg = 2048; + size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { resTy *tmp = sycl::malloc_device( iter_nelems * reduction_nelems, exec_q); if (!tmp) { @@ -6672,34 +6746,13 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, tmp_indexer)); } }); - sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(gemm_ev); - - using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - using InputOutputIterIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< - NoOpIndexerT, NoOpIndexerT>; - using ReductionIndexerT = - dpctl::tensor::offset_utils::Strided1DIndexer; - - InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, - NoOpIndexerT{}}; - ReductionIndexerT reduction_indexer{ - 0, static_cast(reduction_nelems), - static_cast(iter_nelems)}; - - sycl::range<1> iter_range{iter_nelems}; - - cgh.parallel_for>( - iter_range, SequentialReduction( - tmp, res_tp, ReductionOpT(), identity_val, - in_out_iter_indexer, reduction_indexer, - reduction_nelems)); - }); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + sycl::event cleanup_host_task_event = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(red_ev); @@ -6709,138 +6762,128 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, }); return cleanup_host_task_event; } + else { + assert(reduction_groups > 1); - // max_max_wg prevents running out of resources on CPU - constexpr size_t max_max_wg = 2048; - size_t max_wg = std::min( - max_max_wg, - dev.get_info() / 2); - - constexpr size_t preferred_reductions_per_wi = 4; - size_t reductions_per_wi(preferred_reductions_per_wi); - - size_t reduction_groups = - (reduction_nelems + preferred_reductions_per_wi * wg - 1) / - (preferred_reductions_per_wi * wg); - assert(reduction_groups > 1); - - resTy *partially_reduced_tmp = sycl::malloc_device( - iter_nelems * (/* temp */ reduction_nelems + - /* first reduction temp */ reduction_groups), - exec_q); - resTy *partially_reduced_tmp2 = nullptr; + resTy *partially_reduced_tmp = sycl::malloc_device( + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups), + exec_q); + resTy *partially_reduced_tmp2 = nullptr; - if (partially_reduced_tmp == nullptr) { - throw std::runtime_error("Unable to allocate device_memory"); - } - else { - partially_reduced_tmp2 = - partially_reduced_tmp + reduction_nelems * iter_nelems; - } + if (partially_reduced_tmp == nullptr) { + throw std::runtime_error("Unable to allocate device_memory"); + } + else { + partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); - using OuterInnerDimsIndexerT = - dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT tmp_indexer{}; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + Strided1DIndexer, Strided1DIndexer, Strided1DIndexer>; - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * m)}); + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); - size_t lws = wg_delta_n * wg_delta_m; + size_t lws = wg_delta_n * wg_delta_m; - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t n_blocks = ((n + wi_delta_n * wg_delta_n - 1) / + (wi_delta_n * wg_delta_n)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t m_blocks = ((m + wi_delta_m * wg_delta_m - 1) / + (wi_delta_m * wg_delta_m)); - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * + m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); - if constexpr (wi_delta_m == 1) { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = sycl::local_accessor; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); + if constexpr (wi_delta_m == 1) { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, wi_delta_k, - m, m_blocks, wg_delta_m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, tmp_indexer)); - } - else { - using LocAccT1 = sycl::local_accessor; - LocAccT1 local_A_block( - sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), - cgh); - using LocAccT2 = - sycl::local_accessor, 1>; - LocAccT2 local_B_block(sycl::range<1>(wi_delta_k * wg_delta_m), - cgh); + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + } + else { + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block( + sycl::range<1>((wi_delta_n * wg_delta_n) * wi_delta_k), + cgh); + using LocAccT2 = + sycl::local_accessor, 1>; + LocAccT2 local_B_block( + sycl::range<1>(wi_delta_k * wg_delta_m), cgh); - using KernelName = class gemm_batch_tree_nm_krn< - lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, - OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; - cgh.parallel_for( - ndRange, - GemmBatchNoAtomicFunctorThreadNM< - lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, - OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, - BatchDimsIndexerT, wi_delta_n, wi_delta_m>( - lhs_tp, rhs_tp, partially_reduced_tmp, local_A_block, - local_B_block, n, wg_delta_n, k, k_blocks, wi_delta_k, - m, m_blocks, wg_delta_m, batch_nelems, batch_indexer, - lhs_indexer, rhs_indexer, tmp_indexer)); - } - }); + using KernelName = class gemm_batch_tree_nm_krn< + lhsTy, rhsTy, resTy, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, BatchDimsIndexerT, wi_delta_m>; + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, + BatchDimsIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, partially_reduced_tmp, + local_A_block, local_B_block, n, wg_delta_n, k, + k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, + batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer)); + } + }); - sycl::event red_ev = - tree_reduction_for_gemm_contig( - exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, - identity_val, iter_nelems, reduction_nelems, reduction_groups, - wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, - {gemm_ev}); + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); - sycl::event cleanup_host_task_event = - exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(red_ev); - const sycl::context &ctx = exec_q.get_context(); + sycl::event cleanup_host_task_event = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(red_ev); + const sycl::context &ctx = exec_q.get_context(); - cgh.host_task([ctx, partially_reduced_tmp] { - sycl::free(partially_reduced_tmp, ctx); + cgh.host_task([ctx, partially_reduced_tmp] { + sycl::free(partially_reduced_tmp, ctx); + }); }); - }); - return cleanup_host_task_event; + return cleanup_host_task_event; + } } } From 71ef29443500e8159996b1cd7c354c33b6e51389 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 16 Jan 2024 18:21:14 -0800 Subject: [PATCH 38/48] Implements `__matmul__`, `__imatmul__`, and `__rmatmul__` operators for usm_ndarray --- dpctl/tensor/_usmarray.pyx | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 284de1cbe1..3f2f999e1d 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -992,7 +992,7 @@ cdef class usm_ndarray: return dpctl.tensor.less(self, other) def __matmul__(first, other): - return NotImplemented + return dpctl.tensor.matmul(first, other) def __mod__(first, other): return dpctl.tensor.remainder(first, other) @@ -1012,11 +1012,8 @@ cdef class usm_ndarray: def __pos__(self): return dpctl.tensor.positive(self) - def __pow__(first, other, mod): - if mod is None: - return dpctl.tensor.pow(first, other) - else: - return NotImplemented + def __pow__(first, other): + return dpctl.tensor.pow(first, other) def __rshift__(first, other): return dpctl.tensor.bitwise_right_shift(first, other) @@ -1131,7 +1128,7 @@ cdef class usm_ndarray: return dpctl.tensor.bitwise_left_shift(other, self) def __rmatmul__(self, other): - return NotImplemented + return dpctl.tensor.matmul(other, self) def __rmod__(self, other): return dpctl.tensor.remainder(other, self) @@ -1170,11 +1167,7 @@ cdef class usm_ndarray: return dpctl.tensor.bitwise_left_shift(self, other, out=self) def __imatmul__(self, other): - res = self.__matmul__(other) - if res is NotImplemented: - return res - self.__setitem__(Ellipsis, res) - return self + return dpctl.tensor.matmul(self, other, out=self) def __imod__(self, other): return dpctl.tensor.remainder(self, other, out=self) From 4ccb6fd4952d2bc3eb55f6d386ec2233d9a91f6b Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 16 Jan 2024 18:21:26 -0800 Subject: [PATCH 39/48] Makes usm_ndarray operator argument names consistent --- dpctl/tensor/_usmarray.pyx | 53 +++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 3f2f999e1d..67e144f798 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -907,15 +907,15 @@ cdef class usm_ndarray: def __abs__(self): return dpctl.tensor.abs(self) - def __add__(first, other): + def __add__(self, other): """ Implementation for operator.add """ - return dpctl.tensor.add(first, other) + return dpctl.tensor.add(self, other) - def __and__(first, other): + def __and__(self, other): "Implementation for operator.and" - return dpctl.tensor.bitwise_and(first, other) + return dpctl.tensor.bitwise_and(self, other) def __dlpack__(self, stream=None): """ @@ -963,8 +963,8 @@ cdef class usm_ndarray: def __eq__(self, other): return dpctl.tensor.equal(self, other) - def __floordiv__(first, other): - return dpctl.tensor.floor_divide(first, other) + def __floordiv__(self, other): + return dpctl.tensor.floor_divide(self, other) def __ge__(self, other): return dpctl.tensor.greater_equal(self, other) @@ -984,21 +984,20 @@ cdef class usm_ndarray: else: raise TypeError("len() of unsized object") - def __lshift__(first, other): - "See comment in __add__" - return dpctl.tensor.bitwise_left_shift(first, other) + def __lshift__(self, other): + return dpctl.tensor.bitwise_left_shift(self, other) def __lt__(self, other): return dpctl.tensor.less(self, other) - def __matmul__(first, other): - return dpctl.tensor.matmul(first, other) + def __matmul__(self, other): + return dpctl.tensor.matmul(self, other) - def __mod__(first, other): - return dpctl.tensor.remainder(first, other) + def __mod__(self, other): + return dpctl.tensor.remainder(self, other) - def __mul__(first, other): - return dpctl.tensor.multiply(first, other) + def __mul__(self, other): + return dpctl.tensor.multiply(self, other) def __ne__(self, other): return dpctl.tensor.not_equal(self, other) @@ -1006,17 +1005,17 @@ cdef class usm_ndarray: def __neg__(self): return dpctl.tensor.negative(self) - def __or__(first, other): - return dpctl.tensor.bitwise_or(first, other) + def __or__(self, other): + return dpctl.tensor.bitwise_or(self, other) def __pos__(self): return dpctl.tensor.positive(self) - def __pow__(first, other): - return dpctl.tensor.pow(first, other) + def __pow__(self, other): + return dpctl.tensor.pow(self, other) - def __rshift__(first, other): - return dpctl.tensor.bitwise_right_shift(first, other) + def __rshift__(self, other): + return dpctl.tensor.bitwise_right_shift(self, other) def __setitem__(self, key, rhs): cdef tuple _meta @@ -1106,14 +1105,14 @@ cdef class usm_ndarray: return - def __sub__(first, other): - return dpctl.tensor.subtract(first, other) + def __sub__(self, other): + return dpctl.tensor.subtract(self, other) - def __truediv__(first, other): - return dpctl.tensor.divide(first, other) + def __truediv__(self, other): + return dpctl.tensor.divide(self, other) - def __xor__(first, other): - return dpctl.tensor.bitwise_xor(first, other) + def __xor__(self, other): + return dpctl.tensor.bitwise_xor(self, other) def __radd__(self, other): return dpctl.tensor.add(other, self) From 877c7621017876d386309fa224c2af77f5c7397a Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 16 Jan 2024 20:32:02 -0800 Subject: [PATCH 40/48] Test changes for `tensordot` Adds a test for axes errors in `tensordot` for negative axes Incorporates test for `tensordot` promotion of both inputs into `test_tensordot_type_promotion` --- dpctl/tests/test_usm_ndarray_linalg.py | 27 +++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 59a9daf6fc..6fffcb4ee1 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -470,24 +470,25 @@ def test_tensordot_promotion(): r2 = dpt.tensordot(t2, t1) assert r2.dtype == t2.dtype + t3 = dpt.zeros((10, 10), dtype="u4") + r3 = dpt.tensordot(t1, t2) + assert r3.dtype == dpt.result_type(t1.dtype, t3.dtype) -@pytest.mark.parametrize("dt1", _numeric_types) -@pytest.mark.parametrize("dt2", _numeric_types) -def test_tensordot_type_promotion2(dt1, dt2): + +def test_tensordot_axes_errors(): get_queue_or_skip() - q = get_queue_or_skip() - skip_if_dtype_not_supported(dt1, q) - skip_if_dtype_not_supported(dt2, q) + m1 = dpt.zeros((10, 10), dtype="i4") + m2 = dpt.zeros((10, 10), dtype="i4") - m1 = dpt.ones((10, 10), dtype=dt1) - m2 = dpt.ones((10, 10), dtype=dt2) + with pytest.raises(ValueError): + dpt.tensordot(m1, m2, axes=-1) - r = dpt.tensordot(m1, m2, axes=1) - assert r.shape == ( - 10, - 10, - ) + with pytest.raises(ValueError): + dpt.tensordot(m1, m2, axes=((-1,), (1,))) + + with pytest.raises(ValueError): + dpt.tensordot(m1, m2, axes=((1,), (-1,))) @pytest.mark.parametrize("dtype", _numeric_types) From 8ad7ca20cdafa2a637674ba791f71dd93dbc4a32 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 16 Jan 2024 20:32:56 -0800 Subject: [PATCH 41/48] Reverts running certain `matmul` tests under gdb --- .github/workflows/conda-package.yml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index e9c4089602..6d04a43ce6 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -191,11 +191,6 @@ jobs: . $CONDA/etc/profile.d/conda.sh conda activate test_dpctl gdb --batch -ex r -ex 'info sharedlibrary' -ex 'set print elements 1000' -ex bt --args ${CONDA_PREFIX}/bin/python -m pytest -q -ra --disable-warnings --pyargs dpctl.tests.elementwise.test_trigonometric::test_trig_order -vv || true - - name: Run test_matmul_strided under gdb - run: | - . $CONDA/etc/profile.d/conda.sh - conda activate test_dpctl - gdb --batch -ex r -ex 'info sharedlibrary' -ex 'set print elements 1000' -ex bt --args ${CONDA_PREFIX}/bin/python -m pytest -q -ra --disable-warnings --pyargs dpctl.tests.test_usm_ndarray_linalg::test_matmul_strided -vv || true - name: Run tests env: SYCL_CACHE_PERSISTENT: 1 @@ -313,9 +308,6 @@ jobs: shell: cmd /C CALL {0} run: >- conda activate dpctl_test && python -m dpctl -f - - name: Run test_matmul_simple2 - run: | - conda activate dpctl_test && python -m pytest -q --pyargs dpctl.tests.test_usm_ndarray_linalg::test_matmul_simple2 -vv || true - name: Run tests shell: cmd /C CALL {0} env: From dc34e1d89052ed38141d92480fbf713ede20e001 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 16 Jan 2024 23:13:09 -0800 Subject: [PATCH 42/48] Fix to typo in `test_tensordot_promotion` --- dpctl/tests/test_usm_ndarray_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 6fffcb4ee1..dcaeabf2bd 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -471,7 +471,7 @@ def test_tensordot_promotion(): assert r2.dtype == t2.dtype t3 = dpt.zeros((10, 10), dtype="u4") - r3 = dpt.tensordot(t1, t2) + r3 = dpt.tensordot(t1, t3) assert r3.dtype == dpt.result_type(t1.dtype, t3.dtype) From d03d16eeb0a6192a0b9fbf2e6189f41710c63f8f Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 17 Jan 2024 10:12:22 -0800 Subject: [PATCH 43/48] Removes unnecessary input type checks in `matmul` --- dpctl/tensor/_linear_algebra_functions.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py index 1f844a89f9..0894ac2077 100644 --- a/dpctl/tensor/_linear_algebra_functions.py +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -754,15 +754,13 @@ def matmul(x1, x2, out=None, dtype=None, order="K"): "Input and output allocation queues are not compatible" ) - if isinstance(x1, dpt.usm_ndarray): - if ti._array_overlap(x1, out) and buf1_dt is None: - out = dpt.empty_like(out) - - if isinstance(x2, dpt.usm_ndarray): - if ti._array_overlap(x2, out) and buf2_dt is None: - # should not reach if out is reallocated - # after being checked against x1 - out = dpt.empty_like(out) + if ti._array_overlap(x1, out) and buf1_dt is None: + out = dpt.empty_like(out) + + if ti._array_overlap(x2, out) and buf2_dt is None: + # should not reach if out is reallocated + # after being checked against x1 + out = dpt.empty_like(out) if buf1_dt is None and buf2_dt is None: if out is None: From 15fa9522df6545371a1a4f2001566ffa31443a6a Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 17 Jan 2024 11:06:06 -0800 Subject: [PATCH 44/48] More tests added to `test_usm_linalg.py` Adds several tests for `matmul` and expands some `tensordot` and `vecdot` tests to improve coverage --- dpctl/tests/test_usm_ndarray_linalg.py | 193 ++++++++++++++++++++++++- 1 file changed, 189 insertions(+), 4 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index dcaeabf2bd..3d304b27d8 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -22,6 +22,7 @@ import dpctl import dpctl.tensor as dpt from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported +from dpctl.utils import ExecutionPlacementError _numeric_types = [ "i1", @@ -233,12 +234,17 @@ def test_matmul_dims_validation(): def test_matmul_broadcasting(): get_queue_or_skip() - m1 = dpt.ones((7, 11, 16)) - m2 = dpt.ones((16, 13)) + for dt1, dt2 in [ + (dpt.int16, dpt.int32), + (dpt.float32, dpt.int16), + (dpt.int32, dpt.uint32), + ]: + m1 = dpt.ones((7, 11, 16), dtype=dt1) + m2 = dpt.ones((16, 13), dtype=dt2) - r = dpt.matmul(m1, m2[dpt.newaxis, ...]) + r = dpt.matmul(m1, m2[dpt.newaxis, ...]) - assert r.shape == (7, 11, 13) + assert r.shape == (7, 11, 13) @pytest.mark.parametrize("dtype", ["i4", "i8", "f4", "c8"]) @@ -347,6 +353,184 @@ def test_matmul_type_promotion(dt1, dt2): ) +def test_matmul_invalid_dtype(): + get_queue_or_skip() + + m1 = dpt.zeros((10, 10), dtype="f4") + m2 = dpt.zeros((10, 10), dtype="f4") + m3 = dpt.zeros((10, 10), dtype="i4") + + with pytest.raises(ValueError): + dpt.matmul(m1, m2, dtype="i4") + + with pytest.raises(ValueError): + dpt.matmul(m1, m3, dtype="i4") + + with pytest.raises(ValueError): + dpt.matmul(m3, m1, dtype="i4") + + +def test_matmul_out_errors(): + q1 = get_queue_or_skip() + q2 = dpctl.SyclQueue() + + sh = (10, 10) + dt = "i4" + m1 = dpt.zeros(sh, dtype=dt, sycl_queue=q1) + m2 = dpt.zeros(sh, dtype=dt, sycl_queue=q1) + + with pytest.raises(TypeError): + dpt.matmul(m1, m2, out=dict()) + + with pytest.raises(ValueError): + dpt.matmul(m1, m2, out=dpt.empty((10,), dtype=dt, sycl_queue=q1)) + + with pytest.raises(ValueError): + dpt.matmul(m1, m2, out=dpt.empty(sh, dtype="f4", sycl_queue=q1)) + + with pytest.raises(ExecutionPlacementError): + dpt.matmul(m1, m2, out=dpt.empty(sh, dtype=dt, sycl_queue=q2)) + + +def test_matmul_order(): + get_queue_or_skip() + + sh = ( + 10, + 10, + ) + sh2 = tuple(2 * dim for dim in sh) + n = sh[-1] + + for dt1, dt2 in zip(["i4", "i4", "f4"], ["i4", "f4", "i4"]): + ar1 = dpt.ones(sh, dtype=dt1, order="C") + ar2 = dpt.ones(sh, dtype=dt2, order="C") + r1 = dpt.matmul(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.matmul(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.matmul(ar1, ar2, order="A") + assert r3.flags.c_contiguous + r4 = dpt.matmul(ar1, ar2, order="K") + assert r4.flags.c_contiguous + + ar1 = dpt.ones(sh, dtype=dt1, order="F") + ar2 = dpt.ones(sh, dtype=dt2, order="F") + r1 = dpt.matmul(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.matmul(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.matmul(ar1, ar2, order="A") + assert r3.flags.f_contiguous + r4 = dpt.matmul(ar1, ar2, order="K") + assert r4.flags.f_contiguous + + ar1 = dpt.ones(sh2, dtype=dt1, order="C")[:10, ::-2] + ar2 = dpt.ones(sh2, dtype=dt2, order="C")[:10, ::-2] + r4 = dpt.matmul(ar1, ar2, order="K") + assert r4.strides == (n, -1) + r5 = dpt.matmul(ar1, ar2, order="C") + assert r5.strides == (n, 1) + + ar1 = dpt.ones(sh2, dtype=dt1, order="C")[:10, ::-2].mT + ar2 = dpt.ones(sh2, dtype=dt2, order="C")[:10, ::-2].mT + r4 = dpt.matmul(ar1, ar2, order="K") + assert r4.strides == (-1, n) + r5 = dpt.matmul(ar1, ar2, order="C") + assert r5.strides == (n, 1) + + +def test_matmul_invalid_order(): + get_queue_or_skip() + + sh = ( + 10, + 10, + ) + dt = "i4" + + ar1 = dpt.ones(sh, dtype=dt, order="C") + ar2 = dpt.ones(sh, dtype=dt, order="C") + r = dpt.matmul(ar1, ar2, order="invalid") + assert r.flags.c_contiguous + + ar1 = dpt.ones(sh, dtype=dt, order="F") + ar2 = dpt.ones(sh, dtype=dt, order="F") + r = dpt.matmul(ar1, ar2, order="invalid") + assert r.flags.f_contiguous + + +def test_matmul_compute_follows_data(): + q1 = get_queue_or_skip() + q2 = dpctl.SyclQueue() + + sh = ( + 10, + 10, + ) + dt = "i4" + m1 = dpt.zeros(sh, dtype=dt, sycl_queue=q1) + m2 = dpt.zeros(sh, dtype=dt, sycl_queue=q2) + + with pytest.raises(ExecutionPlacementError): + dpt.matmul(m1, m2) + + +def test_matmul_inplace_broadcasting(): + get_queue_or_skip() + + sh = (3, 5, 5) + dt = "i4" + + m1 = dpt.ones((3, 5, 5), dtype=dt) + m2 = dpt.ones((1, 5, 5), dtype=dt) + m1 @= m2 + assert dpt.all(m1 == dpt.full(sh, 5, dtype=dt)) + + +def test_matmul_prepend_dims(): + get_queue_or_skip() + + n = 5 + for dt1, dt2 in [ + (dpt.int32, dpt.int32), + (dpt.int32, dpt.int64), + (dpt.int64, dpt.int32), + (dpt.int32, dpt.uint32), + ]: + m = dpt.ones((n, 4), dtype=dt1) + v = dpt.ones((4,), dtype=dt2) + r = dpt.matmul(m, v) + assert r.shape == (n,) + + r = dpt.matmul(v, m.mT) + assert r.shape == (n,) + + +def test_matmul_inplace_same_tensors(): + get_queue_or_skip() + + n = 5 + sh = ( + n, + n, + ) + + ar1 = dpt.ones(sh, dtype="i4") + ar1 @= ar1 + assert dpt.all(ar1 == dpt.full(sh, n, dtype="i4")) + + ar1 = dpt.ones(sh, dtype="i8") + ar2 = dpt.ones(sh, dtype="i4") + dpt.matmul(ar1, ar2, out=ar1) + assert dpt.all(ar1 == dpt.full(sh, n, dtype=ar1.dtype)) + + ar1 = dpt.ones(sh, dtype="i4") + ar2 = dpt.ones(sh, dtype="i8") + dpt.matmul(ar1, ar2, out=ar2) + assert dpt.all(ar2 == dpt.full(sh, n, dtype=ar2.dtype)) + + @pytest.mark.parametrize("dtype", _numeric_types) def test_tensordot_outer(dtype): q = get_queue_or_skip() @@ -621,6 +805,7 @@ def test_vecdot_broadcast(): (dpt.int32, dpt.int32), (dpt.int32, dpt.int64), (dpt.int64, dpt.int32), + (dpt.int32, dpt.uint32), ]: m1 = dpt.zeros((1, 5), dtype=dt1) m2 = dpt.zeros((5, 5), dtype=dt2) From 3ce9b59dd29faaca1d15c6376f39f15ee7943d65 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 17 Jan 2024 12:54:17 -0600 Subject: [PATCH 45/48] Use result_type with tensors to take device capability into account --- dpctl/tests/test_usm_ndarray_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 3d304b27d8..dab31ec829 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -656,7 +656,7 @@ def test_tensordot_promotion(): t3 = dpt.zeros((10, 10), dtype="u4") r3 = dpt.tensordot(t1, t3) - assert r3.dtype == dpt.result_type(t1.dtype, t3.dtype) + assert r3.dtype == dpt.result_type(t1, t3) def test_tensordot_axes_errors(): From 03c36eb3d91675da299c0c35fadbf6b83aee79bf Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 17 Jan 2024 13:08:20 -0600 Subject: [PATCH 46/48] Use order keyword in test of type promotion for matmul --- dpctl/tests/test_usm_ndarray_linalg.py | 31 +++++++++++++++++++------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index dab31ec829..881729136d 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -336,21 +336,36 @@ def test_matmul_dtype(): @pytest.mark.parametrize("dt1", _numeric_types) @pytest.mark.parametrize("dt2", _numeric_types) -def test_matmul_type_promotion(dt1, dt2): +@pytest.mark.parametrize("order", ["C", "K"]) +def test_matmul_type_promotion(dt1, dt2, order): get_queue_or_skip() q = get_queue_or_skip() skip_if_dtype_not_supported(dt1, q) skip_if_dtype_not_supported(dt2, q) - m1 = dpt.ones((10, 10), dtype=dt1) - m2 = dpt.ones((10, 10), dtype=dt2) + b, n, k, m = 8, 10, 17, 10 + m1 = dpt.ones((1, n, k), dtype=dt1) + m2 = dpt.ones((b, k, m), dtype=dt2) + expected_dt = dpt.result_type(m1, m2) - r = dpt.matmul(m1, m2) - assert r.shape == ( - 10, - 10, - ) + r = dpt.matmul(m1, m2, order=order) + assert r.shape == (b, n, m) + assert r.dtype == expected_dt + + m1 = dpt.ones((b, n, k), dtype=dt1) + m2 = dpt.ones((1, k, m), dtype=dt2) + + r = dpt.matmul(m1, m2, order=order) + assert r.shape == (b, n, m) + assert r.dtype == expected_dt + + m1 = dpt.ones((n, k), dtype=dt1) + m2 = dpt.ones((k, m), dtype=dt2) + + r = dpt.matmul(m1, m2, order=order) + assert r.shape == (n, m) + assert r.dtype == expected_dt def test_matmul_invalid_dtype(): From 1eaadb65d3094d6e443b8f80ecac7e9f01c2684d Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 18 Jan 2024 09:18:32 -0800 Subject: [PATCH 47/48] Make generic k-threaded kernels handle arbitrary m_groups Also increases hyper-parameters for k-threaded kernels to improve performance --- .../include/kernels/linalg_functions/gemm.hpp | 90 +++++++++++-------- 1 file changed, 52 insertions(+), 38 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 255e008777..40fa89e583 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -979,13 +979,16 @@ class GemmFunctorThreadK for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { size_t sq = s + q; size_t sqmj = sq * m + j; - local_B_block[local_s + q] = sycl::vec( - (sq < k && j < m) - ? static_cast(rhs[rhs_indexer(sqmj)]) - : identity_, - (sq < k && j + 1 < m) - ? static_cast(rhs[rhs_indexer(sqmj + 1)]) - : identity_); + sycl::vec local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; } } @@ -1241,7 +1244,7 @@ sycl::event gemm_impl(sycl::queue &exec_q, constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(64); - size_t delta_n(16); + size_t delta_n(32); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -1277,7 +1280,7 @@ sycl::event gemm_impl(sycl::queue &exec_q, constexpr size_t m_groups = 2; size_t delta_k(4); size_t n_wi(64); - size_t delta_n(16); + size_t delta_n(32); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -1411,7 +1414,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(64); - size_t delta_n(16); + size_t delta_n(32); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -1447,7 +1450,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, constexpr size_t m_groups = 2; size_t delta_k(4); size_t n_wi(64); - size_t delta_n(16); + size_t delta_n(32); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -1922,13 +1925,16 @@ class GemmNoAtomicFunctorThreadK for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { size_t sq = s + q; size_t sqmj = sq * m + j; - local_B_block[local_s + q] = sycl::vec( - (sq < k && j < m) - ? static_cast(rhs[rhs_indexer(sqmj)]) - : identity_, - (sq < k && j + 1 < m) - ? static_cast(rhs[rhs_indexer(sqmj + 1)]) - : identity_); + sycl::vec local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; } } @@ -2130,7 +2136,7 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q, { size_t delta_k(4); size_t n_wi(64); - size_t delta_n(16); + size_t delta_n(32); const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = @@ -2862,7 +2868,7 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, { size_t delta_k(4); size_t n_wi(64); - size_t delta_n(16); + size_t delta_n(32); const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = @@ -3986,14 +3992,16 @@ class GemmBatchFunctorThreadK for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { size_t sq = s + q; size_t sqmj = sq * m + j; - local_B_block[local_s + q] = sycl::vec( - (sq < k && j < m) - ? static_cast(rhs[rhs_offset + rhs_indexer(sqmj)]) - : identity_, - (sq < k && j + 1 < m) - ? static_cast( - rhs[rhs_offset + rhs_indexer(sqmj + 1)]) - : identity_); + sycl::vec local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_offset + rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; } } @@ -4310,7 +4318,7 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(64); - size_t delta_n(16); + size_t delta_n(32); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -4351,7 +4359,7 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, constexpr size_t m_groups = 2; size_t delta_k(4); size_t n_wi(64); - size_t delta_n(16); + size_t delta_n(32); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -4516,7 +4524,7 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(64); - size_t delta_n(16); + size_t delta_n(32); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -4557,7 +4565,7 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, constexpr size_t m_groups = 2; size_t delta_k(4); size_t n_wi(64); - size_t delta_n(16); + size_t delta_n(32); gemm_detail::scale_gemm_k_parameters( local_mem_size, reserved_slm_size, delta_k, @@ -5096,10 +5104,16 @@ class GemmBatchNoAtomicFunctorThreadK for (size_t q = 0; q < n_wi * delta_k; q += delta_k) { size_t sq = s + q; size_t sqmj = sq * m + j; - local_B_block[local_s + q] = - (sq < k && j < m) - ? static_cast(rhs[rhs_offset + rhs_indexer(sqmj)]) - : identity_; + sycl::vec local_B_vec; +#pragma unroll + for (size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_offset + rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; } } @@ -5331,7 +5345,7 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, { size_t delta_k(4); size_t n_wi(64); - size_t delta_n(16); + size_t delta_n(32); const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = @@ -6184,7 +6198,7 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, { size_t delta_k(4); size_t n_wi(64); - size_t delta_n(16); + size_t delta_n(32); const sycl::device &dev = exec_q.get_device(); const size_t local_mem_size = From 879b8bb8f08e318fdf6053ab502b1f0cc1856c85 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 22 Jan 2024 13:15:59 -0800 Subject: [PATCH 48/48] Adjusted dispatch logic for gemm kernels Now uses m_groups = 4 when m > 4, and otherwise, m_groups = 1 to improve performance --- .../include/kernels/linalg_functions/gemm.hpp | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 40fa89e583..a4a5d3b929 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1240,7 +1240,7 @@ sycl::event gemm_impl(sycl::queue &exec_q, rhs_shape_strides); OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); - if (m == 1) { + if (m < 4) { constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(64); @@ -1277,7 +1277,7 @@ sycl::event gemm_impl(sycl::queue &exec_q, m, lhs_indexer, rhs_indexer, res_indexer)); } else if (k > n && k > m) { - constexpr size_t m_groups = 2; + constexpr size_t m_groups = 4; size_t delta_k(4); size_t n_wi(64); size_t delta_n(32); @@ -1410,7 +1410,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, OuterInnerIndexerT rhs_indexer{}; OuterInnerIndexerT res_indexer{}; - if (m == 1) { + if (m < 4) { constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(64); @@ -1447,7 +1447,7 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, m, lhs_indexer, rhs_indexer, res_indexer)); } else if (k > n && k > m) { - constexpr size_t m_groups = 2; + constexpr size_t m_groups = 4; size_t delta_k(4); size_t n_wi(64); size_t delta_n(32); @@ -2811,10 +2811,10 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, return gemm_no_reduction_ev; } - if ((k > n && k > m) || m == 1) { + if ((k > n && k > m) || m < 4) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { - if (m == 1) { + if (m < 4) { return gemm_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, @@ -2822,7 +2822,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, depends); } else { - return gemm_tree_k_impl( + return gemm_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, @@ -3504,15 +3504,15 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, return gemm_no_reduction_ev; } - if ((k > n && k > m) || m == 1) { + if ((k > n && k > m) || m < 4) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { - if (m == 1) { + if (m < 4) { return gemm_contig_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } else { - return gemm_contig_tree_k_impl( + return gemm_contig_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } } @@ -4314,7 +4314,7 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, rhs_batch_offset, res_batch_offset, batch_shape_strides); - if (m == 1) { + if (m < 4) { constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(64); @@ -4356,7 +4356,7 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, rhs_indexer, res_indexer)); } else if (k > n && k > m) { - constexpr size_t m_groups = 2; + constexpr size_t m_groups = 4; size_t delta_k(4); size_t n_wi(64); size_t delta_n(32); @@ -4520,7 +4520,7 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, static_cast(k * m)}, Strided1DIndexer{0, static_cast(batch_nelems), static_cast(n * m)}); - if (m == 1) { + if (m < 4) { constexpr size_t m_groups = 1; size_t delta_k(4); size_t n_wi(64); @@ -4562,7 +4562,7 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, rhs_indexer, res_indexer)); } else if (k > n && k > m) { - constexpr size_t m_groups = 2; + constexpr size_t m_groups = 4; size_t delta_k(4); size_t n_wi(64); size_t delta_n(32); @@ -6129,10 +6129,10 @@ gemm_batch_tree_impl(sycl::queue &exec_q, return gemm_batch_no_reduction_ev; } - if ((k > n && k > m) || m == 1) { + if ((k > n && k > m) || m < 4) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { - if (m == 1) { + if (m < 4) { return gemm_batch_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, batch_shape_strides, lhs_batch_offset, @@ -6142,7 +6142,7 @@ gemm_batch_tree_impl(sycl::queue &exec_q, res_outer_shapes_strides, res_shape_strides, depends); } else { - return gemm_batch_tree_k_impl( + return gemm_batch_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, batch_shape_strides, lhs_batch_offset, rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, @@ -6931,16 +6931,16 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, return gemm_batch_no_reduction_ev; } - if ((k > n && k > m) || m == 1) { + if ((k > n && k > m) || m < 4) { using dpctl::tensor::type_utils::is_complex; if constexpr (!is_complex::value) { - if (m == 1) { + if (m < 4) { return gemm_batch_contig_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); } else { - return gemm_batch_contig_tree_k_impl( + return gemm_batch_contig_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); }