@@ -11,23 +11,21 @@ struct inner_product_forward : public dnnl::inner_product_forward {
11
11
const tensor& weights,
12
12
const tensor& bias,
13
13
tensor& dst,
14
- const bool fuse_relu = false ,
15
14
const scale_t & src_scales = scale_t (),
16
15
const scale_t& weights_scales = scale_t(),
17
16
const scale_t& dst_scales = scale_t(),
18
17
const attr_t& attr = attr_t(),
19
18
const prop_kind aprop_kind = prop_kind::forward,
20
19
const lowp_kind alowp_kind = u8s8,
21
20
const engine& aengine = engine::cpu_engine()) {
22
- compute_impl</* with_bias=*/ true >(src, weights, bias, dst, fuse_relu, src_scales,
21
+ compute_impl</* with_bias=*/ true >(src, weights, bias, dst, src_scales,
23
22
weights_scales, dst_scales, attr,
24
23
aprop_kind, alowp_kind, aengine);
25
24
}
26
25
27
26
static void compute (const tensor& src,
28
27
const tensor& weights,
29
28
tensor& dst,
30
- const bool fuse_relu = false ,
31
29
const scale_t & src_scales = scale_t (),
32
30
const scale_t& weights_scales = scale_t(),
33
31
const scale_t& dst_scales = scale_t(),
@@ -36,7 +34,7 @@ struct inner_product_forward : public dnnl::inner_product_forward {
36
34
const lowp_kind alowp_kind = u8s8,
37
35
const engine& aengine = engine::cpu_engine()) {
38
36
static tensor dummy_bias;
39
- compute_impl</* with_bias=*/ false >(src, weights, dummy_bias, dst, fuse_relu, src_scales,
37
+ compute_impl</* with_bias=*/ false >(src, weights, dummy_bias, dst, src_scales,
40
38
weights_scales, dst_scales, attr,
41
39
aprop_kind, alowp_kind, aengine);
42
40
}
@@ -71,7 +69,6 @@ struct inner_product_forward : public dnnl::inner_product_forward {
71
69
const tensor& weights,
72
70
const tensor& bias,
73
71
tensor& dst,
74
- const bool fuse_relu,
75
72
const scale_t & src_scales,
76
73
const scale_t & weights_scales,
77
74
const scale_t & dst_scales,
@@ -87,7 +84,7 @@ struct inner_product_forward : public dnnl::inner_product_forward {
87
84
new_dims[0 ] = src.get_dim (0 );
88
85
src_.reshape (new_dims);
89
86
}
90
- compute_impl_<with_bias>(src_, weights, bias, dst, fuse_relu, src_scales,
87
+ compute_impl_<with_bias>(src_, weights, bias, dst, src_scales,
91
88
weights_scales, dst_scales, attr, aprop_kind,
92
89
alowp_kind, aengine);
93
90
}
@@ -97,7 +94,6 @@ struct inner_product_forward : public dnnl::inner_product_forward {
97
94
const tensor& weights,
98
95
const tensor& bias,
99
96
tensor& dst,
100
- const bool fuse_relu,
101
97
const scale_t & src_scales,
102
98
const scale_t & weights_scales,
103
99
const scale_t & dst_scales,
@@ -167,7 +163,7 @@ struct inner_product_forward : public dnnl::inner_product_forward {
167
163
}
168
164
}
169
165
} else {
170
- op_attr = fuse_relu ? attr_t::fuse_relu () : attr;
166
+ op_attr = attr;
171
167
src_desc = {src.get_dims (), data_type::f32 , format_tag::any};
172
168
if (src.has_scale ()) {
173
169
auto src_scale = src.get_scale ();
0 commit comments