Skip to content

Commit 22a1302

Browse files
authored
[mlir][vector] Add more tests for ConvertVectorToLLVM (1/n) (#101936)
Adds tests with scalable vectors for the Vector-To-LLVM conversion pass. Covers the following Ops: * vector.bitcast * vector.broadcast Note, this has uncovered some missing logic in `BroadcastOpLowering`. This PR fixes the most basic cases where the scalable flags were dropped and the generated code was incorrect. Also, the conditions in `vector::isBroadcastableTo` are relaxed to allow cases like this: ```mlir %0 = vector.broadcast %arg0 : vector<1xf32> to vector<[4]xf32> ``` The `BroadcastOpLowering` pattern is effectively disabled for scalable vectors in more complex cases where an SCF loop would be required to loop over the scalable dims, e.g.: ```mlir %0 = vector.broadcast %arg0 : vector<[4]x1x2xf32> to vector<[4]x3x2xf32> ``` These cases are marked as "Stretch not at start" in the code. In those cases, support for scalable vectors is left as a TODO.
1 parent 37a94b7 commit 22a1302

File tree

3 files changed

+246
-2
lines changed

3 files changed

+246
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2405,7 +2405,10 @@ BroadcastableToResult mlir::vector::isBroadcastableTo(
24052405
bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
24062406
bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
24072407
if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2408-
(srcDimScalableFlag != dstDimScalableFlag))
2408+
// 1 -> [N] is fine, everything else should be rejected when mixing
2409+
// fixed-width and scalable dims
2410+
(srcDimScalableFlag != dstDimScalableFlag &&
2411+
(srcDim != 1 || srcDimScalableFlag)))
24092412
foundMismatchingDims = true;
24102413

24112414
if (foundMismatchingDims) {

mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
125125
// ..
126126
// %x = [%a,%b,%c,%d]
127127
VectorType resType =
128-
VectorType::get(dstType.getShape().drop_front(), eltType);
128+
VectorType::get(dstType.getShape().drop_front(), eltType,
129+
dstType.getScalableDims().drop_front());
129130
Value result = rewriter.create<arith::ConstantOp>(
130131
loc, dstType, rewriter.getZeroAttr(dstType));
131132
if (m == 0) {
@@ -136,6 +137,10 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
136137
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
137138
} else {
138139
// Stetch not at start.
140+
if (dstType.getScalableDims()[0]) {
141+
// TODO: For scalable vectors we should emit an scf.for loop.
142+
return failure();
143+
}
139144
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
140145
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
141146
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);

0 commit comments

Comments
 (0)