From a8af68ab89b60b77e20448a4fb1ac6c11fabe459 Mon Sep 17 00:00:00 2001 From: pinzhenx Date: Fri, 19 Jun 2020 16:34:56 +0000 Subject: [PATCH] refine all mm and binary ops --- intel_pytorch_extension_py/ops/jit_script.py | 3 - scripts/cpu/gen-dense-cpu-ops.py | 1 + tests/cpu/test_bf16_lazy_reorder.py | 5 +- tests/cpu/test_lazy_reorder.py | 91 +++++ torch_ipex/csrc/cpu/DevOPs.cpp | 355 +++++++++---------- torch_ipex/csrc/cpu/DevOPs.h | 1 + torch_ipex/csrc/cpu/dbl/Common.cpp | 26 +- 7 files changed, 287 insertions(+), 195 deletions(-) diff --git a/intel_pytorch_extension_py/ops/jit_script.py b/intel_pytorch_extension_py/ops/jit_script.py index 613b7c590..b0c0a00a3 100644 --- a/intel_pytorch_extension_py/ops/jit_script.py +++ b/intel_pytorch_extension_py/ops/jit_script.py @@ -13,10 +13,7 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None): torch.jit.script = script_ if core.get_jit_opt(): - # bypass buggy broadcastable ops in dnnl during folding - core.disable_auto_dnnl() jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c)) - core.enable_auto_dnnl() return jit_m diff --git a/scripts/cpu/gen-dense-cpu-ops.py b/scripts/cpu/gen-dense-cpu-ops.py index 745a5df61..a0b4b5d1f 100755 --- a/scripts/cpu/gen-dense-cpu-ops.py +++ b/scripts/cpu/gen-dense-cpu-ops.py @@ -63,6 +63,7 @@ 'aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)', 'aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor', 'aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)', + 'aten::size.int(Tensor self, int dim) -> int', 'aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor', 'aten::gelu(Tensor self) -> Tensor', 'aten::gelu_backward(Tensor grad, Tensor self) -> Tensor', diff --git a/tests/cpu/test_bf16_lazy_reorder.py b/tests/cpu/test_bf16_lazy_reorder.py index 87228e1a8..c9b65e588 100644 --- a/tests/cpu/test_bf16_lazy_reorder.py +++ b/tests/cpu/test_bf16_lazy_reorder.py @@ -463,7 +463,7 @@ def test_mm_out(self): def test_bmm(self): rand_seed = int(get_rand_seed()) print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed)) - x_auto_mix_a, x_auto_mix_b, _, x_man_bf16_a, x_man_bf16_b, _ = self._gen_mm_tensor(rand_seed) + x_auto_mix_a, x_auto_mix_b, _, x_man_bf16_a, x_man_bf16_b, _ = self._gen_mm_tensor(rand_seed, batches=16) with AutoDNNL(True), AutoMixPrecision(False): res_man_bf16 = torch.bmm(x_man_bf16_a, x_man_bf16_b) @@ -477,8 +477,7 @@ def test_bmm(self): def test_bmm_out(self): rand_seed = int(get_rand_seed()) print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed)) - x_auto_mix_a, x_auto_mix_b, res_auto_mix, x_man_bf16_a, x_man_bf16_b, res_man_bf16 = self._gen_mm_tensor(rand_seed) - + x_auto_mix_a, x_auto_mix_b, res_auto_mix, x_man_bf16_a, x_man_bf16_b, res_man_bf16 = self._gen_mm_tensor(rand_seed, batches=16) with AutoDNNL(True), AutoMixPrecision(False): torch.bmm(x_man_bf16_a, x_man_bf16_b, out=res_man_bf16) self.assertEqual(res_man_bf16.dtype, torch.bfloat16) diff --git a/tests/cpu/test_lazy_reorder.py b/tests/cpu/test_lazy_reorder.py index 9039b86b9..8e74c9b35 100644 --- a/tests/cpu/test_lazy_reorder.py +++ b/tests/cpu/test_lazy_reorder.py @@ -33,6 +33,13 @@ def get_rand_seed(): return int(time.time() * 1000000000) device = ipex.DEVICE + +def convert_blocked(t): + assert t.dim() == 4, "only support converting 4d tensor" + c = t.size(1) + t = t.clone().to(device) + return F.conv2d(t, torch.ones(c, 1, 1, 1).to(device), groups=c) + class TestConv(TestCase): def test_Conv2d_with_cpu(self): rand_seed = int(get_rand_seed()) @@ -202,6 +209,78 @@ def test_mul_(self): a2 = self._test_mul_('cpu', rand_seed) self.assertEqual(a2, a1.to('cpu')) + def test_mixed_format(self): + ipex.core.enable_auto_dnnl() + rand_seed = int(get_rand_seed()) + print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed)) + torch.manual_seed(rand_seed) + + shape = (2, 3, 4, 5) + + for fname in ['add', 'mul']: + + x_cpu = torch.ones(shape) * 5 + y_cpu = torch.ones(shape) * 4 + + # block tensor is a dpcpp tensor + x_plain = x_cpu.clone().to(device) + y_plain = y_cpu.clone().to(device) + x_block = convert_blocked(x_cpu.clone()) + y_block = convert_blocked(y_cpu.clone()) + + fn = getattr(torch, fname) + ref = fn(x_cpu, y_cpu) + + # test add, mul + def test_outplace(a, b): + a = a.clone() + b = b.clone() + self.assertEqual(fn(a, b), ref) + + test_outplace(x_plain, y_plain) + test_outplace(x_plain, y_block) + test_outplace(y_block, x_plain) + test_outplace(x_block, y_block) + + # test add_out, mul_out + def test_out(a, b, o): + a = a.clone() + b = b.clone() + o = o.clone() + y = fn(a, b, out=o) + self.assertEqual(y, ref) + self.assertEqual(o, ref) + + out = torch.ones(shape).to(device) + test_out(x_plain, y_plain, out) + test_out(x_plain, y_block, out) + test_out(y_block, x_plain, out) + test_out(x_block, y_block, out) + out = torch.ones(1).to(device) + test_out(x_plain, y_plain, out) + test_out(x_plain, y_block, out) + test_out(y_block, x_plain, out) + test_out(x_block, y_block, out) + + # test add_, mul_ + def test_inplace(a, b): + a = a.clone() + b = b.clone() + y = getattr(a, fname + '_')(b) + self.assertEqual(a, ref) + self.assertEqual(y, ref) + + test_inplace(x_plain, y_plain) + test_inplace(x_plain, y_block) + test_inplace(y_block, x_plain) + test_inplace(x_block, y_block) + + # test broadcast + scalar = torch.ones(1).to(device) + self.assertEqual(fn(x_plain, scalar), fn(x_cpu, scalar)) + self.assertEqual(fn(scalar, x_plain), fn(scalar, x_cpu)) + + class TestRelu(TestCase): def _test_relu_(self, device, rand_seed): torch.manual_seed(rand_seed) @@ -388,6 +467,11 @@ def test_addmm(self): torch.addmm(input=res_dpcpp, mat1=b1_dpcpp, mat2=b2_dpcpp, alpha=alpha, beta=beta, out=y_dpcpp) self.assertEqual(y_cpu, y_dpcpp) + res_cpu.addmm_(mat1=b1_cpu, mat2=b2_cpu, alpha=alpha, beta=beta) + res_dpcpp.addmm_(mat1=b1_cpu, mat2=b2_cpu, alpha=alpha, beta=beta) + self.assertEqual(res_cpu, res_dpcpp) + + def test_addbmm(self): ipex.core.enable_auto_dnnl() rand_seed = int(get_rand_seed()) @@ -415,6 +499,10 @@ def test_addbmm(self): torch.addbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha, out=y_dpcpp) self.assertEqual(y_cpu, y_dpcpp, 1e-4) + res_cpu.addbmm_(b1_cpu, b2_cpu, beta=beta, alpha=alpha) + res_dpcpp.addbmm_(b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha) + self.assertEqual(res_cpu, res_dpcpp, 1e-4) + def test_baddbmm(self): ipex.core.enable_auto_dnnl() rand_seed = int(get_rand_seed()) @@ -441,6 +529,9 @@ def test_baddbmm(self): torch.baddbmm(res_cpu, b1_cpu, b2_cpu, alpha=alpha, beta=beta, out=y_cpu), torch.baddbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, alpha=alpha, beta=beta, out=y_dpcpp), self.assertEqual(y_cpu, y_dpcpp) + res_cpu.baddbmm_(b1_cpu, b2_cpu, alpha=alpha, beta=beta) + res_dpcpp.baddbmm_(b1_cpu, b2_cpu, alpha=alpha, beta=beta) + self.assertEqual(res_cpu, res_dpcpp) class TestLinear(TestCase): def test_linear(self): diff --git a/torch_ipex/csrc/cpu/DevOPs.cpp b/torch_ipex/csrc/cpu/DevOPs.cpp index 5d116591e..feea64399 100644 --- a/torch_ipex/csrc/cpu/DevOPs.cpp +++ b/torch_ipex/csrc/cpu/DevOPs.cpp @@ -254,107 +254,95 @@ std::tuple AtenIpexCPUDev::mkldnn_convolution_ return std::tuple(bridge::shallowUpgradeToDPCPPTensor(std::get<0>(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor(std::get<1>(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor(std::get<2>(_ipex_result))); } -at::Tensor& AtenIpexCPUDev::dil_add_out( +template +at::Tensor& dil_add_common( at::Tensor& result, const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) { - DEBUG("AtenIpexCPUDev::dil_add_out\n"); CHECK_DNNL_OP_PRE_COND(self); CHECK_DNNL_OP_PRE_COND(other); + TORCH_CHECK(self.sizes().equals(other.sizes()), + "dil add not support broadcast yet"); + dbl::comm::reorder_to_bf16_for_mix_prec(self); dbl::comm::reorder_to_bf16_for_mix_prec(other); - dbl::comm::reorder_to_bf16_for_mix_prec(result); - dil::tensor x = dbl::comm::try_gen_dil_tensor(self); - dil::tensor y = dbl::comm::try_gen_dil_tensor(other); + auto x = dbl::comm::try_gen_dil_tensor(self); + auto y = dbl::comm::try_gen_dil_tensor(other); + auto z = inplace ? x : dil::tensor(); - dil::tensor z = dbl::comm::try_gen_dil_tensor(result); - const std::vector scales{1.0, alpha.to()}; - dil::sum::compute(scales, {x, y}, z); + dil::sum::compute({1.0, alpha.to()}, {x, y}, z); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(z.is_public_format() || check_tensor_own_whole_storage(result)); - dbl::comm::sync_shape_from_dil_to_aten(result, z); + if (!inplace) { + dbl::comm::equip_dil_buffer(result, z); + } return result; } -at::Tensor AtenIpexCPUDev::dil_add(const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) { - DEBUG("AtenIpexCPUDev::dil_add\n"); - CHECK_DNNL_OP_PRE_COND(self); - CHECK_DNNL_OP_PRE_COND(other); - - dbl::comm::reorder_to_bf16_for_mix_prec(self); - dbl::comm::reorder_to_bf16_for_mix_prec(other); +at::Tensor& AtenIpexCPUDev::dil_add_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) { + DEBUG("AtenIpexCPUDev::dil_add_out\n"); - dil::tensor x = dbl::comm::try_gen_dil_tensor(self); - dil::tensor y = dbl::comm::try_gen_dil_tensor(other); + return dil_add_common(result, self, other, alpha); +} - dil::tensor z; - const std::vector scales{1.0, alpha.to()}; - dil::sum::compute(scales, {x, y}, z); +at::Tensor AtenIpexCPUDev::dil_add(const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) { + DEBUG("AtenIpexCPUDev::dil_add\n"); - return dbl::comm::gen_aten_tensor_by(std::move(z)); + auto result = dbl::comm::empty_dil_tensor({0}, self.options()); + return dil_add_common(result, self, other, alpha); } at::Tensor & AtenIpexCPUDev::dil_add_(at::Tensor& self, const at::Tensor& other, at::Scalar alpha) { DEBUG("AtenIpexCPUDev::dil_add_\n"); + + return dil_add_common(self, self, other, alpha); +} + +template +at::Tensor& dil_mul_common( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& other) { CHECK_DNNL_OP_PRE_COND(self); CHECK_DNNL_OP_PRE_COND(other); + TORCH_CHECK(self.sizes().equals(other.sizes()), + "dil mul not support broadcast yet"); + dbl::comm::reorder_to_bf16_for_mix_prec(self); dbl::comm::reorder_to_bf16_for_mix_prec(other); - auto dil_self = dbl::comm::try_gen_dil_tensor(self); - auto dil_other = dbl::comm::try_gen_dil_tensor(other); + auto x = dbl::comm::try_gen_dil_tensor(self); + auto y = dbl::comm::try_gen_dil_tensor(other); + auto z = inplace ? x : dil::tensor(); - const std::vector scales{1.0, alpha.to()}; - dil::sum::compute(scales, {dil_self, dil_other}, dil_self); + dil::binary::compute(x, y, z, dil::algorithm::binary_mul); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_self.is_public_format() || check_tensor_own_whole_storage(self)); - dbl::comm::sync_shape_from_dil_to_aten(self, dil_self); - return self; + if (!inplace) { + dbl::comm::equip_dil_buffer(result, z); + } + return result; } at::Tensor& AtenIpexCPUDev::dil_mul_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& other) { DEBUG("AtenIpexCPUDev::dil_mul_out\n"); - CHECK_DNNL_OP_PRE_COND(result); - CHECK_DNNL_OP_PRE_COND(self); - CHECK_DNNL_OP_PRE_COND(other); - - dbl::comm::reorder_to_bf16_for_mix_prec(self); - dbl::comm::reorder_to_bf16_for_mix_prec(other); - dbl::comm::reorder_to_bf16_for_mix_prec(result); - - auto dil_result = dbl::comm::try_gen_dil_tensor(result); - auto dil_self = dbl::comm::try_gen_dil_tensor(self); - auto dil_other = dbl::comm::try_gen_dil_tensor(other); - dil::binary::compute(dil_self, dil_other, dil_result, dil::algorithm::binary_mul); - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_result.is_public_format() || check_tensor_own_whole_storage(result)); - dbl::comm::sync_shape_from_dil_to_aten(result, dil_result); - return result; + return dil_mul_common(result, self, other); } at::Tensor AtenIpexCPUDev::dil_mul(const at::Tensor& self, const at::Tensor& other) { DEBUG("AtenIpexCPUDev::dil_mul\n"); - dbl::comm::reorder_to_bf16_for_mix_prec(self); - dbl::comm::reorder_to_bf16_for_mix_prec(other); - - at::Tensor result = dbl::comm::empty_dil_tensor(self.sizes(), self.options()); - - return dil_mul_out(result, self, other); + auto result = dbl::comm::empty_dil_tensor({0}, self.options()); + return dil_mul_common(result, self, other); } at::Tensor& AtenIpexCPUDev::dil_mul_(at::Tensor& self, const at::Tensor& other) { DEBUG("AtenIpexCPUDev::dil_mul_\n"); - dbl::comm::reorder_to_bf16_for_mix_prec(self); - dbl::comm::reorder_to_bf16_for_mix_prec(other); - - return dil_mul_out(self, self, other); + return dil_mul_common(self, self, other); } void matmul_common( @@ -382,86 +370,83 @@ void matmul_common( dil::scale_t(), dil::scale_t(), dil::scale_t(), attr); } -at::Tensor AtenIpexCPUDev::dil_bmm( - const at::Tensor& self, - const at::Tensor& mat2) { +at::Tensor AtenIpexCPUDev::dil_bmm(const at::Tensor& self, const at::Tensor& mat2) { DEBUG("AtenIpexCPUDev::dil_bmm\n"); - dbl::comm::reorder_to_bf16_for_mix_prec(self); - dbl::comm::reorder_to_bf16_for_mix_prec(mat2); - - auto self_size = self.sizes(); - std::vector result_size(self_size.begin(), self_size.end()-1); - result_size.push_back(mat2.size(-1)); - at::Tensor result = dbl::comm::empty_dil_tensor(result_size, self.options()); + auto result = dbl::comm::empty_dil_tensor({0}, self.options()); return dil_bmm_out(result, self, mat2); } -at::Tensor& AtenIpexCPUDev::dil_bmm_out( - at::Tensor &result, - const at::Tensor& batch1, - const at::Tensor& batch2) { +at::Tensor& AtenIpexCPUDev::dil_bmm_out(at::Tensor &result, const at::Tensor& batch1, const at::Tensor& batch2) { DEBUG("AtenIpexCPUDev::dil_bmm_out\n"); CHECK_DNNL_OP_PRE_COND(batch1); CHECK_DNNL_OP_PRE_COND(batch2); - dbl::comm::reorder_to_bf16_for_mix_prec(result); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batch1.dim() == 3 && batch2.dim() == 3); + dil::dims inferred_size{batch1.size(0), batch1.size(1), batch2.size(2)}; + dbl::comm::reorder_to_bf16_for_mix_prec(batch1); dbl::comm::reorder_to_bf16_for_mix_prec(batch2); - const dil::tensor x = dbl::comm::try_gen_dil_tensor(batch1); - const dil::tensor w = dbl::comm::try_gen_dil_tensor(batch2); - dil::tensor y = dbl::comm::try_gen_dil_tensor(result); + auto x = dbl::comm::try_gen_dil_tensor(batch1); + auto w = dbl::comm::try_gen_dil_tensor(batch2); + dil::tensor y; matmul_common(x, w, dil::tensor(), y); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result)); - dbl::comm::sync_shape_from_dil_to_aten(result, y); + dbl::comm::equip_dil_buffer(result, y); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(inferred_size)); return result; } -at::Tensor AtenIpexCPUDev::dil_mm( - const at::Tensor& self, - const at::Tensor& mat2) { +at::Tensor AtenIpexCPUDev::dil_mm(const at::Tensor& self, const at::Tensor& mat2) { DEBUG("AtenIpexCPUDev::dil_mm\n"); - dbl::comm::reorder_to_bf16_for_mix_prec(self); - dbl::comm::reorder_to_bf16_for_mix_prec(mat2); - - return dil_bmm(self, mat2); + auto result = dbl::comm::empty_dil_tensor({0}, self.options()); + return dil_mm_out(result, self, mat2); } -at::Tensor& AtenIpexCPUDev::dil_mm_out( - at::Tensor& result, - const at::Tensor& self, - const at::Tensor& mat2) { +at::Tensor& AtenIpexCPUDev::dil_mm_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& mat2) { DEBUG("AtenIpexCPUDev::dil_mm_out\n"); - dbl::comm::reorder_to_bf16_for_mix_prec(result); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.dim() == 2 && mat2.dim() == 2); + dil::dims inferred_size{self.size(0), mat2.size(1)}; + dbl::comm::reorder_to_bf16_for_mix_prec(self); dbl::comm::reorder_to_bf16_for_mix_prec(mat2); - return dil_bmm_out(result, self, mat2); + auto x = dbl::comm::try_gen_dil_tensor(self); + auto w = dbl::comm::try_gen_dil_tensor(mat2); + dil::tensor y; + matmul_common(x, w, dil::tensor(), y); + + dbl::comm::equip_dil_buffer(result, y); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(inferred_size)); + return result; } -at::Tensor& AtenIpexCPUDev::dil_baddbmm_out( +template +at::Tensor& dil_baddbmm_common( at::Tensor &result, const at::Tensor& self, const at::Tensor& batch1, const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha) { - DEBUG("AtenIpexCPUDev::dil_baddbmm_out\n"); CHECK_DNNL_OP_PRE_COND(self); CHECK_DNNL_OP_PRE_COND(batch1); CHECK_DNNL_OP_PRE_COND(batch2); - dbl::comm::reorder_to_bf16_for_mix_prec(result); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batch1.dim() == 3 && batch2.dim() == 3); + dil::dims inferred_size{batch1.size(0), batch1.size(1), batch2.size(2)}; + TORCH_CHECK(self.sizes().equals(inferred_size), + "dil baddbmm not support broadcast yet"); + dbl::comm::reorder_to_bf16_for_mix_prec(self); dbl::comm::reorder_to_bf16_for_mix_prec(batch1); dbl::comm::reorder_to_bf16_for_mix_prec(batch2); - const dil::tensor x = dbl::comm::try_gen_dil_tensor(batch1); - const dil::tensor w = dbl::comm::try_gen_dil_tensor(batch2); + auto x = dbl::comm::try_gen_dil_tensor(batch1); + auto w = dbl::comm::try_gen_dil_tensor(batch2); dil::tensor bias; if (self.numel() != 0) { bias = dbl::comm::try_gen_dil_tensor(self); @@ -471,110 +456,121 @@ at::Tensor& AtenIpexCPUDev::dil_baddbmm_out( bias.reshape(bias_dims); } } - dil::tensor y = dbl::comm::try_gen_dil_tensor(result); + auto y = inplace ? dbl::comm::try_gen_dil_tensor(self) : dil::tensor(); auto attr_ = dil::attr_t::fuse_sum(); matmul_common(x, w, bias, y, beta, alpha, attr_); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result)); - dbl::comm::sync_shape_from_dil_to_aten(result, y); + if (!inplace) { + dbl::comm::equip_dil_buffer(result, y); + } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(inferred_size)); return result; } -at::Tensor AtenIpexCPUDev::dil_baddbmm( +at::Tensor& AtenIpexCPUDev::dil_baddbmm_out( + at::Tensor &result, const at::Tensor& self, const at::Tensor& batch1, - const at::Tensor & batch2, + const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha) { - DEBUG("AtenIpexCPUDev::dil_baddbmm\n"); + DEBUG("AtenIpexCPUDev::dil_baddbmm_out\n"); - dbl::comm::reorder_to_bf16_for_mix_prec(self); - dbl::comm::reorder_to_bf16_for_mix_prec(batch1); - dbl::comm::reorder_to_bf16_for_mix_prec(batch2); + return dil_baddbmm_common(result, self, batch1, batch2, beta, alpha); +} - auto self_size = batch1.sizes(); - std::vector result_size(self_size.begin(), self_size.end()-1); - result_size.push_back(batch2.size(-1)); - at::Tensor result = dbl::comm::empty_dil_tensor(result_size, self.options()); - return dil_baddbmm_out(result, self, batch1, batch2, beta, alpha); +at::Tensor AtenIpexCPUDev::dil_baddbmm(const at::Tensor& self, const at::Tensor& batch1, const at::Tensor & batch2, at::Scalar beta, at::Scalar alpha) { + DEBUG("AtenIpexCPUDev::dil_baddbmm\n"); + + auto result = dbl::comm::empty_dil_tensor({0}, self.options()); + return dil_baddbmm_common(result, self, batch1, batch2, beta, alpha); } -at::Tensor& AtenIpexCPUDev::dil_baddbmm_( - at::Tensor& self, - const at::Tensor& batch1, - const at::Tensor& batch2, - at::Scalar beta, - at::Scalar alpha) { +at::Tensor& AtenIpexCPUDev::dil_baddbmm_(at::Tensor& self, const at::Tensor& batch1, const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha) { DEBUG("AtenIpexCPUDev::dil_baddbmm_\n"); - dbl::comm::reorder_to_bf16_for_mix_prec(self); - dbl::comm::reorder_to_bf16_for_mix_prec(batch1); - dbl::comm::reorder_to_bf16_for_mix_prec(batch2); - - at::Tensor result = at::empty({0}, self.options()); - return dil_baddbmm_out(self, result, batch1, batch2, beta, alpha); + return dil_baddbmm_out(self, self, batch1, batch2, beta, alpha); } -at::Tensor& AtenIpexCPUDev::dil_addmm_out( +template +at::Tensor& dil_addmm_common( at::Tensor& result, const at::Tensor& self, const at::Tensor& mat1, const at::Tensor& mat2, at::Scalar beta, at::Scalar alpha) { - DEBUG("AtenIpexCPUDev::dil_addmm_out\n"); + CHECK_DNNL_OP_PRE_COND(self); + CHECK_DNNL_OP_PRE_COND(mat1); + CHECK_DNNL_OP_PRE_COND(mat2); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat1.dim() == 2 && mat2.dim() == 2); + dil::dims inferred_size{mat1.size(0), mat2.size(1)}; + TORCH_CHECK(self.sizes().equals(inferred_size), + "dil addmm not support broadcast yet"); - dbl::comm::reorder_to_bf16_for_mix_prec(result); dbl::comm::reorder_to_bf16_for_mix_prec(self); dbl::comm::reorder_to_bf16_for_mix_prec(mat1); dbl::comm::reorder_to_bf16_for_mix_prec(mat2); - return dil_baddbmm_out(result, self, mat1, mat2, beta, alpha); + auto x = dbl::comm::try_gen_dil_tensor(mat1); + auto w = dbl::comm::try_gen_dil_tensor(mat2); + dil::tensor bias; + if (self.numel() != 0) { + bias = dbl::comm::try_gen_dil_tensor(self); + if (bias.ndims() < x.ndims()) { + auto bias_dims = bias.get_dims(); + bias_dims.insert(bias_dims.begin(), 1); + bias.reshape(bias_dims); + } + } + auto y = inplace ? dbl::comm::try_gen_dil_tensor(self) : dil::tensor(); + auto attr_ = dil::attr_t::fuse_sum(); + matmul_common(x, w, bias, y, beta, alpha, attr_); + + if (!inplace) { + dbl::comm::equip_dil_buffer(result, y); + } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(inferred_size)); + return result; } -at::Tensor AtenIpexCPUDev::dil_addmm( - const at::Tensor& self, - const at::Tensor& batch1, - const at::Tensor & batch2, - at::Scalar beta, - at::Scalar alpha) { - DEBUG("AtenIpexCPUDev::dil_addmm\n"); +at::Tensor& AtenIpexCPUDev::dil_addmm_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& mat1, const at::Tensor& mat2, at::Scalar beta, at::Scalar alpha) { + DEBUG("AtenIpexCPUDev::dil_addmm_out\n"); - dbl::comm::reorder_to_bf16_for_mix_prec(self); - dbl::comm::reorder_to_bf16_for_mix_prec(batch1); - dbl::comm::reorder_to_bf16_for_mix_prec(batch2); + return dil_addmm_common(result, self, mat1, mat2, beta, alpha); +} - return dil_baddbmm(self, batch1, batch2, beta, alpha); +at::Tensor AtenIpexCPUDev::dil_addmm(const at::Tensor& self, const at::Tensor& mat1, const at::Tensor & mat2, at::Scalar beta, at::Scalar alpha) { + DEBUG("AtenIpexCPUDev::dil_addmm\n"); + + auto result = dbl::comm::empty_dil_tensor({0}, self.options()); + return dil_addmm_common(result, self, mat1, mat2, beta, alpha); } -at::Tensor& AtenIpexCPUDev::dil_addmm_( - at::Tensor& self, - const at::Tensor& batch1, - const at::Tensor & batch2, - at::Scalar beta, - at::Scalar alpha) { +at::Tensor& AtenIpexCPUDev::dil_addmm_(at::Tensor& self, const at::Tensor& mat1, const at::Tensor & mat2, at::Scalar beta, at::Scalar alpha) { DEBUG("AtenIpexCPUDev::dil_addmm_\n"); - dbl::comm::reorder_to_bf16_for_mix_prec(self); - dbl::comm::reorder_to_bf16_for_mix_prec(batch1); - dbl::comm::reorder_to_bf16_for_mix_prec(batch2); - - return dil_baddbmm_(self, batch1, batch2, beta, alpha); + return dil_addmm_common(self, self, mat1, mat2, beta, alpha); } -at::Tensor& AtenIpexCPUDev::dil_addbmm_out( +template +at::Tensor& dil_addbmm_common( at::Tensor& result, const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha) { - DEBUG("AtenIpexCPUDev::dil_addbmm_out\n"); CHECK_DNNL_OP_PRE_COND(self); CHECK_DNNL_OP_PRE_COND(batch1); CHECK_DNNL_OP_PRE_COND(batch2); - dbl::comm::reorder_to_bf16_for_mix_prec(result); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batch1.dim() == 3 && batch2.dim() == 3); + dil::dims inferred_size{batch1.size(1), batch2.size(2)}; + TORCH_CHECK(self.sizes().equals(inferred_size), + "dil addbmm not support broadcast yet"); + dbl::comm::reorder_to_bf16_for_mix_prec(self); dbl::comm::reorder_to_bf16_for_mix_prec(batch1); dbl::comm::reorder_to_bf16_for_mix_prec(batch2); @@ -583,18 +579,16 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out( // [n, b*m] * [b*m, p] = [n, p] // For batch1: reorder from [b, n, m] to [n, b, m], reshape to [n, b*m] // For batch2: reshape from [b, m, p] to [b*m, p] - const dil::tensor x = dbl::comm::try_gen_dil_tensor(batch1); - dil::tensor w = dbl::comm::try_gen_dil_tensor(batch2); + auto x = dbl::comm::try_gen_dil_tensor(batch1); + auto w = dbl::comm::try_gen_dil_tensor(batch2); auto x_ = x; if (x.get_dim(0) > 1) { x_ = x.transpose(0, 1); } - dil::dims x_dims = {x.get_dim(1), x.get_dim(0) * x.get_dim(2)}; - x_ = x_.reshape(x_dims); - dil::dims w_dims = {w.get_dim(0) * w.get_dim(1), w.get_dim(2)}; - auto w_ = w.reshape(w_dims); - dil::tensor y = dbl::comm::try_gen_dil_tensor(result); + x_ = x_.reshape({x.get_dim(1), x.get_dim(0) * x.get_dim(2)}); + auto w_ = w.reshape({w.get_dim(0) * w.get_dim(1), w.get_dim(2)}); + auto y = inplace ? dbl::comm::try_gen_dil_tensor(self) : dil::tensor(); auto attr_ = dil::attr_t::fuse_sum(); dil::tensor bias; @@ -608,41 +602,30 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out( } matmul_common(x_, w_, bias, y, beta, alpha, attr_); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result)); - dbl::comm::sync_shape_from_dil_to_aten(result, y); + if (!inplace) { + dbl::comm::equip_dil_buffer(result, y); + } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(inferred_size)); return result; } -at::Tensor AtenIpexCPUDev::dil_addbmm( - const at::Tensor &self, - const at::Tensor &batch1, - const at::Tensor &batch2, - at::Scalar beta, - at::Scalar alpha) { - DEBUG("AtenIpexCPUDev::dil_addbmm\n"); +at::Tensor& AtenIpexCPUDev::dil_addbmm_out(at::Tensor& result, const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha) { + DEBUG("AtenIpexCPUDev::dil_addbmm_out\n"); - dbl::comm::reorder_to_bf16_for_mix_prec(self); - dbl::comm::reorder_to_bf16_for_mix_prec(batch1); - dbl::comm::reorder_to_bf16_for_mix_prec(batch2); + return dil_addbmm_common(result, self, batch1, batch2, beta, alpha); +} + +at::Tensor AtenIpexCPUDev::dil_addbmm(const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha) { + DEBUG("AtenIpexCPUDev::dil_addbmm\n"); - at::Tensor result = dbl::comm::empty_dil_tensor(self.sizes(), self.options()); - return dil_addbmm_out(result, self, batch1, batch2, beta, alpha); + auto result = dbl::comm::empty_dil_tensor({0}, self.options()); + return dil_addbmm_common(result, self, batch1, batch2, beta, alpha); } -at::Tensor& AtenIpexCPUDev::dil_addbmm_( - at::Tensor& self, - const at::Tensor& batch1, - const at::Tensor& batch2, - at::Scalar beta, - at::Scalar alpha) { +at::Tensor& AtenIpexCPUDev::dil_addbmm_(at::Tensor& self, const at::Tensor& batch1, const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha) { DEBUG("AtenIpexCPUDev::dil_addbmm_\n"); - dbl::comm::reorder_to_bf16_for_mix_prec(self); - dbl::comm::reorder_to_bf16_for_mix_prec(batch1); - dbl::comm::reorder_to_bf16_for_mix_prec(batch2); - - at::Tensor result = at::empty({0}, self.options()); - return dil_addbmm_out(self, result, batch1, batch2, beta, alpha); + return dil_addbmm_common(self, self, batch1, batch2, beta, alpha); } at::Tensor AtenIpexCPUDev::dil_linear( @@ -1356,6 +1339,14 @@ at::Tensor AtenIpexCPUDev::dil_reshape(const at::Tensor& self, at::IntArrayRef s return dbl::comm::gen_aten_tensor_by(std::move(y)); } +int64_t AtenIpexCPUDev::dil_size(const at::Tensor & self, int64_t dim) { + DEBUG("AtenIpexCPUDev::dil_size\n"); + CHECK_DNNL_OP_PRE_COND(self); + + dim = at::maybe_wrap_dim(dim, self.dim(), false); + return self.sizes()[dim]; +} + at::Tensor AtenIpexCPUDev::dil_clone(const at::Tensor& self, c10::optional optional_memory_format) { DEBUG("AtenIpexCPUDev::dil_clone\n"); CHECK_DNNL_OP_PRE_COND(self); diff --git a/torch_ipex/csrc/cpu/DevOPs.h b/torch_ipex/csrc/cpu/DevOPs.h index d3745212f..58ddc3846 100644 --- a/torch_ipex/csrc/cpu/DevOPs.h +++ b/torch_ipex/csrc/cpu/DevOPs.h @@ -65,6 +65,7 @@ class AtenIpexCPUDev { static at::Tensor& dil_sigmoid_(at::Tensor& self); static at::Tensor dil_sigmoid_backward(const at::Tensor& grad_output, const at::Tensor& output); static at::Tensor dil_reshape(const at::Tensor& self, at::IntArrayRef size); + static int64_t dil_size(const at::Tensor & self, int64_t dim); static at::Tensor dil_clone(const at::Tensor& self, c10::optional optional_memory_format); static at::Tensor dil_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1); static at::Tensor& dil_cat_out(at::Tensor& result, at::TensorList tensors, int64_t dim); diff --git a/torch_ipex/csrc/cpu/dbl/Common.cpp b/torch_ipex/csrc/cpu/dbl/Common.cpp index bb19f4826..3c6c27993 100644 --- a/torch_ipex/csrc/cpu/dbl/Common.cpp +++ b/torch_ipex/csrc/cpu/dbl/Common.cpp @@ -89,6 +89,15 @@ void reorder_to_desc(const at::Tensor& tensor, const dil::tensor::desc& expected } void equip_dil_buffer(const at::Tensor& tensor, dil::tensor dil_tensor_buffer) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + tensor.device().is_dpcpp(), + "dil buffer can only be equipped to dpcpp tensor"); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + check_tensor_own_whole_storage(tensor), + "dil buffer can only be equipped to tensors that own the whole storage, " + "as dil buffer is going to replace the original storage"); + // Build new shade data context cpu::ShadeDataContext *new_shade_data_context = cpu::ShadeDataContext::allocShadeDataContext(); new_shade_data_context->data_type = cpu::SHADE_DATA_TYPE::DIL; @@ -97,13 +106,10 @@ void equip_dil_buffer(const at::Tensor& tensor, dil::tensor dil_tensor_buffer) { void *tensor_data = nullptr; if (dil_tensor_buffer.get_data_type() != get_dil_data_type(tensor.scalar_type())) { new_shade_data_context->mix_prec_type = cpu::MIX_PREC_TYPE::MIX_BF16_FP32; - } else { - if (dil_tensor_buffer.is_public_format()) { - tensor_data = dil_tensor_buffer.get_data_handle(); - new_shade_data_context->cpu_raw_data = tensor_data; - new_shade_data_context->cpu_del_fun = &(c10::detail::deleteNothing); - sync_shape_from_dil_to_aten(tensor, dil_tensor_buffer); - } + } else if (dil_tensor_buffer.is_public_format()) { + tensor_data = dil_tensor_buffer.get_data_handle(); + new_shade_data_context->cpu_raw_data = tensor_data; + new_shade_data_context->cpu_del_fun = &(c10::detail::deleteNothing); } // Create a new DataPtr instances because the DataPtr class does not support set @@ -116,6 +122,12 @@ void equip_dil_buffer(const at::Tensor& tensor, dil::tensor dil_tensor_buffer) { IPEXTensorImpl* ipex_tensor_impl = (IPEXTensorImpl *)tensor.unsafeGetTensorImpl(); ipex_tensor_impl->storage().set_data_ptr(std::move(shade_data_ptr)); + + // After equip_dil_buffer(), whole storage should be managed by dil tensor, + // and thus storage metadata should be overwritten by dil tensor + // Note: Storage::set_numel() might be removed later + ipex_tensor_impl->storage().set_numel(dil_tensor_buffer.get_nelems()); + cpu::dbl::comm::sync_shape_from_dil_to_aten(tensor, dil_tensor_buffer); } dil::tensor try_gen_dil_tensor(const at::Tensor &input) {