Skip to content

Commit 98a7d86

Browse files
jit: enable conv_relu fusion (#15)
* jit: enable conv_relu fusion * only jit fusion for extension path * jit: enable conv_sum and conc_sum_relu fusion * make rewrited linear op can be traced * make rewrited max_pool2d op can be traced * fix max_pool2d backward floating point exception issue * make rewrited AdaptiveAvgPool2d op can be traced * fix linear issue when bias is None * fix max_pool2d issue with stride=None case * add prepack_weight API
1 parent fb9f850 commit 98a7d86

21 files changed

+791
-428
lines changed

cmake/CPU.cmake

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,11 @@ include_directories(${DPCPP_THIRD_PARTY_ROOT}/xsmm/include)
136136
set(DPCPP_SRCS)
137137
set(DPCPP_COMMON_SRCS)
138138
set(DPCPP_CPU_SRCS)
139+
set(DPCPP_JIT_SRCS)
139140

140141
add_subdirectory(${DPCPP_ROOT})
141142
add_subdirectory(${DPCPP_ROOT}/cpu)
143+
add_subdirectory(${DPCPP_ROOT}/jit)
142144

143145
# libxsmm
144146
include(${CMAKE_ROOT}/Modules/ExternalProject.cmake)
@@ -153,7 +155,7 @@ ExternalProject_Add(xsmm
153155
INSTALL_COMMAND ""
154156
)
155157
# Compile code with pybind11
156-
set(DPCPP_SRCS ${DPCPP_ATEN_SRCS} ${DPCPP_COMMON_SRCS} ${DPCPP_CPU_SRCS})
158+
set(DPCPP_SRCS ${DPCPP_ATEN_SRCS} ${DPCPP_COMMON_SRCS} ${DPCPP_CPU_SRCS} ${DPCPP_JIT_SRCS})
157159
pybind11_add_module(${PLUGIN_NAME} SHARED ${DPCPP_SRCS})
158160
target_link_libraries(${PLUGIN_NAME} PRIVATE ${DPCPP_THIRD_PARTY_ROOT}/xsmm/lib/libxsmm.a)
159161

intel_pytorch_extension_py/ops/linear.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,11 @@
22
from torch.autograd import Function
33
import torch.nn.functional as F
44
import _torch_ipex as core
5+
from typing import Optional
56

6-
F_linear = F.linear
7-
8-
class LinearFunction(Function):
9-
@staticmethod
10-
def forward(ctx, input, weight, bias):
11-
output = core.linear(input, weight, bias)
12-
ctx.save_for_backward(input, weight, bias)
13-
return output
14-
15-
@staticmethod
16-
def backward(ctx, grad_output):
17-
input, weight, bias = ctx.saved_tensors
18-
grad_output = grad_output.contiguous()
19-
if bias == None:
20-
output_mask = (input.requires_grad, weight.requires_grad, 0)
21-
else:
22-
output_mask = (input.requires_grad, weight.requires_grad, bias.requires_grad)
23-
grad_input, grad_weight, grad_bias = core.linear_backward(input, grad_output, weight, output_mask)
24-
return (grad_input, grad_weight, grad_bias)
25-
26-
def linear(input, weight, bias=None):
27-
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
28-
return LinearFunction.apply(input, weight, bias)
29-
return F_linear(input, weight, bias)
7+
def linear(input, weight, bias: Optional[torch.Tensor] = None):
8+
if bias is None:
9+
bias = torch.zeros(weight.size(0))
10+
return torch.ops.torch_ipex.linear(input, weight, bias)
3011

3112
F.linear = linear

intel_pytorch_extension_py/ops/pooling.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,12 @@
22
from torch.autograd import Function
33
import torch.nn.functional as F
44
import _torch_ipex as core
5-
from torch.nn.modules.utils import _single
5+
from torch.nn.modules.utils import _single, _pair
6+
from typing import List
67

7-
torch_adaptive_avg_pool2d = torch._C._nn.adaptive_avg_pool2d
8-
torch_max_pool2d = torch.max_pool2d
9-
torch_max_pool3d = torch.max_pool3d
8+
Vector = List[int]
109

11-
class AdaptiveAvgPool2dFunction(Function):
12-
@staticmethod
13-
def forward(ctx, input, output_size):
14-
output = core.adaptive_avg_pool2d(input, _single(output_size))
15-
ctx.save_for_backward(input)
16-
return output
17-
18-
@staticmethod
19-
def backward(ctx, grad_output):
20-
(input,) = ctx.saved_tensors
21-
grad_output = grad_output.contiguous()
22-
grad_input = core.adaptive_avg_pool2d_backward(grad_output, input)
23-
return (grad_input, None)
10+
torch_max_pool3d = torch.max_pool3d
2411

2512
class MaxPoolingFunction(Function):
2613
@staticmethod
@@ -41,21 +28,8 @@ def backward(ctx, grad_output):
4128
grad_input = core.max_pooling_backward(grad_output, output, input, ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation, ctx.ceil_mode)
4229
return (grad_input, None, None, None, None, None)
4330

44-
def adaptive_avg_pool2d(input, output_size):
45-
try:
46-
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
47-
return AdaptiveAvgPool2dFunction.apply(input, output_size)
48-
except RuntimeError:
49-
pass
50-
return torch_adaptive_avg_pool2d(input, output_size)
51-
52-
def max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode):
53-
try:
54-
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
55-
return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode)
56-
except RuntimeError:
57-
pass
58-
return torch_max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
31+
def adaptive_avg_pool2d(input, output_size: Vector):
32+
return torch.ops.torch_ipex.adaptive_avg_pool2d(input, _pair(output_size))
5933

