Skip to content

Commit fc686f6

Browse files
authored
Enable Linear+ReLU fuse by OneDNNL (#20)
* Enable Linear+ReLU fuse by OneDNNL rename some FNs hide dil from frontend and reuse attr args instead of fuse_relu remove useless headfiles since relu' function body are moved to DevCPs.cpp add unit test for linear fuse relu move ut to test_lazy_reorder * remove py
1 parent ce71457 commit fc686f6

File tree

9 files changed

+182
-2
lines changed

9 files changed

+182
-2
lines changed

intel_pytorch_extension_py/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
from .pooling import *
55
from .reshape import *
66
from .mlp import *
7+
from .linear_fuse_relu import *
8+
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn.parameter import Parameter
4+
from torch.nn import init
5+
from torch.autograd import Function
6+
import math
7+
import _torch_ipex as core
8+
9+
class LinearFuseReluFC(Function):
10+
@staticmethod
11+
def forward(ctx, input, weight, bias):
12+
output = core.linear_fuse_relu(input, weight, bias)
13+
ctx.save_for_backward(input, weight, bias, output)
14+
return output
15+
16+
@staticmethod
17+
def backward(ctx, grad_output):
18+
input, weight, bias, output = ctx.saved_tensors
19+
grad_output = grad_output.contiguous()
20+
if bias == None:
21+
output_mask = (input.requires_grad, weight.requires_grad, 0)
22+
else:
23+
output_mask = (input.requires_grad, weight.requires_grad, bias.requires_grad)
24+
grad_output = core.relu_use_dst_backward(grad_output, output)
25+
grad_input, grad_weight, grad_bias = core.linear_backward(input, grad_output, weight, output_mask)
26+
return (grad_input, grad_weight, grad_bias)
27+
28+
class LinearFuseRelu(nn.Module):
29+
r"""DNNL Linear module for using relu fused DNNL kernel"""
30+
31+
__constants__ = ['bias']
32+
33+
def __init__(self, in_features, out_features, bias=True):
34+
super(LinearFuseRelu, self).__init__()
35+
self.in_features = in_features
36+
self.out_features = out_features
37+
self.weight = Parameter(torch.Tensor(out_features, in_features))
38+
39+
if bias:
40+
self.bias = Parameter(torch.Tensor(out_features))
41+
else:
42+
self.register_parameter('bias', None)
43+
self.reset_parameters()
44+
45+
def reset_parameters(self):
46+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
47+
if self.bias is not None:
48+
bound = 1 / math.sqrt(self.in_features)
49+
init.uniform_(self.bias, -bound, bound)
50+
51+
def forward(self, input):
52+
# print(self.weight.shape)
53+
output = LinearFuseReluFC.apply(input, self.weight, self.bias)
54+
return output
55+

tests/cpu/test_lazy_reorder.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,54 @@ def test_linear_backward(self):
453453
y2.backward()
454454
self.assertEqual(x1.grad, x2.grad)
455455

456+
class TestLinearFuseRelu(TestCase):
457+
def test_linear_fuse_relu_forward(self):
458+
ipex.enable_auto_dnnl()
459+
rand_seed = int(get_rand_seed())
460+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
461+
torch.manual_seed(rand_seed)
462+
in_features = torch.randint(3, 10, (1,)).item()
463+
out_features = torch.randint(3, 100, (1,)).item()
464+
for dtype in [torch.bfloat16, torch.float]:
465+
x = torch.randn(3, in_features) * 10
466+
x = x.to(dtype).to('dpcpp')
467+
for bias in [True, False]:
468+
linear = torch.nn.Linear(in_features, out_features, bias=bias).to('dpcpp').to(dtype)
469+
relu = torch.nn.ReLU()
470+
linear_fuse_relu = intel_pytorch_extension_py.LinearFuseRelu(in_features, out_features, bias=bias)
471+
linear_fuse_relu.weight.data = linear.weight.clone()
472+
if bias:
473+
linear_fuse_relu.bias.data = linear.bias.clone()
474+
self.assertEqual(relu(linear(x)).float(), linear_fuse_relu(x).float())
475+
476+
def test_linear_fuse_relu_backward(self):
477+
ipex.enable_auto_dnnl()
478+
rand_seed = int(get_rand_seed())
479+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
480+
torch.manual_seed(rand_seed)
481+
in_features = torch.randint(3, 10, (1,)).item()
482+
out_features = torch.randint(3, 100, (1,)).item()
483+
for dtype in [torch.bfloat16, torch.float]:
484+
x = torch.randn(3, in_features) * 10
485+
x = x.to(dtype).to('dpcpp')
486+
for bias in [True, False]:
487+
linear = torch.nn.Linear(in_features, out_features, bias=bias).to('dpcpp').to(dtype)
488+
relu = torch.nn.ReLU()
489+
linear_fuse_relu = intel_pytorch_extension_py.LinearFuseRelu(in_features, out_features, bias=bias)
490+
linear_fuse_relu.weight.data = linear.weight.clone()
491+
if bias:
492+
linear_fuse_relu.bias.data = linear.bias.clone()
493+
x1 = x.clone().requires_grad_()
494+
x2 = x.clone().requires_grad_()
495+
y1 = relu(linear(x1).float()).sum()
496+
y2 = linear_fuse_relu(x2).sum()
497+
y1.backward()
498+
y2.backward()
499+
self.assertEqual(x1.grad.float(), x2.grad.float())
500+
self.assertEqual(linear.weight.grad.float(), linear_fuse_relu.weight.grad.float())
501+
if bias:
502+
self.assertEqual(linear.bias.grad.float(), linear_fuse_relu.bias.grad.float())
503+
456504
class TestPool(TestCase):
457505
def test_avg_pool2d(self):
458506
ipex.enable_auto_dnnl()

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,48 @@ at::Tensor AtenIpexCPUDev::dil_linear(
557557
return dbl::comm::gen_aten_tensor_by(std::move(y));
558558
}
559559

560+
at::Tensor AtenIpexCPUDev::dil_linear_fuse_relu(
561+
const at::Tensor& self,
562+
const at::Tensor& weight,
563+
const c10::optional<at::Tensor>& bias) {
564+
DEBUG("AtenIpexCPUDev::dil_linear\n");
565+
CHECK_DNNL_OP_PRE_COND(self);
566+
CHECK_DNNL_OP_PRE_COND(weight);
567+
TORCH_CHECK(self.dim() >= 2,
568+
"dil_linear: input needs to has dim at least 2, input dim ", self.dim());
569+
570+
// reshape first if input dim is greater than 2 and the reshape will cost a memory copy.
571+
auto self_reshaped = self.dim() > 2 ? self.reshape({-1, self.size(self.dim() - 1)}) : self;
572+
const dil::tensor x = dbl::comm::try_gen_dil_tensor(self_reshaped);
573+
const dil::tensor w = dbl::comm::try_gen_dil_tensor(weight);
574+
575+
dil::tensor y;
576+
if (bias.has_value()) {
577+
at::Tensor bias_vec = bias.value();
578+
const dil::tensor b = dbl::comm::try_gen_dil_tensor(bias_vec);
579+
dil::inner_product_forward::compute(x, w, b, y,
580+
/*src_scales=*/dil::scale_t(),
581+
/*weight_scales=*/dil::scale_t(),
582+
/*dst_scales=*/dil::scale_t(),
583+
/*attr*/dil::attr_t::fuse_relu());
584+
} else {
585+
dil::inner_product_forward::compute(x, w, y,
586+
/*src_scales=*/dil::scale_t(),
587+
/*weight_scales=*/dil::scale_t(),
588+
/*dst_scales=*/dil::scale_t(),
589+
/*attr*/dil::attr_t::fuse_relu());
590+
}
591+
592+
auto input_size = self.sizes();
593+
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
594+
output_size.push_back(weight.size(0));
595+
596+
if (self.dim() > 2) {
597+
return dbl::comm::gen_aten_tensor_by(y).reshape(output_size);
598+
}
599+
return dbl::comm::gen_aten_tensor_by(y);
600+
}
601+
560602
at::Tensor dil_linear_backward_input(
561603
at::IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight){
562604
DEBUG("AtenIpexCPUDev::dil_linear_backward_input\n");
@@ -988,6 +1030,15 @@ at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) {
9881030
return input;
9891031
}
9901032

1033+
at::Tensor AtenIpexCPUDev::dil_relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output) {
1034+
const dil::tensor& y = dbl::comm::try_gen_dil_tensor(output);
1035+
dil::tensor grady = dbl::comm::try_gen_dil_tensor(grad_output);
1036+
dil::tensor gradx;
1037+
dil::eltwise_backward::compute(y, grady, gradx,
1038+
dil::algorithm::eltwise_relu_use_dst_for_bwd, /*alpha*/ 0.0);
1039+
return dbl::comm::gen_aten_tensor_by(gradx);
1040+
}
1041+
9911042
at::Tensor AtenIpexCPUDev::dil_threshold_backward(const at::Tensor& grad_output, const at::Tensor& input, at::Scalar threshold) {
9921043
DEBUG("AtenIpexCPUDev::dil_threshold_backward\n");
9931044
CHECK_DNNL_OP_PRE_COND(grad_output);

torch_ipex/csrc/cpu/DevOPs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class AtenIpexCPUDev {
3939
static at::Tensor& dil_addbmm_(at::Tensor& self, const at::Tensor& batch1, const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha);
4040
static at::Tensor& dil_addbmm_out(at::Tensor& result, const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha);
4141
static at::Tensor dil_linear(const at::Tensor& self, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
42+
static at::Tensor dil_linear_fuse_relu(const at::Tensor& self, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
4243
static std::tuple<at::Tensor, at::Tensor, at::Tensor> dil_linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask);
4344
static at::Tensor dil_dropout(const at::Tensor& self, double ratio, bool train);
4445
static at::Tensor dil_dropout_backward(const at::Tensor& grady, const at::Tensor& mask, double ratio);
@@ -54,6 +55,7 @@ class AtenIpexCPUDev {
5455
static at::Tensor dil_adaptive_avg_pool2d_backward(const at::Tensor& grad_output, const at::Tensor& input);
5556
static at::Tensor dil_relu(const at::Tensor& input);
5657
static at::Tensor& dil_relu_(at::Tensor& input);
58+
static at::Tensor dil_relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output);
5759
static at::Tensor dil_threshold_backward(const at::Tensor& grad_output, const at::Tensor& input, at::Scalar threshold);
5860
static at::Tensor dil__softmax(const at::Tensor& self, const int64_t dim, bool half_to_float);
5961
static at::Tensor dil__softmax_backward_data(const at::Tensor& grad_output, const at::Tensor& output, int64_t dim, const at::Tensor& self);

torch_ipex/csrc/cpu/ExtendOPs.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,13 @@ at::Tensor AtenIpexTypeExt::linear(const at::Tensor& input, const at::Tensor& we
453453
return cpu::AtenIpexCPUDev::dil_linear(input, weight, bias);
454454
}
455455

456+
at::Tensor AtenIpexTypeExt::linear_fuse_relu(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
457+
RECORD_FUNCTION("linear_fuse_relu", std::vector<c10::IValue>({input, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
458+
return cpu::AtenIpexCPUDev::dil_linear_fuse_relu(input, weight, bias);
459+
}
460+
456461
std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexTypeExt::linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask) {
462+
RECORD_FUNCTION("linear_backward", std::vector<c10::IValue>({input, grad_output, weight}), torch::autograd::Node::peek_at_next_sequence_nr());
457463
return cpu::AtenIpexCPUDev::dil_linear_backward(input, grad_output, weight, output_mask);
458464
}
459465

@@ -477,5 +483,11 @@ at::Tensor AtenIpexTypeExt::reshape(const at::Tensor& input, at::IntArrayRef siz
477483
return cpu::AtenIpexCPUDev::dil_reshape(input, size);
478484
}
479485

486+
487+
at::Tensor AtenIpexTypeExt::relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output) {
488+
RECORD_FUNCTION("dil_relu_use_dst_for_bwd", std::vector<c10::IValue>({grad_output, output}), torch::autograd::Node::peek_at_next_sequence_nr());
489+
return cpu::AtenIpexCPUDev::dil_relu_use_dst_for_bwd(grad_output, output);
490+
}
491+
480492
} // namespace torch_ipex
481493

torch_ipex/csrc/cpu/ExtendOPs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ class AtenIpexTypeExt {
2424
const c10::optional<at::Tensor>& per_sample_weights);
2525

2626
static at::Tensor linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
27+
static at::Tensor linear_fuse_relu(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
2728
static std::tuple<at::Tensor, at::Tensor, at::Tensor> linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask);
29+
static at::Tensor relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output);
2830
static at::Tensor adaptive_avg_pool2d(at::Tensor const& input, at::IntArrayRef output_size);
2931
static at::Tensor adaptive_avg_pool2d_backward(const at::Tensor& grad_output, const at::Tensor& input);
3032
static at::Tensor max_pooling(const at::Tensor& input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode);

torch_ipex/csrc/cpu/dil/dil/operators/inner_product.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ struct inner_product_forward : public dnnl::inner_product_forward {
163163
}
164164
}
165165
} else {
166-
op_attr = attr;
166+
op_attr = attr;
167167
src_desc = {src.get_dims(), data_type::f32, format_tag::any};
168168
if (src.has_scale()) {
169169
auto src_scale = src.get_scale();
@@ -233,7 +233,7 @@ struct inner_product_backward_data : public dnnl::inner_product_backward_data {
233233
tensor& diff_src,
234234
const engine& aengine = engine::cpu_engine()) {
235235
auto weights_ = weights;
236-
if (diff_dst.get_data_type() == data_type::bf16) {
236+
if (diff_dst.get_data_type() == data_type::bf16 && weights.get_data_type() != data_type::bf16) {
237237
weights_.init(weights.get_desc().to_type(data_type::bf16));
238238
weights_.reorder_from(weights);
239239
}

torch_ipex/csrc/init_python_bindings.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,18 @@ void InitIpexModuleBindings(py::module m) {
9191
[](const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
9292
return AtenIpexTypeExt::linear(input, weight, bias);
9393
});
94+
m.def("linear_fuse_relu",
95+
[](const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
96+
return AtenIpexTypeExt::linear_fuse_relu(input, weight, bias);
97+
});
9498
m.def("linear_backward",
9599
[](const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask) {
96100
return AtenIpexTypeExt::linear_backward(input, grad_output, weight, output_mask);
97101
});
102+
m.def("relu_use_dst_backward",
103+
[](const at::Tensor& grad_output, const at::Tensor& output) {
104+
return AtenIpexTypeExt::relu_use_dst_for_bwd(grad_output, output);
105+
});
98106
m.def("adaptive_avg_pool2d",
99107
[](at::Tensor const& input, at::IntArrayRef output_size) {
100108
return AtenIpexTypeExt::adaptive_avg_pool2d(input, output_size);

0 commit comments

Comments
 (0)