1
1
#include " Conv.h"
2
2
#include " mkldnn/MKLDNNCommon.h"
3
3
#include " torch_ipex/csrc/utils.h"
4
+ #include " WeightPrepack.h"
4
5
5
6
namespace torch_ipex {
6
7
namespace cpu {
7
8
8
- namespace {
9
-
10
- using weakref_type = c10::weak_intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>;
11
- using val_blocked = std::tuple<weakref_type, ideep::tensor>;
12
- thread_local std::unordered_map<c10::TensorImpl *, val_blocked> cached_weights;
13
-
14
- } // namespace
15
-
16
9
std::vector<int64_t > calc_conv_output_size (
17
10
at::IntArrayRef input_size,
18
11
at::IntArrayRef kernel_size,
@@ -30,61 +23,6 @@ std::vector<int64_t> calc_conv_output_size(
30
23
return output_size;
31
24
}
32
25
33
- ideep::tensor get_prepack_conv_weights (
34
- const ideep::tensor& input,
35
- const at::Tensor& weight,
36
- at::IntArrayRef stride,
37
- at::IntArrayRef padding,
38
- at::IntArrayRef dilation,
39
- int64_t groups,
40
- const ideep::attr_t & attr) {
41
- auto it = cached_weights.find (weight.unsafeGetTensorImpl ());
42
- if (it != cached_weights.end ()) {
43
- return std::get<1 >(it->second );
44
- } else {
45
- ideep::tensor w = at::native::itensor_view_from_dense (weight);
46
- // TODO: 3d check
47
- bool is_channels_last = input.get_desc ().is_nhwc ();
48
- ideep::tensor::desc packed_desc;
49
- if (is_channels_last) {
50
- packed_desc = ideep::convolution_forward::expected_weights_desc<true >(
51
- w.get_dims (),
52
- w.get_data_type (),
53
- stride.vec (),
54
- padding.vec (),
55
- padding.vec (),
56
- dilation.vec (),
57
- groups,
58
- ideep::algorithm::convolution_direct,
59
- ideep::prop_kind::forward,
60
- input.get_data_type (),
61
- input.get_dims (),
62
- attr);
63
- } else {
64
- packed_desc = ideep::convolution_forward::expected_weights_desc<false >(
65
- w.get_dims (),
66
- w.get_data_type (),
67
- stride.vec (),
68
- padding.vec (),
69
- padding.vec (),
70
- dilation.vec (),
71
- groups,
72
- ideep::algorithm::convolution_direct,
73
- ideep::prop_kind::forward,
74
- input.get_data_type (),
75
- input.get_dims (),
76
- attr);
77
- }
78
- ideep::tensor result;
79
- result.init (packed_desc);
80
- result.feed_from (w);
81
- cached_weights.emplace (
82
- weight.unsafeGetTensorImpl (),
83
- val_blocked{weakref_type (weight.getIntrusivePtr ()), result});
84
- return result;
85
- }
86
- }
87
-
88
26
at::Tensor convolution_impl (
89
27
const at::Tensor& input,
90
28
const at::Tensor& weight,
@@ -96,22 +34,24 @@ at::Tensor convolution_impl(
96
34
const ideep::attr_t & attr) {
97
35
// TODO: the input will be actively converted to channels last format
98
36
// after the 5-D tensor supports channels last format.
99
- const ideep::tensor mkldnn_input = at::native::itensor_view_from_dense (input);
100
- ideep::tensor mkldnn_weight = get_prepack_conv_weights (mkldnn_input, weight, stride, padding, dilation, groups, attr);
37
+ auto input_ = IS_CONTIGUOUS_ANY (input) ? input : input.contiguous ();
38
+ const ideep::tensor mkldnn_input = at::native::itensor_view_from_dense (input_);
39
+ ideep::tensor mkldnn_weight = get_conv_prepacked_weight (mkldnn_input, weight, stride, padding, dilation, groups, attr);
101
40
auto kernel_size = mkldnn_weight.get_dims ();
102
41
std::vector<int64_t > input_size = mkldnn_input.get_dims ();
103
42
std::vector<int64_t > output_sizes =
104
43
calc_conv_output_size (input_size, kernel_size, padding, stride, dilation);
105
44
106
- bool is_channels_last = input .suggest_memory_format () == at::MemoryFormat::ChannelsLast;
107
- auto output = at::empty (output_sizes, input .options ().memory_format (input .suggest_memory_format ()));
45
+ bool is_channels_last = input_ .suggest_memory_format () == at::MemoryFormat::ChannelsLast;
46
+ auto output = at::empty (output_sizes, input_ .options ().memory_format (input_ .suggest_memory_format ()));
108
47
ideep::tensor mkldnn_output;
109
48
if (is_channels_last) {
110
49
mkldnn_output = at::native::itensor_view_from_dense (output);
111
50
}
112
51
113
52
if (bias.defined ()) {
114
- const ideep::tensor mkldnn_bias = at::native::itensor_view_from_dense (bias);
53
+ auto bias_ = IS_CONTIGUOUS_ANY (bias) ? bias : bias.contiguous ();
54
+ const ideep::tensor mkldnn_bias = at::native::itensor_view_from_dense (bias_);
115
55
ideep::convolution_forward::compute (
116
56
mkldnn_input,
117
57
mkldnn_weight,
@@ -165,20 +105,22 @@ void convolution_inplace_impl(
165
105
const ideep::attr_t & attr) {
166
106
// TODO: the input will be actively converted to channels last format
167
107
// after the 5-D tensor supports channels last format.
168
- const ideep::tensor mkldnn_input = at::native::itensor_view_from_dense (input);
169
- ideep::tensor mkldnn_weight = get_prepack_conv_weights (mkldnn_input, weight, stride, padding, dilation, groups, attr);
108
+ auto input_ = IS_CONTIGUOUS_ANY (input) ? input : input.contiguous ();
109
+ const ideep::tensor mkldnn_input = at::native::itensor_view_from_dense (input_);
110
+ ideep::tensor mkldnn_weight = get_conv_prepacked_weight (mkldnn_input, weight, stride, padding, dilation, groups, attr);
170
111
auto kernel_size = mkldnn_weight.get_dims ();
171
112
std::vector<int64_t > input_size = mkldnn_input.get_dims ();
172
113
std::vector<int64_t > output_sizes =
173
114
calc_conv_output_size (input_size, kernel_size, padding, stride, dilation);
174
115
175
- bool is_channels_last = input .suggest_memory_format () == at::MemoryFormat::ChannelsLast;
116
+ bool is_channels_last = input_ .suggest_memory_format () == at::MemoryFormat::ChannelsLast;
176
117
output = IS_CONTIGUOUS_ANY (output) ? output : output.contiguous ();
177
- output = output.to (input .suggest_memory_format ());
118
+ output = output.to (input_ .suggest_memory_format ());
178
119
ideep::tensor mkldnn_output = at::native::itensor_view_from_dense (output);
179
120
180
121
if (bias.defined ()) {
181
- const ideep::tensor mkldnn_bias = at::native::itensor_view_from_dense (bias);
122
+ auto bias_ = IS_CONTIGUOUS_ANY (bias) ? bias : bias.contiguous ();
123
+ const ideep::tensor mkldnn_bias = at::native::itensor_view_from_dense (bias_);
182
124
ideep::convolution_forward::compute (
183
125
mkldnn_input,
184
126
mkldnn_weight,
0 commit comments