Skip to content

Commit 9f7093c

Browse files
committed
Sync shape info between dil tensor and aten tensor
1 parent 8418986 commit 9f7093c

File tree

4 files changed

+22
-12
lines changed

4 files changed

+22
-12
lines changed

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,8 @@ def is_out_func(fname):
306306
if param_var == 'out' and is_out_func(fname):
307307
code += ' TORCH_INTERNAL_ASSERT({}.is_contiguous());\n'.format(param_var)
308308
else:
309-
param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var)
309+
# param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var)
310+
None
310311
param_seq_str_vec.append(param_seq_str)
311312
code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors))\n'
312313
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_seq_str_vec))

tests/cpu/test_torch.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12785,12 +12785,8 @@ def _test_memory_format_transformations(self, device, input_generator_fn, transf
1278512785
clone = transformation_fn(xc)
1278612786

1278712787
if default_is_preserve:
12788-
if ipex.get_auto_dnnl():
12789-
self.assertTrue(clone.is_contiguous())
12790-
self.assertFalse(clone.is_contiguous(memory_format=memory_format))
12791-
else:
12792-
self.assertFalse(clone.is_contiguous())
12793-
self.assertTrue(clone.is_contiguous(memory_format=memory_format))
12788+
self.assertFalse(clone.is_contiguous())
12789+
self.assertTrue(clone.is_contiguous(memory_format=memory_format))
1279412790
else:
1279512791
self.assertTrue(clone.is_contiguous())
1279612792
self.assertFalse(clone.is_contiguous(memory_format=memory_format))

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ dil::tensor dil_tensor_from_dense(const at::Tensor& tensor) {
2020
tensor.layout() == at::Layout::Strided,
2121
"dil_tensor_view_from_dense expects dense tensor input");
2222
at::ScalarType cur_type = tensor.scalar_type();
23-
return {tensor.sizes().vec(), get_dil_data_type(cur_type), tensor.data_ptr()};
23+
return {tensor.sizes().vec(), get_dil_data_type(cur_type), tensor.strides().vec(), tensor.data_ptr()};
2424
}
2525

2626
at::Tensor dil_tensor_to_dense(const at::Tensor& tensor) {
@@ -38,7 +38,6 @@ dil::tensor try_gen_dil_tensor(const at::Tensor &input) {
3838
if (cpu::ShadeDataContext::isDilTensor(input)) {
3939
return cpu::ShadeDataContext::getDilTensor(input);
4040
} else {
41-
TORCH_INTERNAL_ASSERT(input.is_contiguous());
4241
return dil_tensor_from_dense(input);
4342
}
4443
}
@@ -60,7 +59,6 @@ at::Tensor gen_aten_tensor_by(dil::tensor dil_tensor) {
6059
shade_data_context,
6160
cpu::ShadeDataContext::freeShadeDataContext,
6261
at::DeviceType::DPCPP);
63-
auto tensor_sizes = dil_tensor.get_dims();
6462
auto at_data_type = get_at_data_type(dil_tensor.get_data_type());
6563
auto storage_impl = c10::make_intrusive<at::StorageImpl>(
6664
at::scalarTypeToTypeMeta(at_data_type),
@@ -69,10 +67,14 @@ at::Tensor gen_aten_tensor_by(dil::tensor dil_tensor) {
6967
nullptr,
7068
/*resizeable=*/false);
7169
auto _tensor = at::detail::make_tensor<torch_ipex::IPEXTensorImpl>(storage_impl, at::DispatchKey::DPCPPTensorId);
72-
if (tensor_sizes.size() != 1 || tensor_sizes[0] != 0) {
70+
if (dil_tensor.is_public_format()) {
71+
dbl::comm::sync_shape_from_dil_to_aten(_tensor, dil_tensor);
72+
} else {
73+
// Blockformat does not inlcude stride information
74+
auto tensor_sizes = dil_tensor.get_dims();
75+
TORCH_INTERNAL_ASSERT(tensor_sizes.size() != 1 || tensor_sizes[0] != 0);
7376
_tensor.unsafeGetTensorImpl()->set_sizes_contiguous(tensor_sizes);
7477
}
75-
TORCH_INTERNAL_ASSERT(_tensor.is_contiguous());
7678
TORCH_INTERNAL_ASSERT(_tensor.layout() == c10::kStrided);
7779
return _tensor;
7880
}

torch_ipex/csrc/cpu/dil/dil/tensor.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,12 @@ class tensor : public memory {
449449
init(adims, adata_type, ahandle, aengine);
450450
}
451451

452+
// no format_tb, strides, buffer
453+
tensor(const dims &adims, data_type adata_type, const dims &astrides,
454+
void *ahandle, const engine &aengine = engine::cpu_engine()) {
455+
init(adims, adata_type, astrides, ahandle, aengine);
456+
}
457+
452458
// no format_tag, no buffer
453459
tensor(const dims &adims, data_type adata_type,
454460
const engine &aengine = engine::cpu_engine()) {
@@ -480,6 +486,11 @@ class tensor : public memory {
480486
init({adims, adata_type, aformat_tag}, ahandle, aengine);
481487
}
482488

489+
void init(const dims &adims, data_type adata_type, const dims &astrides,
490+
void *ahandle, const engine &aengine = engine::cpu_engine()) {
491+
init({adims, adata_type, astrides}, ahandle, aengine);
492+
}
493+
483494
// format_tag, no buffer
484495
void init(const dims &adims, data_type adata_type, format_tag aformat_tag,
485496
const engine &aengine = engine::cpu_engine()) {

0 commit comments

Comments
 (0)