6
6
#include < ATen/Tensor.h>
7
7
#include < torch/script.h>
8
8
#include < c10/util/Optional.h>
9
+ #include " torch_ipex/csrc/aten_ipex_bridge.h"
9
10
#include " torch_ipex/csrc/utils.h"
10
11
#include " DevOPs.h"
11
12
@@ -68,17 +69,29 @@ class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
68
69
ctx->saved_data [" dilation" ] = dilation;
69
70
ctx->saved_data [" ceil_mode" ] = ceil_mode;
70
71
71
- if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
72
- at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling (input.is_contiguous () ? input : input.contiguous (), kernel_size, stride,
73
- padding, dilation, ceil_mode);
74
- ctx->save_for_backward ({input, output});
75
- return output;
72
+ try {
73
+ if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
74
+ at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling (input.is_contiguous () ? input : input.contiguous (), kernel_size, stride,
75
+ padding, dilation, ceil_mode);
76
+ ctx->save_for_backward ({input, output});
77
+ return output;
78
+ }
79
+ } catch (std::exception& e) {
80
+ #if defined(_DEBUG)
81
+ TORCH_WARN (e.what ());
82
+ #endif
83
+ }
84
+ at::Tensor output, indices;
85
+ if (input.device ().type () == c10::DeviceType::DPCPP) {
86
+ auto && _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor (input);
87
+ auto && _ipex_result = at::max_pool2d_with_indices (_ipex_input, kernel_size, stride, padding, dilation, ceil_mode);
88
+ static_cast <void >(_ipex_result);
89
+ std::tie (output, indices) = std::tuple<at::Tensor,at::Tensor>(torch_ipex::bridge::shallowUpgradeToDPCPPTensor (std::get<0 >(_ipex_result)), torch_ipex::bridge::shallowUpgradeToDPCPPTensor (std::get<1 >(_ipex_result)));
76
90
} else {
77
- at::Tensor output, indices;
78
91
std::tie (output, indices) = at::max_pool2d_with_indices (input, kernel_size, stride, padding, dilation, ceil_mode);
79
- ctx->save_for_backward ({input, indices});
80
- return output;
81
92
}
93
+ ctx->save_for_backward ({input, indices});
94
+ return output;
82
95
}
83
96
84
97
static torch::autograd::tensor_list backward (
@@ -97,9 +110,26 @@ class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
97
110
std::vector<int64_t > dilation = ctx->saved_data [" dilation" ].toIntVector ();
98
111
bool ceil_mode = ctx->saved_data [" ceil_mode" ].toBool ();
99
112
100
- if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
101
- grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling_backward (
102
- grad_output.is_contiguous () ? grad_output : grad_output.contiguous (), indices.is_contiguous () ? indices : indices.contiguous (), input.is_contiguous () ? input : input.contiguous (), kernel_size, stride, padding, dilation, ceil_mode);
113
+
114
+ try {
115
+ if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
116
+ grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling_backward (
117
+ grad_output.is_contiguous () ? grad_output : grad_output.contiguous (), indices.is_contiguous () ? indices : indices.contiguous (), input.is_contiguous () ? input : input.contiguous (), kernel_size, stride, padding, dilation, ceil_mode);
118
+ return {grad_input, at::Tensor (), at::Tensor (), at::Tensor (), at::Tensor (), at::Tensor ()};
119
+ }
120
+ } catch (std::exception& e) {
121
+ #if defined(_DEBUG)
122
+ TORCH_WARN (e.what ());
123
+ #endif
124
+ }
125
+ if (input.device ().type () == c10::DeviceType::DPCPP) {
126
+ auto && _ipex_grad_output = torch_ipex::bridge::shallowFallbackToCPUTensor (grad_output);
127
+ auto && _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor (input);
128
+ auto && _ipex_indices = torch_ipex::bridge::shallowFallbackToCPUTensor (indices);
129
+ auto && _ipex_grad_input = at::max_pool2d_with_indices_backward (_ipex_grad_output, _ipex_input, kernel_size,
130
+ stride, padding, dilation, ceil_mode, _ipex_indices);
131
+ static_cast <void >(_ipex_grad_input);
132
+ grad_input = torch_ipex::bridge::shallowUpgradeToDPCPPTensor (_ipex_grad_input);
103
133
} else {
104
134
grad_input = at::max_pool2d_with_indices_backward (grad_output, input, kernel_size,
105
135
stride, padding, dilation, ceil_mode, indices);
@@ -116,13 +146,23 @@ class NewApaptiveAvgPoolingOp : public torch::autograd::Function<NewApaptiveAvgP
116
146
at::IntArrayRef output_size) {
117
147
ctx->save_for_backward ({input});
118
148
119
- at::Tensor output;
120
- if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
121
- output = torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d (input.is_contiguous () ? input : input.contiguous (), output_size);
149
+ try {
150
+ if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
151
+ return torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d (input.is_contiguous () ? input : input.contiguous (), output_size);
152
+ }
153
+ } catch (std::exception& e) {
154
+ #if defined(_DEBUG)
155
+ TORCH_WARN (e.what ());
156
+ #endif
157
+ }
158
+ if (input.device ().type () == c10::DeviceType::DPCPP) {
159
+ auto && _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor (input);
160
+ auto && _ipex_result = at::_adaptive_avg_pool2d (_ipex_input, output_size);
161
+ static_cast <void >(_ipex_result); // Avoid warnings in case not used
162
+ return torch_ipex::bridge::shallowUpgradeToDPCPPTensor (_ipex_result);
122
163
} else {
123
- output = at::_adaptive_avg_pool2d (input, output_size);
164
+ return at::_adaptive_avg_pool2d (input, output_size);
124
165
}
125
- return output;
126
166
}
127
167
128
168
static torch::autograd::tensor_list backward (
@@ -134,8 +174,22 @@ class NewApaptiveAvgPoolingOp : public torch::autograd::Function<NewApaptiveAvgP
134
174
at::Tensor grad_output = grad_outputs[0 ];
135
175
at::Tensor grad_input;
136
176
137
- if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
138
- grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d_backward (grad_output.is_contiguous () ? grad_output : grad_output.contiguous (), input.is_contiguous () ? input : input.contiguous ());
177
+ try {
178
+ if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
179
+ grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d_backward (grad_output.is_contiguous () ? grad_output : grad_output.contiguous (), input.is_contiguous () ? input : input.contiguous ());
180
+ return {grad_input, at::Tensor ()};
181
+ }
182
+ } catch (std::exception& e) {
183
+ #if defined(_DEBUG)
184
+ TORCH_WARN (e.what ());
185
+ #endif
186
+ }
187
+ if (input.device ().type () == c10::DeviceType::DPCPP) {
188
+ auto && _ipex_grad_output = torch_ipex::bridge::shallowFallbackToCPUTensor (grad_output);
189
+ auto && _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor (input);
190
+ auto && _ipex_result = at::_adaptive_avg_pool2d_backward (_ipex_grad_output, _ipex_input);
191
+ static_cast <void >(_ipex_result); // Avoid warnings in case not used
192
+ grad_input = torch_ipex::bridge::shallowUpgradeToDPCPPTensor (_ipex_result);
139
193
} else {
140
194
grad_input = at::_adaptive_avg_pool2d_backward (grad_output, input);
141
195
}
0 commit comments