Skip to content

Commit f79f430

Browse files
committed
Fold Tensor.extract_slice into a constant splat.
Fold arith.extract_slice into arith.constant when the source is a constant splat and the result type is statically shaped.
1 parent 210bb04 commit f79f430

File tree

4 files changed

+34
-1
lines changed

4 files changed

+34
-1
lines changed

mlir/include/mlir/IR/BuiltinAttributes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,11 @@ class DenseElementsAttr : public Attribute {
655655
/// same total number of elements as well as element type.
656656
DenseElementsAttr reshape(ShapedType newType);
657657

658+
/// Return a new DenseElementsAttr that has the same data as the current
659+
/// attribute, but with a different shape for a splat type. The new type must
660+
/// have the same element type.
661+
DenseElementsAttr resizeSplat(ShapedType newType);
662+
658663
/// Return a new DenseElementsAttr that has the same data as the current
659664
/// attribute, but has bitcast elements to 'newElType'. The new type must have
660665
/// the same bitwidth as the current element type.

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,12 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
12271227
return {};
12281228
}
12291229

1230-
OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
1230+
OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
1231+
if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
1232+
auto resultType = result().getType().cast<ShapedType>();
1233+
if (resultType.hasStaticShape())
1234+
return splat.resizeSplat(resultType);
1235+
}
12311236
if (getSourceType() == getType() &&
12321237
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
12331238
return this->source();

mlir/lib/IR/BuiltinAttributes.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,18 @@ DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
967967
return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
968968
}
969969

970+
DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
971+
assert(isSplat() && "expected a splat type");
972+
973+
ShapedType curType = getType();
974+
if (curType == newType)
975+
return *this;
976+
977+
assert(newType.getElementType() == curType.getElementType() &&
978+
"expected the same element type");
979+
return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), true);
980+
}
981+
970982
/// Return a new DenseElementsAttr that has the same data as the current
971983
/// attribute, but has bitcast elements such that it is now 'newType'. The new
972984
/// type must have the same shape and element types of the same bitwidth as the

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,17 @@ func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>,
621621

622622
// -----
623623

624+
// CHECK-LABEL: func @fold_extract_constant_splat
625+
// CHECK-NOT: tensor.extract_slice
626+
// CHECK: arith.constant dense<42> : tensor<4x4xi32>
627+
func @fold_extract_constant_splat() -> (tensor<4x4xi32>) {
628+
%cst = arith.constant dense<42> : tensor<1024x1024xi32>
629+
%1 = tensor.extract_slice %cst[0,0] [4,4] [1, 1] : tensor<1024x1024xi32> to tensor<4x4xi32>
630+
return %1 : tensor<4x4xi32>
631+
}
632+
633+
// -----
634+
624635
// CHECK-LABEL: func @fold_overlapping_insert
625636
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
626637
func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {

0 commit comments

Comments
 (0)