From 6b3a3e109f738c9199a483e4b4d457797a4d2ef2 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 21 May 2024 14:42:39 +0200 Subject: [PATCH] [MLIR][Tensor] Canonicalize fully covering slice insertions into tensors with unit prefixes If the destination tensor of the insertion of a slice has the same number of elements as the slice, but with a shape that only differs by a prefix of unit-sized dimensions, and if the insertion happens at zero offsets, unit strides and with a size matching the size of the destination, the insertion covers all elements of the destination. The result of such an insertion is equivalent to the slice, with its shape expanded to the type of the destination. Example: ```mlir %0 = tensor.insert_slice %slice into %x[0, 0, 0, 0, 0][1, 1, 1, 16, 32][1, 1, 1, 1, 1] : tensor<16x32xf32> into tensor<1x1x1x16x32xf32> ``` folds into: ```mlir %0 = tensor.expand_shape %slice[[0,1,2,3], [4]] : tensor<16x32xf32> into tensor<1x1x1x16x32xf32> ``` This commit adds a canonicalization pattern for `InsertSliceOp` that implements this pattern. --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 88 +++++++++++++++++++++- mlir/test/Dialect/Tensor/canonicalize.mlir | 12 +++ 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 8545c7b9af8f7..52d7005470232 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2835,6 +2835,91 @@ struct InsertSliceOpSourceCastInserter final return success(); } }; + +/// If the destination tensor of the insertion of a slice has the same +/// number of elements as the slice, but with a shape that only +/// differs by a prefix of unit-sized dimensions, and if the insertion +/// happens at zero offsets, unit strides and with a size matching the +/// size of the destination, the insertion covers all elements of the +/// destination. The result of such an insertion is equivalent to the +/// slice, with its shape expanded to the type of the destination. +/// +/// Example: +/// ```mlir +/// %0 = tensor.insert_slice %slice into +/// %x[0, 0, 0, 0, 0][1, 1, 1, 16, 32][1, 1, 1, 1, 1] : +/// tensor<16x32xf32> into tensor<1x1x1x16x32xf32> +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %0 = tensor.expand_shape %slice[[0,1,2,3], [4]] : +/// tensor<16x32xf32> into tensor<1x1x1x16x32xf32> +/// ``` +struct InsertSliceOpFullRewriteCanonicalizer final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const override { + RankedTensorType sourceType = insertSliceOp.getSourceType(); + RankedTensorType resultType = insertSliceOp.getType(); + + if (sourceType != resultType && sourceType.hasStaticShape() && + resultType.hasStaticShape() && + isSameSizedSuffixShape(resultType.getShape(), sourceType.getShape()) && + succeeded(foldIdentityOffsetSizeAndStrideOpInterface(insertSliceOp, + resultType))) { + SmallVector reassocIndices; + + // Number of leading dimensions with unit size that are not + // shared with the source type + size_t unitPrefixLength = + resultType.getShape().size() - sourceType.getShape().size(); + + // Compose mapping of leading dimensions with unit size and the + // fist common dimension to the first dimension of the source + // tensor + ReassociationIndices unitPrefixExpansion; + + size_t dim; + for (dim = 0; dim < unitPrefixLength; dim++) + unitPrefixExpansion.push_back(dim); + + unitPrefixExpansion.push_back(unitPrefixLength); + reassocIndices.push_back(unitPrefixExpansion); + + // Map remaining common dimensions of the source to the target + for (dim = dim + 1; dim < resultType.getShape().size(); dim++) { + reassocIndices.push_back({static_cast(dim)}); + } + + rewriter.replaceOpWithNewOp( + insertSliceOp, insertSliceOp.getType(), insertSliceOp.getSource(), + reassocIndices); + + return mlir::success(); + } + + return mlir::failure(); + } + +private: + /// Checks if `suffix` is a suffix of `shape` and all preceding + /// elements in `shape` are ones. + static bool isSameSizedSuffixShape(ArrayRef shape, + ArrayRef suffix) { + if (shape.size() >= suffix.size()) { + ArrayRef prefix = shape.take_front(shape.size() - suffix.size()); + ArrayRef remainder = shape.take_back(suffix.size()); + + return llvm::all_of(prefix, [](int64_t d) { return d == 1; }) && + remainder == suffix; + } + + return false; + } +}; } // namespace llvm::SmallBitVector InsertSliceOp::getDroppedDims() { @@ -2845,7 +2930,8 @@ void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, InsertSliceOpCastFolder, - InsertSliceOpSourceCastInserter>(context); + InsertSliceOpSourceCastInserter, + InsertSliceOpFullRewriteCanonicalizer>(context); } Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b, diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 914e5e8b8c4b8..8e66ef9f89c74 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6 // ----- +// CHECK-LABEL: func @trivial_insert_slice_unit_prefix +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> +// CHECK-NOT: tensor.insert_slice +// CHECK: %[[EXPANDED:.[a-z0-9A-Z_]+]] = tensor.expand_shape %[[ARG0]] {{\[\[0, 1, 2, 3\], \[4\], \[5\], \[6\]\] output}}_shape {{\[1, 1, 1, 4, 6, 16, 32\]}} : tensor<4x6x16x32xi8> into tensor<1x1x1x4x6x16x32xi8> +// CHECK: return %[[EXPANDED]] : tensor<1x1x1x4x6x16x32xi8> +func.func @trivial_insert_slice_unit_prefix(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<1x1x1x4x6x16x32xi8>) -> tensor<1x1x1x4x6x16x32xi8> { + %0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0, 0, 0, 0] [1, 1, 1, 4, 6, 16, 32] [1, 1, 1, 1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<1x1x1x4x6x16x32xi8> + return %0 : tensor<1x1x1x4x6x16x32xi8> +} + +// ----- + // CHECK-LABEL: func @empty_insert_slice // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8> // CHECK-SAME: %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>