Skip to content

Commit 3e499cb

Browse files
committed
Enable Linear+ReLU fuse by OneDNNL
rename some FNs hide dil from frontend and reuse attr args instead of fuse_relu remove useless headfiles since relu' function body are moved to DevCPs.cpp add unit test for linear fuse relu move ut to test_lazy_reorder
1 parent fb8b9df commit 3e499cb

File tree

9 files changed

+183
-3
lines changed

9 files changed

+183
-3
lines changed

intel_pytorch_extension_py/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
from .pooling import *
55
from .reshape import *
66
from .mlp import *
7+
from .linear_fuse_relu import *
8+
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn.parameter import Parameter
4+
from torch.nn import init
5+
from torch.autograd import Function
6+
import math
7+
import _torch_ipex as core
8+
9+
class LinearFuseReluFC(Function):
10+
@staticmethod
11+
def forward(ctx, input, weight, bias):
12+
output = core.linear_fuse_relu(input, weight, bias)
13+
ctx.save_for_backward(input, weight, bias, output)
14+
return output
15+
16+
@staticmethod
17+
def backward(ctx, grad_output):
18+
input, weight, bias, output = ctx.saved_tensors
19+
grad_output = grad_output.contiguous()
20+
if bias == None:
21+
output_mask = (input.requires_grad, weight.requires_grad, 0)
22+
else:
23+
output_mask = (input.requires_grad, weight.requires_grad, bias.requires_grad)
24+
grad_output = core.relu_use_dst_backward(grad_output, output)
25+
grad_input, grad_weight, grad_bias = core.linear_backward(input, grad_output, weight, output_mask)
26+
return (grad_input, grad_weight, grad_bias)
27+
28+
class LinearFuseRelu(nn.Module):
29+
r"""DNNL Linear module for using relu fused DNNL kernel"""
30+
31+
__constants__ = ['bias']
32+
33+
def __init__(self, in_features, out_features, bias=True):
34+
super(LinearFuseRelu, self).__init__()
35+
self.in_features = in_features
36+
self.out_features = out_features
37+
self.weight = Parameter(torch.Tensor(out_features, in_features))
38+
39+
if bias:
40+
self.bias = Parameter(torch.Tensor(out_features))
41+
else:
42+
self.register_parameter('bias', None)
43+
self.reset_parameters()
44+
45+
def reset_parameters(self):
46+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
47+
if self.bias is not None:
48+
bound = 1 / math.sqrt(self.in_features)
49+
init.uniform_(self.bias, -bound, bound)
50+
51+
def forward(self, input):
52+
# print(self.weight.shape)
53+
output = LinearFuseReluFC.apply(input, self.weight, self.bias)
54+
return output
55+

tests/cpu/test_lazy_reorder.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414
import _torch_ipex as ipex
1515
ipex._initialize_aten_bindings()
16-
import intel_pytorch_extension
16+
import intel_pytorch_extension_py
1717

1818
import torch.nn as nn
1919
import torch.backends.cudnn as cudnn
@@ -438,6 +438,54 @@ def test_linear_backward(self):
438438
y2.backward()
439439
self.assertEqual(x1.grad, x2.grad)
440440

