Skip to content

Commit b658b7e

Browse files
add UT for autocast with jit trace (#21)
1 parent aa597e8 commit b658b7e

File tree

1 file changed

+61
-30
lines changed

1 file changed

+61
-30
lines changed

tests/cpu/test_autocast.py

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import intel_pytorch_extension as ipex
55
from common_utils import TestCase
66
import time, sys
7+
from torch.testing._core import _get_default_tolerance
78

89
def get_rand_seed():
910
return int(time.time() * 1000000000)
@@ -72,36 +73,66 @@ def test_conv2d_backward(self):
7273
self.assertEqual(in_autocast.grad.dtype, torch.float)
7374
self.assertEqual(_in_cpu.grad, in_autocast.grad, 1e-2)
7475

75-
class SimpleNet(torch.nn.Module):
76-
def __init__(self):
77-
super(SimpleNet, self).__init__()
78-
self.conv = torch.nn.Conv2d(3, 16, (3, 3), stride=(2, 2), padding=(3, 3), bias=False)
79-
self.bn = torch.nn.BatchNorm2d(16)
80-
self.relu = torch.nn.ReLU(inplace=True)
81-
82-
def forward(self, x):
83-
x = self.conv(x)
84-
x = self.bn(x)
85-
x = self.relu(x)
86-
return x
87-
88-
class TestSimpleNet(TestCase):
89-
def test_generate_jit_trace_model(self):
90-
rand_seed = int(get_rand_seed())
91-
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
92-
torch.manual_seed(rand_seed)
93-
94-
model = SimpleNet()
95-
model.eval()
96-
#ipex.core.disable_jit_opt()
97-
x = torch.rand((1, 3, 224, 224))
98-
with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad():
99-
traced_model = torch.jit.trace(model, x)
100-
with torch.no_grad():
101-
y = traced_model(x)
102-
#print(traced_model.graph_for(x))
103-
#ipex.core.enable_jit_opt()
104-
self.assertEqual(y.dtype, torch.float) #conv whitelist, bn blacklist, relu fallthrough
76+
class TestAutocastWithJit(TestCase):
77+
def setUp(self):
78+
super(TestAutocastWithJit, self).setUp()
79+
from test_jit import Conv_Bn_Relu, BatchNorm_Conv_BatchNorm, ConvBatchNorm_Fixed, ConvReshapeBatchNorm,\
80+
CascadedConvBnSumRelu, LinearBn, Linear_Reshape_Bn
81+
self.models = [Conv_Bn_Relu(2, 3, 32, kernel_size=3, stride=1), BatchNorm_Conv_BatchNorm(2, 3, 32, kernel_size=3, stride=1),\
82+
ConvBatchNorm_Fixed(2, 3, 32, kernel_size=3, stride=1), ConvBatchNorm_Fixed(3, 3, 32, kernel_size=3, stride=1),\
83+
ConvReshapeBatchNorm(2, 3, 32, (64, 16, 62, 62), kernel_size=3, stride=1),\
84+
CascadedConvBnSumRelu(2, 3, 64, 32, kernel_size=3, stride=1),\
85+
LinearBn(2 ,32, 32, bias=True),\
86+
Linear_Reshape_Bn(2 ,32, 32,(1,1,64,16),bias=True)]
87+
self.inputs = [torch.randn(32, 3, 64, 64), torch.randn(32, 3, 64, 64),\
88+
torch.randn(32, 3, 64, 64), torch.randn(32, 3, 32, 32, 32),\
89+
torch.randn(32, 3, 64, 64),\
90+
torch.rand(32, 3, 64, 64),\
91+
torch.rand(1, 1, 32, 32),\
92+
torch.rand(1, 1, 32, 32)]
93+
94+
def test_generate_autocast_jit_trace_model(self):
95+
def test_generate_autocast_jit_trace_model(model, x):
96+
model.eval()
97+
ipex.core.disable_jit_opt()
98+
with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad():
99+
traced_model = torch.jit.trace(model, x)
100+
ipex.core.enable_jit_opt()
101+
with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad():
102+
traced_model2 = torch.jit.trace(model, x.clone())
103+
for i in range(self.models.__len__()):
104+
test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
105+
106+
def test_nchw_autocast_jit_trace_model(self):
107+
def test_nchw_autocast_jit_trace_model(model, x):
108+
model.eval()
109+
ipex.core.disable_jit_opt()
110+
with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad():
111+
traced_model = torch.jit.trace(model, x)
112+
with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad():
113+
y = traced_model(x.clone())
114+
y2 = model(x.clone())
115+
ipex.core.enable_jit_opt()
116+
torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-05, atol=_get_default_tolerance(y, y2)[1])
117+
for i in range(self.models.__len__()):
118+
test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
119+
120+
def test_nhwc_autocast_jit_trace_model(self):
121+
def test_nhwc_autocast_jit_trace_model(model, x):
122+
model.eval()
123+
ipex.core.disable_jit_opt()
124+
with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad():
125+
traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
126+
with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad():
127+
y = traced_model(x.clone().to(memory_format=torch.channels_last))
128+
y2 = model(x.clone().to(memory_format=torch.channels_last))
129+
ipex.core.enable_jit_opt()
130+
torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-05, atol=_get_default_tolerance(y, y2)[1])
131+
for i in range(self.models.__len__()):
132+
if self.inputs[i].size().__len__() == 5:
133+
# NHWC 3D case not support yet
134+
continue
135+
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
105136

106137
class TestCustomerOps(TestCase):
107138
def test_interaction_op(self):

0 commit comments

Comments
 (0)