1 //===- llvm/ADT/PointerUnion.h - Discriminated Union of 2 Ptrs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file defines the PointerUnion class, which is a discriminated union of
11 /// pointer types.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef LLVM_ADT_POINTERUNION_H
16 #define LLVM_ADT_POINTERUNION_H
17 
18 #include "llvm/ADT/DenseMapInfo.h"
19 #include "llvm/ADT/PointerIntPair.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/PointerLikeTypeTraits.h"
22 #include <algorithm>
23 #include <cassert>
24 #include <cstddef>
25 #include <cstdint>
26 
27 namespace llvm {
28 
29 namespace pointer_union_detail {
30   /// Determine the number of bits required to store integers with values < n.
31   /// This is ceil(log2(n)).
32   constexpr int bitsRequired(unsigned n) {
33     return n > 1 ? 1 + bitsRequired((n + 1) / 2) : 0;
34   }
35 
36   template <typename... Ts> constexpr int lowBitsAvailable() {
37     return std::min<int>({PointerLikeTypeTraits<Ts>::NumLowBitsAvailable...});
38   }
39 
40   /// Find the first type in a list of types.
41   template <typename T, typename...> struct GetFirstType {
42     using type = T;
43   };
44 
45   /// Provide PointerLikeTypeTraits for void* that is used by PointerUnion
46   /// for the template arguments.
47   template <typename ...PTs> class PointerUnionUIntTraits {
48   public:
49     static inline void *getAsVoidPointer(void *P) { return P; }
50     static inline void *getFromVoidPointer(void *P) { return P; }
51     static constexpr int NumLowBitsAvailable = lowBitsAvailable<PTs...>();
52   };
53 
54   template <typename Derived, typename ValTy, int I, typename ...Types>
55   class PointerUnionMembers;
56 
57   template <typename Derived, typename ValTy, int I>
58   class PointerUnionMembers<Derived, ValTy, I> {
59   protected:
60     ValTy Val;
61     PointerUnionMembers() = default;
62     PointerUnionMembers(ValTy Val) : Val(Val) {}
63 
64     friend struct PointerLikeTypeTraits<Derived>;
65   };
66 
67   template <typename Derived, typename ValTy, int I, typename Type,
68             typename ...Types>
69   class PointerUnionMembers<Derived, ValTy, I, Type, Types...>
70       : public PointerUnionMembers<Derived, ValTy, I + 1, Types...> {
71     using Base = PointerUnionMembers<Derived, ValTy, I + 1, Types...>;
72   public:
73     using Base::Base;
74     PointerUnionMembers() = default;
75     PointerUnionMembers(Type V)
76         : Base(ValTy(const_cast<void *>(
77                          PointerLikeTypeTraits<Type>::getAsVoidPointer(V)),
78                      I)) {}
79 
80     using Base::operator=;
81     Derived &operator=(Type V) {
82       this->Val = ValTy(
83           const_cast<void *>(PointerLikeTypeTraits<Type>::getAsVoidPointer(V)),
84           I);
85       return static_cast<Derived &>(*this);
86     };
87   };
88 }
89 
90 /// A discriminated union of two or more pointer types, with the discriminator
91 /// in the low bit of the pointer.
92 ///
93 /// This implementation is extremely efficient in space due to leveraging the
94 /// low bits of the pointer, while exposing a natural and type-safe API.
95 ///
96 /// Common use patterns would be something like this:
97 ///    PointerUnion<int*, float*> P;
98 ///    P = (int*)0;
99 ///    printf("%d %d", P.is<int*>(), P.is<float*>());  // prints "1 0"
100 ///    X = P.get<int*>();     // ok.
101 ///    Y = P.get<float*>();   // runtime assertion failure.
102 ///    Z = P.get<double*>();  // compile time failure.
103 ///    P = (float*)0;
104 ///    Y = P.get<float*>();   // ok.
105 ///    X = P.get<int*>();     // runtime assertion failure.
106 ///    PointerUnion<int*, int*> Q; // compile time failure.
107 template <typename... PTs>
108 class PointerUnion
109     : public pointer_union_detail::PointerUnionMembers<
110           PointerUnion<PTs...>,
111           PointerIntPair<
112               void *, pointer_union_detail::bitsRequired(sizeof...(PTs)), int,
113               pointer_union_detail::PointerUnionUIntTraits<PTs...>>,
114           0, PTs...> {
115   static_assert(TypesAreDistinct<PTs...>::value,
116                 "PointerUnion alternative types cannot be repeated");
117   // The first type is special because we want to directly cast a pointer to a
118   // default-initialized union to a pointer to the first type. But we don't
119   // want PointerUnion to be a 'template <typename First, typename ...Rest>'
120   // because it's much more convenient to have a name for the whole pack. So
121   // split off the first type here.
122   using First = TypeAtIndex<0, PTs...>;
123   using Base = typename PointerUnion::PointerUnionMembers;
124 
125 public:
126   PointerUnion() = default;
127 
128   PointerUnion(std::nullptr_t) : PointerUnion() {}
129   using Base::Base;
130 
131   /// Test if the pointer held in the union is null, regardless of
132   /// which type it is.
133   bool isNull() const { return !this->Val.getPointer(); }
134 
135   explicit operator bool() const { return !isNull(); }
136 
137   /// Test if the Union currently holds the type matching T.
138   template <typename T> bool is() const {
139     return this->Val.getInt() == FirstIndexOfType<T, PTs...>::value;
140   }
141 
142   /// Returns the value of the specified pointer type.
143   ///
144   /// If the specified pointer type is incorrect, assert.
145   template <typename T> T get() const {
146     assert(is<T>() && "Invalid accessor called");
147     return PointerLikeTypeTraits<T>::getFromVoidPointer(this->Val.getPointer());
148   }
149 
150   /// Returns the current pointer if it is of the specified pointer type,
151   /// otherwise returns null.
152   template <typename T> T dyn_cast() const {
153     if (is<T>())
154       return get<T>();
155     return T();
156   }
157 
158   /// If the union is set to the first pointer type get an address pointing to
159   /// it.
160   First const *getAddrOfPtr1() const {
161     return const_cast<PointerUnion *>(this)->getAddrOfPtr1();
162   }
163 
164   /// If the union is set to the first pointer type get an address pointing to
165   /// it.
166   First *getAddrOfPtr1() {
167     assert(is<First>() && "Val is not the first pointer");
168     assert(
169         PointerLikeTypeTraits<First>::getAsVoidPointer(get<First>()) ==
170             this->Val.getPointer() &&
171         "Can't get the address because PointerLikeTypeTraits changes the ptr");
172     return const_cast<First *>(
173         reinterpret_cast<const First *>(this->Val.getAddrOfPointer()));
174   }
175 
176   /// Assignment from nullptr which just clears the union.
177   const PointerUnion &operator=(std::nullptr_t) {
178     this->Val.initWithPointer(nullptr);
179     return *this;
180   }
181 
182   /// Assignment from elements of the union.
183   using Base::operator=;
184 
185   void *getOpaqueValue() const { return this->Val.getOpaqueValue(); }
186   static inline PointerUnion getFromOpaqueValue(void *VP) {
187     PointerUnion V;
188     V.Val = decltype(V.Val)::getFromOpaqueValue(VP);
189     return V;
190   }
191 };
192 
193 template <typename ...PTs>
194 bool operator==(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
195   return lhs.getOpaqueValue() == rhs.getOpaqueValue();
196 }
197 
198 template <typename ...PTs>
199 bool operator!=(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
200   return lhs.getOpaqueValue() != rhs.getOpaqueValue();
201 }
202 
203 template <typename ...PTs>
204 bool operator<(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
205   return lhs.getOpaqueValue() < rhs.getOpaqueValue();
206 }
207 
208 // Teach SmallPtrSet that PointerUnion is "basically a pointer", that has
209 // # low bits available = min(PT1bits,PT2bits)-1.
210 template <typename ...PTs>
211 struct PointerLikeTypeTraits<PointerUnion<PTs...>> {
212   static inline void *getAsVoidPointer(const PointerUnion<PTs...> &P) {
213     return P.getOpaqueValue();
214   }
215 
216   static inline PointerUnion<PTs...> getFromVoidPointer(void *P) {
217     return PointerUnion<PTs...>::getFromOpaqueValue(P);
218   }
219 
220   // The number of bits available are the min of the pointer types minus the
221   // bits needed for the discriminator.
222   static constexpr int NumLowBitsAvailable = PointerLikeTypeTraits<decltype(
223       PointerUnion<PTs...>::Val)>::NumLowBitsAvailable;
224 };
225 
226 // Teach DenseMap how to use PointerUnions as keys.
227 template <typename ...PTs> struct DenseMapInfo<PointerUnion<PTs...>> {
228   using Union = PointerUnion<PTs...>;
229   using FirstInfo =
230       DenseMapInfo<typename pointer_union_detail::GetFirstType<PTs...>::type>;
231 
232   static inline Union getEmptyKey() { return Union(FirstInfo::getEmptyKey()); }
233 
234   static inline Union getTombstoneKey() {
235     return Union(FirstInfo::getTombstoneKey());
236   }
237 
238   static unsigned getHashValue(const Union &UnionVal) {
239     intptr_t key = (intptr_t)UnionVal.getOpaqueValue();
240     return DenseMapInfo<intptr_t>::getHashValue(key);
241   }
242 
243   static bool isEqual(const Union &LHS, const Union &RHS) {
244     return LHS == RHS;
245   }
246 };
247 
248 } // end namespace llvm
249 
250 #endif // LLVM_ADT_POINTERUNION_H
251