diff --git a/torch_ipex/csrc/aten_ipex_bridge.cpp b/torch_ipex/csrc/aten_ipex_bridge.cpp index 266955201..aa295f3bc 100644 --- a/torch_ipex/csrc/aten_ipex_bridge.cpp +++ b/torch_ipex/csrc/aten_ipex_bridge.cpp @@ -66,13 +66,13 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) { void *data_ctx = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get_context(); cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx; #if defined(_DEBUG) - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_empty())); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor->is_empty())); #endif - dil::tensor &dil_tensor = shade_data_context->dil_tensor; + dil::tensor &dil_tensor = *shade_data_context->dil_tensor; if (dil_tensor.is_public_format()) { #if defined(_DEBUG) - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->cpu_raw_data == shade_data_context->dil_tensor.get_data_handle()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->cpu_raw_data == shade_data_context->dil_tensor->get_data_handle()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->cpu_raw_data != nullptr); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->cpu_del_fun != nullptr); #endif @@ -106,7 +106,7 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) { } -void attachShadeDataConext(const at::Tensor& tensor) { +void attachShadeDataContext(const at::Tensor& tensor) { auto tensor_storage_impl = tensor.storage().unsafeGetStorageImpl(); auto& data_ptr = tensor_storage_impl->data_ptr(); @@ -272,7 +272,7 @@ at::Tensor shallowUpgradeToDPCPPTensor(const at::Tensor& cpuTensor) { CHECK_TENSOR_CRITICAL(_tensor, cpuTensor, true); //TODO: Cannot set reserved_ // dest_impl->reserved_ = src_impl->reserved_; - attachShadeDataConext(_tensor); + attachShadeDataContext(_tensor); return _tensor; } } @@ -303,7 +303,7 @@ at::Tensor shallowUpgradeToDPCPPTensorA(const at::Tensor& ipexTensor, const at:: ipex_impl->copy_auto_grad(cpuTensor.unsafeGetTensorImpl()); CHECK_TENSOR_CRITICAL(_tensor, cpuTensor, true); - attachShadeDataConext(_tensor); + attachShadeDataContext(_tensor); return _tensor; } @@ -388,7 +388,7 @@ const at::Tensor& shallowUpgradeToDPCPPTensorAW(const at::Tensor& ipexTensor, co ipex_tensor_impl->copy_meta_info(cpuTensor.unsafeGetTensorImpl()); ipex_tensor_impl->copy_auto_grad(cpuTensor.unsafeGetTensorImpl()); CHECK_TENSOR_CRITICAL(ipexTensor, cpuTensor, true); - attachShadeDataConext(ipexTensor); + attachShadeDataContext(ipexTensor); return ipexTensor; } } @@ -417,7 +417,7 @@ void reorderTensorToScalarTypeForDNNL(const at::Tensor& ipexTensor, at::ScalarTy // Shade data context has been attached if (cpu::ShadeDataContext::isDilTensor(ipexTensor)) { cpu::ShadeDataContext *shade_context = (cpu::ShadeDataContext*)(ipexTensor.storage().data_ptr().get_context()); - shade_context->dil_tensor.to_type(get_dil_data_type(dstScalarType)); + shade_context->dil_tensor->to_type(get_dil_data_type(dstScalarType)); IPEXTensorImpl* ipex_tensor_impl = (IPEXTensorImpl *)ipexTensor.unsafeGetTensorImpl(); ipex_tensor_impl->reset_data_type(dstScalarType); ipex_tensor_impl->storage().unsafeGetStorageImpl()->set_dtype(at::scalarTypeToTypeMeta(dstScalarType)); diff --git a/torch_ipex/csrc/aten_ipex_bridge.h b/torch_ipex/csrc/aten_ipex_bridge.h index bed667198..e7cf4f1d4 100644 --- a/torch_ipex/csrc/aten_ipex_bridge.h +++ b/torch_ipex/csrc/aten_ipex_bridge.h @@ -13,7 +13,7 @@ namespace bridge { at::Tensor shallowFallbackToCPUTensor(const at::Tensor& ipexTensor); std::vector shallowFallbackToCPUTensorList(const at::TensorList&); -void attachShadeDataConext(const at::Tensor& tensor); +void attachShadeDataContext(const at::Tensor& tensor); /** * Reorder the DNNL tensor to the public format if the input tensor contains DNNL tensor. diff --git a/torch_ipex/csrc/cpu/DevOPs.cpp b/torch_ipex/csrc/cpu/DevOPs.cpp index 6e3283b3b..b0e58d1fd 100644 --- a/torch_ipex/csrc/cpu/DevOPs.cpp +++ b/torch_ipex/csrc/cpu/DevOPs.cpp @@ -64,7 +64,7 @@ at::Tensor AtenIpexCPUDev::dil_convolution( dilation, groups); - return dbl::comm::gen_aten_tensor_by(dil_output); + return dbl::comm::gen_aten_tensor_by(std::move(dil_output)); } at::Tensor dil_convolution_backward_input( @@ -87,7 +87,7 @@ at::Tensor dil_convolution_backward_input( padding.vec(), padding.vec(), groups); - return dbl::comm::gen_aten_tensor_by(dil_grad_input); + return dbl::comm::gen_aten_tensor_by(std::move(dil_grad_input)); } std::tuple dil_convolution_backward_weights( @@ -117,8 +117,8 @@ std::tuple dil_convolution_backward_weights( groups, diff_weight_type); return std::make_tuple( - dbl::comm::gen_aten_tensor_by(dil_grad_weight), - dbl::comm::gen_aten_tensor_by(dil_grad_bias)); + dbl::comm::gen_aten_tensor_by(std::move(dil_grad_weight)), + dbl::comm::gen_aten_tensor_by(std::move(dil_grad_bias))); } else { dil::convolution_backward_weights::compute( dil_input, @@ -132,7 +132,7 @@ std::tuple dil_convolution_backward_weights( groups, diff_weight_type); return std::make_tuple( - dbl::comm::gen_aten_tensor_by(dil_grad_weight), + dbl::comm::gen_aten_tensor_by(std::move(dil_grad_weight)), at::Tensor()); } } @@ -255,7 +255,7 @@ at::Tensor AtenIpexCPUDev::dil_add(const at::Tensor& self, const at::Tensor& oth const std::vector scales{1.0, alpha.to()}; dil::sum::compute(scales, {x, y}, z); - return dbl::comm::gen_aten_tensor_by(z); + return dbl::comm::gen_aten_tensor_by(std::move(z)); } at::Tensor & AtenIpexCPUDev::dil_add_(at::Tensor& self, const at::Tensor& other, at::Scalar alpha) { @@ -552,9 +552,9 @@ at::Tensor AtenIpexCPUDev::dil_linear( output_size.push_back(weight.size(0)); if (self.dim() > 2) { - return dbl::comm::gen_aten_tensor_by(y).reshape(output_size); + return dbl::comm::gen_aten_tensor_by(std::move(y)).reshape(output_size); } - return dbl::comm::gen_aten_tensor_by(y); + return dbl::comm::gen_aten_tensor_by(std::move(y)); } at::Tensor dil_linear_backward_input( @@ -574,9 +574,9 @@ at::Tensor dil_linear_backward_input( grady, w, {input_reshaped_size.begin(), input_reshaped_size.end()}, gradx); if (input_size.size() > 2) { - return dbl::comm::gen_aten_tensor_by(gradx).reshape(input_size); + return dbl::comm::gen_aten_tensor_by(std::move(gradx)).reshape(input_size); } - return dbl::comm::gen_aten_tensor_by(gradx); + return dbl::comm::gen_aten_tensor_by(std::move(gradx)); } std::tuple dil_linear_backward_weights( @@ -593,12 +593,12 @@ std::tuple dil_linear_backward_weights( if (bias_defined) { dil::inner_product_backward_weights::compute(x, grady, gradw, gradb, diff_weight_type); return std::tuple{ - dbl::comm::gen_aten_tensor_by(gradw), - dbl::comm::gen_aten_tensor_by(gradb)}; + dbl::comm::gen_aten_tensor_by(std::move(gradw)), + dbl::comm::gen_aten_tensor_by(std::move(gradb))}; } else { dil::inner_product_backward_weights::compute(x, grady, gradw, diff_weight_type); return std::tuple{ - dbl::comm::gen_aten_tensor_by(gradw), + dbl::comm::gen_aten_tensor_by(std::move(gradw)), at::Tensor()}; } } @@ -632,8 +632,8 @@ std::tuple _dil_dropout( dil::tensor y; dil::dropout_forward::compute(x, ratio, y, mask); return std::tuple{ - dbl::comm::gen_aten_tensor_by(y), - dbl::comm::gen_aten_tensor_by(mask)}; + dbl::comm::gen_aten_tensor_by(std::move(y)), + dbl::comm::gen_aten_tensor_by(std::move(mask))}; } at::Tensor AtenIpexCPUDev::dil_dropout(const at::Tensor& self, double ratio, bool train) { @@ -657,7 +657,7 @@ at::Tensor AtenIpexCPUDev::dil_dropout_backward( dil::tensor dX; dil::dropout_backward::compute(mask_dil, dY, dX); - return dbl::comm::gen_aten_tensor_by(dX); + return dbl::comm::gen_aten_tensor_by(std::move(dX)); } std::tuple AtenIpexCPUDev::dil_native_batch_norm( @@ -696,9 +696,9 @@ std::tuple AtenIpexCPUDev::dil_native_batch_ dil::sum::compute(scales_var, {v, saved_var}, v); } return std::make_tuple( - dbl::comm::gen_aten_tensor_by(y), - dbl::comm::gen_aten_tensor_by(saved_mean), - dbl::comm::gen_aten_tensor_by(saved_var)); + dbl::comm::gen_aten_tensor_by(std::move(y)), + dbl::comm::gen_aten_tensor_by(std::move(saved_mean)), + dbl::comm::gen_aten_tensor_by(std::move(saved_var))); } else { if (use_running_stat) { dil::tensor m = dbl::comm::try_gen_dil_tensor(running_mean); @@ -710,7 +710,7 @@ std::tuple AtenIpexCPUDev::dil_native_batch_ x, w, b, y, eps); } return std::make_tuple( - dbl::comm::gen_aten_tensor_by(y), + dbl::comm::gen_aten_tensor_by(std::move(y)), at::Tensor(), at::Tensor()); } @@ -742,9 +742,9 @@ std::tuple AtenIpexCPUDev::dil_native_batch_ x, m, v, grady, w, gradx, gradw, gradb, eps); return std::make_tuple( - dbl::comm::gen_aten_tensor_by(gradx), - dbl::comm::gen_aten_tensor_by(gradw), - dbl::comm::gen_aten_tensor_by(gradb)); + dbl::comm::gen_aten_tensor_by(std::move(gradx)), + dbl::comm::gen_aten_tensor_by(std::move(gradw)), + dbl::comm::gen_aten_tensor_by(std::move(gradb))); } at::Tensor AtenIpexCPUDev::dil_max_pooling( @@ -969,7 +969,7 @@ at::Tensor AtenIpexCPUDev::dil_relu(const at::Tensor& input) { dil::tensor y; dil::eltwise_forward::compute( x, y, dil::algorithm::eltwise_relu, dil::prop_kind::forward_training, /*alpha*/ 0.0); - return dbl::comm::gen_aten_tensor_by(y); + return dbl::comm::gen_aten_tensor_by(std::move(y)); } at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) { @@ -998,7 +998,7 @@ at::Tensor AtenIpexCPUDev::dil_threshold_backward(const at::Tensor& grad_output, dil::tensor gradx; dil::eltwise_backward::compute(x, grady, gradx, dil::algorithm::eltwise_relu, /*alpha*/ 0.0); - return dbl::comm::gen_aten_tensor_by(gradx); + return dbl::comm::gen_aten_tensor_by(std::move(gradx)); } at::Tensor AtenIpexCPUDev::dil__softmax( @@ -1014,7 +1014,7 @@ at::Tensor AtenIpexCPUDev::dil__softmax( dil::tensor x = dbl::comm::try_gen_dil_tensor(self); dil::tensor y; dil::softmax_forward::compute(x, y, wrapped_dim); - return dbl::comm::gen_aten_tensor_by(y); + return dbl::comm::gen_aten_tensor_by(std::move(y)); } at::Tensor AtenIpexCPUDev::dil__softmax_backward_data( @@ -1032,7 +1032,7 @@ at::Tensor AtenIpexCPUDev::dil__softmax_backward_data( dil::tensor grady = dbl::comm::try_gen_dil_tensor(grad_output_contiguous); dil::tensor gradx; dil::softmax_backward::compute(y, grady, gradx, wrapped_dim); - return dbl::comm::gen_aten_tensor_by(gradx); + return dbl::comm::gen_aten_tensor_by(std::move(gradx)); } at::Tensor AtenIpexCPUDev::dil_sigmoid(const at::Tensor& self) { @@ -1042,7 +1042,7 @@ at::Tensor AtenIpexCPUDev::dil_sigmoid(const at::Tensor& self) { dil::tensor y; dil::eltwise_forward::compute( x, y, dil::algorithm::eltwise_logistic_use_dst_for_bwd, dil::prop_kind::forward); - return dbl::comm::gen_aten_tensor_by(y); + return dbl::comm::gen_aten_tensor_by(std::move(y)); } at::Tensor& AtenIpexCPUDev::dil_sigmoid_(at::Tensor& self) { @@ -1069,7 +1069,7 @@ at::Tensor AtenIpexCPUDev::dil_sigmoid_backward( dil::tensor gx; dil::eltwise_backward::compute(y, gy, gx, dil::algorithm::eltwise_logistic_use_dst_for_bwd); - return dbl::comm::gen_aten_tensor_by(gx); + return dbl::comm::gen_aten_tensor_by(std::move(gx)); } at::Tensor AtenIpexCPUDev::dil_reshape(const at::Tensor& self, at::IntArrayRef size) { @@ -1082,7 +1082,7 @@ at::Tensor AtenIpexCPUDev::dil_reshape(const at::Tensor& self, at::IntArrayRef s const dil::tensor x = dbl::comm::try_gen_dil_tensor(self); dil::tensor y{x}; y.reshape(inferred_size); - return dbl::comm::gen_aten_tensor_by(y); + return dbl::comm::gen_aten_tensor_by(std::move(y)); } at::Tensor AtenIpexCPUDev::dil_clone(const at::Tensor& self, c10::optional optional_memory_format) { @@ -1095,7 +1095,7 @@ at::Tensor AtenIpexCPUDev::dil_clone(const at::Tensor& self, c10::optional AtenIpexCPUDev::dil_split_with_sizes(const at::Tensor& self, at::IntArrayRef split_sizes, int64_t dim) { @@ -1175,7 +1175,7 @@ std::vector AtenIpexCPUDev::dil_split_with_sizes(const at::Tensor& s dim = at::maybe_wrap_dim(dim, self.dim()); auto y = dil::spliter::compute(x, sizes, dim, false); for (auto j = 0; j < num_splits; j++) { - splits[j] = dbl::comm::gen_aten_tensor_by(y[j]); + splits[j] = dbl::comm::gen_aten_tensor_by(std::move(y[j])); } return splits; } diff --git a/torch_ipex/csrc/cpu/ShadeDataContext.h b/torch_ipex/csrc/cpu/ShadeDataContext.h index 9e323c3a6..b2d2fbc45 100644 --- a/torch_ipex/csrc/cpu/ShadeDataContext.h +++ b/torch_ipex/csrc/cpu/ShadeDataContext.h @@ -2,6 +2,7 @@ #include #include +#include #include "dil/dil.hpp" @@ -13,9 +14,9 @@ namespace cpu { enum SHADE_DATA_TYPE {CPU_RAW, DIL}; struct ShadeDataContext { + c10::optional dil_tensor; ///< DNNL memory buffer for lazy reorder void *cpu_raw_data; ///< The raw memory buffer of storage c10::DeleterFnPtr cpu_del_fun; ///< Delete function to release cpu_raw_data - dil::tensor dil_tensor; ///< DNNL memory buffer for lazy reorder SHADE_DATA_TYPE data_type; ///< Memory buffer type @@ -26,9 +27,10 @@ struct ShadeDataContext { ~ShadeDataContext() { if (this->data_type == SHADE_DATA_TYPE::DIL) { // DIL Tensor - if (this->dil_tensor.is_public_format()) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->dil_tensor.has_value()); + if (this->dil_tensor->is_public_format()) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->cpu_raw_data != nullptr); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->dil_tensor.get_data_handle() == this->cpu_raw_data); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->dil_tensor->get_data_handle() == this->cpu_raw_data); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->cpu_del_fun == &(c10::detail::deleteNothing)); } else { // If dil tensor is block format, the cpu raw data means nothing here. @@ -90,16 +92,17 @@ struct ShadeDataContext { TORCH_INTERNAL_ASSERT_DEBUG_ONLY((data_type == SHADE_DATA_TYPE::CPU_RAW) || (data_type == SHADE_DATA_TYPE::DIL)); if (data_type == SHADE_DATA_TYPE::DIL) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor.has_value()); auto raw_cpu_data = tensor.storage().data_ptr().get(); if (raw_cpu_data == nullptr) { // the dnnl tensor does not share data with raw tensor data. - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_empty())); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_public_format())); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor->is_empty())); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor->is_public_format())); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(check_tensor_own_whole_storage(tensor)); return true; } else { // The dnnl tensor shares some data with raw tensor. - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor.is_public_format()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor->is_public_format()); // For the case: // 1. There is a tensor named A @@ -113,7 +116,7 @@ struct ShadeDataContext { // All these tensors share same buffer of Tensor A with different storge offsets and elements. // So the context modification will impact all these tensors. if (check_tensor_own_whole_storage(tensor)) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor.get_size() == tensor.storage().capacity()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor->get_size() == tensor.storage().capacity()); return true; } } @@ -139,13 +142,14 @@ struct ShadeDataContext { * @return If the input tensor does not contain DNNL buffer, the function will return * an empty DNNL buffer. The caller should check the return buffer is empty or not. */ - static inline dil::tensor getDilTensor(const at::Tensor &tensor) { + static inline dil::tensor& getDilTensor(const at::Tensor &tensor) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.has_storage()); void *raw_context = tensor.storage().data_ptr().get_context(); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(raw_context != nullptr); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isDilTensor(tensor)); ShadeDataContext *shade_data_context = (ShadeDataContext*)raw_context; - return shade_data_context->dil_tensor; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor.has_value()); + return *(shade_data_context->dil_tensor); } /** diff --git a/torch_ipex/csrc/cpu/dbl/Common.cpp b/torch_ipex/csrc/cpu/dbl/Common.cpp index ae8f95d3a..07c55fbd7 100644 --- a/torch_ipex/csrc/cpu/dbl/Common.cpp +++ b/torch_ipex/csrc/cpu/dbl/Common.cpp @@ -48,10 +48,10 @@ dil::tensor try_gen_dil_tensor(const at::Tensor &input) { } } -at::Tensor gen_aten_tensor_by(dil::tensor dil_tensor) { +at::Tensor gen_aten_tensor_by(dil::tensor&& dil_tensor) { // Generate new CPU Tensor and store dil tensor at its storage cpu::ShadeDataContext *shade_data_context = cpu::ShadeDataContext::allocShadeDataContext(); - shade_data_context->dil_tensor = dil_tensor; + shade_data_context->dil_tensor = std::forward(dil_tensor); shade_data_context->data_type = cpu::SHADE_DATA_TYPE::DIL; void *tensor_data = nullptr; if (dil_tensor.is_public_format()) { @@ -84,7 +84,7 @@ at::Tensor empty_dil_tensor(at::IntArrayRef sizes, const at::TensorOptions& opti "'memory_format' argument is incompatible with mkldnn tensor");*/ auto data_type = get_dil_data_type(at::typeMetaToScalarType(options.dtype())); dil::tensor it {sizes.vec(), data_type}; - return gen_aten_tensor_by(it); + return gen_aten_tensor_by(std::move(it)); } void sync_shape_from_dil_to_aten(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor) { diff --git a/torch_ipex/csrc/cpu/dbl/Common.h b/torch_ipex/csrc/cpu/dbl/Common.h index 15a47fffb..8a346da47 100644 --- a/torch_ipex/csrc/cpu/dbl/Common.h +++ b/torch_ipex/csrc/cpu/dbl/Common.h @@ -12,7 +12,7 @@ namespace comm { dil::tensor dil_tensor_from_dense(const at::Tensor& tensor); at::Tensor dil_tensor_to_dense(const at::Tensor& tensor); dil::tensor try_gen_dil_tensor(const at::Tensor &input); -at::Tensor gen_aten_tensor_by(dil::tensor tensor); +at::Tensor gen_aten_tensor_by(dil::tensor&& tensor); at::Tensor empty_dil_tensor(at::IntArrayRef sizes, const at::TensorOptions& options); void sync_shape_from_dil_to_aten(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor); diff --git a/torch_ipex/csrc/cpu/dbl/Pool.cpp b/torch_ipex/csrc/cpu/dbl/Pool.cpp index e2ea9d4de..b303115f1 100644 --- a/torch_ipex/csrc/cpu/dbl/Pool.cpp +++ b/torch_ipex/csrc/cpu/dbl/Pool.cpp @@ -151,7 +151,7 @@ at::Tensor _dil_pooling( algo, dil::prop_kind::forward); - return dbl::comm::gen_aten_tensor_by(y); + return dbl::comm::gen_aten_tensor_by(std::move(y)); } at::Tensor _dil_pooling_backward( @@ -223,7 +223,7 @@ at::Tensor _dil_pooling_backward( {padding_vec_r.cbegin(), padding_vec_r.cend()}, algo); - return dbl::comm::gen_aten_tensor_by(gradx); + return dbl::comm::gen_aten_tensor_by(std::move(gradx)); } } // namespace pool