1 //===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines several support classes for defining interfaces.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_SUPPORT_INTERFACESUPPORT_H
14 #define MLIR_SUPPORT_INTERFACESUPPORT_H
15 
16 #include "mlir/Support/TypeID.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/TypeName.h"
20 
21 namespace mlir {
22 namespace detail {
23 //===----------------------------------------------------------------------===//
24 // Interface
25 //===----------------------------------------------------------------------===//
26 
27 /// This class represents an abstract interface. An interface is a simplified
28 /// mechanism for attaching concept based polymorphism to a class hierarchy. An
29 /// interface is comprised of two components:
30 /// * The derived interface class: This is what users interact with, and invoke
31 ///   methods on.
32 /// * An interface `Trait` class: This is the class that is attached to the
33 ///   object implementing the interface. It is the mechanism with which models
34 ///   are specialized.
35 ///
36 /// Derived interfaces types must provide the following template types:
37 /// * ConcreteType: The CRTP derived type.
38 /// * ValueT: The opaque type the derived interface operates on. For example
39 ///           `Operation*` for operation interfaces, or `Attribute` for
40 ///           attribute interfaces.
41 /// * Traits: A class that contains definitions for a 'Concept' and a 'Model'
42 ///           class. The 'Concept' class defines an abstract virtual interface,
43 ///           where as the 'Model' class implements this interface for a
44 ///           specific derived T type. Both of these classes *must* not contain
45 ///           non-static data. A simple example is shown below:
46 ///
47 /// ```c++
48 ///    struct ExampleInterfaceTraits {
49 ///      struct Concept {
50 ///        virtual unsigned getNumInputs(T t) const = 0;
51 ///      };
52 ///      template <typename DerivedT> class Model {
53 ///        unsigned getNumInputs(T t) const final {
54 ///          return cast<DerivedT>(t).getNumInputs();
55 ///        }
56 ///      };
57 ///    };
58 /// ```
59 ///
60 /// * BaseType: A desired base type for the interface. This is a class that
61 ///             provides that provides specific functionality for the `ValueT`
62 ///             value. For instance the specific `Op` that will wrap the
63 ///             `Operation*` for an `OpInterface`.
64 /// * BaseTrait: The base type for the interface trait. This is the base class
65 ///              to use for the interface trait that will be attached to each
66 ///              instance of `ValueT` that implements this interface.
67 ///
68 template <typename ConcreteType, typename ValueT, typename Traits,
69           typename BaseType,
70           template <typename, template <typename> class> class BaseTrait>
71 class Interface : public BaseType {
72 public:
73   using Concept = typename Traits::Concept;
74   template <typename T> using Model = typename Traits::template Model<T>;
75   template <typename T>
76   using FallbackModel = typename Traits::template FallbackModel<T>;
77   using InterfaceBase =
78       Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
79   template <typename T, typename U>
80   using ExternalModel = typename Traits::template ExternalModel<T, U>;
81 
82   /// This is a special trait that registers a given interface with an object.
83   template <typename ConcreteT>
84   struct Trait : public BaseTrait<ConcreteT, Trait> {
85     using ModelT = Model<ConcreteT>;
86 
87     /// Define an accessor for the ID of this interface.
getInterfaceIDTrait88     static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
89   };
90 
91   /// Construct an interface from an instance of the value type.
92   Interface(ValueT t = ValueT())
BaseType(t)93       : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
94     assert((!t || impl) && "expected value to provide interface instance");
95   }
Interface(std::nullptr_t)96   Interface(std::nullptr_t) : BaseType(ValueT()), impl(nullptr) {}
97 
98   /// Construct an interface instance from a type that implements this
99   /// interface's trait.
100   template <typename T, typename std::enable_if_t<
101                             std::is_base_of<Trait<T>, T>::value> * = nullptr>
Interface(T t)102   Interface(T t)
103       : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
104     assert((!t || impl) && "expected value to provide interface instance");
105   }
106 
107   /// Support 'classof' by checking if the given object defines the concrete
108   /// interface.
classof(ValueT t)109   static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); }
110 
111   /// Define an accessor for the ID of this interface.
getInterfaceID()112   static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
113 
114 protected:
115   /// Get the raw concept in the correct derived concept type.
getImpl()116   const Concept *getImpl() const { return impl; }
getImpl()117   Concept *getImpl() { return impl; }
118 
119 private:
120   /// A pointer to the impl concept object.
121   Concept *impl;
122 };
123 
124 //===----------------------------------------------------------------------===//
125 // InterfaceMap
126 //===----------------------------------------------------------------------===//
127 
128 /// Utility to filter a given sequence of types base upon a predicate.
129 template <bool>
130 struct FilterTypeT {
131   template <class E>
132   using type = std::tuple<E>;
133 };
134 template <>
135 struct FilterTypeT<false> {
136   template <class E>
137   using type = std::tuple<>;
138 };
139 template <template <class> class Pred, class... Es>
140 struct FilterTypes {
141   using type = decltype(std::tuple_cat(
142       std::declval<
143           typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...));
144 };
145 
146 namespace {
147 /// Type trait indicating whether all template arguments are
148 /// trivially-destructible.
149 template <typename... Args>
150 struct all_trivially_destructible;
151 
152 template <typename Arg, typename... Args>
153 struct all_trivially_destructible<Arg, Args...> {
154   static constexpr const bool value =
155       std::is_trivially_destructible<Arg>::value &&
156       all_trivially_destructible<Args...>::value;
157 };
158 
159 template <>
160 struct all_trivially_destructible<> {
161   static constexpr const bool value = true;
162 };
163 } // namespace
164 
165 /// This class provides an efficient mapping between a given `Interface` type,
166 /// and a particular implementation of its concept.
167 class InterfaceMap {
168   /// Trait to check if T provides a static 'getInterfaceID' method.
169   template <typename T, typename... Args>
170   using has_get_interface_id = decltype(T::getInterfaceID());
171   template <typename T>
172   using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
173   template <typename... Types>
174   using num_interface_types = typename std::tuple_size<
175       typename FilterTypes<detect_get_interface_id, Types...>::type>;
176 
177 public:
178   InterfaceMap(InterfaceMap &&) = default;
179   ~InterfaceMap() {
180     for (auto &it : interfaces)
181       free(it.second);
182   }
183 
184   /// Construct an InterfaceMap with the given set of template types. For
185   /// convenience given that object trait lists may contain other non-interface
186   /// types, not all of the types need to be interfaces. The provided types that
187   /// do not represent interfaces are not added to the interface map.
188   template <typename... Types>
189   static std::enable_if_t<num_interface_types<Types...>::value != 0,
190                           InterfaceMap>
191   get() {
192     // Filter the provided types for those that are interfaces.
193     using FilteredTupleType =
194         typename FilterTypes<detect_get_interface_id, Types...>::type;
195     return getImpl((FilteredTupleType *)nullptr);
196   }
197 
198   template <typename... Types>
199   static std::enable_if_t<num_interface_types<Types...>::value == 0,
200                           InterfaceMap>
201   get() {
202     return InterfaceMap();
203   }
204 
205   /// Returns an instance of the concept object for the given interface if it
206   /// was registered to this map, null otherwise.
207   template <typename T> typename T::Concept *lookup() const {
208     return reinterpret_cast<typename T::Concept *>(lookup(T::getInterfaceID()));
209   }
210 
211   /// Returns true if the interface map contains an interface for the given id.
212   bool contains(TypeID interfaceID) const { return lookup(interfaceID); }
213 
214   /// Create an InterfaceMap given with the implementation of the interfaces.
215   /// The use of this constructor is in general discouraged in favor of
216   /// 'InterfaceMap::get<InterfaceA, ...>()'.
217   InterfaceMap(MutableArrayRef<std::pair<TypeID, void *>> elements)
218       : interfaces(elements.begin(), elements.end()) {
219     llvm::sort(interfaces, [](const auto &lhs, const auto &rhs) {
220       return compare(lhs.first, rhs.first);
221     });
222   }
223 
224   /// Insert the given models as implementations of the corresponding interfaces
225   /// for the concrete attribute class.
226   template <typename... IfaceModels>
227   void insert() {
228     static_assert(all_trivially_destructible<IfaceModels...>::value,
229                   "interface models must be trivially destructible");
230     std::pair<TypeID, void *> elements[] = {
231         std::make_pair(IfaceModels::Interface::getInterfaceID(),
232                        new (malloc(sizeof(IfaceModels))) IfaceModels())...};
233     insert(elements);
234   }
235 
236 private:
237   /// Compare two TypeID instances by comparing the underlying pointer.
238   static bool compare(TypeID lhs, TypeID rhs) {
239     return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
240   }
241 
242   InterfaceMap() = default;
243 
244   void insert(ArrayRef<std::pair<TypeID, void *>> elements);
245 
246   template <typename... Ts>
247   static InterfaceMap getImpl(std::tuple<Ts...> *) {
248     std::pair<TypeID, void *> elements[] = {std::make_pair(
249         Ts::getInterfaceID(),
250         new (malloc(sizeof(typename Ts::ModelT))) typename Ts::ModelT())...};
251     return InterfaceMap(elements);
252   }
253 
254   /// Returns an instance of the concept object for the given interface id if it
255   /// was registered to this map, null otherwise.
256   void *lookup(TypeID id) const {
257     auto it = llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) {
258       return compare(it.first, id);
259     });
260     return (it != interfaces.end() && it->first == id) ? it->second : nullptr;
261   }
262 
263   /// A list of interface instances, sorted by TypeID.
264   SmallVector<std::pair<TypeID, void *>> interfaces;
265 };
266 
267 } // end namespace detail
268 } // end namespace mlir
269 
270 #endif
271