@@ -50,7 +50,7 @@ namespace bridge {
50
50
51
51
at::Tensor shallowFallbackToCPUTensorImpl (const at::Tensor& ipexTensor);
52
52
53
- void reorderDilTensor (const at::Tensor& ipexTensor) {
53
+ void reorderDilTensorToPublic (const at::Tensor& ipexTensor) {
54
54
void *data_ctx = ipexTensor.unsafeGetTensorImpl ()->storage ().data_ptr ().get_context ();
55
55
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
56
56
// All aten::tensor with dnnl::tensor should be contiguous
@@ -89,12 +89,11 @@ void reorderDilTensor(const at::Tensor& ipexTensor) {
89
89
void attachShadeDataConext (const at::Tensor& tensor) {
90
90
auto tensor_storage_impl = tensor.storage ().unsafeGetStorageImpl ();
91
91
auto & data_ptr = tensor_storage_impl->data_ptr ();
92
- // [NOTE]: We assume the real data of storage should be as same as its context.
93
- // Then we use the assumption to check if current tensor has contained
94
- // shade data context.
95
- if (data_ptr.get () != data_ptr.get_context ()) {
92
+
93
+ // Has contained shade context
94
+ if (check_tensor_own_shade_context (tensor))
96
95
return ;
97
- }
96
+
98
97
auto cur_del_fn = data_ptr.get_deleter ();
99
98
bool res = data_ptr.compare_exchange_deleter (cur_del_fn, &(c10::detail::deleteNothing));
100
99
TORCH_INTERNAL_ASSERT (res);
@@ -189,7 +188,7 @@ at::Tensor shallowFallbackToCPUTensor(const at::Tensor& ipexTensor) {
189
188
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
190
189
// Branch 2.1: Dense + Dil Tensor
191
190
if (cpu::ShadeDataContext::isDilTensor (ipexTensor)) {
192
- reorderDilTensor (ipexTensor);
191
+ reorderDilTensorToPublic (ipexTensor);
193
192
}
194
193
195
194
// Branch 2.2: Dense + CPU Tensor
@@ -496,24 +495,51 @@ std::vector<at::Tensor> shallowFallbackToCPUTensorList(const at::TensorList& ten
496
495
return dpcpp_tensor_vec;
497
496
}
498
497
499
- void cvtTensorToScalaraType (const at::Tensor& ipexTensor, at::ScalarType dstScalarType) {
498
+
499
+ void reorderTensorToScalarTypeForDNNL (const at::Tensor& ipexTensor, at::ScalarType dstScalarType) {
500
+ TORCH_CHECK (dstScalarType == at::kBFloat16 || dstScalarType == at::kFloat );
501
+ auto tensor_dtype = ipexTensor.scalar_type ();
502
+ TORCH_CHECK (tensor_dtype == at::kBFloat16 || tensor_dtype == at::kFloat );
503
+ if (tensor_dtype == dstScalarType)
504
+ return ;
505
+
506
+ if (check_tensor_own_shade_context (ipexTensor)) {
507
+ // Shade data context has been attached
508
+ if (cpu::ShadeDataContext::isDilTensor (ipexTensor)) {
509
+ cpu::ShadeDataContext *shade_context = (cpu::ShadeDataContext*)(ipexTensor.storage ().data_ptr ().get_context ());
510
+ shade_context->dil_tensor .to_type (get_dil_data_type (dstScalarType));
511
+ IPEXTensorImpl* ipex_tensor_impl = (IPEXTensorImpl *)ipexTensor.unsafeGetTensorImpl ();
512
+ ipex_tensor_impl->reset_data_type (dstScalarType);
513
+ ipex_tensor_impl->storage ().unsafeGetStorageImpl ()->set_dtype (at::scalarTypeToTypeMeta (dstScalarType));
514
+ return ;
515
+ }
516
+ }
517
+
518
+ return reorderTensorToScalaraType (ipexTensor, dstScalarType);
519
+ }
520
+
521
+
522
+ void reorderTensorToScalaraType (const at::Tensor& ipexTensor, at::ScalarType dstScalarType) {
500
523
if (!(ipexTensor.defined ()))
501
524
return ;
502
525
503
526
TORCH_CHECK (dstScalarType == at::kBFloat16 || dstScalarType == at::kFloat );
504
- if (ipexTensor.scalar_type () == dstScalarType)
527
+
528
+ auto tensor_dtype = ipexTensor.scalar_type ();
529
+ TORCH_CHECK (tensor_dtype == at::kBFloat16 || tensor_dtype == at::kFloat );
530
+ if (tensor_dtype == dstScalarType)
505
531
return ;
506
532
507
- if (check_data_is_part_of_storage (ipexTensor))
533
+ if (! check_tensor_own_whole_storage (ipexTensor)) {
508
534
return ;
535
+ } else {
536
+ TORCH_INTERNAL_ASSERT (false );
537
+ }
509
538
510
- void *data_ptr = ipexTensor.unsafeGetTensorImpl ()->storage ().data_ptr ().get ();
511
- void *data_ctx = ipexTensor.unsafeGetTensorImpl ()->storage ().data_ptr ().get_context ();
512
- if ((data_ptr != data_ctx) && (data_ctx != nullptr )) {
539
+ if (check_tensor_own_shade_context (ipexTensor)) {
513
540
// Shade data context has been attached
514
- cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
515
541
if (cpu::ShadeDataContext::isDilTensor (ipexTensor)) {
516
- reorderDilTensor (ipexTensor);
542
+ reorderDilTensorToPublic (ipexTensor);
517
543
}
518
544
}
519
545
@@ -528,6 +554,7 @@ void cvtTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dstScal
528
554
allocator,
529
555
/* resizeable=*/ true );
530
556
557
+ void *data_ptr = ipexTensor.unsafeGetTensorImpl ()->storage ().data_ptr ().get ();
531
558
if (dstScalarType == at::kBFloat16 ) {
532
559
torch_ipex::cpu::bf16::converter::fp32_to_bf16 (storage_impl->data_ptr ().get (), data_ptr, nelements);
533
560
} else {
0 commit comments