Skip to content

Commit 56d94b3

Browse files
committed
[mlir] Extract offsets-sizes-strides computation from makeTiledShape(s).
This change separates computation of the actual parameters of the subset and the materialization of subview/extract_slice. That way the users can still use Linalg tiling logic even if they use different operations to materialize the subsets. Differential Revision: https://reviews.llvm.org/D131053
1 parent 57a9bcc commit 56d94b3

File tree

8 files changed

+168
-87
lines changed

8 files changed

+168
-87
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,44 @@ Value materializeOpFoldResult(ImplicitLocOpBuilder &builder,
214214
Value materializeOpFoldResult(OpBuilder &b, Location loc,
215215
OpFoldResult opFoldResult);
216216

217+
/// A struct containg offsets-sizes-strides arguments of the tiled shape.
218+
struct SliceParameters {
219+
SmallVector<OpFoldResult, 3> offsets;
220+
SmallVector<OpFoldResult, 3> sizes;
221+
SmallVector<OpFoldResult, 3> strides;
222+
};
223+
224+
/// Computes SliceParameters for a single `valueToTile`. `omitPartialTileCheck`
225+
/// controls whether to omit the partial/boundary tile condition check in cases
226+
/// where we statically know that it is unnecessary.
227+
SliceParameters
228+
computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
229+
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
230+
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
231+
ArrayRef<OpFoldResult> subShapeSizes,
232+
bool omitPartialTileCheck);
233+
234+
/// Computes SliceParamaters for all `valuesToTile` of the given
235+
/// `linalgOp`, assuming `linalgOp` is being fused into a loop
236+
/// nest for tiling with the given induction variables `ivs` and tile sizes
237+
/// `tileSizes`. `sizeBounds` are the iteration space bounds for *all* the
238+
/// implicit loops in `linalgOp`. `omitPartialTileCheck` controls whether to
239+
/// omit the partial/boundary tile condition check in cases where we statically
240+
/// know that it is unnecessary.
241+
///
242+
/// Note that a constant zero in `tileSizes` means no tiling at that implicit
243+
/// loop. The number of non-zero values in `tileSizes` should be equal to the
244+
/// number of values in `ivs`.
245+
///
246+
/// Some of the `valuesToTile` won't be affected by tiling. For these values,
247+
/// llvm::None will be returned.
248+
SmallVector<Optional<SliceParameters>>
249+
computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
250+
ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
251+
ArrayRef<OpFoldResult> tileSizes,
252+
ArrayRef<OpFoldResult> sizeBounds,
253+
bool omitPartialTileCheck);
254+
217255
/// Creates an extract_slice/subview op for a single `valueToTile` with
218256
/// `builder`. This new operation extracts a tile of `valueToTile`, starting
219257
/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 86 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -802,28 +802,61 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
802802
assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
803803
}
804804

