@@ -53,3 +53,57 @@ class NewLinearOp : public torch::autograd::Function<NewLinearOp> {
53
53
return {grad_input, grad_weight, grad_bias};
54
54
}
55
55
};
56
+
57
+ class NewMaxPoolingOp : public torch ::autograd::Function<NewMaxPoolingOp> {
58
+ public:
59
+ static at::Tensor forward (
60
+ torch::autograd::AutogradContext* ctx,
61
+ at::Tensor input,
62
+ at::IntArrayRef kernel_size,
63
+ at::IntArrayRef stride,
64
+ at::IntArrayRef padding,
65
+ at::IntArrayRef dilation,
66
+ bool ceil_mode) {
67
+ ctx->saved_data [" kernel_size" ] = kernel_size;
68
+ ctx->saved_data [" stride" ] = stride;
69
+ ctx->saved_data [" padding" ] = padding;
70
+ ctx->saved_data [" dilation" ] = dilation;
71
+ ctx->saved_data [" ceil_mode" ] = ceil_mode;
72
+ if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
73
+ at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling (input, kernel_size, stride,
74
+ padding, dilation, ceil_mode);
75
+ ctx->save_for_backward ({input, output});
76
+ return output;
77
+ } else {
78
+ at::Tensor output, indices;
79
+ std::tie (output, indices) = at::max_pool2d_with_indices (input, kernel_size, stride, padding, dilation, ceil_mode);
80
+ ctx->save_for_backward ({input, indices});
81
+ return output;
82
+ }
83
+ }
84
+
85
+ static torch::autograd::tensor_list backward (
86
+ torch::autograd::AutogradContext* ctx,
87
+ torch::autograd::tensor_list grad_outputs) {
88
+ auto saved = ctx->get_saved_variables ();
89
+ at::Tensor input = saved[0 ];
90
+ at::Tensor indices = saved[1 ];
91
+
92
+ at::Tensor grad_output = grad_outputs[0 ];
93
+ at::Tensor grad_input;
94
+ at::IntArrayRef kernel_size = at::IntArrayRef (ctx->saved_data [" kernel_size" ].toIntVector ());
95
+ at::IntArrayRef stride = at::IntArrayRef (ctx->saved_data [" stride" ].toIntVector ());
96
+ at::IntArrayRef padding = at::IntArrayRef (ctx->saved_data [" padding" ].toIntVector ());
97
+ at::IntArrayRef dilation = at::IntArrayRef (ctx->saved_data [" dilation" ].toIntVector ());
98
+ bool ceil_mode = ctx->saved_data [" ceil_mode" ].toBool ();
99
+
100
+ if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
101
+ grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling_backward (
102
+ grad_output, indices, input, kernel_size, stride, padding, dilation, ceil_mode);
103
+ } else {
104
+ grad_input = at::max_pool2d_with_indices_backward (grad_output, input, kernel_size,
105
+ stride, padding, dilation, ceil_mode, indices);
106
+ }
107
+ return {grad_input};
108
+ }
109
+ };
0 commit comments