Skip to content

Commit 79953a5

Browse files
authored
enable the fallback to cpu for adaptive_avg_pool2d and max_pool2d (#84)
1. move "bn folding" and "prepack conv weight" to the hooked jit script function 2. add check on fused node in the graph
1 parent 1d432c2 commit 79953a5

File tree

2 files changed

+167
-18
lines changed

2 files changed

+167
-18
lines changed

tests/cpu/test_lazy_reorder.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,48 @@ def test_adaptive_avg_pool2d_backward(self):
641641
y_cpu.backward()
642642
y_dpcpp.backward()
643643
self.assertEqual(x_cpu.grad, x_dpcpp.grad)
644+
645+
def test_adaptive_avg_pool2d_not_divisible(self):
646+
ipex.enable_auto_dnnl()
647+
rand_seed = int(get_rand_seed())
648+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
649+
torch.manual_seed(rand_seed)
650+
N = torch.randint(3, 10, (1,)).item()
651+
C = torch.randint(3, 10, (1,)).item()
652+
x_cpu = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
653+
x_dpcpp = x_cpu.to(device=device)
654+
# test the fallback to cpu when the input size is not divisible by the output size
655+
adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(6)
656+
657+
y_cpu = adaptive_avg_pool2d(x_cpu)
658+
y_dpcpp = adaptive_avg_pool2d(x_dpcpp)
659+
660+
self.assertEqual(
661+
y_cpu,
662+
y_dpcpp)
663+
664+
self.assertEqual(device, y_dpcpp.device.type)
665+
666+
def test_adaptive_avg_pool2d_backward_not_divisible(self):
667+
ipex.enable_auto_dnnl()
668+
rand_seed = int(get_rand_seed())
669+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
670+
torch.manual_seed(rand_seed)
671+
x = torch.randn(10, 3, 224, 224, dtype=torch.float32) * 100
672+
673+
x_cpu = x.clone().requires_grad_()
674+
x_dpcpp = x.clone().to(device=device).requires_grad_()
675+
# test the fallback to cpu when the input size is not divisible by the output size
676+
adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(6)
677+
678+
y_cpu = adaptive_avg_pool2d(x_cpu).sum()
679+
y_dpcpp = adaptive_avg_pool2d(x_dpcpp).sum()
680+
y_cpu.backward()
681+
y_dpcpp.backward()
682+
self.assertEqual(x_cpu.grad, x_dpcpp.grad)
683+
684+
self.assertEqual(device, x_dpcpp.grad.device.type)
685+
self.assertEqual(device, y_dpcpp.device.type)
644686

645687
def test_max_pool2d(self):
646688
ipex.core.enable_auto_dnnl()
@@ -663,6 +705,33 @@ def test_max_pool2d(self):
663705
ceil_mode=ceil_mode)
664706

665707
self.assertEqual(max_pool2d(x_cpu), max_pool2d(x_dpcpp))
708+
709+
def test_max_pool2d_double(self):
710+
ipex.enable_auto_dnnl()
711+
rand_seed = int(get_rand_seed())
712+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
713+
torch.manual_seed(rand_seed)
714+
N = torch.randint(3, 10, (1,)).item()
715+
C = torch.randint(3, 10, (1,)).item()
716+
717+
for stride in [1, 2, 3]:
718+
for H, W in [(64, 64), (35, 39), (16, 19), [7, 8]]:
719+
# test the fallback to cpu when the input is double
720+
x_cpu = torch.randn(N, C, H, W, dtype=torch.double) * 10
721+
x_dpcpp = x_cpu.to(device=device)
722+
723+
for ceil_mode in [False, True]:
724+
max_pool2d = torch.nn.MaxPool2d(
725+
kernel_size=3 if not ceil_mode else 7,
726+
stride=stride,
727+
padding=1,
728+
ceil_mode=ceil_mode)
729+
730+
y_cpu = max_pool2d(x_cpu)
731+
y_dpcpp = max_pool2d(x_dpcpp)
732+
self.assertEqual(y_cpu, y_dpcpp)
733+
734+
self.assertEqual(device, y_dpcpp.device.type)
666735

