Skip to content

Embedding Bag Sum Fast Support for FP32 and BFloat16 #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion intel_pytorch_extension_py/ops/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from torch.autograd import Function
import _torch_ipex as core

'''
# extension for BF16 fast path only
torch_embedding_bag = torch.embedding_bag
def embeddingbag(weights, inputs, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset):
if weights.dtype == torch.float:
Expand All @@ -12,21 +14,41 @@ def embeddingbag(weights, inputs, offsets, scale_grad_by_freq, mode, sparse, per
ret = (ret, None, None, None)
else:
assert(0, "unimplement embeddingbag path in extension")

'''
def embeddingbag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset):
ret = EmbeddingBagFunction.apply(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset)
return ret


class EmbeddingBagFunction(Function):
'''
@staticmethod
def forward(ctx, weights, inputs, offsets):
ctx.save_for_backward(weights, inputs, offsets)
output = core.embedding_bag_forward(weights, inputs, offsets)
return output
'''
@staticmethod
def forward(ctx, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset):
ctx.scale_grad_by_freq = scale_grad_by_freq
ctx.mode = mode
ctx.sparse = sparse
ctx.num_weight = weight.size(0)
ctx.save_for_backward(indices, offsets, per_sample_weights)
ret = core.embedding_bag_forward(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset)
return ret

'''
@staticmethod
def backward(ctx, grad_out):
weights, inputs, offsets = ctx.saved_tensors
grad_weight = core.embedding_bag_backward(grad_out, weights, inputs, offsets)
return (grad_weight, None, None)
'''
@staticmethod
def backward(ctx, grad, offset2bag, bag_size, maximum_indices):
indices, offsets, per_sample_weights = ctx.saved_tensors
grad_weight = core.embedding_bag_backward(grad, indices, offsets, offset2bag, bag_size, maximum_indices, ctx.num_weight, ctx.scale_grad_by_freq, ctx.mode, ctx.sparse, per_sample_weights)
return grad_weight, None, None, None, None, None, None, None

torch.embedding_bag = embeddingbag
2 changes: 1 addition & 1 deletion torch_ipex/csrc/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FILE(GLOB _CPU_SRCS *.cpp dbl/*.cpp bf16/*.cpp)
FILE(GLOB _CPU_SRCS *.cpp dbl/*.cpp bf16/*.cpp aten/operators/*.cpp)
LIST(APPEND DPCPP_CPU_SRCS ${_CPU_SRCS})

# Pass to parent
Expand Down
26 changes: 26 additions & 0 deletions torch_ipex/csrc/cpu/ExtendOPs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "ExtendOPs.h"
#include "bf16/vec/bf16_vec_kernel.h"
#include "dil/dil.hpp"
#include "aten/aten.hpp"
#include "xsmm/libxsmm_utils.h"
#include "../utils.h"
#include "DevOPs.h"
Expand Down Expand Up @@ -323,6 +324,7 @@ std::vector<at::Tensor> AtenIpexTypeExt::interaction_backward(const at::Tensor &
}
}

#if 0
template<typename T>
static inline at::Tensor _embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weights.is_contiguous());
Expand Down Expand Up @@ -422,6 +424,30 @@ at::Tensor AtenIpexTypeExt::embedding_bag_backward(const at::Tensor &grad_out,
return _embedding_bag_backward<at::BFloat16>(grad_out, weights, inputs, offsets);
}
}
#endif

std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>
AtenIpexTypeExt::embedding_bag_forward(const at::Tensor& weight, const at::Tensor& indices,
const at::Tensor& offsets, bool scale_grad_by_freq, int64_t mode, bool sparse,
const c10::optional<at::Tensor>& per_sample_weights, bool include_last_offset) {
at::Tensor _per_sample_weights;
if(per_sample_weights.has_value()) {
_per_sample_weights = per_sample_weights.value();
}
return cpu::aten::embedding_bag::embedding_bag_impl(weight, indices, offsets, scale_grad_by_freq, mode, sparse, _per_sample_weights, include_last_offset);
}

at::Tensor
AtenIpexTypeExt::embedding_bag_backward(const at::Tensor& grad, const at::Tensor& indices,
const at::Tensor& offsets, const at::Tensor& offset2bag, const at::Tensor& bag_size, const at::Tensor& maximum_indices,
int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse,
const c10::optional<at::Tensor>& per_sample_weights) {
at::Tensor _per_sample_weights;
if(per_sample_weights.has_value()) {
_per_sample_weights = per_sample_weights.value();
}
return cpu::aten::embedding_bag::embedding_bag_backward_impl(grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse, _per_sample_weights);
}

at::Tensor AtenIpexTypeExt::linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
return cpu::AtenIpexCPUDev::dil_linear(input, weight, bias);
Expand Down
15 changes: 13 additions & 2 deletions torch_ipex/csrc/cpu/ExtendOPs.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,19 @@ class AtenIpexTypeExt {
static void packed_add_(at::Tensor & top_half, at::Tensor & bot_half, const at::Tensor & grad, float alpha);
static at::Tensor interaction_forward(const std::vector<at::Tensor> & input);
static std::vector<at::Tensor> interaction_backward(const at::Tensor & grad_out, const std::vector<at::Tensor> & input);
static at::Tensor embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
static at::Tensor embedding_bag_backward(const at::Tensor &grad_out, const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
//static at::Tensor embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
//static at::Tensor embedding_bag_backward(const at::Tensor &grad_out, const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
static std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>
embedding_bag_forward(const at::Tensor & weight, const at::Tensor & indices,
const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse,
const c10::optional<at::Tensor>& per_sample_weights, bool include_last_offset);

static at::Tensor
embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets,
const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices,
int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse,
const c10::optional<at::Tensor>& per_sample_weights);

static at::Tensor linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
static std::tuple<at::Tensor, at::Tensor, at::Tensor> linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask);
static at::Tensor adaptive_avg_pool2d(at::Tensor const& input, at::IntArrayRef output_size);
Expand Down
41 changes: 41 additions & 0 deletions torch_ipex/csrc/cpu/aten/aten.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
*Copyright (c) 2018 Intel Corporation.
*
*Permission is hereby granted, free of charge, to any person obtaining a copy
*of this software and associated documentation files (the "Software"), to deal
*in the Software without restriction, including without limitation the rights
*to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
*copies of the Software, and to permit persons to whom the Software is
*furnished to do so, subject to the following conditions:
*
*The above copyright notice and this permission notice shall be included in
*all copies or substantial portions of the Software.
*
*THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
*IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
*FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
*AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
*LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
*OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
*THE SOFTWARE.
*
*/

#ifndef _ATEN_HPP
#define _ATEN_HPP

#include <cstdlib>
#include <algorithm>
#include <memory>
#include <map>
#include <vector>
#include <iterator>
#include <string>
#include <cstring>
#include <numeric>
#include <functional>
#include <iostream>

#include "operators/embedding_bag.hpp"

#endif
Loading