Skip to content

Commit a310c3b

Browse files
committed
add dil api for setting strides
1 parent 51c8b48 commit a310c3b

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

torch_ipex/csrc/cpu/dil/dil/tensor.hpp

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,6 @@ class tensor : public memory {
8181
return static_cast<data_type>(data.data_type);
8282
}
8383

84-
inline dims get_strides() const {
85-
DIL_ENFORCE(is_plain(), "Call to_public() before get_strides()");
86-
const auto& strides = blocking_strides();
87-
if (!is_grouped()) {
88-
return dims(strides, strides + data.ndims);
89-
} else {
90-
auto ret = dims(strides + 1, strides + data.ndims);
91-
ret[0] = std::min(strides[0], strides[1]);
92-
return ret;
93-
}
94-
}
95-
9684
/** returns true if memory descriptor is zero */
9785
bool is_zero() const { return data.ndims == 0; }
9886

@@ -379,6 +367,17 @@ class tensor : public memory {
379367
return const_cast<dnnl_memory_desc_t&>(data).format_desc.blocking.strides;
380368
}
381369

370+
inline dims get_strides() const {
371+
const auto& strides = blocking_strides();
372+
if (!is_grouped()) {
373+
return dims(strides, strides + data.ndims);
374+
} else {
375+
auto ret = dims(strides + 1, strides + data.ndims);
376+
ret[0] = std::min(strides[0], strides[1]);
377+
return ret;
378+
}
379+
}
380+
382381
void set_g(dim groups) {
383382
auto reserved_size = sizeof(((dnnl_memory_extra_desc_t *)0)->reserved);
384383
auto offset = reserved_size / sizeof(dim) - 1;
@@ -582,7 +581,20 @@ class tensor : public memory {
582581
/// Returns dimension vector
583582
inline dims get_dims() const { return get_desc().get_dims(); }
584583

585-
inline dims get_strides() const { return get_desc().get_strides(); }
584+
inline dims get_strides() const {
585+
DIL_ENFORCE(is_public_format(), "Call to_public() before get_strides()");
586+
return get_desc().get_strides();
587+
}
588+
589+
inline void set_dims_and_strides(const dims &adims, const dims &astrides) {
590+
DIL_ENFORCE(is_public_format(), "Call to_public() before set_dims_and_strides()");
591+
DIL_ENFORCE(adims.size() == astrides.size(), "Dims and strides must have the same size");
592+
if (get_dims() == adims && get_strides() == astrides)
593+
return;
594+
auto new_desc = desc(adims, get_data_type(), astrides);
595+
DIL_ENFORCE(get_size() == new_desc.get_size(), "Invalid dims and strides for the original desc");
596+
set_desc(new_desc);
597+
}
586598

587599
/// Return element number of the param.
588600
/// The number is the meaning values for a tensor, instead of whole buffer.

0 commit comments

Comments
 (0)