667736
def test_max_pool3d(self):
668737
ipex.core.enable_auto_dnnl()
@@ -707,6 +776,32 @@ def test_max_pool2d_backward(self):
707776
y1.backward()
708777
y2.backward()
709778
self.assertEqual(x1.grad, x2.grad)
779+
780+
def test_max_pool2d_backward_double(self):
781+
ipex.enable_auto_dnnl()
782+
rand_seed = int(get_rand_seed())
783+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
784+
torch.manual_seed(rand_seed)
785+
# test the fallback to cpu when the input is double
786+
x = torch.randn(10, 3, 64, 64, dtype=torch.double) * 10
787+
for ceil_mode in [True]:
788+
max_pool2d = torch.nn.MaxPool2d(
789+
kernel_size=3,
790+
stride=2,
791+
padding=1,
792+
ceil_mode=ceil_mode)
793+
794+
x1 = x.clone().requires_grad_()
795+
x2 = x.clone().to(device=device).requires_grad_()
796+
797+
y1 = max_pool2d(x1).sum()
798+
y2 = max_pool2d(x2).sum()
799+
y1.backward()
800+
y2.backward()
801+
self.assertEqual(x1.grad, x2.grad)
802+
803+
self.assertEqual(device, x2.grad.device.type)
804+
self.assertEqual(device, y2.device.type)
710805

711806
def test_max_pool3d_backward(self):
712807
ipex.core.enable_auto_dnnl()

torch_ipex/csrc/cpu/CustomOPs.h

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <ATen/Tensor.h>
77
#include <torch/script.h>
88
#include <c10/util/Optional.h>
9+
#include "torch_ipex/csrc/aten_ipex_bridge.h"
910
#include "torch_ipex/csrc/utils.h"
1011
#include "DevOPs.h"
1112

@@ -68,17 +69,29 @@ class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
6869
ctx->saved_data["dilation"] = dilation;
6970
ctx->saved_data["ceil_mode"] = ceil_mode;
7071

71-
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
72-
at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling(input.is_contiguous() ? input : input.contiguous(), kernel_size, stride,
73-
padding, dilation, ceil_mode);
74-
ctx->save_for_backward({input, output});
75-
return output;
72+
try {
73+
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
74+
at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling(input.is_contiguous() ? input : input.contiguous(), kernel_size, stride,
75+
padding, dilation, ceil_mode);
76+
ctx->save_for_backward({input, output});
77+
return output;
78+
}
79+
} catch (std::exception& e) {
80+
#if defined(_DEBUG)
81+
TORCH_WARN(e.what());
82+
#endif
83+
}
84+
at::Tensor output, indices;
85+
if (input.device().type() == c10::DeviceType::DPCPP) {
86+
auto&& _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor(input);
87+
auto&& _ipex_result = at::max_pool2d_with_indices(_ipex_input, kernel_size, stride, padding, dilation, ceil_mode);
88+
static_cast<void>(_ipex_result);
89+
std::tie(output, indices) = std::tuple<at::Tensor,at::Tensor>(torch_ipex::bridge::shallowUpgradeToDPCPPTensor(std::get<0>(_ipex_result)), torch_ipex::bridge::shallowUpgradeToDPCPPTensor(std::get<1>(_ipex_result)));
7690
} else {
77-
at::Tensor output, indices;
7891
std::tie(output, indices) = at::max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode);
79-
ctx->save_for_backward({input, indices});
80-
return output;
8192
}
93+
ctx->save_for_backward({input, indices});
94+
return output;
8295
}
8396

