Skip to content

Commit 00d1a1a

Browse files
committed
[mlir] Add ReplicateOp to the Transform dialect
This handle manipulation operation allows one to define a new handle that is associated with a the same payload IR operations N times, where N can be driven by the size of payload IR operation list associated with another handle. This can be seen as a sort of broadcast that can be used to ensure the lists associated with two handles have equal numbers of payload IR ops as expected by many pairwise transform operations. Introduce an additional "expensive" check that guards against consuming a handle that is assocaited with the same payload IR operation more than once as this is likely to lead to double-free or other undesired effects. Depends On D129110 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D129216
1 parent d4c5320 commit 00d1a1a

File tree

10 files changed

+295
-4
lines changed

10 files changed

+295
-4
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,27 @@ transposeResults(const SmallVector<SmallVector<Operation *>, 1> &m) {
845845
return res;
846846
}
847847
} // namespace detail
848+
849+
/// Populates `effects` with the memory effects indicating the operation on the
850+
/// given handle value:
851+
/// - consumes = Read + Free,
852+
/// - produces = Allocate + Write,
853+
/// - onlyReads = Read.
854+
void consumesHandle(ValueRange handles,
855+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
856+
void producesHandle(ValueRange handles,
857+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
858+
void onlyReadsHandle(ValueRange handles,
859+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
860+
861+
/// Checks whether the transform op consumes the given handle.
862+
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
863+
864+
/// Populates `effects` with the memory effects indicating the access to payload
865+
/// IR resource.
866+
void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
867+
void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
868+
848869
} // namespace transform
849870
} // namespace mlir
850871

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,42 @@ def PDLMatchOp : TransformDialectOp<"pdl_match",
174174
let assemblyFormat = "$pattern_name `in` $root attr-dict";
175175
}
176176

