@@ -131,7 +131,7 @@ def __init__(self, dim, in_channels, out_channels, dest_shape, **kwargs):
131
131
132
132
def forward (self , x ):
133
133
conv_output = self .conv (x )
134
- return self .bn (torch .reshape (conv_output , self .dest_shape ))
134
+ return self .bn (torch .reshape (conv_output , self .dest_shape ))
135
135
136
136
class Conv_Conv_Concat (nn .Module ):
137
137
def __init__ (self , dim , in_channels , out_channels , ** kwargs ):
@@ -240,7 +240,7 @@ def __init__(self, in_channels, out_channels, **kwargs):
240
240
241
241
def forward (self , x ):
242
242
return F .relu (self .linear (x ), inplace = True )
243
-
243
+
244
244
class LinearGelu (nn .Module ):
245
245
def __init__ (self , in_channels , out_channels , ** kwargs ):
246
246
super (LinearGelu , self ).__init__ ()
@@ -308,7 +308,7 @@ def __init__(self, dim, in_channels, out_channels, **kwargs):
308
308
seed = 2018
309
309
torch .manual_seed (seed )
310
310
self .conv = conv_module [dim ](in_channels , out_channels , bias = False , ** kwargs )
311
-
311
+
312
312
def forward (self , x ):
313
313
y = self .conv (x )
314
314
if y .size (1 ) != x .size (1 ):
@@ -754,36 +754,36 @@ def test_output_linear_relu(self):
754
754
torch .rand (32 , 3 ),
755
755
kind_in_graph = "ipex::linear_relu" ,
756
756
prec = 0.02 )
757
-
757
+
758
758
def test_output_linear_add (self ):
759
759
self ._test_output (
760
760
LinearAdd (3 , 32 , bias = True ),
761
761
torch .rand (32 , 3 ),
762
- kind_in_graph = "aten ::linear" )
762
+ kind_in_graph = "ipex ::linear" )
763
763
764
764
def test_output_linear_reshape_relu (self ):
765
765
self ._test_output (
766
766
Linear_Reshape_Relu (3 , 32 ,(64 ,16 ),bias = True ),
767
767
torch .rand (32 , 3 ),
768
- kind_in_graph = "aten ::linear" )
768
+ kind_in_graph = "ipex ::linear" )
769
769
770
770
def test_output_linear_sigmoid (self ):
771
771
self ._test_output (
772
772
LinearSigmoid (3 , 32 , bias = True ),
773
773
torch .rand (32 , 3 ),
774
- kind_in_graph = "aten ::linear" )
774
+ kind_in_graph = "ipex ::linear" )
775
775
776
776
def test_output_linear_bn (self ):
777
777
self ._test_output (
778
778
LinearBn (2 ,32 , 32 , bias = True ),
779
779
torch .rand (1 , 1 , 32 , 32 ),
780
- kind_in_graph = "aten ::linear" )
780
+ kind_in_graph = "ipex ::linear" )
781
781
782
782
def test_output_linear_reshape_bn (self ):
783
783
self ._test_output (
784
784
Linear_Reshape_Bn (2 ,32 , 32 ,(1 ,1 ,64 ,16 ),bias = True ),
785
785
torch .rand (1 , 1 , 32 , 32 ),
786
- kind_in_graph = "aten ::linear" )
786
+ kind_in_graph = "ipex ::linear" )
787
787
788
788
def test_output_linear_gelu (self ):
789
789
self ._test_output (
@@ -856,4 +856,4 @@ def test_jit_conv_sum_in_diff_block(self):
856
856
if __name__ == '__main__' :
857
857
torch .manual_seed (2020 )
858
858
# core.enable_auto_dnnl()
859
- test = unittest .main ()
859
+ test = unittest .main ()
0 commit comments