Skip to content

Commit aaed398

Browse files
authored
fix emb (#6)
1 parent a252efc commit aaed398

File tree

5 files changed

+132
-192
lines changed

5 files changed

+132
-192
lines changed
Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import torch
2-
from torch import nn
3-
from torch.autograd import Function
42
import _torch_ipex as core
3+
import warnings
54

6-
# # extension for BF16 fast path only
5+
torch_embedding_bag = torch.embedding_bag
76

8-
9-
def embeddingbag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset):
10-
ret = torch.ops.torch_ipex.embedding_bag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset)
11-
if len(ret)==1:
12-
ret += [torch.Tensor(), torch.Tensor(), torch.Tensor()]
7+
def embeddingbag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx):
8+
if core.embedding_bag_fast_path_sum(weights, per_sample_weights, mode, padding_idx):
9+
ret = torch.ops.torch_ipex.embedding_bag(weights, indices, offsets, sparse, include_last_offset)
10+
# torch.embedding_bag expected 4 Tensor returned
11+
# here we only return 1 tensor since the other three tensors are not needed in our fast path
12+
ret = [ret, torch.Tensor(), torch.Tensor(), torch.Tensor()]
13+
else:
14+
warnings.warn('Fallback to torch.embedding bag')
15+
ret = torch_embedding_bag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx)
1316
return ret
1417

1518
torch.embedding_bag = embeddingbag

tests/cpu/test_emb.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,48 @@
11
import torch
22
import torch.nn as nn
3-
import intel_pytorch_extension as ipex
43
import unittest
54
import copy
65
from common_utils import TestCase
76

87
class TestEMB(TestCase):
9-
def test_emb(self):
8+
def _test_emb(self, mode):
109
#E = nn.EmbeddingBag(10, 5, mode="sum", sparse=True)
11-
cpu_emb = nn.EmbeddingBag(10, 3, mode='sum', sparse=True)
12-
dpcpp_emb = copy.deepcopy(cpu_emb)
13-
bf16_emb = copy.deepcopy(cpu_emb).bfloat16()
10+
aten_emb = nn.EmbeddingBag(10, 3, mode=mode, sparse=True)
11+
ipex_emb = copy.deepcopy(aten_emb)
12+
bf16_emb = copy.deepcopy(aten_emb).bfloat16()
1413
# a batch of 2 samples of 4 indices each
15-
cpu_input = torch.LongTensor([1,2,4,5,4,3,2,9])
16-
dpcpp_input = cpu_input.clone().detach()
17-
18-
cpu_offsets = torch.LongTensor([0,1,2,3,4,5,6,7])
19-
dpcpp_offsets = cpu_offsets.clone().detach()
20-
21-
cpu_out = cpu_emb(cpu_input, cpu_offsets)
22-
23-
#torch.embedding_bag = ipex.embeddingbag
24-
dpcpp_out = dpcpp_emb(dpcpp_input, dpcpp_offsets)
25-
bf16_out = bf16_emb(dpcpp_input, dpcpp_offsets)
26-
27-
self.assertEqual(cpu_out, dpcpp_out.to('cpu'))
28-
self.assertEqual(cpu_out, bf16_out.to('cpu').float(), 0.01)
29-
30-
cpu_out.mean().backward()
31-
dpcpp_out.mean().backward()
32-
bf16_out.float().mean().backward()
33-
34-
self.assertEqual(cpu_emb.weight.grad.data._nnz(), dpcpp_emb.weight.grad.data._nnz())
35-
self.assertEqual(cpu_emb.weight.grad.data.sparse_dim(), dpcpp_emb.weight.grad.data.sparse_dim())
36-
self.assertEqual(cpu_emb.weight.grad.data.dense_dim(), dpcpp_emb.weight.grad.data.dense_dim())
37-
self.assertEqual(cpu_emb.weight.grad.data.is_coalesced(), dpcpp_emb.weight.grad.data.is_coalesced())
38-
self.assertEqual(cpu_emb.weight.grad.data._indices(), dpcpp_emb.weight.grad.data._indices().to('cpu'))
39-
self.assertEqual(cpu_emb.weight.grad.data._values(), dpcpp_emb.weight.grad.data._values().to('cpu'))
40-
41-
self.assertEqual(cpu_emb.weight.grad.data._values(), dpcpp_emb.weight.grad.data._values().to('cpu'), 0.01)
42-
self.assertEqual(bf16_emb.weight.grad.data._values().dtype, torch.bfloat16)
14+
input = torch.LongTensor([1,2,4,5,4,3,2,9])
15+
offsets = torch.LongTensor([0,1,2,3,4,5,6,7])
16+
# aten path
17+
aten_out = aten_emb(input, offsets)
18+
aten_out.mean().backward()
19+
20+
# ipex fast path (both fp32/bf16)
21+
import intel_pytorch_extension
22+
ipex_out = ipex_emb(input, offsets)
23+
ipex_out.mean().backward()
24+
if mode == 'sum':
25+
bf16_out = bf16_emb(input, offsets)
26+
bf16_out.mean().backward()
27+
self.assertEqual(aten_out, bf16_out.float(), 0.01)
28+
self.assertEqual(bf16_emb.weight.grad.data._values().dtype, torch.bfloat16)
29+
del(intel_pytorch_extension)
30+
31+
self.assertEqual(aten_out, ipex_out)
32+
33+
self.assertEqual(aten_emb.weight.grad.data._nnz(), ipex_emb.weight.grad.data._nnz())
34+
self.assertEqual(aten_emb.weight.grad.data.sparse_dim(), ipex_emb.weight.grad.data.sparse_dim())
35+
self.assertEqual(aten_emb.weight.grad.data.dense_dim(), ipex_emb.weight.grad.data.dense_dim())
36+
self.assertEqual(aten_emb.weight.grad.data.is_coalesced(), ipex_emb.weight.grad.data.is_coalesced())
37+
self.assertEqual(aten_emb.weight.grad.data._indices(), ipex_emb.weight.grad.data._indices())
38+
self.assertEqual(aten_emb.weight.grad.data._values(), ipex_emb.weight.grad.data._values())
39+
self.assertEqual(aten_emb.weight.grad.data._values(), ipex_emb.weight.grad.data._values(), 0.01)
40+
41+
def test_emb_fast_path(self):
42+
self._test_emb(mode='mean')
43+
44+
def test_emb_fallback_path(self):
45+
self._test_emb(mode='sum')
4346

4447
if __name__ == '__main__':
4548
test = unittest.main()

torch_ipex/csrc/cpu/ExtendOPs.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,19 @@ class AtenIpexTypeExt {
3939
static std::vector<at::Tensor> interaction_backward(const at::Tensor & grad_out,
4040
const std::vector<at::Tensor> & input);
4141

42-
static std::vector<at::Tensor> embedding_bag(
43-
const at::Tensor &weight, const at::Tensor &indices,
44-
const at::Tensor &offsets, bool scale_grad_by_freq, int64_t mode,
45-
bool sparse, const c10::optional<at::Tensor> &per_sample_weights,
42+
static at::Tensor embedding_bag(
43+
const at::Tensor &weight,
44+
const at::Tensor &indices,
45+
const at::Tensor &offsets,
46+
bool sparse,
4647
bool include_last_offset);
48+
49+
static bool embedding_bag_fast_path_sum(
50+
const at::Tensor weight,
51+
const c10::optional<at::Tensor> per_sample_weights,
52+
int64_t mode,
53+
const c10::optional<int64_t> padding_idx);
54+
4755
};
4856

4957
} // namespace torch_ipex

0 commit comments

Comments
 (0)