Skip to content

Commit 060ea58

Browse files
authored
Fuse shuffle (#183)
* Fuse the shuffle pattern and leverage oneDNN implementation because PyTorch does not support shuffle primitive. * Fix code style issues. Formatted by clang-format
1 parent f4af2b9 commit 060ea58

File tree

4 files changed

+384
-341
lines changed

4 files changed

+384
-341
lines changed

tests/cpu/test_jit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,12 @@ def test_output_linear_gelu(self):
872872
prec=5e-3,
873873
levels=['O0'])
874874

875+
def test_channel_shuffle(self):
876+
self._test_output(
877+
ChannelShuffle(10, 16, 50, 50, 4),
878+
torch.rand(10, 16, 50, 50),
879+
kind_in_graph="ipex::shuffle_2d")
880+
875881
def test_jit_function(self):
876882
# test hool trace and script can works for function
877883
def fn(input, weight, bias):

torch_ipex/csrc/cpu/CustomOPs.cpp

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,16 @@ at::Tensor AtenIpexJITDev::dil_convolution_sigmoid(
7979
}
8080

8181
/**
82-
* Dispatch at::matmul + at::div pattern to ipex for jit inference, but only one-element
83-
* tensor and channel dim boadcast is enabled in oneDNN 2.2.0 now. So, for simplicity,this path is just
84-
* a fallback path now.
85-
* output(out) = (tensor1 * tensor2).div(div_input)
86-
*
87-
* @param tensor1
88-
* @param tensor2
89-
* @param out Optinal output provided by user for matmul
90-
* @param div_input Input Tensor for div
91-
* @return Value for the fusion pattern output.
82+
* Dispatch at::matmul + at::div pattern to ipex for jit inference, but only
83+
* one-element tensor and channel dim boadcast is enabled in oneDNN 2.2.0 now.
84+
* So, for simplicity,this path is just a fallback path now. output(out) =
85+
* (tensor1 * tensor2).div(div_input)
86+
*
87+
* @param tensor1
88+
* @param tensor2
89+
* @param out Optinal output provided by user for matmul
90+
* @param div_input Input Tensor for div
91+
* @return Value for the fusion pattern output.
9292
*/
9393
at::Tensor AtenIpexJITDev::dil_matmul_div(
9494
const at::Tensor& tensor1,
@@ -101,19 +101,18 @@ at::Tensor AtenIpexJITDev::dil_matmul_div(
101101
if (out.defined()) {
102102
at::matmul_out(out, tensor1, tensor2);
103103
return out.div(div_input);
104-
}
104+
}
105105
auto output = at::matmul(tensor1, tensor2);
106106
return output.div(div_input);
107-
108-
109107
}
110108

111109
/**
112-
*Dispatch at::matmul + at::div pattern to ipex for jit inference, but only bmm with same shape for
113-
*tensor1 and tensor2 and scalar input for div will be dispatched to oneDNN kernel. Otherwise will fallback.
114-
*For oneDNN kernel, scalar input will be used as the scale attribute for matmul primitive.
110+
*Dispatch at::matmul + at::div pattern to ipex for jit inference, but only bmm
111+
*with same shape for tensor1 and tensor2 and scalar input for div will be
112+
*dispatched to oneDNN kernel. Otherwise will fallback. For oneDNN kernel,
113+
*scalar input will be used as the scale attribute for matmul primitive.
115114
*output(out) = (tensor1 * tensor2).div(div_input_scalar).
116-
*ToDo: matmul + div scalar for matmul with other shape
115+
*ToDo: matmul + div scalar for matmul with other shape
117116
*
118117
*@param tensor1
119118
*@param tensor2
@@ -131,8 +130,8 @@ at::Tensor AtenIpexJITDev::dil_matmul_div(
131130
#endif
132131
auto dim_tensor1 = tensor1.dim();
133132
auto dim_tensor2 = tensor2.dim();
134-
if (dim_tensor1 == dim_tensor2 && dim_tensor1 >= 3) {
135-
float scale = 1.0 / div_input.to<float>();
133+
if (dim_tensor1 == dim_tensor2 && dim_tensor1 >= 3) {
134+
float scale = 1.0f / div_input.to<float>();
136135
return bmm_impl(tensor1, tensor2, out, ideep::attr_t(), scale);
137136
} else {
138137
return AtenIpexJITDev::dil_matmul_div(tensor1, tensor2, out, at::native::wrapped_scalar_tensor(div_input));
@@ -309,26 +308,24 @@ at::Tensor AtenIpexJITDev::dil_linear_fuse_eltwise(
309308
return linear_impl(self, weight, bias, attr);
310309
}
311310

312-
313311
/**
314312
*Dispatch Linear + Add fusion pattern to ipex oneDNN kernel for inference mode.
315313
*This feature might improve performance for cases like residual learning blocks
316-
*Pattern: accum = accum * alpha + Linear(self, weight, bias)
314+
*Pattern: accum = accum * alpha + Linear(self, weight, bias)
317315
*
318-
*@param self Activatin input for Linear
316+
*@param self Activatin input for Linear
319317
*@param weight Weight for Linear
320318
*@param bias Bias for Linear
321319
*@param accum One input for add operation, another is the output of Linear
322-
*@param alpha Scale for accum when doing add operation.
320+
*@param alpha Scale for accum when doing add operation.
323321
*
324-
*@return Value for the fusion pattern output.
322+
*@return Value for the fusion pattern output.
325323
*/
326-
at::Tensor AtenIpexJITDev::dil_linear_add(
327-
const at::Tensor& self,
328-
const at::Tensor& weight,
329-
const at::Tensor& bias,
330-
at::Tensor& accumu,
331-
at::Scalar alpha) {
324+
at::Tensor AtenIpexJITDev::dil_linear_add(const at::Tensor &self,
325+
const at::Tensor &weight,
326+
const at::Tensor &bias,
327+
at::Tensor &accumu,
328+
at::Scalar alpha) {
332329
#if defined(IPEX_PROFILE_OP)
333330
RECORD_FUNCTION("AtenIpexJITDev::dil_linear_add", std::vector<c10::IValue>({}));
334331
#endif
@@ -468,5 +465,17 @@ at::Tensor AtenIpexJITDev::dil_layernorm(
468465
at::native_layer_norm(input, normalized_shape, weight, bias, eps));
469466
}
470467

468+
at::Tensor AtenIpexJITDev::dil_shuffle(const at::Tensor &self,
469+
at::IntArrayRef view_shape, int64_t dim0,
470+
int64_t dim1) {
471+
ideep::tensor _self = itensor_view_from_dense(self);
472+
auto group_dim = dim0 < dim1 ? dim0 : dim1;
473+
auto groups = view_shape[group_dim];
474+
auto output = at::empty_like(self);
475+
ideep::tensor _output = itensor_view_from_dense(output);
476+
ideep::channel_shuffle_forward::compute(_self, _output, groups, group_dim);
477+
return output;
478+
}
479+
471480
} // namespace cpu
472481
} // namespace torch_ipex

torch_ipex/csrc/cpu/CustomOPs.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ class AtenIpexJITDev {
194194
bool weight_channels_last, bool weight_prepacked, at::Tensor &accumu,
195195
at::Scalar alpha);
196196

197+
static at::Tensor dil_shuffle(const at::Tensor &self,
198+
at::IntArrayRef view_shape, int64_t dim0,
199+
int64_t dim1);
200+
197201
// int8 op
198202
static at::Tensor dil_qembeddingbag(const at::Tensor weight,
199203
const at::Tensor indices,

0 commit comments

Comments
 (0)