diff --git a/intel_pytorch_extension_py/ops/linear.py b/intel_pytorch_extension_py/ops/linear.py index 55fb7bd4f..05a90b23b 100644 --- a/intel_pytorch_extension_py/ops/linear.py +++ b/intel_pytorch_extension_py/ops/linear.py @@ -24,7 +24,7 @@ def backward(ctx, grad_output): return (grad_input, grad_weight, grad_bias) def linear(input, weight, bias=None): - if input.device.type == 'dpcpp': + if input.device.type == 'dpcpp' and core.get_auto_dnnl(): return LinearFunction.apply(input, weight, bias) return F_linear(input, weight, bias) diff --git a/intel_pytorch_extension_py/ops/pooling.py b/intel_pytorch_extension_py/ops/pooling.py index 7e38cae1b..7ff457d56 100644 --- a/intel_pytorch_extension_py/ops/pooling.py +++ b/intel_pytorch_extension_py/ops/pooling.py @@ -2,16 +2,16 @@ from torch.autograd import Function import torch.nn.functional as F import _torch_ipex as core +from torch.nn.modules.utils import _single -F_adaptive_avg_pool2d = F.adaptive_avg_pool2d +torch_adaptive_avg_pool2d = torch._C._nn.adaptive_avg_pool2d torch_max_pool2d = torch.max_pool2d torch_max_pool3d = torch.max_pool3d class AdaptiveAvgPool2dFunction(Function): @staticmethod def forward(ctx, input, output_size): - _output_size = _list_with_default(output_size, input.size()) - output = core.adaptive_avg_pool2d(input, _output_size) + output = core.adaptive_avg_pool2d(input, _single(output_size)) ctx.save_for_backward(input) return output @@ -25,44 +25,46 @@ def backward(ctx, grad_output): class MaxPoolingFunction(Function): @staticmethod def forward(ctx, input, kernel_size, stride, padding, dilation, ceil_mode): - output = core.max_pooling(input, (kernel_size,), (stride,), (padding,), (dilation,), ceil_mode) - ctx.save_for_backward(output, input) - ctx.kernel_size = kernel_size - ctx.stride = stride - ctx.padding = padding - ctx.dilation = dilation + ctx.kernel_size = _single(kernel_size) + ctx.stride = _single(stride) + ctx.padding = _single(padding) + ctx.dilation = _single(dilation) ctx.ceil_mode = ceil_mode + output = core.max_pooling(input, ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation, ctx.ceil_mode) + ctx.save_for_backward(output, input) return output @staticmethod def backward(ctx, grad_output): output, input= ctx.saved_tensors grad_output = grad_output.contiguous() - grad_input = core.max_pooling_backward(grad_output, output, input, (ctx.kernel_size,), (ctx.stride,), (ctx.padding,), (ctx.dilation,), ctx.ceil_mode) + grad_input = core.max_pooling_backward(grad_output, output, input, ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation, ctx.ceil_mode) return (grad_input, None, None, None, None, None) -def _list_with_default(out_size, defaults): - if isinstance(out_size, int): - return (out_size,) - if len(defaults) <= len(out_size): - raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1)) - return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])] - def adaptive_avg_pool2d(input, output_size): - if input.device.type == 'dpcpp': - return AdaptiveAvgPool2dFunction.apply(input, output_size) - return F_adaptive_avg_pool2d(input, output_size) + try: + if input.device.type == 'dpcpp' and core.get_auto_dnnl(): + return AdaptiveAvgPool2dFunction.apply(input, output_size) + except RuntimeError: + pass + return torch_adaptive_avg_pool2d(input, output_size) def max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode): - if input.device.type == 'dpcpp': - return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode) + try: + if input.device.type == 'dpcpp' and core.get_auto_dnnl(): + return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode) + except RuntimeError: + pass return torch_max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode): - if input.device.type == 'dpcpp': - return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode) + try: + if input.device.type == 'dpcpp' and core.get_auto_dnnl(): + return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode) + except RuntimeError: + pass return torch_max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) -F.adaptive_avg_pool2d = adaptive_avg_pool2d +torch._C._nn.adaptive_avg_pool2d = adaptive_avg_pool2d torch.max_pool2d = max_pool2d torch.max_pool3d = max_pool3d \ No newline at end of file diff --git a/intel_pytorch_extension_py/ops/reshape.py b/intel_pytorch_extension_py/ops/reshape.py index fe00d5795..dd1122959 100644 --- a/intel_pytorch_extension_py/ops/reshape.py +++ b/intel_pytorch_extension_py/ops/reshape.py @@ -11,7 +11,7 @@ def forward(ctx, input, size): return output def reshape(input, size): - if input.device.type == 'dpcpp': + if input.device.type == 'dpcpp' and core.get_auto_dnnl(): return ReshapeFunction.apply(input, size) return torch_reshape(input, size) diff --git a/tests/cpu/test_lazy_reorder.py b/tests/cpu/test_lazy_reorder.py index 39b79fdb6..4943e537a 100644 --- a/tests/cpu/test_lazy_reorder.py +++ b/tests/cpu/test_lazy_reorder.py @@ -13,7 +13,7 @@ import torch import _torch_ipex as ipex ipex._initialize_aten_bindings() -import intel_pytorch_extension_py +import intel_pytorch_extension import torch.nn as nn import torch.backends.cudnn as cudnn diff --git a/tests/cpu/test_rn50_cpu_ops.py b/tests/cpu/test_rn50_cpu_ops.py index 4e0e53139..c53c78a3f 100644 --- a/tests/cpu/test_rn50_cpu_ops.py +++ b/tests/cpu/test_rn50_cpu_ops.py @@ -57,6 +57,7 @@ import torch import _torch_ipex as ipex ipex._initialize_aten_bindings() +import intel_pytorch_extension import torch.nn as nn import torch.backends.cudnn as cudnn diff --git a/tests/cpu/test_torch.py b/tests/cpu/test_torch.py index 6752ed3da..c5fbab570 100644 --- a/tests/cpu/test_torch.py +++ b/tests/cpu/test_torch.py @@ -83,6 +83,7 @@ skipIf, skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, onlyCUDA, onlyCPU, \ dtypes, dtypesIfCUDA, deviceCountAtLeast, skipCUDAIf, precisionOverride, ipex import torch.backends.quantized +import intel_pytorch_extension # load_tests from common_utils is used to automatically filter tests for