diff --git a/cmake/CPU.cmake b/cmake/CPU.cmake index 4440f42f2..589bbb97a 100644 --- a/cmake/CPU.cmake +++ b/cmake/CPU.cmake @@ -16,7 +16,7 @@ add_subdirectory(${DPCPP_THIRD_PARTY_ROOT}/mkl-dnn) # Define build type IF(CMAKE_BUILD_TYPE MATCHES Debug) message("Debug build.") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -D_DEBUG") ELSE() message("Release build.") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") diff --git a/tests/cpu/test_lazy_reorder.py b/tests/cpu/test_lazy_reorder.py index b8818b71d..39b79fdb6 100644 --- a/tests/cpu/test_lazy_reorder.py +++ b/tests/cpu/test_lazy_reorder.py @@ -749,6 +749,33 @@ def test_transpose(self): x_dpcpp.transpose(dim1, dim2), ) + def test_view(self): + ipex.enable_auto_dnnl() + old_shape = (4, 16) + new_shape = (1, 4, 4, 4) + + x_cpu = torch.randn(old_shape) + x_dpcpp = x_cpu.to(device=device).clone() + print(x_dpcpp.size()) + + x_cpu_view = x_cpu.view(new_shape) + print(x_cpu_view.size()) + x_dpcpp_view = x_dpcpp.view(new_shape) + print(x_dpcpp_view.size()) + + y = torch.randn(new_shape) + out_cpu = x_cpu_view * y + # test if the shape of x_dpcpp_view is compatible with y + out_dpcpp = x_dpcpp_view * y + self.assertEqual(out_cpu, out_dpcpp) + + # test if metadata of x_dpcpp has not been altered + y = torch.randn(old_shape) + out_cpu = x_cpu * y + out_dpcpp = x_dpcpp * y + self.assertEqual(out_cpu, out_dpcpp) + + class TestSoftMax(TestCase): def test_softmax(self): ipex.enable_auto_dnnl() diff --git a/tests/cpu/test_rn50_cpu_ops.py b/tests/cpu/test_rn50_cpu_ops.py index 3a4252519..4e0e53139 100644 --- a/tests/cpu/test_rn50_cpu_ops.py +++ b/tests/cpu/test_rn50_cpu_ops.py @@ -416,6 +416,9 @@ def test_view(self): self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) + # TODO(Eikan): DNNL OP does not support >6 dim tensor, so we disable it temporily. When we fix it, we will open it + old_dnnl_conf = ipex.get_auto_dnnl() + ipex.disable_auto_dnnl() # test view when tensor is not contiguous in every dimension, but only # contiguous dimensions are touched. tensor = torch.rand(4, 2, 5, 1, 6, 2, 9, 3, device=device).transpose(-1, 2).transpose(-2, 3) @@ -441,6 +444,10 @@ def test_view(self): # adding size 1 dims view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1] self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) + if old_dnnl_conf: + ipex.enable_auto_dnnl() + else: + ipex.disable_auto_dnnl() # invalid views self.assertRaises(RuntimeError, lambda: tensor.view(-1)) diff --git a/tests/cpu/test_torch.py b/tests/cpu/test_torch.py index af1472e19..6752ed3da 100644 --- a/tests/cpu/test_torch.py +++ b/tests/cpu/test_torch.py @@ -81,7 +81,7 @@ from multiprocessing.reduction import ForkingPickler from common_device_type import instantiate_device_type_tests, \ skipIf, skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, onlyCUDA, onlyCPU, \ - dtypes, dtypesIfCUDA, deviceCountAtLeast, skipCUDAIf, precisionOverride + dtypes, dtypesIfCUDA, deviceCountAtLeast, skipCUDAIf, precisionOverride, ipex import torch.backends.quantized @@ -8725,7 +8725,10 @@ def test_diagflat(self, device): # Noncontig input x = torch.randn((2, 3, 4), dtype=dtype, device=device).transpose(2, 0) - self.assertFalse(x.is_contiguous()) + if ipex.get_auto_dnnl(): + self.assertTrue(x.is_contiguous()) + else: + self.assertFalse(x.is_contiguous()) result = torch.diagflat(x) expected = torch.diag(x.contiguous().view(-1)) self.assertEqual(result, expected) @@ -9773,8 +9776,12 @@ def test_cdist_non_contiguous(self, device): y = torch.randn(5, 3, device=device).transpose(-1, -2) actual = torch.cdist(x, y, p=1, compute_mode=cm) expected = brute_cdist(x, y, p=1) - self.assertFalse(x.is_contiguous()) - self.assertFalse(y.is_contiguous()) + if ipex.get_auto_dnnl(): + self.assertTrue(x.is_contiguous()) + self.assertTrue(y.is_contiguous()) + else: + self.assertFalse(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) self.assertTrue(torch.allclose(expected, actual)) x = torch.randn(7, 5, device=device) @@ -9799,8 +9806,12 @@ def test_cdist_non_contiguous_batch(self, device): y = torch.randn(4, 3, 2, 5, 3, device=device).transpose(-1, -2) actual = torch.cdist(x, y, p=1, compute_mode=cm) expected = brute_cdist(x, y, p=1) - self.assertFalse(x.is_contiguous()) - self.assertFalse(y.is_contiguous()) + if ipex.get_auto_dnnl(): + self.assertTrue(x.is_contiguous()) + self.assertTrue(y.is_contiguous()) + else: + self.assertFalse(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) self.assertTrue(torch.allclose(expected, actual)) x = torch.randn(7, 2, 7, 5, device=device) @@ -9808,14 +9819,20 @@ def test_cdist_non_contiguous_batch(self, device): actual = torch.cdist(x, y, p=1, compute_mode=cm) expected = brute_cdist(x, y, p=1) self.assertTrue(x.is_contiguous()) - self.assertFalse(y.is_contiguous()) + if ipex.get_auto_dnnl(): + self.assertTrue(y.is_contiguous()) + else: + self.assertFalse(y.is_contiguous()) self.assertTrue(torch.allclose(expected, actual)) x = torch.randn(4, 5, 7, device=device).transpose(-1, -2) y = torch.randn(4, 3, 5, device=device) actual = torch.cdist(x, y, p=1, compute_mode=cm) expected = brute_cdist(x, y, p=1) - self.assertFalse(x.is_contiguous()) + if ipex.get_auto_dnnl(): + self.assertTrue(x.is_contiguous()) + else: + self.assertFalse(x.is_contiguous()) self.assertTrue(y.is_contiguous()) self.assertTrue(torch.allclose(expected, actual)) @@ -10249,6 +10266,7 @@ def test_unfold_scalars(self, device): def test_copy_all_dtypes_and_devices(self, device): from copy import copy + ipex.enable_auto_dnnl() for dt in torch.testing.get_all_dtypes(): x = torch.tensor([1, 2, 3, 4], dtype=dt, device=device) x_clone = x.clone() @@ -10264,6 +10282,7 @@ def test_copy_all_dtypes_and_devices(self, device): # copy is a shallow copy, only copies the tensor view, # not the data self.assertEqual(x, y) + ipex.enable_auto_dnnl() def test_resize_all_dtypes_and_devices(self, device): shape = (2, 2) @@ -10761,7 +10780,8 @@ def test_tensor_shape_empty(self, device): self.assertEqual([(0, 1, 0, 0), (0, 1, 1, 0), (0, 1, 2, 0)], [z.shape for z in torch.split(x, (0, 1, 2), dim=2)]) - self.assertRaises(RuntimeError, lambda: torch.split(x, 0, dim=1)) + with self.assertRaises(RuntimeError): + torch.split(x, 0, dim=1) # This is strange because the split size is larger than the dim size, but consistent with # how split handles that case generally (when no 0s are involved). self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 1, dim=0)]) @@ -12764,8 +12784,12 @@ def _test_memory_format_transformations(self, device, input_generator_fn, transf clone = transformation_fn(xc) if default_is_preserve: - self.assertFalse(clone.is_contiguous()) - self.assertTrue(clone.is_contiguous(memory_format=memory_format)) + if ipex.get_auto_dnnl(): + self.assertTrue(clone.is_contiguous()) + self.assertFalse(clone.is_contiguous(memory_format=memory_format)) + else: + self.assertFalse(clone.is_contiguous()) + self.assertTrue(clone.is_contiguous(memory_format=memory_format)) else: self.assertTrue(clone.is_contiguous()) self.assertFalse(clone.is_contiguous(memory_format=memory_format)) @@ -14398,7 +14422,6 @@ def fn(self, device, dtype): # Runs the tensor op on CPU and device cpu_result = getattr(cpu_tensor, op_str)(*cpu_args) device_result = getattr(device_tensor, op_str)(*device_args) - # Compares CPU and device inputs and outputs precision = half_precision if dtype == torch.half else float_precision @@ -14512,4 +14535,5 @@ class TestTorch(TestCase, _TestTorchMixin): instantiate_device_type_tests(TestTensorDeviceOps, globals(), except_for='cpu') if __name__ == '__main__': + ipex.enable_auto_dnnl() run_tests() diff --git a/torch_ipex/csrc/aten_ipex_bridge.cpp b/torch_ipex/csrc/aten_ipex_bridge.cpp index d3039ccbd..c31b38968 100644 --- a/torch_ipex/csrc/aten_ipex_bridge.cpp +++ b/torch_ipex/csrc/aten_ipex_bridge.cpp @@ -18,6 +18,7 @@ namespace torch_ipex { namespace bridge { +#if defined(_DEBUG) #define CHECK_TENSOR(a, b) \ TORCH_INTERNAL_ASSERT(a.numel() == b.numel()); \ TORCH_INTERNAL_ASSERT(a.dtype() == b.dtype()); \ @@ -30,13 +31,21 @@ namespace bridge { TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->is_wrapped_number() == b.unsafeGetTensorImpl()->is_wrapped_number()); \ TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->version_counter().current_version() == b.unsafeGetTensorImpl()->version_counter().current_version()); \ TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->allow_tensor_metadata_change() == b.unsafeGetTensorImpl()->allow_tensor_metadata_change()) +#else +#define CHECK_TENSOR(a, b) ((void) 0) +#endif +#if defined(_DEBUG) #define CHECK_TENSOR_CRITICAL(a, b, may_alias) \ TORCH_INTERNAL_ASSERT(!may_alias || a.data_ptr() == b.data_ptr()); \ TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->strides() == b.unsafeGetTensorImpl()->strides()); \ TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->storage_offset() == b.unsafeGetTensorImpl()->storage_offset()); \ CHECK_TENSOR(a, b) +#else +#define CHECK_TENSOR_CRITICAL(a, b, may_alias) ((void) 0) +#endif +#if defined(_DEBUG) #define CHECK_SPARSE_TENSOR_CRITICAL(a, b, may_alias) \ TORCH_INTERNAL_ASSERT(!may_alias || a._indices().data_ptr() == b._indices().data_ptr()); \ TORCH_INTERNAL_ASSERT(!may_alias || a._values().data_ptr() == b._values().data_ptr()); \ @@ -46,43 +55,54 @@ namespace bridge { TORCH_INTERNAL_ASSERT(a.is_coalesced() == b.is_coalesced()); \ CHECK_TENSOR(a._indices(), b._indices()); \ CHECK_TENSOR(a._values(), b._values()) - +#else +#define CHECK_SPARSE_TENSOR_CRITICAL(a, b, may_alias) ((void) 0) +#endif at::Tensor shallowFallbackToCPUTensorImpl(const at::Tensor& ipexTensor); void reorderDilTensorToPublic(const at::Tensor& ipexTensor) { void *data_ctx = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get_context(); cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx; - // All aten::tensor with dnnl::tensor should be contiguous +#if defined(_DEBUG) TORCH_WARN(ipexTensor.is_contiguous()); TORCH_INTERNAL_ASSERT(! (shade_data_context->dil_tensor.is_empty())); +#endif dil::tensor &dil_tensor = shade_data_context->dil_tensor; - dil::dims sizes = dil_tensor.get_dims(); - dil::dims strides; - if (dil_tensor.is_public_format()) { +#if defined(_DEBUG) TORCH_INTERNAL_ASSERT(shade_data_context->cpu_raw_data == shade_data_context->dil_tensor.get_data_handle()); TORCH_INTERNAL_ASSERT(shade_data_context->cpu_raw_data != nullptr); TORCH_INTERNAL_ASSERT(shade_data_context->cpu_del_fun != nullptr); - strides = dil_tensor.get_strides(); +#endif } else { - auto dims = dil_tensor.get_dims(); - // NOTE: int32_t dims from ideep::tensor but sizes needs int64_t - at::Tensor cpu_tensor = at::empty( - sizes, ipexTensor.options().device(c10::kCPU).layout(c10::kStrided)); - TORCH_INTERNAL_ASSERT(cpu_tensor.scalar_type() == get_at_data_type(dil_tensor.get_data_type())); - auto pub_tensor = dil_tensor.to_public(cpu_tensor.data_ptr(), dil_tensor.get_data_type()); - strides = pub_tensor.get_strides(); - at::DataPtr& cpu_tensor_data_ptr = cpu_tensor.unsafeGetTensorImpl()->storage().unsafeGetStorageImpl()->data_ptr(); - ipexTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(cpu_tensor_data_ptr)); - // The tensor has been reset to new DataPtr, then we need to attach new shade data context. - attachShadeDataConext(ipexTensor); +#if defined(_DEBUG) + auto& data_ptr = ipexTensor.storage().unsafeGetStorageImpl()->data_ptr(); + TORCH_INTERNAL_ASSERT(data_ptr.get_deleter() == &(cpu::ShadeDataContext::freeShadeDataContext)); + TORCH_INTERNAL_ASSERT(shade_data_context->cpu_del_fun == nullptr); +#endif + auto pub_tensor = dil_tensor.to_public(nullptr, dil_tensor.get_data_type()); + + cpu::ShadeDataContext *new_shade_data_context = cpu::ShadeDataContext::allocShadeDataContext(); + new_shade_data_context->data_type = cpu::SHADE_DATA_TYPE::DIL; + new_shade_data_context->dil_tensor = pub_tensor; + // Share with DNNL raw data because it is plain format now + new_shade_data_context->cpu_raw_data = pub_tensor.get_data_handle(); + // Cannot free CPU data because the the data is owned by DNNL + new_shade_data_context->cpu_del_fun = &(c10::detail::deleteNothing); + + // Create a new DataPtr instances because the DataPtr class does not support set + // its data or context directly + c10::DataPtr shade_data_ptr( + pub_tensor.get_data_handle(), + new_shade_data_context, + &(cpu::ShadeDataContext::freeShadeDataContext), + ipexTensor.device().type()); + + ipexTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(shade_data_ptr)); TORCH_INTERNAL_ASSERT(ipexTensor.is_contiguous()); } - - auto* ipexTensorImpl = (IPEXTensorImpl *)ipexTensor.unsafeGetTensorImpl(); - ipexTensorImpl->force_set_strided(sizes, strides); } @@ -279,32 +299,6 @@ at::Tensor upgradeToDPCPPTensor(const at::Tensor& cpuTensor) { return _tensor; } -at::Tensor shallowUpgradeToDPCPPShadeTensor(const at::Tensor& cpuTensor) { - if (!(cpuTensor.defined())) { - return at::Tensor(); - } - TORCH_INTERNAL_ASSERT(cpuTensor.device().type() == at::DeviceType::CPU); - if (cpuTensor.is_sparse()) shallowUpgradeToDPCPPTensor(cpuTensor); - - auto cpu_storage_impl = cpuTensor.storage().unsafeGetStorageImpl(); - auto& data_ptr = cpu_storage_impl->data_ptr(); - auto cur_del_fn = data_ptr.get_deleter(); - bool res = data_ptr.compare_exchange_deleter(cur_del_fn, &(c10::detail::deleteNothing)); - TORCH_INTERNAL_ASSERT(res); - // Make sure that does not triger free resource for set_ptr - cpu::ShadeDataContext *shade_data_context = cpu::ShadeDataContext::allocShadeDataContext(); - shade_data_context->cpu_raw_data = data_ptr.get(); - shade_data_context->cpu_del_fun = cur_del_fn; - shade_data_context->data_type = cpu::SHADE_DATA_TYPE::CPU_RAW; - c10::DataPtr shade_data_ptr( - data_ptr.get(), - shade_data_context, - cpu::ShadeDataContext::freeShadeDataContext, - at::DeviceType::CPU); - cpuTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(shade_data_ptr)); - return shallowUpgradeToDPCPPTensor(cpuTensor); -} - // Upgrade CPU tensor to DPCPP Tensor with shallow copy // It will create an new DPCPP tensor but shares CPU tensor buffer // [NOTE]: Device info of Dense CPU tensor is polluted. diff --git a/torch_ipex/csrc/cpu/DevOPs.cpp b/torch_ipex/csrc/cpu/DevOPs.cpp index c6d1dd32f..ac7caaa39 100644 --- a/torch_ipex/csrc/cpu/DevOPs.cpp +++ b/torch_ipex/csrc/cpu/DevOPs.cpp @@ -239,6 +239,7 @@ at::Tensor& AtenIpexCPUDev::dil_add_out( const std::vector scales{1.0, alpha.to()}; dil::sum::compute(scales, {x, y}, z); + dbl::comm::sync_shape_from_dil_to_aten(result, z); return result; } @@ -254,7 +255,6 @@ at::Tensor AtenIpexCPUDev::dil_add(const at::Tensor& self, const at::Tensor& oth dil::sum::compute(scales, {x, y}, z); return dbl::comm::gen_aten_tensor_by(z); - } at::Tensor & AtenIpexCPUDev::dil_add_(at::Tensor& self, const at::Tensor& other, at::Scalar alpha) { @@ -267,6 +267,7 @@ at::Tensor & AtenIpexCPUDev::dil_add_(at::Tensor& self, const at::Tensor& other, const std::vector scales{1.0, alpha.to()}; dil::sum::compute(scales, {dil_self, dil_other}, dil_self); + dbl::comm::sync_shape_from_dil_to_aten(self, dil_self); return self; } @@ -282,6 +283,7 @@ at::Tensor& AtenIpexCPUDev::dil_mul_out(at::Tensor& result, const at::Tensor& se dil::binary::compute(dil_self, dil_other, dil_result, dil::algorithm::binary_mul); + dbl::comm::sync_shape_from_dil_to_aten(result, dil_result); return result; } @@ -343,6 +345,8 @@ at::Tensor& AtenIpexCPUDev::dil_bmm_out( const dil::tensor w = dbl::comm::try_gen_dil_tensor(batch2); dil::tensor y = dbl::comm::try_gen_dil_tensor(result); matmul_common(x, w, dil::tensor(), y); + + dbl::comm::sync_shape_from_dil_to_aten(result, y); return result; } @@ -386,6 +390,7 @@ at::Tensor& AtenIpexCPUDev::dil_baddbmm_out( dil::tensor y = dbl::comm::try_gen_dil_tensor(result); auto attr_ = dil::attr_t::fuse_sum(); matmul_common(x, w, bias, y, beta, alpha, attr_); + dbl::comm::sync_shape_from_dil_to_aten(result, y); return result; } @@ -484,6 +489,7 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out( } } matmul_common(x_, w_, bias, y, beta, alpha, attr_); + dbl::comm::sync_shape_from_dil_to_aten(result, y); return result; } @@ -968,6 +974,7 @@ at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) { dil::algorithm::eltwise_relu, dil::prop_kind::forward_training, /*alpha*/ 0.0); + dbl::comm::sync_shape_from_dil_to_aten(input, dil_self); return input; } @@ -1034,6 +1041,7 @@ at::Tensor& AtenIpexCPUDev::dil_sigmoid_(at::Tensor& self) { dil::tensor x = dbl::comm::try_gen_dil_tensor(self); dil::eltwise_forward::compute( x, x, dil::algorithm::eltwise_logistic_use_dst_for_bwd, dil::prop_kind::forward); + dbl::comm::sync_shape_from_dil_to_aten(self, x); return self; } @@ -1081,10 +1089,13 @@ at::Tensor AtenIpexCPUDev::dil_clone(const at::Tensor& self, c10::optional 0, "DNNL transpose cannot generate DNNL tensor for the input aten Tensor. input tensor dim: ", self.dim()); dil::tensor y; std::vector axes(x.ndims()); std::iota(axes.begin(), axes.end(), 0); + dim0 = at::maybe_wrap_dim(dim0, self.dim()); + dim1 = at::maybe_wrap_dim(dim1, self.dim()); std::swap(axes[dim0], axes[dim1]); y.transpose_from(x, axes); return dbl::comm::gen_aten_tensor_by(y); @@ -1102,7 +1113,7 @@ at::Tensor& AtenIpexCPUDev::dil_cat_out(at::Tensor& result, at::TensorList tenso DEBUG("AtenIpexCPUDev::dil_cat_out\n"); CHECK_DNNL_OP_PRE_COND(result); check_cat_no_zero_dim(tensors); - dim = legacy_cat_wrap_dim(dim, tensors); + dim = at::legacy_cat_wrap_dim(dim, tensors); std::vector x; for (auto i =0; i< tensors.size(); i++) { TORCH_CHECK(!(tensors[i].dim() == 1 && tensors[i].sizes()[0] == 0), @@ -1111,13 +1122,14 @@ at::Tensor& AtenIpexCPUDev::dil_cat_out(at::Tensor& result, at::TensorList tenso } dil::tensor y = dbl::comm::try_gen_dil_tensor(result); dil::concat::compute(x, dim, y); + dbl::comm::sync_shape_from_dil_to_aten(result, y); return result; } at::Tensor AtenIpexCPUDev::dil_cat(at::TensorList tensors, int64_t dim) { DEBUG("AtenIpexCPUDev::dil_cat\n"); check_cat_no_zero_dim(tensors); - dim = legacy_cat_wrap_dim(dim, tensors); + dim = at::legacy_cat_wrap_dim(dim, tensors); std::vector x; at::Tensor tensors_contiguous[tensors.size()]; for (auto i = 0; i < tensors.size(); i++) { @@ -1145,6 +1157,8 @@ std::vector AtenIpexCPUDev::dil_split_with_sizes(const at::Tensor& s "entries, but got split_sizes=", split_sizes); sizes.push_back((int32_t)length); } + + dim = at::maybe_wrap_dim(dim, self.dim()); auto y = dil::spliter::compute(x, sizes, dim, false); for (auto j = 0; j < num_splits; j++) { splits[j] = dbl::comm::gen_aten_tensor_by(y[j]); @@ -1155,6 +1169,7 @@ std::vector AtenIpexCPUDev::dil_split_with_sizes(const at::Tensor& s std::vector AtenIpexCPUDev::dil_split(const at::Tensor& self, int64_t split_size, int64_t dim) { DEBUG("AtenIpexCPUDev::dil_split\n"); CHECK_DNNL_OP_PRE_COND(self); + dim = at::maybe_wrap_dim(dim, self.dim()); int64_t dim_size = self.size(dim); int64_t num_splits = 1; if (split_size != 0) { diff --git a/torch_ipex/csrc/cpu/ShadeDataContext.h b/torch_ipex/csrc/cpu/ShadeDataContext.h index 634b5df49..55bee10fa 100644 --- a/torch_ipex/csrc/cpu/ShadeDataContext.h +++ b/torch_ipex/csrc/cpu/ShadeDataContext.h @@ -90,7 +90,9 @@ struct ShadeDataContext { TORCH_INTERNAL_ASSERT((data_type == SHADE_DATA_TYPE::CPU_RAW) || (data_type == SHADE_DATA_TYPE::DIL)); if (data_type == SHADE_DATA_TYPE::DIL) { +#if defined(_DEBUG) TORCH_WARN(tensor.is_contiguous()); +#endif auto raw_cpu_data = tensor.storage().data_ptr().get(); if (raw_cpu_data == nullptr) { // the dnnl tensor does not share data with raw tensor data. diff --git a/torch_ipex/csrc/cpu/dbl/Common.cpp b/torch_ipex/csrc/cpu/dbl/Common.cpp index b49b19a70..3ce82914a 100644 --- a/torch_ipex/csrc/cpu/dbl/Common.cpp +++ b/torch_ipex/csrc/cpu/dbl/Common.cpp @@ -86,6 +86,14 @@ at::Tensor empty_dil_tensor(at::IntArrayRef sizes, const at::TensorOptions& opti return gen_aten_tensor_by(it); } +void sync_shape_from_dil_to_aten(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor) { + dil::dims sizes = dil_tensor.get_dims(); + dil::dims strides = dil_tensor.get_strides(); + TORCH_INTERNAL_ASSERT(ipex_tensor.device().type() == at::DeviceType::DPCPP); + auto* _tensor_impl = (IPEXTensorImpl *)ipex_tensor.unsafeGetTensorImpl(); + _tensor_impl->force_set_strided(sizes, strides); +} + } // namespace comm } // namespace dbl } // namespace cpu diff --git a/torch_ipex/csrc/cpu/dbl/Common.h b/torch_ipex/csrc/cpu/dbl/Common.h index a350ba2d8..15a47fffb 100644 --- a/torch_ipex/csrc/cpu/dbl/Common.h +++ b/torch_ipex/csrc/cpu/dbl/Common.h @@ -14,6 +14,7 @@ at::Tensor dil_tensor_to_dense(const at::Tensor& tensor); dil::tensor try_gen_dil_tensor(const at::Tensor &input); at::Tensor gen_aten_tensor_by(dil::tensor tensor); at::Tensor empty_dil_tensor(at::IntArrayRef sizes, const at::TensorOptions& options); +void sync_shape_from_dil_to_aten(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor); } // namespace comm } // namespace dbl diff --git a/torch_ipex/csrc/cpu/dbl/DNNLChecker.cpp b/torch_ipex/csrc/cpu/dbl/DNNLChecker.cpp index e9f2efff6..7afef50f9 100644 --- a/torch_ipex/csrc/cpu/dbl/DNNLChecker.cpp +++ b/torch_ipex/csrc/cpu/dbl/DNNLChecker.cpp @@ -9,12 +9,13 @@ namespace dbl { namespace chk { bool dnnl_support_the_tensors(const std::vector &tensor_vec) { - return dnnl_support_the_dimension_of(tensor_vec) && + return dnnl_tensor_has_data(tensor_vec) && + dnnl_support_the_dimension_of(tensor_vec) && dnnl_support_the_data_type_of(tensor_vec); } bool dnnl_inplace_support_the_tensors(const std::vector &tensor_vec) { - return dnnl_support_the_dimension_of(tensor_vec) && + return dnnl_tensor_has_data(tensor_vec) && dnnl_support_the_data_type_of(tensor_vec) && dnnl_support_the_memory_layout_of(tensor_vec); } @@ -53,6 +54,14 @@ bool dnnl_support_the_dimension_of(const std::vector &tensor_vec) { return true; } +bool dnnl_tensor_has_data(const std::vector &tensor_vec) { + for (auto it = tensor_vec.begin(); it != tensor_vec.end(); ++it) + if (it->data_ptr() == nullptr) + return false; + + return true; +} + } // namespace chk } // namespace dbl } // namespace cpu diff --git a/torch_ipex/csrc/cpu/dbl/DNNLChecker.h b/torch_ipex/csrc/cpu/dbl/DNNLChecker.h index 609195fcc..fc6eae28a 100644 --- a/torch_ipex/csrc/cpu/dbl/DNNLChecker.h +++ b/torch_ipex/csrc/cpu/dbl/DNNLChecker.h @@ -61,6 +61,14 @@ bool dnnl_support_the_data_type_of(const std::vector &tensor_vec); */ bool dnnl_support_the_dimension_of(const std::vector &tensor_vec); +/** + * Check if the input tensor has data + * + * @param tensor_vec input tensors + * + */ +static inline bool dnnl_tensor_has_data(const std::vector &tensor_vec); + } // namespace chk } // namespace dbl } // namespace cpu diff --git a/torch_ipex/csrc/utils.cpp b/torch_ipex/csrc/utils.cpp index 0477f3b8e..f30256165 100644 --- a/torch_ipex/csrc/utils.cpp +++ b/torch_ipex/csrc/utils.cpp @@ -80,7 +80,9 @@ dil::data_type get_dil_data_type(at::ScalarType at_dt) { } else if (at_dt == at::ScalarType::QUInt8) { return dil::data_type::u8; } else { +#if defined(_DEBUG) TORCH_WARN("DNNL does not support current data type."); +#endif return dil::data_type::undef; } } @@ -109,7 +111,8 @@ bool check_tensor_own_whole_storage(const at::Tensor& tensor) { return false; return (tensor.storage_offset() == 0) && - (tensor.numel() == tensor.storage().numel()); + (tensor.numel() == tensor.storage().numel()) && + (tensor.itemsize() == tensor.storage().itemsize()); } bool check_tensor_own_shade_context(const at::Tensor& tensor) {