1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5  * You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #include <utility>
8 
9 #include "mozilla/Assertions.h"
10 #include "mozilla/NonDereferenceable.h"
11 
12 using mozilla::NonDereferenceable;
13 
14 #define CHECK MOZ_RELEASE_ASSERT
15 
TestNonDereferenceableSimple()16 void TestNonDereferenceableSimple() {
17   // Default construction.
18   NonDereferenceable<int> nd0;
19   CHECK(!nd0);
20   CHECK(!nd0.value());
21 
22   int i = 1;
23   int i2 = 2;
24 
25   // Construction with pointer.
26   NonDereferenceable<int> nd1(&i);
27   CHECK(!!nd1);
28   CHECK(nd1.value() == reinterpret_cast<uintptr_t>(&i));
29 
30   // Assignment with pointer.
31   nd1 = &i2;
32   CHECK(nd1.value() == reinterpret_cast<uintptr_t>(&i2));
33 
34   // Copy-construction.
35   NonDereferenceable<int> nd2(nd1);
36   CHECK(nd2.value() == reinterpret_cast<uintptr_t>(&i2));
37 
38   // Copy-assignment.
39   nd2 = nd0;
40   CHECK(!nd2.value());
41 
42   // Move-construction.
43   NonDereferenceable<int> nd3{NonDereferenceable<int>(&i)};
44   CHECK(nd3.value() == reinterpret_cast<uintptr_t>(&i));
45 
46   // Move-assignment.
47   nd3 = std::move(nd1);
48   CHECK(nd3.value() == reinterpret_cast<uintptr_t>(&i2));
49   // Note: Not testing nd1's value because we don't want to assume what state
50   // it is left in after move. But at least it should be reusable:
51   nd1 = &i;
52   CHECK(nd1.value() == reinterpret_cast<uintptr_t>(&i));
53 }
54 
TestNonDereferenceableHierarchy()55 void TestNonDereferenceableHierarchy() {
56   struct Base1 {
57     // Member variable, to make sure Base1 is not empty.
58     int x1;
59   };
60   struct Base2 {
61     int x2;
62   };
63   struct Derived : Base1, Base2 {};
64 
65   Derived d;
66 
67   // Construct NonDereferenceable from raw pointer.
68   NonDereferenceable<Derived> ndd = NonDereferenceable<Derived>(&d);
69   CHECK(ndd);
70   CHECK(ndd.value() == reinterpret_cast<uintptr_t>(&d));
71 
72   // Cast Derived to Base1.
73   NonDereferenceable<Base1> ndb1 = ndd;
74   CHECK(ndb1);
75   CHECK(ndb1.value() == reinterpret_cast<uintptr_t>(static_cast<Base1*>(&d)));
76 
77   // Cast Base1 back to Derived.
78   NonDereferenceable<Derived> nddb1 = ndb1;
79   CHECK(nddb1.value() == reinterpret_cast<uintptr_t>(&d));
80 
81   // Cast Derived to Base2.
82   NonDereferenceable<Base2> ndb2 = ndd;
83   CHECK(ndb2);
84   CHECK(ndb2.value() == reinterpret_cast<uintptr_t>(static_cast<Base2*>(&d)));
85   // Sanity check that Base2 should be offset from the start of Derived.
86   CHECK(ndb2.value() != ndd.value());
87 
88   // Cast Base2 back to Derived.
89   NonDereferenceable<Derived> nddb2 = ndb2;
90   CHECK(nddb2.value() == reinterpret_cast<uintptr_t>(&d));
91 
92   // Note that it's not possible to jump between bases, as they're not obviously
93   // related, i.e.: `NonDereferenceable<Base2> ndb22 = ndb1;` doesn't compile.
94   // However it's possible to explicitly navigate through the derived object:
95   NonDereferenceable<Base2> ndb22 = NonDereferenceable<Derived>(ndb1);
96   CHECK(ndb22.value() == reinterpret_cast<uintptr_t>(static_cast<Base2*>(&d)));
97 
98   // Handling nullptr; should stay nullptr even for offset bases.
99   ndd = nullptr;
100   CHECK(!ndd);
101   CHECK(!ndd.value());
102   ndb1 = ndd;
103   CHECK(!ndb1);
104   CHECK(!ndb1.value());
105   ndb2 = ndd;
106   CHECK(!ndb2);
107   CHECK(!ndb2.value());
108   nddb2 = ndb2;
109   CHECK(!nddb2);
110   CHECK(!nddb2.value());
111 }
112 
113 template <typename T, size_t Index>
114 struct CRTPBase {
115   // Convert `this` from `CRTPBase*` to `T*` while construction is still in
116   // progress; normally UBSan -fsanitize=vptr would catch this, but using
117   // NonDereferenceable should keep UBSan happy.
CRTPBaseCRTPBase118   CRTPBase() : mDerived(this) {}
119   NonDereferenceable<T> mDerived;
120 };
121 
TestNonDereferenceableCRTP()122 void TestNonDereferenceableCRTP() {
123   struct Derived : CRTPBase<Derived, 1>, CRTPBase<Derived, 2> {};
124   using Base1 = Derived::CRTPBase<Derived, 1>;
125   using Base2 = Derived::CRTPBase<Derived, 2>;
126 
127   Derived d;
128   // Verify that base constructors have correctly captured the address of the
129   // (at the time still incomplete) derived object.
130   CHECK(d.Base1::mDerived.value() == reinterpret_cast<uintptr_t>(&d));
131   CHECK(d.Base2::mDerived.value() == reinterpret_cast<uintptr_t>(&d));
132 
133   // Construct NonDereferenceable from raw pointer.
134   NonDereferenceable<Derived> ndd = NonDereferenceable<Derived>(&d);
135   CHECK(ndd);
136   CHECK(ndd.value() == reinterpret_cast<uintptr_t>(&d));
137 
138   // Cast Derived to Base1.
139   NonDereferenceable<Base1> ndb1 = ndd;
140   CHECK(ndb1);
141   CHECK(ndb1.value() == reinterpret_cast<uintptr_t>(static_cast<Base1*>(&d)));
142 
143   // Cast Base1 back to Derived.
144   NonDereferenceable<Derived> nddb1 = ndb1;
145   CHECK(nddb1.value() == reinterpret_cast<uintptr_t>(&d));
146 
147   // Cast Derived to Base2.
148   NonDereferenceable<Base2> ndb2 = ndd;
149   CHECK(ndb2);
150   CHECK(ndb2.value() == reinterpret_cast<uintptr_t>(static_cast<Base2*>(&d)));
151   // Sanity check that Base2 should be offset from the start of Derived.
152   CHECK(ndb2.value() != ndd.value());
153 
154   // Cast Base2 back to Derived.
155   NonDereferenceable<Derived> nddb2 = ndb2;
156   CHECK(nddb2.value() == reinterpret_cast<uintptr_t>(&d));
157 
158   // Note that it's not possible to jump between bases, as they're not obviously
159   // related, i.e.: `NonDereferenceable<Base2> ndb22 = ndb1;` doesn't compile.
160   // However it's possible to explicitly navigate through the derived object:
161   NonDereferenceable<Base2> ndb22 = NonDereferenceable<Derived>(ndb1);
162   CHECK(ndb22.value() == reinterpret_cast<uintptr_t>(static_cast<Base2*>(&d)));
163 }
164 
main()165 int main() {
166   TestNonDereferenceableSimple();
167   TestNonDereferenceableHierarchy();
168   TestNonDereferenceableCRTP();
169 
170   return 0;
171 }
172