7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
10
+ #include " mlir/Analysis/DataLayoutAnalysis.h"
11
+ #include " mlir/Dialect/LLVMIR/FunctionCallUtils.h"
10
12
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
11
13
12
14
using namespace mlir ;
13
15
14
- Value AllocLikeOpLLVMLowering::createAligned (
16
+ namespace {
17
+ // TODO: Fix the LLVM utilities for looking up functions to take Operation*
18
+ // with SymbolTable trait instead of ModuleOp and make similar change here. This
19
+ // allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
20
+ // of getParentOfType<ModuleOp> to pass down the operation.
21
+ LLVM::LLVMFuncOp getNotalignedAllocFn (LLVMTypeConverter *typeConverter,
22
+ ModuleOp module , Type indexType) {
23
+ bool useGenericFn = typeConverter->getOptions ().useGenericFunctions ;
24
+
25
+ if (useGenericFn)
26
+ return LLVM::lookupOrCreateGenericAllocFn (module , indexType);
27
+
28
+ return LLVM::lookupOrCreateMallocFn (module , indexType);
29
+ }
30
+
31
+ LLVM::LLVMFuncOp getAlignedAllocFn (LLVMTypeConverter *typeConverter,
32
+ ModuleOp module , Type indexType) {
33
+ bool useGenericFn = typeConverter->getOptions ().useGenericFunctions ;
34
+
35
+ if (useGenericFn)
36
+ return LLVM::lookupOrCreateGenericAlignedAllocFn (module , indexType);
37
+
38
+ return LLVM::lookupOrCreateAlignedAllocFn (module , indexType);
39
+ }
40
+
41
+ } // end namespace
42
+
43
+ Value AllocationOpLLVMLowering::createAligned (
15
44
ConversionPatternRewriter &rewriter, Location loc, Value input,
16
45
Value alignment) {
17
46
Value one = createIndexAttrConstant (rewriter, loc, alignment.getType (), 1 );
@@ -21,6 +50,88 @@ Value AllocLikeOpLLVMLowering::createAligned(
21
50
return rewriter.create <LLVM::SubOp>(loc, bumped, mod);
22
51
}
23
52
53
+ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign (
54
+ ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
55
+ Operation *op, Value alignment) const {
56
+ if (alignment) {
57
+ // Adjust the allocation size to consider alignment.
58
+ sizeBytes = rewriter.create <LLVM::AddOp>(loc, sizeBytes, alignment);
59
+ }
60
+
61
+ MemRefType memRefType = getMemRefResultType (op);
62
+ // Allocate the underlying buffer.
63
+ Type elementPtrType = this ->getElementPtrType (memRefType);
64
+ LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn (
65
+ getTypeConverter (), op->getParentOfType <ModuleOp>(), getIndexType ());
66
+ auto results = rewriter.create <LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
67
+ Value allocatedPtr = rewriter.create <LLVM::BitcastOp>(loc, elementPtrType,
68
+ results.getResult ());
69
+
70
+ Value alignedPtr = allocatedPtr;
71
+ if (alignment) {
72
+ // Compute the aligned pointer.
73
+ Value allocatedInt =
74
+ rewriter.create <LLVM::PtrToIntOp>(loc, getIndexType (), allocatedPtr);
75
+ Value alignmentInt = createAligned (rewriter, loc, allocatedInt, alignment);
76
+ alignedPtr =
77
+ rewriter.create <LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
78
+ }
79
+
80
+ return std::make_tuple (allocatedPtr, alignedPtr);
81
+ }
82
+
83
+ unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes (
84
+ MemRefType memRefType, Operation *op,
85
+ const DataLayout *defaultLayout) const {
86
+ const DataLayout *layout = defaultLayout;
87
+ if (const DataLayoutAnalysis *analysis =
88
+ getTypeConverter ()->getDataLayoutAnalysis ()) {
89
+ layout = &analysis->getAbove (op);
90
+ }
91
+ Type elementType = memRefType.getElementType ();
92
+ if (auto memRefElementType = elementType.dyn_cast <MemRefType>())
93
+ return getTypeConverter ()->getMemRefDescriptorSize (memRefElementType,
94
+ *layout);
95
+ if (auto memRefElementType = elementType.dyn_cast <UnrankedMemRefType>())
96
+ return getTypeConverter ()->getUnrankedMemRefDescriptorSize (
97
+ memRefElementType, *layout);
98
+ return layout->getTypeSize (elementType);
99
+ }
100
+
101
+ bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf (
102
+ MemRefType type, uint64_t factor, Operation *op,
103
+ const DataLayout *defaultLayout) const {
104
+ uint64_t sizeDivisor = getMemRefEltSizeInBytes (type, op, defaultLayout);
105
+ for (unsigned i = 0 , e = type.getRank (); i < e; i++) {
106
+ if (ShapedType::isDynamic (type.getDimSize (i)))
107
+ continue ;
108
+ sizeDivisor = sizeDivisor * type.getDimSize (i);
109
+ }
110
+ return sizeDivisor % factor == 0 ;
111
+ }
112
+
113
+ Value AllocationOpLLVMLowering::allocateBufferAutoAlign (
114
+ ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
115
+ Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
116
+ Value allocAlignment = createIndexConstant (rewriter, loc, alignment);
117
+
118
+ MemRefType memRefType = getMemRefResultType (op);
119
+ // Function aligned_alloc requires size to be a multiple of alignment; we pad
120
+ // the size to the next multiple if necessary.
121
+ if (!isMemRefSizeMultipleOf (memRefType, alignment, op, defaultLayout))
122
+ sizeBytes = createAligned (rewriter, loc, sizeBytes, allocAlignment);
123
+
124
+ Type elementPtrType = this ->getElementPtrType (memRefType);
125
+ LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn (
126
+ getTypeConverter (), op->getParentOfType <ModuleOp>(), getIndexType ());
127
+ auto results = rewriter.create <LLVM::CallOp>(
128
+ loc, allocFuncOp, ValueRange ({allocAlignment, sizeBytes}));
129
+ Value allocatedPtr = rewriter.create <LLVM::BitcastOp>(loc, elementPtrType,
130
+ results.getResult ());
131
+
132
+ return allocatedPtr;
133
+ }
134
+
24
135
LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite (
25
136
Operation *op, ArrayRef<Value> operands,
26
137
ConversionPatternRewriter &rewriter) const {
0 commit comments