Skip to content

Commit ff47102

Browse files
committed
Merge remote-tracking branch 'ipex-github/master'
Conflicts: torch_ipex/csrc/cpu/dbl/Common.cpp
2 parents ca17113 + f7fbd8a commit ff47102

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ at::Tensor dil_tensor_to_dense(const at::Tensor& tensor) {
3737
dil::tensor try_gen_dil_tensor(const at::Tensor &input) {
3838
if (cpu::ShadeDataContext::isDilTensor(input)) {
3939
auto dil_tensor = cpu::ShadeDataContext::getDilTensor(input);
40-
if (!check_aten_dil_shape_info(input, dil_tensor)) {
41-
//TODO(Eikan): Pinzhen will fix the issue here
42-
TORCH_INTERNAL_ASSERT(false);
40+
if (dil_tensor.is_public_format()) {
41+
dil_tensor.set_dims_and_strides(input.sizes().vec(), input.strides().vec());
4342
}
4443
return dil_tensor;
4544
} else {

torch_ipex/csrc/cpu/dil/dil/tensor.hpp

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,6 @@ class tensor : public memory {
8181
return static_cast<data_type>(data.data_type);
8282
}
8383

84-
inline dims get_strides() const {
85-
DIL_ENFORCE(is_plain(), "Call to_public() before get_strides()");
86-
const auto& strides = blocking_strides();
87-
if (!is_grouped()) {
88-
return dims(strides, strides + data.ndims);
89-
} else {
90-
auto ret = dims(strides + 1, strides + data.ndims);
91-
ret[0] = std::min(strides[0], strides[1]);
92-
return ret;
93-
}
94-
}
95-
9684
/** returns true if memory descriptor is zero */
9785
bool is_zero() const { return data.ndims == 0; }
9886

@@ -379,6 +367,17 @@ class tensor : public memory {
379367
return const_cast<dnnl_memory_desc_t&>(data).format_desc.blocking.strides;
380368
}
381369

370+
inline dims get_strides() const {
371+
const auto& strides = blocking_strides();
372+
if (!is_grouped()) {
373+
return dims(strides, strides + data.ndims);
374+
} else {
375+
auto ret = dims(strides + 1, strides + data.ndims);
376+
ret[0] = std::min(strides[0], strides[1]);
377+
return ret;
378+
}
379+
}
380+
382381
void set_g(dim groups) {
383382
auto reserved_size = sizeof(((dnnl_memory_extra_desc_t *)0)->reserved);
384383
auto offset = reserved_size / sizeof(dim) - 1;
@@ -582,7 +581,20 @@ class tensor : public memory {
582581
/// Returns dimension vector
583582
inline dims get_dims() const { return get_desc().get_dims(); }
584583

585-
inline dims get_strides() const { return get_desc().get_strides(); }
584+
inline dims get_strides() const {
585+
DIL_ENFORCE(is_public_format(), "Call to_public() before get_strides()");
586+
return get_desc().get_strides();
587+
}
588+
589+
inline void set_dims_and_strides(const dims &adims, const dims &astrides) {
590+
DIL_ENFORCE(is_public_format(), "Call to_public() before set_dims_and_strides()");
591+
DIL_ENFORCE(adims.size() == astrides.size(), "Dims and strides must have the same size");
592+
if (get_dims() == adims && get_strides() == astrides)
593+
return;
594+
auto new_desc = desc(adims, get_data_type(), astrides);
595+
DIL_ENFORCE(get_size() == new_desc.get_size(), "Invalid dims and strides for the original desc");
596+
set_desc(new_desc);
597+
}
586598

587599
/// Return element number of the param.
588600
/// The number is the meaning values for a tensor, instead of whole buffer.

0 commit comments

Comments
 (0)