Skip to content

Commit f7fbd8a

Browse files
authored
Merge pull request #14 from pinzhenx/sync_strides
Sync shape info between dil tensor and aten tensor
2 parents 8418986 + f4732b8 commit f7fbd8a

File tree

5 files changed

+53
-27
lines changed

5 files changed

+53
-27
lines changed

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,8 @@ def is_out_func(fname):
306306
if param_var == 'out' and is_out_func(fname):
307307
code += ' TORCH_INTERNAL_ASSERT({}.is_contiguous());\n'.format(param_var)
308308
else:
309-
param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var)
309+
# param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var)
310+
None
310311
param_seq_str_vec.append(param_seq_str)
311312
code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors))\n'
312313
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_seq_str_vec))

tests/cpu/test_torch.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12785,12 +12785,8 @@ def _test_memory_format_transformations(self, device, input_generator_fn, transf
1278512785
clone = transformation_fn(xc)
1278612786

1278712787
if default_is_preserve:
12788-
if ipex.get_auto_dnnl():
12789-
self.assertTrue(clone.is_contiguous())
12790-
self.assertFalse(clone.is_contiguous(memory_format=memory_format))
12791-
else:
12792-
self.assertFalse(clone.is_contiguous())
12793-
self.assertTrue(clone.is_contiguous(memory_format=memory_format))
12788+
self.assertFalse(clone.is_contiguous())
12789+
self.assertTrue(clone.is_contiguous(memory_format=memory_format))
1279412790
else:
1279512791
self.assertTrue(clone.is_contiguous())
1279612792
self.assertFalse(clone.is_contiguous(memory_format=memory_format))

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ dil::tensor dil_tensor_from_dense(const at::Tensor& tensor) {
2020
tensor.layout() == at::Layout::Strided,
2121
"dil_tensor_view_from_dense expects dense tensor input");
2222
at::ScalarType cur_type = tensor.scalar_type();
23-
return {tensor.sizes().vec(), get_dil_data_type(cur_type), tensor.data_ptr()};
23+
return {tensor.sizes().vec(), get_dil_data_type(cur_type), tensor.strides().vec(), tensor.data_ptr()};
2424
}
2525

