1 //===- StandardTypes.h - MLIR Standard Type Classes -------------*- C++ -*-===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_STANDARDTYPES_H 10 #define MLIR_IR_STANDARDTYPES_H 11 12 #include "mlir/IR/Types.h" 13 14 namespace llvm { 15 struct fltSemantics; 16 } // namespace llvm 17 18 namespace mlir { 19 class AffineExpr; 20 class AffineMap; 21 class FloatType; 22 class IndexType; 23 class IntegerType; 24 class Location; 25 class MLIRContext; 26 27 namespace detail { 28 29 struct IntegerTypeStorage; 30 struct ShapedTypeStorage; 31 struct VectorTypeStorage; 32 struct RankedTensorTypeStorage; 33 struct UnrankedTensorTypeStorage; 34 struct MemRefTypeStorage; 35 struct UnrankedMemRefTypeStorage; 36 struct ComplexTypeStorage; 37 struct TupleTypeStorage; 38 39 } // namespace detail 40 41 namespace StandardTypes { 42 enum Kind { 43 // Floating point. 44 BF16 = Type::Kind::FIRST_STANDARD_TYPE, 45 F16, 46 F32, 47 F64, 48 FIRST_FLOATING_POINT_TYPE = BF16, 49 LAST_FLOATING_POINT_TYPE = F64, 50 51 // Target pointer sized integer, used (e.g.) in affine mappings. 52 Index, 53 54 // Derived types. 55 Integer, 56 Vector, 57 RankedTensor, 58 UnrankedTensor, 59 MemRef, 60 UnrankedMemRef, 61 Complex, 62 Tuple, 63 None, 64 }; 65 66 } // namespace StandardTypes 67 68 /// Index is a special integer-like type with unknown platform-dependent bit 69 /// width. 70 class IndexType : public Type::TypeBase<IndexType, Type> { 71 public: 72 using Base::Base; 73 74 /// Get an instance of the IndexType. 75 static IndexType get(MLIRContext *context); 76 77 /// Support method to enable LLVM-style type casting. kindof(unsigned kind)78 static bool kindof(unsigned kind) { return kind == StandardTypes::Index; } 79 }; 80 81 /// Integer types can have arbitrary bitwidth up to a large fixed limit. 82 class IntegerType 83 : public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> { 84 public: 85 using Base::Base; 86 87 /// Get or create a new IntegerType of the given width within the context. 88 /// Assume the width is within the allowed range and assert on failures. 89 /// Use getChecked to handle failures gracefully. 90 static IntegerType get(unsigned width, MLIRContext *context); 91 92 /// Get or create a new IntegerType of the given width within the context, 93 /// defined at the given, potentially unknown, location. If the width is 94 /// outside the allowed range, emit errors and return a null type. 95 static IntegerType getChecked(unsigned width, MLIRContext *context, 96 Location location); 97 98 /// Verify the construction of an integer type. 99 static LogicalResult verifyConstructionInvariants(Optional<Location> loc, 100 MLIRContext *context, 101 unsigned width); 102 103 /// Return the bitwidth of this integer type. 104 unsigned getWidth() const; 105 106 /// Methods for support type inquiry through isa, cast, and dyn_cast. kindof(unsigned kind)107 static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; } 108 109 /// Integer representation maximal bitwidth. 110 static constexpr unsigned kMaxWidth = 4096; 111 }; 112 113 class FloatType : public Type::TypeBase<FloatType, Type> { 114 public: 115 using Base::Base; 116 117 static FloatType get(StandardTypes::Kind kind, MLIRContext *context); 118 119 // Convenience factories. getBF16(MLIRContext * ctx)120 static FloatType getBF16(MLIRContext *ctx) { 121 return get(StandardTypes::BF16, ctx); 122 } getF16(MLIRContext * ctx)123 static FloatType getF16(MLIRContext *ctx) { 124 return get(StandardTypes::F16, ctx); 125 } getF32(MLIRContext * ctx)126 static FloatType getF32(MLIRContext *ctx) { 127 return get(StandardTypes::F32, ctx); 128 } getF64(MLIRContext * ctx)129 static FloatType getF64(MLIRContext *ctx) { 130 return get(StandardTypes::F64, ctx); 131 } 132 133 /// Methods for support type inquiry through isa, cast, and dyn_cast. kindof(unsigned kind)134 static bool kindof(unsigned kind) { 135 return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE && 136 kind <= StandardTypes::LAST_FLOATING_POINT_TYPE; 137 } 138 139 /// Return the bitwidth of this float type. 140 unsigned getWidth(); 141 142 /// Return the floating semantics of this float type. 143 const llvm::fltSemantics &getFloatSemantics(); 144 }; 145 146 /// The 'complex' type represents a complex number with a parameterized element 147 /// type, which is composed of a real and imaginary value of that element type. 148 /// 149 /// The element must be a floating point or integer scalar type. 150 /// 151 class ComplexType 152 : public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> { 153 public: 154 using Base::Base; 155 156 /// Get or create a ComplexType with the provided element type. 157 static ComplexType get(Type elementType); 158 159 /// Get or create a ComplexType with the provided element type. This emits 160 /// and error at the specified location and returns null if the element type 161 /// isn't supported. 162 static ComplexType getChecked(Type elementType, Location location); 163 164 /// Verify the construction of an integer type. 165 static LogicalResult verifyConstructionInvariants(Optional<Location> loc, 166 MLIRContext *context, 167 Type elementType); 168 169 Type getElementType(); 170 kindof(unsigned kind)171 static bool kindof(unsigned kind) { return kind == StandardTypes::Complex; } 172 }; 173 174 /// This is a common base class between Vector, UnrankedTensor, RankedTensor, 175 /// and MemRef types because they share behavior and semantics around shape, 176 /// rank, and fixed element type. Any type with these semantics should inherit 177 /// from ShapedType. 178 class ShapedType : public Type { 179 public: 180 using ImplType = detail::ShapedTypeStorage; 181 using Type::Type; 182 183 // TODO(ntv): merge these two special values in a single one used everywhere. 184 // Unfortunately, uses of `-1` have crept deep into the codebase now and are 185 // hard to track. 186 static constexpr int64_t kDynamicSize = -1; 187 static constexpr int64_t kDynamicStrideOrOffset = 188 std::numeric_limits<int64_t>::min(); 189 190 /// Return the element type. 191 Type getElementType() const; 192 193 /// If an element type is an integer or a float, return its width. Otherwise, 194 /// abort. 195 unsigned getElementTypeBitWidth() const; 196 197 /// If it has static shape, return the number of elements. Otherwise, abort. 198 int64_t getNumElements() const; 199 200 /// If this is a ranked type, return the rank. Otherwise, abort. 201 int64_t getRank() const; 202 203 /// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors 204 /// have a rank, while unranked tensors do not. 205 bool hasRank() const; 206 207 /// If this is a ranked type, return the shape. Otherwise, abort. 208 ArrayRef<int64_t> getShape() const; 209 210 /// If this is unranked type or any dimension has unknown size (<0), it 211 /// doesn't have static shape. If all dimensions have known size (>= 0), it 212 /// has static shape. 213 bool hasStaticShape() const; 214 215 /// If this has a static shape and the shape is equal to `shape` return true. 216 bool hasStaticShape(ArrayRef<int64_t> shape) const; 217 218 /// If this is a ranked type, return the number of dimensions with dynamic 219 /// size. Otherwise, abort. 220 int64_t getNumDynamicDims() const; 221 222 /// If this is ranked type, return the size of the specified dimension. 223 /// Otherwise, abort. 224 int64_t getDimSize(int64_t i) const; 225 226 /// Returns the position of the dynamic dimension relative to just the dynamic 227 /// dimensions, given its `index` within the shape. 228 unsigned getDynamicDimIndex(unsigned index) const; 229 230 /// Get the total amount of bits occupied by a value of this type. This does 231 /// not take into account any memory layout or widening constraints, e.g. a 232 /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice 233 /// it will likely be stored as in a 4xi64 vector register. Fail an assertion 234 /// if the size cannot be computed statically, i.e. if the type has a dynamic 235 /// shape or if its elemental type does not have a known bit width. 236 int64_t getSizeInBits() const; 237 238 /// Methods for support type inquiry through isa, cast, and dyn_cast. classof(Type type)239 static bool classof(Type type) { 240 return type.getKind() == StandardTypes::Vector || 241 type.getKind() == StandardTypes::RankedTensor || 242 type.getKind() == StandardTypes::UnrankedTensor || 243 type.getKind() == StandardTypes::UnrankedMemRef || 244 type.getKind() == StandardTypes::MemRef; 245 } 246 247 /// Whether the given dimension size indicates a dynamic dimension. isDynamic(int64_t dSize)248 static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; } isDynamicStrideOrOffset(int64_t dStrideOrOffset)249 static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) { 250 return dStrideOrOffset == kDynamicStrideOrOffset; 251 } 252 }; 253 254 /// Vector types represent multi-dimensional SIMD vectors, and have a fixed 255 /// known constant shape with one or more dimension. 256 class VectorType 257 : public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> { 258 public: 259 using Base::Base; 260 261 /// Get or create a new VectorType of the provided shape and element type. 262 /// Assumes the arguments define a well-formed VectorType. 263 static VectorType get(ArrayRef<int64_t> shape, Type elementType); 264 265 /// Get or create a new VectorType of the provided shape and element type 266 /// declared at the given, potentially unknown, location. If the VectorType 267 /// defined by the arguments would be ill-formed, emit errors and return 268 /// nullptr-wrapping type. 269 static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType, 270 Location location); 271 272 /// Verify the construction of a vector type. 273 static LogicalResult verifyConstructionInvariants(Optional<Location> loc, 274 MLIRContext *context, 275 ArrayRef<int64_t> shape, 276 Type elementType); 277 278 /// Returns true of the given type can be used as an element of a vector type. 279 /// In particular, vectors can consist of integer or float primitives. isValidElementType(Type t)280 static bool isValidElementType(Type t) { return t.isIntOrFloat(); } 281 282 ArrayRef<int64_t> getShape() const; 283 284 /// Methods for support type inquiry through isa, cast, and dyn_cast. kindof(unsigned kind)285 static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; } 286 }; 287 288 /// Tensor types represent multi-dimensional arrays, and have two variants: 289 /// RankedTensorType and UnrankedTensorType. 290 class TensorType : public ShapedType { 291 public: 292 using ShapedType::ShapedType; 293 294 /// Return true if the specified element type is ok in a tensor. isValidElementType(Type type)295 static bool isValidElementType(Type type) { 296 // Note: Non standard/builtin types are allowed to exist within tensor 297 // types. Dialects are expected to verify that tensor types have a valid 298 // element type within that dialect. 299 return type.isIntOrFloat() || type.isa<ComplexType>() || 300 type.isa<VectorType>() || type.isa<OpaqueType>() || 301 (type.getKind() > Type::Kind::LAST_STANDARD_TYPE); 302 } 303 304 /// Methods for support type inquiry through isa, cast, and dyn_cast. classof(Type type)305 static bool classof(Type type) { 306 return type.getKind() == StandardTypes::RankedTensor || 307 type.getKind() == StandardTypes::UnrankedTensor; 308 } 309 }; 310 311 /// Ranked tensor types represent multi-dimensional arrays that have a shape 312 /// with a fixed number of dimensions. Each shape element can be a positive 313 /// integer or unknown (represented -1). 314 class RankedTensorType 315 : public Type::TypeBase<RankedTensorType, TensorType, 316 detail::RankedTensorTypeStorage> { 317 public: 318 using Base::Base; 319 320 /// Get or create a new RankedTensorType of the provided shape and element 321 /// type. Assumes the arguments define a well-formed type. 322 static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType); 323 324 /// Get or create a new RankedTensorType of the provided shape and element 325 /// type declared at the given, potentially unknown, location. If the 326 /// RankedTensorType defined by the arguments would be ill-formed, emit errors 327 /// and return a nullptr-wrapping type. 328 static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType, 329 Location location); 330 331 /// Verify the construction of a ranked tensor type. 332 static LogicalResult verifyConstructionInvariants(Optional<Location> loc, 333 MLIRContext *context, 334 ArrayRef<int64_t> shape, 335 Type elementType); 336 337 ArrayRef<int64_t> getShape() const; 338 kindof(unsigned kind)339 static bool kindof(unsigned kind) { 340 return kind == StandardTypes::RankedTensor; 341 } 342 }; 343 344 /// Unranked tensor types represent multi-dimensional arrays that have an 345 /// unknown shape. 346 class UnrankedTensorType 347 : public Type::TypeBase<UnrankedTensorType, TensorType, 348 detail::UnrankedTensorTypeStorage> { 349 public: 350 using Base::Base; 351 352 /// Get or create a new UnrankedTensorType of the provided shape and element 353 /// type. Assumes the arguments define a well-formed type. 354 static UnrankedTensorType get(Type elementType); 355 356 /// Get or create a new UnrankedTensorType of the provided shape and element 357 /// type declared at the given, potentially unknown, location. If the 358 /// UnrankedTensorType defined by the arguments would be ill-formed, emit 359 /// errors and return a nullptr-wrapping type. 360 static UnrankedTensorType getChecked(Type elementType, Location location); 361 362 /// Verify the construction of a unranked tensor type. 363 static LogicalResult verifyConstructionInvariants(Optional<Location> loc, 364 MLIRContext *context, 365 Type elementType); 366 getShape()367 ArrayRef<int64_t> getShape() const { return llvm::None; } 368 kindof(unsigned kind)369 static bool kindof(unsigned kind) { 370 return kind == StandardTypes::UnrankedTensor; 371 } 372 }; 373 374 /// Base MemRef for Ranked and Unranked variants 375 class BaseMemRefType : public ShapedType { 376 public: 377 using ShapedType::ShapedType; 378 379 /// Methods for support type inquiry through isa, cast, and dyn_cast. classof(Type type)380 static bool classof(Type type) { 381 return type.getKind() == StandardTypes::MemRef || 382 type.getKind() == StandardTypes::UnrankedMemRef; 383 } 384 }; 385 386 /// MemRef types represent a region of memory that have a shape with a fixed 387 /// number of dimensions. Each shape element can be a non-negative integer or 388 /// unknown (represented by any negative integer). MemRef types also have an 389 /// affine map composition, represented as an array AffineMap pointers. 390 class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType, 391 detail::MemRefTypeStorage> { 392 public: 393 using Base::Base; 394 395 /// Get or create a new MemRefType based on shape, element type, affine 396 /// map composition, and memory space. Assumes the arguments define a 397 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType 398 /// construction failures. 399 static MemRefType get(ArrayRef<int64_t> shape, Type elementType, 400 ArrayRef<AffineMap> affineMapComposition = {}, 401 unsigned memorySpace = 0); 402 403 /// Get or create a new MemRefType based on shape, element type, affine 404 /// map composition, and memory space declared at the given location. 405 /// If the location is unknown, the last argument should be an instance of 406 /// UnknownLoc. If the MemRefType defined by the arguments would be 407 /// ill-formed, emits errors (to the handler registered with the context or to 408 /// the error stream) and returns nullptr. 409 static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType, 410 ArrayRef<AffineMap> affineMapComposition, 411 unsigned memorySpace, Location location); 412 413 ArrayRef<int64_t> getShape() const; 414 415 /// Returns an array of affine map pointers representing the memref affine 416 /// map composition. 417 ArrayRef<AffineMap> getAffineMaps() const; 418 419 /// Returns the memory space in which data referred to by this memref resides. 420 unsigned getMemorySpace() const; 421 422 // TODO(ntv): merge these two special values in a single one used everywhere. 423 // Unfortunately, uses of `-1` have crept deep into the codebase now and are 424 // hard to track. 425 static constexpr int64_t kDynamicSize = -1; getDynamicStrideOrOffset()426 static int64_t getDynamicStrideOrOffset() { 427 return ShapedType::kDynamicStrideOrOffset; 428 } 429 kindof(unsigned kind)430 static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; } 431 432 private: 433 /// Get or create a new MemRefType defined by the arguments. If the resulting 434 /// type would be ill-formed, return nullptr. If the location is provided, 435 /// emit detailed error messages. 436 static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType, 437 ArrayRef<AffineMap> affineMapComposition, 438 unsigned memorySpace, Optional<Location> location); 439 using Base::getImpl; 440 }; 441 442 /// Unranked MemRef type represent multi-dimensional MemRefs that 443 /// have an unknown rank. 444 class UnrankedMemRefType 445 : public Type::TypeBase<UnrankedMemRefType, BaseMemRefType, 446 detail::UnrankedMemRefTypeStorage> { 447 public: 448 using Base::Base; 449 450 /// Get or create a new UnrankedMemRefType of the provided element 451 /// type and memory space 452 static UnrankedMemRefType get(Type elementType, unsigned memorySpace); 453 454 /// Get or create a new UnrankedMemRefType of the provided element 455 /// type and memory space declared at the given, potentially unknown, 456 /// location. If the UnrankedMemRefType defined by the arguments would be 457 /// ill-formed, emit errors and return a nullptr-wrapping type. 458 static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace, 459 Location location); 460 461 /// Verify the construction of a unranked memref type. 462 static LogicalResult verifyConstructionInvariants(Optional<Location> loc, 463 MLIRContext *context, 464 Type elementType, 465 unsigned memorySpace); 466 getShape()467 ArrayRef<int64_t> getShape() const { return llvm::None; } 468 469 /// Returns the memory space in which data referred to by this memref resides. 470 unsigned getMemorySpace() const; kindof(unsigned kind)471 static bool kindof(unsigned kind) { 472 return kind == StandardTypes::UnrankedMemRef; 473 } 474 }; 475 476 /// Tuple types represent a collection of other types. Note: This type merely 477 /// provides a common mechanism for representing tuples in MLIR. It is up to 478 /// dialect authors to provides operations for manipulating them, e.g. 479 /// extract_tuple_element. When possible, users should prefer multi-result 480 /// operations in the place of tuples. 481 class TupleType 482 : public Type::TypeBase<TupleType, Type, detail::TupleTypeStorage> { 483 public: 484 using Base::Base; 485 486 /// Get or create a new TupleType with the provided element types. Assumes the 487 /// arguments define a well-formed type. 488 static TupleType get(ArrayRef<Type> elementTypes, MLIRContext *context); 489 490 /// Get or create an empty tuple type. get(MLIRContext * context)491 static TupleType get(MLIRContext *context) { return get({}, context); } 492 493 /// Return the elements types for this tuple. 494 ArrayRef<Type> getTypes() const; 495 496 /// Accumulate the types contained in this tuple and tuples nested within it. 497 /// Note that this only flattens nested tuples, not any other container type, 498 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 499 /// (i32, tensor<i32>, f32, i64) 500 void getFlattenedTypes(SmallVectorImpl<Type> &types); 501 502 /// Return the number of held types. 503 size_t size() const; 504 505 /// Iterate over the held elements. 506 using iterator = ArrayRef<Type>::iterator; begin()507 iterator begin() const { return getTypes().begin(); } end()508 iterator end() const { return getTypes().end(); } 509 510 /// Return the element type at index 'index'. getType(size_t index)511 Type getType(size_t index) const { 512 assert(index < size() && "invalid index for tuple type"); 513 return getTypes()[index]; 514 } 515 kindof(unsigned kind)516 static bool kindof(unsigned kind) { return kind == StandardTypes::Tuple; } 517 }; 518 519 /// NoneType is a unit type, i.e. a type with exactly one possible value, where 520 /// its value does not have a defined dynamic representation. 521 class NoneType : public Type::TypeBase<NoneType, Type> { 522 public: 523 using Base::Base; 524 525 /// Get an instance of the NoneType. 526 static NoneType get(MLIRContext *context); 527 kindof(unsigned kind)528 static bool kindof(unsigned kind) { return kind == StandardTypes::None; } 529 }; 530 531 /// Returns the strides of the MemRef if the layout map is in strided form. 532 /// MemRefs with layout maps in strided form include: 533 /// 1. empty or identity layout map, in which case the stride information is 534 /// the canonical form computed from sizes; 535 /// 2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`, 536 /// where K and ki's are constants or symbols. 537 /// 538 /// A stride specification is a list of integer values that are either static 539 /// or dynamic (encoded with getDynamicStrideOrOffset()). Strides encode the 540 /// distance in the number of elements between successive entries along a 541 /// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>` 542 /// specifies a view into a non-contiguous memory region of `42` by `16` `f32` 543 /// elements in which the distance between two consecutive elements along the 544 /// outer dimension is `1` and the distance between two consecutive elements 545 /// along the inner dimension is `64`. 546 /// 547 /// If a simple strided form cannot be extracted from the composition of the 548 /// layout map, returns llvm::None. 549 /// 550 /// The convention is that the strides for dimensions d0, .. dn appear in 551 /// order to make indexing intuitive into the result. 552 LogicalResult getStridesAndOffset(MemRefType t, 553 SmallVectorImpl<int64_t> &strides, 554 int64_t &offset); 555 LogicalResult getStridesAndOffset(MemRefType t, 556 SmallVectorImpl<AffineExpr> &strides, 557 AffineExpr &offset); 558 559 /// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset() 560 /// represents a dynamic value), return the single result AffineMap which 561 /// represents the linearized strided layout map. Dimensions correspond to the 562 /// offset followed by the strides in order. Symbols are inserted for each 563 /// dynamic dimension in order. A stride cannot take value `0`. 564 /// 565 /// Examples: 566 /// ========= 567 /// 568 /// 1. For offset: 0 strides: ?, ?, 1 return 569 /// (i, j, k)[M, N]->(M * i + N * j + k) 570 /// 571 /// 2. For offset: 3 strides: 32, ?, 16 return 572 /// (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k) 573 /// 574 /// 3. For offset: ? strides: ?, ?, ? return 575 /// (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k) 576 AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset, 577 MLIRContext *context); 578 579 /// Return a version of `t` with identity layout if it can be determined 580 /// statically that the layout is the canonical contiguous strided layout. 581 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 582 /// `t` with simplifed layout. 583 MemRefType canonicalizeStridedLayout(MemRefType t); 584 585 /// Return true if the layout for `t` is compatible with strided semantics. 586 bool isStrided(MemRefType t); 587 588 } // end namespace mlir 589 590 #endif // MLIR_IR_STANDARDTYPES_H 591