Skip to content

Enable Linear+ReLU fuse by OneDNNL #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions intel_pytorch_extension_py/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
from .pooling import *
from .reshape import *
from .mlp import *
from .linear_fuse_relu import *

55 changes: 55 additions & 0 deletions intel_pytorch_extension_py/ops/linear_fuse_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
from torch import nn
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.autograd import Function
import math
import _torch_ipex as core

class LinearFuseReluFC(Function):
@staticmethod
def forward(ctx, input, weight, bias):
output = core.linear_fuse_relu(input, weight, bias)
ctx.save_for_backward(input, weight, bias, output)
return output

@staticmethod
def backward(ctx, grad_output):
input, weight, bias, output = ctx.saved_tensors
grad_output = grad_output.contiguous()
if bias == None:
output_mask = (input.requires_grad, weight.requires_grad, 0)
else:
output_mask = (input.requires_grad, weight.requires_grad, bias.requires_grad)
grad_output = core.relu_use_dst_backward(grad_output, output)
grad_input, grad_weight, grad_bias = core.linear_backward(input, grad_output, weight, output_mask)
return (grad_input, grad_weight, grad_bias)

class LinearFuseRelu(nn.Module):
r"""DNNL Linear module for using relu fused DNNL kernel"""

__constants__ = ['bias']

def __init__(self, in_features, out_features, bias=True):
super(LinearFuseRelu, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))

if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()

def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
bound = 1 / math.sqrt(self.in_features)
init.uniform_(self.bias, -bound, bound)

def forward(self, input):
# print(self.weight.shape)
output = LinearFuseReluFC.apply(input, self.weight, self.bias)
return output

48 changes: 48 additions & 0 deletions tests/cpu/test_lazy_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,54 @@ def test_linear_backward(self):
y2.backward()
self.assertEqual(x1.grad, x2.grad)

class TestLinearFuseRelu(TestCase):
def test_linear_fuse_relu_forward(self):
ipex.enable_auto_dnnl()
rand_seed = int(get_rand_seed())
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
torch.manual_seed(rand_seed)
in_features = torch.randint(3, 10, (1,)).item()
out_features = torch.randint(3, 100, (1,)).item()
for dtype in [torch.bfloat16, torch.float]:
x = torch.randn(3, in_features) * 10
x = x.to(dtype).to('dpcpp')
for bias in [True, False]:
linear = torch.nn.Linear(in_features, out_features, bias=bias).to('dpcpp').to(dtype)
relu = torch.nn.ReLU()
linear_fuse_relu = intel_pytorch_extension_py.LinearFuseRelu(in_features, out_features, bias=bias)
linear_fuse_relu.weight.data = linear.weight.clone()
if bias:
linear_fuse_relu.bias.data = linear.bias.clone()
self.assertEqual(relu(linear(x)).float(), linear_fuse_relu(x).float())

def test_linear_fuse_relu_backward(self):
ipex.enable_auto_dnnl()
rand_seed = int(get_rand_seed())
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
torch.manual_seed(rand_seed)
in_features = torch.randint(3, 10, (1,)).item()
out_features = torch.randint(3, 100, (1,)).item()
for dtype in [torch.bfloat16, torch.float]:
x = torch.randn(3, in_features) * 10
x = x.to(dtype).to('dpcpp')
for bias in [True, False]:
linear = torch.nn.Linear(in_features, out_features, bias=bias).to('dpcpp').to(dtype)
relu = torch.nn.ReLU()
linear_fuse_relu = intel_pytorch_extension_py.LinearFuseRelu(in_features, out_features, bias=bias)
linear_fuse_relu.weight.data = linear.weight.clone()
if bias:
linear_fuse_relu.bias.data = linear.bias.clone()
x1 = x.clone().requires_grad_()
x2 = x.clone().requires_grad_()
y1 = relu(linear(x1).float()).sum()
y2 = linear_fuse_relu(x2).sum()
y1.backward()
y2.backward()
self.assertEqual(x1.grad.float(), x2.grad.float())
self.assertEqual(linear.weight.grad.float(), linear_fuse_relu.weight.grad.float())
if bias:
self.assertEqual(linear.bias.grad.float(), linear_fuse_relu.bias.grad.float())

