1 //===- MemRefBuilder.h - Helper for LLVM MemRef equivalents -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Provides a convenience API for emitting IR that inspects or constructs values 10 // of LLVM dialect structure type that correspond to ranked or unranked memref. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H 15 #define MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H 16 17 #include "mlir/Conversion/LLVMCommon/StructBuilder.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 namespace mlir { 21 22 class LLVMTypeConverter; 23 class MemRefType; 24 class UnrankedMemRefType; 25 26 namespace LLVM { 27 class LLVMPointerType; 28 } // namespace LLVM 29 30 /// Helper class to produce LLVM dialect operations extracting or inserting 31 /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. 32 /// The Value may be null, in which case none of the operations are valid. 33 class MemRefDescriptor : public StructBuilder { 34 public: 35 /// Construct a helper for the given descriptor value. 36 explicit MemRefDescriptor(Value descriptor); 37 /// Builds IR creating an `undef` value of the descriptor type. 38 static MemRefDescriptor undef(OpBuilder &builder, Location loc, 39 Type descriptorType); 40 /// Builds IR creating a MemRef descriptor that represents `type` and 41 /// populates it with static shape and stride information extracted from the 42 /// type. 43 static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, 44 LLVMTypeConverter &typeConverter, 45 MemRefType type, Value memory); 46 47 /// Builds IR extracting the allocated pointer from the descriptor. 48 Value allocatedPtr(OpBuilder &builder, Location loc); 49 /// Builds IR inserting the allocated pointer into the descriptor. 50 void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr); 51 52 /// Builds IR extracting the aligned pointer from the descriptor. 53 Value alignedPtr(OpBuilder &builder, Location loc); 54 55 /// Builds IR inserting the aligned pointer into the descriptor. 56 void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr); 57 58 /// Builds IR extracting the offset from the descriptor. 59 Value offset(OpBuilder &builder, Location loc); 60 61 /// Builds IR inserting the offset into the descriptor. 62 void setOffset(OpBuilder &builder, Location loc, Value offset); 63 void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset); 64 65 /// Builds IR extracting the pos-th size from the descriptor. 66 Value size(OpBuilder &builder, Location loc, unsigned pos); 67 Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank); 68 69 /// Builds IR inserting the pos-th size into the descriptor 70 void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size); 71 void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, 72 uint64_t size); 73 74 /// Builds IR extracting the pos-th size from the descriptor. 75 Value stride(OpBuilder &builder, Location loc, unsigned pos); 76 77 /// Builds IR inserting the pos-th stride into the descriptor 78 void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride); 79 void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, 80 uint64_t stride); 81 82 /// Returns the (LLVM) pointer type this descriptor contains. 83 LLVM::LLVMPointerType getElementPtrType(); 84 85 /// Builds IR populating a MemRef descriptor structure from a list of 86 /// individual values composing that descriptor, in the following order: 87 /// - allocated pointer; 88 /// - aligned pointer; 89 /// - offset; 90 /// - <rank> sizes; 91 /// - <rank> shapes; 92 /// where <rank> is the MemRef rank as provided in `type`. 93 static Value pack(OpBuilder &builder, Location loc, 94 LLVMTypeConverter &converter, MemRefType type, 95 ValueRange values); 96 97 /// Builds IR extracting individual elements of a MemRef descriptor structure 98 /// and returning them as `results` list. 99 static void unpack(OpBuilder &builder, Location loc, Value packed, 100 MemRefType type, SmallVectorImpl<Value> &results); 101 102 /// Returns the number of non-aggregate values that would be produced by 103 /// `unpack`. 104 static unsigned getNumUnpackedValues(MemRefType type); 105 106 private: 107 // Cached index type. 108 Type indexType; 109 }; 110 111 /// Helper class allowing the user to access a range of Values that correspond 112 /// to an unpacked memref descriptor using named accessors. This does not own 113 /// the values. 114 class MemRefDescriptorView { 115 public: 116 /// Constructs the view from a range of values. Infers the rank from the size 117 /// of the range. 118 explicit MemRefDescriptorView(ValueRange range); 119 120 /// Returns the allocated pointer Value. 121 Value allocatedPtr(); 122 123 /// Returns the aligned pointer Value. 124 Value alignedPtr(); 125 126 /// Returns the offset Value. 127 Value offset(); 128 129 /// Returns the pos-th size Value. 130 Value size(unsigned pos); 131 132 /// Returns the pos-th stride Value. 133 Value stride(unsigned pos); 134 135 private: 136 /// Rank of the memref the descriptor is pointing to. 137 int rank; 138 /// Underlying range of Values. 139 ValueRange elements; 140 }; 141 142 class UnrankedMemRefDescriptor : public StructBuilder { 143 public: 144 /// Construct a helper for the given descriptor value. 145 explicit UnrankedMemRefDescriptor(Value descriptor); 146 /// Builds IR creating an `undef` value of the descriptor type. 147 static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, 148 Type descriptorType); 149 150 /// Builds IR extracting the rank from the descriptor 151 Value rank(OpBuilder &builder, Location loc); 152 /// Builds IR setting the rank in the descriptor 153 void setRank(OpBuilder &builder, Location loc, Value value); 154 /// Builds IR extracting ranked memref descriptor ptr 155 Value memRefDescPtr(OpBuilder &builder, Location loc); 156 /// Builds IR setting ranked memref descriptor ptr 157 void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value); 158 159 /// Builds IR populating an unranked MemRef descriptor structure from a list 160 /// of individual constituent values in the following order: 161 /// - rank of the memref; 162 /// - pointer to the memref descriptor. 163 static Value pack(OpBuilder &builder, Location loc, 164 LLVMTypeConverter &converter, UnrankedMemRefType type, 165 ValueRange values); 166 167 /// Builds IR extracting individual elements that compose an unranked memref 168 /// descriptor and returns them as `results` list. 169 static void unpack(OpBuilder &builder, Location loc, Value packed, 170 SmallVectorImpl<Value> &results); 171 172 /// Returns the number of non-aggregate values that would be produced by 173 /// `unpack`. getNumUnpackedValues()174 static unsigned getNumUnpackedValues() { return 2; } 175 176 /// Builds IR computing the sizes in bytes (suitable for opaque allocation) 177 /// and appends the corresponding values into `sizes`. 178 static void computeSizes(OpBuilder &builder, Location loc, 179 LLVMTypeConverter &typeConverter, 180 ArrayRef<UnrankedMemRefDescriptor> values, 181 SmallVectorImpl<Value> &sizes); 182 183 /// TODO: The following accessors don't take alignment rules between elements 184 /// of the descriptor struct into account. For some architectures, it might be 185 /// necessary to extend them and to use `llvm::DataLayout` contained in 186 /// `LLVMTypeConverter`. 187 188 /// Builds IR extracting the allocated pointer from the descriptor. 189 static Value allocatedPtr(OpBuilder &builder, Location loc, 190 Value memRefDescPtr, Type elemPtrPtrType); 191 /// Builds IR inserting the allocated pointer into the descriptor. 192 static void setAllocatedPtr(OpBuilder &builder, Location loc, 193 Value memRefDescPtr, Type elemPtrPtrType, 194 Value allocatedPtr); 195 196 /// Builds IR extracting the aligned pointer from the descriptor. 197 static Value alignedPtr(OpBuilder &builder, Location loc, 198 LLVMTypeConverter &typeConverter, Value memRefDescPtr, 199 Type elemPtrPtrType); 200 /// Builds IR inserting the aligned pointer into the descriptor. 201 static void setAlignedPtr(OpBuilder &builder, Location loc, 202 LLVMTypeConverter &typeConverter, 203 Value memRefDescPtr, Type elemPtrPtrType, 204 Value alignedPtr); 205 206 /// Builds IR extracting the offset from the descriptor. 207 static Value offset(OpBuilder &builder, Location loc, 208 LLVMTypeConverter &typeConverter, Value memRefDescPtr, 209 Type elemPtrPtrType); 210 /// Builds IR inserting the offset into the descriptor. 211 static void setOffset(OpBuilder &builder, Location loc, 212 LLVMTypeConverter &typeConverter, Value memRefDescPtr, 213 Type elemPtrPtrType, Value offset); 214 215 /// Builds IR extracting the pointer to the first element of the size array. 216 static Value sizeBasePtr(OpBuilder &builder, Location loc, 217 LLVMTypeConverter &typeConverter, 218 Value memRefDescPtr, 219 LLVM::LLVMPointerType elemPtrPtrType); 220 /// Builds IR extracting the size[index] from the descriptor. 221 static Value size(OpBuilder &builder, Location loc, 222 LLVMTypeConverter typeConverter, Value sizeBasePtr, 223 Value index); 224 /// Builds IR inserting the size[index] into the descriptor. 225 static void setSize(OpBuilder &builder, Location loc, 226 LLVMTypeConverter typeConverter, Value sizeBasePtr, 227 Value index, Value size); 228 229 /// Builds IR extracting the pointer to the first element of the stride array. 230 static Value strideBasePtr(OpBuilder &builder, Location loc, 231 LLVMTypeConverter &typeConverter, 232 Value sizeBasePtr, Value rank); 233 /// Builds IR extracting the stride[index] from the descriptor. 234 static Value stride(OpBuilder &builder, Location loc, 235 LLVMTypeConverter typeConverter, Value strideBasePtr, 236 Value index, Value stride); 237 /// Builds IR inserting the stride[index] into the descriptor. 238 static void setStride(OpBuilder &builder, Location loc, 239 LLVMTypeConverter typeConverter, Value strideBasePtr, 240 Value index, Value stride); 241 }; 242 243 } // namespace mlir 244 245 #endif // MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H_ 246