6034
def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode):
6135
try:
@@ -65,6 +39,11 @@ def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode):
6539
pass
6640
return torch_max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode)
6741

42+
def max_pool2d(input, kernel_size: Vector, stride: Vector, padding: Vector, dilation: Vector, ceil_mode: bool):
43+
if not stride:
44+
stride = kernel_size
45+
return torch.ops.torch_ipex.max_pool2d(input, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation), ceil_mode)
46+
6847
torch._C._nn.adaptive_avg_pool2d = adaptive_avg_pool2d
6948
torch.max_pool2d = max_pool2d
70-
torch.max_pool3d = max_pool3d
49+
torch.max_pool3d = max_pool3d

tests/cpu/test_jit.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from __future__ import division
2+
from __future__ import print_function
3+
4+
'''
5+
From PyTorch:
6+
7+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
8+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
9+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
10+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
11+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
12+
Copyright (c) 2011-2013 NYU (Clement Farabet)
13+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
14+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
15+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
16+
17+
From Caffe2:
18+
19+
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
20+
21+
All contributions by Facebook:
22+
Copyright (c) 2016 Facebook Inc.
23+
24+
All contributions by Google:
25+
Copyright (c) 2015 Google Inc.
26+
All rights reserved.
27+
28+
All contributions by Yangqing Jia:
29+
Copyright (c) 2015 Yangqing Jia
30+
All rights reserved.
31+
32+
All contributions from Caffe:
33+
Copyright(c) 2013, 2014, 2015, the respective contributors
34+
All rights reserved.
35+
36+
All other contributions:
37+
Copyright(c) 2015, 2016 the respective contributors
38+
All rights reserved.
39+
40+
Caffe2 uses a copyright model similar to Caffe: each contributor holds
41+
copyright over their contributions to Caffe2. The project versioning records
42+
all such contribution and copyright details. If a contributor wants to further
43+
mark their specific copyright on a particular contribution, they should
44+
indicate their copyright solely in the commit message of the change when it is
45+
committed.
46+
47+
All rights reserved.
48+
'''
49+
50+
"""Tests for rn50."""
51+
52+
import math
53+
import random
54+
import unittest
55+
from functools import reduce
56+
57+
import torch
58+
import torch.nn as nn
59+
from torch.jit._recursive import wrap_cpp_module
60+
import copy
61+
62+
import intel_pytorch_extension
63+
from intel_pytorch_extension import core
64+
65+
import torch.nn as nn
66+
import torch.backends.cudnn as cudnn
67+
from torch.nn import Parameter
68+
import torch.nn.functional as F
69+
from torch.autograd import gradcheck
70+
from torch.autograd.gradcheck import gradgradcheck
71+
from torch._six import inf, nan
72+
73+
from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
74+
TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
75+
IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, do_test_empty_full, \
76+
IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \
77+
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf
78+
79+
device = 'dpcpp:0'
80+
#device = 'cpu:0'
81+
SIZE = 100
82+
83+
torch._C._jit_set_profiling_mode(False)
84+
torch._C._jit_set_profiling_executor(False)
85+
86+
def test_output(model, x):
87+
modelName = model.__class__.__name__
88+
core.disable_jit()
89+
90+
model = model.to('dpcpp').eval()
91+
x = x.to('dpcpp')
92+
with torch.no_grad():
93+
result = model(x)
94+
95+
smodel = torch.jit.script(model)
96+
smodel.eval()
97+
with torch.no_grad():
98+
sresult = smodel(x)
99+
100+
print(f'\nAre {modelName} and Scripted{modelName} outputs the same: ',
101+
torch.allclose(
102+
sresult, result, rtol=1e-05, atol=1e-06, equal_nan=False))
103+
104+
core.enable_jit()
105+
pmodel = torch.jit.script(model)
106+
# bn folding
107+
pmodel = wrap_cpp_module(torch._C._jit_pass_fold_convbn(pmodel._c))
108+
with torch.no_grad():
109+
# conv relu fusion, conv sum fusion or conv sum relu fusion
110+
print(pmodel.graph_for(x))
111+
presult = pmodel(x)
112+
113+
# print(result)
114+
# print(sresult)
115+
# print(presult)
116+
117+
print(f'\nWith or without pyrys, are Scripted{modelName} outputs the same: ',
118+
torch.allclose(
119+
sresult, presult, rtol=1e-05, atol=1e-06, equal_nan=False))
120+
121+
class Conv2dRelu_Fixed(nn.Module):
122+
def __init__(self, in_channels, out_channels, **kwargs):
123+
super(Conv2dRelu_Fixed, self).__init__()
124+
seed = 2018
125+
torch.manual_seed(seed)
126+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
127+
128+
def forward(self, x):
129+
return F.relu(self.conv(x), inplace=True)
130+
131+
class CascadedConv2dBnSumRelu(nn.Module):
132+
def __init__(self, in_channels, mid_channels, out_channels, **kwargs):
133+
super(CascadedConv2dBnSumRelu, self).__init__()
134+
torch.manual_seed(2018)
135+
self.conv = nn.Conv2d(in_channels, mid_channels, bias=False, **kwargs)
136+
self.conv1 = nn.Conv2d(
137+
mid_channels, out_channels, bias=False, padding=1, **kwargs)
138+
self.conv2 = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
139+
self.bn = nn.BatchNorm2d(mid_channels, eps=0.001)
140+
self.bn1 = nn.BatchNorm2d(out_channels, eps=0.001)
141+
self.bn2 = nn.BatchNorm2d(out_channels, eps=0.001)
142+
143+
def forward(self, x):
144+
a = self.conv(x)
145+
a = self.bn(a)
146+
a = F.relu(a, inplace=True)
147+
a = self.conv1(a)
148+
a = self.bn1(a)
149+
b = self.conv2(x)
150+
b = self.bn2(b)
151+
return F.relu(a.add_(b), inplace=True)
152+
153+
class Tester(TestCase):
154+
n = 32
155+
c = 3
156+
h = 224
157+
w = 224
158+
print('input size: (%d, %d, %d, %d)' % (n, c, h, w))
159+
160+
def test_output_conv_relu(self):
161+
test_output(
162+
Conv2dRelu_Fixed(self.c, 32, kernel_size=3, stride=1),
163+
torch.rand(self.n, self.c, self.h, self.w))
164+
165+
def test_output_cascaded_conv2d_bn_sum_relu(self):
166+
test_output(
167+
CascadedConv2dBnSumRelu(self.c, 64, 32, kernel_size=3, stride=1),
168+
torch.rand(self.n, self.c, self.h, self.w))
169+
170+
if __name__ == '__main__':
171+
core.enable_auto_dnnl()
172+
test = unittest.main()

torch_ipex/csrc/auto_opt_config.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ class AutoOptConfig {
1717
return auto_dnnl_;
1818
}
1919

20+
inline void set_jit_fuse(bool jit_fuse) {
21+
jit_fuse_ = jit_fuse;
22+
}
23+
24+
inline bool get_jit_fuse() {
25+
return jit_fuse_;
26+
}
27+
2028
inline void set_mix_bf16_fp32(bool value) {
2129
mix_bf16_fp32_ = value;
2230
}
@@ -39,6 +47,7 @@ class AutoOptConfig {
3947

4048
private:
4149
bool auto_dnnl_;
50+
bool jit_fuse_;
4251
bool mix_bf16_fp32_;
4352
bool pure_bf16_;
4453
};

0 commit comments

Comments
 (0)