@@ -23,46 +23,68 @@ def get_eltwise_fn(name):
23
23
else :
24
24
raise NameError ('Eltwise function %s not found' % name )
25
25
26
- # TODO: UTs could run separately but not all together due to IPEX config cache
27
- # enable all the UTs once the IPEX config cache is fixed
28
-
29
- # class TestOp(JitLlgaTestCase):
30
- # def test_conv2d(self):
31
- # for [
32
- # spatial,
33
- # in_channels,
34
- # out_channels,
35
- # kernel,
36
- # padding,
37
- # stride,
38
- # dilation,
39
- # g,
40
- # bias
41
- # ] in itertools.product(
42
- # [7],
43
- # [8],
44
- # [7],
45
- # [3],
46
- # [0],
47
- # [1],
48
- # [1],
49
- # [1],
50
- # [True, False]):
51
-
52
- # m = nn.Conv2d(in_channels=in_channels * g,
53
- # out_channels=out_channels * g,
54
- # kernel_size=kernel,
55
- # padding=padding,
56
- # stride=stride,
57
- # dilation=dilation,
58
- # groups=g,
59
- # bias=bias)
60
- # x = torch.rand(1, in_channels * g, spatial, spatial)
26
+ class TestOp (JitLlgaTestCase ):
27
+ def test_conv2d (self ):
28
+ for [
29
+ spatial ,
30
+ in_channels ,
31
+ out_channels ,
32
+ kernel ,
33
+ padding ,
34
+ stride ,
35
+ dilation ,
36
+ g ,
37
+ bias
38
+ ] in itertools .product (
39
+ [7 , 8 ],
40
+ [8 , 15 ],
41
+ [7 , 16 ],
42
+ [3 , 4 ],
43
+ [0 , 2 ],
44
+ [1 , 2 ],
45
+ [1 , 2 ],
46
+ [1 , 2 ],
47
+ [True , False ]):
48
+
49
+ m = nn .Conv2d (in_channels = in_channels * g ,
50
+ out_channels = out_channels * g ,
51
+ kernel_size = kernel ,
52
+ padding = padding ,
53
+ stride = stride ,
54
+ dilation = dilation ,
55
+ groups = g ,
56
+ bias = bias )
57
+ x = torch .rand (1 , in_channels * g , spatial , spatial )
58
+
59
+ graph = self .checkQuantizeTrace (m , [x ], atol = 2e-1 , config_name = "conv2d" )
60
+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 4 )
61
+ self .assertFused (graph , ['aten::_convolution' , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
61
62
62
- # graph = self.checkQuantizeTrace(m, x, atol=2e-1)
63
- # self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
64
- # self.assertFused(graph, ['aten::conv2d', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
63
+ patterns = [
64
+ ["aten::quantize_per_tensor" ],
65
+ ["aten::quantize_per_channel" ],
66
+ ["aten::dequantize" , "aten::_convolution" , "aten::quantize_per_tensor" ],
67
+ ["aten::dequantize" ]
68
+ ]
69
+ self .checkPatterns (graph , patterns )
70
+
71
+ def test_linear (self ):
72
+ for bias in [True , False ]:
73
+ x = torch .rand (32 , 28 )
74
+ m = torch .nn .Linear (in_features = 28 , out_features = 64 , bias = bias )
65
75
76
+ graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , config_name = "linear" )
77
+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 4 )
78
+ self .assertFused (graph , ['aten::linear' , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
79
+
80
+ patterns = [
81
+ ["aten::quantize_per_tensor" ],
82
+ ["aten::quantize_per_channel" ],
83
+ ["aten::dequantize" , "aten::linear" , "aten::quantize_per_tensor" ],
84
+ ["aten::dequantize" ]
85
+ ]
86
+ self .checkPatterns (graph , patterns )
87
+
66
88
class TestFusionPattern (JitLlgaTestCase ):
67
89
def test_conv2d_eltwise (self ):
68
90
class M (nn .Module ):
@@ -76,72 +98,193 @@ def forward(self, x):
76
98
x = self .conv1 (x )
77
99
x = self .eltwise (x )
78
100
x = self .conv2 (x )
79
- x = self .eltwise (x )
80
101
return x
81
102
82
103
for eltwise in ['relu' ]: # TODO: ['sigmoid', 'sqrt', 'abs', 'square', 'hardtanh']
83
- for inplace in [False ]:
104
+ for inplace in [False , True ]:
84
105
eltwise_fn_name = eltwise + '_' if inplace else eltwise
85
106
eltwise_fn = get_eltwise_fn (eltwise_fn_name )
86
107
87
108
m = M (eltwise_fn )
88
109
x = torch .rand (1 , 32 , 28 , 28 )
89
110
90
- graph = self .checkQuantizeTrace (m , x , atol = 2e-1 )
111
+ graph = self .checkQuantizeTrace (m , [ x ] , atol = 2e-1 , config_name = "conv2d_eltwise" )
91
112
self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 6 )
92
- self .assertFused (graph , ['aten::conv2d' , 'aten::' + eltwise , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
93
-
94
- # def test_conv2d_bn(self):
95
- # class M(nn.Module):
96
- # def __init__(self):
97
- # super(M, self).__init__()
98
- # self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=False) # TODO: bias=True
99
- # self.bn1 = nn.BatchNorm2d(32)
100
-
101
- # def forward(self, x):
102
- # x = self.conv1(x)
103
- # x = self.bn1(x)
104
- # return x
105
-
106
- # m = M().eval()
107
- # x = torch.rand(1, 32, 28, 28)
113
+ self .assertFused (graph , ['aten::_convolution' , 'aten::' + eltwise , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
114
+
115
+ patterns = [
116
+ ["aten::quantize_per_tensor" ],
117
+ ["aten::quantize_per_channel" ],
118
+ ["aten::dequantize" , "aten::_convolution" , 'aten::' + eltwise , "aten::quantize_per_tensor" ], # inplace op will become outplace op on the JIT graph
119
+ ["aten::quantize_per_channel" ],
120
+ ["aten::dequantize" , "aten::_convolution" , "aten::quantize_per_tensor" ],
121
+ ["aten::dequantize" ]
122
+ ]
123
+ self .checkPatterns (graph , patterns )
124
+
125
+ def test_conv2d_bn (self ):
126
+ class M (nn .Module ):
127
+ def __init__ (self , bias ):
128
+ super (M , self ).__init__ ()
129
+ self .conv1 = nn .Conv2d (32 , 5 , 3 , padding = 1 , bias = False )
130
+ self .bn1 = nn .BatchNorm2d (5 )
131
+
132
+ def forward (self , x ):
133
+ x = self .conv1 (x )
134
+ x = self .bn1 (x )
135
+ return x
136
+ for bias in [False , True ]:
137
+ m = M (bias ).eval ()
138
+ x = torch .rand (1 , 32 , 16 , 16 )
139
+ # TODO: This shape will fail
140
+ # x = torch.rand(1, 32, 28, 28)
141
+
142
+ graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , folding = True , config_name = "conv2d_bn" )
143
+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 4 )
144
+ self .assertFused (graph , ['aten::_convolution' , 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
145
+
146
+ patterns = [
147
+ ["aten::quantize_per_tensor" ],
148
+ ["aten::quantize_per_channel" ],
149
+ ["aten::dequantize" , "aten::_convolution" , "aten::quantize_per_tensor" ],
150
+ ["aten::dequantize" ]
151
+ ]
152
+ self .checkPatterns (graph , patterns )
153
+
154
+ def test_conv2d_bn_relu (self ):
155
+ class M (nn .Module ):
156
+ def __init__ (self ):
157
+ super (M , self ).__init__ ()
158
+ self .conv1 = nn .Conv2d (32 , 32 , 3 , padding = 1 , bias = True )
159
+ self .bn1 = nn .BatchNorm2d (32 )
160
+
161
+ def forward (self , x ):
162
+ x = self .conv1 (x )
163
+ x = self .bn1 (x )
164
+ x = F .relu (x )
165
+ return x
166
+
167
+ m = M ().eval ()
168
+ x = torch .rand (1 , 32 , 28 , 28 )
169
+ graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , folding = True , config_name = "conv2d_bn_relu" )
170
+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 4 )
171
+ self .assertFused (graph , ['aten::_convolution' , 'aten::relu' ,
172
+ 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' , 'aten::dequantize' ])
173
+
174
+ patterns = [
175
+ ["aten::quantize_per_tensor" ],
176
+ ["aten::quantize_per_channel" ],
177
+ ["aten::dequantize" , "aten::_convolution" , "aten::relu" , "aten::quantize_per_tensor" ],
178
+ ["aten::dequantize" ]
179
+ ]
180
+ self .checkPatterns (graph , patterns )
181
+
182
+ def test_linear_eltwise (self ):
183
+ class M (nn .Module ):
184
+ def __init__ (self , eltwise_fn , bias ):
185
+ super (M , self ).__init__ ()
186
+ self .linear = nn .Linear (28 , 64 , bias )
187
+ self .eltwise = eltwise_fn
188
+
189
+ def forward (self , x ):
190
+ x = self .linear (x )
191
+ x = self .eltwise (x )
192
+ return x
193
+
194
+ # TODO: use itertools.product once all combinations is supported
195
+ for [has_bias , eltwise ] in [
196
+ [True , 'relu' ],
197
+ [False , 'relu' ],
198
+ # [True, 'gelu'], # TODO: enable it once linear_gelu default recipe is fixed
199
+ # [False, 'gelu'], # TODO: enable it once linear_gelu default recipe is fixed
200
+ [True , 'sigmoid' ],
201
+ [False , 'sigmoid' ],
202
+ ]:
203
+ eltwise_fn = get_eltwise_fn (eltwise )
204
+ m = M (eltwise_fn , has_bias )
205
+ x = torch .rand (32 , 28 , requires_grad = False )
206
+
207
+ graph = self .checkQuantizeTrace (m , [x ], atol = 1e-1 , config_name = "linear_eltwise" )
208
+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 4 )
209
+ self .assertFused (graph , ['aten::' + eltwise ])
210
+
211
+ patterns = [
212
+ ["aten::quantize_per_tensor" ],
213
+ ["aten::quantize_per_channel" ],
214
+ ["aten::dequantize" , "aten::linear" , "aten::" + eltwise , "aten::quantize_per_tensor" ],
215
+ ["aten::dequantize" ]
216
+ ]
217
+ self .checkPatterns (graph , patterns )
218
+
219
+ def test_conv2d_sum (self ):
220
+ class M (nn .Module ):
221
+ def __init__ (self , bias = False ):
222
+ super (M , self ).__init__ ()
223
+ self .conv1 = nn .Conv2d (32 , 32 , 3 , padding = 1 , bias = bias )
224
+ self .bn1 = nn .BatchNorm2d (32 )
225
+ self .conv2 = nn .Conv2d (32 , 32 , 3 , padding = 1 , bias = bias )
226
+ self .bn2 = nn .BatchNorm2d (32 )
227
+ self .relu = nn .ReLU ()
228
+ self .conv3 = nn .Conv2d (32 , 32 , 3 , padding = 1 , bias = bias )
229
+ self .bn3 = nn .BatchNorm2d (32 )
230
+
231
+ def forward (self , x , y ):
232
+ x = self .conv1 (x )
233
+ x = self .bn1 (x )
234
+ y = self .conv2 (y )
235
+ y = self .bn2 (y )
236
+ z = self .relu (x + y )
237
+ z = self .conv3 (z )
238
+ z = self .bn3 (z )
239
+ return z
240
+
241
+ for bias in [True , False ]:
242
+ m = M (bias ).eval ()
243
+ x = torch .rand (1 , 32 , 16 , 16 , requires_grad = False )
244
+ y = torch .rand (1 , 32 , 16 , 16 , requires_grad = False )
245
+ graph = self .checkQuantizeTrace (m , [x , y ], folding = True , atol = 1e-1 , config_name = "conv2d_sum" )
246
+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 13 ) # TODO: nb FUSION_GROUP=10 when oneDNN support sum post_ops with zps
247
+
248
+ # TODO: check patterns when oneDNN support sum post_ops with zps
249
+ # patterns = [
250
+ # ["aten::quantize_per_tensor"],
251
+ # ["aten::quantize_per_channel"],
252
+ # ["aten::dequantize", "aten::_convolution", "aten::quantize_per_tensor"],
253
+ # ["aten::quantize_per_channel"],
254
+ # ["aten::dequantize", "aten::_convolution", "aten::quantize_per_tensor"],
255
+ # ["aten::quantize_per_channel"],
256
+ # ["aten::dequantize", "aten::_convolution", "aten::relu", "aten::add", "aten::quantize_per_tensor"],
257
+ # ["aten::quantize_per_channel"],
258
+ # ["aten::dequantize", "aten::_convolution", "aten::quantize_per_tensor"],
259
+ # ["aten::dequantize"]
260
+ # ]
261
+ # self.checkPatterns(graph, patterns)
262
+
263
+ class TestModel (JitLlgaTestCase ):
264
+ @skipIfNoTorchVision
265
+ def _test_vision (self , model_name ):
266
+ m = getattr (torchvision .models , model_name )().eval ()
267
+ x = torch .rand (1 , 3 , 224 , 224 ) / 10
268
+
269
+ graph = self .checkQuantizeTrace (m , [x ], atol = 2e-1 , folding = True , config_name = model_name )
108
270
109
- # graph = self.checkQuantizeTrace(m, x, atol=1e-1, folding=True)
110
- # self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 4)
111
- # self.assertFused(graph, ['aten::conv2d', 'aten::quantize_per_tensor', 'aten::quantize_per_channel', 'aten::dequantize'])
112
-
113
-
114
- # class TestModel(JitLlgaTestCase):
115
- # @skipIfNoTorchVision
116
- # def _test_vision(self, model_name):
117
- # m = getattr(torchvision.models, model_name)().eval()
118
- # x = torch.rand(1, 3, 224, 224) / 10
119
-
120
- # graph = self.checkQuantizeTrace(m, x, atol=2e-1, folding=True)
121
- # # self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 116)
122
- # self.assertFused(graph, ['aten::conv2d', 'aten::batch_norm',
123
- # 'aten::relu', 'aten::mm', 'aten::add',
124
- # 'aten::avg_pool2d', 'aten::max_pool2d',
125
- # 'aten::linear'
126
- # 'aten::quantize_per_tensor', 'aten::quantize_per_channel',
127
- # 'aten::dequantize'])
128
- # # self.assertFused(graph, ['aten::conv2d', 'aten::batch_norm',
129
- # # 'aten::relu', 'aten::mm', 'aten::add',
130
- # # 'aten::avg_pool2d', 'aten::max_pool2d',
131
- # # 'aten::quantize_per_tensor', 'aten::quantize_per_channel',
132
- # # 'aten::dequantize'])
133
-
134
-
135
- # for model_name, enabled in [
136
- # ['resnet50', True],
137
- # ]:
138
- # def wrapper(mname):
139
- # @unittest.skipIf(not enabled, 'Disabled')
140
- # def test(self):
141
- # return self._test_vision(mname)
142
- # return test
143
-
144
- # setattr(TestModel, 'test_vision_%s' % model_name, wrapper(model_name))
271
+ # TODO: aten::adaptive_avg_pool2d also need to be fused once backend supported it
272
+ self .assertFused (graph , ['aten::_convolution' , 'aten::relu' ,
273
+ 'aten::max_pool2d' , 'aten::linear'
274
+ 'aten::quantize_per_tensor' , 'aten::quantize_per_channel' ,
275
+ 'aten::dequantize' ])
276
+
277
+
278
+ for model_name , enabled in [
279
+ ['resnet50' , True ],
280
+ ]:
281
+ def wrapper (mname ):
282
+ @unittest .skipIf (not enabled , 'Disabled' )
283
+ def test (self ):
284
+ return self ._test_vision (mname )
285
+ return test
286
+
287
+ setattr (TestModel , 'test_vision_%s' % model_name , wrapper (model_name ))
145
288
146
289
if __name__ == '__main__' :
147
290
run_tests ()
0 commit comments