Skip to content

Commit 6b835d2

Browse files
committed
fix *mm ops and resizing behavior
1 parent e24f42d commit 6b835d2

File tree

2 files changed

+90
-113
lines changed

2 files changed

+90
-113
lines changed

tests/cpu/test_lazy_reorder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,11 @@ def test_addmm(self):
467467
torch.addmm(input=res_dpcpp, mat1=b1_dpcpp, mat2=b2_dpcpp, alpha=alpha, beta=beta, out=y_dpcpp)
468468
self.assertEqual(y_cpu, y_dpcpp)
469469

470+
res_cpu.addmm_(mat1=b1_cpu, mat2=b2_cpu, alpha=alpha, beta=beta)
471+
res_dpcpp.addmm_(mat1=b1_cpu, mat2=b2_cpu, alpha=alpha, beta=beta)
472+
self.assertEqual(res_cpu, res_dpcpp)
473+
474+
470475
def test_addbmm(self):
471476
ipex.enable_auto_dnnl()
472477
rand_seed = int(get_rand_seed())
@@ -494,6 +499,10 @@ def test_addbmm(self):
494499
torch.addbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha, out=y_dpcpp)
495500
self.assertEqual(y_cpu, y_dpcpp, 1e-4)
496501

502+
res_cpu.addbmm_(b1_cpu, b2_cpu, beta=beta, alpha=alpha)
503+
res_dpcpp.addbmm_(b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha)
504+
self.assertEqual(res_cpu, res_dpcpp, 1e-4)
505+
497506
def test_baddbmm(self):
498507
ipex.enable_auto_dnnl()
499508
rand_seed = int(get_rand_seed())
@@ -520,6 +529,9 @@ def test_baddbmm(self):
520529
torch.baddbmm(res_cpu, b1_cpu, b2_cpu, alpha=alpha, beta=beta, out=y_cpu),
521530
torch.baddbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, alpha=alpha, beta=beta, out=y_dpcpp),
522531
self.assertEqual(y_cpu, y_dpcpp)
532+
res_cpu.baddbmm_(b1_cpu, b2_cpu, alpha=alpha, beta=beta)
533+
res_dpcpp.baddbmm_(b1_cpu, b2_cpu, alpha=alpha, beta=beta)
534+
self.assertEqual(res_cpu, res_dpcpp)
523535

524536
class TestLinear(TestCase):
525537
def test_linear(self):

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 78 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -370,99 +370,77 @@ void matmul_common(
370370
dil::scale_t(), dil::scale_t(), dil::scale_t(), attr);
371371
}
372372

373-
at::Tensor AtenIpexCPUDev::dil_bmm(
374-
const at::Tensor& self,
375-
const at::Tensor& mat2) {
373+
at::Tensor AtenIpexCPUDev::dil_bmm(const at::Tensor& self, const at::Tensor& mat2) {
376374
DEBUG("AtenIpexCPUDev::dil_bmm\n");
377375

378-
at::Tensor result = at::empty({0}, self.options());
376+
auto result = dbl::comm::empty_dil_tensor({0}, self.options());
379377
return dil_bmm_out(result, self, mat2);
380378
}
381379

382-
at::Tensor& AtenIpexCPUDev::dil_bmm_out(
383-
at::Tensor &result,
384-
const at::Tensor& batch1,
385-
const at::Tensor& batch2) {
380+
at::Tensor& AtenIpexCPUDev::dil_bmm_out(at::Tensor &result, const at::Tensor& batch1, const at::Tensor& batch2) {
386381
DEBUG("AtenIpexCPUDev::dil_bmm_out\n");
387382
CHECK_DNNL_OP_PRE_COND(batch1);
388383
CHECK_DNNL_OP_PRE_COND(batch2);
389384

390385
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batch1.dim() == 3 && batch2.dim() == 3);
391-
at::IntArrayRef inferred_size{batch1.size(0), batch1.size(1), batch2.size(2)};
392-
if (!result.sizes().equals(inferred_size)) {
393-
result.resize_(inferred_size);
394-
}
386+
dil::dims inferred_size{batch1.size(0), batch1.size(1), batch2.size(2)};
395387

396-
dbl::comm::reorder_to_bf16_for_mix_prec(result);
397388
dbl::comm::reorder_to_bf16_for_mix_prec(batch1);
398389
dbl::comm::reorder_to_bf16_for_mix_prec(batch2);
399390

400391
auto x = dbl::comm::try_gen_dil_tensor(batch1);
401392
auto w = dbl::comm::try_gen_dil_tensor(batch2);
402-
auto y = dbl::comm::try_gen_dil_tensor(result);
393+
dil::tensor y;
403394
matmul_common(x, w, dil::tensor(), y);
404395

405-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result));
406-
dbl::comm::sync_shape_from_dil_to_aten(result, y);
396+
dbl::comm::equip_dil_buffer(result, y);
397+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(inferred_size));
407398
return result;
408399
}
409400

