Skip to content

Commit c5dee18

Browse files
committed
[mlir][memref] Add support for erasing dead allocations.
Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D159135
1 parent 7248e57 commit c5dee18

File tree

6 files changed

+152
-0
lines changed

6 files changed

+152
-0
lines changed

mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,44 @@ def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
159159
"$target attr-dict `:` functional-type(operands, results)";
160160
}
161161

162+
def MemRefEraseDeadAllocAndStoresOp
163+
: Op<Transform_Dialect, "memref.erase_dead_alloc_and_stores", [
164+
TransformEachOpTrait, TransformOpInterface,
165+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
166+
ReportTrackingListenerFailuresOpTrait
167+
]> {
168+
let description = [{
169+
This applies memory optimization on memref. In particular it does store to
170+
load forwarding, dead store elimination and dead alloc elimination.
171+
172+
#### Return modes
173+
174+
This operation applies a set of memory optimization on the whole region of
175+
the operand.
176+
177+
The transformation does not consume the target handle. It modifies the
178+
payload. Dead allocations, loads and stores are silently dropped from all
179+
mappings.
180+
}];
181+
182+
let arguments = (ins TransformHandleTypeInterface:$target);
183+
let results = (outs);
184+
185+
let assemblyFormat = "$target attr-dict `:` functional-type($target, results)";
186+
187+
let skipDefaultBuilders = 1;
188+
let builders = [
189+
OpBuilder<(ins "Value":$target)>
190+
];
191+
let extraClassDeclaration = [{
192+
::mlir::DiagnosedSilenceableFailure applyToOne(
193+
::mlir::transform::TransformRewriter &rewriter,
194+
::mlir::Operation *target,
195+
::mlir::transform::ApplyToEachResultList &results,
196+
::mlir::transform::TransformState &state);
197+
}];
198+
}
199+
162200
def MemRefMakeLoopIndependentOp
163201
: Op<Transform_Dialect, "memref.make_loop_independent",
164202
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,

mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
6060
int dstBits, OpFoldResult offset,
6161
ArrayRef<OpFoldResult> sizes);
6262

63+
// Track temporary allocations that are never read from. If this is the case
64+
// it means both the allocations and associated stores can be removed.
65+
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp);
66+
6367
} // namespace memref
6468
} // namespace mlir
6569

mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1515
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
1616
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
17+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1718
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
1819
#include "mlir/Dialect/SCF/IR/SCF.h"
1920
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
2021
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
2122
#include "mlir/Dialect/Vector/IR/VectorOps.h"
23+
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
2224
#include "mlir/Interfaces/LoopLikeInterface.h"
2325
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2426
#include "llvm/Support/Debug.h"
@@ -132,6 +134,32 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
132134
return DiagnosedSilenceableFailure::success();
133135
}
134136

137+
//===----------------------------------------------------------------------===//
138+
// MemRefEraseDeadAllocAndStoresOp
139+
//===----------------------------------------------------------------------===//
140+
141+
DiagnosedSilenceableFailure
142+
transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
143+
transform::TransformRewriter &rewriter, Operation *target,
144+
transform::ApplyToEachResultList &results,
145+
transform::TransformState &state) {
146+
// Apply store to load forwarding and dead store elimination.
147+
vector::transferOpflowOpt(rewriter, target);
148+
memref::eraseDeadAllocAndStores(rewriter, target);
149+
return DiagnosedSilenceableFailure::success();
150+
}
151+
152+
void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
153+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
154+
transform::onlyReadsHandle(getTarget(), effects);
155+
transform::modifiesPayload(effects);
156+
}
157+
void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder,
158+
OperationState &result,
159+
Value target) {
160+
result.addOperands(target);
161+
}
162+
135163
//===----------------------------------------------------------------------===//
136164
// MemRefMakeLoopIndependentOp
137165
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1515
#include "mlir/Dialect/Arith/Utils/Utils.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1718

