Skip to content

Commit d5acd80

Browse files
Embedding Bag Sum Fast Support for FP32 and BFloat16 (#23)
1 parent b051d72 commit d5acd80

File tree

9 files changed

+592
-14
lines changed

9 files changed

+592
-14
lines changed

intel_pytorch_extension_py/ops/embeddingbag.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from torch.autograd import Function
44
import _torch_ipex as core
55

6+
'''
7+
# extension for BF16 fast path only
68
torch_embedding_bag = torch.embedding_bag
79
def embeddingbag(weights, inputs, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset):
810
if weights.dtype == torch.float:
@@ -12,21 +14,41 @@ def embeddingbag(weights, inputs, offsets, scale_grad_by_freq, mode, sparse, per
1214
ret = (ret, None, None, None)
1315
else:
1416
assert(0, "unimplement embeddingbag path in extension")
15-
17+
'''
18+
def embeddingbag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset):
19+
ret = EmbeddingBagFunction.apply(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset)
1620
return ret
1721

1822

1923
class EmbeddingBagFunction(Function):
24+
'''
2025
@staticmethod
2126
def forward(ctx, weights, inputs, offsets):
2227
ctx.save_for_backward(weights, inputs, offsets)
2328
output = core.embedding_bag_forward(weights, inputs, offsets)
2429
return output
30+
'''
31+
@staticmethod
32+
def forward(ctx, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset):
33+
ctx.scale_grad_by_freq = scale_grad_by_freq
34+
ctx.mode = mode
35+
ctx.sparse = sparse
36+
ctx.num_weight = weight.size(0)
37+
ctx.save_for_backward(indices, offsets, per_sample_weights)
38+
ret = core.embedding_bag_forward(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset)
39+
return ret
2540

41+
'''
2642
@staticmethod
2743
def backward(ctx, grad_out):
2844
weights, inputs, offsets = ctx.saved_tensors
2945
grad_weight = core.embedding_bag_backward(grad_out, weights, inputs, offsets)
3046
return (grad_weight, None, None)
47+
'''
48+
@staticmethod
49+
def backward(ctx, grad, offset2bag, bag_size, maximum_indices):
50+
indices, offsets, per_sample_weights = ctx.saved_tensors
51+
grad_weight = core.embedding_bag_backward(grad, indices, offsets, offset2bag, bag_size, maximum_indices, ctx.num_weight, ctx.scale_grad_by_freq, ctx.mode, ctx.sparse, per_sample_weights)
52+
return grad_weight, None, None, None, None, None, None, None
3153

3254
torch.embedding_bag = embeddingbag

