1 //===- StandardTypes.h - MLIR Standard Type Classes -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_STANDARDTYPES_H
10 #define MLIR_IR_STANDARDTYPES_H
11 
12 #include "mlir/IR/Types.h"
13 
14 namespace llvm {
15 struct fltSemantics;
16 } // namespace llvm
17 
18 namespace mlir {
19 class AffineExpr;
20 class AffineMap;
21 class FloatType;
22 class IndexType;
23 class IntegerType;
24 class Location;
25 class MLIRContext;
26 
27 namespace detail {
28 
29 struct IntegerTypeStorage;
30 struct ShapedTypeStorage;
31 struct VectorTypeStorage;
32 struct RankedTensorTypeStorage;
33 struct UnrankedTensorTypeStorage;
34 struct MemRefTypeStorage;
35 struct UnrankedMemRefTypeStorage;
36 struct ComplexTypeStorage;
37 struct TupleTypeStorage;
38 
39 } // namespace detail
40 
41 namespace StandardTypes {
42 enum Kind {
43   // Floating point.
44   BF16 = Type::Kind::FIRST_STANDARD_TYPE,
45   F16,
46   F32,
47   F64,
48   FIRST_FLOATING_POINT_TYPE = BF16,
49   LAST_FLOATING_POINT_TYPE = F64,
50 
51   // Target pointer sized integer, used (e.g.) in affine mappings.
52   Index,
53 
54   // Derived types.
55   Integer,
56   Vector,
57   RankedTensor,
58   UnrankedTensor,
59   MemRef,
60   UnrankedMemRef,
61   Complex,
62   Tuple,
63   None,
64 };
65 
66 } // namespace StandardTypes
67 
68 //===----------------------------------------------------------------------===//
69 // ComplexType
70 //===----------------------------------------------------------------------===//
71 
72 /// The 'complex' type represents a complex number with a parameterized element
73 /// type, which is composed of a real and imaginary value of that element type.
74 ///
75 /// The element must be a floating point or integer scalar type.
76 ///
77 class ComplexType
78     : public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> {
79 public:
80   using Base::Base;
81 
82   /// Get or create a ComplexType with the provided element type.
83   static ComplexType get(Type elementType);
84 
85   /// Get or create a ComplexType with the provided element type.  This emits
86   /// and error at the specified location and returns null if the element type
87   /// isn't supported.
88   static ComplexType getChecked(Type elementType, Location location);
89 
90   /// Verify the construction of an integer type.
91   static LogicalResult verifyConstructionInvariants(Location loc,
92                                                     Type elementType);
93 
94   Type getElementType();
95 
kindof(unsigned kind)96   static bool kindof(unsigned kind) { return kind == StandardTypes::Complex; }
97 };
98 
99 //===----------------------------------------------------------------------===//
100 // IndexType
101 //===----------------------------------------------------------------------===//
102 
103 /// Index is a special integer-like type with unknown platform-dependent bit
104 /// width.
105 class IndexType : public Type::TypeBase<IndexType, Type, TypeStorage> {
106 public:
107   using Base::Base;
108 
109   /// Get an instance of the IndexType.
110   static IndexType get(MLIRContext *context);
111 
112   /// Support method to enable LLVM-style type casting.
kindof(unsigned kind)113   static bool kindof(unsigned kind) { return kind == StandardTypes::Index; }
114 
115   /// Storage bit width used for IndexType by internal compiler data structures.
116   static constexpr unsigned kInternalStorageBitWidth = 64;
117 };
118 
119 //===----------------------------------------------------------------------===//
120 // IntegerType
121 //===----------------------------------------------------------------------===//
122 
123 /// Integer types can have arbitrary bitwidth up to a large fixed limit.
124 class IntegerType
125     : public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> {
126 public:
127   using Base::Base;
128 
129   /// Signedness semantics.
130   enum SignednessSemantics {
131     Signless, /// No signedness semantics
132     Signed,   /// Signed integer
133     Unsigned, /// Unsigned integer
134   };
135 
136   /// Get or create a new IntegerType of the given width within the context.
137   /// The created IntegerType is signless (i.e., no signedness semantics).
138   /// Assume the width is within the allowed range and assert on failures. Use
139   /// getChecked to handle failures gracefully.
140   static IntegerType get(unsigned width, MLIRContext *context);
141 
142   /// Get or create a new IntegerType of the given width within the context.
143   /// The created IntegerType has signedness semantics as indicated via
144   /// `signedness`. Assume the width is within the allowed range and assert on
145   /// failures. Use getChecked to handle failures gracefully.
146   static IntegerType get(unsigned width, SignednessSemantics signedness,
147                          MLIRContext *context);
148 
149   /// Get or create a new IntegerType of the given width within the context,
150   /// defined at the given, potentially unknown, location.  The created
151   /// IntegerType is signless (i.e., no signedness semantics). If the width is
152   /// outside the allowed range, emit errors and return a null type.
153   static IntegerType getChecked(unsigned width, Location location);
154 
155   /// Get or create a new IntegerType of the given width within the context,
156   /// defined at the given, potentially unknown, location. The created
157   /// IntegerType has signedness semantics as indicated via `signedness`. If the
158   /// width is outside the allowed range, emit errors and return a null type.
159   static IntegerType getChecked(unsigned width, SignednessSemantics signedness,
160                                 Location location);
161 
162   /// Verify the construction of an integer type.
163   static LogicalResult
164   verifyConstructionInvariants(Location loc, unsigned width,
165                                SignednessSemantics signedness);
166 
167   /// Return the bitwidth of this integer type.
168   unsigned getWidth() const;
169 
170   /// Return the signedness semantics of this integer type.
171   SignednessSemantics getSignedness() const;
172 
173   /// Return true if this is a signless integer type.
isSignless()174   bool isSignless() const { return getSignedness() == Signless; }
175   /// Return true if this is a signed integer type.
isSigned()176   bool isSigned() const { return getSignedness() == Signed; }
177   /// Return true if this is an unsigned integer type.
isUnsigned()178   bool isUnsigned() const { return getSignedness() == Unsigned; }
179 
180   /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)181   static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; }
182 
183   /// Integer representation maximal bitwidth.
184   static constexpr unsigned kMaxWidth = 4096;
185 };
186 
187 //===----------------------------------------------------------------------===//
188 // FloatType
189 //===----------------------------------------------------------------------===//
190 
191 class FloatType : public Type::TypeBase<FloatType, Type, TypeStorage> {
192 public:
193   using Base::Base;
194 
195   static FloatType get(StandardTypes::Kind kind, MLIRContext *context);
196 
197   // Convenience factories.
getBF16(MLIRContext * ctx)198   static FloatType getBF16(MLIRContext *ctx) {
199     return get(StandardTypes::BF16, ctx);
200   }
getF16(MLIRContext * ctx)201   static FloatType getF16(MLIRContext *ctx) {
202     return get(StandardTypes::F16, ctx);
203   }
getF32(MLIRContext * ctx)204   static FloatType getF32(MLIRContext *ctx) {
205     return get(StandardTypes::F32, ctx);
206   }
getF64(MLIRContext * ctx)207   static FloatType getF64(MLIRContext *ctx) {
208     return get(StandardTypes::F64, ctx);
209   }
210 
211   /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)212   static bool kindof(unsigned kind) {
213     return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE &&
214            kind <= StandardTypes::LAST_FLOATING_POINT_TYPE;
215   }
216 
217   /// Return the bitwidth of this float type.
218   unsigned getWidth();
219 
220   /// Return the floating semantics of this float type.
221   const llvm::fltSemantics &getFloatSemantics();
222 };
223 
224 //===----------------------------------------------------------------------===//
225 // NoneType
226 //===----------------------------------------------------------------------===//
227 
228 /// NoneType is a unit type, i.e. a type with exactly one possible value, where
229 /// its value does not have a defined dynamic representation.
230 class NoneType : public Type::TypeBase<NoneType, Type, TypeStorage> {
231 public:
232   using Base::Base;
233 
234   /// Get an instance of the NoneType.
235   static NoneType get(MLIRContext *context);
236 
kindof(unsigned kind)237   static bool kindof(unsigned kind) { return kind == StandardTypes::None; }
238 };
239 
240 //===----------------------------------------------------------------------===//
241 // ShapedType
242 //===----------------------------------------------------------------------===//
243 
244 /// This is a common base class between Vector, UnrankedTensor, RankedTensor,
245 /// and MemRef types because they share behavior and semantics around shape,
246 /// rank, and fixed element type. Any type with these semantics should inherit
247 /// from ShapedType.
248 class ShapedType : public Type {
249 public:
250   using ImplType = detail::ShapedTypeStorage;
251   using Type::Type;
252 
253   // TODO: merge these two special values in a single one used everywhere.
254   // Unfortunately, uses of `-1` have crept deep into the codebase now and are
255   // hard to track.
256   static constexpr int64_t kDynamicSize = -1;
257   static constexpr int64_t kDynamicStrideOrOffset =
258       std::numeric_limits<int64_t>::min();
259 
260   /// Return the element type.
261   Type getElementType() const;
262 
263   /// If an element type is an integer or a float, return its width. Otherwise,
264   /// abort.
265   unsigned getElementTypeBitWidth() const;
266 
267   /// If it has static shape, return the number of elements. Otherwise, abort.
268   int64_t getNumElements() const;
269 
270   /// If this is a ranked type, return the rank. Otherwise, abort.
271   int64_t getRank() const;
272 
273   /// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
274   /// have a rank, while unranked tensors do not.
275   bool hasRank() const;
276 
277   /// If this is a ranked type, return the shape. Otherwise, abort.
278   ArrayRef<int64_t> getShape() const;
279 
280   /// If this is unranked type or any dimension has unknown size (<0), it
281   /// doesn't have static shape. If all dimensions have known size (>= 0), it
282   /// has static shape.
283   bool hasStaticShape() const;
284 
285   /// If this has a static shape and the shape is equal to `shape` return true.
286   bool hasStaticShape(ArrayRef<int64_t> shape) const;
287 
288   /// If this is a ranked type, return the number of dimensions with dynamic
289   /// size. Otherwise, abort.
290   int64_t getNumDynamicDims() const;
291 
292   /// If this is ranked type, return the size of the specified dimension.
293   /// Otherwise, abort.
294   int64_t getDimSize(unsigned idx) const;
295 
296   /// Returns true if this dimension has a dynamic size (for ranked types);
297   /// aborts for unranked types.
298   bool isDynamicDim(unsigned idx) const;
299 
300   /// Returns the position of the dynamic dimension relative to just the dynamic
301   /// dimensions, given its `index` within the shape.
302   unsigned getDynamicDimIndex(unsigned index) const;
303 
304   /// Get the total amount of bits occupied by a value of this type.  This does
305   /// not take into account any memory layout or widening constraints, e.g. a
306   /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
307   /// it will likely be stored as in a 4xi64 vector register.  Fail an assertion
308   /// if the size cannot be computed statically, i.e. if the type has a dynamic
309   /// shape or if its elemental type does not have a known bit width.
310   int64_t getSizeInBits() const;
311 
312   /// Methods for support type inquiry through isa, cast, and dyn_cast.
classof(Type type)313   static bool classof(Type type) {
314     return type.getKind() == StandardTypes::Vector ||
315            type.getKind() == StandardTypes::RankedTensor ||
316            type.getKind() == StandardTypes::UnrankedTensor ||
317            type.getKind() == StandardTypes::UnrankedMemRef ||
318            type.getKind() == StandardTypes::MemRef;
319   }
320 
321   /// Whether the given dimension size indicates a dynamic dimension.
isDynamic(int64_t dSize)322   static constexpr bool isDynamic(int64_t dSize) {
323     return dSize == kDynamicSize;
324   }
isDynamicStrideOrOffset(int64_t dStrideOrOffset)325   static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
326     return dStrideOrOffset == kDynamicStrideOrOffset;
327   }
328 };
329 
330 //===----------------------------------------------------------------------===//
331 // VectorType
332 //===----------------------------------------------------------------------===//
333 
334 /// Vector types represent multi-dimensional SIMD vectors, and have a fixed
335 /// known constant shape with one or more dimension.
336 class VectorType
337     : public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> {
338 public:
339   using Base::Base;
340 
341   /// Get or create a new VectorType of the provided shape and element type.
342   /// Assumes the arguments define a well-formed VectorType.
343   static VectorType get(ArrayRef<int64_t> shape, Type elementType);
344 
345   /// Get or create a new VectorType of the provided shape and element type
346   /// declared at the given, potentially unknown, location.  If the VectorType
347   /// defined by the arguments would be ill-formed, emit errors and return
348   /// nullptr-wrapping type.
349   static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType,
350                                Location location);
351 
352   /// Verify the construction of a vector type.
353   static LogicalResult verifyConstructionInvariants(Location loc,
354                                                     ArrayRef<int64_t> shape,
355                                                     Type elementType);
356 
357   /// Returns true of the given type can be used as an element of a vector type.
358   /// In particular, vectors can consist of integer or float primitives.
isValidElementType(Type t)359   static bool isValidElementType(Type t) {
360     return t.isa<IntegerType, FloatType>();
361   }
362 
363   ArrayRef<int64_t> getShape() const;
364 
365   /// Methods for support type inquiry through isa, cast, and dyn_cast.
kindof(unsigned kind)366   static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; }
367 };
368 
369 //===----------------------------------------------------------------------===//
370 // TensorType
371 //===----------------------------------------------------------------------===//
372 
373 /// Tensor types represent multi-dimensional arrays, and have two variants:
374 /// RankedTensorType and UnrankedTensorType.
375 class TensorType : public ShapedType {
376 public:
377   using ShapedType::ShapedType;
378 
379   /// Return true if the specified element type is ok in a tensor.
isValidElementType(Type type)380   static bool isValidElementType(Type type) {
381     // Note: Non standard/builtin types are allowed to exist within tensor
382     // types. Dialects are expected to verify that tensor types have a valid
383     // element type within that dialect.
384     return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
385                     IndexType>() ||
386            (type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
387   }
388 
389   /// Methods for support type inquiry through isa, cast, and dyn_cast.
classof(Type type)390   static bool classof(Type type) {
391     return type.getKind() == StandardTypes::RankedTensor ||
392            type.getKind() == StandardTypes::UnrankedTensor;
393   }
394 };
395 
396 //===----------------------------------------------------------------------===//
397 // RankedTensorType
398 
399 /// Ranked tensor types represent multi-dimensional arrays that have a shape
400 /// with a fixed number of dimensions. Each shape element can be a non-negative
401 /// integer or unknown (represented by -1).
402 class RankedTensorType
403     : public Type::TypeBase<RankedTensorType, TensorType,
404                             detail::RankedTensorTypeStorage> {
405 public:
406   using Base::Base;
407 
408   /// Get or create a new RankedTensorType of the provided shape and element
409   /// type. Assumes the arguments define a well-formed type.
410   static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType);
411 
412   /// Get or create a new RankedTensorType of the provided shape and element
413   /// type declared at the given, potentially unknown, location.  If the
414   /// RankedTensorType defined by the arguments would be ill-formed, emit errors
415   /// and return a nullptr-wrapping type.
416   static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType,
417                                      Location location);
418 
419   /// Verify the construction of a ranked tensor type.
420   static LogicalResult verifyConstructionInvariants(Location loc,
421                                                     ArrayRef<int64_t> shape,
422                                                     Type elementType);
423 
424   ArrayRef<int64_t> getShape() const;
425 
kindof(unsigned kind)426   static bool kindof(unsigned kind) {
427     return kind == StandardTypes::RankedTensor;
428   }
429 };
430 
431 //===----------------------------------------------------------------------===//
432 // UnrankedTensorType
433 
434 /// Unranked tensor types represent multi-dimensional arrays that have an
435 /// unknown shape.
436 class UnrankedTensorType
437     : public Type::TypeBase<UnrankedTensorType, TensorType,
438                             detail::UnrankedTensorTypeStorage> {
439 public:
440   using Base::Base;
441 
442   /// Get or create a new UnrankedTensorType of the provided shape and element
443   /// type. Assumes the arguments define a well-formed type.
444   static UnrankedTensorType get(Type elementType);
445 
446   /// Get or create a new UnrankedTensorType of the provided shape and element
447   /// type declared at the given, potentially unknown, location.  If the
448   /// UnrankedTensorType defined by the arguments would be ill-formed, emit
449   /// errors and return a nullptr-wrapping type.
450   static UnrankedTensorType getChecked(Type elementType, Location location);
451 
452   /// Verify the construction of a unranked tensor type.
453   static LogicalResult verifyConstructionInvariants(Location loc,
454                                                     Type elementType);
455 
getShape()456   ArrayRef<int64_t> getShape() const { return llvm::None; }
457 
kindof(unsigned kind)458   static bool kindof(unsigned kind) {
459     return kind == StandardTypes::UnrankedTensor;
460   }
461 };
462 
463 //===----------------------------------------------------------------------===//
464 // BaseMemRefType
465 //===----------------------------------------------------------------------===//
466 
467 /// Base MemRef for Ranked and Unranked variants
468 class BaseMemRefType : public ShapedType {
469 public:
470   using ShapedType::ShapedType;
471 
472   /// Methods for support type inquiry through isa, cast, and dyn_cast.
classof(Type type)473   static bool classof(Type type) {
474     return type.getKind() == StandardTypes::MemRef ||
475            type.getKind() == StandardTypes::UnrankedMemRef;
476   }
477 };
478 
479 //===----------------------------------------------------------------------===//
480 // MemRefType
481 
482 /// MemRef types represent a region of memory that have a shape with a fixed
483 /// number of dimensions. Each shape element can be a non-negative integer or
484 /// unknown (represented by -1). MemRef types also have an affine map
485 /// composition, represented as an array AffineMap pointers.
486 class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
487                                          detail::MemRefTypeStorage> {
488 public:
489   /// This is a builder type that keeps local references to arguments. Arguments
490   /// that are passed into the builder must out-live the builder.
491   class Builder {
492   public:
493     // Build from another MemRefType.
Builder(MemRefType other)494     explicit Builder(MemRefType other)
495         : shape(other.getShape()), elementType(other.getElementType()),
496           affineMaps(other.getAffineMaps()),
497           memorySpace(other.getMemorySpace()) {}
498 
499     // Build from scratch.
Builder(ArrayRef<int64_t> shape,Type elementType)500     Builder(ArrayRef<int64_t> shape, Type elementType)
501         : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) {
502     }
503 
setShape(ArrayRef<int64_t> newShape)504     Builder &setShape(ArrayRef<int64_t> newShape) {
505       shape = newShape;
506       return *this;
507     }
508 
setElementType(Type newElementType)509     Builder &setElementType(Type newElementType) {
510       elementType = newElementType;
511       return *this;
512     }
513 
setAffineMaps(ArrayRef<AffineMap> newAffineMaps)514     Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) {
515       affineMaps = newAffineMaps;
516       return *this;
517     }
518 
setMemorySpace(unsigned newMemorySpace)519     Builder &setMemorySpace(unsigned newMemorySpace) {
520       memorySpace = newMemorySpace;
521       return *this;
522     }
523 
MemRefType()524     operator MemRefType() {
525       return MemRefType::get(shape, elementType, affineMaps, memorySpace);
526     }
527 
528   private:
529     ArrayRef<int64_t> shape;
530     Type elementType;
531     ArrayRef<AffineMap> affineMaps;
532     unsigned memorySpace;
533   };
534 
535   using Base::Base;
536 
537   /// Get or create a new MemRefType based on shape, element type, affine
538   /// map composition, and memory space.  Assumes the arguments define a
539   /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
540   /// construction failures.
541   static MemRefType get(ArrayRef<int64_t> shape, Type elementType,
542                         ArrayRef<AffineMap> affineMapComposition = {},
543                         unsigned memorySpace = 0);
544 
545   /// Get or create a new MemRefType based on shape, element type, affine
546   /// map composition, and memory space declared at the given location.
547   /// If the location is unknown, the last argument should be an instance of
548   /// UnknownLoc.  If the MemRefType defined by the arguments would be
549   /// ill-formed, emits errors (to the handler registered with the context or to
550   /// the error stream) and returns nullptr.
551   static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType,
552                                ArrayRef<AffineMap> affineMapComposition,
553                                unsigned memorySpace, Location location);
554 
555   ArrayRef<int64_t> getShape() const;
556 
557   /// Returns an array of affine map pointers representing the memref affine
558   /// map composition.
559   ArrayRef<AffineMap> getAffineMaps() const;
560 
561   /// Returns the memory space in which data referred to by this memref resides.
562   unsigned getMemorySpace() const;
563 
564   // TODO: merge these two special values in a single one used everywhere.
565   // Unfortunately, uses of `-1` have crept deep into the codebase now and are
566   // hard to track.
567   static constexpr int64_t kDynamicSize = -1;
getDynamicStrideOrOffset()568   static int64_t getDynamicStrideOrOffset() {
569     return ShapedType::kDynamicStrideOrOffset;
570   }
571 
kindof(unsigned kind)572   static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
573 
574 private:
575   /// Get or create a new MemRefType defined by the arguments.  If the resulting
576   /// type would be ill-formed, return nullptr.  If the location is provided,
577   /// emit detailed error messages.
578   static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType,
579                             ArrayRef<AffineMap> affineMapComposition,
580                             unsigned memorySpace, Optional<Location> location);
581   using Base::getImpl;
582 };
583 
584 //===----------------------------------------------------------------------===//
585 // UnrankedMemRefType
586 
587 /// Unranked MemRef type represent multi-dimensional MemRefs that
588 /// have an unknown rank.
589 class UnrankedMemRefType
590     : public Type::TypeBase<UnrankedMemRefType, BaseMemRefType,
591                             detail::UnrankedMemRefTypeStorage> {
592 public:
593   using Base::Base;
594 
595   /// Get or create a new UnrankedMemRefType of the provided element
596   /// type and memory space
597   static UnrankedMemRefType get(Type elementType, unsigned memorySpace);
598 
599   /// Get or create a new UnrankedMemRefType of the provided element
600   /// type and memory space declared at the given, potentially unknown,
601   /// location. If the UnrankedMemRefType defined by the arguments would be
602   /// ill-formed, emit errors and return a nullptr-wrapping type.
603   static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace,
604                                        Location location);
605 
606   /// Verify the construction of a unranked memref type.
607   static LogicalResult verifyConstructionInvariants(Location loc,
608                                                     Type elementType,
609                                                     unsigned memorySpace);
610 
getShape()611   ArrayRef<int64_t> getShape() const { return llvm::None; }
612 
613   /// Returns the memory space in which data referred to by this memref resides.
614   unsigned getMemorySpace() const;
kindof(unsigned kind)615   static bool kindof(unsigned kind) {
616     return kind == StandardTypes::UnrankedMemRef;
617   }
618 };
619 
620 //===----------------------------------------------------------------------===//
621 // TupleType
622 //===----------------------------------------------------------------------===//
623 
624 /// Tuple types represent a collection of other types. Note: This type merely
625 /// provides a common mechanism for representing tuples in MLIR. It is up to
626 /// dialect authors to provides operations for manipulating them, e.g.
627 /// extract_tuple_element. When possible, users should prefer multi-result
628 /// operations in the place of tuples.
629 class TupleType
630     : public Type::TypeBase<TupleType, Type, detail::TupleTypeStorage> {
631 public:
632   using Base::Base;
633 
634   /// Get or create a new TupleType with the provided element types. Assumes the
635   /// arguments define a well-formed type.
636   static TupleType get(ArrayRef<Type> elementTypes, MLIRContext *context);
637 
638   /// Get or create an empty tuple type.
get(MLIRContext * context)639   static TupleType get(MLIRContext *context) { return get({}, context); }
640 
641   /// Return the elements types for this tuple.
642   ArrayRef<Type> getTypes() const;
643 
644   /// Accumulate the types contained in this tuple and tuples nested within it.
645   /// Note that this only flattens nested tuples, not any other container type,
646   /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
647   /// (i32, tensor<i32>, f32, i64)
648   void getFlattenedTypes(SmallVectorImpl<Type> &types);
649 
650   /// Return the number of held types.
651   size_t size() const;
652 
653   /// Iterate over the held elements.
654   using iterator = ArrayRef<Type>::iterator;
begin()655   iterator begin() const { return getTypes().begin(); }
end()656   iterator end() const { return getTypes().end(); }
657 
658   /// Return the element type at index 'index'.
getType(size_t index)659   Type getType(size_t index) const {
660     assert(index < size() && "invalid index for tuple type");
661     return getTypes()[index];
662   }
663 
kindof(unsigned kind)664   static bool kindof(unsigned kind) { return kind == StandardTypes::Tuple; }
665 };
666 
667 //===----------------------------------------------------------------------===//
668 // Type Utilities
669 //===----------------------------------------------------------------------===//
670 
671 /// Returns the strides of the MemRef if the layout map is in strided form.
672 /// MemRefs with layout maps in strided form include:
673 ///   1. empty or identity layout map, in which case the stride information is
674 ///      the canonical form computed from sizes;
675 ///   2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`,
676 ///      where K and ki's are constants or symbols.
677 ///
678 /// A stride specification is a list of integer values that are either static
679 /// or dynamic (encoded with getDynamicStrideOrOffset()). Strides encode the
680 /// distance in the number of elements between successive entries along a
681 /// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>`
682 /// specifies a view into a non-contiguous memory region of `42` by `16` `f32`
683 /// elements in which the distance between two consecutive elements along the
684 /// outer dimension is `1` and the distance between two consecutive elements
685 /// along the inner dimension is `64`.
686 ///
687 /// If a simple strided form cannot be extracted from the composition of the
688 /// layout map, returns llvm::None.
689 ///
690 /// The convention is that the strides for dimensions d0, .. dn appear in
691 /// order to make indexing intuitive into the result.
692 LogicalResult getStridesAndOffset(MemRefType t,
693                                   SmallVectorImpl<int64_t> &strides,
694                                   int64_t &offset);
695 LogicalResult getStridesAndOffset(MemRefType t,
696                                   SmallVectorImpl<AffineExpr> &strides,
697                                   AffineExpr &offset);
698 
699 /// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset()
700 /// represents a dynamic value), return the single result AffineMap which
701 /// represents the linearized strided layout map. Dimensions correspond to the
702 /// offset followed by the strides in order. Symbols are inserted for each
703 /// dynamic dimension in order. A stride cannot take value `0`.
704 ///
705 /// Examples:
706 /// =========
707 ///
708 ///   1. For offset: 0 strides: ?, ?, 1 return
709 ///         (i, j, k)[M, N]->(M * i + N * j + k)
710 ///
711 ///   2. For offset: 3 strides: 32, ?, 16 return
712 ///         (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k)
713 ///
714 ///   3. For offset: ? strides: ?, ?, ? return
715 ///         (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k)
716 AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
717                                      MLIRContext *context);
718 
719 /// Return a version of `t` with identity layout if it can be determined
720 /// statically that the layout is the canonical contiguous strided layout.
721 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
722 /// `t` with simplified layout.
723 MemRefType canonicalizeStridedLayout(MemRefType t);
724 
725 /// Return a version of `t` with a layout that has all dynamic offset and
726 /// strides. This is used to erase the static layout.
727 MemRefType eraseStridedLayout(MemRefType t);
728 
729 /// Given MemRef `sizes` that are either static or dynamic, returns the
730 /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
731 /// once a dynamic dimension is encountered, all canonical strides become
732 /// dynamic and need to be encoded with a different symbol.
733 /// For canonical strides expressions, the offset is always 0 and and fastest
734 /// varying stride is always `1`.
735 ///
736 /// Examples:
737 ///   - memref<3x4x5xf32> has canonical stride expression
738 ///         `20*exprs[0] + 5*exprs[1] + exprs[2]`.
739 ///   - memref<3x?x5xf32> has canonical stride expression
740 ///         `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
741 ///   - memref<3x4x?xf32> has canonical stride expression
742 ///         `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
743 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
744                                           ArrayRef<AffineExpr> exprs,
745                                           MLIRContext *context);
746 
747 /// Return the result of makeCanonicalStrudedLayoutExpr for the common case
748 /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
749 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
750                                           MLIRContext *context);
751 
752 /// Return true if the layout for `t` is compatible with strided semantics.
753 bool isStrided(MemRefType t);
754 
755 } // end namespace mlir
756 
757 #endif // MLIR_IR_STANDARDTYPES_H
758