Skip to content

Commit 59b29ab

Browse files
sanchitintelyifeizh2chunyuan-w
authored
Graph Compiler support (#523)
* Prepare the bridge-code for SC MHA Fusion pattern type #2 * Map aten::permute * For now, match ipex::softmax_, although ideally, LLGA shouldn't see it * temporarily update LLGA branch to Yifei's * Do not skip passing any PyTorch ops to oneDNN Graph * Remove redundant temporary change * Only map permute to oneDNN Graph if it's safe * Trivial changes * Fix logic for mapping permute * [skip-ci] Add comment & remove extra newline * Cherry-picked LLVM auto-detection on top of oneDNN dev-graph-preview4 * Update third_party/llga * Do not build SC by default * Update third_party llga with public repo * Make required changes * Fix style * Avoid unnecessary function call * Fix contiguous bug & formatting * Add unit-test for contiguous * Use AliasDb to safely map view ops to oneDNN Graph * Temporarily change llga gitmodule path * Fix formatting * Add test_mha_pattern for int8 not working * Revert temp changes & skip SC unit-tests * Revert temp change in .gitmodules * Update skip reason for UT * fix int8-bf16 & int8-fp32 UT * update skip reason * Update unit-tests * Change "Semi-Compiler" to "Graph Compiler" * Rectify padding example * Revert removal of isSupported * Fix harmless typo * Remove unused variable that should've been removed earlier * Fix clang-format Co-authored-by: yifeizh2 <yifei.zhang@intel.com> Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com>
1 parent cf4ea24 commit 59b29ab

File tree

5 files changed

+210
-56
lines changed

5 files changed

+210
-56
lines changed

intel_extension_for_pytorch/csrc/jit/codegen/onednn/graph_helper.cpp

Lines changed: 60 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,15 @@ Operator makeDequantOp(Node* node, Node* input_node) {
168168
}
169169
}
170170