805+
static Value materializeTiledShape(OpBuilder &builder, Location loc,
806+
Value valueToTile,
807+
const SliceParameters &sliceParams) {
808+
auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
809+
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
810+
.Case([&](MemRefType) {
811+
return builder.create<memref::SubViewOp>(
812+
loc, valueToTile, sliceParams.offsets,
813+
sliceParams.sizes, sliceParams.strides);
814+
})
815+
.Case([&](RankedTensorType) {
816+
return makeComposedExtractSliceOp(
817+
builder, loc, valueToTile, sliceParams.offsets,
818+
sliceParams.sizes, sliceParams.strides);
819+
})
820+
.Default([](ShapedType) -> Operation * {
821+
llvm_unreachable("Unexpected shaped type");
822+
});
823+
return sliceOp->getResult(0);
824+
}
825+
805826
Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
806827
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
807828
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
808829
ArrayRef<OpFoldResult> subShapeSizes,
809830
bool omitPartialTileCheck) {
831+
SliceParameters sliceParams =
832+
computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
833+
ubs, subShapeSizes, omitPartialTileCheck);
834+
return materializeTiledShape(builder, loc, valueToTile, sliceParams);
835+
}
836+
837+
SliceParameters
838+
computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
839+
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
840+
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
841+
ArrayRef<OpFoldResult> subShapeSizes,
842+
bool omitPartialTileCheck) {
810843
auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
811844
assert(shapedType && "only shaped types can be tiled");
812845
ArrayRef<int64_t> shape = shapedType.getShape();
813846
int64_t rank = shapedType.getRank();
814847

815848
// Construct a new subview / extract_slice for the tile.
816-
SmallVector<OpFoldResult, 4> offsets, sizes, strides;
817-
offsets.reserve(rank);
818-
sizes.reserve(rank);
819-
strides.reserve(rank);
849+
SliceParameters sliceParams;
850+
sliceParams.offsets.reserve(rank);
851+
sliceParams.sizes.reserve(rank);
852+
sliceParams.strides.reserve(rank);
820853
for (unsigned r = 0; r < rank; ++r) {
821-
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: for dim#" << r);
854+
LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
822855
if (!isTiled(map.getSubMap({r}), tileSizes)) {
823-
offsets.push_back(builder.getIndexAttr(0));
856+
sliceParams.offsets.push_back(builder.getIndexAttr(0));
824857
OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
825-
sizes.push_back(dim);
826-
strides.push_back(builder.getIndexAttr(1));
858+
sliceParams.sizes.push_back(dim);
859+
sliceParams.strides.push_back(builder.getIndexAttr(1));
827860
LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
828861
continue;
829862
}
@@ -832,26 +865,27 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
832865
// Tiling creates a new slice at the proper index, the slice step is 1
833866
// (i.e. the op does not subsample, stepping occurs in the loop).
834867
auto m = map.getSubMap({r});
835-
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: submap: " << m << "\n");
868+
LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
836869
IRRewriter rewriter(builder);
837870
OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs);
838-
offsets.push_back(offset);
871+
sliceParams.offsets.push_back(offset);
839872
OpFoldResult closedIntSize =
840873
makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
841874
// Resulting size needs to be made half open interval again.
842875
AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
843876
OpFoldResult size =
844877
makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize);
845-
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: raw size: " << size << "\n");
846878
LLVM_DEBUG(llvm::dbgs()
847-
<< "makeTiledShape: new offset: " << offset << "\n");
848-
strides.push_back(builder.getIndexAttr(1));
879+
<< "computeSliceParameters: raw size: " << size << "\n");
880+
LLVM_DEBUG(llvm::dbgs()
881+
<< "computeSliceParameters: new offset: " << offset << "\n");
882+
sliceParams.strides.push_back(builder.getIndexAttr(1));
849883

850884
if (omitPartialTileCheck) {
851885
// We statically know that the partial/boundary tile condition is
852886
// unnecessary.
853887
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
854-
sizes.push_back(size);
888+
sliceParams.sizes.push_back(size);
855889
continue;
856890
}
857891

@@ -903,22 +937,9 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
903937
makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});
904938
}
905939
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
906-
sizes.push_back(size);
940+
sliceParams.sizes.push_back(size);
907941
}
908-
909-
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
910-
.Case([&](MemRefType) {
911-
return builder.create<memref::SubViewOp>(
912-
loc, valueToTile, offsets, sizes, strides);
913-
})
914-
.Case([&](RankedTensorType) {
915-
return makeComposedExtractSliceOp(
916-
builder, loc, valueToTile, offsets, sizes, strides);
917-
})
918-
.Default([](ShapedType) -> Operation * {
919-
llvm_unreachable("Unexpected shaped type");
920-
});
921-
return sliceOp->getResult(0);
942+
return sliceParams;
922943
}
923944

924945
SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
@@ -1003,28 +1024,29 @@ Value materializeOpFoldResult(OpBuilder &builder, Location loc,
10031024
return materializeOpFoldResult(b, opFoldResult);
10041025
}
10051026

