1 //===- MemRefBuilder.cpp - Helper for LLVM MemRef equivalents -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
10 #include "MemRefDescriptor.h"
11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/Support/MathExtras.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // MemRefDescriptor implementation
20 //===----------------------------------------------------------------------===//
21 
22 /// Construct a helper for the given descriptor value.
MemRefDescriptor(Value descriptor)23 MemRefDescriptor::MemRefDescriptor(Value descriptor)
24     : StructBuilder(descriptor) {
25   assert(value != nullptr && "value cannot be null");
26   indexType = value.getType()
27                   .cast<LLVM::LLVMStructType>()
28                   .getBody()[kOffsetPosInMemRefDescriptor];
29 }
30 
31 /// Builds IR creating an `undef` value of the descriptor type.
undef(OpBuilder & builder,Location loc,Type descriptorType)32 MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
33                                          Type descriptorType) {
34 
35   Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
36   return MemRefDescriptor(descriptor);
37 }
38 
39 /// Builds IR creating a MemRef descriptor that represents `type` and
40 /// populates it with static shape and stride information extracted from the
41 /// type.
42 MemRefDescriptor
fromStaticShape(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,MemRefType type,Value memory)43 MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
44                                   LLVMTypeConverter &typeConverter,
45                                   MemRefType type, Value memory) {
46   assert(type.hasStaticShape() && "unexpected dynamic shape");
47 
48   // Extract all strides and offsets and verify they are static.
49   int64_t offset;
50   SmallVector<int64_t, 4> strides;
51   auto result = getStridesAndOffset(type, strides, offset);
52   (void)result;
53   assert(succeeded(result) && "unexpected failure in stride computation");
54   assert(!MemRefType::isDynamicStrideOrOffset(offset) &&
55          "expected static offset");
56   assert(!llvm::any_of(strides, [](int64_t stride) {
57     return MemRefType::isDynamicStrideOrOffset(stride);
58   }) && "expected static strides");
59 
60   auto convertedType = typeConverter.convertType(type);
61   assert(convertedType && "unexpected failure in memref type conversion");
62 
63   auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
64   descr.setAllocatedPtr(builder, loc, memory);
65   descr.setAlignedPtr(builder, loc, memory);
66   descr.setConstantOffset(builder, loc, offset);
67 
68   // Fill in sizes and strides
69   for (unsigned i = 0, e = type.getRank(); i != e; ++i) {
70     descr.setConstantSize(builder, loc, i, type.getDimSize(i));
71     descr.setConstantStride(builder, loc, i, strides[i]);
72   }
73   return descr;
74 }
75 
76 /// Builds IR extracting the allocated pointer from the descriptor.
allocatedPtr(OpBuilder & builder,Location loc)77 Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
78   return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
79 }
80 
81 /// Builds IR inserting the allocated pointer into the descriptor.
setAllocatedPtr(OpBuilder & builder,Location loc,Value ptr)82 void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
83                                        Value ptr) {
84   setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
85 }
86 
87 /// Builds IR extracting the aligned pointer from the descriptor.
alignedPtr(OpBuilder & builder,Location loc)88 Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
89   return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
90 }
91 
92 /// Builds IR inserting the aligned pointer into the descriptor.
setAlignedPtr(OpBuilder & builder,Location loc,Value ptr)93 void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
94                                      Value ptr) {
95   setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
96 }
97 
98 // Creates a constant Op producing a value of `resultType` from an index-typed
99 // integer attribute.
createIndexAttrConstant(OpBuilder & builder,Location loc,Type resultType,int64_t value)100 static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
101                                      Type resultType, int64_t value) {
102   return builder.create<LLVM::ConstantOp>(
103       loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
104 }
105 
106 /// Builds IR extracting the offset from the descriptor.
offset(OpBuilder & builder,Location loc)107 Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
108   return builder.create<LLVM::ExtractValueOp>(
109       loc, indexType, value,
110       builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
111 }
112 
113 /// Builds IR inserting the offset into the descriptor.
setOffset(OpBuilder & builder,Location loc,Value offset)114 void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
115                                  Value offset) {
116   value = builder.create<LLVM::InsertValueOp>(
117       loc, structType, value, offset,
118       builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
119 }
120 
121 /// Builds IR inserting the offset into the descriptor.
setConstantOffset(OpBuilder & builder,Location loc,uint64_t offset)122 void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
123                                          uint64_t offset) {
124   setOffset(builder, loc,
125             createIndexAttrConstant(builder, loc, indexType, offset));
126 }
127 
128 /// Builds IR extracting the pos-th size from the descriptor.
size(OpBuilder & builder,Location loc,unsigned pos)129 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
130   return builder.create<LLVM::ExtractValueOp>(
131       loc, indexType, value,
132       builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
133 }
134 
size(OpBuilder & builder,Location loc,Value pos,int64_t rank)135 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
136                              int64_t rank) {
137   auto indexPtrTy = LLVM::LLVMPointerType::get(indexType);
138   auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank);
139   auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy);
140 
141   // Copy size values to stack-allocated memory.
142   auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
143   auto one = createIndexAttrConstant(builder, loc, indexType, 1);
144   auto sizes = builder.create<LLVM::ExtractValueOp>(
145       loc, arrayTy, value,
146       builder.getI64ArrayAttr({kSizePosInMemRefDescriptor}));
147   auto sizesPtr =
148       builder.create<LLVM::AllocaOp>(loc, arrayPtrTy, one, /*alignment=*/0);
149   builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
150 
151   // Load an return size value of interest.
152   auto resultPtr = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizesPtr,
153                                                ValueRange({zero, pos}));
154   return builder.create<LLVM::LoadOp>(loc, resultPtr);
155 }
156 
157 /// Builds IR inserting the pos-th size into the descriptor
setSize(OpBuilder & builder,Location loc,unsigned pos,Value size)158 void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
159                                Value size) {
160   value = builder.create<LLVM::InsertValueOp>(
161       loc, structType, value, size,
162       builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
163 }
164 
setConstantSize(OpBuilder & builder,Location loc,unsigned pos,uint64_t size)165 void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
166                                        unsigned pos, uint64_t size) {
167   setSize(builder, loc, pos,
168           createIndexAttrConstant(builder, loc, indexType, size));
169 }
170 
171 /// Builds IR extracting the pos-th stride from the descriptor.
stride(OpBuilder & builder,Location loc,unsigned pos)172 Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
173   return builder.create<LLVM::ExtractValueOp>(
174       loc, indexType, value,
175       builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
176 }
177 
178 /// Builds IR inserting the pos-th stride into the descriptor
setStride(OpBuilder & builder,Location loc,unsigned pos,Value stride)179 void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
180                                  Value stride) {
181   value = builder.create<LLVM::InsertValueOp>(
182       loc, structType, value, stride,
183       builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
184 }
185 
setConstantStride(OpBuilder & builder,Location loc,unsigned pos,uint64_t stride)186 void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
187                                          unsigned pos, uint64_t stride) {
188   setStride(builder, loc, pos,
189             createIndexAttrConstant(builder, loc, indexType, stride));
190 }
191 
getElementPtrType()192 LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
193   return value.getType()
194       .cast<LLVM::LLVMStructType>()
195       .getBody()[kAlignedPtrPosInMemRefDescriptor]
196       .cast<LLVM::LLVMPointerType>();
197 }
198 
199 /// Creates a MemRef descriptor structure from a list of individual values
200 /// composing that descriptor, in the following order:
201 /// - allocated pointer;
202 /// - aligned pointer;
203 /// - offset;
204 /// - <rank> sizes;
205 /// - <rank> shapes;
206 /// where <rank> is the MemRef rank as provided in `type`.
pack(OpBuilder & builder,Location loc,LLVMTypeConverter & converter,MemRefType type,ValueRange values)207 Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
208                              LLVMTypeConverter &converter, MemRefType type,
209                              ValueRange values) {
210   Type llvmType = converter.convertType(type);
211   auto d = MemRefDescriptor::undef(builder, loc, llvmType);
212 
213   d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
214   d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
215   d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
216 
217   int64_t rank = type.getRank();
218   for (unsigned i = 0; i < rank; ++i) {
219     d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
220     d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
221   }
222 
223   return d;
224 }
225 
226 /// Builds IR extracting individual elements of a MemRef descriptor structure
227 /// and returning them as `results` list.
unpack(OpBuilder & builder,Location loc,Value packed,MemRefType type,SmallVectorImpl<Value> & results)228 void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
229                               MemRefType type,
230                               SmallVectorImpl<Value> &results) {
231   int64_t rank = type.getRank();
232   results.reserve(results.size() + getNumUnpackedValues(type));
233 
234   MemRefDescriptor d(packed);
235   results.push_back(d.allocatedPtr(builder, loc));
236   results.push_back(d.alignedPtr(builder, loc));
237   results.push_back(d.offset(builder, loc));
238   for (int64_t i = 0; i < rank; ++i)
239     results.push_back(d.size(builder, loc, i));
240   for (int64_t i = 0; i < rank; ++i)
241     results.push_back(d.stride(builder, loc, i));
242 }
243 
244 /// Returns the number of non-aggregate values that would be produced by
245 /// `unpack`.
getNumUnpackedValues(MemRefType type)246 unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
247   // Two pointers, offset, <rank> sizes, <rank> shapes.
248   return 3 + 2 * type.getRank();
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // MemRefDescriptorView implementation.
253 //===----------------------------------------------------------------------===//
254 
MemRefDescriptorView(ValueRange range)255 MemRefDescriptorView::MemRefDescriptorView(ValueRange range)
256     : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {}
257 
allocatedPtr()258 Value MemRefDescriptorView::allocatedPtr() {
259   return elements[kAllocatedPtrPosInMemRefDescriptor];
260 }
261 
alignedPtr()262 Value MemRefDescriptorView::alignedPtr() {
263   return elements[kAlignedPtrPosInMemRefDescriptor];
264 }
265 
offset()266 Value MemRefDescriptorView::offset() {
267   return elements[kOffsetPosInMemRefDescriptor];
268 }
269 
size(unsigned pos)270 Value MemRefDescriptorView::size(unsigned pos) {
271   return elements[kSizePosInMemRefDescriptor + pos];
272 }
273 
stride(unsigned pos)274 Value MemRefDescriptorView::stride(unsigned pos) {
275   return elements[kSizePosInMemRefDescriptor + rank + pos];
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // UnrankedMemRefDescriptor implementation
280 //===----------------------------------------------------------------------===//
281 
282 /// Construct a helper for the given descriptor value.
UnrankedMemRefDescriptor(Value descriptor)283 UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
284     : StructBuilder(descriptor) {}
285 
286 /// Builds IR creating an `undef` value of the descriptor type.
undef(OpBuilder & builder,Location loc,Type descriptorType)287 UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
288                                                          Location loc,
289                                                          Type descriptorType) {
290   Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
291   return UnrankedMemRefDescriptor(descriptor);
292 }
rank(OpBuilder & builder,Location loc)293 Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
294   return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
295 }
setRank(OpBuilder & builder,Location loc,Value v)296 void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
297                                        Value v) {
298   setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
299 }
memRefDescPtr(OpBuilder & builder,Location loc)300 Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
301                                               Location loc) {
302   return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
303 }
setMemRefDescPtr(OpBuilder & builder,Location loc,Value v)304 void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
305                                                 Location loc, Value v) {
306   setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
307 }
308 
309 /// Builds IR populating an unranked MemRef descriptor structure from a list
310 /// of individual constituent values in the following order:
311 /// - rank of the memref;
312 /// - pointer to the memref descriptor.
pack(OpBuilder & builder,Location loc,LLVMTypeConverter & converter,UnrankedMemRefType type,ValueRange values)313 Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
314                                      LLVMTypeConverter &converter,
315                                      UnrankedMemRefType type,
316                                      ValueRange values) {
317   Type llvmType = converter.convertType(type);
318   auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
319 
320   d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
321   d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
322   return d;
323 }
324 
325 /// Builds IR extracting individual elements that compose an unranked memref
326 /// descriptor and returns them as `results` list.
unpack(OpBuilder & builder,Location loc,Value packed,SmallVectorImpl<Value> & results)327 void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
328                                       Value packed,
329                                       SmallVectorImpl<Value> &results) {
330   UnrankedMemRefDescriptor d(packed);
331   results.reserve(results.size() + 2);
332   results.push_back(d.rank(builder, loc));
333   results.push_back(d.memRefDescPtr(builder, loc));
334 }
335 
computeSizes(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,ArrayRef<UnrankedMemRefDescriptor> values,SmallVectorImpl<Value> & sizes)336 void UnrankedMemRefDescriptor::computeSizes(
337     OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
338     ArrayRef<UnrankedMemRefDescriptor> values, SmallVectorImpl<Value> &sizes) {
339   if (values.empty())
340     return;
341 
342   // Cache the index type.
343   Type indexType = typeConverter.getIndexType();
344 
345   // Initialize shared constants.
346   Value one = createIndexAttrConstant(builder, loc, indexType, 1);
347   Value two = createIndexAttrConstant(builder, loc, indexType, 2);
348   Value pointerSize = createIndexAttrConstant(
349       builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8));
350   Value indexSize =
351       createIndexAttrConstant(builder, loc, indexType,
352                               ceilDiv(typeConverter.getIndexTypeBitwidth(), 8));
353 
354   sizes.reserve(sizes.size() + values.size());
355   for (UnrankedMemRefDescriptor desc : values) {
356     // Emit IR computing the memory necessary to store the descriptor. This
357     // assumes the descriptor to be
358     //   { type*, type*, index, index[rank], index[rank] }
359     // and densely packed, so the total size is
360     //   2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
361     // TODO: consider including the actual size (including eventual padding due
362     // to data layout) into the unranked descriptor.
363     Value doublePointerSize =
364         builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
365 
366     // (1 + 2 * rank) * sizeof(index)
367     Value rank = desc.rank(builder, loc);
368     Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank);
369     Value doubleRankIncremented =
370         builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one);
371     Value rankIndexSize = builder.create<LLVM::MulOp>(
372         loc, indexType, doubleRankIncremented, indexSize);
373 
374     // Total allocation size.
375     Value allocationSize = builder.create<LLVM::AddOp>(
376         loc, indexType, doublePointerSize, rankIndexSize);
377     sizes.push_back(allocationSize);
378   }
379 }
380 
allocatedPtr(OpBuilder & builder,Location loc,Value memRefDescPtr,Type elemPtrPtrType)381 Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
382                                              Value memRefDescPtr,
383                                              Type elemPtrPtrType) {
384 
385   Value elementPtrPtr =
386       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
387   return builder.create<LLVM::LoadOp>(loc, elementPtrPtr);
388 }
389 
setAllocatedPtr(OpBuilder & builder,Location loc,Value memRefDescPtr,Type elemPtrPtrType,Value allocatedPtr)390 void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
391                                                Value memRefDescPtr,
392                                                Type elemPtrPtrType,
393                                                Value allocatedPtr) {
394   Value elementPtrPtr =
395       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
396   builder.create<LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr);
397 }
398 
alignedPtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType)399 Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
400                                            LLVMTypeConverter &typeConverter,
401                                            Value memRefDescPtr,
402                                            Type elemPtrPtrType) {
403   Value elementPtrPtr =
404       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
405 
406   Value one =
407       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
408   Value alignedGep = builder.create<LLVM::GEPOp>(
409       loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
410   return builder.create<LLVM::LoadOp>(loc, alignedGep);
411 }
412 
setAlignedPtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType,Value alignedPtr)413 void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
414                                              LLVMTypeConverter &typeConverter,
415                                              Value memRefDescPtr,
416                                              Type elemPtrPtrType,
417                                              Value alignedPtr) {
418   Value elementPtrPtr =
419       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
420 
421   Value one =
422       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
423   Value alignedGep = builder.create<LLVM::GEPOp>(
424       loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
425   builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
426 }
427 
offset(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType)428 Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
429                                        LLVMTypeConverter &typeConverter,
430                                        Value memRefDescPtr,
431                                        Type elemPtrPtrType) {
432   Value elementPtrPtr =
433       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
434 
435   Value two =
436       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
437   Value offsetGep = builder.create<LLVM::GEPOp>(
438       loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
439   offsetGep = builder.create<LLVM::BitcastOp>(
440       loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
441   return builder.create<LLVM::LoadOp>(loc, offsetGep);
442 }
443 
setOffset(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType,Value offset)444 void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
445                                          LLVMTypeConverter &typeConverter,
446                                          Value memRefDescPtr,
447                                          Type elemPtrPtrType, Value offset) {
448   Value elementPtrPtr =
449       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
450 
451   Value two =
452       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
453   Value offsetGep = builder.create<LLVM::GEPOp>(
454       loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
455   offsetGep = builder.create<LLVM::BitcastOp>(
456       loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
457   builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
458 }
459 
sizeBasePtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMPointerType elemPtrPtrType)460 Value UnrankedMemRefDescriptor::sizeBasePtr(
461     OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
462     Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) {
463   Type elemPtrTy = elemPtrPtrType.getElementType();
464   Type indexTy = typeConverter.getIndexType();
465   Type structPtrTy =
466       LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral(
467           indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy}));
468   Value structPtr =
469       builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
470 
471   Type int32_type = typeConverter.convertType(builder.getI32Type());
472   Value zero =
473       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
474   Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
475                                                  builder.getI32IntegerAttr(3));
476   return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy),
477                                      structPtr, ValueRange({zero, three}));
478 }
479 
size(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value sizeBasePtr,Value index)480 Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
481                                      LLVMTypeConverter typeConverter,
482                                      Value sizeBasePtr, Value index) {
483   Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
484   Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
485                                                    ValueRange({index}));
486   return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
487 }
488 
setSize(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value sizeBasePtr,Value index,Value size)489 void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
490                                        LLVMTypeConverter typeConverter,
491                                        Value sizeBasePtr, Value index,
492                                        Value size) {
493   Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
494   Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
495                                                    ValueRange({index}));
496   builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
497 }
498 
strideBasePtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value sizeBasePtr,Value rank)499 Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
500                                               LLVMTypeConverter &typeConverter,
501                                               Value sizeBasePtr, Value rank) {
502   Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
503   return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
504                                      ValueRange({rank}));
505 }
506 
stride(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value strideBasePtr,Value index,Value stride)507 Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
508                                        LLVMTypeConverter typeConverter,
509                                        Value strideBasePtr, Value index,
510                                        Value stride) {
511   Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
512   Value strideStoreGep = builder.create<LLVM::GEPOp>(
513       loc, indexPtrTy, strideBasePtr, ValueRange({index}));
514   return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
515 }
516 
setStride(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value strideBasePtr,Value index,Value stride)517 void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
518                                          LLVMTypeConverter typeConverter,
519                                          Value strideBasePtr, Value index,
520                                          Value stride) {
521   Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
522   Value strideStoreGep = builder.create<LLVM::GEPOp>(
523       loc, indexPtrTy, strideBasePtr, ValueRange({index}));
524   builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
525 }
526