2626
at::Tensor dil_tensor_to_dense(const at::Tensor& tensor) {
@@ -36,9 +36,12 @@ at::Tensor dil_tensor_to_dense(const at::Tensor& tensor) {
3636

3737
dil::tensor try_gen_dil_tensor(const at::Tensor &input) {
3838
if (cpu::ShadeDataContext::isDilTensor(input)) {
39-
return cpu::ShadeDataContext::getDilTensor(input);
39+
auto dil_tensor = cpu::ShadeDataContext::getDilTensor(input);
40+
if (dil_tensor.is_public_format()) {
41+
dil_tensor.set_dims_and_strides(input.sizes().vec(), input.strides().vec());
42+
}
43+
return dil_tensor;
4044
} else {
41-
TORCH_INTERNAL_ASSERT(input.is_contiguous());
4245
return dil_tensor_from_dense(input);
4346
}
4447
}
@@ -60,7 +63,6 @@ at::Tensor gen_aten_tensor_by(dil::tensor dil_tensor) {
6063
shade_data_context,
6164
cpu::ShadeDataContext::freeShadeDataContext,
6265
at::DeviceType::DPCPP);
63-
auto tensor_sizes = dil_tensor.get_dims();
6466
auto at_data_type = get_at_data_type(dil_tensor.get_data_type());
6567
auto storage_impl = c10::make_intrusive<at::StorageImpl>(
6668
at::scalarTypeToTypeMeta(at_data_type),
@@ -69,10 +71,14 @@ at::Tensor gen_aten_tensor_by(dil::tensor dil_tensor) {
6971
nullptr,
7072
/*resizeable=*/false);
7173
auto _tensor = at::detail::make_tensor<torch_ipex::IPEXTensorImpl>(storage_impl, at::DispatchKey::DPCPPTensorId);
72-
if (tensor_sizes.size() != 1 || tensor_sizes[0] != 0) {
74+
if (dil_tensor.is_public_format()) {
75+
dbl::comm::sync_shape_from_dil_to_aten(_tensor, dil_tensor);
76+
} else {
77+
// Blockformat does not inlcude stride information
78+
auto tensor_sizes = dil_tensor.get_dims();
79+
TORCH_INTERNAL_ASSERT(tensor_sizes.size() != 1 || tensor_sizes[0] != 0);
7380
_tensor.unsafeGetTensorImpl()->set_sizes_contiguous(tensor_sizes);
7481
}
75-
TORCH_INTERNAL_ASSERT(_tensor.is_contiguous());
7682
TORCH_INTERNAL_ASSERT(_tensor.layout() == c10::kStrided);
7783
return _tensor;
7884
}

torch_ipex/csrc/cpu/dbl/DNNLChecker.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ bool dnnl_support_the_dimension_of(const std::vector<at::Tensor> &tensor_vec) {
5656

5757
bool dnnl_tensor_has_data(const std::vector<at::Tensor> &tensor_vec) {
5858
for (auto it = tensor_vec.begin(); it != tensor_vec.end(); ++it)
59-
if (it->data_ptr() == nullptr)
59+
if (it->numel() == 0)
6060
return false;
6161

6262
return true;

torch_ipex/csrc/cpu/dil/dil/tensor.hpp

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,6 @@ class tensor : public memory {
8181
return static_cast<data_type>(data.data_type);
8282
}
8383

84-
inline dims get_strides() const {
85-
DIL_ENFORCE(is_plain(), "Call to_public() before get_strides()");
86-
const auto& strides = blocking_strides();
87-
if (!is_grouped()) {
88-
return dims(strides, strides + data.ndims);
89-
} else {
90-
auto ret = dims(strides + 1, strides + data.ndims);
91-
ret[0] = std::min(strides[0], strides[1]);
92-
return ret;
93-
}
94-
}
95-
9684
/** returns true if memory descriptor is zero */
9785
bool is_zero() const { return data.ndims == 0; }
9886

@@ -379,6 +367,17 @@ class tensor : public memory {
379367
return const_cast<dnnl_memory_desc_t&>(data).format_desc.blocking.strides;
380368
}
381369

370+
inline dims get_strides() const {
371+
const auto& strides = blocking_strides();
372+
if (!is_grouped()) {
373+
return dims(strides, strides + data.ndims);
374+
} else {
375+
auto ret = dims(strides + 1, strides + data.ndims);
376+
ret[0] = std::min(strides[0], strides[1]);
377+
return ret;
378+
}
379+
}
380+
382381
void set_g(dim groups) {
383382
auto reserved_size = sizeof(((dnnl_memory_extra_desc_t *)0)->reserved);
384383
auto offset = reserved_size / sizeof(dim) - 1;
@@ -449,6 +448,12 @@ class tensor : public memory {
449448
init(adims, adata_type, ahandle, aengine);
450449
}
451450

451+
// no format_tb, strides, buffer
452+
tensor(const dims &adims, data_type adata_type, const dims &astrides,
453+
void *ahandle, const engine &aengine = engine::cpu_engine()) {
454+
init(adims, adata_type, astrides, ahandle, aengine);
455+
}
456+
452457
// no format_tag, no buffer
453458
tensor(const dims &adims, data_type adata_type,
454459
const engine &aengine = engine::cpu_engine()) {
@@ -480,6 +485,11 @@ class tensor : public memory {
480485
init({adims, adata_type, aformat_tag}, ahandle, aengine);
481486
}
482487

488+
void init(const dims &adims, data_type adata_type, const dims &astrides,
489+
void *ahandle, const engine &aengine = engine::cpu_engine()) {
490+
init({adims, adata_type, astrides}, ahandle, aengine);
491+
}
492+
483493
// format_tag, no buffer
484494
void init(const dims &adims, data_type adata_type, format_tag aformat_tag,
485495
const engine &aengine = engine::cpu_engine()) {
@@ -571,7 +581,20 @@ class tensor : public memory {
571581
/// Returns dimension vector
572582
inline dims get_dims() const { return get_desc().get_dims(); }
573583

574-
inline dims get_strides() const { return get_desc().get_strides(); }
584+
inline dims get_strides() const {
585+
DIL_ENFORCE(is_public_format(), "Call to_public() before get_strides()");
586+
return get_desc().get_strides();
587+
}
588+
589+
inline void set_dims_and_strides(const dims &adims, const dims &astrides) {
590+
DIL_ENFORCE(is_public_format(), "Call to_public() before set_dims_and_strides()");
591+
DIL_ENFORCE(adims.size() == astrides.size(), "Dims and strides must have the same size");
592+
if (get_dims() == adims && get_strides() == astrides)
593+
return;
594+
auto new_desc = desc(adims, get_data_type(), astrides);
595+
DIL_ENFORCE(get_size() == new_desc.get_size(), "Invalid dims and strides for the original desc");
596+
set_desc(new_desc);
597+
}
575598

576599
/// Return element number of the param.
577600
/// The number is the meaning values for a tensor, instead of whole buffer.

0 commit comments

Comments
 (0)