1 //===- ConvertStandardToLLVM.h - Convert to the LLVM dialect ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Provides a dialect conversion targeting the LLVM IR dialect. By default, it 10 // converts Standard ops and types and provides hooks for dialect-specific 11 // extensions to the conversion. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H 16 #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H 17 18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 21 namespace llvm { 22 class IntegerType; 23 class LLVMContext; 24 class Module; 25 class Type; 26 } // namespace llvm 27 28 namespace mlir { 29 30 class BaseMemRefType; 31 class ComplexType; 32 class LLVMTypeConverter; 33 class UnrankedMemRefType; 34 35 namespace LLVM { 36 class LLVMDialect; 37 class LLVMPointerType; 38 } // namespace LLVM 39 40 /// Callback to convert function argument types. It converts a MemRef function 41 /// argument to a list of non-aggregate types containing descriptor 42 /// information, and an UnrankedmemRef function argument to a list containing 43 /// the rank and a pointer to a descriptor struct. 44 LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter, 45 Type type, 46 SmallVectorImpl<Type> &result); 47 48 /// Callback to convert function argument types. It converts MemRef function 49 /// arguments to bare pointers to the MemRef element type. 50 LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, 51 Type type, 52 SmallVectorImpl<Type> &result); 53 54 /// Conversion from types in the Standard dialect to the LLVM IR dialect. 55 class LLVMTypeConverter : public TypeConverter { 56 /// Give structFuncArgTypeConverter access to memref-specific functions. 57 friend LogicalResult 58 structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, 59 SmallVectorImpl<Type> &result); 60 61 public: 62 using TypeConverter::convertType; 63 64 /// Create an LLVMTypeConverter using the default LowerToLLVMOptions. 65 LLVMTypeConverter(MLIRContext *ctx); 66 67 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. 68 LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options); 69 70 /// Convert a function type. The arguments and results are converted one by 71 /// one and results are packed into a wrapped LLVM IR structure type. `result` 72 /// is populated with argument mapping. 73 Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, 74 SignatureConversion &result); 75 76 /// Convert a non-empty list of types to be returned from a function into a 77 /// supported LLVM IR type. In particular, if more than one value is 78 /// returned, create an LLVM IR structure type with elements that correspond 79 /// to each of the MLIR types converted with `convertType`. 80 Type packFunctionResults(ArrayRef<Type> types); 81 82 /// Convert a type in the context of the default or bare pointer calling 83 /// convention. Calling convention sensitive types, such as MemRefType and 84 /// UnrankedMemRefType, are converted following the specific rules for the 85 /// calling convention. Calling convention independent types are converted 86 /// following the default LLVM type conversions. 87 Type convertCallingConventionType(Type type); 88 89 /// Promote the bare pointers in 'values' that resulted from memrefs to 90 /// descriptors. 'stdTypes' holds the types of 'values' before the conversion 91 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). 92 void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, 93 Location loc, ArrayRef<Type> stdTypes, 94 SmallVectorImpl<Value> &values); 95 96 /// Returns the MLIR context. 97 MLIRContext &getContext(); 98 99 /// Returns the LLVM dialect. getDialect()100 LLVM::LLVMDialect *getDialect() { return llvmDialect; } 101 getOptions()102 const LowerToLLVMOptions &getOptions() const { return options; } 103 104 /// Promote the LLVM representation of all operands including promoting MemRef 105 /// descriptors to stack and use pointers to struct to avoid the complexity 106 /// of the platform-specific C/C++ ABI lowering related to struct argument 107 /// passing. 108 SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands, 109 ValueRange operands, 110 OpBuilder &builder); 111 112 /// Promote the LLVM struct representation of one MemRef descriptor to stack 113 /// and use pointer to struct to avoid the complexity of the platform-specific 114 /// C/C++ ABI lowering related to struct argument passing. 115 Value promoteOneMemRefDescriptor(Location loc, Value operand, 116 OpBuilder &builder); 117 118 /// Converts the function type to a C-compatible format, in particular using 119 /// pointers to memref descriptors for arguments. 120 Type convertFunctionTypeCWrapper(FunctionType type); 121 122 /// Returns the data layout to use during and after conversion. getDataLayout()123 const llvm::DataLayout &getDataLayout() { return options.dataLayout; } 124 125 /// Gets the LLVM representation of the index type. The returned type is an 126 /// integer type with the size configured for this type converter. 127 Type getIndexType(); 128 129 /// Gets the bitwidth of the index type when converted to LLVM. getIndexTypeBitwidth()130 unsigned getIndexTypeBitwidth() { return options.indexBitwidth; } 131 132 /// Gets the pointer bitwidth. 133 unsigned getPointerBitwidth(unsigned addressSpace = 0); 134 135 protected: 136 /// Pointer to the LLVM dialect. 137 LLVM::LLVMDialect *llvmDialect; 138 139 private: 140 /// Convert a function type. The arguments and results are converted one by 141 /// one. Additionally, if the function returns more than one value, pack the 142 /// results into an LLVM IR structure type so that the converted function type 143 /// returns at most one result. 144 Type convertFunctionType(FunctionType type); 145 146 /// Convert the index type. Uses llvmModule data layout to create an integer 147 /// of the pointer bitwidth. 148 Type convertIndexType(IndexType type); 149 150 /// Convert an integer type `i*` to `!llvm<"i*">`. 151 Type convertIntegerType(IntegerType type); 152 153 /// Convert a floating point type: `f16` to `f16`, `f32` to 154 /// `f32` and `f64` to `f64`. `bf16` is not supported 155 /// by LLVM. 156 Type convertFloatType(FloatType type); 157 158 /// Convert complex number type: `complex<f16>` to `!llvm<"{ half, half }">`, 159 /// `complex<f32>` to `!llvm<"{ float, float }">`, and `complex<f64>` to 160 /// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported. 161 Type convertComplexType(ComplexType type); 162 163 /// Convert a memref type into an LLVM type that captures the relevant data. 164 Type convertMemRefType(MemRefType type); 165 166 /// Convert a memref type into a list of LLVM IR types that will form the 167 /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides` 168 /// arrays in the descriptors are unpacked to individual index-typed elements, 169 /// else they are are kept as rank-sized arrays of index type. In particular, 170 /// the list will contain: 171 /// - two pointers to the memref element type, followed by 172 /// - an index-typed offset, followed by 173 /// - (if unpackAggregates = true) 174 /// - one index-typed size per dimension of the memref, followed by 175 /// - one index-typed stride per dimension of the memref. 176 /// - (if unpackArrregates = false) 177 /// - one rank-sized array of index-type for the size of each dimension 178 /// - one rank-sized array of index-type for the stride of each dimension 179 /// 180 /// For example, memref<?x?xf32> is converted to the following list: 181 /// - `!llvm<"float*">` (allocated pointer), 182 /// - `!llvm<"float*">` (aligned pointer), 183 /// - `i64` (offset), 184 /// - `i64`, `i64` (sizes), 185 /// - `i64`, `i64` (strides). 186 /// These types can be recomposed to a memref descriptor struct. 187 SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type, 188 bool unpackAggregates); 189 190 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types 191 /// that will form the unranked memref descriptor. In particular, this list 192 /// contains: 193 /// - an integer rank, followed by 194 /// - a pointer to the memref descriptor struct. 195 /// For example, memref<*xf32> is converted to the following list: 196 /// i64 (rank) 197 /// !llvm<"i8*"> (type-erased pointer). 198 /// These types can be recomposed to a unranked memref descriptor struct. 199 SmallVector<Type, 2> getUnrankedMemRefDescriptorFields(); 200 201 // Convert an unranked memref type to an LLVM type that captures the 202 // runtime rank and a pointer to the static ranked memref desc 203 Type convertUnrankedMemRefType(UnrankedMemRefType type); 204 205 /// Convert a memref type to a bare pointer to the memref element type. 206 Type convertMemRefToBarePtr(BaseMemRefType type); 207 208 // Convert a 1D vector type into an LLVM vector type. 209 Type convertVectorType(VectorType type); 210 211 /// Options for customizing the llvm lowering. 212 LowerToLLVMOptions options; 213 }; 214 215 /// Helper class to produce LLVM dialect operations extracting or inserting 216 /// values to a struct. 217 class StructBuilder { 218 public: 219 /// Construct a helper for the given value. 220 explicit StructBuilder(Value v); 221 /// Builds IR creating an `undef` value of the descriptor type. 222 static StructBuilder undef(OpBuilder &builder, Location loc, 223 Type descriptorType); 224 Value()225 /*implicit*/ operator Value() { return value; } 226 227 protected: 228 // LLVM value 229 Value value; 230 // Cached struct type. 231 Type structType; 232 233 protected: 234 /// Builds IR to extract a value from the struct at position pos 235 Value extractPtr(OpBuilder &builder, Location loc, unsigned pos); 236 /// Builds IR to set a value in the struct at position pos 237 void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr); 238 }; 239 240 class ComplexStructBuilder : public StructBuilder { 241 public: 242 /// Construct a helper for the given complex number value. 243 using StructBuilder::StructBuilder; 244 /// Build IR creating an `undef` value of the complex number type. 245 static ComplexStructBuilder undef(OpBuilder &builder, Location loc, 246 Type type); 247 248 // Build IR extracting the real value from the complex number struct. 249 Value real(OpBuilder &builder, Location loc); 250 // Build IR inserting the real value into the complex number struct. 251 void setReal(OpBuilder &builder, Location loc, Value real); 252 253 // Build IR extracting the imaginary value from the complex number struct. 254 Value imaginary(OpBuilder &builder, Location loc); 255 // Build IR inserting the imaginary value into the complex number struct. 256 void setImaginary(OpBuilder &builder, Location loc, Value imaginary); 257 }; 258 259 /// Helper class to produce LLVM dialect operations extracting or inserting 260 /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. 261 /// The Value may be null, in which case none of the operations are valid. 262 class MemRefDescriptor : public StructBuilder { 263 public: 264 /// Construct a helper for the given descriptor value. 265 explicit MemRefDescriptor(Value descriptor); 266 /// Builds IR creating an `undef` value of the descriptor type. 267 static MemRefDescriptor undef(OpBuilder &builder, Location loc, 268 Type descriptorType); 269 /// Builds IR creating a MemRef descriptor that represents `type` and 270 /// populates it with static shape and stride information extracted from the 271 /// type. 272 static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, 273 LLVMTypeConverter &typeConverter, 274 MemRefType type, Value memory); 275 276 /// Builds IR extracting the allocated pointer from the descriptor. 277 Value allocatedPtr(OpBuilder &builder, Location loc); 278 /// Builds IR inserting the allocated pointer into the descriptor. 279 void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr); 280 281 /// Builds IR extracting the aligned pointer from the descriptor. 282 Value alignedPtr(OpBuilder &builder, Location loc); 283 284 /// Builds IR inserting the aligned pointer into the descriptor. 285 void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr); 286 287 /// Builds IR extracting the offset from the descriptor. 288 Value offset(OpBuilder &builder, Location loc); 289 290 /// Builds IR inserting the offset into the descriptor. 291 void setOffset(OpBuilder &builder, Location loc, Value offset); 292 void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset); 293 294 /// Builds IR extracting the pos-th size from the descriptor. 295 Value size(OpBuilder &builder, Location loc, unsigned pos); 296 Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank); 297 298 /// Builds IR inserting the pos-th size into the descriptor 299 void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size); 300 void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, 301 uint64_t size); 302 303 /// Builds IR extracting the pos-th size from the descriptor. 304 Value stride(OpBuilder &builder, Location loc, unsigned pos); 305 306 /// Builds IR inserting the pos-th stride into the descriptor 307 void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride); 308 void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, 309 uint64_t stride); 310 311 /// Returns the (LLVM) pointer type this descriptor contains. 312 LLVM::LLVMPointerType getElementPtrType(); 313 314 /// Builds IR populating a MemRef descriptor structure from a list of 315 /// individual values composing that descriptor, in the following order: 316 /// - allocated pointer; 317 /// - aligned pointer; 318 /// - offset; 319 /// - <rank> sizes; 320 /// - <rank> shapes; 321 /// where <rank> is the MemRef rank as provided in `type`. 322 static Value pack(OpBuilder &builder, Location loc, 323 LLVMTypeConverter &converter, MemRefType type, 324 ValueRange values); 325 326 /// Builds IR extracting individual elements of a MemRef descriptor structure 327 /// and returning them as `results` list. 328 static void unpack(OpBuilder &builder, Location loc, Value packed, 329 MemRefType type, SmallVectorImpl<Value> &results); 330 331 /// Returns the number of non-aggregate values that would be produced by 332 /// `unpack`. 333 static unsigned getNumUnpackedValues(MemRefType type); 334 335 private: 336 // Cached index type. 337 Type indexType; 338 }; 339 340 /// Helper class allowing the user to access a range of Values that correspond 341 /// to an unpacked memref descriptor using named accessors. This does not own 342 /// the values. 343 class MemRefDescriptorView { 344 public: 345 /// Constructs the view from a range of values. Infers the rank from the size 346 /// of the range. 347 explicit MemRefDescriptorView(ValueRange range); 348 349 /// Returns the allocated pointer Value. 350 Value allocatedPtr(); 351 352 /// Returns the aligned pointer Value. 353 Value alignedPtr(); 354 355 /// Returns the offset Value. 356 Value offset(); 357 358 /// Returns the pos-th size Value. 359 Value size(unsigned pos); 360 361 /// Returns the pos-th stride Value. 362 Value stride(unsigned pos); 363 364 private: 365 /// Rank of the memref the descriptor is pointing to. 366 int rank; 367 /// Underlying range of Values. 368 ValueRange elements; 369 }; 370 371 class UnrankedMemRefDescriptor : public StructBuilder { 372 public: 373 /// Construct a helper for the given descriptor value. 374 explicit UnrankedMemRefDescriptor(Value descriptor); 375 /// Builds IR creating an `undef` value of the descriptor type. 376 static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, 377 Type descriptorType); 378 379 /// Builds IR extracting the rank from the descriptor 380 Value rank(OpBuilder &builder, Location loc); 381 /// Builds IR setting the rank in the descriptor 382 void setRank(OpBuilder &builder, Location loc, Value value); 383 /// Builds IR extracting ranked memref descriptor ptr 384 Value memRefDescPtr(OpBuilder &builder, Location loc); 385 /// Builds IR setting ranked memref descriptor ptr 386 void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value); 387 388 /// Builds IR populating an unranked MemRef descriptor structure from a list 389 /// of individual constituent values in the following order: 390 /// - rank of the memref; 391 /// - pointer to the memref descriptor. 392 static Value pack(OpBuilder &builder, Location loc, 393 LLVMTypeConverter &converter, UnrankedMemRefType type, 394 ValueRange values); 395 396 /// Builds IR extracting individual elements that compose an unranked memref 397 /// descriptor and returns them as `results` list. 398 static void unpack(OpBuilder &builder, Location loc, Value packed, 399 SmallVectorImpl<Value> &results); 400 401 /// Returns the number of non-aggregate values that would be produced by 402 /// `unpack`. getNumUnpackedValues()403 static unsigned getNumUnpackedValues() { return 2; } 404 405 /// Builds IR computing the sizes in bytes (suitable for opaque allocation) 406 /// and appends the corresponding values into `sizes`. 407 static void computeSizes(OpBuilder &builder, Location loc, 408 LLVMTypeConverter &typeConverter, 409 ArrayRef<UnrankedMemRefDescriptor> values, 410 SmallVectorImpl<Value> &sizes); 411 412 /// TODO: The following accessors don't take alignment rules between elements 413 /// of the descriptor struct into account. For some architectures, it might be 414 /// necessary to extend them and to use `llvm::DataLayout` contained in 415 /// `LLVMTypeConverter`. 416 417 /// Builds IR extracting the allocated pointer from the descriptor. 418 static Value allocatedPtr(OpBuilder &builder, Location loc, 419 Value memRefDescPtr, Type elemPtrPtrType); 420 /// Builds IR inserting the allocated pointer into the descriptor. 421 static void setAllocatedPtr(OpBuilder &builder, Location loc, 422 Value memRefDescPtr, Type elemPtrPtrType, 423 Value allocatedPtr); 424 425 /// Builds IR extracting the aligned pointer from the descriptor. 426 static Value alignedPtr(OpBuilder &builder, Location loc, 427 LLVMTypeConverter &typeConverter, Value memRefDescPtr, 428 Type elemPtrPtrType); 429 /// Builds IR inserting the aligned pointer into the descriptor. 430 static void setAlignedPtr(OpBuilder &builder, Location loc, 431 LLVMTypeConverter &typeConverter, 432 Value memRefDescPtr, Type elemPtrPtrType, 433 Value alignedPtr); 434 435 /// Builds IR extracting the offset from the descriptor. 436 static Value offset(OpBuilder &builder, Location loc, 437 LLVMTypeConverter &typeConverter, Value memRefDescPtr, 438 Type elemPtrPtrType); 439 /// Builds IR inserting the offset into the descriptor. 440 static void setOffset(OpBuilder &builder, Location loc, 441 LLVMTypeConverter &typeConverter, Value memRefDescPtr, 442 Type elemPtrPtrType, Value offset); 443 444 /// Builds IR extracting the pointer to the first element of the size array. 445 static Value sizeBasePtr(OpBuilder &builder, Location loc, 446 LLVMTypeConverter &typeConverter, 447 Value memRefDescPtr, 448 LLVM::LLVMPointerType elemPtrPtrType); 449 /// Builds IR extracting the size[index] from the descriptor. 450 static Value size(OpBuilder &builder, Location loc, 451 LLVMTypeConverter typeConverter, Value sizeBasePtr, 452 Value index); 453 /// Builds IR inserting the size[index] into the descriptor. 454 static void setSize(OpBuilder &builder, Location loc, 455 LLVMTypeConverter typeConverter, Value sizeBasePtr, 456 Value index, Value size); 457 458 /// Builds IR extracting the pointer to the first element of the stride array. 459 static Value strideBasePtr(OpBuilder &builder, Location loc, 460 LLVMTypeConverter &typeConverter, 461 Value sizeBasePtr, Value rank); 462 /// Builds IR extracting the stride[index] from the descriptor. 463 static Value stride(OpBuilder &builder, Location loc, 464 LLVMTypeConverter typeConverter, Value strideBasePtr, 465 Value index, Value stride); 466 /// Builds IR inserting the stride[index] into the descriptor. 467 static void setStride(OpBuilder &builder, Location loc, 468 LLVMTypeConverter typeConverter, Value strideBasePtr, 469 Value index, Value stride); 470 }; 471 472 /// Base class for operation conversions targeting the LLVM IR dialect. It 473 /// provides the conversion patterns with access to the LLVMTypeConverter and 474 /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the 475 /// LowerToLLVMOptions by reference meaning the references have to remain alive 476 /// during the entire pattern lifetime. 477 class ConvertToLLVMPattern : public ConversionPattern { 478 public: 479 ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, 480 LLVMTypeConverter &typeConverter, 481 PatternBenefit benefit = 1); 482 483 protected: 484 /// Returns the LLVM dialect. 485 LLVM::LLVMDialect &getDialect() const; 486 487 LLVMTypeConverter *getTypeConverter() const; 488 489 /// Gets the MLIR type wrapping the LLVM integer type whose bit width is 490 /// defined by the used type converter. 491 Type getIndexType() const; 492 493 /// Gets the MLIR type wrapping the LLVM integer type whose bit width 494 /// corresponds to that of a LLVM pointer type. 495 Type getIntPtrType(unsigned addressSpace = 0) const; 496 497 /// Gets the MLIR type wrapping the LLVM void type. 498 Type getVoidType() const; 499 500 /// Get the MLIR type wrapping the LLVM i8* type. 501 Type getVoidPtrType() const; 502 503 /// Create an LLVM dialect operation defining the given index constant. 504 Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, 505 uint64_t value) const; 506 507 // This is a strided getElementPtr variant that linearizes subscripts as: 508 // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. 509 Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, 510 ValueRange indices, 511 ConversionPatternRewriter &rewriter) const; 512 513 /// Returns if the given memref has identity maps and the element type is 514 /// convertible to LLVM. 515 bool isConvertibleAndHasIdentityMaps(MemRefType type) const; 516 517 /// Returns the type of a pointer to an element of the memref. 518 Type getElementPtrType(MemRefType type) const; 519 520 /// Computes sizes, strides and buffer size in bytes of `memRefType` with 521 /// identity layout. Emits constant ops for the static sizes of `memRefType`, 522 /// and uses `dynamicSizes` for the others. Emits instructions to compute 523 /// strides and buffer size from these sizes. 524 /// 525 /// For example, memref<4x?xf32> emits: 526 /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64 527 /// `sizes[1]` = `dynamicSizes[0]` 528 /// `strides[1]` = llvm.mlir.constant(1 : index) : i64 529 /// `strides[0]` = `sizes[0]` 530 /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64 531 /// %nullptr = llvm.mlir.null : !llvm.ptr<f32> 532 /// %gep = llvm.getelementptr %nullptr[%size] 533 /// : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32> 534 /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr<f32> to i64 535 void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, 536 ValueRange dynamicSizes, 537 ConversionPatternRewriter &rewriter, 538 SmallVectorImpl<Value> &sizes, 539 SmallVectorImpl<Value> &strides, 540 Value &sizeBytes) const; 541 542 /// Computes the size of type in bytes. 543 Value getSizeInBytes(Location loc, Type type, 544 ConversionPatternRewriter &rewriter) const; 545 546 /// Computes total number of elements for the given shape. 547 Value getNumElements(Location loc, ArrayRef<Value> shape, 548 ConversionPatternRewriter &rewriter) const; 549 550 /// Creates and populates a canonical memref descriptor struct. 551 MemRefDescriptor 552 createMemRefDescriptor(Location loc, MemRefType memRefType, 553 Value allocatedPtr, Value alignedPtr, 554 ArrayRef<Value> sizes, ArrayRef<Value> strides, 555 ConversionPatternRewriter &rewriter) const; 556 }; 557 558 /// Utility class for operation conversions targeting the LLVM dialect that 559 /// match exactly one source operation. 560 template <typename SourceOp> 561 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { 562 public: 563 explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, 564 PatternBenefit benefit = 1) 565 : ConvertToLLVMPattern(SourceOp::getOperationName(), 566 &typeConverter.getContext(), typeConverter, 567 benefit) {} 568 569 /// Wrappers around the RewritePattern methods that pass the derived op type. rewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)570 void rewrite(Operation *op, ArrayRef<Value> operands, 571 ConversionPatternRewriter &rewriter) const final { 572 rewrite(cast<SourceOp>(op), operands, rewriter); 573 } match(Operation * op)574 LogicalResult match(Operation *op) const final { 575 return match(cast<SourceOp>(op)); 576 } 577 LogicalResult matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)578 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 579 ConversionPatternRewriter &rewriter) const final { 580 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); 581 } 582 583 /// Rewrite and Match methods that operate on the SourceOp type. These must be 584 /// overridden by the derived pattern class. rewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)585 virtual void rewrite(SourceOp op, ArrayRef<Value> operands, 586 ConversionPatternRewriter &rewriter) const { 587 llvm_unreachable("must override rewrite or matchAndRewrite"); 588 } match(SourceOp op)589 virtual LogicalResult match(SourceOp op) const { 590 llvm_unreachable("must override match or matchAndRewrite"); 591 } 592 virtual LogicalResult matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)593 matchAndRewrite(SourceOp op, ArrayRef<Value> operands, 594 ConversionPatternRewriter &rewriter) const { 595 if (succeeded(match(op))) { 596 rewrite(op, operands, rewriter); 597 return success(); 598 } 599 return failure(); 600 } 601 602 private: 603 using ConvertToLLVMPattern::match; 604 using ConvertToLLVMPattern::matchAndRewrite; 605 }; 606 607 namespace LLVM { 608 namespace detail { 609 /// Replaces the given operation "op" with a new operation of type "targetOp" 610 /// and given operands. 611 LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, 612 ValueRange operands, 613 LLVMTypeConverter &typeConverter, 614 ConversionPatternRewriter &rewriter); 615 616 LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, 617 ValueRange operands, 618 LLVMTypeConverter &typeConverter, 619 ConversionPatternRewriter &rewriter); 620 } // namespace detail 621 } // namespace LLVM 622 623 /// Generic implementation of one-to-one conversion from "SourceOp" to 624 /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent. 625 /// Upholds a convention that multi-result operations get converted into an 626 /// operation returning the LLVM IR structure type, in which case individual 627 /// values must be extracted from using LLVM::ExtractValueOp before being used. 628 template <typename SourceOp, typename TargetOp> 629 class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> { 630 public: 631 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; 632 using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>; 633 634 /// Converts the type of the result to an LLVM type, pass operands as is, 635 /// preserve attributes. 636 LogicalResult matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)637 matchAndRewrite(SourceOp op, ArrayRef<Value> operands, 638 ConversionPatternRewriter &rewriter) const override { 639 return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), 640 operands, *this->getTypeConverter(), 641 rewriter); 642 } 643 }; 644 645 /// Basic lowering implementation to rewrite Ops with just one result to the 646 /// LLVM Dialect. This supports higher-dimensional vector types. 647 template <typename SourceOp, typename TargetOp> 648 class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> { 649 public: 650 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; 651 using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>; 652 653 LogicalResult matchAndRewrite(SourceOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)654 matchAndRewrite(SourceOp op, ArrayRef<Value> operands, 655 ConversionPatternRewriter &rewriter) const override { 656 static_assert( 657 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, 658 "expected single result op"); 659 static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>, 660 SourceOp>::value, 661 "expected same operands and result type"); 662 return LLVM::detail::vectorOneToOneRewrite( 663 op, TargetOp::getOperationName(), operands, *this->getTypeConverter(), 664 rewriter); 665 } 666 }; 667 668 /// Derived class that automatically populates legalization information for 669 /// different LLVM ops. 670 class LLVMConversionTarget : public ConversionTarget { 671 public: 672 explicit LLVMConversionTarget(MLIRContext &ctx); 673 }; 674 675 } // namespace mlir 676 677 #endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H 678