Skip to content

Commit d600af5

Browse files
committed
Merge remote-tracking branch 'gitlab/master'
2 parents 5bc355e + 41caea8 commit d600af5

26 files changed

+885
-237
lines changed

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@
9595
'aten::upsample_linear1d(Tensor self, int[1] output_size, bool align_corners, float? scales=None) -> Tensor',
9696
'aten::upsample_linear1d_backward(Tensor grad_output, int[1] output_size, int[3] input_size, bool align_corners, float? scales=None) -> Tensor',
9797
'aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor',
98+
'aten::upsample_bilinear2d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor',
9899
'aten::upsample_bilinear2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor',
100+
'aten::upsample_bilinear2d_backward.vec(Tensor grad_output, int[]? output_size, int[] input_size, bool align_corners, float[]? scale_factors) -> Tensor',
99101
'aten::upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor',
100102
'aten::upsample_trilinear3d_backward(Tensor grad_output, int[3] output_size, int[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor',
101103
'aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)',
@@ -578,7 +580,7 @@ def is_conv_overrideable_func(fname):
578580

579581
# Gen OP Name
580582
code += '#if defined(IPEX_DISP_OP)\n'
581-
code += ' printf("{}::{}\\n");\n'.format(_IPEX_OP_FUNC_NS, cpp_sig.def_name)
583+
code += ' printf("{}::{}\\n");\n'.format(_IPEX_OP_FUNC_NS, new_cpp_func_name)
582584
code += '#endif\n'
583585

584586
# Gen profile info
@@ -587,7 +589,7 @@ def is_conv_overrideable_func(fname):
587589
if param.core_type in ['Tensor', 'Scalar']:
588590
profiler_inputs.append(param.name)
589591
code += '#if defined(IPEX_PROFILE_OP)\n'
590-
code += ' RECORD_FUNCTION("{ns}::{name}", std::vector<c10::IValue>({{{input_names}}}));\n'.format(ns=_IPEX_OP_FUNC_NS, name=cpp_sig.def_name, input_names=', '.join(profiler_inputs))
592+
code += ' RECORD_FUNCTION("{ns}::{name}", std::vector<c10::IValue>({{}}));\n'.format(ns=_IPEX_OP_FUNC_NS, name=new_cpp_func_name)
591593
code += '#endif\n'
592594

593595
if is_conv_overrideable_func(cpp_sig.def_name):

tests/cpu/test_int8.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import itertools
99
import time
1010
import json
11+
import sys
1112

1213
import torch
1314
import torch.nn as nn
@@ -22,6 +23,9 @@
2223

2324
from common_utils import TestCase
2425

26+
def get_rand_seed():
27+
return int(time.time() * 1000000000)
28+
2529
device = ipex.DEVICE
2630
class TestQuantizationConfigueTune(TestCase):
2731
def test_quantization_status(self):
@@ -71,7 +75,7 @@ def test_quantization_status(self):
7175

7276

7377
class TestQuantization(TestCase):
74-
def compare_fp32_int8(self, model, x):
78+
def _compare_fp32_int8(self, model, x):
7579
conf = ipex.AmpConf(torch.int8)
7680
with ipex.AutoMixPrecision(conf, running_mode='calibration'):
7781
ref = model(x)
@@ -85,6 +89,25 @@ def compare_fp32_int8(self, model, x):
8589
self.assertEqual(ref, y, prec=0.1)
8690
os.remove('configure.json')
8791

92+
def _lstm_compare_fp32_int8(self, model, *args):
93+
conf = ipex.AmpConf(torch.int8)
94+
with ipex.AutoMixPrecision(conf, running_mode='calibration'):
95+
with torch.no_grad():
96+
ref, hy_ref = model(*args)
97+
conf.save('configure.json')
98+
99+
conf = ipex.AmpConf(torch.int8, 'configure.json')
100+
with ipex.AutoMixPrecision(conf, running_mode='inference'):
101+
with torch.no_grad():
102+
y, hy = model(*args)
103+
104+
self.assertTrue(ipex.core.is_int8_dil_tensor(y))
105+
106+
self.assertEqual(ref, y, prec=0.1)
107+
self.assertEqual(hy_ref[0], hy[0], prec=0.01)
108+
self.assertEqual(hy_ref[1], hy[1], prec=0.01)
109+
os.remove('configure.json')
110+
88111
def test_conv2d(self):
89112
options = itertools.product([1, 4], [True, False], [1, 2])
90113
for groups, bias, dilation in options:
@@ -100,12 +123,12 @@ def test_conv2d(self):
100123
dilation=dilation,
101124
bias=bias,
102125
groups=groups).float().to(device)
103-
self.compare_fp32_int8(conv2d, x)
126+
self._compare_fp32_int8(conv2d, x)
104127

