Skip to content

Commit 5ed7ba8

Browse files
authored
[LLGA] fix UTs (#11)
* [LLGA] fix unique id * clear cached weights_scales_ * [LLGA] save the dequant info on the graph instead of in a global variable * [LLGA] save per_channel zps and scales as tensor instead of vector on the graph * [LLGA] enable more UTs * [LLGA] add ut for conv_eltwise inplace * [LLGA] add check on fusion patterns in the subgraphs * [LLGA] add config name for different UTs * [LLGA] use temporary file path for the config in UTs
1 parent 330e277 commit 5ed7ba8

File tree

5 files changed

+334
-192
lines changed

5 files changed

+334
-192
lines changed

tests/cpu/test_jit_llga_quantization_fuser.py

Lines changed: 236 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -23,46 +23,68 @@ def get_eltwise_fn(name):
2323
else:
2424
raise NameError('Eltwise function %s not found' % name)
2525

26-
# TODO: UTs could run separately but not all together due to IPEX config cache
27-
# enable all the UTs once the IPEX config cache is fixed
28-
29-
# class TestOp(JitLlgaTestCase):
30-
# def test_conv2d(self):
31-
# for [
32-
# spatial,
33-
# in_channels,
34-
# out_channels,
35-
# kernel,
36-
# padding,
37-
# stride,
38-
# dilation,
39-
# g,
40-
# bias
41-
# ] in itertools.product(
42-
# [7],
43-
# [8],
44-
# [7],
45-
# [3],
46-
# [0],
47-
# [1],
48-
# [1],
49-
# [1],
50-
# [True, False]):
51-
52-
# m = nn.Conv2d(in_channels=in_channels * g,
53-
# out_channels=out_channels * g,
54-
# kernel_size=kernel,
55-
# padding=padding,
56-
# stride=stride,
57-
# dilation=dilation,
58-
# groups=g,
59-
# bias=bias)
60-
# x = torch.rand(1, in_channels * g, spatial, spatial)
26+
class TestOp(JitLlgaTestCase):
27+
def test_conv2d(self):
28+
for [
29+
spatial,
30+
in_channels,
31+
out_channels,
32+
kernel,
33+
padding,
34+
stride,
35+
dilation,
36+
g,
37+
bias
38+
] in itertools.product(
39+
[7, 8],
40+
[8, 15],
41+
[7, 16],
42+
[3, 4],
43+
[0, 2],
44+
[1, 2],
45+
[1, 2],
46+
[1, 2],
47+
[True, False]):
48+
49+
m = nn.Conv2d(in_channels=in_channels * g,
50+
out_channels=out_channels * g,
51+
kernel_size=kernel,
52+
padding=padding,
53+
stride=stride,
54+
dilation=dilation,
55+
groups=g,
56+
bias=bias)
57+
x = torch.rand(1, in_channels * g, spatial, spatial)
58+
59+
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, config_name="conv2d")
60+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
61+
self.assertFused(graph, ['aten::_convolution', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
6162

62-
# graph = self.checkQuantizeTrace(m, x, atol=2e-1)
63-
# self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
64-
# self.assertFused(graph, ['aten::conv2d', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
63+
patterns = [
64+
["aten::quantize_per_tensor"],
65+
["aten::quantize_per_channel"],
66+
["aten::dequantize", "aten::_convolution", "aten::quantize_per_tensor"],
67+
["aten::dequantize"]
68+
]
69+
self.checkPatterns(graph, patterns)
70+
71+
def test_linear(self):
72+
for bias in [True, False]:
73+
x = torch.rand(32, 28)
74+
m = torch.nn.Linear(in_features=28, out_features=64, bias=bias)
6575

76+
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="linear")
77+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
78+
self.assertFused(graph, ['aten::linear', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
79+
80+
patterns = [
81+
["aten::quantize_per_tensor"],
82+
["aten::quantize_per_channel"],
83+
["aten::dequantize", "aten::linear", "aten::quantize_per_tensor"],
84+
["aten::dequantize"]
85+
]
86+
self.checkPatterns(graph, patterns)
87+
6688
class TestFusionPattern(JitLlgaTestCase):
6789
def test_conv2d_eltwise(self):
6890
class M(nn.Module):
@@ -76,72 +98,193 @@ def forward(self, x):
7698
x = self.conv1(x)
7799
x = self.eltwise(x)
78100
x = self.conv2(x)
79-
x = self.eltwise(x)
80101
return x
81102

82103
for eltwise in ['relu']: # TODO: ['sigmoid', 'sqrt', 'abs', 'square', 'hardtanh']
83-
for inplace in [False]:
104+
for inplace in [False, True]:
84105
eltwise_fn_name = eltwise + '_' if inplace else eltwise
85106
eltwise_fn = get_eltwise_fn(eltwise_fn_name)
86107

87108
m = M(eltwise_fn)
88109
x = torch.rand(1, 32, 28, 28)
89110

90-
graph = self.checkQuantizeTrace(m, x, atol=2e-1)
111+
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, config_name="conv2d_eltwise")
91112
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 6)
92-
self.assertFused(graph, ['aten::conv2d', 'aten::' + eltwise, 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
93-
94-
# def test_conv2d_bn(self):
95-
# class M(nn.Module):
96-
# def __init__(self):
97-
# super(M, self).__init__()
98-
# self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=False) # TODO: bias=True
99-
# self.bn1 = nn.BatchNorm2d(32)
100-
101-
# def forward(self, x):
102-
# x = self.conv1(x)
103-
# x = self.bn1(x)
104-
# return x
105-
106-
# m = M().eval()
107-
# x = torch.rand(1, 32, 28, 28)
113+
self.assertFused(graph, ['aten::_convolution', 'aten::' + eltwise, 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
114+
115+
patterns = [
116+
["aten::quantize_per_tensor"],
117+
["aten::quantize_per_channel"],
118+
["aten::dequantize", "aten::_convolution", 'aten::' + eltwise, "aten::quantize_per_tensor"], # inplace op will become outplace op on the JIT graph
119+
["aten::quantize_per_channel"],
120+
["aten::dequantize", "aten::_convolution", "aten::quantize_per_tensor"],
121+
["aten::dequantize"]
122+
]
123+
self.checkPatterns(graph, patterns)
124+
125+
def test_conv2d_bn(self):
126+
class M(nn.Module):
127+
def __init__(self, bias):
128+
super(M, self).__init__()
129+
self.conv1 = nn.Conv2d(32, 5, 3, padding=1, bias=False)
130+
self.bn1 = nn.BatchNorm2d(5)
131+
132+
def forward(self, x):
133+
x = self.conv1(x)
134+
x = self.bn1(x)
135+
return x
136+
for bias in [False, True]:
137+
m = M(bias).eval()
138+
x = torch.rand(1, 32, 16, 16)
139+
# TODO: This shape will fail
140+
# x = torch.rand(1, 32, 28, 28)
141+
142+
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, folding=True, config_name="conv2d_bn")
143+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
144+
self.assertFused(graph, ['aten::_convolution', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
145+
146+
patterns = [
147+
["aten::quantize_per_tensor"],
148+
["aten::quantize_per_channel"],
149+
["aten::dequantize", "aten::_convolution", "aten::quantize_per_tensor"],
150+
["aten::dequantize"]
151+
]
152+
self.checkPatterns(graph, patterns)
153+
154+
def test_conv2d_bn_relu(self):
155+
class M(nn.Module):
156+
def __init__(self):
157+
super(M, self).__init__()
158+
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
159+
self.bn1 = nn.BatchNorm2d(32)
160+
161+
def forward(self, x):
162+
x = self.conv1(x)
163+
x = self.bn1(x)
164+
x = F.relu(x)
165+
return x
166+
167+
m = M().eval()
168+
x = torch.rand(1, 32, 28, 28)
169+
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, folding=True, config_name="conv2d_bn_relu")
170+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
171+
self.assertFused(graph, ['aten::_convolution', 'aten::relu',
172+
'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
173+
174+
patterns = [
175+
["aten::quantize_per_tensor"],
176+
["aten::quantize_per_channel"],
177+
["aten::dequantize", "aten::_convolution", "aten::relu", "aten::quantize_per_tensor"],
178+
["aten::dequantize"]
179+
]
180+
self.checkPatterns(graph, patterns)
181+
182+
def test_linear_eltwise(self):
183+
class M(nn.Module):
184+
def __init__(self, eltwise_fn, bias):
185+
super(M, self).__init__()
186+
self.linear = nn.Linear(28, 64, bias)
187+
self.eltwise = eltwise_fn
188+
189+
def forward(self, x):
190+
x = self.linear(x)
191+
x = self.eltwise(x)
192+
return x
193+
194+
# TODO: use itertools.product once all combinations is supported
195+
for [has_bias, eltwise] in [
196+
[True, 'relu'],
197+
[False, 'relu'],
198+
# [True, 'gelu'], # TODO: enable it once linear_gelu default recipe is fixed
199+
# [False, 'gelu'], # TODO: enable it once linear_gelu default recipe is fixed
200+
[True, 'sigmoid'],
201+
[False, 'sigmoid'],
202+
]:
203+
eltwise_fn = get_eltwise_fn(eltwise)
204+
m = M(eltwise_fn, has_bias)
205+
x = torch.rand(32, 28, requires_grad=False)
206+
207+
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="linear_eltwise")
208+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
209+
self.assertFused(graph, ['aten::' + eltwise])
210+
211+
patterns = [
212+
["aten::quantize_per_tensor"],
213+
["aten::quantize_per_channel"],
214+
["aten::dequantize", "aten::linear", "aten::" + eltwise, "aten::quantize_per_tensor"],
215+
["aten::dequantize"]
216+
]
217+
self.checkPatterns(graph, patterns)
218+
219+
def test_conv2d_sum(self):
220+
class M(nn.Module):
221+
def __init__(self, bias=False):
222+
super(M, self).__init__()
223+
self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
224+
self.bn1 = nn.BatchNorm2d(32)
225+
self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
226+
self.bn2 = nn.BatchNorm2d(32)
227+
self.relu = nn.ReLU()
228+
self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=bias)
229+
self.bn3 = nn.BatchNorm2d(32)
230+
231+
def forward(self, x, y):
232+
x = self.conv1(x)
233+
x = self.bn1(x)
234+
y = self.conv2(y)
235+
y = self.bn2(y)
236+
z = self.relu(x + y)
237+
z = self.conv3(z)
238+
z = self.bn3(z)
239+
return z
240+
241+
for bias in [True, False]:
242+
m = M(bias).eval()
243+
x = torch.rand(1, 32, 16, 16, requires_grad=False)
244+
y = torch.rand(1, 32, 16, 16, requires_grad=False)
245+
graph = self.checkQuantizeTrace(m, [x, y], folding=True, atol=1e-1, config_name="conv2d_sum")
246+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 13) # TODO: nb FUSION_GROUP=10 when oneDNN support sum post_ops with zps
247+
248+
# TODO: check patterns when oneDNN support sum post_ops with zps
249+
# patterns = [
250+
# ["aten::quantize_per_tensor"],
251+
# ["aten::quantize_per_channel"],
252+
# ["aten::dequantize", "aten::_convolution", "aten::quantize_per_tensor"],
253+
# ["aten::quantize_per_channel"],
254+
# ["aten::dequantize", "aten::_convolution", "aten::quantize_per_tensor"],
255+
# ["aten::quantize_per_channel"],
256+
# ["aten::dequantize", "aten::_convolution", "aten::relu", "aten::add", "aten::quantize_per_tensor"],
257+
# ["aten::quantize_per_channel"],
258+
# ["aten::dequantize", "aten::_convolution", "aten::quantize_per_tensor"],
259+
# ["aten::dequantize"]
260+
# ]
261+
# self.checkPatterns(graph, patterns)
262+
263+
class TestModel(JitLlgaTestCase):
264+
@skipIfNoTorchVision
265+
def _test_vision(self, model_name):
266+
m = getattr(torchvision.models, model_name)().eval()
267+
x = torch.rand(1, 3, 224, 224) / 10
268+
269+
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, folding=True, config_name=model_name)
108270

109-
# graph = self.checkQuantizeTrace(m, x, atol=1e-1, folding=True)
110-
# self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
111-
# self.assertFused(graph, ['aten::conv2d', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
112-
113-
114-
# class TestModel(JitLlgaTestCase):
115-
# @skipIfNoTorchVision
116-
# def _test_vision(self, model_name):
117-
# m = getattr(torchvision.models, model_name)().eval()
118-
# x = torch.rand(1, 3, 224, 224) / 10
119-
120-
# graph = self.checkQuantizeTrace(m, x, atol=2e-1, folding=True)
121-
# # self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 116)
122-
# self.assertFused(graph, ['aten::conv2d', 'aten::batch_norm',
123-
# 'aten::relu', 'aten::mm', 'aten::add',
124-
# 'aten::avg_pool2d', 'aten::max_pool2d',
125-
# 'aten::linear'
126-
# 'aten::quantize_per_tensor', 'aten::quantize_per_channel',
127-
# 'aten::dequantize'])
128-
# # self.assertFused(graph, ['aten::conv2d', 'aten::batch_norm',
129-
# # 'aten::relu', 'aten::mm', 'aten::add',
130-
# # 'aten::avg_pool2d', 'aten::max_pool2d',
131-
# # 'aten::quantize_per_tensor', 'aten::quantize_per_channel',
132-
# # 'aten::dequantize'])
133-
134-
135-
# for model_name, enabled in [
136-
# ['resnet50', True],
137-
# ]:
138-
# def wrapper(mname):
139-
# @unittest.skipIf(not enabled, 'Disabled')
140-
# def test(self):
141-
# return self._test_vision(mname)
142-
# return test
143-
144-
# setattr(TestModel, 'test_vision_%s' % model_name, wrapper(model_name))
271+
# TODO: aten::adaptive_avg_pool2d also need to be fused once backend supported it
272+
self.assertFused(graph, ['aten::_convolution', 'aten::relu',
273+
'aten::max_pool2d', 'aten::linear'
274+
'aten::quantize_per_tensor', 'aten::quantize_per_channel',
275+
'aten::dequantize'])
276+
277+
278+
for model_name, enabled in [
279+
['resnet50', True],
280+
]:
281+
def wrapper(mname):
282+
@unittest.skipIf(not enabled, 'Disabled')
283+
def test(self):
284+
return self._test_vision(mname)
285+
return test
286+
287+
setattr(TestModel, 'test_vision_%s' % model_name, wrapper(model_name))
145288

146289
if __name__ == '__main__':
147290
run_tests()

0 commit comments

Comments
 (0)