Skip to content

Commit 3ed3e43

Browse files
author
MaheshRavishankar
committed
[mlir] Move memref.dim canonicalization using InferShapedTypeOpInterface to a separate pass.
Based on dicussion in [this](https://llvm.discourse.group/t/remove-canonicalizer-for-memref-dim-via-shapedtypeopinterface/3641) thread the pattern to resolve the `memref.dim` of a value that is a result of an operation that implements the `InferShapedTypeOpInterface` is moved to a separate pass instead of running it as a canonicalization pass. This allows shape resolution to happen when explicitly required, instead of automatically through a canonicalization. Differential Revision: https://reviews.llvm.org/D104321
1 parent 838490d commit 3ed3e43

File tree

15 files changed

+635
-407
lines changed

15 files changed

+635
-407
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@
1616
#include "mlir/Pass/Pass.h"
1717

1818
namespace mlir {
19+
20+
class AffineDialect;
21+
namespace tensor {
22+
class TensorDialect;
23+
} // namespace tensor
24+
namespace vector {
25+
class VectorDialect;
26+
} // namespace vector
27+
1928
namespace memref {
2029

2130
//===----------------------------------------------------------------------===//
@@ -26,6 +35,11 @@ namespace memref {
2635
/// into `patterns`.
2736
void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
2837

38+
/// Appends patterns that resolve `memref.dim` operations with values that are
39+
/// defined by operations that implement the `InferShapedTypeOpInterface`, in
40+
/// terms of shapes of its input operands.
41+
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
42+
2943
//===----------------------------------------------------------------------===//
3044
// Passes
3145
//===----------------------------------------------------------------------===//
@@ -34,6 +48,11 @@ void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
3448
/// load/store ops into `patterns`.
3549
std::unique_ptr<Pass> createFoldSubViewOpsPass();
3650

51+
/// Creates an operation pass to resolve `memref.dim` operations with values
52+
/// that are defined by operations that implement the
53+
/// `InferShapedTypeOpInterface`, in terms of shapes of its input operands.
54+
std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();
55+
3756
//===----------------------------------------------------------------------===//
3857
// Registration
3958
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@ def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
2323
];
2424
}
2525

26+
def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
27+
let summary = "Resolve memref.dim of result values";
28+
let description = [{
29+
The pass resolves memref.dim of result of operations that
30+
implement the `InferShapedTypeOpInterface` in terms of shapes of
31+
its operands.
32+
}];
33+
let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()";
34+
let dependentDialects = [
35+
"memref::MemRefDialect", "tensor::TensorDialect"
36+
];
37+
}
2638

2739
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
2840

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -794,84 +794,12 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
794794
return success();
795795
}
796796
};
797-
798-
/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
799-
/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
800-
/// TODO(ravishankarm): This is better put as a interface utility method
801-
/// somewhere, but that would imply the interface will depend on the `tensor`
802-
/// dialect. Ideally maybe a utility method in the `tensor` dialect.
803-
static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
804-
int64_t dimIndex) {
805-
unsigned resultNumber = result.getResultNumber();
806-
auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
807-
Location loc = result.getOwner()->getLoc();
808-
if (!shapedTypeOp)
809-
return nullptr;
810-
811-
// The interface exposes two methods, one that returns the shape of all the
812-
// results as `Value` and other that returns the shape as a list of
813-
// `SmallVector<Value>`. The former takes precedence over the latter. So first
814-
// check if the op implements the first interface method or the second, and
815-
// get the value to use appropriately.
816-
SmallVector<Value> reifiedResultShapes;
817-
if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
818-
builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
819-
if (reifiedResultShapes.size() <= resultNumber)
820-
return nullptr;
821-
Value resultShape = reifiedResultShapes[resultNumber];
822-
auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
823-
if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
824-
return nullptr;
825-
return builder.create<tensor::ExtractOp>(
826-
loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
827-
}
828-
829-
SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
830-
if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
831-
builder, reifiedResultShapesPerDim)))
832-
return nullptr;
833-
if (reifiedResultShapesPerDim.size() <= resultNumber ||
834-
reifiedResultShapesPerDim[resultNumber].size() !=
835-
static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
836-
return nullptr;
837-
OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
838-
if (auto attr = valueOrAttr.dyn_cast<Attribute>())
839-
return builder.createOrFold<ConstantIndexOp>(
840-
loc, attr.cast<IntegerAttr>().getInt());
841-
return valueOrAttr.get<Value>();
842-
}
843-
844-
/// Fold dim of an operation that implements the InferShapedTypeOpInterface
845-
struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
846-
using OpRewritePattern<DimOp>::OpRewritePattern;
847-
848-
LogicalResult matchAndRewrite(DimOp dimOp,
849-
PatternRewriter &rewriter) const override {
850-
OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
851-
if (!dimValue)
852-
return failure();
853-
auto shapedTypeOp =
854-
dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
855-
if (!shapedTypeOp)
856-
return failure();
857-
858-
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
859-
if (!dimIndex)
860-
return failure();
861-
Value replacement =
862-
getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
863-
if (!replacement)
864-
return failure();
865-
rewriter.replaceOp(dimOp, replacement);
866-
return success();
867-
}
868-
};
869797
} // end anonymous namespace.
870798