105128
def test_relu(self):
106129
x = torch.randn((4, 5), dtype=torch.float32).to(device)
107130
relu = nn.ReLU()
108-
self.compare_fp32_int8(relu, x)
131+
self._compare_fp32_int8(relu, x)
109132

110133
def test_max_pool2d(self):
111134
N = torch.randint(3, 10, (1,)).item()
@@ -118,7 +141,7 @@ def test_max_pool2d(self):
118141
stride=stride,
119142
padding=1,
120143
ceil_mode=ceil_mode)
121-
self.compare_fp32_int8(max_pool2d, x)
144+
self._compare_fp32_int8(max_pool2d, x)
122145

123146
def test_avg_pool2d(self):
124147
N = torch.randint(3, 10, (1,)).item()
@@ -131,15 +154,15 @@ def test_avg_pool2d(self):
131154
stride=2,
132155
padding=1,
133156
count_include_pad=count_include_pad)
134-
self.compare_fp32_int8(avg_pool2d, x)
157+
self._compare_fp32_int8(avg_pool2d, x)
135158

136159
def test_adaptive_avg_pool2d(self):
137160
N = torch.randint(3, 10, (1,)).item()
138161
C = torch.randint(3, 10, (1,)).item()
139162
x = torch.randn(N, C, 224, 224, dtype=torch.float32).to(device)
140163

141164
adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7)
142-
self.compare_fp32_int8(adaptive_avg_pool2d, x)
165+
self._compare_fp32_int8(adaptive_avg_pool2d, x)
143166

144167
def test_linear(self):
145168
in_features = torch.randint(3, 10, (1,)).item()
@@ -148,8 +171,32 @@ def test_linear(self):
148171
for bias in [True, False]:
149172
x = torch.randn(3, in_features, dtype=torch.float32).to(device)
150173
linear = torch.nn.Linear(in_features, out_features, bias=bias).float().to(device)
151-
self.compare_fp32_int8(linear, x)
174+
self._compare_fp32_int8(linear, x)
175+
176+
def _lstm_int8(self, seq_len, batch_size, input_size, hidden_size, num_layers, bidirectional, bias, empty_state):
177+
rand_seed = int(get_rand_seed())
178+
179+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
180+
torch.manual_seed(rand_seed)
181+
182+
num_directions = 2 if bidirectional else 1
183+
184+
input_dpcpp = torch.FloatTensor(seq_len, batch_size, input_size).uniform_(-1, 1).to(device=device)
185+
h0_dpcpp = torch.FloatTensor(num_layers * num_directions, batch_size, hidden_size).uniform_(-1, 1).to(device=device)
186+
c0_dpcpp = torch.FloatTensor(num_layers * num_directions, batch_size, hidden_size).uniform_(-1, 1).to(device=device)
187+
model_dpcpp = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, bias=bias).to(device=device).eval()
188+
189+
self._lstm_compare_fp32_int8(model_dpcpp, input_dpcpp)
152190

191+
def test_lstm(self):
192+
self._lstm_int8(seq_len=5, batch_size=2, input_size=16, hidden_size=16, num_layers=1, bidirectional=False, bias=True, empty_state=False)
193+
194+
self._lstm_int8(seq_len=5, batch_size=2, input_size=16, hidden_size=16, num_layers=1, bidirectional=True, bias=True, empty_state=False)
195+
196+
self._lstm_int8(seq_len=5, batch_size=2, input_size=16, hidden_size=16, num_layers=1, bidirectional=False, bias=False, empty_state=False)
197+
198+
self._lstm_int8(seq_len=5, batch_size=2, input_size=16, hidden_size=16, num_layers=1, bidirectional=True, bias=False, empty_state=False)
199+
153200
if __name__ == '__main__':
154201
rand_seed = int(time.time() * 1000000000)
155202
torch.manual_seed(rand_seed)

tests/cpu/test_lazy_reorder.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import itertools
1414
import torch
1515
import intel_pytorch_extension as ipex
16+
import contextlib
17+
import io
1618

