Skip to content

Commit 7df7612

Browse files
committed
[mlir][NFC] Migrate rest of the dialects to the new fold API
1 parent 7039bd2 commit 7df7612

File tree

31 files changed

+72
-80
lines changed

31 files changed

+72
-80
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def Affine_Dialect : Dialect {
2424
let cppNamespace = "mlir";
2525
let hasConstantMaterializer = 1;
2626
let dependentDialects = ["arith::ArithDialect"];
27+
let useFoldAPI = kEmitFoldAdaptorFolder;
2728
}
2829

2930
// Base class for Affine dialect ops.

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def Bufferization_Dialect : Dialect {
6969
kEscapeAttrName = "bufferization.escape";
7070
}];
7171
let hasOperationAttrVerify = 1;
72+
let useFoldAPI = kEmitFoldAdaptorFolder;
7273
}
7374

7475
#endif // BUFFERIZATION_BASE

mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def Complex_Dialect : Dialect {
2222
let dependentDialects = ["arith::ArithDialect"];
2323
let hasConstantMaterializer = 1;
2424
let useDefaultAttributePrinterParser = 1;
25+
let useFoldAPI = kEmitFoldAdaptorFolder;
2526
}
2627

2728
#endif // COMPLEX_BASE

mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def EmitC_Dialect : Dialect {
3131
let hasConstantMaterializer = 1;
3232
let useDefaultTypePrinterParser = 1;
3333
let useDefaultAttributePrinterParser = 1;
34+
let useFoldAPI = kEmitFoldAdaptorFolder;
3435
}
3536

3637
#endif // MLIR_DIALECT_EMITC_IR_EMITCBASE

mlir/include/mlir/Dialect/Func/IR/FuncOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def Func_Dialect : Dialect {
2323
let cppNamespace = "::mlir::func";
2424
let dependentDialects = ["cf::ControlFlowDialect"];
2525
let hasConstantMaterializer = 1;
26+
let useFoldAPI = kEmitFoldAdaptorFolder;
2627
}
2728

2829
// Base class for Func dialect ops.

mlir/include/mlir/Dialect/GPU/IR/GPUBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def GPU_Dialect : Dialect {
5656
let dependentDialects = ["arith::ArithDialect"];
5757
let useDefaultAttributePrinterParser = 1;
5858
let useDefaultTypePrinterParser = 1;
59+
let useFoldAPI = kEmitFoldAdaptorFolder;
5960
}
6061

