From 3e499cb308c8f45aa33f60d79e90697e43a5ad27 Mon Sep 17 00:00:00 2001 From: "Zhu, Haozhe" Date: Mon, 25 May 2020 10:19:59 +0800 Subject: [PATCH 1/2] Enable Linear+ReLU fuse by OneDNNL rename some FNs hide dil from frontend and reuse attr args instead of fuse_relu remove useless headfiles since relu' function body are moved to DevCPs.cpp add unit test for linear fuse relu move ut to test_lazy_reorder --- intel_pytorch_extension_py/ops/__init__.py | 2 + .../ops/linear_fuse_relu.py | 55 +++++++++++++++++++ tests/cpu/test_lazy_reorder.py | 50 ++++++++++++++++- torch_ipex/csrc/cpu/DevOPs.cpp | 51 +++++++++++++++++ torch_ipex/csrc/cpu/DevOPs.h | 2 + torch_ipex/csrc/cpu/ExtendOPs.cpp | 12 ++++ torch_ipex/csrc/cpu/ExtendOPs.h | 2 + .../cpu/dil/dil/operators/inner_product.hpp | 4 +- torch_ipex/csrc/init_python_bindings.cpp | 8 +++ 9 files changed, 183 insertions(+), 3 deletions(-) create mode 100644 intel_pytorch_extension_py/ops/linear_fuse_relu.py diff --git a/intel_pytorch_extension_py/ops/__init__.py b/intel_pytorch_extension_py/ops/__init__.py index cfb79efbc..d652ed89f 100644 --- a/intel_pytorch_extension_py/ops/__init__.py +++ b/intel_pytorch_extension_py/ops/__init__.py @@ -4,3 +4,5 @@ from .pooling import * from .reshape import * from .mlp import * +from .linear_fuse_relu import * + diff --git a/intel_pytorch_extension_py/ops/linear_fuse_relu.py b/intel_pytorch_extension_py/ops/linear_fuse_relu.py new file mode 100644 index 000000000..d15e1b67f --- /dev/null +++ b/intel_pytorch_extension_py/ops/linear_fuse_relu.py @@ -0,0 +1,55 @@ +import torch +from torch import nn +from torch.nn.parameter import Parameter +from torch.nn import init +from torch.autograd import Function +import math +import _torch_ipex as core + +class LinearFuseReluFC(Function): + @staticmethod + def forward(ctx, input, weight, bias): + output = core.linear_fuse_relu(input, weight, bias) + ctx.save_for_backward(input, weight, bias, output) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias, output = ctx.saved_tensors + grad_output = grad_output.contiguous() + if bias == None: + output_mask = (input.requires_grad, weight.requires_grad, 0) + else: + output_mask = (input.requires_grad, weight.requires_grad, bias.requires_grad) + grad_output = core.relu_use_dst_backward(grad_output, output) + grad_input, grad_weight, grad_bias = core.linear_backward(input, grad_output, weight, output_mask) + return (grad_input, grad_weight, grad_bias) + +class LinearFuseRelu(nn.Module): + r"""DNNL Linear module for using relu fused DNNL kernel""" + + __constants__ = ['bias'] + + def __init__(self, in_features, out_features, bias=True): + super(LinearFuseRelu, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter(torch.Tensor(out_features, in_features)) + + if bias: + self.bias = Parameter(torch.Tensor(out_features)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + bound = 1 / math.sqrt(self.in_features) + init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + # print(self.weight.shape) + output = LinearFuseReluFC.apply(input, self.weight, self.bias) + return output + diff --git a/tests/cpu/test_lazy_reorder.py b/tests/cpu/test_lazy_reorder.py index 34a5429c9..8d21a0551 100644 --- a/tests/cpu/test_lazy_reorder.py +++ b/tests/cpu/test_lazy_reorder.py @@ -13,7 +13,7 @@ import torch import _torch_ipex as ipex ipex._initialize_aten_bindings() -import intel_pytorch_extension +import intel_pytorch_extension_py import torch.nn as nn import torch.backends.cudnn as cudnn @@ -438,6 +438,54 @@ def test_linear_backward(self): y2.backward() self.assertEqual(x1.grad, x2.grad) +class TestLinearFuseRelu(TestCase): + def test_linear_fuse_relu_forward(self): + ipex.enable_auto_dnnl() + rand_seed = int(get_rand_seed()) + print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed)) + torch.manual_seed(rand_seed) + in_features = torch.randint(3, 10, (1,)).item() + out_features = torch.randint(3, 100, (1,)).item() + for dtype in [torch.bfloat16, torch.float]: + x = torch.randn(3, in_features) * 10 + x = x.to(dtype).to('dpcpp') + for bias in [True, False]: + linear = torch.nn.Linear(in_features, out_features, bias=bias).to('dpcpp').to(dtype) + relu = torch.nn.ReLU() + linear_fuse_relu = intel_pytorch_extension_py.LinearFuseRelu(in_features, out_features, bias=bias) + linear_fuse_relu.weight.data = linear.weight.clone() + if bias: + linear_fuse_relu.bias.data = linear.bias.clone() + self.assertEqual(relu(linear(x)).float(), linear_fuse_relu(x).float()) + + def test_linear_fuse_relu_backward(self): + ipex.enable_auto_dnnl() + rand_seed = int(get_rand_seed()) + print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed)) + torch.manual_seed(rand_seed) + in_features = torch.randint(3, 10, (1,)).item() + out_features = torch.randint(3, 100, (1,)).item() + for dtype in [torch.bfloat16, torch.float]: + x = torch.randn(3, in_features) * 10 + x = x.to(dtype).to('dpcpp') + for bias in [True, False]: + linear = torch.nn.Linear(in_features, out_features, bias=bias).to('dpcpp').to(dtype) + relu = torch.nn.ReLU() + linear_fuse_relu = intel_pytorch_extension_py.LinearFuseRelu(in_features, out_features, bias=bias) + linear_fuse_relu.weight.data = linear.weight.clone() + if bias: + linear_fuse_relu.bias.data = linear.bias.clone() + x1 = x.clone().requires_grad_() + x2 = x.clone().requires_grad_() + y1 = relu(linear(x1).float()).sum() + y2 = linear_fuse_relu(x2).sum() + y1.backward() + y2.backward() + self.assertEqual(x1.grad.float(), x2.grad.float()) + self.assertEqual(linear.weight.grad.float(), linear_fuse_relu.weight.grad.float()) + if bias: + self.assertEqual(linear.bias.grad.float(), linear_fuse_relu.bias.grad.float()) + class TestPool(TestCase): def test_avg_pool2d(self): ipex.enable_auto_dnnl() diff --git a/torch_ipex/csrc/cpu/DevOPs.cpp b/torch_ipex/csrc/cpu/DevOPs.cpp index 9f9c24806..e8cfa5b04 100644 --- a/torch_ipex/csrc/cpu/DevOPs.cpp +++ b/torch_ipex/csrc/cpu/DevOPs.cpp @@ -549,6 +549,48 @@ at::Tensor AtenIpexCPUDev::dil_linear( return dbl::comm::gen_aten_tensor_by(y); } +at::Tensor AtenIpexCPUDev::dil_linear_fuse_relu( + const at::Tensor& self, + const at::Tensor& weight, + const c10::optional& bias) { + DEBUG("AtenIpexCPUDev::dil_linear\n"); + CHECK_DNNL_OP_PRE_COND(self); + CHECK_DNNL_OP_PRE_COND(weight); + TORCH_CHECK(self.dim() >= 2, + "dil_linear: input needs to has dim at least 2, input dim ", self.dim()); + + // reshape first if input dim is greater than 2 and the reshape will cost a memory copy. + auto self_reshaped = self.dim() > 2 ? self.reshape({-1, self.size(self.dim() - 1)}) : self; + const dil::tensor x = dbl::comm::try_gen_dil_tensor(self_reshaped); + const dil::tensor w = dbl::comm::try_gen_dil_tensor(weight); + + dil::tensor y; + if (bias.has_value()) { + at::Tensor bias_vec = bias.value(); + const dil::tensor b = dbl::comm::try_gen_dil_tensor(bias_vec); + dil::inner_product_forward::compute(x, w, b, y, + /*src_scales=*/dil::scale_t(), + /*weight_scales=*/dil::scale_t(), + /*dst_scales=*/dil::scale_t(), + /*attr*/dil::attr_t::fuse_relu()); + } else { + dil::inner_product_forward::compute(x, w, y, + /*src_scales=*/dil::scale_t(), + /*weight_scales=*/dil::scale_t(), + /*dst_scales=*/dil::scale_t(), + /*attr*/dil::attr_t::fuse_relu()); + } + + auto input_size = self.sizes(); + std::vector output_size(input_size.begin(), input_size.end() - 1); + output_size.push_back(weight.size(0)); + + if (self.dim() > 2) { + return dbl::comm::gen_aten_tensor_by(y).reshape(output_size); + } + return dbl::comm::gen_aten_tensor_by(y); +} + at::Tensor dil_linear_backward_input( at::IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight){ DEBUG("AtenIpexCPUDev::dil_linear_backward_input\n"); @@ -978,6 +1020,15 @@ at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) { return input; } +at::Tensor AtenIpexCPUDev::dil_relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output) { + const dil::tensor& y = dbl::comm::try_gen_dil_tensor(output); + dil::tensor grady = dbl::comm::try_gen_dil_tensor(grad_output); + dil::tensor gradx; + dil::eltwise_backward::compute(y, grady, gradx, + dil::algorithm::eltwise_relu_use_dst_for_bwd, /*alpha*/ 0.0); + return dbl::comm::gen_aten_tensor_by(gradx); +} + at::Tensor AtenIpexCPUDev::dil_threshold_backward(const at::Tensor& grad_output, const at::Tensor& input, at::Scalar threshold) { DEBUG("AtenIpexCPUDev::dil_threshold_backward\n"); CHECK_DNNL_OP_PRE_COND(grad_output); diff --git a/torch_ipex/csrc/cpu/DevOPs.h b/torch_ipex/csrc/cpu/DevOPs.h index 4a7b7ee54..7c76873e6 100644 --- a/torch_ipex/csrc/cpu/DevOPs.h +++ b/torch_ipex/csrc/cpu/DevOPs.h @@ -39,6 +39,7 @@ class AtenIpexCPUDev { static at::Tensor& dil_addbmm_(at::Tensor& self, const at::Tensor& batch1, const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha); static at::Tensor& dil_addbmm_out(at::Tensor& result, const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha); static at::Tensor dil_linear(const at::Tensor& self, const at::Tensor& weight, const c10::optional& bias); + static at::Tensor dil_linear_fuse_relu(const at::Tensor& self, const at::Tensor& weight, const c10::optional& bias); static std::tuple dil_linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array output_mask); static at::Tensor dil_dropout(const at::Tensor& self, double ratio, bool train); static at::Tensor dil_dropout_backward(const at::Tensor& grady, const at::Tensor& mask, double ratio); @@ -54,6 +55,7 @@ class AtenIpexCPUDev { static at::Tensor dil_adaptive_avg_pool2d_backward(const at::Tensor& grad_output, const at::Tensor& input); static at::Tensor dil_relu(const at::Tensor& input); static at::Tensor& dil_relu_(at::Tensor& input); + static at::Tensor dil_relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output); static at::Tensor dil_threshold_backward(const at::Tensor& grad_output, const at::Tensor& input, at::Scalar threshold); static at::Tensor dil__softmax(const at::Tensor& self, const int64_t dim, bool half_to_float); static at::Tensor dil__softmax_backward_data(const at::Tensor& grad_output, const at::Tensor& output, int64_t dim, const at::Tensor& self); diff --git a/torch_ipex/csrc/cpu/ExtendOPs.cpp b/torch_ipex/csrc/cpu/ExtendOPs.cpp index 69d08d1bb..fbc50ad7f 100644 --- a/torch_ipex/csrc/cpu/ExtendOPs.cpp +++ b/torch_ipex/csrc/cpu/ExtendOPs.cpp @@ -427,7 +427,13 @@ at::Tensor AtenIpexTypeExt::linear(const at::Tensor& input, const at::Tensor& we return cpu::AtenIpexCPUDev::dil_linear(input, weight, bias); } +at::Tensor AtenIpexTypeExt::linear_fuse_relu(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) { + RECORD_FUNCTION("linear_fuse_relu", std::vector({input, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr()); + return cpu::AtenIpexCPUDev::dil_linear_fuse_relu(input, weight, bias); +} + std::tuple AtenIpexTypeExt::linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array output_mask) { + RECORD_FUNCTION("linear_backward", std::vector({input, grad_output, weight}), torch::autograd::Node::peek_at_next_sequence_nr()); return cpu::AtenIpexCPUDev::dil_linear_backward(input, grad_output, weight, output_mask); } @@ -451,5 +457,11 @@ at::Tensor AtenIpexTypeExt::reshape(const at::Tensor& input, at::IntArrayRef siz return cpu::AtenIpexCPUDev::dil_reshape(input, size); } + +at::Tensor AtenIpexTypeExt::relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output) { + RECORD_FUNCTION("dil_relu_use_dst_for_bwd", std::vector({grad_output, output}), torch::autograd::Node::peek_at_next_sequence_nr()); + return cpu::AtenIpexCPUDev::dil_relu_use_dst_for_bwd(grad_output, output); +} + } // namespace torch_ipex diff --git a/torch_ipex/csrc/cpu/ExtendOPs.h b/torch_ipex/csrc/cpu/ExtendOPs.h index aa462d9a0..7615d42db 100644 --- a/torch_ipex/csrc/cpu/ExtendOPs.h +++ b/torch_ipex/csrc/cpu/ExtendOPs.h @@ -13,7 +13,9 @@ class AtenIpexTypeExt { static at::Tensor embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets); static at::Tensor embedding_bag_backward(const at::Tensor &grad_out, const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets); static at::Tensor linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias); + static at::Tensor linear_fuse_relu(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias); static std::tuple linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array output_mask); + static at::Tensor relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output); static at::Tensor adaptive_avg_pool2d(at::Tensor const& input, at::IntArrayRef output_size); static at::Tensor adaptive_avg_pool2d_backward(const at::Tensor& grad_output, const at::Tensor& input); static at::Tensor max_pooling(const at::Tensor& input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); diff --git a/torch_ipex/csrc/cpu/dil/dil/operators/inner_product.hpp b/torch_ipex/csrc/cpu/dil/dil/operators/inner_product.hpp index c57e1c3f0..ad129ca9c 100644 --- a/torch_ipex/csrc/cpu/dil/dil/operators/inner_product.hpp +++ b/torch_ipex/csrc/cpu/dil/dil/operators/inner_product.hpp @@ -163,7 +163,7 @@ struct inner_product_forward : public dnnl::inner_product_forward { } } } else { - op_attr = attr; + op_attr = attr; src_desc = {src.get_dims(), data_type::f32, format_tag::any}; if (src.has_scale()) { auto src_scale = src.get_scale(); @@ -233,7 +233,7 @@ struct inner_product_backward_data : public dnnl::inner_product_backward_data { tensor& diff_src, const engine& aengine = engine::cpu_engine()) { auto weights_ = weights; - if (diff_dst.get_data_type() == data_type::bf16) { + if (diff_dst.get_data_type() == data_type::bf16 && weights.get_data_type() != data_type::bf16) { weights_.init(weights.get_desc().to_type(data_type::bf16)); weights_.reorder_from(weights); } diff --git a/torch_ipex/csrc/init_python_bindings.cpp b/torch_ipex/csrc/init_python_bindings.cpp index 6cb5b7c2f..758e5d930 100644 --- a/torch_ipex/csrc/init_python_bindings.cpp +++ b/torch_ipex/csrc/init_python_bindings.cpp @@ -92,10 +92,18 @@ void InitIpexModuleBindings(py::module m) { [](const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) { return AtenIpexTypeExt::linear(input, weight, bias); }); + m.def("linear_fuse_relu", + [](const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) { + return AtenIpexTypeExt::linear_fuse_relu(input, weight, bias); + }); m.def("linear_backward", [](const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array output_mask) { return AtenIpexTypeExt::linear_backward(input, grad_output, weight, output_mask); }); + m.def("relu_use_dst_backward", + [](const at::Tensor& grad_output, const at::Tensor& output) { + return AtenIpexTypeExt::relu_use_dst_for_bwd(grad_output, output); + }); m.def("adaptive_avg_pool2d", [](at::Tensor const& input, at::IntArrayRef output_size) { return AtenIpexTypeExt::adaptive_avg_pool2d(input, output_size); From 81d80ea4382a92947e1e21b45a4cae0d0366291b Mon Sep 17 00:00:00 2001 From: "Zhu, Haozhe" Date: Thu, 28 May 2020 14:07:53 +0800 Subject: [PATCH 2/2] remove py --- tests/cpu/test_lazy_reorder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpu/test_lazy_reorder.py b/tests/cpu/test_lazy_reorder.py index 8d21a0551..41d548f1e 100644 --- a/tests/cpu/test_lazy_reorder.py +++ b/tests/cpu/test_lazy_reorder.py @@ -13,7 +13,7 @@ import torch import _torch_ipex as ipex ipex._initialize_aten_bindings() -import intel_pytorch_extension_py +import intel_pytorch_extension import torch.nn as nn import torch.backends.cudnn as cudnn