Skip to content

Commit c023ec0

Browse files
remove unnecessary training reorder (#553)
* remove unnessary training reorder * clang-format
1 parent 5973d9f commit c023ec0

File tree

2 files changed

+24
-22
lines changed
  • intel_extension_for_pytorch/csrc

2 files changed

+24
-22
lines changed

intel_extension_for_pytorch/csrc/aten/cpu/Conv.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,18 @@ at::Tensor convolution_backward_input(
233233
"Only support 2d or 3d convolution for convolution_backward_input");
234234

235235
const ideep::tensor mkldnn_grad_output = itensor_view_from_dense(grad_output);
236-
bool is_channels_last =
237-
grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
238-
grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d;
236+
bool is_channels_last_contiguous =
237+
grad_output.is_contiguous(at::MemoryFormat::ChannelsLast) ||
238+
grad_output.is_contiguous(at::MemoryFormat::ChannelsLast3d);
239+
240+
auto memory_format = at::MemoryFormat::Contiguous;
241+
if (is_channels_last_contiguous) {
242+
if (input_size.size() == 4) {
243+
memory_format = at::MemoryFormat::ChannelsLast;
244+
} else {
245+
memory_format = at::MemoryFormat::ChannelsLast3d;
246+
}
247+
}
239248

240249
std::vector<int64_t> origin_weight_dims;
241250
origin_weight_dims.push_back(grad_output.size(1));
@@ -256,11 +265,10 @@ at::Tensor convolution_backward_input(
256265
{},
257266
ideep::attr_t());
258267

259-
auto grad_input = at::empty(
260-
input_size,
261-
grad_output.options().memory_format(grad_output.suggest_memory_format()));
268+
auto grad_input =
269+
at::empty(input_size, grad_output.options().memory_format(memory_format));
262270
ideep::tensor mkldnn_grad_input;
263-
if (is_channels_last) {
271+
if (is_channels_last_contiguous) {
264272
mkldnn_grad_input = itensor_view_from_dense(grad_input);
265273
}
266274

@@ -275,7 +283,7 @@ at::Tensor convolution_backward_input(
275283
padding.vec(),
276284
groups);
277285

278-
if (is_channels_last) {
286+
if (is_channels_last_contiguous) {
279287
return grad_input;
280288
} else {
281289
return mkldnn_to_dense(new_with_itensor_mkldnn(
@@ -302,9 +310,10 @@ std::tuple<at::Tensor, at::Tensor> convolution_backward_weights(
302310
"Only support 2d or 3d convolution for convolution_backward_weights");
303311
const ideep::tensor mkldnn_grad_output = itensor_view_from_dense(grad_output);
304312
const ideep::tensor mkldnn_input = itensor_view_from_dense(input);
305-
bool is_channels_last =
306-
grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
307-
grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d;
313+
314+
bool is_channels_last_contiguous =
315+
grad_output.is_contiguous(at::MemoryFormat::ChannelsLast) ||
316+
grad_output.is_contiguous(at::MemoryFormat::ChannelsLast3d);
308317

309318
auto grad_weight = at::empty(weight_size, grad_output.options());
310319
at::Tensor grad_bias;
@@ -361,7 +370,7 @@ std::tuple<at::Tensor, at::Tensor> convolution_backward_weights(
361370
if (weight_packed) {
362371
return std::make_tuple(grad_weight, grad_bias);
363372
} else {
364-
if (is_channels_last) {
373+
if (is_channels_last_contiguous) {
365374
auto memory_format = input.dim() == 4 ? at::MemoryFormat::ChannelsLast
366375
: at::MemoryFormat::ChannelsLast3d;
367376
return std::make_tuple(

intel_extension_for_pytorch/csrc/cpu/ideep/ideep/operators/conv.hpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -697,8 +697,7 @@ struct convolution_forward
697697
// it will be removed after block format reorder performance improved.
698698
if (!weights.get_desc().is_plain() &&
699699
weights.get_desc() != pd.weights_desc()) {
700-
auto temp = weights.to_public(nullptr, weights.get_data_type());
701-
expected_weights = temp.reorder_if_differ_in(pd.weights_desc());
700+
expected_weights = weights.reorder_if_differ_in(pd.weights_desc());
702701
} else {
703702
expected_weights = weights.make_grouped_weights(param.groups)
704703
.reorder_if_differ_in(pd.weights_desc());
@@ -763,8 +762,7 @@ struct convolution_forward
763762
// it will be removed after block format reorder performance improved.
764763
if (!weights.get_desc().is_plain() &&
765764
weights.get_desc() != pd.weights_desc()) {
766-
auto temp = weights.to_public(nullptr, weights.get_data_type());
767-
expected_weights = temp.reorder_if_differ_in(pd.weights_desc());
765+
expected_weights = weights.reorder_if_differ_in(pd.weights_desc());
768766
} else {
769767
expected_weights = weights.make_grouped_weights(param.groups)
770768
.reorder_if_differ_in(pd.weights_desc());
@@ -1074,12 +1072,7 @@ struct convolution_backward_weights
10741072
// diff_weights has been init in FW side, but has diff desc with
10751073
// expected_diff_weights.
10761074
if (diff_weights.get_desc() != expected_diff_weights_desc) {
1077-
// TODO: there has an issue when reorder block to block,
1078-
// will be removed after
1079-
// https://jira.devtools.intel.com/browse/MFDNN-5557 is fixed.
1080-
auto temp = expected_diff_weights.to_public(
1081-
nullptr, expected_diff_weights.get_data_type());
1082-
diff_weights.feed_from(temp);
1075+
diff_weights.feed_from(expected_diff_weights);
10831076
}
10841077
}
10851078
};

0 commit comments

Comments
 (0)