8497
static torch::autograd::tensor_list backward(
@@ -97,9 +110,26 @@ class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
97110
std::vector<int64_t> dilation = ctx->saved_data["dilation"].toIntVector();
98111
bool ceil_mode = ctx->saved_data["ceil_mode"].toBool();
99112

100-
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
101-
grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling_backward(
102-
grad_output.is_contiguous() ? grad_output : grad_output.contiguous(), indices.is_contiguous() ? indices : indices.contiguous(), input.is_contiguous() ? input : input.contiguous(), kernel_size, stride, padding, dilation, ceil_mode);
113+
114+
try {
115+
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
116+
grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling_backward(
117+
grad_output.is_contiguous() ? grad_output : grad_output.contiguous(), indices.is_contiguous() ? indices : indices.contiguous(), input.is_contiguous() ? input : input.contiguous(), kernel_size, stride, padding, dilation, ceil_mode);
118+
return {grad_input, at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor()};
119+
}
120+
} catch (std::exception& e) {
121+
#if defined(_DEBUG)
122+
TORCH_WARN(e.what());
123+
#endif
124+
}
125+
if (input.device().type() == c10::DeviceType::DPCPP) {
126+
auto&& _ipex_grad_output = torch_ipex::bridge::shallowFallbackToCPUTensor(grad_output);
127+
auto&& _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor(input);
128+
auto&& _ipex_indices = torch_ipex::bridge::shallowFallbackToCPUTensor(indices);
129+
auto&& _ipex_grad_input = at::max_pool2d_with_indices_backward(_ipex_grad_output, _ipex_input, kernel_size,
130+
stride, padding, dilation, ceil_mode, _ipex_indices);
131+
static_cast<void>(_ipex_grad_input);
132+
grad_input = torch_ipex::bridge::shallowUpgradeToDPCPPTensor(_ipex_grad_input);
103133
} else {
104134
grad_input = at::max_pool2d_with_indices_backward(grad_output, input, kernel_size,
105135
stride, padding, dilation, ceil_mode, indices);
@@ -116,13 +146,23 @@ class NewApaptiveAvgPoolingOp : public torch::autograd::Function<NewApaptiveAvgP
116146
at::IntArrayRef output_size) {
117147
ctx->save_for_backward({input});
118148

119-
at::Tensor output;
120-
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
121-
output = torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d(input.is_contiguous() ? input : input.contiguous(), output_size);
149+
try{
150+
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
151+
return torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d(input.is_contiguous() ? input : input.contiguous(), output_size);
152+
}
153+
} catch (std::exception& e) {
154+
#if defined(_DEBUG)
155+
TORCH_WARN(e.what());
156+
#endif
157+
}
158+
if (input.device().type() == c10::DeviceType::DPCPP) {
159+
auto&& _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor(input);
160+
auto&& _ipex_result = at::_adaptive_avg_pool2d(_ipex_input, output_size);
161+
static_cast<void>(_ipex_result); // Avoid warnings in case not used
162+
return torch_ipex::bridge::shallowUpgradeToDPCPPTensor(_ipex_result);
122163
} else {
123-
output = at::_adaptive_avg_pool2d(input, output_size);
164+
return at::_adaptive_avg_pool2d(input, output_size);
124165
}
125-
return output;
126166
}
127167

128168
static torch::autograd::tensor_list backward(
@@ -134,8 +174,22 @@ class NewApaptiveAvgPoolingOp : public torch::autograd::Function<NewApaptiveAvgP
134174
at::Tensor grad_output = grad_outputs[0];
135175
at::Tensor grad_input;
136176

137-
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
138-
grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d_backward(grad_output.is_contiguous() ? grad_output : grad_output.contiguous(), input.is_contiguous() ? input : input.contiguous());
177+
try {
178+
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
179+
grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d_backward(grad_output.is_contiguous() ? grad_output : grad_output.contiguous(), input.is_contiguous() ? input : input.contiguous());
180+
return {grad_input, at::Tensor()};
181+
}
182+
} catch (std::exception& e) {
183+
#if defined(_DEBUG)
184+
TORCH_WARN(e.what());
185+
#endif
186+
}
187+
if (input.device().type() == c10::DeviceType::DPCPP) {
188+
auto&& _ipex_grad_output = torch_ipex::bridge::shallowFallbackToCPUTensor(grad_output);
189+
auto&& _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor(input);
190+
auto&& _ipex_result = at::_adaptive_avg_pool2d_backward(_ipex_grad_output, _ipex_input);
191+
static_cast<void>(_ipex_result); // Avoid warnings in case not used
192+
grad_input = torch_ipex::bridge::shallowUpgradeToDPCPPTensor(_ipex_result);
139193
} else {
140194
grad_input = at::_adaptive_avg_pool2d_backward(grad_output, input);
141195
}

0 commit comments

Comments
 (0)