class TestPool(TestCase):
def test_avg_pool2d(self):
ipex.enable_auto_dnnl()
Expand Down
51 changes: 51 additions & 0 deletions torch_ipex/csrc/cpu/DevOPs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,48 @@ at::Tensor AtenIpexCPUDev::dil_linear(
return dbl::comm::gen_aten_tensor_by(y);
}

at::Tensor AtenIpexCPUDev::dil_linear_fuse_relu(
const at::Tensor& self,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias) {
DEBUG("AtenIpexCPUDev::dil_linear\n");
CHECK_DNNL_OP_PRE_COND(self);
CHECK_DNNL_OP_PRE_COND(weight);
TORCH_CHECK(self.dim() >= 2,
"dil_linear: input needs to has dim at least 2, input dim ", self.dim());

// reshape first if input dim is greater than 2 and the reshape will cost a memory copy.
auto self_reshaped = self.dim() > 2 ? self.reshape({-1, self.size(self.dim() - 1)}) : self;
const dil::tensor x = dbl::comm::try_gen_dil_tensor(self_reshaped);
const dil::tensor w = dbl::comm::try_gen_dil_tensor(weight);

dil::tensor y;
if (bias.has_value()) {
at::Tensor bias_vec = bias.value();
const dil::tensor b = dbl::comm::try_gen_dil_tensor(bias_vec);
dil::inner_product_forward::compute(x, w, b, y,
/*src_scales=*/dil::scale_t(),
/*weight_scales=*/dil::scale_t(),
/*dst_scales=*/dil::scale_t(),
/*attr*/dil::attr_t::fuse_relu());
} else {
dil::inner_product_forward::compute(x, w, y,
/*src_scales=*/dil::scale_t(),
/*weight_scales=*/dil::scale_t(),
/*dst_scales=*/dil::scale_t(),
/*attr*/dil::attr_t::fuse_relu());
}

auto input_size = self.sizes();
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
output_size.push_back(weight.size(0));

if (self.dim() > 2) {
return dbl::comm::gen_aten_tensor_by(y).reshape(output_size);
}
return dbl::comm::gen_aten_tensor_by(y);
}

at::Tensor dil_linear_backward_input(
at::IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight){
DEBUG("AtenIpexCPUDev::dil_linear_backward_input\n");
Expand Down Expand Up @@ -978,6 +1020,15 @@ at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) {
return input;
}

at::Tensor AtenIpexCPUDev::dil_relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output) {
const dil::tensor& y = dbl::comm::try_gen_dil_tensor(output);
dil::tensor grady = dbl::comm::try_gen_dil_tensor(grad_output);
dil::tensor gradx;
dil::eltwise_backward::compute(y, grady, gradx,
dil::algorithm::eltwise_relu_use_dst_for_bwd, /*alpha*/ 0.0);
return dbl::comm::gen_aten_tensor_by(gradx);
}