1819
namespace mlir {
1920
namespace memref {
@@ -120,5 +121,39 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
120121
return linearizedMemRefInfo;
121122
}
122123

124+
/// Returns true if all the uses of op are not read/load.
125+
/// There can be SubviewOp users as long as all its users are also
126+
/// StoreOp/transfer_write. If return true it also fills out the uses, if it
127+
/// returns false uses is unchanged.
128+
static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
129+
std::vector<Operation *> opUses;
130+
for (OpOperand &use : op->getUses()) {
131+
Operation *useOp = use.getOwner();
132+
if (isa<memref::DeallocOp>(useOp) ||
133+
(useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
134+
!mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
135+
(isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) {
136+
opUses.push_back(useOp);
137+
continue;
138+
}
139+
return false;
140+
}
141+
uses.insert(uses.end(), opUses.begin(), opUses.end());
142+
return true;
143+
}
144+
145+
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) {
146+
std::vector<Operation *> opToErase;
147+
parentOp->walk([&](memref::AllocOp op) {
148+
std::vector<Operation *> candidates;
149+
if (resultIsNotRead(op, candidates)) {
150+
opToErase.insert(opToErase.end(), candidates.begin(), candidates.end());
151+
opToErase.push_back(op.getOperation());
152+
}
153+
});
154+
for (Operation *op : opToErase)
155+
rewriter.eraseOp(op);
156+
}
157+
123158
} // namespace memref
124159
} // namespace mlir

mlir/test/Dialect/MemRef/transform-ops.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,50 @@ transform.sequence failures(propagate) {
259259

260260
// -----
261261

262+
// CHECK-LABEL: func.func @dead_alloc
263+
func.func @dead_alloc() {
264+
// CHECK-NOT: %{{.+}} = memref.alloc
265+
%0 = memref.alloc() : memref<8x64xf32, 3>
266+
%1 = memref.subview %0[0, 0] [8, 4] [1, 1] : memref<8x64xf32, 3> to
267+
memref<8x4xf32, affine_map<(d0, d1) -> (d0 * 64 + d1)>, 3>
268+
%c0 = arith.constant 0 : index
269+
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x4xf32>
270+
vector.transfer_write %cst_0, %1[%c0, %c0] {in_bounds = [true, true]} :
271+
vector<1x4xf32>, memref<8x4xf32, affine_map<(d0, d1) -> (d0 * 64 + d1)>, 3>
272+
return
273+
}
274+
275+
transform.sequence failures(propagate) {
276+
^bb1(%arg1: !transform.any_op):
277+
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
278+
transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
279+
}
280+
281+
// -----
282+
283+
// CHECK-LABEL: @store_to_load
284+
// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>)
285+
// CHECK-NOT: memref.alloc()
286+
// CHECK-NOT: vector.transfer_write
287+
// CHECK-NOT: vector.transfer_read
288+
// CHECK: return %[[ARG]] : vector<4xf32>
289+
func.func @store_to_load(%arg: vector<4xf32>) -> vector<4xf32> {
290+
%c0 = arith.constant 0 : index
291+
%cst_1 = arith.constant 0.000000e+00 : f32
292+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32>
293+
vector.transfer_write %arg, %alloc[%c0] {in_bounds = [true]} : vector<4xf32>, memref<64xf32>
294+
%r = vector.transfer_read %alloc[%c0], %cst_1 {in_bounds = [true]} : memref<64xf32>, vector<4xf32>
295+
return %r : vector<4xf32>
296+
}
297+
298+
transform.sequence failures(propagate) {
299+
^bb1(%arg1: !transform.any_op):
300+
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
301+
transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
302+
}
303+
304+
// -----
305+
262306
// CHECK-LABEL: func @lower_to_llvm
263307
// CHECK-NOT: memref.alloc
264308
// CHECK: llvm.call @malloc

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11561,6 +11561,7 @@ cc_library(
1156111561
":AffineDialect",
1156211562
":ArithUtils",
1156311563
":MemRefDialect",
11564+
":VectorDialect",
1156411565
],
1156511566
)
1156611567

@@ -11664,11 +11665,13 @@ cc_library(
1166411665
":MemRefDialect",
1166511666
":MemRefTransformOpsIncGen",
1166611667
":MemRefTransforms",
11668+
":MemRefUtils",
1166711669
":NVGPUDialect",
1166811670
":SCFDialect",
1166911671
":TransformDialect",
1167011672
":TransformUtils",
1167111673
":VectorDialect",
11674+
":VectorTransforms",
1167211675
"//llvm:Support",
1167311676
],
1167411677
)

0 commit comments

Comments
 (0)