Skip to content

Commit a262ee1

Browse files
committed
jit: enable conv_sum and conc_sum_relu fusion
1 parent 8c9fb3d commit a262ee1

File tree

8 files changed

+290
-81
lines changed

8 files changed

+290
-81
lines changed

tests/cpu/test_jit.py

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656

5757
import torch
5858
import torch.nn as nn
59+
from torch.jit._recursive import wrap_cpp_module
5960
import copy
6061

6162
import intel_pytorch_extension
@@ -82,29 +83,89 @@
8283
torch._C._jit_set_profiling_mode(False)
8384
torch._C._jit_set_profiling_executor(False)
8485

85-
class Conv_relu(nn.Module):
86-
def __init__(self):
87-
super(Conv_relu, self).__init__()
86+
def test_output(model, x):
87+
modelName = model.__class__.__name__
88+
core.disable_jit()
89+
90+
model = model.to('dpcpp').eval()
91+
x = x.to('dpcpp')
92+
with torch.no_grad():
93+
result = model(x)
94+
95+
smodel = torch.jit.script(model)
96+
smodel.eval()
97+
with torch.no_grad():
98+
sresult = smodel(x)
99+
100+
print(f'\nAre {modelName} and Scripted{modelName} outputs the same: ',
101+
torch.allclose(
102+
sresult, result, rtol=1e-05, atol=1e-06, equal_nan=False))
103+
104+
core.enable_jit()
105+
pmodel = torch.jit.script(model)
106+
# bn folding
107+
pmodel = wrap_cpp_module(torch._C._jit_pass_fold_convbn(pmodel._c))
108+
with torch.no_grad():
109+
# conv relu fusion, conv sum fusion or conv sum relu fusion
110+
print(pmodel.graph_for(x))
111+
presult = pmodel(x)
112+
113+
# print(result)
114+
# print(sresult)
115+
# print(presult)
116+
117+
print(f'\nWith or without pyrys, are Scripted{modelName} outputs the same: ',
118+
torch.allclose(
119+
sresult, presult, rtol=1e-05, atol=1e-06, equal_nan=False))
120+
121+
class Conv2dRelu_Fixed(nn.Module):
122+
def __init__(self, in_channels, out_channels, **kwargs):
123+
super(Conv2dRelu_Fixed, self).__init__()
124+
seed = 2018
125+
torch.manual_seed(seed)
126+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
127+
128+
def forward(self, x):
129+
return F.relu(self.conv(x), inplace=True)
130+
131+
class CascadedConv2dBnSumRelu(nn.Module):
132+
def __init__(self, in_channels, mid_channels, out_channels, **kwargs):
133+
super(CascadedConv2dBnSumRelu, self).__init__()
88134
torch.manual_seed(2018)
89-
self.conv = torch.nn.Conv2d(20, 20, 5)
135+
self.conv = nn.Conv2d(in_channels, mid_channels, bias=False, **kwargs)
136+
self.conv1 = nn.Conv2d(
137+
mid_channels, out_channels, bias=False, padding=1, **kwargs)
138+
self.conv2 = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
139+
self.bn = nn.BatchNorm2d(mid_channels, eps=0.001)
140+
self.bn1 = nn.BatchNorm2d(out_channels, eps=0.001)
141+
self.bn2 = nn.BatchNorm2d(out_channels, eps=0.001)
90142

