Skip to content

[mlir][tensor] add gather decompose pattern #119805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,15 @@ def TypeConversionCastShapeDynamicDimsOp : Op<Transform_Dialect,
"(`ignore_dynamic_info` $ignore_dynamic_info^)? attr-dict";
}

def ApplyDecomposeTensorGatherPatternsOp : Op<Transform_Dialect,
"apply_patterns.tensor.decompose_gather",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that tensor.gather ops should be decomposed into a chain of
tensor.extract_slice and linalg.generic to extract the element from source.
}];

let assemblyFormat = "attr-dict";
}

#endif // TENSOR_TRANSFORM_OPS
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ using ControlFoldFn = std::function<bool(OpOperand *)>;
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
const ControlFoldFn &controlFn);

/// Populates `patterns` with patterns that decompose `tensor.gather` into
/// `tensor.empty` and `linalg.geric`, followed by a chain
/// of `tensor.extract_slice` operations on the inputs. This is intended to be
/// used as a tensor -> linalg lowering that decomposes gather such
/// that it can be bufferized into a sequence of bufferized op.
void populateDecomposeTensorGatherPatterns(RewritePatternSet &patterns);

//===----------------------------------------------------------------------===//
// Transform helpers
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn);
}

void transform::ApplyDecomposeTensorGatherPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
tensor::populateDecomposeTensorGatherPatterns(patterns);
}

//===----------------------------------------------------------------------===//
// TypeConversionCastTensorShapeOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
EmptyOpPatterns.cpp
ExtractSliceFromReshapeUtils.cpp
FoldTensorSubsetOps.cpp
GatherOpPatterns.cpp
IndependenceTransforms.cpp
MergeConsecutiveInsertExtractSlicePatterns.cpp
PackAndUnpackPatterns.cpp
Expand Down
166 changes: 166 additions & 0 deletions mlir/lib/Dialect/Tensor/Transforms/GatherOpPatterns.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
//===- GatherOpPatterns.cpp - Patterns related to tensor.concat lowering --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;
using namespace mlir::tensor;

namespace {

/// Decompose `tensor.gather` into `linalg.generic`.
///
/// %2 = tensor.gather %0[%1] gather_dims([0]) : (tensor<7x128xf16>,
/// tensor<1x7x1xindex>) -> tensor<1x7x128xf16>
///
/// Becomes
///
/// %empty = tensor.empty() : tensor<1x7x128xf16>
/// %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1,
/// 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
/// ["parallel", "parallel", "parallel"]} ins(%expanded : tensor<1x7x1xindex>)
/// outs(%13 : tensor<1x7x128xf16>) {
/// ^bb0(%in: index, %out: f16):
/// %17 = linalg.index 2 : index
/// %extracted = tensor.extract %0[%in, %17] : tensor<7x128xf16>
/// linalg.yield %extracted : f16
/// } -> tensor<1x7x128xf16>
struct DecomposeTensorGatherOp : public OpRewritePattern<tensor::GatherOp> {
using OpRewritePattern<tensor::GatherOp>::OpRewritePattern;

SmallVector<OpFoldResult> getDstMixedSizes(PatternRewriter &rewriter,
Location loc,
tensor::GatherOp gatherOp) const {
SmallVector<OpFoldResult> dstSize =
tensor::getMixedSizes(rewriter, loc, gatherOp.getResult());
SmallVector<OpFoldResult> indexSize =
tensor::getMixedSizes(rewriter, loc, gatherOp.getIndices());
SmallVector<OpFoldResult> srcSize =
tensor::getMixedSizes(rewriter, loc, gatherOp.getSource());
SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());
bool isShrinkDst = (indexSize.size() - 1) + srcSize.size() ==
dstSize.size() + gatherDims.size();
for (size_t i = 0; i < indexSize.size() - 1; i++) {
dstSize[i] = indexSize[i];
}
auto cnt = 0;
for (size_t i = indexSize.size() - 1; i < dstSize.size(); i++) {
while (isShrinkDst && llvm::find(gatherDims, cnt) != gatherDims.end()) {
cnt++;
}
dstSize[i] = llvm::find(gatherDims, cnt) == gatherDims.end()
? srcSize[cnt]
: getAsIndexOpFoldResult(rewriter.getContext(), 1);
cnt++;
}
return dstSize;
}

LogicalResult matchAndRewrite(tensor::GatherOp gatherOp,
PatternRewriter &rewriter) const override {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(gatherOp);
Location loc = gatherOp.getLoc();
SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());

// create destination tensor for linalg out
RankedTensorType dstType = gatherOp.getResultType();
Value dstTensor = rewriter.create<tensor::EmptyOp>(
loc, getDstMixedSizes(rewriter, loc, gatherOp),
dstType.getElementType());

// split index tensor to create the linalg input
SmallVector<Value> indexTensors;
Value originIndexTensor = gatherOp.getIndices();
SmallVector<OpFoldResult> indexTensorSize =
tensor::getMixedSizes(rewriter, loc, originIndexTensor);
SmallVector<OpFoldResult> indexTensorStride(
indexTensorSize.size(),
getAsIndexOpFoldResult(rewriter.getContext(), 1));
SmallVector<OpFoldResult> indexTensorOffset(
indexTensorSize.size(),
getAsIndexOpFoldResult(rewriter.getContext(), 0));
indexTensorSize[indexTensorSize.size() - 1] =
getAsIndexOpFoldResult(rewriter.getContext(), 1);

for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
indexTensorOffset[indexTensorSize.size() - 1] =
getAsIndexOpFoldResult(rewriter.getContext(), cnt);
Value indexTensor = rewriter.create<tensor::ExtractSliceOp>(
loc, originIndexTensor, indexTensorOffset, indexTensorSize,
indexTensorStride);
indexTensors.emplace_back(indexTensor);
}

