Skip to content

Commit f2253a8

Browse files
committed
make rewrited max_pool2d op can be traced
1 parent 3fd406b commit f2253a8

File tree

4 files changed

+67
-15
lines changed

4 files changed

+67
-15
lines changed

intel_pytorch_extension_py/ops/pooling.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
from torch.autograd import Function
33
import torch.nn.functional as F
44
import _torch_ipex as core
5-
from torch.nn.modules.utils import _single
5+
from torch.nn.modules.utils import _single, _pair
6+
from typing import List
7+
8+
Vector = List[int]
69

710
torch_adaptive_avg_pool2d = torch._C._nn.adaptive_avg_pool2d
811
torch_max_pool2d = torch.max_pool2d
@@ -49,14 +52,6 @@ def adaptive_avg_pool2d(input, output_size):
4952
pass
5053
return torch_adaptive_avg_pool2d(input, output_size)
5154

52-
def max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode):
53-
try:
54-
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
55-
return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode)
56-
except RuntimeError:
57-
pass
58-
return torch_max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
59-
6055
def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode):
6156
try:
6257
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
@@ -65,6 +60,9 @@ def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode):
6560
pass
6661
return torch_max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode)
6762

63+
def max_pool2d(input, kernel_size: Vector, stride: Vector, padding: Vector, dilation: Vector, ceil_mode: bool):
64+
return torch.ops.torch_ipex.max_pool2d(input, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation), ceil_mode)
65+
6866
torch._C._nn.adaptive_avg_pool2d = adaptive_avg_pool2d
6967
torch.max_pool2d = max_pool2d
70-
torch.max_pool3d = max_pool3d
68+
torch.max_pool3d = max_pool3d

torch_ipex/csrc/cpu/CustomerOps.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,57 @@ class NewLinearOp : public torch::autograd::Function<NewLinearOp> {
5353
return {grad_input, grad_weight, grad_bias};
5454
}
5555
};
56+
57+
class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
58+
public:
59+
static at::Tensor forward(
60+
torch::autograd::AutogradContext* ctx,
61+
at::Tensor input,
62+
at::IntArrayRef kernel_size,
63+
at::IntArrayRef stride,
64+
at::IntArrayRef padding,
65+
at::IntArrayRef dilation,
66+
bool ceil_mode) {
67+
ctx->saved_data["kernel_size"] = kernel_size;
68+
ctx->saved_data["stride"] = stride;
69+
ctx->saved_data["padding"] = padding;
70+
ctx->saved_data["dilation"] = dilation;
71+
ctx->saved_data["ceil_mode"] = ceil_mode;
72+
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
73+
at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling(input, kernel_size, stride,
74+
padding, dilation, ceil_mode);
75+
ctx->save_for_backward({input, output});
76+
return output;
77+
} else {
78+
at::Tensor output, indices;
79+
std::tie(output, indices) = at::max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode);
80+
ctx->save_for_backward({input, indices});
81+
return output;
82+
}
83+
}
84+
85+
static torch::autograd::tensor_list backward(
86+
torch::autograd::AutogradContext* ctx,
87+
torch::autograd::tensor_list grad_outputs) {
88+
auto saved = ctx->get_saved_variables();
89+
at::Tensor input = saved[0];
90+
at::Tensor indices = saved[1];
91+
92+
at::Tensor grad_output = grad_outputs[0];
93+
at::Tensor grad_input;
94+
at::IntArrayRef kernel_size = at::IntArrayRef(ctx->saved_data["kernel_size"].toIntVector());
95+
at::IntArrayRef stride = at::IntArrayRef(ctx->saved_data["stride"].toIntVector());
96+
at::IntArrayRef padding = at::IntArrayRef(ctx->saved_data["padding"].toIntVector());
97+
at::IntArrayRef dilation = at::IntArrayRef(ctx->saved_data["dilation"].toIntVector());
98+
bool ceil_mode = ctx->saved_data["ceil_mode"].toBool();
99+
100+
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
101+
grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling_backward(
102+
grad_output, indices, input, kernel_size, stride, padding, dilation, ceil_mode);
103+
} else {
104+
grad_input = at::max_pool2d_with_indices_backward(grad_output, input, kernel_size,
105+
stride, padding, dilation, ceil_mode, indices);
106+
}
107+
return {grad_input};
108+
}
109+
};

torch_ipex/csrc/cpu/ExtendOPs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ at::Tensor AtenIpexTypeExt::adaptive_avg_pool2d_backward(const at::Tensor& grad_
468468
}
469469

470470
at::Tensor AtenIpexTypeExt::max_pooling(const at::Tensor& input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) {
471-
return cpu::AtenIpexCPUDev::dil_max_pooling(input, kernel_size, stride, padding, dilation, ceil_mode);
471+
return NewMaxPoolingOp::apply(input, kernel_size, stride, padding, dilation, ceil_mode);
472472
}
473473

474474
at::Tensor AtenIpexTypeExt::max_pooling_backward(const at::Tensor& grad_output, const at::Tensor& output, const at::Tensor& input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) {

torch_ipex/csrc/cpu/RegisterOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
static auto registry =
55
torch::RegisterOperators()
6-
.op("torch_ipex::linear",
7-
[](const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) {
8-
return torch_ipex::AtenIpexTypeExt::linear(input, weight, bias);
6+
.op("torch_ipex::linear", &torch_ipex::AtenIpexTypeExt::linear)
7+
.op("torch_ipex::max_pool2d", [](const at::Tensor& self, c10::List<int64_t> kernel_size,
8+
c10::List<int64_t> stride, c10::List<int64_t> padding, c10::List<int64_t> dilation, bool ceil_mode=false){
9+
return torch_ipex::AtenIpexTypeExt::max_pooling(self, kernel_size.vec(), stride.vec(), padding.vec(), dilation.vec(), ceil_mode);
910
});
1011

11-

0 commit comments

Comments
 (0)