91143
def forward(self, x):
92-
x = self.conv(x)
93-
return x.relu()
94-
95-
class TestJITOP(TestCase):
96-
def test_conv_relu_fusion(self):
97-
x = torch.randn(1, 20, 20, 20).to('dpcpp')
98-
99-
model = Conv_relu().to('dpcpp').eval()
100-
101-
with torch.no_grad():
102-
core.disable_jit()
103-
y1 = model(x)
104-
core.enable_jit()
105-
script_model = torch.jit.script(model)
106-
y2 = script_model(x)
107-
self.assertEqual(y1, y2)
144+
a = self.conv(x)
145+
a = self.bn(a)
146+
a = F.relu(a, inplace=True)
147+
a = self.conv1(a)
148+
a = self.bn1(a)
149+
b = self.conv2(x)
150+
b = self.bn2(b)
151+
return F.relu(a.add_(b), inplace=True)
152+
153+
class Tester(TestCase):
154+
n = 32
155+
c = 3
156+
h = 224
157+
w = 224
158+
print('input size: (%d, %d, %d, %d)' % (n, c, h, w))
159+
160+
def test_output_conv_relu(self):
161+
test_output(
162+
Conv2dRelu_Fixed(self.c, 32, kernel_size=3, stride=1),
163+
torch.rand(self.n, self.c, self.h, self.w))
164+
165+
def test_output_cascaded_conv2d_bn_sum_relu(self):
166+
test_output(
167+
CascadedConv2dBnSumRelu(self.c, 64, 32, kernel_size=3, stride=1),
168+
torch.rand(self.n, self.c, self.h, self.w))
108169

109170
if __name__ == '__main__':
110171
core.enable_auto_dnnl()

torch_ipex/csrc/cpu/FusionOPs.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,77 @@ at::Tensor AtenIpexJITDev::dil_convolution_relu(
5555
return dbl::comm::gen_aten_tensor_by(dil_output);
5656
}
5757

58+
static at::Tensor& dil_convolution_inplace_fusion(
59+
const at::Tensor& input,
60+
const at::Tensor& weight,
61+
const at::Tensor& bias,
62+
at::Tensor& accumu,
63+
at::IntArrayRef stride,
64+
at::IntArrayRef padding,
65+
at::IntArrayRef dilation,
66+
int64_t groups,
67+
const dil::attr_t& attr) {
68+
dil::tensor dil_input;
69+
dil::tensor dil_weight;
70+
dil::tensor dil_output;
71+
c10::optional<dil::tensor> dil_bias{c10::nullopt};
72+
73+
auto input_contiguous = input.contiguous();
74+
auto weight_contiguous = weight.contiguous();
75+
auto output_contiguous = accumu.contiguous();
76+
77+
dil_input = dbl::comm::try_gen_dil_tensor(input_contiguous);
78+
dil_weight = dbl::comm::try_gen_dil_tensor(weight_contiguous);
79+
dil_output = dbl::comm::try_gen_dil_tensor(output_contiguous);
80+
if (bias.defined()) {
81+
auto bias_contiguous = bias.contiguous();
82+
dil_bias = dbl::comm::try_gen_dil_tensor(bias_contiguous);
83+
}
84+
85+
dbl::conv::conv2d_inplace_impl(
86+
dil_input,
87+
dil_weight,
88+
dil_bias,
89+
dil_output,
90+
padding,
91+
stride,
92+
dilation,
93+
groups,
94+
attr);
95+
96+
dbl::comm::sync_shape_from_dil_to_aten(accumu, dil_output);
97+
return accumu;
98+
}
99+
100+
at::Tensor& AtenIpexJITDev::dil_convolution_sum(
101+
const at::Tensor & input,
102+
const at::Tensor & weight,
103+
const at::Tensor & bias,
104+
at::IntArrayRef stride,
105+
at::IntArrayRef padding,
106+
at::IntArrayRef dilation,
107+
int64_t groups,
108+
at::Tensor& accumu,
109+
at::Scalar alpha) {
110+
auto scale = alpha.to<float>();
111+
return dil_convolution_inplace_fusion(input, weight, bias, accumu, stride, padding,
112+
dilation, groups, dil::attr_t::fuse_sum(scale));
113+
}
114+
115+
at::Tensor& AtenIpexJITDev::dil_convolution_sum_relu(
116+
const at::Tensor & input,
117+
const at::Tensor & weight,
118+
const at::Tensor & bias,
119+
at::IntArrayRef stride,
120+
at::IntArrayRef padding,
121+
at::IntArrayRef dilation,
122+
int64_t groups,
123+
at::Tensor& accumu,
124+
at::Scalar alpha) {
125+
auto scale = alpha.to<float>();
126+
return dil_convolution_inplace_fusion(input, weight, bias, accumu, stride, padding,
127+
dilation, groups, dil::attr_t::residual(scale));
128+
}
129+
58130
} // namespace cpu
59131
} // namespace torch_ipex

