Skip to content

Commit 88f4292

Browse files
[mlir][bufferization] OneShotBufferizeOp: Add options to use linalg.copy
This new option allows users to specify a custom memcpy op. Differential Revision: https://reviews.llvm.org/D155280
1 parent 9ff7181 commit 88f4292

File tree

5 files changed

+54
-1
lines changed

5 files changed

+54
-1
lines changed

mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ def OneShotBufferizeOp
5858
DefaultValuedAttr<BoolAttr, "false">:$bufferize_function_boundaries,
5959
DefaultValuedAttr<BoolAttr, "true">:$create_deallocs,
6060
DefaultValuedAttr<BoolAttr, "false">:$test_analysis_only,
61-
DefaultValuedAttr<BoolAttr, "false">:$print_conflicts);
61+
DefaultValuedAttr<BoolAttr, "false">:$print_conflicts,
62+
DefaultValuedAttr<StrAttr, "\"memref.copy\"">:$memcpy_op);
6263

6364
let results = (outs TransformHandleTypeInterface:$transformed);
6465

66+
let hasVerifier = 1;
6567
let assemblyFormat = [{
6668
(`layout` `{` $function_boundary_type_conversion^ `}`)?
6769
$target attr-dict `:` functional-type($target, results)

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
1313
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
1414
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
15+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1516
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1617
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1718
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
@@ -25,6 +26,12 @@ using namespace mlir::transform;
2526
// OneShotBufferizeOp
2627
//===----------------------------------------------------------------------===//
2728

29+
LogicalResult transform::OneShotBufferizeOp::verify() {
30+
if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
31+
return emitOpError() << "unsupported memcpy op";
32+
return success();
33+
}
34+
2835
DiagnosedSilenceableFailure
2936
transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
3037
TransformResults &transformResults,
@@ -39,6 +46,19 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
3946
if (getFunctionBoundaryTypeConversion().has_value())
4047
options.setFunctionBoundaryTypeConversion(
4148
*getFunctionBoundaryTypeConversion());
49+
if (getMemcpyOp() == "memref.copy") {
50+
options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
51+
b.create<memref::CopyOp>(loc, from, to);
52+
return success();
53+
};
54+
} else if (getMemcpyOp() == "linalg.copy") {
55+
options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
56+
b.create<linalg::CopyOp>(loc, from, to);
57+
return success();
58+
};
59+
} else {
60+
llvm_unreachable("invalid copy op");
61+
}
4262

4363
auto payloadOps = state.getPayloadOps(getTarget());
4464
for (Operation *target : payloadOps) {

mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRBufferizationTransformOps
1111
MLIRIR
1212
MLIRBufferizationDialect
1313
MLIRBufferizationTransforms
14+
MLIRLinalgDialect
1415
MLIRParser
1516
MLIRPDLDialect
1617
MLIRSideEffectInterfaces

mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,35 @@ func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf3
2828

2929
// -----
3030

31+
// Emit linalg.copy instead of memref.copy.
32+
33+
transform.sequence failures(propagate) {
34+
^bb0(%arg1: !transform.any_op):
35+
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
36+
%1 = transform.bufferization.one_shot_bufferize %0 {memcpy_op = "linalg.copy"} : (!transform.any_op) -> !transform.any_op
37+
}
38+
39+
// CHECK-LABEL: func @test_function(
40+
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
41+
// CHECK-NOT: memref.copy
42+
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
43+
%c0 = arith.constant 0 : index
44+
45+
// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
46+
// CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
47+
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
48+
// CHECK: linalg.copy ins(%[[A_memref]] : memref<{{.*}}>) outs(%[[alloc]]
49+
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
50+
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
51+
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
52+
53+
// CHECK: memref.dealloc %[[alloc]]
54+
// CHECK: return %[[res_tensor]]
55+
return %0 : tensor<?xf32>
56+
}
57+
58+
// -----
59+
3160
// Test analysis of One-Shot Bufferize only.
3261

3362
transform.sequence failures(propagate) {

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11477,6 +11477,7 @@ cc_library(
1147711477
":BufferizationTransformOpsIncGen",
1147811478
":BufferizationTransforms",
1147911479
":IR",
11480+
":LinalgDialect",
1148011481
":MemRefDialect",
1148111482
":Parser",
1148211483
":SideEffectInterfaces",

0 commit comments

Comments
 (0)