Skip to content

Commit d6541fc

Browse files
authored
[mlir][tensor] Fold padding expand_shape into insert_slice (#93018)
1 parent fa63771 commit d6541fc

File tree

2 files changed

+133
-1
lines changed

2 files changed

+133
-1
lines changed

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,42 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
7979
return success();
8080
}
8181
};
82+
83+
/// Fold expand_shape which only adds static dimensions of size `1`
84+
/// into insert_slice.
85+
template <typename OpTy>
86+
struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
87+
using OpRewritePattern<OpTy>::OpRewritePattern;
88+
89+
LogicalResult matchAndRewrite(OpTy insertSliceOp,
90+
PatternRewriter &rewriter) const override {
91+
auto expandShapeOp = insertSliceOp.getSource()
92+
.template getDefiningOp<tensor::ExpandShapeOp>();
93+
if (!expandShapeOp)
94+
return failure();
95+
96+
// Only fold away simple expansion where all added dimensions have static
97+
// size `1`.
98+
SliceVerificationResult res = isRankReducedType(
99+
expandShapeOp.getResultType(), expandShapeOp.getSrcType());
100+
if (res != SliceVerificationResult::Success)
101+
return rewriter.notifyMatchFailure(insertSliceOp,
102+
"expected rank increasing expansion");
103+
104+
rewriter.modifyOpInPlace(insertSliceOp, [&]() {
105+
insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
106+
});
107+
return success();
108+
}
109+
};
82110
} // namespace
83111

84112
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
85113
RewritePatternSet &patterns) {
86114
patterns.add<FoldExpandOfRankReducingExtract,
87115
FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
88-
FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>>(
116+
FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
117+
FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
118+
FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
89119
patterns.getContext());
90120
}

mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,105 @@ func.func @rank_reducing_parallel_insert_of_collapse_shape(
5454
}
5555
return %1 : tensor<?x?x?x?xf32>
5656
}
57+
58+
// -----
59+
60+
// CHECK-LABEL: func @insert_of_padding_expand_shape(
61+
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
62+
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
63+
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
64+
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
65+
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
66+
// CHECK: return %[[insert]]
67+
func.func @insert_of_padding_expand_shape(
68+
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
69+
-> tensor<?x?x?x?xf32> {
70+
%c0 = arith.constant 0 : index
71+
%c1 = arith.constant 1 : index
72+
%sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
73+
%sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
74+
%0 = tensor.expand_shape %t [[0, 1], [2, 3]] output_shape [1, %sz0, 1, %sz1]
75+
: tensor<?x?xf32> into tensor<1x?x1x?xf32>
76+
%1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][1, %sz0, 1, %sz1][1, 1, 1, 1]
77+
: tensor<1x?x1x?xf32> into tensor<?x?x?x?xf32>
78+
return %1 : tensor<?x?x?x?xf32>
79+
}
80+
81+
// -----
82+
83+
// CHECK-LABEL: func @insert_of_non_padding_expand_shape(
84+
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
85+
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
86+
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
87+
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
88+
// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
89+
// CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
90+
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[expand]] into %[[d]][%[[x]], %[[y]], 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
91+
// CHECK: return %[[insert]]
92+
func.func @insert_of_non_padding_expand_shape(
93+
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
94+
-> tensor<?x?x?x?xf32> {
95+
%c0 = arith.constant 0 : index
96+
%c1 = arith.constant 1 : index
97+
%sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
98+
%sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
99+
%0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [%sz, %sz0, %sz1]
100+
: tensor<?x?xf32> into tensor<?x?x?xf32>
101+
%1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][%sz, 1, %sz0, %sz1][1, 1, 1, 1]
102+
: tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
103+
return %1 : tensor<?x?x?x?xf32>
104+
}
105+
106+
// -----
107+
108+
// CHECK-LABEL: func @parallel_insert_of_padding_expand_shape(
109+
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
110+
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
111+
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
112+
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
113+
// CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
114+
func.func @parallel_insert_of_padding_expand_shape(
115+
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
116+
-> tensor<?x?x?x?xf32> {
117+
%c0 = arith.constant 0 : index
118+
%c1 = arith.constant 1 : index
119+
%sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
120+
%sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
121+
%0 = tensor.expand_shape %t [[0, 1], [2, 3]] output_shape [1, %sz0, 1, %sz1]
122+
: tensor<?x?xf32> into tensor<1x?x1x?xf32>
123+
%1 = scf.forall (%i, %j) in (%x, %y) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) {
124+
scf.forall.in_parallel {
125+
tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][1, %sz0, 1, %sz1][1, 1, 1, 1]
126+
: tensor<1x?x1x?xf32> into tensor<?x?x?x?xf32>
127+
}
128+
}
129+
return %1 : tensor<?x?x?x?xf32>
130+
}
131+
132+
// -----
133+
134+
// CHECK-LABEL: func @parallel_insert_of_non_padding_expand_shape(
135+
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
136+
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
137+
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
138+
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
139+
// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
140+
// CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
141+
// CHECK: tensor.parallel_insert_slice %[[expand]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
142+
func.func @parallel_insert_of_non_padding_expand_shape(
143+
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
144+
-> tensor<?x?x?x?xf32> {
145+
%c0 = arith.constant 0 : index
146+
%c1 = arith.constant 1 : index
147+
%sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
148+
%sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
149+
%0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [%sz, %sz0, %sz1]
150+
: tensor<?x?xf32> into tensor<?x?x?xf32>
151+
%1 = scf.forall (%i, %j) in (%x, %y) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) {
152+
scf.forall.in_parallel {
153+
tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][%sz, 1, %sz0, %sz1][1, 1, 1, 1]
154+
: tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
155+
}
156+
}
157+
return %1 : tensor<?x?x?x?xf32>
158+
}

0 commit comments

Comments
 (0)