@@ -186,6 +186,7 @@ class SparseTensorStorageBase {
186
186
187
187
protected:
188
188
const MapRef map; // non-owning pointers into dim2lvl/lvl2dim vectors
189
+ const bool allDense;
189
190
};
190
191
191
192
// / A memory-resident sparse tensor using a storage scheme based on
@@ -293,8 +294,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
293
294
// / Partially specialize lexicographical insertions based on template types.
294
295
void lexInsert (const uint64_t *lvlCoords, V val) final {
295
296
assert (lvlCoords);
296
- bool allDense = std::all_of (getLvlTypes ().begin (), getLvlTypes ().end (),
297
- [](LevelType lt) { return isDenseLT (lt); });
298
297
if (allDense) {
299
298
uint64_t lvlRank = getLvlRank ();
300
299
uint64_t valIdx = 0 ;
@@ -363,10 +362,12 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
363
362
364
363
// / Finalizes lexicographic insertions.
365
364
void endLexInsert () final {
366
- if (values.empty ())
367
- finalizeSegment (0 );
368
- else
369
- endPath (0 );
365
+ if (!allDense) {
366
+ if (values.empty ())
367
+ finalizeSegment (0 );
368
+ else
369
+ endPath (0 );
370
+ }
370
371
}
371
372
372
373
// / Allocates a new COO object and initializes it with the contents.
@@ -705,31 +706,26 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
705
706
// we reserve position/coordinate space based on all previous dense
706
707
// levels, which works well up to first sparse level; but we should
707
708
// really use nnz and dense/sparse distribution.
708
- bool allDense = true ;
709
709
uint64_t sz = 1 ;
710
710
for (uint64_t l = 0 ; l < lvlRank; l++) {
711
711
if (isCompressedLvl (l)) {
712
712
positions[l].reserve (sz + 1 );
713
713
positions[l].push_back (0 );
714
714
coordinates[l].reserve (sz);
715
715
sz = 1 ;
716
- allDense = false ;
717
716
} else if (isLooseCompressedLvl (l)) {
718
717
positions[l].reserve (2 * sz + 1 ); // last one unused
719
718
positions[l].push_back (0 );
720
719
coordinates[l].reserve (sz);
721
720
sz = 1 ;
722
- allDense = false ;
723
721
} else if (isSingletonLvl (l)) {
724
722
coordinates[l].reserve (sz);
725
723
sz = 1 ;
726
- allDense = false ;
727
724
} else if (is2OutOf4Lvl (l)) {
728
- assert (allDense && l == lvlRank - 1 && " unexpected 2:4 usage" );
725
+ assert (l == lvlRank - 1 && " unexpected 2:4 usage" );
729
726
sz = detail::checkedMul (sz, lvlSizes[l]) / 2 ;
730
727
coordinates[l].reserve (sz);
731
728
values.reserve (sz);
732
- allDense = false ;
733
729
} else { // Dense level.
734
730
assert (isDenseLvl (l));
735
731
sz = detail::checkedMul (sz, lvlSizes[l]);
0 commit comments