|
| 1 | +//===- ResolveShapedTypeResultDims.cpp - Resolve memref.dim ops of result values |
| 2 | +//-------===// |
| 3 | +// |
| 4 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 5 | +// See https://llvm.org/LICENSE.txt for license information. |
| 6 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 7 | +// |
| 8 | +//===----------------------------------------------------------------------===// |
| 9 | +// |
| 10 | +// This pass resolves `memref.dim` operations of result values in terms of |
| 11 | +// shapes of their operands using the `InferShapedTypeOpInterface`. |
| 12 | +// |
| 13 | +//===----------------------------------------------------------------------===// |
| 14 | + |
| 15 | +#include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 16 | +#include "mlir/Dialect/MemRef/Transforms/Passes.h" |
| 17 | +#include "mlir/Dialect/StandardOps/IR/Ops.h" |
| 18 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 19 | +#include "mlir/Interfaces/InferTypeOpInterface.h" |
| 20 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 21 | + |
| 22 | +using namespace mlir; |
| 23 | + |
| 24 | +/// Helper method to get the `Value` that is the shape of the `resultIdx`-th |
| 25 | +/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`. |
| 26 | +/// TODO(ravishankarm): This is better put as a interface utility method |
| 27 | +/// somewhere, but that would imply the interface will depend on the `tensor` |
| 28 | +/// dialect. Ideally maybe a utility method in the `tensor` dialect. |
| 29 | +static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result, |
| 30 | + int64_t dimIndex) { |
| 31 | + unsigned resultNumber = result.getResultNumber(); |
| 32 | + auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner()); |
| 33 | + Location loc = result.getOwner()->getLoc(); |
| 34 | + if (!shapedTypeOp) |
| 35 | + return nullptr; |
| 36 | + |
| 37 | + // The interface exposes two methods, one that returns the shape of all the |
| 38 | + // results as `Value` and other that returns the shape as a list of |
| 39 | + // `SmallVector<Value>`. The former takes precedence over the latter. So first |
| 40 | + // check if the op implements the first interface method or the second, and |
| 41 | + // get the value to use appropriately. |
| 42 | + SmallVector<Value> reifiedResultShapes; |
| 43 | + if (succeeded(shapedTypeOp.reifyReturnTypeShapes( |
| 44 | + builder, result.getOwner()->getOperands(), reifiedResultShapes))) { |
| 45 | + if (reifiedResultShapes.size() <= resultNumber) |
| 46 | + return nullptr; |
| 47 | + Value resultShape = reifiedResultShapes[resultNumber]; |
| 48 | + auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>(); |
| 49 | + if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>()) |
| 50 | + return nullptr; |
| 51 | + return builder.create<tensor::ExtractOp>( |
| 52 | + loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex)); |
| 53 | + } |
| 54 | + |
| 55 | + SmallVector<SmallVector<Value>> reifiedResultShapesPerDim; |
| 56 | + if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim( |
| 57 | + builder, reifiedResultShapesPerDim))) |
| 58 | + return nullptr; |
| 59 | + if (reifiedResultShapesPerDim.size() <= resultNumber || |
| 60 | + reifiedResultShapesPerDim[resultNumber].size() != |
| 61 | + static_cast<size_t>(result.getType().cast<ShapedType>().getRank())) |
| 62 | + return nullptr; |
| 63 | + OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex]; |
| 64 | + if (auto attr = valueOrAttr.dyn_cast<Attribute>()) |
| 65 | + return builder.createOrFold<ConstantIndexOp>( |
| 66 | + loc, attr.cast<IntegerAttr>().getInt()); |
| 67 | + return valueOrAttr.get<Value>(); |
| 68 | +} |
| 69 | + |
| 70 | +namespace { |
| 71 | +/// Fold dim of an operation that implements the InferShapedTypeOpInterface |
| 72 | +struct DimOfShapedTypeOpInterface : public OpRewritePattern<memref::DimOp> { |
| 73 | + using OpRewritePattern<memref::DimOp>::OpRewritePattern; |
| 74 | + |
| 75 | + LogicalResult matchAndRewrite(memref::DimOp dimOp, |
| 76 | + PatternRewriter &rewriter) const override { |
| 77 | + OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>(); |
| 78 | + if (!dimValue) |
| 79 | + return failure(); |
| 80 | + auto shapedTypeOp = |
| 81 | + dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner()); |
| 82 | + if (!shapedTypeOp) |
| 83 | + return failure(); |
| 84 | + |
| 85 | + Optional<int64_t> dimIndex = dimOp.getConstantIndex(); |
| 86 | + if (!dimIndex) |
| 87 | + return failure(); |
| 88 | + Value replacement = |
| 89 | + getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex); |
| 90 | + if (!replacement) |
| 91 | + return failure(); |
| 92 | + rewriter.replaceOp(dimOp, replacement); |
| 93 | + return success(); |
| 94 | + } |
| 95 | +}; |
| 96 | +} // namespace |
| 97 | + |
| 98 | +//===----------------------------------------------------------------------===// |
| 99 | +// Pass registration |
| 100 | +//===----------------------------------------------------------------------===// |
| 101 | + |
| 102 | +namespace { |
| 103 | +#define GEN_PASS_CLASSES |
| 104 | +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
| 105 | + |
| 106 | +struct ResolveShapedTypeResultDimsPass final |
| 107 | + : public ResolveShapedTypeResultDimsBase<ResolveShapedTypeResultDimsPass> { |
| 108 | + void runOnOperation() override; |
| 109 | +}; |
| 110 | +} // namespace |
| 111 | + |
| 112 | +void memref::populateResolveShapedTypeResultDimsPatterns( |
| 113 | + RewritePatternSet &patterns) { |
| 114 | + patterns.add<DimOfShapedTypeOpInterface>(patterns.getContext()); |
| 115 | +} |
| 116 | + |
| 117 | +void ResolveShapedTypeResultDimsPass::runOnOperation() { |
| 118 | + RewritePatternSet patterns(&getContext()); |
| 119 | + memref::populateResolveShapedTypeResultDimsPatterns(patterns); |
| 120 | + if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(), |
| 121 | + std::move(patterns)))) |
| 122 | + return signalPassFailure(); |
| 123 | +} |
| 124 | + |
| 125 | +std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() { |
| 126 | + return std::make_unique<ResolveShapedTypeResultDimsPass>(); |
| 127 | +} |
0 commit comments