Skip to content

use optional dil tensor & move semantics #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions torch_ipex/csrc/aten_ipex_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
void *data_ctx = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get_context();
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
#if defined(_DEBUG)
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_empty()));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor->is_empty()));
#endif
dil::tensor &dil_tensor = shade_data_context->dil_tensor;
dil::tensor &dil_tensor = *shade_data_context->dil_tensor;

if (dil_tensor.is_public_format()) {
#if defined(_DEBUG)
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->cpu_raw_data == shade_data_context->dil_tensor.get_data_handle());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->cpu_raw_data == shade_data_context->dil_tensor->get_data_handle());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->cpu_raw_data != nullptr);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->cpu_del_fun != nullptr);
#endif
Expand Down Expand Up @@ -106,7 +106,7 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
}


void attachShadeDataConext(const at::Tensor& tensor) {
void attachShadeDataContext(const at::Tensor& tensor) {
auto tensor_storage_impl = tensor.storage().unsafeGetStorageImpl();
auto& data_ptr = tensor_storage_impl->data_ptr();

Expand Down Expand Up @@ -272,7 +272,7 @@ at::Tensor shallowUpgradeToDPCPPTensor(const at::Tensor& cpuTensor) {
CHECK_TENSOR_CRITICAL(_tensor, cpuTensor, true);
//TODO: Cannot set reserved_
// dest_impl->reserved_ = src_impl->reserved_;
attachShadeDataConext(_tensor);
attachShadeDataContext(_tensor);
return _tensor;
}
}
Expand Down Expand Up @@ -303,7 +303,7 @@ at::Tensor shallowUpgradeToDPCPPTensorA(const at::Tensor& ipexTensor, const at::
ipex_impl->copy_auto_grad(cpuTensor.unsafeGetTensorImpl());
CHECK_TENSOR_CRITICAL(_tensor, cpuTensor, true);

attachShadeDataConext(_tensor);
attachShadeDataContext(_tensor);
return _tensor;
}

Expand Down Expand Up @@ -388,7 +388,7 @@ const at::Tensor& shallowUpgradeToDPCPPTensorAW(const at::Tensor& ipexTensor, co
ipex_tensor_impl->copy_meta_info(cpuTensor.unsafeGetTensorImpl());
ipex_tensor_impl->copy_auto_grad(cpuTensor.unsafeGetTensorImpl());
CHECK_TENSOR_CRITICAL(ipexTensor, cpuTensor, true);
attachShadeDataConext(ipexTensor);
attachShadeDataContext(ipexTensor);
return ipexTensor;
}
}
Expand Down Expand Up @@ -417,7 +417,7 @@ void reorderTensorToScalarTypeForDNNL(const at::Tensor& ipexTensor, at::ScalarTy
// Shade data context has been attached
if (cpu::ShadeDataContext::isDilTensor(ipexTensor)) {
cpu::ShadeDataContext *shade_context = (cpu::ShadeDataContext*)(ipexTensor.storage().data_ptr().get_context());
shade_context->dil_tensor.to_type(get_dil_data_type(dstScalarType));
shade_context->dil_tensor->to_type(get_dil_data_type(dstScalarType));
IPEXTensorImpl* ipex_tensor_impl = (IPEXTensorImpl *)ipexTensor.unsafeGetTensorImpl();
ipex_tensor_impl->reset_data_type(dstScalarType);
ipex_tensor_impl->storage().unsafeGetStorageImpl()->set_dtype(at::scalarTypeToTypeMeta(dstScalarType));
Expand Down
2 changes: 1 addition & 1 deletion torch_ipex/csrc/aten_ipex_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace bridge {
at::Tensor shallowFallbackToCPUTensor(const at::Tensor& ipexTensor);
std::vector<at::Tensor> shallowFallbackToCPUTensorList(const at::TensorList&);

void attachShadeDataConext(const at::Tensor& tensor);
void attachShadeDataContext(const at::Tensor& tensor);

/**
* Reorder the DNNL tensor to the public format if the input tensor contains DNNL tensor.
Expand Down
68 changes: 34 additions & 34 deletions torch_ipex/csrc/cpu/DevOPs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ at::Tensor AtenIpexCPUDev::dil_convolution(
dilation,
groups);

return dbl::comm::gen_aten_tensor_by(dil_output);
return dbl::comm::gen_aten_tensor_by(std::move(dil_output));
}

at::Tensor dil_convolution_backward_input(
Expand All @@ -87,7 +87,7 @@ at::Tensor dil_convolution_backward_input(
padding.vec(),
padding.vec(),
groups);
return dbl::comm::gen_aten_tensor_by(dil_grad_input);
return dbl::comm::gen_aten_tensor_by(std::move(dil_grad_input));
}

std::tuple<at::Tensor, at::Tensor> dil_convolution_backward_weights(
Expand Down Expand Up @@ -117,8 +117,8 @@ std::tuple<at::Tensor, at::Tensor> dil_convolution_backward_weights(
groups,
diff_weight_type);
return std::make_tuple(
dbl::comm::gen_aten_tensor_by(dil_grad_weight),
dbl::comm::gen_aten_tensor_by(dil_grad_bias));
dbl::comm::gen_aten_tensor_by(std::move(dil_grad_weight)),
dbl::comm::gen_aten_tensor_by(std::move(dil_grad_bias)));
} else {
dil::convolution_backward_weights::compute(
dil_input,
Expand All @@ -132,7 +132,7 @@ std::tuple<at::Tensor, at::Tensor> dil_convolution_backward_weights(
groups,
diff_weight_type);
return std::make_tuple(
dbl::comm::gen_aten_tensor_by(dil_grad_weight),
dbl::comm::gen_aten_tensor_by(std::move(dil_grad_weight)),
at::Tensor());
}
}
Expand Down Expand Up @@ -255,7 +255,7 @@ at::Tensor AtenIpexCPUDev::dil_add(const at::Tensor& self, const at::Tensor& oth
const std::vector<float> scales{1.0, alpha.to<float>()};
dil::sum::compute(scales, {x, y}, z);

return dbl::comm::gen_aten_tensor_by(z);
return dbl::comm::gen_aten_tensor_by(std::move(z));
}

at::Tensor & AtenIpexCPUDev::dil_add_(at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
Expand Down Expand Up @@ -552,9 +552,9 @@ at::Tensor AtenIpexCPUDev::dil_linear(
output_size.push_back(weight.size(0));

if (self.dim() > 2) {
return dbl::comm::gen_aten_tensor_by(y).reshape(output_size);
return dbl::comm::gen_aten_tensor_by(std::move(y)).reshape(output_size);
}
return dbl::comm::gen_aten_tensor_by(y);
return dbl::comm::gen_aten_tensor_by(std::move(y));
}

at::Tensor dil_linear_backward_input(
Expand All @@ -574,9 +574,9 @@ at::Tensor dil_linear_backward_input(
grady, w, {input_reshaped_size.begin(), input_reshaped_size.end()}, gradx);

if (input_size.size() > 2) {
return dbl::comm::gen_aten_tensor_by(gradx).reshape(input_size);
return dbl::comm::gen_aten_tensor_by(std::move(gradx)).reshape(input_size);
}
return dbl::comm::gen_aten_tensor_by(gradx);
return dbl::comm::gen_aten_tensor_by(std::move(gradx));
}

std::tuple<at::Tensor, at::Tensor> dil_linear_backward_weights(
Expand All @@ -593,12 +593,12 @@ std::tuple<at::Tensor, at::Tensor> dil_linear_backward_weights(
if (bias_defined) {
dil::inner_product_backward_weights::compute(x, grady, gradw, gradb, diff_weight_type);
return std::tuple<at::Tensor, at::Tensor>{
dbl::comm::gen_aten_tensor_by(gradw),
dbl::comm::gen_aten_tensor_by(gradb)};
dbl::comm::gen_aten_tensor_by(std::move(gradw)),
dbl::comm::gen_aten_tensor_by(std::move(gradb))};
} else {
dil::inner_product_backward_weights::compute(x, grady, gradw, diff_weight_type);
return std::tuple<at::Tensor, at::Tensor>{
dbl::comm::gen_aten_tensor_by(gradw),
dbl::comm::gen_aten_tensor_by(std::move(gradw)),
at::Tensor()};
}
}
Expand Down Expand Up @@ -632,8 +632,8 @@ std::tuple<at::Tensor, at::Tensor> _dil_dropout(
dil::tensor y;
dil::dropout_forward::compute(x, ratio, y, mask);
return std::tuple<at::Tensor, at::Tensor>{
dbl::comm::gen_aten_tensor_by(y),
dbl::comm::gen_aten_tensor_by(mask)};
dbl::comm::gen_aten_tensor_by(std::move(y)),
dbl::comm::gen_aten_tensor_by(std::move(mask))};
}

at::Tensor AtenIpexCPUDev::dil_dropout(const at::Tensor& self, double ratio, bool train) {
Expand All @@ -657,7 +657,7 @@ at::Tensor AtenIpexCPUDev::dil_dropout_backward(

dil::tensor dX;
dil::dropout_backward::compute(mask_dil, dY, dX);
return dbl::comm::gen_aten_tensor_by(dX);
return dbl::comm::gen_aten_tensor_by(std::move(dX));
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_batch_norm(
Expand Down Expand Up @@ -696,9 +696,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_batch_
dil::sum::compute(scales_var, {v, saved_var}, v);
}
return std::make_tuple(
dbl::comm::gen_aten_tensor_by(y),
dbl::comm::gen_aten_tensor_by(saved_mean),
dbl::comm::gen_aten_tensor_by(saved_var));
dbl::comm::gen_aten_tensor_by(std::move(y)),
dbl::comm::gen_aten_tensor_by(std::move(saved_mean)),
dbl::comm::gen_aten_tensor_by(std::move(saved_var)));
} else {
if (use_running_stat) {
dil::tensor m = dbl::comm::try_gen_dil_tensor(running_mean);
Expand All @@ -710,7 +710,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_batch_
x, w, b, y, eps);
}
return std::make_tuple(
dbl::comm::gen_aten_tensor_by(y),
dbl::comm::gen_aten_tensor_by(std::move(y)),
at::Tensor(),
at::Tensor());
}
Expand Down Expand Up @@ -742,9 +742,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_batch_
x, m, v, grady, w, gradx, gradw, gradb, eps);

return std::make_tuple(
dbl::comm::gen_aten_tensor_by(gradx),
dbl::comm::gen_aten_tensor_by(gradw),
dbl::comm::gen_aten_tensor_by(gradb));
dbl::comm::gen_aten_tensor_by(std::move(gradx)),
dbl::comm::gen_aten_tensor_by(std::move(gradw)),
dbl::comm::gen_aten_tensor_by(std::move(gradb)));
}

at::Tensor AtenIpexCPUDev::dil_max_pooling(
Expand Down Expand Up @@ -969,7 +969,7 @@ at::Tensor AtenIpexCPUDev::dil_relu(const at::Tensor& input) {
dil::tensor y;
dil::eltwise_forward::compute(
x, y, dil::algorithm::eltwise_relu, dil::prop_kind::forward_training, /*alpha*/ 0.0);
return dbl::comm::gen_aten_tensor_by(y);
return dbl::comm::gen_aten_tensor_by(std::move(y));
}

at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) {
Expand Down Expand Up @@ -998,7 +998,7 @@ at::Tensor AtenIpexCPUDev::dil_threshold_backward(const at::Tensor& grad_output,
dil::tensor gradx;
dil::eltwise_backward::compute(x, grady, gradx,
dil::algorithm::eltwise_relu, /*alpha*/ 0.0);
return dbl::comm::gen_aten_tensor_by(gradx);
return dbl::comm::gen_aten_tensor_by(std::move(gradx));
}

at::Tensor AtenIpexCPUDev::dil__softmax(
Expand All @@ -1014,7 +1014,7 @@ at::Tensor AtenIpexCPUDev::dil__softmax(
dil::tensor x = dbl::comm::try_gen_dil_tensor(self);
dil::tensor y;
dil::softmax_forward::compute(x, y, wrapped_dim);
return dbl::comm::gen_aten_tensor_by(y);
return dbl::comm::gen_aten_tensor_by(std::move(y));
}

at::Tensor AtenIpexCPUDev::dil__softmax_backward_data(
Expand All @@ -1032,7 +1032,7 @@ at::Tensor AtenIpexCPUDev::dil__softmax_backward_data(
dil::tensor grady = dbl::comm::try_gen_dil_tensor(grad_output_contiguous);
dil::tensor gradx;
dil::softmax_backward::compute(y, grady, gradx, wrapped_dim);
return dbl::comm::gen_aten_tensor_by(gradx);
return dbl::comm::gen_aten_tensor_by(std::move(gradx));
}

at::Tensor AtenIpexCPUDev::dil_sigmoid(const at::Tensor& self) {
Expand All @@ -1042,7 +1042,7 @@ at::Tensor AtenIpexCPUDev::dil_sigmoid(const at::Tensor& self) {
dil::tensor y;
dil::eltwise_forward::compute(
x, y, dil::algorithm::eltwise_logistic_use_dst_for_bwd, dil::prop_kind::forward);
return dbl::comm::gen_aten_tensor_by(y);
return dbl::comm::gen_aten_tensor_by(std::move(y));
}

at::Tensor& AtenIpexCPUDev::dil_sigmoid_(at::Tensor& self) {
Expand All @@ -1069,7 +1069,7 @@ at::Tensor AtenIpexCPUDev::dil_sigmoid_backward(
dil::tensor gx;
dil::eltwise_backward::compute(y, gy, gx,
dil::algorithm::eltwise_logistic_use_dst_for_bwd);
return dbl::comm::gen_aten_tensor_by(gx);
return dbl::comm::gen_aten_tensor_by(std::move(gx));
}

at::Tensor AtenIpexCPUDev::dil_reshape(const at::Tensor& self, at::IntArrayRef size) {
Expand All @@ -1082,7 +1082,7 @@ at::Tensor AtenIpexCPUDev::dil_reshape(const at::Tensor& self, at::IntArrayRef s
const dil::tensor x = dbl::comm::try_gen_dil_tensor(self);
dil::tensor y{x};
y.reshape(inferred_size);
return dbl::comm::gen_aten_tensor_by(y);
return dbl::comm::gen_aten_tensor_by(std::move(y));
}

at::Tensor AtenIpexCPUDev::dil_clone(const at::Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
Expand All @@ -1095,7 +1095,7 @@ at::Tensor AtenIpexCPUDev::dil_clone(const at::Tensor& self, c10::optional<c10::
dil::tensor src = dbl::comm::try_gen_dil_tensor(self);
dil::tensor dst;
dil::direct_copy::compute(src, dst);
return dbl::comm::gen_aten_tensor_by(dst);
return dbl::comm::gen_aten_tensor_by(std::move(dst));
}

at::Tensor AtenIpexCPUDev::dil_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1) {
Expand All @@ -1110,7 +1110,7 @@ at::Tensor AtenIpexCPUDev::dil_transpose(const at::Tensor & self, int64_t dim0,
dim1 = at::maybe_wrap_dim(dim1, self.dim());
std::swap(axes[dim0], axes[dim1]);
y.transpose_from(x, axes);
return dbl::comm::gen_aten_tensor_by(y);
return dbl::comm::gen_aten_tensor_by(std::move(y));
}

inline void check_cat_no_zero_dim(at::TensorList tensors) {
Expand Down Expand Up @@ -1154,7 +1154,7 @@ at::Tensor AtenIpexCPUDev::dil_cat(at::TensorList tensors, int64_t dim) {
}
dil::tensor y;
dil::concat::compute(x, dim, y);
return dbl::comm::gen_aten_tensor_by(y);
return dbl::comm::gen_aten_tensor_by(std::move(y));
}

std::vector<at::Tensor> AtenIpexCPUDev::dil_split_with_sizes(const at::Tensor& self, at::IntArrayRef split_sizes, int64_t dim) {
Expand All @@ -1175,7 +1175,7 @@ std::vector<at::Tensor> AtenIpexCPUDev::dil_split_with_sizes(const at::Tensor& s
dim = at::maybe_wrap_dim(dim, self.dim());
auto y = dil::spliter::compute(x, sizes, dim, false);
for (auto j = 0; j < num_splits; j++) {
splits[j] = dbl::comm::gen_aten_tensor_by(y[j]);
splits[j] = dbl::comm::gen_aten_tensor_by(std::move(y[j]));
}
return splits;
}
Expand Down
22 changes: 13 additions & 9 deletions torch_ipex/csrc/cpu/ShadeDataContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/Tensor.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>

#include "dil/dil.hpp"

Expand All @@ -13,9 +14,9 @@ namespace cpu {
enum SHADE_DATA_TYPE {CPU_RAW, DIL};

struct ShadeDataContext {
c10::optional<dil::tensor> dil_tensor; ///< DNNL memory buffer for lazy reorder
void *cpu_raw_data; ///< The raw memory buffer of storage
c10::DeleterFnPtr cpu_del_fun; ///< Delete function to release cpu_raw_data
dil::tensor dil_tensor; ///< DNNL memory buffer for lazy reorder

SHADE_DATA_TYPE data_type; ///< Memory buffer type

Expand All @@ -26,9 +27,10 @@ struct ShadeDataContext {

~ShadeDataContext() {
if (this->data_type == SHADE_DATA_TYPE::DIL) { // DIL Tensor
if (this->dil_tensor.is_public_format()) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->dil_tensor.has_value());
if (this->dil_tensor->is_public_format()) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->cpu_raw_data != nullptr);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->dil_tensor.get_data_handle() == this->cpu_raw_data);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->dil_tensor->get_data_handle() == this->cpu_raw_data);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->cpu_del_fun == &(c10::detail::deleteNothing));
} else {
// If dil tensor is block format, the cpu raw data means nothing here.
Expand Down Expand Up @@ -90,16 +92,17 @@ struct ShadeDataContext {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((data_type == SHADE_DATA_TYPE::CPU_RAW) || (data_type == SHADE_DATA_TYPE::DIL));

if (data_type == SHADE_DATA_TYPE::DIL) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor.has_value());
auto raw_cpu_data = tensor.storage().data_ptr().get();
if (raw_cpu_data == nullptr) {
// the dnnl tensor does not share data with raw tensor data.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_empty()));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_public_format()));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor->is_empty()));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor->is_public_format()));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(check_tensor_own_whole_storage(tensor));
return true;
} else {
// The dnnl tensor shares some data with raw tensor.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor.is_public_format());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor->is_public_format());

// For the case:
// 1. There is a tensor named A
Expand All @@ -113,7 +116,7 @@ struct ShadeDataContext {
// All these tensors share same buffer of Tensor A with different storge offsets and elements.
// So the context modification will impact all these tensors.
if (check_tensor_own_whole_storage(tensor)) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor.get_size() == tensor.storage().capacity());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor->get_size() == tensor.storage().capacity());
return true;
}
}
Expand All @@ -139,13 +142,14 @@ struct ShadeDataContext {
* @return If the input tensor does not contain DNNL buffer, the function will return
* an empty DNNL buffer. The caller should check the return buffer is empty or not.
*/
static inline dil::tensor getDilTensor(const at::Tensor &tensor) {
static inline dil::tensor& getDilTensor(const at::Tensor &tensor) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.has_storage());
void *raw_context = tensor.storage().data_ptr().get_context();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(raw_context != nullptr);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isDilTensor(tensor));
ShadeDataContext *shade_data_context = (ShadeDataContext*)raw_context;
return shade_data_context->dil_tensor;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor.has_value());
return *(shade_data_context->dil_tensor);
}

/**
Expand Down
Loading