Skip to content

Commit 97d4606

Browse files
bf16:replace aten:max_pool2d with ipex::max_pool2d for good performance (#7)
1 parent dac3224 commit 97d4606

File tree

7 files changed

+250
-0
lines changed

7 files changed

+250
-0
lines changed

torch_ipex/csrc/cpu/FusionOPs.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "torch_ipex/csrc/utils.h"
33
#include "Conv.h"
44
#include "Linear.h"
5+
#include "Pooling.h"
56

67
#include <ATen/Context.h>
78
#include <ATen/InferSize.h>

torch_ipex/csrc/cpu/Pooling.cpp

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#include "mkldnn/MKLDNNCommon.h"
2+
#include "torch_ipex/csrc/utils.h"
3+
4+
#include <ATen/native/Pool.h>
5+
6+
namespace torch_ipex {
7+
namespace cpu {
8+
9+
inline std::vector<int64_t> expand_param_if_needed(
10+
at::IntArrayRef list_param,
11+
const char* param_name,
12+
int64_t expected_dim) {
13+
if (list_param.size() == 1) {
14+
return std::vector<int64_t>(expected_dim, list_param[0]);
15+
} else if ((int64_t)list_param.size() != expected_dim) {
16+
std::ostringstream ss;
17+
ss << "expected " << param_name << " to be a single integer value or a "
18+
<< "list of " << expected_dim << " values to match the convolution "
19+
<< "dimensions, but got " << param_name << "=" << list_param;
20+
AT_ERROR(ss.str());
21+
} else {
22+
return list_param.vec();
23+
}
24+
}
25+
26+
std::vector<int64_t> pool_output_sizes(
27+
at::IntArrayRef input_size,
28+
at::IntArrayRef kernel_size,
29+
at::IntArrayRef stride,
30+
at::IntArrayRef padding_l,
31+
at::IntArrayRef padding_r,
32+
at::IntArrayRef dilation,
33+
bool ceil_mode) {
34+
std::vector<int64_t> output_size(input_size.size());
35+
// copy N and C
36+
output_size[0] = input_size[0];
37+
output_size[1] = input_size[1];
38+
39+
for (size_t i = 2; i < input_size.size(); ++i) {
40+
output_size[i] = at::native::pooling_output_shape_pad_lr<int64_t>(
41+
input_size[i],
42+
kernel_size[i - 2],
43+
padding_l[i - 2],
44+
padding_r[i - 2],
45+
stride[i - 2],
46+
dilation[i - 2],
47+
ceil_mode
48+
);
49+
}
50+
51+
return output_size;
52+
}
53+
54+
static at::Tensor _mkldnn_pooling(
55+
const at::Tensor& input,
56+
at::IntArrayRef kernel_size,
57+
at::IntArrayRef stride,
58+
at::IntArrayRef padding,
59+
at::IntArrayRef dilation,
60+
bool ceil_mode,
61+
ideep::algorithm algo) {
62+
63+
const int64_t dims = input.dim() - 2;
64+
auto kernel_size_vec = expand_param_if_needed(kernel_size, "kernel_size", dims);
65+
if (stride.empty()) stride = kernel_size;
66+
auto stride_vec = expand_param_if_needed(stride, "stride", dims);
67+
auto padding_vec = expand_param_if_needed(padding, "padding", dims);
68+
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
69+
auto padding_vec_l = padding_vec;
70+
auto padding_vec_r = padding_vec;
71+
auto dilation_vec = expand_param_if_needed(dilation, "dilation", dims);
72+
73+
// TODO: the input will be actively converted to channels last format
74+
// after the 5-D tensor supports channels last format.
75+
const ideep::tensor mkldnn_input = at::native::itensor_view_from_dense(input);
76+
std::vector<int64_t> output_sizes;
77+
78+
if (ceil_mode) {
79+
// MKLDNN does not support ceil mode, so we adjust padding
80+
// on the right side to match behavior. Adjust output size
81+
// accordingly.
82+
const std::vector<int64_t> output_sizes_ceil = pool_output_sizes(
83+
input.sizes(),
84+
kernel_size_vec,
85+
stride_vec,
86+
padding_vec_l,
87+
padding_vec_r,
88+
dilation_vec,
89+
true /* ceil_mode */);
90+
91+
// adjust padding until output sizes agree
92+
bool all_equal = false;
93+
while (!all_equal) {
94+
output_sizes = pool_output_sizes(
95+
input.sizes(),
96+
kernel_size_vec,
97+
stride_vec,
98+
padding_vec_l,
99+
padding_vec_r,
100+
dilation_vec,
101+
false /*ceil_mode */);
102+
103+
all_equal = true;
104+
for (size_t i = 2; i < input.sizes().size(); ++i) {
105+
if (output_sizes[i] < output_sizes_ceil[i]) {
106+
padding_vec_r[i - 2]++;
107+
all_equal = false;
108+
}
109+
}
110+
}
111+
} else {
112+
output_sizes = pool_output_sizes(
113+
input.sizes(),
114+
kernel_size_vec,
115+
stride_vec,
116+
padding_vec_l,
117+
padding_vec_r,
118+
dilation_vec,
119+
false /*ceil_mode */);
120+
}
121+
122+
bool is_channels_last = input.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
123+
auto output = at::empty({0}, input.options());
124+
ideep::tensor mkldnn_output;
125+
if (is_channels_last) {
126+
output.resize_(output_sizes, input.suggest_memory_format());
127+
mkldnn_output = at::native::itensor_view_from_dense(output);
128+
}
129+
130+
auto aprop_kind = ideep::prop_kind::forward;
131+
// for max_pool, prop_kind::forward will save indices as workspace for backward use,
132+
// for inference, don't need the indices, set aprop_kind to prop_kind::forward_inference
133+
// can reduce the memory use.
134+
if (ideep::algorithm::pooling_max == algo
135+
&& !(input.requires_grad() && at::GradMode::is_enabled())) {
136+
aprop_kind = ideep::prop_kind::forward_inference;
137+
}
138+
139+
ideep::tensor y;
140+
ideep::pooling_forward::compute(
141+
mkldnn_input,
142+
{output_sizes.cbegin(), output_sizes.cend()},
143+
mkldnn_output,
144+
{stride_vec.cbegin(), stride_vec.cend()},
145+
{kernel_size_vec.cbegin(), kernel_size_vec.cend()},
146+
{padding_vec_l.cbegin(), padding_vec_l.cend()},
147+
{padding_vec_r.cbegin(), padding_vec_r.cend()},
148+
algo,
149+
aprop_kind);
150+
151+
if (is_channels_last) {
152+
return output;
153+
} else {
154+
return at::native::mkldnn_to_dense(
155+
at::native::new_with_itensor_mkldnn(std::move(mkldnn_output), optTypeMetaToScalarType(input.options().dtype_opt()),
156+
input.options().device_opt()));
157+
}
158+
}
159+
160+
at::Tensor dil_max_pool2d(
161+
const at::Tensor& input,
162+
at::IntArrayRef kernel_size,
163+
at::IntArrayRef stride,
164+
at::IntArrayRef padding,
165+
at::IntArrayRef dilation,
166+
bool ceil_mode) {
167+
#if defined(IPEX_PROFILE_OP)
168+
RECORD_FUNCTION("AtenIpexJITDev::dil_max_pool2d", std::vector<c10::IValue>({}));
169+
#endif
170+
TORCH_CHECK(std::all_of(dilation.cbegin(), dilation.cend(), [](int64_t i) { return 1 == i; }),
171+
"mkldnn_max_pool2d does not support dilation case");
172+
return _mkldnn_pooling(
173+
IS_CONTIGUOUS_ANY(input) ? input : input.contiguous(),
174+
kernel_size,
175+
stride,
176+
padding,
177+
dilation,
178+
ceil_mode,
179+
ideep::algorithm::pooling_max);
180+
}
181+
182+
} // namespace cpu
183+
} // namespace torch_ipex

torch_ipex/csrc/cpu/Pooling.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
5+
#include "ideep/ideep.hpp"
6+
7+
#include <vector>
8+
9+
namespace torch { namespace jit {
10+
11+
namespace ipex {
12+
static auto max_pool2d = Symbol::fromQualString("ipex::max_pool2d");
13+
}
14+
15+
}} // namespace torch::jit
16+
17+
namespace torch_ipex {
18+
namespace cpu {
19+
20+
at::Tensor dil_max_pool2d(
21+
const at::Tensor& input,
22+
at::IntArrayRef kernel_size,
23+
at::IntArrayRef stride,
24+
at::IntArrayRef padding,
25+
at::IntArrayRef dilation,
26+
bool ceil_mode);
27+
28+
} // namespace cpu
29+
} // namespace torch_ipex

torch_ipex/csrc/jit/fusion_pass.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "graph_rewrite.h"
44

55
#include "cpu/FusionOPs.h"
6+
#include "cpu/Pooling.h"
67

78
#include <c10/util/hash.h>
89
#include <torch/csrc/jit/runtime/operator.h>
@@ -319,6 +320,9 @@ void FusionPass(std::shared_ptr<Graph> &graph) {
319320

320321
// replace aten conv with ipex conv
321322
graph_rewrite::replaceAtenConvolutionWithIpexConv(graph);
323+
324+
// replace aten max_pool2d witj ipex max_pool2d
325+
graph_rewrite::replaceAtenMaxPool2dWithIpexMaxPool2d(graph);
322326
// TODO: Some post processing?? ECS/EDC/Peephole???
323327
ConstantPropagation(graph);
324328
}

torch_ipex/csrc/jit/graph_rewrite.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,20 @@ void replaceAtenConvolutionWithIpexConv(std::shared_ptr<Graph>& graph) {
382382
rewriter_conv2d.runOnGraph(graph);
383383
}
384384

385+
void replaceAtenMaxPool2dWithIpexMaxPool2d(std::shared_ptr<Graph>& graph) {
386+
std::string max_pool2d = R"(
387+
graph(%a, %kernel_size:int[], %stride:int[], %padding:int[], %dilation:int[], %ceil_mode:bool):
388+
%r = aten::max_pool2d(%a, %kernel_size, %stride, %padding, %dilation, %ceil_mode)
389+
return (%r) )";
390+
std::string ipex_max_pool2d = R"(
391+
graph(%a, %kernel_size:int[], %stride:int[], %padding:int[], %dilation:int[], %ceil_mode:bool):
392+
%r = ipex::max_pool2d(%a, %kernel_size, %stride, %padding, %dilation, %ceil_mode)
393+
return (%r) )";
394+
SubgraphRewriter rewriter_max_pool2d;
395+
rewriter_max_pool2d.RegisterRewritePattern(max_pool2d, ipex_max_pool2d);
396+
rewriter_max_pool2d.runOnGraph(graph);
397+
}
398+
385399
} // namespace graph_rewrite
386400
} // namespace jit
387401
} // namespace torch

