1 //===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
10 #define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
11 
12 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 
16 namespace mlir {
17 
18 namespace LLVM {
19 namespace detail {
20 /// Replaces the given operation "op" with a new operation of type "targetOp"
21 /// and given operands.
22 LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
23                               ValueRange operands,
24                               LLVMTypeConverter &typeConverter,
25                               ConversionPatternRewriter &rewriter);
26 } // namespace detail
27 } // namespace LLVM
28 
29 /// Base class for operation conversions targeting the LLVM IR dialect. It
30 /// provides the conversion patterns with access to the LLVMTypeConverter and
31 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the
32 /// LowerToLLVMOptions by reference meaning the references have to remain alive
33 /// during the entire pattern lifetime.
34 class ConvertToLLVMPattern : public ConversionPattern {
35 public:
36   ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
37                        LLVMTypeConverter &typeConverter,
38                        PatternBenefit benefit = 1);
39 
40 protected:
41   /// Returns the LLVM dialect.
42   LLVM::LLVMDialect &getDialect() const;
43 
44   LLVMTypeConverter *getTypeConverter() const;
45 
46   /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
47   /// defined by the used type converter.
48   Type getIndexType() const;
49 
50   /// Gets the MLIR type wrapping the LLVM integer type whose bit width
51   /// corresponds to that of a LLVM pointer type.
52   Type getIntPtrType(unsigned addressSpace = 0) const;
53 
54   /// Gets the MLIR type wrapping the LLVM void type.
55   Type getVoidType() const;
56 
57   /// Get the MLIR type wrapping the LLVM i8* type.
58   Type getVoidPtrType() const;
59 
60   /// Create a constant Op producing a value of `resultType` from an index-typed
61   /// integer attribute.
62   static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
63                                        Type resultType, int64_t value);
64 
65   /// Create an LLVM dialect operation defining the given index constant.
66   Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
67                             uint64_t value) const;
68 
69   // This is a strided getElementPtr variant that linearizes subscripts as:
70   //   `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
71   Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
72                              ValueRange indices,
73                              ConversionPatternRewriter &rewriter) const;
74 
75   /// Returns if the given memref has identity maps and the element type is
76   /// convertible to LLVM.
77   bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
78 
79   /// Returns the type of a pointer to an element of the memref.
80   Type getElementPtrType(MemRefType type) const;
81 
82   /// Computes sizes, strides and buffer size in bytes of `memRefType` with
83   /// identity layout. Emits constant ops for the static sizes of `memRefType`,
84   /// and uses `dynamicSizes` for the others. Emits instructions to compute
85   /// strides and buffer size from these sizes.
86   ///
87   /// For example, memref<4x?xf32> emits:
88   /// `sizes[0]`   = llvm.mlir.constant(4 : index) : i64
89   /// `sizes[1]`   = `dynamicSizes[0]`
90   /// `strides[1]` = llvm.mlir.constant(1 : index) : i64
91   /// `strides[0]` = `sizes[0]`
92   /// %size        = llvm.mul `sizes[0]`, `sizes[1]` : i64
93   /// %nullptr     = llvm.mlir.null : !llvm.ptr<f32>
94   /// %gep         = llvm.getelementptr %nullptr[%size]
95   ///                  : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
96   /// `sizeBytes`  = llvm.ptrtoint %gep : !llvm.ptr<f32> to i64
97   void getMemRefDescriptorSizes(Location loc, MemRefType memRefType,
98                                 ValueRange dynamicSizes,
99                                 ConversionPatternRewriter &rewriter,
100                                 SmallVectorImpl<Value> &sizes,
101                                 SmallVectorImpl<Value> &strides,
102                                 Value &sizeBytes) const;
103 
104   /// Computes the size of type in bytes.
105   Value getSizeInBytes(Location loc, Type type,
106                        ConversionPatternRewriter &rewriter) const;
107 
108   /// Computes total number of elements for the given shape.
109   Value getNumElements(Location loc, ArrayRef<Value> shape,
110                        ConversionPatternRewriter &rewriter) const;
111 
112   /// Creates and populates a canonical memref descriptor struct.
113   MemRefDescriptor
114   createMemRefDescriptor(Location loc, MemRefType memRefType,
115                          Value allocatedPtr, Value alignedPtr,
116                          ArrayRef<Value> sizes, ArrayRef<Value> strides,
117                          ConversionPatternRewriter &rewriter) const;
118 
119   /// Copies the memory descriptor for any operands that were unranked
120   /// descriptors originally to heap-allocated memory (if toDynamic is true) or
121   /// to stack-allocated memory (otherwise). Also frees the previously used
122   /// memory (that is assumed to be heap-allocated) if toDynamic is false.
123   LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
124                                         TypeRange origTypes,
125                                         SmallVectorImpl<Value> &operands,
126                                         bool toDynamic) const;
127 };
128 
129 /// Utility class for operation conversions targeting the LLVM dialect that
130 /// match exactly one source operation.
131 template <typename SourceOp>
132 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
133 public:
134   using OpAdaptor = typename SourceOp::Adaptor;
135 
136   explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
137                                   PatternBenefit benefit = 1)
138       : ConvertToLLVMPattern(SourceOp::getOperationName(),
139                              &typeConverter.getContext(), typeConverter,
140                              benefit) {}
141 
142   /// Wrappers around the RewritePattern methods that pass the derived op type.
rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)143   void rewrite(Operation *op, ArrayRef<Value> operands,
144                ConversionPatternRewriter &rewriter) const final {
145     rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
146             rewriter);
147   }
match(Operation * op)148   LogicalResult match(Operation *op) const final {
149     return match(cast<SourceOp>(op));
150   }
151   LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)152   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
153                   ConversionPatternRewriter &rewriter) const final {
154     return matchAndRewrite(cast<SourceOp>(op),
155                            OpAdaptor(operands, op->getAttrDictionary()),
156                            rewriter);
157   }
158 
159   /// Rewrite and Match methods that operate on the SourceOp type. These must be
160   /// overridden by the derived pattern class.
161   /// NOTICE: These methods are deprecated and will be removed. All new code
162   /// should use the adaptor methods below instead.
rewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)163   virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
164                        ConversionPatternRewriter &rewriter) const {
165     llvm_unreachable("must override rewrite or matchAndRewrite");
166   }
167   virtual LogicalResult
matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)168   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
169                   ConversionPatternRewriter &rewriter) const {
170     if (succeeded(match(op))) {
171       rewrite(op, OpAdaptor(operands, op->getAttrDictionary()), rewriter);
172       return success();
173     }
174     return failure();
175   }
176 
177   /// Rewrite and Match methods that operate on the SourceOp type. These must be
178   /// overridden by the derived pattern class.
match(SourceOp op)179   virtual LogicalResult match(SourceOp op) const {
180     llvm_unreachable("must override match or matchAndRewrite");
181   }
rewrite(SourceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter)182   virtual void rewrite(SourceOp op, OpAdaptor adaptor,
183                        ConversionPatternRewriter &rewriter) const {
184     ValueRange operands = adaptor.getOperands();
185     rewrite(op,
186             ArrayRef<Value>(operands.getBase().get<const Value *>(),
187                             operands.size()),
188             rewriter);
189   }
190   virtual LogicalResult
matchAndRewrite(SourceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter)191   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
192                   ConversionPatternRewriter &rewriter) const {
193     ValueRange operands = adaptor.getOperands();
194     return matchAndRewrite(
195         op,
196         ArrayRef<Value>(operands.getBase().get<const Value *>(),
197                         operands.size()),
198         rewriter);
199   }
200 
201 private:
202   using ConvertToLLVMPattern::match;
203   using ConvertToLLVMPattern::matchAndRewrite;
204 };
205 
206 /// Generic implementation of one-to-one conversion from "SourceOp" to
207 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
208 /// Upholds a convention that multi-result operations get converted into an
209 /// operation returning the LLVM IR structure type, in which case individual
210 /// values must be extracted from using LLVM::ExtractValueOp before being used.
211 template <typename SourceOp, typename TargetOp>
212 class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
213 public:
214   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
215   using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>;
216 
217   /// Converts the type of the result to an LLVM type, pass operands as is,
218   /// preserve attributes.
219   LogicalResult
matchAndRewrite(SourceOp op,typename SourceOp::Adaptor adaptor,ConversionPatternRewriter & rewriter)220   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
221                   ConversionPatternRewriter &rewriter) const override {
222     return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
223                                          adaptor.getOperands(),
224                                          *this->getTypeConverter(), rewriter);
225   }
226 };
227 
228 } // namespace mlir
229 
230 #endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H
231