1
- import torch
2
1
import unittest
3
2
import itertools
3
+ from functools import wraps
4
+
5
+ import torch
4
6
import torch .nn as nn
5
7
import torch .nn .functional as F
6
8
from test_jit_llga_utils import JitLlgaTestCase , run_tests , LLGA_FUSION_GROUP
7
9
from torch .testing ._internal .common_utils import TEST_SCIPY
8
10
11
+ import intel_pytorch_extension as ipex
12
+
9
13
try :
10
14
import torchvision
11
15
HAS_TORCHVISION = True
@@ -23,7 +27,21 @@ def get_eltwise_fn(name):
23
27
else :
24
28
raise NameError ('Eltwise function %s not found' % name )
25
29
30
+ # For LLGA UT, disable the PyTorch profiling executor and the IPEX JIT opt
31
+ def llga_test_env (func ):
32
+ @wraps (func )
33
+ def wrapTheFunction (* args ):
34
+ torch ._C ._jit_set_profiling_mode (False )
35
+ torch ._C ._jit_set_profiling_executor (False )
36
+ ipex .core .disable_jit_opt ()
37
+ func (* args )
38
+ ipex .core .enable_jit_opt ()
39
+ torch ._C ._jit_set_profiling_mode (True )
40
+ torch ._C ._jit_set_profiling_executor (True )
41
+ return wrapTheFunction
42
+
26
43
class TestOp (JitLlgaTestCase ):
44
+ @llga_test_env
27
45
def test_conv2d (self ):
28
46
for [
29
47
spatial ,
@@ -68,6 +86,7 @@ def test_conv2d(self):
68
86
]
69
87
self .checkPatterns (graph , patterns )
70
88
89
+ @llga_test_env
71
90
def test_linear (self ):
72
91
for bias in [True , False ]:
73
92
x = torch .rand (32 , 28 )
@@ -86,6 +105,7 @@ def test_linear(self):
86
105
self .checkPatterns (graph , patterns )
87
106
88
107
class TestFusionPattern (JitLlgaTestCase ):
108
+ @llga_test_env
89
109
def test_conv2d_eltwise (self ):
90
110
class M (nn .Module ):
91
111
def __init__ (self , eltwise_fn ):
@@ -122,6 +142,7 @@ def forward(self, x):
122
142
]
123
143
self .checkPatterns (graph , patterns )
124
144
145
+ @llga_test_env
125
146
def test_conv2d_bn (self ):
126
147
class M (nn .Module ):
127
148
def __init__ (self , bias ):
@@ -151,6 +172,7 @@ def forward(self, x):
151
172
]
152
173
self .checkPatterns (graph , patterns )
153
174
175
+ @llga_test_env
154
176
def test_conv2d_bn_relu (self ):
155
177
class M (nn .Module ):
156
178
def __init__ (self ):
@@ -179,6 +201,7 @@ def forward(self, x):
179
201
]
180
202
self .checkPatterns (graph , patterns )
181
203
204
+ @llga_test_env
182
205
def test_linear_eltwise (self ):
183
206
class M (nn .Module ):
184
207
def __init__ (self , eltwise_fn , bias ):
@@ -216,6 +239,7 @@ def forward(self, x):
216
239
]
217
240
self .checkPatterns (graph , patterns )
218
241
242
+ @llga_test_env
219
243
def test_conv2d_sum (self ):
220
244
class M (nn .Module ):
221
245
def __init__ (self , bias = False ):
@@ -260,6 +284,7 @@ def forward(self, x, y):
260
284
# ]
261
285
# self.checkPatterns(graph, patterns)
262
286
287
+ @llga_test_env
263
288
def test_linear_dropout_sum (self ):
264
289
class M (nn .Module ):
265
290
def __init__ (self ):
@@ -295,6 +320,7 @@ def forward(self, x, y):
295
320
296
321
class TestModel (JitLlgaTestCase ):
297
322
@skipIfNoTorchVision
323
+ @llga_test_env
298
324
def _test_vision (self , model_name ):
299
325
m = getattr (torchvision .models , model_name )().eval ()
300
326
x = torch .rand (1 , 3 , 224 , 224 ) / 10
0 commit comments