Skip to content

Commit 9f13b93

Browse files
committed
[mlir][memref] Add realloc op.
Add memref.realloc and canonicalization of the op. Add conversion patterns for lowering the op to LLVM using unaligned alloc or aligned alloc based on the conversion option. Add filecheck tests for parsing and converting the op. Add an integration test. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D133424
1 parent e664dea commit 9f13b93

File tree

11 files changed

+850
-153
lines changed

11 files changed

+850
-153
lines changed

mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h

Lines changed: 84 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,34 +13,107 @@
1313

1414
namespace mlir {
1515

16-
/// Lowering for AllocOp and AllocaOp.
17-
struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern {
16+
/// Lowering for memory allocation ops.
17+
struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
1818
using ConvertToLLVMPattern::createIndexConstant;
1919
using ConvertToLLVMPattern::getIndexType;
2020
using ConvertToLLVMPattern::getVoidPtrType;
2121

22-
explicit AllocLikeOpLLVMLowering(StringRef opName,
23-
LLVMTypeConverter &converter)
22+
explicit AllocationOpLLVMLowering(StringRef opName,
23+
LLVMTypeConverter &converter)
2424
: ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
2525

2626
protected:
27-
// Returns 'input' aligned up to 'alignment'. Computes
28-
// bumped = input + alignement - 1
29-
// aligned = bumped - bumped % alignment
27+
/// Computes the aligned value for 'input' as follows:
28+
/// bumped = input + alignement - 1
29+
/// aligned = bumped - bumped % alignment
3030
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
3131
Value input, Value alignment);
3232

33+
static MemRefType getMemRefResultType(Operation *op) {
34+
return op->getResult(0).getType().cast<MemRefType>();
35+
}
36+
37+
/// Computes the alignment for the given memory allocation op.
38+
template <typename OpType>
39+
Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
40+
OpType op) const {
41+
MemRefType memRefType = op.getType();
42+
Value alignment;
43+
if (auto alignmentAttr = op.getAlignment()) {
44+
alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
45+
} else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
46+
// In the case where no alignment is specified, we may want to override
47+
// `malloc's` behavior. `malloc` typically aligns at the size of the
48+
// biggest scalar on a target HW. For non-scalars, use the natural
49+
// alignment of the LLVM type given by the LLVM DataLayout.
50+
alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
51+
}
52+
return alignment;
53+
}
54+
55+
/// Computes the alignment for aligned_alloc used to allocate the buffer for
56+
/// the memory allocation op.
57+
///
58+
/// Aligned_alloc requires the allocation size to be a power of two, and the
59+
/// allocation size to be a multiple of the alignment.
60+
template <typename OpType>
61+
int64_t alignedAllocationGetAlignment(ConversionPatternRewriter &rewriter,
62+
Location loc, OpType op,
63+
const DataLayout *defaultLayout) const {
64+
if (Optional<uint64_t> alignment = op.getAlignment())
65+
return *alignment;
66+
67+
// Whenever we don't have alignment set, we will use an alignment
68+
// consistent with the element type; since the allocation size has to be a
69+
// power of two, we will bump to the next power of two if it isn't.
70+
unsigned eltSizeBytes =
71+
getMemRefEltSizeInBytes(op.getType(), op, defaultLayout);
72+
return std::max(kMinAlignedAllocAlignment,
73+
llvm::PowerOf2Ceil(eltSizeBytes));
74+
}
75+
76+
/// Allocates a memory buffer using an allocation method that doesn't
77+
/// guarantee alignment. Returns the pointer and its aligned value.
78+
std::tuple<Value, Value>
79+
allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc,
80+
Value sizeBytes, Operation *op,
81+
Value alignment) const;
82+
83+
/// Allocates a memory buffer using an aligned allocation method.
84+
Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter,
85+
Location loc, Value sizeBytes, Operation *op,
86+
const DataLayout *defaultLayout,
87+
int64_t alignment) const;
88+
89+
private:
90+
/// Computes the byte size for the MemRef element type.
91+
unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op,
92+
const DataLayout *defaultLayout) const;
93+
94+
/// Returns true if the memref size in bytes is known to be a multiple of
95+
/// factor.
96+
bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
97+
const DataLayout *defaultLayout) const;
98+
99+
/// The minimum alignment to use with aligned_alloc (has to be a power of 2).
100+
static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
101+
};
102+
103+
/// Lowering for AllocOp and AllocaOp.
104+
struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering {
105+
explicit AllocLikeOpLLVMLowering(StringRef opName,
106+
LLVMTypeConverter &converter)
107+
: AllocationOpLLVMLowering(opName, converter) {}
108+
109+
protected:
33110
/// Allocates the underlying buffer. Returns the allocated pointer and the
34111
/// aligned pointer.
35112
virtual std::tuple<Value, Value>
36113
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
37114
Value sizeBytes, Operation *op) const = 0;
38115