441+
class TestLinearFuseRelu(TestCase):
442+
def test_linear_fuse_relu_forward(self):
443+
ipex.enable_auto_dnnl()
444+
rand_seed = int(get_rand_seed())
445+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
446+
torch.manual_seed(rand_seed)
447+
in_features = torch.randint(3, 10, (1,)).item()
448+
out_features = torch.randint(3, 100, (1,)).item()
449+
for dtype in [torch.bfloat16, torch.float]:
450+
x = torch.randn(3, in_features) * 10
451+
x = x.to(dtype).to('dpcpp')
452+
for bias in [True, False]:
453+
linear = torch.nn.Linear(in_features, out_features, bias=bias).to('dpcpp').to(dtype)
454+
relu = torch.nn.ReLU()
455+
linear_fuse_relu = intel_pytorch_extension_py.LinearFuseRelu(in_features, out_features, bias=bias)
456+
linear_fuse_relu.weight.data = linear.weight.clone()
457+
if bias:
458+
linear_fuse_relu.bias.data = linear.bias.clone()
459+
self.assertEqual(relu(linear(x)).float(), linear_fuse_relu(x).float())
460+
461+
def test_linear_fuse_relu_backward(self):
462+
ipex.enable_auto_dnnl()
463+
rand_seed = int(get_rand_seed())
464+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
465+
torch.manual_seed(rand_seed)
466+
in_features = torch.randint(3, 10, (1,)).item()
467+
out_features = torch.randint(3, 100, (1,)).item()
468+
for dtype in [torch.bfloat16, torch.float]:
469+
x = torch.randn(3, in_features) * 10
470+
x = x.to(dtype).to('dpcpp')
471+
for bias in [True, False]:
472+
linear = torch.nn.Linear(in_features, out_features, bias=bias).to('dpcpp').to(dtype)
473+
relu = torch.nn.ReLU()
474+
linear_fuse_relu = intel_pytorch_extension_py.LinearFuseRelu(in_features, out_features, bias=bias)
475+
linear_fuse_relu.weight.data = linear.weight.clone()
476+
if bias:
477+
linear_fuse_relu.bias.data = linear.bias.clone()
478+
x1 = x.clone().requires_grad_()
479+
x2 = x.clone().requires_grad_()
480+
y1 = relu(linear(x1).float()).sum()
481+
y2 = linear_fuse_relu(x2).sum()
482+
y1.backward()
483+
y2.backward()
484+
self.assertEqual(x1.grad.float(), x2.grad.float())
485+
self.assertEqual(linear.weight.grad.float(), linear_fuse_relu.weight.grad.float())
486+
if bias:
487+
self.assertEqual(linear.bias.grad.float(), linear_fuse_relu.bias.grad.float())
488+
441489
class TestPool(TestCase):
442490
def test_avg_pool2d(self):
443491
ipex.enable_auto_dnnl()

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,48 @@ at::Tensor AtenIpexCPUDev::dil_linear(
549549
return dbl::comm::gen_aten_tensor_by(y);
550550
}
551551

552+
at::Tensor AtenIpexCPUDev::dil_linear_fuse_relu(
553+
const at::Tensor& self,
554+
const at::Tensor& weight,
555+
const c10::optional<at::Tensor>& bias) {
556+
DEBUG("AtenIpexCPUDev::dil_linear\n");
557+
CHECK_DNNL_OP_PRE_COND(self);
558+
CHECK_DNNL_OP_PRE_COND(weight);
559+
TORCH_CHECK(self.dim() >= 2,
560+
"dil_linear: input needs to has dim at least 2, input dim ", self.dim());
561+
562+
// reshape first if input dim is greater than 2 and the reshape will cost a memory copy.
563+
auto self_reshaped = self.dim() > 2 ? self.reshape({-1, self.size(self.dim() - 1)}) : self;
564+
const dil::tensor x = dbl::comm::try_gen_dil_tensor(self_reshaped);
565+
const dil::tensor w = dbl::comm::try_gen_dil_tensor(weight);
566+
567+
dil::tensor y;
568+
if (bias.has_value()) {
569+
at::Tensor bias_vec = bias.value();
570+
const dil::tensor b = dbl::comm::try_gen_dil_tensor(bias_vec);
571+
dil::inner_product_forward::compute(x, w, b, y,
572+
/*src_scales=*/dil::scale_t(),
573+
/*weight_scales=*/dil::scale_t(),
574+
/*dst_scales=*/dil::scale_t(),
575+
/*attr*/dil::attr_t::fuse_relu());
576+
} else {
577+
dil::inner_product_forward::compute(x, w, y,
578+
/*src_scales=*/dil::scale_t(),
579+
/*weight_scales=*/dil::scale_t(),
580+
/*dst_scales=*/dil::scale_t(),
581+
/*attr*/dil::attr_t::fuse_relu());
582+
}
583+
584+
auto input_size = self.sizes();
585+
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
586+
output_size.push_back(weight.size(0));
587+
588+
if (self.dim() > 2) {
589+
return dbl::comm::gen_aten_tensor_by(y).reshape(output_size);
590+
}
591+
return dbl::comm::gen_aten_tensor_by(y);
592+
}
593+
552594
at::Tensor dil_linear_backward_input(
553595
at::IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight){
554596
DEBUG("AtenIpexCPUDev::dil_linear_backward_input\n");
@@ -978,6 +1020,15 @@ at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) {
9781020
return input;
9791021
}
9801022