410-
at::Tensor AtenIpexCPUDev::dil_mm(
411-
const at::Tensor& self,
412-
const at::Tensor& mat2) {
401+
at::Tensor AtenIpexCPUDev::dil_mm(const at::Tensor& self, const at::Tensor& mat2) {
413402
DEBUG("AtenIpexCPUDev::dil_mm\n");
414403

415-
at::Tensor result = at::empty({0}, self.options());
404+
auto result = dbl::comm::empty_dil_tensor({0}, self.options());
416405
return dil_mm_out(result, self, mat2);
417406
}
418407

419-
at::Tensor& AtenIpexCPUDev::dil_mm_out(
420-
at::Tensor& result,
421-
const at::Tensor& self,
422-
const at::Tensor& mat2) {
408+
at::Tensor& AtenIpexCPUDev::dil_mm_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& mat2) {
423409
DEBUG("AtenIpexCPUDev::dil_mm_out\n");
424410

425411
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.dim() == 2 && mat2.dim() == 2);
426-
at::IntArrayRef inferred_size{self.size(0), mat2.size(1)};
427-
if (!result.sizes().equals(inferred_size)) {
428-
result.resize_(inferred_size);
429-
}
412+
dil::dims inferred_size{self.size(0), mat2.size(1)};
430413

431-
dbl::comm::reorder_to_bf16_for_mix_prec(result);
432414
dbl::comm::reorder_to_bf16_for_mix_prec(self);
433415
dbl::comm::reorder_to_bf16_for_mix_prec(mat2);
434416

435417
auto x = dbl::comm::try_gen_dil_tensor(self);
436418
auto w = dbl::comm::try_gen_dil_tensor(mat2);
437-
auto y = dbl::comm::try_gen_dil_tensor(result);
419+
dil::tensor y;
438420
matmul_common(x, w, dil::tensor(), y);
439421

440-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result));
441-
dbl::comm::sync_shape_from_dil_to_aten(result, y);
422+
dbl::comm::equip_dil_buffer(result, y);
423+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(inferred_size));
442424
return result;
443425
}
444426

445-
at::Tensor& AtenIpexCPUDev::dil_baddbmm_out(
427+
template <bool inplace>
428+
at::Tensor& dil_baddbmm_common(
446429
at::Tensor &result,
447430
const at::Tensor& self,
448431
const at::Tensor& batch1,
449432
const at::Tensor& batch2,
450433
at::Scalar beta,
451434
at::Scalar alpha) {
452-
DEBUG("AtenIpexCPUDev::dil_baddbmm_out\n");
453435
CHECK_DNNL_OP_PRE_COND(self);
454436
CHECK_DNNL_OP_PRE_COND(batch1);
455437
CHECK_DNNL_OP_PRE_COND(batch2);
456438

457439
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batch1.dim() == 3 && batch2.dim() == 3);
458-
at::IntArrayRef inferred_size{batch1.size(0), batch1.size(1), batch2.size(2)};
459-
if (!result.sizes().equals(inferred_size)) {
460-
result.resize_(inferred_size);
461-
}
440+
dil::dims inferred_size{batch1.size(0), batch1.size(1), batch2.size(2)};
462441
TORCH_CHECK(self.sizes().equals(inferred_size),
463-
"dil_baddbmm not support broadcast yet");
442+
"dil baddbmm not support broadcast yet");
464443

