Skip to content

Commit 2045238

Browse files
[mlir][tensor] Add runtime verification for cast/dim/extract/insert/extract_slice
1 parent 4b97c6e commit 2045238

File tree

8 files changed

+462
-0
lines changed

8 files changed

+462
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TENSOR_RUNTIMEOPVERIFICATION_H
10+
#define MLIR_DIALECT_TENSOR_RUNTIMEOPVERIFICATION_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace tensor {
16+
void registerRuntimeVerifiableOpInterfaceExternalModels(
17+
DialectRegistry &registry);
18+
} // namespace tensor
19+
} // namespace mlir
20+
21+
#endif // MLIR_DIALECT_TENSOR_RUNTIMEOPVERIFICATION_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
8585
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
8686
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
87+
#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h"
8788
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
8889
#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
8990
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -186,6 +187,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
186187
tensor::registerBufferizableOpInterfaceExternalModels(registry);
187188
tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
188189
tensor::registerInferTypeOpInterfaceExternalModels(registry);
190+
tensor::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
189191
tensor::registerSubsetOpInterfaceExternalModels(registry);
190192
tensor::registerTilingInterfaceExternalModels(registry);
191193
tensor::registerValueBoundsOpInterfaceExternalModels(registry);

mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
88
MergeConsecutiveInsertExtractSlicePatterns.cpp
99
ReshapePatterns.cpp
1010
RewriteAsConstant.cpp
11+
RuntimeOpVerification.cpp
1112
SwapExtractSliceWithProducerPatterns.cpp
1213
SubsetInsertionOpInterfaceImpl.cpp
1314

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h"
10+
11+
#include "mlir/Dialect/Arith/IR/Arith.h"
12+
#include "mlir/Dialect/Arith/Utils/Utils.h"
13+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
14+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
15+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
16+
#include "mlir/Dialect/Utils/IndexingUtils.h"
17+
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
18+
19+
using namespace mlir;
20+
21+
namespace mlir {
22+
namespace tensor {
23+
namespace {
24+
/// Generate a runtime check for lb <= value < ub.
25+
Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
26+
Value lb, Value ub) {
27+
Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
28+
loc, arith::CmpIPredicate::sge, value, lb);
29+
Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
30+
loc, arith::CmpIPredicate::slt, value, ub);
31+
Value inBounds =
32+
builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
33+
return inBounds;
34+
}
35+
36+
struct CastOpInterface
37+
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
38+
CastOp> {
39+
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
40+
Location loc) const {
41+
auto castOp = cast<CastOp>(op);
42+
auto srcType = cast<TensorType>(castOp.getSource().getType());
43+
44+
// Nothing to check if the result is an unranked tensor.
45+
auto resultType = dyn_cast<RankedTensorType>(castOp.getType());
46+
if (!resultType)
47+
return;
48+
49+
if (isa<UnrankedTensorType>(srcType)) {
50+
// Check rank.
51+
Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
52+
Value resultRank =
53+
builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
54+
Value isSameRank = builder.create<arith::CmpIOp>(
55+
loc, arith::CmpIPredicate::eq, srcRank, resultRank);
56+
builder.create<cf::AssertOp>(
57+
loc, isSameRank,
58+
RuntimeVerifiableOpInterface::generateErrorMessage(op,
59+
"rank mismatch"));
60+
}
61+
62+
// Check dimension sizes.
63+
for (const auto &it : llvm::enumerate(resultType.getShape())) {
64+
// Static dim size -> static/dynamic dim size does not need verification.
65+
if (auto rankedSrcType = dyn_cast<RankedTensorType>(srcType))
66+
if (!rankedSrcType.isDynamicDim(it.index()))
67+
continue;
68+
69+
// Static/dynamic dim size -> dynamic dim size does not need verification.
70+
if (resultType.isDynamicDim(it.index()))
71+
continue;
72+
73+
Value srcDimSz =
74+
builder.create<DimOp>(loc, castOp.getSource(), it.index());
75+
Value resultDimSz =
76+
builder.create<arith::ConstantIndexOp>(loc, it.value());
77+
Value isSameSz = builder.create<arith::CmpIOp>(
78+
loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
79+
builder.create<cf::AssertOp>(
80+
loc, isSameSz,
81+
RuntimeVerifiableOpInterface::generateErrorMessage(
82+
op, "size mismatch of dim " + std::to_string(it.index())));
83+
}
84+
}
85+
};
86+
87+
struct DimOpInterface
88+
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
89+
DimOp> {
90+
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
91+
Location loc) const {
92+
auto dimOp = cast<DimOp>(op);
93+
Value rank = builder.create<RankOp>(loc, dimOp.getSource());
94+
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
95+
builder.create<cf::AssertOp>(
96+
loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
97+
RuntimeVerifiableOpInterface::generateErrorMessage(
98+
op, "index is out of bounds"));
99+
}
100+
};
101+
102+
/// Verifies that the indices on extract/insert ops are in-bounds of the
103+
/// tensor's index space: 0 <= index#i < dim#i
104+
template <typename OpTy>
105+
struct ExtractInsertOpInterface
106+
: public RuntimeVerifiableOpInterface::ExternalModel<
107+
ExtractInsertOpInterface<OpTy>, OpTy> {
108+
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
109+
Location loc) const {
110+
auto extractInsertOp = cast<OpTy>(op);
111+
112+
Value tensor;
113+
if constexpr (std::is_same_v<OpTy, ExtractOp>) {
114+
tensor = extractInsertOp.getTensor();
115+
} else if constexpr (std::is_same_v<OpTy, InsertOp>) {
116+
tensor = extractInsertOp.getDest();
117+
} else {
118+
llvm_unreachable("invalid op");
119+
}
120+
auto tensorType = cast<RankedTensorType>(tensor.getType());
121+
auto rank = tensorType.getRank();
122+
if (rank == 0) {
123+
// Nothing to check for 0-d tensors.
124+
return;
125+
}
126+
127+
auto indices = extractInsertOp.getIndices();
128+
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
129+
Value assertCond;
130+
for (auto i : llvm::seq<int64_t>(0, rank)) {
131+
Value dimOp = builder.createOrFold<tensor::DimOp>(loc, tensor, i);
132+
Value inBounds =
133+
generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
134+
assertCond =
135+
i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
136+
: inBounds;
137+
}
138+
builder.create<cf::AssertOp>(
139+
loc, assertCond,
140+
RuntimeVerifiableOpInterface::generateErrorMessage(
141+
op, "out-of-bounds access"));
142+
}
143+
};
144+
145+
struct ExtractSliceOpInterface
146+
: public RuntimeVerifiableOpInterface::ExternalModel<
147+
ExtractSliceOpInterface, ExtractSliceOp> {
148+
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
149+
Location loc) const {
150+
auto extractSliceOp = cast<ExtractSliceOp>(op);
151+
RankedTensorType sourceType = extractSliceOp.getSource().getType();
152+
153+
// For each dimension, assert that:
154+
// 0 <= offset < dim_size
155+
// 0 <= offset + (size - 1) * stride < dim_size
156+
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
157+
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
158+
for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
159+
Value offset = getValueOrCreateConstantIndexOp(
160+
builder, loc, extractSliceOp.getMixedOffsets()[i]);
161+
Value size = getValueOrCreateConstantIndexOp(
162+
builder, loc, extractSliceOp.getMixedSizes()[i]);
163+
Value stride = getValueOrCreateConstantIndexOp(
164+
builder, loc, extractSliceOp.getMixedStrides()[i]);
165+
166+
// Verify that offset is in-bounds.
167+
Value dimSize = builder.createOrFold<tensor::DimOp>(
168+
loc, extractSliceOp.getSource(), i);
169+
Value offsetInBounds =
170+
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
171+
builder.create<cf::AssertOp>(
172+
loc, offsetInBounds,
173+
RuntimeVerifiableOpInterface::generateErrorMessage(
174+
op, "offset " + std::to_string(i) + " is out-of-bounds"));
175+
176+
// Verify that slice does not run out-of-bounds.
177+
Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
178+
Value sizeMinusOneTimesStride =
179+
builder.create<arith::MulIOp>(loc, sizeMinusOne, stride);
180+
Value lastPos =
181+
builder.create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
182+
Value lastPosInBounds =
183+
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
184+
builder.create<cf::AssertOp>(
185+
loc, lastPosInBounds,
186+
RuntimeVerifiableOpInterface::generateErrorMessage(
187+
op, "extract_slice runs out-of-bounds along dimension " +
188+
std::to_string(i)));
189+
}
190+
}
191+
};
192+
} // namespace
193+
} // namespace tensor
194+
} // namespace mlir
195+
196+
void mlir::tensor::registerRuntimeVerifiableOpInterfaceExternalModels(
197+
DialectRegistry &registry) {
198+
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
199+
CastOp::attachInterface<CastOpInterface>(*ctx);
200+
DimOp::attachInterface<DimOpInterface>(*ctx);
201+
ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(*ctx);
202+
ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
203+
InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(*ctx);
204+
205+
// Load additional dialects of which ops may get created.
206+
ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
207+
});
208+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
3+
// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \
4+
// RUN: -test-cf-assert \
5+
// RUN: -convert-scf-to-cf \
6+
// RUN: -convert-to-llvm | \
7+
// RUN: mlir-runner -e main -entry-point-result=void \
8+
// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \
9+
// RUN: FileCheck %s
10+
11+
func.func private @cast_to_static_dim(%t: tensor<?xf32>) -> tensor<10xf32> {
12+
%0 = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
13+
return %0 : tensor<10xf32>
14+
}
15+
16+
func.func private @cast_to_ranked(%t: tensor<*xf32>) -> tensor<f32> {
17+
%0 = tensor.cast %t : tensor<*xf32> to tensor<f32>
18+
return %0 : tensor<f32>
19+
}
20+
21+
func.func private @valid_cast(%t: tensor<*xf32>) -> tensor<?xf32> {
22+
%0 = tensor.cast %t : tensor<*xf32> to tensor<?xf32>
23+
return %0 : tensor<?xf32>
24+
}
25+
26+
func.func @main() {
27+
// All casts inside the called functions are invalid at runtime, except for
28+
// the last one.
29+
%alloc = tensor.empty() : tensor<5xf32>
30+
31+
// CHECK: ERROR: Runtime op verification failed
32+
// CHECK-NEXT: "tensor.cast"(%{{.*}}) : (tensor<?xf32>) -> tensor<10xf32>
33+
// CHECK-NEXT: ^ size mismatch of dim 0
34+
// CHECK-NEXT: Location: loc({{.*}})
35+
%1 = tensor.cast %alloc : tensor<5xf32> to tensor<?xf32>
36+
func.call @cast_to_static_dim(%1) : (tensor<?xf32>) -> (tensor<10xf32>)
37+
38+
// CHECK-NEXT: ERROR: Runtime op verification failed
39+
// CHECK-NEXT: "tensor.cast"(%{{.*}}) : (tensor<*xf32>) -> tensor<f32>
40+
// CHECK-NEXT: ^ rank mismatch
41+
// CHECK-NEXT: Location: loc({{.*}})
42+
%3 = tensor.cast %alloc : tensor<5xf32> to tensor<*xf32>
43+
func.call @cast_to_ranked(%3) : (tensor<*xf32>) -> (tensor<f32>)
44+
45+
// A last cast that actually succeeds.
46+
// CHECK-NOT: ERROR: Runtime op verification failed
47+
func.call @valid_cast(%3) : (tensor<*xf32>) -> (tensor<?xf32>)
48+
49+
return
50+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -one-shot-bufferize \
3+
// RUN: -buffer-deallocation-pipeline \
4+
// RUN: -test-cf-assert \
5+
// RUN: -convert-to-llvm | \
6+
// RUN: mlir-runner -e main -entry-point-result=void \
7+
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
8+
// RUN: FileCheck %s
9+
10+
func.func @main() {
11+
%c4 = arith.constant 4 : index
12+
%tensor = tensor.empty() : tensor<1xf32>
13+
14+
// CHECK: ERROR: Runtime op verification failed
15+
// CHECK-NEXT: "tensor.dim"(%{{.*}}, %{{.*}}) : (tensor<1xf32>, index) -> index
16+
// CHECK-NEXT: ^ index is out of bounds
17+
// CHECK-NEXT: Location: loc({{.*}})
18+
%dim = tensor.dim %tensor, %c4 : tensor<1xf32>
19+
20+
return
21+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
3+
// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \
4+
// RUN: -test-cf-assert \
5+
// RUN: -convert-scf-to-cf \
6+
// RUN: -convert-to-llvm | \
7+
// RUN: mlir-runner -e main -entry-point-result=void \
8+
// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \
9+
// RUN: FileCheck %s
10+
11+
func.func @extract(%tensor: tensor<1xf32>, %index: index) {
12+
tensor.extract %tensor[%index] : tensor<1xf32>
13+
return
14+
}
15+
16+
func.func @extract_dynamic(%tensor: tensor<?xf32>, %index: index) {
17+
tensor.extract %tensor[%index] : tensor<?xf32>
18+
return
19+
}
20+
21+
func.func @extract_nd_dynamic(%tensor: tensor<?x?x?xf32>, %index0: index, %index1: index, %index2: index) {
22+
tensor.extract %tensor[%index0, %index1, %index2] : tensor<?x?x?xf32>
23+
return
24+
}
25+
26+
func.func @main() {
27+
%0 = arith.constant 0 : index
28+
%1 = arith.constant 1 : index
29+
%n1 = arith.constant -1 : index
30+
%2 = arith.constant 2 : index
31+
%alloca_1 = tensor.empty() : tensor<1xf32>
32+
%alloc_1 = tensor.empty(%1) : tensor<?xf32>
33+
%alloc_2x2x2 = tensor.empty(%2, %2, %2) : tensor<?x?x?xf32>
34+
35+
// CHECK: ERROR: Runtime op verification failed
36+
// CHECK-NEXT: "tensor.extract"(%{{.*}}, %{{.*}}) : (tensor<1xf32>, index) -> f32
37+
// CHECK-NEXT: ^ out-of-bounds access
38+
// CHECK-NEXT: Location: loc({{.*}})
39+
func.call @extract(%alloca_1, %1) : (tensor<1xf32>, index) -> ()
40+
41+
// CHECK: ERROR: Runtime op verification failed
42+
// CHECK-NEXT: "tensor.extract"(%{{.*}}, %{{.*}}) : (tensor<?xf32>, index) -> f32
43+
// CHECK-NEXT: ^ out-of-bounds access
44+
// CHECK-NEXT: Location: loc({{.*}})
45+
func.call @extract_dynamic(%alloc_1, %1) : (tensor<?xf32>, index) -> ()
46+
47+
// CHECK: ERROR: Runtime op verification failed
48+
// CHECK-NEXT: "tensor.extract"(%{{.*}}, %{{.*}}) : (tensor<?x?x?xf32>, index, index, index) -> f32
49+
// CHECK-NEXT: ^ out-of-bounds access
50+
// CHECK-NEXT: Location: loc({{.*}})
51+
func.call @extract_nd_dynamic(%alloc_2x2x2, %1, %n1, %0) : (tensor<?x?x?xf32>, index, index, index) -> ()
52+
53+
// CHECK-NOT: ERROR: Runtime op verification failed
54+
func.call @extract(%alloca_1, %0) : (tensor<1xf32>, index) -> ()
55+
56+
// CHECK-NOT: ERROR: Runtime op verification failed
57+
func.call @extract_dynamic(%alloc_1, %0) : (tensor<?xf32>, index) -> ()
58+
59+
// CHECK-NOT: ERROR: Runtime op verification failed
60+
func.call @extract_nd_dynamic(%alloc_2x2x2, %1, %1, %0) : (tensor<?x?x?xf32>, index, index, index) -> ()
61+
62+
return
63+
}
64+

0 commit comments

Comments
 (0)