|
1 | 1 | import torch
|
2 | 2 | import torch.nn as nn
|
3 |
| -import intel_pytorch_extension as ipex |
4 | 3 | import unittest
|
5 | 4 | import copy
|
6 | 5 | from common_utils import TestCase
|
7 | 6 |
|
8 | 7 | class TestEMB(TestCase):
|
9 |
| - def test_emb(self): |
| 8 | + def _test_emb(self, mode): |
10 | 9 | #E = nn.EmbeddingBag(10, 5, mode="sum", sparse=True)
|
11 |
| - cpu_emb = nn.EmbeddingBag(10, 3, mode='sum', sparse=True) |
12 |
| - dpcpp_emb = copy.deepcopy(cpu_emb) |
13 |
| - bf16_emb = copy.deepcopy(cpu_emb).bfloat16() |
| 10 | + aten_emb = nn.EmbeddingBag(10, 3, mode=mode, sparse=True) |
| 11 | + ipex_emb = copy.deepcopy(aten_emb) |
| 12 | + bf16_emb = copy.deepcopy(aten_emb).bfloat16() |
14 | 13 | # a batch of 2 samples of 4 indices each
|
15 |
| - cpu_input = torch.LongTensor([1,2,4,5,4,3,2,9]) |
16 |
| - dpcpp_input = cpu_input.clone().detach() |
17 |
| - |
18 |
| - cpu_offsets = torch.LongTensor([0,1,2,3,4,5,6,7]) |
19 |
| - dpcpp_offsets = cpu_offsets.clone().detach() |
20 |
| - |
21 |
| - cpu_out = cpu_emb(cpu_input, cpu_offsets) |
22 |
| - |
23 |
| - #torch.embedding_bag = ipex.embeddingbag |
24 |
| - dpcpp_out = dpcpp_emb(dpcpp_input, dpcpp_offsets) |
25 |
| - bf16_out = bf16_emb(dpcpp_input, dpcpp_offsets) |
26 |
| - |
27 |
| - self.assertEqual(cpu_out, dpcpp_out.to('cpu')) |
28 |
| - self.assertEqual(cpu_out, bf16_out.to('cpu').float(), 0.01) |
29 |
| - |
30 |
| - cpu_out.mean().backward() |
31 |
| - dpcpp_out.mean().backward() |
32 |
| - bf16_out.float().mean().backward() |
33 |
| - |
34 |
| - self.assertEqual(cpu_emb.weight.grad.data._nnz(), dpcpp_emb.weight.grad.data._nnz()) |
35 |
| - self.assertEqual(cpu_emb.weight.grad.data.sparse_dim(), dpcpp_emb.weight.grad.data.sparse_dim()) |
36 |
| - self.assertEqual(cpu_emb.weight.grad.data.dense_dim(), dpcpp_emb.weight.grad.data.dense_dim()) |
37 |
| - self.assertEqual(cpu_emb.weight.grad.data.is_coalesced(), dpcpp_emb.weight.grad.data.is_coalesced()) |
38 |
| - self.assertEqual(cpu_emb.weight.grad.data._indices(), dpcpp_emb.weight.grad.data._indices().to('cpu')) |
39 |
| - self.assertEqual(cpu_emb.weight.grad.data._values(), dpcpp_emb.weight.grad.data._values().to('cpu')) |
40 |
| - |
41 |
| - self.assertEqual(cpu_emb.weight.grad.data._values(), dpcpp_emb.weight.grad.data._values().to('cpu'), 0.01) |
42 |
| - self.assertEqual(bf16_emb.weight.grad.data._values().dtype, torch.bfloat16) |
| 14 | + input = torch.LongTensor([1,2,4,5,4,3,2,9]) |
| 15 | + offsets = torch.LongTensor([0,1,2,3,4,5,6,7]) |
| 16 | + # aten path |
| 17 | + aten_out = aten_emb(input, offsets) |
| 18 | + aten_out.mean().backward() |
| 19 | + |
| 20 | + # ipex fast path (both fp32/bf16) |
| 21 | + import intel_pytorch_extension |
| 22 | + ipex_out = ipex_emb(input, offsets) |
| 23 | + ipex_out.mean().backward() |
| 24 | + if mode == 'sum': |
| 25 | + bf16_out = bf16_emb(input, offsets) |
| 26 | + bf16_out.mean().backward() |
| 27 | + self.assertEqual(aten_out, bf16_out.float(), 0.01) |
| 28 | + self.assertEqual(bf16_emb.weight.grad.data._values().dtype, torch.bfloat16) |
| 29 | + del(intel_pytorch_extension) |
| 30 | + |
| 31 | + self.assertEqual(aten_out, ipex_out) |
| 32 | + |
| 33 | + self.assertEqual(aten_emb.weight.grad.data._nnz(), ipex_emb.weight.grad.data._nnz()) |
| 34 | + self.assertEqual(aten_emb.weight.grad.data.sparse_dim(), ipex_emb.weight.grad.data.sparse_dim()) |
| 35 | + self.assertEqual(aten_emb.weight.grad.data.dense_dim(), ipex_emb.weight.grad.data.dense_dim()) |
| 36 | + self.assertEqual(aten_emb.weight.grad.data.is_coalesced(), ipex_emb.weight.grad.data.is_coalesced()) |
| 37 | + self.assertEqual(aten_emb.weight.grad.data._indices(), ipex_emb.weight.grad.data._indices()) |
| 38 | + self.assertEqual(aten_emb.weight.grad.data._values(), ipex_emb.weight.grad.data._values()) |
| 39 | + self.assertEqual(aten_emb.weight.grad.data._values(), ipex_emb.weight.grad.data._values(), 0.01) |
| 40 | + |
| 41 | + def test_emb_fast_path(self): |
| 42 | + self._test_emb(mode='mean') |
| 43 | + |
| 44 | + def test_emb_fallback_path(self): |
| 45 | + self._test_emb(mode='sum') |
43 | 46 |
|
44 | 47 | if __name__ == '__main__':
|
45 | 48 | test = unittest.main()
|
0 commit comments