diff --git a/scripts/cpu/gen-dense-cpu-ops.py b/scripts/cpu/gen-dense-cpu-ops.py index 7a8b7253b..9784e8910 100755 --- a/scripts/cpu/gen-dense-cpu-ops.py +++ b/scripts/cpu/gen-dense-cpu-ops.py @@ -306,8 +306,7 @@ def is_out_func(fname): if param_var == 'out' and is_out_func(fname): code += ' TORCH_INTERNAL_ASSERT({}.is_contiguous());\n'.format(param_var) else: - # param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var) - None + param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var) param_seq_str_vec.append(param_seq_str) code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors))\n' code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_seq_str_vec)) @@ -315,6 +314,9 @@ def is_out_func(fname): code += ' }\n' code += ' } catch (std::exception& e) {\n' + code += '#if defined(_DEBUG)\n' + code += ' TORCH_WARN(e.what());\n' + code += '#endif\n' code += ' }\n\n' diff --git a/tests/cpu/test_lazy_reorder.py b/tests/cpu/test_lazy_reorder.py index 4943e537a..34a5429c9 100644 --- a/tests/cpu/test_lazy_reorder.py +++ b/tests/cpu/test_lazy_reorder.py @@ -368,12 +368,12 @@ def test_addbmm(self): addbmm_cpu = torch.addbmm(res_cpu, b1_cpu, b2_cpu, beta=beta, alpha=alpha) addbmm_dpcpp = torch.addbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha) - self.assertEqual(addbmm_cpu, addbmm_dpcpp) + self.assertEqual(addbmm_cpu, addbmm_dpcpp, 1e-4) y_cpu = torch.randn(M, O, dtype=torch.float32) y_dpcpp = y_cpu.to(device=device) torch.addbmm(res_cpu, b1_cpu, b2_cpu, beta=beta, alpha=alpha, out=y_cpu) torch.addbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha, out=y_dpcpp) - self.assertEqual(y_cpu, y_dpcpp) + self.assertEqual(y_cpu, y_dpcpp, 1e-4) def test_baddbmm(self): ipex.enable_auto_dnnl() @@ -683,7 +683,6 @@ def test_batch_norm2d_backward(self): bn = torch.nn.BatchNorm2d(3) bn_dpcpp = copy.deepcopy(bn).to(device=device) - y_cpu = bn(x_cpu).sum() y_dpcpp = bn_dpcpp(x_dpcpp).sum() y_cpu.backward() @@ -756,17 +755,24 @@ def test_view(self): x_cpu = torch.randn(old_shape) x_dpcpp = x_cpu.to(device=device).clone() - print(x_dpcpp.size()) + self.assertTrue(ipex.is_dil_tensor(x_dpcpp)) + self.assertEqual(ipex.get_dil_tensor_sizes(x_dpcpp), [4, 16]) + self.assertEqual(ipex.get_dil_tensor_strides(x_dpcpp), [16, 1]) x_cpu_view = x_cpu.view(new_shape) - print(x_cpu_view.size()) + self.assertEqual(x_cpu_view.size(), [1, 4, 4, 4]) + self.assertEqual(x_cpu_view.stride(), [64, 16, 4, 1]) + x_dpcpp_view = x_dpcpp.view(new_shape) - print(x_dpcpp_view.size()) + self.assertTrue(ipex.is_dil_tensor(x_dpcpp_view)) y = torch.randn(new_shape) out_cpu = x_cpu_view * y # test if the shape of x_dpcpp_view is compatible with y out_dpcpp = x_dpcpp_view * y + self.assertTrue(ipex.is_dil_tensor(out_dpcpp)) + self.assertEqual(ipex.get_dil_tensor_sizes(out_dpcpp), [1, 4, 4, 4]) + self.assertEqual(ipex.get_dil_tensor_strides(out_dpcpp), [64, 16, 4, 1]) self.assertEqual(out_cpu, out_dpcpp) # test if metadata of x_dpcpp has not been altered diff --git a/torch_ipex/csrc/aten_ipex_bridge.cpp b/torch_ipex/csrc/aten_ipex_bridge.cpp index c31b38968..62b187c98 100644 --- a/torch_ipex/csrc/aten_ipex_bridge.cpp +++ b/torch_ipex/csrc/aten_ipex_bridge.cpp @@ -11,6 +11,7 @@ #include "ipex_tensor_impl.h" #include "ipex_sparse_tensor_impl.h" +#include "cpu/dbl/Common.h" #include "cpu/ShadeDataContext.h" #include "cpu/bf16/Converter.h" #include "utils.h" @@ -65,7 +66,6 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) { void *data_ctx = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get_context(); cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx; #if defined(_DEBUG) - TORCH_WARN(ipexTensor.is_contiguous()); TORCH_INTERNAL_ASSERT(! (shade_data_context->dil_tensor.is_empty())); #endif dil::tensor &dil_tensor = shade_data_context->dil_tensor; @@ -101,7 +101,7 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) { ipexTensor.device().type()); ipexTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(shade_data_ptr)); - TORCH_INTERNAL_ASSERT(ipexTensor.is_contiguous()); + cpu::dbl::comm::sync_shape_from_dil_to_aten(ipexTensor, pub_tensor); } } diff --git a/torch_ipex/csrc/cpu/ShadeDataContext.h b/torch_ipex/csrc/cpu/ShadeDataContext.h index 55bee10fa..ace9ce5dd 100644 --- a/torch_ipex/csrc/cpu/ShadeDataContext.h +++ b/torch_ipex/csrc/cpu/ShadeDataContext.h @@ -90,9 +90,6 @@ struct ShadeDataContext { TORCH_INTERNAL_ASSERT((data_type == SHADE_DATA_TYPE::CPU_RAW) || (data_type == SHADE_DATA_TYPE::DIL)); if (data_type == SHADE_DATA_TYPE::DIL) { -#if defined(_DEBUG) - TORCH_WARN(tensor.is_contiguous()); -#endif auto raw_cpu_data = tensor.storage().data_ptr().get(); if (raw_cpu_data == nullptr) { // the dnnl tensor does not share data with raw tensor data. @@ -113,15 +110,11 @@ struct ShadeDataContext { // C = A[4:7, :] // All these tensors share same buffer of Tensor A with different storge offsets and elements. // So the context modification will impact all these tensors. - if ((shade_data_context->dil_tensor.get_data_handle() == raw_cpu_data) && - (shade_data_context->dil_tensor.get_nelems() == tensor.storage().numel()) && - (shade_data_context->dil_tensor.get_data_type() == get_dil_data_type(tensor.scalar_type()))) { - //TODO: Do we need to check strides here? + if (check_tensor_own_whole_storage(tensor)) { TORCH_INTERNAL_ASSERT(shade_data_context->dil_tensor.get_size() == tensor.storage().capacity()); return true; } } - TORCH_INTERNAL_ASSERT(false); } return false; @@ -148,13 +141,9 @@ struct ShadeDataContext { TORCH_INTERNAL_ASSERT(tensor.has_storage()); void *raw_context = tensor.storage().data_ptr().get_context(); TORCH_INTERNAL_ASSERT(raw_context != nullptr); - if (isDilTensor(tensor)) { - ShadeDataContext *shade_data_context = (ShadeDataContext*)raw_context; - return shade_data_context->dil_tensor; - } else { - TORCH_INTERNAL_ASSERT(false); - return dil::tensor(); - } + TORCH_INTERNAL_ASSERT(isDilTensor(tensor)); + ShadeDataContext *shade_data_context = (ShadeDataContext*)raw_context; + return shade_data_context->dil_tensor; } /** diff --git a/torch_ipex/csrc/cpu/dbl/Common.cpp b/torch_ipex/csrc/cpu/dbl/Common.cpp index c27d941d8..bf338402d 100644 --- a/torch_ipex/csrc/cpu/dbl/Common.cpp +++ b/torch_ipex/csrc/cpu/dbl/Common.cpp @@ -37,7 +37,7 @@ at::Tensor dil_tensor_to_dense(const at::Tensor& tensor) { dil::tensor try_gen_dil_tensor(const at::Tensor &input) { if (cpu::ShadeDataContext::isDilTensor(input)) { auto dil_tensor = cpu::ShadeDataContext::getDilTensor(input); - if (dil_tensor.is_public_format()) { + if ((!check_aten_dil_shape_info(input, dil_tensor)) && dil_tensor.is_public_format()) { dil_tensor.set_dims_and_strides(input.sizes().vec(), input.strides().vec()); } return dil_tensor; diff --git a/torch_ipex/csrc/init_python_bindings.cpp b/torch_ipex/csrc/init_python_bindings.cpp index dba2a398a..6cb5b7c2f 100644 --- a/torch_ipex/csrc/init_python_bindings.cpp +++ b/torch_ipex/csrc/init_python_bindings.cpp @@ -12,6 +12,8 @@ #include "aten_ipex_type.h" #include "auto_opt_config.h" +#include "cpu/dil/dil.hpp" +#include "cpu/ShadeDataContext.h" #include "cpu/ExtendOPs.h" #include "cpu/MlpOPs.h" @@ -29,6 +31,28 @@ void setAutoDNNL(bool val) { AutoOptConfig::singleton().set_auto_dnnl(val); } +/// **** Only for unit test **** +bool isDilTensor(const at::Tensor &tensor) { + return cpu::ShadeDataContext::isDilTensor(tensor); +} + +dil::dims getDilTensorSizes(const at::Tensor &tensor) { + if (isDilTensor(tensor)) { + auto dil_tensor = cpu::ShadeDataContext::getDilTensor(tensor); + return dil_tensor.get_dims(); + } + return dil::dims(); +} + +dil::dims getDilTensorStrides(const at::Tensor &tensor) { + if (isDilTensor(tensor)) { + auto dil_tensor = cpu::ShadeDataContext::getDilTensor(tensor); + return dil_tensor.get_strides(); + } + return dil::dims(); +} +/// **************************** + void InitIpexModuleBindings(py::module m) { m.def("_initialize_aten_bindings", []() { AtenIpexType::InitializeAtenBindings(); }); @@ -97,6 +121,10 @@ void InitIpexModuleBindings(py::module m) { m.def("mlp_create_handle", &AtenIpexTypeMLPExt::create_handle); m.def("mlp_set_relu_mask", &AtenIpexTypeMLPExt::set_relu_mask); m.def("mlp_release_handle", &AtenIpexTypeMLPExt::release_handle); + + m.def("is_dil_tensor", &isDilTensor); + m.def("get_dil_tensor_sizes", &getDilTensorSizes); + m.def("get_dil_tensor_strides", &getDilTensorStrides); } } // namespace diff --git a/torch_ipex/csrc/utils.cpp b/torch_ipex/csrc/utils.cpp index f30256165..245ee4523 100644 --- a/torch_ipex/csrc/utils.cpp +++ b/torch_ipex/csrc/utils.cpp @@ -127,4 +127,13 @@ bool check_tensor_own_shade_context(const at::Tensor& tensor) { return (data_ptr != data_ctx) && (data_ctx != nullptr); } +bool check_aten_dil_shape_info(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor) { + if (dil_tensor.is_public_format()) { + return ipex_tensor.sizes().vec() == dil_tensor.get_dims() && + ipex_tensor.strides().vec() == dil_tensor.get_strides(); + } else { + return ipex_tensor.sizes().vec() == dil_tensor.get_dims(); + } +} + } // namespace torch_ipex diff --git a/torch_ipex/csrc/utils.h b/torch_ipex/csrc/utils.h index 0e3ccfebb..36edb1b48 100644 --- a/torch_ipex/csrc/utils.h +++ b/torch_ipex/csrc/utils.h @@ -20,5 +20,6 @@ at::ScalarType get_at_data_type(dil::data_type); bool check_auto_dnnl(); bool check_tensor_own_whole_storage(const at::Tensor& tensor); bool check_tensor_own_shade_context(const at::Tensor& tensor); +bool check_aten_dil_shape_info(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor); } // namespace torch_ipex