1 //===- llvm/ADT/PointerSumType.h --------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_ADT_POINTERSUMTYPE_H
10 #define LLVM_ADT_POINTERSUMTYPE_H
11 
12 #include "llvm/ADT/bit.h"
13 #include "llvm/ADT/DenseMapInfo.h"
14 #include "llvm/Support/PointerLikeTypeTraits.h"
15 #include <cassert>
16 #include <cstdint>
17 #include <type_traits>
18 
19 namespace llvm {
20 
21 /// A compile time pair of an integer tag and the pointer-like type which it
22 /// indexes within a sum type. Also allows the user to specify a particular
23 /// traits class for pointer types with custom behavior such as over-aligned
24 /// allocation.
25 template <uintptr_t N, typename PointerArgT,
26           typename TraitsArgT = PointerLikeTypeTraits<PointerArgT>>
27 struct PointerSumTypeMember {
28   enum { Tag = N };
29   using PointerT = PointerArgT;
30   using TraitsT = TraitsArgT;
31 };
32 
33 namespace detail {
34 
35 template <typename TagT, typename... MemberTs> struct PointerSumTypeHelper;
36 
37 } // end namespace detail
38 
39 /// A sum type over pointer-like types.
40 ///
41 /// This is a normal tagged union across pointer-like types that uses the low
42 /// bits of the pointers to store the tag.
43 ///
44 /// Each member of the sum type is specified by passing a \c
45 /// PointerSumTypeMember specialization in the variadic member argument list.
46 /// This allows the user to control the particular tag value associated with
47 /// a particular type, use the same type for multiple different tags, and
48 /// customize the pointer-like traits used for a particular member. Note that
49 /// these *must* be specializations of \c PointerSumTypeMember, no other type
50 /// will suffice, even if it provides a compatible interface.
51 ///
52 /// This type implements all of the comparison operators and even hash table
53 /// support by comparing the underlying storage of the pointer values. It
54 /// doesn't support delegating to particular members for comparisons.
55 ///
56 /// It also default constructs to a zero tag with a null pointer, whatever that
57 /// would be. This means that the zero value for the tag type is significant
58 /// and may be desirable to set to a state that is particularly desirable to
59 /// default construct.
60 ///
61 /// Having a supported zero-valued tag also enables getting the address of a
62 /// pointer stored with that tag provided it is stored in its natural bit
63 /// representation. This works because in the case of a zero-valued tag, the
64 /// pointer's value is directly stored into this object and we can expose the
65 /// address of that internal storage. This is especially useful when building an
66 /// `ArrayRef` of a single pointer stored in a sum type.
67 ///
68 /// There is no support for constructing or accessing with a dynamic tag as
69 /// that would fundamentally violate the type safety provided by the sum type.
70 template <typename TagT, typename... MemberTs> class PointerSumType {
71   using HelperT = detail::PointerSumTypeHelper<TagT, MemberTs...>;
72 
73   // We keep both the raw value and the min tag value's pointer in a union. When
74   // the minimum tag value is zero, this allows code below to cleanly expose the
75   // address of the zero-tag pointer instead of just the zero-tag pointer
76   // itself. This is especially useful when building `ArrayRef`s out of a single
77   // pointer. However, we have to carefully access the union due to the active
78   // member potentially changing. When we *store* a new value, we directly
79   // access the union to allow us to store using the obvious types. However,
80   // when we *read* a value, we copy the underlying storage out to avoid relying
81   // on one member or the other being active.
82   union StorageT {
83     // Ensure we get a null default constructed value. We don't use a member
84     // initializer because some compilers seem to not implement those correctly
85     // for a union.
StorageT()86     StorageT() : Value(0) {}
87 
88     uintptr_t Value;
89 
90     typename HelperT::template Lookup<HelperT::MinTag>::PointerT MinTagPointer;
91   };
92 
93   StorageT Storage;
94 
95 public:
96   constexpr PointerSumType() = default;
97 
98   /// A typed setter to a given tagged member of the sum type.
99   template <TagT N>
set(typename HelperT::template Lookup<N>::PointerT Pointer)100   void set(typename HelperT::template Lookup<N>::PointerT Pointer) {
101     void *V = HelperT::template Lookup<N>::TraitsT::getAsVoidPointer(Pointer);
102     assert((reinterpret_cast<uintptr_t>(V) & HelperT::TagMask) == 0 &&
103            "Pointer is insufficiently aligned to store the discriminant!");
104     Storage.Value = reinterpret_cast<uintptr_t>(V) | N;
105   }
106 
107   /// A typed constructor for a specific tagged member of the sum type.
108   template <TagT N>
109   static PointerSumType
create(typename HelperT::template Lookup<N>::PointerT Pointer)110   create(typename HelperT::template Lookup<N>::PointerT Pointer) {
111     PointerSumType Result;
112     Result.set<N>(Pointer);
113     return Result;
114   }
115 
116   /// Clear the value to null with the min tag type.
clear()117   void clear() { set<HelperT::MinTag>(nullptr); }
118 
getTag()119   TagT getTag() const {
120     return static_cast<TagT>(getOpaqueValue() & HelperT::TagMask);
121   }
122 
is()123   template <TagT N> bool is() const { return N == getTag(); }
124 
get()125   template <TagT N> typename HelperT::template Lookup<N>::PointerT get() const {
126     void *P = is<N>() ? getVoidPtr() : nullptr;
127     return HelperT::template Lookup<N>::TraitsT::getFromVoidPointer(P);
128   }
129 
130   template <TagT N>
cast()131   typename HelperT::template Lookup<N>::PointerT cast() const {
132     assert(is<N>() && "This instance has a different active member.");
133     return HelperT::template Lookup<N>::TraitsT::getFromVoidPointer(
134         getVoidPtr());
135   }
136 
137   /// If the tag is zero and the pointer's value isn't changed when being
138   /// stored, get the address of the stored value type-punned to the zero-tag's
139   /// pointer type.
140   typename HelperT::template Lookup<HelperT::MinTag>::PointerT const *
getAddrOfZeroTagPointer()141   getAddrOfZeroTagPointer() const {
142     return const_cast<PointerSumType *>(this)->getAddrOfZeroTagPointer();
143   }
144 
145   /// If the tag is zero and the pointer's value isn't changed when being
146   /// stored, get the address of the stored value type-punned to the zero-tag's
147   /// pointer type.
148   typename HelperT::template Lookup<HelperT::MinTag>::PointerT *
getAddrOfZeroTagPointer()149   getAddrOfZeroTagPointer() {
150     static_assert(HelperT::MinTag == 0, "Non-zero minimum tag value!");
151     assert(is<HelperT::MinTag>() && "The active tag is not zero!");
152     // Store the initial value of the pointer when read out of our storage.
153     auto InitialPtr = get<HelperT::MinTag>();
154     // Now update the active member of the union to be the actual pointer-typed
155     // member so that accessing it indirectly through the returned address is
156     // valid.
157     Storage.MinTagPointer = InitialPtr;
158     // Finally, validate that this was a no-op as expected by reading it back
159     // out using the same underlying-storage read as above.
160     assert(InitialPtr == get<HelperT::MinTag>() &&
161            "Switching to typed storage changed the pointer returned!");
162     // Now we can correctly return an address to typed storage.
163     return &Storage.MinTagPointer;
164   }
165 
166   explicit operator bool() const {
167     return getOpaqueValue() & HelperT::PointerMask;
168   }
169   bool operator==(const PointerSumType &R) const {
170     return getOpaqueValue() == R.getOpaqueValue();
171   }
172   bool operator!=(const PointerSumType &R) const {
173     return getOpaqueValue() != R.getOpaqueValue();
174   }
175   bool operator<(const PointerSumType &R) const {
176     return getOpaqueValue() < R.getOpaqueValue();
177   }
178   bool operator>(const PointerSumType &R) const {
179     return getOpaqueValue() > R.getOpaqueValue();
180   }
181   bool operator<=(const PointerSumType &R) const {
182     return getOpaqueValue() <= R.getOpaqueValue();
183   }
184   bool operator>=(const PointerSumType &R) const {
185     return getOpaqueValue() >= R.getOpaqueValue();
186   }
187 
getOpaqueValue()188   uintptr_t getOpaqueValue() const {
189     // Read the underlying storage of the union, regardless of the active
190     // member.
191     return bit_cast<uintptr_t>(Storage);
192   }
193 
194 protected:
getVoidPtr()195   void *getVoidPtr() const {
196     return reinterpret_cast<void *>(getOpaqueValue() & HelperT::PointerMask);
197   }
198 };
199 
200 namespace detail {
201 
202 /// A helper template for implementing \c PointerSumType. It provides fast
203 /// compile-time lookup of the member from a particular tag value, along with
204 /// useful constants and compile time checking infrastructure..
205 template <typename TagT, typename... MemberTs>
206 struct PointerSumTypeHelper : MemberTs... {
207   // First we use a trick to allow quickly looking up information about
208   // a particular member of the sum type. This works because we arranged to
209   // have this type derive from all of the member type templates. We can select
210   // the matching member for a tag using type deduction during overload
211   // resolution.
212   template <TagT N, typename PointerT, typename TraitsT>
213   static PointerSumTypeMember<N, PointerT, TraitsT>
214   LookupOverload(PointerSumTypeMember<N, PointerT, TraitsT> *);
215   template <TagT N> static void LookupOverload(...);
216   template <TagT N> struct Lookup {
217     // Compute a particular member type by resolving the lookup helper overload.
218     using MemberT = decltype(
219         LookupOverload<N>(static_cast<PointerSumTypeHelper *>(nullptr)));
220 
221     /// The Nth member's pointer type.
222     using PointerT = typename MemberT::PointerT;
223 
224     /// The Nth member's traits type.
225     using TraitsT = typename MemberT::TraitsT;
226   };
227 
228   // Next we need to compute the number of bits available for the discriminant
229   // by taking the min of the bits available for each member. Much of this
230   // would be amazingly easier with good constexpr support.
231   template <uintptr_t V, uintptr_t... Vs>
232   struct Min : std::integral_constant<
233                    uintptr_t, (V < Min<Vs...>::value ? V : Min<Vs...>::value)> {
234   };
235   template <uintptr_t V>
236   struct Min<V> : std::integral_constant<uintptr_t, V> {};
237   enum { NumTagBits = Min<MemberTs::TraitsT::NumLowBitsAvailable...>::value };
238 
239   // Also compute the smallest discriminant and various masks for convenience.
240   constexpr static TagT MinTag =
241       static_cast<TagT>(Min<MemberTs::Tag...>::value);
242   enum : uint64_t {
243     PointerMask = static_cast<uint64_t>(-1) << NumTagBits,
244     TagMask = ~PointerMask
245   };
246 
247   // Finally we need a recursive template to do static checks of each
248   // member.
249   template <typename MemberT, typename... InnerMemberTs>
250   struct Checker : Checker<InnerMemberTs...> {
251     static_assert(MemberT::Tag < (1 << NumTagBits),
252                   "This discriminant value requires too many bits!");
253   };
254   template <typename MemberT> struct Checker<MemberT> : std::true_type {
255     static_assert(MemberT::Tag < (1 << NumTagBits),
256                   "This discriminant value requires too many bits!");
257   };
258   static_assert(Checker<MemberTs...>::value,
259                 "Each member must pass the checker.");
260 };
261 
262 } // end namespace detail
263 
264 // Teach DenseMap how to use PointerSumTypes as keys.
265 template <typename TagT, typename... MemberTs>
266 struct DenseMapInfo<PointerSumType<TagT, MemberTs...>> {
267   using SumType = PointerSumType<TagT, MemberTs...>;
268   using HelperT = detail::PointerSumTypeHelper<TagT, MemberTs...>;
269   enum { SomeTag = HelperT::MinTag };
270   using SomePointerT =
271       typename HelperT::template Lookup<HelperT::MinTag>::PointerT;
272   using SomePointerInfo = DenseMapInfo<SomePointerT>;
273 
274   static inline SumType getEmptyKey() {
275     return SumType::create<SomeTag>(SomePointerInfo::getEmptyKey());
276   }
277 
278   static inline SumType getTombstoneKey() {
279     return SumType::create<SomeTag>(SomePointerInfo::getTombstoneKey());
280   }
281 
282   static unsigned getHashValue(const SumType &Arg) {
283     uintptr_t OpaqueValue = Arg.getOpaqueValue();
284     return DenseMapInfo<uintptr_t>::getHashValue(OpaqueValue);
285   }
286 
287   static bool isEqual(const SumType &LHS, const SumType &RHS) {
288     return LHS == RHS;
289   }
290 };
291 
292 } // end namespace llvm
293 
294 #endif // LLVM_ADT_POINTERSUMTYPE_H
295