@@ -233,9 +233,18 @@ at::Tensor convolution_backward_input(
233
233
" Only support 2d or 3d convolution for convolution_backward_input" );
234
234
235
235
const ideep::tensor mkldnn_grad_output = itensor_view_from_dense (grad_output);
236
- bool is_channels_last =
237
- grad_output.suggest_memory_format () == at::MemoryFormat::ChannelsLast ||
238
- grad_output.suggest_memory_format () == at::MemoryFormat::ChannelsLast3d;
236
+ bool is_channels_last_contiguous =
237
+ grad_output.is_contiguous (at::MemoryFormat::ChannelsLast) ||
238
+ grad_output.is_contiguous (at::MemoryFormat::ChannelsLast3d);
239
+
240
+ auto memory_format = at::MemoryFormat::Contiguous;
241
+ if (is_channels_last_contiguous) {
242
+ if (input_size.size () == 4 ) {
243
+ memory_format = at::MemoryFormat::ChannelsLast;
244
+ } else {
245
+ memory_format = at::MemoryFormat::ChannelsLast3d;
246
+ }
247
+ }
239
248
240
249
std::vector<int64_t > origin_weight_dims;
241
250
origin_weight_dims.push_back (grad_output.size (1 ));
@@ -256,11 +265,10 @@ at::Tensor convolution_backward_input(
256
265
{},
257
266
ideep::attr_t ());
258
267
259
- auto grad_input = at::empty (
260
- input_size,
261
- grad_output.options ().memory_format (grad_output.suggest_memory_format ()));
268
+ auto grad_input =
269
+ at::empty (input_size, grad_output.options ().memory_format (memory_format));
262
270
ideep::tensor mkldnn_grad_input;
263
- if (is_channels_last ) {
271
+ if (is_channels_last_contiguous ) {
264
272
mkldnn_grad_input = itensor_view_from_dense (grad_input);
265
273
}
266
274
@@ -275,7 +283,7 @@ at::Tensor convolution_backward_input(
275
283
padding.vec (),
276
284
groups);
277
285
278
- if (is_channels_last ) {
286
+ if (is_channels_last_contiguous ) {
279
287
return grad_input;
280
288
} else {
281
289
return mkldnn_to_dense (new_with_itensor_mkldnn (
@@ -302,9 +310,10 @@ std::tuple<at::Tensor, at::Tensor> convolution_backward_weights(
302
310
" Only support 2d or 3d convolution for convolution_backward_weights" );
303
311
const ideep::tensor mkldnn_grad_output = itensor_view_from_dense (grad_output);
304
312
const ideep::tensor mkldnn_input = itensor_view_from_dense (input);
305
- bool is_channels_last =
306
- grad_output.suggest_memory_format () == at::MemoryFormat::ChannelsLast ||
307
- grad_output.suggest_memory_format () == at::MemoryFormat::ChannelsLast3d;
313
+
314
+ bool is_channels_last_contiguous =
315
+ grad_output.is_contiguous (at::MemoryFormat::ChannelsLast) ||
316
+ grad_output.is_contiguous (at::MemoryFormat::ChannelsLast3d);
308
317
309
318
auto grad_weight = at::empty (weight_size, grad_output.options ());
310
319
at::Tensor grad_bias;
@@ -361,7 +370,7 @@ std::tuple<at::Tensor, at::Tensor> convolution_backward_weights(
361
370
if (weight_packed) {
362
371
return std::make_tuple (grad_weight, grad_bias);
363
372
} else {
364
- if (is_channels_last ) {
373
+ if (is_channels_last_contiguous ) {
365
374
auto memory_format = input.dim () == 4 ? at::MemoryFormat::ChannelsLast
366
375
: at::MemoryFormat::ChannelsLast3d;
367
376
return std::make_tuple (
0 commit comments