1 //===- Types.h - MLIR Type Classes ------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_TYPES_H
10 #define MLIR_IR_TYPES_H
11 
12 #include "mlir/IR/TypeSupport.h"
13 #include "llvm/ADT/ArrayRef.h"
14 #include "llvm/ADT/DenseMapInfo.h"
15 #include "llvm/Support/PointerLikeTypeTraits.h"
16 
17 namespace mlir {
18 class FloatType;
19 class Identifier;
20 class IndexType;
21 class IntegerType;
22 class MLIRContext;
23 class TypeStorage;
24 class TypeRange;
25 
26 namespace detail {
27 struct FunctionTypeStorage;
28 struct OpaqueTypeStorage;
29 } // namespace detail
30 
31 /// Instances of the Type class are uniqued, have an immutable identifier and an
32 /// optional mutable component.  They wrap a pointer to the storage object owned
33 /// by MLIRContext.  Therefore, instances of Type are passed around by value.
34 ///
35 /// Some types are "primitives" meaning they do not have any parameters, for
36 /// example the Index type.  Parametric types have additional information that
37 /// differentiates the types of the same class, for example the Integer type has
38 /// bitwidth, making i8 and i16 belong to the same kind by be different
39 /// instances of the IntegerType. Type parameters are part of the unique
40 /// immutable key.  The mutable component of the type can be modified after the
41 /// type is created, but cannot affect the identity of the type.
42 ///
43 /// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
44 ///
45 /// Derived type classes are expected to implement several required
46 /// implementation hooks:
47 ///  * Optional:
48 ///    - static LogicalResult verifyConstructionInvariants(Location loc,
49 ///                                                        Args... args)
50 ///      * This method is invoked when calling the 'TypeBase::get/getChecked'
51 ///        methods to ensure that the arguments passed in are valid to construct
52 ///        a type instance with.
53 ///      * This method is expected to return failure if a type cannot be
54 ///        constructed with 'args', success otherwise.
55 ///      * 'args' must correspond with the arguments passed into the
56 ///        'TypeBase::get' call.
57 ///
58 ///
59 /// Type storage objects inherit from TypeStorage and contain the following:
60 ///    - The dialect that defined the type.
61 ///    - Any parameters of the type.
62 ///    - An optional mutable component.
63 /// For non-parametric types, a convenience DefaultTypeStorage is provided.
64 /// Parametric storage types must derive TypeStorage and respect the following:
65 ///    - Define a type alias, KeyTy, to a type that uniquely identifies the
66 ///      instance of the type.
67 ///      * The key type must be constructible from the values passed into the
68 ///        detail::TypeUniquer::get call.
69 ///      * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
70 ///        storage class must define a hashing method:
71 ///         'static unsigned hashKey(const KeyTy &)'
72 ///
73 ///    - Provide a method, 'bool operator==(const KeyTy &) const', to
74 ///      compare the storage instance against an instance of the key type.
75 ///
76 ///    - Provide a static construction method:
77 ///        'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)'
78 ///      that builds a unique instance of the derived storage. The arguments to
79 ///      this function are an allocator to store any uniqued data within the
80 ///      context and the key type for this storage.
81 ///
82 ///    - If they have a mutable component, this component must not be a part of
83 //       the key.
84 class Type {
85 public:
86   /// Utility class for implementing types.
87   template <typename ConcreteType, typename BaseType, typename StorageType,
88             template <typename T> class... Traits>
89   using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
90                                            detail::TypeUniquer, Traits...>;
91 
92   using ImplType = TypeStorage;
93 
Type()94   constexpr Type() : impl(nullptr) {}
Type(const ImplType * impl)95   /* implicit */ Type(const ImplType *impl)
96       : impl(const_cast<ImplType *>(impl)) {}
97 
98   Type(const Type &other) = default;
99   Type &operator=(const Type &other) = default;
100 
101   bool operator==(Type other) const { return impl == other.impl; }
102   bool operator!=(Type other) const { return !(*this == other); }
103   explicit operator bool() const { return impl; }
104 
105   bool operator!() const { return impl == nullptr; }
106 
107   template <typename U> bool isa() const;
108   template <typename First, typename Second, typename... Rest>
109   bool isa() const;
110   template <typename U> U dyn_cast() const;
111   template <typename U> U dyn_cast_or_null() const;
112   template <typename U> U cast() const;
113 
114   // Support type casting Type to itself.
classof(Type)115   static bool classof(Type) { return true; }
116 
117   /// Return a unique identifier for the concrete type. This is used to support
118   /// dynamic type casting.
getTypeID()119   TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
120 
121   /// Return the LLVMContext in which this type was uniqued.
122   MLIRContext *getContext() const;
123 
124   /// Get the dialect this type is registered to.
125   Dialect &getDialect() const;
126 
127   // Convenience predicates.  This is only for floating point types,
128   // derived types should use isa/dyn_cast.
129   bool isIndex();
130   bool isBF16();
131   bool isF16();
132   bool isF32();
133   bool isF64();
134 
135   /// Return true if this is an integer type with the specified width.
136   bool isInteger(unsigned width);
137   /// Return true if this is a signless integer type (with the specified width).
138   bool isSignlessInteger();
139   bool isSignlessInteger(unsigned width);
140   /// Return true if this is a signed integer type (with the specified width).
141   bool isSignedInteger();
142   bool isSignedInteger(unsigned width);
143   /// Return true if this is an unsigned integer type (with the specified
144   /// width).
145   bool isUnsignedInteger();
146   bool isUnsignedInteger(unsigned width);
147 
148   /// Return the bit width of an integer or a float type, assert failure on
149   /// other types.
150   unsigned getIntOrFloatBitWidth();
151 
152   /// Return true if this is a signless integer or index type.
153   bool isSignlessIntOrIndex();
154   /// Return true if this is a signless integer, index, or float type.
155   bool isSignlessIntOrIndexOrFloat();
156   /// Return true of this is a signless integer or a float type.
157   bool isSignlessIntOrFloat();
158 
159   /// Return true if this is an integer (of any signedness) or an index type.
160   bool isIntOrIndex();
161   /// Return true if this is an integer (of any signedness) or a float type.
162   bool isIntOrFloat();
163   /// Return true if this is an integer (of any signedness), index, or float
164   /// type.
165   bool isIntOrIndexOrFloat();
166 
167   /// Print the current type.
168   void print(raw_ostream &os);
169   void dump();
170 
171   friend ::llvm::hash_code hash_value(Type arg);
172 
173   /// Methods for supporting PointerLikeTypeTraits.
getAsOpaquePointer()174   const void *getAsOpaquePointer() const {
175     return static_cast<const void *>(impl);
176   }
getFromOpaquePointer(const void * pointer)177   static Type getFromOpaquePointer(const void *pointer) {
178     return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
179   }
180 
181   /// Return the abstract type descriptor for this type.
getAbstractType()182   const AbstractType &getAbstractType() { return impl->getAbstractType(); }
183 
184 protected:
185   ImplType *impl;
186 };
187 
188 inline raw_ostream &operator<<(raw_ostream &os, Type type) {
189   type.print(os);
190   return os;
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // TypeTraitBase
195 //===----------------------------------------------------------------------===//
196 
197 namespace TypeTrait {
198 /// This class represents the base of a type trait.
199 template <typename ConcreteType, template <typename> class TraitType>
200 using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>;
201 } // namespace TypeTrait
202 
203 //===----------------------------------------------------------------------===//
204 // TypeInterface
205 //===----------------------------------------------------------------------===//
206 
207 /// This class represents the base of a type interface. See the definition  of
208 /// `detail::Interface` for requirements on the `Traits` type.
209 template <typename ConcreteType, typename Traits>
210 class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
211                                                TypeTrait::TraitBase> {
212 public:
213   using Base = TypeInterface<ConcreteType, Traits>;
214   using InterfaceBase =
215       detail::Interface<ConcreteType, Type, Traits, Type, TypeTrait::TraitBase>;
216   using InterfaceBase::InterfaceBase;
217 
218 private:
219   /// Returns the impl interface instance for the given type.
getInterfaceFor(Type type)220   static typename InterfaceBase::Concept *getInterfaceFor(Type type) {
221     return type.getAbstractType().getInterface<ConcreteType>();
222   }
223 
224   /// Allow access to 'getInterfaceFor'.
225   friend InterfaceBase;
226 };
227 
228 //===----------------------------------------------------------------------===//
229 // FunctionType
230 //===----------------------------------------------------------------------===//
231 
232 /// Function types map from a list of inputs to a list of results.
233 class FunctionType
234     : public Type::TypeBase<FunctionType, Type, detail::FunctionTypeStorage> {
235 public:
236   using Base::Base;
237 
238   static FunctionType get(TypeRange inputs, TypeRange results,
239                           MLIRContext *context);
240 
241   /// Input types.
242   unsigned getNumInputs() const;
getInput(unsigned i)243   Type getInput(unsigned i) const { return getInputs()[i]; }
244   ArrayRef<Type> getInputs() const;
245 
246   /// Result types.
247   unsigned getNumResults() const;
getResult(unsigned i)248   Type getResult(unsigned i) const { return getResults()[i]; }
249   ArrayRef<Type> getResults() const;
250 
251   /// Returns a new function type without the specified arguments and results.
252   FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
253                                         ArrayRef<unsigned> resultIndices);
254 };
255 
256 //===----------------------------------------------------------------------===//
257 // OpaqueType
258 //===----------------------------------------------------------------------===//
259 
260 /// Opaque types represent types of non-registered dialects. These are types
261 /// represented in their raw string form, and can only usefully be tested for
262 /// type equality.
263 class OpaqueType
264     : public Type::TypeBase<OpaqueType, Type, detail::OpaqueTypeStorage> {
265 public:
266   using Base::Base;
267 
268   /// Get or create a new OpaqueType with the provided dialect and string data.
269   static OpaqueType get(Identifier dialect, StringRef typeData,
270                         MLIRContext *context);
271 
272   /// Get or create a new OpaqueType with the provided dialect and string data.
273   /// If the given identifier is not a valid namespace for a dialect, then a
274   /// null type is returned.
275   static OpaqueType getChecked(Identifier dialect, StringRef typeData,
276                                MLIRContext *context, Location location);
277 
278   /// Returns the dialect namespace of the opaque type.
279   Identifier getDialectNamespace() const;
280 
281   /// Returns the raw type data of the opaque type.
282   StringRef getTypeData() const;
283 
284   /// Verify the construction of an opaque type.
285   static LogicalResult verifyConstructionInvariants(Location loc,
286                                                     Identifier dialect,
287                                                     StringRef typeData);
288 };
289 
290 // Make Type hashable.
hash_value(Type arg)291 inline ::llvm::hash_code hash_value(Type arg) {
292   return ::llvm::hash_value(arg.impl);
293 }
294 
isa()295 template <typename U> bool Type::isa() const {
296   assert(impl && "isa<> used on a null type.");
297   return U::classof(*this);
298 }
299 
300 template <typename First, typename Second, typename... Rest>
isa()301 bool Type::isa() const {
302   return isa<First>() || isa<Second, Rest...>();
303 }
304 
dyn_cast()305 template <typename U> U Type::dyn_cast() const {
306   return isa<U>() ? U(impl) : U(nullptr);
307 }
dyn_cast_or_null()308 template <typename U> U Type::dyn_cast_or_null() const {
309   return (impl && isa<U>()) ? U(impl) : U(nullptr);
310 }
cast()311 template <typename U> U Type::cast() const {
312   assert(isa<U>());
313   return U(impl);
314 }
315 
316 } // end namespace mlir
317 
318 namespace llvm {
319 
320 // Type hash just like pointers.
321 template <> struct DenseMapInfo<mlir::Type> {
322   static mlir::Type getEmptyKey() {
323     auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
324     return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
325   }
326   static mlir::Type getTombstoneKey() {
327     auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
328     return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
329   }
330   static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); }
331   static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
332 };
333 
334 /// We align TypeStorage by 8, so allow LLVM to steal the low bits.
335 template <> struct PointerLikeTypeTraits<mlir::Type> {
336 public:
337   static inline void *getAsVoidPointer(mlir::Type I) {
338     return const_cast<void *>(I.getAsOpaquePointer());
339   }
340   static inline mlir::Type getFromVoidPointer(void *P) {
341     return mlir::Type::getFromOpaquePointer(P);
342   }
343   static constexpr int NumLowBitsAvailable = 3;
344 };
345 
346 } // namespace llvm
347 
348 #endif // MLIR_IR_TYPES_H
349