1 //===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H 10 #define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H 11 12 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Transforms/DialectConversion.h" 15 16 namespace mlir { 17 18 namespace LLVM { 19 namespace detail { 20 /// Replaces the given operation "op" with a new operation of type "targetOp" 21 /// and given operands. 22 LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, 23 ValueRange operands, 24 LLVMTypeConverter &typeConverter, 25 ConversionPatternRewriter &rewriter); 26 } // namespace detail 27 } // namespace LLVM 28 29 /// Base class for operation conversions targeting the LLVM IR dialect. It 30 /// provides the conversion patterns with access to the LLVMTypeConverter and 31 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the 32 /// LowerToLLVMOptions by reference meaning the references have to remain alive 33 /// during the entire pattern lifetime. 34 class ConvertToLLVMPattern : public ConversionPattern { 35 public: 36 ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, 37 LLVMTypeConverter &typeConverter, 38 PatternBenefit benefit = 1); 39 40 protected: 41 /// Returns the LLVM dialect. 42 LLVM::LLVMDialect &getDialect() const; 43 44 LLVMTypeConverter *getTypeConverter() const; 45 46 /// Gets the MLIR type wrapping the LLVM integer type whose bit width is 47 /// defined by the used type converter. 48 Type getIndexType() const; 49 50 /// Gets the MLIR type wrapping the LLVM integer type whose bit width 51 /// corresponds to that of a LLVM pointer type. 52 Type getIntPtrType(unsigned addressSpace = 0) const; 53 54 /// Gets the MLIR type wrapping the LLVM void type. 55 Type getVoidType() const; 56 57 /// Get the MLIR type wrapping the LLVM i8* type. 58 Type getVoidPtrType() const; 59 60 /// Create a constant Op producing a value of `resultType` from an index-typed 61 /// integer attribute. 62 static Value createIndexAttrConstant(OpBuilder &builder, Location loc, 63 Type resultType, int64_t value); 64 65 /// Create an LLVM dialect operation defining the given index constant. 66 Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, 67 uint64_t value) const; 68 69 // This is a strided getElementPtr variant that linearizes subscripts as: 70 // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. 71 Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, 72 ValueRange indices, 73 ConversionPatternRewriter &rewriter) const; 74 75 /// Returns if the given memref has identity maps and the element type is 76 /// convertible to LLVM. 77 bool isConvertibleAndHasIdentityMaps(MemRefType type) const; 78 79 /// Returns the type of a pointer to an element of the memref. 80 Type getElementPtrType(MemRefType type) const; 81 82 /// Computes sizes, strides and buffer size in bytes of `memRefType` with 83 /// identity layout. Emits constant ops for the static sizes of `memRefType`, 84 /// and uses `dynamicSizes` for the others. Emits instructions to compute 85 /// strides and buffer size from these sizes. 86 /// 87 /// For example, memref<4x?xf32> emits: 88 /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64 89 /// `sizes[1]` = `dynamicSizes[0]` 90 /// `strides[1]` = llvm.mlir.constant(1 : index) : i64 91 /// `strides[0]` = `sizes[0]` 92 /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64 93 /// %nullptr = llvm.mlir.null : !llvm.ptr<f32> 94 /// %gep = llvm.getelementptr %nullptr[%size] 95 /// : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32> 96 /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr<f32> to i64 97 void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, 98 ValueRange dynamicSizes, 99 ConversionPatternRewriter &rewriter, 100 SmallVectorImpl<Value> &sizes, 101 SmallVectorImpl<Value> &strides, 102 Value &sizeBytes) const; 103 104 /// Computes the size of type in bytes. 105 Value getSizeInBytes(Location loc, Type type, 106 ConversionPatternRewriter &rewriter) const; 107 108 /// Computes total number of elements for the given shape. 109 Value getNumElements(Location loc, ArrayRef<Value> shape, 110 ConversionPatternRewriter &rewriter) const; 111 112 /// Creates and populates a canonical memref descriptor struct. 113 MemRefDescriptor 114 createMemRefDescriptor(Location loc, MemRefType memRefType, 115 Value allocatedPtr, Value alignedPtr, 116 ArrayRef<Value> sizes, ArrayRef<Value> strides, 117 ConversionPatternRewriter &rewriter) const; 118 119 /// Copies the memory descriptor for any operands that were unranked 120 /// descriptors originally to heap-allocated memory (if toDynamic is true) or 121 /// to stack-allocated memory (otherwise). Also frees the previously used 122 /// memory (that is assumed to be heap-allocated) if toDynamic is false. 123 LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, 124 TypeRange origTypes, 125 SmallVectorImpl<Value> &operands, 126 bool toDynamic) const; 127 }; 128 129 /// Utility class for operation conversions targeting the LLVM dialect that 130 /// match exactly one source operation. 131 template <typename SourceOp> 132 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { 133 public: 134 using OpAdaptor = typename SourceOp::Adaptor; 135 136 explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, 137 PatternBenefit benefit = 1) 138 : ConvertToLLVMPattern(SourceOp::getOperationName(), 139 &typeConverter.getContext(), typeConverter, 140 benefit) {} 141 142 /// Wrappers around the RewritePattern methods that pass the derived op type. rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)143 void rewrite(Operation *op, ArrayRef<Value> operands, 144 ConversionPatternRewriter &rewriter) const final { 145 rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()), 146 rewriter); 147 } match(Operation * op)148 LogicalResult match(Operation *op) const final { 149 return match(cast<SourceOp>(op)); 150 } 151 LogicalResult matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)152 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 153 ConversionPatternRewriter &rewriter) const final { 154 return matchAndRewrite(cast<SourceOp>(op), 155 OpAdaptor(operands, op->getAttrDictionary()), 156 rewriter); 157 } 158 159 /// Rewrite and Match methods that operate on the SourceOp type. These must be 160 /// overridden by the derived pattern class. 161 /// NOTICE: These methods are deprecated and will be removed. All new code 162 /// should use the adaptor methods below instead. rewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)163 virtual void rewrite(SourceOp op, ArrayRef<Value> operands, 164 ConversionPatternRewriter &rewriter) const { 165 llvm_unreachable("must override rewrite or matchAndRewrite"); 166 } 167 virtual LogicalResult matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)168 matchAndRewrite(SourceOp op, ArrayRef<Value> operands, 169 ConversionPatternRewriter &rewriter) const { 170 if (succeeded(match(op))) { 171 rewrite(op, OpAdaptor(operands, op->getAttrDictionary()), rewriter); 172 return success(); 173 } 174 return failure(); 175 } 176 177 /// Rewrite and Match methods that operate on the SourceOp type. These must be 178 /// overridden by the derived pattern class. match(SourceOp op)179 virtual LogicalResult match(SourceOp op) const { 180 llvm_unreachable("must override match or matchAndRewrite"); 181 } rewrite(SourceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter)182 virtual void rewrite(SourceOp op, OpAdaptor adaptor, 183 ConversionPatternRewriter &rewriter) const { 184 ValueRange operands = adaptor.getOperands(); 185 rewrite(op, 186 ArrayRef<Value>(operands.getBase().get<const Value *>(), 187 operands.size()), 188 rewriter); 189 } 190 virtual LogicalResult matchAndRewrite(SourceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter)191 matchAndRewrite(SourceOp op, OpAdaptor adaptor, 192 ConversionPatternRewriter &rewriter) const { 193 ValueRange operands = adaptor.getOperands(); 194 return matchAndRewrite( 195 op, 196 ArrayRef<Value>(operands.getBase().get<const Value *>(), 197 operands.size()), 198 rewriter); 199 } 200 201 private: 202 using ConvertToLLVMPattern::match; 203 using ConvertToLLVMPattern::matchAndRewrite; 204 }; 205 206 /// Generic implementation of one-to-one conversion from "SourceOp" to 207 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent. 208 /// Upholds a convention that multi-result operations get converted into an 209 /// operation returning the LLVM IR structure type, in which case individual 210 /// values must be extracted from using LLVM::ExtractValueOp before being used. 211 template <typename SourceOp, typename TargetOp> 212 class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> { 213 public: 214 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; 215 using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>; 216 217 /// Converts the type of the result to an LLVM type, pass operands as is, 218 /// preserve attributes. 219 LogicalResult matchAndRewrite(SourceOp op,typename SourceOp::Adaptor adaptor,ConversionPatternRewriter & rewriter)220 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 221 ConversionPatternRewriter &rewriter) const override { 222 return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), 223 adaptor.getOperands(), 224 *this->getTypeConverter(), rewriter); 225 } 226 }; 227 228 } // namespace mlir 229 230 #endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H 231