1 //===- MemRefBuilder.h - Helper for LLVM MemRef equivalents -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Provides a convenience API for emitting IR that inspects or constructs values
10 // of LLVM dialect structure type that correspond to ranked or unranked memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H
15 #define MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H
16 
17 #include "mlir/Conversion/LLVMCommon/StructBuilder.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 namespace mlir {
21 
22 class LLVMTypeConverter;
23 class MemRefType;
24 class UnrankedMemRefType;
25 
26 namespace LLVM {
27 class LLVMPointerType;
28 } // namespace LLVM
29 
30 /// Helper class to produce LLVM dialect operations extracting or inserting
31 /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
32 /// The Value may be null, in which case none of the operations are valid.
33 class MemRefDescriptor : public StructBuilder {
34 public:
35   /// Construct a helper for the given descriptor value.
36   explicit MemRefDescriptor(Value descriptor);
37   /// Builds IR creating an `undef` value of the descriptor type.
38   static MemRefDescriptor undef(OpBuilder &builder, Location loc,
39                                 Type descriptorType);
40   /// Builds IR creating a MemRef descriptor that represents `type` and
41   /// populates it with static shape and stride information extracted from the
42   /// type.
43   static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
44                                           LLVMTypeConverter &typeConverter,
45                                           MemRefType type, Value memory);
46 
47   /// Builds IR extracting the allocated pointer from the descriptor.
48   Value allocatedPtr(OpBuilder &builder, Location loc);
49   /// Builds IR inserting the allocated pointer into the descriptor.
50   void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr);
51 
52   /// Builds IR extracting the aligned pointer from the descriptor.
53   Value alignedPtr(OpBuilder &builder, Location loc);
54 
55   /// Builds IR inserting the aligned pointer into the descriptor.
56   void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr);
57 
58   /// Builds IR extracting the offset from the descriptor.
59   Value offset(OpBuilder &builder, Location loc);
60 
61   /// Builds IR inserting the offset into the descriptor.
62   void setOffset(OpBuilder &builder, Location loc, Value offset);
63   void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset);
64 
65   /// Builds IR extracting the pos-th size from the descriptor.
66   Value size(OpBuilder &builder, Location loc, unsigned pos);
67   Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank);
68 
69   /// Builds IR inserting the pos-th size into the descriptor
70   void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
71   void setConstantSize(OpBuilder &builder, Location loc, unsigned pos,
72                        uint64_t size);
73 
74   /// Builds IR extracting the pos-th size from the descriptor.
75   Value stride(OpBuilder &builder, Location loc, unsigned pos);
76 
77   /// Builds IR inserting the pos-th stride into the descriptor
78   void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride);
79   void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
80                          uint64_t stride);
81 
82   /// Returns the (LLVM) pointer type this descriptor contains.
83   LLVM::LLVMPointerType getElementPtrType();
84 
85   /// Builds IR populating a MemRef descriptor structure from a list of
86   /// individual values composing that descriptor, in the following order:
87   /// - allocated pointer;
88   /// - aligned pointer;
89   /// - offset;
90   /// - <rank> sizes;
91   /// - <rank> shapes;
92   /// where <rank> is the MemRef rank as provided in `type`.
93   static Value pack(OpBuilder &builder, Location loc,
94                     LLVMTypeConverter &converter, MemRefType type,
95                     ValueRange values);
96 
97   /// Builds IR extracting individual elements of a MemRef descriptor structure
98   /// and returning them as `results` list.
99   static void unpack(OpBuilder &builder, Location loc, Value packed,
100                      MemRefType type, SmallVectorImpl<Value> &results);
101 
102   /// Returns the number of non-aggregate values that would be produced by
103   /// `unpack`.
104   static unsigned getNumUnpackedValues(MemRefType type);
105 
106 private:
107   // Cached index type.
108   Type indexType;
109 };
110 
111 /// Helper class allowing the user to access a range of Values that correspond
112 /// to an unpacked memref descriptor using named accessors. This does not own
113 /// the values.
114 class MemRefDescriptorView {
115 public:
116   /// Constructs the view from a range of values. Infers the rank from the size
117   /// of the range.
118   explicit MemRefDescriptorView(ValueRange range);
119 
120   /// Returns the allocated pointer Value.
121   Value allocatedPtr();
122 
123   /// Returns the aligned pointer Value.
124   Value alignedPtr();
125 
126   /// Returns the offset Value.
127   Value offset();
128 
129   /// Returns the pos-th size Value.
130   Value size(unsigned pos);
131 
132   /// Returns the pos-th stride Value.
133   Value stride(unsigned pos);
134 
135 private:
136   /// Rank of the memref the descriptor is pointing to.
137   int rank;
138   /// Underlying range of Values.
139   ValueRange elements;
140 };
141 
142 class UnrankedMemRefDescriptor : public StructBuilder {
143 public:
144   /// Construct a helper for the given descriptor value.
145   explicit UnrankedMemRefDescriptor(Value descriptor);
146   /// Builds IR creating an `undef` value of the descriptor type.
147   static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
148                                         Type descriptorType);
149 
150   /// Builds IR extracting the rank from the descriptor
151   Value rank(OpBuilder &builder, Location loc);
152   /// Builds IR setting the rank in the descriptor
153   void setRank(OpBuilder &builder, Location loc, Value value);
154   /// Builds IR extracting ranked memref descriptor ptr
155   Value memRefDescPtr(OpBuilder &builder, Location loc);
156   /// Builds IR setting ranked memref descriptor ptr
157   void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
158 
159   /// Builds IR populating an unranked MemRef descriptor structure from a list
160   /// of individual constituent values in the following order:
161   /// - rank of the memref;
162   /// - pointer to the memref descriptor.
163   static Value pack(OpBuilder &builder, Location loc,
164                     LLVMTypeConverter &converter, UnrankedMemRefType type,
165                     ValueRange values);
166 
167   /// Builds IR extracting individual elements that compose an unranked memref
168   /// descriptor and returns them as `results` list.
169   static void unpack(OpBuilder &builder, Location loc, Value packed,
170                      SmallVectorImpl<Value> &results);
171 
172   /// Returns the number of non-aggregate values that would be produced by
173   /// `unpack`.
getNumUnpackedValues()174   static unsigned getNumUnpackedValues() { return 2; }
175 
176   /// Builds IR computing the sizes in bytes (suitable for opaque allocation)
177   /// and appends the corresponding values into `sizes`.
178   static void computeSizes(OpBuilder &builder, Location loc,
179                            LLVMTypeConverter &typeConverter,
180                            ArrayRef<UnrankedMemRefDescriptor> values,
181                            SmallVectorImpl<Value> &sizes);
182 
183   /// TODO: The following accessors don't take alignment rules between elements
184   /// of the descriptor struct into account. For some architectures, it might be
185   /// necessary to extend them and to use `llvm::DataLayout` contained in
186   /// `LLVMTypeConverter`.
187 
188   /// Builds IR extracting the allocated pointer from the descriptor.
189   static Value allocatedPtr(OpBuilder &builder, Location loc,
190                             Value memRefDescPtr, Type elemPtrPtrType);
191   /// Builds IR inserting the allocated pointer into the descriptor.
192   static void setAllocatedPtr(OpBuilder &builder, Location loc,
193                               Value memRefDescPtr, Type elemPtrPtrType,
194                               Value allocatedPtr);
195 
196   /// Builds IR extracting the aligned pointer from the descriptor.
197   static Value alignedPtr(OpBuilder &builder, Location loc,
198                           LLVMTypeConverter &typeConverter, Value memRefDescPtr,
199                           Type elemPtrPtrType);
200   /// Builds IR inserting the aligned pointer into the descriptor.
201   static void setAlignedPtr(OpBuilder &builder, Location loc,
202                             LLVMTypeConverter &typeConverter,
203                             Value memRefDescPtr, Type elemPtrPtrType,
204                             Value alignedPtr);
205 
206   /// Builds IR extracting the offset from the descriptor.
207   static Value offset(OpBuilder &builder, Location loc,
208                       LLVMTypeConverter &typeConverter, Value memRefDescPtr,
209                       Type elemPtrPtrType);
210   /// Builds IR inserting the offset into the descriptor.
211   static void setOffset(OpBuilder &builder, Location loc,
212                         LLVMTypeConverter &typeConverter, Value memRefDescPtr,
213                         Type elemPtrPtrType, Value offset);
214 
215   /// Builds IR extracting the pointer to the first element of the size array.
216   static Value sizeBasePtr(OpBuilder &builder, Location loc,
217                            LLVMTypeConverter &typeConverter,
218                            Value memRefDescPtr,
219                            LLVM::LLVMPointerType elemPtrPtrType);
220   /// Builds IR extracting the size[index] from the descriptor.
221   static Value size(OpBuilder &builder, Location loc,
222                     LLVMTypeConverter typeConverter, Value sizeBasePtr,
223                     Value index);
224   /// Builds IR inserting the size[index] into the descriptor.
225   static void setSize(OpBuilder &builder, Location loc,
226                       LLVMTypeConverter typeConverter, Value sizeBasePtr,
227                       Value index, Value size);
228 
229   /// Builds IR extracting the pointer to the first element of the stride array.
230   static Value strideBasePtr(OpBuilder &builder, Location loc,
231                              LLVMTypeConverter &typeConverter,
232                              Value sizeBasePtr, Value rank);
233   /// Builds IR extracting the stride[index] from the descriptor.
234   static Value stride(OpBuilder &builder, Location loc,
235                       LLVMTypeConverter typeConverter, Value strideBasePtr,
236                       Value index, Value stride);
237   /// Builds IR inserting the stride[index] into the descriptor.
238   static void setStride(OpBuilder &builder, Location loc,
239                         LLVMTypeConverter typeConverter, Value strideBasePtr,
240                         Value index, Value stride);
241 };
242 
243 } // namespace mlir
244 
245 #endif // MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H_
246