76
76
IS_SANDCASTLE , load_tests , brute_pdist , brute_cdist , slowTest , \
77
77
skipCUDANonDefaultStreamIf , skipCUDAMemoryLeakCheckIf
78
78
79
- torch ._C ._jit_set_profiling_mode (False )
80
- torch ._C ._jit_set_profiling_executor (False )
81
-
82
- # device = ipex.DEVICE
83
79
device = 'cpu:0'
84
80
SIZE = 100
85
81
@@ -117,6 +113,7 @@ def __init__(self, dim, in_channels, out_channels, **kwargs):
117
113
self .conv = conv_module [dim ](in_channels , out_channels , bias = False , ** kwargs )
118
114
self .bn1 = bn_module [dim ](in_channels , eps = 0.001 )
119
115
self .bn2 = bn_module [dim ](out_channels , eps = 0.001 )
116
+
120
117
def forward (self , x ):
121
118
return self .bn2 (self .conv (self .bn1 (x )))
122
119
@@ -140,6 +137,7 @@ def __init__(self, dim, in_channels, out_channels, **kwargs):
140
137
torch .manual_seed (seed )
141
138
self .conv1 = conv_module [dim ](in_channels , out_channels , bias = False , ** kwargs )
142
139
self .conv2 = conv_module [dim ](in_channels , out_channels , bias = False , ** kwargs )
140
+
143
141
def forward (self , x ):
144
142
return torch .cat ((self .conv1 (x ),self .conv2 (x )))
145
143
@@ -159,6 +157,7 @@ def __init__(self, dim, in_channels, out_channels, **kwargs):
159
157
seed = 2018
160
158
torch .manual_seed (seed )
161
159
self .conv = conv_module [dim ](in_channels , out_channels , bias = False , ** kwargs )
160
+
162
161
def forward (self , x ):
163
162
return torch .add (F .relu (self .conv (x ), inplace = True ),self .conv (x ))
164
163
@@ -169,6 +168,7 @@ def __init__(self, dim, in_channels, out_channels, **kwargs):
169
168
torch .manual_seed (seed )
170
169
self .conv = conv_module [dim ](in_channels , out_channels , bias = False , ** kwargs )
171
170
self .bn = bn_module [dim ](out_channels , eps = 0.001 )
171
+
172
172
def forward (self , x ):
173
173
return F .relu (self .bn (self .conv (x )), inplace = True )
174
174
@@ -204,6 +204,7 @@ def __init__(self, dim, in_channels, out_channels, dest_shape, **kwargs):
204
204
self .dest_shape = dest_shape
205
205
self .conv1 = conv_module [dim ](in_channels , out_channels , bias = False , ** kwargs )
206
206
self .conv2 = conv_module [dim ](in_channels , out_channels , bias = False , ** kwargs )
207
+
207
208
def forward (self , x ):
208
209
a = torch .reshape (self .conv1 (x ), self .dest_shape )
209
210
b = torch .reshape (self .conv2 (x ), self .dest_shape )
@@ -268,6 +269,7 @@ def __init__(self, in_channels, out_channels,dest_shape, **kwargs):
268
269
torch .manual_seed (seed )
269
270
self .linear = nn .Linear (in_channels , out_channels , ** kwargs )
270
271
self .dest_shape = dest_shape
272
+
271
273
def forward (self , x ):
272
274
return F .relu (torch .reshape (self .linear (x ),self .dest_shape ), inplace = True )
273
275
@@ -288,6 +290,7 @@ def __init__(self,dim,in_channels, out_channels, **kwargs):
288
290
torch .manual_seed (seed )
289
291
self .linear = nn .Linear (in_channels , out_channels , ** kwargs )
290
292
self .bn = bn_module [dim ](1 , eps = 0.001 )
293
+
291
294
def forward (self , x ):
292
295
return self .bn (self .linear (x ))
293
296
@@ -299,6 +302,7 @@ def __init__(self,dim,in_channels, out_channels,dest_shape,**kwargs):
299
302
self .linear = nn .Linear (in_channels , out_channels , ** kwargs )
300
303
self .bn = bn_module [dim ](1 , eps = 0.001 )
301
304
self .dest_shape = dest_shape
305
+
302
306
def forward (self , x ):
303
307
return self .bn (torch .reshape (self .linear (x ),self .dest_shape ))
304
308
@@ -409,61 +413,45 @@ class Tester(TestCase):
409
413
def _test_output (self , model , x , kind_in_graph = None , kind_not_in_graph = None ):
410
414
modelName = model .__class__ .__name__
411
415
core .disable_jit_opt ()
412
- # core.disable_mix_bf16_fp32()
413
-
414
416
model = model .eval ()
417
+ model = ipex .optimize (model , dtype = torch .float32 )
415
418
if x .dim () == 4 :
416
419
x = x .to (memory_format = torch .channels_last )
417
420
with torch .no_grad ():
418
421
result = model (x )
419
422
420
- script_model = torch .jit .script (model )
421
- script_model .eval ()
422
-
423
- trace_model = torch .jit .trace (model , x )
424
- trace_model .eval ()
423
+ traced_model = torch .jit .trace (model , x )
424
+ traced_model .eval ()
425
425
with torch .no_grad ():
426
- sresult = script_model (x )
427
- tresult = trace_model (x )
426
+ tresult = traced_model (x )
428
427
429
- self .assertEqual (result , sresult )
430
428
self .assertEqual (result , tresult )
431
429
432
430
core .enable_jit_opt ()
433
- script_fused_model = torch .jit .script (model )
434
431
trace_fused_model = torch .jit .trace (model , x )
435
432
with torch .no_grad ():
436
433
# conv relu fusion, conv sum fusion or conv sum relu fusion
437
- script_graph = script_fused_model .graph_for (x )
438
- # print(script_graph)
439
- fused_sresult = script_fused_model (x )
440
-
441
434
trace_graph = trace_fused_model .graph_for (x )
442
435
# print(trace_graph)
443
436
fused_tresult = trace_fused_model (x )
444
437
445
- self .assertEqual (result , fused_sresult )
446
438
self .assertEqual (result , fused_tresult )
447
439
448
440
# check if the fused node exists in the graph
449
441
if kind_in_graph is not None :
450
- self .assertTrue (any (n .kind () == kind_in_graph for n in script_graph .nodes ()))
451
442
self .assertTrue (any (n .kind () == kind_in_graph for n in trace_graph .nodes ()))
452
443
453
444
# check if certain node does not exist in the graph
454
445
if kind_not_in_graph is not None :
455
- self .assertTrue (all (n .kind () != kind_not_in_graph for n in script_graph .nodes ()))
456
446
self .assertTrue (all (n .kind () != kind_not_in_graph for n in trace_graph .nodes ()))
457
447
458
448
459
449
def _test_output_bf16 (self , model , x , kind_in_graph = None , kind_not_in_graph = None , prec = None ):
460
450
modelName = model .__class__ .__name__
461
451
462
- # core.enable_auto_dnnl()
463
452
core .enable_jit_opt ()
464
- # core.enable_mix_bf16_fp32()
465
-
466
453
model = model .eval ()
454
+ model = ipex .optimize (model , dtype = torch .bfloat16 )
467
455
if x .dim () == 4 :
468
456
x = x .to (memory_format = torch .channels_last )
469
457
x2 = x .clone ()
@@ -472,37 +460,24 @@ def _test_output_bf16(self, model, x, kind_in_graph=None, kind_not_in_graph=None
472
460
with ipex .amp .autocast (enabled = True , configure = ipex .conf .AmpConf (torch .bfloat16 )), torch .no_grad ():
473
461
# bf16, native path
474
462
result = model (x )
475
- # script_fused_model = torch.jit.script(copy.deepcopy(model))
476
463
trace_fused_model = torch .jit .trace (copy .deepcopy (model ), x3 )
477
- # bf16, jit script path
478
- # script_graph = script_fused_model.graph_for(x2)
479
- # fused_sresult = script_fused_model(x2)
480
- # bf 16, jit trace path
464
+ # bf16, jit trace path
481
465
trace_graph = trace_fused_model .graph_for (x3 )
482
466
fused_tresult = trace_fused_model (x3 )
483
467
484
- # disable mix_bf16_fp32 when the calculation is done
485
- # to avoid affecting other scripts
486
- # core.disable_mix_bf16_fp32()
487
-
488
- # self.assertEqual(fused_sresult, result, prec=prec)
489
468
self .assertEqual (fused_tresult , result , prec = prec )
490
- # self.assertEqual(result.dtype, torch.bfloat16)
491
- # self.assertEqual(fused_sresult.dtype, torch.bfloat16)
492
469
self .assertEqual (fused_tresult .dtype , torch .bfloat16 )
493
470
494
471
# check if the fused node exists in the graph
495
472
if kind_in_graph is not None :
496
- # self.assertTrue(any(n.kind() == kind_in_graph for n in script_graph.nodes()))
497
473
self .assertTrue (any (n .kind () == kind_in_graph for n in trace_graph .nodes ()))
498
474
499
475
# check if certain node does not exist in the graph
500
476
if kind_not_in_graph is not None :
501
- # self.assertTrue(all(n.kind() != kind_not_in_graph for n in script_graph.nodes()))
502
477
self .assertTrue (all (n .kind () != kind_not_in_graph for n in trace_graph .nodes ()))
503
478
479
+
504
480
def test_conv2d_fusion (self ):
505
- # ipex.core.disable_jit_opt()
506
481
batch_size = 32
507
482
out_channels = 64
508
483
in_channels = 3
@@ -694,7 +669,6 @@ def test_output_conv_sum_2d(self):
694
669
kind_in_graph = "ipex::conv2d_sum" ,
695
670
prec = 0.1 )
696
671
697
-
698
672
def test_output_conv_sum_3d (self ):
699
673
self ._test_output (
700
674
ConvSum (3 , 3 , 32 , kernel_size = 3 , stride = 1 ),
@@ -706,7 +680,6 @@ def test_output_conv_sum_3d(self):
706
680
kind_in_graph = "ipex::conv3d_sum" ,
707
681
prec = 0.04 )
708
682
709
-
710
683
def test_output_cascaded_conv_bn_sum_relu_2d (self ):
711
684
self ._test_output (
712
685
CascadedConvBnSumRelu (2 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
@@ -720,7 +693,6 @@ def test_output_cascaded_conv_bn_sum_relu_2d(self):
720
693
kind_not_in_graph = "aten::batch_norm" ,
721
694
prec = 0.02 )
722
695
723
-
724
696
def test_output_cascaded_conv_bn_sum_relu_3d (self ):
725
697
self ._test_output (
726
698
CascadedConvBnSumRelu (3 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
@@ -734,7 +706,6 @@ def test_output_cascaded_conv_bn_sum_relu_3d(self):
734
706
kind_not_in_graph = "aten::batch_norm" ,
735
707
prec = 0.02 )
736
708
737
-
738
709
def test_output_linear_relu (self ):
739
710
self ._test_output (
740
711
LinearRelu (3 , 32 , bias = True ),
@@ -790,28 +761,20 @@ def test_output_linear_gelu(self):
790
761
LinearGelu (3 , 32 , bias = True ),
791
762
torch .rand (32 , 3 ),
792
763
kind_in_graph = "ipex::linear_gelu" )
793
- # self._test_output_bf16(
794
- # LinearGelu(3, 32, bias=True),
795
- # torch.rand(32, 3),
796
- # kind_in_graph="ipex::linear_gelu",
797
- # prec=5e-3)
764
+ self ._test_output_bf16 (
765
+ LinearGelu (3 , 32 , bias = True ),
766
+ torch .rand (32 , 3 ),
767
+ kind_in_graph = "ipex::linear_gelu" ,
768
+ prec = 5e-3 )
798
769
self ._test_output (
799
770
LinearGelu (3 , 32 , bias = False ),
800
771
torch .rand (32 , 3 ),
801
772
kind_in_graph = "ipex::linear_gelu" )
802
- # self._test_output_bf16(
803
- # LinearGelu(3, 32, bias=False),
804
- # torch.rand(32, 3),
805
- # kind_in_graph="ipex::linear_gelu",
806
- # prec=5e-3)
807
-
808
-
809
- # def test_channel_shuffle(self):
810
- # self._test_output(
811
- # ChannelShuffle(10, 16, 50, 50, 4),
812
- # torch.rand(10, 16, 50, 50),
813
- # kind_in_graph="ipex::shuffle_2d")
814
-
773
+ self ._test_output_bf16 (
774
+ LinearGelu (3 , 32 , bias = False ),
775
+ torch .rand (32 , 3 ),
776
+ kind_in_graph = "ipex::linear_gelu" ,
777
+ prec = 5e-3 )
815
778
816
779
def test_jit_function (self ):
817
780
# test hool trace and script can works for function
@@ -840,20 +803,7 @@ def test_jit_conv_sum_in_diff_block(self):
840
803
torch .rand (32 , 3 , 64 , 64 ),
841
804
kind_not_in_graph = "ipex::conv2d_sum" )
842
805
843
- # def test_manmually_fused_linear_relu(self):
844
- # m = LinearRelu(3, 32, bias=True).eval()
845
- # x = torch.rand(32, 3)
846
- # with torch.no_grad():
847
- # result = m(x)
848
- # fused_m = ipex.LinearRelu(3, 32)
849
- # fused_m.weight = m.linear.weight
850
- # fused_m.bias = m.linear.bias
851
- # with torch.no_grad():
852
- # fused_result = fused_m(x)
853
- # self.assertEqual(fused_result, result)
854
-
855
806
856
807
if __name__ == '__main__' :
857
808
torch .manual_seed (2020 )
858
- # core.enable_auto_dnnl()
859
809
test = unittest .main ()
0 commit comments