1 //===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines several support classes for defining interfaces. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_SUPPORT_INTERFACESUPPORT_H 14 #define MLIR_SUPPORT_INTERFACESUPPORT_H 15 16 #include "mlir/Support/TypeID.h" 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/ADT/DenseMap.h" 19 #include "llvm/Support/TypeName.h" 20 21 namespace mlir { 22 namespace detail { 23 //===----------------------------------------------------------------------===// 24 // Interface 25 //===----------------------------------------------------------------------===// 26 27 /// This class represents an abstract interface. An interface is a simplified 28 /// mechanism for attaching concept based polymorphism to a class hierarchy. An 29 /// interface is comprised of two components: 30 /// * The derived interface class: This is what users interact with, and invoke 31 /// methods on. 32 /// * An interface `Trait` class: This is the class that is attached to the 33 /// object implementing the interface. It is the mechanism with which models 34 /// are specialized. 35 /// 36 /// Derived interfaces types must provide the following template types: 37 /// * ConcreteType: The CRTP derived type. 38 /// * ValueT: The opaque type the derived interface operates on. For example 39 /// `Operation*` for operation interfaces, or `Attribute` for 40 /// attribute interfaces. 41 /// * Traits: A class that contains definitions for a 'Concept' and a 'Model' 42 /// class. The 'Concept' class defines an abstract virtual interface, 43 /// where as the 'Model' class implements this interface for a 44 /// specific derived T type. Both of these classes *must* not contain 45 /// non-static data. A simple example is shown below: 46 /// 47 /// ```c++ 48 /// struct ExampleInterfaceTraits { 49 /// struct Concept { 50 /// virtual unsigned getNumInputs(T t) const = 0; 51 /// }; 52 /// template <typename DerivedT> class Model { 53 /// unsigned getNumInputs(T t) const final { 54 /// return cast<DerivedT>(t).getNumInputs(); 55 /// } 56 /// }; 57 /// }; 58 /// ``` 59 /// 60 /// * BaseType: A desired base type for the interface. This is a class that 61 /// provides that provides specific functionality for the `ValueT` 62 /// value. For instance the specific `Op` that will wrap the 63 /// `Operation*` for an `OpInterface`. 64 /// * BaseTrait: The base type for the interface trait. This is the base class 65 /// to use for the interface trait that will be attached to each 66 /// instance of `ValueT` that implements this interface. 67 /// 68 template <typename ConcreteType, typename ValueT, typename Traits, 69 typename BaseType, 70 template <typename, template <typename> class> class BaseTrait> 71 class Interface : public BaseType { 72 public: 73 using Concept = typename Traits::Concept; 74 template <typename T> using Model = typename Traits::template Model<T>; 75 template <typename T> 76 using FallbackModel = typename Traits::template FallbackModel<T>; 77 using InterfaceBase = 78 Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>; 79 template <typename T, typename U> 80 using ExternalModel = typename Traits::template ExternalModel<T, U>; 81 82 /// This is a special trait that registers a given interface with an object. 83 template <typename ConcreteT> 84 struct Trait : public BaseTrait<ConcreteT, Trait> { 85 using ModelT = Model<ConcreteT>; 86 87 /// Define an accessor for the ID of this interface. getInterfaceIDTrait88 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } 89 }; 90 91 /// Construct an interface from an instance of the value type. 92 Interface(ValueT t = ValueT()) BaseType(t)93 : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { 94 assert((!t || impl) && "expected value to provide interface instance"); 95 } Interface(std::nullptr_t)96 Interface(std::nullptr_t) : BaseType(ValueT()), impl(nullptr) {} 97 98 /// Construct an interface instance from a type that implements this 99 /// interface's trait. 100 template <typename T, typename std::enable_if_t< 101 std::is_base_of<Trait<T>, T>::value> * = nullptr> Interface(T t)102 Interface(T t) 103 : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { 104 assert((!t || impl) && "expected value to provide interface instance"); 105 } 106 107 /// Support 'classof' by checking if the given object defines the concrete 108 /// interface. classof(ValueT t)109 static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); } 110 111 /// Define an accessor for the ID of this interface. getInterfaceID()112 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } 113 114 protected: 115 /// Get the raw concept in the correct derived concept type. getImpl()116 const Concept *getImpl() const { return impl; } getImpl()117 Concept *getImpl() { return impl; } 118 119 private: 120 /// A pointer to the impl concept object. 121 Concept *impl; 122 }; 123 124 //===----------------------------------------------------------------------===// 125 // InterfaceMap 126 //===----------------------------------------------------------------------===// 127 128 /// Utility to filter a given sequence of types base upon a predicate. 129 template <bool> 130 struct FilterTypeT { 131 template <class E> 132 using type = std::tuple<E>; 133 }; 134 template <> 135 struct FilterTypeT<false> { 136 template <class E> 137 using type = std::tuple<>; 138 }; 139 template <template <class> class Pred, class... Es> 140 struct FilterTypes { 141 using type = decltype(std::tuple_cat( 142 std::declval< 143 typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...)); 144 }; 145 146 namespace { 147 /// Type trait indicating whether all template arguments are 148 /// trivially-destructible. 149 template <typename... Args> 150 struct all_trivially_destructible; 151 152 template <typename Arg, typename... Args> 153 struct all_trivially_destructible<Arg, Args...> { 154 static constexpr const bool value = 155 std::is_trivially_destructible<Arg>::value && 156 all_trivially_destructible<Args...>::value; 157 }; 158 159 template <> 160 struct all_trivially_destructible<> { 161 static constexpr const bool value = true; 162 }; 163 } // namespace 164 165 /// This class provides an efficient mapping between a given `Interface` type, 166 /// and a particular implementation of its concept. 167 class InterfaceMap { 168 /// Trait to check if T provides a static 'getInterfaceID' method. 169 template <typename T, typename... Args> 170 using has_get_interface_id = decltype(T::getInterfaceID()); 171 template <typename T> 172 using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>; 173 template <typename... Types> 174 using num_interface_types = typename std::tuple_size< 175 typename FilterTypes<detect_get_interface_id, Types...>::type>; 176 177 public: 178 InterfaceMap(InterfaceMap &&) = default; 179 ~InterfaceMap() { 180 for (auto &it : interfaces) 181 free(it.second); 182 } 183 184 /// Construct an InterfaceMap with the given set of template types. For 185 /// convenience given that object trait lists may contain other non-interface 186 /// types, not all of the types need to be interfaces. The provided types that 187 /// do not represent interfaces are not added to the interface map. 188 template <typename... Types> 189 static std::enable_if_t<num_interface_types<Types...>::value != 0, 190 InterfaceMap> 191 get() { 192 // Filter the provided types for those that are interfaces. 193 using FilteredTupleType = 194 typename FilterTypes<detect_get_interface_id, Types...>::type; 195 return getImpl((FilteredTupleType *)nullptr); 196 } 197 198 template <typename... Types> 199 static std::enable_if_t<num_interface_types<Types...>::value == 0, 200 InterfaceMap> 201 get() { 202 return InterfaceMap(); 203 } 204 205 /// Returns an instance of the concept object for the given interface if it 206 /// was registered to this map, null otherwise. 207 template <typename T> typename T::Concept *lookup() const { 208 return reinterpret_cast<typename T::Concept *>(lookup(T::getInterfaceID())); 209 } 210 211 /// Returns true if the interface map contains an interface for the given id. 212 bool contains(TypeID interfaceID) const { return lookup(interfaceID); } 213 214 /// Create an InterfaceMap given with the implementation of the interfaces. 215 /// The use of this constructor is in general discouraged in favor of 216 /// 'InterfaceMap::get<InterfaceA, ...>()'. 217 InterfaceMap(MutableArrayRef<std::pair<TypeID, void *>> elements) 218 : interfaces(elements.begin(), elements.end()) { 219 llvm::sort(interfaces, [](const auto &lhs, const auto &rhs) { 220 return compare(lhs.first, rhs.first); 221 }); 222 } 223 224 /// Insert the given models as implementations of the corresponding interfaces 225 /// for the concrete attribute class. 226 template <typename... IfaceModels> 227 void insert() { 228 static_assert(all_trivially_destructible<IfaceModels...>::value, 229 "interface models must be trivially destructible"); 230 std::pair<TypeID, void *> elements[] = { 231 std::make_pair(IfaceModels::Interface::getInterfaceID(), 232 new (malloc(sizeof(IfaceModels))) IfaceModels())...}; 233 insert(elements); 234 } 235 236 private: 237 /// Compare two TypeID instances by comparing the underlying pointer. 238 static bool compare(TypeID lhs, TypeID rhs) { 239 return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer(); 240 } 241 242 InterfaceMap() = default; 243 244 void insert(ArrayRef<std::pair<TypeID, void *>> elements); 245 246 template <typename... Ts> 247 static InterfaceMap getImpl(std::tuple<Ts...> *) { 248 std::pair<TypeID, void *> elements[] = {std::make_pair( 249 Ts::getInterfaceID(), 250 new (malloc(sizeof(typename Ts::ModelT))) typename Ts::ModelT())...}; 251 return InterfaceMap(elements); 252 } 253 254 /// Returns an instance of the concept object for the given interface id if it 255 /// was registered to this map, null otherwise. 256 void *lookup(TypeID id) const { 257 auto it = llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) { 258 return compare(it.first, id); 259 }); 260 return (it != interfaces.end() && it->first == id) ? it->second : nullptr; 261 } 262 263 /// A list of interface instances, sorted by TypeID. 264 SmallVector<std::pair<TypeID, void *>> interfaces; 265 }; 266 267 } // end namespace detail 268 } // end namespace mlir 269 270 #endif 271