torch_ipex/csrc/cpu/FusionOPs.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@ namespace torch { namespace jit {
1111
// XXX: PyTorch does not support nesting namespace
1212
// And the alias analysis is not working for namespace other than aten ...
1313
// So we fake some op namespaces to workaround that.
14-
namespace dnnl {
15-
static auto conv2d_relu = Symbol::fromQualString("dnnl::conv2d_relu");
16-
static auto conv2d_sum = Symbol::fromQualString("dnnl::conv2d_sum");
17-
static auto conv2d_relu_sum = Symbol::fromQualString("dnnl::conv2d_relu_sum");
18-
static auto conv2d_sum_relu = Symbol::fromQualString("dnnl::conv2d_sum_relu");
19-
14+
namespace ipex {
15+
static auto conv2d_relu = Symbol::fromQualString("ipex::conv2d_relu");
16+
static auto conv2d_sum = Symbol::fromQualString("ipex::conv2d_sum");
17+
static auto conv2d_relu_sum = Symbol::fromQualString("ipex::conv2d_relu_sum");
18+
static auto conv2d_sum_relu = Symbol::fromQualString("ipex::conv2d_sum_relu");
2019
}
2120

2221
}} // namespace torch::jit
@@ -29,6 +28,10 @@ class AtenIpexJITDev {
2928
// for JIT ops
3029
static at::Tensor dil_convolution_relu(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
3130

31+
static at::Tensor& dil_convolution_sum(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups, at::Tensor& accumu, at::Scalar alpha);
32+
33+
static at::Tensor& dil_convolution_sum_relu( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups, at::Tensor& accumu, at::Scalar alpha);
34+
3235
};
3336

3437
} // namespace cpu

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,13 @@ void sync_shape_from_dil_to_aten(const at::Tensor& ipex_tensor, const dil::tenso
9191
dil::dims sizes = dil_tensor.get_dims();
9292
if (dil_tensor.is_public_format()) {
9393
dil::dims strides = dil_tensor.get_strides();
94-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ipex_tensor.device().type() == at::DeviceType::DPCPP);
9594
auto* _tensor_impl = (IPEXTensorImpl *)ipex_tensor.unsafeGetTensorImpl();
9695
_tensor_impl->force_set_strided(sizes, strides);
9796
} else {
9897
// Blockformat does not inlcude stride information
9998
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.size() != 1 || sizes[0] != 0);
10099
ipex_tensor.unsafeGetTensorImpl()->set_sizes_contiguous(sizes);
101100
}
102-
103101
}
104102

105103
} // namespace comm

torch_ipex/csrc/cpu/dbl/Conv.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,67 @@ dil::tensor conv2d_impl(
8686
return y;
8787
}
8888