1719
from common_ipex_conf import AutoMixPrecision, AutoDNNL
1820

@@ -1303,6 +1305,33 @@ def test_unsqueeze(self):
13031305
x_dpcpp = x.clone().to(device=device)
13041306
self.assertEqual(x_dpcpp.unsqueeze(1), x.unsqueeze(1))
13051307

1308+
with AutoDNNL(True):
1309+
x = torch.randn(3, 64, 64, dtype=torch.float32)
1310+
x_xpu = x.clone().to(device=device)
1311+
conv2d_cpu = torch.nn.Conv2d(3, 6, (3, 3))
1312+
conv2d_xpu = copy.deepcopy(conv2d_cpu).to(device=device)
1313+
x_nchw = x.unsqueeze(0)
1314+
x_xpu_nchw = x_xpu.unsqueeze(0)
1315+
self.assertEqual(conv2d_cpu(x_nchw), conv2d_xpu(x_xpu_nchw))
1316+
1317+
conv2d_cpu = torch.nn.Conv2d(3, 1, (3, 3))
1318+
conv2d_xpu = copy.deepcopy(conv2d_cpu).to(ipex.DEVICE)
1319+
# reshape the conv2d weight to chw
1320+
conv2d_weight_seq = conv2d_xpu.weight.clone().squeeze()
1321+
# reshape the conv2d weight to nchw
1322+
conv2d_weight_unseq = torch.unsqueeze(conv2d_weight_seq, 0)
1323+
1324+
conv2d_xpu.weight.data = conv2d_weight_unseq
1325+
1326+
a = torch.randn(1, 3, 10, 10).to(ipex.DEVICE)
1327+
# Make sure the conv2d_xpu.weight is blocked format
1328+
conv2d_xpu(a)
1329+
# Make sure the unsqueeze does not trigger reorder
1330+
conv2d_weight_unseq = torch.unsqueeze(conv2d_weight_seq, 0)
1331+
self.assertEqual(conv2d_xpu(a), conv2d_cpu(a.to("cpu")))
1332+
1333+
1334+
13061335
class TestSoftMax(TestCase):
13071336
def test_softmax(self):
13081337
with AutoDNNL(True):
@@ -1580,7 +1609,7 @@ def _lstm_params_list(self, cell):
15801609
if cell == "RNN":
15811610
params_dict["nonlinearity"] = ["tanh"] # ["tanh", "relu"] TODO relu has accuracy issue
15821611
elif cell == "GRU":
1583-
params_dict["nonlinearity"] = [""]
1612+
params_dict["nonlinearity"] = [""]
15841613

15851614
params_list = []
15861615

@@ -1592,16 +1621,16 @@ def _test_lstm(self, training):
15921621
rand_seed = int(get_rand_seed())
15931622
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
15941623
torch.manual_seed(rand_seed)
1595-
1624+
15961625
params_list = self._lstm_params_list("LSTM")
15971626

15981627
for input_size, hidden_size, num_layers, bidirectional, bias, empty_state, batch_first, dropout, batch_size, seq_len in itertools.product(*params_list):
15991628
# dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1
16001629
if dropout > 0 and num_layers == 1:
16011630
continue
1602-
1631+
16031632
num_directions = 2 if bidirectional else 1
1604-
1633+
16051634
if batch_first:
16061635
input = torch.randn(batch_size, seq_len, input_size)
16071636
else:
@@ -1649,7 +1678,7 @@ def _test_lstm(self, training):
16491678
hy_cpu[0].sum().backward(retain_graph=True)
16501679
hy_dpcpp[0].sum().backward(retain_graph=True)
16511680
self.assertEqual(h0_dpcpp.grad.to('cpu'), h_cpu.grad)
1652-
1681+
16531682
hy_cpu[1].sum().backward(retain_graph=True)
16541683
hy_dpcpp[1].sum().backward(retain_graph=True)
16551684
self.assertEqual(c0_dpcpp.grad.to('cpu'), c_cpu.grad)
@@ -1658,16 +1687,16 @@ def _test_rnn(self, cell, training):
16581687
rand_seed = int(get_rand_seed())
16591688
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
16601689
torch.manual_seed(rand_seed)
1661-
1690+
16621691
params_list = self._lstm_params_list(cell)
16631692