torch_ipex/csrc/jit/graph_rewrite.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph);
2323
void replaceAtenConvolutionWithIpexConv(std::shared_ptr<Graph>& graph);
2424
void FuseConvolutionWithEltwise(std::shared_ptr<Graph>& graph);
2525
void FuseShuffle(std::shared_ptr<Graph>& graph);
26+
void replaceAtenMaxPool2dWithIpexMaxPool2d(std::shared_ptr<Graph>& graph);
2627

2728
} // namespace graph_rewrite_helper
2829
} // namespace jit

torch_ipex/csrc/jit/register_dnnl_jit_ops.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "torch_ipex/csrc/cpu/FusionOPs.h"
77
#include "torch_ipex/csrc/utils.h"
8+
#include "torch_ipex/csrc/cpu/Pooling.h"
89

910
namespace torch {
1011
namespace jit {
@@ -146,6 +147,23 @@ RegisterOperators op({
146147
};
147148
},
148149
aliasAnalysisFromSchema()),
150+
Operator(
151+
"ipex::max_pool2d(Tensor input, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode) -> Tensor",
152+
[](const Node* node) -> Operation {
153+
return [](Stack* stack) {
154+
auto result = torch_ipex::cpu::dil_max_pool2d(
155+
(std::move(peek(stack, 0, 6))).toTensor(),
156+
(std::move(peek(stack, 1, 6))).toIntVector(),
157+
(std::move(peek(stack, 2, 6))).toIntVector(),
158+
(std::move(peek(stack, 3, 6))).toIntVector(),
159+
(std::move(peek(stack, 4, 6))).toIntVector(),
160+
(std::move(peek(stack, 5, 6))).toBool());
161+
drop(stack, 6);
162+
pack(stack, std::move(result));
163+
return 0;
164+
};
165+
},
166+
aliasAnalysisFromSchema()),
149167
});
150168
} // namespace jit
151169
} // namespace torch

0 commit comments

Comments
 (0)