39116
private:
40-
static MemRefType getMemRefResultType(Operation *op) {
41-
return op->getResult(0).getType().cast<MemRefType>();
42-
}
43-
44117
// An `alloc` is converted into a definition of a memref descriptor value and
45118
// a call to `malloc` to allocate the underlying data buffer. The memref
46119
// descriptor is of the LLVM structure type where:

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,99 @@ def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, []> {
178178
let hasVerifier = 1;
179179
}
180180

181+
//===----------------------------------------------------------------------===//
182+
// ReallocOp
183+
//===----------------------------------------------------------------------===//
184+
185+
186+
def MemRef_ReallocOp : MemRef_Op<"realloc"> {
187+
let summary = "memory reallocation operation";
188+
let description = [{
189+
The `realloc` operation changes the size of a memory region. The memory
190+
region is specified by a 1D source memref and the size of the new memory
191+
region is specified by a 1D result memref type and an optional dynamic Value
192+
of `Index` type. The source and the result memref must be in the same memory
193+
space and have the same element type.
194+
195+
The operation may move the memory region to a new location. In this case,
196+
the content of the memory block is preserved up to the lesser of the new
197+
and old sizes. If the new size if larger, the value of the extended memory
198+
is undefined. This is consistent with the ISO C realloc.
199+
200+
The operation returns an SSA value for the memref.
201+
202+
Example:
203+
204+
```mlir
205+
%0 = memref.realloc %src : memref<64xf32> to memref<124xf32>
206+
```
207+
208+
The source memref may have a dynamic shape, in which case, the compiler will
209+
generate code to extract its size from the runtime data structure for the
210+
memref.
211+
212+
```mlir
213+
%1 = memref.realloc %src : memref<?xf32> to memref<124xf32>
214+
```
215+
216+
If the result memref has a dynamic shape, a result dimension operand is
217+
needed to spefify its dynamic dimension. In the example below, the ssa value
218+
'%d' specifies the unknown dimension of the result memref.
219+
220+
```mlir
221+
%2 = memref.realloc %src(%d) : memref<?xf32> to memref<?xf32>
222+
```
223+
224+
An optional `alignment` attribute may be specified to ensure that the
225+
region of memory that will be indexed is aligned at the specified byte
226+
boundary. This is consistent with the fact that memref.alloc supports such
227+
an optional alignment attribute. Note that in ISO C standard, neither alloc
228+
nor realloc supports alignment, though there is aligned_alloc but not
229+
aligned_realloc.
230+
231+
```mlir
232+
%3 = memref.ralloc %src {alignment = 8} : memref<64xf32> to memref<124xf32>
233+
```
234+
235+
Referencing the memref through the old SSA value after realloc is undefined
236+
behavior.
237+
238+
```mlir
239+
%new = memref.realloc %old : memref<64xf32> to memref<124xf32>
240+
%4 = memref.load %new[%index] // ok
241+
%5 = memref.load %old[%index] // undefined behavior
242+
```
243+
}];
244+
245+
let arguments = (ins MemRefRankOf<[AnyType], [1]>:$source,
246+
Optional<Index>:$dynamicResultSize,
247+
ConfinedAttr<OptionalAttr<I64Attr>,
248+
[IntMinValue<0>]>:$alignment);
249+
250+
let results = (outs MemRefRankOf<[AnyType], [1]>);
251+
252+
let builders = [
253+
OpBuilder<(ins "MemRefType":$resultType,
254+
"Value":$source,
255+
CArg<"Value", "Value()">:$dynamicResultSize), [{
256+
return build($_builder, $_state, resultType, source, dynamicResultSize,
257+
IntegerAttr());
258+
}]>];
259+
260+
let extraClassDeclaration = [{
261+
/// The result of a realloc is always a memref.
262+
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
263+
}];
264+
265+
let assemblyFormat = [{
266+
$source (`(` $dynamicResultSize^ `)`)? attr-dict
267+
`:` type($source) `to` type(results)
268+
}];
269+
270+
let hasCanonicalizer = 1;
271+
let hasVerifier = 1;
272+
}
273+
181274
//===----------------------------------------------------------------------===//
182275
// AllocaOp
183276
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,40 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
10+
#include "mlir/Analysis/DataLayoutAnalysis.h"
11+
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1012
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1113

