|
4 | 4 | import intel_pytorch_extension as ipex
|
5 | 5 | from common_utils import TestCase
|
6 | 6 | import time, sys
|
| 7 | +from torch.testing._core import _get_default_tolerance |
7 | 8 |
|
8 | 9 | def get_rand_seed():
|
9 | 10 | return int(time.time() * 1000000000)
|
@@ -72,36 +73,66 @@ def test_conv2d_backward(self):
|
72 | 73 | self.assertEqual(in_autocast.grad.dtype, torch.float)
|
73 | 74 | self.assertEqual(_in_cpu.grad, in_autocast.grad, 1e-2)
|
74 | 75 |
|
75 |
| -class SimpleNet(torch.nn.Module): |
76 |
| - def __init__(self): |
77 |
| - super(SimpleNet, self).__init__() |
78 |
| - self.conv = torch.nn.Conv2d(3, 16, (3, 3), stride=(2, 2), padding=(3, 3), bias=False) |
79 |
| - self.bn = torch.nn.BatchNorm2d(16) |
80 |
| - self.relu = torch.nn.ReLU(inplace=True) |
81 |
| - |
82 |
| - def forward(self, x): |
83 |
| - x = self.conv(x) |
84 |
| - x = self.bn(x) |
85 |
| - x = self.relu(x) |
86 |
| - return x |
87 |
| - |
88 |
| -class TestSimpleNet(TestCase): |
89 |
| - def test_generate_jit_trace_model(self): |
90 |
| - rand_seed = int(get_rand_seed()) |
91 |
| - print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed)) |
92 |
| - torch.manual_seed(rand_seed) |
93 |
| - |
94 |
| - model = SimpleNet() |
95 |
| - model.eval() |
96 |
| - #ipex.core.disable_jit_opt() |
97 |
| - x = torch.rand((1, 3, 224, 224)) |
98 |
| - with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad(): |
99 |
| - traced_model = torch.jit.trace(model, x) |
100 |
| - with torch.no_grad(): |
101 |
| - y = traced_model(x) |
102 |
| - #print(traced_model.graph_for(x)) |
103 |
| - #ipex.core.enable_jit_opt() |
104 |
| - self.assertEqual(y.dtype, torch.float) #conv whitelist, bn blacklist, relu fallthrough |
| 76 | +class TestAutocastWithJit(TestCase): |
| 77 | + def setUp(self): |
| 78 | + super(TestAutocastWithJit, self).setUp() |
| 79 | + from test_jit import Conv_Bn_Relu, BatchNorm_Conv_BatchNorm, ConvBatchNorm_Fixed, ConvReshapeBatchNorm,\ |
| 80 | + CascadedConvBnSumRelu, LinearBn, Linear_Reshape_Bn |
| 81 | + self.models = [Conv_Bn_Relu(2, 3, 32, kernel_size=3, stride=1), BatchNorm_Conv_BatchNorm(2, 3, 32, kernel_size=3, stride=1),\ |
| 82 | + ConvBatchNorm_Fixed(2, 3, 32, kernel_size=3, stride=1), ConvBatchNorm_Fixed(3, 3, 32, kernel_size=3, stride=1),\ |
| 83 | + ConvReshapeBatchNorm(2, 3, 32, (64, 16, 62, 62), kernel_size=3, stride=1),\ |
| 84 | + CascadedConvBnSumRelu(2, 3, 64, 32, kernel_size=3, stride=1),\ |
| 85 | + LinearBn(2 ,32, 32, bias=True),\ |
| 86 | + Linear_Reshape_Bn(2 ,32, 32,(1,1,64,16),bias=True)] |
| 87 | + self.inputs = [torch.randn(32, 3, 64, 64), torch.randn(32, 3, 64, 64),\ |
| 88 | + torch.randn(32, 3, 64, 64), torch.randn(32, 3, 32, 32, 32),\ |
| 89 | + torch.randn(32, 3, 64, 64),\ |
| 90 | + torch.rand(32, 3, 64, 64),\ |
| 91 | + torch.rand(1, 1, 32, 32),\ |
| 92 | + torch.rand(1, 1, 32, 32)] |
| 93 | + |
| 94 | + def test_generate_autocast_jit_trace_model(self): |
| 95 | + def test_generate_autocast_jit_trace_model(model, x): |
| 96 | + model.eval() |
| 97 | + ipex.core.disable_jit_opt() |
| 98 | + with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad(): |
| 99 | + traced_model = torch.jit.trace(model, x) |
| 100 | + ipex.core.enable_jit_opt() |
| 101 | + with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad(): |
| 102 | + traced_model2 = torch.jit.trace(model, x.clone()) |
| 103 | + for i in range(self.models.__len__()): |
| 104 | + test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i]) |
| 105 | + |
| 106 | + def test_nchw_autocast_jit_trace_model(self): |
| 107 | + def test_nchw_autocast_jit_trace_model(model, x): |
| 108 | + model.eval() |
| 109 | + ipex.core.disable_jit_opt() |
| 110 | + with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad(): |
| 111 | + traced_model = torch.jit.trace(model, x) |
| 112 | + with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad(): |
| 113 | + y = traced_model(x.clone()) |
| 114 | + y2 = model(x.clone()) |
| 115 | + ipex.core.enable_jit_opt() |
| 116 | + torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-05, atol=_get_default_tolerance(y, y2)[1]) |
| 117 | + for i in range(self.models.__len__()): |
| 118 | + test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i]) |
| 119 | + |
| 120 | + def test_nhwc_autocast_jit_trace_model(self): |
| 121 | + def test_nhwc_autocast_jit_trace_model(model, x): |
| 122 | + model.eval() |
| 123 | + ipex.core.disable_jit_opt() |
| 124 | + with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad(): |
| 125 | + traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last)) |
| 126 | + with ipex.amp.autocast(enabled=True, configure=ipex.conf.AmpConf(torch.bfloat16)), torch.no_grad(): |
| 127 | + y = traced_model(x.clone().to(memory_format=torch.channels_last)) |
| 128 | + y2 = model(x.clone().to(memory_format=torch.channels_last)) |
| 129 | + ipex.core.enable_jit_opt() |
| 130 | + torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-05, atol=_get_default_tolerance(y, y2)[1]) |
| 131 | + for i in range(self.models.__len__()): |
| 132 | + if self.inputs[i].size().__len__() == 5: |
| 133 | + # NHWC 3D case not support yet |
| 134 | + continue |
| 135 | + test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i]) |
105 | 136 |
|
106 | 137 | class TestCustomerOps(TestCase):
|
107 | 138 | def test_interaction_op(self):
|
|
0 commit comments