171-
Operator createOperator(Node* node) {
171+
Operator LlgaGraphHelper::createOperator(Node* node) const {
172172
// switch does not allow non-constexpr function, to make the Symbol constexpr,
173173
// we must add them to the list in aten/src/ATen/core/interned_strings.h to
174174
// explicitly use interned strings as symbols. Thus, we use if-else here
175175
// instead of switch to avoid having to apply patch on PyTorch.
176-
if (node->kind() == Symbol::aten("conv2d")) {
176+
auto nodeKind = node->kind();
177+
// Calling node->kind() only once so that the compiler would create a
178+
// jump-table
179+
if (nodeKind == Symbol::aten("conv2d")) {
177180
fixConvOptionalBias(node);
178181
return Operator(node, opkind::Convolution)
179182
.setInput(0, 1, 2)
@@ -185,7 +188,7 @@ Operator createOperator(Node* node) {
185188
.setAttr("dilations", Operator::Ints, 5)
186189
.setAttr("groups", Operator::Int, 6)
187190
.setAttr("filter_format", std::string("OIX"));
188-
} else if (node->kind() == Symbol::aten("_convolution")) {
191+
} else if (nodeKind == Symbol::aten("_convolution")) {
189192
bool transposed = Operator::Bool(node, 6);
190193
REQ(!transposed);
191194

@@ -199,7 +202,7 @@ Operator createOperator(Node* node) {
199202
.setAttr("dilations", Operator::Ints, 5)
200203
.setAttr("groups", Operator::Int, 8)
201204
.setAttr("filter_format", std::string("OIX"));
202-
} else if (node->kind() == Symbol::aten("batch_norm")) {
205+
} else if (nodeKind == Symbol::aten("batch_norm")) {
203206
auto training = toIValue(node->input(5));
204207
REQ(training.has_value()); // cannot get training status in script mode
205208
REQ(!training->toBool()); // TODO: support bn training
@@ -208,40 +211,40 @@ Operator createOperator(Node* node) {
208211
.setOutput(0)
209212
.setAttr("data_format", std::string("NCX"))
210213
.setAttr("epsilon", Operator::Float, 7);
211-
} else if (node->kind() == Symbol::aten("layer_norm")) {
214+
} else if (nodeKind == Symbol::aten("layer_norm")) {
212215
auto normalized_shape = Operator::Ints(node, 1);
213216
REQ(normalized_shape.size() == 1);
214217
return Operator(node, opkind::LayerNorm)
215218
.setInput(0, 2, 3)
216219
.setOutput(0)
217220
.setAttr("epsilon", Operator::Float, 4)
218221
.setAttr("keep_stats", false);
219-
} else if (node->kind() == Symbol::aten("add")) {
222+
} else if (nodeKind == Symbol::aten("add")) {
220223
return makeBinaryOp(node, opkind::Add);
221-
} else if (node->kind() == Symbol::aten("div")) {
224+
} else if (nodeKind == Symbol::aten("div")) {
222225
return makeBinaryOp(node, opkind::Divide);
223-
} else if (node->kind() == Symbol::aten("tanh")) {
226+
} else if (nodeKind == Symbol::aten("tanh")) {
224227
return makeEltwiseOp(node, opkind::Tanh);
225-
} else if (node->kind() == Symbol::aten("relu")) {
228+
} else if (nodeKind == Symbol::aten("relu")) {
226229
return makeEltwiseOp(node, opkind::ReLU);
227-
} else if (node->kind() == Symbol::aten("elu")) {
230+
} else if (nodeKind == Symbol::aten("elu")) {
228231
return makeEltwiseOp(node, opkind::Elu)
229232
.setAttr("alpha", Operator::Float, 1);
230-
} else if (node->kind() == Symbol::aten("sigmoid")) {
233+
} else if (nodeKind == Symbol::aten("sigmoid")) {
231234
return makeEltwiseOp(node, opkind::Sigmoid);
232-
} else if (node->kind() == Symbol::aten("gelu")) {
235+
} else if (nodeKind == Symbol::aten("gelu")) {
233236
return makeEltwiseOp(node, opkind::GELU);
234-
} else if (node->kind() == Symbol::aten("sqrt")) {
237+
} else if (nodeKind == Symbol::aten("sqrt")) {
235238
return makeEltwiseOp(node, opkind::Sqrt);
236-
} else if (node->kind() == Symbol::aten("abs")) {
239+
} else if (nodeKind == Symbol::aten("abs")) {
237240
return makeEltwiseOp(node, opkind::Abs);
238-
} else if (node->kind() == Symbol::aten("square")) {
241+
} else if (nodeKind == Symbol::aten("square")) {
239242
return makeEltwiseOp(node, opkind::Square);
240-
} else if (node->kind() == Symbol::aten("hardtanh")) {
243+
} else if (nodeKind == Symbol::aten("hardtanh")) {
241244
return makeEltwiseOp(node, opkind::HardTanh)
242245
.setAttr("min", Operator::Float, 1)
243246
.setAttr("max", Operator::Float, 2);
244-
} else if (node->kind() == Symbol::aten("softmax")) {
247+
} else if (nodeKind == Symbol::aten("softmax")) {
245248
auto dim0 = getDimensions(node->input(0));
246249
REQ(dim0.has_value());
247250

@@ -253,7 +256,7 @@ Operator createOperator(Node* node) {
253256
.setInput(0)
254257
.setOutput(0)
255258
.setAttr("axis", axis);
256-
} else if (node->kind() == Symbol::aten("cat")) {
259+
} else if (nodeKind == Symbol::aten("cat")) {
257260
return makeWildcardOp(node); // TODO: remove once Concat is supported
258261

259262
auto o = Operator(node, opkind::Concat);
@@ -273,7 +276,7 @@ Operator createOperator(Node* node) {
273276
for (auto input : listConstruct->inputs())
274277
o.setInputValue(input);
275278
return o.setOutput(0).setAttr("axis", Operator::Int, 1);
276-
} else if (node->kind() == Symbol::aten("max_pool2d")) {
279+
} else if (nodeKind == Symbol::aten("max_pool2d")) {
277280
auto rounding_type = Operator::Bool(node, 5) ? "ceil" : "floor";
278281
return Operator(node, opkind::MaxPool)
279282
.setInput(0)
@@ -285,7 +288,7 @@ Operator createOperator(Node* node) {
285288
.setAttr("pads_end", Operator::Ints, 3)
286289
.setAttr("dilations", Operator::Ints, 4)
287290
.setAttr("rounding_type", std::string(rounding_type));
288-
} else if (node->kind() == Symbol::aten("avg_pool2d")) {
291+
} else if (nodeKind == Symbol::aten("avg_pool2d")) {
289292
auto rounding_type = Operator::Bool(node, 4) ? "ceil" : "floor";
290293
auto divisor_override = toIValue(node->input(6));
291294
REQ(divisor_override->isNone());
@@ -299,17 +302,17 @@ Operator createOperator(Node* node) {
299302
.setAttr("pads_end", Operator::Ints, 3)
300303
.setAttr("exclude_pad", !Operator::Bool(node, 5))
301304
.setAttr("rounding_type", std::string(rounding_type));
302-
} else if (node->kind() == Symbol::aten("matmul")) {
305+
} else if (nodeKind == Symbol::aten("matmul")) {
303306
auto dim0 = getDimensions(node->input(0)).value_or(-1);
304307
auto dim1 = getDimensions(node->input(1)).value_or(-1);
305308
// TODO: support all shape combinations
306309
REQ((dim0 == 2 && dim1 == 2) || (dim0 == 4 && dim1 == 4) ||
307310
(dim0 == 3 && dim1 == 2));
308311
// fall through
309312
return Operator(node, opkind::MatMul).setInput(0, 1).setOutput(0);
310-
} else if (node->kind() == Symbol::aten("mm")) {
313+
} else if (nodeKind == Symbol::aten("mm")) {
311314
return Operator(node, opkind::MatMul).setInput(0, 1).setOutput(0);
312-
} else if (node->kind() == Symbol::aten("linear")) {
315+
} else if (nodeKind == Symbol::aten("linear")) {
313316
auto dim0 = getDimensions(node->input(0)).value_or(-1);
314317
auto dim1 = getDimensions(node->input(1)).value_or(-1);
315318
// REQ(dim1 == 2);
@@ -318,9 +321,9 @@ Operator createOperator(Node* node) {
318321
.setInput(0, 1, 2)
319322
.setOutput(0)
320323
.setAttr("transpose_b", true);
321-
} else if (node->kind() == Symbol::aten("to")) {
324+
} else if (nodeKind == Symbol::aten("to")) {
322325
return Operator(node, opkind::TypeCast).setInput(0).setOutput(0);
323-
} else if (node->kind() == Symbol::aten("quantize_per_tensor")) {
326+
} else if (nodeKind == Symbol::aten("quantize_per_tensor")) {
324327
// TODO: how to handle this case
325328
REQ(node->input(1)->node()->kind() != Symbol::aten("q_scale"));
326329

@@ -337,7 +340,7 @@ Operator createOperator(Node* node) {
337340
.setAttr("zps", Operator::IntToVector, 2)
338341
.setAttr("out_type", Operator::String, 3)
339342
.setAttr("qtype", std::string("per_tensor"));
340-
} else if (node->kind() == Symbol::aten("quantize_per_channel")) {
343+
} else if (nodeKind == Symbol::aten("quantize_per_channel")) {
341344
return Operator(node, opkind::Quantize)
342345
.setInput(0)
343346
.setOutput(0)
@@ -346,7 +349,7 @@ Operator createOperator(Node* node) {
346349
.setAttr("axis", Operator::Int, 3)
347350
.setAttr("out_type", Operator::String, 4)
348351
.setAttr("qtype", std::string("per_channel"));
349-
} else if (node->kind() == Symbol::aten("dequantize")) {
352+
} else if (nodeKind == Symbol::aten("dequantize")) {
350353
if (node->numAttributes() == 0) {
351354
Node* input_node = node->input(0)->node();
352355
TORCH_CHECK(
@@ -384,17 +387,36 @@ Operator createOperator(Node* node) {
384387
.setAttr("qtype", node->s(Symbol::attr("qtype")));
385388
}
386389
}
390+
} else if (nodeKind == Symbol::aten("permute")) {
391+
REQ(aliasDb_->hasInputWriters(node) == false) {
392+
return Operator(node, opkind::StaticTranspose)
393+
.setInput(0)
394+
.setOutput(0)
395+
.setAttr("order", toIValue(node->input(1))->toIntVector());
396+
}
397+
} else if (nodeKind == Symbol::aten("contiguous")) {
398+
// Contiguous should only be mapped to oneDNN Graph if the destination
399+
// memory-layout is different than the source memory-format
400+
// Strides would be different, but shape would be same
401+
auto typeOfInput = node->input(0)->type()->expect<TensorType>();
402+
auto typeOfOutput = node->output(0)->type()->expect<TensorType>();
403+
auto inputStrides = typeOfInput->strides().concrete_sizes();
404+
auto outputStrides = typeOfOutput->strides().concrete_sizes();
405+
REQ(inputStrides != outputStrides);
406+
return Operator(node, opkind::Reorder).setInput(0).setOutput(0);
387407
}
408+
409+
GRAPH_DEBUG("Making ", nodeKind.toQualString(), " a wildcard");
388410
return makeWildcardOp(node);
389411
}
390412

391-
dnnl::graph::op createLlgaOp(Node* node) {
413+
dnnl::graph::op LlgaGraphHelper::createLlgaOp(Node* node) {
392414
return createOperator(node).llgaOp();
393415
}
394416

395-
bool isSupported(Node* node) {
417+
bool LlgaGraphHelper::isSupported(Node* node) const {
396418
return createOperator(node).kind() != opkind::Wildcard;
397-
};
419+
}
398420

399421
DeviceType inferDeviceFromValue(Value* v) {
400422
auto tt = v->type()->cast<TensorType>();
@@ -470,22 +492,21 @@ LlgaGraphHelper::LlgaGraphHelper(
470492
auto deviceType = inferDevice(graph);
471493
auto engineKind = getLlgaEngineKind(deviceType);
472494
dnnl::graph::graph g{engineKind};
473-
495+
aliasDb_ = torch::make_unique<torch::jit::AliasDb>(graph);
474496
GRAPH_DEBUG("Constructing LLGA graph");
475497
// TODO: select nodes in top-level block for now
476498
for (auto* node : graph->block()->nodes()) {
499+
auto kindOfNode = node->kind();
477500
auto op = createLlgaOp(node);
478501

479502
try {
480503
g.add_op(op);
504+
GRAPH_DEBUG(" Added node ", kindOfNode.toQualString());
481505
} catch (std::exception& e) {
482-
GRAPH_DEBUG(
483-
"The backend failed to add node ", node->kind().toQualString());
506+
GRAPH_DEBUG("The backend failed to add node ", kindOfNode.toQualString());
484507
g.add_op(makeWildcardOp(node).llgaOp());
485508
}
486509

487-
GRAPH_DEBUG(" Added node ", node->kind().toQualString());
488-
489510
for (Value* input : node->inputs()) {
490511
tensorIdToValue_.emplace(input->unique(), input);
491512
}
@@ -528,20 +549,6 @@ bool LlgaGraphHelper::shouldMerge(Node* toMerge, Node* subgraph) {
528549
opToOwningPartition_.get(subgraph);
529550
}
530551

531-
bool isViewOp(Node* n) {
532-
switch (n->kind()) {
533-
case aten::view:
534-
case aten::view_as:
535-
case aten::reshape:
536-
case aten::reshape_as:
537-
case aten::transpose:
538-
case aten::expand:
539-
case aten::expand_as:
540-
return true;
541-
}
542-
return false;
543-
}
544-
545552
void checkAndRemoveAttr(Node* n, std::string attr) {
546553
TORCH_CHECK(
547554
n->hasAttributeS(attr),
@@ -584,9 +591,6 @@ bool LlgaGraphHelper::shouldConsiderForMerge(Node* node) {
584591
if (isLlgaSubgraph(node)) {
585592
return true;
586593
}
587-
if (isViewOp(node)) {
588-
return false;
589-
}
590594
// For a partition composed of 1 single quant, 1 single dequant or 1 single to
591595
// do not rewrite it in the bridge, so that the FWK may have chances
592596
// to optimize single int8/bf16 op that LLGA does not support
@@ -662,9 +666,11 @@ size_t LlgaGraphHelper::countSupportedOps(
662666
const std::shared_ptr<Graph>& graph) const {
663667
// TODO: count nodes in top-level block for now
664668
size_t cnt = 0;
665-
for (auto* node : graph->block()->nodes())
666-
if (isSupported(node))
669+
for (auto* node : graph->block()->nodes()) {
670+
if (isSupported(node)) {
667671
cnt++;
672+
}
673+
}
668674
return cnt;
669675
}
670676

intel_extension_for_pytorch/csrc/jit/codegen/onednn/graph_helper.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <oneapi/dnnl/dnnl_graph.hpp>
44
#include <torch/csrc/jit/ir/ir.h>
5+
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
56
#include "jit/codegen/onednn/operator.h"
67

78
namespace torch {
@@ -62,11 +63,19 @@ class LlgaGraphHelper {
6263

6364
std::map<size_t, Value*> getTensorIdToValue() const;
6465

66+
dnnl::graph::op createLlgaOp(Node* node);
67+
68+
Operator createOperator(Node* node) const;
69+
70+
bool isSupported(Node* node) const;
71+
6572
private:
6673
size_t countSupportedOps(const std::shared_ptr<Graph>& graph) const;
6774

6875
bool isSingleQuantDequantTo(Node* node);
6976

77+
std::unique_ptr<AliasDb> aliasDb_ = nullptr;
78+
7079
OpPartitionMap opToOwningPartition_;
7180
std::vector<dnnl::graph::partition> partitions_;
7281
std::map<size_t, Value*>

intel_extension_for_pytorch/csrc/jit/codegen/onednn/prepare_binary.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@ void mayConvertScalarInputToTensor(Node* node) {
2727
auto t = g->insert(
2828
aten::as_tensor, {scalar}, {{"dtype", at::ScalarType::Float}});
2929
// tensor(42.0) : Float([]) --> tensor([42.0]) : Float([1])
30+
c10::optional<size_t> t_dim = 1;
31+
auto target_type = TensorTypePtr(
32+
TensorType::create(at::ScalarType::Float, at::kCPU, t_dim, false));
33+
target_type = target_type->withSizes({1});
34+
t->setType(target_type);
3035
auto unsqueezed = g->insert(aten::unsqueeze, {t, 0});
36+
unsqueezed->setType(target_type);
3137
node->replaceInput(1, unsqueezed);
3238
// Add a mark here and convert tensor back to scalar later on for unfused
3339
// add/div

0 commit comments

Comments
 (0)