@@ -89,7 +89,7 @@ class tensor : public memory {
89
89
inline dim_t nelems (bool with_padding = false ) const {
90
90
if (is_zero ()) return 0 ;
91
91
auto dims = with_padding ? data.padded_dims : data.dims ;
92
- return std::accumulate (dims, dims + data.ndims , 1 ,
92
+ return std::accumulate (dims, dims + data.ndims , ( dim_t ) 1 ,
93
93
std::multiplies<dim_t >());
94
94
}
95
95
@@ -155,7 +155,7 @@ class tensor : public memory {
155
155
// compute compatible block_dims with v0.x
156
156
dims block_dims (data.ndims , 1 );
157
157
for (auto i = 0 ; i < blk.inner_nblks ; i++) {
158
- block_dims[blk.inner_idxs [i]] *= blk.inner_blks [i];
158
+ block_dims[blk.inner_idxs [i]] *= blk.inner_blks [i];
159
159
}
160
160
for (auto i = 0 ; i < data.ndims ; i++) {
161
161
if (data.dims [i] < block_dims[i]) continue ;
@@ -749,7 +749,7 @@ class tensor : public memory {
749
749
data_type dst_type = data_type::f32 ) const {
750
750
auto dst_desc = get_desc ();
751
751
752
- // If we get a non-plain blocking format, say `Acdb16A`, we may not be able
752
+ // If we get a non-plain blocking format, say `Acdb16A`, we may not be able
753
753
// to recover it to its "unblocked" format `acdb`. Instead, we will convert
754
754
// it to its default format `abcd` based on its dimensions.
755
755
if (!is_public_format ()) {
@@ -828,7 +828,7 @@ class tensor : public memory {
828
828
// TODO(xpz): support per-channel dequantize
829
829
DIL_ENFORCE (get_scale ().size () == 1 , " Incorrect scale size" );
830
830
dst.feed_from (*this );
831
- return dst;
831
+ return dst;
832
832
}
833
833
834
834
// reorder src to part of this tensor
@@ -875,9 +875,9 @@ class tensor : public memory {
875
875
// / Return whether the param has a scale
876
876
bool has_scale () const { return scale_ != nullptr && !scale_->empty (); }
877
877
878
- // / Return whether the param has a zero_point
878
+ // / Return whether the param has a zero_point
879
879
bool has_zero_point () const { return zero_point_ != nullptr && !zero_point_->empty (); }
880
-
880
+
881
881
// / Return the zero_point of this param.
882
882
const std::vector<int32_t > &get_zero_point () const { return *zero_point_.get (); }
883
883
0 commit comments