// create the affine map
SmallVector<AffineMap> affineMaps;
SmallVector<AffineExpr> dimExprs;
size_t dstRank = dstType.getShape().size();
for (unsigned i = 0; i < indexTensorSize.size() - 1; ++i)
dimExprs.push_back(rewriter.getAffineDimExpr(i));
dimExprs.push_back(getAffineConstantExpr(0, rewriter.getContext()));

for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
AffineMap currentMap =
AffineMap::get(/*dimCount=*/dstRank, /*symbolCount=*/0, dimExprs,
rewriter.getContext());
affineMaps.emplace_back(currentMap);
}
affineMaps.emplace_back(rewriter.getMultiDimIdentityMap(dstRank));

// create iterater types array
SmallVector<utils::IteratorType> iteratorTypesArray(
dstRank, utils::IteratorType::parallel);

// check whether the gather op is valid
size_t srcRank = gatherOp.getSourceType().getShape().size();
assert(((indexTensorSize.size() - 1) + srcRank == dstRank ||
(indexTensorSize.size() - 1) + srcRank ==
dstRank + gatherDims.size()) &&
"Expected: index_size - 1 + source_size == dst_size or dst_szie - "
"gather_size. \n");
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
gatherOp, TypeRange(dstType), indexTensors, ValueRange{dstTensor},
affineMaps, iteratorTypesArray,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indexValues(srcRank);
bool isShrinkDst = (indexTensorSize.size() - 1) + srcRank ==
dstRank + gatherDims.size();
int cnt = 0;
for (auto i = indexTensorSize.size() - 1; i < dstRank; i++) {
while (isShrinkDst &&
llvm::find(gatherDims, cnt) != gatherDims.end()) {
cnt++;
}
indexValues[cnt] = b.create<linalg::IndexOp>(loc, i);
cnt++;
}
for (auto &&[i, dim] : llvm::enumerate(gatherDims)) {
indexValues[dim] = args[i];
}

Value extract = b.create<tensor::ExtractOp>(loc, gatherOp.getSource(),
indexValues);
b.create<linalg::YieldOp>(loc, extract);
});
return success();
}
};

} // namespace

void mlir::tensor::populateDecomposeTensorGatherPatterns(
RewritePatternSet &patterns) {
patterns.add<DecomposeTensorGatherOp>(patterns.getContext());
}
66 changes: 66 additions & 0 deletions mlir/test/Dialect/Tensor/decompose-gather.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// RUN: mlir-opt -split-input-file -transform-interpreter -cse --mlir-print-local-scope %s | FileCheck %s

/// CHECK-LABEL: @gather_single_gather_dim
func.func @gather_single_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32> {
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2x2xf32>
/// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2x2xf32>)
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32>
return %1 : tensor<2x3x2x2x2xf32>
}

/// CHECK-LABEL: @gather_single_gather_dim_no_shrink
func.func @gather_single_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32> {
/// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<2x3x2x1x2x2xf32>
/// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY1:.*]] : tensor<2x3x2x1x2x2xf32>)
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32>
return %1 : tensor<2x3x2x1x2x2xf32>
}

/// CHECK-LABEL: @gather_multiple_gather_dim
func.func @gather_multiple_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32> {
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2xf32>
/// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
/// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
/// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<2x3x1xindex>, tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2xf32>)
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32>
return %1 : tensor<2x3x2x2xf32>
}

/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink
func.func @gather_multiple_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32> {
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32>
return %1 : tensor<2x3x2x1x1x2xf32>
}

/// CHECK-LABEL: @gather_single_gather_dim_dynamic
func.func @gather_single_gather_dim_dynamic(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32> {
/// CHECK: %[[DIM1:.*]] = tensor.dim
/// CHECK: %[[DIM2:.*]] = tensor.dim
/// CHECK: %[[DIM3:.*]] = tensor.dim
/// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]], %[[DIM3:.*]]) : tensor<2x3x?x?x?xf32>
/// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x?x?x?xf32>)
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<?x?x?x?xf32>, tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32>
return %1 : tensor<2x3x?x?x?xf32>
}

/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink_dynamic
func.func @gather_multiple_gather_dim_no_shrink_dynamic(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32> {
/// CHECK: %[[DIM1:.*]] = tensor.dim
/// CHECK: %[[DIM2:.*]] = tensor.dim
/// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]]) : tensor<?x?x2x1x1x2xf32>
/// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [%[[DIM1:.*]], %[[DIM2:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
/// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [%[[DIM1:.*]], %[[DIM2:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
/// CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, 0)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<?x?x1xindex>, tensor<?x?x1xindex>) outs(%[[EMPTY:.*]] : tensor<?x?x2x1x1x2xf32>)
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32>
return %1 : tensor<?x?x2x1x1x2xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
transform.apply_patterns.tensor.decompose_gather
} : !transform.op<"func.func">
transform.yield
}
}
Loading