465-
dbl::comm::reorder_to_bf16_for_mix_prec(result);
466444
dbl::comm::reorder_to_bf16_for_mix_prec(self);
467445
dbl::comm::reorder_to_bf16_for_mix_prec(batch1);
468446
dbl::comm::reorder_to_bf16_for_mix_prec(batch2);
@@ -478,60 +456,59 @@ at::Tensor& AtenIpexCPUDev::dil_baddbmm_out(
478456
bias.reshape(bias_dims);
479457
}
480458
}
481-
auto y = dbl::comm::try_gen_dil_tensor(result);
459+
auto y = inplace ? dbl::comm::try_gen_dil_tensor(self) : dil::tensor();
482460
auto attr_ = dil::attr_t::fuse_sum();
483461
matmul_common(x, w, bias, y, beta, alpha, attr_);
484462

485-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result));
486-
dbl::comm::sync_shape_from_dil_to_aten(result, y);
463+
if (!inplace) {
464+
dbl::comm::equip_dil_buffer(result, y);
465+
}
466+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(inferred_size));
487467
return result;
488468
}
489469

490-
at::Tensor AtenIpexCPUDev::dil_baddbmm(
470+
at::Tensor& AtenIpexCPUDev::dil_baddbmm_out(
471+
at::Tensor &result,
491472
const at::Tensor& self,
492473
const at::Tensor& batch1,
493-
const at::Tensor & batch2,
474+
const at::Tensor& batch2,
494475
at::Scalar beta,
495476
at::Scalar alpha) {
477+
DEBUG("AtenIpexCPUDev::dil_baddbmm_out\n");
478+
479+
return dil_baddbmm_common</*inplace=*/false>(result, self, batch1, batch2, beta, alpha);
480+
}
481+
482+
at::Tensor AtenIpexCPUDev::dil_baddbmm(const at::Tensor& self, const at::Tensor& batch1, const at::Tensor & batch2, at::Scalar beta, at::Scalar alpha) {
496483
DEBUG("AtenIpexCPUDev::dil_baddbmm\n");
497484

498-
at::Tensor result = at::empty({0}, self.options());
499-
return dil_baddbmm_out(result, self, batch1, batch2, beta, alpha);
485+
auto result = dbl::comm::empty_dil_tensor({0}, self.options());
486+
return dil_baddbmm_common</*inplace=*/false>(result, self, batch1, batch2, beta, alpha);
500487
}
501488

502-
at::Tensor& AtenIpexCPUDev::dil_baddbmm_(
503-
at::Tensor& self,
504-
const at::Tensor& batch1,
505-
const at::Tensor& batch2,
506-
at::Scalar beta,
507-
at::Scalar alpha) {
489+
at::Tensor& AtenIpexCPUDev::dil_baddbmm_(at::Tensor& self, const at::Tensor& batch1, const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha) {
508490
DEBUG("AtenIpexCPUDev::dil_baddbmm_\n");
509491

510-
at::Tensor result = at::empty({0}, self.options());
511-
return dil_baddbmm_out(self, result, batch1, batch2, beta, alpha);
492+
return dil_baddbmm_out(self, self, batch1, batch2, beta, alpha);
512493
}
513494

514-
at::Tensor& AtenIpexCPUDev::dil_addmm_out(
495+
template<bool inplace>
496+
at::Tensor& dil_addmm_common(
515497
at::Tensor& result,
516498
const at::Tensor& self,
517499
const at::Tensor& mat1,
518500
const at::Tensor& mat2,
519501
at::Scalar beta,
520502
at::Scalar alpha) {
521-
DEBUG("AtenIpexCPUDev::dil_addmm_out\n");
522503
CHECK_DNNL_OP_PRE_COND(self);
523504
CHECK_DNNL_OP_PRE_COND(mat1);
524505
CHECK_DNNL_OP_PRE_COND(mat2);
525506

526507
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat1.dim() == 2 && mat2.dim() == 2);
527-
at::IntArrayRef inferred_size{mat1.size(0), mat2.size(1)};
528-
if (!result.sizes().equals(inferred_size)) {
529-
result.resize_(inferred_size);
530-
}
508+
dil::dims inferred_size{mat1.size(0), mat2.size(1)};
531509
TORCH_CHECK(self.sizes().equals(inferred_size),
532-
"dil_addmm not support broadcast yet");
510+
"dil addmm not support broadcast yet");
533511

534-
dbl::comm::reorder_to_bf16_for_mix_prec(result);
535512
dbl::comm::reorder_to_bf16_for_mix_prec(self);
536513
dbl::comm::reorder_to_bf16_for_mix_prec(mat1);
537514
dbl::comm::reorder_to_bf16_for_mix_prec(mat2);
@@ -547,60 +524,53 @@ at::Tensor& AtenIpexCPUDev::dil_addmm_out(
547524
bias.reshape(bias_dims);
548525
}
549526
}
550-
auto y = dbl::comm::try_gen_dil_tensor(result);
527+
auto y = inplace ? dbl::comm::try_gen_dil_tensor(self) : dil::tensor();
551528
auto attr_ = dil::attr_t::fuse_sum();
552529
matmul_common(x, w, bias, y, beta, alpha, attr_);
553530

554-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result));
555-
dbl::comm::sync_shape_from_dil_to_aten(result, y);
531+
if (!inplace) {
532+
dbl::comm::equip_dil_buffer(result, y);
533+
}
534+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(inferred_size));
556535
return result;
557536
}
558537

