Skip to content

Commit de92da8

Browse files
authored
Merge pull request #1 from EikanWang/master
Add data type reorder for DNNL OP
2 parents aa86342 + 5be717a commit de92da8

File tree

7 files changed

+96
-18
lines changed

7 files changed

+96
-18
lines changed

torch_ipex/csrc/aten_ipex_bridge.cpp

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ namespace bridge {
5050

5151
at::Tensor shallowFallbackToCPUTensorImpl(const at::Tensor& ipexTensor);
5252

53-
void reorderDilTensor(const at::Tensor& ipexTensor) {
53+
void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
5454
void *data_ctx = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get_context();
5555
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
5656
// All aten::tensor with dnnl::tensor should be contiguous
@@ -89,12 +89,11 @@ void reorderDilTensor(const at::Tensor& ipexTensor) {
8989
void attachShadeDataConext(const at::Tensor& tensor) {
9090
auto tensor_storage_impl = tensor.storage().unsafeGetStorageImpl();
9191
auto& data_ptr = tensor_storage_impl->data_ptr();
92-
// [NOTE]: We assume the real data of storage should be as same as its context.
93-
// Then we use the assumption to check if current tensor has contained
94-
// shade data context.
95-
if (data_ptr.get() != data_ptr.get_context()) {
92+
93+
// Has contained shade context
94+
if (check_tensor_own_shade_context(tensor))
9695
return;
97-
}
96+
9897
auto cur_del_fn = data_ptr.get_deleter();
9998
bool res = data_ptr.compare_exchange_deleter(cur_del_fn, &(c10::detail::deleteNothing));
10099
TORCH_INTERNAL_ASSERT(res);
@@ -189,7 +188,7 @@ at::Tensor shallowFallbackToCPUTensor(const at::Tensor& ipexTensor) {
189188
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
190189
// Branch 2.1: Dense + Dil Tensor
191190
if (cpu::ShadeDataContext::isDilTensor(ipexTensor)) {
192-
reorderDilTensor(ipexTensor);
191+
reorderDilTensorToPublic(ipexTensor);
193192
}
194193

195194
// Branch 2.2: Dense + CPU Tensor
@@ -496,24 +495,51 @@ std::vector<at::Tensor> shallowFallbackToCPUTensorList(const at::TensorList& ten
496495
return dpcpp_tensor_vec;
497496
}
498497

499-
void cvtTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dstScalarType) {
498+
499+
void reorderTensorToScalarTypeForDNNL(const at::Tensor& ipexTensor, at::ScalarType dstScalarType) {
500+
TORCH_CHECK(dstScalarType == at::kBFloat16 || dstScalarType == at::kFloat);
501+
auto tensor_dtype = ipexTensor.scalar_type();
502+
TORCH_CHECK(tensor_dtype == at::kBFloat16 || tensor_dtype == at::kFloat);
503+
if (tensor_dtype == dstScalarType)
504+
return;
505+
506+
if (check_tensor_own_shade_context(ipexTensor)) {
507+
// Shade data context has been attached
508+
if (cpu::ShadeDataContext::isDilTensor(ipexTensor)) {
509+
cpu::ShadeDataContext *shade_context = (cpu::ShadeDataContext*)(ipexTensor.storage().data_ptr().get_context());
510+
shade_context->dil_tensor.to_type(get_dil_data_type(dstScalarType));
511+
IPEXTensorImpl* ipex_tensor_impl = (IPEXTensorImpl *)ipexTensor.unsafeGetTensorImpl();
512+
ipex_tensor_impl->reset_data_type(dstScalarType);
513+
ipex_tensor_impl->storage().unsafeGetStorageImpl()->set_dtype(at::scalarTypeToTypeMeta(dstScalarType));
514+
return;
515+
}
516+
}
517+
518+
return reorderTensorToScalaraType(ipexTensor, dstScalarType);
519+
}
520+
521+
522+
void reorderTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dstScalarType) {
500523
if (!(ipexTensor.defined()))
501524
return;
502525

503526
TORCH_CHECK(dstScalarType == at::kBFloat16 || dstScalarType == at::kFloat);
504-
if (ipexTensor.scalar_type() == dstScalarType)
527+
528+
auto tensor_dtype = ipexTensor.scalar_type();
529+
TORCH_CHECK(tensor_dtype == at::kBFloat16 || tensor_dtype == at::kFloat);
530+
if (tensor_dtype == dstScalarType)
505531
return;
506532

507-
if (check_data_is_part_of_storage(ipexTensor))
533+
if (!check_tensor_own_whole_storage(ipexTensor)) {
508534
return;
535+
} else {
536+
TORCH_INTERNAL_ASSERT(false);
537+
}
509538

510-
void *data_ptr = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get();
511-
void *data_ctx = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get_context();
512-
if ((data_ptr != data_ctx) && (data_ctx != nullptr)) {
539+
if (check_tensor_own_shade_context(ipexTensor)) {
513540
// Shade data context has been attached
514-
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
515541
if (cpu::ShadeDataContext::isDilTensor(ipexTensor)) {
516-
reorderDilTensor(ipexTensor);
542+
reorderDilTensorToPublic(ipexTensor);
517543
}
518544
}
519545

@@ -528,6 +554,7 @@ void cvtTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dstScal
528554
allocator,
529555
/*resizeable=*/true);
530556

557+
void *data_ptr = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get();
531558
if (dstScalarType == at::kBFloat16) {
532559
torch_ipex::cpu::bf16::converter::fp32_to_bf16(storage_impl->data_ptr().get(), data_ptr, nelements);
533560
} else {

torch_ipex/csrc/aten_ipex_bridge.h

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,39 @@ std::vector<at::Tensor> fallbackToCPUTensorList(const at::TensorList&);
1616
std::vector<at::Tensor> shallowFallbackToCPUTensorList(const at::TensorList&);
1717

1818
void attachShadeDataConext(const at::Tensor& tensor);
19-
void cvtTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dstScalarType);
19+
20+
/**
21+
* Reorder the DNNL tensor to the public format if the input tensor contains DNNL tensor.
22+
*
23+
* @param[in] ipexTensor The DNNL tensor of the input ipex tensor to be reordered to public format
24+
*/
25+
void reorderDilTensorToPublic(const at::Tensor& ipexTensor);
26+
27+
/**
28+
* Reorder the input tensor to the specified scalar type. It is an optimized version for
29+
* DNNL OP. It means that if DNNL supports current OP, you should call this API. Otherwise, you
30+
* should call @sa @ref reorderTensorToScalaraType
31+
*
32+
* @param[in] ipexTensor The input ipex tensor to be reordered to the spcified scalar type
33+
* @param[in] dstScalarType The scalar type which the input ipex tensor will be reordered to. It should
34+
* be at::kBFloat16 or at::kFloat
35+
*
36+
* @note
37+
* If the input aten tensor is a DNNL tensor and DNNL supports current OP. Then we only
38+
* need to set the data type of DNNL tensor descriptor to the spcified scalar type. It can
39+
* avoid memory copy to improve performance. And we also need to reset the type meta of
40+
* IPEXTensorImpl and StorageImpl to the corresponding type meta of the specified scalar type.
41+
*/
42+
void reorderTensorToScalarTypeForDNNL(const at::Tensor& ipexTensor, at::ScalarType dstScalarType);
43+
44+
/**
45+
* Reorder the input tensor to the specified scalar type.
46+
*
47+
* @param[in] ipexTensor The input ipex tensor to be reordered to the spcified scalar type
48+
* @param[in] dstScalarType The scalar type which the input ipex tensor will be reordered to. It should
49+
* be at::kBFloat16 or at::kFloat
50+
*/
51+
void reorderTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dstScalarType);
2052

2153
// Convert CPU tensor to DPCPP tensor
2254
at::Tensor upgradeToDPCPPTensor(const at::Tensor& ipexTensor);

torch_ipex/csrc/cpu/ShadeDataContext.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ struct ShadeDataContext {
8787
void *storage_context = tensor.storage().data_ptr().get_context();
8888
ShadeDataContext *shade_data_context = (ShadeDataContext*)storage_context;
8989
auto data_type = shade_data_context->data_type;
90+
TORCH_INTERNAL_ASSERT((data_type == SHADE_DATA_TYPE::CPU_RAW) || (data_type == SHADE_DATA_TYPE::DIL));
9091

9192
if (data_type == SHADE_DATA_TYPE::DIL) {
9293
TORCH_WARN(tensor.is_contiguous());

torch_ipex/csrc/ipex_tensor_impl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ void IPEXTensorImpl::set_dpcpp_tensor_id() {
6464
this->key_set_.add(at::DispatchKey::VariableTensorId);
6565
}
6666

67+
void IPEXTensorImpl::reset_data_type(at::ScalarType dst_type) {
68+
this->data_type_ = at::scalarTypeToTypeMeta(dst_type);
69+
}
70+
6771
void IPEXTensorImpl::copy_auto_grad(c10::TensorImpl *src_impl) {
6872
if (! src_impl->requires_grad()) {
6973
TORCH_INTERNAL_ASSERT(! this->requires_grad());

torch_ipex/csrc/ipex_tensor_impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class IPEXTensorImpl : public c10::TensorImpl {
2424
void set_storage_data_ptr(c10::DataPtr);
2525
void set_dpcpp_tensor_id();
2626
void force_set_strided(at::IntArrayRef size, at::IntArrayRef stride /*, optional<int64_t> storage_offset_*/);
27+
void reset_data_type(at::ScalarType dst_type);
2728

2829
c10::Storage& get_storage() {
2930
return this->storage_;

torch_ipex/csrc/utils.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,24 @@ bool check_auto_dnnl() {
104104
return AutoOptConfig::singleton().get_auto_dnnl();
105105
}
106106

107-
bool check_data_is_part_of_storage(const at::Tensor& tensor) {
107+
bool check_tensor_own_whole_storage(const at::Tensor& tensor) {
108108
if (!(tensor.defined()))
109109
return false;
110110

111111
return (tensor.storage_offset() == 0) &&
112112
(tensor.numel() == tensor.storage().numel());
113113
}
114114

115+
bool check_tensor_own_shade_context(const at::Tensor& tensor) {
116+
if (!(tensor.defined()))
117+
return false;
118+
119+
// [NOTE]: We assume the real data of storage should be as same as its context.
120+
// Then we use the assumption to check if current tensor has contained
121+
// shade data context.
122+
void *data_ptr = tensor.unsafeGetTensorImpl()->storage().data_ptr().get();
123+
void *data_ctx = tensor.unsafeGetTensorImpl()->storage().data_ptr().get_context();
124+
return (data_ptr != data_ctx) && (data_ctx != nullptr);
125+
}
126+
115127
} // namespace torch_ipex

torch_ipex/csrc/utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ bool get_device_count(c10::Device dev_type, c10::DeviceIndex *count);
1818
dil::data_type get_dil_data_type(at::ScalarType);
1919
at::ScalarType get_at_data_type(dil::data_type);
2020
bool check_auto_dnnl();
21-
bool check_data_is_part_of_storage(const at::Tensor& tensor);
21+
bool check_tensor_own_whole_storage(const at::Tensor& tensor);
22+
bool check_tensor_own_shade_context(const at::Tensor& tensor);
2223

2324
} // namespace torch_ipex

0 commit comments

Comments
 (0)