1 //===- BuiltinTypes.h - MLIR Builtin Type Classes ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_IR_BUILTINTYPES_H
10 #define MLIR_IR_BUILTINTYPES_H
11
12 #include "mlir/IR/Types.h"
13
14 namespace llvm {
15 struct fltSemantics;
16 } // namespace llvm
17
18 namespace mlir {
19 class AffineExpr;
20 class AffineMap;
21 class FloatType;
22 class Identifier;
23 class IndexType;
24 class IntegerType;
25 class Location;
26 class MLIRContext;
27 class TypeRange;
28
29 namespace detail {
30
31 struct BaseMemRefTypeStorage;
32 struct MemRefTypeStorage;
33 struct RankedTensorTypeStorage;
34 struct ShapedTypeStorage;
35 struct UnrankedMemRefTypeStorage;
36 struct UnrankedTensorTypeStorage;
37 struct VectorTypeStorage;
38
39 } // namespace detail
40
41 //===----------------------------------------------------------------------===//
42 // FloatType
43 //===----------------------------------------------------------------------===//
44
45 class FloatType : public Type {
46 public:
47 using Type::Type;
48
49 // Convenience factories.
50 static FloatType getBF16(MLIRContext *ctx);
51 static FloatType getF16(MLIRContext *ctx);
52 static FloatType getF32(MLIRContext *ctx);
53 static FloatType getF64(MLIRContext *ctx);
54 static FloatType getF80(MLIRContext *ctx);
55 static FloatType getF128(MLIRContext *ctx);
56
57 /// Methods for support type inquiry through isa, cast, and dyn_cast.
58 static bool classof(Type type);
59
60 /// Return the bitwidth of this float type.
61 unsigned getWidth();
62
63 /// Get or create a new FloatType with bitwidth scaled by `scale`.
64 /// Return null if the scaled element type cannot be represented.
65 FloatType scaleElementBitwidth(unsigned scale);
66
67 /// Return the floating semantics of this float type.
68 const llvm::fltSemantics &getFloatSemantics();
69 };
70
71 //===----------------------------------------------------------------------===//
72 // ShapedType
73 //===----------------------------------------------------------------------===//
74
75 /// This is a common base class between Vector, UnrankedTensor, RankedTensor,
76 /// and MemRef types because they share behavior and semantics around shape,
77 /// rank, and fixed element type. Any type with these semantics should inherit
78 /// from ShapedType.
79 class ShapedType : public Type {
80 public:
81 using ImplType = detail::ShapedTypeStorage;
82 using Type::Type;
83
84 // TODO: merge these two special values in a single one used everywhere.
85 // Unfortunately, uses of `-1` have crept deep into the codebase now and are
86 // hard to track.
87 static constexpr int64_t kDynamicSize = -1;
88 static constexpr int64_t kDynamicStrideOrOffset =
89 std::numeric_limits<int64_t>::min();
90
91 /// Return the element type.
92 Type getElementType() const;
93
94 /// If an element type is an integer or a float, return its width. Otherwise,
95 /// abort.
96 unsigned getElementTypeBitWidth() const;
97
98 /// If it has static shape, return the number of elements. Otherwise, abort.
99 int64_t getNumElements() const;
100
101 /// If this is a ranked type, return the rank. Otherwise, abort.
102 int64_t getRank() const;
103
104 /// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
105 /// have a rank, while unranked tensors do not.
106 bool hasRank() const;
107
108 /// If this is a ranked type, return the shape. Otherwise, abort.
109 ArrayRef<int64_t> getShape() const;
110
111 /// If this is unranked type or any dimension has unknown size (<0), it
112 /// doesn't have static shape. If all dimensions have known size (>= 0), it
113 /// has static shape.
114 bool hasStaticShape() const;
115
116 /// If this has a static shape and the shape is equal to `shape` return true.
117 bool hasStaticShape(ArrayRef<int64_t> shape) const;
118
119 /// If this is a ranked type, return the number of dimensions with dynamic
120 /// size. Otherwise, abort.
121 int64_t getNumDynamicDims() const;
122
123 /// If this is ranked type, return the size of the specified dimension.
124 /// Otherwise, abort.
125 int64_t getDimSize(unsigned idx) const;
126
127 /// Returns true if this dimension has a dynamic size (for ranked types);
128 /// aborts for unranked types.
129 bool isDynamicDim(unsigned idx) const;
130
131 /// Returns the position of the dynamic dimension relative to just the dynamic
132 /// dimensions, given its `index` within the shape.
133 unsigned getDynamicDimIndex(unsigned index) const;
134
135 /// Get the total amount of bits occupied by a value of this type. This does
136 /// not take into account any memory layout or widening constraints, e.g. a
137 /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
138 /// it will likely be stored as in a 4xi64 vector register. Fail an assertion
139 /// if the size cannot be computed statically, i.e. if the type has a dynamic
140 /// shape or if its elemental type does not have a known bit width.
141 int64_t getSizeInBits() const;
142
143 /// Methods for support type inquiry through isa, cast, and dyn_cast.
144 static bool classof(Type type);
145
146 /// Whether the given dimension size indicates a dynamic dimension.
isDynamic(int64_t dSize)147 static constexpr bool isDynamic(int64_t dSize) {
148 return dSize == kDynamicSize;
149 }
isDynamicStrideOrOffset(int64_t dStrideOrOffset)150 static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
151 return dStrideOrOffset == kDynamicStrideOrOffset;
152 }
153 };
154
155 //===----------------------------------------------------------------------===//
156 // VectorType
157 //===----------------------------------------------------------------------===//
158
159 /// Vector types represent multi-dimensional SIMD vectors, and have a fixed
160 /// known constant shape with one or more dimension.
161 class VectorType
162 : public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> {
163 public:
164 using Base::Base;
165
166 /// Get or create a new VectorType of the provided shape and element type.
167 /// Assumes the arguments define a well-formed VectorType.
168 static VectorType get(ArrayRef<int64_t> shape, Type elementType);
169
170 /// Get or create a new VectorType of the provided shape and element type
171 /// declared at the given, potentially unknown, location. If the VectorType
172 /// defined by the arguments would be ill-formed, emit errors and return
173 /// nullptr-wrapping type.
174 static VectorType getChecked(Location location, ArrayRef<int64_t> shape,
175 Type elementType);
176
177 /// Verify the construction of a vector type.
178 static LogicalResult verifyConstructionInvariants(Location loc,
179 ArrayRef<int64_t> shape,
180 Type elementType);
181
182 /// Returns true of the given type can be used as an element of a vector type.
183 /// In particular, vectors can consist of integer or float primitives.
isValidElementType(Type t)184 static bool isValidElementType(Type t) {
185 return t.isa<IntegerType, FloatType>();
186 }
187
188 ArrayRef<int64_t> getShape() const;
189
190 /// Get or create a new VectorType with the same shape as `this` and an
191 /// element type of bitwidth scaled by `scale`.
192 /// Return null if the scaled element type cannot be represented.
193 VectorType scaleElementBitwidth(unsigned scale);
194 };
195
196 //===----------------------------------------------------------------------===//
197 // TensorType
198 //===----------------------------------------------------------------------===//
199
200 /// Tensor types represent multi-dimensional arrays, and have two variants:
201 /// RankedTensorType and UnrankedTensorType.
202 class TensorType : public ShapedType {
203 public:
204 using ShapedType::ShapedType;
205
206 /// Return true if the specified element type is ok in a tensor.
207 static bool isValidElementType(Type type);
208
209 /// Methods for support type inquiry through isa, cast, and dyn_cast.
210 static bool classof(Type type);
211 };
212
213 //===----------------------------------------------------------------------===//
214 // RankedTensorType
215
216 /// Ranked tensor types represent multi-dimensional arrays that have a shape
217 /// with a fixed number of dimensions. Each shape element can be a non-negative
218 /// integer or unknown (represented by -1).
219 class RankedTensorType
220 : public Type::TypeBase<RankedTensorType, TensorType,
221 detail::RankedTensorTypeStorage> {
222 public:
223 using Base::Base;
224
225 /// Get or create a new RankedTensorType of the provided shape and element
226 /// type. Assumes the arguments define a well-formed type.
227 static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType);
228
229 /// Get or create a new RankedTensorType of the provided shape and element
230 /// type declared at the given, potentially unknown, location. If the
231 /// RankedTensorType defined by the arguments would be ill-formed, emit errors
232 /// and return a nullptr-wrapping type.
233 static RankedTensorType getChecked(Location location, ArrayRef<int64_t> shape,
234 Type elementType);
235
236 /// Verify the construction of a ranked tensor type.
237 static LogicalResult verifyConstructionInvariants(Location loc,
238 ArrayRef<int64_t> shape,
239 Type elementType);
240
241 ArrayRef<int64_t> getShape() const;
242 };
243
244 //===----------------------------------------------------------------------===//
245 // UnrankedTensorType
246
247 /// Unranked tensor types represent multi-dimensional arrays that have an
248 /// unknown shape.
249 class UnrankedTensorType
250 : public Type::TypeBase<UnrankedTensorType, TensorType,
251 detail::UnrankedTensorTypeStorage> {
252 public:
253 using Base::Base;
254
255 /// Get or create a new UnrankedTensorType of the provided shape and element
256 /// type. Assumes the arguments define a well-formed type.
257 static UnrankedTensorType get(Type elementType);
258
259 /// Get or create a new UnrankedTensorType of the provided shape and element
260 /// type declared at the given, potentially unknown, location. If the
261 /// UnrankedTensorType defined by the arguments would be ill-formed, emit
262 /// errors and return a nullptr-wrapping type.
263 static UnrankedTensorType getChecked(Location location, Type elementType);
264
265 /// Verify the construction of a unranked tensor type.
266 static LogicalResult verifyConstructionInvariants(Location loc,
267 Type elementType);
268
getShape()269 ArrayRef<int64_t> getShape() const { return llvm::None; }
270 };
271
272 //===----------------------------------------------------------------------===//
273 // BaseMemRefType
274 //===----------------------------------------------------------------------===//
275
276 /// Base MemRef for Ranked and Unranked variants
277 class BaseMemRefType : public ShapedType {
278 public:
279 using ImplType = detail::BaseMemRefTypeStorage;
280 using ShapedType::ShapedType;
281
282 /// Return true if the specified element type is ok in a memref.
283 static bool isValidElementType(Type type);
284
285 /// Methods for support type inquiry through isa, cast, and dyn_cast.
286 static bool classof(Type type);
287
288 /// Returns the memory space in which data referred to by this memref resides.
289 unsigned getMemorySpace() const;
290 };
291
292 //===----------------------------------------------------------------------===//
293 // MemRefType
294
295 /// MemRef types represent a region of memory that have a shape with a fixed
296 /// number of dimensions. Each shape element can be a non-negative integer or
297 /// unknown (represented by -1). MemRef types also have an affine map
298 /// composition, represented as an array AffineMap pointers.
299 class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
300 detail::MemRefTypeStorage> {
301 public:
302 /// This is a builder type that keeps local references to arguments. Arguments
303 /// that are passed into the builder must out-live the builder.
304 class Builder {
305 public:
306 // Build from another MemRefType.
Builder(MemRefType other)307 explicit Builder(MemRefType other)
308 : shape(other.getShape()), elementType(other.getElementType()),
309 affineMaps(other.getAffineMaps()),
310 memorySpace(other.getMemorySpace()) {}
311
312 // Build from scratch.
Builder(ArrayRef<int64_t> shape,Type elementType)313 Builder(ArrayRef<int64_t> shape, Type elementType)
314 : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) {
315 }
316
setShape(ArrayRef<int64_t> newShape)317 Builder &setShape(ArrayRef<int64_t> newShape) {
318 shape = newShape;
319 return *this;
320 }
321
setElementType(Type newElementType)322 Builder &setElementType(Type newElementType) {
323 elementType = newElementType;
324 return *this;
325 }
326
setAffineMaps(ArrayRef<AffineMap> newAffineMaps)327 Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) {
328 affineMaps = newAffineMaps;
329 return *this;
330 }
331
setMemorySpace(unsigned newMemorySpace)332 Builder &setMemorySpace(unsigned newMemorySpace) {
333 memorySpace = newMemorySpace;
334 return *this;
335 }
336
MemRefType()337 operator MemRefType() {
338 return MemRefType::get(shape, elementType, affineMaps, memorySpace);
339 }
340
341 private:
342 ArrayRef<int64_t> shape;
343 Type elementType;
344 ArrayRef<AffineMap> affineMaps;
345 unsigned memorySpace;
346 };
347
348 using Base::Base;
349
350 /// Get or create a new MemRefType based on shape, element type, affine
351 /// map composition, and memory space. Assumes the arguments define a
352 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType
353 /// construction failures.
354 static MemRefType get(ArrayRef<int64_t> shape, Type elementType,
355 ArrayRef<AffineMap> affineMapComposition = {},
356 unsigned memorySpace = 0);
357
358 /// Get or create a new MemRefType based on shape, element type, affine
359 /// map composition, and memory space declared at the given location.
360 /// If the location is unknown, the last argument should be an instance of
361 /// UnknownLoc. If the MemRefType defined by the arguments would be
362 /// ill-formed, emits errors (to the handler registered with the context or to
363 /// the error stream) and returns nullptr.
364 static MemRefType getChecked(Location location, ArrayRef<int64_t> shape,
365 Type elementType,
366 ArrayRef<AffineMap> affineMapComposition,
367 unsigned memorySpace);
368
369 ArrayRef<int64_t> getShape() const;
370
371 /// Returns an array of affine map pointers representing the memref affine
372 /// map composition.
373 ArrayRef<AffineMap> getAffineMaps() const;
374
375 // TODO: merge these two special values in a single one used everywhere.
376 // Unfortunately, uses of `-1` have crept deep into the codebase now and are
377 // hard to track.
getDynamicStrideOrOffset()378 static int64_t getDynamicStrideOrOffset() {
379 return ShapedType::kDynamicStrideOrOffset;
380 }
381
382 private:
383 /// Get or create a new MemRefType defined by the arguments. If the resulting
384 /// type would be ill-formed, return nullptr. If the location is provided,
385 /// emit detailed error messages.
386 static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType,
387 ArrayRef<AffineMap> affineMapComposition,
388 unsigned memorySpace, Optional<Location> location);
389 using Base::getImpl;
390 };
391
392 //===----------------------------------------------------------------------===//
393 // UnrankedMemRefType
394
395 /// Unranked MemRef type represent multi-dimensional MemRefs that
396 /// have an unknown rank.
397 class UnrankedMemRefType
398 : public Type::TypeBase<UnrankedMemRefType, BaseMemRefType,
399 detail::UnrankedMemRefTypeStorage> {
400 public:
401 using Base::Base;
402
403 /// Get or create a new UnrankedMemRefType of the provided element
404 /// type and memory space
405 static UnrankedMemRefType get(Type elementType, unsigned memorySpace);
406
407 /// Get or create a new UnrankedMemRefType of the provided element
408 /// type and memory space declared at the given, potentially unknown,
409 /// location. If the UnrankedMemRefType defined by the arguments would be
410 /// ill-formed, emit errors and return a nullptr-wrapping type.
411 static UnrankedMemRefType getChecked(Location location, Type elementType,
412 unsigned memorySpace);
413
414 /// Verify the construction of a unranked memref type.
415 static LogicalResult verifyConstructionInvariants(Location loc,
416 Type elementType,
417 unsigned memorySpace);
418
getShape()419 ArrayRef<int64_t> getShape() const { return llvm::None; }
420 };
421 } // end namespace mlir
422
423 //===----------------------------------------------------------------------===//
424 // Tablegen Type Declarations
425 //===----------------------------------------------------------------------===//
426
427 #define GET_TYPEDEF_CLASSES
428 #include "mlir/IR/BuiltinTypes.h.inc"
429
430 //===----------------------------------------------------------------------===//
431 // Deferred Method Definitions
432 //===----------------------------------------------------------------------===//
433
434 namespace mlir {
classof(Type type)435 inline bool BaseMemRefType::classof(Type type) {
436 return type.isa<MemRefType, UnrankedMemRefType>();
437 }
438
isValidElementType(Type type)439 inline bool BaseMemRefType::isValidElementType(Type type) {
440 return type.isIntOrIndexOrFloat() || type.isa<ComplexType, VectorType>();
441 }
442
classof(Type type)443 inline bool FloatType::classof(Type type) {
444 return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
445 Float80Type, Float128Type>();
446 }
447
getBF16(MLIRContext * ctx)448 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
449 return BFloat16Type::get(ctx);
450 }
451
getF16(MLIRContext * ctx)452 inline FloatType FloatType::getF16(MLIRContext *ctx) {
453 return Float16Type::get(ctx);
454 }
455
getF32(MLIRContext * ctx)456 inline FloatType FloatType::getF32(MLIRContext *ctx) {
457 return Float32Type::get(ctx);
458 }
459
getF64(MLIRContext * ctx)460 inline FloatType FloatType::getF64(MLIRContext *ctx) {
461 return Float64Type::get(ctx);
462 }
463
getF80(MLIRContext * ctx)464 inline FloatType FloatType::getF80(MLIRContext *ctx) {
465 return Float80Type::get(ctx);
466 }
467
getF128(MLIRContext * ctx)468 inline FloatType FloatType::getF128(MLIRContext *ctx) {
469 return Float128Type::get(ctx);
470 }
471
classof(Type type)472 inline bool ShapedType::classof(Type type) {
473 return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
474 UnrankedMemRefType, MemRefType>();
475 }
476
classof(Type type)477 inline bool TensorType::classof(Type type) {
478 return type.isa<RankedTensorType, UnrankedTensorType>();
479 }
480
481 //===----------------------------------------------------------------------===//
482 // Type Utilities
483 //===----------------------------------------------------------------------===//
484
485 /// Returns the strides of the MemRef if the layout map is in strided form.
486 /// MemRefs with layout maps in strided form include:
487 /// 1. empty or identity layout map, in which case the stride information is
488 /// the canonical form computed from sizes;
489 /// 2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`,
490 /// where K and ki's are constants or symbols.
491 ///
492 /// A stride specification is a list of integer values that are either static
493 /// or dynamic (encoded with getDynamicStrideOrOffset()). Strides encode the
494 /// distance in the number of elements between successive entries along a
495 /// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>`
496 /// specifies a view into a non-contiguous memory region of `42` by `16` `f32`
497 /// elements in which the distance between two consecutive elements along the
498 /// outer dimension is `1` and the distance between two consecutive elements
499 /// along the inner dimension is `64`.
500 ///
501 /// Returns whether a simple strided form can be extracted from the composition
502 /// of the layout map.
503 ///
504 /// The convention is that the strides for dimensions d0, .. dn appear in
505 /// order to make indexing intuitive into the result.
506 LogicalResult getStridesAndOffset(MemRefType t,
507 SmallVectorImpl<int64_t> &strides,
508 int64_t &offset);
509 LogicalResult getStridesAndOffset(MemRefType t,
510 SmallVectorImpl<AffineExpr> &strides,
511 AffineExpr &offset);
512
513 /// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset()
514 /// represents a dynamic value), return the single result AffineMap which
515 /// represents the linearized strided layout map. Dimensions correspond to the
516 /// offset followed by the strides in order. Symbols are inserted for each
517 /// dynamic dimension in order. A stride cannot take value `0`.
518 ///
519 /// Examples:
520 /// =========
521 ///
522 /// 1. For offset: 0 strides: ?, ?, 1 return
523 /// (i, j, k)[M, N]->(M * i + N * j + k)
524 ///
525 /// 2. For offset: 3 strides: 32, ?, 16 return
526 /// (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k)
527 ///
528 /// 3. For offset: ? strides: ?, ?, ? return
529 /// (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k)
530 AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
531 MLIRContext *context);
532
533 /// Return a version of `t` with identity layout if it can be determined
534 /// statically that the layout is the canonical contiguous strided layout.
535 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
536 /// `t` with simplified layout.
537 MemRefType canonicalizeStridedLayout(MemRefType t);
538
539 /// Return a version of `t` with a layout that has all dynamic offset and
540 /// strides. This is used to erase the static layout.
541 MemRefType eraseStridedLayout(MemRefType t);
542
543 /// Given MemRef `sizes` that are either static or dynamic, returns the
544 /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
545 /// once a dynamic dimension is encountered, all canonical strides become
546 /// dynamic and need to be encoded with a different symbol.
547 /// For canonical strides expressions, the offset is always 0 and and fastest
548 /// varying stride is always `1`.
549 ///
550 /// Examples:
551 /// - memref<3x4x5xf32> has canonical stride expression
552 /// `20*exprs[0] + 5*exprs[1] + exprs[2]`.
553 /// - memref<3x?x5xf32> has canonical stride expression
554 /// `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
555 /// - memref<3x4x?xf32> has canonical stride expression
556 /// `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
557 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
558 ArrayRef<AffineExpr> exprs,
559 MLIRContext *context);
560
561 /// Return the result of makeCanonicalStrudedLayoutExpr for the common case
562 /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
563 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
564 MLIRContext *context);
565
566 /// Return true if the layout for `t` is compatible with strided semantics.
567 bool isStrided(MemRefType t);
568
569 } // end namespace mlir
570
571 #endif // MLIR_IR_BUILTINTYPES_H
572