559-
at::Tensor AtenIpexCPUDev::dil_addmm(
560-
const at::Tensor& self,
561-
const at::Tensor& batch1,
562-
const at::Tensor & batch2,
563-
at::Scalar beta,
564-
at::Scalar alpha) {
538+
at::Tensor& AtenIpexCPUDev::dil_addmm_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& mat1, const at::Tensor& mat2, at::Scalar beta, at::Scalar alpha) {
539+
DEBUG("AtenIpexCPUDev::dil_addmm_out\n");
540+
541+
return dil_addmm_common</*inplace=*/false>(result, self, mat1, mat2, beta, alpha);
542+
}
543+
544+
at::Tensor AtenIpexCPUDev::dil_addmm(const at::Tensor& self, const at::Tensor& mat1, const at::Tensor & mat2, at::Scalar beta, at::Scalar alpha) {
565545
DEBUG("AtenIpexCPUDev::dil_addmm\n");
566546

567-
at::Tensor result = at::empty({0}, self.options());
568-
return dil_addmm_out(result, self, batch1, batch2, beta, alpha);
547+
auto result = dbl::comm::empty_dil_tensor({0}, self.options());
548+
return dil_addmm_common</*inplace=*/false>(result, self, mat1, mat2, beta, alpha);
569549
}
570550

