@@ -370,99 +370,77 @@ void matmul_common(
370
370
dil::scale_t (), dil::scale_t (), dil::scale_t (), attr);
371
371
}
372
372
373
- at::Tensor AtenIpexCPUDev::dil_bmm (
374
- const at::Tensor& self,
375
- const at::Tensor& mat2) {
373
+ at::Tensor AtenIpexCPUDev::dil_bmm (const at::Tensor& self, const at::Tensor& mat2) {
376
374
DEBUG (" AtenIpexCPUDev::dil_bmm\n " );
377
375
378
- at::Tensor result = at::empty ({0 }, self.options ());
376
+ auto result = dbl::comm::empty_dil_tensor ({0 }, self.options ());
379
377
return dil_bmm_out (result, self, mat2);
380
378
}
381
379
382
- at::Tensor& AtenIpexCPUDev::dil_bmm_out (
383
- at::Tensor &result,
384
- const at::Tensor& batch1,
385
- const at::Tensor& batch2) {
380
+ at::Tensor& AtenIpexCPUDev::dil_bmm_out (at::Tensor &result, const at::Tensor& batch1, const at::Tensor& batch2) {
386
381
DEBUG (" AtenIpexCPUDev::dil_bmm_out\n " );
387
382
CHECK_DNNL_OP_PRE_COND (batch1);
388
383
CHECK_DNNL_OP_PRE_COND (batch2);
389
384
390
385
TORCH_INTERNAL_ASSERT_DEBUG_ONLY (batch1.dim () == 3 && batch2.dim () == 3 );
391
- at::IntArrayRef inferred_size{batch1.size (0 ), batch1.size (1 ), batch2.size (2 )};
392
- if (!result.sizes ().equals (inferred_size)) {
393
- result.resize_ (inferred_size);
394
- }
386
+ dil::dims inferred_size{batch1.size (0 ), batch1.size (1 ), batch2.size (2 )};
395
387
396
- dbl::comm::reorder_to_bf16_for_mix_prec (result);
397
388
dbl::comm::reorder_to_bf16_for_mix_prec (batch1);
398
389
dbl::comm::reorder_to_bf16_for_mix_prec (batch2);
399
390
400
391
auto x = dbl::comm::try_gen_dil_tensor (batch1);
401
392
auto w = dbl::comm::try_gen_dil_tensor (batch2);
402
- auto y = dbl::comm::try_gen_dil_tensor (result) ;
393
+ dil::tensor y ;
403
394
matmul_common (x, w, dil::tensor (), y);
404
395
405
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (y. is_public_format () || check_tensor_own_whole_storage ( result) );
406
- dbl::comm::sync_shape_from_dil_to_aten (result, y );
396
+ dbl::comm::equip_dil_buffer ( result, y );
397
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (result. sizes (). equals (inferred_size) );
407
398
return result;
408
399
}
409
400
410
- at::Tensor AtenIpexCPUDev::dil_mm (
411
- const at::Tensor& self,
412
- const at::Tensor& mat2) {
401
+ at::Tensor AtenIpexCPUDev::dil_mm (const at::Tensor& self, const at::Tensor& mat2) {
413
402
DEBUG (" AtenIpexCPUDev::dil_mm\n " );
414
403
415
- at::Tensor result = at::empty ({0 }, self.options ());
404
+ auto result = dbl::comm::empty_dil_tensor ({0 }, self.options ());
416
405
return dil_mm_out (result, self, mat2);
417
406
}
418
407
419
- at::Tensor& AtenIpexCPUDev::dil_mm_out (
420
- at::Tensor& result,
421
- const at::Tensor& self,
422
- const at::Tensor& mat2) {
408
+ at::Tensor& AtenIpexCPUDev::dil_mm_out (at::Tensor& result, const at::Tensor& self, const at::Tensor& mat2) {
423
409
DEBUG (" AtenIpexCPUDev::dil_mm_out\n " );
424
410
425
411
TORCH_INTERNAL_ASSERT_DEBUG_ONLY (self.dim () == 2 && mat2.dim () == 2 );
426
- at::IntArrayRef inferred_size{self.size (0 ), mat2.size (1 )};
427
- if (!result.sizes ().equals (inferred_size)) {
428
- result.resize_ (inferred_size);
429
- }
412
+ dil::dims inferred_size{self.size (0 ), mat2.size (1 )};
430
413
431
- dbl::comm::reorder_to_bf16_for_mix_prec (result);
432
414
dbl::comm::reorder_to_bf16_for_mix_prec (self);
433
415
dbl::comm::reorder_to_bf16_for_mix_prec (mat2);
434
416
435
417
auto x = dbl::comm::try_gen_dil_tensor (self);
436
418
auto w = dbl::comm::try_gen_dil_tensor (mat2);
437
- auto y = dbl::comm::try_gen_dil_tensor (result) ;
419
+ dil::tensor y ;
438
420
matmul_common (x, w, dil::tensor (), y);
439
421
440
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (y. is_public_format () || check_tensor_own_whole_storage ( result) );
441
- dbl::comm::sync_shape_from_dil_to_aten (result, y );
422
+ dbl::comm::equip_dil_buffer ( result, y );
423
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (result. sizes (). equals (inferred_size) );
442
424
return result;
443
425
}
444
426
445
- at::Tensor& AtenIpexCPUDev::dil_baddbmm_out (
427
+ template <bool inplace>
428
+ at::Tensor& dil_baddbmm_common (
446
429
at::Tensor &result,
447
430
const at::Tensor& self,
448
431
const at::Tensor& batch1,
449
432
const at::Tensor& batch2,
450
433
at::Scalar beta,
451
434
at::Scalar alpha) {
452
- DEBUG (" AtenIpexCPUDev::dil_baddbmm_out\n " );
453
435
CHECK_DNNL_OP_PRE_COND (self);
454
436
CHECK_DNNL_OP_PRE_COND (batch1);
455
437
CHECK_DNNL_OP_PRE_COND (batch2);
456
438
457
439
TORCH_INTERNAL_ASSERT_DEBUG_ONLY (batch1.dim () == 3 && batch2.dim () == 3 );
458
- at::IntArrayRef inferred_size{batch1.size (0 ), batch1.size (1 ), batch2.size (2 )};
459
- if (!result.sizes ().equals (inferred_size)) {
460
- result.resize_ (inferred_size);
461
- }
440
+ dil::dims inferred_size{batch1.size (0 ), batch1.size (1 ), batch2.size (2 )};
462
441
TORCH_CHECK (self.sizes ().equals (inferred_size),
463
- " dil_baddbmm not support broadcast yet" );
442
+ " dil baddbmm not support broadcast yet" );
464
443
465
- dbl::comm::reorder_to_bf16_for_mix_prec (result);
466
444
dbl::comm::reorder_to_bf16_for_mix_prec (self);
467
445
dbl::comm::reorder_to_bf16_for_mix_prec (batch1);
468
446
dbl::comm::reorder_to_bf16_for_mix_prec (batch2);
@@ -478,60 +456,59 @@ at::Tensor& AtenIpexCPUDev::dil_baddbmm_out(
478
456
bias.reshape (bias_dims);
479
457
}
480
458
}
481
- auto y = dbl::comm::try_gen_dil_tensor (result );
459
+ auto y = inplace ? dbl::comm::try_gen_dil_tensor (self) : dil::tensor ( );
482
460
auto attr_ = dil::attr_t::fuse_sum ();
483
461
matmul_common (x, w, bias, y, beta, alpha, attr_);
484
462
485
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (y.is_public_format () || check_tensor_own_whole_storage (result));
486
- dbl::comm::sync_shape_from_dil_to_aten (result, y);
463
+ if (!inplace) {
464
+ dbl::comm::equip_dil_buffer (result, y);
465
+ }
466
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (result.sizes ().equals (inferred_size));
487
467
return result;
488
468
}
489
469
490
- at::Tensor AtenIpexCPUDev::dil_baddbmm (
470
+ at::Tensor& AtenIpexCPUDev::dil_baddbmm_out (
471
+ at::Tensor &result,
491
472
const at::Tensor& self,
492
473
const at::Tensor& batch1,
493
- const at::Tensor & batch2,
474
+ const at::Tensor& batch2,
494
475
at::Scalar beta,
495
476
at::Scalar alpha) {
477
+ DEBUG (" AtenIpexCPUDev::dil_baddbmm_out\n " );
478
+
479
+ return dil_baddbmm_common</* inplace=*/ false >(result, self, batch1, batch2, beta, alpha);
480
+ }
481
+
482
+ at::Tensor AtenIpexCPUDev::dil_baddbmm (const at::Tensor& self, const at::Tensor& batch1, const at::Tensor & batch2, at::Scalar beta, at::Scalar alpha) {
496
483
DEBUG (" AtenIpexCPUDev::dil_baddbmm\n " );
497
484
498
- at::Tensor result = at::empty ({0 }, self.options ());
499
- return dil_baddbmm_out (result, self, batch1, batch2, beta, alpha);
485
+ auto result = dbl::comm::empty_dil_tensor ({0 }, self.options ());
486
+ return dil_baddbmm_common< /* inplace= */ false > (result, self, batch1, batch2, beta, alpha);
500
487
}
501
488
502
- at::Tensor& AtenIpexCPUDev::dil_baddbmm_ (
503
- at::Tensor& self,
504
- const at::Tensor& batch1,
505
- const at::Tensor& batch2,
506
- at::Scalar beta,
507
- at::Scalar alpha) {
489
+ at::Tensor& AtenIpexCPUDev::dil_baddbmm_ (at::Tensor& self, const at::Tensor& batch1, const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha) {
508
490
DEBUG (" AtenIpexCPUDev::dil_baddbmm_\n " );
509
491
510
- at::Tensor result = at::empty ({0 }, self.options ());
511
- return dil_baddbmm_out (self, result, batch1, batch2, beta, alpha);
492
+ return dil_baddbmm_out (self, self, batch1, batch2, beta, alpha);
512
493
}
513
494
514
- at::Tensor& AtenIpexCPUDev::dil_addmm_out (
495
+ template <bool inplace>
496
+ at::Tensor& dil_addmm_common (
515
497
at::Tensor& result,
516
498
const at::Tensor& self,
517
499
const at::Tensor& mat1,
518
500
const at::Tensor& mat2,
519
501
at::Scalar beta,
520
502
at::Scalar alpha) {
521
- DEBUG (" AtenIpexCPUDev::dil_addmm_out\n " );
522
503
CHECK_DNNL_OP_PRE_COND (self);
523
504
CHECK_DNNL_OP_PRE_COND (mat1);
524
505
CHECK_DNNL_OP_PRE_COND (mat2);
525
506
526
507
TORCH_INTERNAL_ASSERT_DEBUG_ONLY (mat1.dim () == 2 && mat2.dim () == 2 );
527
- at::IntArrayRef inferred_size{mat1.size (0 ), mat2.size (1 )};
528
- if (!result.sizes ().equals (inferred_size)) {
529
- result.resize_ (inferred_size);
530
- }
508
+ dil::dims inferred_size{mat1.size (0 ), mat2.size (1 )};
531
509
TORCH_CHECK (self.sizes ().equals (inferred_size),
532
- " dil_addmm not support broadcast yet" );
510
+ " dil addmm not support broadcast yet" );
533
511
534
- dbl::comm::reorder_to_bf16_for_mix_prec (result);
535
512
dbl::comm::reorder_to_bf16_for_mix_prec (self);
536
513
dbl::comm::reorder_to_bf16_for_mix_prec (mat1);
537
514
dbl::comm::reorder_to_bf16_for_mix_prec (mat2);
@@ -547,60 +524,53 @@ at::Tensor& AtenIpexCPUDev::dil_addmm_out(
547
524
bias.reshape (bias_dims);
548
525
}
549
526
}
550
- auto y = dbl::comm::try_gen_dil_tensor (result );
527
+ auto y = inplace ? dbl::comm::try_gen_dil_tensor (self) : dil::tensor ( );
551
528
auto attr_ = dil::attr_t::fuse_sum ();
552
529
matmul_common (x, w, bias, y, beta, alpha, attr_);
553
530
554
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (y.is_public_format () || check_tensor_own_whole_storage (result));
555
- dbl::comm::sync_shape_from_dil_to_aten (result, y);
531
+ if (!inplace) {
532
+ dbl::comm::equip_dil_buffer (result, y);
533
+ }
534
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (result.sizes ().equals (inferred_size));
556
535
return result;
557
536
}
558
537
559
- at::Tensor AtenIpexCPUDev::dil_addmm (
560
- const at::Tensor& self,
561
- const at::Tensor& batch1,
562
- const at::Tensor & batch2,
563
- at::Scalar beta,
564
- at::Scalar alpha) {
538
+ at::Tensor& AtenIpexCPUDev::dil_addmm_out (at::Tensor& result, const at::Tensor& self, const at::Tensor& mat1, const at::Tensor& mat2, at::Scalar beta, at::Scalar alpha) {
539
+ DEBUG (" AtenIpexCPUDev::dil_addmm_out\n " );
540
+
541
+ return dil_addmm_common</* inplace=*/ false >(result, self, mat1, mat2, beta, alpha);
542
+ }
543
+
544
+ at::Tensor AtenIpexCPUDev::dil_addmm (const at::Tensor& self, const at::Tensor& mat1, const at::Tensor & mat2, at::Scalar beta, at::Scalar alpha) {
565
545
DEBUG (" AtenIpexCPUDev::dil_addmm\n " );
566
546
567
- at::Tensor result = at::empty ({0 }, self.options ());
568
- return dil_addmm_out (result, self, batch1, batch2 , beta, alpha);
547
+ auto result = dbl::comm::empty_dil_tensor ({0 }, self.options ());
548
+ return dil_addmm_common< /* inplace= */ false > (result, self, mat1, mat2 , beta, alpha);
569
549
}
570
550
571
- at::Tensor& AtenIpexCPUDev::dil_addmm_ (
572
- at::Tensor& self,
573
- const at::Tensor& batch1,
574
- const at::Tensor & batch2,
575
- at::Scalar beta,
576
- at::Scalar alpha) {
551
+ at::Tensor& AtenIpexCPUDev::dil_addmm_ (at::Tensor& self, const at::Tensor& mat1, const at::Tensor & mat2, at::Scalar beta, at::Scalar alpha) {
577
552
DEBUG (" AtenIpexCPUDev::dil_addmm_\n " );
578
553
579
- at::Tensor result = at::empty ({0 }, self.options ());
580
- return dil_addmm_out (self, result, batch1, batch2, beta, alpha);
554
+ return dil_addmm_common</* inplace=*/ false >(self, self, mat1, mat2, beta, alpha);
581
555
}
582
556
583
- at::Tensor& AtenIpexCPUDev::dil_addbmm_out (
557
+ template <bool inplace>
558
+ at::Tensor& dil_addbmm_common (
584
559
at::Tensor& result,
585
560
const at::Tensor &self,
586
561
const at::Tensor &batch1,
587
562
const at::Tensor &batch2,
588
563
at::Scalar beta,
589
564
at::Scalar alpha) {
590
- DEBUG (" AtenIpexCPUDev::dil_addbmm_out\n " );
591
565
CHECK_DNNL_OP_PRE_COND (self);
592
566
CHECK_DNNL_OP_PRE_COND (batch1);
593
567
CHECK_DNNL_OP_PRE_COND (batch2);
594
568
595
569
TORCH_INTERNAL_ASSERT_DEBUG_ONLY (batch1.dim () == 3 && batch2.dim () == 3 );
596
- at::IntArrayRef inferred_size{batch1.size (1 ), batch2.size (2 )};
597
- if (!result.sizes ().equals (inferred_size)) {
598
- result.resize_ (inferred_size);
599
- }
570
+ dil::dims inferred_size{batch1.size (1 ), batch2.size (2 )};
600
571
TORCH_CHECK (self.sizes ().equals (inferred_size),
601
- " dil_addbmm not support broadcast yet" );
572
+ " dil addbmm not support broadcast yet" );
602
573
603
- dbl::comm::reorder_to_bf16_for_mix_prec (result);
604
574
dbl::comm::reorder_to_bf16_for_mix_prec (self);
605
575
dbl::comm::reorder_to_bf16_for_mix_prec (batch1);
606
576
dbl::comm::reorder_to_bf16_for_mix_prec (batch2);
@@ -616,11 +586,9 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out(
616
586
if (x.get_dim (0 ) > 1 ) {
617
587
x_ = x.transpose (0 , 1 );
618
588
}
619
- dil::dims x_dims = {x.get_dim (1 ), x.get_dim (0 ) * x.get_dim (2 )};
620
- x_ = x_.reshape (x_dims);
621
- dil::dims w_dims = {w.get_dim (0 ) * w.get_dim (1 ), w.get_dim (2 )};
622
- auto w_ = w.reshape (w_dims);
623
- auto y = dbl::comm::try_gen_dil_tensor (result);
589
+ x_ = x_.reshape ({x.get_dim (1 ), x.get_dim (0 ) * x.get_dim (2 )});
590
+ auto w_ = w.reshape ({w.get_dim (0 ) * w.get_dim (1 ), w.get_dim (2 )});
591
+ auto y = inplace ? dbl::comm::try_gen_dil_tensor (self) : dil::tensor ();
624
592
auto attr_ = dil::attr_t::fuse_sum ();
625
593
626
594
dil::tensor bias;
@@ -634,33 +602,30 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out(
634
602
}
635
603
matmul_common (x_, w_, bias, y, beta, alpha, attr_);
636
604
637
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (y.is_public_format () || check_tensor_own_whole_storage (result));
638
- dbl::comm::sync_shape_from_dil_to_aten (result, y);
605
+ if (!inplace) {
606
+ dbl::comm::equip_dil_buffer (result, y);
607
+ }
608
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (result.sizes ().equals (inferred_size));
639
609
return result;
640
610
}
641
611
642
- at::Tensor AtenIpexCPUDev::dil_addbmm (
643
- const at::Tensor &self,
644
- const at::Tensor &batch1,
645
- const at::Tensor &batch2,
646
- at::Scalar beta,
647
- at::Scalar alpha) {
612
+ at::Tensor& AtenIpexCPUDev::dil_addbmm_out (at::Tensor& result, const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha) {
613
+ DEBUG (" AtenIpexCPUDev::dil_addbmm_out\n " );
614
+
615
+ return dil_addbmm_common</* inplace=*/ false >(result, self, batch1, batch2, beta, alpha);
616
+ }
617
+
618
+ at::Tensor AtenIpexCPUDev::dil_addbmm (const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha) {
648
619
DEBUG (" AtenIpexCPUDev::dil_addbmm\n " );
649
620
650
- at::Tensor result = at::empty ({0 }, self.options ());
651
- return dil_addbmm_out (result, self, batch1, batch2, beta, alpha);
621
+ auto result = dbl::comm::empty_dil_tensor ({0 }, self.options ());
622
+ return dil_addbmm_common< /* inplace= */ false > (result, self, batch1, batch2, beta, alpha);
652
623
}
653
624
654
- at::Tensor& AtenIpexCPUDev::dil_addbmm_ (
655
- at::Tensor& self,
656
- const at::Tensor& batch1,
657
- const at::Tensor& batch2,
658
- at::Scalar beta,
659
- at::Scalar alpha) {
625
+ at::Tensor& AtenIpexCPUDev::dil_addbmm_ (at::Tensor& self, const at::Tensor& batch1, const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha) {
660
626
DEBUG (" AtenIpexCPUDev::dil_addbmm_\n " );
661
627
662
- at::Tensor result = at::empty ({0 }, self.options ());
663
- return dil_addbmm_out (self, result, batch1, batch2, beta, alpha);
628
+ return dil_addbmm_common</* inplace=*/ true >(self, self, batch1, batch2, beta, alpha);
664
629
}
665
630
666
631
at::Tensor AtenIpexCPUDev::dil_linear (
0 commit comments