diff --git a/tests/cpu/test_lazy_reorder.py b/tests/cpu/test_lazy_reorder.py index 34a5429c9..34d4f1fc0 100644 --- a/tests/cpu/test_lazy_reorder.py +++ b/tests/cpu/test_lazy_reorder.py @@ -252,6 +252,22 @@ def _test_conv_add_relu_(self, device, rand_seed): return conv_op_output, conv_op_input, add_src + def _test_conv_relu_(self, device, rand_seed): + ipex.enable_auto_dnnl() + torch.manual_seed(rand_seed) + conv_op = torch.nn.Conv2d(1, 1, (7, 7)).to(device=device) + conv_op_input = torch.rand((1, 1, 10, 10)).to(device=device) + conv_op_output = conv_op(conv_op_input) + conv_op_output.relu_() + return conv_op_output + + def test_conv_relu_(self): + rand_seed = int(get_rand_seed()) + res_dcpp_dnnl = self._test_conv_relu_("dpcpp:0", rand_seed) + self.assertTrue(ipex.is_dil_tensor(res_dcpp_dnnl)) + res_cpu = self._test_conv_relu_("cpu", rand_seed) + self.assertEqual(res_cpu, res_dcpp_dnnl.to('cpu')) + def test_conv_add_relu_(self): ipex.enable_auto_dnnl() rand_seed = int(get_rand_seed()) @@ -260,18 +276,18 @@ def test_conv_add_relu_(self): ipex.disable_auto_dnnl() res_dcpp_cpu, input_dpcpp_cpu, _ = self._test_conv_add_relu_("dpcpp:0", rand_seed) - + res_cpu, input_cpu, _ = self._test_conv_add_relu_("cpu", rand_seed) self.assertEqual(res_cpu, res_dcpp_cpu.to('cpu')) self.assertEqual(res_cpu, res_dcpp_dnnl.to('cpu')) ipex.enable_auto_dnnl() - res_dcpp_dnnl.sum()#.backward() - res_dcpp_cpu.sum()#.backward() - res_cpu.sum()#.backward() + res_dcpp_dnnl.sum().backward() + res_dcpp_cpu.sum().backward() + res_cpu.sum().backward() - #self.assertEqual(input_dpcpp_dnnl.grad.to('cpu'), input_cpu.grad, prec=0.0) - #self.assertEqual(input_dpcpp_cpu.grad.to('cpu'), input_cpu.grad, prec=0.0) + self.assertEqual(input_dpcpp_dnnl.grad.to('cpu'), input_cpu.grad, prec=0.0) + self.assertEqual(input_dpcpp_cpu.grad.to('cpu'), input_cpu.grad, prec=0.0) class TestLinearAlgebraOps(TestCase): def test_mm(self): diff --git a/torch_ipex/csrc/cpu/DevOPs.cpp b/torch_ipex/csrc/cpu/DevOPs.cpp index 9f9c24806..6e3283b3b 100644 --- a/torch_ipex/csrc/cpu/DevOPs.cpp +++ b/torch_ipex/csrc/cpu/DevOPs.cpp @@ -239,6 +239,7 @@ at::Tensor& AtenIpexCPUDev::dil_add_out( const std::vector scales{1.0, alpha.to()}; dil::sum::compute(scales, {x, y}, z); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(z.is_public_format() || check_tensor_own_whole_storage(result)); dbl::comm::sync_shape_from_dil_to_aten(result, z); return result; } @@ -267,6 +268,7 @@ at::Tensor & AtenIpexCPUDev::dil_add_(at::Tensor& self, const at::Tensor& other, const std::vector scales{1.0, alpha.to()}; dil::sum::compute(scales, {dil_self, dil_other}, dil_self); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_self.is_public_format() || check_tensor_own_whole_storage(self)); dbl::comm::sync_shape_from_dil_to_aten(self, dil_self); return self; } @@ -283,6 +285,7 @@ at::Tensor& AtenIpexCPUDev::dil_mul_out(at::Tensor& result, const at::Tensor& se dil::binary::compute(dil_self, dil_other, dil_result, dil::algorithm::binary_mul); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_result.is_public_format() || check_tensor_own_whole_storage(result)); dbl::comm::sync_shape_from_dil_to_aten(result, dil_result); return result; } @@ -346,6 +349,7 @@ at::Tensor& AtenIpexCPUDev::dil_bmm_out( dil::tensor y = dbl::comm::try_gen_dil_tensor(result); matmul_common(x, w, dil::tensor(), y); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result)); dbl::comm::sync_shape_from_dil_to_aten(result, y); return result; } @@ -390,6 +394,8 @@ at::Tensor& AtenIpexCPUDev::dil_baddbmm_out( dil::tensor y = dbl::comm::try_gen_dil_tensor(result); auto attr_ = dil::attr_t::fuse_sum(); matmul_common(x, w, bias, y, beta, alpha, attr_); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result)); dbl::comm::sync_shape_from_dil_to_aten(result, y); return result; } @@ -489,6 +495,8 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out( } } matmul_common(x_, w_, bias, y, beta, alpha, attr_); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result)); dbl::comm::sync_shape_from_dil_to_aten(result, y); return result; } @@ -974,6 +982,8 @@ at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) { dil::algorithm::eltwise_relu, dil::prop_kind::forward_training, /*alpha*/ 0.0); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_self.is_public_format() || check_tensor_own_whole_storage(input)); dbl::comm::sync_shape_from_dil_to_aten(input, dil_self); return input; } @@ -1041,6 +1051,8 @@ at::Tensor& AtenIpexCPUDev::dil_sigmoid_(at::Tensor& self) { dil::tensor x = dbl::comm::try_gen_dil_tensor(self); dil::eltwise_forward::compute( x, x, dil::algorithm::eltwise_logistic_use_dst_for_bwd, dil::prop_kind::forward); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(x.is_public_format() || check_tensor_own_whole_storage(self)); dbl::comm::sync_shape_from_dil_to_aten(self, x); return self; } @@ -1122,6 +1134,8 @@ at::Tensor& AtenIpexCPUDev::dil_cat_out(at::Tensor& result, at::TensorList tenso } dil::tensor y = dbl::comm::try_gen_dil_tensor(result); dil::concat::compute(x, dim, y); + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result)); dbl::comm::sync_shape_from_dil_to_aten(result, y); return result; } diff --git a/torch_ipex/csrc/cpu/ShadeDataContext.h b/torch_ipex/csrc/cpu/ShadeDataContext.h index 37216a83e..9e323c3a6 100644 --- a/torch_ipex/csrc/cpu/ShadeDataContext.h +++ b/torch_ipex/csrc/cpu/ShadeDataContext.h @@ -94,6 +94,8 @@ struct ShadeDataContext { if (raw_cpu_data == nullptr) { // the dnnl tensor does not share data with raw tensor data. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_empty())); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_public_format())); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(check_tensor_own_whole_storage(tensor)); return true; } else { // The dnnl tensor shares some data with raw tensor. diff --git a/torch_ipex/csrc/cpu/dbl/Common.cpp b/torch_ipex/csrc/cpu/dbl/Common.cpp index 7432d3a9e..ae8f95d3a 100644 --- a/torch_ipex/csrc/cpu/dbl/Common.cpp +++ b/torch_ipex/csrc/cpu/dbl/Common.cpp @@ -40,6 +40,8 @@ dil::tensor try_gen_dil_tensor(const at::Tensor &input) { if ((!check_aten_dil_shape_info(input, dil_tensor)) && dil_tensor.is_public_format()) { dil_tensor.set_dims_and_strides(input.sizes().vec(), input.strides().vec()); } + // Does not support the case if the dil tensor is block format but it is just a part of tensor buffer + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_tensor.is_public_format() || check_tensor_own_whole_storage(input)); return dil_tensor; } else { return dil_tensor_from_dense(input); @@ -71,14 +73,7 @@ at::Tensor gen_aten_tensor_by(dil::tensor dil_tensor) { nullptr, /*resizeable=*/false); auto _tensor = at::detail::make_tensor(storage_impl, at::DispatchKey::DPCPPTensorId); - if (dil_tensor.is_public_format()) { - dbl::comm::sync_shape_from_dil_to_aten(_tensor, dil_tensor); - } else { - // Blockformat does not inlcude stride information - auto tensor_sizes = dil_tensor.get_dims(); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor_sizes.size() != 1 || tensor_sizes[0] != 0); - _tensor.unsafeGetTensorImpl()->set_sizes_contiguous(tensor_sizes); - } + dbl::comm::sync_shape_from_dil_to_aten(_tensor, dil_tensor); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_tensor.layout() == c10::kStrided); return _tensor; } @@ -94,10 +89,17 @@ at::Tensor empty_dil_tensor(at::IntArrayRef sizes, const at::TensorOptions& opti void sync_shape_from_dil_to_aten(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor) { dil::dims sizes = dil_tensor.get_dims(); - dil::dims strides = dil_tensor.get_strides(); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ipex_tensor.device().type() == at::DeviceType::DPCPP); - auto* _tensor_impl = (IPEXTensorImpl *)ipex_tensor.unsafeGetTensorImpl(); - _tensor_impl->force_set_strided(sizes, strides); + if (dil_tensor.is_public_format()) { + dil::dims strides = dil_tensor.get_strides(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ipex_tensor.device().type() == at::DeviceType::DPCPP); + auto* _tensor_impl = (IPEXTensorImpl *)ipex_tensor.unsafeGetTensorImpl(); + _tensor_impl->force_set_strided(sizes, strides); + } else { + // Blockformat does not inlcude stride information + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.size() != 1 || sizes[0] != 0); + ipex_tensor.unsafeGetTensorImpl()->set_sizes_contiguous(sizes); + } + } } // namespace comm