Skip to content

Commit c0931a8

Browse files
committed
hide dil from frontend and reuse attr args instead of fuse_relu
1 parent 6a614f0 commit c0931a8

File tree

4 files changed

+19
-15
lines changed

4 files changed

+19
-15
lines changed

intel_pytorch_extension_py/ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@
44
from .pooling import *
55
from .reshape import *
66
from .mlp import *
7-
from .dil_linear_relu import *
7+
from .linear_fuse_relu import *
88

intel_pytorch_extension_py/ops/dil_linear_relu.py renamed to intel_pytorch_extension_py/ops/linear_fuse_relu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import math
77
import _torch_ipex as core
88

9-
class dilLinearFuseReluFC(Function):
9+
class LinearFuseReluFC(Function):
1010
@staticmethod
1111
def forward(ctx, input, weight, bias):
1212
output = core.linear_fuse_relu(input, weight, bias)
@@ -25,13 +25,13 @@ def backward(ctx, grad_output):
2525
grad_input, grad_weight, grad_bias = core.linear_backward(input, grad_output, weight, output_mask)
2626
return (grad_input, grad_weight, grad_bias)
2727

28-
class dilLinearFuseRelu(nn.Module):
28+
class LinearFuseRelu(nn.Module):
2929
r"""DNNL Linear module for using relu fused DNNL kernel"""
3030

3131
__constants__ = ['bias']
3232

3333
def __init__(self, in_features, out_features, bias=True):
34-
super(dilLinearFuseRelu, self).__init__()
34+
super(LinearFuseRelu, self).__init__()
3535
self.in_features = in_features
3636
self.out_features = out_features
3737
self.weight = Parameter(torch.Tensor(out_features, in_features))
@@ -50,6 +50,6 @@ def reset_parameters(self):
5050

5151
def forward(self, input):
5252
# print(self.weight.shape)
53-
output = dilLinearFuseReluFC.apply(input, self.weight, self.bias)
53+
output = LinearFuseReluFC.apply(input, self.weight, self.bias)
5454
return output
5555

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,9 +568,17 @@ at::Tensor AtenIpexCPUDev::dil_linear_fuse_relu(
568568
if (bias.has_value()) {
569569
at::Tensor bias_vec = bias.value();
570570
const dil::tensor b = dbl::comm::try_gen_dil_tensor(bias_vec);
571-
dil::inner_product_forward::compute(x, w, b, y, /*fuse_relu=*/true);
571+
dil::inner_product_forward::compute(x, w, b, y,
572+
/*src_scales=*/dil::scale_t(),
573+
/*weight_scales=*/dil::scale_t(),
574+
/*dst_scales=*/dil::scale_t(),
575+
/*attr*/dil::attr_t::fuse_relu());
572576
} else {
573-
dil::inner_product_forward::compute(x, w, y, /*fuse_relu=*/true);
577+
dil::inner_product_forward::compute(x, w, y,
578+
/*src_scales=*/dil::scale_t(),
579+
/*weight_scales=*/dil::scale_t(),
580+
/*dst_scales=*/dil::scale_t(),
581+
/*attr*/dil::attr_t::fuse_relu());
574582
}
575583

576584
auto input_size = self.sizes();

torch_ipex/csrc/cpu/dil/dil/operators/inner_product.hpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,21 @@ struct inner_product_forward : public dnnl::inner_product_forward {
1111
const tensor& weights,
1212
const tensor& bias,
1313
tensor& dst,
14-
const bool fuse_relu = false,
1514
const scale_t& src_scales = scale_t(),
1615
const scale_t& weights_scales = scale_t(),
1716
const scale_t& dst_scales = scale_t(),
1817
const attr_t& attr = attr_t(),
1918
const prop_kind aprop_kind = prop_kind::forward,
2019
const lowp_kind alowp_kind = u8s8,
2120
const engine& aengine = engine::cpu_engine()) {
22-
compute_impl</*with_bias=*/true>(src, weights, bias, dst, fuse_relu, src_scales,
21+
compute_impl</*with_bias=*/true>(src, weights, bias, dst, src_scales,
2322
weights_scales, dst_scales, attr,
2423
aprop_kind, alowp_kind, aengine);
2524
}
2625

2726
static void compute(const tensor& src,
2827
const tensor& weights,
2928
tensor& dst,
30-
const bool fuse_relu = false,
3129
const scale_t& src_scales = scale_t(),
3230
const scale_t& weights_scales = scale_t(),
3331
const scale_t& dst_scales = scale_t(),
@@ -36,7 +34,7 @@ struct inner_product_forward : public dnnl::inner_product_forward {
3634
const lowp_kind alowp_kind = u8s8,
3735
const engine& aengine = engine::cpu_engine()) {
3836
static tensor dummy_bias;
39-
compute_impl</*with_bias=*/false>(src, weights, dummy_bias, dst, fuse_relu, src_scales,
37+
compute_impl</*with_bias=*/false>(src, weights, dummy_bias, dst, src_scales,
4038
weights_scales, dst_scales, attr,
4139
aprop_kind, alowp_kind, aengine);
4240
}
@@ -71,7 +69,6 @@ struct inner_product_forward : public dnnl::inner_product_forward {
7169
const tensor& weights,
7270
const tensor& bias,
7371
tensor& dst,
74-
const bool fuse_relu,
7572
const scale_t& src_scales,
7673
const scale_t& weights_scales,
7774
const scale_t& dst_scales,
@@ -87,7 +84,7 @@ struct inner_product_forward : public dnnl::inner_product_forward {
8784
new_dims[0] = src.get_dim(0);
8885
src_.reshape(new_dims);
8986
}
90-
compute_impl_<with_bias>(src_, weights, bias, dst, fuse_relu, src_scales,
87+
compute_impl_<with_bias>(src_, weights, bias, dst, src_scales,
9188
weights_scales, dst_scales, attr, aprop_kind,
9289
alowp_kind, aengine);
9390
}
@@ -97,7 +94,6 @@ struct inner_product_forward : public dnnl::inner_product_forward {
9794
const tensor& weights,
9895
const tensor& bias,
9996
tensor& dst,
100-
const bool fuse_relu,
10197
const scale_t& src_scales,
10298
const scale_t& weights_scales,
10399
const scale_t& dst_scales,
@@ -167,7 +163,7 @@ struct inner_product_forward : public dnnl::inner_product_forward {
167163
}
168164
}
169165
} else {
170-
op_attr = fuse_relu ? attr_t::fuse_relu() : attr;
166+
op_attr = attr;
171167
src_desc = {src.get_dims(), data_type::f32, format_tag::any};
172168
if (src.has_scale()) {
173169
auto src_scale = src.get_scale();

0 commit comments

Comments
 (0)