Skip to content

Commit ffb9bbf

Browse files
authored
[mlir][MemRef] Changed AssumeAlignment into a Pure ViewLikeOp (#139521)
Made AssumeAlignment a ViewLikeOp that returns a new SSA memref equal to its memref argument and made it have Pure trait. This gives it a defined memory effect that matches what it does in practice and makes it behave nicely with optimizations which won't get rid of it unless its result isn't being used.
1 parent a0a2a1e commit ffb9bbf

File tree

11 files changed

+87
-41
lines changed

11 files changed

+87
-41
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,22 +142,37 @@ class AllocLikeOp<string mnemonic,
142142
// AssumeAlignmentOp
143143
//===----------------------------------------------------------------------===//
144144

145-
def AssumeAlignmentOp : MemRef_Op<"assume_alignment"> {
145+
def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
146+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
147+
Pure,
148+
ViewLikeOpInterface,
149+
SameOperandsAndResultType
150+
]> {
146151
let summary =
147-
"assertion that gives alignment information to the input memref";
152+
"assumption that gives alignment information to the input memref";
148153
let description = [{
149-
The `assume_alignment` operation takes a memref and an integer of alignment
150-
value, and internally annotates the buffer with the given alignment. If
151-
the buffer isn't aligned to the given alignment, the behavior is undefined.
154+
The `assume_alignment` operation takes a memref and an integer alignment
155+
value. It returns a new SSA value of the same memref type, but associated
156+
with the assumption that the underlying buffer is aligned to the given
157+
alignment.
152158

153-
This operation doesn't affect the semantics of a correct program. It's for
154-
optimization only, and the optimization is best-effort.
159+
If the buffer isn't aligned to the given alignment, its result is poison.
160+
This operation doesn't affect the semantics of a program where the
161+
alignment assumption holds true. It is intended for optimization purposes,
162+
allowing the compiler to generate more efficient code based on the
163+
alignment assumption. The optimization is best-effort.
155164
}];
156165
let arguments = (ins AnyMemRef:$memref,
157166
ConfinedAttr<I32Attr, [IntPositive]>:$alignment);
158-
let results = (outs);
167+
let results = (outs AnyMemRef:$result);
159168

160169
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
170+
let extraClassDeclaration = [{
171+
MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }
172+
173+
Value getViewSource() { return getMemref(); }
174+
}];
175+
161176
let hasVerifier = 1;
162177
}
163178

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,7 @@ struct AssumeAlignmentOpLowering
432432
createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
433433
rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
434434
alignmentConst);
435-
436-
rewriter.eraseOp(op);
435+
rewriter.replaceOp(op, memref);
437436
return success();
438437
}
439438
};

mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,6 @@ using namespace mlir::gpu;
4444
// The functions below provide interface-like verification, but are too specific
4545
// to barrier elimination to become interfaces.
4646

47-
/// Implement the MemoryEffectsOpInterface in the suitable way.
48-
static bool isKnownNoEffectsOpWithoutInterface(Operation *op) {
49-
// memref::AssumeAlignment is conceptually pure, but marking it as such would
50-
// make DCE immediately remove it.
51-
return isa<memref::AssumeAlignmentOp>(op);
52-
}
53-
5447
/// Returns `true` if the op is defines the parallel region that is subject to
5548
/// barrier synchronization.
5649
static bool isParallelRegionBoundary(Operation *op) {
@@ -101,10 +94,6 @@ collectEffects(Operation *op,
10194
if (ignoreBarriers && isa<BarrierOp>(op))
10295
return true;
10396

104-
// Skip over ops that we know have no effects.
105-
if (isKnownNoEffectsOpWithoutInterface(op))
106-
return true;
107-
10897
// Collect effect instances the operation. Note that the implementation of
10998
// getEffects erases all effect instances that have the type other than the
11099
// template parameter so we collect them first in a local buffer and then

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,11 @@ LogicalResult AssumeAlignmentOp::verify() {
527527
return success();
528528
}
529529

530+
void AssumeAlignmentOp::getAsmResultNames(
531+
function_ref<void(Value, StringRef)> setNameFn) {
532+
setNameFn(getResult(), "assume_align");
533+
}
534+
530535
//===----------------------------------------------------------------------===//
531536
// CastOp
532537
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ struct ConvertMemRefAssumeAlignment final
229229
}
230230

231231
rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
232-
op, adaptor.getMemref(), adaptor.getAlignmentAttr());
232+
op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
233233
return success();
234234
}
235235
};

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,35 @@ struct ExtractStridedMetadataOpGetGlobalFolder
919919
}
920920
};
921921