1214
using namespace mlir;
1315

14-
Value AllocLikeOpLLVMLowering::createAligned(
16+
namespace {
17+
// TODO: Fix the LLVM utilities for looking up functions to take Operation*
18+
// with SymbolTable trait instead of ModuleOp and make similar change here. This
19+
// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
20+
// of getParentOfType<ModuleOp> to pass down the operation.
21+
LLVM::LLVMFuncOp getNotalignedAllocFn(LLVMTypeConverter *typeConverter,
22+
ModuleOp module, Type indexType) {
23+
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
24+
25+
if (useGenericFn)
26+
return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
27+
28+
return LLVM::lookupOrCreateMallocFn(module, indexType);
29+
}
30+
31+
LLVM::LLVMFuncOp getAlignedAllocFn(LLVMTypeConverter *typeConverter,
32+
ModuleOp module, Type indexType) {
33+
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
34+
35+
if (useGenericFn)
36+
return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
37+
38+
return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
39+
}
40+
41+
} // end namespace
42+
43+
Value AllocationOpLLVMLowering::createAligned(
1544
ConversionPatternRewriter &rewriter, Location loc, Value input,
1645
Value alignment) {
1746
Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
@@ -21,6 +50,88 @@ Value AllocLikeOpLLVMLowering::createAligned(
2150
return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
2251
}
2352

53+
std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
54+
ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
55+
Operation *op, Value alignment) const {
56+
if (alignment) {
57+
// Adjust the allocation size to consider alignment.
58+
sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
59+
}
60+
61+
MemRefType memRefType = getMemRefResultType(op);
62+
// Allocate the underlying buffer.
63+
Type elementPtrType = this->getElementPtrType(memRefType);
64+
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
65+
getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
66+
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
67+
Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
68+
results.getResult());
69+
70+
Value alignedPtr = allocatedPtr;
71+
if (alignment) {
72+
// Compute the aligned pointer.
73+
Value allocatedInt =
74+
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
75+
Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
76+
alignedPtr =
77+
rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
78+
}
79+
80+
return std::make_tuple(allocatedPtr, alignedPtr);
81+
}
82+
83+
unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
84+
MemRefType memRefType, Operation *op,
85+
const DataLayout *defaultLayout) const {
86+
const DataLayout *layout = defaultLayout;
87+
if (const DataLayoutAnalysis *analysis =
88+
getTypeConverter()->getDataLayoutAnalysis()) {
89+
layout = &analysis->getAbove(op);
90+
}
91+
Type elementType = memRefType.getElementType();
92+
if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
93+
return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
94+
*layout);
95+
if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
96+
return getTypeConverter()->getUnrankedMemRefDescriptorSize(
97+
memRefElementType, *layout);
98+
return layout->getTypeSize(elementType);
99+
}
100+
101+
bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
102+
MemRefType type, uint64_t factor, Operation *op,
103+
const DataLayout *defaultLayout) const {
104+
uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout);
105+
for (unsigned i = 0, e = type.getRank(); i < e; i++) {
106+
if (ShapedType::isDynamic(type.getDimSize(i)))
107+
continue;
108+
sizeDivisor = sizeDivisor * type.getDimSize(i);
109+
}
110+
return sizeDivisor % factor == 0;
111+
}
112+
113+
Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
114+
ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
115+
Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
116+
Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
117+
118+
MemRefType memRefType = getMemRefResultType(op);
119+
// Function aligned_alloc requires size to be a multiple of alignment; we pad
120+
// the size to the next multiple if necessary.
121+
if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout))
122+
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
123+
124+
Type elementPtrType = this->getElementPtrType(memRefType);
125+
LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
126+
getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
127+
auto results = rewriter.create<LLVM::CallOp>(
128+
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
129+
Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
130+
results.getResult());
131+
132+
return allocatedPtr;
133+
}
134+
24135
LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
25136
Operation *op, ArrayRef<Value> operands,
26137
ConversionPatternRewriter &rewriter) const {

0 commit comments

Comments
 (0)