Skip to content

Commit 31f3014

Browse files
authored
[LLGA] use decorator to change settings for LLGA UT (#31)
1 parent a6dda7f commit 31f3014

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

tests/cpu/test_jit_llga_quantization_fuser.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
import torch
21
import unittest
32
import itertools
3+
from functools import wraps
4+
5+
import torch
46
import torch.nn as nn
57
import torch.nn.functional as F
68
from test_jit_llga_utils import JitLlgaTestCase, run_tests, LLGA_FUSION_GROUP
79
from torch.testing._internal.common_utils import TEST_SCIPY
810

11+
import intel_pytorch_extension as ipex
12+
913
try:
1014
import torchvision
1115
HAS_TORCHVISION = True
@@ -23,7 +27,21 @@ def get_eltwise_fn(name):
2327
else:
2428
raise NameError('Eltwise function %s not found' % name)
2529

30+
# For LLGA UT, disable the PyTorch profiling executor and the IPEX JIT opt
31+
def llga_test_env(func):
32+
@wraps(func)
33+
def wrapTheFunction(*args):
34+
torch._C._jit_set_profiling_mode(False)
35+
torch._C._jit_set_profiling_executor(False)
36+
ipex.core.disable_jit_opt()
37+
func(*args)
38+
ipex.core.enable_jit_opt()
39+
torch._C._jit_set_profiling_mode(True)
40+
torch._C._jit_set_profiling_executor(True)
41+
return wrapTheFunction
42+
2643
class TestOp(JitLlgaTestCase):
44+
@llga_test_env
2745
def test_conv2d(self):
2846
for [
2947
spatial,
@@ -68,6 +86,7 @@ def test_conv2d(self):
6886
]
6987
self.checkPatterns(graph, patterns)
7088

89+
@llga_test_env
7190
def test_linear(self):
7291
for bias in [True, False]:
7392
x = torch.rand(32, 28)
@@ -86,6 +105,7 @@ def test_linear(self):
86105
self.checkPatterns(graph, patterns)
87106

88107
class TestFusionPattern(JitLlgaTestCase):
108+
@llga_test_env
89109
def test_conv2d_eltwise(self):
90110
class M(nn.Module):
91111
def __init__(self, eltwise_fn):
@@ -122,6 +142,7 @@ def forward(self, x):
122142
]
123143
self.checkPatterns(graph, patterns)
124144

145+
@llga_test_env
125146
def test_conv2d_bn(self):
126147
class M(nn.Module):
127148
def __init__(self, bias):
@@ -151,6 +172,7 @@ def forward(self, x):
151172
]
152173
self.checkPatterns(graph, patterns)
153174

175+
@llga_test_env
154176
def test_conv2d_bn_relu(self):
155177
class M(nn.Module):
156178
def __init__(self):
@@ -179,6 +201,7 @@ def forward(self, x):
179201
]
180202
self.checkPatterns(graph, patterns)
181203

204+
@llga_test_env
182205
def test_linear_eltwise(self):
183206
class M(nn.Module):
184207
def __init__(self, eltwise_fn, bias):
@@ -216,6 +239,7 @@ def forward(self, x):
216239
]
217240
self.checkPatterns(graph, patterns)
218241

242+
@llga_test_env
219243
def test_conv2d_sum(self):
220244
class M(nn.Module):
221245
def __init__(self, bias=False):
@@ -260,6 +284,7 @@ def forward(self, x, y):
260284
# ]
261285
# self.checkPatterns(graph, patterns)
262286

287+
@llga_test_env
263288
def test_linear_dropout_sum(self):
264289
class M(nn.Module):
265290
def __init__(self):
@@ -295,6 +320,7 @@ def forward(self, x, y):
295320

296321
class TestModel(JitLlgaTestCase):
297322
@skipIfNoTorchVision
323+
@llga_test_env
298324
def _test_vision(self, model_name):
299325
m = getattr(torchvision.models, model_name)().eval()
300326
x = torch.rand(1, 3, 224, 224) / 10

tests/cpu/test_jit_llga_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,6 @@
1414

1515
LLGA_FUSION_GROUP = 'ipex::LlgaFusionGroup'
1616

17-
# disable PyTorch jit profiling
18-
torch._C._jit_set_profiling_mode(False)
19-
torch._C._jit_set_profiling_executor(False)
20-
21-
# disbale ipex jit optimization for fp32 and bf16 path
22-
ipex.core.disable_jit_opt()
23-
2417
def all_backward_graphs(module):
2518
ge_state = module.get_debug_state()
2619
fwd_plan = get_execution_plan(ge_state)

0 commit comments

Comments
 (0)