diff --git a/cmake/CPU.cmake b/cmake/CPU.cmake index 72693d419..5d57ccfb9 100644 --- a/cmake/CPU.cmake +++ b/cmake/CPU.cmake @@ -136,9 +136,11 @@ include_directories(${DPCPP_THIRD_PARTY_ROOT}/xsmm/include) set(DPCPP_SRCS) set(DPCPP_COMMON_SRCS) set(DPCPP_CPU_SRCS) +set(DPCPP_JIT_SRCS) add_subdirectory(${DPCPP_ROOT}) add_subdirectory(${DPCPP_ROOT}/cpu) +add_subdirectory(${DPCPP_ROOT}/jit) # libxsmm include(${CMAKE_ROOT}/Modules/ExternalProject.cmake) @@ -153,7 +155,7 @@ ExternalProject_Add(xsmm INSTALL_COMMAND "" ) # Compile code with pybind11 -set(DPCPP_SRCS ${DPCPP_ATEN_SRCS} ${DPCPP_COMMON_SRCS} ${DPCPP_CPU_SRCS}) +set(DPCPP_SRCS ${DPCPP_ATEN_SRCS} ${DPCPP_COMMON_SRCS} ${DPCPP_CPU_SRCS} ${DPCPP_JIT_SRCS}) pybind11_add_module(${PLUGIN_NAME} SHARED ${DPCPP_SRCS}) target_link_libraries(${PLUGIN_NAME} PRIVATE ${DPCPP_THIRD_PARTY_ROOT}/xsmm/lib/libxsmm.a) diff --git a/intel_pytorch_extension_py/ops/linear.py b/intel_pytorch_extension_py/ops/linear.py index 05a90b23b..ab9a5480e 100644 --- a/intel_pytorch_extension_py/ops/linear.py +++ b/intel_pytorch_extension_py/ops/linear.py @@ -2,30 +2,11 @@ from torch.autograd import Function import torch.nn.functional as F import _torch_ipex as core +from typing import Optional -F_linear = F.linear - -class LinearFunction(Function): - @staticmethod - def forward(ctx, input, weight, bias): - output = core.linear(input, weight, bias) - ctx.save_for_backward(input, weight, bias) - return output - - @staticmethod - def backward(ctx, grad_output): - input, weight, bias = ctx.saved_tensors - grad_output = grad_output.contiguous() - if bias == None: - output_mask = (input.requires_grad, weight.requires_grad, 0) - else: - output_mask = (input.requires_grad, weight.requires_grad, bias.requires_grad) - grad_input, grad_weight, grad_bias = core.linear_backward(input, grad_output, weight, output_mask) - return (grad_input, grad_weight, grad_bias) - -def linear(input, weight, bias=None): - if input.device.type == 'dpcpp' and core.get_auto_dnnl(): - return LinearFunction.apply(input, weight, bias) - return F_linear(input, weight, bias) +def linear(input, weight, bias: Optional[torch.Tensor] = None): + if bias is None: + bias = torch.zeros(weight.size(0)) + return torch.ops.torch_ipex.linear(input, weight, bias) F.linear = linear diff --git a/intel_pytorch_extension_py/ops/pooling.py b/intel_pytorch_extension_py/ops/pooling.py index 7ff457d56..12114a91f 100644 --- a/intel_pytorch_extension_py/ops/pooling.py +++ b/intel_pytorch_extension_py/ops/pooling.py @@ -2,25 +2,12 @@ from torch.autograd import Function import torch.nn.functional as F import _torch_ipex as core -from torch.nn.modules.utils import _single +from torch.nn.modules.utils import _single, _pair +from typing import List -torch_adaptive_avg_pool2d = torch._C._nn.adaptive_avg_pool2d -torch_max_pool2d = torch.max_pool2d -torch_max_pool3d = torch.max_pool3d +Vector = List[int] -class AdaptiveAvgPool2dFunction(Function): - @staticmethod - def forward(ctx, input, output_size): - output = core.adaptive_avg_pool2d(input, _single(output_size)) - ctx.save_for_backward(input) - return output - - @staticmethod - def backward(ctx, grad_output): - (input,) = ctx.saved_tensors - grad_output = grad_output.contiguous() - grad_input = core.adaptive_avg_pool2d_backward(grad_output, input) - return (grad_input, None) +torch_max_pool3d = torch.max_pool3d class MaxPoolingFunction(Function): @staticmethod @@ -41,21 +28,8 @@ def backward(ctx, grad_output): grad_input = core.max_pooling_backward(grad_output, output, input, ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation, ctx.ceil_mode) return (grad_input, None, None, None, None, None) -def adaptive_avg_pool2d(input, output_size): - try: - if input.device.type == 'dpcpp' and core.get_auto_dnnl(): - return AdaptiveAvgPool2dFunction.apply(input, output_size) - except RuntimeError: - pass - return torch_adaptive_avg_pool2d(input, output_size) - -def max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode): - try: - if input.device.type == 'dpcpp' and core.get_auto_dnnl(): - return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode) - except RuntimeError: - pass - return torch_max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) +def adaptive_avg_pool2d(input, output_size: Vector): + return torch.ops.torch_ipex.adaptive_avg_pool2d(input, _pair(output_size)) def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode): try: @@ -65,6 +39,11 @@ def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode): pass return torch_max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) +def max_pool2d(input, kernel_size: Vector, stride: Vector, padding: Vector, dilation: Vector, ceil_mode: bool): + if not stride: + stride = kernel_size + return torch.ops.torch_ipex.max_pool2d(input, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation), ceil_mode) + torch._C._nn.adaptive_avg_pool2d = adaptive_avg_pool2d torch.max_pool2d = max_pool2d -torch.max_pool3d = max_pool3d \ No newline at end of file +torch.max_pool3d = max_pool3d diff --git a/tests/cpu/test_jit.py b/tests/cpu/test_jit.py new file mode 100644 index 000000000..f585da98d --- /dev/null +++ b/tests/cpu/test_jit.py @@ -0,0 +1,172 @@ +from __future__ import division +from __future__ import print_function + +''' +From PyTorch: + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +From Caffe2: + +Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + +All contributions by Google: +Copyright (c) 2015 Google Inc. +All rights reserved. + +All contributions by Yangqing Jia: +Copyright (c) 2015 Yangqing Jia +All rights reserved. + +All contributions from Caffe: +Copyright(c) 2013, 2014, 2015, the respective contributors +All rights reserved. + +All other contributions: +Copyright(c) 2015, 2016 the respective contributors +All rights reserved. + +Caffe2 uses a copyright model similar to Caffe: each contributor holds +copyright over their contributions to Caffe2. The project versioning records +all such contribution and copyright details. If a contributor wants to further +mark their specific copyright on a particular contribution, they should +indicate their copyright solely in the commit message of the change when it is +committed. + +All rights reserved. +''' + +"""Tests for rn50.""" + +import math +import random +import unittest +from functools import reduce + +import torch +import torch.nn as nn +from torch.jit._recursive import wrap_cpp_module +import copy + +import intel_pytorch_extension +from intel_pytorch_extension import core + +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torch.nn import Parameter +import torch.nn.functional as F +from torch.autograd import gradcheck +from torch.autograd.gradcheck import gradgradcheck +from torch._six import inf, nan + +from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \ + TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \ + IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, do_test_empty_full, \ + IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \ + skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf + +device = 'dpcpp:0' +#device = 'cpu:0' +SIZE = 100 + +torch._C._jit_set_profiling_mode(False) +torch._C._jit_set_profiling_executor(False) + +def test_output(model, x): + modelName = model.__class__.__name__ + core.disable_jit() + + model = model.to('dpcpp').eval() + x = x.to('dpcpp') + with torch.no_grad(): + result = model(x) + + smodel = torch.jit.script(model) + smodel.eval() + with torch.no_grad(): + sresult = smodel(x) + + print(f'\nAre {modelName} and Scripted{modelName} outputs the same: ', + torch.allclose( + sresult, result, rtol=1e-05, atol=1e-06, equal_nan=False)) + + core.enable_jit() + pmodel = torch.jit.script(model) + # bn folding + pmodel = wrap_cpp_module(torch._C._jit_pass_fold_convbn(pmodel._c)) + with torch.no_grad(): + # conv relu fusion, conv sum fusion or conv sum relu fusion + print(pmodel.graph_for(x)) + presult = pmodel(x) + + # print(result) + # print(sresult) + # print(presult) + + print(f'\nWith or without pyrys, are Scripted{modelName} outputs the same: ', + torch.allclose( + sresult, presult, rtol=1e-05, atol=1e-06, equal_nan=False)) + +class Conv2dRelu_Fixed(nn.Module): + def __init__(self, in_channels, out_channels, **kwargs): + super(Conv2dRelu_Fixed, self).__init__() + seed = 2018 + torch.manual_seed(seed) + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + + def forward(self, x): + return F.relu(self.conv(x), inplace=True) + +class CascadedConv2dBnSumRelu(nn.Module): + def __init__(self, in_channels, mid_channels, out_channels, **kwargs): + super(CascadedConv2dBnSumRelu, self).__init__() + torch.manual_seed(2018) + self.conv = nn.Conv2d(in_channels, mid_channels, bias=False, **kwargs) + self.conv1 = nn.Conv2d( + mid_channels, out_channels, bias=False, padding=1, **kwargs) + self.conv2 = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(mid_channels, eps=0.001) + self.bn1 = nn.BatchNorm2d(out_channels, eps=0.001) + self.bn2 = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + a = self.conv(x) + a = self.bn(a) + a = F.relu(a, inplace=True) + a = self.conv1(a) + a = self.bn1(a) + b = self.conv2(x) + b = self.bn2(b) + return F.relu(a.add_(b), inplace=True) + +class Tester(TestCase): + n = 32 + c = 3 + h = 224 + w = 224 + print('input size: (%d, %d, %d, %d)' % (n, c, h, w)) + + def test_output_conv_relu(self): + test_output( + Conv2dRelu_Fixed(self.c, 32, kernel_size=3, stride=1), + torch.rand(self.n, self.c, self.h, self.w)) + + def test_output_cascaded_conv2d_bn_sum_relu(self): + test_output( + CascadedConv2dBnSumRelu(self.c, 64, 32, kernel_size=3, stride=1), + torch.rand(self.n, self.c, self.h, self.w)) + +if __name__ == '__main__': + core.enable_auto_dnnl() + test = unittest.main() diff --git a/torch_ipex/csrc/auto_opt_config.h b/torch_ipex/csrc/auto_opt_config.h index 333a0adfd..2f950edef 100644 --- a/torch_ipex/csrc/auto_opt_config.h +++ b/torch_ipex/csrc/auto_opt_config.h @@ -17,6 +17,14 @@ class AutoOptConfig { return auto_dnnl_; } + inline void set_jit_fuse(bool jit_fuse) { + jit_fuse_ = jit_fuse; + } + + inline bool get_jit_fuse() { + return jit_fuse_; + } + inline void set_mix_bf16_fp32(bool value) { mix_bf16_fp32_ = value; } @@ -39,6 +47,7 @@ class AutoOptConfig { private: bool auto_dnnl_; + bool jit_fuse_; bool mix_bf16_fp32_; bool pure_bf16_; }; diff --git a/torch_ipex/csrc/cpu/CustomOPs.h b/torch_ipex/csrc/cpu/CustomOPs.h new file mode 100644 index 000000000..2cea2ad05 --- /dev/null +++ b/torch_ipex/csrc/cpu/CustomOPs.h @@ -0,0 +1,144 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include "torch_ipex/csrc/utils.h" +#include "DevOPs.h" + +class NewLinearOp : public torch::autograd::Function { + public: + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor input, + at::Tensor weight, + at::Tensor bias = at::Tensor()) { + ctx->save_for_backward({input, weight, bias}); + if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) { + return torch_ipex::cpu::AtenIpexCPUDev::dil_linear(input, weight, bias); + } else { + return at::linear(input, weight, bias); + } + } + + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + at::Tensor input = saved[0]; + at::Tensor weight = saved[1]; + at::Tensor bias = saved[2]; + + at::Tensor grad_output = grad_outputs[0]; + at::Tensor grad_input, grad_weight; + at::Tensor grad_bias = torch::Tensor(); + + if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) { + grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_input( + input.sizes(), grad_output.contiguous(), weight); + std::tie(grad_weight, grad_bias) = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_weights( + grad_output.contiguous(), input, weight, bias.defined()); + } else { + grad_input = grad_output.mm(weight); + grad_weight = grad_output.t().mm(input); + if (bias.defined()) { + grad_bias = grad_output.sum(0); + } + } + return {grad_input, grad_weight, grad_bias}; + } +}; + +class NewMaxPoolingOp : public torch::autograd::Function { + public: + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor input, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + bool ceil_mode) { + ctx->saved_data["kernel_size"] = kernel_size; + ctx->saved_data["stride"] = stride; + ctx->saved_data["padding"] = padding; + ctx->saved_data["dilation"] = dilation; + ctx->saved_data["ceil_mode"] = ceil_mode; + + if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) { + at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling(input, kernel_size, stride, + padding, dilation, ceil_mode); + ctx->save_for_backward({input, output}); + return output; + } else { + at::Tensor output, indices; + std::tie(output, indices) = at::max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode); + ctx->save_for_backward({input, indices}); + return output; + } + } + + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + at::Tensor input = saved[0]; + at::Tensor indices = saved[1]; + + at::Tensor grad_output = grad_outputs[0].contiguous(); + at::Tensor grad_input; + + std::vector kernel_size = ctx->saved_data["kernel_size"].toIntVector(); + std::vector stride = ctx->saved_data["stride"].toIntVector(); + std::vector padding = ctx->saved_data["padding"].toIntVector(); + std::vector dilation = ctx->saved_data["dilation"].toIntVector(); + bool ceil_mode = ctx->saved_data["ceil_mode"].toBool(); + + if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) { + grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling_backward( + grad_output, indices, input, kernel_size, stride, padding, dilation, ceil_mode); + } else { + grad_input = at::max_pool2d_with_indices_backward(grad_output, input, kernel_size, + stride, padding, dilation, ceil_mode, indices); + } + return {grad_input, at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor()}; + } +}; + +class NewApaptiveAvgPoolingOp : public torch::autograd::Function { + public: + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor input, + at::IntArrayRef output_size) { + ctx->save_for_backward({input}); + + at::Tensor output; + if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) { + output = torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d(input, output_size); + } else { + output = at::_adaptive_avg_pool2d(input, output_size); + } + return output; + } + + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + at::Tensor input = saved[0]; + + at::Tensor grad_output = grad_outputs[0].contiguous(); + at::Tensor grad_input; + + if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) { + grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d_backward(grad_output, input); + } else { + grad_input = at::_adaptive_avg_pool2d_backward(grad_output, input); + } + return {grad_input, at::Tensor()}; + } +}; diff --git a/torch_ipex/csrc/cpu/DevOPs.cpp b/torch_ipex/csrc/cpu/DevOPs.cpp index ebe231f42..9f405ed02 100644 --- a/torch_ipex/csrc/cpu/DevOPs.cpp +++ b/torch_ipex/csrc/cpu/DevOPs.cpp @@ -526,7 +526,7 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_( at::Tensor AtenIpexCPUDev::dil_linear( const at::Tensor& self, const at::Tensor& weight, - const c10::optional& bias) { + const at::Tensor& bias) { DEBUG("AtenIpexCPUDev::dil_linear\n"); CHECK_DNNL_OP_PRE_COND(self); CHECK_DNNL_OP_PRE_COND(weight); @@ -539,9 +539,8 @@ at::Tensor AtenIpexCPUDev::dil_linear( const dil::tensor w = dbl::comm::try_gen_dil_tensor(weight); dil::tensor y; - if (bias.has_value()) { - at::Tensor bias_vec = bias.value(); - const dil::tensor b = dbl::comm::try_gen_dil_tensor(bias_vec); + if (bias.defined()) { + const dil::tensor b = dbl::comm::try_gen_dil_tensor(bias); dil::inner_product_forward::compute(x, w, b, y); } else { dil::inner_product_forward::compute(x, w, y); @@ -599,7 +598,7 @@ at::Tensor AtenIpexCPUDev::dil_linear_fuse_relu( return dbl::comm::gen_aten_tensor_by(std::move(y)); } -at::Tensor dil_linear_backward_input( +at::Tensor AtenIpexCPUDev::dil_linear_backward_input( at::IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight){ DEBUG("AtenIpexCPUDev::dil_linear_backward_input\n"); auto grad_output_reshaped = grad_output.dim() > 2 ? @@ -621,7 +620,7 @@ at::Tensor dil_linear_backward_input( return dbl::comm::gen_aten_tensor_by(std::move(gradx)); } -std::tuple dil_linear_backward_weights( +std::tuple AtenIpexCPUDev::dil_linear_backward_weights( const at::Tensor& grad_output, const at::Tensor& input, const at::Tensor& weight, bool bias_defined) { DEBUG("AtenIpexCPUDev::dil_linear_backward_weights\n"); auto grad_output_reshaped = grad_output.dim() > 2 ? diff --git a/torch_ipex/csrc/cpu/DevOPs.h b/torch_ipex/csrc/cpu/DevOPs.h index 7c76873e6..49c47a199 100644 --- a/torch_ipex/csrc/cpu/DevOPs.h +++ b/torch_ipex/csrc/cpu/DevOPs.h @@ -38,8 +38,10 @@ class AtenIpexCPUDev { static at::Tensor dil_addbmm(const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha); static at::Tensor& dil_addbmm_(at::Tensor& self, const at::Tensor& batch1, const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha); static at::Tensor& dil_addbmm_out(at::Tensor& result, const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha); - static at::Tensor dil_linear(const at::Tensor& self, const at::Tensor& weight, const c10::optional& bias); static at::Tensor dil_linear_fuse_relu(const at::Tensor& self, const at::Tensor& weight, const c10::optional& bias); + static at::Tensor dil_linear(const at::Tensor& self, const at::Tensor& weight, const at::Tensor& bias); + static at::Tensor dil_linear_backward_input(at::IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight); + static std::tuple dil_linear_backward_weights(const at::Tensor& grad_output, const at::Tensor& input, const at::Tensor& weight, bool bias_defined); static std::tuple dil_linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array output_mask); static at::Tensor dil_dropout(const at::Tensor& self, double ratio, bool train); static at::Tensor dil_dropout_backward(const at::Tensor& grady, const at::Tensor& mask, double ratio); @@ -69,6 +71,7 @@ class AtenIpexCPUDev { static at::Tensor dil_cat(at::TensorList tensors, int64_t dim); static std::vector dil_split_with_sizes(const at::Tensor& self, at::IntArrayRef split_sizes, int64_t dim); static std::vector dil_split(const at::Tensor& self, int64_t split_size, int64_t dim); + }; } // namespace cpu diff --git a/torch_ipex/csrc/cpu/ExtendOPs.cpp b/torch_ipex/csrc/cpu/ExtendOPs.cpp index bb11a869f..a0cacd084 100644 --- a/torch_ipex/csrc/cpu/ExtendOPs.cpp +++ b/torch_ipex/csrc/cpu/ExtendOPs.cpp @@ -10,6 +10,7 @@ #include "xsmm/libxsmm_utils.h" #include "../utils.h" #include "DevOPs.h" +#include "CustomOPs.h" namespace torch_ipex { @@ -449,8 +450,9 @@ AtenIpexTypeExt::embedding_bag_backward(const at::Tensor& grad, const at::Tensor return cpu::aten::embedding_bag::embedding_bag_backward_impl(grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse, _per_sample_weights); } -at::Tensor AtenIpexTypeExt::linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) { - return cpu::AtenIpexCPUDev::dil_linear(input, weight, bias); + +at::Tensor AtenIpexTypeExt::linear(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) { + return NewLinearOp::apply(input, weight, bias); } at::Tensor AtenIpexTypeExt::linear_fuse_relu(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) { @@ -464,7 +466,7 @@ std::tuple AtenIpexTypeExt::linear_backward( } at::Tensor AtenIpexTypeExt::adaptive_avg_pool2d(at::Tensor const& input, at::IntArrayRef output_size) { - return cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d(input, output_size); + return NewApaptiveAvgPoolingOp::apply(input, output_size); } at::Tensor AtenIpexTypeExt::adaptive_avg_pool2d_backward(const at::Tensor& grad_output, const at::Tensor& input) { @@ -472,7 +474,7 @@ at::Tensor AtenIpexTypeExt::adaptive_avg_pool2d_backward(const at::Tensor& grad_ } at::Tensor AtenIpexTypeExt::max_pooling(const at::Tensor& input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { - return cpu::AtenIpexCPUDev::dil_max_pooling(input, kernel_size, stride, padding, dilation, ceil_mode); + return NewMaxPoolingOp::apply(input, kernel_size, stride, padding, dilation, ceil_mode); } at::Tensor AtenIpexTypeExt::max_pooling_backward(const at::Tensor& grad_output, const at::Tensor& output, const at::Tensor& input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { diff --git a/torch_ipex/csrc/cpu/ExtendOPs.h b/torch_ipex/csrc/cpu/ExtendOPs.h index 9305e454b..dedc3e2a4 100644 --- a/torch_ipex/csrc/cpu/ExtendOPs.h +++ b/torch_ipex/csrc/cpu/ExtendOPs.h @@ -23,8 +23,8 @@ class AtenIpexTypeExt { int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional& per_sample_weights); - static at::Tensor linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias); static at::Tensor linear_fuse_relu(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias); + static at::Tensor linear(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias = at::Tensor()); static std::tuple linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array output_mask); static at::Tensor relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output); static at::Tensor adaptive_avg_pool2d(at::Tensor const& input, at::IntArrayRef output_size); diff --git a/torch_ipex/csrc/cpu/FusionOPs.cpp b/torch_ipex/csrc/cpu/FusionOPs.cpp new file mode 100644 index 000000000..d9fec98fa --- /dev/null +++ b/torch_ipex/csrc/cpu/FusionOPs.cpp @@ -0,0 +1,131 @@ +#include "torch_ipex/csrc/cpu/FusionOPs.h" + +#include +#include +#include +#include +#include + +#include + +#include "torch_ipex/csrc/aten_ipex_bridge.h" +#include "torch_ipex/csrc/ipex_tensor_impl.h" +#include "torch_ipex/csrc/utils.h" +#include "dbl/Common.h" +#include "dbl/Conv.h" +#include "ShadeDataContext.h" + +#include "dil/dil.hpp" + +namespace torch_ipex { +namespace cpu { + +at::Tensor AtenIpexJITDev::dil_convolution_relu( + const at::Tensor & input, + const at::Tensor & weight, + const at::Tensor & bias, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + int64_t groups) { + dil::tensor dil_input; + dil::tensor dil_weight; + c10::optional dil_bias{c10::nullopt}; + + auto input_contiguous = input.contiguous(); + auto weight_contiguous = weight.contiguous(); + + dil_input = dbl::comm::try_gen_dil_tensor(input_contiguous); + dil_weight = dbl::comm::try_gen_dil_tensor(weight_contiguous); + if (bias.defined()) { + auto bias_contiguous = bias.contiguous(); + dil_bias = dbl::comm::try_gen_dil_tensor(bias_contiguous); + } + + dil::tensor dil_output = dbl::conv::conv2d_impl( + dil_input, + dil_weight, + dil_bias, + padding, + stride, + dilation, + groups, + dil::attr_t::fuse_relu()); + + return dbl::comm::gen_aten_tensor_by(std::move(dil_output)); +} + +static at::Tensor& dil_convolution_inplace_fusion( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& bias, + at::Tensor& accumu, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + int64_t groups, + const dil::attr_t& attr) { + dil::tensor dil_input; + dil::tensor dil_weight; + dil::tensor dil_output; + c10::optional dil_bias{c10::nullopt}; + + auto input_contiguous = input.contiguous(); + auto weight_contiguous = weight.contiguous(); + auto output_contiguous = accumu.contiguous(); + + dil_input = dbl::comm::try_gen_dil_tensor(input_contiguous); + dil_weight = dbl::comm::try_gen_dil_tensor(weight_contiguous); + dil_output = dbl::comm::try_gen_dil_tensor(output_contiguous); + if (bias.defined()) { + auto bias_contiguous = bias.contiguous(); + dil_bias = dbl::comm::try_gen_dil_tensor(bias_contiguous); + } + + dbl::conv::conv2d_inplace_impl( + dil_input, + dil_weight, + dil_bias, + dil_output, + padding, + stride, + dilation, + groups, + attr); + + dbl::comm::sync_shape_from_dil_to_aten(accumu, dil_output); + return accumu; +} + +at::Tensor& AtenIpexJITDev::dil_convolution_sum( + const at::Tensor & input, + const at::Tensor & weight, + const at::Tensor & bias, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + int64_t groups, + at::Tensor& accumu, + at::Scalar alpha) { + auto scale = alpha.to(); + return dil_convolution_inplace_fusion(input, weight, bias, accumu, stride, padding, + dilation, groups, dil::attr_t::fuse_sum(scale)); +} + +at::Tensor& AtenIpexJITDev::dil_convolution_sum_relu( + const at::Tensor & input, + const at::Tensor & weight, + const at::Tensor & bias, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + int64_t groups, + at::Tensor& accumu, + at::Scalar alpha) { + auto scale = alpha.to(); + return dil_convolution_inplace_fusion(input, weight, bias, accumu, stride, padding, + dilation, groups, dil::attr_t::residual(scale)); +} + +} // namespace cpu +} // namespace torch_ipex diff --git a/torch_ipex/csrc/cpu/FusionOPs.h b/torch_ipex/csrc/cpu/FusionOPs.h new file mode 100644 index 000000000..dcab1ea66 --- /dev/null +++ b/torch_ipex/csrc/cpu/FusionOPs.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +#include + +#include "dil/dil.hpp" + +namespace torch { namespace jit { + +// XXX: PyTorch does not support nesting namespace +// And the alias analysis is not working for namespace other than aten ... +// So we fake some op namespaces to workaround that. +namespace ipex { + static auto conv2d_relu = Symbol::fromQualString("ipex::conv2d_relu"); + static auto conv2d_sum = Symbol::fromQualString("ipex::conv2d_sum"); + static auto conv2d_relu_sum = Symbol::fromQualString("ipex::conv2d_relu_sum"); + static auto conv2d_sum_relu = Symbol::fromQualString("ipex::conv2d_sum_relu"); +} + +}} // namespace torch::jit + +namespace torch_ipex { +namespace cpu { + +class AtenIpexJITDev { + public: + // for JIT ops + static at::Tensor dil_convolution_relu(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups); + + static at::Tensor& dil_convolution_sum(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups, at::Tensor& accumu, at::Scalar alpha); + + static at::Tensor& dil_convolution_sum_relu( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups, at::Tensor& accumu, at::Scalar alpha); + +}; + +} // namespace cpu +} // namespace torch_ipex diff --git a/torch_ipex/csrc/cpu/RegisterOps.cpp b/torch_ipex/csrc/cpu/RegisterOps.cpp new file mode 100644 index 000000000..694d0b9de --- /dev/null +++ b/torch_ipex/csrc/cpu/RegisterOps.cpp @@ -0,0 +1,14 @@ +#include +#include "ExtendOPs.h" + +static auto registry = + torch::RegisterOperators() + .op("torch_ipex::linear", &torch_ipex::AtenIpexTypeExt::linear) + .op("torch_ipex::max_pool2d", [](const at::Tensor& self, c10::List kernel_size, + c10::List stride, c10::List padding, c10::List dilation, bool ceil_mode=false){ + return torch_ipex::AtenIpexTypeExt::max_pooling(self, kernel_size.vec(), stride.vec(), padding.vec(), dilation.vec(), ceil_mode); + }) + .op("torch_ipex::adaptive_avg_pool2d", [](const at::Tensor&self, c10::List output_size) { + return torch_ipex::AtenIpexTypeExt::adaptive_avg_pool2d(self, output_size.vec()); + }); + diff --git a/torch_ipex/csrc/cpu/dbl/Common.cpp b/torch_ipex/csrc/cpu/dbl/Common.cpp index 3be05955d..13cabe94e 100644 --- a/torch_ipex/csrc/cpu/dbl/Common.cpp +++ b/torch_ipex/csrc/cpu/dbl/Common.cpp @@ -99,7 +99,6 @@ void sync_shape_from_dil_to_aten(const at::Tensor& ipex_tensor, const dil::tenso TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.size() != 1 || sizes[0] != 0); ipex_tensor.unsafeGetTensorImpl()->set_sizes_contiguous(sizes); } - } } // namespace comm diff --git a/torch_ipex/csrc/cpu/dbl/Conv.cpp b/torch_ipex/csrc/cpu/dbl/Conv.cpp index b8576e669..c3fce71ca 100644 --- a/torch_ipex/csrc/cpu/dbl/Conv.cpp +++ b/torch_ipex/csrc/cpu/dbl/Conv.cpp @@ -31,7 +31,8 @@ dil::tensor conv2d_impl( at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, - int64_t groups) { + int64_t groups, + const dil::attr_t& attr) { std::vector kernel_size(x.ndims()); // mkldnn conv2d weights could have been re-ordered to 5d by // mkldnn_reorder_conv2d_weight @@ -61,7 +62,11 @@ dil::tensor conv2d_impl( {dilation.begin(), dilation.end()}, {padding.begin(), padding.end()}, {padding.begin(), padding.end()}, - groups); + groups, + dil::scale_t(), + dil::scale_t(), + dil::scale_t(), + attr); } else { dil::convolution_forward::compute( x, @@ -72,11 +77,76 @@ dil::tensor conv2d_impl( {dilation.begin(), dilation.end()}, {padding.begin(), padding.end()}, {padding.begin(), padding.end()}, - groups); + groups, + dil::scale_t(), + dil::scale_t(), + dil::scale_t(), + attr); } return y; } +void conv2d_inplace_impl( + const dil::tensor& x, + const dil::tensor& w, + const c10::optional& b, + dil::tensor& y, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + const dil::attr_t& attr) { + std::vector kernel_size(x.ndims()); + // mkldnn conv2d weights could have been re-ordered to 5d by + // mkldnn_reorder_conv2d_weight + if (w.ndims() == x.ndims() + 1) { + AT_ASSERTM( + groups > 1, + "Only group _mkldnn_conv2d weights could have been reordered to 5d"); + kernel_size[0] = w.get_dim(0) * w.get_dim(1); + std::copy_n(w.get_dims().cbegin() + 2, x.ndims() - 1, kernel_size.begin() + 1); + } else { + std::copy_n(w.get_dims().cbegin(), x.ndims(), kernel_size.begin()); + } + + const dil::dims x_dims = x.get_dims(); + std::vector input_size{x_dims.cbegin(), x_dims.cend()}; + std::vector output_sizes = calc_conv_output_size(input_size, kernel_size, padding, stride, dilation); + + if (b.has_value()) { + dil::convolution_forward::compute( + x, + w, + b.value(), + {output_sizes.cbegin(), output_sizes.cend()}, + y, + {stride.begin(), stride.end()}, + {dilation.begin(), dilation.end()}, + {padding.begin(), padding.end()}, + {padding.begin(), padding.end()}, + groups, + dil::scale_t(), + dil::scale_t(), + dil::scale_t(), + attr); + } else { + dil::convolution_forward::compute( + x, + w, + {output_sizes.cbegin(), output_sizes.cend()}, + y, + {stride.begin(), stride.end()}, + {dilation.begin(), dilation.end()}, + {padding.begin(), padding.end()}, + {padding.begin(), padding.end()}, + groups, + dil::scale_t(), + dil::scale_t(), + dil::scale_t(), + attr); + } +} + } // namespace conv } // namespace dbl } // namespace cpu diff --git a/torch_ipex/csrc/cpu/dbl/Conv.h b/torch_ipex/csrc/cpu/dbl/Conv.h index 224551ca4..5f954f330 100644 --- a/torch_ipex/csrc/cpu/dbl/Conv.h +++ b/torch_ipex/csrc/cpu/dbl/Conv.h @@ -25,7 +25,19 @@ dil::tensor conv2d_impl( at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, - int64_t groups); + int64_t groups, + const dil::attr_t& attr = dil::attr_t()); + +void conv2d_inplace_impl( + const dil::tensor& x, + const dil::tensor& w, + const c10::optional& b, + dil::tensor& y, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + const dil::attr_t& attr = dil::attr_t()); } // namespace conv } // namespace dbl diff --git a/torch_ipex/csrc/init_python_bindings.cpp b/torch_ipex/csrc/init_python_bindings.cpp index b50eca837..33066cd5a 100644 --- a/torch_ipex/csrc/init_python_bindings.cpp +++ b/torch_ipex/csrc/init_python_bindings.cpp @@ -5,6 +5,12 @@ #include #include +#include +#include +#include +#include +#include "jit/fusion_pass.h" + #include #include #include @@ -88,7 +94,7 @@ void InitIpexModuleBindings(py::module m) { }); m.def("linear", - [](const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) { + [](const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) { return AtenIpexTypeExt::linear(input, weight, bias); }); m.def("linear_fuse_relu", @@ -128,15 +134,26 @@ void InitIpexModuleBindings(py::module m) { m.def("mlp_create_handle", &AtenIpexTypeMLPExt::create_handle); m.def("mlp_set_relu_mask", &AtenIpexTypeMLPExt::set_relu_mask); m.def("mlp_release_handle", &AtenIpexTypeMLPExt::release_handle); - m.def("is_dil_tensor", &isDilTensor); m.def("get_dil_tensor_sizes", &getDilTensorSizes); m.def("get_dil_tensor_strides", &getDilTensorStrides); + m.def("enable_jit", []() { AutoOptConfig::singleton().set_jit_fuse(true); }); + m.def("disable_jit", []() { AutoOptConfig::singleton().set_jit_fuse(false); }); + m.def("get_jit", []() { return AutoOptConfig::singleton().get_jit_fuse(); }); } } // namespace - -void InitIpexBindings(py::module m) { InitIpexModuleBindings(m); } +using namespace torch::jit; + +void InitIpexBindings(py::module m) { + InitIpexModuleBindings(m); + // jit fusion pass + RegisterPass pass([](std::shared_ptr& g) { + if (AutoOptConfig::singleton().get_jit_fuse()) { + torch::jit::FusionPass(g); + } + }); +} } // namespace torch_ipex diff --git a/torch_ipex/csrc/jit/CMakeLists.txt b/torch_ipex/csrc/jit/CMakeLists.txt new file mode 100644 index 000000000..3f313b336 --- /dev/null +++ b/torch_ipex/csrc/jit/CMakeLists.txt @@ -0,0 +1,8 @@ +LIST(APPEND DPCPP_JIT_SRCS + ${DPCPP_ROOT}/jit/fusion_pass.cpp + ${DPCPP_ROOT}/jit/register_dnnl_jit_ops.cpp + +) + +# Pass to parent +set(DPCPP_JIT_SRCS ${DPCPP_JIT_SRCS} PARENT_SCOPE) diff --git a/torch_ipex/csrc/jit/fusion_pass.cpp b/torch_ipex/csrc/jit/fusion_pass.cpp index 62ca9c86c..2661c7844 100644 --- a/torch_ipex/csrc/jit/fusion_pass.cpp +++ b/torch_ipex/csrc/jit/fusion_pass.cpp @@ -1,7 +1,8 @@ #include -#include "graph_ext.h" #include "fusion_pass.h" -#include "accelerated_ops.h" + +#include "cpu/FusionOPs.h" + #include #include #include @@ -80,84 +81,38 @@ class OpFuser { aliasDb_ = std::make_unique(graph_); } - Node* fuseNodes(Node *curr, Value *path, Rule rule) { - return fuseOpsWithNewKind(curr, path, curr->owningGraph(), rule->second); - } - - // - // currently we only have to fold conv2d + batch_norm - // - bool isFoldable(Node* node, Node* prev) { - bool foldable = (node->kind() == dnnl::batch_norm - && prev->kind() == dnnl::conv2d); - - // - // Check whether all the sources are constant ??? - // Does performance improve no matter we do it pre-compiling or runtime? - // - auto* conv2d = reinterpret_cast(prev)->cast(); - auto* batch_norm = reinterpret_cast(node)->cast(); - - foldable = foldable - && conv2d->hasConstantParams() - && batch_norm->hasConstantParams(); - return foldable; - } - - Node* foldNodes(Node *conv2d, Node *batch_norm) { - // Change weight/bias source - auto* fold_weight = createBatchNormFoldWeight(conv2d, batch_norm); - fold_weight->insertBefore(conv2d); - conv2d->replaceInput(1, fold_weight->output()); + Node* fuseOpsWithNewKind(Node *curr, Value *v, Graph *g, NodeKind kind) { + auto newNode = g->create(kind); + auto prev = v->node(); + newNode->insertBefore(prev); + newNode->setScope(prev->scope()); + newNode->copyAttributes(*prev); - auto* fold_bias = createBatchNormFoldBias(conv2d, batch_norm); - fold_bias->insertBefore(conv2d); - conv2d->replaceInput(2, fold_bias->output()); + for (auto input : prev->inputs()) { + newNode->addInput(input); + } - batch_norm->replaceAllUsesWith(conv2d); - batch_norm->destroy(); - return conv2d; - } + for (auto input : curr->inputs()) { + if (input != v) { + newNode->addInput(input); + } + } - Node* createBatchNormFoldWeight(Node *conv2d, Node *batch_norm) { - auto g = conv2d->owningGraph(); - auto newNode = g->create(dnnl::fold_weight); - newNode->setScope(conv2d->scope()); + // Copy curr or prev? + newNode->output()->copyMetadata(prev->output()); + newNode->output()->setType(prev->output()->type()); - // We need following parameters - newNode->addInput(conv2d->input(1)); // Conv2d weights - newNode->addInput(batch_norm->input(1)); // Batch norm weights - newNode->addInput(batch_norm->input(4)); // running_var (delta) - newNode->addInput(batch_norm->input(7)); // eps + v->replaceAllUsesWith(newNode->output()); + curr->replaceAllUsesWith(newNode); - // We get meta and type from conv2d weight value - newNode->output()->copyMetadata(conv2d->input(1)); - newNode->output()->setType(conv2d->input(1)->type()); - newNode->output()->setDebugName(conv2d->input(1)->debugName() + ".bn_folded"); + prev->destroy(); + curr->destroy(); return newNode; } - Node* createBatchNormFoldBias(Node *conv2d, Node *batch_norm) { - auto g = conv2d->owningGraph(); - auto newNode = g->create(dnnl::fold_bias); - newNode->setScope(conv2d->scope()); - - // We need following information - newNode->addInput(conv2d->input(1)); // Conv weight - newNode->addInput(conv2d->input(2)); // Conv bias - newNode->addInput(batch_norm->input(1)); // batch norm weight - newNode->addInput(batch_norm->input(2)); // batch norm bias - newNode->addInput(batch_norm->input(3)); // running_mean (mu) - newNode->addInput(batch_norm->input(4)); // running_var (delta) - newNode->addInput(batch_norm->input(7)); // eps - - // We get meta and type from conv2d bias value - newNode->output()->copyMetadata(conv2d->input(2)); - newNode->output()->setType(conv2d->input(2)->type()); - newNode->output()->setDebugName(conv2d->input(2)->debugName() + ".bn_folded"); - - return newNode; + Node* fuseNodes(Node *curr, Value *path, Rule rule) { + return fuseOpsWithNewKind(curr, path, curr->owningGraph(), rule->second); } bool aliasIsSafeForSquashingValue(Node *node, Value *v) { @@ -198,7 +153,7 @@ class OpFuser { } // throw - auto er = script::ErrorReport(node->sourceRange()); + auto er = ErrorReport(node->sourceRange()); er << "Schema not found for fusion process. \n"; er << "Prev: " << *prev << "\n"; er << "Node: " << *node << "\n"; @@ -295,51 +250,39 @@ class OpFuser { } std::pair processNode(Node *node) { - auto nodeExt = reinterpret_cast(node); Node* pos = node; bool changed = false; - if (nodeExt->isDNNLOps()) { - // - // Check whether we could fuse to one certain value path - // - for (auto *v : node->inputs()) { - auto prev = v->node(); - auto fuseRule = isFusable(node, prev); - - // We can fuse only one path - if (fuseRule && aliasIsSafeForFusion(node, v, fuseRule)) { - pos = fuseNodes(node, v, fuseRule.value()); - changed = true; - break; - } else if (isFoldable(node, prev) - && aliasIsSafeForSquashingValue(node, v)) { - pos = foldNodes(prev, node); - changed = true; - break; - } + // + // Check whether we could fuse to one certain value path + // + for (auto *v : node->inputs()) { + auto prev = v->node(); + auto fuseRule = isFusable(node, prev); + + // We can fuse only one path + if (fuseRule && aliasIsSafeForFusion(node, v, fuseRule)) { + pos = fuseNodes(node, v, fuseRule.value()); + changed = true; + break; } } - return std::make_pair(++pos->iterator(), changed); -} + } + }; // TODO: These rules should be more scalable OpFuser::RuleTab OpFuser::dnnlRules = { - {{dnnl::conv2d, dnnl::relu}, dnnl::conv2d_relu}, - {{dnnl::conv2d, dnnl::relu_}, dnnl::conv2d_relu}, - /* - {{dnnl::batch_norm, dnnl::relu}, dnnl::batch_norm_relu}, - {{dnnl::batch_norm, dnnl::relu_}, dnnl::batch_norm_relu}, - */ - {{dnnl::conv2d_sum, dnnl::relu}, dnnl::conv2d_sum_relu}, - {{dnnl::conv2d_sum, dnnl::relu_}, dnnl::conv2d_sum_relu}, - - {{dnnl::conv2d, dnnl::sum}, dnnl::conv2d_sum}, - {{dnnl::conv2d, dnnl::sum_}, dnnl::conv2d_sum}, - // {{dnnl::conv2d_relu, dnnl::sum}, dnnl::conv2d_relu_sum} + {{aten::conv2d, aten::relu}, ipex::conv2d_relu}, + {{aten::conv2d, Symbol::fromQualString("aten::relu_")}, ipex::conv2d_relu}, + {{ipex::conv2d_sum, aten::relu}, ipex::conv2d_sum_relu}, + {{ipex::conv2d_sum, Symbol::fromQualString("aten::relu_")}, ipex::conv2d_sum_relu}, + + {{aten::conv2d, aten::add}, ipex::conv2d_sum}, + {{aten::conv2d, aten::add_}, ipex::conv2d_sum}, + //{{dnnl::conv2d_relu, aten::add}, dnnl::conv2d_relu_sum} }; void FusionPass(std::shared_ptr &graph) { @@ -351,4 +294,5 @@ void FusionPass(std::shared_ptr &graph) { // TODO: Some post processing?? ECS/EDC/Peephole??? ConstantPropagation(graph); } + }} // namespace torch::jit diff --git a/torch_ipex/csrc/jit/graph_ext.cpp b/torch_ipex/csrc/jit/graph_ext.cpp index 46b6ef6bf..efbec2cf8 100644 --- a/torch_ipex/csrc/jit/graph_ext.cpp +++ b/torch_ipex/csrc/jit/graph_ext.cpp @@ -207,7 +207,6 @@ formatTag Conv2dNode::expectedWeightFormat( return desc.get_internal_format(); } - void Conv2dNode::fixWeightFormatIfPossible() { if (couldInferFormats()) { auto tensor = toIValue(this->input(1))->toTensor(); diff --git a/torch_ipex/csrc/jit/register_dnnl_jit_ops.cpp b/torch_ipex/csrc/jit/register_dnnl_jit_ops.cpp index 487e2711d..2d5102b77 100644 --- a/torch_ipex/csrc/jit/register_dnnl_jit_ops.cpp +++ b/torch_ipex/csrc/jit/register_dnnl_jit_ops.cpp @@ -1,16 +1,17 @@ -#include "torch/csrc/jit/runtime/operator.h" -#include "torch/csrc/jit/runtime/custom_operator.h" -#include "accelerated_ops.h" -#include "graph_ext.h" -#include "dnnl_ops.h" +#include + +#include +#include + +#include "torch_ipex/csrc/utils.h" +#include "cpu/FusionOPs.h" + namespace torch { namespace jit { -c10::OperatorOptions aliasAnalysisFromSchema() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); - return result; +c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; } at::Tensor toOptionalTensor(const IValue& v) { @@ -20,260 +21,99 @@ at::Tensor toOptionalTensor(const IValue& v) { return v.toTensor(); } -using namespace at::native; +using namespace torch_ipex::cpu; RegisterOperators op({ Operator( - "dnnl::reorder(Tensor self) -> Tensor", - [](const Node* node) -> Operation { - return [node] (Stack& stack) { - auto* enode = reinterpret_cast(node); - auto from = enode->inputFormat(0); - auto to = enode->inputFormat(1); - auto groups = enode->getGroupInfo(); - - auto result = dnnl_reorder( - (std::move(peek(stack, 0, 1))).toTensor(), from, to, groups); - drop(stack, 1); - pack(stack, std::move(result)); - return 0; - }; - }, - aliasAnalysisFromSchema() - ), - Operator( - "dnnl::relu(Tensor self) -> Tensor", - [](const Node* node) -> Operation { - return [] (Stack& stack) { - auto result = dnnl_relu( - (std::move(peek(stack, 0, 1))).toTensor()); - drop(stack, 1); - pack(stack, std::move(result)); - return 0; - }; - }, - aliasAnalysisFromSchema() - ), - Operator( - "dnnl::relu_(Tensor(a!) self) -> Tensor(a!)", - [] (const Node* node) -> Operation { - return [] (Stack& stack) { - at::Tensor input; - pop(stack, input); - auto result = dnnl_relu_(input); - push(stack, std::move(result)); - return 0; - }; - }, - aliasAnalysisFromSchema() - ), - Operator( - "dnnl::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", - [] (const Node* node) -> Operation { - return [] (Stack& stack) { - auto result = dnnl_conv2d( - (std::move(peek(stack, 0, 7))).toTensor(), - (std::move(peek(stack, 1, 7))).toTensor(), - toOptionalTensor(std::move(peek(stack, 2, 7))), - (std::move(peek(stack, 3, 7))).toIntVector(), - (std::move(peek(stack, 4, 7))).toIntVector(), - (std::move(peek(stack, 5, 7))).toIntVector(), - (std::move(peek(stack, 6, 7))).toInt()); - drop(stack, 7); - pack(stack, std::move(result)); - return 0; - }; - }, - aliasAnalysisFromSchema() - ), - Operator( - "dnnl::conv2d_relu(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", - [] (const Node* node) ->Operation { - return [] (Stack& stack) { - auto result = dnnl_conv2d_relu( - (std::move(peek(stack, 0, 7))).toTensor(), - (std::move(peek(stack, 1, 7))).toTensor(), - toOptionalTensor(std::move(peek(stack, 2, 7))), - (std::move(peek(stack, 3, 7))).toIntVector(), - (std::move(peek(stack, 4, 7))).toIntVector(), - (std::move(peek(stack, 5, 7))).toIntVector(), - (std::move(peek(stack, 6, 7))).toInt()); - drop(stack, 7); - pack(stack, std::move(result)); - return 0; - }; - }, - aliasAnalysisFromSchema() - ), - Operator( - "dnnl::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", + "ipex::conv2d_relu(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", [] (const Node* node) ->Operation { - return [] (Stack& stack) { - auto result = dnnl_batch_norm( - (std::move(peek(stack, 0, 9))).toTensor(), - toOptionalTensor(std::move(peek(stack, 1, 9))), - toOptionalTensor(std::move(peek(stack, 2, 9))), - toOptionalTensor(std::move(peek(stack, 3, 9))), - toOptionalTensor(std::move(peek(stack, 4, 9))), - (std::move(peek(stack, 5, 9))).toBool(), - (std::move(peek(stack, 6, 9))).toDouble(), - (std::move(peek(stack, 7, 9))).toDouble(), - (std::move(peek(stack, 8, 9))).toBool()); - drop(stack, 9); - pack(stack, std::move(result)); - return 0; - }; - }, - aliasAnalysisFromSchema() - ), - Operator( - "dnnl::fold_weight(Tensor weight, Tensor? bn_weight, Tensor? running_var, float eps) -> Tensor", - [] (const Node* node) -> Operation { - return [] (Stack& stack) { - auto result = dnnl_fold_weight( - (std::move(peek(stack, 0, 4))).toTensor(), - toOptionalTensor(std::move(peek(stack, 1, 4))), - toOptionalTensor(std::move(peek(stack, 2, 4))), - (std::move(peek(stack, 3, 4))).toDouble()); - drop(stack, 4); - pack(stack, std::move(result)); - return 0; - }; + if (torch_ipex::check_auto_dnnl()) { + return [] (Stack& stack) { + auto result = AtenIpexJITDev::dil_convolution_relu( + (std::move(peek(stack, 0, 7))).toTensor(), + (std::move(peek(stack, 1, 7))).toTensor(), + toOptionalTensor(std::move(peek(stack, 2, 7))), + (std::move(peek(stack, 3, 7))).toIntVector(), + (std::move(peek(stack, 4, 7))).toIntVector(), + (std::move(peek(stack, 5, 7))).toIntVector(), + (std::move(peek(stack, 6, 7))).toInt()); + drop(stack, 7); + pack(stack, std::move(result)); + return 0; + }; + } else { + TORCH_CHECK(false, "PyTorch native path not support convolution relu fusion now"); + } }, aliasAnalysisFromSchema() ), Operator( - "dnnl::fold_bias(Tensor weight, Tensor? bias, Tensor? bn_weight, Tensor? bn_bias, Tensor? running_mean, Tensor? running_var, float eps) -> Tensor", - [] (const Node* node) -> Operation{ - return [] (Stack& stack) { - auto result = dnnl_fold_bias( - (std::move(peek(stack, 0, 7))).toTensor(), - toOptionalTensor(std::move(peek(stack, 1, 7))), - toOptionalTensor(std::move(peek(stack, 2, 7))), - toOptionalTensor(std::move(peek(stack, 3, 7))), - toOptionalTensor(std::move(peek(stack, 4, 7))), - toOptionalTensor(std::move(peek(stack, 5, 7))), - (std::move(peek(stack, 6, 7))).toDouble()); - drop(stack, 7); - pack(stack, std::move(result)); - return 0; - }; - }, - aliasAnalysisFromSchema() - ), - Operator( - "dnnl::sum(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", + "ipex::conv2d_sum(Tensor input, Tensor weight, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation, int groups, Tensor(a!) accumu, *, Scalar alpha) -> Tensor(a!)", [] (const Node* node) ->Operation { - return [] (Stack& stack) { - auto result = dnnl_sum( - (std::move(peek(stack, 0, 3))).toTensor(), - (std::move(peek(stack, 1, 3))).toTensor(), - (std::move(peek(stack, 2, 3))).toScalar() - ); - drop(stack, 3); - pack(stack, std::move(result)); - return 0; - }; + if (torch_ipex::check_auto_dnnl()) { + return [] (Stack& stack) { + auto output = (std::move(peek(stack, 7, 9))).toTensor(); + auto result = AtenIpexJITDev::dil_convolution_sum( + (std::move(peek(stack, 0, 9))).toTensor(), + (std::move(peek(stack, 1, 9))).toTensor(), + toOptionalTensor(std::move(peek(stack, 2, 9))), + (std::move(peek(stack, 3, 9))).toIntVector(), + (std::move(peek(stack, 4, 9))).toIntVector(), + (std::move(peek(stack, 5, 9))).toIntVector(), + (std::move(peek(stack, 6, 9))).toInt(), + output, + (std::move(peek(stack, 8, 9))).toScalar() + ); + drop(stack, 9); + pack(stack, std::move(result)); + return 0; + }; + } else { + TORCH_CHECK(false, "PyTorch native path not support convolution sum fusion now"); + } }, aliasAnalysisFromSchema() ), Operator( - "dnnl::sum_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", - [] (const Node* node) ->Operation{ - return [](Stack &stack) { - auto self = (std::move(peek(stack, 0, 3))).toTensor(); - auto result = dnnl_sum_( - self, - (std::move(peek(stack, 1, 3))).toTensor(), - (std::move(peek(stack, 2, 3))).toScalar()); - drop(stack, 3); - pack(stack, std::move(result)); - return 0; - }; - }, - aliasAnalysisFromSchema() - ), - Operator( - "dnnl::conv2d_sum(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1, Tensor(a!) accumu, *, Scalar alpha=1) -> Tensor(a!)", + "ipex::conv2d_sum_relu(Tensor input, Tensor weight, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation, int groups, Tensor(a!) accumu, *, Scalar alpha) -> Tensor(a!)", [] (const Node* node) ->Operation { - return [] (Stack& stack) { - auto output = (std::move(peek(stack, 7, 9))).toTensor(); - auto result = dnnl_conv2d_sum( - (std::move(peek(stack, 0, 9))).toTensor(), - (std::move(peek(stack, 1, 9))).toTensor(), - toOptionalTensor(std::move(peek(stack, 2, 9))), - (std::move(peek(stack, 3, 9))).toIntVector(), - (std::move(peek(stack, 4, 9))).toIntVector(), - (std::move(peek(stack, 5, 9))).toIntVector(), - (std::move(peek(stack, 6, 9))).toInt(), - output, - (std::move(peek(stack, 8, 9))).toScalar() - ); - drop(stack, 9); - pack(stack, std::move(result)); - return 0; - }; + if (torch_ipex::check_auto_dnnl()) { + return [] (Stack& stack) { + auto output = (std::move(peek(stack, 7, 9))).toTensor(); + auto result = AtenIpexJITDev::dil_convolution_sum_relu( + (std::move(peek(stack, 0, 9))).toTensor(), + (std::move(peek(stack, 1, 9))).toTensor(), + toOptionalTensor(std::move(peek(stack, 2, 9))), + (std::move(peek(stack, 3, 9))).toIntVector(), + (std::move(peek(stack, 4, 9))).toIntVector(), + (std::move(peek(stack, 5, 9))).toIntVector(), + (std::move(peek(stack, 6, 9))).toInt(), + output, + (std::move(peek(stack, 8, 9))).toScalar() + ); + drop(stack, 9); + pack(stack, std::move(result)); + return 0; + }; + } else { + TORCH_CHECK(false, "PyTorch native path not support convolution sum relu fusion now"); + } }, aliasAnalysisFromSchema() ), Operator( - "dnnl::conv2d_sum_relu(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1, Tensor(a!) accumu, *, Scalar alpha=1) -> Tensor(a!)", + "ipex::prepack_weight(Tensor input, Tensor weight, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation, int groups) -> Tensor(a!)", [] (const Node* node) ->Operation { - return [] (Stack& stack) { - auto output = (std::move(peek(stack, 7, 9))).toTensor(); - auto result = dnnl_conv2d_sum_relu( - (std::move(peek(stack, 0, 9))).toTensor(), - (std::move(peek(stack, 1, 9))).toTensor(), - toOptionalTensor(std::move(peek(stack, 2, 9))), - (std::move(peek(stack, 3, 9))).toIntVector(), - (std::move(peek(stack, 4, 9))).toIntVector(), - (std::move(peek(stack, 5, 9))).toIntVector(), - (std::move(peek(stack, 6, 9))).toInt(), - output, - (std::move(peek(stack, 8, 9))).toScalar() - ); - drop(stack, 9); - pack(stack, std::move(result)); - return 0; - }; - }, - aliasAnalysisFromSchema() - ), - Operator( - "dnnl::pooling_max_2d(Tensor input, int[2] kernel_size, int[2] stride=1, int[2] padding=0, int[2] dilation=1, bool ceil_mode=0) -> Tensor(a!)", - [] (const Node *node) ->Operation { - return [] (Stack& stack) { - auto result = dnnl_pooling_max_2d( - (std::move(peek(stack, 0, 6))).toTensor(), // Input tensor - (std::move(peek(stack, 1, 6))).toIntVector(), // Kernel size - (std::move(peek(stack, 2, 6))).toIntVector(), // Stride - (std::move(peek(stack, 3, 6))).toIntVector(), // Padding - (std::move(peek(stack, 4, 6))).toIntVector(), // Dilation - (std::move(peek(stack, 5, 6))).toBool()); // Ceil mode - drop(stack, 6); - pack(stack, std::move(result)); - return 0; - }; + if (torch_ipex::check_auto_dnnl()) { + return [] (Stack& stack) { + return 0; + }; + } else { + TORCH_CHECK(false, "PyTorch native path not support prepack weight now"); + } }, aliasAnalysisFromSchema() - ), - Operator( - "dnnl::pooling_avg_2d(Tensor input, int[2] kernel_size, int[2] stride=1, int[2] padding=0, bool ceil_mode=0, bool count_include_pad=True, int? divisor_override=None) -> Tensor(a!)", - [] (const Node *node) ->Operation { - return [] (Stack& stack) { - auto result = dnnl_pooling_avg_2d( - (std::move(peek(stack, 0, 7))).toTensor(), // Input tensor - (std::move(peek(stack, 1, 7))).toIntVector(), // Kernel size - (std::move(peek(stack, 2, 7))).toIntVector(), // Stride - (std::move(peek(stack, 3, 7))).toIntVector(), // Padding - (std::move(peek(stack, 4, 7))).toBool()); // Ceil mode - drop(stack, 7); - pack(stack, std::move(result)); - return 0; - }; - }, - aliasAnalysisFromSchema() - ), + ) }); } }