871799
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
872800
MLIRContext *context) {
873801
results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
874-
DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
802+
DimOfCastOp<tensor::CastOp>>(context);
875803
}
876804

877805
// ---------------------------------------------------------------------------

mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_dialect_library(MLIRMemRefTransforms
22
FoldSubViewOps.cpp
3+
ResolveShapedTypeResultDims.cpp
34

45
ADDITIONAL_HEADER_DIRS
56
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
@@ -9,9 +10,11 @@ add_mlir_dialect_library(MLIRMemRefTransforms
910

1011
LINK_LIBS PUBLIC
1112
MLIRAffine
13+
MLIRInferTypeOpInterface
1214
MLIRMemRef
1315
MLIRPass
1416
MLIRStandard
17+
MLIRTensor
1518
MLIRTransforms
1619
MLIRVector
1720
)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
//===- ResolveShapedTypeResultDims.cpp - Resolve memref.dim ops of result values
2+
//-------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This pass resolves `memref.dim` operations of result values in terms of
11+
// shapes of their operands using the `InferShapedTypeOpInterface`.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
16+
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
17+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
18+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
19+
#include "mlir/Interfaces/InferTypeOpInterface.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
22+
using namespace mlir;
23+
24+
/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
25+
/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
26+
/// TODO(ravishankarm): This is better put as a interface utility method
27+
/// somewhere, but that would imply the interface will depend on the `tensor`
28+
/// dialect. Ideally maybe a utility method in the `tensor` dialect.
29+
static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
30+
int64_t dimIndex) {
31+
unsigned resultNumber = result.getResultNumber();
32+
auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
33+
Location loc = result.getOwner()->getLoc();
34+
if (!shapedTypeOp)
35+
return nullptr;
36+
37+
// The interface exposes two methods, one that returns the shape of all the
38+
// results as `Value` and other that returns the shape as a list of
39+
// `SmallVector<Value>`. The former takes precedence over the latter. So first
40+
// check if the op implements the first interface method or the second, and
41+
// get the value to use appropriately.
42+
SmallVector<Value> reifiedResultShapes;
43+
if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
44+
builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
45+
if (reifiedResultShapes.size() <= resultNumber)
46+
return nullptr;
47+
Value resultShape = reifiedResultShapes[resultNumber];
48+
auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
49+
if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
50+
return nullptr;
51+
return builder.create<tensor::ExtractOp>(
52+
loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
53+
}
54+
55+
SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
56+
if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
57+
builder, reifiedResultShapesPerDim)))
58+
return nullptr;
59+
if (reifiedResultShapesPerDim.size() <= resultNumber ||
60+
reifiedResultShapesPerDim[resultNumber].size() !=
61+
static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
62+
return nullptr;
63+
OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
64+
if (auto attr = valueOrAttr.dyn_cast<Attribute>())
65+
return builder.createOrFold<ConstantIndexOp>(
66+
loc, attr.cast<IntegerAttr>().getInt());
67+
return valueOrAttr.get<Value>();
68+
}
69+
70+
namespace {
71+
/// Fold dim of an operation that implements the InferShapedTypeOpInterface
72+
struct DimOfShapedTypeOpInterface : public OpRewritePattern<memref::DimOp> {
73+
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
74+
75+
LogicalResult matchAndRewrite(memref::DimOp dimOp,
76+
PatternRewriter &rewriter) const override {
77+
OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
78+
if (!dimValue)
79+
return failure();
80+
auto shapedTypeOp =
81+
dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
82+
if (!shapedTypeOp)
83+
return failure();
84+
85+
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
86+
if (!dimIndex)
87+
return failure();
88+
Value replacement =
89+
getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
90+
if (!replacement)
91+
return failure();
92+
rewriter.replaceOp(dimOp, replacement);
93+
return success();
94+
}
95+
};
96+
} // namespace
97+
98+
//===----------------------------------------------------------------------===//
99+
// Pass registration
100+
//===----------------------------------------------------------------------===//
101+
102+
namespace {
103+
#define GEN_PASS_CLASSES
104+
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
105+
106+
struct ResolveShapedTypeResultDimsPass final
107+
: public ResolveShapedTypeResultDimsBase<ResolveShapedTypeResultDimsPass> {
108+
void runOnOperation() override;
109+
};
110+
} // namespace
111+
112+
void memref::populateResolveShapedTypeResultDimsPatterns(
113+
RewritePatternSet &patterns) {
114+
patterns.add<DimOfShapedTypeOpInterface>(patterns.getContext());
115+
}
116+
117+
void ResolveShapedTypeResultDimsPass::runOnOperation() {
118+
RewritePatternSet patterns(&getContext());
119+
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
120+
if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
121+
std::move(patterns))))
122+
return signalPassFailure();
123+
}
124+
125+
std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
126+
return std::make_unique<ResolveShapedTypeResultDimsPass>();
127+
}

0 commit comments

Comments
 (0)