Skip to content

Commit a6151f4

Browse files
[mlir][IR] Move match and rewrite functions into separate class (llvm#129861)
The vast majority of rewrite / conversion patterns uses a combined `matchAndRewrite` instead of separate `match` and `rewrite` functions. This PR optimizes the code base for the most common case where users implement a combined `matchAndRewrite`. There are no longer any `match` and `rewrite` functions in `RewritePattern`, `ConversionPattern` and their derived classes. Instead, there is a `SplitMatchAndRewriteImpl` class that implements `matchAndRewrite` in terms of `match` and `rewrite`. Details: * The `RewritePattern` and `ConversionPattern` classes are simpler (fewer functions). Especially the `ConversionPattern` class, which now has 5 fewer functions. (There were various `rewrite` overloads to account for 1:1 / 1:N patterns.) * There is a new class `SplitMatchAndRewriteImpl` that derives from `RewritePattern` / `OpRewritePatern` / ..., along with a type alias `RewritePattern::SplitMatchAndRewrite` for convenience. * Fewer `llvm_unreachable` are needed throughout the code base. Instead, we can use pure virtual functions. (In cases where users previously had to implement `rewrite` or `matchAndRewrite`, etc.) * This PR may also improve the number of [`-Woverload-virtual` warnings](https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933) that are produced by GCC. (To be confirmed...) Note for LLVM integration: Patterns with separate `match` / `rewrite` implementations, must derive from `X::SplitMatchAndRewrite` instead of `X`. --------- Co-authored-by: River Riddle <riddleriver@gmail.com>
1 parent 87976ca commit a6151f4

File tree

12 files changed

+175
-226
lines changed

12 files changed

+175
-226
lines changed

flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
187187

188188
const fir::FIRToLLVMPassOptions &options;
189189

190-
using ConvertToLLVMPattern::match;
191190
using ConvertToLLVMPattern::matchAndRewrite;
192191
};
193192

@@ -206,20 +205,6 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
206205
options, benefit) {}
207206

208207
/// Wrappers around the RewritePattern methods that pass the derived op type.
209-
void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
210-
mlir::ConversionPatternRewriter &rewriter) const final {
211-
rewrite(mlir::cast<SourceOp>(op),
212-
OpAdaptor(operands, mlir::cast<SourceOp>(op)), rewriter);
213-
}
214-
void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::ValueRange> operands,
215-
mlir::ConversionPatternRewriter &rewriter) const final {
216-
auto sourceOp = llvm::cast<SourceOp>(op);
217-
rewrite(llvm::cast<SourceOp>(op), OneToNOpAdaptor(operands, sourceOp),
218-
rewriter);
219-
}
220-
llvm::LogicalResult match(mlir::Operation *op) const final {
221-
return match(mlir::cast<SourceOp>(op));
222-
}
223208
llvm::LogicalResult
224209
matchAndRewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
225210
mlir::ConversionPatternRewriter &rewriter) const final {
@@ -235,28 +220,12 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
235220
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
236221
rewriter);
237222
}
238-
/// Rewrite and Match methods that operate on the SourceOp type. These must be
223+
/// Methods that operate on the SourceOp type. These must be
239224
/// overridden by the derived pattern class.
240-
virtual llvm::LogicalResult match(SourceOp op) const {
241-
llvm_unreachable("must override match or matchAndRewrite");
242-
}
243-
virtual void rewrite(SourceOp op, OpAdaptor adaptor,
244-
mlir::ConversionPatternRewriter &rewriter) const {
245-
llvm_unreachable("must override rewrite or matchAndRewrite");
246-
}
247-
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
248-
mlir::ConversionPatternRewriter &rewriter) const {
249-
llvm::SmallVector<mlir::Value> oneToOneOperands =
250-
getOneToOneAdaptorOperands(adaptor.getOperands());
251-
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
252-
}
253225
virtual llvm::LogicalResult
254226
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
255227
mlir::ConversionPatternRewriter &rewriter) const {
256-
if (mlir::failed(match(op)))
257-
return mlir::failure();
258-
rewrite(op, adaptor, rewriter);
259-
return mlir::success();
228+
llvm_unreachable("matchAndRewrite is not implemented");
260229
}
261230
virtual llvm::LogicalResult
262231
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -268,7 +237,6 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
268237

269238
private:
270239
using ConvertFIRToLLVMPattern::matchAndRewrite;
271-
using ConvertToLLVMPattern::match;
272240
};
273241

274242
/// FIR conversion pattern template

mlir/docs/PatternRewriter.md

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,23 @@ possible cost and use the predicate to guard the match.
3838
### Root Operation Name (Optional)
3939

4040
The name of the root operation that this pattern matches against. If specified,
41-
only operations with the given root name will be provided to the `match` and
42-
`rewrite` implementation. If not specified, any operation type may be provided.
43-
The root operation name should be provided whenever possible, because it
44-
simplifies the analysis of patterns when applying a cost model. To match any
41+
only operations with the given root name will be provided to the
42+
`matchAndRewrite` implementation. If not specified, any operation type may be
43+
provided. The root operation name should be provided whenever possible, because
44+
it simplifies the analysis of patterns when applying a cost model. To match any
4545
operation type, a special tag must be provided to make the intent explicit:
4646
`MatchAnyOpTypeTag`.
4747

