1 //===- MemRefBuilder.cpp - Helper for LLVM MemRef equivalents -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
10 #include "MemRefDescriptor.h"
11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/Support/MathExtras.h"
15
16 using namespace mlir;
17
18 //===----------------------------------------------------------------------===//
19 // MemRefDescriptor implementation
20 //===----------------------------------------------------------------------===//
21
22 /// Construct a helper for the given descriptor value.
MemRefDescriptor(Value descriptor)23 MemRefDescriptor::MemRefDescriptor(Value descriptor)
24 : StructBuilder(descriptor) {
25 assert(value != nullptr && "value cannot be null");
26 indexType = value.getType()
27 .cast<LLVM::LLVMStructType>()
28 .getBody()[kOffsetPosInMemRefDescriptor];
29 }
30
31 /// Builds IR creating an `undef` value of the descriptor type.
undef(OpBuilder & builder,Location loc,Type descriptorType)32 MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
33 Type descriptorType) {
34
35 Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
36 return MemRefDescriptor(descriptor);
37 }
38
39 /// Builds IR creating a MemRef descriptor that represents `type` and
40 /// populates it with static shape and stride information extracted from the
41 /// type.
42 MemRefDescriptor
fromStaticShape(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,MemRefType type,Value memory)43 MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
44 LLVMTypeConverter &typeConverter,
45 MemRefType type, Value memory) {
46 assert(type.hasStaticShape() && "unexpected dynamic shape");
47
48 // Extract all strides and offsets and verify they are static.
49 int64_t offset;
50 SmallVector<int64_t, 4> strides;
51 auto result = getStridesAndOffset(type, strides, offset);
52 (void)result;
53 assert(succeeded(result) && "unexpected failure in stride computation");
54 assert(!MemRefType::isDynamicStrideOrOffset(offset) &&
55 "expected static offset");
56 assert(!llvm::any_of(strides, [](int64_t stride) {
57 return MemRefType::isDynamicStrideOrOffset(stride);
58 }) && "expected static strides");
59
60 auto convertedType = typeConverter.convertType(type);
61 assert(convertedType && "unexpected failure in memref type conversion");
62
63 auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
64 descr.setAllocatedPtr(builder, loc, memory);
65 descr.setAlignedPtr(builder, loc, memory);
66 descr.setConstantOffset(builder, loc, offset);
67
68 // Fill in sizes and strides
69 for (unsigned i = 0, e = type.getRank(); i != e; ++i) {
70 descr.setConstantSize(builder, loc, i, type.getDimSize(i));
71 descr.setConstantStride(builder, loc, i, strides[i]);
72 }
73 return descr;
74 }
75
76 /// Builds IR extracting the allocated pointer from the descriptor.
allocatedPtr(OpBuilder & builder,Location loc)77 Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
78 return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
79 }
80
81 /// Builds IR inserting the allocated pointer into the descriptor.
setAllocatedPtr(OpBuilder & builder,Location loc,Value ptr)82 void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
83 Value ptr) {
84 setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
85 }
86
87 /// Builds IR extracting the aligned pointer from the descriptor.
alignedPtr(OpBuilder & builder,Location loc)88 Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
89 return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
90 }
91
92 /// Builds IR inserting the aligned pointer into the descriptor.
setAlignedPtr(OpBuilder & builder,Location loc,Value ptr)93 void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
94 Value ptr) {
95 setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
96 }
97
98 // Creates a constant Op producing a value of `resultType` from an index-typed
99 // integer attribute.
createIndexAttrConstant(OpBuilder & builder,Location loc,Type resultType,int64_t value)100 static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
101 Type resultType, int64_t value) {
102 return builder.create<LLVM::ConstantOp>(
103 loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
104 }
105
106 /// Builds IR extracting the offset from the descriptor.
offset(OpBuilder & builder,Location loc)107 Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
108 return builder.create<LLVM::ExtractValueOp>(
109 loc, indexType, value,
110 builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
111 }
112
113 /// Builds IR inserting the offset into the descriptor.
setOffset(OpBuilder & builder,Location loc,Value offset)114 void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
115 Value offset) {
116 value = builder.create<LLVM::InsertValueOp>(
117 loc, structType, value, offset,
118 builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
119 }
120
121 /// Builds IR inserting the offset into the descriptor.
setConstantOffset(OpBuilder & builder,Location loc,uint64_t offset)122 void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
123 uint64_t offset) {
124 setOffset(builder, loc,
125 createIndexAttrConstant(builder, loc, indexType, offset));
126 }
127
128 /// Builds IR extracting the pos-th size from the descriptor.
size(OpBuilder & builder,Location loc,unsigned pos)129 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
130 return builder.create<LLVM::ExtractValueOp>(
131 loc, indexType, value,
132 builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
133 }
134
size(OpBuilder & builder,Location loc,Value pos,int64_t rank)135 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
136 int64_t rank) {
137 auto indexPtrTy = LLVM::LLVMPointerType::get(indexType);
138 auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank);
139 auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy);
140
141 // Copy size values to stack-allocated memory.
142 auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
143 auto one = createIndexAttrConstant(builder, loc, indexType, 1);
144 auto sizes = builder.create<LLVM::ExtractValueOp>(
145 loc, arrayTy, value,
146 builder.getI64ArrayAttr({kSizePosInMemRefDescriptor}));
147 auto sizesPtr =
148 builder.create<LLVM::AllocaOp>(loc, arrayPtrTy, one, /*alignment=*/0);
149 builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
150
151 // Load an return size value of interest.
152 auto resultPtr = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizesPtr,
153 ValueRange({zero, pos}));
154 return builder.create<LLVM::LoadOp>(loc, resultPtr);
155 }
156
157 /// Builds IR inserting the pos-th size into the descriptor
setSize(OpBuilder & builder,Location loc,unsigned pos,Value size)158 void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
159 Value size) {
160 value = builder.create<LLVM::InsertValueOp>(
161 loc, structType, value, size,
162 builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
163 }
164
setConstantSize(OpBuilder & builder,Location loc,unsigned pos,uint64_t size)165 void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
166 unsigned pos, uint64_t size) {
167 setSize(builder, loc, pos,
168 createIndexAttrConstant(builder, loc, indexType, size));
169 }
170
171 /// Builds IR extracting the pos-th stride from the descriptor.
stride(OpBuilder & builder,Location loc,unsigned pos)172 Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
173 return builder.create<LLVM::ExtractValueOp>(
174 loc, indexType, value,
175 builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
176 }
177
178 /// Builds IR inserting the pos-th stride into the descriptor
setStride(OpBuilder & builder,Location loc,unsigned pos,Value stride)179 void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
180 Value stride) {
181 value = builder.create<LLVM::InsertValueOp>(
182 loc, structType, value, stride,
183 builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
184 }
185
setConstantStride(OpBuilder & builder,Location loc,unsigned pos,uint64_t stride)186 void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
187 unsigned pos, uint64_t stride) {
188 setStride(builder, loc, pos,
189 createIndexAttrConstant(builder, loc, indexType, stride));
190 }
191
getElementPtrType()192 LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
193 return value.getType()
194 .cast<LLVM::LLVMStructType>()
195 .getBody()[kAlignedPtrPosInMemRefDescriptor]
196 .cast<LLVM::LLVMPointerType>();
197 }
198
199 /// Creates a MemRef descriptor structure from a list of individual values
200 /// composing that descriptor, in the following order:
201 /// - allocated pointer;
202 /// - aligned pointer;
203 /// - offset;
204 /// - <rank> sizes;
205 /// - <rank> shapes;
206 /// where <rank> is the MemRef rank as provided in `type`.
pack(OpBuilder & builder,Location loc,LLVMTypeConverter & converter,MemRefType type,ValueRange values)207 Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
208 LLVMTypeConverter &converter, MemRefType type,
209 ValueRange values) {
210 Type llvmType = converter.convertType(type);
211 auto d = MemRefDescriptor::undef(builder, loc, llvmType);
212
213 d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
214 d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
215 d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
216
217 int64_t rank = type.getRank();
218 for (unsigned i = 0; i < rank; ++i) {
219 d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
220 d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
221 }
222
223 return d;
224 }
225
226 /// Builds IR extracting individual elements of a MemRef descriptor structure
227 /// and returning them as `results` list.
unpack(OpBuilder & builder,Location loc,Value packed,MemRefType type,SmallVectorImpl<Value> & results)228 void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
229 MemRefType type,
230 SmallVectorImpl<Value> &results) {
231 int64_t rank = type.getRank();
232 results.reserve(results.size() + getNumUnpackedValues(type));
233
234 MemRefDescriptor d(packed);
235 results.push_back(d.allocatedPtr(builder, loc));
236 results.push_back(d.alignedPtr(builder, loc));
237 results.push_back(d.offset(builder, loc));
238 for (int64_t i = 0; i < rank; ++i)
239 results.push_back(d.size(builder, loc, i));
240 for (int64_t i = 0; i < rank; ++i)
241 results.push_back(d.stride(builder, loc, i));
242 }
243
244 /// Returns the number of non-aggregate values that would be produced by
245 /// `unpack`.
getNumUnpackedValues(MemRefType type)246 unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
247 // Two pointers, offset, <rank> sizes, <rank> shapes.
248 return 3 + 2 * type.getRank();
249 }
250
251 //===----------------------------------------------------------------------===//
252 // MemRefDescriptorView implementation.
253 //===----------------------------------------------------------------------===//
254
MemRefDescriptorView(ValueRange range)255 MemRefDescriptorView::MemRefDescriptorView(ValueRange range)
256 : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {}
257
allocatedPtr()258 Value MemRefDescriptorView::allocatedPtr() {
259 return elements[kAllocatedPtrPosInMemRefDescriptor];
260 }
261
alignedPtr()262 Value MemRefDescriptorView::alignedPtr() {
263 return elements[kAlignedPtrPosInMemRefDescriptor];
264 }
265
offset()266 Value MemRefDescriptorView::offset() {
267 return elements[kOffsetPosInMemRefDescriptor];
268 }
269
size(unsigned pos)270 Value MemRefDescriptorView::size(unsigned pos) {
271 return elements[kSizePosInMemRefDescriptor + pos];
272 }
273
stride(unsigned pos)274 Value MemRefDescriptorView::stride(unsigned pos) {
275 return elements[kSizePosInMemRefDescriptor + rank + pos];
276 }
277
278 //===----------------------------------------------------------------------===//
279 // UnrankedMemRefDescriptor implementation
280 //===----------------------------------------------------------------------===//
281
282 /// Construct a helper for the given descriptor value.
UnrankedMemRefDescriptor(Value descriptor)283 UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
284 : StructBuilder(descriptor) {}
285
286 /// Builds IR creating an `undef` value of the descriptor type.
undef(OpBuilder & builder,Location loc,Type descriptorType)287 UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
288 Location loc,
289 Type descriptorType) {
290 Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
291 return UnrankedMemRefDescriptor(descriptor);
292 }
rank(OpBuilder & builder,Location loc)293 Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
294 return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
295 }
setRank(OpBuilder & builder,Location loc,Value v)296 void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
297 Value v) {
298 setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
299 }
memRefDescPtr(OpBuilder & builder,Location loc)300 Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
301 Location loc) {
302 return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
303 }
setMemRefDescPtr(OpBuilder & builder,Location loc,Value v)304 void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
305 Location loc, Value v) {
306 setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
307 }
308
309 /// Builds IR populating an unranked MemRef descriptor structure from a list
310 /// of individual constituent values in the following order:
311 /// - rank of the memref;
312 /// - pointer to the memref descriptor.
pack(OpBuilder & builder,Location loc,LLVMTypeConverter & converter,UnrankedMemRefType type,ValueRange values)313 Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
314 LLVMTypeConverter &converter,
315 UnrankedMemRefType type,
316 ValueRange values) {
317 Type llvmType = converter.convertType(type);
318 auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
319
320 d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
321 d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
322 return d;
323 }
324
325 /// Builds IR extracting individual elements that compose an unranked memref
326 /// descriptor and returns them as `results` list.
unpack(OpBuilder & builder,Location loc,Value packed,SmallVectorImpl<Value> & results)327 void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
328 Value packed,
329 SmallVectorImpl<Value> &results) {
330 UnrankedMemRefDescriptor d(packed);
331 results.reserve(results.size() + 2);
332 results.push_back(d.rank(builder, loc));
333 results.push_back(d.memRefDescPtr(builder, loc));
334 }
335
computeSizes(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,ArrayRef<UnrankedMemRefDescriptor> values,SmallVectorImpl<Value> & sizes)336 void UnrankedMemRefDescriptor::computeSizes(
337 OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
338 ArrayRef<UnrankedMemRefDescriptor> values, SmallVectorImpl<Value> &sizes) {
339 if (values.empty())
340 return;
341
342 // Cache the index type.
343 Type indexType = typeConverter.getIndexType();
344
345 // Initialize shared constants.
346 Value one = createIndexAttrConstant(builder, loc, indexType, 1);
347 Value two = createIndexAttrConstant(builder, loc, indexType, 2);
348 Value pointerSize = createIndexAttrConstant(
349 builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8));
350 Value indexSize =
351 createIndexAttrConstant(builder, loc, indexType,
352 ceilDiv(typeConverter.getIndexTypeBitwidth(), 8));
353
354 sizes.reserve(sizes.size() + values.size());
355 for (UnrankedMemRefDescriptor desc : values) {
356 // Emit IR computing the memory necessary to store the descriptor. This
357 // assumes the descriptor to be
358 // { type*, type*, index, index[rank], index[rank] }
359 // and densely packed, so the total size is
360 // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
361 // TODO: consider including the actual size (including eventual padding due
362 // to data layout) into the unranked descriptor.
363 Value doublePointerSize =
364 builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
365
366 // (1 + 2 * rank) * sizeof(index)
367 Value rank = desc.rank(builder, loc);
368 Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank);
369 Value doubleRankIncremented =
370 builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one);
371 Value rankIndexSize = builder.create<LLVM::MulOp>(
372 loc, indexType, doubleRankIncremented, indexSize);
373
374 // Total allocation size.
375 Value allocationSize = builder.create<LLVM::AddOp>(
376 loc, indexType, doublePointerSize, rankIndexSize);
377 sizes.push_back(allocationSize);
378 }
379 }
380
allocatedPtr(OpBuilder & builder,Location loc,Value memRefDescPtr,Type elemPtrPtrType)381 Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
382 Value memRefDescPtr,
383 Type elemPtrPtrType) {
384
385 Value elementPtrPtr =
386 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
387 return builder.create<LLVM::LoadOp>(loc, elementPtrPtr);
388 }
389
setAllocatedPtr(OpBuilder & builder,Location loc,Value memRefDescPtr,Type elemPtrPtrType,Value allocatedPtr)390 void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
391 Value memRefDescPtr,
392 Type elemPtrPtrType,
393 Value allocatedPtr) {
394 Value elementPtrPtr =
395 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
396 builder.create<LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr);
397 }
398
alignedPtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType)399 Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
400 LLVMTypeConverter &typeConverter,
401 Value memRefDescPtr,
402 Type elemPtrPtrType) {
403 Value elementPtrPtr =
404 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
405
406 Value one =
407 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
408 Value alignedGep = builder.create<LLVM::GEPOp>(
409 loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
410 return builder.create<LLVM::LoadOp>(loc, alignedGep);
411 }
412
setAlignedPtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType,Value alignedPtr)413 void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
414 LLVMTypeConverter &typeConverter,
415 Value memRefDescPtr,
416 Type elemPtrPtrType,
417 Value alignedPtr) {
418 Value elementPtrPtr =
419 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
420
421 Value one =
422 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
423 Value alignedGep = builder.create<LLVM::GEPOp>(
424 loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
425 builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
426 }
427
offset(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType)428 Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
429 LLVMTypeConverter &typeConverter,
430 Value memRefDescPtr,
431 Type elemPtrPtrType) {
432 Value elementPtrPtr =
433 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
434
435 Value two =
436 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
437 Value offsetGep = builder.create<LLVM::GEPOp>(
438 loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
439 offsetGep = builder.create<LLVM::BitcastOp>(
440 loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
441 return builder.create<LLVM::LoadOp>(loc, offsetGep);
442 }
443
setOffset(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,Type elemPtrPtrType,Value offset)444 void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
445 LLVMTypeConverter &typeConverter,
446 Value memRefDescPtr,
447 Type elemPtrPtrType, Value offset) {
448 Value elementPtrPtr =
449 builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
450
451 Value two =
452 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
453 Value offsetGep = builder.create<LLVM::GEPOp>(
454 loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
455 offsetGep = builder.create<LLVM::BitcastOp>(
456 loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
457 builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
458 }
459
sizeBasePtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value memRefDescPtr,LLVM::LLVMPointerType elemPtrPtrType)460 Value UnrankedMemRefDescriptor::sizeBasePtr(
461 OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
462 Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) {
463 Type elemPtrTy = elemPtrPtrType.getElementType();
464 Type indexTy = typeConverter.getIndexType();
465 Type structPtrTy =
466 LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral(
467 indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy}));
468 Value structPtr =
469 builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
470
471 Type int32_type = typeConverter.convertType(builder.getI32Type());
472 Value zero =
473 createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
474 Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
475 builder.getI32IntegerAttr(3));
476 return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy),
477 structPtr, ValueRange({zero, three}));
478 }
479
size(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value sizeBasePtr,Value index)480 Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
481 LLVMTypeConverter typeConverter,
482 Value sizeBasePtr, Value index) {
483 Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
484 Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
485 ValueRange({index}));
486 return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
487 }
488
setSize(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value sizeBasePtr,Value index,Value size)489 void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
490 LLVMTypeConverter typeConverter,
491 Value sizeBasePtr, Value index,
492 Value size) {
493 Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
494 Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
495 ValueRange({index}));
496 builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
497 }
498
strideBasePtr(OpBuilder & builder,Location loc,LLVMTypeConverter & typeConverter,Value sizeBasePtr,Value rank)499 Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
500 LLVMTypeConverter &typeConverter,
501 Value sizeBasePtr, Value rank) {
502 Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
503 return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
504 ValueRange({rank}));
505 }
506
stride(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value strideBasePtr,Value index,Value stride)507 Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
508 LLVMTypeConverter typeConverter,
509 Value strideBasePtr, Value index,
510 Value stride) {
511 Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
512 Value strideStoreGep = builder.create<LLVM::GEPOp>(
513 loc, indexPtrTy, strideBasePtr, ValueRange({index}));
514 return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
515 }
516
setStride(OpBuilder & builder,Location loc,LLVMTypeConverter typeConverter,Value strideBasePtr,Value index,Value stride)517 void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
518 LLVMTypeConverter typeConverter,
519 Value strideBasePtr, Value index,
520 Value stride) {
521 Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
522 Value strideStoreGep = builder.create<LLVM::GEPOp>(
523 loc, indexPtrTy, strideBasePtr, ValueRange({index}));
524 builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
525 }
526