Skip to content

Commit 1d432c2

Browse files
authored
Fix over-flow issue when calculate dil tensor elements (#86)
If the input tensor is with large number of elements, the nelemnts api will be over flow
1 parent 3fe5370 commit 1d432c2

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

torch_ipex/csrc/cpu/dil/dil/tensor.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class tensor : public memory {
8989
inline dim_t nelems(bool with_padding = false) const {
9090
if (is_zero()) return 0;
9191
auto dims = with_padding ? data.padded_dims : data.dims;
92-
return std::accumulate(dims, dims + data.ndims, 1,
92+
return std::accumulate(dims, dims + data.ndims, (dim_t)1,
9393
std::multiplies<dim_t>());
9494
}
9595

@@ -155,7 +155,7 @@ class tensor : public memory {
155155
// compute compatible block_dims with v0.x
156156
dims block_dims(data.ndims, 1);
157157
for (auto i = 0; i < blk.inner_nblks; i++) {
158-
block_dims[blk.inner_idxs[i]] *= blk.inner_blks[i];
158+
block_dims[blk.inner_idxs[i]] *= blk.inner_blks[i];
159159
}
160160
for (auto i = 0; i < data.ndims; i++) {
161161
if (data.dims[i] < block_dims[i]) continue;
@@ -749,7 +749,7 @@ class tensor : public memory {
749749
data_type dst_type = data_type::f32) const {
750750
auto dst_desc = get_desc();
751751

752-
// If we get a non-plain blocking format, say `Acdb16A`, we may not be able
752+
// If we get a non-plain blocking format, say `Acdb16A`, we may not be able
753753
// to recover it to its "unblocked" format `acdb`. Instead, we will convert
754754
// it to its default format `abcd` based on its dimensions.
755755
if (!is_public_format()) {
@@ -828,7 +828,7 @@ class tensor : public memory {
828828
// TODO(xpz): support per-channel dequantize
829829
DIL_ENFORCE(get_scale().size() == 1, "Incorrect scale size");
830830
dst.feed_from(*this);
831-
return dst;
831+
return dst;
832832
}
833833

834834
// reorder src to part of this tensor
@@ -875,9 +875,9 @@ class tensor : public memory {
875875
/// Return whether the param has a scale
876876
bool has_scale() const { return scale_ != nullptr && !scale_->empty(); }
877877

878-
/// Return whether the param has a zero_point
878+
/// Return whether the param has a zero_point
879879
bool has_zero_point() const { return zero_point_ != nullptr && !zero_point_->empty(); }
880-
880+
881881
/// Return the zero_point of this param.
882882
const std::vector<int32_t> &get_zero_point() const { return *zero_point_.get(); }
883883

0 commit comments

Comments
 (0)