Skip to content

Commit 750d619

Browse files
authored
[JIT] turn on the profiling executor (#41)
remove convbn folding in JIT path
1 parent fcab70d commit 750d619

File tree

4 files changed

+28
-103
lines changed

4 files changed

+28
-103
lines changed

intel_pytorch_extension_py/ops/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,3 @@
33
from .lstm import *
44
from .interaction import *
55
from .embeddingbag import *
6-
from .jit import *

intel_pytorch_extension_py/ops/jit.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

tests/cpu/test_jit.py

Lines changed: 25 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,6 @@
7676
IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \
7777
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf
7878

79-
torch._C._jit_set_profiling_mode(False)
80-
torch._C._jit_set_profiling_executor(False)
81-
82-
# device = ipex.DEVICE
8379
device = 'cpu:0'
8480
SIZE = 100
8581

@@ -117,6 +113,7 @@ def __init__(self, dim, in_channels, out_channels, **kwargs):
117113
self.conv = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
118114
self.bn1 = bn_module[dim](in_channels, eps=0.001)
119115
self.bn2 = bn_module[dim](out_channels, eps=0.001)
116+
120117
def forward(self, x):
121118
return self.bn2(self.conv(self.bn1(x)))
122119

@@ -140,6 +137,7 @@ def __init__(self, dim, in_channels, out_channels, **kwargs):
140137
torch.manual_seed(seed)
141138
self.conv1 = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
142139
self.conv2 = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
140+
143141
def forward(self, x):
144142
return torch.cat((self.conv1(x),self.conv2(x)))
145143

@@ -159,6 +157,7 @@ def __init__(self, dim, in_channels, out_channels, **kwargs):
159157
seed = 2018
160158
torch.manual_seed(seed)
161159
self.conv = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
160+
162161
def forward(self, x):
163162
return torch.add(F.relu(self.conv(x), inplace=True),self.conv(x))
164163

@@ -169,6 +168,7 @@ def __init__(self, dim, in_channels, out_channels, **kwargs):
169168
torch.manual_seed(seed)
170169
self.conv = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
171170
self.bn = bn_module[dim](out_channels, eps=0.001)
171+
172172
def forward(self, x):
173173
return F.relu(self.bn(self.conv(x)), inplace=True)
174174

@@ -204,6 +204,7 @@ def __init__(self, dim, in_channels, out_channels, dest_shape, **kwargs):
204204
self.dest_shape = dest_shape
205205
self.conv1 = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
206206
self.conv2 = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
207+
207208
def forward(self, x):
208209
a=torch.reshape(self.conv1(x), self.dest_shape)
209210
b=torch.reshape(self.conv2(x), self.dest_shape)
@@ -268,6 +269,7 @@ def __init__(self, in_channels, out_channels,dest_shape, **kwargs):
268269
torch.manual_seed(seed)
269270
self.linear = nn.Linear(in_channels, out_channels, **kwargs)
270271
self.dest_shape = dest_shape
272+
271273
def forward(self, x):
272274
return F.relu(torch.reshape(self.linear(x),self.dest_shape), inplace=True)
273275

@@ -288,6 +290,7 @@ def __init__(self,dim,in_channels, out_channels, **kwargs):
288290
torch.manual_seed(seed)
289291
self.linear = nn.Linear(in_channels, out_channels, **kwargs)
290292
self.bn = bn_module[dim](1, eps=0.001)
293+
291294
def forward(self, x):
292295
return self.bn(self.linear(x))
293296

@@ -299,6 +302,7 @@ def __init__(self,dim,in_channels, out_channels,dest_shape,**kwargs):
299302
self.linear = nn.Linear(in_channels, out_channels, **kwargs)
300303
self.bn = bn_module[dim](1, eps=0.001)
301304
self.dest_shape = dest_shape
305+
302306
def forward(self, x):
303307
return self.bn(torch.reshape(self.linear(x),self.dest_shape))
304308

@@ -409,61 +413,45 @@ class Tester(TestCase):
409413
def _test_output(self, model, x, kind_in_graph=None, kind_not_in_graph=None):
410414
modelName = model.__class__.__name__
411415
core.disable_jit_opt()
412-
# core.disable_mix_bf16_fp32()
413-
414416
model = model.eval()
417+
model = ipex.optimize(model, dtype=torch.float32)
415418
if x.dim() == 4:
416419
x = x.to(memory_format=torch.channels_last)
417420
with torch.no_grad():
418421
result = model(x)
419422

420-
script_model = torch.jit.script(model)
421-
script_model.eval()
422-
423-
trace_model = torch.jit.trace(model, x)
424-
trace_model.eval()
423+
traced_model = torch.jit.trace(model, x)
424+
traced_model.eval()
425425
with torch.no_grad():
426-
sresult = script_model(x)
427-
tresult = trace_model(x)
426+
tresult = traced_model(x)
428427

429-
self.assertEqual(result, sresult)
430428
self.assertEqual(result, tresult)
431429

432430
core.enable_jit_opt()
433-
script_fused_model = torch.jit.script(model)
434431
trace_fused_model = torch.jit.trace(model, x)
435432
with torch.no_grad():
436433
# conv relu fusion, conv sum fusion or conv sum relu fusion
437-
script_graph = script_fused_model.graph_for(x)
438-
# print(script_graph)
439-
fused_sresult = script_fused_model(x)
440-
441434
trace_graph = trace_fused_model.graph_for(x)
442435
# print(trace_graph)
443436
fused_tresult = trace_fused_model(x)
444437

445-
self.assertEqual(result, fused_sresult)
446438
self.assertEqual(result, fused_tresult)
447439

448440
# check if the fused node exists in the graph
449441
if kind_in_graph is not None:
450-
self.assertTrue(any(n.kind() == kind_in_graph for n in script_graph.nodes()))
451442
self.assertTrue(any(n.kind() == kind_in_graph for n in trace_graph.nodes()))
452443

453444
# check if certain node does not exist in the graph
454445
if kind_not_in_graph is not None:
455-
self.assertTrue(all(n.kind() != kind_not_in_graph for n in script_graph.nodes()))
456446
self.assertTrue(all(n.kind() != kind_not_in_graph for n in trace_graph.nodes()))
457447

458448

459449
def _test_output_bf16(self, model, x, kind_in_graph=None, kind_not_in_graph=None, prec=None):
460450
modelName = model.__class__.__name__
461451

462-
# core.enable_auto_dnnl()
463452
core.enable_jit_opt()
464-
# core.enable_mix_bf16_fp32()
465-
466453
model = model.eval()
454+
model = ipex.optimize(model, dtype=torch.bfloat16)
467455
if x.dim() == 4:
468456
x = x.to(memory_format=torch.channels_last)
469457
x2 = x.clone()
@@ -472,37 +460,24 @@ def _test_output_bf16(self, model, x, kind_in_graph=None, kind_not_in_graph=None
472460
with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad():
473461
# bf16, native path
474462
result = model(x)
475-
# script_fused_model = torch.jit.script(copy.deepcopy(model))
476463
trace_fused_model = torch.jit.trace(copy.deepcopy(model), x3)
477-
# bf16, jit script path
478-
# script_graph = script_fused_model.graph_for(x2)
479-
# fused_sresult = script_fused_model(x2)
480-
# bf 16, jit trace path
464+
# bf16, jit trace path
481465
trace_graph = trace_fused_model.graph_for(x3)
482466
fused_tresult = trace_fused_model(x3)
483467

484-
# disable mix_bf16_fp32 when the calculation is done
485-
# to avoid affecting other scripts
486-
# core.disable_mix_bf16_fp32()
487-
488-
# self.assertEqual(fused_sresult, result, prec=prec)
489468
self.assertEqual(fused_tresult, result, prec=prec)
490-
# self.assertEqual(result.dtype, torch.bfloat16)
491-
# self.assertEqual(fused_sresult.dtype, torch.bfloat16)
492469
self.assertEqual(fused_tresult.dtype, torch.bfloat16)
493470

494471
# check if the fused node exists in the graph
495472
if kind_in_graph is not None:
496-
# self.assertTrue(any(n.kind() == kind_in_graph for n in script_graph.nodes()))
497473
self.assertTrue(any(n.kind() == kind_in_graph for n in trace_graph.nodes()))
498474

499475
# check if certain node does not exist in the graph
500476
if kind_not_in_graph is not None:
501-
# self.assertTrue(all(n.kind() != kind_not_in_graph for n in script_graph.nodes()))
502477
self.assertTrue(all(n.kind() != kind_not_in_graph for n in trace_graph.nodes()))
503478

479+
504480
def test_conv2d_fusion(self):
505-
# ipex.core.disable_jit_opt()
506481
batch_size = 32
507482
out_channels = 64
508483
in_channels = 3
@@ -694,7 +669,6 @@ def test_output_conv_sum_2d(self):
694669
kind_in_graph="ipex::conv2d_sum",
695670
prec=0.1)
696671

697-
698672
def test_output_conv_sum_3d(self):
699673
self._test_output(
700674
ConvSum(3, 3, 32, kernel_size=3, stride=1),
@@ -706,7 +680,6 @@ def test_output_conv_sum_3d(self):
706680
kind_in_graph="ipex::conv3d_sum",
707681
prec=0.04)
708682

709-
710683
def test_output_cascaded_conv_bn_sum_relu_2d(self):
711684
self._test_output(
712685
CascadedConvBnSumRelu(2, 3, 64, 32, kernel_size=3, stride=1),
@@ -720,7 +693,6 @@ def test_output_cascaded_conv_bn_sum_relu_2d(self):
720693
kind_not_in_graph="aten::batch_norm",
721694
prec=0.02)
722695

723-
724696
def test_output_cascaded_conv_bn_sum_relu_3d(self):
725697
self._test_output(
726698
CascadedConvBnSumRelu(3, 3, 64, 32, kernel_size=3, stride=1),
@@ -734,7 +706,6 @@ def test_output_cascaded_conv_bn_sum_relu_3d(self):
734706
kind_not_in_graph="aten::batch_norm",
735707
prec=0.02)
736708

737-
738709
def test_output_linear_relu(self):
739710
self._test_output(
740711
LinearRelu(3, 32, bias=True),
@@ -790,28 +761,20 @@ def test_output_linear_gelu(self):
790761
LinearGelu(3, 32, bias=True),
791762
torch.rand(32, 3),
792763
kind_in_graph="ipex::linear_gelu")
793-
# self._test_output_bf16(
794-
# LinearGelu(3, 32, bias=True),
795-
# torch.rand(32, 3),
796-
# kind_in_graph="ipex::linear_gelu",
797-
# prec=5e-3)
764+
self._test_output_bf16(
765+
LinearGelu(3, 32, bias=True),
766+
torch.rand(32, 3),
767+
kind_in_graph="ipex::linear_gelu",
768+
prec=5e-3)
798769
self._test_output(
799770
LinearGelu(3, 32, bias=False),
800771
torch.rand(32, 3),
801772
kind_in_graph="ipex::linear_gelu")
802-
# self._test_output_bf16(
803-
# LinearGelu(3, 32, bias=False),
804-
# torch.rand(32, 3),
805-
# kind_in_graph="ipex::linear_gelu",
806-
# prec=5e-3)
807-
808-
809-
# def test_channel_shuffle(self):
810-
# self._test_output(
811-
# ChannelShuffle(10, 16, 50, 50, 4),
812-
# torch.rand(10, 16, 50, 50),
813-
# kind_in_graph="ipex::shuffle_2d")
814-
773+
self._test_output_bf16(
774+
LinearGelu(3, 32, bias=False),
775+
torch.rand(32, 3),
776+
kind_in_graph="ipex::linear_gelu",
777+
prec=5e-3)
815778

816779
def test_jit_function(self):
817780
# test hool trace and script can works for function
@@ -840,20 +803,7 @@ def test_jit_conv_sum_in_diff_block(self):
840803
torch.rand(32, 3, 64, 64),
841804
kind_not_in_graph="ipex::conv2d_sum")
842805