922+
/// Pattern to replace `extract_strided_metadata(assume_alignment)`
923+
///
924+
/// With
925+
/// \verbatim
926+
/// extract_strided_metadata(memref)
927+
/// \endverbatim
928+
///
929+
/// Since `assume_alignment` is a view-like op that does not modify the
930+
/// underlying buffer, offset, sizes, or strides, extracting strided metadata
931+
/// from its result is equivalent to extracting it from its source. This
932+
/// canonicalization removes the unnecessary indirection.
933+
struct ExtractStridedMetadataOpAssumeAlignmentFolder
934+
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
935+
public:
936+
using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
937+
938+
LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
939+
PatternRewriter &rewriter) const override {
940+
auto assumeAlignmentOp =
941+
op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
942+
if (!assumeAlignmentOp)
943+
return failure();
944+
945+
rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>(
946+
op, assumeAlignmentOp.getViewSource());
947+
return success();
948+
}
949+
};
950+
922951
/// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
923952
/// source of the ViewLikeOp.
924953
class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
@@ -1185,6 +1214,7 @@ void memref::populateExpandStridedMetadataPatterns(
11851214
ExtractStridedMetadataOpSubviewFolder,
11861215
ExtractStridedMetadataOpCastFolder,
11871216
ExtractStridedMetadataOpMemorySpaceCastFolder,
1217+
ExtractStridedMetadataOpAssumeAlignmentFolder,
11881218
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
11891219
patterns.getContext());
11901220
}
@@ -1201,6 +1231,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
12011231
ExtractStridedMetadataOpReinterpretCastFolder,
12021232
ExtractStridedMetadataOpCastFolder,
12031233
ExtractStridedMetadataOpMemorySpaceCastFolder,
1234+
ExtractStridedMetadataOpAssumeAlignmentFolder,
12041235
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
12051236
patterns.getContext());
12061237
}

mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ func.func @load_and_assume(
683683
%arg0: memref<?x?xf32, strided<[?, ?], offset: ?>>,
684684
%i0: index, %i1: index)
685685
-> f32 {
686-
memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
687-
%2 = memref.load %arg0[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
686+
%arg0_align = memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
687+
%2 = memref.load %arg0_align[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
688688
func.return %2 : f32
689689
}

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/misc-other.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,11 @@ func.func @func_with_assert(%arg0: index, %arg1: index) {
1010
%0 = arith.cmpi slt, %arg0, %arg1 : index
1111
cf.assert %0, "%arg0 must be less than %arg1"
1212
return
13+
}
14+
15+
// CHECK-LABEL: func @func_with_assume_alignment(
16+
// CHECK: %0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
17+
func.func @func_with_assume_alignment(%arg0: memref<128xi8>) {
18+
%0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
19+
return
1320
}

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ func.func @memref_load_i4(%arg0: index) -> i4 {
6363

6464
func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
6565
%0 = memref.alloc() : memref<3x125xi4>
66-
memref.assume_alignment %0, 64 : memref<3x125xi4>
67-
%1 = memref.load %0[%arg0,%arg1] : memref<3x125xi4>
66+
%align0 =memref.assume_alignment %0, 64 : memref<3x125xi4>
67+
%1 = memref.load %align0[%arg0,%arg1] : memref<3x125xi4>
6868
return %1 : i4
6969
}
7070
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
@@ -73,9 +73,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
7373
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
7474
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
7575
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
76-
// CHECK: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
76+
// CHECK: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
7777
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
78-
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
78+
// CHECK: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
7979
// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
8080
// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
8181
// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
@@ -88,9 +88,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
8888
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
8989
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
9090
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
91-
// CHECK32: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
91+
// CHECK32: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
9292
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
93-
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
93+
// CHECK32: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
9494
// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
9595
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
9696
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
@@ -350,16 +350,16 @@ func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
350350

351351
func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
352352
%0 = memref.alloc() : memref<3x125xi4>
353-
memref.assume_alignment %0, 64 : memref<3x125xi4>
354-
memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
353+
%align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
354+
memref.store %arg2, %align0[%arg0,%arg1] : memref<3x125xi4>
355355
return
356356
}
357357
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
358358
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
359359
// CHECK: func @memref_store_i4_rank2(
360360
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
361361
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
362-
// CHECK-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
362+
// CHECK-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
363363
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
364364
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
365365
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
@@ -369,16 +369,16 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
369369
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
370370
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
371371
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
372-
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
373-
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
372+
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
373+
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
374374
// CHECK: return
375375

376376
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
377377
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)>
378378
// CHECK32: func @memref_store_i4_rank2(
379379
// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
380380
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
381-
// CHECK32-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
381+
// CHECK32-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
382382
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
383383
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
384384
// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
@@ -388,8 +388,8 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
388388
// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
389389
// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
390390
// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
391-
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
392-
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
391+
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
392+
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
393393
// CHECK32: return
394394

395395
// -----

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ func.func @invalid_memref_cast() {
878878
// alignment is not power of 2.
879879
func.func @assume_alignment(%0: memref<4x4xf16>) {
880880
// expected-error@+1 {{alignment must be power of 2}}
881-
memref.assume_alignment %0, 12 : memref<4x4xf16>
881+
%1 = memref.assume_alignment %0, 12 : memref<4x4xf16>
882882
return
883883
}
884884

@@ -887,7 +887,7 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
887887
// 0 alignment value.
888888
func.func @assume_alignment(%0: memref<4x4xf16>) {
889889
// expected-error@+1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
890-
memref.assume_alignment %0, 0 : memref<4x4xf16>
890+
%1 = memref.assume_alignment %0, 0 : memref<4x4xf16>
891891
return
892892
}
893893

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
284284
// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
285285
func.func @assume_alignment(%0: memref<4x4xf16>) {
286286
// CHECK: memref.assume_alignment %[[MEMREF]], 16 : memref<4x4xf16>
287-
memref.assume_alignment %0, 16 : memref<4x4xf16>
287+
%1 = memref.assume_alignment %0, 16 : memref<4x4xf16>
288288
return
289289
}
290290

0 commit comments

Comments
 (0)