at::Tensor AtenIpexCPUDev::dil_threshold_backward(const at::Tensor& grad_output, const at::Tensor& input, at::Scalar threshold) {
DEBUG("AtenIpexCPUDev::dil_threshold_backward\n");
CHECK_DNNL_OP_PRE_COND(grad_output);
Expand Down
2 changes: 2 additions & 0 deletions torch_ipex/csrc/cpu/DevOPs.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class AtenIpexCPUDev {
static at::Tensor& dil_addbmm_(at::Tensor& self, const at::Tensor& batch1, const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha);
static at::Tensor& dil_addbmm_out(at::Tensor& result, const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha);
static at::Tensor dil_linear(const at::Tensor& self, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
static at::Tensor dil_linear_fuse_relu(const at::Tensor& self, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
static std::tuple<at::Tensor, at::Tensor, at::Tensor> dil_linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask);
static at::Tensor dil_dropout(const at::Tensor& self, double ratio, bool train);
static at::Tensor dil_dropout_backward(const at::Tensor& grady, const at::Tensor& mask, double ratio);
Expand All @@ -54,6 +55,7 @@ class AtenIpexCPUDev {
static at::Tensor dil_adaptive_avg_pool2d_backward(const at::Tensor& grad_output, const at::Tensor& input);
static at::Tensor dil_relu(const at::Tensor& input);
static at::Tensor& dil_relu_(at::Tensor& input);
static at::Tensor dil_relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output);
static at::Tensor dil_threshold_backward(const at::Tensor& grad_output, const at::Tensor& input, at::Scalar threshold);
static at::Tensor dil__softmax(const at::Tensor& self, const int64_t dim, bool half_to_float);
static at::Tensor dil__softmax_backward_data(const at::Tensor& grad_output, const at::Tensor& output, int64_t dim, const at::Tensor& self);
Expand Down
12 changes: 12 additions & 0 deletions torch_ipex/csrc/cpu/ExtendOPs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,13 @@ at::Tensor AtenIpexTypeExt::linear(const at::Tensor& input, const at::Tensor& we
return cpu::AtenIpexCPUDev::dil_linear(input, weight, bias);
}

at::Tensor AtenIpexTypeExt::linear_fuse_relu(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
RECORD_FUNCTION("linear_fuse_relu", std::vector<c10::IValue>({input, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
return cpu::AtenIpexCPUDev::dil_linear_fuse_relu(input, weight, bias);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexTypeExt::linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask) {
RECORD_FUNCTION("linear_backward", std::vector<c10::IValue>({input, grad_output, weight}), torch::autograd::Node::peek_at_next_sequence_nr());
return cpu::AtenIpexCPUDev::dil_linear_backward(input, grad_output, weight, output_mask);
}

Expand All @@ -451,5 +457,11 @@ at::Tensor AtenIpexTypeExt::reshape(const at::Tensor& input, at::IntArrayRef siz
return cpu::AtenIpexCPUDev::dil_reshape(input, size);
}


at::Tensor AtenIpexTypeExt::relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output) {
RECORD_FUNCTION("dil_relu_use_dst_for_bwd", std::vector<c10::IValue>({grad_output, output}), torch::autograd::Node::peek_at_next_sequence_nr());
return cpu::AtenIpexCPUDev::dil_relu_use_dst_for_bwd(grad_output, output);
}

} // namespace torch_ipex

2 changes: 2 additions & 0 deletions torch_ipex/csrc/cpu/ExtendOPs.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ class AtenIpexTypeExt {
static at::Tensor embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
static at::Tensor embedding_bag_backward(const at::Tensor &grad_out, const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
static at::Tensor linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
static at::Tensor linear_fuse_relu(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
static std::tuple<at::Tensor, at::Tensor, at::Tensor> linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask);
static at::Tensor relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output);
static at::Tensor adaptive_avg_pool2d(at::Tensor const& input, at::IntArrayRef output_size);
static at::Tensor adaptive_avg_pool2d_backward(const at::Tensor& grad_output, const at::Tensor& input);
static at::Tensor max_pooling(const at::Tensor& input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode);
Expand Down
4 changes: 2 additions & 2 deletions torch_ipex/csrc/cpu/dil/dil/operators/inner_product.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ struct inner_product_forward : public dnnl::inner_product_forward {
}
}
} else {
op_attr = attr;
op_attr = attr;
src_desc = {src.get_dims(), data_type::f32, format_tag::any};
if (src.has_scale()) {
auto src_scale = src.get_scale();
Expand Down Expand Up @@ -233,7 +233,7 @@ struct inner_product_backward_data : public dnnl::inner_product_backward_data {
tensor& diff_src,
const engine& aengine = engine::cpu_engine()) {
auto weights_ = weights;
if (diff_dst.get_data_type() == data_type::bf16) {
if (diff_dst.get_data_type() == data_type::bf16 && weights.get_data_type() != data_type::bf16) {
weights_.init(weights.get_desc().to_type(data_type::bf16));
weights_.reorder_from(weights);
}
Expand Down
8 changes: 8 additions & 0 deletions torch_ipex/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,18 @@ void InitIpexModuleBindings(py::module m) {
[](const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
return AtenIpexTypeExt::linear(input, weight, bias);
});
m.def("linear_fuse_relu",
[](const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
return AtenIpexTypeExt::linear_fuse_relu(input, weight, bias);
});
m.def("linear_backward",
[](const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask) {
return AtenIpexTypeExt::linear_backward(input, grad_output, weight, output_mask);
});
m.def("relu_use_dst_backward",
[](const at::Tensor& grad_output, const at::Tensor& output) {
return AtenIpexTypeExt::relu_use_dst_for_bwd(grad_output, output);
});
m.def("adaptive_avg_pool2d",
[](at::Tensor const& input, at::IntArrayRef output_size) {
return AtenIpexTypeExt::adaptive_avg_pool2d(input, output_size);
Expand Down