177+
def ReplicateOp : TransformDialectOp<"replicate",
178+
[DeclareOpInterfaceMethods<TransformOpInterface>,
179+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
180+
let summary = "Lists payload ops multiple times in the new handle";
181+
let description = [{
182+
Produces a new handle associated with a list of payload IR ops that is
183+
computed by repeating the list of payload IR ops associated with the
184+
operand handle as many times as the "pattern" handle has associated
185+
operations. For example, if pattern is associated with [op1, op2] and the
186+
operand handle is associated with [op3, op4, op5], the resulting handle
187+
will be associated with [op3, op4, op5, op3, op4, op5].
188+
189+
This transformation is useful to "align" the sizes of payload IR lists
190+
before a transformation that expects, e.g., identically-sized lists. For
191+
example, a transformation may be parameterized by same notional per-target
192+
size computed at runtime and supplied as another handle, the replication
193+
allows this size to be computed only once and used for every target instead
194+
of replicating the computation itself.
195+
196+
Note that it is undesirable to pass a handle with duplicate operations to
197+
an operation that consumes the handle. Handle consumption often indicates
198+
that the associated payload IR ops are destroyed, so having the same op
199+
listed more than once will lead to double-free. Single-operand
200+
MergeHandlesOp may be used to deduplicate the associated list of payload IR
201+
ops when necessary. Furthermore, a combination of ReplicateOp and
202+
MergeHandlesOp can be used to construct arbitrary lists with repetitions.
203+
}];
204+
205+
let arguments = (ins PDL_Operation:$pattern,
206+
Variadic<PDL_Operation>:$handles);
207+
let results = (outs Variadic<PDL_Operation>:$replicated);
208+
let assemblyFormat =
209+
"`num` `(` $pattern `)` $handles "
210+
"custom<PDLOpTypedResults>(type($replicated), ref($handles)) attr-dict";
211+
}
212+
177213
def SequenceOp : TransformDialectOp<"sequence",
178214
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
179215
["getSuccessorEntryOperands", "getSuccessorRegions",

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
5555
LogicalResult transform::TransformState::tryEmplaceReverseMapping(
5656
Mappings &map, Operation *operation, Value handle) {
5757
auto insertionResult = map.reverse.insert({operation, handle});
58-
if (!insertionResult.second) {
58+
if (!insertionResult.second && insertionResult.first->second != handle) {
5959
InFlightDiagnostic diag = operation->emitError()
6060
<< "operation tracked by two handles";
6161
diag.attachNote(handle.getLoc()) << "handle";
@@ -191,9 +191,27 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
191191
DiagnosedSilenceableFailure
192192
transform::TransformState::applyTransform(TransformOpInterface transform) {
193193
LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
194-
if (options.getExpensiveChecksEnabled() &&
195-
failed(checkAndRecordHandleInvalidation(transform))) {
196-
return DiagnosedSilenceableFailure::definiteFailure();
194+
if (options.getExpensiveChecksEnabled()) {
195+
if (failed(checkAndRecordHandleInvalidation(transform)))
196+
return DiagnosedSilenceableFailure::definiteFailure();
197+
198+
for (OpOperand &operand : transform->getOpOperands()) {
199+
if (!isHandleConsumed(operand.get(), transform))
200+
continue;
201+
202+
DenseSet<Operation *> seen;
203+
for (Operation *op : getPayloadOps(operand.get())) {
204+
if (!seen.insert(op).second) {
205+
DiagnosedSilenceableFailure diag =
206+
transform.emitSilenceableError()
207+
<< "a handle passed as operand #" << operand.getOperandNumber()
208+
<< " and consumed by this operation points to a payload "
209+
"operation more than once";
210+
diag.attachNote(op->getLoc()) << "repeated target op";
211+
return diag;
212+
}
213+
}
214+
}
197215
}
198216

199217
transform::TransformResults results(transform->getNumResults());
@@ -326,6 +344,70 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
326344
return success();
327345
}
328346

347+
//===----------------------------------------------------------------------===//
348+
// Memory effects.
349+
//===----------------------------------------------------------------------===//
350+
351+
void transform::consumesHandle(
352+
ValueRange handles,
353+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
354+
for (Value handle : handles) {
355+
effects.emplace_back(MemoryEffects::Read::get(), handle,
356+
TransformMappingResource::get());
357+
effects.emplace_back(MemoryEffects::Free::get(), handle,
358+
TransformMappingResource::get());
359+
}
360+
}
361+
362+
/// Returns `true` if the given list of effects instances contains an instance
363+
/// with the effect type specified as template parameter.
364+
template <typename EffectTy>
365+
static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects) {
366+
return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
367+
return isa<EffectTy>(effect.getEffect());
368+
});
369+
}
370+
371+
bool transform::isHandleConsumed(Value handle,
372+
transform::TransformOpInterface transform) {
373+
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
374+
SmallVector<MemoryEffects::EffectInstance> effects;
375+
iface.getEffectsOnValue(handle, effects);
376+
return hasEffect<MemoryEffects::Read>(effects) &&
377+
hasEffect<MemoryEffects::Free>(effects);
378+
}
379+
380+
void transform::producesHandle(
381+
ValueRange handles,
382+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
383+
for (Value handle : handles) {
384+
effects.emplace_back(MemoryEffects::Allocate::get(), handle,
385+
TransformMappingResource::get());
386+
effects.emplace_back(MemoryEffects::Write::get(), handle,
387+
TransformMappingResource::get());
388+
}
389+
}
390+
391+
void transform::onlyReadsHandle(
392+
ValueRange handles,
393+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
394+
for (Value handle : handles) {
395+
effects.emplace_back(MemoryEffects::Read::get(), handle,
396+
TransformMappingResource::get());
397+
}
398+
}
399+
400+
void transform::modifiesPayload(
401+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
402+
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
403+
effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
404+
}
405+
406+
void transform::onlyReadsPayload(
407+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
408+
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
409+
}
410+
329411
//===----------------------------------------------------------------------===//
330412
// Generated interface implementation.
331413
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@
2323

2424
using namespace mlir;
2525

26+
static ParseResult parsePDLOpTypedResults(
27+
OpAsmParser &parser, SmallVectorImpl<Type> &types,
28+
const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
29+
types.resize(handles.size(), pdl::OperationType::get(parser.getContext()));
30+
return success();
31+
}
32+
33+
static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
34+
ValueRange) {}
35+
2636
#define GET_OP_CLASSES
2737
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
2838

@@ -354,6 +364,33 @@ transform::PDLMatchOp::apply(transform::TransformResults &results,
354364
return DiagnosedSilenceableFailure::success();
355365
}
356366

367+
//===----------------------------------------------------------------------===//
368+
// ReplicateOp
369+
//===----------------------------------------------------------------------===//
370+
371+
DiagnosedSilenceableFailure
372+
transform::ReplicateOp::apply(transform::TransformResults &results,
373+
transform::TransformState &state) {
374+
unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
375+
for (const auto &en : llvm::enumerate(getHandles())) {
376+
Value handle = en.value();
377+
ArrayRef<Operation *> current = state.getPayloadOps(handle);
378+
SmallVector<Operation *> payload;
379+
payload.reserve(numRepetitions * current.size());
380+
for (unsigned i = 0; i < numRepetitions; ++i)
381+
llvm::append_range(payload, current);
382+
results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
383+
}
384+
return DiagnosedSilenceableFailure::success();
385+
}
386+
387+
void transform::ReplicateOp::getEffects(
388+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
389+
onlyReadsHandle(getPattern(), effects);
390+
consumesHandle(getHandles(), effects);
391+
producesHandle(getReplicated(), effects);
392+
}
393+
357394
//===----------------------------------------------------------------------===//
358395
// SequenceOp
359396
//===----------------------------------------------------------------------===//

mlir/python/mlir/dialects/_transform_ops_ext.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@ def __init__(self,
5959
ip=ip)
6060

6161

62+
class ReplicateOp:
63+
64+
def __init__(self,
65+
pattern: Union[Operation, Value],
66+
handles: Sequence[Union[Operation, Value]],
67+
*,
68+
loc=None,
69+
ip=None):
70+
super().__init__(
71+
[pdl.OperationType.get()] * len(handles),
72+
_get_op_result_or_value(pattern),
73+
[_get_op_result_or_value(h) for h in handles],
74+
loc=loc,
75+
ip=ip)
76+
77+
6278
class SequenceOp:
6379

6480
@overload

mlir/test/Dialect/Transform/expensive-checks.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,37 @@ transform.with_pdl_patterns {
2525
test_print_remark_at_operand %0, "remark"
2626
}
2727
}
28+
29+
// -----
30+
31+
func.func @func1() {
32+
// expected-note @below {{repeated target op}}
33+
return
34+
}
35+
func.func private @func2()
36+
37+
transform.with_pdl_patterns {
38+
^bb0(%arg0: !pdl.operation):
39+
pdl.pattern @func : benefit(1) {
40+
%0 = operands
41+
%1 = types
42+
%2 = operation "func.func"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
43+
rewrite %2 with "transform.dialect"
44+
}
45+
pdl.pattern @return : benefit(1) {
46+
%0 = operands
47+
%1 = types
48+
%2 = operation "func.return"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
49+
rewrite %2 with "transform.dialect"
50+
}
51+
52+
sequence %arg0 {
53+
^bb1(%arg1: !pdl.operation):
54+
%0 = pdl_match @func in %arg1
55+
%1 = pdl_match @return in %arg1
56+
%2 = replicate num(%0) %1
57+
// expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}}
58+
test_consume_operand %2
59+
test_print_remark_at_operand %0, "remark"
60+
}
61+
}

