Skip to content

Commit 39fd990

Browse files
committed
enable DNNL Python OPs(adaptive_avg_pool2d, max_pool2d, max_pool3d) to fallback to CPU.
1 parent de92da8 commit 39fd990

File tree

3 files changed

+29
-27
lines changed

3 files changed

+29
-27
lines changed

intel_pytorch_extension_py/ops/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def backward(ctx, grad_output):
2424
return (grad_input, grad_weight, grad_bias)
2525

2626
def linear(input, weight, bias=None):
27-
if input.device.type == 'dpcpp':
27+
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
2828
return LinearFunction.apply(input, weight, bias)
2929
return F_linear(input, weight, bias)
3030

intel_pytorch_extension_py/ops/pooling.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
from torch.autograd import Function
33
import torch.nn.functional as F
44
import _torch_ipex as core
5+
from torch.nn.modules.utils import _single
56

6-
F_adaptive_avg_pool2d = F.adaptive_avg_pool2d
7+
torch_adaptive_avg_pool2d = torch._C._nn.adaptive_avg_pool2d
78
torch_max_pool2d = torch.max_pool2d
89
torch_max_pool3d = torch.max_pool3d
910

1011
class AdaptiveAvgPool2dFunction(Function):
1112
@staticmethod
1213
def forward(ctx, input, output_size):
13-
_output_size = _list_with_default(output_size, input.size())
14-
output = core.adaptive_avg_pool2d(input, _output_size)
14+
output = core.adaptive_avg_pool2d(input, _single(output_size))
1515
ctx.save_for_backward(input)
1616
return output
1717

@@ -25,44 +25,46 @@ def backward(ctx, grad_output):
2525
class MaxPoolingFunction(Function):
2626
@staticmethod
2727
def forward(ctx, input, kernel_size, stride, padding, dilation, ceil_mode):
28-
output = core.max_pooling(input, (kernel_size,), (stride,), (padding,), (dilation,), ceil_mode)
29-
ctx.save_for_backward(output, input)
30-
ctx.kernel_size = kernel_size
31-
ctx.stride = stride
32-
ctx.padding = padding
33-
ctx.dilation = dilation
28+
ctx.kernel_size = _single(kernel_size)
29+
ctx.stride = _single(stride)
30+
ctx.padding = _single(padding)
31+
ctx.dilation = _single(dilation)
3432
ctx.ceil_mode = ceil_mode
33+
output = core.max_pooling(input, ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation, ctx.ceil_mode)
34+
ctx.save_for_backward(output, input)
3535
return output
3636

3737
@staticmethod
3838
def backward(ctx, grad_output):
3939
output, input= ctx.saved_tensors
4040
grad_output = grad_output.contiguous()
41-
grad_input = core.max_pooling_backward(grad_output, output, input, (ctx.kernel_size,), (ctx.stride,), (ctx.padding,), (ctx.dilation,), ctx.ceil_mode)
41+
grad_input = core.max_pooling_backward(grad_output, output, input, ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation, ctx.ceil_mode)
4242
return (grad_input, None, None, None, None, None)
4343

44-
def _list_with_default(out_size, defaults):
45-
if isinstance(out_size, int):
46-
return (out_size,)
47-
if len(defaults) <= len(out_size):
48-
raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1))
49-
return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])]
50-
5144
def adaptive_avg_pool2d(input, output_size):
52-
if input.device.type == 'dpcpp':
53-
return AdaptiveAvgPool2dFunction.apply(input, output_size)
54-
return F_adaptive_avg_pool2d(input, output_size)
45+
try:
46+
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
47+
return AdaptiveAvgPool2dFunction.apply(input, output_size)
48+
except RuntimeError:
49+
pass
50+
return torch_adaptive_avg_pool2d(input, output_size)
5551

5652
def max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode):
57-
if input.device.type == 'dpcpp':
58-
return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode)
53+
try:
54+
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
55+
return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode)
56+
except RuntimeError:
57+
pass
5958
return torch_max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
6059

6160
def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode):
62-
if input.device.type == 'dpcpp':
63-
return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode)
61+
try:
62+
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
63+
return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode)
64+
except RuntimeError:
65+
pass
6466
return torch_max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode)
6567

66-
F.adaptive_avg_pool2d = adaptive_avg_pool2d
68+
torch._C._nn.adaptive_avg_pool2d = adaptive_avg_pool2d
6769
torch.max_pool2d = max_pool2d
6870
torch.max_pool3d = max_pool3d

intel_pytorch_extension_py/ops/reshape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def forward(ctx, input, size):
1111
return output
1212

1313
def reshape(input, size):
14-
if input.device.type == 'dpcpp':
14+
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
1515
return ReshapeFunction.apply(input, size)
1616
return torch_reshape(input, size)
1717

0 commit comments

Comments
 (0)