Skip to content

Commit 7039bd2

Browse files
committed
1 parent a703d15 commit 7039bd2

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def MemRef_Dialect : Dialect {
2121
}];
2222
let dependentDialects = ["arith::ArithDialect"];
2323
let hasConstantMaterializer = 1;
24+
let useFoldAPI = kEmitFoldAdaptorFolder;
2425
}
2526

2627
#endif // MEMREF_BASE

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
808808
return false;
809809
}
810810

811-
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
811+
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
812812
return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
813813
}
814814

@@ -883,7 +883,7 @@ void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
883883
results.add<FoldCopyOfCast, FoldSelfCopy>(context);
884884
}
885885

886-
LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
886+
LogicalResult CopyOp::fold(FoldAdaptor adaptor,
887887
SmallVectorImpl<OpFoldResult> &results) {
888888
/// copy(memrefcast) -> copy
889889
bool folded = false;
@@ -902,7 +902,7 @@ LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands,
902902
// DeallocOp
903903
//===----------------------------------------------------------------------===//
904904

905-
LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
905+
LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
906906
SmallVectorImpl<OpFoldResult> &results) {
907907
/// dealloc(memrefcast) -> dealloc
908908
return foldMemRefCast(*this);
@@ -1056,9 +1056,9 @@ llvm::SmallBitVector SubViewOp::getDroppedDims() {
10561056
return *unusedDims;
10571057
}
10581058

1059-
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
1059+
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
10601060
// All forms of folding require a known index.
1061-
auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
1061+
auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
10621062
if (!index)
10631063
return {};
10641064

@@ -1322,7 +1322,7 @@ LogicalResult DmaStartOp::verify() {
13221322
return success();
13231323
}
13241324

1325-
LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1325+
LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
13261326
SmallVectorImpl<OpFoldResult> &results) {
13271327
/// dma_start(memrefcast) -> dma_start
13281328
return foldMemRefCast(*this);
@@ -1332,7 +1332,7 @@ LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
13321332
// DmaWaitOp
13331333
// ---------------------------------------------------------------------------
13341334

1335-
LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1335+
LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
13361336
SmallVectorImpl<OpFoldResult> &results) {
13371337
/// dma_wait(memrefcast) -> dma_wait
13381338
return foldMemRefCast(*this);
@@ -1433,7 +1433,7 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
14331433
}
14341434

14351435
LogicalResult
1436-
ExtractStridedMetadataOp::fold(ArrayRef<Attribute> cstOperands,
1436+
ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
14371437
SmallVectorImpl<OpFoldResult> &results) {
14381438
OpBuilder builder(*this);
14391439

@@ -1677,7 +1677,7 @@ LogicalResult LoadOp::verify() {
16771677
return success();
16781678
}
16791679

1680-
OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1680+
OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
16811681
/// load(memrefcast) -> load
16821682
if (succeeded(foldMemRefCast(*this)))
16831683
return getResult();
@@ -1747,7 +1747,7 @@ LogicalResult PrefetchOp::verify() {
17471747
return success();
17481748
}
17491749

1750-
LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1750+
LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
17511751
SmallVectorImpl<OpFoldResult> &results) {
17521752
// prefetch(memrefcast) -> prefetch
17531753
return foldMemRefCast(*this);
@@ -1757,7 +1757,7 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
17571757
// RankOp
17581758
//===----------------------------------------------------------------------===//
17591759

1760-
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1760+
OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
17611761
// Constant fold rank when the rank of the operand is known.
17621762
auto type = getOperand().getType();
17631763
auto shapedType = type.dyn_cast<ShapedType>();
@@ -1881,7 +1881,7 @@ LogicalResult ReinterpretCastOp::verify() {
18811881
return success();
18821882
}
18831883

1884-
OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
1884+
OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
18851885
Value src = getSource();
18861886
auto getPrevSrc = [&]() -> Value {
18871887
// reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
@@ -2465,12 +2465,14 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
24652465
CollapseShapeOpMemRefCastFolder>(context);
24662466
}
24672467

2468-
OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
2469-
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
2468+
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2469+
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2470+
adaptor.getOperands());
24702471
}
24712472

2472-
OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
2473-
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
2473+
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2474+
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2475+
adaptor.getOperands());
24742476
}
24752477

24762478
//===----------------------------------------------------------------------===//
@@ -2522,7 +2524,7 @@ LogicalResult StoreOp::verify() {
25222524
return success();
25232525
}
25242526

2525-
LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
2527+
LogicalResult StoreOp::fold(FoldAdaptor adaptor,
25262528
SmallVectorImpl<OpFoldResult> &results) {
25272529
/// store(memrefcast) -> store
25282530
return foldMemRefCast(*this, getValueToStore());
@@ -3101,7 +3103,7 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
31013103
SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
31023104
}
31033105

3104-
OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
3106+
OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
31053107
auto resultShapedType = getResult().getType().cast<ShapedType>();
31063108
auto sourceShapedType = getSource().getType().cast<ShapedType>();
31073109

@@ -3217,7 +3219,7 @@ LogicalResult TransposeOp::verify() {
32173219
return success();
32183220
}
32193221

3220-
OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
3222+
OpFoldResult TransposeOp::fold(FoldAdaptor) {
32213223
if (succeeded(foldMemRefCast(*this)))
32223224
return getResult();
32233225
return {};
@@ -3393,7 +3395,7 @@ LogicalResult AtomicRMWOp::verify() {
33933395
return success();
33943396
}
33953397

3396-
OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
3398+
OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
33973399
/// atomicrmw(memrefcast) -> atomicrmw
33983400
if (succeeded(foldMemRefCast(*this, getValue())))
33993401
return getResult();

0 commit comments

Comments
 (0)