16641693
for input_size, hidden_size, num_layers, bidirectional, bias, empty_state, batch_first, dropout, batch_size, seq_len, nonlinearity in itertools.product(*params_list):
16651694
# dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1
16661695
if dropout > 0 and num_layers == 1:
16671696
continue
1668-
1697+
16691698
num_directions = 2 if bidirectional else 1
1670-
1699+
16711700
if batch_first:
16721701
input = torch.randn(batch_size, seq_len, input_size)
16731702
else:
@@ -1683,7 +1712,7 @@ def _test_rnn(self, cell, training):
16831712
model_cpu = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, bias=bias, dropout=dropout, batch_first=batch_first, nonlinearity=nonlinearity)
16841713
elif cell == "GRU":
16851714
model_cpu = torch.nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, bias=bias, dropout=dropout, batch_first=batch_first)
1686-
1715+
16871716
model_cpu.train() if training else model_cpu.eval()
16881717

16891718
input_dpcpp = input.clone().to(device=device).requires_grad_(training)
@@ -1720,7 +1749,7 @@ def _test_pack_padded_sequence_lstm(self, training):
17201749
rand_seed = int(get_rand_seed())
17211750
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
17221751
torch.manual_seed(rand_seed)
1723-
1752+
17241753
embedding_dim = 1024
17251754
hidden_dim = 10
17261755
batch_size = 24
@@ -1755,7 +1784,7 @@ def _test_pack_padded_sequence_lstm(self, training):
17551784

17561785
lstm_out, hidden_out = lstm(embeds, (hidden_0, hidden_1))
17571786
lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
1758-
1787+
17591788
with AutoDNNL(True):
17601789
lstm_out_dpcpp, hidden_out_dpcpp = lstm_dpcpp(embeds_dpcpp, (hidden_0_dpcpp, hidden_1_dpcpp))
17611790
lstm_out_dpcpp, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out_dpcpp, batch_first=True)
@@ -1770,16 +1799,16 @@ def _test_pack_padded_sequence_lstm(self, training):
17701799
self.assertEqual(sentences_dpcpp.grad.to('cpu'), sentences.grad)
17711800
self.assertEqual(lstm_dpcpp.weight_ih_l0.grad.to('cpu'), lstm.weight_ih_l0.grad)
17721801
self.assertEqual(lstm_dpcpp.weight_hh_l0.grad.to('cpu'), lstm.weight_hh_l0.grad)
1773-
1802+
17741803
self.assertEqual(lstm_dpcpp.bias_ih_l0.grad.to('cpu'), lstm.bias_ih_l0.grad)
17751804
self.assertEqual(lstm_dpcpp.bias_hh_l0.grad.to('cpu'), lstm.bias_hh_l0.grad)
1776-
1805+
17771806
self.assertEqual(hidden_0_dpcpp.grad.to('cpu'), hidden_0.grad)
17781807
self.assertEqual(hidden_1_dpcpp.grad.to('cpu'), hidden_1.grad)
17791808

17801809
def test_lstm_inference(self):
17811810
self._test_lstm(training=False)
1782-
1811+
17831812
def test_lstm_training(self):
17841813
self._test_lstm(training=True)
17851814

@@ -1937,6 +1966,17 @@ def test_upsample_bilinear2d_scale_factor(self):
19371966
y_dpcpp.sum().backward()
19381967
self.assertEqual(x_cpu.grad, x_dpcpp.grad)
19391968

1969+
with AutoDNNL(True):
1970+
x = torch.randn(2, 2, 4, 4)
1971+
x_cpu = x.clone().requires_grad_()
1972+
x_dpcpp = x.clone().to(device=device).requires_grad_()
1973+
y_cpu = F.interpolate(x_cpu, scale_factor = [2, 3], mode='bilinear', align_corners=False, recompute_scale_factor=False)
1974+
y_dpcpp = F.interpolate(x_dpcpp, scale_factor = [2, 3], mode='bilinear', align_corners=False, recompute_scale_factor=False)
1975+
self.assertEqual(y_cpu, y_dpcpp)
1976+
y_cpu.sum().backward()
1977+
y_dpcpp.sum().backward()
1978+
self.assertEqual(x_cpu.grad, x_dpcpp.grad)
1979+
19401980
def test_upsample_bilinear2d_size(self):
19411981
rand_seed = int(get_rand_seed())
19421982
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))

0 commit comments

Comments
 (0)