Skip to content

Commit 6cac792

Browse files
authored
[mlir][Vector] Improve vector.mask verifier (#139823)
This PR improves the `vector.mask` verifier to make sure it's not applying masking semantics to operations defined outside of the `vector.mask` region. Documentation is updated to emphasize that and make it clearer, even though it already stated that. As part of this change, the logic that ensures that a terminator is present in the region mask has been simplified to make it less surprising to the user when a `vector.yield` is explicitly provided in the IR.
1 parent dfc65ef commit 6cac792

File tree

3 files changed

+59
-18
lines changed

3 files changed

+59
-18
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2482,8 +2482,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
24822482
masked. Values used within the region are captured from above. Only one
24832483
*maskable* operation can be masked with a `vector.mask` operation at a time.
24842484
An operation is *maskable* if it implements the `MaskableOpInterface`. The
2485-
terminator yields all results of the maskable operation to the result of
2486-
this operation.
2485+
terminator yields all results from the maskable operation to the result of
2486+
this operation. No other values are allowed to be yielded.
2487+
2488+
An empty `vector.mask` operation is currently legal to enable optimizations
2489+
across the `vector.mask` region. However, this might change in the future
2490+
once vector transformations gain better support for `vector.mask`.
2491+
TODO: Consider making empty `vector.mask` illegal.
24872492

24882493
The vector mask argument holds a bit for each vector lane and determines
24892494
which vector lanes should execute the maskable operation and which ones

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6550,29 +6550,33 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
65506550
}
65516551

65526552
void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
6553-
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
6554-
MaskOp>::ensureTerminator(region, builder, loc);
6555-
// Keep the default yield terminator if the number of masked operations is not
6556-
// the expected. This case will trigger a verification failure.
6553+
// 1. For an empty `vector.mask`, create a default terminator.
6554+
if (region.empty() || region.front().empty()) {
6555+
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
6556+
MaskOp>::ensureTerminator(region, builder, loc);
6557+
return;
6558+
}
6559+
6560+
// 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
65576561
Block &block = region.front();
6558-
if (block.getOperations().size() != 2)
6562+
if (isa<vector::YieldOp>(block.back()))
65596563
return;
65606564

6561-
// Replace default yield terminator with a new one that returns the results
6562-
// from the masked operation.
6563-
OpBuilder opBuilder(builder.getContext());
6564-
Operation *maskedOp = &block.front();
6565-
Operation *oldYieldOp = &block.back();
6566-
assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
6565+
// 3. For a non-empty `vector.mask` without an explicit terminator:
65676566

6568-
// Empty vector.mask op.
6569-
if (maskedOp == oldYieldOp)
6567+
// Create default terminator if the number of masked operations is not
6568+
// one. This case will trigger a verification failure.
6569+
if (block.getOperations().size() != 1) {
6570+
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
6571+
MaskOp>::ensureTerminator(region, builder, loc);
65706572
return;
6573+
}
65716574

6572-
opBuilder.setInsertionPoint(oldYieldOp);
6575+
// Create a terminator that yields the results from the masked operation.
6576+
OpBuilder opBuilder(builder.getContext());
6577+
Operation *maskedOp = &block.front();
6578+
opBuilder.setInsertionPointToEnd(&block);
65736579
opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
6574-
oldYieldOp->dropAllReferences();
6575-
oldYieldOp->erase();
65766580
}
65776581

65786582
LogicalResult MaskOp::verify() {
@@ -6607,6 +6611,10 @@ LogicalResult MaskOp::verify() {
66076611
return emitOpError("expects number of results to match maskable operation "
66086612
"number of results");
66096613

6614+
if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
6615+
return emitOpError("expects all the results from the MaskableOpInterface "
6616+
"to match all the values returned by the terminator");
6617+
66106618
if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
66116619
return emitOpError(
66126620
"expects result type to match maskable operation result type");

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,20 @@ func.func @vector_mask_empty_passthru_no_return_type(%mask : vector<8xi1>,
17561756

17571757
// -----
17581758

1759+
func.func @vector_mask_non_empty_external_return(%t: tensor<?xf32>, %idx: index,
1760+
%m: vector<16xi1>, %ext: vector<16xf32>) -> vector<16xf32> {
1761+
%ft0 = arith.constant 0.0 : f32
1762+
// expected-error@+1 {{'vector.mask' op expects all the results from the MaskableOpInterface to match all the values returned by the terminator}}
1763+
%0 = vector.mask %m {
1764+
%1 =vector.transfer_read %t[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
1765+
vector.yield %ext : vector<16xf32>
1766+
} : vector<16xi1> -> vector<16xf32>
1767+
1768+
return %0 : vector<16xf32>
1769+
}
1770+
1771+
// -----
1772+
17591773
func.func @vector_mask_empty_passthru_empty_return_type(%mask : vector<8xi1>,
17601774
%passthru : vector<8xi32>) {
17611775
// expected-error@+1 {{'vector.mask' expects a result if passthru operand is provided}}
@@ -1765,6 +1779,20 @@ func.func @vector_mask_empty_passthru_empty_return_type(%mask : vector<8xi1>,
17651779

17661780
// -----
17671781

1782+
func.func @vector_mask_non_empty_mixed_return(%t: tensor<?xf32>, %idx: index,
1783+
%m: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) {
1784+
%ft0 = arith.constant 0.0 : f32
1785+
// expected-error@+1 {{'vector.mask' op expects number of results to match maskable operation number of results}}
1786+
%0:2 = vector.mask %m {
1787+
%1 =vector.transfer_read %t[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
1788+
vector.yield %1, %ext : vector<16xf32>, vector<16xf32>
1789+
} : vector<16xi1> -> (vector<16xf32>, vector<16xf32>)
1790+
1791+
return %0#0, %0#1 : vector<16xf32>, vector<16xf32>
1792+
}
1793+
1794+
// -----
1795+
17681796
func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
17691797
// expected-error@+1 {{op failed to verify that position is a multiple of the source length.}}
17701798
%0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>

0 commit comments

Comments
 (0)