Skip to content

Commit a6dda7f

Browse files
fix jit test issue (#32)
1 parent 8a5f2ee commit a6dda7f

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

tests/cpu/test_jit.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(self, dim, in_channels, out_channels, dest_shape, **kwargs):
131131

132132
def forward(self, x):
133133
conv_output = self.conv(x)
134-
return self.bn(torch.reshape(conv_output, self.dest_shape))
134+
return self.bn(torch.reshape(conv_output, self.dest_shape))
135135

136136
class Conv_Conv_Concat(nn.Module):
137137
def __init__(self, dim, in_channels, out_channels, **kwargs):
@@ -240,7 +240,7 @@ def __init__(self, in_channels, out_channels, **kwargs):
240240

241241
def forward(self, x):
242242
return F.relu(self.linear(x), inplace=True)
243-
243+
244244
class LinearGelu(nn.Module):
245245
def __init__(self, in_channels, out_channels, **kwargs):
246246
super(LinearGelu, self).__init__()
@@ -308,7 +308,7 @@ def __init__(self, dim, in_channels, out_channels, **kwargs):
308308
seed = 2018
309309
torch.manual_seed(seed)
310310
self.conv = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
311-
311+
312312
def forward(self, x):
313313
y = self.conv(x)
314314
if y.size(1) != x.size(1):
@@ -754,36 +754,36 @@ def test_output_linear_relu(self):
754754
torch.rand(32, 3),
755755
kind_in_graph="ipex::linear_relu",
756756
prec=0.02)
757-
757+
758758
def test_output_linear_add(self):
759759
self._test_output(
760760
LinearAdd(3, 32, bias=True),
761761
torch.rand(32, 3),
762-
kind_in_graph="aten::linear")
762+
kind_in_graph="ipex::linear")
763763

764764
def test_output_linear_reshape_relu(self):
765765
self._test_output(
766766
Linear_Reshape_Relu(3, 32,(64,16),bias=True),
767767
torch.rand(32, 3),
768-
kind_in_graph="aten::linear")
768+
kind_in_graph="ipex::linear")
769769

770770
def test_output_linear_sigmoid(self):
771771
self._test_output(
772772
LinearSigmoid(3, 32, bias=True),
773773
torch.rand(32, 3),
774-
kind_in_graph="aten::linear")
774+
kind_in_graph="ipex::linear")
775775

776776
def test_output_linear_bn(self):
777777
self._test_output(
778778
LinearBn(2 ,32, 32, bias=True),
779779
torch.rand(1, 1, 32, 32),
780-
kind_in_graph="aten::linear")
780+
kind_in_graph="ipex::linear")
781781

782782
def test_output_linear_reshape_bn(self):
783783
self._test_output(
784784
Linear_Reshape_Bn(2 ,32, 32,(1,1,64,16),bias=True),
785785
torch.rand(1, 1, 32, 32),
786-
kind_in_graph="aten::linear")
786+
kind_in_graph="ipex::linear")
787787

788788
def test_output_linear_gelu(self):
789789
self._test_output(
@@ -856,4 +856,4 @@ def test_jit_conv_sum_in_diff_block(self):
856856
if __name__ == '__main__':
857857
torch.manual_seed(2020)
858858
# core.enable_auto_dnnl()
859-
test = unittest.main()
859+
test = unittest.main()

0 commit comments

Comments
 (0)