Skip to content

Commit 4e84df6

Browse files
committed
make rewrited AdaptiveAvgPool2d op can be traced
1 parent 52a39fd commit 4e84df6

File tree

4 files changed

+42
-25
lines changed

4 files changed

+42
-25
lines changed

intel_pytorch_extension_py/ops/pooling.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,8 @@
77

88
Vector = List[int]
99

10-
torch_adaptive_avg_pool2d = torch._C._nn.adaptive_avg_pool2d
11-
torch_max_pool2d = torch.max_pool2d
1210
torch_max_pool3d = torch.max_pool3d
1311

14-
class AdaptiveAvgPool2dFunction(Function):
15-
@staticmethod
16-
def forward(ctx, input, output_size):
17-
output = core.adaptive_avg_pool2d(input, _single(output_size))
18-
ctx.save_for_backward(input)
19-
return output
20-
21-
@staticmethod
22-
def backward(ctx, grad_output):
23-
(input,) = ctx.saved_tensors
24-
grad_output = grad_output.contiguous()
25-
grad_input = core.adaptive_avg_pool2d_backward(grad_output, input)
26-
return (grad_input, None)
27-
2812
class MaxPoolingFunction(Function):
2913
@staticmethod
3014
def forward(ctx, input, kernel_size, stride, padding, dilation, ceil_mode):
@@ -44,13 +28,8 @@ def backward(ctx, grad_output):
4428
grad_input = core.max_pooling_backward(grad_output, output, input, ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation, ctx.ceil_mode)
4529
return (grad_input, None, None, None, None, None)
4630

47-
def adaptive_avg_pool2d(input, output_size):
48-
try:
49-
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
50-
return AdaptiveAvgPool2dFunction.apply(input, output_size)
51-
except RuntimeError:
52-
pass
53-
return torch_adaptive_avg_pool2d(input, output_size)
31+
def adaptive_avg_pool2d(input, output_size: Vector):
32+
return torch.ops.torch_ipex.adaptive_avg_pool2d(input, _pair(output_size))
5433

5534
def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode):
5635
try:

torch_ipex/csrc/cpu/CustomOPs.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,38 @@ class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
107107
return {grad_input, at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor()};
108108
}
109109
};
110+
111+
class NewApaptiveAvgPoolingOp : public torch::autograd::Function<NewApaptiveAvgPoolingOp> {
112+
public:
113+
static at::Tensor forward(
114+
torch::autograd::AutogradContext* ctx,
115+
at::Tensor input,
116+
at::IntArrayRef output_size) {
117+
ctx->save_for_backward({input});
118+
119+
at::Tensor output;
120+
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
121+
output = torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d(input, output_size);
122+
} else {
123+
output = at::_adaptive_avg_pool2d(input, output_size);
124+
}
125+
return output;
126+
}
127+
128+
static torch::autograd::tensor_list backward(
129+
torch::autograd::AutogradContext* ctx,
130+
torch::autograd::tensor_list grad_outputs) {
131+
auto saved = ctx->get_saved_variables();
132+
at::Tensor input = saved[0];
133+
134+
at::Tensor grad_output = grad_outputs[0].contiguous();
135+
at::Tensor grad_input;
136+
137+
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
138+
grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d_backward(grad_output, input);
139+
} else {
140+
grad_input = at::_adaptive_avg_pool2d_backward(grad_output, input);
141+
}
142+
return {grad_input, at::Tensor()};
143+
}
144+
};

torch_ipex/csrc/cpu/ExtendOPs.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include "xsmm/libxsmm_utils.h"
1111
#include "../utils.h"
1212
#include "DevOPs.h"
13-
#include "CustomOps.h"
13+
#include "CustomOPs.h"
1414

1515
namespace torch_ipex {
1616

@@ -460,7 +460,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexTypeExt::linear_backward(
460460
}
461461

462462
at::Tensor AtenIpexTypeExt::adaptive_avg_pool2d(at::Tensor const& input, at::IntArrayRef output_size) {
463-
return cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d(input, output_size);
463+
return NewApaptiveAvgPoolingOp::apply(input, output_size);
464464
}
465465

466466
at::Tensor AtenIpexTypeExt::adaptive_avg_pool2d_backward(const at::Tensor& grad_output, const at::Tensor& input) {

torch_ipex/csrc/cpu/RegisterOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,8 @@ static auto registry =
77
.op("torch_ipex::max_pool2d", [](const at::Tensor& self, c10::List<int64_t> kernel_size,
88
c10::List<int64_t> stride, c10::List<int64_t> padding, c10::List<int64_t> dilation, bool ceil_mode=false){
99
return torch_ipex::AtenIpexTypeExt::max_pooling(self, kernel_size.vec(), stride.vec(), padding.vec(), dilation.vec(), ceil_mode);
10+
})
11+
.op("torch_ipex::adaptive_avg_pool2d", [](const at::Tensor&self, c10::List<int64_t> output_size) {
12+
return torch_ipex::AtenIpexTypeExt::adaptive_avg_pool2d(self, output_size.vec());
1013
});
1114

0 commit comments

Comments
 (0)