Skip to content

[mlir][tensor] add gather decompose pattern #119805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

zhczhong
Copy link
Member

Current tensor.gather cannot be bufferized and further lowered. Here add a decompose pattern to help decompose the tensor.gather into a series of bufferized op(tensor.empty, linalg.generic, tensor.extract_slice)

@llvmbot
Copy link
Member

llvmbot commented Dec 13, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: zhicong zhong (zhczhong)

Changes

Current tensor.gather cannot be bufferized and further lowered. Here add a decompose pattern to help decompose the tensor.gather into a series of bufferized op(tensor.empty, linalg.generic, tensor.extract_slice)


Full diff: https://github.com/llvm/llvm-project/pull/119805.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td (+11)
  • (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h (+7)
  • (modified) mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp (+5)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp (+166)
  • (added) mlir/test/Dialect/Tensor/decompose-gather.mlir (+66)
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 81bab1b0c82f7a..2be2d019e11228 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -189,4 +189,15 @@ def TypeConversionCastShapeDynamicDimsOp : Op<Transform_Dialect,
       "(`ignore_dynamic_info` $ignore_dynamic_info^)? attr-dict";
 }
 
+def ApplyDecomposeTensorGatherPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.tensor.decompose_gather",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that tensor.gather ops should be decomposed into a chain of
+    tensor.extract_slice and linalg.generic to extract the element from source.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // TENSOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index ae695e0326ca1a..fa73f74d0be66d 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -102,6 +102,13 @@ using ControlFoldFn = std::function<bool(OpOperand *)>;
 void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
                                        const ControlFoldFn &controlFn);
 
+/// Populates `patterns` with patterns that decompose `tensor.gather` into
+/// `tensor.empty` and `linalg.geric`, followed by a chain
+/// of `tensor.extract_slice` operations on the inputs. This is intended to be
+/// used as a tensor -> linalg lowering that decomposes gather such
+/// that it can be bufferized into a sequence of bufferized op.
+void populateDecomposeTensorGatherPatterns(RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 // Transform helpers
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 99199252710f99..cb2d01df40b8d8 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -143,6 +143,11 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
     tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn);
 }
 
+void transform::ApplyDecomposeTensorGatherPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  tensor::populateDecomposeTensorGatherPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // TypeConversionCastTensorShapeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index cc6275fee671aa..f1a23e5e3bfbfc 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   EmptyOpPatterns.cpp
   ExtractSliceFromReshapeUtils.cpp
   FoldTensorSubsetOps.cpp
+  GatherOpPatterns.cpp
   IndependenceTransforms.cpp
   MergeConsecutiveInsertExtractSlicePatterns.cpp
   PackAndUnpackPatterns.cpp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp
new file mode 100644
index 00000000000000..5905ee049228a5
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp
@@ -0,0 +1,166 @@
+//===- GatherOpPatterns.cpp - Patterns related to tensor.concat lowering --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+/// Decompose `tensor.gather` into `linalg.generic`.
+///
+/// %2 = tensor.gather %0[%1] gather_dims([0]) : (tensor<7x128xf16>,
+/// tensor<1x7x1xindex>) -> tensor<1x7x128xf16>
+///
+/// Becomes
+///
+/// %empty = tensor.empty() : tensor<1x7x128xf16>
+/// %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1,
+/// 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
+/// ["parallel", "parallel", "parallel"]} ins(%expanded : tensor<1x7x1xindex>)
+/// outs(%13 : tensor<1x7x128xf16>) {
+///    ^bb0(%in: index, %out: f16):
+///      %17 = linalg.index 2 : index
+///      %extracted = tensor.extract %0[%in, %17] : tensor<7x128xf16>
+///      linalg.yield %extracted : f16
+///    } -> tensor<1x7x128xf16>
+struct DecomposeTensorGatherOp : public OpRewritePattern<tensor::GatherOp> {
+  using OpRewritePattern<tensor::GatherOp>::OpRewritePattern;
+
+  SmallVector<OpFoldResult> getDstMixedSizes(PatternRewriter &rewriter,
+                                             Location loc,
+                                             tensor::GatherOp gatherOp) const {
+    SmallVector<OpFoldResult> dstSize =
+        tensor::getMixedSizes(rewriter, loc, gatherOp.getResult());
+    SmallVector<OpFoldResult> indexSize =
+        tensor::getMixedSizes(rewriter, loc, gatherOp.getIndices());
+    SmallVector<OpFoldResult> srcSize =
+        tensor::getMixedSizes(rewriter, loc, gatherOp.getSource());
+    SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());
+    bool isShrinkDst = (indexSize.size() - 1) + srcSize.size() ==
+                       dstSize.size() + gatherDims.size();
+    for (size_t i = 0; i < indexSize.size() - 1; i++) {
+      dstSize[i] = indexSize[i];
+    }
+    auto cnt = 0;
+    for (size_t i = indexSize.size() - 1; i < dstSize.size(); i++) {
+      while (isShrinkDst && llvm::find(gatherDims, cnt) != gatherDims.end()) {
+        cnt++;
+      }
+      dstSize[i] = llvm::find(gatherDims, cnt) == gatherDims.end()
+                       ? srcSize[cnt]
+                       : getAsIndexOpFoldResult(rewriter.getContext(), 1);
+      cnt++;
+    }
+    return dstSize;
+  }
+
+  LogicalResult matchAndRewrite(tensor::GatherOp gatherOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(gatherOp);
+    Location loc = gatherOp.getLoc();
+    SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());
+
+    // create destination tensor for linalg out
+    RankedTensorType dstType = gatherOp.getResultType();
+    Value dstTensor = rewriter.create<tensor::EmptyOp>(
+        loc, getDstMixedSizes(rewriter, loc, gatherOp),
+        dstType.getElementType());
+
+    // split index tensor to create the linalg input
+    SmallVector<Value> indexTensors;
+    Value originIndexTensor = gatherOp.getIndices();
+    SmallVector<OpFoldResult> indexTensorSize =
+        tensor::getMixedSizes(rewriter, loc, originIndexTensor);
+    SmallVector<OpFoldResult> indexTensorStride(
+        indexTensorSize.size(),
+        getAsIndexOpFoldResult(rewriter.getContext(), 1));
+    SmallVector<OpFoldResult> indexTensorOffset(
+        indexTensorSize.size(),
+        getAsIndexOpFoldResult(rewriter.getContext(), 0));
+    indexTensorSize[indexTensorSize.size() - 1] =
+        getAsIndexOpFoldResult(rewriter.getContext(), 1);
+
+    for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
+      indexTensorOffset[indexTensorSize.size() - 1] =
+          getAsIndexOpFoldResult(rewriter.getContext(), cnt);
+      Value indexTensor = rewriter.create<tensor::ExtractSliceOp>(
+          loc, originIndexTensor, indexTensorOffset, indexTensorSize,
+          indexTensorStride);
+      indexTensors.emplace_back(indexTensor);
+    }
+
+    // create the affine map
+    SmallVector<AffineMap> affineMaps;
+    SmallVector<AffineExpr> dimExprs;
+    size_t dstRank = dstType.getShape().size();
+    for (unsigned i = 0; i < indexTensorSize.size() - 1; ++i)
+      dimExprs.push_back(rewriter.getAffineDimExpr(i));
+    dimExprs.push_back(getAffineConstantExpr(0, rewriter.getContext()));
+
+    for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
+      AffineMap currentMap =
+          AffineMap::get(/*dimCount=*/dstRank, /*symbolCount=*/0, dimExprs,
+                         rewriter.getContext());
+      affineMaps.emplace_back(currentMap);
+    }
+    affineMaps.emplace_back(rewriter.getMultiDimIdentityMap(dstRank));
+
+    // create iterater types array
+    SmallVector<utils::IteratorType> iteratorTypesArray(
+        dstRank, utils::IteratorType::parallel);
+
+    // check whether the gather op is valid
+    size_t srcRank = gatherOp.getSourceType().getShape().size();
+    assert(((indexTensorSize.size() - 1) + srcRank == dstRank ||
+            (indexTensorSize.size() - 1) + srcRank ==
+                dstRank + gatherDims.size()) &&
+           "Expected: index_size - 1 + source_size == dst_size or dst_szie - "
+           "gather_size. \n");
+    rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+        gatherOp, TypeRange(dstType), indexTensors, ValueRange{dstTensor},
+        affineMaps, iteratorTypesArray,
+        [&](OpBuilder &b, Location loc, ValueRange args) {
+          SmallVector<Value> indexValues(srcRank);
+          bool isShrinkDst = (indexTensorSize.size() - 1) + srcRank ==
+                             dstRank + gatherDims.size();
+          int cnt = 0;
+          for (auto i = indexTensorSize.size() - 1; i < dstRank; i++) {
+            while (isShrinkDst &&
+                   llvm::find(gatherDims, cnt) != gatherDims.end()) {
+              cnt++;
+            }
+            indexValues[cnt] = b.create<linalg::IndexOp>(loc, i);
+            cnt++;
+          }
+          for (auto &&[i, dim] : llvm::enumerate(gatherDims)) {
+            indexValues[dim] = args[i];
+          }
+
+          Value extract = b.create<tensor::ExtractOp>(loc, gatherOp.getSource(),
+                                                      indexValues);
+          b.create<linalg::YieldOp>(loc, extract);
+        });
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tensor::populateDecomposeTensorGatherPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<DecomposeTensorGatherOp>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Tensor/decompose-gather.mlir b/mlir/test/Dialect/Tensor/decompose-gather.mlir
new file mode 100644
index 00000000000000..587dfc8cc7e2fc
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/decompose-gather.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter -cse --mlir-print-local-scope %s | FileCheck %s
+
+/// CHECK-LABEL: @gather_single_gather_dim
+func.func @gather_single_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32> {
+    /// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2x2xf32>
+    /// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2x2xf32>)
+    %1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32>
+  return %1 : tensor<2x3x2x2x2xf32>
+}
+
+/// CHECK-LABEL: @gather_single_gather_dim_no_shrink
+func.func @gather_single_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32> {
+    /// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<2x3x2x1x2x2xf32>
+    /// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY1:.*]] : tensor<2x3x2x1x2x2xf32>)
+    %1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32>
+  return %1 : tensor<2x3x2x1x2x2xf32>
+}
+
+/// CHECK-LABEL: @gather_multiple_gather_dim
+func.func @gather_multiple_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32> {
+    // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2xf32>
+    /// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
+    /// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
+      /// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<2x3x1xindex>, tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2xf32>)
+    %1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32>
+  return %1 : tensor<2x3x2x2xf32>
+}
+
+/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink
+func.func @gather_multiple_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32> {
+    %1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32>
+  return %1 : tensor<2x3x2x1x1x2xf32>
+}
+
+/// CHECK-LABEL: @gather_single_gather_dim_dynamic
+func.func @gather_single_gather_dim_dynamic(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32> {
+    /// CHECK: %[[DIM1:.*]] = tensor.dim
+    /// CHECK: %[[DIM2:.*]] = tensor.dim
+    /// CHECK: %[[DIM3:.*]] = tensor.dim
+    /// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]], %[[DIM3:.*]]) : tensor<2x3x?x?x?xf32>
+    /// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x?x?x?xf32>)
+    %1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<?x?x?x?xf32>, tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32>
+  return %1 : tensor<2x3x?x?x?xf32>
+}
+
+/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink_dynamic
+func.func @gather_multiple_gather_dim_no_shrink_dynamic(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32> {
+    /// CHECK: %[[DIM1:.*]] = tensor.dim
+    /// CHECK: %[[DIM2:.*]] = tensor.dim
+    /// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]]) : tensor<?x?x2x1x1x2xf32>
+    /// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [%[[DIM1:.*]], %[[DIM2:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
+    /// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [%[[DIM1:.*]], %[[DIM2:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
+    /// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<?x?x1xindex>, tensor<?x?x1xindex>) outs(%[[EMPTY:.*]] : tensor<?x?x2x1x1x2xf32>)
+    %1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32>
+  return %1 : tensor<?x?x2x1x1x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.tensor.decompose_gather
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants