Skip to content

Add data type reorder for DNNL OP #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions torch_ipex/csrc/aten_ipex_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ namespace bridge {

at::Tensor shallowFallbackToCPUTensorImpl(const at::Tensor& ipexTensor);

void reorderDilTensor(const at::Tensor& ipexTensor) {
void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
void *data_ctx = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get_context();
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
// All aten::tensor with dnnl::tensor should be contiguous
Expand Down Expand Up @@ -89,12 +89,11 @@ void reorderDilTensor(const at::Tensor& ipexTensor) {
void attachShadeDataConext(const at::Tensor& tensor) {
auto tensor_storage_impl = tensor.storage().unsafeGetStorageImpl();
auto& data_ptr = tensor_storage_impl->data_ptr();
// [NOTE]: We assume the real data of storage should be as same as its context.
// Then we use the assumption to check if current tensor has contained
// shade data context.
if (data_ptr.get() != data_ptr.get_context()) {

// Has contained shade context
if (check_tensor_own_shade_context(tensor))
return;
}

auto cur_del_fn = data_ptr.get_deleter();
bool res = data_ptr.compare_exchange_deleter(cur_del_fn, &(c10::detail::deleteNothing));
TORCH_INTERNAL_ASSERT(res);
Expand Down Expand Up @@ -189,7 +188,7 @@ at::Tensor shallowFallbackToCPUTensor(const at::Tensor& ipexTensor) {
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
// Branch 2.1: Dense + Dil Tensor
if (cpu::ShadeDataContext::isDilTensor(ipexTensor)) {
reorderDilTensor(ipexTensor);
reorderDilTensorToPublic(ipexTensor);
}

// Branch 2.2: Dense + CPU Tensor
Expand Down Expand Up @@ -496,24 +495,51 @@ std::vector<at::Tensor> shallowFallbackToCPUTensorList(const at::TensorList& ten
return dpcpp_tensor_vec;
}

void cvtTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dstScalarType) {

void reorderTensorToScalarTypeForDNNL(const at::Tensor& ipexTensor, at::ScalarType dstScalarType) {
TORCH_CHECK(dstScalarType == at::kBFloat16 || dstScalarType == at::kFloat);
auto tensor_dtype = ipexTensor.scalar_type();
TORCH_CHECK(tensor_dtype == at::kBFloat16 || tensor_dtype == at::kFloat);
if (tensor_dtype == dstScalarType)
return;

if (check_tensor_own_shade_context(ipexTensor)) {
// Shade data context has been attached
if (cpu::ShadeDataContext::isDilTensor(ipexTensor)) {
cpu::ShadeDataContext *shade_context = (cpu::ShadeDataContext*)(ipexTensor.storage().data_ptr().get_context());
shade_context->dil_tensor.to_type(get_dil_data_type(dstScalarType));
IPEXTensorImpl* ipex_tensor_impl = (IPEXTensorImpl *)ipexTensor.unsafeGetTensorImpl();
ipex_tensor_impl->reset_data_type(dstScalarType);
ipex_tensor_impl->storage().unsafeGetStorageImpl()->set_dtype(at::scalarTypeToTypeMeta(dstScalarType));
return;
}
}

return reorderTensorToScalaraType(ipexTensor, dstScalarType);
}


void reorderTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dstScalarType) {
if (!(ipexTensor.defined()))
return;

TORCH_CHECK(dstScalarType == at::kBFloat16 || dstScalarType == at::kFloat);
if (ipexTensor.scalar_type() == dstScalarType)

auto tensor_dtype = ipexTensor.scalar_type();
TORCH_CHECK(tensor_dtype == at::kBFloat16 || tensor_dtype == at::kFloat);
if (tensor_dtype == dstScalarType)
return;

if (check_data_is_part_of_storage(ipexTensor))
if (!check_tensor_own_whole_storage(ipexTensor)) {
return;
} else {
TORCH_INTERNAL_ASSERT(false);
}

void *data_ptr = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get();
void *data_ctx = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get_context();
if ((data_ptr != data_ctx) && (data_ctx != nullptr)) {
if (check_tensor_own_shade_context(ipexTensor)) {
// Shade data context has been attached
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
if (cpu::ShadeDataContext::isDilTensor(ipexTensor)) {
reorderDilTensor(ipexTensor);
reorderDilTensorToPublic(ipexTensor);
}
}

Expand All @@ -528,6 +554,7 @@ void cvtTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dstScal
allocator,
/*resizeable=*/true);

void *data_ptr = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get();
if (dstScalarType == at::kBFloat16) {
torch_ipex::cpu::bf16::converter::fp32_to_bf16(storage_impl->data_ptr().get(), data_ptr, nelements);
} else {
Expand Down
34 changes: 33 additions & 1 deletion torch_ipex/csrc/aten_ipex_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,39 @@ std::vector<at::Tensor> fallbackToCPUTensorList(const at::TensorList&);
std::vector<at::Tensor> shallowFallbackToCPUTensorList(const at::TensorList&);

void attachShadeDataConext(const at::Tensor& tensor);
void cvtTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dstScalarType);

/**
* Reorder the DNNL tensor to the public format if the input tensor contains DNNL tensor.
*
* @param[in] ipexTensor The DNNL tensor of the input ipex tensor to be reordered to public format
*/
void reorderDilTensorToPublic(const at::Tensor& ipexTensor);

/**
* Reorder the input tensor to the specified scalar type. It is an optimized version for
* DNNL OP. It means that if DNNL supports current OP, you should call this API. Otherwise, you
* should call @sa @ref reorderTensorToScalaraType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sa @ref ??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doxygen semantic. "sa" means "see", "ref" means a link.

*
* @param[in] ipexTensor The input ipex tensor to be reordered to the spcified scalar type
* @param[in] dstScalarType The scalar type which the input ipex tensor will be reordered to. It should
* be at::kBFloat16 or at::kFloat
*
* @note
* If the input aten tensor is a DNNL tensor and DNNL supports current OP. Then we only
* need to set the data type of DNNL tensor descriptor to the spcified scalar type. It can
* avoid memory copy to improve performance. And we also need to reset the type meta of
* IPEXTensorImpl and StorageImpl to the corresponding type meta of the specified scalar type.
*/
void reorderTensorToScalarTypeForDNNL(const at::Tensor& ipexTensor, at::ScalarType dstScalarType);

/**
* Reorder the input tensor to the specified scalar type.
*
* @param[in] ipexTensor The input ipex tensor to be reordered to the spcified scalar type
* @param[in] dstScalarType The scalar type which the input ipex tensor will be reordered to. It should
* be at::kBFloat16 or at::kFloat
*/
void reorderTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dstScalarType);

// Convert CPU tensor to DPCPP tensor
at::Tensor upgradeToDPCPPTensor(const at::Tensor& ipexTensor);
Expand Down
1 change: 1 addition & 0 deletions torch_ipex/csrc/cpu/ShadeDataContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ struct ShadeDataContext {
void *storage_context = tensor.storage().data_ptr().get_context();
ShadeDataContext *shade_data_context = (ShadeDataContext*)storage_context;
auto data_type = shade_data_context->data_type;
TORCH_INTERNAL_ASSERT((data_type == SHADE_DATA_TYPE::CPU_RAW) || (data_type == SHADE_DATA_TYPE::DIL));

if (data_type == SHADE_DATA_TYPE::DIL) {
TORCH_WARN(tensor.is_contiguous());
Expand Down
4 changes: 4 additions & 0 deletions torch_ipex/csrc/ipex_tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ void IPEXTensorImpl::set_dpcpp_tensor_id() {
this->key_set_.add(at::DispatchKey::VariableTensorId);
}

void IPEXTensorImpl::reset_data_type(at::ScalarType dst_type) {
this->data_type_ = at::scalarTypeToTypeMeta(dst_type);
}

void IPEXTensorImpl::copy_auto_grad(c10::TensorImpl *src_impl) {
if (! src_impl->requires_grad()) {
TORCH_INTERNAL_ASSERT(! this->requires_grad());
Expand Down
1 change: 1 addition & 0 deletions torch_ipex/csrc/ipex_tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class IPEXTensorImpl : public c10::TensorImpl {
void set_storage_data_ptr(c10::DataPtr);
void set_dpcpp_tensor_id();
void force_set_strided(at::IntArrayRef size, at::IntArrayRef stride /*, optional<int64_t> storage_offset_*/);
void reset_data_type(at::ScalarType dst_type);

c10::Storage& get_storage() {
return this->storage_;
Expand Down
14 changes: 13 additions & 1 deletion torch_ipex/csrc/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,24 @@ bool check_auto_dnnl() {
return AutoOptConfig::singleton().get_auto_dnnl();
}

bool check_data_is_part_of_storage(const at::Tensor& tensor) {
bool check_tensor_own_whole_storage(const at::Tensor& tensor) {
if (!(tensor.defined()))
return false;

return (tensor.storage_offset() == 0) &&
(tensor.numel() == tensor.storage().numel());
}

bool check_tensor_own_shade_context(const at::Tensor& tensor) {
if (!(tensor.defined()))
return false;

// [NOTE]: We assume the real data of storage should be as same as its context.
// Then we use the assumption to check if current tensor has contained
// shade data context.
void *data_ptr = tensor.unsafeGetTensorImpl()->storage().data_ptr().get();
void *data_ctx = tensor.unsafeGetTensorImpl()->storage().data_ptr().get_context();
return (data_ptr != data_ctx) && (data_ctx != nullptr);
}

} // namespace torch_ipex
3 changes: 2 additions & 1 deletion torch_ipex/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ bool get_device_count(c10::Device dev_type, c10::DeviceIndex *count);
dil::data_type get_dil_data_type(at::ScalarType);
at::ScalarType get_at_data_type(dil::data_type);
bool check_auto_dnnl();
bool check_data_is_part_of_storage(const at::Tensor& tensor);
bool check_tensor_own_whole_storage(const at::Tensor& tensor);
bool check_tensor_own_shade_context(const at::Tensor& tensor);

} // namespace torch_ipex