@@ -168,12 +168,15 @@ Operator makeDequantOp(Node* node, Node* input_node) {
168
168
}
169
169
}
170
170
171
- Operator createOperator (Node* node) {
171
+ Operator LlgaGraphHelper:: createOperator (Node* node) const {
172
172
// switch does not allow non-constexpr function, to make the Symbol constexpr,
173
173
// we must add them to the list in aten/src/ATen/core/interned_strings.h to
174
174
// explicitly use interned strings as symbols. Thus, we use if-else here
175
175
// instead of switch to avoid having to apply patch on PyTorch.
176
- if (node->kind () == Symbol::aten (" conv2d" )) {
176
+ auto nodeKind = node->kind ();
177
+ // Calling node->kind() only once so that the compiler would create a
178
+ // jump-table
179
+ if (nodeKind == Symbol::aten (" conv2d" )) {
177
180
fixConvOptionalBias (node);
178
181
return Operator (node, opkind::Convolution)
179
182
.setInput (0 , 1 , 2 )
@@ -185,7 +188,7 @@ Operator createOperator(Node* node) {
185
188
.setAttr (" dilations" , Operator::Ints, 5 )
186
189
.setAttr (" groups" , Operator::Int, 6 )
187
190
.setAttr (" filter_format" , std::string (" OIX" ));
188
- } else if (node-> kind () == Symbol::aten (" _convolution" )) {
191
+ } else if (nodeKind == Symbol::aten (" _convolution" )) {
189
192
bool transposed = Operator::Bool (node, 6 );
190
193
REQ (!transposed);
191
194
@@ -199,7 +202,7 @@ Operator createOperator(Node* node) {
199
202
.setAttr (" dilations" , Operator::Ints, 5 )
200
203
.setAttr (" groups" , Operator::Int, 8 )
201
204
.setAttr (" filter_format" , std::string (" OIX" ));
202
- } else if (node-> kind () == Symbol::aten (" batch_norm" )) {
205
+ } else if (nodeKind == Symbol::aten (" batch_norm" )) {
203
206
auto training = toIValue (node->input (5 ));
204
207
REQ (training.has_value ()); // cannot get training status in script mode
205
208
REQ (!training->toBool ()); // TODO: support bn training
@@ -208,40 +211,40 @@ Operator createOperator(Node* node) {
208
211
.setOutput (0 )
209
212
.setAttr (" data_format" , std::string (" NCX" ))
210
213
.setAttr (" epsilon" , Operator::Float, 7 );
211
- } else if (node-> kind () == Symbol::aten (" layer_norm" )) {
214
+ } else if (nodeKind == Symbol::aten (" layer_norm" )) {
212
215
auto normalized_shape = Operator::Ints (node, 1 );
213
216
REQ (normalized_shape.size () == 1 );
214
217
return Operator (node, opkind::LayerNorm)
215
218
.setInput (0 , 2 , 3 )
216
219
.setOutput (0 )
217
220
.setAttr (" epsilon" , Operator::Float, 4 )
218
221
.setAttr (" keep_stats" , false );
219
- } else if (node-> kind () == Symbol::aten (" add" )) {
222
+ } else if (nodeKind == Symbol::aten (" add" )) {
220
223
return makeBinaryOp (node, opkind::Add);
221
- } else if (node-> kind () == Symbol::aten (" div" )) {
224
+ } else if (nodeKind == Symbol::aten (" div" )) {
222
225
return makeBinaryOp (node, opkind::Divide);
223
- } else if (node-> kind () == Symbol::aten (" tanh" )) {
226
+ } else if (nodeKind == Symbol::aten (" tanh" )) {
224
227
return makeEltwiseOp (node, opkind::Tanh);
225
- } else if (node-> kind () == Symbol::aten (" relu" )) {
228
+ } else if (nodeKind == Symbol::aten (" relu" )) {
226
229
return makeEltwiseOp (node, opkind::ReLU);
227
- } else if (node-> kind () == Symbol::aten (" elu" )) {
230
+ } else if (nodeKind == Symbol::aten (" elu" )) {
228
231
return makeEltwiseOp (node, opkind::Elu)
229
232
.setAttr (" alpha" , Operator::Float, 1 );
230
- } else if (node-> kind () == Symbol::aten (" sigmoid" )) {
233
+ } else if (nodeKind == Symbol::aten (" sigmoid" )) {
231
234
return makeEltwiseOp (node, opkind::Sigmoid);
232
- } else if (node-> kind () == Symbol::aten (" gelu" )) {
235
+ } else if (nodeKind == Symbol::aten (" gelu" )) {
233
236
return makeEltwiseOp (node, opkind::GELU);
234
- } else if (node-> kind () == Symbol::aten (" sqrt" )) {
237
+ } else if (nodeKind == Symbol::aten (" sqrt" )) {
235
238
return makeEltwiseOp (node, opkind::Sqrt);
236
- } else if (node-> kind () == Symbol::aten (" abs" )) {
239
+ } else if (nodeKind == Symbol::aten (" abs" )) {
237
240
return makeEltwiseOp (node, opkind::Abs);
238
- } else if (node-> kind () == Symbol::aten (" square" )) {
241
+ } else if (nodeKind == Symbol::aten (" square" )) {
239
242
return makeEltwiseOp (node, opkind::Square);
240
- } else if (node-> kind () == Symbol::aten (" hardtanh" )) {
243
+ } else if (nodeKind == Symbol::aten (" hardtanh" )) {
241
244
return makeEltwiseOp (node, opkind::HardTanh)
242
245
.setAttr (" min" , Operator::Float, 1 )
243
246
.setAttr (" max" , Operator::Float, 2 );
244
- } else if (node-> kind () == Symbol::aten (" softmax" )) {
247
+ } else if (nodeKind == Symbol::aten (" softmax" )) {
245
248
auto dim0 = getDimensions (node->input (0 ));
246
249
REQ (dim0.has_value ());
247
250
@@ -253,7 +256,7 @@ Operator createOperator(Node* node) {
253
256
.setInput (0 )
254
257
.setOutput (0 )
255
258
.setAttr (" axis" , axis);
256
- } else if (node-> kind () == Symbol::aten (" cat" )) {
259
+ } else if (nodeKind == Symbol::aten (" cat" )) {
257
260
return makeWildcardOp (node); // TODO: remove once Concat is supported
258
261
259
262
auto o = Operator (node, opkind::Concat);
@@ -273,7 +276,7 @@ Operator createOperator(Node* node) {
273
276
for (auto input : listConstruct->inputs ())
274
277
o.setInputValue (input);
275
278
return o.setOutput (0 ).setAttr (" axis" , Operator::Int, 1 );
276
- } else if (node-> kind () == Symbol::aten (" max_pool2d" )) {
279
+ } else if (nodeKind == Symbol::aten (" max_pool2d" )) {
277
280
auto rounding_type = Operator::Bool (node, 5 ) ? " ceil" : " floor" ;
278
281
return Operator (node, opkind::MaxPool)
279
282
.setInput (0 )
@@ -285,7 +288,7 @@ Operator createOperator(Node* node) {
285
288
.setAttr (" pads_end" , Operator::Ints, 3 )
286
289
.setAttr (" dilations" , Operator::Ints, 4 )
287
290
.setAttr (" rounding_type" , std::string (rounding_type));
288
- } else if (node-> kind () == Symbol::aten (" avg_pool2d" )) {
291
+ } else if (nodeKind == Symbol::aten (" avg_pool2d" )) {
289
292
auto rounding_type = Operator::Bool (node, 4 ) ? " ceil" : " floor" ;
290
293
auto divisor_override = toIValue (node->input (6 ));
291
294
REQ (divisor_override->isNone ());
@@ -299,17 +302,17 @@ Operator createOperator(Node* node) {
299
302
.setAttr (" pads_end" , Operator::Ints, 3 )
300
303
.setAttr (" exclude_pad" , !Operator::Bool (node, 5 ))
301
304
.setAttr (" rounding_type" , std::string (rounding_type));
302
- } else if (node-> kind () == Symbol::aten (" matmul" )) {
305
+ } else if (nodeKind == Symbol::aten (" matmul" )) {
303
306
auto dim0 = getDimensions (node->input (0 )).value_or (-1 );
304
307
auto dim1 = getDimensions (node->input (1 )).value_or (-1 );
305
308
// TODO: support all shape combinations
306
309
REQ ((dim0 == 2 && dim1 == 2 ) || (dim0 == 4 && dim1 == 4 ) ||
307
310
(dim0 == 3 && dim1 == 2 ));
308
311
// fall through
309
312
return Operator (node, opkind::MatMul).setInput (0 , 1 ).setOutput (0 );
310
- } else if (node-> kind () == Symbol::aten (" mm" )) {
313
+ } else if (nodeKind == Symbol::aten (" mm" )) {
311
314
return Operator (node, opkind::MatMul).setInput (0 , 1 ).setOutput (0 );
312
- } else if (node-> kind () == Symbol::aten (" linear" )) {
315
+ } else if (nodeKind == Symbol::aten (" linear" )) {
313
316
auto dim0 = getDimensions (node->input (0 )).value_or (-1 );
314
317
auto dim1 = getDimensions (node->input (1 )).value_or (-1 );
315
318
// REQ(dim1 == 2);
@@ -318,9 +321,9 @@ Operator createOperator(Node* node) {
318
321
.setInput (0 , 1 , 2 )
319
322
.setOutput (0 )
320
323
.setAttr (" transpose_b" , true );
321
- } else if (node-> kind () == Symbol::aten (" to" )) {
324
+ } else if (nodeKind == Symbol::aten (" to" )) {
322
325
return Operator (node, opkind::TypeCast).setInput (0 ).setOutput (0 );
323
- } else if (node-> kind () == Symbol::aten (" quantize_per_tensor" )) {
326
+ } else if (nodeKind == Symbol::aten (" quantize_per_tensor" )) {
324
327
// TODO: how to handle this case
325
328
REQ (node->input (1 )->node ()->kind () != Symbol::aten (" q_scale" ));
326
329
@@ -337,7 +340,7 @@ Operator createOperator(Node* node) {
337
340
.setAttr (" zps" , Operator::IntToVector, 2 )
338
341
.setAttr (" out_type" , Operator::String, 3 )
339
342
.setAttr (" qtype" , std::string (" per_tensor" ));
340
- } else if (node-> kind () == Symbol::aten (" quantize_per_channel" )) {
343
+ } else if (nodeKind == Symbol::aten (" quantize_per_channel" )) {
341
344
return Operator (node, opkind::Quantize)
342
345
.setInput (0 )
343
346
.setOutput (0 )
@@ -346,7 +349,7 @@ Operator createOperator(Node* node) {
346
349
.setAttr (" axis" , Operator::Int, 3 )
347
350
.setAttr (" out_type" , Operator::String, 4 )
348
351
.setAttr (" qtype" , std::string (" per_channel" ));
349
- } else if (node-> kind () == Symbol::aten (" dequantize" )) {
352
+ } else if (nodeKind == Symbol::aten (" dequantize" )) {
350
353
if (node->numAttributes () == 0 ) {
351
354
Node* input_node = node->input (0 )->node ();
352
355
TORCH_CHECK (
@@ -384,17 +387,36 @@ Operator createOperator(Node* node) {
384
387
.setAttr (" qtype" , node->s (Symbol::attr (" qtype" )));
385
388
}
386
389
}
390
+ } else if (nodeKind == Symbol::aten (" permute" )) {
391
+ REQ (aliasDb_->hasInputWriters (node) == false ) {
392
+ return Operator (node, opkind::StaticTranspose)
393
+ .setInput (0 )
394
+ .setOutput (0 )
395
+ .setAttr (" order" , toIValue (node->input (1 ))->toIntVector ());
396
+ }
397
+ } else if (nodeKind == Symbol::aten (" contiguous" )) {
398
+ // Contiguous should only be mapped to oneDNN Graph if the destination
399
+ // memory-layout is different than the source memory-format
400
+ // Strides would be different, but shape would be same
401
+ auto typeOfInput = node->input (0 )->type ()->expect <TensorType>();
402
+ auto typeOfOutput = node->output (0 )->type ()->expect <TensorType>();
403
+ auto inputStrides = typeOfInput->strides ().concrete_sizes ();
404
+ auto outputStrides = typeOfOutput->strides ().concrete_sizes ();
405
+ REQ (inputStrides != outputStrides);
406
+ return Operator (node, opkind::Reorder).setInput (0 ).setOutput (0 );
387
407
}
408
+
409
+ GRAPH_DEBUG (" Making " , nodeKind.toQualString (), " a wildcard" );
388
410
return makeWildcardOp (node);
389
411
}
390
412
391
- dnnl::graph::op createLlgaOp (Node* node) {
413
+ dnnl::graph::op LlgaGraphHelper:: createLlgaOp (Node* node) {
392
414
return createOperator (node).llgaOp ();
393
415
}
394
416
395
- bool isSupported (Node* node) {
417
+ bool LlgaGraphHelper:: isSupported (Node* node) const {
396
418
return createOperator (node).kind () != opkind::Wildcard;
397
- };
419
+ }
398
420
399
421
DeviceType inferDeviceFromValue (Value* v) {
400
422
auto tt = v->type ()->cast <TensorType>();
@@ -470,22 +492,21 @@ LlgaGraphHelper::LlgaGraphHelper(
470
492
auto deviceType = inferDevice (graph);
471
493
auto engineKind = getLlgaEngineKind (deviceType);
472
494
dnnl::graph::graph g{engineKind};
473
-
495
+ aliasDb_ = torch::make_unique<torch::jit::AliasDb>(graph);
474
496
GRAPH_DEBUG (" Constructing LLGA graph" );
475
497
// TODO: select nodes in top-level block for now
476
498
for (auto * node : graph->block ()->nodes ()) {
499
+ auto kindOfNode = node->kind ();
477
500
auto op = createLlgaOp (node);
478
501
479
502
try {
480
503
g.add_op (op);
504
+ GRAPH_DEBUG (" Added node " , kindOfNode.toQualString ());
481
505
} catch (std::exception& e) {
482
- GRAPH_DEBUG (
483
- " The backend failed to add node " , node->kind ().toQualString ());
506
+ GRAPH_DEBUG (" The backend failed to add node " , kindOfNode.toQualString ());
484
507
g.add_op (makeWildcardOp (node).llgaOp ());
485
508
}
486
509
487
- GRAPH_DEBUG (" Added node " , node->kind ().toQualString ());
488
-
489
510
for (Value* input : node->inputs ()) {
490
511
tensorIdToValue_.emplace (input->unique (), input);
491
512
}
@@ -528,20 +549,6 @@ bool LlgaGraphHelper::shouldMerge(Node* toMerge, Node* subgraph) {
528
549
opToOwningPartition_.get (subgraph);
529
550
}
530
551
531
- bool isViewOp (Node* n) {
532
- switch (n->kind ()) {
533
- case aten::view:
534
- case aten::view_as:
535
- case aten::reshape:
536
- case aten::reshape_as:
537
- case aten::transpose:
538
- case aten::expand:
539
- case aten::expand_as:
540
- return true ;
541
- }
542
- return false ;
543
- }
544
-
545
552
void checkAndRemoveAttr (Node* n, std::string attr) {
546
553
TORCH_CHECK (
547
554
n->hasAttributeS (attr),
@@ -584,9 +591,6 @@ bool LlgaGraphHelper::shouldConsiderForMerge(Node* node) {
584
591
if (isLlgaSubgraph (node)) {
585
592
return true ;
586
593
}
587
- if (isViewOp (node)) {
588
- return false ;
589
- }
590
594
// For a partition composed of 1 single quant, 1 single dequant or 1 single to
591
595
// do not rewrite it in the bridge, so that the FWK may have chances
592
596
// to optimize single int8/bf16 op that LLGA does not support
@@ -662,9 +666,11 @@ size_t LlgaGraphHelper::countSupportedOps(
662
666
const std::shared_ptr<Graph>& graph) const {
663
667
// TODO: count nodes in top-level block for now
664
668
size_t cnt = 0 ;
665
- for (auto * node : graph->block ()->nodes ())
666
- if (isSupported (node))
669
+ for (auto * node : graph->block ()->nodes ()) {
670
+ if (isSupported (node)) {
667
671
cnt++;
672
+ }
673
+ }
668
674
return cnt;
669
675
}
670
676
0 commit comments