48-
### `match` and `rewrite` implementation
48+
### `matchAndRewrite` implementation
4949

5050
This is the chunk of code that matches a given root `Operation` and performs a
5151
rewrite of the IR. A `RewritePattern` can specify this implementation either via
52-
separate `match` and `rewrite` methods, or via a combined `matchAndRewrite`
53-
method. When using the combined `matchAndRewrite` method, no IR mutation should
54-
take place before the match is deemed successful. The combined `matchAndRewrite`
55-
is useful when non-trivially recomputable information is required by the
56-
matching and rewriting phase. See below for examples:
52+
the `matchAndRewrite` method or via separate `match` and `rewrite` methods when
53+
deriving from `RewritePattern::SplitMatchAndRewrite`. When using the combined
54+
`matchAndRewrite` method, no IR mutation should take place before the match is
55+
deemed successful. The combined `matchAndRewrite` is useful when non-trivially
56+
recomputable information is required by the matching and rewriting phase. See
57+
below for examples:
5758

5859
```c++
5960
class MyPattern : public RewritePattern {
@@ -105,6 +106,10 @@ Within the `rewrite` section of a pattern, the following constraints apply:
105106
`eraseOp`) should be used instead.
106107
* The root operation is required to either be: updated in-place, replaced, or
107108
erased.
109+
* `matchAndRewrite` must return "success" if and only if the IR was modified.
110+
`match` must return "success" if and only if the IR is going to be modified
111+
during `rewrite`.
112+
108113
109114
### Application Recursion
110115

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ LogicalResult oneToOneRewrite(
4040
/// during the entire pattern lifetime.
4141
class ConvertToLLVMPattern : public ConversionPattern {
4242
public:
43+
using SplitMatchAndRewrite =
44+
detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
45+
4346
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
4447
const LLVMTypeConverter &typeConverter,
4548
PatternBenefit benefit = 1);
@@ -142,9 +145,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
142145
template <typename SourceOp>
143146
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
144147
public:
148+
using OperationT = SourceOp;
145149
using OpAdaptor = typename SourceOp::Adaptor;
146150
using OneToNOpAdaptor =
147151
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
152+
using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
153+
ConvertOpToLLVMPattern<SourceOp>>;
148154

149155
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
150156
PatternBenefit benefit = 1)
@@ -153,19 +159,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
153159
benefit) {}
154160

155161
/// Wrappers around the RewritePattern methods that pass the derived op type.
156-
void rewrite(Operation *op, ArrayRef<Value> operands,
157-
ConversionPatternRewriter &rewriter) const final {
158-
auto sourceOp = cast<SourceOp>(op);
159-
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
160-
}
161-
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
162-
ConversionPatternRewriter &rewriter) const final {
163-
auto sourceOp = cast<SourceOp>(op);
164-
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
165-
}
166-
LogicalResult match(Operation *op) const final {
167-
return match(cast<SourceOp>(op));
168-
}
169162
LogicalResult
170163
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
171164
ConversionPatternRewriter &rewriter) const final {
@@ -180,28 +173,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
180173
rewriter);
181174
}
182175

183-
/// Rewrite and Match methods that operate on the SourceOp type. These must be
176+
/// Methods that operate on the SourceOp type. One of these must be
184177
/// overridden by the derived pattern class.
185-
virtual LogicalResult match(SourceOp op) const {
186-
llvm_unreachable("must override match or matchAndRewrite");
187-
}
188-
virtual void rewrite(SourceOp op, OpAdaptor adaptor,
189-
ConversionPatternRewriter &rewriter) const {
190-
llvm_unreachable("must override rewrite or matchAndRewrite");
191-
}
192-
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
193-
ConversionPatternRewriter &rewriter) const {
194-
SmallVector<Value> oneToOneOperands =
195-
getOneToOneAdaptorOperands(adaptor.getOperands());
196-
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
197-
}
198178
virtual LogicalResult
199179
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
200180
ConversionPatternRewriter &rewriter) const {
201-
if (failed(match(op)))
202-
return failure();
203-
rewrite(op, adaptor, rewriter);
204-
return success();
181+
llvm_unreachable("matchAndRewrite is not implemented");
205182
}
206183
virtual LogicalResult
207184
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -212,7 +189,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
212189
}
213190

214191
private:
215-
using ConvertToLLVMPattern::match;
216192
using ConvertToLLVMPattern::matchAndRewrite;
217193
};
218194

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -234,41 +234,52 @@ class Pattern {
234234
// RewritePattern
235235
//===----------------------------------------------------------------------===//
236236

237-
/// RewritePattern is the common base class for all DAG to DAG replacements.
238-
/// There are two possible usages of this class:
239-
/// * Multi-step RewritePattern with "match" and "rewrite"
240-
/// - By overloading the "match" and "rewrite" functions, the user can
241-
/// separate the concerns of matching and rewriting.
242-
/// * Single-step RewritePattern with "matchAndRewrite"
243-
/// - By overloading the "matchAndRewrite" function, the user can perform
244-
/// the rewrite in the same call as the match.
245-
///
246-
class RewritePattern : public Pattern {
247-
public:
248-
virtual ~RewritePattern() = default;
237+
namespace detail {
238+
/// Helper class that derives from a RewritePattern class and provides separate
239+
/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
240+
template <typename PatternT>
241+
class SplitMatchAndRewriteImpl : public PatternT {
242+
using PatternT::PatternT;
243+
244+
/// Attempt to match against IR rooted at the specified operation, which is
245+
/// the same operation kind as getRootKind().
246+
///
247+
/// Note: This function must not modify the IR.
248+
virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
249249

250250
/// Rewrite the IR rooted at the specified operation with the result of
251251
/// this pattern, generating any new operations with the specified
252-
/// builder. If an unexpected error is encountered (an internal
253-
/// compiler error), it is emitted through the normal MLIR diagnostic
254-
/// hooks and the IR is left in a valid state.
255-
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
256-
257-
/// Attempt to match against code rooted at the specified operation,
258-
/// which is the same operation code as getRootKind().
259-
virtual LogicalResult match(Operation *op) const;
252+
/// rewriter.
253+
virtual void rewrite(typename PatternT::OperationT op,
254+
PatternRewriter &rewriter) const = 0;
260255

261-
/// Attempt to match against code rooted at the specified operation,
262-
/// which is the same operation code as getRootKind(). If successful, this
263-
/// function will automatically perform the rewrite.
264-
virtual LogicalResult matchAndRewrite(Operation *op,
265-
PatternRewriter &rewriter) const {
256+
LogicalResult matchAndRewrite(typename PatternT::OperationT op,
257+
PatternRewriter &rewriter) const final {
266258
if (succeeded(match(op))) {
267259
rewrite(op, rewriter);
268260
return success();
269261
}
270262
return failure();
271263
}
264+
};
265+
} // namespace detail
266+
267+
/// RewritePattern is the common base class for all DAG to DAG replacements.
268+
class RewritePattern : public Pattern {
269+
public:
270+
using OperationT = Operation *;
271+
using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
272+
273+
virtual ~RewritePattern() = default;
274+
275+
/// Attempt to match against code rooted at the specified operation,
276+
/// which is the same operation code as getRootKind(). If successful, perform
277+
/// the rewrite.
278+
///
279+
/// Note: Implementations must modify the IR if and only if the function
280+
/// returns "success".
281+
virtual LogicalResult matchAndRewrite(Operation *op,
282+
PatternRewriter &rewriter) const = 0;
272283

273284
/// This method provides a convenient interface for creating and initializing
274285
/// derived rewrite patterns of the given type `T`.
@@ -317,36 +328,19 @@ namespace detail {
317328
/// class or Interface.
318329
template <typename SourceOp>
319330
struct OpOrInterfaceRewritePatternBase : public RewritePattern {
331+
using OperationT = SourceOp;
320332
using RewritePattern::RewritePattern;
321333

322-
/// Wrappers around the RewritePattern methods that pass the derived op type.
323-
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
324-
rewrite(cast<SourceOp>(op), rewriter);
325-
}
326-
LogicalResult match(Operation *op) const final {
327-
return match(cast<SourceOp>(op));
328-
}
334+
/// Wrapper around the RewritePattern method that passes the derived op type.
329335
LogicalResult matchAndRewrite(Operation *op,
330336
PatternRewriter &rewriter) const final {
331337
return matchAndRewrite(cast<SourceOp>(op), rewriter);
332338
}
333339

334-
/// Rewrite and Match methods that operate on the SourceOp type. These must be
335-
/// overridden by the derived pattern class.
336-
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
337-
llvm_unreachable("must override rewrite or matchAndRewrite");
338-
}
339-
virtual LogicalResult match(SourceOp op) const {
340-
llvm_unreachable("must override match or matchAndRewrite");
341-
}
340+
/// Method that operates on the SourceOp type. Must be overridden by the
341+
/// derived pattern class.
342342
virtual LogicalResult matchAndRewrite(SourceOp op,
343-
PatternRewriter &rewriter) const {
344-
if (succeeded(match(op))) {
345-
rewrite(op, rewriter);
346-
return success();
347-
}
348-
return failure();
349-
}
343+
PatternRewriter &rewriter) const = 0;
350344
};
351345
} // namespace detail
352346

@@ -356,6 +350,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
356350
template <typename SourceOp>
357351
struct OpRewritePattern
358352
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
353+
using SplitMatchAndRewrite =
354+
detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
355+
359356
/// Patterns must specify the root operation name they match against, and can
360357
/// also specify the benefit of the pattern matching and a list of generated
361358
/// ops.
@@ -371,6 +368,9 @@ struct OpRewritePattern
371368
template <typename SourceOp>
372369
struct OpInterfaceRewritePattern
373370
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
371+
using SplitMatchAndRewrite =
372+
detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
373+
374374
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
375375
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
376376
Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),

0 commit comments

Comments
 (0)