1006-
SmallVector<Value> makeTiledShapes(OpBuilder &b, Location loc,
1007-
LinalgOp linalgOp, ValueRange valuesToTile,
1008-
ArrayRef<OpFoldResult> ivs,
1009-
ArrayRef<OpFoldResult> tileSizes,
1010-
ArrayRef<OpFoldResult> sizeBounds,
1011-
bool omitPartialTileCheck) {
1027+
SmallVector<Optional<SliceParameters>>
1028+
computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
1029+
ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
1030+
ArrayRef<OpFoldResult> tileSizes,
1031+
ArrayRef<OpFoldResult> sizeBounds,
1032+
bool omitPartialTileCheck) {
10121033
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
10131034
llvm::make_range(tileSizes.begin(), tileSizes.end()),
10141035
[](OpFoldResult v) { return !isZero(v); })) &&
10151036
"expected as many ivs as non-zero sizes");
10161037

10171038
// Construct (potentially temporary) mins and maxes on which to apply maps
10181039
// that define tile subshapes.
1019-
SmallVector<OpFoldResult> lbs = computeTileOffsets(b, loc, ivs, tileSizes);
1040+
SmallVector<OpFoldResult> lbs =
1041+
computeTileOffsets(builder, loc, ivs, tileSizes);
10201042
SmallVector<OpFoldResult> subShapeSizes =
1021-
computeTileSizes(b, loc, tileSizes, sizeBounds);
1043+
computeTileSizes(builder, loc, tileSizes, sizeBounds);
10221044

10231045
assert(static_cast<int64_t>(valuesToTile.size()) ==
10241046
linalgOp.getNumInputsAndOutputs() &&
10251047
"expected one value to tile for every operand");
1026-
SmallVector<Value> tiledShapes;
1027-
tiledShapes.reserve(valuesToTile.size());
1048+
SmallVector<Optional<SliceParameters>> allSliceParams;
1049+
allSliceParams.reserve(valuesToTile.size());
10281050
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
10291051
Value shapedOp = valuesToTile[opOperand->getOperandNumber()];
10301052
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
@@ -1035,18 +1057,39 @@ SmallVector<Value> makeTiledShapes(OpBuilder &b, Location loc,
10351057
// extract/insert slice pairs make the accessed iteration argument
10361058
// subdomains explicit.
10371059
if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) {
1038-
tiledShapes.push_back(shapedOp);
1060+
allSliceParams.push_back(llvm::None);
10391061
LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: "
10401062
<< opOperand->get().getType() << "\n");
10411063
continue;
10421064
}
10431065
LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
10441066

1045-
tiledShapes.push_back(makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs,
1046-
sizeBounds, subShapeSizes,
1047-
omitPartialTileCheck));
1067+
allSliceParams.push_back(computeSliceParameters(
1068+
builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
1069+
omitPartialTileCheck));
10481070
}
10491071

1072+
return allSliceParams;
1073+
}
1074+
1075+
SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
1076+
LinalgOp linalgOp, ValueRange valuesToTile,
1077+
ArrayRef<OpFoldResult> ivs,
1078+
ArrayRef<OpFoldResult> tileSizes,
1079+
ArrayRef<OpFoldResult> sizeBounds,
1080+
bool omitPartialTileCheck) {
1081+
SmallVector<Optional<SliceParameters>> allSliceParameter =
1082+
computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
1083+
tileSizes, sizeBounds, omitPartialTileCheck);
1084+
SmallVector<Value> tiledShapes;
1085+
for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
1086+
Value valueToTile = std::get<0>(item);
1087+
Optional<SliceParameters> sliceParams = std::get<1>(item);
1088+
tiledShapes.push_back(
1089+
sliceParams.hasValue()
1090+
? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
1091+
: valueToTile);
1092+
}
10501093
return tiledShapes;
10511094
}
10521095

