Skip to content

Commit 6fb7c2d

Browse files
authored
[mlir][sparse] bug fix on all-dense lex insertion (#73987)
Fixes a bug that appended values after insertion completed. Also slight optimization by avoiding all-Dense computation for every lexInsert call
1 parent b80b5f1 commit 6fb7c2d

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ class SparseTensorStorageBase {
186186

187187
protected:
188188
const MapRef map; // non-owning pointers into dim2lvl/lvl2dim vectors
189+
const bool allDense;
189190
};
190191

191192
/// A memory-resident sparse tensor using a storage scheme based on
@@ -293,8 +294,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
293294
/// Partially specialize lexicographical insertions based on template types.
294295
void lexInsert(const uint64_t *lvlCoords, V val) final {
295296
assert(lvlCoords);
296-
bool allDense = std::all_of(getLvlTypes().begin(), getLvlTypes().end(),
297-
[](LevelType lt) { return isDenseLT(lt); });
298297
if (allDense) {
299298
uint64_t lvlRank = getLvlRank();
300299
uint64_t valIdx = 0;
@@ -363,10 +362,12 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
363362

364363
/// Finalizes lexicographic insertions.
365364
void endLexInsert() final {
366-
if (values.empty())
367-
finalizeSegment(0);
368-
else
369-
endPath(0);
365+
if (!allDense) {
366+
if (values.empty())
367+
finalizeSegment(0);
368+
else
369+
endPath(0);
370+
}
370371
}
371372

372373
/// Allocates a new COO object and initializes it with the contents.
@@ -705,31 +706,26 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
705706
// we reserve position/coordinate space based on all previous dense
706707
// levels, which works well up to first sparse level; but we should
707708
// really use nnz and dense/sparse distribution.
708-
bool allDense = true;
709709
uint64_t sz = 1;
710710
for (uint64_t l = 0; l < lvlRank; l++) {
711711
if (isCompressedLvl(l)) {
712712
positions[l].reserve(sz + 1);
713713
positions[l].push_back(0);
714714
coordinates[l].reserve(sz);
715715
sz = 1;
716-
allDense = false;
717716
} else if (isLooseCompressedLvl(l)) {
718717
positions[l].reserve(2 * sz + 1); // last one unused
719718
positions[l].push_back(0);
720719
coordinates[l].reserve(sz);
721720
sz = 1;
722-
allDense = false;
723721
} else if (isSingletonLvl(l)) {
724722
coordinates[l].reserve(sz);
725723
sz = 1;
726-
allDense = false;
727724
} else if (is2OutOf4Lvl(l)) {
728-
assert(allDense && l == lvlRank - 1 && "unexpected 2:4 usage");
725+
assert(l == lvlRank - 1 && "unexpected 2:4 usage");
729726
sz = detail::checkedMul(sz, lvlSizes[l]) / 2;
730727
coordinates[l].reserve(sz);
731728
values.reserve(sz);
732-
allDense = false;
733729
} else { // Dense level.
734730
assert(isDenseLvl(l));
735731
sz = detail::checkedMul(sz, lvlSizes[l]);

mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717

1818
using namespace mlir::sparse_tensor;
1919

20+
static inline bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes) {
21+
for (uint64_t l = 0; l < lvlRank; l++)
22+
if (!isDenseLT(lvlTypes[l]))
23+
return false;
24+
return true;
25+
}
26+
2027
SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
2128
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
2229
const uint64_t *lvlSizes, const LevelType *lvlTypes,
@@ -26,15 +33,16 @@ SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
2633
lvlTypes(lvlTypes, lvlTypes + lvlRank),
2734
dim2lvlVec(dim2lvl, dim2lvl + lvlRank),
2835
lvl2dimVec(lvl2dim, lvl2dim + dimRank),
29-
map(dimRank, lvlRank, dim2lvlVec.data(), lvl2dimVec.data()) {
36+
map(dimRank, lvlRank, dim2lvlVec.data(), lvl2dimVec.data()),
37+
allDense(isAllDense(lvlRank, lvlTypes)) {
3038
assert(dimSizes && lvlSizes && lvlTypes && dim2lvl && lvl2dim);
3139
// Validate dim-indexed parameters.
3240
assert(dimRank > 0 && "Trivial shape is unsupported");
33-
for (uint64_t d = 0; d < dimRank; ++d)
41+
for (uint64_t d = 0; d < dimRank; d++)
3442
assert(dimSizes[d] > 0 && "Dimension size zero has trivial storage");
3543
// Validate lvl-indexed parameters.
3644
assert(lvlRank > 0 && "Trivial shape is unsupported");
37-
for (uint64_t l = 0; l < lvlRank; ++l) {
45+
for (uint64_t l = 0; l < lvlRank; l++) {
3846
assert(lvlSizes[l] > 0 && "Level size zero has trivial storage");
3947
assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) ||
4048
isSingletonLvl(l) || is2OutOf4Lvl(l));

0 commit comments

Comments
 (0)