diff --git a/intel_pytorch_extension_py/ops/embeddingbag.py b/intel_pytorch_extension_py/ops/embeddingbag.py index 2a0016d93..6b8a46528 100644 --- a/intel_pytorch_extension_py/ops/embeddingbag.py +++ b/intel_pytorch_extension_py/ops/embeddingbag.py @@ -3,6 +3,8 @@ from torch.autograd import Function import _torch_ipex as core +''' +# extension for BF16 fast path only torch_embedding_bag = torch.embedding_bag def embeddingbag(weights, inputs, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset): if weights.dtype == torch.float: @@ -12,21 +14,41 @@ def embeddingbag(weights, inputs, offsets, scale_grad_by_freq, mode, sparse, per ret = (ret, None, None, None) else: assert(0, "unimplement embeddingbag path in extension") - +''' +def embeddingbag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset): + ret = EmbeddingBagFunction.apply(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset) return ret class EmbeddingBagFunction(Function): + ''' @staticmethod def forward(ctx, weights, inputs, offsets): ctx.save_for_backward(weights, inputs, offsets) output = core.embedding_bag_forward(weights, inputs, offsets) return output + ''' + @staticmethod + def forward(ctx, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset): + ctx.scale_grad_by_freq = scale_grad_by_freq + ctx.mode = mode + ctx.sparse = sparse + ctx.num_weight = weight.size(0) + ctx.save_for_backward(indices, offsets, per_sample_weights) + ret = core.embedding_bag_forward(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset) + return ret + ''' @staticmethod def backward(ctx, grad_out): weights, inputs, offsets = ctx.saved_tensors grad_weight = core.embedding_bag_backward(grad_out, weights, inputs, offsets) return (grad_weight, None, None) + ''' + @staticmethod + def backward(ctx, grad, offset2bag, bag_size, maximum_indices): + indices, offsets, per_sample_weights = ctx.saved_tensors + grad_weight = core.embedding_bag_backward(grad, indices, offsets, offset2bag, bag_size, maximum_indices, ctx.num_weight, ctx.scale_grad_by_freq, ctx.mode, ctx.sparse, per_sample_weights) + return grad_weight, None, None, None, None, None, None, None torch.embedding_bag = embeddingbag diff --git a/torch_ipex/csrc/cpu/CMakeLists.txt b/torch_ipex/csrc/cpu/CMakeLists.txt index a7c7751ee..bf76efe6e 100644 --- a/torch_ipex/csrc/cpu/CMakeLists.txt +++ b/torch_ipex/csrc/cpu/CMakeLists.txt @@ -1,4 +1,4 @@ -FILE(GLOB _CPU_SRCS *.cpp dbl/*.cpp bf16/*.cpp) +FILE(GLOB _CPU_SRCS *.cpp dbl/*.cpp bf16/*.cpp aten/operators/*.cpp) LIST(APPEND DPCPP_CPU_SRCS ${_CPU_SRCS}) # Pass to parent diff --git a/torch_ipex/csrc/cpu/ExtendOPs.cpp b/torch_ipex/csrc/cpu/ExtendOPs.cpp index 69d08d1bb..12e9d16b4 100644 --- a/torch_ipex/csrc/cpu/ExtendOPs.cpp +++ b/torch_ipex/csrc/cpu/ExtendOPs.cpp @@ -6,6 +6,7 @@ #include "ExtendOPs.h" #include "bf16/vec/bf16_vec_kernel.h" #include "dil/dil.hpp" +#include "aten/aten.hpp" #include "xsmm/libxsmm_utils.h" #include "../utils.h" #include "DevOPs.h" @@ -323,6 +324,7 @@ std::vector AtenIpexTypeExt::interaction_backward(const at::Tensor & } } +#if 0 template static inline at::Tensor _embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weights.is_contiguous()); @@ -422,6 +424,30 @@ at::Tensor AtenIpexTypeExt::embedding_bag_backward(const at::Tensor &grad_out, return _embedding_bag_backward(grad_out, weights, inputs, offsets); } } +#endif + +std::tuple +AtenIpexTypeExt::embedding_bag_forward(const at::Tensor& weight, const at::Tensor& indices, + const at::Tensor& offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, + const c10::optional& per_sample_weights, bool include_last_offset) { + at::Tensor _per_sample_weights; + if(per_sample_weights.has_value()) { + _per_sample_weights = per_sample_weights.value(); + } + return cpu::aten::embedding_bag::embedding_bag_impl(weight, indices, offsets, scale_grad_by_freq, mode, sparse, _per_sample_weights, include_last_offset); +} + +at::Tensor +AtenIpexTypeExt::embedding_bag_backward(const at::Tensor& grad, const at::Tensor& indices, + const at::Tensor& offsets, const at::Tensor& offset2bag, const at::Tensor& bag_size, const at::Tensor& maximum_indices, + int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, + const c10::optional& per_sample_weights) { + at::Tensor _per_sample_weights; + if(per_sample_weights.has_value()) { + _per_sample_weights = per_sample_weights.value(); + } + return cpu::aten::embedding_bag::embedding_bag_backward_impl(grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse, _per_sample_weights); +} at::Tensor AtenIpexTypeExt::linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) { return cpu::AtenIpexCPUDev::dil_linear(input, weight, bias); diff --git a/torch_ipex/csrc/cpu/ExtendOPs.h b/torch_ipex/csrc/cpu/ExtendOPs.h index aa462d9a0..0b4bfc6df 100644 --- a/torch_ipex/csrc/cpu/ExtendOPs.h +++ b/torch_ipex/csrc/cpu/ExtendOPs.h @@ -10,8 +10,19 @@ class AtenIpexTypeExt { static void packed_add_(at::Tensor & top_half, at::Tensor & bot_half, const at::Tensor & grad, float alpha); static at::Tensor interaction_forward(const std::vector & input); static std::vector interaction_backward(const at::Tensor & grad_out, const std::vector & input); - static at::Tensor embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets); - static at::Tensor embedding_bag_backward(const at::Tensor &grad_out, const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets); + //static at::Tensor embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets); + //static at::Tensor embedding_bag_backward(const at::Tensor &grad_out, const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets); + static std::tuple + embedding_bag_forward(const at::Tensor & weight, const at::Tensor & indices, + const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, + const c10::optional& per_sample_weights, bool include_last_offset); + + static at::Tensor + embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, + const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, + int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, + const c10::optional& per_sample_weights); + static at::Tensor linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias); static std::tuple linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array output_mask); static at::Tensor adaptive_avg_pool2d(at::Tensor const& input, at::IntArrayRef output_size); diff --git a/torch_ipex/csrc/cpu/aten/aten.hpp b/torch_ipex/csrc/cpu/aten/aten.hpp new file mode 100644 index 000000000..b033cd3f1 --- /dev/null +++ b/torch_ipex/csrc/cpu/aten/aten.hpp @@ -0,0 +1,41 @@ +/* + *Copyright (c) 2018 Intel Corporation. + * + *Permission is hereby granted, free of charge, to any person obtaining a copy + *of this software and associated documentation files (the "Software"), to deal + *in the Software without restriction, including without limitation the rights + *to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + *copies of the Software, and to permit persons to whom the Software is + *furnished to do so, subject to the following conditions: + * + *The above copyright notice and this permission notice shall be included in + *all copies or substantial portions of the Software. + * + *THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + *IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + *FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + *AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + *LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + *OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + *THE SOFTWARE. + * + */ + +#ifndef _ATEN_HPP +#define _ATEN_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "operators/embedding_bag.hpp" + +#endif diff --git a/torch_ipex/csrc/cpu/aten/operators/embedding_bag.cpp b/torch_ipex/csrc/cpu/aten/operators/embedding_bag.cpp new file mode 100755 index 000000000..bab6988f1 --- /dev/null +++ b/torch_ipex/csrc/cpu/aten/operators/embedding_bag.cpp @@ -0,0 +1,313 @@ +#include "embedding_bag.hpp" +#include "aten_ipex_bridge.h" +#include "cpu/bf16/vec/bf16_vec_kernel.h" + +namespace torch_ipex { +namespace cpu { +namespace aten { +namespace embedding_bag { + +const int MODE_SUM = 0; +const int MODE_MEAN = 1; +const int MODE_MAX = 2; + +static inline void +make_offset2bag(const at::Tensor &offsets, const at::Tensor &indices, at::Tensor& offset2bag) { + offset2bag.index_add_(0, offsets, at::ones_like(offsets)); // offset2bag = [1 0 1 0 1] + offset2bag[0] -= 1; // offset2bag = [0 0 1 0 1] + offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2] +} + +// To save compute, if we are going to go down the fast path case for the 'sum' +// mode, we skip calculating offset2bag, since it is not going to be used. +static inline bool is_bfloat16_tensor(const at::Tensor tensor_) { + if (tensor_.scalar_type() == at::kBFloat16) return true; + return false; +} + +static inline bool embedding_bag_fast_path_sum(const at::Tensor weight, const at::Tensor per_sample_weights, int64_t mode) { + if ((mode != MODE_SUM) || (weight.stride(1) != 1) || per_sample_weights.defined()) return false; + if ((weight.scalar_type() != at::kFloat) && (weight.scalar_type() != at::kBFloat16)) return false; + return true; +} + +template +static inline at::Tensor _embedding_bag_index_add_select_fast(const at::Tensor select_indices, + const at::Tensor src, const at::Tensor offsets, bool include_last_offset) { + int64_t ddim = src.size(1); + auto* src_data = src.data_ptr(); + int64_t output_size = offsets.numel() - 1; + auto* offsets_data = offsets.data_ptr(); + std::vector offsets_include_last; + + if (!include_last_offset) { + output_size = offsets.numel(); + offsets_include_last.resize(output_size + 1); + std::memcpy(offsets_include_last.data(), offsets_data, sizeof(int64_t) * output_size); + offsets_include_last[output_size] = select_indices.numel(); + offsets_data = offsets_include_last.data(); + } + + at::Tensor output = at::empty({output_size, src.size(1)}, src.options()); + auto* output_data = output.data_ptr(); + auto indices_accessor = select_indices.accessor(); + at::parallel_for(0, output_size, 16, [&](int64_t start, int64_t end) { + for (int64_t i = start; i < end; i++) { + auto* out_data_ptr = &output_data[i * ddim]; + zero_ker((T*)out_data_ptr, ddim); + auto inputs_start = offsets_data[i]; + auto inputs_end = offsets_data[i + 1]; + for (int64_t s = inputs_start; s < inputs_end; s++) { + T* select_data_ptr = &src_data[indices_accessor[s] * ddim]; + add_ker((T *)out_data_ptr, (T *)select_data_ptr, ddim); + } + } + }); + + return output; +} + +std::tuple +embedding_bag_impl(const at::Tensor & weight, const at::Tensor & indices, + const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, + const at::Tensor & per_sample_weights, bool include_last_offset) { + + at::Tensor offsets_ = offsets.contiguous(); + if (embedding_bag_fast_path_sum(weight, per_sample_weights, mode)) { + at::Tensor bag_size; + at::Tensor offset2bag; + if (weight.requires_grad()) { + // in MODE_SUM, only initialize bag_size if we need gradients + bag_size = at::native::full(offsets_.sizes(), 0, indices.options()); + offset2bag = at::empty({0}, offsets_.options()); + } + + at::Tensor output; + if(is_bfloat16_tensor(weight)) { + output = _embedding_bag_index_add_select_fast(indices, weight, offsets_, include_last_offset); + } else { + output = _embedding_bag_index_add_select_fast(indices, weight, offsets_, include_last_offset); + } + return std::tuple(output, offset2bag, bag_size, bag_size); + } + + //May need full support for Bfloat16 + auto&& _ipex_weight = bridge::shallowFallbackToCPUTensor(weight); + auto&& _ipex_indices = bridge::shallowFallbackToCPUTensor(indices); + auto&& _ipex_offsets = bridge::shallowFallbackToCPUTensor(offsets_); + auto&& _ipex_per_sample_weights = bridge::shallowFallbackToCPUTensor(per_sample_weights); + auto&& _ipex_result = at::embedding_bag(_ipex_weight, _ipex_indices, _ipex_offsets, scale_grad_by_freq, mode, sparse, _ipex_per_sample_weights, include_last_offset); + static_cast(_ipex_result); // Avoid warnings in case not used + return std::tuple(bridge::shallowUpgradeToDPCPPTensor(std::get<0>(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor(std::get<1>(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor(std::get<2>(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor(std::get<3>(_ipex_result))); +} + +static inline at::Tensor expand_values_if_needed(const at::Tensor& values) { + // expand + if (values.dim() == 0) { + // Mimic Numpy behavior here and treat it as a 1D tensor + return values.expand({1}); + } + + return values; +} + +static inline +at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor& indices, const at::Tensor& values_, c10::ArrayRef size, const at::TensorOptions& options) { + + at::Tensor values = expand_values_if_needed(values_); + assert(options.has_layout() && options.layout() == c10::kSparse); + int64_t sparse_dim = indices.size(0); + int64_t dense_dim = values.dim() - 1; + return at::native::new_with_dims_and_tensor_sparse(sparse_dim, dense_dim, size, indices, values, values.options().layout(c10::kSparse)); +} + +template +static inline at::Tensor embedding_bag_sparse_backward_sum_fast( + const at::Tensor grad, const at::Tensor indices, + const at::Tensor offsets, int num_weights, int mode) { + + assert((mode == MODE_SUM) && (grad.stride(1) == 1)); + + int64_t indices_size0 = indices.size(0); + int64_t ddim = grad.size(1); + at::Tensor index_grad = at::empty({indices_size0, ddim}, grad.options()); + int grad_stride0 = grad.stride(0); + + auto offsets_accessor = offsets.accessor(); + auto offset_numel = offsets.numel(); + + T* gradout_data = index_grad.data_ptr(); + T* grad_data = grad.data_ptr(); + at::parallel_for(0, offset_numel, 16, [&](int64_t start, int64_t end) { + for(auto mb = start; mb < end; mb++) { + int64_t select_off_start = offsets_accessor[mb]; + int64_t select_off_end = (mb < (offset_numel - 1) ? offsets_accessor[mb + 1] : indices_size0); + auto grad_block = grad_data + grad_stride0 * mb; + for (int64_t s = select_off_start; s < select_off_end; s++) { + move_ker((T*)(gradout_data + ddim * s), (T*)grad_block, ddim); + } + } + }); + + int64_t num_features = index_grad.size(-1); + auto weight_size = std::array{{ num_weights, num_features }}; + auto dense_options = index_grad.options(); + + if (index_grad.numel() == 0) { + return _sparse_coo_tensor_unsafe(at::empty({1, 0}, indices.options()), + at::empty({0, num_features}, dense_options), + weight_size); + } + + auto index = indices.reshape({1, -1}); + auto values = index_grad.reshape({-1, num_features}); + + return _sparse_coo_tensor_unsafe(index, values, weight_size); +} + +static inline int64_t +count_and_map_uniq(const at::TensorAccessor& indices_accessor, int64_t indices_length, std::vector& indices_to_index, std::vector& index_to_indices) { + int64_t u = 0; + for (int64_t i = 0; i < indices_length; i++) { + int64_t indices = indices_accessor[i]; + if (indices_to_index[indices] == -1ull) { + indices_to_index[indices] = u; + index_to_indices[u] = indices; + u++; + } + } + return u; +} + +template +static inline at::Tensor embedding_bag_dense_backward_sum_fast(const at::Tensor grad, const at::Tensor indices, const at::Tensor offsets, const at::Tensor offset2bag, int num_weights, int mode) { + + assert((mode == MODE_SUM) && (grad.stride(1) == 1)); + + auto offset_numel = offsets.numel(); + int64_t indices_numel = indices.numel(); + auto indices_accessor = indices.accessor(); + std::vector indices_to_index(num_weights, -1ull); + std::vector index_to_indices; + index_to_indices.reserve(num_weights); + int64_t unique_indices = count_and_map_uniq(indices_accessor, indices_numel, indices_to_index, index_to_indices); + + int max_threads = at::get_num_threads(); + max_threads = (unique_indices < max_threads) ? unique_indices : max_threads; + int64_t avg_chunk_down = unique_indices / max_threads; + std::vector chuck_size(max_threads); + std::vector chuck_sum_size(max_threads + 1); + for (auto i = 0; i < max_threads; i++) { + chuck_size[i] = avg_chunk_down; + } + //make chunk balance among threads as 211 + for (auto i = 0 ; i < unique_indices % max_threads ; i++) { + chuck_size[i] += 1; + } + chuck_sum_size[0] = 0; + for (auto i = 1; i < max_threads; i++) { + chuck_sum_size[i] = chuck_sum_size[i - 1] + chuck_size[i - 1]; + } + chuck_sum_size[max_threads] = unique_indices; + + int64_t ddim = grad.size(1); + + at::Tensor index_grad_weight = at::empty({num_weights, ddim}, grad.options()); + T* gradout_data = index_grad_weight.data_ptr(); + zero_ker((T*)gradout_data, num_weights * ddim); + + std::vector temp_grad_weight(unique_indices * ddim); + float* temp_output = temp_grad_weight.data(); + zero_ker(temp_output, unique_indices * ddim); + + int64_t* offset2bag_data = offset2bag.data_ptr(); + T* grad_data = grad.data_ptr(); + at::parallel_for(0, max_threads, 0, [&](int64_t start, int64_t end) { + for(int k = start; k < end; k++) { + int64_t chunk_start = chuck_sum_size[k]; + int64_t chunk_end = chuck_sum_size[k + 1]; + for (int64_t mb = 0; mb < indices_numel; mb++) { + int64_t indices_num = indices_accessor[mb]; + int64_t index = indices_to_index[indices_num]; + if (index >= chunk_start && index < chunk_end) { + auto s = offset2bag_data[mb]; + add_ker((float*)(temp_output + index * ddim), (T*)(grad_data + s * ddim), ddim); + } + } + for (int64_t index = chunk_start; index < chunk_end; index++) { + auto indices = index_to_indices[index]; + move_ker((T*)(gradout_data + indices * ddim), (float*)(temp_output + index * ddim), ddim); + } + } + }); + + return index_grad_weight; +} + +static inline bool embedding_bag_backward_fast_path_sum(const at::Tensor grad, const at::Tensor indices, const at::Tensor per_sample_weights, bool scale_grad_by_freq, int64_t mode) { + + if ((grad.scalar_type() != at::kFloat) && (grad.scalar_type() != at::kBFloat16)) return false; + if ((mode != MODE_SUM) || (grad.stride(1) != 1)) return false; + if (per_sample_weights.defined() || scale_grad_by_freq) return false; + + return true; +} + +static inline at::Tensor +embedding_bag_get_offset2bag(const at::Tensor indices, const at::Tensor & offsets, const at::Tensor & offset2bag) +{ + auto offset_numel = offsets.numel(); + int64_t indices_numel = indices.numel(); + at::Tensor offset2bag_ ; + if (indices_numel != 0 && offset2bag.numel() == 0) { + if (indices_numel != offset_numel) { + offset2bag_ = at::native::full({indices.sizes()[0] + 1}, 0, indices.options()); + make_offset2bag(offsets, indices, offset2bag_); + offset2bag_.resize_({indices.sizes()[0]}); + } else { + offset2bag_ = offsets.contiguous(); + } + } else { + offset2bag_ = offset2bag; + } + return offset2bag_; +} + +at::Tensor embedding_bag_backward_impl(const at::Tensor & grad, const at::Tensor & indices, + const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, + int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, + const at::Tensor & per_sample_weights) { + if (sparse) { + if (embedding_bag_backward_fast_path_sum(grad, indices, per_sample_weights, scale_grad_by_freq, mode)) { + if (is_bfloat16_tensor(grad)) { + return embedding_bag_sparse_backward_sum_fast(grad, indices, offsets, num_weights, mode); + } else { + return embedding_bag_sparse_backward_sum_fast(grad, indices, offsets, num_weights, mode); + } + } else { + //May need full support for Bfloat16 + at::Tensor offset2bag_ = embedding_bag_get_offset2bag(indices, offsets, offset2bag); + return at::_embedding_bag_sparse_backward(grad, indices, offsets, offset2bag_, + bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights); + } + } else { + at::Tensor offset2bag_ = embedding_bag_get_offset2bag(indices, offsets, offset2bag); + auto grad_c = grad.contiguous(); + if (embedding_bag_backward_fast_path_sum(grad_c, indices, per_sample_weights, scale_grad_by_freq, mode)) { + if (is_bfloat16_tensor(grad)) { + return embedding_bag_dense_backward_sum_fast(grad_c, indices, offsets, offset2bag_, num_weights, mode); + } else { + return embedding_bag_dense_backward_sum_fast(grad_c, indices, offsets, offset2bag_, num_weights, mode); + } + } else { + //May need full support for Bfloat16 + return at::_embedding_bag_dense_backward(grad_c, indices, offsets, offset2bag_, bag_size, + maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights); + } + } +} + +} // namespace embedding_bag +} // namespace aten +} // namespace cpu +} // namespace torch_ipex diff --git a/torch_ipex/csrc/cpu/aten/operators/embedding_bag.hpp b/torch_ipex/csrc/cpu/aten/operators/embedding_bag.hpp new file mode 100755 index 000000000..d85be761b --- /dev/null +++ b/torch_ipex/csrc/cpu/aten/operators/embedding_bag.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +#include + + +namespace torch_ipex { +namespace cpu { +namespace aten { +namespace embedding_bag { + +std::tuple +embedding_bag_impl(const at::Tensor & weight, const at::Tensor & indices, + const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, + const at::Tensor & per_sample_weights, bool include_last_offset); + +at::Tensor embedding_bag_backward_impl(const at::Tensor & grad, const at::Tensor & indices, + const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, + int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, + const at::Tensor & per_sample_weights); + +} // namespace embedding_bag +} // namespace aten +} // namespace cpu +} // namespace torch_ipex diff --git a/torch_ipex/csrc/cpu/bf16/vec/bf16_vec_kernel.h b/torch_ipex/csrc/cpu/bf16/vec/bf16_vec_kernel.h index f58620a55..fd90899a9 100644 --- a/torch_ipex/csrc/cpu/bf16/vec/bf16_vec_kernel.h +++ b/torch_ipex/csrc/cpu/bf16/vec/bf16_vec_kernel.h @@ -42,6 +42,7 @@ inline void packed_bf16_add_ker(at::BFloat16 *a1, at::BFloat16 *a2, at::BFloat16 inline void add_ker(at::BFloat16 *inout, at::BFloat16 *in, int len) { int i = 0; + #pragma unroll(2) for(; i < len - 15; i += 16) { auto x1 = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(inout + i))); auto x2 = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(in + i))); @@ -59,12 +60,25 @@ inline void add_ker(at::BFloat16 *inout, at::BFloat16 *in, int len) { inline void add_ker(float *inout, float *in, int len) { int i = 0; - for(; i < len - 15; i += 16) { - auto x1 = _mm512_loadu_ps(inout + i); - auto x2 = _mm512_loadu_ps(in + i); - x1 = _mm512_add_ps(x1, x2); - _mm512_storeu_ps(inout + i, x1); + for(; i < len - 31; i += 32) { + auto out1 = _mm512_loadu_ps(inout + i); + auto out2 = _mm512_loadu_ps(inout + i + 16); + auto in1 = _mm512_loadu_ps(in + i); + auto in2 = _mm512_loadu_ps(in + i + 16); + out1 = _mm512_add_ps(out1, in1); + out2 = _mm512_add_ps(out1, in2); + _mm512_storeu_ps(inout + i, out1); + _mm512_storeu_ps(inout + i + 16, out2); + } + + if (i < len - 15) { + auto out1 = _mm512_loadu_ps(inout + i); + auto in1 = _mm512_loadu_ps(in + i); + out1 = _mm512_add_ps(out1, in1); + _mm512_storeu_ps(inout + i, out1); + i += 16; } + if(i < len) { auto mask = (1 << (len - i)) - 1; auto x1 = _mm512_maskz_loadu_ps(mask, inout + i); @@ -73,3 +87,125 @@ inline void add_ker(float *inout, float *in, int len) { _mm512_mask_storeu_ps(inout + i, mask, x1); } } + +inline void add_ker(float *inout, at::BFloat16 *in, int len) { + int i = 0; + for(; i < len - 31; i += 32) { + auto out1 = _mm512_loadu_ps(inout + i); + auto out2 = _mm512_loadu_ps(inout + i + 16); + auto in1 = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(in + i))); + auto in2 = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(in + i + 16))); + out1 = _mm512_add_ps(out1, in1); + out2 = _mm512_add_ps(out2, in2); + _mm512_storeu_ps(inout + i, out1); + _mm512_storeu_ps(inout + i + 16, out2); + } + + if (i < len - 15) { + auto out1 = _mm512_loadu_ps(inout + i); + auto in1 = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(in + i))); + out1 = _mm512_add_ps(out1, in1); + _mm512_storeu_ps(inout + i, out1); + i += 16; + } + + if(i < len) { + auto mask = (1 << (len - i)) - 1; + auto x1 = _mm512_maskz_loadu_ps(mask, inout + i); + auto x2 = cvt_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, in + i)); + x1 = _mm512_add_ps(x1, x2); + _mm512_mask_storeu_ps(inout + i, mask, x1); + } +} + +static inline void move_ker(at::BFloat16 *out, float *in, int64_t len) { + int64_t i = 0; + for (; i < len - 31; i += 32) { + auto in0 = cvt_fp32_to_bf16(_mm512_loadu_ps(in + i)); + auto in1 = cvt_fp32_to_bf16(_mm512_loadu_ps(in + i + 16)); + _mm256_storeu_si256((__m256i *)(out + i), in0); + _mm256_storeu_si256((__m256i *)(out + i + 16), in1); + } + + if (i < len - 15) { + auto in0 = cvt_fp32_to_bf16(_mm512_loadu_ps(in + i)); + _mm256_storeu_si256((__m256i *)(out + i), in0); + i += 16; + } + + if (i < len) { + auto mask = ((1 << (len - i)) - 1); + auto in0 = cvt_fp32_to_bf16(_mm512_maskz_loadu_ps(mask, in + i)); + _mm256_mask_storeu_epi16((__m256i *)(out + i), mask, in0); + } +} + +static inline void move_ker(float *out, float *in, int64_t len) { + int64_t i = 0; + for (; i < len - 31 ; i += 32) { + auto in0 = _mm512_loadu_ps(in + i ); + auto in1 = _mm512_loadu_ps(in + i + 16); + _mm512_storeu_ps(out + i, in0); + _mm512_storeu_ps(out + i + 16, in1); + } + + if (i < len - 15) { + auto in0 = _mm512_loadu_ps(in + i); + _mm512_storeu_ps(out + i, in0); + i += 16; + } + + if (i < len) { + auto mask = ((1 << (len - i)) - 1); + auto in0 = _mm512_maskz_loadu_ps(mask, in + i); + _mm512_mask_storeu_ps(out + i, mask, in0); + } +} + +static inline void move_ker(at::BFloat16 *out, at::BFloat16 *in, int64_t len) { + int64_t i = 0; + for (; i < len - 63; i += 64) { + auto in0 = _mm512_loadu_si512(in + i); + auto in1 = _mm512_loadu_si512(in + i + 16); + _mm512_storeu_si512(out + i, in0); + _mm512_storeu_si512(out + i + 16, in1); + } + + if (i < len - 31) { + auto in0 = _mm512_loadu_si512(in + i); + _mm512_storeu_si512(out + i, in0); + i += 31; + } + + if (i < len) { + auto mask = (1 << (len - i)) - 1; + auto in0 = _mm512_maskz_loadu_epi16(mask, in + i); + _mm512_mask_storeu_epi16(out + i, mask, in0); + } +} + +static inline void zero_ker(float *out, int64_t len) { + int64_t i = 0; + __m512 zero_512 = _mm512_setzero_ps(); + for (; i < len - 15; i += 16) { + _mm512_storeu_ps(out + i, zero_512); + } + + if (i < len) { + auto mask = ((1 << (len - i)) - 1); + _mm512_mask_storeu_ps(out + i, mask, zero_512); + } +} + +static inline void zero_ker(at::BFloat16 *out, int64_t len) { + int64_t i = 0; + __m512i zero_512 = _mm512_setzero_si512(); + for (; i < len - 31; i += 32) { + _mm512_storeu_si512(out + i, zero_512); + } + + if (i < len) { + auto mask = ((1 << (len - i)) - 1); + _mm512_mask_storeu_epi16(out + i, mask, zero_512); + } +} diff --git a/torch_ipex/csrc/init_python_bindings.cpp b/torch_ipex/csrc/init_python_bindings.cpp index 6cb5b7c2f..26f6f02f8 100644 --- a/torch_ipex/csrc/init_python_bindings.cpp +++ b/torch_ipex/csrc/init_python_bindings.cpp @@ -80,14 +80,15 @@ void InitIpexModuleBindings(py::module m) { return AtenIpexTypeExt::interaction_backward(grad_out, input); }); m.def("embedding_bag_forward", - [](const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets) { - return AtenIpexTypeExt::embedding_bag_forward(weights, inputs, offsets); + [](const at::Tensor& weight, const at::Tensor& indices, const at::Tensor& offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional& per_sample_weights, bool include_last_offset) { + return AtenIpexTypeExt::embedding_bag_forward(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset); }); + m.def("embedding_bag_backward", - [](const at::Tensor &grad_out, const at::Tensor &weights, - const at::Tensor &inputs, const at::Tensor &offsets) { - return AtenIpexTypeExt::embedding_bag_backward(grad_out, weights, inputs, offsets); + [](const at::Tensor& grad, const at::Tensor& indices, const at::Tensor& offsets, const at::Tensor offset2bag, const at::Tensor& bag_size, const at::Tensor& maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional& per_sample_weights) { + return AtenIpexTypeExt::embedding_bag_backward(grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights); }); + m.def("linear", [](const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) { return AtenIpexTypeExt::linear(input, weight, bias);