Skip to content

Commit 7060422

Browse files
authored
[mlir][Linalg]: Optimize linalg generic in transform::PromoteOp to avoid unnecessary copies (#68555)
If the operands are not used in the payload of linalg generic operations, there is no need to copy them before the operation.
1 parent a653749 commit 7060422

File tree

3 files changed

+11
-1
lines changed

3 files changed

+11
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/Transforms/FoldUtils.h"
2929
#include "llvm/ADT/MapVector.h"
3030
#include "llvm/ADT/SmallBitVector.h"
31+
#include "llvm/ADT/SmallSet.h"
3132
#include "llvm/ADT/TypeSwitch.h"
3233
#include "llvm/Support/CommandLine.h"
3334
#include "llvm/Support/Debug.h"
@@ -142,6 +143,8 @@ struct LinalgOpInstancePromotionOptions {
142143
const LinalgPromotionOptions &options);
143144
/// SubViews to promote.
144145
MapVector<int64_t, Value> subViews;
146+
/// Subviews operand numbers to copy in using copyInFn.
147+
llvm::SmallSet<int64_t, 4> operandsNumbersToCopyIn;
145148
/// True if the full view should be used for the promoted buffer.
146149
DenseMap<Value, bool> useFullTileBuffers;
147150

@@ -174,6 +177,11 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
174177
Operation *op = opOperand.get().getDefiningOp();
175178
if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
176179
subViews[operandNumber] = sv;
180+
// In case of linalg generic, copy in only if subview is used in linalg
181+
// payload.
182+
if (!isa<linalg::GenericOp>(linalgOp) ||
183+
linalgOp.payloadUsesValueFromOperand(&opOperand))
184+
operandsNumbersToCopyIn.insert(operandNumber);
177185
useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
178186
}
179187
}
@@ -324,6 +332,8 @@ promoteSubViews(ImplicitLocOpBuilder &b,
324332
auto info = promotionInfoMap.find(v.first);
325333
if (info == promotionInfoMap.end())
326334
continue;
335+
if (options.operandsNumbersToCopyIn.count(v.first) == 0)
336+
continue;
327337
if (failed(options.copyInFn(
328338
b, cast<memref::SubViewOp>(v.second.getDefiningOp()),
329339
info->second.partialLocalView)))

mlir/test/Dialect/GPU/promotion.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
// RUN: mlir-opt -allow-unregistered-dialect -pass-pipeline='builtin.module(gpu.module(gpu.func(test-gpu-memory-promotion)))' -split-input-file %s | FileCheck %s
23

34
gpu.module @foo {

mlir/test/Dialect/Linalg/promote.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,6 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
353353
// CHECK: %[[VAL_62:.*]] = memref.subview %[[VAL_61]][0, 0] {{\[}}%[[VAL_52]], %[[VAL_55]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
354354
// CHECK: memref.copy %[[VAL_3]], %[[VAL_24]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
355355
// CHECK: memref.copy %[[VAL_4]], %[[VAL_43]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
356-
// CHECK: memref.copy %[[VAL_5]], %[[VAL_62]] : memref<4x3xf32, strided<[4, 1]>, 1> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
357356
// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%[[VAL_24]], %[[VAL_43]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>, memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) outs(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) {
358357
// CHECK: ^bb0(%[[VAL_63:.*]]: f32, %[[VAL_64:.*]]: f32, %[[VAL_65:.*]]: f32):
359358
// CHECK: %[[VAL_66:.*]] = arith.addf %[[VAL_63]], %[[VAL_64]] : f32

0 commit comments

Comments
 (0)