@@ -29,22 +29,23 @@ class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
29
29
explicit AffineDimFinder (ArrayRef<utils::IteratorType> itTypes)
30
30
: iterTypes(itTypes) {}
31
31
32
- // Override method from AffineExprVisitor.
32
+ // / Overrides the visit method from AffineExprVisitor.
33
33
void visitDimExpr (AffineDimExpr expr) {
34
34
if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition ()])
35
35
pickedDim = expr;
36
36
}
37
37
38
- // / Set the desired iterator type that we want to pick.
38
+ // / Sets the desired iterator type that we want to pick.
39
39
void setPickedIterType (utils::IteratorType iterType) {
40
40
pickIterType = iterType;
41
41
}
42
42
43
- // / Get the desired AffineDimExpr.
43
+ // / Gets the desired AffineDimExpr.
44
44
AffineDimExpr getDimExpr () const {
45
45
return llvm::cast<AffineDimExpr>(pickedDim);
46
46
}
47
47
48
+ // / Walks the graph in post order to find dim expr.
48
49
void walkPostOrder (AffineExpr expr) {
49
50
pickedDim = nullptr ;
50
51
AffineExprVisitor<AffineDimFinder>::walkPostOrder (expr);
@@ -55,11 +56,11 @@ class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
55
56
AffineExpr pickedDim;
56
57
// / The iterator type that we want.
57
58
utils::IteratorType pickIterType;
58
- // / The mapping between dim=> iterator type .
59
+ // / The mapping between levels and iterator types .
59
60
ArrayRef<utils::IteratorType> iterTypes;
60
61
};
61
62
62
- // Flattens an affine expression into a list of AffineDimExprs.
63
+ // / Flattens an affine expression into a list of AffineDimExprs.
63
64
struct AffineDimCollector : public AffineExprVisitor <AffineDimCollector> {
64
65
// Overrides method from AffineExprVisitor.
65
66
void visitDimExpr (AffineDimExpr expr) { dims.push_back (expr); }
@@ -97,8 +98,8 @@ AffineMap IterationGraphSorter::topoSort() {
97
98
98
99
SmallVector<unsigned > loopOrder;
99
100
while (!redIt.empty () || !parIt.empty ()) {
100
- // We always prefer parallel loop over reduction loop because putting
101
- // reduction loop early might make the loop sequence inadmissible.
101
+ // We always prefer a parallel loop over a reduction loop because putting
102
+ // a reduction loop early might make the loop sequence inadmissible.
102
103
auto &it = !parIt.empty () ? parIt : redIt;
103
104
auto src = it.back ();
104
105
loopOrder.push_back (src);
@@ -114,6 +115,7 @@ AffineMap IterationGraphSorter::topoSort() {
114
115
}
115
116
}
116
117
118
+ // Return the topological sort on success.
117
119
if (loopOrder.size () == numLoops)
118
120
return AffineMap::getPermutationMap (loopOrder, out.getContext ());
119
121
@@ -164,13 +166,14 @@ IterationGraphSorter::IterationGraphSorter(
164
166
}
165
167
166
168
AffineMap IterationGraphSorter::sort (SortMask mask, Value ignored) {
167
- // Reset the interation graph.
169
+ // Reset the adjacency matrix that represents the iteration graph.
168
170
for (auto &row : itGraph)
169
171
std::fill (row.begin (), row.end (), false );
170
172
171
- // Reset cached in-degree.
173
+ // Reset in-degree.
172
174
std::fill (inDegree.begin (), inDegree.end (), 0 );
173
175
176
+ // Add the constraints for the loop to level map.
174
177
for (auto [in, map] : llvm::zip (ins, loop2InsLvl)) {
175
178
// Get map and encoding.
176
179
const auto enc = getSparseTensorEncoding (in.getType ());
@@ -180,11 +183,12 @@ AffineMap IterationGraphSorter::sort(SortMask mask, Value ignored) {
180
183
addConstraints (in, map);
181
184
}
182
185
183
- // Get map and encoding .
186
+ // Add the constraints for the output map .
184
187
const auto enc = getSparseTensorEncoding (out.getType ());
185
188
if ((enc || includesDenseOutput (mask)) && out != ignored)
186
189
addConstraints (out, loop2OutLvl);
187
190
191
+ // Return the topological sort (empty for cyclic).
188
192
return topoSort ();
189
193
}
190
194
@@ -196,6 +200,7 @@ void IterationGraphSorter::addConstraints(Value t, AffineMap loop2LvlMap) {
196
200
}
197
201
};
198
202
203
+ // Set up a reduction finder.
199
204
AffineDimFinder finder (iterTypes);
200
205
finder.setPickedIterType (utils::IteratorType::reduction);
201
206
0 commit comments