1023+
at::Tensor AtenIpexCPUDev::dil_relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output) {
1024+
const dil::tensor& y = dbl::comm::try_gen_dil_tensor(output);
1025+
dil::tensor grady = dbl::comm::try_gen_dil_tensor(grad_output);
1026+
dil::tensor gradx;
1027+
dil::eltwise_backward::compute(y, grady, gradx,
1028+
dil::algorithm::eltwise_relu_use_dst_for_bwd, /*alpha*/ 0.0);
1029+
return dbl::comm::gen_aten_tensor_by(gradx);
1030+
}
1031+
9811032
at::Tensor AtenIpexCPUDev::dil_threshold_backward(const at::Tensor& grad_output, const at::Tensor& input, at::Scalar threshold) {
9821033
DEBUG("AtenIpexCPUDev::dil_threshold_backward\n");
9831034
CHECK_DNNL_OP_PRE_COND(grad_output);

torch_ipex/csrc/cpu/DevOPs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class AtenIpexCPUDev {
3939
static at::Tensor& dil_addbmm_(at::Tensor& self, const at::Tensor& batch1, const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha);
4040
static at::Tensor& dil_addbmm_out(at::Tensor& result, const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha);
4141
static at::Tensor dil_linear(const at::Tensor& self, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
42+
static at::Tensor dil_linear_fuse_relu(const at::Tensor& self, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
4243
static std::tuple<at::Tensor, at::Tensor, at::Tensor> dil_linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask);
4344
static at::Tensor dil_dropout(const at::Tensor& self, double ratio, bool train);
4445
static at::Tensor dil_dropout_backward(const at::Tensor& grady, const at::Tensor& mask, double ratio);
@@ -54,6 +55,7 @@ class AtenIpexCPUDev {
5455
static at::Tensor dil_adaptive_avg_pool2d_backward(const at::Tensor& grad_output, const at::Tensor& input);
5556
static at::Tensor dil_relu(const at::Tensor& input);
5657
static at::Tensor& dil_relu_(at::Tensor& input);
58+
static at::Tensor dil_relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output);
5759
static at::Tensor dil_threshold_backward(const at::Tensor& grad_output, const at::Tensor& input, at::Scalar threshold);
5860
static at::Tensor dil__softmax(const at::Tensor& self, const int64_t dim, bool half_to_float);
5961
static at::Tensor dil__softmax_backward_data(const at::Tensor& grad_output, const at::Tensor& output, int64_t dim, const at::Tensor& self);

torch_ipex/csrc/cpu/ExtendOPs.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,13 @@ at::Tensor AtenIpexTypeExt::linear(const at::Tensor& input, const at::Tensor& we
427427
return cpu::AtenIpexCPUDev::dil_linear(input, weight, bias);
428428
}
429429

430+
at::Tensor AtenIpexTypeExt::linear_fuse_relu(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
431+
RECORD_FUNCTION("linear_fuse_relu", std::vector<c10::IValue>({input, weight, bias}), torch::autograd::Node::peek_at_next_sequence_nr());
432+
return cpu::AtenIpexCPUDev::dil_linear_fuse_relu(input, weight, bias);
433+
}
434+
430435
std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexTypeExt::linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask) {
436+
RECORD_FUNCTION("linear_backward", std::vector<c10::IValue>({input, grad_output, weight}), torch::autograd::Node::peek_at_next_sequence_nr());
431437
return cpu::AtenIpexCPUDev::dil_linear_backward(input, grad_output, weight, output_mask);
432438
}
433439

@@ -451,5 +457,11 @@ at::Tensor AtenIpexTypeExt::reshape(const at::Tensor& input, at::IntArrayRef siz
451457
return cpu::AtenIpexCPUDev::dil_reshape(input, size);
452458
}
453459

460+
461+
at::Tensor AtenIpexTypeExt::relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output) {
462+
RECORD_FUNCTION("dil_relu_use_dst_for_bwd", std::vector<c10::IValue>({grad_output, output}), torch::autograd::Node::peek_at_next_sequence_nr());
463+
return cpu::AtenIpexCPUDev::dil_relu_use_dst_for_bwd(grad_output, output);
464+
}
465+
454466
} // namespace torch_ipex
455467

