1 //===- FunctionExtras.h - Function type erasure utilities -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file provides a collection of function (or more generally, callable)
10 /// type erasure utilities supplementing those provided by the standard library
11 /// in `<function>`.
12 ///
13 /// It provides `unique_function`, which works like `std::function` but supports
14 /// move-only callable objects and const-qualification.
15 ///
16 /// Future plans:
17 /// - Add a `function` that provides ref-qualified support, which doesn't work
18 ///   with `std::function`.
19 /// - Provide support for specifying multiple signatures to type erase callable
20 ///   objects with an overload set, such as those produced by generic lambdas.
21 /// - Expand to include a copyable utility that directly replaces std::function
22 ///   but brings the above improvements.
23 ///
24 /// Note that LLVM's utilities are greatly simplified by not supporting
25 /// allocators.
26 ///
27 /// If the standard library ever begins to provide comparable facilities we can
28 /// consider switching to those.
29 ///
30 //===----------------------------------------------------------------------===//
31 
32 #ifndef LLVM_ADT_FUNCTION_EXTRAS_H
33 #define LLVM_ADT_FUNCTION_EXTRAS_H
34 
35 #include "llvm/ADT/PointerIntPair.h"
36 #include "llvm/ADT/PointerUnion.h"
37 #include "llvm/Support/MemAlloc.h"
38 #include "llvm/Support/type_traits.h"
39 #include <memory>
40 #include <type_traits>
41 
42 namespace llvm {
43 
44 /// unique_function is a type-erasing functor similar to std::function.
45 ///
46 /// It can hold move-only function objects, like lambdas capturing unique_ptrs.
47 /// Accordingly, it is movable but not copyable.
48 ///
49 /// It supports const-qualification:
50 /// - unique_function<int() const> has a const operator().
51 ///   It can only hold functions which themselves have a const operator().
52 /// - unique_function<int()> has a non-const operator().
53 ///   It can hold functions with a non-const operator(), like mutable lambdas.
54 template <typename FunctionT> class unique_function;
55 
56 namespace detail {
57 
58 template <typename T>
59 using EnableIfTrivial =
60     std::enable_if_t<llvm::is_trivially_move_constructible<T>::value &&
61                      std::is_trivially_destructible<T>::value>;
62 
63 template <typename ReturnT, typename... ParamTs> class UniqueFunctionBase {
64 protected:
65   static constexpr size_t InlineStorageSize = sizeof(void *) * 3;
66 
67   // MSVC has a bug and ICEs if we give it a particular dependent value
68   // expression as part of the `std::conditional` below. To work around this,
69   // we build that into a template struct's constexpr bool.
70   template <typename T> struct IsSizeLessThanThresholdT {
71     static constexpr bool value = sizeof(T) <= (2 * sizeof(void *));
72   };
73 
74   // Provide a type function to map parameters that won't observe extra copies
75   // or moves and which are small enough to likely pass in register to values
76   // and all other types to l-value reference types. We use this to compute the
77   // types used in our erased call utility to minimize copies and moves unless
78   // doing so would force things unnecessarily into memory.
79   //
80   // The heuristic used is related to common ABI register passing conventions.
81   // It doesn't have to be exact though, and in one way it is more strict
82   // because we want to still be able to observe either moves *or* copies.
83   template <typename T>
84   using AdjustedParamT = typename std::conditional<
85       !std::is_reference<T>::value &&
86           llvm::is_trivially_copy_constructible<T>::value &&
87           llvm::is_trivially_move_constructible<T>::value &&
88           IsSizeLessThanThresholdT<T>::value,
89       T, T &>::type;
90 
91   // The type of the erased function pointer we use as a callback to dispatch to
92   // the stored callable when it is trivial to move and destroy.
93   using CallPtrT = ReturnT (*)(void *CallableAddr,
94                                AdjustedParamT<ParamTs>... Params);
95   using MovePtrT = void (*)(void *LHSCallableAddr, void *RHSCallableAddr);
96   using DestroyPtrT = void (*)(void *CallableAddr);
97 
98   /// A struct to hold a single trivial callback with sufficient alignment for
99   /// our bitpacking.
100   struct alignas(8) TrivialCallback {
101     CallPtrT CallPtr;
102   };
103 
104   /// A struct we use to aggregate three callbacks when we need full set of
105   /// operations.
106   struct alignas(8) NonTrivialCallbacks {
107     CallPtrT CallPtr;
108     MovePtrT MovePtr;
109     DestroyPtrT DestroyPtr;
110   };
111 
112   // Create a pointer union between either a pointer to a static trivial call
113   // pointer in a struct or a pointer to a static struct of the call, move, and
114   // destroy pointers.
115   using CallbackPointerUnionT =
116       PointerUnion<TrivialCallback *, NonTrivialCallbacks *>;
117 
118   // The main storage buffer. This will either have a pointer to out-of-line
119   // storage or an inline buffer storing the callable.
120   union StorageUnionT {
121     // For out-of-line storage we keep a pointer to the underlying storage and
122     // the size. This is enough to deallocate the memory.
123     struct OutOfLineStorageT {
124       void *StoragePtr;
125       size_t Size;
126       size_t Alignment;
127     } OutOfLineStorage;
128     static_assert(
129         sizeof(OutOfLineStorageT) <= InlineStorageSize,
130         "Should always use all of the out-of-line storage for inline storage!");
131 
132     // For in-line storage, we just provide an aligned character buffer. We
133     // provide three pointers worth of storage here.
134     // This is mutable as an inlined `const unique_function<void() const>` may
135     // still modify its own mutable members.
136     mutable
137         typename std::aligned_storage<InlineStorageSize, alignof(void *)>::type
138             InlineStorage;
139   } StorageUnion;
140 
141   // A compressed pointer to either our dispatching callback or our table of
142   // dispatching callbacks and the flag for whether the callable itself is
143   // stored inline or not.
144   PointerIntPair<CallbackPointerUnionT, 1, bool> CallbackAndInlineFlag;
145 
146   bool isInlineStorage() const { return CallbackAndInlineFlag.getInt(); }
147 
148   bool isTrivialCallback() const {
149     return CallbackAndInlineFlag.getPointer().template is<TrivialCallback *>();
150   }
151 
152   CallPtrT getTrivialCallback() const {
153     return CallbackAndInlineFlag.getPointer().template get<TrivialCallback *>()->CallPtr;
154   }
155 
156   NonTrivialCallbacks *getNonTrivialCallbacks() const {
157     return CallbackAndInlineFlag.getPointer()
158         .template get<NonTrivialCallbacks *>();
159   }
160 
161   CallPtrT getCallPtr() const {
162     return isTrivialCallback() ? getTrivialCallback()
163                                : getNonTrivialCallbacks()->CallPtr;
164   }
165 
166   // These three functions are only const in the narrow sense. They return
167   // mutable pointers to function state.
168   // This allows unique_function<T const>::operator() to be const, even if the
169   // underlying functor may be internally mutable.
170   //
171   // const callers must ensure they're only used in const-correct ways.
172   void *getCalleePtr() const {
173     return isInlineStorage() ? getInlineStorage() : getOutOfLineStorage();
174   }
175   void *getInlineStorage() const { return &StorageUnion.InlineStorage; }
176   void *getOutOfLineStorage() const {
177     return StorageUnion.OutOfLineStorage.StoragePtr;
178   }
179 
180   size_t getOutOfLineStorageSize() const {
181     return StorageUnion.OutOfLineStorage.Size;
182   }
183   size_t getOutOfLineStorageAlignment() const {
184     return StorageUnion.OutOfLineStorage.Alignment;
185   }
186 
187   void setOutOfLineStorage(void *Ptr, size_t Size, size_t Alignment) {
188     StorageUnion.OutOfLineStorage = {Ptr, Size, Alignment};
189   }
190 
191   template <typename CalledAsT>
192   static ReturnT CallImpl(void *CallableAddr,
193                           AdjustedParamT<ParamTs>... Params) {
194     auto &Func = *reinterpret_cast<CalledAsT *>(CallableAddr);
195     return Func(std::forward<ParamTs>(Params)...);
196   }
197 
198   template <typename CallableT>
199   static void MoveImpl(void *LHSCallableAddr, void *RHSCallableAddr) noexcept {
200     new (LHSCallableAddr)
201         CallableT(std::move(*reinterpret_cast<CallableT *>(RHSCallableAddr)));
202   }
203 
204   template <typename CallableT>
205   static void DestroyImpl(void *CallableAddr) noexcept {
206     reinterpret_cast<CallableT *>(CallableAddr)->~CallableT();
207   }
208 
209   // The pointers to call/move/destroy functions are determined for each
210   // callable type (and called-as type, which determines the overload chosen).
211   // (definitions are out-of-line).
212 
213   // By default, we need an object that contains all the different
214   // type erased behaviors needed. Create a static instance of the struct type
215   // here and each instance will contain a pointer to it.
216   // Wrap in a struct to avoid https://gcc.gnu.org/PR71954
217   template <typename CallableT, typename CalledAs, typename Enable = void>
218   struct CallbacksHolder {
219     static NonTrivialCallbacks Callbacks;
220   };
221   // See if we can create a trivial callback. We need the callable to be
222   // trivially moved and trivially destroyed so that we don't have to store
223   // type erased callbacks for those operations.
224   template <typename CallableT, typename CalledAs>
225   struct CallbacksHolder<CallableT, CalledAs, EnableIfTrivial<CallableT>> {
226     static TrivialCallback Callbacks;
227   };
228 
229   // A simple tag type so the call-as type to be passed to the constructor.
230   template <typename T> struct CalledAs {};
231 
232   // Essentially the "main" unique_function constructor, but subclasses
233   // provide the qualified type to be used for the call.
234   // (We always store a T, even if the call will use a pointer to const T).
235   template <typename CallableT, typename CalledAsT>
236   UniqueFunctionBase(CallableT Callable, CalledAs<CalledAsT>) {
237     bool IsInlineStorage = true;
238     void *CallableAddr = getInlineStorage();
239     if (sizeof(CallableT) > InlineStorageSize ||
240         alignof(CallableT) > alignof(decltype(StorageUnion.InlineStorage))) {
241       IsInlineStorage = false;
242       // Allocate out-of-line storage. FIXME: Use an explicit alignment
243       // parameter in C++17 mode.
244       auto Size = sizeof(CallableT);
245       auto Alignment = alignof(CallableT);
246       CallableAddr = allocate_buffer(Size, Alignment);
247       setOutOfLineStorage(CallableAddr, Size, Alignment);
248     }
249 
250     // Now move into the storage.
251     new (CallableAddr) CallableT(std::move(Callable));
252     CallbackAndInlineFlag.setPointerAndInt(
253         &CallbacksHolder<CallableT, CalledAsT>::Callbacks, IsInlineStorage);
254   }
255 
256   ~UniqueFunctionBase() {
257     if (!CallbackAndInlineFlag.getPointer())
258       return;
259 
260     // Cache this value so we don't re-check it after type-erased operations.
261     bool IsInlineStorage = isInlineStorage();
262 
263     if (!isTrivialCallback())
264       getNonTrivialCallbacks()->DestroyPtr(
265           IsInlineStorage ? getInlineStorage() : getOutOfLineStorage());
266 
267     if (!IsInlineStorage)
268       deallocate_buffer(getOutOfLineStorage(), getOutOfLineStorageSize(),
269                         getOutOfLineStorageAlignment());
270   }
271 
272   UniqueFunctionBase(UniqueFunctionBase &&RHS) noexcept {
273     // Copy the callback and inline flag.
274     CallbackAndInlineFlag = RHS.CallbackAndInlineFlag;
275 
276     // If the RHS is empty, just copying the above is sufficient.
277     if (!RHS)
278       return;
279 
280     if (!isInlineStorage()) {
281       // The out-of-line case is easiest to move.
282       StorageUnion.OutOfLineStorage = RHS.StorageUnion.OutOfLineStorage;
283     } else if (isTrivialCallback()) {
284       // Move is trivial, just memcpy the bytes across.
285       memcpy(getInlineStorage(), RHS.getInlineStorage(), InlineStorageSize);
286     } else {
287       // Non-trivial move, so dispatch to a type-erased implementation.
288       getNonTrivialCallbacks()->MovePtr(getInlineStorage(),
289                                         RHS.getInlineStorage());
290     }
291 
292     // Clear the old callback and inline flag to get back to as-if-null.
293     RHS.CallbackAndInlineFlag = {};
294 
295 #ifndef NDEBUG
296     // In debug builds, we also scribble across the rest of the storage.
297     memset(RHS.getInlineStorage(), 0xAD, InlineStorageSize);
298 #endif
299   }
300 
301   UniqueFunctionBase &operator=(UniqueFunctionBase &&RHS) noexcept {
302     if (this == &RHS)
303       return *this;
304 
305     // Because we don't try to provide any exception safety guarantees we can
306     // implement move assignment very simply by first destroying the current
307     // object and then move-constructing over top of it.
308     this->~UniqueFunctionBase();
309     new (this) UniqueFunctionBase(std::move(RHS));
310     return *this;
311   }
312 
313   UniqueFunctionBase() = default;
314 
315 public:
316   explicit operator bool() const {
317     return (bool)CallbackAndInlineFlag.getPointer();
318   }
319 };
320 
321 template <typename R, typename... P>
322 template <typename CallableT, typename CalledAsT, typename Enable>
323 typename UniqueFunctionBase<R, P...>::NonTrivialCallbacks UniqueFunctionBase<
324     R, P...>::CallbacksHolder<CallableT, CalledAsT, Enable>::Callbacks = {
325     &CallImpl<CalledAsT>, &MoveImpl<CallableT>, &DestroyImpl<CallableT>};
326 
327 template <typename R, typename... P>
328 template <typename CallableT, typename CalledAsT>
329 typename UniqueFunctionBase<R, P...>::TrivialCallback
330     UniqueFunctionBase<R, P...>::CallbacksHolder<
331         CallableT, CalledAsT, EnableIfTrivial<CallableT>>::Callbacks{
332         &CallImpl<CalledAsT>};
333 
334 } // namespace detail
335 
336 template <typename R, typename... P>
337 class unique_function<R(P...)> : public detail::UniqueFunctionBase<R, P...> {
338   using Base = detail::UniqueFunctionBase<R, P...>;
339 
340 public:
341   unique_function() = default;
342   unique_function(std::nullptr_t) {}
343   unique_function(unique_function &&) = default;
344   unique_function(const unique_function &) = delete;
345   unique_function &operator=(unique_function &&) = default;
346   unique_function &operator=(const unique_function &) = delete;
347 
348   template <typename CallableT>
349   unique_function(CallableT Callable)
350       : Base(std::forward<CallableT>(Callable),
351              typename Base::template CalledAs<CallableT>{}) {}
352 
353   R operator()(P... Params) {
354     return this->getCallPtr()(this->getCalleePtr(), Params...);
355   }
356 };
357 
358 template <typename R, typename... P>
359 class unique_function<R(P...) const>
360     : public detail::UniqueFunctionBase<R, P...> {
361   using Base = detail::UniqueFunctionBase<R, P...>;
362 
363 public:
364   unique_function() = default;
365   unique_function(std::nullptr_t) {}
366   unique_function(unique_function &&) = default;
367   unique_function(const unique_function &) = delete;
368   unique_function &operator=(unique_function &&) = default;
369   unique_function &operator=(const unique_function &) = delete;
370 
371   template <typename CallableT>
372   unique_function(CallableT Callable)
373       : Base(std::forward<CallableT>(Callable),
374              typename Base::template CalledAs<const CallableT>{}) {}
375 
376   R operator()(P... Params) const {
377     return this->getCallPtr()(this->getCalleePtr(), Params...);
378   }
379 };
380 
381 } // end namespace llvm
382 
383 #endif // LLVM_ADT_FUNCTION_H
384