Skip to content

Commit 804110a

Browse files
committed
add unit test for linear fuse relu
1 parent c5e411c commit 804110a

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

tests/cpu/test_linear_fuse_relu.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import torch
2+
import time
3+
import intel_pytorch_extension_py as ipex
4+
K=1 #128
5+
C=16 #64
6+
MB = 28
7+
8+
class Cast(torch.nn.Module):
9+
__constants__ = ['to_dtype']
10+
11+
def __init__(self, to_dtype):
12+
super(Cast, self).__init__()
13+
self.to_dtype = to_dtype
14+
15+
def forward(self, input):
16+
return input.to(self.to_dtype)
17+
18+
def extra_repr(self):
19+
return 'to(%s)' % self.to_dtype
20+
21+
def get_rand_seed():
22+
return int(time.time() * 1000000000)
23+
24+
def _ipxex_linear_relu(random_seed, data_type = torch.float32):
25+
torch.manual_seed(random_seed)
26+
fc = ipex.LinearFuseRelu(C, K).to(data_type).to('dpcpp')
27+
return fc
28+
29+
def _cpu_linear_relu(random_seed, data_type = torch.float32):
30+
torch.manual_seed(random_seed)
31+
fc = torch.nn.ModuleList()
32+
fc.append(torch.nn.Linear(C, K).to(data_type))
33+
if data_type == torch.bfloat16:
34+
fc.append(Cast(torch.float32))
35+
fc.append(torch.nn.ReLU())
36+
return torch.nn.Sequential(*fc)
37+
38+
def _run_mlp(random_seed, fc_module, data_type = torch.float32, device='cpu'):
39+
torch.manual_seed(random_seed)
40+
x1 = torch.randn(MB, C).to(data_type).to(device).requires_grad_()
41+
y1 = fc_module(x1)
42+
z1 = y1.mean()
43+
z1.backward()
44+
if type(fc_module) == torch.nn.modules.container.Sequential:
45+
return x1.grad, fc_module[0].weight.grad, fc_module[0].bias.grad
46+
return x1.grad, fc_module.weight.grad, fc_module.bias.grad
47+
48+
for data_type in [torch.float32, torch.bfloat16]:
49+
seed = get_rand_seed()
50+
ipex_fc = _ipxex_linear_relu(seed, data_type)
51+
cpu_fc = _cpu_linear_relu(seed, data_type)
52+
53+
rtol = 1e-5
54+
atol = rtol
55+
if data_type == torch.bfloat16:
56+
rtol = 1e-2
57+
atol = rtol
58+
59+
seed = get_rand_seed()
60+
input_grad_ipex, weight_grad_ipex, bias_grad_ipex = _run_mlp(seed, ipex_fc, data_type, device='dpcpp')
61+
input_grad_cpu, weight_grad_cpu, bias_grad_cpu = _run_mlp(seed, cpu_fc, data_type)
62+
63+
if input_grad_ipex is None:
64+
if input_grad_cpu is not None:
65+
print("##################### {} linear fuse relu input grad FAIL".format(str(data_type)))
66+
else:
67+
print("##################### {} linear fuse relu input grad PASS".format(str(data_type)))
68+
else:
69+
if not input_grad_ipex.to(torch.float32).allclose(input_grad_cpu.to(torch.float32), rtol=rtol, atol=atol):
70+
print("##################### {} linear fuse relu input grad FAIL".format(str(data_type)))
71+
else:
72+
print("##################### {} linear fuse relu input grad PASS".format(str(data_type)))
73+
74+
if not weight_grad_ipex.to(torch.float32).allclose(weight_grad_cpu.to(torch.float32), rtol=rtol, atol=atol):
75+
print("##################### {} linear fuse relu weight grad FAIL".format(str(data_type)))
76+
else:
77+
print("##################### {} linear fuse relu weight grad PASS".format(str(data_type)))
78+
79+
if not bias_grad_ipex.to(torch.float32).allclose(bias_grad_cpu.to(torch.float32), rtol=rtol, atol=atol):
80+
print("##################### {} linear fuse relu bias grad FAIL".format(str(data_type)))
81+
else:
82+
print("##################### {} linear fuse relu bias grad PASS".format(str(data_type)))

0 commit comments

Comments
 (0)