Skip to content

Commit 688dadc

Browse files
authored
[LLGA] map div to llga and change scalar input to tensor (#192)
1 parent 709111a commit 688dadc

File tree

4 files changed

+85
-8
lines changed

4 files changed

+85
-8
lines changed

tests/cpu/test_jit_llga_quantization_fuser.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,77 @@ def forward(self, x):
400400
self.assertFused(graph, ['aten::_convolution', 'aten::relu', 'aten::quantize_per_channel'])
401401
self.checkPatterns(graph, patterns)
402402

403+
@llga_test_env
404+
def test_bmm_div_scalar(self):
405+
class M(nn.Module):
406+
def __init__(self, div_value):
407+
super(M, self).__init__()
408+
self.div_value = div_value
409+
410+
def forward(self, x, y):
411+
mm_res = torch.matmul(x, y)
412+
return mm_res.div(self.div_value)
413+
414+
x = torch.randn(128, 16, 384, 64)
415+
y = torch.randn(128, 1, 64, 384)
416+
patterns = [
417+
["aten::dequantize", "aten::matmul", "aten::div"],
418+
]
419+
m = M(8.)
420+
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, config_name="bmm_div_scalar", qscheme=torch.per_tensor_affine)
421+
# TODO: enable the below check when matmul-div fusion is supported in the backend
422+
# self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
423+
# self.assertFused(graph, ['aten::matmul', 'aten::div'])
424+
# self.checkPatterns(graph, patterns)
425+
426+
@llga_test_env
427+
def test_bmm_div_identity(self):
428+
class M(nn.Module):
429+
def __init__(self, div_value):
430+
super(M, self).__init__()
431+
self.div_value = div_value
432+
433+
def forward(self, x, y):
434+
mm_res = torch.matmul(x, y)
435+
return mm_res.div(self.div_value)
436+
437+
x = torch.randn(128, 16, 384, 64)
438+
y = torch.randn(128, 1, 64, 384)
439+
patterns = [
440+
["aten::dequantize", "aten::matmul"],
441+
]
442+
m = M(1.)
443+
graph = self.checkQuantizeTrace(m, [x, y], atol=2e-1, config_name="bmm_div_identity", qscheme=torch.per_tensor_affine)
444+
# divide by 1 should be removed by Constant Propagation
445+
self.assertGraphContainsExactly(graph, "aten::div", 0, consider_subgraphs=True)
446+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
447+
self.assertFused(graph, ['aten::matmul'])
448+
# TODO: enable this check when int8 matmul is supported in the backend
449+
# self.checkPatterns(graph, patterns)
450+
451+
@llga_test_env
452+
def test_bmm_div_tensor(self):
453+
class M(nn.Module):
454+
def __init__(self):
455+
super(M, self).__init__()
456+
457+
def forward(self, x, y, z):
458+
mm_res = torch.matmul(x, y)
459+
return mm_res.div(z)
460+
461+
x = torch.randn(128, 16, 384, 64)
462+
y = torch.randn(128, 1, 64, 384)
463+
patterns = [
464+
["aten::dequantize", "aten::matmul", "aten::div"],
465+
]
466+
for z in [torch.randn(384), torch.randn(128, 16, 384, 384)]:
467+
m = M()
468+
graph = self.checkQuantizeTrace(m, [x, y, z], atol=2e-1, config_name="bmm_div_tensor", qscheme=torch.per_tensor_affine)
469+
# TODO: enable the below check when matmul-div fusion is supported in the backend
470+
# self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
471+
# self.assertFused(graph, ['aten::matmul', 'aten::div'])
472+
# self.checkPatterns(graph, patterns)
473+
403474
class TestShapeFallback(JitLlgaTestCase):
404475
@unittest.skipIf(True, 'Size peephole optimization not enabled yet')
405476
@llga_test_env

torch_ipex/csrc/jit/codegen/onednn/graph_helper.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ Operator createOperator(Node* node) {
159159
.setAttr("keep_stats", false);
160160
} else if (node->kind() == Symbol::aten("add")) {
161161
return makeBinaryOp(node, opkind::Add);
162+
} else if (node->kind() == Symbol::aten("div")) {
163+
return makeBinaryOp(node, opkind::Divide);
162164
} else if (node->kind() == Symbol::aten("tanh")) {
163165
return makeEltwiseOp(node, opkind::Tanh);
164166
} else if (node->kind() == Symbol::aten("relu")) {

torch_ipex/csrc/jit/codegen/onednn/prepare_binary.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ static void ConvertScalarToTensor(Block* block) {
3838
ConvertScalarToTensor(sub);
3939
}
4040

41-
if (node->kind() == aten::add || node->kind() == aten::mul) {
41+
if (node->kind() == aten::add || node->kind() == aten::mul ||
42+
node->kind() == aten::div) {
4243
mayConvertScalarInputToTensor(node);
4344
}
4445
}
@@ -71,24 +72,25 @@ static void DecomposeFusedAdd(Block* block) {
7172
}
7273
}
7374

74-
static void EliminateIdentityMulAdd(Block* block) {
75+
static void EliminateIdentityMulAddDiv(Block* block) {
7576
for (auto node : block->nodes()) {
7677
for (auto sub : node->blocks()) {
77-
EliminateIdentityMulAdd(sub);
78+
EliminateIdentityMulAddDiv(sub);
7879
}
7980

8081
if ((node->kind() == aten::add && compareConstValue(node->input(1), 0.0)) ||
81-
(node->kind() == aten::mul && compareConstValue(node->input(1), 1.0))) {
82+
(node->kind() == aten::mul && compareConstValue(node->input(1), 1.0)) ||
83+
(node->kind() == aten::div && compareConstValue(node->input(1), 1.0))) {
8284
node->output()->replaceAllUsesWith(node->input(0));
8385
}
8486
}
8587
}
8688

8789
void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph) {
8890
DecomposeFusedAdd(graph->block());
89-
EliminateIdentityMulAdd(graph->block());
91+
EliminateIdentityMulAddDiv(graph->block());
9092
EliminateDeadCode(graph);
91-
// ConvertScalarToTensor must be placed after EliminateIdentityMulAdd
93+
// ConvertScalarToTensor must be placed after EliminateIdentityMulAddDiv
9294
ConvertScalarToTensor(graph->block());
9395
// TODO: after conv-bn folding, bias will become bias? (Optional) after this pass
9496
// and will lose it when using mustNotBeNone to check Optional Bias

torch_ipex/csrc/jit/codegen/onednn/prepare_binary.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ namespace onednn {
1111
//
1212
// The pass does the following:
1313
//
14-
// - (1). Convert scalar input of aten::add and aten::mul into Float tensor with
14+
// - (1). Convert scalar input of aten::add, aten::mul and aten::div into Float
15+
// tensor with
1516
// dimension [1]
1617
//
1718
// - (2). Decompose fused add into aten::mul + aten::add when alpha != 1.0
1819
//
19-
// - (3). Eliminate identity add/mul, i.e., tensor + 0, tensor * 1
20+
// - (3). Eliminate identity add/mul/div, i.e., tensor + 0, tensor * 1,
21+
// tensor / 1
2022
//
2123
// (1) and (2) are in the purpose of aligning with the OP spec of LLGA.
2224
// (3) is an optimization pass to remove the redundant calculation

0 commit comments

Comments
 (0)