Skip to content

Commit 8a5f2ee

Browse files
linear weight prepack cache for jit path (#30)
* linear weight prepack cache for jit path * add description of weight prepack * fix conv bn fusion error when some model can't use FX * change the description of weight prepack
1 parent 05543d5 commit 8a5f2ee

File tree

17 files changed

+300
-227
lines changed

17 files changed

+300
-227
lines changed

ideep/ideep/operators/inner_product.hpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ struct inner_product_forward : public dnnl::inner_product_forward {
178178
// align weights data type with src
179179
dst_data_type = src.get_data_type() == data_type::bf16 ? data_type::bf16
180180
: data_type::f32;
181-
src_desc = src.get_desc().to_type(dst_data_type);
182-
weights_desc = weights.get_desc().to_type(dst_data_type);
181+
src_desc = src.get_desc().to_type(dst_data_type).to_format_any();
182+
weights_desc = weights.get_desc().to_type(dst_data_type).to_format_any();
183183
if (with_bias) {
184184
IDEEP_ENFORCE(utils::one_of(bias.get_data_type(),
185185
data_type::f32, data_type::bf16),
@@ -197,9 +197,16 @@ struct inner_product_forward : public dnnl::inner_product_forward {
197197

198198
auto expected_src = src.reorder_if_differ_in(pd.src_desc(), src_attr);
199199
auto expected_weights = weights.reorder_if_differ_in(pd.weights_desc(), weights_attr);
200-
dst.reinit_if_possible(pd.dst_desc());
201-
if (!dst_scales.empty() && dst.get_data_type() != data_type::f32) {
202-
dst.set_scale(dst_scales_in);
200+
// [ Note output buffer ]
201+
// In this case, dst is an empty ideep tensor, can be re-init
202+
// If dst is not empty, ideep must write result to dst's memory and it is caller's duty to
203+
// make sure dst is big enough to hold the result
204+
if (dst.is_empty()) {
205+
dst.init(pd.dst_desc());
206+
}
207+
auto expected_dst = dst.reorder_if_differ_in(pd.dst_desc());
208+
if (!dst_scales.empty() && utils::one_of(dst.get_data_type(), data_type::u8, data_type::s8)) {
209+
expected_dst.set_scale(dst_scales_in);
203210
}
204211

205212
if (with_bias){
@@ -208,17 +215,19 @@ struct inner_product_forward : public dnnl::inner_product_forward {
208215
{{DNNL_ARG_SRC, expected_src},
209216
{DNNL_ARG_WEIGHTS, expected_weights},
210217
{DNNL_ARG_BIAS, expected_bias},
211-
{DNNL_ARG_DST, dst}});
218+
{DNNL_ARG_DST, expected_dst}});
212219
} else {
213220
super(pd).execute(stream::default_stream(),
214221
{{DNNL_ARG_SRC, expected_src},
215222
{DNNL_ARG_WEIGHTS, expected_weights},
216-
{DNNL_ARG_DST, dst}});
223+
{DNNL_ARG_DST, expected_dst}});
217224
}
218225

219-
if (attr.non_negitive_output() && dst.get_data_type() == data_type::s8) {
220-
dst.to_type(data_type::u8);
226+
if (attr.non_negitive_output() && expected_dst.get_data_type() == data_type::s8) {
227+
expected_dst.to_type(data_type::u8);
221228
}
229+
// reorder back to dst's buffer if needed
230+
expected_dst.reorder_to_if_differ_from(dst);
222231
}
223232
};
224233

ideep/ideep/tensor.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,14 @@ class tensor : public memory {
663663
}
664664
}
665665

