Skip to content

Commit 52a39fd

Browse files
committed
fix max_pool2d backward floating point exception issue
1 parent f2253a8 commit 52a39fd

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

torch_ipex/csrc/cpu/CustomerOps.h renamed to torch_ipex/csrc/cpu/CustomOPs.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
#include "torch_ipex/csrc/utils.h"
1010
#include "DevOPs.h"
1111

12-
using namespace at;
13-
1412
class NewLinearOp : public torch::autograd::Function<NewLinearOp> {
1513
public:
1614
static at::Tensor forward(
@@ -40,9 +38,9 @@ class NewLinearOp : public torch::autograd::Function<NewLinearOp> {
4038

4139
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
4240
grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_input(
43-
input.sizes(), grad_output, weight);
41+
input.sizes(), grad_output.contiguous(), weight);
4442
std::tie(grad_weight, grad_bias) = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_weights(
45-
grad_output, input, weight, bias.defined());
43+
grad_output.contiguous(), input, weight, bias.defined());
4644
} else {
4745
auto grad_input = grad_output.mm(weight);
4846
auto grad_weight = grad_output.t().mm(input);
@@ -69,6 +67,7 @@ class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
6967
ctx->saved_data["padding"] = padding;
7068
ctx->saved_data["dilation"] = dilation;
7169
ctx->saved_data["ceil_mode"] = ceil_mode;
70+
7271
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
7372
at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling(input, kernel_size, stride,
7473
padding, dilation, ceil_mode);
@@ -89,12 +88,13 @@ class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
8988
at::Tensor input = saved[0];
9089
at::Tensor indices = saved[1];
9190

92-
at::Tensor grad_output = grad_outputs[0];
91+
at::Tensor grad_output = grad_outputs[0].contiguous();
9392
at::Tensor grad_input;
94-
at::IntArrayRef kernel_size = at::IntArrayRef(ctx->saved_data["kernel_size"].toIntVector());
95-
at::IntArrayRef stride = at::IntArrayRef(ctx->saved_data["stride"].toIntVector());
96-
at::IntArrayRef padding = at::IntArrayRef(ctx->saved_data["padding"].toIntVector());
97-
at::IntArrayRef dilation = at::IntArrayRef(ctx->saved_data["dilation"].toIntVector());
93+
94+
std::vector<int64_t> kernel_size = ctx->saved_data["kernel_size"].toIntVector();
95+
std::vector<int64_t> stride = ctx->saved_data["stride"].toIntVector();
96+
std::vector<int64_t> padding = ctx->saved_data["padding"].toIntVector();
97+
std::vector<int64_t> dilation = ctx->saved_data["dilation"].toIntVector();
9898
bool ceil_mode = ctx->saved_data["ceil_mode"].toBool();
9999

100100
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
@@ -104,6 +104,6 @@ class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
104104
grad_input = at::max_pool2d_with_indices_backward(grad_output, input, kernel_size,
105105
stride, padding, dilation, ceil_mode, indices);
106106
}
107-
return {grad_input};
107+
return {grad_input, at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor()};
108108
}
109109
};

torch_ipex/csrc/cpu/ExtendOPs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include "xsmm/libxsmm_utils.h"
1111
#include "../utils.h"
1212
#include "DevOPs.h"
13-
#include "CustomerOps.h"
13+
#include "CustomOps.h"
1414

1515
namespace torch_ipex {
1616

torch_ipex/csrc/cpu/FusionOPs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ at::Tensor AtenIpexJITDev::dil_convolution_relu(
5252
groups,
5353
dil::attr_t::fuse_relu());
5454

55-
return dbl::comm::gen_aten_tensor_by(dil_output);
55+
return dbl::comm::gen_aten_tensor_by(std::move(dil_output));
5656
}
5757

5858
static at::Tensor& dil_convolution_inplace_fusion(

0 commit comments

Comments
 (0)