6162
def GPU_AsyncToken : DialectType<

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def LLVM_Dialect : Dialect {
3131
let hasRegionArgAttrVerify = 1;
3232
let hasRegionResultAttrVerify = 1;
3333
let hasOperationAttrVerify = 1;
34+
let useFoldAPI = kEmitFoldAdaptorFolder;
3435

3536
let extraClassDeclaration = [{
3637
/// Name of the data layout attributes.

mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def Linalg_Dialect : Dialect {
4646
let hasCanonicalizer = 1;
4747
let hasOperationAttrVerify = 1;
4848
let hasConstantMaterializer = 1;
49+
let useFoldAPI = kEmitFoldAdaptorFolder;
4950
let extraClassDeclaration = [{
5051
/// Attribute name used to to memoize indexing maps for named ops.
5152
constexpr const static ::llvm::StringLiteral

mlir/include/mlir/Dialect/Quant/QuantOpsBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def Quantization_Dialect : Dialect {
2020
let cppNamespace = "::mlir::quant";
2121

2222
let useDefaultTypePrinterParser = 1;
23+
let useFoldAPI = kEmitFoldAdaptorFolder;
2324
}
2425

2526
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def SCF_Dialect : Dialect {
2525
let name = "scf";
2626
let cppNamespace = "::mlir::scf";
2727
let dependentDialects = ["arith::ArithDialect"];
28+
let useFoldAPI = kEmitFoldAdaptorFolder;
2829
}
2930

3031
// Base class for SCF dialect ops.

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def SparseTensor_Dialect : Dialect {
8383

8484
let useDefaultAttributePrinterParser = 1;
8585
let useDefaultTypePrinterParser = 1;
86+
let useFoldAPI = kEmitFoldAdaptorFolder;
8687
}
8788

8889
#endif // SPARSETENSOR_BASE

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def Transform_Dialect : Dialect {
2323
"::mlir::pdl_interp::PDLInterpDialect",
2424
];
2525

26+
let useFoldAPI = kEmitFoldAdaptorFolder;
27+
2628
let extraClassDeclaration = [{
2729
/// Returns the named PDL constraint functions available in the dialect
2830
/// as a map from their name to the function.

mlir/include/mlir/IR/BuiltinDialect.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def Builtin_Dialect : Dialect {
3434

3535
public:
3636
}];
37+
38+
let useFoldAPI = kEmitFoldAdaptorFolder;
3739
}
3840

3941
#endif // BUILTIN_BASE

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ bool AffineApplyOp::isValidSymbol(Region *region) {
562562
});
563563
}
564564

565-
OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
565+
OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
566566
auto map = getAffineMap();
567567

568568
// Fold dims and symbols to existing values.
@@ -574,7 +574,7 @@ OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
574574

575575
// Otherwise, default to folding the map.
576576
SmallVector<Attribute, 1> result;
577-
if (failed(map.constantFold(operands, result)))
577+
if (failed(map.constantFold(adaptor.getMapOperands(), result)))
578578
return {};
579579
return result[0];
580580
}
@@ -2135,7 +2135,7 @@ static bool hasTrivialZeroTripCount(AffineForOp op) {
21352135
return tripCount && *tripCount == 0;
21362136
}
21372137

2138-
LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
2138+
LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
21392139
SmallVectorImpl<OpFoldResult> &results) {
21402140
bool folded = succeeded(foldLoopBounds(*this));
21412141
folded |= succeeded(canonicalizeLoopBounds(*this));
@@ -2723,7 +2723,7 @@ static void composeSetAndOperands(IntegerSet &set,
27232723
}
27242724

27252725
/// Canonicalize an affine if op's conditional (integer set + operands).
2726-
LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
2726+
LogicalResult AffineIfOp::fold(FoldAdaptor,
27272727
SmallVectorImpl<OpFoldResult> &) {
27282728
auto set = getIntegerSet();
27292729
SmallVector<Value, 4> operands(getOperands());
@@ -2858,7 +2858,7 @@ void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
28582858
results.add<SimplifyAffineOp<AffineLoadOp>>(context);
28592859
}
28602860

2861-
OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
2861+
OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
28622862
/// load(memrefcast) -> load
28632863
if (succeeded(memref::foldMemRefCast(*this)))
28642864
return getResult();
@@ -2975,7 +2975,7 @@ void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
29752975
results.add<SimplifyAffineOp<AffineStoreOp>>(context);
29762976
}
29772977

2978-
LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
2978+
LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
29792979
SmallVectorImpl<OpFoldResult> &results) {
29802980
/// store(memrefcast) -> store
29812981
return memref::foldMemRefCast(*this, getValueToStore());
@@ -3282,8 +3282,8 @@ struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
32823282
// %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
32833283
//
32843284

3285-
OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
3286-
return foldMinMaxOp(*this, operands);
3285+
OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3286+
return foldMinMaxOp(*this, adaptor.getOperands());
32873287
}
32883288

32893289
void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -3310,8 +3310,8 @@ void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
33103310
// %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
33113311
//
33123312

3313-
OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
3314-
return foldMinMaxOp(*this, operands);
3313+
OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3314+
return foldMinMaxOp(*this, adaptor.getOperands());
33153315
}
33163316

33173317
void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -3431,7 +3431,7 @@ void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
34313431
results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
34323432
}
34333433

3434-
LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
3434+
LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
34353435
SmallVectorImpl<OpFoldResult> &results) {
34363436
/// prefetch(memrefcast) -> prefetch
34373437
return memref::foldMemRefCast(*this);
@@ -3705,7 +3705,7 @@ static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
37053705
return success();
37063706
}
37073707

3708-
LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands,
3708+
LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
37093709
SmallVectorImpl<OpFoldResult> &results) {
37103710
return canonicalizeLoopBounds(*this);
37113711
}

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ void CloneOp::getEffects(
458458
SideEffects::DefaultResource::get());
459459
}
460460

461-
OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
461+
OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
462462
return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
463463
}
464464

@@ -560,7 +560,7 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
560560
// ToTensorOp
561561
//===----------------------------------------------------------------------===//
562562

