1 //===- StandardTypes.h - MLIR Standard Type Classes -------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_STANDARDTYPES_H 10 #define MLIR_IR_STANDARDTYPES_H 11 12 #include "mlir/IR/Types.h" 13 14 namespace llvm { 15 struct fltSemantics; 16 } // namespace llvm 17 18 namespace mlir { 19 class AffineExpr; 20 class AffineMap; 21 class FloatType; 22 class IndexType; 23 class IntegerType; 24 class Location; 25 class MLIRContext; 26 27 namespace detail { 28 29 struct IntegerTypeStorage; 30 struct ShapedTypeStorage; 31 struct VectorTypeStorage; 32 struct RankedTensorTypeStorage; 33 struct UnrankedTensorTypeStorage; 34 struct MemRefTypeStorage; 35 struct UnrankedMemRefTypeStorage; 36 struct ComplexTypeStorage; 37 struct TupleTypeStorage; 38 39 } // namespace detail 40 41 namespace StandardTypes { 42 enum Kind { 43 // Floating point. 44 BF16 = Type::Kind::FIRST_STANDARD_TYPE, 45 F16, 46 F32, 47 F64, 48 FIRST_FLOATING_POINT_TYPE = BF16, 49 LAST_FLOATING_POINT_TYPE = F64, 50 51 // Target pointer sized integer, used (e.g.) in affine mappings. 52 Index, 53 54 // Derived types. 55 Integer, 56 Vector, 57 RankedTensor, 58 UnrankedTensor, 59 MemRef, 60 UnrankedMemRef, 61 Complex, 62 Tuple, 63 None, 64 }; 65 66 } // namespace StandardTypes 67 68 //===----------------------------------------------------------------------===// 69 // ComplexType 70 //===----------------------------------------------------------------------===// 71 72 /// The 'complex' type represents a complex number with a parameterized element 73 /// type, which is composed of a real and imaginary value of that element type. 74 /// 75 /// The element must be a floating point or integer scalar type. 76 /// 77 class ComplexType 78 : public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> { 79 public: 80 using Base::Base; 81 82 /// Get or create a ComplexType with the provided element type. 83 static ComplexType get(Type elementType); 84 85 /// Get or create a ComplexType with the provided element type. This emits 86 /// and error at the specified location and returns null if the element type 87 /// isn't supported. 88 static ComplexType getChecked(Type elementType, Location location); 89 90 /// Verify the construction of an integer type. 91 static LogicalResult verifyConstructionInvariants(Location loc, 92 Type elementType); 93 94 Type getElementType(); 95 kindof(unsigned kind)96 static bool kindof(unsigned kind) { return kind == StandardTypes::Complex; } 97 }; 98 99 //===----------------------------------------------------------------------===// 100 // IndexType 101 //===----------------------------------------------------------------------===// 102 103 /// Index is a special integer-like type with unknown platform-dependent bit 104 /// width. 105 class IndexType : public Type::TypeBase<IndexType, Type, TypeStorage> { 106 public: 107 using Base::Base; 108 109 /// Get an instance of the IndexType. 110 static IndexType get(MLIRContext *context); 111 112 /// Support method to enable LLVM-style type casting. kindof(unsigned kind)113 static bool kindof(unsigned kind) { return kind == StandardTypes::Index; } 114 115 /// Storage bit width used for IndexType by internal compiler data structures. 116 static constexpr unsigned kInternalStorageBitWidth = 64; 117 }; 118 119 //===----------------------------------------------------------------------===// 120 // IntegerType 121 //===----------------------------------------------------------------------===// 122 123 /// Integer types can have arbitrary bitwidth up to a large fixed limit. 124 class IntegerType 125 : public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> { 126 public: 127 using Base::Base; 128 129 /// Signedness semantics. 130 enum SignednessSemantics { 131 Signless, /// No signedness semantics 132 Signed, /// Signed integer 133 Unsigned, /// Unsigned integer 134 }; 135 136 /// Get or create a new IntegerType of the given width within the context. 137 /// The created IntegerType is signless (i.e., no signedness semantics). 138 /// Assume the width is within the allowed range and assert on failures. Use 139 /// getChecked to handle failures gracefully. 140 static IntegerType get(unsigned width, MLIRContext *context); 141 142 /// Get or create a new IntegerType of the given width within the context. 143 /// The created IntegerType has signedness semantics as indicated via 144 /// `signedness`. Assume the width is within the allowed range and assert on 145 /// failures. Use getChecked to handle failures gracefully. 146 static IntegerType get(unsigned width, SignednessSemantics signedness, 147 MLIRContext *context); 148 149 /// Get or create a new IntegerType of the given width within the context, 150 /// defined at the given, potentially unknown, location. The created 151 /// IntegerType is signless (i.e., no signedness semantics). If the width is 152 /// outside the allowed range, emit errors and return a null type. 153 static IntegerType getChecked(unsigned width, Location location); 154 155 /// Get or create a new IntegerType of the given width within the context, 156 /// defined at the given, potentially unknown, location. The created 157 /// IntegerType has signedness semantics as indicated via `signedness`. If the 158 /// width is outside the allowed range, emit errors and return a null type. 159 static IntegerType getChecked(unsigned width, SignednessSemantics signedness, 160 Location location); 161 162 /// Verify the construction of an integer type. 163 static LogicalResult 164 verifyConstructionInvariants(Location loc, unsigned width, 165 SignednessSemantics signedness); 166 167 /// Return the bitwidth of this integer type. 168 unsigned getWidth() const; 169 170 /// Return the signedness semantics of this integer type. 171 SignednessSemantics getSignedness() const; 172 173 /// Return true if this is a signless integer type. isSignless()174 bool isSignless() const { return getSignedness() == Signless; } 175 /// Return true if this is a signed integer type. isSigned()176 bool isSigned() const { return getSignedness() == Signed; } 177 /// Return true if this is an unsigned integer type. isUnsigned()178 bool isUnsigned() const { return getSignedness() == Unsigned; } 179 180 /// Methods for support type inquiry through isa, cast, and dyn_cast. kindof(unsigned kind)181 static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; } 182 183 /// Integer representation maximal bitwidth. 184 static constexpr unsigned kMaxWidth = 4096; 185 }; 186 187 //===----------------------------------------------------------------------===// 188 // FloatType 189 //===----------------------------------------------------------------------===// 190 191 class FloatType : public Type::TypeBase<FloatType, Type, TypeStorage> { 192 public: 193 using Base::Base; 194 195 static FloatType get(StandardTypes::Kind kind, MLIRContext *context); 196 197 // Convenience factories. getBF16(MLIRContext * ctx)198 static FloatType getBF16(MLIRContext *ctx) { 199 return get(StandardTypes::BF16, ctx); 200 } getF16(MLIRContext * ctx)201 static FloatType getF16(MLIRContext *ctx) { 202 return get(StandardTypes::F16, ctx); 203 } getF32(MLIRContext * ctx)204 static FloatType getF32(MLIRContext *ctx) { 205 return get(StandardTypes::F32, ctx); 206 } getF64(MLIRContext * ctx)207 static FloatType getF64(MLIRContext *ctx) { 208 return get(StandardTypes::F64, ctx); 209 } 210 211 /// Methods for support type inquiry through isa, cast, and dyn_cast. kindof(unsigned kind)212 static bool kindof(unsigned kind) { 213 return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE && 214 kind <= StandardTypes::LAST_FLOATING_POINT_TYPE; 215 } 216 217 /// Return the bitwidth of this float type. 218 unsigned getWidth(); 219 220 /// Return the floating semantics of this float type. 221 const llvm::fltSemantics &getFloatSemantics(); 222 }; 223 224 //===----------------------------------------------------------------------===// 225 // NoneType 226 //===----------------------------------------------------------------------===// 227 228 /// NoneType is a unit type, i.e. a type with exactly one possible value, where 229 /// its value does not have a defined dynamic representation. 230 class NoneType : public Type::TypeBase<NoneType, Type, TypeStorage> { 231 public: 232 using Base::Base; 233 234 /// Get an instance of the NoneType. 235 static NoneType get(MLIRContext *context); 236 kindof(unsigned kind)237 static bool kindof(unsigned kind) { return kind == StandardTypes::None; } 238 }; 239 240 //===----------------------------------------------------------------------===// 241 // ShapedType 242 //===----------------------------------------------------------------------===// 243 244 /// This is a common base class between Vector, UnrankedTensor, RankedTensor, 245 /// and MemRef types because they share behavior and semantics around shape, 246 /// rank, and fixed element type. Any type with these semantics should inherit 247 /// from ShapedType. 248 class ShapedType : public Type { 249 public: 250 using ImplType = detail::ShapedTypeStorage; 251 using Type::Type; 252 253 // TODO: merge these two special values in a single one used everywhere. 254 // Unfortunately, uses of `-1` have crept deep into the codebase now and are 255 // hard to track. 256 static constexpr int64_t kDynamicSize = -1; 257 static constexpr int64_t kDynamicStrideOrOffset = 258 std::numeric_limits<int64_t>::min(); 259 260 /// Return the element type. 261 Type getElementType() const; 262 263 /// If an element type is an integer or a float, return its width. Otherwise, 264 /// abort. 265 unsigned getElementTypeBitWidth() const; 266 267 /// If it has static shape, return the number of elements. Otherwise, abort. 268 int64_t getNumElements() const; 269 270 /// If this is a ranked type, return the rank. Otherwise, abort. 271 int64_t getRank() const; 272 273 /// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors 274 /// have a rank, while unranked tensors do not. 275 bool hasRank() const; 276 277 /// If this is a ranked type, return the shape. Otherwise, abort. 278 ArrayRef<int64_t> getShape() const; 279 280 /// If this is unranked type or any dimension has unknown size (<0), it 281 /// doesn't have static shape. If all dimensions have known size (>= 0), it 282 /// has static shape. 283 bool hasStaticShape() const; 284 285 /// If this has a static shape and the shape is equal to `shape` return true. 286 bool hasStaticShape(ArrayRef<int64_t> shape) const; 287 288 /// If this is a ranked type, return the number of dimensions with dynamic 289 /// size. Otherwise, abort. 290 int64_t getNumDynamicDims() const; 291 292 /// If this is ranked type, return the size of the specified dimension. 293 /// Otherwise, abort. 294 int64_t getDimSize(unsigned idx) const; 295 296 /// Returns true if this dimension has a dynamic size (for ranked types); 297 /// aborts for unranked types. 298 bool isDynamicDim(unsigned idx) const; 299 300 /// Returns the position of the dynamic dimension relative to just the dynamic 301 /// dimensions, given its `index` within the shape. 302 unsigned getDynamicDimIndex(unsigned index) const; 303 304 /// Get the total amount of bits occupied by a value of this type. This does 305 /// not take into account any memory layout or widening constraints, e.g. a 306 /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice 307 /// it will likely be stored as in a 4xi64 vector register. Fail an assertion 308 /// if the size cannot be computed statically, i.e. if the type has a dynamic 309 /// shape or if its elemental type does not have a known bit width. 310 int64_t getSizeInBits() const; 311 312 /// Methods for support type inquiry through isa, cast, and dyn_cast. classof(Type type)313 static bool classof(Type type) { 314 return type.getKind() == StandardTypes::Vector || 315 type.getKind() == StandardTypes::RankedTensor || 316 type.getKind() == StandardTypes::UnrankedTensor || 317 type.getKind() == StandardTypes::UnrankedMemRef || 318 type.getKind() == StandardTypes::MemRef; 319 } 320 321 /// Whether the given dimension size indicates a dynamic dimension. isDynamic(int64_t dSize)322 static constexpr bool isDynamic(int64_t dSize) { 323 return dSize == kDynamicSize; 324 } isDynamicStrideOrOffset(int64_t dStrideOrOffset)325 static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) { 326 return dStrideOrOffset == kDynamicStrideOrOffset; 327 } 328 }; 329 330 //===----------------------------------------------------------------------===// 331 // VectorType 332 //===----------------------------------------------------------------------===// 333 334 /// Vector types represent multi-dimensional SIMD vectors, and have a fixed 335 /// known constant shape with one or more dimension. 336 class VectorType 337 : public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> { 338 public: 339 using Base::Base; 340 341 /// Get or create a new VectorType of the provided shape and element type. 342 /// Assumes the arguments define a well-formed VectorType. 343 static VectorType get(ArrayRef<int64_t> shape, Type elementType); 344 345 /// Get or create a new VectorType of the provided shape and element type 346 /// declared at the given, potentially unknown, location. If the VectorType 347 /// defined by the arguments would be ill-formed, emit errors and return 348 /// nullptr-wrapping type. 349 static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType, 350 Location location); 351 352 /// Verify the construction of a vector type. 353 static LogicalResult verifyConstructionInvariants(Location loc, 354 ArrayRef<int64_t> shape, 355 Type elementType); 356 357 /// Returns true of the given type can be used as an element of a vector type. 358 /// In particular, vectors can consist of integer or float primitives. isValidElementType(Type t)359 static bool isValidElementType(Type t) { 360 return t.isa<IntegerType, FloatType>(); 361 } 362 363 ArrayRef<int64_t> getShape() const; 364 365 /// Methods for support type inquiry through isa, cast, and dyn_cast. kindof(unsigned kind)366 static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; } 367 }; 368 369 //===----------------------------------------------------------------------===// 370 // TensorType 371 //===----------------------------------------------------------------------===// 372 373 /// Tensor types represent multi-dimensional arrays, and have two variants: 374 /// RankedTensorType and UnrankedTensorType. 375 class TensorType : public ShapedType { 376 public: 377 using ShapedType::ShapedType; 378 379 /// Return true if the specified element type is ok in a tensor. isValidElementType(Type type)380 static bool isValidElementType(Type type) { 381 // Note: Non standard/builtin types are allowed to exist within tensor 382 // types. Dialects are expected to verify that tensor types have a valid 383 // element type within that dialect. 384 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 385 IndexType>() || 386 (type.getKind() > Type::Kind::LAST_STANDARD_TYPE); 387 } 388 389 /// Methods for support type inquiry through isa, cast, and dyn_cast. classof(Type type)390 static bool classof(Type type) { 391 return type.getKind() == StandardTypes::RankedTensor || 392 type.getKind() == StandardTypes::UnrankedTensor; 393 } 394 }; 395 396 //===----------------------------------------------------------------------===// 397 // RankedTensorType 398 399 /// Ranked tensor types represent multi-dimensional arrays that have a shape 400 /// with a fixed number of dimensions. Each shape element can be a non-negative 401 /// integer or unknown (represented by -1). 402 class RankedTensorType 403 : public Type::TypeBase<RankedTensorType, TensorType, 404 detail::RankedTensorTypeStorage> { 405 public: 406 using Base::Base; 407 408 /// Get or create a new RankedTensorType of the provided shape and element 409 /// type. Assumes the arguments define a well-formed type. 410 static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType); 411 412 /// Get or create a new RankedTensorType of the provided shape and element 413 /// type declared at the given, potentially unknown, location. If the 414 /// RankedTensorType defined by the arguments would be ill-formed, emit errors 415 /// and return a nullptr-wrapping type. 416 static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType, 417 Location location); 418 419 /// Verify the construction of a ranked tensor type. 420 static LogicalResult verifyConstructionInvariants(Location loc, 421 ArrayRef<int64_t> shape, 422 Type elementType); 423 424 ArrayRef<int64_t> getShape() const; 425 kindof(unsigned kind)426 static bool kindof(unsigned kind) { 427 return kind == StandardTypes::RankedTensor; 428 } 429 }; 430 431 //===----------------------------------------------------------------------===// 432 // UnrankedTensorType 433 434 /// Unranked tensor types represent multi-dimensional arrays that have an 435 /// unknown shape. 436 class UnrankedTensorType 437 : public Type::TypeBase<UnrankedTensorType, TensorType, 438 detail::UnrankedTensorTypeStorage> { 439 public: 440 using Base::Base; 441 442 /// Get or create a new UnrankedTensorType of the provided shape and element 443 /// type. Assumes the arguments define a well-formed type. 444 static UnrankedTensorType get(Type elementType); 445 446 /// Get or create a new UnrankedTensorType of the provided shape and element 447 /// type declared at the given, potentially unknown, location. If the 448 /// UnrankedTensorType defined by the arguments would be ill-formed, emit 449 /// errors and return a nullptr-wrapping type. 450 static UnrankedTensorType getChecked(Type elementType, Location location); 451 452 /// Verify the construction of a unranked tensor type. 453 static LogicalResult verifyConstructionInvariants(Location loc, 454 Type elementType); 455 getShape()456 ArrayRef<int64_t> getShape() const { return llvm::None; } 457 kindof(unsigned kind)458 static bool kindof(unsigned kind) { 459 return kind == StandardTypes::UnrankedTensor; 460 } 461 }; 462 463 //===----------------------------------------------------------------------===// 464 // BaseMemRefType 465 //===----------------------------------------------------------------------===// 466 467 /// Base MemRef for Ranked and Unranked variants 468 class BaseMemRefType : public ShapedType { 469 public: 470 using ShapedType::ShapedType; 471 472 /// Methods for support type inquiry through isa, cast, and dyn_cast. classof(Type type)473 static bool classof(Type type) { 474 return type.getKind() == StandardTypes::MemRef || 475 type.getKind() == StandardTypes::UnrankedMemRef; 476 } 477 }; 478 479 //===----------------------------------------------------------------------===// 480 // MemRefType 481 482 /// MemRef types represent a region of memory that have a shape with a fixed 483 /// number of dimensions. Each shape element can be a non-negative integer or 484 /// unknown (represented by -1). MemRef types also have an affine map 485 /// composition, represented as an array AffineMap pointers. 486 class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType, 487 detail::MemRefTypeStorage> { 488 public: 489 /// This is a builder type that keeps local references to arguments. Arguments 490 /// that are passed into the builder must out-live the builder. 491 class Builder { 492 public: 493 // Build from another MemRefType. Builder(MemRefType other)494 explicit Builder(MemRefType other) 495 : shape(other.getShape()), elementType(other.getElementType()), 496 affineMaps(other.getAffineMaps()), 497 memorySpace(other.getMemorySpace()) {} 498 499 // Build from scratch. Builder(ArrayRef<int64_t> shape,Type elementType)500 Builder(ArrayRef<int64_t> shape, Type elementType) 501 : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) { 502 } 503 setShape(ArrayRef<int64_t> newShape)504 Builder &setShape(ArrayRef<int64_t> newShape) { 505 shape = newShape; 506 return *this; 507 } 508 setElementType(Type newElementType)509 Builder &setElementType(Type newElementType) { 510 elementType = newElementType; 511 return *this; 512 } 513 setAffineMaps(ArrayRef<AffineMap> newAffineMaps)514 Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) { 515 affineMaps = newAffineMaps; 516 return *this; 517 } 518 setMemorySpace(unsigned newMemorySpace)519 Builder &setMemorySpace(unsigned newMemorySpace) { 520 memorySpace = newMemorySpace; 521 return *this; 522 } 523 MemRefType()524 operator MemRefType() { 525 return MemRefType::get(shape, elementType, affineMaps, memorySpace); 526 } 527 528 private: 529 ArrayRef<int64_t> shape; 530 Type elementType; 531 ArrayRef<AffineMap> affineMaps; 532 unsigned memorySpace; 533 }; 534 535 using Base::Base; 536 537 /// Get or create a new MemRefType based on shape, element type, affine 538 /// map composition, and memory space. Assumes the arguments define a 539 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType 540 /// construction failures. 541 static MemRefType get(ArrayRef<int64_t> shape, Type elementType, 542 ArrayRef<AffineMap> affineMapComposition = {}, 543 unsigned memorySpace = 0); 544 545 /// Get or create a new MemRefType based on shape, element type, affine 546 /// map composition, and memory space declared at the given location. 547 /// If the location is unknown, the last argument should be an instance of 548 /// UnknownLoc. If the MemRefType defined by the arguments would be 549 /// ill-formed, emits errors (to the handler registered with the context or to 550 /// the error stream) and returns nullptr. 551 static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType, 552 ArrayRef<AffineMap> affineMapComposition, 553 unsigned memorySpace, Location location); 554 555 ArrayRef<int64_t> getShape() const; 556 557 /// Returns an array of affine map pointers representing the memref affine 558 /// map composition. 559 ArrayRef<AffineMap> getAffineMaps() const; 560 561 /// Returns the memory space in which data referred to by this memref resides. 562 unsigned getMemorySpace() const; 563 564 // TODO: merge these two special values in a single one used everywhere. 565 // Unfortunately, uses of `-1` have crept deep into the codebase now and are 566 // hard to track. 567 static constexpr int64_t kDynamicSize = -1; getDynamicStrideOrOffset()568 static int64_t getDynamicStrideOrOffset() { 569 return ShapedType::kDynamicStrideOrOffset; 570 } 571 kindof(unsigned kind)572 static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; } 573 574 private: 575 /// Get or create a new MemRefType defined by the arguments. If the resulting 576 /// type would be ill-formed, return nullptr. If the location is provided, 577 /// emit detailed error messages. 578 static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType, 579 ArrayRef<AffineMap> affineMapComposition, 580 unsigned memorySpace, Optional<Location> location); 581 using Base::getImpl; 582 }; 583 584 //===----------------------------------------------------------------------===// 585 // UnrankedMemRefType 586 587 /// Unranked MemRef type represent multi-dimensional MemRefs that 588 /// have an unknown rank. 589 class UnrankedMemRefType 590 : public Type::TypeBase<UnrankedMemRefType, BaseMemRefType, 591 detail::UnrankedMemRefTypeStorage> { 592 public: 593 using Base::Base; 594 595 /// Get or create a new UnrankedMemRefType of the provided element 596 /// type and memory space 597 static UnrankedMemRefType get(Type elementType, unsigned memorySpace); 598 599 /// Get or create a new UnrankedMemRefType of the provided element 600 /// type and memory space declared at the given, potentially unknown, 601 /// location. If the UnrankedMemRefType defined by the arguments would be 602 /// ill-formed, emit errors and return a nullptr-wrapping type. 603 static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace, 604 Location location); 605 606 /// Verify the construction of a unranked memref type. 607 static LogicalResult verifyConstructionInvariants(Location loc, 608 Type elementType, 609 unsigned memorySpace); 610 getShape()611 ArrayRef<int64_t> getShape() const { return llvm::None; } 612 613 /// Returns the memory space in which data referred to by this memref resides. 614 unsigned getMemorySpace() const; kindof(unsigned kind)615 static bool kindof(unsigned kind) { 616 return kind == StandardTypes::UnrankedMemRef; 617 } 618 }; 619 620 //===----------------------------------------------------------------------===// 621 // TupleType 622 //===----------------------------------------------------------------------===// 623 624 /// Tuple types represent a collection of other types. Note: This type merely 625 /// provides a common mechanism for representing tuples in MLIR. It is up to 626 /// dialect authors to provides operations for manipulating them, e.g. 627 /// extract_tuple_element. When possible, users should prefer multi-result 628 /// operations in the place of tuples. 629 class TupleType 630 : public Type::TypeBase<TupleType, Type, detail::TupleTypeStorage> { 631 public: 632 using Base::Base; 633 634 /// Get or create a new TupleType with the provided element types. Assumes the 635 /// arguments define a well-formed type. 636 static TupleType get(ArrayRef<Type> elementTypes, MLIRContext *context); 637 638 /// Get or create an empty tuple type. get(MLIRContext * context)639 static TupleType get(MLIRContext *context) { return get({}, context); } 640 641 /// Return the elements types for this tuple. 642 ArrayRef<Type> getTypes() const; 643 644 /// Accumulate the types contained in this tuple and tuples nested within it. 645 /// Note that this only flattens nested tuples, not any other container type, 646 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 647 /// (i32, tensor<i32>, f32, i64) 648 void getFlattenedTypes(SmallVectorImpl<Type> &types); 649 650 /// Return the number of held types. 651 size_t size() const; 652 653 /// Iterate over the held elements. 654 using iterator = ArrayRef<Type>::iterator; begin()655 iterator begin() const { return getTypes().begin(); } end()656 iterator end() const { return getTypes().end(); } 657 658 /// Return the element type at index 'index'. getType(size_t index)659 Type getType(size_t index) const { 660 assert(index < size() && "invalid index for tuple type"); 661 return getTypes()[index]; 662 } 663 kindof(unsigned kind)664 static bool kindof(unsigned kind) { return kind == StandardTypes::Tuple; } 665 }; 666 667 //===----------------------------------------------------------------------===// 668 // Type Utilities 669 //===----------------------------------------------------------------------===// 670 671 /// Returns the strides of the MemRef if the layout map is in strided form. 672 /// MemRefs with layout maps in strided form include: 673 /// 1. empty or identity layout map, in which case the stride information is 674 /// the canonical form computed from sizes; 675 /// 2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`, 676 /// where K and ki's are constants or symbols. 677 /// 678 /// A stride specification is a list of integer values that are either static 679 /// or dynamic (encoded with getDynamicStrideOrOffset()). Strides encode the 680 /// distance in the number of elements between successive entries along a 681 /// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>` 682 /// specifies a view into a non-contiguous memory region of `42` by `16` `f32` 683 /// elements in which the distance between two consecutive elements along the 684 /// outer dimension is `1` and the distance between two consecutive elements 685 /// along the inner dimension is `64`. 686 /// 687 /// If a simple strided form cannot be extracted from the composition of the 688 /// layout map, returns llvm::None. 689 /// 690 /// The convention is that the strides for dimensions d0, .. dn appear in 691 /// order to make indexing intuitive into the result. 692 LogicalResult getStridesAndOffset(MemRefType t, 693 SmallVectorImpl<int64_t> &strides, 694 int64_t &offset); 695 LogicalResult getStridesAndOffset(MemRefType t, 696 SmallVectorImpl<AffineExpr> &strides, 697 AffineExpr &offset); 698 699 /// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset() 700 /// represents a dynamic value), return the single result AffineMap which 701 /// represents the linearized strided layout map. Dimensions correspond to the 702 /// offset followed by the strides in order. Symbols are inserted for each 703 /// dynamic dimension in order. A stride cannot take value `0`. 704 /// 705 /// Examples: 706 /// ========= 707 /// 708 /// 1. For offset: 0 strides: ?, ?, 1 return 709 /// (i, j, k)[M, N]->(M * i + N * j + k) 710 /// 711 /// 2. For offset: 3 strides: 32, ?, 16 return 712 /// (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k) 713 /// 714 /// 3. For offset: ? strides: ?, ?, ? return 715 /// (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k) 716 AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset, 717 MLIRContext *context); 718 719 /// Return a version of `t` with identity layout if it can be determined 720 /// statically that the layout is the canonical contiguous strided layout. 721 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 722 /// `t` with simplified layout. 723 MemRefType canonicalizeStridedLayout(MemRefType t); 724 725 /// Return a version of `t` with a layout that has all dynamic offset and 726 /// strides. This is used to erase the static layout. 727 MemRefType eraseStridedLayout(MemRefType t); 728 729 /// Given MemRef `sizes` that are either static or dynamic, returns the 730 /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and 731 /// once a dynamic dimension is encountered, all canonical strides become 732 /// dynamic and need to be encoded with a different symbol. 733 /// For canonical strides expressions, the offset is always 0 and and fastest 734 /// varying stride is always `1`. 735 /// 736 /// Examples: 737 /// - memref<3x4x5xf32> has canonical stride expression 738 /// `20*exprs[0] + 5*exprs[1] + exprs[2]`. 739 /// - memref<3x?x5xf32> has canonical stride expression 740 /// `s0*exprs[0] + 5*exprs[1] + exprs[2]`. 741 /// - memref<3x4x?xf32> has canonical stride expression 742 /// `s1*exprs[0] + s0*exprs[1] + exprs[2]`. 743 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 744 ArrayRef<AffineExpr> exprs, 745 MLIRContext *context); 746 747 /// Return the result of makeCanonicalStrudedLayoutExpr for the common case 748 /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)} 749 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 750 MLIRContext *context); 751 752 /// Return true if the layout for `t` is compatible with strided semantics. 753 bool isStrided(MemRefType t); 754 755 } // end namespace mlir 756 757 #endif // MLIR_IR_STANDARDTYPES_H 758