Skip to content

Commit a158a0e

Browse files
committed
Add assert to check if the dnnl buffer is block format then it should occupy whole storage buffer
1 parent fb8b9df commit a158a0e

File tree

3 files changed

+30
-12
lines changed

3 files changed

+30
-12
lines changed

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ at::Tensor& AtenIpexCPUDev::dil_add_out(
239239
const std::vector<float> scales{1.0, alpha.to<float>()};
240240
dil::sum::compute(scales, {x, y}, z);
241241

242+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(z.is_public_format() || check_tensor_own_whole_storage(result));
242243
dbl::comm::sync_shape_from_dil_to_aten(result, z);
243244
return result;
244245
}
@@ -267,6 +268,7 @@ at::Tensor & AtenIpexCPUDev::dil_add_(at::Tensor& self, const at::Tensor& other,
267268
const std::vector<float> scales{1.0, alpha.to<float>()};
268269
dil::sum::compute(scales, {dil_self, dil_other}, dil_self);
269270

271+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_self.is_public_format() || check_tensor_own_whole_storage(self));
270272
dbl::comm::sync_shape_from_dil_to_aten(self, dil_self);
271273
return self;
272274
}
@@ -283,6 +285,7 @@ at::Tensor& AtenIpexCPUDev::dil_mul_out(at::Tensor& result, const at::Tensor& se
283285

284286
dil::binary::compute(dil_self, dil_other, dil_result, dil::algorithm::binary_mul);
285287

288+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_result.is_public_format() || check_tensor_own_whole_storage(result));
286289
dbl::comm::sync_shape_from_dil_to_aten(result, dil_result);
287290
return result;
288291
}
@@ -346,6 +349,7 @@ at::Tensor& AtenIpexCPUDev::dil_bmm_out(
346349
dil::tensor y = dbl::comm::try_gen_dil_tensor(result);
347350
matmul_common(x, w, dil::tensor(), y);
348351

352+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result));
349353
dbl::comm::sync_shape_from_dil_to_aten(result, y);
350354
return result;
351355
}
@@ -390,6 +394,8 @@ at::Tensor& AtenIpexCPUDev::dil_baddbmm_out(
390394
dil::tensor y = dbl::comm::try_gen_dil_tensor(result);
391395
auto attr_ = dil::attr_t::fuse_sum();
392396
matmul_common(x, w, bias, y, beta, alpha, attr_);
397+
398+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result));
393399
dbl::comm::sync_shape_from_dil_to_aten(result, y);
394400
return result;
395401
}
@@ -489,6 +495,8 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out(
489495
}
490496
}
491497
matmul_common(x_, w_, bias, y, beta, alpha, attr_);
498+
499+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result));
492500
dbl::comm::sync_shape_from_dil_to_aten(result, y);
493501
return result;
494502
}
@@ -974,6 +982,8 @@ at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) {
974982
dil::algorithm::eltwise_relu,
975983
dil::prop_kind::forward_training,
976984
/*alpha*/ 0.0);
985+
986+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_self.is_public_format() || check_tensor_own_whole_storage(input));
977987
dbl::comm::sync_shape_from_dil_to_aten(input, dil_self);
978988
return input;
979989
}
@@ -1041,6 +1051,8 @@ at::Tensor& AtenIpexCPUDev::dil_sigmoid_(at::Tensor& self) {
10411051
dil::tensor x = dbl::comm::try_gen_dil_tensor(self);
10421052
dil::eltwise_forward::compute(
10431053
x, x, dil::algorithm::eltwise_logistic_use_dst_for_bwd, dil::prop_kind::forward);
1054+
1055+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(x.is_public_format() || check_tensor_own_whole_storage(self));
10441056
dbl::comm::sync_shape_from_dil_to_aten(self, x);
10451057
return self;
10461058
}
@@ -1122,6 +1134,8 @@ at::Tensor& AtenIpexCPUDev::dil_cat_out(at::Tensor& result, at::TensorList tenso
11221134
}
11231135
dil::tensor y = dbl::comm::try_gen_dil_tensor(result);
11241136
dil::concat::compute(x, dim, y);
1137+
1138+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result));
11251139
dbl::comm::sync_shape_from_dil_to_aten(result, y);
11261140
return result;
11271141
}

torch_ipex/csrc/cpu/ShadeDataContext.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ struct ShadeDataContext {
9494
if (raw_cpu_data == nullptr) {
9595
// the dnnl tensor does not share data with raw tensor data.
9696
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_empty()));
97+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_public_format()));
98+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(check_tensor_own_whole_storage(tensor));
9799
return true;
98100
} else {
99101
// The dnnl tensor shares some data with raw tensor.

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ dil::tensor try_gen_dil_tensor(const at::Tensor &input) {
4040
if ((!check_aten_dil_shape_info(input, dil_tensor)) && dil_tensor.is_public_format()) {
4141
dil_tensor.set_dims_and_strides(input.sizes().vec(), input.strides().vec());
4242
}
43+
// Does not support the case if the dil tensor is block format but it is just a part of tensor buffer
44+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_tensor.is_public_format() || check_tensor_own_whole_storage(input));
4345
return dil_tensor;
4446
} else {
4547
return dil_tensor_from_dense(input);
@@ -71,14 +73,7 @@ at::Tensor gen_aten_tensor_by(dil::tensor dil_tensor) {
7173
nullptr,
7274
/*resizeable=*/false);
7375
auto _tensor = at::detail::make_tensor<torch_ipex::IPEXTensorImpl>(storage_impl, at::DispatchKey::DPCPPTensorId);
74-
if (dil_tensor.is_public_format()) {
75-
dbl::comm::sync_shape_from_dil_to_aten(_tensor, dil_tensor);
76-
} else {
77-
// Blockformat does not inlcude stride information
78-
auto tensor_sizes = dil_tensor.get_dims();
79-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor_sizes.size() != 1 || tensor_sizes[0] != 0);
80-
_tensor.unsafeGetTensorImpl()->set_sizes_contiguous(tensor_sizes);
81-
}
76+
dbl::comm::sync_shape_from_dil_to_aten(_tensor, dil_tensor);
8277
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_tensor.layout() == c10::kStrided);
8378
return _tensor;
8479
}
@@ -94,10 +89,17 @@ at::Tensor empty_dil_tensor(at::IntArrayRef sizes, const at::TensorOptions& opti
9489

9590
void sync_shape_from_dil_to_aten(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor) {
9691
dil::dims sizes = dil_tensor.get_dims();
97-
dil::dims strides = dil_tensor.get_strides();
98-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ipex_tensor.device().type() == at::DeviceType::DPCPP);
99-
auto* _tensor_impl = (IPEXTensorImpl *)ipex_tensor.unsafeGetTensorImpl();
100-
_tensor_impl->force_set_strided(sizes, strides);
92+
if (dil_tensor.is_public_format()) {
93+
dil::dims strides = dil_tensor.get_strides();
94+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ipex_tensor.device().type() == at::DeviceType::DPCPP);
95+
auto* _tensor_impl = (IPEXTensorImpl *)ipex_tensor.unsafeGetTensorImpl();
96+
_tensor_impl->force_set_strided(sizes, strides);
97+
} else {
98+
// Blockformat does not inlcude stride information
99+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.size() != 1 || sizes[0] != 0);
100+
ipex_tensor.unsafeGetTensorImpl()->set_sizes_contiguous(sizes);
101+
}
102+
101103
}
102104

103105
} // namespace comm

0 commit comments

Comments
 (0)