843-
# def test_manmually_fused_linear_relu(self):
844-
# m = LinearRelu(3, 32, bias=True).eval()
845-
# x = torch.rand(32, 3)
846-
# with torch.no_grad():
847-
# result = m(x)
848-
# fused_m = ipex.LinearRelu(3, 32)
849-
# fused_m.weight = m.linear.weight
850-
# fused_m.bias = m.linear.bias
851-
# with torch.no_grad():
852-
# fused_result = fused_m(x)
853-
# self.assertEqual(fused_result, result)
854-
855806

856807
if __name__ == '__main__':
857808
torch.manual_seed(2020)
858-
# core.enable_auto_dnnl()
859809
test = unittest.main()

torch_ipex/csrc/jit/fusion_pass.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <torch/csrc/jit/ir/alias_analysis.h>
1212
#include <torch/csrc/jit/passes/constant_propagation.h>
1313
#include <torch/csrc/jit/frontend/error_report.h>
14+
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
1415

1516
using namespace torch::jit;
1617

@@ -304,6 +305,8 @@ OpFuser::RuleTab OpFuser::dnnlRules = {
304305
};
305306

306307
void FusionPass(std::shared_ptr<Graph> &graph) {
308+
RemoveProfileNodesAndSpecializeTypes(graph);
309+
RemoveTensorTypeSpecializations(graph);
307310
// Replace _convolution with conv2d or conv3d
308311
graph_rewrite::replaceConvolutionWithAtenConv(graph);
309312

0 commit comments

Comments
 (0)