Skip to content

Commit 7b34979

Browse files
authored
[LLGA] fix set_size of LlgaTensorImpl (#157)
* [LLGA] fix set_size * [LLGA] do not support set_size for LlgaTensorImpl * fix clang format
1 parent 6595260 commit 7b34979

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

torch_ipex/csrc/LlgaTensorImpl.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@ LlgaTensorImpl::LlgaTensorImpl(
1313
c10::DispatchKeySet(DispatchKey::MkldnnCPU),
1414
data_type),
1515
desc_(desc) {
16-
for (int64_t i = 0; i < desc.sizes().size(); i++) {
17-
c10::TensorImpl::set_size(i, desc.sizes()[i]);
18-
}
19-
}
16+
sizes_and_strides_.set_sizes(desc.sizes());
17+
refresh_numel();
18+
}
2019

2120
// The following are publically exposed as methods of Tensor
2221
IntArrayRef LlgaTensorImpl::strides() const {
@@ -35,6 +34,11 @@ int64_t LlgaTensorImpl::storage_offset() const {
3534
TORCH_CHECK(false, "Cannot access the storage_offset() of LlgaTensorImpl");
3635
}
3736

37+
// The following are some internal inherited methods that we do not support.
38+
// They should never get called.
39+
void LlgaTensorImpl::set_size(int64_t dim, int64_t new_size) {
40+
TORCH_INTERNAL_ASSERT(false, "Cannot set_size for LlgaTensorImpl");
41+
}
3842
void LlgaTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
3943
TORCH_INTERNAL_ASSERT(false, "Cannot set_stride for LlgaTensorImpl");
4044
}

torch_ipex/csrc/LlgaTensorImpl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ struct TORCH_API LlgaTensorImpl : public c10::TensorImpl {
197197
at::MemoryFormat::Contiguous) const override;
198198
IntArrayRef strides() const override;
199199
int64_t stride(int64_t d) const override;
200+
void set_size(int64_t dim, int64_t new_size) override;
200201
void set_stride(int64_t dim, int64_t new_stride) override;
201202
void set_storage_offset(int64_t storage_offset) override;
202203
bool has_storage() const override;

0 commit comments

Comments
 (0)