9
9
#include " torch_ipex/csrc/utils.h"
10
10
#include " DevOPs.h"
11
11
12
- using namespace at ;
13
-
14
12
class NewLinearOp : public torch ::autograd::Function<NewLinearOp> {
15
13
public:
16
14
static at::Tensor forward (
@@ -40,9 +38,9 @@ class NewLinearOp : public torch::autograd::Function<NewLinearOp> {
40
38
41
39
if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
42
40
grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_input (
43
- input.sizes (), grad_output, weight);
41
+ input.sizes (), grad_output. contiguous () , weight);
44
42
std::tie (grad_weight, grad_bias) = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_weights (
45
- grad_output, input, weight, bias.defined ());
43
+ grad_output. contiguous () , input, weight, bias.defined ());
46
44
} else {
47
45
auto grad_input = grad_output.mm (weight);
48
46
auto grad_weight = grad_output.t ().mm (input);
@@ -69,6 +67,7 @@ class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
69
67
ctx->saved_data [" padding" ] = padding;
70
68
ctx->saved_data [" dilation" ] = dilation;
71
69
ctx->saved_data [" ceil_mode" ] = ceil_mode;
70
+
72
71
if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
73
72
at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling (input, kernel_size, stride,
74
73
padding, dilation, ceil_mode);
@@ -89,12 +88,13 @@ class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
89
88
at::Tensor input = saved[0 ];
90
89
at::Tensor indices = saved[1 ];
91
90
92
- at::Tensor grad_output = grad_outputs[0 ];
91
+ at::Tensor grad_output = grad_outputs[0 ]. contiguous () ;
93
92
at::Tensor grad_input;
94
- at::IntArrayRef kernel_size = at::IntArrayRef (ctx->saved_data [" kernel_size" ].toIntVector ());
95
- at::IntArrayRef stride = at::IntArrayRef (ctx->saved_data [" stride" ].toIntVector ());
96
- at::IntArrayRef padding = at::IntArrayRef (ctx->saved_data [" padding" ].toIntVector ());
97
- at::IntArrayRef dilation = at::IntArrayRef (ctx->saved_data [" dilation" ].toIntVector ());
93
+
94
+ std::vector<int64_t > kernel_size = ctx->saved_data [" kernel_size" ].toIntVector ();
95
+ std::vector<int64_t > stride = ctx->saved_data [" stride" ].toIntVector ();
96
+ std::vector<int64_t > padding = ctx->saved_data [" padding" ].toIntVector ();
97
+ std::vector<int64_t > dilation = ctx->saved_data [" dilation" ].toIntVector ();
98
98
bool ceil_mode = ctx->saved_data [" ceil_mode" ].toBool ();
99
99
100
100
if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
@@ -104,6 +104,6 @@ class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
104
104
grad_input = at::max_pool2d_with_indices_backward (grad_output, input, kernel_size,
105
105
stride, padding, dilation, ceil_mode, indices);
106
106
}
107
- return {grad_input};
107
+ return {grad_input, at::Tensor (), at::Tensor (), at::Tensor (), at::Tensor (), at::Tensor () };
108
108
}
109
109
};
0 commit comments