1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
AllocOpLowering__anona348bc670111::AllocOpLowering27   AllocOpLowering(LLVMTypeConverter &converter)
28       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
29                                 converter) {}
30 
allocateBuffer__anona348bc670111::AllocOpLowering31   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
32                                           Location loc, Value sizeBytes,
33                                           Operation *op) const override {
34     // Heap allocations.
35     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
36     MemRefType memRefType = allocOp.getType();
37 
38     Value alignment;
39     if (auto alignmentAttr = allocOp.alignment()) {
40       alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
41     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
42       // In the case where no alignment is specified, we may want to override
43       // `malloc's` behavior. `malloc` typically aligns at the size of the
44       // biggest scalar on a target HW. For non-scalars, use the natural
45       // alignment of the LLVM type given by the LLVM DataLayout.
46       alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
47     }
48 
49     if (alignment) {
50       // Adjust the allocation size to consider alignment.
51       sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
52     }
53 
54     // Allocate the underlying buffer and store a pointer to it in the MemRef
55     // descriptor.
56     Type elementPtrType = this->getElementPtrType(memRefType);
57     auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
58         allocOp->getParentOfType<ModuleOp>(), getIndexType());
59     auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
60                                   getVoidPtrType());
61     Value allocatedPtr =
62         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
63 
64     Value alignedPtr = allocatedPtr;
65     if (alignment) {
66       // Compute the aligned type pointer.
67       Value allocatedInt =
68           rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
69       Value alignmentInt =
70           createAligned(rewriter, loc, allocatedInt, alignment);
71       alignedPtr =
72           rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
73     }
74 
75     return std::make_tuple(allocatedPtr, alignedPtr);
76   }
77 };
78 
79 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
AlignedAllocOpLowering__anona348bc670111::AlignedAllocOpLowering80   AlignedAllocOpLowering(LLVMTypeConverter &converter)
81       : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
82                                 converter) {}
83 
84   /// Returns the memref's element size in bytes using the data layout active at
85   /// `op`.
86   // TODO: there are other places where this is used. Expose publicly?
getMemRefEltSizeInBytes__anona348bc670111::AlignedAllocOpLowering87   unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
88     const DataLayout *layout = &defaultLayout;
89     if (const DataLayoutAnalysis *analysis =
90             getTypeConverter()->getDataLayoutAnalysis()) {
91       layout = &analysis->getAbove(op);
92     }
93     Type elementType = memRefType.getElementType();
94     if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
95       return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
96                                                          *layout);
97     if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
98       return getTypeConverter()->getUnrankedMemRefDescriptorSize(
99           memRefElementType, *layout);
100     return layout->getTypeSize(elementType);
101   }
102 
103   /// Returns true if the memref size in bytes is known to be a multiple of
104   /// factor assuming the data layout active at `op`.
isMemRefSizeMultipleOf__anona348bc670111::AlignedAllocOpLowering105   bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
106                               Operation *op) const {
107     uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
108     for (unsigned i = 0, e = type.getRank(); i < e; i++) {
109       if (type.isDynamic(type.getDimSize(i)))
110         continue;
111       sizeDivisor = sizeDivisor * type.getDimSize(i);
112     }
113     return sizeDivisor % factor == 0;
114   }
115 
116   /// Returns the alignment to be used for the allocation call itself.
117   /// aligned_alloc requires the allocation size to be a power of two, and the
118   /// allocation size to be a multiple of alignment,
getAllocationAlignment__anona348bc670111::AlignedAllocOpLowering119   int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
120     if (Optional<uint64_t> alignment = allocOp.alignment())
121       return *alignment;
122 
123     // Whenever we don't have alignment set, we will use an alignment
124     // consistent with the element type; since the allocation size has to be a
125     // power of two, we will bump to the next power of two if it already isn't.
126     auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
127     return std::max(kMinAlignedAllocAlignment,
128                     llvm::PowerOf2Ceil(eltSizeBytes));
129   }
130 
allocateBuffer__anona348bc670111::AlignedAllocOpLowering131   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
132                                           Location loc, Value sizeBytes,
133                                           Operation *op) const override {
134     // Heap allocations.
135     memref::AllocOp allocOp = cast<memref::AllocOp>(op);
136     MemRefType memRefType = allocOp.getType();
137     int64_t alignment = getAllocationAlignment(allocOp);
138     Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
139 
140     // aligned_alloc requires size to be a multiple of alignment; we will pad
141     // the size to the next multiple if necessary.
142     if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
143       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
144 
145     Type elementPtrType = this->getElementPtrType(memRefType);
146     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
147         allocOp->getParentOfType<ModuleOp>(), getIndexType());
148     auto results =
149         createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
150                        getVoidPtrType());
151     Value allocatedPtr =
152         rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
153 
154     return std::make_tuple(allocatedPtr, allocatedPtr);
155   }
156 
157   /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
158   static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
159 
160   /// Default layout to use in absence of the corresponding analysis.
161   DataLayout defaultLayout;
162 };
163 
164 // Out of line definition, required till C++17.
165 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
166 
167 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
AllocaOpLowering__anona348bc670111::AllocaOpLowering168   AllocaOpLowering(LLVMTypeConverter &converter)
169       : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
170                                 converter) {}
171 
172   /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
173   /// is set to null for stack allocations. `accessAlignment` is set if
174   /// alignment is needed post allocation (for eg. in conjunction with malloc).
allocateBuffer__anona348bc670111::AllocaOpLowering175   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
176                                           Location loc, Value sizeBytes,
177                                           Operation *op) const override {
178 
179     // With alloca, one gets a pointer to the element type right away.
180     // For stack allocations.
181     auto allocaOp = cast<memref::AllocaOp>(op);
182     auto elementPtrType = this->getElementPtrType(allocaOp.getType());
183 
184     auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
185         loc, elementPtrType, sizeBytes,
186         allocaOp.alignment() ? *allocaOp.alignment() : 0);
187 
188     return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
189   }
190 };
191 
192 struct AllocaScopeOpLowering
193     : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
194   using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
195 
196   LogicalResult
matchAndRewrite__anona348bc670111::AllocaScopeOpLowering197   matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, ArrayRef<Value> operands,
198                   ConversionPatternRewriter &rewriter) const override {
199     OpBuilder::InsertionGuard guard(rewriter);
200     Location loc = allocaScopeOp.getLoc();
201 
202     // Split the current block before the AllocaScopeOp to create the inlining
203     // point.
204     auto *currentBlock = rewriter.getInsertionBlock();
205     auto *remainingOpsBlock =
206         rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
207     Block *continueBlock;
208     if (allocaScopeOp.getNumResults() == 0) {
209       continueBlock = remainingOpsBlock;
210     } else {
211       continueBlock = rewriter.createBlock(remainingOpsBlock,
212                                            allocaScopeOp.getResultTypes());
213       rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
214     }
215 
216     // Inline body region.
217     Block *beforeBody = &allocaScopeOp.bodyRegion().front();
218     Block *afterBody = &allocaScopeOp.bodyRegion().back();
219     rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
220 
221     // Save stack and then branch into the body of the region.
222     rewriter.setInsertionPointToEnd(currentBlock);
223     auto stackSaveOp =
224         rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
225     rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
226 
227     // Replace the alloca_scope return with a branch that jumps out of the body.
228     // Stack restore before leaving the body region.
229     rewriter.setInsertionPointToEnd(afterBody);
230     auto returnOp =
231         cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
232     auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
233         returnOp, returnOp.results(), continueBlock);
234 
235     // Insert stack restore before jumping out the body of the region.
236     rewriter.setInsertionPoint(branchOp);
237     rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
238 
239     // Replace the op with values return from the body region.
240     rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
241 
242     return success();
243   }
244 };
245 
246 struct AssumeAlignmentOpLowering
247     : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
248   using ConvertOpToLLVMPattern<
249       memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
250 
251   LogicalResult
matchAndRewrite__anona348bc670111::AssumeAlignmentOpLowering252   matchAndRewrite(memref::AssumeAlignmentOp op, ArrayRef<Value> operands,
253                   ConversionPatternRewriter &rewriter) const override {
254     memref::AssumeAlignmentOp::Adaptor transformed(operands);
255     Value memref = transformed.memref();
256     unsigned alignment = op.alignment();
257     auto loc = op.getLoc();
258 
259     MemRefDescriptor memRefDescriptor(memref);
260     Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
261 
262     // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
263     // the asserted memref.alignedPtr isn't used anywhere else, as the real
264     // users like load/store/views always re-extract memref.alignedPtr as they
265     // get lowered.
266     //
267     // This relies on LLVM's CSE optimization (potentially after SROA), since
268     // after CSE all memref.alignedPtr instances get de-duplicated into the same
269     // pointer SSA value.
270     auto intPtrType =
271         getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
272     Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
273     Value mask =
274         createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
275     Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
276     rewriter.create<LLVM::AssumeOp>(
277         loc, rewriter.create<LLVM::ICmpOp>(
278                  loc, LLVM::ICmpPredicate::eq,
279                  rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
280 
281     rewriter.eraseOp(op);
282     return success();
283   }
284 };
285 
286 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
287 // The memref descriptor being an SSA value, there is no need to clean it up
288 // in any way.
289 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
290   using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
291 
DeallocOpLowering__anona348bc670111::DeallocOpLowering292   explicit DeallocOpLowering(LLVMTypeConverter &converter)
293       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
294 
295   LogicalResult
matchAndRewrite__anona348bc670111::DeallocOpLowering296   matchAndRewrite(memref::DeallocOp op, ArrayRef<Value> operands,
297                   ConversionPatternRewriter &rewriter) const override {
298     assert(operands.size() == 1 && "dealloc takes one operand");
299     memref::DeallocOp::Adaptor transformed(operands);
300 
301     // Insert the `free` declaration if it is not already present.
302     auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
303     MemRefDescriptor memref(transformed.memref());
304     Value casted = rewriter.create<LLVM::BitcastOp>(
305         op.getLoc(), getVoidPtrType(),
306         memref.allocatedPtr(rewriter, op.getLoc()));
307     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
308         op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted);
309     return success();
310   }
311 };
312 
313 // A `dim` is converted to a constant for static sizes and to an access to the
314 // size stored in the memref descriptor for dynamic sizes.
315 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
316   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
317 
318   LogicalResult
matchAndRewrite__anona348bc670111::DimOpLowering319   matchAndRewrite(memref::DimOp dimOp, ArrayRef<Value> operands,
320                   ConversionPatternRewriter &rewriter) const override {
321     Type operandType = dimOp.source().getType();
322     if (operandType.isa<UnrankedMemRefType>()) {
323       rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef(
324                                     operandType, dimOp, operands, rewriter)});
325 
326       return success();
327     }
328     if (operandType.isa<MemRefType>()) {
329       rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(
330                                     operandType, dimOp, operands, rewriter)});
331       return success();
332     }
333     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
334   }
335 
336 private:
extractSizeOfUnrankedMemRef__anona348bc670111::DimOpLowering337   Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
338                                     ArrayRef<Value> operands,
339                                     ConversionPatternRewriter &rewriter) const {
340     Location loc = dimOp.getLoc();
341     memref::DimOp::Adaptor transformed(operands);
342 
343     auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
344     auto scalarMemRefType =
345         MemRefType::get({}, unrankedMemRefType.getElementType());
346     unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
347 
348     // Extract pointer to the underlying ranked descriptor and bitcast it to a
349     // memref<element_type> descriptor pointer to minimize the number of GEP
350     // operations.
351     UnrankedMemRefDescriptor unrankedDesc(transformed.source());
352     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
353     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
354         loc,
355         LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
356                                    addressSpace),
357         underlyingRankedDesc);
358 
359     // Get pointer to offset field of memref<element_type> descriptor.
360     Type indexPtrTy = LLVM::LLVMPointerType::get(
361         getTypeConverter()->getIndexType(), addressSpace);
362     Value two = rewriter.create<LLVM::ConstantOp>(
363         loc, typeConverter->convertType(rewriter.getI32Type()),
364         rewriter.getI32IntegerAttr(2));
365     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
366         loc, indexPtrTy, scalarMemRefDescPtr,
367         ValueRange({createIndexConstant(rewriter, loc, 0), two}));
368 
369     // The size value that we have to extract can be obtained using GEPop with
370     // `dimOp.index() + 1` index argument.
371     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
372         loc, createIndexConstant(rewriter, loc, 1), transformed.index());
373     Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
374                                                  ValueRange({idxPlusOne}));
375     return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
376   }
377 
getConstantDimIndex__anona348bc670111::DimOpLowering378   Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
379     if (Optional<int64_t> idx = dimOp.getConstantIndex())
380       return idx;
381 
382     if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
383       return constantOp.value().cast<IntegerAttr>().getValue().getSExtValue();
384 
385     return llvm::None;
386   }
387 
extractSizeOfRankedMemRef__anona348bc670111::DimOpLowering388   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
389                                   ArrayRef<Value> operands,
390                                   ConversionPatternRewriter &rewriter) const {
391     Location loc = dimOp.getLoc();
392     memref::DimOp::Adaptor transformed(operands);
393     // Take advantage if index is constant.
394     MemRefType memRefType = operandType.cast<MemRefType>();
395     if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
396       int64_t i = index.getValue();
397       if (memRefType.isDynamicDim(i)) {
398         // extract dynamic size from the memref descriptor.
399         MemRefDescriptor descriptor(transformed.source());
400         return descriptor.size(rewriter, loc, i);
401       }
402       // Use constant for static size.
403       int64_t dimSize = memRefType.getDimSize(i);
404       return createIndexConstant(rewriter, loc, dimSize);
405     }
406     Value index = transformed.index();
407     int64_t rank = memRefType.getRank();
408     MemRefDescriptor memrefDescriptor(transformed.source());
409     return memrefDescriptor.size(rewriter, loc, index, rank);
410   }
411 };
412 
413 /// Returns the LLVM type of the global variable given the memref type `type`.
convertGlobalMemrefTypeToLLVM(MemRefType type,LLVMTypeConverter & typeConverter)414 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
415                                           LLVMTypeConverter &typeConverter) {
416   // LLVM type for a global memref will be a multi-dimension array. For
417   // declarations or uninitialized global memrefs, we can potentially flatten
418   // this to a 1D array. However, for memref.global's with an initial value,
419   // we do not intend to flatten the ElementsAttribute when going from std ->
420   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
421   Type elementType = typeConverter.convertType(type.getElementType());
422   Type arrayTy = elementType;
423   // Shape has the outermost dim at index 0, so need to walk it backwards
424   for (int64_t dim : llvm::reverse(type.getShape()))
425     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
426   return arrayTy;
427 }
428 
429 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
430 struct GlobalMemrefOpLowering
431     : public ConvertOpToLLVMPattern<memref::GlobalOp> {
432   using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
433 
434   LogicalResult
matchAndRewrite__anona348bc670111::GlobalMemrefOpLowering435   matchAndRewrite(memref::GlobalOp global, ArrayRef<Value> operands,
436                   ConversionPatternRewriter &rewriter) const override {
437     MemRefType type = global.type().cast<MemRefType>();
438     if (!isConvertibleAndHasIdentityMaps(type))
439       return failure();
440 
441     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
442 
443     LLVM::Linkage linkage =
444         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
445 
446     Attribute initialValue = nullptr;
447     if (!global.isExternal() && !global.isUninitialized()) {
448       auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
449       initialValue = elementsAttr;
450 
451       // For scalar memrefs, the global variable created is of the element type,
452       // so unpack the elements attribute to extract the value.
453       if (type.getRank() == 0)
454         initialValue = elementsAttr.getValue({});
455     }
456 
457     rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
458         global, arrayTy, global.constant(), linkage, global.sym_name(),
459         initialValue, /*alignment=*/0, type.getMemorySpaceAsInt());
460     return success();
461   }
462 };
463 
464 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
465 /// the first element stashed into the descriptor. This reuses
466 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
467 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
GetGlobalMemrefOpLowering__anona348bc670111::GetGlobalMemrefOpLowering468   GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
469       : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
470                                 converter) {}
471 
472   /// Buffer "allocation" for memref.get_global op is getting the address of
473   /// the global variable referenced.
allocateBuffer__anona348bc670111::GetGlobalMemrefOpLowering474   std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
475                                           Location loc, Value sizeBytes,
476                                           Operation *op) const override {
477     auto getGlobalOp = cast<memref::GetGlobalOp>(op);
478     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
479     unsigned memSpace = type.getMemorySpaceAsInt();
480 
481     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
482     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
483         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
484 
485     // Get the address of the first element in the array by creating a GEP with
486     // the address of the GV as the base, and (rank + 1) number of 0 indices.
487     Type elementType = typeConverter->convertType(type.getElementType());
488     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
489 
490     SmallVector<Value, 4> operands = {addressOf};
491     operands.insert(operands.end(), type.getRank() + 1,
492                     createIndexConstant(rewriter, loc, 0));
493     auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands);
494 
495     // We do not expect the memref obtained using `memref.get_global` to be
496     // ever deallocated. Set the allocated pointer to be known bad value to
497     // help debug if that ever happens.
498     auto intPtrType = getIntPtrType(memSpace);
499     Value deadBeefConst =
500         createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
501     auto deadBeefPtr =
502         rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
503 
504     // Both allocated and aligned pointers are same. We could potentially stash
505     // a nullptr for the allocated pointer since we do not expect any dealloc.
506     return std::make_tuple(deadBeefPtr, gep);
507   }
508 };
509 
510 // Common base for load and store operations on MemRefs. Restricts the match
511 // to supported MemRef types. Provides functionality to emit code accessing a
512 // specific element of the underlying data buffer.
513 template <typename Derived>
514 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
515   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
516   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
517   using Base = LoadStoreOpLowering<Derived>;
518 
match__anona348bc670111::LoadStoreOpLowering519   LogicalResult match(Derived op) const override {
520     MemRefType type = op.getMemRefType();
521     return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
522   }
523 };
524 
525 // Load operation is lowered to obtaining a pointer to the indexed element
526 // and loading it.
527 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
528   using Base::Base;
529 
530   LogicalResult
matchAndRewrite__anona348bc670111::LoadOpLowering531   matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
532                   ConversionPatternRewriter &rewriter) const override {
533     memref::LoadOp::Adaptor transformed(operands);
534     auto type = loadOp.getMemRefType();
535 
536     Value dataPtr =
537         getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(),
538                              transformed.indices(), rewriter);
539     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
540     return success();
541   }
542 };
543 
544 // Store operation is lowered to obtaining a pointer to the indexed element,
545 // and storing the given value to it.
546 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
547   using Base::Base;
548 
549   LogicalResult
matchAndRewrite__anona348bc670111::StoreOpLowering550   matchAndRewrite(memref::StoreOp op, ArrayRef<Value> operands,
551                   ConversionPatternRewriter &rewriter) const override {
552     auto type = op.getMemRefType();
553     memref::StoreOp::Adaptor transformed(operands);
554 
555     Value dataPtr =
556         getStridedElementPtr(op.getLoc(), type, transformed.memref(),
557                              transformed.indices(), rewriter);
558     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
559                                                dataPtr);
560     return success();
561   }
562 };
563 
564 // The prefetch operation is lowered in a way similar to the load operation
565 // except that the llvm.prefetch operation is used for replacement.
566 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
567   using Base::Base;
568 
569   LogicalResult
matchAndRewrite__anona348bc670111::PrefetchOpLowering570   matchAndRewrite(memref::PrefetchOp prefetchOp, ArrayRef<Value> operands,
571                   ConversionPatternRewriter &rewriter) const override {
572     memref::PrefetchOp::Adaptor transformed(operands);
573     auto type = prefetchOp.getMemRefType();
574     auto loc = prefetchOp.getLoc();
575 
576     Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(),
577                                          transformed.indices(), rewriter);
578 
579     // Replace with llvm.prefetch.
580     auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
581     auto isWrite = rewriter.create<LLVM::ConstantOp>(
582         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
583     auto localityHint = rewriter.create<LLVM::ConstantOp>(
584         loc, llvmI32Type,
585         rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
586     auto isData = rewriter.create<LLVM::ConstantOp>(
587         loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
588 
589     rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
590                                                 localityHint, isData);
591     return success();
592   }
593 };
594 
595 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
596   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
597 
match__anona348bc670111::MemRefCastOpLowering598   LogicalResult match(memref::CastOp memRefCastOp) const override {
599     Type srcType = memRefCastOp.getOperand().getType();
600     Type dstType = memRefCastOp.getType();
601 
602     // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
603     // used for type erasure. For now they must preserve underlying element type
604     // and require source and result type to have the same rank. Therefore,
605     // perform a sanity check that the underlying structs are the same. Once op
606     // semantics are relaxed we can revisit.
607     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
608       return success(typeConverter->convertType(srcType) ==
609                      typeConverter->convertType(dstType));
610 
611     // At least one of the operands is unranked type
612     assert(srcType.isa<UnrankedMemRefType>() ||
613            dstType.isa<UnrankedMemRefType>());
614 
615     // Unranked to unranked cast is disallowed
616     return !(srcType.isa<UnrankedMemRefType>() &&
617              dstType.isa<UnrankedMemRefType>())
618                ? success()
619                : failure();
620   }
621 
rewrite__anona348bc670111::MemRefCastOpLowering622   void rewrite(memref::CastOp memRefCastOp, ArrayRef<Value> operands,
623                ConversionPatternRewriter &rewriter) const override {
624     memref::CastOp::Adaptor transformed(operands);
625 
626     auto srcType = memRefCastOp.getOperand().getType();
627     auto dstType = memRefCastOp.getType();
628     auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
629     auto loc = memRefCastOp.getLoc();
630 
631     // For ranked/ranked case, just keep the original descriptor.
632     if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
633       return rewriter.replaceOp(memRefCastOp, {transformed.source()});
634 
635     if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
636       // Casting ranked to unranked memref type
637       // Set the rank in the destination from the memref type
638       // Allocate space on the stack and copy the src memref descriptor
639       // Set the ptr in the destination to the stack space
640       auto srcMemRefType = srcType.cast<MemRefType>();
641       int64_t rank = srcMemRefType.getRank();
642       // ptr = AllocaOp sizeof(MemRefDescriptor)
643       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
644           loc, transformed.source(), rewriter);
645       // voidptr = BitCastOp srcType* to void*
646       auto voidPtr =
647           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
648               .getResult();
649       // rank = ConstantOp srcRank
650       auto rankVal = rewriter.create<LLVM::ConstantOp>(
651           loc, typeConverter->convertType(rewriter.getIntegerType(64)),
652           rewriter.getI64IntegerAttr(rank));
653       // undef = UndefOp
654       UnrankedMemRefDescriptor memRefDesc =
655           UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
656       // d1 = InsertValueOp undef, rank, 0
657       memRefDesc.setRank(rewriter, loc, rankVal);
658       // d2 = InsertValueOp d1, voidptr, 1
659       memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
660       rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
661 
662     } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
663       // Casting from unranked type to ranked.
664       // The operation is assumed to be doing a correct cast. If the destination
665       // type mismatches the unranked the type, it is undefined behavior.
666       UnrankedMemRefDescriptor memRefDesc(transformed.source());
667       // ptr = ExtractValueOp src, 1
668       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
669       // castPtr = BitCastOp i8* to structTy*
670       auto castPtr =
671           rewriter
672               .create<LLVM::BitcastOp>(
673                   loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
674               .getResult();
675       // struct = LoadOp castPtr
676       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
677       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
678     } else {
679       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
680     }
681   }
682 };
683 
684 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
685   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
686 
687   LogicalResult
matchAndRewrite__anona348bc670111::MemRefCopyOpLowering688   matchAndRewrite(memref::CopyOp op, ArrayRef<Value> operands,
689                   ConversionPatternRewriter &rewriter) const override {
690     auto loc = op.getLoc();
691     memref::CopyOp::Adaptor adaptor(operands);
692     auto srcType = op.source().getType().cast<BaseMemRefType>();
693     auto targetType = op.target().getType().cast<BaseMemRefType>();
694 
695     // First make sure we have an unranked memref descriptor representation.
696     auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
697       auto rank = rewriter.create<LLVM::ConstantOp>(
698           loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
699       auto *typeConverter = getTypeConverter();
700       auto ptr =
701           typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
702       auto voidPtr =
703           rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
704               .getResult();
705       auto unrankedType =
706           UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
707       return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
708                                             unrankedType,
709                                             ValueRange{rank, voidPtr});
710     };
711 
712     Value unrankedSource = srcType.hasRank()
713                                ? makeUnranked(adaptor.source(), srcType)
714                                : adaptor.source();
715     Value unrankedTarget = targetType.hasRank()
716                                ? makeUnranked(adaptor.target(), targetType)
717                                : adaptor.target();
718 
719     // Now promote the unranked descriptors to the stack.
720     auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
721                                                  rewriter.getIndexAttr(1));
722     auto promote = [&](Value desc) {
723       auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
724       auto allocated =
725           rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
726       rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
727       return allocated;
728     };
729 
730     auto sourcePtr = promote(unrankedSource);
731     auto targetPtr = promote(unrankedTarget);
732 
733     auto elemSize = rewriter.create<LLVM::ConstantOp>(
734         loc, getIndexType(),
735         rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8));
736     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
737         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
738     rewriter.create<LLVM::CallOp>(loc, copyFn,
739                                   ValueRange{elemSize, sourcePtr, targetPtr});
740     rewriter.eraseOp(op);
741 
742     return success();
743   }
744 };
745 
746 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
747 /// memref type. In unranked case, the fields are extracted from the underlying
748 /// ranked descriptor.
extractPointersAndOffset(Location loc,ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Value originalOperand,Value convertedOperand,Value * allocatedPtr,Value * alignedPtr,Value * offset=nullptr)749 static void extractPointersAndOffset(Location loc,
750                                      ConversionPatternRewriter &rewriter,
751                                      LLVMTypeConverter &typeConverter,
752                                      Value originalOperand,
753                                      Value convertedOperand,
754                                      Value *allocatedPtr, Value *alignedPtr,
755                                      Value *offset = nullptr) {
756   Type operandType = originalOperand.getType();
757   if (operandType.isa<MemRefType>()) {
758     MemRefDescriptor desc(convertedOperand);
759     *allocatedPtr = desc.allocatedPtr(rewriter, loc);
760     *alignedPtr = desc.alignedPtr(rewriter, loc);
761     if (offset != nullptr)
762       *offset = desc.offset(rewriter, loc);
763     return;
764   }
765 
766   unsigned memorySpace =
767       operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
768   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
769   Type llvmElementType = typeConverter.convertType(elementType);
770   Type elementPtrPtrType = LLVM::LLVMPointerType::get(
771       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
772 
773   // Extract pointer to the underlying ranked memref descriptor and cast it to
774   // ElemType**.
775   UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
776   Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
777 
778   *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
779       rewriter, loc, underlyingDescPtr, elementPtrPtrType);
780   *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
781       rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
782   if (offset != nullptr) {
783     *offset = UnrankedMemRefDescriptor::offset(
784         rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
785   }
786 }
787 
788 struct MemRefReinterpretCastOpLowering
789     : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
790   using ConvertOpToLLVMPattern<
791       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
792 
793   LogicalResult
matchAndRewrite__anona348bc670111::MemRefReinterpretCastOpLowering794   matchAndRewrite(memref::ReinterpretCastOp castOp, ArrayRef<Value> operands,
795                   ConversionPatternRewriter &rewriter) const override {
796     memref::ReinterpretCastOp::Adaptor adaptor(operands,
797                                                castOp->getAttrDictionary());
798     Type srcType = castOp.source().getType();
799 
800     Value descriptor;
801     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
802                                                adaptor, &descriptor)))
803       return failure();
804     rewriter.replaceOp(castOp, {descriptor});
805     return success();
806   }
807 
808 private:
convertSourceMemRefToDescriptor__anona348bc670111::MemRefReinterpretCastOpLowering809   LogicalResult convertSourceMemRefToDescriptor(
810       ConversionPatternRewriter &rewriter, Type srcType,
811       memref::ReinterpretCastOp castOp,
812       memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
813     MemRefType targetMemRefType =
814         castOp.getResult().getType().cast<MemRefType>();
815     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
816                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
817     if (!llvmTargetDescriptorTy)
818       return failure();
819 
820     // Create descriptor.
821     Location loc = castOp.getLoc();
822     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
823 
824     // Set allocated and aligned pointers.
825     Value allocatedPtr, alignedPtr;
826     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
827                              castOp.source(), adaptor.source(), &allocatedPtr,
828                              &alignedPtr);
829     desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
830     desc.setAlignedPtr(rewriter, loc, alignedPtr);
831 
832     // Set offset.
833     if (castOp.isDynamicOffset(0))
834       desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
835     else
836       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
837 
838     // Set sizes and strides.
839     unsigned dynSizeId = 0;
840     unsigned dynStrideId = 0;
841     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
842       if (castOp.isDynamicSize(i))
843         desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
844       else
845         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
846 
847       if (castOp.isDynamicStride(i))
848         desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
849       else
850         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
851     }
852     *descriptor = desc;
853     return success();
854   }
855 };
856 
857 struct MemRefReshapeOpLowering
858     : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
859   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
860 
861   LogicalResult
matchAndRewrite__anona348bc670111::MemRefReshapeOpLowering862   matchAndRewrite(memref::ReshapeOp reshapeOp, ArrayRef<Value> operands,
863                   ConversionPatternRewriter &rewriter) const override {
864     auto *op = reshapeOp.getOperation();
865     memref::ReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
866     Type srcType = reshapeOp.source().getType();
867 
868     Value descriptor;
869     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
870                                                adaptor, &descriptor)))
871       return failure();
872     rewriter.replaceOp(op, {descriptor});
873     return success();
874   }
875 
876 private:
877   LogicalResult
convertSourceMemRefToDescriptor__anona348bc670111::MemRefReshapeOpLowering878   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
879                                   Type srcType, memref::ReshapeOp reshapeOp,
880                                   memref::ReshapeOp::Adaptor adaptor,
881                                   Value *descriptor) const {
882     // Conversion for statically-known shape args is performed via
883     // `memref_reinterpret_cast`.
884     auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
885     if (shapeMemRefType.hasStaticShape())
886       return failure();
887 
888     // The shape is a rank-1 tensor with unknown length.
889     Location loc = reshapeOp.getLoc();
890     MemRefDescriptor shapeDesc(adaptor.shape());
891     Value resultRank = shapeDesc.size(rewriter, loc, 0);
892 
893     // Extract address space and element type.
894     auto targetType =
895         reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
896     unsigned addressSpace = targetType.getMemorySpaceAsInt();
897     Type elementType = targetType.getElementType();
898 
899     // Create the unranked memref descriptor that holds the ranked one. The
900     // inner descriptor is allocated on stack.
901     auto targetDesc = UnrankedMemRefDescriptor::undef(
902         rewriter, loc, typeConverter->convertType(targetType));
903     targetDesc.setRank(rewriter, loc, resultRank);
904     SmallVector<Value, 4> sizes;
905     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
906                                            targetDesc, sizes);
907     Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
908         loc, getVoidPtrType(), sizes.front(), llvm::None);
909     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
910 
911     // Extract pointers and offset from the source memref.
912     Value allocatedPtr, alignedPtr, offset;
913     extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
914                              reshapeOp.source(), adaptor.source(),
915                              &allocatedPtr, &alignedPtr, &offset);
916 
917     // Set pointers and offset.
918     Type llvmElementType = typeConverter->convertType(elementType);
919     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
920         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
921     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
922                                               elementPtrPtrType, allocatedPtr);
923     UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
924                                             underlyingDescPtr,
925                                             elementPtrPtrType, alignedPtr);
926     UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
927                                         underlyingDescPtr, elementPtrPtrType,
928                                         offset);
929 
930     // Use the offset pointer as base for further addressing. Copy over the new
931     // shape and compute strides. For this, we create a loop from rank-1 to 0.
932     Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
933         rewriter, loc, *getTypeConverter(), underlyingDescPtr,
934         elementPtrPtrType);
935     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
936         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
937     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
938     Value oneIndex = createIndexConstant(rewriter, loc, 1);
939     Value resultRankMinusOne =
940         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
941 
942     Block *initBlock = rewriter.getInsertionBlock();
943     Type indexType = getTypeConverter()->getIndexType();
944     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
945 
946     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
947                                             {indexType, indexType});
948 
949     // Move the remaining initBlock ops to condBlock.
950     Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
951     rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
952 
953     rewriter.setInsertionPointToEnd(initBlock);
954     rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
955                                 condBlock);
956     rewriter.setInsertionPointToStart(condBlock);
957     Value indexArg = condBlock->getArgument(0);
958     Value strideArg = condBlock->getArgument(1);
959 
960     Value zeroIndex = createIndexConstant(rewriter, loc, 0);
961     Value pred = rewriter.create<LLVM::ICmpOp>(
962         loc, IntegerType::get(rewriter.getContext(), 1),
963         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
964 
965     Block *bodyBlock =
966         rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
967     rewriter.setInsertionPointToStart(bodyBlock);
968 
969     // Copy size from shape to descriptor.
970     Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
971     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
972         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
973     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
974     UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
975                                       targetSizesBase, indexArg, size);
976 
977     // Write stride value and compute next one.
978     UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
979                                         targetStridesBase, indexArg, strideArg);
980     Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
981 
982     // Decrement loop counter and branch back.
983     Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
984     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
985                                 condBlock);
986 
987     Block *remainder =
988         rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
989 
990     // Hook up the cond exit to the remainder.
991     rewriter.setInsertionPointToEnd(condBlock);
992     rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
993                                     llvm::None);
994 
995     // Reset position to beginning of new remainder block.
996     rewriter.setInsertionPointToStart(remainder);
997 
998     *descriptor = targetDesc;
999     return success();
1000   }
1001 };
1002 
1003 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1004 /// `Value`s.
getAsValues(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<OpFoldResult> valueOrAttrVec)1005 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1006                                       Type &llvmIndexType,
1007                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
1008   return llvm::to_vector<4>(
1009       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1010         if (auto attr = value.dyn_cast<Attribute>())
1011           return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1012         return value.get<Value>();
1013       }));
1014 }
1015 
1016 /// Compute a map that for a given dimension of the expanded type gives the
1017 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1018 /// the `reassocation` maps.
1019 static DenseMap<int64_t, int64_t>
getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation)1020 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1021   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1022   for (auto &en : enumerate(reassociation)) {
1023     for (auto dim : en.value())
1024       expandedDimToCollapsedDim[dim] = en.index();
1025   }
1026   return expandedDimToCollapsedDim;
1027 }
1028 
1029 static OpFoldResult
getExpandedOutputDimSize(OpBuilder & b,Location loc,Type & llvmIndexType,int64_t outDimIndex,ArrayRef<int64_t> outStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> inStaticShape,ArrayRef<ReassociationIndices> reassocation,DenseMap<int64_t,int64_t> & outDimToInDimMap)1030 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1031                          int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1032                          MemRefDescriptor &inDesc,
1033                          ArrayRef<int64_t> inStaticShape,
1034                          ArrayRef<ReassociationIndices> reassocation,
1035                          DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1036   int64_t outDimSize = outStaticShape[outDimIndex];
1037   if (!ShapedType::isDynamic(outDimSize))
1038     return b.getIndexAttr(outDimSize);
1039 
1040   // Calculate the multiplication of all the out dim sizes except the
1041   // current dim.
1042   int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1043   int64_t otherDimSizesMul = 1;
1044   for (auto otherDimIndex : reassocation[inDimIndex]) {
1045     if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1046       continue;
1047     int64_t otherDimSize = outStaticShape[otherDimIndex];
1048     assert(!ShapedType::isDynamic(otherDimSize) &&
1049            "single dimension cannot be expanded into multiple dynamic "
1050            "dimensions");
1051     otherDimSizesMul *= otherDimSize;
1052   }
1053 
1054   // outDimSize = inDimSize / otherOutDimSizesMul
1055   int64_t inDimSize = inStaticShape[inDimIndex];
1056   Value inDimSizeDynamic =
1057       ShapedType::isDynamic(inDimSize)
1058           ? inDesc.size(b, loc, inDimIndex)
1059           : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1060                                        b.getIndexAttr(inDimSize));
1061   Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1062       loc, inDimSizeDynamic,
1063       b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1064                                  b.getIndexAttr(otherDimSizesMul)));
1065   return outDimSizeDynamic;
1066 }
1067 
getCollapsedOutputDimSize(OpBuilder & b,Location loc,Type & llvmIndexType,int64_t outDimIndex,int64_t outDimSize,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<ReassociationIndices> reassocation)1068 static OpFoldResult getCollapsedOutputDimSize(
1069     OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1070     int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1071     MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1072   if (!ShapedType::isDynamic(outDimSize))
1073     return b.getIndexAttr(outDimSize);
1074 
1075   Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1076   Value outDimSizeDynamic = c1;
1077   for (auto inDimIndex : reassocation[outDimIndex]) {
1078     int64_t inDimSize = inStaticShape[inDimIndex];
1079     Value inDimSizeDynamic =
1080         ShapedType::isDynamic(inDimSize)
1081             ? inDesc.size(b, loc, inDimIndex)
1082             : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1083                                          b.getIndexAttr(inDimSize));
1084     outDimSizeDynamic =
1085         b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1086   }
1087   return outDimSizeDynamic;
1088 }
1089 
1090 static SmallVector<OpFoldResult, 4>
getCollapsedOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassocation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1091 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1092                         ArrayRef<ReassociationIndices> reassocation,
1093                         ArrayRef<int64_t> inStaticShape,
1094                         MemRefDescriptor &inDesc,
1095                         ArrayRef<int64_t> outStaticShape) {
1096   return llvm::to_vector<4>(llvm::map_range(
1097       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1098         return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1099                                          outStaticShape[outDimIndex],
1100                                          inStaticShape, inDesc, reassocation);
1101       }));
1102 }
1103 
1104 static SmallVector<OpFoldResult, 4>
getExpandedOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassocation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1105 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1106                        ArrayRef<ReassociationIndices> reassocation,
1107                        ArrayRef<int64_t> inStaticShape,
1108                        MemRefDescriptor &inDesc,
1109                        ArrayRef<int64_t> outStaticShape) {
1110   DenseMap<int64_t, int64_t> outDimToInDimMap =
1111       getExpandedDimToCollapsedDimMap(reassocation);
1112   return llvm::to_vector<4>(llvm::map_range(
1113       llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1114         return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1115                                         outStaticShape, inDesc, inStaticShape,
1116                                         reassocation, outDimToInDimMap);
1117       }));
1118 }
1119 
1120 static SmallVector<Value>
getDynamicOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassocation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1121 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1122                       ArrayRef<ReassociationIndices> reassocation,
1123                       ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1124                       ArrayRef<int64_t> outStaticShape) {
1125   return outStaticShape.size() < inStaticShape.size()
1126              ? getAsValues(b, loc, llvmIndexType,
1127                            getCollapsedOutputShape(b, loc, llvmIndexType,
1128                                                    reassocation, inStaticShape,
1129                                                    inDesc, outStaticShape))
1130              : getAsValues(b, loc, llvmIndexType,
1131                            getExpandedOutputShape(b, loc, llvmIndexType,
1132                                                   reassocation, inStaticShape,
1133                                                   inDesc, outStaticShape));
1134 }
1135 
1136 // ReshapeOp creates a new view descriptor of the proper rank.
1137 // For now, the only conversion supported is for target MemRef with static sizes
1138 // and strides.
1139 template <typename ReshapeOp>
1140 class ReassociatingReshapeOpConversion
1141     : public ConvertOpToLLVMPattern<ReshapeOp> {
1142 public:
1143   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1144   using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1145 
1146   LogicalResult
matchAndRewrite(ReshapeOp reshapeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1147   matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
1148                   ConversionPatternRewriter &rewriter) const override {
1149     MemRefType dstType = reshapeOp.getResultType();
1150     MemRefType srcType = reshapeOp.getSrcType();
1151     if (!srcType.getAffineMaps().empty() || !dstType.getAffineMaps().empty()) {
1152       return rewriter.notifyMatchFailure(reshapeOp,
1153                                          "only empty layout map is supported");
1154     }
1155 
1156     int64_t offset;
1157     SmallVector<int64_t, 4> strides;
1158     if (failed(getStridesAndOffset(dstType, strides, offset))) {
1159       return rewriter.notifyMatchFailure(
1160           reshapeOp, "failed to get stride and offset exprs");
1161     }
1162 
1163     ReshapeOpAdaptor adaptor(operands);
1164     MemRefDescriptor srcDesc(adaptor.src());
1165     Location loc = reshapeOp->getLoc();
1166     auto dstDesc = MemRefDescriptor::undef(
1167         rewriter, loc, this->typeConverter->convertType(dstType));
1168     dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1169     dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1170     dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1171 
1172     ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1173     ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1174     Type llvmIndexType =
1175         this->typeConverter->convertType(rewriter.getIndexType());
1176     SmallVector<Value> dstShape = getDynamicOutputShape(
1177         rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1178         srcStaticShape, srcDesc, dstStaticShape);
1179     for (auto &en : llvm::enumerate(dstShape))
1180       dstDesc.setSize(rewriter, loc, en.index(), en.value());
1181 
1182     auto isStaticStride = [](int64_t stride) {
1183       return !ShapedType::isDynamicStrideOrOffset(stride);
1184     };
1185     if (llvm::all_of(strides, isStaticStride)) {
1186       for (auto &en : llvm::enumerate(strides))
1187         dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1188     } else {
1189       Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1190                                                    rewriter.getIndexAttr(1));
1191       Value stride = c1;
1192       for (auto dimIndex :
1193            llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1194         dstDesc.setStride(rewriter, loc, dimIndex, stride);
1195         stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1196       }
1197     }
1198     rewriter.replaceOp(reshapeOp, {dstDesc});
1199     return success();
1200   }
1201 };
1202 
1203 /// Conversion pattern that transforms a subview op into:
1204 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1205 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1206 ///      and stride.
1207 /// The subview op is replaced by the descriptor.
1208 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1209   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1210 
1211   LogicalResult
matchAndRewrite__anona348bc670111::SubViewOpLowering1212   matchAndRewrite(memref::SubViewOp subViewOp, ArrayRef<Value> operands,
1213                   ConversionPatternRewriter &rewriter) const override {
1214     auto loc = subViewOp.getLoc();
1215 
1216     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
1217     auto sourceElementTy =
1218         typeConverter->convertType(sourceMemRefType.getElementType());
1219 
1220     auto viewMemRefType = subViewOp.getType();
1221     auto inferredType = memref::SubViewOp::inferResultType(
1222                             subViewOp.getSourceType(),
1223                             extractFromI64ArrayAttr(subViewOp.static_offsets()),
1224                             extractFromI64ArrayAttr(subViewOp.static_sizes()),
1225                             extractFromI64ArrayAttr(subViewOp.static_strides()))
1226                             .cast<MemRefType>();
1227     auto targetElementTy =
1228         typeConverter->convertType(viewMemRefType.getElementType());
1229     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1230     if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1231         !LLVM::isCompatibleType(sourceElementTy) ||
1232         !LLVM::isCompatibleType(targetElementTy) ||
1233         !LLVM::isCompatibleType(targetDescTy))
1234       return failure();
1235 
1236     // Extract the offset and strides from the type.
1237     int64_t offset;
1238     SmallVector<int64_t, 4> strides;
1239     auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1240     if (failed(successStrides))
1241       return failure();
1242 
1243     // Create the descriptor.
1244     if (!LLVM::isCompatibleType(operands.front().getType()))
1245       return failure();
1246     MemRefDescriptor sourceMemRef(operands.front());
1247     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1248 
1249     // Copy the buffer pointer from the old descriptor to the new one.
1250     Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1251     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1252         loc,
1253         LLVM::LLVMPointerType::get(targetElementTy,
1254                                    viewMemRefType.getMemorySpaceAsInt()),
1255         extracted);
1256     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1257 
1258     // Copy the aligned pointer from the old descriptor to the new one.
1259     extracted = sourceMemRef.alignedPtr(rewriter, loc);
1260     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1261         loc,
1262         LLVM::LLVMPointerType::get(targetElementTy,
1263                                    viewMemRefType.getMemorySpaceAsInt()),
1264         extracted);
1265     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1266 
1267     auto shape = viewMemRefType.getShape();
1268     auto inferredShape = inferredType.getShape();
1269     size_t inferredShapeRank = inferredShape.size();
1270     size_t resultShapeRank = shape.size();
1271     llvm::SmallDenseSet<unsigned> unusedDims =
1272         computeRankReductionMask(inferredShape, shape).getValue();
1273 
1274     // Extract strides needed to compute offset.
1275     SmallVector<Value, 4> strideValues;
1276     strideValues.reserve(inferredShapeRank);
1277     for (unsigned i = 0; i < inferredShapeRank; ++i)
1278       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1279 
1280     // Offset.
1281     auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1282     if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1283       targetMemRef.setConstantOffset(rewriter, loc, offset);
1284     } else {
1285       Value baseOffset = sourceMemRef.offset(rewriter, loc);
1286       // `inferredShapeRank` may be larger than the number of offset operands
1287       // because of trailing semantics. In this case, the offset is guaranteed
1288       // to be interpreted as 0 and we can just skip the extra dimensions.
1289       for (unsigned i = 0, e = std::min(inferredShapeRank,
1290                                         subViewOp.getMixedOffsets().size());
1291            i < e; ++i) {
1292         Value offset =
1293             // TODO: need OpFoldResult ODS adaptor to clean this up.
1294             subViewOp.isDynamicOffset(i)
1295                 ? operands[subViewOp.getIndexOfDynamicOffset(i)]
1296                 : rewriter.create<LLVM::ConstantOp>(
1297                       loc, llvmIndexType,
1298                       rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1299         Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1300         baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1301       }
1302       targetMemRef.setOffset(rewriter, loc, baseOffset);
1303     }
1304 
1305     // Update sizes and strides.
1306     SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1307     SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1308     assert(mixedSizes.size() == mixedStrides.size() &&
1309            "expected sizes and strides of equal length");
1310     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1311          i >= 0 && j >= 0; --i) {
1312       if (unusedDims.contains(i))
1313         continue;
1314 
1315       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1316       // In this case, the size is guaranteed to be interpreted as Dim and the
1317       // stride as 1.
1318       Value size, stride;
1319       if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1320         // If the static size is available, use it directly. This is similar to
1321         // the folding of dim(constant-op) but removes the need for dim to be
1322         // aware of LLVM constants and for this pass to be aware of std
1323         // constants.
1324         int64_t staticSize =
1325             subViewOp.source().getType().cast<MemRefType>().getShape()[i];
1326         if (staticSize != ShapedType::kDynamicSize) {
1327           size = rewriter.create<LLVM::ConstantOp>(
1328               loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1329         } else {
1330           Value pos = rewriter.create<LLVM::ConstantOp>(
1331               loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1332           Value dim =
1333               rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos);
1334           auto cast = rewriter.create<UnrealizedConversionCastOp>(
1335               loc, llvmIndexType, dim);
1336           size = cast.getResult(0);
1337         }
1338         stride = rewriter.create<LLVM::ConstantOp>(
1339             loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1340       } else {
1341         // TODO: need OpFoldResult ODS adaptor to clean this up.
1342         size =
1343             subViewOp.isDynamicSize(i)
1344                 ? operands[subViewOp.getIndexOfDynamicSize(i)]
1345                 : rewriter.create<LLVM::ConstantOp>(
1346                       loc, llvmIndexType,
1347                       rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1348         if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1349           stride = rewriter.create<LLVM::ConstantOp>(
1350               loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1351         } else {
1352           stride = subViewOp.isDynamicStride(i)
1353                        ? operands[subViewOp.getIndexOfDynamicStride(i)]
1354                        : rewriter.create<LLVM::ConstantOp>(
1355                              loc, llvmIndexType,
1356                              rewriter.getI64IntegerAttr(
1357                                  subViewOp.getStaticStride(i)));
1358           stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1359         }
1360       }
1361       targetMemRef.setSize(rewriter, loc, j, size);
1362       targetMemRef.setStride(rewriter, loc, j, stride);
1363       j--;
1364     }
1365 
1366     rewriter.replaceOp(subViewOp, {targetMemRef});
1367     return success();
1368   }
1369 };
1370 
1371 /// Conversion pattern that transforms a transpose op into:
1372 ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
1373 ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
1374 ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1375 ///      and stride. Size and stride are permutations of the original values.
1376 ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1377 /// The transpose op is replaced by the alloca'ed pointer.
1378 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1379 public:
1380   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1381 
1382   LogicalResult
matchAndRewrite(memref::TransposeOp transposeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1383   matchAndRewrite(memref::TransposeOp transposeOp, ArrayRef<Value> operands,
1384                   ConversionPatternRewriter &rewriter) const override {
1385     auto loc = transposeOp.getLoc();
1386     memref::TransposeOpAdaptor adaptor(operands);
1387     MemRefDescriptor viewMemRef(adaptor.in());
1388 
1389     // No permutation, early exit.
1390     if (transposeOp.permutation().isIdentity())
1391       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1392 
1393     auto targetMemRef = MemRefDescriptor::undef(
1394         rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1395 
1396     // Copy the base and aligned pointers from the old descriptor to the new
1397     // one.
1398     targetMemRef.setAllocatedPtr(rewriter, loc,
1399                                  viewMemRef.allocatedPtr(rewriter, loc));
1400     targetMemRef.setAlignedPtr(rewriter, loc,
1401                                viewMemRef.alignedPtr(rewriter, loc));
1402 
1403     // Copy the offset pointer from the old descriptor to the new one.
1404     targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1405 
1406     // Iterate over the dimensions and apply size/stride permutation.
1407     for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
1408       int sourcePos = en.index();
1409       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1410       targetMemRef.setSize(rewriter, loc, targetPos,
1411                            viewMemRef.size(rewriter, loc, sourcePos));
1412       targetMemRef.setStride(rewriter, loc, targetPos,
1413                              viewMemRef.stride(rewriter, loc, sourcePos));
1414     }
1415 
1416     rewriter.replaceOp(transposeOp, {targetMemRef});
1417     return success();
1418   }
1419 };
1420 
1421 /// Conversion pattern that transforms an op into:
1422 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
1423 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
1424 ///      and stride.
1425 /// The view op is replaced by the descriptor.
1426 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1427   using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1428 
1429   // Build and return the value for the idx^th shape dimension, either by
1430   // returning the constant shape dimension or counting the proper dynamic size.
getSize__anona348bc670111::ViewOpLowering1431   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1432                 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1433                 unsigned idx) const {
1434     assert(idx < shape.size());
1435     if (!ShapedType::isDynamic(shape[idx]))
1436       return createIndexConstant(rewriter, loc, shape[idx]);
1437     // Count the number of dynamic dims in range [0, idx]
1438     unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1439       return ShapedType::isDynamic(v);
1440     });
1441     return dynamicSizes[nDynamic];
1442   }
1443 
1444   // Build and return the idx^th stride, either by returning the constant stride
1445   // or by computing the dynamic stride from the current `runningStride` and
1446   // `nextSize`. The caller should keep a running stride and update it with the
1447   // result returned by this function.
getStride__anona348bc670111::ViewOpLowering1448   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1449                   ArrayRef<int64_t> strides, Value nextSize,
1450                   Value runningStride, unsigned idx) const {
1451     assert(idx < strides.size());
1452     if (!MemRefType::isDynamicStrideOrOffset(strides[idx]))
1453       return createIndexConstant(rewriter, loc, strides[idx]);
1454     if (nextSize)
1455       return runningStride
1456                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1457                  : nextSize;
1458     assert(!runningStride);
1459     return createIndexConstant(rewriter, loc, 1);
1460   }
1461 
1462   LogicalResult
matchAndRewrite__anona348bc670111::ViewOpLowering1463   matchAndRewrite(memref::ViewOp viewOp, ArrayRef<Value> operands,
1464                   ConversionPatternRewriter &rewriter) const override {
1465     auto loc = viewOp.getLoc();
1466     memref::ViewOpAdaptor adaptor(operands);
1467 
1468     auto viewMemRefType = viewOp.getType();
1469     auto targetElementTy =
1470         typeConverter->convertType(viewMemRefType.getElementType());
1471     auto targetDescTy = typeConverter->convertType(viewMemRefType);
1472     if (!targetDescTy || !targetElementTy ||
1473         !LLVM::isCompatibleType(targetElementTy) ||
1474         !LLVM::isCompatibleType(targetDescTy))
1475       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1476              failure();
1477 
1478     int64_t offset;
1479     SmallVector<int64_t, 4> strides;
1480     auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1481     if (failed(successStrides))
1482       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1483     assert(offset == 0 && "expected offset to be 0");
1484 
1485     // Create the descriptor.
1486     MemRefDescriptor sourceMemRef(adaptor.source());
1487     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1488 
1489     // Field 1: Copy the allocated pointer, used for malloc/free.
1490     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1491     auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
1492     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1493         loc,
1494         LLVM::LLVMPointerType::get(targetElementTy,
1495                                    srcMemRefType.getMemorySpaceAsInt()),
1496         allocatedPtr);
1497     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1498 
1499     // Field 2: Copy the actual aligned pointer to payload.
1500     Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1501     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
1502                                               alignedPtr, adaptor.byte_shift());
1503     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1504         loc,
1505         LLVM::LLVMPointerType::get(targetElementTy,
1506                                    srcMemRefType.getMemorySpaceAsInt()),
1507         alignedPtr);
1508     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1509 
1510     // Field 3: The offset in the resulting type must be 0. This is because of
1511     // the type change: an offset on srcType* may not be expressible as an
1512     // offset on dstType*.
1513     targetMemRef.setOffset(rewriter, loc,
1514                            createIndexConstant(rewriter, loc, offset));
1515 
1516     // Early exit for 0-D corner case.
1517     if (viewMemRefType.getRank() == 0)
1518       return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1519 
1520     // Fields 4 and 5: Update sizes and strides.
1521     if (strides.back() != 1)
1522       return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1523              failure();
1524     Value stride = nullptr, nextSize = nullptr;
1525     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1526       // Update size.
1527       Value size =
1528           getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
1529       targetMemRef.setSize(rewriter, loc, i, size);
1530       // Update stride.
1531       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1532       targetMemRef.setStride(rewriter, loc, i, stride);
1533       nextSize = size;
1534     }
1535 
1536     rewriter.replaceOp(viewOp, {targetMemRef});
1537     return success();
1538   }
1539 };
1540 
1541 } // namespace
1542 
populateMemRefToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)1543 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
1544                                                   RewritePatternSet &patterns) {
1545   // clang-format off
1546   patterns.add<
1547       AllocaOpLowering,
1548       AllocaScopeOpLowering,
1549       AssumeAlignmentOpLowering,
1550       DimOpLowering,
1551       DeallocOpLowering,
1552       GlobalMemrefOpLowering,
1553       GetGlobalMemrefOpLowering,
1554       LoadOpLowering,
1555       MemRefCastOpLowering,
1556       MemRefCopyOpLowering,
1557       MemRefReinterpretCastOpLowering,
1558       MemRefReshapeOpLowering,
1559       PrefetchOpLowering,
1560       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1561       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1562       StoreOpLowering,
1563       SubViewOpLowering,
1564       TransposeOpLowering,
1565       ViewOpLowering>(converter);
1566   // clang-format on
1567   auto allocLowering = converter.getOptions().allocLowering;
1568   if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1569     patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1570   else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1571     patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1572 }
1573 
1574 namespace {
1575 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
1576   MemRefToLLVMPass() = default;
1577 
runOnOperation__anona348bc670911::MemRefToLLVMPass1578   void runOnOperation() override {
1579     Operation *op = getOperation();
1580     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1581     LowerToLLVMOptions options(&getContext(),
1582                                dataLayoutAnalysis.getAtOrAbove(op));
1583     options.allocLowering =
1584         (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1585                          : LowerToLLVMOptions::AllocLowering::Malloc);
1586     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1587       options.overrideIndexBitwidth(indexBitwidth);
1588 
1589     LLVMTypeConverter typeConverter(&getContext(), options,
1590                                     &dataLayoutAnalysis);
1591     RewritePatternSet patterns(&getContext());
1592     populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
1593     LLVMConversionTarget target(getContext());
1594     target.addLegalOp<FuncOp>();
1595     if (failed(applyPartialConversion(op, target, std::move(patterns))))
1596       signalPassFailure();
1597   }
1598 };
1599 } // namespace
1600 
createMemRefToLLVMPass()1601 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
1602   return std::make_unique<MemRefToLLVMPass>();
1603 }
1604