1 //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21
22 using namespace mlir;
23
24 namespace {
25
26 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
AllocOpLowering__anona8fa999b0111::AllocOpLowering27 AllocOpLowering(LLVMTypeConverter &converter)
28 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
29 converter) {}
30
allocateBuffer__anona8fa999b0111::AllocOpLowering31 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
32 Location loc, Value sizeBytes,
33 Operation *op) const override {
34 // Heap allocations.
35 memref::AllocOp allocOp = cast<memref::AllocOp>(op);
36 MemRefType memRefType = allocOp.getType();
37
38 Value alignment;
39 if (auto alignmentAttr = allocOp.alignment()) {
40 alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
41 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
42 // In the case where no alignment is specified, we may want to override
43 // `malloc's` behavior. `malloc` typically aligns at the size of the
44 // biggest scalar on a target HW. For non-scalars, use the natural
45 // alignment of the LLVM type given by the LLVM DataLayout.
46 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
47 }
48
49 if (alignment) {
50 // Adjust the allocation size to consider alignment.
51 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
52 }
53
54 // Allocate the underlying buffer and store a pointer to it in the MemRef
55 // descriptor.
56 Type elementPtrType = this->getElementPtrType(memRefType);
57 auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
58 allocOp->getParentOfType<ModuleOp>(), getIndexType());
59 auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
60 getVoidPtrType());
61 Value allocatedPtr =
62 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
63
64 Value alignedPtr = allocatedPtr;
65 if (alignment) {
66 // Compute the aligned type pointer.
67 Value allocatedInt =
68 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
69 Value alignmentInt =
70 createAligned(rewriter, loc, allocatedInt, alignment);
71 alignedPtr =
72 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
73 }
74
75 return std::make_tuple(allocatedPtr, alignedPtr);
76 }
77 };
78
79 struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
AlignedAllocOpLowering__anona8fa999b0111::AlignedAllocOpLowering80 AlignedAllocOpLowering(LLVMTypeConverter &converter)
81 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
82 converter) {}
83
84 /// Returns the memref's element size in bytes using the data layout active at
85 /// `op`.
86 // TODO: there are other places where this is used. Expose publicly?
getMemRefEltSizeInBytes__anona8fa999b0111::AlignedAllocOpLowering87 unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
88 const DataLayout *layout = &defaultLayout;
89 if (const DataLayoutAnalysis *analysis =
90 getTypeConverter()->getDataLayoutAnalysis()) {
91 layout = &analysis->getAbove(op);
92 }
93 Type elementType = memRefType.getElementType();
94 if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
95 return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
96 *layout);
97 if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
98 return getTypeConverter()->getUnrankedMemRefDescriptorSize(
99 memRefElementType, *layout);
100 return layout->getTypeSize(elementType);
101 }
102
103 /// Returns true if the memref size in bytes is known to be a multiple of
104 /// factor assuming the data layout active at `op`.
isMemRefSizeMultipleOf__anona8fa999b0111::AlignedAllocOpLowering105 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
106 Operation *op) const {
107 uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
108 for (unsigned i = 0, e = type.getRank(); i < e; i++) {
109 if (type.isDynamic(type.getDimSize(i)))
110 continue;
111 sizeDivisor = sizeDivisor * type.getDimSize(i);
112 }
113 return sizeDivisor % factor == 0;
114 }
115
116 /// Returns the alignment to be used for the allocation call itself.
117 /// aligned_alloc requires the allocation size to be a power of two, and the
118 /// allocation size to be a multiple of alignment,
getAllocationAlignment__anona8fa999b0111::AlignedAllocOpLowering119 int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
120 if (Optional<uint64_t> alignment = allocOp.alignment())
121 return *alignment;
122
123 // Whenever we don't have alignment set, we will use an alignment
124 // consistent with the element type; since the allocation size has to be a
125 // power of two, we will bump to the next power of two if it already isn't.
126 auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
127 return std::max(kMinAlignedAllocAlignment,
128 llvm::PowerOf2Ceil(eltSizeBytes));
129 }
130
allocateBuffer__anona8fa999b0111::AlignedAllocOpLowering131 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
132 Location loc, Value sizeBytes,
133 Operation *op) const override {
134 // Heap allocations.
135 memref::AllocOp allocOp = cast<memref::AllocOp>(op);
136 MemRefType memRefType = allocOp.getType();
137 int64_t alignment = getAllocationAlignment(allocOp);
138 Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
139
140 // aligned_alloc requires size to be a multiple of alignment; we will pad
141 // the size to the next multiple if necessary.
142 if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
143 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
144
145 Type elementPtrType = this->getElementPtrType(memRefType);
146 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
147 allocOp->getParentOfType<ModuleOp>(), getIndexType());
148 auto results =
149 createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
150 getVoidPtrType());
151 Value allocatedPtr =
152 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
153
154 return std::make_tuple(allocatedPtr, allocatedPtr);
155 }
156
157 /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
158 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
159
160 /// Default layout to use in absence of the corresponding analysis.
161 DataLayout defaultLayout;
162 };
163
164 // Out of line definition, required till C++17.
165 constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
166
167 struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
AllocaOpLowering__anona8fa999b0111::AllocaOpLowering168 AllocaOpLowering(LLVMTypeConverter &converter)
169 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
170 converter) {}
171
172 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
173 /// is set to null for stack allocations. `accessAlignment` is set if
174 /// alignment is needed post allocation (for eg. in conjunction with malloc).
allocateBuffer__anona8fa999b0111::AllocaOpLowering175 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
176 Location loc, Value sizeBytes,
177 Operation *op) const override {
178
179 // With alloca, one gets a pointer to the element type right away.
180 // For stack allocations.
181 auto allocaOp = cast<memref::AllocaOp>(op);
182 auto elementPtrType = this->getElementPtrType(allocaOp.getType());
183
184 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
185 loc, elementPtrType, sizeBytes,
186 allocaOp.alignment() ? *allocaOp.alignment() : 0);
187
188 return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
189 }
190 };
191
192 struct AllocaScopeOpLowering
193 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
194 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
195
196 LogicalResult
matchAndRewrite__anona8fa999b0111::AllocaScopeOpLowering197 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, ArrayRef<Value> operands,
198 ConversionPatternRewriter &rewriter) const override {
199 OpBuilder::InsertionGuard guard(rewriter);
200 Location loc = allocaScopeOp.getLoc();
201
202 // Split the current block before the AllocaScopeOp to create the inlining
203 // point.
204 auto *currentBlock = rewriter.getInsertionBlock();
205 auto *remainingOpsBlock =
206 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
207 Block *continueBlock;
208 if (allocaScopeOp.getNumResults() == 0) {
209 continueBlock = remainingOpsBlock;
210 } else {
211 continueBlock = rewriter.createBlock(remainingOpsBlock,
212 allocaScopeOp.getResultTypes());
213 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
214 }
215
216 // Inline body region.
217 Block *beforeBody = &allocaScopeOp.bodyRegion().front();
218 Block *afterBody = &allocaScopeOp.bodyRegion().back();
219 rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
220
221 // Save stack and then branch into the body of the region.
222 rewriter.setInsertionPointToEnd(currentBlock);
223 auto stackSaveOp =
224 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
225 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
226
227 // Replace the alloca_scope return with a branch that jumps out of the body.
228 // Stack restore before leaving the body region.
229 rewriter.setInsertionPointToEnd(afterBody);
230 auto returnOp =
231 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
232 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
233 returnOp, returnOp.results(), continueBlock);
234
235 // Insert stack restore before jumping out the body of the region.
236 rewriter.setInsertionPoint(branchOp);
237 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
238
239 // Replace the op with values return from the body region.
240 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
241
242 return success();
243 }
244 };
245
246 struct AssumeAlignmentOpLowering
247 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
248 using ConvertOpToLLVMPattern<
249 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
250
251 LogicalResult
matchAndRewrite__anona8fa999b0111::AssumeAlignmentOpLowering252 matchAndRewrite(memref::AssumeAlignmentOp op, ArrayRef<Value> operands,
253 ConversionPatternRewriter &rewriter) const override {
254 memref::AssumeAlignmentOp::Adaptor transformed(operands);
255 Value memref = transformed.memref();
256 unsigned alignment = op.alignment();
257 auto loc = op.getLoc();
258
259 MemRefDescriptor memRefDescriptor(memref);
260 Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
261
262 // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
263 // the asserted memref.alignedPtr isn't used anywhere else, as the real
264 // users like load/store/views always re-extract memref.alignedPtr as they
265 // get lowered.
266 //
267 // This relies on LLVM's CSE optimization (potentially after SROA), since
268 // after CSE all memref.alignedPtr instances get de-duplicated into the same
269 // pointer SSA value.
270 auto intPtrType =
271 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
272 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
273 Value mask =
274 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
275 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
276 rewriter.create<LLVM::AssumeOp>(
277 loc, rewriter.create<LLVM::ICmpOp>(
278 loc, LLVM::ICmpPredicate::eq,
279 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
280
281 rewriter.eraseOp(op);
282 return success();
283 }
284 };
285
286 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
287 // The memref descriptor being an SSA value, there is no need to clean it up
288 // in any way.
289 struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
290 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
291
DeallocOpLowering__anona8fa999b0111::DeallocOpLowering292 explicit DeallocOpLowering(LLVMTypeConverter &converter)
293 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
294
295 LogicalResult
matchAndRewrite__anona8fa999b0111::DeallocOpLowering296 matchAndRewrite(memref::DeallocOp op, ArrayRef<Value> operands,
297 ConversionPatternRewriter &rewriter) const override {
298 assert(operands.size() == 1 && "dealloc takes one operand");
299 memref::DeallocOp::Adaptor transformed(operands);
300
301 // Insert the `free` declaration if it is not already present.
302 auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
303 MemRefDescriptor memref(transformed.memref());
304 Value casted = rewriter.create<LLVM::BitcastOp>(
305 op.getLoc(), getVoidPtrType(),
306 memref.allocatedPtr(rewriter, op.getLoc()));
307 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
308 op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted);
309 return success();
310 }
311 };
312
313 // A `dim` is converted to a constant for static sizes and to an access to the
314 // size stored in the memref descriptor for dynamic sizes.
315 struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
316 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
317
318 LogicalResult
matchAndRewrite__anona8fa999b0111::DimOpLowering319 matchAndRewrite(memref::DimOp dimOp, ArrayRef<Value> operands,
320 ConversionPatternRewriter &rewriter) const override {
321 Type operandType = dimOp.source().getType();
322 if (operandType.isa<UnrankedMemRefType>()) {
323 rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef(
324 operandType, dimOp, operands, rewriter)});
325
326 return success();
327 }
328 if (operandType.isa<MemRefType>()) {
329 rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(
330 operandType, dimOp, operands, rewriter)});
331 return success();
332 }
333 llvm_unreachable("expected MemRefType or UnrankedMemRefType");
334 }
335
336 private:
extractSizeOfUnrankedMemRef__anona8fa999b0111::DimOpLowering337 Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
338 ArrayRef<Value> operands,
339 ConversionPatternRewriter &rewriter) const {
340 Location loc = dimOp.getLoc();
341 memref::DimOp::Adaptor transformed(operands);
342
343 auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
344 auto scalarMemRefType =
345 MemRefType::get({}, unrankedMemRefType.getElementType());
346 unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
347
348 // Extract pointer to the underlying ranked descriptor and bitcast it to a
349 // memref<element_type> descriptor pointer to minimize the number of GEP
350 // operations.
351 UnrankedMemRefDescriptor unrankedDesc(transformed.source());
352 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
353 Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
354 loc,
355 LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
356 addressSpace),
357 underlyingRankedDesc);
358
359 // Get pointer to offset field of memref<element_type> descriptor.
360 Type indexPtrTy = LLVM::LLVMPointerType::get(
361 getTypeConverter()->getIndexType(), addressSpace);
362 Value two = rewriter.create<LLVM::ConstantOp>(
363 loc, typeConverter->convertType(rewriter.getI32Type()),
364 rewriter.getI32IntegerAttr(2));
365 Value offsetPtr = rewriter.create<LLVM::GEPOp>(
366 loc, indexPtrTy, scalarMemRefDescPtr,
367 ValueRange({createIndexConstant(rewriter, loc, 0), two}));
368
369 // The size value that we have to extract can be obtained using GEPop with
370 // `dimOp.index() + 1` index argument.
371 Value idxPlusOne = rewriter.create<LLVM::AddOp>(
372 loc, createIndexConstant(rewriter, loc, 1), transformed.index());
373 Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
374 ValueRange({idxPlusOne}));
375 return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
376 }
377
getConstantDimIndex__anona8fa999b0111::DimOpLowering378 Optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
379 if (Optional<int64_t> idx = dimOp.getConstantIndex())
380 return idx;
381
382 if (auto constantOp = dimOp.index().getDefiningOp<LLVM::ConstantOp>())
383 return constantOp.value().cast<IntegerAttr>().getValue().getSExtValue();
384
385 return llvm::None;
386 }
387
extractSizeOfRankedMemRef__anona8fa999b0111::DimOpLowering388 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
389 ArrayRef<Value> operands,
390 ConversionPatternRewriter &rewriter) const {
391 Location loc = dimOp.getLoc();
392 memref::DimOp::Adaptor transformed(operands);
393 // Take advantage if index is constant.
394 MemRefType memRefType = operandType.cast<MemRefType>();
395 if (Optional<int64_t> index = getConstantDimIndex(dimOp)) {
396 int64_t i = index.getValue();
397 if (memRefType.isDynamicDim(i)) {
398 // extract dynamic size from the memref descriptor.
399 MemRefDescriptor descriptor(transformed.source());
400 return descriptor.size(rewriter, loc, i);
401 }
402 // Use constant for static size.
403 int64_t dimSize = memRefType.getDimSize(i);
404 return createIndexConstant(rewriter, loc, dimSize);
405 }
406 Value index = transformed.index();
407 int64_t rank = memRefType.getRank();
408 MemRefDescriptor memrefDescriptor(transformed.source());
409 return memrefDescriptor.size(rewriter, loc, index, rank);
410 }
411 };
412
413 /// Returns the LLVM type of the global variable given the memref type `type`.
convertGlobalMemrefTypeToLLVM(MemRefType type,LLVMTypeConverter & typeConverter)414 static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
415 LLVMTypeConverter &typeConverter) {
416 // LLVM type for a global memref will be a multi-dimension array. For
417 // declarations or uninitialized global memrefs, we can potentially flatten
418 // this to a 1D array. However, for memref.global's with an initial value,
419 // we do not intend to flatten the ElementsAttribute when going from std ->
420 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
421 Type elementType = typeConverter.convertType(type.getElementType());
422 Type arrayTy = elementType;
423 // Shape has the outermost dim at index 0, so need to walk it backwards
424 for (int64_t dim : llvm::reverse(type.getShape()))
425 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
426 return arrayTy;
427 }
428
429 /// GlobalMemrefOp is lowered to a LLVM Global Variable.
430 struct GlobalMemrefOpLowering
431 : public ConvertOpToLLVMPattern<memref::GlobalOp> {
432 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
433
434 LogicalResult
matchAndRewrite__anona8fa999b0111::GlobalMemrefOpLowering435 matchAndRewrite(memref::GlobalOp global, ArrayRef<Value> operands,
436 ConversionPatternRewriter &rewriter) const override {
437 MemRefType type = global.type().cast<MemRefType>();
438 if (!isConvertibleAndHasIdentityMaps(type))
439 return failure();
440
441 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
442
443 LLVM::Linkage linkage =
444 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
445
446 Attribute initialValue = nullptr;
447 if (!global.isExternal() && !global.isUninitialized()) {
448 auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
449 initialValue = elementsAttr;
450
451 // For scalar memrefs, the global variable created is of the element type,
452 // so unpack the elements attribute to extract the value.
453 if (type.getRank() == 0)
454 initialValue = elementsAttr.getValue({});
455 }
456
457 rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
458 global, arrayTy, global.constant(), linkage, global.sym_name(),
459 initialValue, /*alignment=*/0, type.getMemorySpaceAsInt());
460 return success();
461 }
462 };
463
464 /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
465 /// the first element stashed into the descriptor. This reuses
466 /// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
467 struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
GetGlobalMemrefOpLowering__anona8fa999b0111::GetGlobalMemrefOpLowering468 GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
469 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
470 converter) {}
471
472 /// Buffer "allocation" for memref.get_global op is getting the address of
473 /// the global variable referenced.
allocateBuffer__anona8fa999b0111::GetGlobalMemrefOpLowering474 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
475 Location loc, Value sizeBytes,
476 Operation *op) const override {
477 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
478 MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
479 unsigned memSpace = type.getMemorySpaceAsInt();
480
481 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
482 auto addressOf = rewriter.create<LLVM::AddressOfOp>(
483 loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
484
485 // Get the address of the first element in the array by creating a GEP with
486 // the address of the GV as the base, and (rank + 1) number of 0 indices.
487 Type elementType = typeConverter->convertType(type.getElementType());
488 Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
489
490 SmallVector<Value, 4> operands = {addressOf};
491 operands.insert(operands.end(), type.getRank() + 1,
492 createIndexConstant(rewriter, loc, 0));
493 auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands);
494
495 // We do not expect the memref obtained using `memref.get_global` to be
496 // ever deallocated. Set the allocated pointer to be known bad value to
497 // help debug if that ever happens.
498 auto intPtrType = getIntPtrType(memSpace);
499 Value deadBeefConst =
500 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
501 auto deadBeefPtr =
502 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
503
504 // Both allocated and aligned pointers are same. We could potentially stash
505 // a nullptr for the allocated pointer since we do not expect any dealloc.
506 return std::make_tuple(deadBeefPtr, gep);
507 }
508 };
509
510 // Common base for load and store operations on MemRefs. Restricts the match
511 // to supported MemRef types. Provides functionality to emit code accessing a
512 // specific element of the underlying data buffer.
513 template <typename Derived>
514 struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
515 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
516 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
517 using Base = LoadStoreOpLowering<Derived>;
518
match__anona8fa999b0111::LoadStoreOpLowering519 LogicalResult match(Derived op) const override {
520 MemRefType type = op.getMemRefType();
521 return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
522 }
523 };
524
525 // Load operation is lowered to obtaining a pointer to the indexed element
526 // and loading it.
527 struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
528 using Base::Base;
529
530 LogicalResult
matchAndRewrite__anona8fa999b0111::LoadOpLowering531 matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
532 ConversionPatternRewriter &rewriter) const override {
533 memref::LoadOp::Adaptor transformed(operands);
534 auto type = loadOp.getMemRefType();
535
536 Value dataPtr =
537 getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(),
538 transformed.indices(), rewriter);
539 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
540 return success();
541 }
542 };
543
544 // Store operation is lowered to obtaining a pointer to the indexed element,
545 // and storing the given value to it.
546 struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
547 using Base::Base;
548
549 LogicalResult
matchAndRewrite__anona8fa999b0111::StoreOpLowering550 matchAndRewrite(memref::StoreOp op, ArrayRef<Value> operands,
551 ConversionPatternRewriter &rewriter) const override {
552 auto type = op.getMemRefType();
553 memref::StoreOp::Adaptor transformed(operands);
554
555 Value dataPtr =
556 getStridedElementPtr(op.getLoc(), type, transformed.memref(),
557 transformed.indices(), rewriter);
558 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
559 dataPtr);
560 return success();
561 }
562 };
563
564 // The prefetch operation is lowered in a way similar to the load operation
565 // except that the llvm.prefetch operation is used for replacement.
566 struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
567 using Base::Base;
568
569 LogicalResult
matchAndRewrite__anona8fa999b0111::PrefetchOpLowering570 matchAndRewrite(memref::PrefetchOp prefetchOp, ArrayRef<Value> operands,
571 ConversionPatternRewriter &rewriter) const override {
572 memref::PrefetchOp::Adaptor transformed(operands);
573 auto type = prefetchOp.getMemRefType();
574 auto loc = prefetchOp.getLoc();
575
576 Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(),
577 transformed.indices(), rewriter);
578
579 // Replace with llvm.prefetch.
580 auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
581 auto isWrite = rewriter.create<LLVM::ConstantOp>(
582 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
583 auto localityHint = rewriter.create<LLVM::ConstantOp>(
584 loc, llvmI32Type,
585 rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
586 auto isData = rewriter.create<LLVM::ConstantOp>(
587 loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
588
589 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
590 localityHint, isData);
591 return success();
592 }
593 };
594
595 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
596 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
597
match__anona8fa999b0111::MemRefCastOpLowering598 LogicalResult match(memref::CastOp memRefCastOp) const override {
599 Type srcType = memRefCastOp.getOperand().getType();
600 Type dstType = memRefCastOp.getType();
601
602 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
603 // used for type erasure. For now they must preserve underlying element type
604 // and require source and result type to have the same rank. Therefore,
605 // perform a sanity check that the underlying structs are the same. Once op
606 // semantics are relaxed we can revisit.
607 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
608 return success(typeConverter->convertType(srcType) ==
609 typeConverter->convertType(dstType));
610
611 // At least one of the operands is unranked type
612 assert(srcType.isa<UnrankedMemRefType>() ||
613 dstType.isa<UnrankedMemRefType>());
614
615 // Unranked to unranked cast is disallowed
616 return !(srcType.isa<UnrankedMemRefType>() &&
617 dstType.isa<UnrankedMemRefType>())
618 ? success()
619 : failure();
620 }
621
rewrite__anona8fa999b0111::MemRefCastOpLowering622 void rewrite(memref::CastOp memRefCastOp, ArrayRef<Value> operands,
623 ConversionPatternRewriter &rewriter) const override {
624 memref::CastOp::Adaptor transformed(operands);
625
626 auto srcType = memRefCastOp.getOperand().getType();
627 auto dstType = memRefCastOp.getType();
628 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
629 auto loc = memRefCastOp.getLoc();
630
631 // For ranked/ranked case, just keep the original descriptor.
632 if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
633 return rewriter.replaceOp(memRefCastOp, {transformed.source()});
634
635 if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
636 // Casting ranked to unranked memref type
637 // Set the rank in the destination from the memref type
638 // Allocate space on the stack and copy the src memref descriptor
639 // Set the ptr in the destination to the stack space
640 auto srcMemRefType = srcType.cast<MemRefType>();
641 int64_t rank = srcMemRefType.getRank();
642 // ptr = AllocaOp sizeof(MemRefDescriptor)
643 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
644 loc, transformed.source(), rewriter);
645 // voidptr = BitCastOp srcType* to void*
646 auto voidPtr =
647 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
648 .getResult();
649 // rank = ConstantOp srcRank
650 auto rankVal = rewriter.create<LLVM::ConstantOp>(
651 loc, typeConverter->convertType(rewriter.getIntegerType(64)),
652 rewriter.getI64IntegerAttr(rank));
653 // undef = UndefOp
654 UnrankedMemRefDescriptor memRefDesc =
655 UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
656 // d1 = InsertValueOp undef, rank, 0
657 memRefDesc.setRank(rewriter, loc, rankVal);
658 // d2 = InsertValueOp d1, voidptr, 1
659 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
660 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
661
662 } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
663 // Casting from unranked type to ranked.
664 // The operation is assumed to be doing a correct cast. If the destination
665 // type mismatches the unranked the type, it is undefined behavior.
666 UnrankedMemRefDescriptor memRefDesc(transformed.source());
667 // ptr = ExtractValueOp src, 1
668 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
669 // castPtr = BitCastOp i8* to structTy*
670 auto castPtr =
671 rewriter
672 .create<LLVM::BitcastOp>(
673 loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
674 .getResult();
675 // struct = LoadOp castPtr
676 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
677 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
678 } else {
679 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
680 }
681 }
682 };
683
684 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
685 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
686
687 LogicalResult
matchAndRewrite__anona8fa999b0111::MemRefCopyOpLowering688 matchAndRewrite(memref::CopyOp op, ArrayRef<Value> operands,
689 ConversionPatternRewriter &rewriter) const override {
690 auto loc = op.getLoc();
691 memref::CopyOp::Adaptor adaptor(operands);
692 auto srcType = op.source().getType().cast<BaseMemRefType>();
693 auto targetType = op.target().getType().cast<BaseMemRefType>();
694
695 // First make sure we have an unranked memref descriptor representation.
696 auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
697 auto rank = rewriter.create<LLVM::ConstantOp>(
698 loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
699 auto *typeConverter = getTypeConverter();
700 auto ptr =
701 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
702 auto voidPtr =
703 rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
704 .getResult();
705 auto unrankedType =
706 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
707 return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
708 unrankedType,
709 ValueRange{rank, voidPtr});
710 };
711
712 Value unrankedSource = srcType.hasRank()
713 ? makeUnranked(adaptor.source(), srcType)
714 : adaptor.source();
715 Value unrankedTarget = targetType.hasRank()
716 ? makeUnranked(adaptor.target(), targetType)
717 : adaptor.target();
718
719 // Now promote the unranked descriptors to the stack.
720 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
721 rewriter.getIndexAttr(1));
722 auto promote = [&](Value desc) {
723 auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
724 auto allocated =
725 rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
726 rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
727 return allocated;
728 };
729
730 auto sourcePtr = promote(unrankedSource);
731 auto targetPtr = promote(unrankedTarget);
732
733 auto elemSize = rewriter.create<LLVM::ConstantOp>(
734 loc, getIndexType(),
735 rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8));
736 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
737 op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
738 rewriter.create<LLVM::CallOp>(loc, copyFn,
739 ValueRange{elemSize, sourcePtr, targetPtr});
740 rewriter.eraseOp(op);
741
742 return success();
743 }
744 };
745
746 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
747 /// memref type. In unranked case, the fields are extracted from the underlying
748 /// ranked descriptor.
extractPointersAndOffset(Location loc,ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,Value originalOperand,Value convertedOperand,Value * allocatedPtr,Value * alignedPtr,Value * offset=nullptr)749 static void extractPointersAndOffset(Location loc,
750 ConversionPatternRewriter &rewriter,
751 LLVMTypeConverter &typeConverter,
752 Value originalOperand,
753 Value convertedOperand,
754 Value *allocatedPtr, Value *alignedPtr,
755 Value *offset = nullptr) {
756 Type operandType = originalOperand.getType();
757 if (operandType.isa<MemRefType>()) {
758 MemRefDescriptor desc(convertedOperand);
759 *allocatedPtr = desc.allocatedPtr(rewriter, loc);
760 *alignedPtr = desc.alignedPtr(rewriter, loc);
761 if (offset != nullptr)
762 *offset = desc.offset(rewriter, loc);
763 return;
764 }
765
766 unsigned memorySpace =
767 operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
768 Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
769 Type llvmElementType = typeConverter.convertType(elementType);
770 Type elementPtrPtrType = LLVM::LLVMPointerType::get(
771 LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
772
773 // Extract pointer to the underlying ranked memref descriptor and cast it to
774 // ElemType**.
775 UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
776 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
777
778 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
779 rewriter, loc, underlyingDescPtr, elementPtrPtrType);
780 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
781 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
782 if (offset != nullptr) {
783 *offset = UnrankedMemRefDescriptor::offset(
784 rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
785 }
786 }
787
788 struct MemRefReinterpretCastOpLowering
789 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
790 using ConvertOpToLLVMPattern<
791 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
792
793 LogicalResult
matchAndRewrite__anona8fa999b0111::MemRefReinterpretCastOpLowering794 matchAndRewrite(memref::ReinterpretCastOp castOp, ArrayRef<Value> operands,
795 ConversionPatternRewriter &rewriter) const override {
796 memref::ReinterpretCastOp::Adaptor adaptor(operands,
797 castOp->getAttrDictionary());
798 Type srcType = castOp.source().getType();
799
800 Value descriptor;
801 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
802 adaptor, &descriptor)))
803 return failure();
804 rewriter.replaceOp(castOp, {descriptor});
805 return success();
806 }
807
808 private:
convertSourceMemRefToDescriptor__anona8fa999b0111::MemRefReinterpretCastOpLowering809 LogicalResult convertSourceMemRefToDescriptor(
810 ConversionPatternRewriter &rewriter, Type srcType,
811 memref::ReinterpretCastOp castOp,
812 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
813 MemRefType targetMemRefType =
814 castOp.getResult().getType().cast<MemRefType>();
815 auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
816 .dyn_cast_or_null<LLVM::LLVMStructType>();
817 if (!llvmTargetDescriptorTy)
818 return failure();
819
820 // Create descriptor.
821 Location loc = castOp.getLoc();
822 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
823
824 // Set allocated and aligned pointers.
825 Value allocatedPtr, alignedPtr;
826 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
827 castOp.source(), adaptor.source(), &allocatedPtr,
828 &alignedPtr);
829 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
830 desc.setAlignedPtr(rewriter, loc, alignedPtr);
831
832 // Set offset.
833 if (castOp.isDynamicOffset(0))
834 desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
835 else
836 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
837
838 // Set sizes and strides.
839 unsigned dynSizeId = 0;
840 unsigned dynStrideId = 0;
841 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
842 if (castOp.isDynamicSize(i))
843 desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
844 else
845 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
846
847 if (castOp.isDynamicStride(i))
848 desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
849 else
850 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
851 }
852 *descriptor = desc;
853 return success();
854 }
855 };
856
857 struct MemRefReshapeOpLowering
858 : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
859 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
860
861 LogicalResult
matchAndRewrite__anona8fa999b0111::MemRefReshapeOpLowering862 matchAndRewrite(memref::ReshapeOp reshapeOp, ArrayRef<Value> operands,
863 ConversionPatternRewriter &rewriter) const override {
864 auto *op = reshapeOp.getOperation();
865 memref::ReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
866 Type srcType = reshapeOp.source().getType();
867
868 Value descriptor;
869 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
870 adaptor, &descriptor)))
871 return failure();
872 rewriter.replaceOp(op, {descriptor});
873 return success();
874 }
875
876 private:
877 LogicalResult
convertSourceMemRefToDescriptor__anona8fa999b0111::MemRefReshapeOpLowering878 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
879 Type srcType, memref::ReshapeOp reshapeOp,
880 memref::ReshapeOp::Adaptor adaptor,
881 Value *descriptor) const {
882 // Conversion for statically-known shape args is performed via
883 // `memref_reinterpret_cast`.
884 auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
885 if (shapeMemRefType.hasStaticShape())
886 return failure();
887
888 // The shape is a rank-1 tensor with unknown length.
889 Location loc = reshapeOp.getLoc();
890 MemRefDescriptor shapeDesc(adaptor.shape());
891 Value resultRank = shapeDesc.size(rewriter, loc, 0);
892
893 // Extract address space and element type.
894 auto targetType =
895 reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
896 unsigned addressSpace = targetType.getMemorySpaceAsInt();
897 Type elementType = targetType.getElementType();
898
899 // Create the unranked memref descriptor that holds the ranked one. The
900 // inner descriptor is allocated on stack.
901 auto targetDesc = UnrankedMemRefDescriptor::undef(
902 rewriter, loc, typeConverter->convertType(targetType));
903 targetDesc.setRank(rewriter, loc, resultRank);
904 SmallVector<Value, 4> sizes;
905 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
906 targetDesc, sizes);
907 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
908 loc, getVoidPtrType(), sizes.front(), llvm::None);
909 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
910
911 // Extract pointers and offset from the source memref.
912 Value allocatedPtr, alignedPtr, offset;
913 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
914 reshapeOp.source(), adaptor.source(),
915 &allocatedPtr, &alignedPtr, &offset);
916
917 // Set pointers and offset.
918 Type llvmElementType = typeConverter->convertType(elementType);
919 auto elementPtrPtrType = LLVM::LLVMPointerType::get(
920 LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
921 UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
922 elementPtrPtrType, allocatedPtr);
923 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
924 underlyingDescPtr,
925 elementPtrPtrType, alignedPtr);
926 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
927 underlyingDescPtr, elementPtrPtrType,
928 offset);
929
930 // Use the offset pointer as base for further addressing. Copy over the new
931 // shape and compute strides. For this, we create a loop from rank-1 to 0.
932 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
933 rewriter, loc, *getTypeConverter(), underlyingDescPtr,
934 elementPtrPtrType);
935 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
936 rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
937 Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
938 Value oneIndex = createIndexConstant(rewriter, loc, 1);
939 Value resultRankMinusOne =
940 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
941
942 Block *initBlock = rewriter.getInsertionBlock();
943 Type indexType = getTypeConverter()->getIndexType();
944 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
945
946 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
947 {indexType, indexType});
948
949 // Move the remaining initBlock ops to condBlock.
950 Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
951 rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
952
953 rewriter.setInsertionPointToEnd(initBlock);
954 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
955 condBlock);
956 rewriter.setInsertionPointToStart(condBlock);
957 Value indexArg = condBlock->getArgument(0);
958 Value strideArg = condBlock->getArgument(1);
959
960 Value zeroIndex = createIndexConstant(rewriter, loc, 0);
961 Value pred = rewriter.create<LLVM::ICmpOp>(
962 loc, IntegerType::get(rewriter.getContext(), 1),
963 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
964
965 Block *bodyBlock =
966 rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
967 rewriter.setInsertionPointToStart(bodyBlock);
968
969 // Copy size from shape to descriptor.
970 Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
971 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
972 loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
973 Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
974 UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
975 targetSizesBase, indexArg, size);
976
977 // Write stride value and compute next one.
978 UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
979 targetStridesBase, indexArg, strideArg);
980 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
981
982 // Decrement loop counter and branch back.
983 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
984 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
985 condBlock);
986
987 Block *remainder =
988 rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
989
990 // Hook up the cond exit to the remainder.
991 rewriter.setInsertionPointToEnd(condBlock);
992 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
993 llvm::None);
994
995 // Reset position to beginning of new remainder block.
996 rewriter.setInsertionPointToStart(remainder);
997
998 *descriptor = targetDesc;
999 return success();
1000 }
1001 };
1002
1003 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
1004 /// `Value`s.
getAsValues(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<OpFoldResult> valueOrAttrVec)1005 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
1006 Type &llvmIndexType,
1007 ArrayRef<OpFoldResult> valueOrAttrVec) {
1008 return llvm::to_vector<4>(
1009 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
1010 if (auto attr = value.dyn_cast<Attribute>())
1011 return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
1012 return value.get<Value>();
1013 }));
1014 }
1015
1016 /// Compute a map that for a given dimension of the expanded type gives the
1017 /// dimension in the collapsed type it maps to. Essentially its the inverse of
1018 /// the `reassocation` maps.
1019 static DenseMap<int64_t, int64_t>
getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation)1020 getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
1021 llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
1022 for (auto &en : enumerate(reassociation)) {
1023 for (auto dim : en.value())
1024 expandedDimToCollapsedDim[dim] = en.index();
1025 }
1026 return expandedDimToCollapsedDim;
1027 }
1028
1029 static OpFoldResult
getExpandedOutputDimSize(OpBuilder & b,Location loc,Type & llvmIndexType,int64_t outDimIndex,ArrayRef<int64_t> outStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> inStaticShape,ArrayRef<ReassociationIndices> reassocation,DenseMap<int64_t,int64_t> & outDimToInDimMap)1030 getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
1031 int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
1032 MemRefDescriptor &inDesc,
1033 ArrayRef<int64_t> inStaticShape,
1034 ArrayRef<ReassociationIndices> reassocation,
1035 DenseMap<int64_t, int64_t> &outDimToInDimMap) {
1036 int64_t outDimSize = outStaticShape[outDimIndex];
1037 if (!ShapedType::isDynamic(outDimSize))
1038 return b.getIndexAttr(outDimSize);
1039
1040 // Calculate the multiplication of all the out dim sizes except the
1041 // current dim.
1042 int64_t inDimIndex = outDimToInDimMap[outDimIndex];
1043 int64_t otherDimSizesMul = 1;
1044 for (auto otherDimIndex : reassocation[inDimIndex]) {
1045 if (otherDimIndex == static_cast<unsigned>(outDimIndex))
1046 continue;
1047 int64_t otherDimSize = outStaticShape[otherDimIndex];
1048 assert(!ShapedType::isDynamic(otherDimSize) &&
1049 "single dimension cannot be expanded into multiple dynamic "
1050 "dimensions");
1051 otherDimSizesMul *= otherDimSize;
1052 }
1053
1054 // outDimSize = inDimSize / otherOutDimSizesMul
1055 int64_t inDimSize = inStaticShape[inDimIndex];
1056 Value inDimSizeDynamic =
1057 ShapedType::isDynamic(inDimSize)
1058 ? inDesc.size(b, loc, inDimIndex)
1059 : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1060 b.getIndexAttr(inDimSize));
1061 Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
1062 loc, inDimSizeDynamic,
1063 b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1064 b.getIndexAttr(otherDimSizesMul)));
1065 return outDimSizeDynamic;
1066 }
1067
getCollapsedOutputDimSize(OpBuilder & b,Location loc,Type & llvmIndexType,int64_t outDimIndex,int64_t outDimSize,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<ReassociationIndices> reassocation)1068 static OpFoldResult getCollapsedOutputDimSize(
1069 OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
1070 int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
1071 MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
1072 if (!ShapedType::isDynamic(outDimSize))
1073 return b.getIndexAttr(outDimSize);
1074
1075 Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
1076 Value outDimSizeDynamic = c1;
1077 for (auto inDimIndex : reassocation[outDimIndex]) {
1078 int64_t inDimSize = inStaticShape[inDimIndex];
1079 Value inDimSizeDynamic =
1080 ShapedType::isDynamic(inDimSize)
1081 ? inDesc.size(b, loc, inDimIndex)
1082 : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
1083 b.getIndexAttr(inDimSize));
1084 outDimSizeDynamic =
1085 b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
1086 }
1087 return outDimSizeDynamic;
1088 }
1089
1090 static SmallVector<OpFoldResult, 4>
getCollapsedOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassocation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1091 getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1092 ArrayRef<ReassociationIndices> reassocation,
1093 ArrayRef<int64_t> inStaticShape,
1094 MemRefDescriptor &inDesc,
1095 ArrayRef<int64_t> outStaticShape) {
1096 return llvm::to_vector<4>(llvm::map_range(
1097 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1098 return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1099 outStaticShape[outDimIndex],
1100 inStaticShape, inDesc, reassocation);
1101 }));
1102 }
1103
1104 static SmallVector<OpFoldResult, 4>
getExpandedOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassocation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1105 getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1106 ArrayRef<ReassociationIndices> reassocation,
1107 ArrayRef<int64_t> inStaticShape,
1108 MemRefDescriptor &inDesc,
1109 ArrayRef<int64_t> outStaticShape) {
1110 DenseMap<int64_t, int64_t> outDimToInDimMap =
1111 getExpandedDimToCollapsedDimMap(reassocation);
1112 return llvm::to_vector<4>(llvm::map_range(
1113 llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
1114 return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
1115 outStaticShape, inDesc, inStaticShape,
1116 reassocation, outDimToInDimMap);
1117 }));
1118 }
1119
1120 static SmallVector<Value>
getDynamicOutputShape(OpBuilder & b,Location loc,Type & llvmIndexType,ArrayRef<ReassociationIndices> reassocation,ArrayRef<int64_t> inStaticShape,MemRefDescriptor & inDesc,ArrayRef<int64_t> outStaticShape)1121 getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
1122 ArrayRef<ReassociationIndices> reassocation,
1123 ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
1124 ArrayRef<int64_t> outStaticShape) {
1125 return outStaticShape.size() < inStaticShape.size()
1126 ? getAsValues(b, loc, llvmIndexType,
1127 getCollapsedOutputShape(b, loc, llvmIndexType,
1128 reassocation, inStaticShape,
1129 inDesc, outStaticShape))
1130 : getAsValues(b, loc, llvmIndexType,
1131 getExpandedOutputShape(b, loc, llvmIndexType,
1132 reassocation, inStaticShape,
1133 inDesc, outStaticShape));
1134 }
1135
1136 // ReshapeOp creates a new view descriptor of the proper rank.
1137 // For now, the only conversion supported is for target MemRef with static sizes
1138 // and strides.
1139 template <typename ReshapeOp>
1140 class ReassociatingReshapeOpConversion
1141 : public ConvertOpToLLVMPattern<ReshapeOp> {
1142 public:
1143 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1144 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1145
1146 LogicalResult
matchAndRewrite(ReshapeOp reshapeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1147 matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
1148 ConversionPatternRewriter &rewriter) const override {
1149 MemRefType dstType = reshapeOp.getResultType();
1150 MemRefType srcType = reshapeOp.getSrcType();
1151 if (!srcType.getAffineMaps().empty() || !dstType.getAffineMaps().empty()) {
1152 return rewriter.notifyMatchFailure(reshapeOp,
1153 "only empty layout map is supported");
1154 }
1155
1156 int64_t offset;
1157 SmallVector<int64_t, 4> strides;
1158 if (failed(getStridesAndOffset(dstType, strides, offset))) {
1159 return rewriter.notifyMatchFailure(
1160 reshapeOp, "failed to get stride and offset exprs");
1161 }
1162
1163 ReshapeOpAdaptor adaptor(operands);
1164 MemRefDescriptor srcDesc(adaptor.src());
1165 Location loc = reshapeOp->getLoc();
1166 auto dstDesc = MemRefDescriptor::undef(
1167 rewriter, loc, this->typeConverter->convertType(dstType));
1168 dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
1169 dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
1170 dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
1171
1172 ArrayRef<int64_t> srcStaticShape = srcType.getShape();
1173 ArrayRef<int64_t> dstStaticShape = dstType.getShape();
1174 Type llvmIndexType =
1175 this->typeConverter->convertType(rewriter.getIndexType());
1176 SmallVector<Value> dstShape = getDynamicOutputShape(
1177 rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
1178 srcStaticShape, srcDesc, dstStaticShape);
1179 for (auto &en : llvm::enumerate(dstShape))
1180 dstDesc.setSize(rewriter, loc, en.index(), en.value());
1181
1182 auto isStaticStride = [](int64_t stride) {
1183 return !ShapedType::isDynamicStrideOrOffset(stride);
1184 };
1185 if (llvm::all_of(strides, isStaticStride)) {
1186 for (auto &en : llvm::enumerate(strides))
1187 dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
1188 } else {
1189 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
1190 rewriter.getIndexAttr(1));
1191 Value stride = c1;
1192 for (auto dimIndex :
1193 llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
1194 dstDesc.setStride(rewriter, loc, dimIndex, stride);
1195 stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
1196 }
1197 }
1198 rewriter.replaceOp(reshapeOp, {dstDesc});
1199 return success();
1200 }
1201 };
1202
1203 /// Conversion pattern that transforms a subview op into:
1204 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1205 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1206 /// and stride.
1207 /// The subview op is replaced by the descriptor.
1208 struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1209 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1210
1211 LogicalResult
matchAndRewrite__anona8fa999b0111::SubViewOpLowering1212 matchAndRewrite(memref::SubViewOp subViewOp, ArrayRef<Value> operands,
1213 ConversionPatternRewriter &rewriter) const override {
1214 auto loc = subViewOp.getLoc();
1215
1216 auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
1217 auto sourceElementTy =
1218 typeConverter->convertType(sourceMemRefType.getElementType());
1219
1220 auto viewMemRefType = subViewOp.getType();
1221 auto inferredType = memref::SubViewOp::inferResultType(
1222 subViewOp.getSourceType(),
1223 extractFromI64ArrayAttr(subViewOp.static_offsets()),
1224 extractFromI64ArrayAttr(subViewOp.static_sizes()),
1225 extractFromI64ArrayAttr(subViewOp.static_strides()))
1226 .cast<MemRefType>();
1227 auto targetElementTy =
1228 typeConverter->convertType(viewMemRefType.getElementType());
1229 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1230 if (!sourceElementTy || !targetDescTy || !targetElementTy ||
1231 !LLVM::isCompatibleType(sourceElementTy) ||
1232 !LLVM::isCompatibleType(targetElementTy) ||
1233 !LLVM::isCompatibleType(targetDescTy))
1234 return failure();
1235
1236 // Extract the offset and strides from the type.
1237 int64_t offset;
1238 SmallVector<int64_t, 4> strides;
1239 auto successStrides = getStridesAndOffset(inferredType, strides, offset);
1240 if (failed(successStrides))
1241 return failure();
1242
1243 // Create the descriptor.
1244 if (!LLVM::isCompatibleType(operands.front().getType()))
1245 return failure();
1246 MemRefDescriptor sourceMemRef(operands.front());
1247 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1248
1249 // Copy the buffer pointer from the old descriptor to the new one.
1250 Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
1251 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1252 loc,
1253 LLVM::LLVMPointerType::get(targetElementTy,
1254 viewMemRefType.getMemorySpaceAsInt()),
1255 extracted);
1256 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1257
1258 // Copy the aligned pointer from the old descriptor to the new one.
1259 extracted = sourceMemRef.alignedPtr(rewriter, loc);
1260 bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1261 loc,
1262 LLVM::LLVMPointerType::get(targetElementTy,
1263 viewMemRefType.getMemorySpaceAsInt()),
1264 extracted);
1265 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1266
1267 auto shape = viewMemRefType.getShape();
1268 auto inferredShape = inferredType.getShape();
1269 size_t inferredShapeRank = inferredShape.size();
1270 size_t resultShapeRank = shape.size();
1271 llvm::SmallDenseSet<unsigned> unusedDims =
1272 computeRankReductionMask(inferredShape, shape).getValue();
1273
1274 // Extract strides needed to compute offset.
1275 SmallVector<Value, 4> strideValues;
1276 strideValues.reserve(inferredShapeRank);
1277 for (unsigned i = 0; i < inferredShapeRank; ++i)
1278 strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
1279
1280 // Offset.
1281 auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1282 if (!ShapedType::isDynamicStrideOrOffset(offset)) {
1283 targetMemRef.setConstantOffset(rewriter, loc, offset);
1284 } else {
1285 Value baseOffset = sourceMemRef.offset(rewriter, loc);
1286 // `inferredShapeRank` may be larger than the number of offset operands
1287 // because of trailing semantics. In this case, the offset is guaranteed
1288 // to be interpreted as 0 and we can just skip the extra dimensions.
1289 for (unsigned i = 0, e = std::min(inferredShapeRank,
1290 subViewOp.getMixedOffsets().size());
1291 i < e; ++i) {
1292 Value offset =
1293 // TODO: need OpFoldResult ODS adaptor to clean this up.
1294 subViewOp.isDynamicOffset(i)
1295 ? operands[subViewOp.getIndexOfDynamicOffset(i)]
1296 : rewriter.create<LLVM::ConstantOp>(
1297 loc, llvmIndexType,
1298 rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
1299 Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
1300 baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
1301 }
1302 targetMemRef.setOffset(rewriter, loc, baseOffset);
1303 }
1304
1305 // Update sizes and strides.
1306 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
1307 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
1308 assert(mixedSizes.size() == mixedStrides.size() &&
1309 "expected sizes and strides of equal length");
1310 for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
1311 i >= 0 && j >= 0; --i) {
1312 if (unusedDims.contains(i))
1313 continue;
1314
1315 // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
1316 // In this case, the size is guaranteed to be interpreted as Dim and the
1317 // stride as 1.
1318 Value size, stride;
1319 if (static_cast<unsigned>(i) >= mixedSizes.size()) {
1320 // If the static size is available, use it directly. This is similar to
1321 // the folding of dim(constant-op) but removes the need for dim to be
1322 // aware of LLVM constants and for this pass to be aware of std
1323 // constants.
1324 int64_t staticSize =
1325 subViewOp.source().getType().cast<MemRefType>().getShape()[i];
1326 if (staticSize != ShapedType::kDynamicSize) {
1327 size = rewriter.create<LLVM::ConstantOp>(
1328 loc, llvmIndexType, rewriter.getI64IntegerAttr(staticSize));
1329 } else {
1330 Value pos = rewriter.create<LLVM::ConstantOp>(
1331 loc, llvmIndexType, rewriter.getI64IntegerAttr(i));
1332 Value dim =
1333 rewriter.create<memref::DimOp>(loc, subViewOp.source(), pos);
1334 auto cast = rewriter.create<UnrealizedConversionCastOp>(
1335 loc, llvmIndexType, dim);
1336 size = cast.getResult(0);
1337 }
1338 stride = rewriter.create<LLVM::ConstantOp>(
1339 loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
1340 } else {
1341 // TODO: need OpFoldResult ODS adaptor to clean this up.
1342 size =
1343 subViewOp.isDynamicSize(i)
1344 ? operands[subViewOp.getIndexOfDynamicSize(i)]
1345 : rewriter.create<LLVM::ConstantOp>(
1346 loc, llvmIndexType,
1347 rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
1348 if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
1349 stride = rewriter.create<LLVM::ConstantOp>(
1350 loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
1351 } else {
1352 stride = subViewOp.isDynamicStride(i)
1353 ? operands[subViewOp.getIndexOfDynamicStride(i)]
1354 : rewriter.create<LLVM::ConstantOp>(
1355 loc, llvmIndexType,
1356 rewriter.getI64IntegerAttr(
1357 subViewOp.getStaticStride(i)));
1358 stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
1359 }
1360 }
1361 targetMemRef.setSize(rewriter, loc, j, size);
1362 targetMemRef.setStride(rewriter, loc, j, stride);
1363 j--;
1364 }
1365
1366 rewriter.replaceOp(subViewOp, {targetMemRef});
1367 return success();
1368 }
1369 };
1370
1371 /// Conversion pattern that transforms a transpose op into:
1372 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1373 /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1374 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1375 /// and stride. Size and stride are permutations of the original values.
1376 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1377 /// The transpose op is replaced by the alloca'ed pointer.
1378 class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1379 public:
1380 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1381
1382 LogicalResult
matchAndRewrite(memref::TransposeOp transposeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1383 matchAndRewrite(memref::TransposeOp transposeOp, ArrayRef<Value> operands,
1384 ConversionPatternRewriter &rewriter) const override {
1385 auto loc = transposeOp.getLoc();
1386 memref::TransposeOpAdaptor adaptor(operands);
1387 MemRefDescriptor viewMemRef(adaptor.in());
1388
1389 // No permutation, early exit.
1390 if (transposeOp.permutation().isIdentity())
1391 return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1392
1393 auto targetMemRef = MemRefDescriptor::undef(
1394 rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
1395
1396 // Copy the base and aligned pointers from the old descriptor to the new
1397 // one.
1398 targetMemRef.setAllocatedPtr(rewriter, loc,
1399 viewMemRef.allocatedPtr(rewriter, loc));
1400 targetMemRef.setAlignedPtr(rewriter, loc,
1401 viewMemRef.alignedPtr(rewriter, loc));
1402
1403 // Copy the offset pointer from the old descriptor to the new one.
1404 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
1405
1406 // Iterate over the dimensions and apply size/stride permutation.
1407 for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
1408 int sourcePos = en.index();
1409 int targetPos = en.value().cast<AffineDimExpr>().getPosition();
1410 targetMemRef.setSize(rewriter, loc, targetPos,
1411 viewMemRef.size(rewriter, loc, sourcePos));
1412 targetMemRef.setStride(rewriter, loc, targetPos,
1413 viewMemRef.stride(rewriter, loc, sourcePos));
1414 }
1415
1416 rewriter.replaceOp(transposeOp, {targetMemRef});
1417 return success();
1418 }
1419 };
1420
1421 /// Conversion pattern that transforms an op into:
1422 /// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1423 /// 2. Updates to the descriptor to introduce the data ptr, offset, size
1424 /// and stride.
1425 /// The view op is replaced by the descriptor.
1426 struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1427 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1428
1429 // Build and return the value for the idx^th shape dimension, either by
1430 // returning the constant shape dimension or counting the proper dynamic size.
getSize__anona8fa999b0111::ViewOpLowering1431 Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1432 ArrayRef<int64_t> shape, ValueRange dynamicSizes,
1433 unsigned idx) const {
1434 assert(idx < shape.size());
1435 if (!ShapedType::isDynamic(shape[idx]))
1436 return createIndexConstant(rewriter, loc, shape[idx]);
1437 // Count the number of dynamic dims in range [0, idx]
1438 unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
1439 return ShapedType::isDynamic(v);
1440 });
1441 return dynamicSizes[nDynamic];
1442 }
1443
1444 // Build and return the idx^th stride, either by returning the constant stride
1445 // or by computing the dynamic stride from the current `runningStride` and
1446 // `nextSize`. The caller should keep a running stride and update it with the
1447 // result returned by this function.
getStride__anona8fa999b0111::ViewOpLowering1448 Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1449 ArrayRef<int64_t> strides, Value nextSize,
1450 Value runningStride, unsigned idx) const {
1451 assert(idx < strides.size());
1452 if (!MemRefType::isDynamicStrideOrOffset(strides[idx]))
1453 return createIndexConstant(rewriter, loc, strides[idx]);
1454 if (nextSize)
1455 return runningStride
1456 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1457 : nextSize;
1458 assert(!runningStride);
1459 return createIndexConstant(rewriter, loc, 1);
1460 }
1461
1462 LogicalResult
matchAndRewrite__anona8fa999b0111::ViewOpLowering1463 matchAndRewrite(memref::ViewOp viewOp, ArrayRef<Value> operands,
1464 ConversionPatternRewriter &rewriter) const override {
1465 auto loc = viewOp.getLoc();
1466 memref::ViewOpAdaptor adaptor(operands);
1467
1468 auto viewMemRefType = viewOp.getType();
1469 auto targetElementTy =
1470 typeConverter->convertType(viewMemRefType.getElementType());
1471 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1472 if (!targetDescTy || !targetElementTy ||
1473 !LLVM::isCompatibleType(targetElementTy) ||
1474 !LLVM::isCompatibleType(targetDescTy))
1475 return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1476 failure();
1477
1478 int64_t offset;
1479 SmallVector<int64_t, 4> strides;
1480 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1481 if (failed(successStrides))
1482 return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1483 assert(offset == 0 && "expected offset to be 0");
1484
1485 // Create the descriptor.
1486 MemRefDescriptor sourceMemRef(adaptor.source());
1487 auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
1488
1489 // Field 1: Copy the allocated pointer, used for malloc/free.
1490 Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
1491 auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
1492 Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1493 loc,
1494 LLVM::LLVMPointerType::get(targetElementTy,
1495 srcMemRefType.getMemorySpaceAsInt()),
1496 allocatedPtr);
1497 targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
1498
1499 // Field 2: Copy the actual aligned pointer to payload.
1500 Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
1501 alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
1502 alignedPtr, adaptor.byte_shift());
1503 bitcastPtr = rewriter.create<LLVM::BitcastOp>(
1504 loc,
1505 LLVM::LLVMPointerType::get(targetElementTy,
1506 srcMemRefType.getMemorySpaceAsInt()),
1507 alignedPtr);
1508 targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
1509
1510 // Field 3: The offset in the resulting type must be 0. This is because of
1511 // the type change: an offset on srcType* may not be expressible as an
1512 // offset on dstType*.
1513 targetMemRef.setOffset(rewriter, loc,
1514 createIndexConstant(rewriter, loc, offset));
1515
1516 // Early exit for 0-D corner case.
1517 if (viewMemRefType.getRank() == 0)
1518 return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1519
1520 // Fields 4 and 5: Update sizes and strides.
1521 if (strides.back() != 1)
1522 return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1523 failure();
1524 Value stride = nullptr, nextSize = nullptr;
1525 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1526 // Update size.
1527 Value size =
1528 getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
1529 targetMemRef.setSize(rewriter, loc, i, size);
1530 // Update stride.
1531 stride = getStride(rewriter, loc, strides, nextSize, stride, i);
1532 targetMemRef.setStride(rewriter, loc, i, stride);
1533 nextSize = size;
1534 }
1535
1536 rewriter.replaceOp(viewOp, {targetMemRef});
1537 return success();
1538 }
1539 };
1540
1541 } // namespace
1542
populateMemRefToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)1543 void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
1544 RewritePatternSet &patterns) {
1545 // clang-format off
1546 patterns.add<
1547 AllocaOpLowering,
1548 AllocaScopeOpLowering,
1549 AssumeAlignmentOpLowering,
1550 DimOpLowering,
1551 DeallocOpLowering,
1552 GlobalMemrefOpLowering,
1553 GetGlobalMemrefOpLowering,
1554 LoadOpLowering,
1555 MemRefCastOpLowering,
1556 MemRefCopyOpLowering,
1557 MemRefReinterpretCastOpLowering,
1558 MemRefReshapeOpLowering,
1559 PrefetchOpLowering,
1560 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1561 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1562 StoreOpLowering,
1563 SubViewOpLowering,
1564 TransposeOpLowering,
1565 ViewOpLowering>(converter);
1566 // clang-format on
1567 auto allocLowering = converter.getOptions().allocLowering;
1568 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1569 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
1570 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1571 patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
1572 }
1573
1574 namespace {
1575 struct MemRefToLLVMPass : public ConvertMemRefToLLVMBase<MemRefToLLVMPass> {
1576 MemRefToLLVMPass() = default;
1577
runOnOperation__anona8fa999b0911::MemRefToLLVMPass1578 void runOnOperation() override {
1579 Operation *op = getOperation();
1580 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1581 LowerToLLVMOptions options(&getContext(),
1582 dataLayoutAnalysis.getAtOrAbove(op));
1583 options.allocLowering =
1584 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1585 : LowerToLLVMOptions::AllocLowering::Malloc);
1586 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1587 options.overrideIndexBitwidth(indexBitwidth);
1588
1589 LLVMTypeConverter typeConverter(&getContext(), options,
1590 &dataLayoutAnalysis);
1591 RewritePatternSet patterns(&getContext());
1592 populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
1593 LLVMConversionTarget target(getContext());
1594 target.addLegalOp<FuncOp>();
1595 if (failed(applyPartialConversion(op, target, std::move(patterns))))
1596 signalPassFailure();
1597 }
1598 };
1599 } // namespace
1600
createMemRefToLLVMPass()1601 std::unique_ptr<Pass> mlir::createMemRefToLLVMPass() {
1602 return std::make_unique<MemRefToLLVMPass>();
1603 }
1604