mlir/test/Dialect/Transform/test-interpreter.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,3 +569,31 @@ transform.with_pdl_patterns {
569569
transform.test_mixed_sucess_and_silenceable %0
570570
}
571571
}
572+
573+
// -----
574+
575+
module {
576+
func.func private @foo()
577+
func.func private @bar()
578+
579+
transform.with_pdl_patterns {
580+
^bb0(%arg0: !pdl.operation):
581+
pdl.pattern @func : benefit(1) {
582+
%0 = pdl.operands
583+
%1 = pdl.types
584+
%2 = pdl.operation "func.func"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
585+
pdl.rewrite %2 with "transform.dialect"
586+
}
587+
588+
transform.sequence %arg0 {
589+
^bb0(%arg1: !pdl.operation):
590+
%0 = pdl_match @func in %arg1
591+
%1 = replicate num(%0) %arg1
592+
// expected-remark @below {{2}}
593+
test_print_number_of_associated_payload_ir_ops %1
594+
%2 = replicate num(%0) %1
595+
// expected-remark @below {{4}}
596+
test_print_number_of_associated_payload_ir_ops %2
597+
}
598+
}
599+
}

mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,18 @@ mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
275275
return emitDefaultSilenceableFailure(target);
276276
}
277277

278+
DiagnosedSilenceableFailure
279+
mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
280+
transform::TransformResults &results, transform::TransformState &state) {
281+
emitRemark() << state.getPayloadOps(getHandle()).size();
282+
return DiagnosedSilenceableFailure::success();
283+
}
284+
285+
void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
286+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
287+
transform::onlyReadsHandle(getHandle(), effects);
288+
}
289+
278290
namespace {
279291
/// Test extension of the Transform dialect. Registers additional ops and
280292
/// declares PDL as dependent dialect since the additional ops are using PDL

mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,4 +212,13 @@ def TestMixedSuccessAndSilenceableOp
212212
}];
213213
}
214214

215+
def TestPrintNumberOfAssociatedPayloadIROps
216+
: Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_ops",
217+
[DeclareOpInterfaceMethods<TransformOpInterface>,
218+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
219+
let arguments = (ins PDL_Operation:$handle);
220+
let assemblyFormat = "$handle attr-dict";
221+
let cppNamespace = "::mlir::test";
222+
}
223+
215224
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD

mlir/test/python/dialects/transform.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,19 @@ def testMergeHandlesOp():
9494
# CHECK: transform.sequence
9595
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
9696
# CHECK: = merge_handles %[[ARG1]]
97+
98+
99+
@run
100+
def testReplicateOp():
101+
with_pdl = transform.WithPDLPatternsOp()
102+
with InsertionPoint(with_pdl.body):
103+
sequence = transform.SequenceOp(with_pdl.bodyTarget)
104+
with InsertionPoint(sequence.body):
105+
m1 = transform.PDLMatchOp(sequence.bodyTarget, "first")
106+
m2 = transform.PDLMatchOp(sequence.bodyTarget, "second")
107+
transform.ReplicateOp(m1, [m2])
108+
transform.YieldOp()
109+
# CHECK-LABEL: TEST: testReplicateOp
110+
# CHECK: %[[FIRST:.+]] = pdl_match
111+
# CHECK: %[[SECOND:.+]] = pdl_match
112+
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]

0 commit comments

Comments
 (0)