1 //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/LLVMCommon/Pattern.h"
10 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
13 #include "mlir/IR/AffineMap.h"
14
15 using namespace mlir;
16
17 //===----------------------------------------------------------------------===//
18 // ConvertToLLVMPattern
19 //===----------------------------------------------------------------------===//
20
ConvertToLLVMPattern(StringRef rootOpName,MLIRContext * context,LLVMTypeConverter & typeConverter,PatternBenefit benefit)21 ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
22 MLIRContext *context,
23 LLVMTypeConverter &typeConverter,
24 PatternBenefit benefit)
25 : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
26
getTypeConverter() const27 LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
28 return static_cast<LLVMTypeConverter *>(
29 ConversionPattern::getTypeConverter());
30 }
31
getDialect() const32 LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
33 return *getTypeConverter()->getDialect();
34 }
35
getIndexType() const36 Type ConvertToLLVMPattern::getIndexType() const {
37 return getTypeConverter()->getIndexType();
38 }
39
getIntPtrType(unsigned addressSpace) const40 Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
41 return IntegerType::get(&getTypeConverter()->getContext(),
42 getTypeConverter()->getPointerBitwidth(addressSpace));
43 }
44
getVoidType() const45 Type ConvertToLLVMPattern::getVoidType() const {
46 return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
47 }
48
getVoidPtrType() const49 Type ConvertToLLVMPattern::getVoidPtrType() const {
50 return LLVM::LLVMPointerType::get(
51 IntegerType::get(&getTypeConverter()->getContext(), 8));
52 }
53
createIndexAttrConstant(OpBuilder & builder,Location loc,Type resultType,int64_t value)54 Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
55 Location loc,
56 Type resultType,
57 int64_t value) {
58 return builder.create<LLVM::ConstantOp>(
59 loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
60 }
61
createIndexConstant(ConversionPatternRewriter & builder,Location loc,uint64_t value) const62 Value ConvertToLLVMPattern::createIndexConstant(
63 ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
64 return createIndexAttrConstant(builder, loc, getIndexType(), value);
65 }
66
getStridedElementPtr(Location loc,MemRefType type,Value memRefDesc,ValueRange indices,ConversionPatternRewriter & rewriter) const67 Value ConvertToLLVMPattern::getStridedElementPtr(
68 Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
69 ConversionPatternRewriter &rewriter) const {
70
71 int64_t offset;
72 SmallVector<int64_t, 4> strides;
73 auto successStrides = getStridesAndOffset(type, strides, offset);
74 assert(succeeded(successStrides) && "unexpected non-strided memref");
75 (void)successStrides;
76
77 MemRefDescriptor memRefDescriptor(memRefDesc);
78 Value base = memRefDescriptor.alignedPtr(rewriter, loc);
79
80 Value index;
81 if (offset != 0) // Skip if offset is zero.
82 index = MemRefType::isDynamicStrideOrOffset(offset)
83 ? memRefDescriptor.offset(rewriter, loc)
84 : createIndexConstant(rewriter, loc, offset);
85
86 for (int i = 0, e = indices.size(); i < e; ++i) {
87 Value increment = indices[i];
88 if (strides[i] != 1) { // Skip if stride is 1.
89 Value stride = MemRefType::isDynamicStrideOrOffset(strides[i])
90 ? memRefDescriptor.stride(rewriter, loc, i)
91 : createIndexConstant(rewriter, loc, strides[i]);
92 increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
93 }
94 index =
95 index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
96 }
97
98 Type elementPtrType = memRefDescriptor.getElementPtrType();
99 return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
100 : base;
101 }
102
103 // Check if the MemRefType `type` is supported by the lowering. We currently
104 // only support memrefs with identity maps.
isConvertibleAndHasIdentityMaps(MemRefType type) const105 bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
106 MemRefType type) const {
107 if (!typeConverter->convertType(type.getElementType()))
108 return false;
109 return type.getAffineMaps().empty() ||
110 llvm::all_of(type.getAffineMaps(),
111 [](AffineMap map) { return map.isIdentity(); });
112 }
113
getElementPtrType(MemRefType type) const114 Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
115 auto elementType = type.getElementType();
116 auto structElementType = typeConverter->convertType(elementType);
117 return LLVM::LLVMPointerType::get(structElementType,
118 type.getMemorySpaceAsInt());
119 }
120
getMemRefDescriptorSizes(Location loc,MemRefType memRefType,ValueRange dynamicSizes,ConversionPatternRewriter & rewriter,SmallVectorImpl<Value> & sizes,SmallVectorImpl<Value> & strides,Value & sizeBytes) const121 void ConvertToLLVMPattern::getMemRefDescriptorSizes(
122 Location loc, MemRefType memRefType, ValueRange dynamicSizes,
123 ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
124 SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
125 assert(isConvertibleAndHasIdentityMaps(memRefType) &&
126 "layout maps must have been normalized away");
127 assert(count(memRefType.getShape(), ShapedType::kDynamicSize) ==
128 static_cast<ssize_t>(dynamicSizes.size()) &&
129 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
130
131 sizes.reserve(memRefType.getRank());
132 unsigned dynamicIndex = 0;
133 for (int64_t size : memRefType.getShape()) {
134 sizes.push_back(size == ShapedType::kDynamicSize
135 ? dynamicSizes[dynamicIndex++]
136 : createIndexConstant(rewriter, loc, size));
137 }
138
139 // Strides: iterate sizes in reverse order and multiply.
140 int64_t stride = 1;
141 Value runningStride = createIndexConstant(rewriter, loc, 1);
142 strides.resize(memRefType.getRank());
143 for (auto i = memRefType.getRank(); i-- > 0;) {
144 strides[i] = runningStride;
145
146 int64_t size = memRefType.getShape()[i];
147 if (size == 0)
148 continue;
149 bool useSizeAsStride = stride == 1;
150 if (size == ShapedType::kDynamicSize)
151 stride = ShapedType::kDynamicSize;
152 if (stride != ShapedType::kDynamicSize)
153 stride *= size;
154
155 if (useSizeAsStride)
156 runningStride = sizes[i];
157 else if (stride == ShapedType::kDynamicSize)
158 runningStride =
159 rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
160 else
161 runningStride = createIndexConstant(rewriter, loc, stride);
162 }
163
164 // Buffer size in bytes.
165 Type elementPtrType = getElementPtrType(memRefType);
166 Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
167 Value gepPtr = rewriter.create<LLVM::GEPOp>(
168 loc, elementPtrType, ArrayRef<Value>{nullPtr, runningStride});
169 sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
170 }
171
getSizeInBytes(Location loc,Type type,ConversionPatternRewriter & rewriter) const172 Value ConvertToLLVMPattern::getSizeInBytes(
173 Location loc, Type type, ConversionPatternRewriter &rewriter) const {
174 // Compute the size of an individual element. This emits the MLIR equivalent
175 // of the following sizeof(...) implementation in LLVM IR:
176 // %0 = getelementptr %elementType* null, %indexType 1
177 // %1 = ptrtoint %elementType* %0 to %indexType
178 // which is a common pattern of getting the size of a type in bytes.
179 auto convertedPtrType =
180 LLVM::LLVMPointerType::get(typeConverter->convertType(type));
181 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
182 auto gep = rewriter.create<LLVM::GEPOp>(
183 loc, convertedPtrType,
184 ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
185 return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
186 }
187
getNumElements(Location loc,ArrayRef<Value> shape,ConversionPatternRewriter & rewriter) const188 Value ConvertToLLVMPattern::getNumElements(
189 Location loc, ArrayRef<Value> shape,
190 ConversionPatternRewriter &rewriter) const {
191 // Compute the total number of memref elements.
192 Value numElements =
193 shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
194 for (unsigned i = 1, e = shape.size(); i < e; ++i)
195 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
196 return numElements;
197 }
198
199 /// Creates and populates the memref descriptor struct given all its fields.
createMemRefDescriptor(Location loc,MemRefType memRefType,Value allocatedPtr,Value alignedPtr,ArrayRef<Value> sizes,ArrayRef<Value> strides,ConversionPatternRewriter & rewriter) const200 MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
201 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
202 ArrayRef<Value> sizes, ArrayRef<Value> strides,
203 ConversionPatternRewriter &rewriter) const {
204 auto structType = typeConverter->convertType(memRefType);
205 auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
206
207 // Field 1: Allocated pointer, used for malloc/free.
208 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
209
210 // Field 2: Actual aligned pointer to payload.
211 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
212
213 // Field 3: Offset in aligned pointer.
214 memRefDescriptor.setOffset(rewriter, loc,
215 createIndexConstant(rewriter, loc, 0));
216
217 // Fields 4: Sizes.
218 for (auto en : llvm::enumerate(sizes))
219 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
220
221 // Field 5: Strides.
222 for (auto en : llvm::enumerate(strides))
223 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
224
225 return memRefDescriptor;
226 }
227
copyUnrankedDescriptors(OpBuilder & builder,Location loc,TypeRange origTypes,SmallVectorImpl<Value> & operands,bool toDynamic) const228 LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
229 OpBuilder &builder, Location loc, TypeRange origTypes,
230 SmallVectorImpl<Value> &operands, bool toDynamic) const {
231 assert(origTypes.size() == operands.size() &&
232 "expected as may original types as operands");
233
234 // Find operands of unranked memref type and store them.
235 SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs;
236 for (unsigned i = 0, e = operands.size(); i < e; ++i)
237 if (origTypes[i].isa<UnrankedMemRefType>())
238 unrankedMemrefs.emplace_back(operands[i]);
239
240 if (unrankedMemrefs.empty())
241 return success();
242
243 // Compute allocation sizes.
244 SmallVector<Value, 4> sizes;
245 UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
246 unrankedMemrefs, sizes);
247
248 // Get frequently used types.
249 MLIRContext *context = builder.getContext();
250 Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
251 auto i1Type = IntegerType::get(context, 1);
252 Type indexType = getTypeConverter()->getIndexType();
253
254 // Find the malloc and free, or declare them if necessary.
255 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
256 LLVM::LLVMFuncOp freeFunc, mallocFunc;
257 if (toDynamic)
258 mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
259 if (!toDynamic)
260 freeFunc = LLVM::lookupOrCreateFreeFn(module);
261
262 // Initialize shared constants.
263 Value zero =
264 builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false));
265
266 unsigned unrankedMemrefPos = 0;
267 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
268 Type type = origTypes[i];
269 if (!type.isa<UnrankedMemRefType>())
270 continue;
271 Value allocationSize = sizes[unrankedMemrefPos++];
272 UnrankedMemRefDescriptor desc(operands[i]);
273
274 // Allocate memory, copy, and free the source if necessary.
275 Value memory =
276 toDynamic
277 ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
278 .getResult(0)
279 : builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
280 /*alignment=*/0);
281 Value source = desc.memRefDescPtr(builder, loc);
282 builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero);
283 if (!toDynamic)
284 builder.create<LLVM::CallOp>(loc, freeFunc, source);
285
286 // Create a new descriptor. The same descriptor can be returned multiple
287 // times, attempting to modify its pointer can lead to memory leaks
288 // (allocated twice and overwritten) or double frees (the caller does not
289 // know if the descriptor points to the same memory).
290 Type descriptorType = getTypeConverter()->convertType(type);
291 if (!descriptorType)
292 return failure();
293 auto updatedDesc =
294 UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
295 Value rank = desc.rank(builder, loc);
296 updatedDesc.setRank(builder, loc, rank);
297 updatedDesc.setMemRefDescPtr(builder, loc, memory);
298
299 operands[i] = updatedDesc;
300 }
301
302 return success();
303 }
304
305 //===----------------------------------------------------------------------===//
306 // Detail methods
307 //===----------------------------------------------------------------------===//
308
309 /// Replaces the given operation "op" with a new operation of type "targetOp"
310 /// and given operands.
oneToOneRewrite(Operation * op,StringRef targetOp,ValueRange operands,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)311 LogicalResult LLVM::detail::oneToOneRewrite(
312 Operation *op, StringRef targetOp, ValueRange operands,
313 LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
314 unsigned numResults = op->getNumResults();
315
316 Type packedType;
317 if (numResults != 0) {
318 packedType = typeConverter.packFunctionResults(op->getResultTypes());
319 if (!packedType)
320 return failure();
321 }
322
323 // Create the operation through state since we don't know its C++ type.
324 OperationState state(op->getLoc(), targetOp);
325 state.addTypes(packedType);
326 state.addOperands(operands);
327 state.addAttributes(op->getAttrs());
328 Operation *newOp = rewriter.createOperation(state);
329
330 // If the operation produced 0 or 1 result, return them immediately.
331 if (numResults == 0)
332 return rewriter.eraseOp(op), success();
333 if (numResults == 1)
334 return rewriter.replaceOp(op, newOp->getResult(0)), success();
335
336 // Otherwise, it had been converted to an operation producing a structure.
337 // Extract individual results from the structure and return them as list.
338 SmallVector<Value, 4> results;
339 results.reserve(numResults);
340 for (unsigned i = 0; i < numResults; ++i) {
341 auto type = typeConverter.convertType(op->getResult(i).getType());
342 results.push_back(rewriter.create<LLVM::ExtractValueOp>(
343 op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
344 }
345 rewriter.replaceOp(op, results);
346 return success();
347 }
348