Skip to content

Commit b6176e5

Browse files
committed
enable DNNL Python OPs(adaptive_avg_pool2d, max_pool2d, max_pool3d) to fallback to CPU.
1 parent de92da8 commit b6176e5

File tree

3 files changed

+33
-22
lines changed

3 files changed

+33
-22
lines changed

intel_pytorch_extension_py/ops/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def backward(ctx, grad_output):
2424
return (grad_input, grad_weight, grad_bias)
2525

2626
def linear(input, weight, bias=None):
27-
if input.device.type == 'dpcpp':
27+
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
2828
return LinearFunction.apply(input, weight, bias)
2929
return F_linear(input, weight, bias)
3030

intel_pytorch_extension_py/ops/pooling.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
import torch.nn.functional as F
44
import _torch_ipex as core
55

6-
F_adaptive_avg_pool2d = F.adaptive_avg_pool2d
6+
torch_adaptive_avg_pool2d = torch._C._nn.adaptive_avg_pool2d
77
torch_max_pool2d = torch.max_pool2d
88
torch_max_pool3d = torch.max_pool3d
99

1010
class AdaptiveAvgPool2dFunction(Function):
1111
@staticmethod
1212
def forward(ctx, input, output_size):
13-
_output_size = _list_with_default(output_size, input.size())
14-
output = core.adaptive_avg_pool2d(input, _output_size)
13+
if type(output_size) is int:
14+
output_size = (output_size,)
15+
output = core.adaptive_avg_pool2d(input, output_size)
1516
ctx.save_for_backward(input)
1617
return output
1718

@@ -25,7 +26,15 @@ def backward(ctx, grad_output):
2526
class MaxPoolingFunction(Function):
2627
@staticmethod
2728
def forward(ctx, input, kernel_size, stride, padding, dilation, ceil_mode):
28-
output = core.max_pooling(input, (kernel_size,), (stride,), (padding,), (dilation,), ceil_mode)
29+
if type(kernel_size) is int:
30+
kernel_size = (kernel_size,)
31+
if type(stride) is int:
32+
stride = (stride,)
33+
if type(padding) is int:
34+
padding = (padding,)
35+
if type(dilation) is int:
36+
dilation = (dilation,)
37+
output = core.max_pooling(input, kernel_size, stride, padding, dilation, ceil_mode)
2938
ctx.save_for_backward(output, input)
3039
ctx.kernel_size = kernel_size
3140
ctx.stride = stride
@@ -38,31 +47,33 @@ def forward(ctx, input, kernel_size, stride, padding, dilation, ceil_mode):
3847
def backward(ctx, grad_output):
3948
output, input= ctx.saved_tensors
4049
grad_output = grad_output.contiguous()
41-
grad_input = core.max_pooling_backward(grad_output, output, input, (ctx.kernel_size,), (ctx.stride,), (ctx.padding,), (ctx.dilation,), ctx.ceil_mode)
50+
grad_input = core.max_pooling_backward(grad_output, output, input, ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation, ctx.ceil_mode)
4251
return (grad_input, None, None, None, None, None)
4352

44-
def _list_with_default(out_size, defaults):
45-
if isinstance(out_size, int):
46-
return (out_size,)
47-
if len(defaults) <= len(out_size):
48-
raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1))
49-
return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])]
50-
5153
def adaptive_avg_pool2d(input, output_size):
52-
if input.device.type == 'dpcpp':
53-
return AdaptiveAvgPool2dFunction.apply(input, output_size)
54-
return F_adaptive_avg_pool2d(input, output_size)
54+
try:
55+
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
56+
return AdaptiveAvgPool2dFunction.apply(input, output_size)
57+
except RuntimeError:
58+
return torch_adaptive_avg_pool2d(input, output_size)
59+
return torch_adaptive_avg_pool2d(input, output_size)
5560

5661
def max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode):
57-
if input.device.type == 'dpcpp':
58-
return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode)
62+
try:
63+
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
64+
return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode)
65+
except RuntimeError:
66+
return torch_max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
5967
return torch_max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
6068

6169
def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode):
62-
if input.device.type == 'dpcpp':
63-
return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode)
70+
try:
71+
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
72+
return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode)
73+
except RuntimeError:
74+
return torch_max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode)
6475
return torch_max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode)
6576

66-
F.adaptive_avg_pool2d = adaptive_avg_pool2d
77+
torch._C._nn.adaptive_avg_pool2d = adaptive_avg_pool2d
6778
torch.max_pool2d = max_pool2d
6879
torch.max_pool3d = max_pool3d

intel_pytorch_extension_py/ops/reshape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def forward(ctx, input, size):
1111
return output
1212

1313
def reshape(input, size):
14-
if input.device.type == 'dpcpp':
14+
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
1515
return ReshapeFunction.apply(input, size)
1616
return torch_reshape(input, size)
1717

0 commit comments

Comments
 (0)