torch_ipex/csrc/cpu/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FILE(GLOB _CPU_SRCS *.cpp dbl/*.cpp bf16/*.cpp)
1+
FILE(GLOB _CPU_SRCS *.cpp dbl/*.cpp bf16/*.cpp aten/operators/*.cpp)
22
LIST(APPEND DPCPP_CPU_SRCS ${_CPU_SRCS})
33

44
# Pass to parent

torch_ipex/csrc/cpu/ExtendOPs.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "ExtendOPs.h"
77
#include "bf16/vec/bf16_vec_kernel.h"
88
#include "dil/dil.hpp"
9+
#include "aten/aten.hpp"
910
#include "xsmm/libxsmm_utils.h"
1011
#include "../utils.h"
1112
#include "DevOPs.h"
@@ -323,6 +324,7 @@ std::vector<at::Tensor> AtenIpexTypeExt::interaction_backward(const at::Tensor &
323324
}
324325
}
325326

327+
#if 0
326328
template<typename T>
327329
static inline at::Tensor _embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets) {
328330
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weights.is_contiguous());
@@ -422,6 +424,30 @@ at::Tensor AtenIpexTypeExt::embedding_bag_backward(const at::Tensor &grad_out,
422424
return _embedding_bag_backward<at::BFloat16>(grad_out, weights, inputs, offsets);
423425
}
424426
}
427+
#endif
428+
429+
std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>
430+
AtenIpexTypeExt::embedding_bag_forward(const at::Tensor& weight, const at::Tensor& indices,
431+
const at::Tensor& offsets, bool scale_grad_by_freq, int64_t mode, bool sparse,
432+
const c10::optional<at::Tensor>& per_sample_weights, bool include_last_offset) {
433+
at::Tensor _per_sample_weights;
434+
if(per_sample_weights.has_value()) {
435+
_per_sample_weights = per_sample_weights.value();
436+
}
437+
return cpu::aten::embedding_bag::embedding_bag_impl(weight, indices, offsets, scale_grad_by_freq, mode, sparse, _per_sample_weights, include_last_offset);
438+
}
439+
440+
at::Tensor
441+
AtenIpexTypeExt::embedding_bag_backward(const at::Tensor& grad, const at::Tensor& indices,
442+
const at::Tensor& offsets, const at::Tensor& offset2bag, const at::Tensor& bag_size, const at::Tensor& maximum_indices,
443+
int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse,
444+
const c10::optional<at::Tensor>& per_sample_weights) {
445+
at::Tensor _per_sample_weights;
446+
if(per_sample_weights.has_value()) {
447+
_per_sample_weights = per_sample_weights.value();
448+
}
449+
return cpu::aten::embedding_bag::embedding_bag_backward_impl(grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse, _per_sample_weights);
450+
}
425451

426452
at::Tensor AtenIpexTypeExt::linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
427453
return cpu::AtenIpexCPUDev::dil_linear(input, weight, bias);

torch_ipex/csrc/cpu/ExtendOPs.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,19 @@ class AtenIpexTypeExt {
1010
static void packed_add_(at::Tensor & top_half, at::Tensor & bot_half, const at::Tensor & grad, float alpha);
1111
static at::Tensor interaction_forward(const std::vector<at::Tensor> & input);
1212
static std::vector<at::Tensor> interaction_backward(const at::Tensor & grad_out, const std::vector<at::Tensor> & input);
13-
static at::Tensor embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
14-
static at::Tensor embedding_bag_backward(const at::Tensor &grad_out, const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
13+
//static at::Tensor embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
14+
//static at::Tensor embedding_bag_backward(const at::Tensor &grad_out, const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
15+
static std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>
16+
embedding_bag_forward(const at::Tensor & weight, const at::Tensor & indices,
17+
const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse,
18+
const c10::optional<at::Tensor>& per_sample_weights, bool include_last_offset);
19+
20+
static at::Tensor
21+
embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets,
22+
const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices,
23+
int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse,
24+
const c10::optional<at::Tensor>& per_sample_weights);
25+
1526
static at::Tensor linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
1627
static std::tuple<at::Tensor, at::Tensor, at::Tensor> linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask);
1728
static at::Tensor adaptive_avg_pool2d(at::Tensor const& input, at::IntArrayRef output_size);

torch_ipex/csrc/cpu/aten/aten.hpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
*Copyright (c) 2018 Intel Corporation.
3+
*
4+
*Permission is hereby granted, free of charge, to any person obtaining a copy
5+
*of this software and associated documentation files (the "Software"), to deal
6+
*in the Software without restriction, including without limitation the rights
7+
*to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
*copies of the Software, and to permit persons to whom the Software is
9+
*furnished to do so, subject to the following conditions:
10+
*
11+
*The above copyright notice and this permission notice shall be included in
12+
*all copies or substantial portions of the Software.
13+
*
14+
*THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
*IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
*FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
*AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
*LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
*OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20+
*THE SOFTWARE.
21+
*
22+
*/
23+
24+
#ifndef _ATEN_HPP
25+
#define _ATEN_HPP
26+
27+
#include <cstdlib>
28+
#include <algorithm>
29+
#include <memory>
30+
#include <map>
31+
#include <vector>
32+
#include <iterator>
33+
#include <string>
34+
#include <cstring>
35+
#include <numeric>
36+
#include <functional>
37+
#include <iostream>
38+
39+
#include "operators/embedding_bag.hpp"
40+
41+
#endif

0 commit comments

Comments
 (0)