571-
at::Tensor& AtenIpexCPUDev::dil_addmm_(
572-
at::Tensor& self,
573-
const at::Tensor& batch1,
574-
const at::Tensor & batch2,
575-
at::Scalar beta,
576-
at::Scalar alpha) {
551+
at::Tensor& AtenIpexCPUDev::dil_addmm_(at::Tensor& self, const at::Tensor& mat1, const at::Tensor & mat2, at::Scalar beta, at::Scalar alpha) {
577552
DEBUG("AtenIpexCPUDev::dil_addmm_\n");
578553

579-
at::Tensor result = at::empty({0}, self.options());
580-
return dil_addmm_out(self, result, batch1, batch2, beta, alpha);
554+
return dil_addmm_common</*inplace=*/false>(self, self, mat1, mat2, beta, alpha);
581555
}
582556

583-
at::Tensor& AtenIpexCPUDev::dil_addbmm_out(
557+
template<bool inplace>
558+
at::Tensor& dil_addbmm_common(
584559
at::Tensor& result,
585560
const at::Tensor &self,
586561
const at::Tensor &batch1,
587562
const at::Tensor &batch2,
588563
at::Scalar beta,
589564
at::Scalar alpha) {
590-
DEBUG("AtenIpexCPUDev::dil_addbmm_out\n");
591565
CHECK_DNNL_OP_PRE_COND(self);
592566
CHECK_DNNL_OP_PRE_COND(batch1);
593567
CHECK_DNNL_OP_PRE_COND(batch2);
594568

595569
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batch1.dim() == 3 && batch2.dim() == 3);
596-
at::IntArrayRef inferred_size{batch1.size(1), batch2.size(2)};
597-
if (!result.sizes().equals(inferred_size)) {
598-
result.resize_(inferred_size);
599-
}
570+
dil::dims inferred_size{batch1.size(1), batch2.size(2)};
600571
TORCH_CHECK(self.sizes().equals(inferred_size),
601-
"dil_addbmm not support broadcast yet");
572+
"dil addbmm not support broadcast yet");
602573

603-
dbl::comm::reorder_to_bf16_for_mix_prec(result);
604574
dbl::comm::reorder_to_bf16_for_mix_prec(self);
605575
dbl::comm::reorder_to_bf16_for_mix_prec(batch1);
606576
dbl::comm::reorder_to_bf16_for_mix_prec(batch2);
@@ -616,11 +586,9 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out(
616586
if (x.get_dim(0) > 1) {
617587
x_ = x.transpose(0, 1);
618588
}
619-
dil::dims x_dims = {x.get_dim(1), x.get_dim(0) * x.get_dim(2)};
620-
x_ = x_.reshape(x_dims);
621-
dil::dims w_dims = {w.get_dim(0) * w.get_dim(1), w.get_dim(2)};
622-
auto w_ = w.reshape(w_dims);
623-
auto y = dbl::comm::try_gen_dil_tensor(result);
589+
x_ = x_.reshape({x.get_dim(1), x.get_dim(0) * x.get_dim(2)});
590+
auto w_ = w.reshape({w.get_dim(0) * w.get_dim(1), w.get_dim(2)});
591+
auto y = inplace ? dbl::comm::try_gen_dil_tensor(self) : dil::tensor();
624592
auto attr_ = dil::attr_t::fuse_sum();
625593

626594
dil::tensor bias;
@@ -634,33 +602,30 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out(
634602
}
635603
matmul_common(x_, w_, bias, y, beta, alpha, attr_);
636604

637-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result));
638-
dbl::comm::sync_shape_from_dil_to_aten(result, y);
605+
if (!inplace) {
606+
dbl::comm::equip_dil_buffer(result, y);
607+
}
608+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(inferred_size));
639609
return result;
640610
}
641611

642-
at::Tensor AtenIpexCPUDev::dil_addbmm(
643-
const at::Tensor &self,
644-
const at::Tensor &batch1,
645-
const at::Tensor &batch2,
646-
at::Scalar beta,
647-
at::Scalar alpha) {
612+
at::Tensor& AtenIpexCPUDev::dil_addbmm_out(at::Tensor& result, const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha) {
613+
DEBUG("AtenIpexCPUDev::dil_addbmm_out\n");
614+
615+
return dil_addbmm_common</*inplace=*/false>(result, self, batch1, batch2, beta, alpha);
616+
}
617+
618+
at::Tensor AtenIpexCPUDev::dil_addbmm(const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha) {
648619
DEBUG("AtenIpexCPUDev::dil_addbmm\n");
649620

650-
at::Tensor result = at::empty({0}, self.options());
651-
return dil_addbmm_out(result, self, batch1, batch2, beta, alpha);
621+
auto result = dbl::comm::empty_dil_tensor({0}, self.options());
622+
return dil_addbmm_common</*inplace=*/false>(result, self, batch1, batch2, beta, alpha);
652623
}
653624

654-
at::Tensor& AtenIpexCPUDev::dil_addbmm_(
655-
at::Tensor& self,
656-
const at::Tensor& batch1,
657-
const at::Tensor& batch2,
658-
at::Scalar beta,
659-
at::Scalar alpha) {
625+
at::Tensor& AtenIpexCPUDev::dil_addbmm_(at::Tensor& self, const at::Tensor& batch1, const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha) {
660626
DEBUG("AtenIpexCPUDev::dil_addbmm_\n");
661627

662-
at::Tensor result = at::empty({0}, self.options());
663-
return dil_addbmm_out(self, result, batch1, batch2, beta, alpha);
628+
return dil_addbmm_common</*inplace=*/true>(self, self, batch1, batch2, beta, alpha);
664629
}
665630

666631
at::Tensor AtenIpexCPUDev::dil_linear(

0 commit comments

Comments
 (0)