diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 8545c7b9af8f7..52d7005470232 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2835,6 +2835,91 @@ struct InsertSliceOpSourceCastInserter final return success(); } }; + +/// If the destination tensor of the insertion of a slice has the same +/// number of elements as the slice, but with a shape that only +/// differs by a prefix of unit-sized dimensions, and if the insertion +/// happens at zero offsets, unit strides and with a size matching the +/// size of the destination, the insertion covers all elements of the +/// destination. The result of such an insertion is equivalent to the +/// slice, with its shape expanded to the type of the destination. +/// +/// Example: +/// ```mlir +/// %0 = tensor.insert_slice %slice into +/// %x[0, 0, 0, 0, 0][1, 1, 1, 16, 32][1, 1, 1, 1, 1] : +/// tensor<16x32xf32> into tensor<1x1x1x16x32xf32> +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %0 = tensor.expand_shape %slice[[0,1,2,3], [4]] : +/// tensor<16x32xf32> into tensor<1x1x1x16x32xf32> +/// ``` +struct InsertSliceOpFullRewriteCanonicalizer final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const override { + RankedTensorType sourceType = insertSliceOp.getSourceType(); + RankedTensorType resultType = insertSliceOp.getType(); + + if (sourceType != resultType && sourceType.hasStaticShape() && + resultType.hasStaticShape() && + isSameSizedSuffixShape(resultType.getShape(), sourceType.getShape()) && + succeeded(foldIdentityOffsetSizeAndStrideOpInterface(insertSliceOp, + resultType))) { + SmallVector reassocIndices; + + // Number of leading dimensions with unit size that are not + // shared with the source type + size_t unitPrefixLength = + resultType.getShape().size() - sourceType.getShape().size(); + + // Compose mapping of leading dimensions with unit size and the + // fist common dimension to the first dimension of the source + // tensor + ReassociationIndices unitPrefixExpansion; + + size_t dim; + for (dim = 0; dim < unitPrefixLength; dim++) + unitPrefixExpansion.push_back(dim); + + unitPrefixExpansion.push_back(unitPrefixLength); + reassocIndices.push_back(unitPrefixExpansion); + + // Map remaining common dimensions of the source to the target + for (dim = dim + 1; dim < resultType.getShape().size(); dim++) { + reassocIndices.push_back({static_cast(dim)}); + } + + rewriter.replaceOpWithNewOp( + insertSliceOp, insertSliceOp.getType(), insertSliceOp.getSource(), + reassocIndices); + + return mlir::success(); + } + + return mlir::failure(); + } + +private: + /// Checks if `suffix` is a suffix of `shape` and all preceding + /// elements in `shape` are ones. + static bool isSameSizedSuffixShape(ArrayRef shape, + ArrayRef suffix) { + if (shape.size() >= suffix.size()) { + ArrayRef prefix = shape.take_front(shape.size() - suffix.size()); + ArrayRef remainder = shape.take_back(suffix.size()); + + return llvm::all_of(prefix, [](int64_t d) { return d == 1; }) && + remainder == suffix; + } + + return false; + } +}; } // namespace llvm::SmallBitVector InsertSliceOp::getDroppedDims() { @@ -2845,7 +2930,8 @@ void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, InsertSliceOpCastFolder, - InsertSliceOpSourceCastInserter>(context); + InsertSliceOpSourceCastInserter, + InsertSliceOpFullRewriteCanonicalizer>(context); } Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b, diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 914e5e8b8c4b8..8e66ef9f89c74 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6 // ----- +// CHECK-LABEL: func @trivial_insert_slice_unit_prefix +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> +// CHECK-NOT: tensor.insert_slice +// CHECK: %[[EXPANDED:.[a-z0-9A-Z_]+]] = tensor.expand_shape %[[ARG0]] {{\[\[0, 1, 2, 3\], \[4\], \[5\], \[6\]\] output}}_shape {{\[1, 1, 1, 4, 6, 16, 32\]}} : tensor<4x6x16x32xi8> into tensor<1x1x1x4x6x16x32xi8> +// CHECK: return %[[EXPANDED]] : tensor<1x1x1x4x6x16x32xi8> +func.func @trivial_insert_slice_unit_prefix(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<1x1x1x4x6x16x32xi8>) -> tensor<1x1x1x4x6x16x32xi8> { + %0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0, 0, 0, 0] [1, 1, 1, 4, 6, 16, 32] [1, 1, 1, 1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<1x1x1x4x6x16x32xi8> + return %0 : tensor<1x1x1x4x6x16x32xi8> +} + +// ----- + // CHECK-LABEL: func @empty_insert_slice // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8> // CHECK-SAME: %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>