1 //===- FunctionExtras.h - Function type erasure utilities -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// This file provides a collection of function (or more generally, callable) 10 /// type erasure utilities supplementing those provided by the standard library 11 /// in `<function>`. 12 /// 13 /// It provides `unique_function`, which works like `std::function` but supports 14 /// move-only callable objects and const-qualification. 15 /// 16 /// Future plans: 17 /// - Add a `function` that provides ref-qualified support, which doesn't work 18 /// with `std::function`. 19 /// - Provide support for specifying multiple signatures to type erase callable 20 /// objects with an overload set, such as those produced by generic lambdas. 21 /// - Expand to include a copyable utility that directly replaces std::function 22 /// but brings the above improvements. 23 /// 24 /// Note that LLVM's utilities are greatly simplified by not supporting 25 /// allocators. 26 /// 27 /// If the standard library ever begins to provide comparable facilities we can 28 /// consider switching to those. 29 /// 30 //===----------------------------------------------------------------------===// 31 32 #ifndef LLVM_ADT_FUNCTION_EXTRAS_H 33 #define LLVM_ADT_FUNCTION_EXTRAS_H 34 35 #include "llvm/ADT/PointerIntPair.h" 36 #include "llvm/ADT/PointerUnion.h" 37 #include "llvm/Support/MemAlloc.h" 38 #include "llvm/Support/type_traits.h" 39 #include <memory> 40 #include <type_traits> 41 42 namespace llvm { 43 44 /// unique_function is a type-erasing functor similar to std::function. 45 /// 46 /// It can hold move-only function objects, like lambdas capturing unique_ptrs. 47 /// Accordingly, it is movable but not copyable. 48 /// 49 /// It supports const-qualification: 50 /// - unique_function<int() const> has a const operator(). 51 /// It can only hold functions which themselves have a const operator(). 52 /// - unique_function<int()> has a non-const operator(). 53 /// It can hold functions with a non-const operator(), like mutable lambdas. 54 template <typename FunctionT> class unique_function; 55 56 namespace detail { 57 58 template <typename T> 59 using EnableIfTrivial = 60 std::enable_if_t<llvm::is_trivially_move_constructible<T>::value && 61 std::is_trivially_destructible<T>::value>; 62 63 template <typename ReturnT, typename... ParamTs> class UniqueFunctionBase { 64 protected: 65 static constexpr size_t InlineStorageSize = sizeof(void *) * 3; 66 67 // MSVC has a bug and ICEs if we give it a particular dependent value 68 // expression as part of the `std::conditional` below. To work around this, 69 // we build that into a template struct's constexpr bool. 70 template <typename T> struct IsSizeLessThanThresholdT { 71 static constexpr bool value = sizeof(T) <= (2 * sizeof(void *)); 72 }; 73 74 // Provide a type function to map parameters that won't observe extra copies 75 // or moves and which are small enough to likely pass in register to values 76 // and all other types to l-value reference types. We use this to compute the 77 // types used in our erased call utility to minimize copies and moves unless 78 // doing so would force things unnecessarily into memory. 79 // 80 // The heuristic used is related to common ABI register passing conventions. 81 // It doesn't have to be exact though, and in one way it is more strict 82 // because we want to still be able to observe either moves *or* copies. 83 template <typename T> 84 using AdjustedParamT = typename std::conditional< 85 !std::is_reference<T>::value && 86 llvm::is_trivially_copy_constructible<T>::value && 87 llvm::is_trivially_move_constructible<T>::value && 88 IsSizeLessThanThresholdT<T>::value, 89 T, T &>::type; 90 91 // The type of the erased function pointer we use as a callback to dispatch to 92 // the stored callable when it is trivial to move and destroy. 93 using CallPtrT = ReturnT (*)(void *CallableAddr, 94 AdjustedParamT<ParamTs>... Params); 95 using MovePtrT = void (*)(void *LHSCallableAddr, void *RHSCallableAddr); 96 using DestroyPtrT = void (*)(void *CallableAddr); 97 98 /// A struct to hold a single trivial callback with sufficient alignment for 99 /// our bitpacking. 100 struct alignas(8) TrivialCallback { 101 CallPtrT CallPtr; 102 }; 103 104 /// A struct we use to aggregate three callbacks when we need full set of 105 /// operations. 106 struct alignas(8) NonTrivialCallbacks { 107 CallPtrT CallPtr; 108 MovePtrT MovePtr; 109 DestroyPtrT DestroyPtr; 110 }; 111 112 // Create a pointer union between either a pointer to a static trivial call 113 // pointer in a struct or a pointer to a static struct of the call, move, and 114 // destroy pointers. 115 using CallbackPointerUnionT = 116 PointerUnion<TrivialCallback *, NonTrivialCallbacks *>; 117 118 // The main storage buffer. This will either have a pointer to out-of-line 119 // storage or an inline buffer storing the callable. 120 union StorageUnionT { 121 // For out-of-line storage we keep a pointer to the underlying storage and 122 // the size. This is enough to deallocate the memory. 123 struct OutOfLineStorageT { 124 void *StoragePtr; 125 size_t Size; 126 size_t Alignment; 127 } OutOfLineStorage; 128 static_assert( 129 sizeof(OutOfLineStorageT) <= InlineStorageSize, 130 "Should always use all of the out-of-line storage for inline storage!"); 131 132 // For in-line storage, we just provide an aligned character buffer. We 133 // provide three pointers worth of storage here. 134 // This is mutable as an inlined `const unique_function<void() const>` may 135 // still modify its own mutable members. 136 mutable 137 typename std::aligned_storage<InlineStorageSize, alignof(void *)>::type 138 InlineStorage; 139 } StorageUnion; 140 141 // A compressed pointer to either our dispatching callback or our table of 142 // dispatching callbacks and the flag for whether the callable itself is 143 // stored inline or not. 144 PointerIntPair<CallbackPointerUnionT, 1, bool> CallbackAndInlineFlag; 145 146 bool isInlineStorage() const { return CallbackAndInlineFlag.getInt(); } 147 148 bool isTrivialCallback() const { 149 return CallbackAndInlineFlag.getPointer().template is<TrivialCallback *>(); 150 } 151 152 CallPtrT getTrivialCallback() const { 153 return CallbackAndInlineFlag.getPointer().template get<TrivialCallback *>()->CallPtr; 154 } 155 156 NonTrivialCallbacks *getNonTrivialCallbacks() const { 157 return CallbackAndInlineFlag.getPointer() 158 .template get<NonTrivialCallbacks *>(); 159 } 160 161 CallPtrT getCallPtr() const { 162 return isTrivialCallback() ? getTrivialCallback() 163 : getNonTrivialCallbacks()->CallPtr; 164 } 165 166 // These three functions are only const in the narrow sense. They return 167 // mutable pointers to function state. 168 // This allows unique_function<T const>::operator() to be const, even if the 169 // underlying functor may be internally mutable. 170 // 171 // const callers must ensure they're only used in const-correct ways. 172 void *getCalleePtr() const { 173 return isInlineStorage() ? getInlineStorage() : getOutOfLineStorage(); 174 } 175 void *getInlineStorage() const { return &StorageUnion.InlineStorage; } 176 void *getOutOfLineStorage() const { 177 return StorageUnion.OutOfLineStorage.StoragePtr; 178 } 179 180 size_t getOutOfLineStorageSize() const { 181 return StorageUnion.OutOfLineStorage.Size; 182 } 183 size_t getOutOfLineStorageAlignment() const { 184 return StorageUnion.OutOfLineStorage.Alignment; 185 } 186 187 void setOutOfLineStorage(void *Ptr, size_t Size, size_t Alignment) { 188 StorageUnion.OutOfLineStorage = {Ptr, Size, Alignment}; 189 } 190 191 template <typename CalledAsT> 192 static ReturnT CallImpl(void *CallableAddr, 193 AdjustedParamT<ParamTs>... Params) { 194 auto &Func = *reinterpret_cast<CalledAsT *>(CallableAddr); 195 return Func(std::forward<ParamTs>(Params)...); 196 } 197 198 template <typename CallableT> 199 static void MoveImpl(void *LHSCallableAddr, void *RHSCallableAddr) noexcept { 200 new (LHSCallableAddr) 201 CallableT(std::move(*reinterpret_cast<CallableT *>(RHSCallableAddr))); 202 } 203 204 template <typename CallableT> 205 static void DestroyImpl(void *CallableAddr) noexcept { 206 reinterpret_cast<CallableT *>(CallableAddr)->~CallableT(); 207 } 208 209 // The pointers to call/move/destroy functions are determined for each 210 // callable type (and called-as type, which determines the overload chosen). 211 // (definitions are out-of-line). 212 213 // By default, we need an object that contains all the different 214 // type erased behaviors needed. Create a static instance of the struct type 215 // here and each instance will contain a pointer to it. 216 // Wrap in a struct to avoid https://gcc.gnu.org/PR71954 217 template <typename CallableT, typename CalledAs, typename Enable = void> 218 struct CallbacksHolder { 219 static NonTrivialCallbacks Callbacks; 220 }; 221 // See if we can create a trivial callback. We need the callable to be 222 // trivially moved and trivially destroyed so that we don't have to store 223 // type erased callbacks for those operations. 224 template <typename CallableT, typename CalledAs> 225 struct CallbacksHolder<CallableT, CalledAs, EnableIfTrivial<CallableT>> { 226 static TrivialCallback Callbacks; 227 }; 228 229 // A simple tag type so the call-as type to be passed to the constructor. 230 template <typename T> struct CalledAs {}; 231 232 // Essentially the "main" unique_function constructor, but subclasses 233 // provide the qualified type to be used for the call. 234 // (We always store a T, even if the call will use a pointer to const T). 235 template <typename CallableT, typename CalledAsT> 236 UniqueFunctionBase(CallableT Callable, CalledAs<CalledAsT>) { 237 bool IsInlineStorage = true; 238 void *CallableAddr = getInlineStorage(); 239 if (sizeof(CallableT) > InlineStorageSize || 240 alignof(CallableT) > alignof(decltype(StorageUnion.InlineStorage))) { 241 IsInlineStorage = false; 242 // Allocate out-of-line storage. FIXME: Use an explicit alignment 243 // parameter in C++17 mode. 244 auto Size = sizeof(CallableT); 245 auto Alignment = alignof(CallableT); 246 CallableAddr = allocate_buffer(Size, Alignment); 247 setOutOfLineStorage(CallableAddr, Size, Alignment); 248 } 249 250 // Now move into the storage. 251 new (CallableAddr) CallableT(std::move(Callable)); 252 CallbackAndInlineFlag.setPointerAndInt( 253 &CallbacksHolder<CallableT, CalledAsT>::Callbacks, IsInlineStorage); 254 } 255 256 ~UniqueFunctionBase() { 257 if (!CallbackAndInlineFlag.getPointer()) 258 return; 259 260 // Cache this value so we don't re-check it after type-erased operations. 261 bool IsInlineStorage = isInlineStorage(); 262 263 if (!isTrivialCallback()) 264 getNonTrivialCallbacks()->DestroyPtr( 265 IsInlineStorage ? getInlineStorage() : getOutOfLineStorage()); 266 267 if (!IsInlineStorage) 268 deallocate_buffer(getOutOfLineStorage(), getOutOfLineStorageSize(), 269 getOutOfLineStorageAlignment()); 270 } 271 272 UniqueFunctionBase(UniqueFunctionBase &&RHS) noexcept { 273 // Copy the callback and inline flag. 274 CallbackAndInlineFlag = RHS.CallbackAndInlineFlag; 275 276 // If the RHS is empty, just copying the above is sufficient. 277 if (!RHS) 278 return; 279 280 if (!isInlineStorage()) { 281 // The out-of-line case is easiest to move. 282 StorageUnion.OutOfLineStorage = RHS.StorageUnion.OutOfLineStorage; 283 } else if (isTrivialCallback()) { 284 // Move is trivial, just memcpy the bytes across. 285 memcpy(getInlineStorage(), RHS.getInlineStorage(), InlineStorageSize); 286 } else { 287 // Non-trivial move, so dispatch to a type-erased implementation. 288 getNonTrivialCallbacks()->MovePtr(getInlineStorage(), 289 RHS.getInlineStorage()); 290 } 291 292 // Clear the old callback and inline flag to get back to as-if-null. 293 RHS.CallbackAndInlineFlag = {}; 294 295 #ifndef NDEBUG 296 // In debug builds, we also scribble across the rest of the storage. 297 memset(RHS.getInlineStorage(), 0xAD, InlineStorageSize); 298 #endif 299 } 300 301 UniqueFunctionBase &operator=(UniqueFunctionBase &&RHS) noexcept { 302 if (this == &RHS) 303 return *this; 304 305 // Because we don't try to provide any exception safety guarantees we can 306 // implement move assignment very simply by first destroying the current 307 // object and then move-constructing over top of it. 308 this->~UniqueFunctionBase(); 309 new (this) UniqueFunctionBase(std::move(RHS)); 310 return *this; 311 } 312 313 UniqueFunctionBase() = default; 314 315 public: 316 explicit operator bool() const { 317 return (bool)CallbackAndInlineFlag.getPointer(); 318 } 319 }; 320 321 template <typename R, typename... P> 322 template <typename CallableT, typename CalledAsT, typename Enable> 323 typename UniqueFunctionBase<R, P...>::NonTrivialCallbacks UniqueFunctionBase< 324 R, P...>::CallbacksHolder<CallableT, CalledAsT, Enable>::Callbacks = { 325 &CallImpl<CalledAsT>, &MoveImpl<CallableT>, &DestroyImpl<CallableT>}; 326 327 template <typename R, typename... P> 328 template <typename CallableT, typename CalledAsT> 329 typename UniqueFunctionBase<R, P...>::TrivialCallback 330 UniqueFunctionBase<R, P...>::CallbacksHolder< 331 CallableT, CalledAsT, EnableIfTrivial<CallableT>>::Callbacks{ 332 &CallImpl<CalledAsT>}; 333 334 } // namespace detail 335 336 template <typename R, typename... P> 337 class unique_function<R(P...)> : public detail::UniqueFunctionBase<R, P...> { 338 using Base = detail::UniqueFunctionBase<R, P...>; 339 340 public: 341 unique_function() = default; 342 unique_function(std::nullptr_t) {} 343 unique_function(unique_function &&) = default; 344 unique_function(const unique_function &) = delete; 345 unique_function &operator=(unique_function &&) = default; 346 unique_function &operator=(const unique_function &) = delete; 347 348 template <typename CallableT> 349 unique_function(CallableT Callable) 350 : Base(std::forward<CallableT>(Callable), 351 typename Base::template CalledAs<CallableT>{}) {} 352 353 R operator()(P... Params) { 354 return this->getCallPtr()(this->getCalleePtr(), Params...); 355 } 356 }; 357 358 template <typename R, typename... P> 359 class unique_function<R(P...) const> 360 : public detail::UniqueFunctionBase<R, P...> { 361 using Base = detail::UniqueFunctionBase<R, P...>; 362 363 public: 364 unique_function() = default; 365 unique_function(std::nullptr_t) {} 366 unique_function(unique_function &&) = default; 367 unique_function(const unique_function &) = delete; 368 unique_function &operator=(unique_function &&) = default; 369 unique_function &operator=(const unique_function &) = delete; 370 371 template <typename CallableT> 372 unique_function(CallableT Callable) 373 : Base(std::forward<CallableT>(Callable), 374 typename Base::template CalledAs<const CallableT>{}) {} 375 376 R operator()(P... Params) const { 377 return this->getCallPtr()(this->getCalleePtr(), Params...); 378 } 379 }; 380 381 } // end namespace llvm 382 383 #endif // LLVM_ADT_FUNCTION_H 384