1 //===- BuiltinTypes.h - MLIR Builtin Type Classes ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_BUILTINTYPES_H
10 #define MLIR_IR_BUILTINTYPES_H
11 
12 #include "mlir/IR/Types.h"
13 
14 namespace llvm {
15 struct fltSemantics;
16 } // namespace llvm
17 
18 namespace mlir {
19 class AffineExpr;
20 class AffineMap;
21 class FloatType;
22 class Identifier;
23 class IndexType;
24 class IntegerType;
25 class Location;
26 class MLIRContext;
27 class TypeRange;
28 
29 namespace detail {
30 
31 struct BaseMemRefTypeStorage;
32 struct MemRefTypeStorage;
33 struct RankedTensorTypeStorage;
34 struct ShapedTypeStorage;
35 struct UnrankedMemRefTypeStorage;
36 struct UnrankedTensorTypeStorage;
37 struct VectorTypeStorage;
38 
39 } // namespace detail
40 
41 //===----------------------------------------------------------------------===//
42 // FloatType
43 //===----------------------------------------------------------------------===//
44 
45 class FloatType : public Type {
46 public:
47   using Type::Type;
48 
49   // Convenience factories.
50   static FloatType getBF16(MLIRContext *ctx);
51   static FloatType getF16(MLIRContext *ctx);
52   static FloatType getF32(MLIRContext *ctx);
53   static FloatType getF64(MLIRContext *ctx);
54   static FloatType getF80(MLIRContext *ctx);
55   static FloatType getF128(MLIRContext *ctx);
56 
57   /// Methods for support type inquiry through isa, cast, and dyn_cast.
58   static bool classof(Type type);
59 
60   /// Return the bitwidth of this float type.
61   unsigned getWidth();
62 
63   /// Get or create a new FloatType with bitwidth scaled by `scale`.
64   /// Return null if the scaled element type cannot be represented.
65   FloatType scaleElementBitwidth(unsigned scale);
66 
67   /// Return the floating semantics of this float type.
68   const llvm::fltSemantics &getFloatSemantics();
69 };
70 
71 //===----------------------------------------------------------------------===//
72 // ShapedType
73 //===----------------------------------------------------------------------===//
74 
75 /// This is a common base class between Vector, UnrankedTensor, RankedTensor,
76 /// and MemRef types because they share behavior and semantics around shape,
77 /// rank, and fixed element type. Any type with these semantics should inherit
78 /// from ShapedType.
79 class ShapedType : public Type {
80 public:
81   using ImplType = detail::ShapedTypeStorage;
82   using Type::Type;
83 
84   // TODO: merge these two special values in a single one used everywhere.
85   // Unfortunately, uses of `-1` have crept deep into the codebase now and are
86   // hard to track.
87   static constexpr int64_t kDynamicSize = -1;
88   static constexpr int64_t kDynamicStrideOrOffset =
89       std::numeric_limits<int64_t>::min();
90 
91   /// Return the element type.
92   Type getElementType() const;
93 
94   /// If an element type is an integer or a float, return its width. Otherwise,
95   /// abort.
96   unsigned getElementTypeBitWidth() const;
97 
98   /// If it has static shape, return the number of elements. Otherwise, abort.
99   int64_t getNumElements() const;
100 
101   /// If this is a ranked type, return the rank. Otherwise, abort.
102   int64_t getRank() const;
103 
104   /// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
105   /// have a rank, while unranked tensors do not.
106   bool hasRank() const;
107 
108   /// If this is a ranked type, return the shape. Otherwise, abort.
109   ArrayRef<int64_t> getShape() const;
110 
111   /// If this is unranked type or any dimension has unknown size (<0), it
112   /// doesn't have static shape. If all dimensions have known size (>= 0), it
113   /// has static shape.
114   bool hasStaticShape() const;
115 
116   /// If this has a static shape and the shape is equal to `shape` return true.
117   bool hasStaticShape(ArrayRef<int64_t> shape) const;
118 
119   /// If this is a ranked type, return the number of dimensions with dynamic
120   /// size. Otherwise, abort.
121   int64_t getNumDynamicDims() const;
122 
123   /// If this is ranked type, return the size of the specified dimension.
124   /// Otherwise, abort.
125   int64_t getDimSize(unsigned idx) const;
126 
127   /// Returns true if this dimension has a dynamic size (for ranked types);
128   /// aborts for unranked types.
129   bool isDynamicDim(unsigned idx) const;
130 
131   /// Returns the position of the dynamic dimension relative to just the dynamic
132   /// dimensions, given its `index` within the shape.
133   unsigned getDynamicDimIndex(unsigned index) const;
134 
135   /// Get the total amount of bits occupied by a value of this type.  This does
136   /// not take into account any memory layout or widening constraints, e.g. a
137   /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
138   /// it will likely be stored as in a 4xi64 vector register.  Fail an assertion
139   /// if the size cannot be computed statically, i.e. if the type has a dynamic
140   /// shape or if its elemental type does not have a known bit width.
141   int64_t getSizeInBits() const;
142 
143   /// Methods for support type inquiry through isa, cast, and dyn_cast.
144   static bool classof(Type type);
145 
146   /// Whether the given dimension size indicates a dynamic dimension.
isDynamic(int64_t dSize)147   static constexpr bool isDynamic(int64_t dSize) {
148     return dSize == kDynamicSize;
149   }
isDynamicStrideOrOffset(int64_t dStrideOrOffset)150   static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
151     return dStrideOrOffset == kDynamicStrideOrOffset;
152   }
153 };
154 
155 //===----------------------------------------------------------------------===//
156 // VectorType
157 //===----------------------------------------------------------------------===//
158 
159 /// Vector types represent multi-dimensional SIMD vectors, and have a fixed
160 /// known constant shape with one or more dimension.
161 class VectorType
162     : public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> {
163 public:
164   using Base::Base;
165 
166   /// Get or create a new VectorType of the provided shape and element type.
167   /// Assumes the arguments define a well-formed VectorType.
168   static VectorType get(ArrayRef<int64_t> shape, Type elementType);
169 
170   /// Get or create a new VectorType of the provided shape and element type
171   /// declared at the given, potentially unknown, location.  If the VectorType
172   /// defined by the arguments would be ill-formed, emit errors and return
173   /// nullptr-wrapping type.
174   static VectorType getChecked(Location location, ArrayRef<int64_t> shape,
175                                Type elementType);
176 
177   /// Verify the construction of a vector type.
178   static LogicalResult verifyConstructionInvariants(Location loc,
179                                                     ArrayRef<int64_t> shape,
180                                                     Type elementType);
181 
182   /// Returns true of the given type can be used as an element of a vector type.
183   /// In particular, vectors can consist of integer or float primitives.
isValidElementType(Type t)184   static bool isValidElementType(Type t) {
185     return t.isa<IntegerType, FloatType>();
186   }
187 
188   ArrayRef<int64_t> getShape() const;
189 
190   /// Get or create a new VectorType with the same shape as `this` and an
191   /// element type of bitwidth scaled by `scale`.
192   /// Return null if the scaled element type cannot be represented.
193   VectorType scaleElementBitwidth(unsigned scale);
194 };
195 
196 //===----------------------------------------------------------------------===//
197 // TensorType
198 //===----------------------------------------------------------------------===//
199 
200 /// Tensor types represent multi-dimensional arrays, and have two variants:
201 /// RankedTensorType and UnrankedTensorType.
202 class TensorType : public ShapedType {
203 public:
204   using ShapedType::ShapedType;
205 
206   /// Return true if the specified element type is ok in a tensor.
207   static bool isValidElementType(Type type);
208 
209   /// Methods for support type inquiry through isa, cast, and dyn_cast.
210   static bool classof(Type type);
211 };
212 
213 //===----------------------------------------------------------------------===//
214 // RankedTensorType
215 
216 /// Ranked tensor types represent multi-dimensional arrays that have a shape
217 /// with a fixed number of dimensions. Each shape element can be a non-negative
218 /// integer or unknown (represented by -1).
219 class RankedTensorType
220     : public Type::TypeBase<RankedTensorType, TensorType,
221                             detail::RankedTensorTypeStorage> {
222 public:
223   using Base::Base;
224 
225   /// Get or create a new RankedTensorType of the provided shape and element
226   /// type. Assumes the arguments define a well-formed type.
227   static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType);
228 
229   /// Get or create a new RankedTensorType of the provided shape and element
230   /// type declared at the given, potentially unknown, location.  If the
231   /// RankedTensorType defined by the arguments would be ill-formed, emit errors
232   /// and return a nullptr-wrapping type.
233   static RankedTensorType getChecked(Location location, ArrayRef<int64_t> shape,
234                                      Type elementType);
235 
236   /// Verify the construction of a ranked tensor type.
237   static LogicalResult verifyConstructionInvariants(Location loc,
238                                                     ArrayRef<int64_t> shape,
239                                                     Type elementType);
240 
241   ArrayRef<int64_t> getShape() const;
242 };
243 
244 //===----------------------------------------------------------------------===//
245 // UnrankedTensorType
246 
247 /// Unranked tensor types represent multi-dimensional arrays that have an
248 /// unknown shape.
249 class UnrankedTensorType
250     : public Type::TypeBase<UnrankedTensorType, TensorType,
251                             detail::UnrankedTensorTypeStorage> {
252 public:
253   using Base::Base;
254 
255   /// Get or create a new UnrankedTensorType of the provided shape and element
256   /// type. Assumes the arguments define a well-formed type.
257   static UnrankedTensorType get(Type elementType);
258 
259   /// Get or create a new UnrankedTensorType of the provided shape and element
260   /// type declared at the given, potentially unknown, location.  If the
261   /// UnrankedTensorType defined by the arguments would be ill-formed, emit
262   /// errors and return a nullptr-wrapping type.
263   static UnrankedTensorType getChecked(Location location, Type elementType);
264 
265   /// Verify the construction of a unranked tensor type.
266   static LogicalResult verifyConstructionInvariants(Location loc,
267                                                     Type elementType);
268 
getShape()269   ArrayRef<int64_t> getShape() const { return llvm::None; }
270 };
271 
272 //===----------------------------------------------------------------------===//
273 // BaseMemRefType
274 //===----------------------------------------------------------------------===//
275 
276 /// Base MemRef for Ranked and Unranked variants
277 class BaseMemRefType : public ShapedType {
278 public:
279   using ImplType = detail::BaseMemRefTypeStorage;
280   using ShapedType::ShapedType;
281 
282   /// Return true if the specified element type is ok in a memref.
283   static bool isValidElementType(Type type);
284 
285   /// Methods for support type inquiry through isa, cast, and dyn_cast.
286   static bool classof(Type type);
287 
288   /// Returns the memory space in which data referred to by this memref resides.
289   unsigned getMemorySpace() const;
290 };
291 
292 //===----------------------------------------------------------------------===//
293 // MemRefType
294 
295 /// MemRef types represent a region of memory that have a shape with a fixed
296 /// number of dimensions. Each shape element can be a non-negative integer or
297 /// unknown (represented by -1). MemRef types also have an affine map
298 /// composition, represented as an array AffineMap pointers.
299 class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
300                                          detail::MemRefTypeStorage> {
301 public:
302   /// This is a builder type that keeps local references to arguments. Arguments
303   /// that are passed into the builder must out-live the builder.
304   class Builder {
305   public:
306     // Build from another MemRefType.
Builder(MemRefType other)307     explicit Builder(MemRefType other)
308         : shape(other.getShape()), elementType(other.getElementType()),
309           affineMaps(other.getAffineMaps()),
310           memorySpace(other.getMemorySpace()) {}
311 
312     // Build from scratch.
Builder(ArrayRef<int64_t> shape,Type elementType)313     Builder(ArrayRef<int64_t> shape, Type elementType)
314         : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) {
315     }
316 
setShape(ArrayRef<int64_t> newShape)317     Builder &setShape(ArrayRef<int64_t> newShape) {
318       shape = newShape;
319       return *this;
320     }
321 
setElementType(Type newElementType)322     Builder &setElementType(Type newElementType) {
323       elementType = newElementType;
324       return *this;
325     }
326 
setAffineMaps(ArrayRef<AffineMap> newAffineMaps)327     Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) {
328       affineMaps = newAffineMaps;
329       return *this;
330     }
331 
setMemorySpace(unsigned newMemorySpace)332     Builder &setMemorySpace(unsigned newMemorySpace) {
333       memorySpace = newMemorySpace;
334       return *this;
335     }
336 
MemRefType()337     operator MemRefType() {
338       return MemRefType::get(shape, elementType, affineMaps, memorySpace);
339     }
340 
341   private:
342     ArrayRef<int64_t> shape;
343     Type elementType;
344     ArrayRef<AffineMap> affineMaps;
345     unsigned memorySpace;
346   };
347 
348   using Base::Base;
349 
350   /// Get or create a new MemRefType based on shape, element type, affine
351   /// map composition, and memory space.  Assumes the arguments define a
352   /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
353   /// construction failures.
354   static MemRefType get(ArrayRef<int64_t> shape, Type elementType,
355                         ArrayRef<AffineMap> affineMapComposition = {},
356                         unsigned memorySpace = 0);
357 
358   /// Get or create a new MemRefType based on shape, element type, affine
359   /// map composition, and memory space declared at the given location.
360   /// If the location is unknown, the last argument should be an instance of
361   /// UnknownLoc.  If the MemRefType defined by the arguments would be
362   /// ill-formed, emits errors (to the handler registered with the context or to
363   /// the error stream) and returns nullptr.
364   static MemRefType getChecked(Location location, ArrayRef<int64_t> shape,
365                                Type elementType,
366                                ArrayRef<AffineMap> affineMapComposition,
367                                unsigned memorySpace);
368 
369   ArrayRef<int64_t> getShape() const;
370 
371   /// Returns an array of affine map pointers representing the memref affine
372   /// map composition.
373   ArrayRef<AffineMap> getAffineMaps() const;
374 
375   // TODO: merge these two special values in a single one used everywhere.
376   // Unfortunately, uses of `-1` have crept deep into the codebase now and are
377   // hard to track.
getDynamicStrideOrOffset()378   static int64_t getDynamicStrideOrOffset() {
379     return ShapedType::kDynamicStrideOrOffset;
380   }
381 
382 private:
383   /// Get or create a new MemRefType defined by the arguments.  If the resulting
384   /// type would be ill-formed, return nullptr.  If the location is provided,
385   /// emit detailed error messages.
386   static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType,
387                             ArrayRef<AffineMap> affineMapComposition,
388                             unsigned memorySpace, Optional<Location> location);
389   using Base::getImpl;
390 };
391 
392 //===----------------------------------------------------------------------===//
393 // UnrankedMemRefType
394 
395 /// Unranked MemRef type represent multi-dimensional MemRefs that
396 /// have an unknown rank.
397 class UnrankedMemRefType
398     : public Type::TypeBase<UnrankedMemRefType, BaseMemRefType,
399                             detail::UnrankedMemRefTypeStorage> {
400 public:
401   using Base::Base;
402 
403   /// Get or create a new UnrankedMemRefType of the provided element
404   /// type and memory space
405   static UnrankedMemRefType get(Type elementType, unsigned memorySpace);
406 
407   /// Get or create a new UnrankedMemRefType of the provided element
408   /// type and memory space declared at the given, potentially unknown,
409   /// location. If the UnrankedMemRefType defined by the arguments would be
410   /// ill-formed, emit errors and return a nullptr-wrapping type.
411   static UnrankedMemRefType getChecked(Location location, Type elementType,
412                                        unsigned memorySpace);
413 
414   /// Verify the construction of a unranked memref type.
415   static LogicalResult verifyConstructionInvariants(Location loc,
416                                                     Type elementType,
417                                                     unsigned memorySpace);
418 
getShape()419   ArrayRef<int64_t> getShape() const { return llvm::None; }
420 };
421 } // end namespace mlir
422 
423 //===----------------------------------------------------------------------===//
424 // Tablegen Type Declarations
425 //===----------------------------------------------------------------------===//
426 
427 #define GET_TYPEDEF_CLASSES
428 #include "mlir/IR/BuiltinTypes.h.inc"
429 
430 //===----------------------------------------------------------------------===//
431 // Deferred Method Definitions
432 //===----------------------------------------------------------------------===//
433 
434 namespace mlir {
classof(Type type)435 inline bool BaseMemRefType::classof(Type type) {
436   return type.isa<MemRefType, UnrankedMemRefType>();
437 }
438 
isValidElementType(Type type)439 inline bool BaseMemRefType::isValidElementType(Type type) {
440   return type.isIntOrIndexOrFloat() || type.isa<ComplexType, VectorType>();
441 }
442 
classof(Type type)443 inline bool FloatType::classof(Type type) {
444   return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
445                   Float80Type, Float128Type>();
446 }
447 
getBF16(MLIRContext * ctx)448 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
449   return BFloat16Type::get(ctx);
450 }
451 
getF16(MLIRContext * ctx)452 inline FloatType FloatType::getF16(MLIRContext *ctx) {
453   return Float16Type::get(ctx);
454 }
455 
getF32(MLIRContext * ctx)456 inline FloatType FloatType::getF32(MLIRContext *ctx) {
457   return Float32Type::get(ctx);
458 }
459 
getF64(MLIRContext * ctx)460 inline FloatType FloatType::getF64(MLIRContext *ctx) {
461   return Float64Type::get(ctx);
462 }
463 
getF80(MLIRContext * ctx)464 inline FloatType FloatType::getF80(MLIRContext *ctx) {
465   return Float80Type::get(ctx);
466 }
467 
getF128(MLIRContext * ctx)468 inline FloatType FloatType::getF128(MLIRContext *ctx) {
469   return Float128Type::get(ctx);
470 }
471 
classof(Type type)472 inline bool ShapedType::classof(Type type) {
473   return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
474                   UnrankedMemRefType, MemRefType>();
475 }
476 
classof(Type type)477 inline bool TensorType::classof(Type type) {
478   return type.isa<RankedTensorType, UnrankedTensorType>();
479 }
480 
481 //===----------------------------------------------------------------------===//
482 // Type Utilities
483 //===----------------------------------------------------------------------===//
484 
485 /// Returns the strides of the MemRef if the layout map is in strided form.
486 /// MemRefs with layout maps in strided form include:
487 ///   1. empty or identity layout map, in which case the stride information is
488 ///      the canonical form computed from sizes;
489 ///   2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`,
490 ///      where K and ki's are constants or symbols.
491 ///
492 /// A stride specification is a list of integer values that are either static
493 /// or dynamic (encoded with getDynamicStrideOrOffset()). Strides encode the
494 /// distance in the number of elements between successive entries along a
495 /// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>`
496 /// specifies a view into a non-contiguous memory region of `42` by `16` `f32`
497 /// elements in which the distance between two consecutive elements along the
498 /// outer dimension is `1` and the distance between two consecutive elements
499 /// along the inner dimension is `64`.
500 ///
501 /// Returns whether a simple strided form can be extracted from the composition
502 /// of the layout map.
503 ///
504 /// The convention is that the strides for dimensions d0, .. dn appear in
505 /// order to make indexing intuitive into the result.
506 LogicalResult getStridesAndOffset(MemRefType t,
507                                   SmallVectorImpl<int64_t> &strides,
508                                   int64_t &offset);
509 LogicalResult getStridesAndOffset(MemRefType t,
510                                   SmallVectorImpl<AffineExpr> &strides,
511                                   AffineExpr &offset);
512 
513 /// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset()
514 /// represents a dynamic value), return the single result AffineMap which
515 /// represents the linearized strided layout map. Dimensions correspond to the
516 /// offset followed by the strides in order. Symbols are inserted for each
517 /// dynamic dimension in order. A stride cannot take value `0`.
518 ///
519 /// Examples:
520 /// =========
521 ///
522 ///   1. For offset: 0 strides: ?, ?, 1 return
523 ///         (i, j, k)[M, N]->(M * i + N * j + k)
524 ///
525 ///   2. For offset: 3 strides: 32, ?, 16 return
526 ///         (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k)
527 ///
528 ///   3. For offset: ? strides: ?, ?, ? return
529 ///         (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k)
530 AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
531                                      MLIRContext *context);
532 
533 /// Return a version of `t` with identity layout if it can be determined
534 /// statically that the layout is the canonical contiguous strided layout.
535 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
536 /// `t` with simplified layout.
537 MemRefType canonicalizeStridedLayout(MemRefType t);
538 
539 /// Return a version of `t` with a layout that has all dynamic offset and
540 /// strides. This is used to erase the static layout.
541 MemRefType eraseStridedLayout(MemRefType t);
542 
543 /// Given MemRef `sizes` that are either static or dynamic, returns the
544 /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
545 /// once a dynamic dimension is encountered, all canonical strides become
546 /// dynamic and need to be encoded with a different symbol.
547 /// For canonical strides expressions, the offset is always 0 and and fastest
548 /// varying stride is always `1`.
549 ///
550 /// Examples:
551 ///   - memref<3x4x5xf32> has canonical stride expression
552 ///         `20*exprs[0] + 5*exprs[1] + exprs[2]`.
553 ///   - memref<3x?x5xf32> has canonical stride expression
554 ///         `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
555 ///   - memref<3x4x?xf32> has canonical stride expression
556 ///         `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
557 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
558                                           ArrayRef<AffineExpr> exprs,
559                                           MLIRContext *context);
560 
561 /// Return the result of makeCanonicalStrudedLayoutExpr for the common case
562 /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
563 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
564                                           MLIRContext *context);
565 
566 /// Return true if the layout for `t` is compatible with strided semantics.
567 bool isStrided(MemRefType t);
568 
569 } // end namespace mlir
570 
571 #endif // MLIR_IR_BUILTINTYPES_H
572