563-
OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
563+
OpFoldResult ToTensorOp::fold(FoldAdaptor) {
564564
if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
565565
// Approximate alias analysis by conservatively folding only when no there
566566
// is no interleaved operation.
@@ -596,7 +596,7 @@ void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
596596
// ToMemrefOp
597597
//===----------------------------------------------------------------------===//
598598

599-
OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) {
599+
OpFoldResult ToMemrefOp::fold(FoldAdaptor) {
600600
if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
601601
if (memrefToTensor.getMemref().getType() == getType())
602602
return memrefToTensor.getMemref();

mlir/lib/Dialect/Complex/IR/ComplexOps.cpp

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ using namespace mlir::complex;
1717
// ConstantOp
1818
//===----------------------------------------------------------------------===//
1919

20-
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
21-
assert(operands.empty() && "constant has no operands");
20+
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
2221
return getValue();
2322
}
2423

@@ -68,8 +67,7 @@ LogicalResult ConstantOp::verify() {
6867
// CreateOp
6968
//===----------------------------------------------------------------------===//
7069

71-
OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
72-
assert(operands.size() == 2 && "binary op takes two operands");
70+
OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
7371
// Fold complex.create(complex.re(op), complex.im(op)).
7472
if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
7573
if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
@@ -85,9 +83,8 @@ OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
8583
// ImOp
8684
//===----------------------------------------------------------------------===//
8785

88-
OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
89-
assert(operands.size() == 1 && "unary op takes 1 operand");
90-
ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
86+
OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
87+
ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
9188
if (arrayAttr && arrayAttr.size() == 2)
9289
return arrayAttr[1];
9390
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
@@ -99,9 +96,8 @@ OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
9996
// ReOp
10097
//===----------------------------------------------------------------------===//
10198

102-
OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
103-
assert(operands.size() == 1 && "unary op takes 1 operand");
104-
ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
99+
OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
100+
ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
105101
if (arrayAttr && arrayAttr.size() == 2)
106102
return arrayAttr[0];
107103
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
@@ -113,9 +109,7 @@ OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
113109
// AddOp
114110
//===----------------------------------------------------------------------===//
115111

116-
OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
117-
assert(operands.size() == 2 && "binary op takes 2 operands");
118-
112+
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
119113
// complex.add(complex.sub(a, b), b) -> a
120114
if (auto sub = getLhs().getDefiningOp<SubOp>())
121115
if (getRhs() == sub.getRhs())
@@ -142,9 +136,7 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
142136
// SubOp
143137
//===----------------------------------------------------------------------===//
144138

145-
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
146-
assert(operands.size() == 2 && "binary op takes 2 operands");
147-
139+
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
148140
// complex.sub(complex.add(a, b), b) -> a
149141
if (auto add = getLhs().getDefiningOp<AddOp>())
150142
if (getRhs() == add.getRhs())
@@ -166,9 +158,7 @@ OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
166158
// NegOp
167159
//===----------------------------------------------------------------------===//
168160

169-
OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
170-
assert(operands.size() == 1 && "unary op takes 1 operand");
171-
161+
OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
172162
// complex.neg(complex.neg(a)) -> a
173163
if (auto negOp = getOperand().getDefiningOp<NegOp>())
174164
return negOp.getOperand();
@@ -180,9 +170,7 @@ OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
180170
// LogOp
181171
//===----------------------------------------------------------------------===//
182172

183-
OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
184-
assert(operands.size() == 1 && "unary op takes 1 operand");
185-
173+
OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
186174
// complex.log(complex.exp(a)) -> a
187175
if (auto expOp = getOperand().getDefiningOp<ExpOp>())
188176
return expOp.getOperand();
@@ -194,9 +182,7 @@ OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
194182
// ExpOp
195183
//===----------------------------------------------------------------------===//
196184

197-
OpFoldResult ExpOp::fold(ArrayRef<Attribute> operands) {
198-
assert(operands.size() == 1 && "unary op takes 1 operand");
199-
185+
OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
200186
// complex.exp(complex.log(a)) -> a
201187
if (auto logOp = getOperand().getDefiningOp<LogOp>())
202188
return logOp.getOperand();
@@ -208,9 +194,7 @@ OpFoldResult ExpOp::fold(ArrayRef<Attribute> operands) {
208194
// ConjOp
209195
//===----------------------------------------------------------------------===//
210196

211-
OpFoldResult ConjOp::fold(ArrayRef<Attribute> operands) {
212-
assert(operands.size() == 1 && "unary op takes 1 operand");
213-
197+
OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
214198
// complex.conj(complex.conj(a)) -> a
215199
if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
216200
return conjOp.getOperand();

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ LogicalResult emitc::ConstantOp::verify() {
129129
return success();
130130
}
131131

132-
OpFoldResult emitc::ConstantOp::fold(ArrayRef<Attribute> operands) {
133-
assert(operands.empty() && "constant has no operands");
132+
OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) {
134133
return getValue();
135134
}
136135

mlir/lib/Dialect/Func/IR/FuncOps.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,7 @@ LogicalResult ConstantOp::verify() {
201201
return success();
202202
}
203203

204-
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
205-
assert(operands.empty() && "constant has no operands");
204+
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
206205
return getValueAttr();
207206
}
208207

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,12 +1286,12 @@ LogicalResult SubgroupMmaComputeOp::verify() {
12861286
return success();
12871287
}
12881288

1289-
LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands,
1289+
LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
12901290
SmallVectorImpl<::mlir::OpFoldResult> &results) {
12911291
return memref::foldMemRefCast(*this);
12921292
}
12931293

1294-
LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
1294+
LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
12951295
SmallVectorImpl<::mlir::OpFoldResult> &results) {
12961296
return memref::foldMemRefCast(*this);
12971297
}

0 commit comments

Comments
 (0)