mlir/test/Dialect/Linalg/tile-and-distribute.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ func.func @gemm1(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32
1616
// CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
1717
// CHECK: scf.for %[[ARG3:.*]] =
1818
// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
19-
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
2019
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
21-
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
2220
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
23-
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
24-
// CHECK: %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX]]]
21+
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
22+
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
23+
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
24+
// CHECK: %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
2525
// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
2626

2727
// -----
@@ -48,11 +48,11 @@ func.func @gemm2(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32
4848
// CHECK: scf.if %[[INBOUNDS]]
4949
// CHECK: scf.for %[[ARG3:.*]] =
5050
// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
51-
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
5251
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
53-
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
5452
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
5553
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
54+
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
55+
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
5656
// CHECK: %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
5757
// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
5858

@@ -106,11 +106,11 @@ func.func @gemm4(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32
106106
// CHECK: scf.if %[[INBOUNDS]]
107107
// CHECK: scf.for %[[ARG3:.*]] =
108108
// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
109-
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
110109
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
111-
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
112110
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
113111
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
112+
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
113+
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
114114
// CHECK: %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
115115
// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
116116

@@ -139,9 +139,9 @@ func.func @gemm5(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32
139139
// CHECK: scf.parallel (%[[ARG3:.*]]) = (%[[LBX]]) to (%{{.*}}) step (%[[STEPX]])
140140
// CHECK: scf.for %[[ARG4:.*]] =
141141
// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
142+
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
142143
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG4]]]
143144
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG4]], %[[ARG3]]]
144-
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
145145
// CHECK: %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[ARG3]]]
146146
// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
147147

@@ -166,10 +166,10 @@ func.func @gemm6(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32
166166
// CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
167167
// CHECK: scf.parallel (%[[ARG3:.*]]) = (%[[LBY]]) to (%{{.*}}) step (%[[STEPY]])
168168
// CHECK: scf.for %[[ARG4:.*]] =
169-
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[ARG3]], %[[ARG4]]]
170169
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
171-
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG4]], %[[OFFSETX]]]
172170
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
171+
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[ARG3]], %[[ARG4]]]
172+
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG4]], %[[OFFSETX]]]
173173
// CHECK: %[[SV3:.*]] = memref.subview %[[ARG2]][%[[ARG3]], %[[OFFSETX_2]]]
174174
// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
175175

mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,9 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
241241
// CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND2_MAP]](%[[IV2]])[%[[ELEM_OC]]]
242242
// CHECK-NEXT: %[[OFFSET_OW:.+]] = affine.apply #[[X2_MAP]](%[[IV2]])
243243
// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[FILL_W]], %[[FILTER_W]]]
244+
// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]]
244245
// CHECK-NEXT: %[[ST_INPUT:.+]] = tensor.extract_slice %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
245246
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
246-
// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]]
247247
// CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
248248
// CHECK-NEXT: %[[ST_ELEM:.+]] = tensor.extract_slice %[[ELEM]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
249249
// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]

mlir/test/Dialect/Linalg/tile-conv.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ func.func @conv(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>, %arg2 : memref
2626
// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[T3]] step %[[C3]]
2727
// CHECK: %[[T4:.*]] = affine.min #[[MAP0]](%[[ARG3]])[%[[T2]], %[[T0]]]
2828
// CHECK: %[[T5:.*]] = affine.min #[[MAP1]](%[[ARG4]])[%[[T3]], %[[T1]]]
29-
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[ARG3]], %[[ARG4]]] [%[[T4]], %[[T5]]]
3029
// CHECK: %[[T6:.*]] = affine.min #[[MAP2]](%[[ARG3]])[%[[T2]]
3130
// CHECK: %[[T7:.*]] = affine.min #[[MAP3]](%[[ARG4]])[%[[T3]]]
31+
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[ARG3]], %[[ARG4]]] [%[[T4]], %[[T5]]]
3232
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]]] [%[[T6]], %[[T7]]]
3333
// CHECK: linalg.conv_2d
3434
// CHECK-SAME: ins(%[[SV1]], %[[ARG1]]

0 commit comments

Comments
 (0)