torch_ipex/csrc/cpu/ExtendOPs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ class AtenIpexTypeExt {
1313
static at::Tensor embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
1414
static at::Tensor embedding_bag_backward(const at::Tensor &grad_out, const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
1515
static at::Tensor linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
16+
static at::Tensor linear_fuse_relu(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
1617
static std::tuple<at::Tensor, at::Tensor, at::Tensor> linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask);
18+
static at::Tensor relu_use_dst_for_bwd(const at::Tensor& grad_output, const at::Tensor& output);
1719
static at::Tensor adaptive_avg_pool2d(at::Tensor const& input, at::IntArrayRef output_size);
1820
static at::Tensor adaptive_avg_pool2d_backward(const at::Tensor& grad_output, const at::Tensor& input);
1921
static at::Tensor max_pooling(const at::Tensor& input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode);

torch_ipex/csrc/cpu/dil/dil/operators/inner_product.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ struct inner_product_forward : public dnnl::inner_product_forward {
163163
}
164164
}
165165
} else {
166-
op_attr = attr;
166+
op_attr = attr;
167167
src_desc = {src.get_dims(), data_type::f32, format_tag::any};
168168
if (src.has_scale()) {
169169
auto src_scale = src.get_scale();
@@ -233,7 +233,7 @@ struct inner_product_backward_data : public dnnl::inner_product_backward_data {
233233
tensor& diff_src,
234234
const engine& aengine = engine::cpu_engine()) {
235235
auto weights_ = weights;
236-
if (diff_dst.get_data_type() == data_type::bf16) {
236+
if (diff_dst.get_data_type() == data_type::bf16 && weights.get_data_type() != data_type::bf16) {
237237
weights_.init(weights.get_desc().to_type(data_type::bf16));
238238
weights_.reorder_from(weights);
239239
}

torch_ipex/csrc/init_python_bindings.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,18 @@ void InitIpexModuleBindings(py::module m) {
9292
[](const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
9393
return AtenIpexTypeExt::linear(input, weight, bias);
9494
});
95+
m.def("linear_fuse_relu",
96+
[](const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
97+
return AtenIpexTypeExt::linear_fuse_relu(input, weight, bias);
98+
});
9599
m.def("linear_backward",
96100
[](const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask) {
97101
return AtenIpexTypeExt::linear_backward(input, grad_output, weight, output_mask);
98102
});
103+
m.def("relu_use_dst_backward",
104+
[](const at::Tensor& grad_output, const at::Tensor& output) {
105+
return AtenIpexTypeExt::relu_use_dst_for_bwd(grad_output, output);
106+
});
99107
m.def("adaptive_avg_pool2d",
100108
[](at::Tensor const& input, at::IntArrayRef output_size) {
101109
return AtenIpexTypeExt::adaptive_avg_pool2d(input, output_size);

0 commit comments

Comments
 (0)