1 //===- ConvertStandardToLLVM.h - Convert to the LLVM dialect ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Provides a dialect conversion targeting the LLVM IR dialect.  By default, it
10 // converts Standard ops and types and provides hooks for dialect-specific
11 // extensions to the conversion.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
16 #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
17 
18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 namespace llvm {
22 class IntegerType;
23 class LLVMContext;
24 class Module;
25 class Type;
26 } // namespace llvm
27 
28 namespace mlir {
29 
30 class BaseMemRefType;
31 class ComplexType;
32 class LLVMTypeConverter;
33 class UnrankedMemRefType;
34 
35 namespace LLVM {
36 class LLVMDialect;
37 class LLVMPointerType;
38 } // namespace LLVM
39 
40 /// Callback to convert function argument types. It converts a MemRef function
41 /// argument to a list of non-aggregate types containing descriptor
42 /// information, and an UnrankedmemRef function argument to a list containing
43 /// the rank and a pointer to a descriptor struct.
44 LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter,
45                                          Type type,
46                                          SmallVectorImpl<Type> &result);
47 
48 /// Callback to convert function argument types. It converts MemRef function
49 /// arguments to bare pointers to the MemRef element type.
50 LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
51                                           Type type,
52                                           SmallVectorImpl<Type> &result);
53 
54 /// Conversion from types in the Standard dialect to the LLVM IR dialect.
55 class LLVMTypeConverter : public TypeConverter {
56   /// Give structFuncArgTypeConverter access to memref-specific functions.
57   friend LogicalResult
58   structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type,
59                              SmallVectorImpl<Type> &result);
60 
61 public:
62   using TypeConverter::convertType;
63 
64   /// Create an LLVMTypeConverter using the default LowerToLLVMOptions.
65   LLVMTypeConverter(MLIRContext *ctx);
66 
67   /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
68   LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
69 
70   /// Convert a function type.  The arguments and results are converted one by
71   /// one and results are packed into a wrapped LLVM IR structure type. `result`
72   /// is populated with argument mapping.
73   Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
74                                 SignatureConversion &result);
75 
76   /// Convert a non-empty list of types to be returned from a function into a
77   /// supported LLVM IR type.  In particular, if more than one value is
78   /// returned, create an LLVM IR structure type with elements that correspond
79   /// to each of the MLIR types converted with `convertType`.
80   Type packFunctionResults(ArrayRef<Type> types);
81 
82   /// Convert a type in the context of the default or bare pointer calling
83   /// convention. Calling convention sensitive types, such as MemRefType and
84   /// UnrankedMemRefType, are converted following the specific rules for the
85   /// calling convention. Calling convention independent types are converted
86   /// following the default LLVM type conversions.
87   Type convertCallingConventionType(Type type);
88 
89   /// Promote the bare pointers in 'values' that resulted from memrefs to
90   /// descriptors. 'stdTypes' holds the types of 'values' before the conversion
91   /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
92   void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter,
93                                     Location loc, ArrayRef<Type> stdTypes,
94                                     SmallVectorImpl<Value> &values);
95 
96   /// Returns the MLIR context.
97   MLIRContext &getContext();
98 
99   /// Returns the LLVM dialect.
getDialect()100   LLVM::LLVMDialect *getDialect() { return llvmDialect; }
101 
getOptions()102   const LowerToLLVMOptions &getOptions() const { return options; }
103 
104   /// Promote the LLVM representation of all operands including promoting MemRef
105   /// descriptors to stack and use pointers to struct to avoid the complexity
106   /// of the platform-specific C/C++ ABI lowering related to struct argument
107   /// passing.
108   SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
109                                         ValueRange operands,
110                                         OpBuilder &builder);
111 
112   /// Promote the LLVM struct representation of one MemRef descriptor to stack
113   /// and use pointer to struct to avoid the complexity of the platform-specific
114   /// C/C++ ABI lowering related to struct argument passing.
115   Value promoteOneMemRefDescriptor(Location loc, Value operand,
116                                    OpBuilder &builder);
117 
118   /// Converts the function type to a C-compatible format, in particular using
119   /// pointers to memref descriptors for arguments.
120   Type convertFunctionTypeCWrapper(FunctionType type);
121 
122   /// Returns the data layout to use during and after conversion.
getDataLayout()123   const llvm::DataLayout &getDataLayout() { return options.dataLayout; }
124 
125   /// Gets the LLVM representation of the index type. The returned type is an
126   /// integer type with the size configured for this type converter.
127   Type getIndexType();
128 
129   /// Gets the bitwidth of the index type when converted to LLVM.
getIndexTypeBitwidth()130   unsigned getIndexTypeBitwidth() { return options.indexBitwidth; }
131 
132   /// Gets the pointer bitwidth.
133   unsigned getPointerBitwidth(unsigned addressSpace = 0);
134 
135 protected:
136   /// Pointer to the LLVM dialect.
137   LLVM::LLVMDialect *llvmDialect;
138 
139 private:
140   /// Convert a function type.  The arguments and results are converted one by
141   /// one.  Additionally, if the function returns more than one value, pack the
142   /// results into an LLVM IR structure type so that the converted function type
143   /// returns at most one result.
144   Type convertFunctionType(FunctionType type);
145 
146   /// Convert the index type.  Uses llvmModule data layout to create an integer
147   /// of the pointer bitwidth.
148   Type convertIndexType(IndexType type);
149 
150   /// Convert an integer type `i*` to `!llvm<"i*">`.
151   Type convertIntegerType(IntegerType type);
152 
153   /// Convert a floating point type: `f16` to `f16`, `f32` to
154   /// `f32` and `f64` to `f64`.  `bf16` is not supported
155   /// by LLVM.
156   Type convertFloatType(FloatType type);
157 
158   /// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`,
159   /// `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>` to
160   /// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported.
161   Type convertComplexType(ComplexType type);
162 
163   /// Convert a memref type into an LLVM type that captures the relevant data.
164   Type convertMemRefType(MemRefType type);
165 
166   /// Convert a memref type into a list of LLVM IR types that will form the
167   /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
168   /// arrays in the descriptors are unpacked to individual index-typed elements,
169   /// else they are are kept as rank-sized arrays of index type. In particular,
170   /// the list will contain:
171   /// - two pointers to the memref element type, followed by
172   /// - an index-typed offset, followed by
173   /// - (if unpackAggregates = true)
174   ///    - one index-typed size per dimension of the memref, followed by
175   ///    - one index-typed stride per dimension of the memref.
176   /// - (if unpackArrregates = false)
177   ///   - one rank-sized array of index-type for the size of each dimension
178   ///   - one rank-sized array of index-type for the stride of each dimension
179   ///
180   /// For example, memref<?x?xf32> is converted to the following list:
181   /// - `!llvm<"float*">` (allocated pointer),
182   /// - `!llvm<"float*">` (aligned pointer),
183   /// - `i64` (offset),
184   /// - `i64`, `i64` (sizes),
185   /// - `i64`, `i64` (strides).
186   /// These types can be recomposed to a memref descriptor struct.
187   SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
188                                                  bool unpackAggregates);
189 
190   /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
191   /// that will form the unranked memref descriptor. In particular, this list
192   /// contains:
193   /// - an integer rank, followed by
194   /// - a pointer to the memref descriptor struct.
195   /// For example, memref<*xf32> is converted to the following list:
196   /// i64 (rank)
197   /// !llvm<"i8*"> (type-erased pointer).
198   /// These types can be recomposed to a unranked memref descriptor struct.
199   SmallVector<Type, 2> getUnrankedMemRefDescriptorFields();
200 
201   // Convert an unranked memref type to an LLVM type that captures the
202   // runtime rank and a pointer to the static ranked memref desc
203   Type convertUnrankedMemRefType(UnrankedMemRefType type);
204 
205   /// Convert a memref type to a bare pointer to the memref element type.
206   Type convertMemRefToBarePtr(BaseMemRefType type);
207 
208   // Convert a 1D vector type into an LLVM vector type.
209   Type convertVectorType(VectorType type);
210 
211   /// Options for customizing the llvm lowering.
212   LowerToLLVMOptions options;
213 };
214 
215 /// Helper class to produce LLVM dialect operations extracting or inserting
216 /// values to a struct.
217 class StructBuilder {
218 public:
219   /// Construct a helper for the given value.
220   explicit StructBuilder(Value v);
221   /// Builds IR creating an `undef` value of the descriptor type.
222   static StructBuilder undef(OpBuilder &builder, Location loc,
223                              Type descriptorType);
224 
Value()225   /*implicit*/ operator Value() { return value; }
226 
227 protected:
228   // LLVM value
229   Value value;
230   // Cached struct type.
231   Type structType;
232 
233 protected:
234   /// Builds IR to extract a value from the struct at position pos
235   Value extractPtr(OpBuilder &builder, Location loc, unsigned pos);
236   /// Builds IR to set a value in the struct at position pos
237   void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr);
238 };
239 
240 class ComplexStructBuilder : public StructBuilder {
241 public:
242   /// Construct a helper for the given complex number value.
243   using StructBuilder::StructBuilder;
244   /// Build IR creating an `undef` value of the complex number type.
245   static ComplexStructBuilder undef(OpBuilder &builder, Location loc,
246                                     Type type);
247 
248   // Build IR extracting the real value from the complex number struct.
249   Value real(OpBuilder &builder, Location loc);
250   // Build IR inserting the real value into the complex number struct.
251   void setReal(OpBuilder &builder, Location loc, Value real);
252 
253   // Build IR extracting the imaginary value from the complex number struct.
254   Value imaginary(OpBuilder &builder, Location loc);
255   // Build IR inserting the imaginary value into the complex number struct.
256   void setImaginary(OpBuilder &builder, Location loc, Value imaginary);
257 };
258 
259 /// Helper class to produce LLVM dialect operations extracting or inserting
260 /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
261 /// The Value may be null, in which case none of the operations are valid.
262 class MemRefDescriptor : public StructBuilder {
263 public:
264   /// Construct a helper for the given descriptor value.
265   explicit MemRefDescriptor(Value descriptor);
266   /// Builds IR creating an `undef` value of the descriptor type.
267   static MemRefDescriptor undef(OpBuilder &builder, Location loc,
268                                 Type descriptorType);
269   /// Builds IR creating a MemRef descriptor that represents `type` and
270   /// populates it with static shape and stride information extracted from the
271   /// type.
272   static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
273                                           LLVMTypeConverter &typeConverter,
274                                           MemRefType type, Value memory);
275 
276   /// Builds IR extracting the allocated pointer from the descriptor.
277   Value allocatedPtr(OpBuilder &builder, Location loc);
278   /// Builds IR inserting the allocated pointer into the descriptor.
279   void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr);
280 
281   /// Builds IR extracting the aligned pointer from the descriptor.
282   Value alignedPtr(OpBuilder &builder, Location loc);
283 
284   /// Builds IR inserting the aligned pointer into the descriptor.
285   void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr);
286 
287   /// Builds IR extracting the offset from the descriptor.
288   Value offset(OpBuilder &builder, Location loc);
289 
290   /// Builds IR inserting the offset into the descriptor.
291   void setOffset(OpBuilder &builder, Location loc, Value offset);
292   void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset);
293 
294   /// Builds IR extracting the pos-th size from the descriptor.
295   Value size(OpBuilder &builder, Location loc, unsigned pos);
296   Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank);
297 
298   /// Builds IR inserting the pos-th size into the descriptor
299   void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
300   void setConstantSize(OpBuilder &builder, Location loc, unsigned pos,
301                        uint64_t size);
302 
303   /// Builds IR extracting the pos-th size from the descriptor.
304   Value stride(OpBuilder &builder, Location loc, unsigned pos);
305 
306   /// Builds IR inserting the pos-th stride into the descriptor
307   void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride);
308   void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
309                          uint64_t stride);
310 
311   /// Returns the (LLVM) pointer type this descriptor contains.
312   LLVM::LLVMPointerType getElementPtrType();
313 
314   /// Builds IR populating a MemRef descriptor structure from a list of
315   /// individual values composing that descriptor, in the following order:
316   /// - allocated pointer;
317   /// - aligned pointer;
318   /// - offset;
319   /// - <rank> sizes;
320   /// - <rank> shapes;
321   /// where <rank> is the MemRef rank as provided in `type`.
322   static Value pack(OpBuilder &builder, Location loc,
323                     LLVMTypeConverter &converter, MemRefType type,
324                     ValueRange values);
325 
326   /// Builds IR extracting individual elements of a MemRef descriptor structure
327   /// and returning them as `results` list.
328   static void unpack(OpBuilder &builder, Location loc, Value packed,
329                      MemRefType type, SmallVectorImpl<Value> &results);
330 
331   /// Returns the number of non-aggregate values that would be produced by
332   /// `unpack`.
333   static unsigned getNumUnpackedValues(MemRefType type);
334 
335 private:
336   // Cached index type.
337   Type indexType;
338 };
339 
340 /// Helper class allowing the user to access a range of Values that correspond
341 /// to an unpacked memref descriptor using named accessors. This does not own
342 /// the values.
343 class MemRefDescriptorView {
344 public:
345   /// Constructs the view from a range of values. Infers the rank from the size
346   /// of the range.
347   explicit MemRefDescriptorView(ValueRange range);
348 
349   /// Returns the allocated pointer Value.
350   Value allocatedPtr();
351 
352   /// Returns the aligned pointer Value.
353   Value alignedPtr();
354 
355   /// Returns the offset Value.
356   Value offset();
357 
358   /// Returns the pos-th size Value.
359   Value size(unsigned pos);
360 
361   /// Returns the pos-th stride Value.
362   Value stride(unsigned pos);
363 
364 private:
365   /// Rank of the memref the descriptor is pointing to.
366   int rank;
367   /// Underlying range of Values.
368   ValueRange elements;
369 };
370 
371 class UnrankedMemRefDescriptor : public StructBuilder {
372 public:
373   /// Construct a helper for the given descriptor value.
374   explicit UnrankedMemRefDescriptor(Value descriptor);
375   /// Builds IR creating an `undef` value of the descriptor type.
376   static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
377                                         Type descriptorType);
378 
379   /// Builds IR extracting the rank from the descriptor
380   Value rank(OpBuilder &builder, Location loc);
381   /// Builds IR setting the rank in the descriptor
382   void setRank(OpBuilder &builder, Location loc, Value value);
383   /// Builds IR extracting ranked memref descriptor ptr
384   Value memRefDescPtr(OpBuilder &builder, Location loc);
385   /// Builds IR setting ranked memref descriptor ptr
386   void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
387 
388   /// Builds IR populating an unranked MemRef descriptor structure from a list
389   /// of individual constituent values in the following order:
390   /// - rank of the memref;
391   /// - pointer to the memref descriptor.
392   static Value pack(OpBuilder &builder, Location loc,
393                     LLVMTypeConverter &converter, UnrankedMemRefType type,
394                     ValueRange values);
395 
396   /// Builds IR extracting individual elements that compose an unranked memref
397   /// descriptor and returns them as `results` list.
398   static void unpack(OpBuilder &builder, Location loc, Value packed,
399                      SmallVectorImpl<Value> &results);
400 
401   /// Returns the number of non-aggregate values that would be produced by
402   /// `unpack`.
getNumUnpackedValues()403   static unsigned getNumUnpackedValues() { return 2; }
404 
405   /// Builds IR computing the sizes in bytes (suitable for opaque allocation)
406   /// and appends the corresponding values into `sizes`.
407   static void computeSizes(OpBuilder &builder, Location loc,
408                            LLVMTypeConverter &typeConverter,
409                            ArrayRef<UnrankedMemRefDescriptor> values,
410                            SmallVectorImpl<Value> &sizes);
411 
412   /// TODO: The following accessors don't take alignment rules between elements
413   /// of the descriptor struct into account. For some architectures, it might be
414   /// necessary to extend them and to use `llvm::DataLayout` contained in
415   /// `LLVMTypeConverter`.
416 
417   /// Builds IR extracting the allocated pointer from the descriptor.
418   static Value allocatedPtr(OpBuilder &builder, Location loc,
419                             Value memRefDescPtr, Type elemPtrPtrType);
420   /// Builds IR inserting the allocated pointer into the descriptor.
421   static void setAllocatedPtr(OpBuilder &builder, Location loc,
422                               Value memRefDescPtr, Type elemPtrPtrType,
423                               Value allocatedPtr);
424 
425   /// Builds IR extracting the aligned pointer from the descriptor.
426   static Value alignedPtr(OpBuilder &builder, Location loc,
427                           LLVMTypeConverter &typeConverter, Value memRefDescPtr,
428                           Type elemPtrPtrType);
429   /// Builds IR inserting the aligned pointer into the descriptor.
430   static void setAlignedPtr(OpBuilder &builder, Location loc,
431                             LLVMTypeConverter &typeConverter,
432                             Value memRefDescPtr, Type elemPtrPtrType,
433                             Value alignedPtr);
434 
435   /// Builds IR extracting the offset from the descriptor.
436   static Value offset(OpBuilder &builder, Location loc,
437                       LLVMTypeConverter &typeConverter, Value memRefDescPtr,
438                       Type elemPtrPtrType);
439   /// Builds IR inserting the offset into the descriptor.
440   static void setOffset(OpBuilder &builder, Location loc,
441                         LLVMTypeConverter &typeConverter, Value memRefDescPtr,
442                         Type elemPtrPtrType, Value offset);
443 
444   /// Builds IR extracting the pointer to the first element of the size array.
445   static Value sizeBasePtr(OpBuilder &builder, Location loc,
446                            LLVMTypeConverter &typeConverter,
447                            Value memRefDescPtr,
448                            LLVM::LLVMPointerType elemPtrPtrType);
449   /// Builds IR extracting the size[index] from the descriptor.
450   static Value size(OpBuilder &builder, Location loc,
451                     LLVMTypeConverter typeConverter, Value sizeBasePtr,
452                     Value index);
453   /// Builds IR inserting the size[index] into the descriptor.
454   static void setSize(OpBuilder &builder, Location loc,
455                       LLVMTypeConverter typeConverter, Value sizeBasePtr,
456                       Value index, Value size);
457 
458   /// Builds IR extracting the pointer to the first element of the stride array.
459   static Value strideBasePtr(OpBuilder &builder, Location loc,
460                              LLVMTypeConverter &typeConverter,
461                              Value sizeBasePtr, Value rank);
462   /// Builds IR extracting the stride[index] from the descriptor.
463   static Value stride(OpBuilder &builder, Location loc,
464                       LLVMTypeConverter typeConverter, Value strideBasePtr,
465                       Value index, Value stride);
466   /// Builds IR inserting the stride[index] into the descriptor.
467   static void setStride(OpBuilder &builder, Location loc,
468                         LLVMTypeConverter typeConverter, Value strideBasePtr,
469                         Value index, Value stride);
470 };
471 
472 /// Base class for operation conversions targeting the LLVM IR dialect. It
473 /// provides the conversion patterns with access to the LLVMTypeConverter and
474 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
475 /// LowerToLLVMOptions by reference meaning the references have to remain alive
476 /// during the entire pattern lifetime.
477 class ConvertToLLVMPattern : public ConversionPattern {
478 public:
479   ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
480                        LLVMTypeConverter &typeConverter,
481                        PatternBenefit benefit = 1);
482 
483 protected:
484   /// Returns the LLVM dialect.
485   LLVM::LLVMDialect &getDialect() const;
486 
487   LLVMTypeConverter *getTypeConverter() const;
488 
489   /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
490   /// defined by the used type converter.
491   Type getIndexType() const;
492 
493   /// Gets the MLIR type wrapping the LLVM integer type whose bit width
494   /// corresponds to that of a LLVM pointer type.
495   Type getIntPtrType(unsigned addressSpace = 0) const;
496 
497   /// Gets the MLIR type wrapping the LLVM void type.
498   Type getVoidType() const;
499 
500   /// Get the MLIR type wrapping the LLVM i8* type.
501   Type getVoidPtrType() const;
502 
503   /// Create an LLVM dialect operation defining the given index constant.
504   Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
505                             uint64_t value) const;
506 
507   // This is a strided getElementPtr variant that linearizes subscripts as:
508   //   `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
509   Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
510                              ValueRange indices,
511                              ConversionPatternRewriter &rewriter) const;
512 
513   /// Returns if the given memref has identity maps and the element type is
514   /// convertible to LLVM.
515   bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
516 
517   /// Returns the type of a pointer to an element of the memref.
518   Type getElementPtrType(MemRefType type) const;
519 
520   /// Computes sizes, strides and buffer size in bytes of `memRefType` with
521   /// identity layout. Emits constant ops for the static sizes of `memRefType`,
522   /// and uses `dynamicSizes` for the others. Emits instructions to compute
523   /// strides and buffer size from these sizes.
524   ///
525   /// For example, memref<4x?xf32> emits:
526   /// `sizes[0]`   = llvm.mlir.constant(4 : index) : i64
527   /// `sizes[1]`   = `dynamicSizes[0]`
528   /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
529   /// `strides[0]` = `sizes[0]`
530   /// %size        = llvm.mul `sizes[0]`, `sizes[1]` : i64
531   /// %nullptr     = llvm.mlir.null : !llvm.ptr<f32>
532   /// %gep         = llvm.getelementptr %nullptr[%size]
533   ///                  : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
534   /// `sizeBytes`  = llvm.ptrtoint %gep : !llvm.ptr<f32> to i64
535   void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
536                                 ValueRange dynamicSizes,
537                                 ConversionPatternRewriter &rewriter,
538                                 SmallVectorImpl<Value> &sizes,
539                                 SmallVectorImpl<Value> &strides,
540                                 Value &sizeBytes) const;
541 
542   /// Computes the size of type in bytes.
543   Value getSizeInBytes(Location loc, Type type,
544                        ConversionPatternRewriter &rewriter) const;
545 
546   /// Computes total number of elements for the given shape.
547   Value getNumElements(Location loc, ArrayRef<Value> shape,
548                        ConversionPatternRewriter &rewriter) const;
549 
550   /// Creates and populates a canonical memref descriptor struct.
551   MemRefDescriptor
552   createMemRefDescriptor(Location loc, MemRefType memRefType,
553                          Value allocatedPtr, Value alignedPtr,
554                          ArrayRef<Value> sizes, ArrayRef<Value> strides,
555                          ConversionPatternRewriter &rewriter) const;
556 };
557 
558 /// Utility class for operation conversions targeting the LLVM dialect that
559 /// match exactly one source operation.
560 template <typename SourceOp>
561 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
562 public:
563   explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
564                                   PatternBenefit benefit = 1)
565       : ConvertToLLVMPattern(SourceOp::getOperationName(),
566                              &typeConverter.getContext(), typeConverter,
567                              benefit) {}
568 
569   /// Wrappers around the RewritePattern methods that pass the derived op type.
rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)570   void rewrite(Operation *op, ArrayRef<Value> operands,
571                ConversionPatternRewriter &rewriter) const final {
572     rewrite(cast<SourceOp>(op), operands, rewriter);
573   }
match(Operation * op)574   LogicalResult match(Operation *op) const final {
575     return match(cast<SourceOp>(op));
576   }
577   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)578   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
579                   ConversionPatternRewriter &rewriter) const final {
580     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
581   }
582 
583   /// Rewrite and Match methods that operate on the SourceOp type. These must be
584   /// overridden by the derived pattern class.
rewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)585   virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
586                        ConversionPatternRewriter &rewriter) const {
587     llvm_unreachable("must override rewrite or matchAndRewrite");
588   }
match(SourceOp op)589   virtual LogicalResult match(SourceOp op) const {
590     llvm_unreachable("must override match or matchAndRewrite");
591   }
592   virtual LogicalResult
matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)593   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
594                   ConversionPatternRewriter &rewriter) const {
595     if (succeeded(match(op))) {
596       rewrite(op, operands, rewriter);
597       return success();
598     }
599     return failure();
600   }
601 
602 private:
603   using ConvertToLLVMPattern::match;
604   using ConvertToLLVMPattern::matchAndRewrite;
605 };
606 
607 namespace LLVM {
608 namespace detail {
609 /// Replaces the given operation "op" with a new operation of type "targetOp"
610 /// and given operands.
611 LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
612                               ValueRange operands,
613                               LLVMTypeConverter &typeConverter,
614                               ConversionPatternRewriter &rewriter);
615 
616 LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
617                                     ValueRange operands,
618                                     LLVMTypeConverter &typeConverter,
619                                     ConversionPatternRewriter &rewriter);
620 } // namespace detail
621 } // namespace LLVM
622 
623 /// Generic implementation of one-to-one conversion from "SourceOp" to
624 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
625 /// Upholds a convention that multi-result operations get converted into an
626 /// operation returning the LLVM IR structure type, in which case individual
627 /// values must be extracted from using LLVM::ExtractValueOp before being used.
628 template <typename SourceOp, typename TargetOp>
629 class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
630 public:
631   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
632   using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>;
633 
634   /// Converts the type of the result to an LLVM type, pass operands as is,
635   /// preserve attributes.
636   LogicalResult
matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)637   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
638                   ConversionPatternRewriter &rewriter) const override {
639     return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
640                                          operands, *this->getTypeConverter(),
641                                          rewriter);
642   }
643 };
644 
645 /// Basic lowering implementation to rewrite Ops with just one result to the
646 /// LLVM Dialect. This supports higher-dimensional vector types.
647 template <typename SourceOp, typename TargetOp>
648 class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
649 public:
650   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
651   using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
652 
653   LogicalResult
matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)654   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
655                   ConversionPatternRewriter &rewriter) const override {
656     static_assert(
657         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
658         "expected single result op");
659     static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
660                                   SourceOp>::value,
661                   "expected same operands and result type");
662     return LLVM::detail::vectorOneToOneRewrite(
663         op, TargetOp::getOperationName(), operands, *this->getTypeConverter(),
664         rewriter);
665   }
666 };
667 
668 /// Derived class that automatically populates legalization information for
669 /// different LLVM ops.
670 class LLVMConversionTarget : public ConversionTarget {
671 public:
672   explicit LLVMConversionTarget(MLIRContext &ctx);
673 };
674 
675 } // namespace mlir
676 
677 #endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
678