89+
void conv2d_inplace_impl(
90+
const dil::tensor& x,
91+
const dil::tensor& w,
92+
const c10::optional<dil::tensor>& b,
93+
dil::tensor& y,
94+
at::IntArrayRef padding,
95+
at::IntArrayRef stride,
96+
at::IntArrayRef dilation,
97+
int64_t groups,
98+
const dil::attr_t& attr) {
99+
std::vector<int64_t> kernel_size(x.ndims());
100+
// mkldnn conv2d weights could have been re-ordered to 5d by
101+
// mkldnn_reorder_conv2d_weight
102+
if (w.ndims() == x.ndims() + 1) {
103+
AT_ASSERTM(
104+
groups > 1,
105+
"Only group _mkldnn_conv2d weights could have been reordered to 5d");
106+
kernel_size[0] = w.get_dim(0) * w.get_dim(1);
107+
std::copy_n(w.get_dims().cbegin() + 2, x.ndims() - 1, kernel_size.begin() + 1);
108+
} else {
109+
std::copy_n(w.get_dims().cbegin(), x.ndims(), kernel_size.begin());
110+
}
111+
112+
const dil::dims x_dims = x.get_dims();
113+
std::vector<int64_t> input_size{x_dims.cbegin(), x_dims.cend()};
114+
std::vector<int64_t> output_sizes = calc_conv_output_size(input_size, kernel_size, padding, stride, dilation);
115+
116+
if (b.has_value()) {
117+
dil::convolution_forward::compute(
118+
x,
119+
w,
120+
b.value(),
121+
{output_sizes.cbegin(), output_sizes.cend()},
122+
y,
123+
{stride.begin(), stride.end()},
124+
{dilation.begin(), dilation.end()},
125+
{padding.begin(), padding.end()},
126+
{padding.begin(), padding.end()},
127+
groups,
128+
dil::scale_t(),
129+
dil::scale_t(),
130+
dil::scale_t(),
131+
attr);
132+
} else {
133+
dil::convolution_forward::compute(
134+
x,
135+
w,
136+
{output_sizes.cbegin(), output_sizes.cend()},
137+
y,
138+
{stride.begin(), stride.end()},
139+
{dilation.begin(), dilation.end()},
140+
{padding.begin(), padding.end()},
141+
{padding.begin(), padding.end()},
142+
groups,
143+
dil::scale_t(),
144+
dil::scale_t(),
145+
dil::scale_t(),
146+
attr);
147+
}
148+
}
149+
89150
} // namespace conv
90151
} // namespace dbl
91152
} // namespace cpu

torch_ipex/csrc/cpu/dbl/Conv.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ dil::tensor conv2d_impl(
2828
int64_t groups,
2929
const dil::attr_t& attr = dil::attr_t());
3030

31+
void conv2d_inplace_impl(
32+
const dil::tensor& x,
33+
const dil::tensor& w,
34+
const c10::optional<dil::tensor>& b,
35+
dil::tensor& y,
36+
at::IntArrayRef padding,
37+
at::IntArrayRef stride,
38+
at::IntArrayRef dilation,
39+
int64_t groups,
40+
const dil::attr_t& attr = dil::attr_t());
41+
3142
} // namespace conv
3243
} // namespace dbl
3344
} // namespace cpu

torch_ipex/csrc/jit/fusion_pass.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,14 @@ class OpFuser {
275275

276276
// TODO: These rules should be more scalable
277277
OpFuser::RuleTab OpFuser::dnnlRules = {
278-
{{aten::conv2d, aten::relu}, dnnl::conv2d_relu},
279-
{{aten::conv2d, Symbol::fromQualString("aten::relu_")}, dnnl::conv2d_relu},
280-
/*
281-
{{AtenIpexCPUDev::conv2d_sum, AtenIpexCPUDev::relu}, AtenIpexCPUDev::conv2d_sum_relu},
282-
{{AtenIpexCPUDev::conv2d_sum, dnnl::relu_}, AtenIpexCPUDev::conv2d_sum_relu},
283-
284-
{{aten::conv2d, aten::add}, AtenIpexCPUDev::conv2d_sum},
285-
{{aten::conv2d, aten::add_}, AtenIpexCPUDev::conv2d_sum},
286-
{{AtenIpexCPUDev::conv2d_relu, aten::add}, AtenIpexCPUDev::conv2d_relu_sum}
287-
*/
278+
{{aten::conv2d, aten::relu}, ipex::conv2d_relu},
279+
{{aten::conv2d, Symbol::fromQualString("aten::relu_")}, ipex::conv2d_relu},
280+
{{ipex::conv2d_sum, aten::relu}, ipex::conv2d_sum_relu},
281+
{{ipex::conv2d_sum, Symbol::fromQualString("aten::relu_")}, ipex::conv2d_sum_relu},
282+
283+
{{aten::conv2d, aten::add}, ipex::conv2d_sum},
284+
{{aten::conv2d, aten::add_}, ipex::conv2d_sum},
285+
//{{dnnl::conv2d_relu, aten::add}, dnnl::conv2d_relu_sum}
288286
};
289287

290288
void FusionPass(std::shared_ptr<Graph> &graph) {

0 commit comments

Comments
 (0)