666+
// Reorder data from *this to dst if dst's memory desc(size, stride, format, etc) is different from *this;
667+
void reorder_to_if_differ_from(tensor &dst, const attr_t &aattr = attr_t()) const {
668+
if (dst.get_desc() != get_desc()) {
669+
this->reorder_to(dst, aattr);
670+
}
671+
return;
672+
}
673+
666674
// workaround for issue intel/mkl-dnn#588
667675
desc _get_unblocked_desc_if_4c_blocked() const {
668676
auto desc = get_desc();

intel_pytorch_extension_py/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ def convert_module_data_type(module, dtype):
2323
return module
2424

2525
def optimize(model, dtype=torch.bfloat16, level='O1'):
26-
optimized_model = conv_bn_fuse(model)
26+
try:
27+
optimized_model = conv_bn_fuse(model)
28+
except:
29+
warnings.warn("Conv BN folding failed during the optimize process.")
30+
optimized_model = model
2731
if dtype == torch.bfloat16:
2832
optimized_model = convert_module_data_type(optimized_model, torch.bfloat16)
2933
return optimized_model

torch_ipex/csrc/cpu/Conv.cpp

Lines changed: 15 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
11
#include "Conv.h"
22
#include "mkldnn/MKLDNNCommon.h"
33
#include "torch_ipex/csrc/utils.h"
4+
#include "WeightPrepack.h"
45

56
namespace torch_ipex {
67
namespace cpu {
78

8-
namespace {
9-
10-
using weakref_type = c10::weak_intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>;
11-
using val_blocked = std::tuple<weakref_type, ideep::tensor>;
12-
thread_local std::unordered_map<c10::TensorImpl *, val_blocked> cached_weights;
13-
14-
} // namespace
15-
169
std::vector<int64_t> calc_conv_output_size(
1710
at::IntArrayRef input_size,
1811
at::IntArrayRef kernel_size,
@@ -30,61 +23,6 @@ std::vector<int64_t> calc_conv_output_size(
3023
return output_size;
3124
}
3225

33-
ideep::tensor get_prepack_conv_weights(
34-
const ideep::tensor& input,
35-
const at::Tensor& weight,
36-
at::IntArrayRef stride,
37-
at::IntArrayRef padding,
38-
at::IntArrayRef dilation,
39-
int64_t groups,
40-
const ideep::attr_t& attr) {
41-
auto it = cached_weights.find(weight.unsafeGetTensorImpl());
42-
if (it != cached_weights.end()) {
43-
return std::get<1>(it->second);
44-
} else {
45-
ideep::tensor w = at::native::itensor_view_from_dense(weight);
46-
// TODO: 3d check
47-
bool is_channels_last = input.get_desc().is_nhwc();
48-
ideep::tensor::desc packed_desc;
49-
if (is_channels_last) {
50-
packed_desc = ideep::convolution_forward::expected_weights_desc<true>(
51-
w.get_dims(),
52-
w.get_data_type(),
53-
stride.vec(),
54-
padding.vec(),
55-
padding.vec(),
56-
dilation.vec(),
57-
groups,
58-
ideep::algorithm::convolution_direct,
59-
ideep::prop_kind::forward,
60-
input.get_data_type(),
61-
input.get_dims(),
62-
attr);
63-
} else {
64-
packed_desc = ideep::convolution_forward::expected_weights_desc<false>(
65-
w.get_dims(),
66-
w.get_data_type(),
67-
stride.vec(),
68-
padding.vec(),
69-
padding.vec(),
70-
dilation.vec(),
71-
groups,
72-
ideep::algorithm::convolution_direct,
73-
ideep::prop_kind::forward,
74-
input.get_data_type(),
75-
input.get_dims(),
76-
attr);
77-
}
78-
ideep::tensor result;
79-
result.init(packed_desc);
80-
result.feed_from(w);
81-
cached_weights.emplace(
82-
weight.unsafeGetTensorImpl(),
83-
val_blocked{weakref_type(weight.getIntrusivePtr()), result});
84-
return result;
85-
}
86-
}
87-
8826
at::Tensor convolution_impl(
8927
const at::Tensor& input,
9028
const at::Tensor& weight,
@@ -96,22 +34,24 @@ at::Tensor convolution_impl(
9634
const ideep::attr_t& attr) {
9735
// TODO: the input will be actively converted to channels last format
9836
// after the 5-D tensor supports channels last format.
99-
const ideep::tensor mkldnn_input = at::native::itensor_view_from_dense(input);
100-
ideep::tensor mkldnn_weight = get_prepack_conv_weights(mkldnn_input, weight, stride, padding, dilation, groups, attr);
37+
auto input_ = IS_CONTIGUOUS_ANY(input) ? input : input.contiguous();
38+
const ideep::tensor mkldnn_input = at::native::itensor_view_from_dense(input_);
39+
ideep::tensor mkldnn_weight = get_conv_prepacked_weight(mkldnn_input, weight, stride, padding, dilation, groups, attr);
10140
auto kernel_size = mkldnn_weight.get_dims();
10241
std::vector<int64_t> input_size = mkldnn_input.get_dims();
10342
std::vector<int64_t> output_sizes =
10443
calc_conv_output_size(input_size, kernel_size, padding, stride, dilation);
10544

106-
bool is_channels_last = input.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
107-
auto output = at::empty(output_sizes, input.options().memory_format(input.suggest_memory_format()));
45+
bool is_channels_last = input_.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
46+
auto output = at::empty(output_sizes, input_.options().memory_format(input_.suggest_memory_format()));
10847
ideep::tensor mkldnn_output;
10948
if (is_channels_last) {
11049
mkldnn_output = at::native::itensor_view_from_dense(output);
11150
}
11251

11352
if (bias.defined()) {
114-
const ideep::tensor mkldnn_bias = at::native::itensor_view_from_dense(bias);
53+
auto bias_ = IS_CONTIGUOUS_ANY(bias) ? bias : bias.contiguous();
54+
const ideep::tensor mkldnn_bias = at::native::itensor_view_from_dense(bias_);
11555
ideep::convolution_forward::compute(
11656
mkldnn_input,
11757
mkldnn_weight,
@@ -165,20 +105,22 @@ void convolution_inplace_impl(
165105
const ideep::attr_t& attr) {
166106
// TODO: the input will be actively converted to channels last format
167107
// after the 5-D tensor supports channels last format.
168-
const ideep::tensor mkldnn_input = at::native::itensor_view_from_dense(input);
169-
ideep::tensor mkldnn_weight = get_prepack_conv_weights(mkldnn_input, weight, stride, padding, dilation, groups, attr);
108+
auto input_ = IS_CONTIGUOUS_ANY(input) ? input : input.contiguous();
109+
const ideep::tensor mkldnn_input = at::native::itensor_view_from_dense(input_);
110+
ideep::tensor mkldnn_weight = get_conv_prepacked_weight(mkldnn_input, weight, stride, padding, dilation, groups, attr);
170111
auto kernel_size = mkldnn_weight.get_dims();
171112
std::vector<int64_t> input_size = mkldnn_input.get_dims();
172113
std::vector<int64_t> output_sizes =
173114
calc_conv_output_size(input_size, kernel_size, padding, stride, dilation);
174115

175-
bool is_channels_last = input.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
116+
bool is_channels_last = input_.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
176117
output = IS_CONTIGUOUS_ANY(output) ? output : output.contiguous();
177-
output = output.to(input.suggest_memory_format());
118+
output = output.to(input_.suggest_memory_format());
178119
ideep::tensor mkldnn_output = at::native::itensor_view_from_dense(output);
179120

180121
if (bias.defined()) {
181-
const ideep::tensor mkldnn_bias = at::native::itensor_view_from_dense(bias);
122+
auto bias_ = IS_CONTIGUOUS_ANY(bias) ? bias : bias.contiguous();
123+
const ideep::tensor mkldnn_bias = at::native::itensor_view_from_dense(bias_);
182124
ideep::convolution_forward::compute(
183125
mkldnn_input,
184126
mkldnn_weight,

torch_ipex/csrc/cpu/Conv.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
#pragma once
22

33
#include <ATen/Tensor.h>
4-
54
#include "ideep/ideep.hpp"
65

7-
#include <vector>
8-
96
namespace torch_ipex {
107
namespace cpu {
118

0 commit comments

Comments
 (0)