1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
5  * You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #include <type_traits>
8 
9 #include "mozilla/NotNull.h"
10 #include "mozilla/RefPtr.h"
11 #include "mozilla/UniquePtr.h"
12 #include "mozilla/Unused.h"
13 
14 using mozilla::MakeNotNull;
15 using mozilla::NotNull;
16 using mozilla::UniquePtr;
17 using mozilla::WrapNotNull;
18 
19 #define CHECK MOZ_RELEASE_ASSERT
20 
21 class Blah {
22  public:
Blah()23   Blah() : mX(0) {}
blah()24   void blah(){};
25   int mX;
26 };
27 
28 // A simple smart pointer that implicity converts to and from T*.
29 template <typename T>
30 class MyPtr {
31   T* mRawPtr;
32 
33  public:
MyPtr()34   MyPtr() : mRawPtr(nullptr) {}
MyPtr(T * aRawPtr)35   MOZ_IMPLICIT MyPtr(T* aRawPtr) : mRawPtr(aRawPtr) {}
36 
get() const37   T* get() const { return mRawPtr; }
operator T*() const38   operator T*() const { return get(); }
39 
operator ->() const40   T* operator->() const { return get(); }
41 };
42 
43 // A simple class that works with RefPtr. It keeps track of the maximum
44 // refcount value for testing purposes.
45 class MyRefType {
46   int mExpectedMaxRefCnt;
47   int mMaxRefCnt;
48   int mRefCnt;
49 
50  public:
MyRefType(int aExpectedMaxRefCnt)51   explicit MyRefType(int aExpectedMaxRefCnt)
52       : mExpectedMaxRefCnt(aExpectedMaxRefCnt), mMaxRefCnt(0), mRefCnt(0) {}
53 
~MyRefType()54   ~MyRefType() { CHECK(mMaxRefCnt == mExpectedMaxRefCnt); }
55 
AddRef()56   uint32_t AddRef() {
57     mRefCnt++;
58     if (mRefCnt > mMaxRefCnt) {
59       mMaxRefCnt = mRefCnt;
60     }
61     return mRefCnt;
62   }
63 
Release()64   uint32_t Release() {
65     CHECK(mRefCnt > 0);
66     mRefCnt--;
67     if (mRefCnt == 0) {
68       delete this;
69       return 0;
70     }
71     return mRefCnt;
72   }
73 };
74 
f_i(int * aPtr)75 void f_i(int* aPtr) {}
f_my(MyPtr<int> aPtr)76 void f_my(MyPtr<int> aPtr) {}
77 
f_nni(NotNull<int * > aPtr)78 void f_nni(NotNull<int*> aPtr) {}
f_nnmy(NotNull<MyPtr<int>> aPtr)79 void f_nnmy(NotNull<MyPtr<int>> aPtr) {}
80 
TestNotNullWithMyPtr()81 void TestNotNullWithMyPtr() {
82   int i4 = 4;
83   int i5 = 5;
84 
85   MyPtr<int> my4 = &i4;
86   MyPtr<int> my5 = &i5;
87 
88   NotNull<int*> nni4 = WrapNotNull(&i4);
89   NotNull<int*> nni5 = WrapNotNull(&i5);
90   NotNull<MyPtr<int>> nnmy4 = WrapNotNull(my4);
91 
92   // WrapNotNull(nullptr);                       // no wrapping from nullptr
93   // WrapNotNull(0);                             // no wrapping from zero
94 
95   // NotNull<int*> construction combinations
96   // NotNull<int*> nni4a;                        // no default
97   // NotNull<int*> nni4a(nullptr);               // no nullptr
98   // NotNull<int*> nni4a(0);                     // no zero
99   // NotNull<int*> nni4a(&i4);                   // no int*
100   // NotNull<int*> nni4a(my4);                   // no MyPtr<int>
101   NotNull<int*> nni4b(WrapNotNull(&i4));  // WrapNotNull(int*)
102   NotNull<int*> nni4c(WrapNotNull(my4));  // WrapNotNull(MyPtr<int>)
103   NotNull<int*> nni4d(nni4);              // NotNull<int*>
104   NotNull<int*> nni4e(nnmy4);             // NotNull<MyPtr<int>>
105   CHECK(*nni4b == 4);
106   CHECK(*nni4c == 4);
107   CHECK(*nni4d == 4);
108   CHECK(*nni4e == 4);
109 
110   // NotNull<MyPtr<int>> construction combinations
111   // NotNull<MyPtr<int>> nnmy4a;                 // no default
112   // NotNull<MyPtr<int>> nnmy4a(nullptr);        // no nullptr
113   // NotNull<MyPtr<int>> nnmy4a(0);              // no zero
114   // NotNull<MyPtr<int>> nnmy4a(&i4);            // no int*
115   // NotNull<MyPtr<int>> nnmy4a(my4);            // no MyPtr<int>
116   NotNull<MyPtr<int>> nnmy4b(WrapNotNull(&i4));  // WrapNotNull(int*)
117   NotNull<MyPtr<int>> nnmy4c(WrapNotNull(my4));  // WrapNotNull(MyPtr<int>)
118   NotNull<MyPtr<int>> nnmy4d(nni4);              // NotNull<int*>
119   NotNull<MyPtr<int>> nnmy4e(nnmy4);             // NotNull<MyPtr<int>>
120   CHECK(*nnmy4b == 4);
121   CHECK(*nnmy4c == 4);
122   CHECK(*nnmy4d == 4);
123   CHECK(*nnmy4e == 4);
124 
125   // NotNull<int*> assignment combinations
126   // nni4b = nullptr;                            // no nullptr
127   // nni4b = 0;                                  // no zero
128   // nni4a = &i4;                                // no int*
129   // nni4a = my4;                                // no MyPtr<int>
130   nni4b = WrapNotNull(&i4);  // WrapNotNull(int*)
131   nni4c = WrapNotNull(my4);  // WrapNotNull(MyPtr<int>)
132   nni4d = nni4;              // NotNull<int*>
133   nni4e = nnmy4;             // NotNull<MyPtr<int>>
134   CHECK(*nni4b == 4);
135   CHECK(*nni4c == 4);
136   CHECK(*nni4d == 4);
137   CHECK(*nni4e == 4);
138 
139   // NotNull<MyPtr<int>> assignment combinations
140   // nnmy4a = nullptr;                           // no nullptr
141   // nnmy4a = 0;                                 // no zero
142   // nnmy4a = &i4;                               // no int*
143   // nnmy4a = my4;                               // no MyPtr<int>
144   nnmy4b = WrapNotNull(&i4);  // WrapNotNull(int*)
145   nnmy4c = WrapNotNull(my4);  // WrapNotNull(MyPtr<int>)
146   nnmy4d = nni4;              // NotNull<int*>
147   nnmy4e = nnmy4;             // NotNull<MyPtr<int>>
148   CHECK(*nnmy4b == 4);
149   CHECK(*nnmy4c == 4);
150   CHECK(*nnmy4d == 4);
151   CHECK(*nnmy4e == 4);
152 
153   NotNull<MyPtr<int>> nnmy5 = WrapNotNull(&i5);
154   CHECK(*nnmy5 == 5);
155   CHECK(nnmy5 == &i5);    // NotNull<MyPtr<int>> == int*
156   CHECK(nnmy5 == my5);    // NotNull<MyPtr<int>> == MyPtr<int>
157   CHECK(nnmy5 == nni5);   // NotNull<MyPtr<int>> == NotNull<int*>
158   CHECK(nnmy5 == nnmy5);  // NotNull<MyPtr<int>> == NotNull<MyPtr<int>>
159   CHECK(&i5 == nnmy5);    // int*                == NotNull<MyPtr<int>>
160   CHECK(my5 == nnmy5);    // MyPtr<int>          == NotNull<MyPtr<int>>
161   CHECK(nni5 == nnmy5);   // NotNull<int*>       == NotNull<MyPtr<int>>
162   CHECK(nnmy5 == nnmy5);  // NotNull<MyPtr<int>> == NotNull<MyPtr<int>>
163   // CHECK(nni5 == nullptr);  // no comparisons with nullptr
164   // CHECK(nullptr == nni5);  // no comparisons with nullptr
165   // CHECK(nni5 == 0);        // no comparisons with zero
166   // CHECK(0 == nni5);        // no comparisons with zero
167 
168   CHECK(*nnmy5 == 5);
169   CHECK(nnmy5 != &i4);    // NotNull<MyPtr<int>> != int*
170   CHECK(nnmy5 != my4);    // NotNull<MyPtr<int>> != MyPtr<int>
171   CHECK(nnmy5 != nni4);   // NotNull<MyPtr<int>> != NotNull<int*>
172   CHECK(nnmy5 != nnmy4);  // NotNull<MyPtr<int>> != NotNull<MyPtr<int>>
173   CHECK(&i4 != nnmy5);    // int*                != NotNull<MyPtr<int>>
174   CHECK(my4 != nnmy5);    // MyPtr<int>          != NotNull<MyPtr<int>>
175   CHECK(nni4 != nnmy5);   // NotNull<int*>       != NotNull<MyPtr<int>>
176   CHECK(nnmy4 != nnmy5);  // NotNull<MyPtr<int>> != NotNull<MyPtr<int>>
177   // CHECK(nni4 != nullptr);  // no comparisons with nullptr
178   // CHECK(nullptr != nni4);  // no comparisons with nullptr
179   // CHECK(nni4 != 0);        // no comparisons with zero
180   // CHECK(0 != nni4);        // no comparisons with zero
181 
182   // int* parameter
183   f_i(&i4);         // identity int*                        --> int*
184   f_i(my4);         // implicit MyPtr<int>                  --> int*
185   f_i(my4.get());   // explicit MyPtr<int>                  --> int*
186   f_i(nni4);        // implicit NotNull<int*>               --> int*
187   f_i(nni4.get());  // explicit NotNull<int*>               --> int*
188   // f_i(nnmy4);         // no implicit NotNull<MyPtr<int>>      --> int*
189   f_i(nnmy4.get());        // explicit NotNull<MyPtr<int>>         --> int*
190   f_i(nnmy4.get().get());  // doubly-explicit NotNull<MyPtr<int>> --> int*
191 
192   // MyPtr<int> parameter
193   f_my(&i4);        // implicit int*                         --> MyPtr<int>
194   f_my(my4);        // identity MyPtr<int>                   --> MyPtr<int>
195   f_my(my4.get());  // explicit MyPtr<int>                   --> MyPtr<int>
196   // f_my(nni4);         // no implicit NotNull<int*>             --> MyPtr<int>
197   f_my(nni4.get());   // explicit NotNull<int*>                --> MyPtr<int>
198   f_my(nnmy4);        // implicit NotNull<MyPtr<int>>          --> MyPtr<int>
199   f_my(nnmy4.get());  // explicit NotNull<MyPtr<int>>          --> MyPtr<int>
200   f_my(
201       nnmy4.get().get());  // doubly-explicit NotNull<MyPtr<int>> --> MyPtr<int>
202 
203   // NotNull<int*> parameter
204   f_nni(nni4);   // identity NotNull<int*>       --> NotNull<int*>
205   f_nni(nnmy4);  // implicit NotNull<MyPtr<int>> --> NotNull<int*>
206 
207   // NotNull<MyPtr<int>> parameter
208   f_nnmy(nni4);   // implicit NotNull<int*>       --> NotNull<MyPtr<int>>
209   f_nnmy(nnmy4);  // identity NotNull<MyPtr<int>> --> NotNull<MyPtr<int>>
210 
211   // CHECK(nni4);        // disallow boolean conversion / unary expression usage
212   // CHECK(nnmy4);       // ditto
213 
214   // '->' dereferencing.
215   Blah blah;
216   MyPtr<Blah> myblah = &blah;
217   NotNull<Blah*> nnblah = WrapNotNull(&blah);
218   NotNull<MyPtr<Blah>> nnmyblah = WrapNotNull(myblah);
219   (&blah)->blah();   // int*
220   myblah->blah();    // MyPtr<int>
221   nnblah->blah();    // NotNull<int*>
222   nnmyblah->blah();  // NotNull<MyPtr<int>>
223 
224   (&blah)->mX = 1;
225   CHECK((&blah)->mX == 1);
226   myblah->mX = 2;
227   CHECK(myblah->mX == 2);
228   nnblah->mX = 3;
229   CHECK(nnblah->mX == 3);
230   nnmyblah->mX = 4;
231   CHECK(nnmyblah->mX == 4);
232 
233   // '*' dereferencing (lvalues and rvalues)
234   *(&i4) = 7;  // int*
235   CHECK(*(&i4) == 7);
236   *my4 = 6;  // MyPtr<int>
237   CHECK(*my4 == 6);
238   *nni4 = 5;  // NotNull<int*>
239   CHECK(*nni4 == 5);
240   *nnmy4 = 4;  // NotNull<MyPtr<int>>
241   CHECK(*nnmy4 == 4);
242 
243   // Non-null arrays.
244   static const int N = 20;
245   int a[N];
246   NotNull<int*> nna = WrapNotNull(a);
247   for (int i = 0; i < N; i++) {
248     nna[i] = i;
249   }
250   for (int i = 0; i < N; i++) {
251     nna[i] *= 2;
252   }
253   for (int i = 0; i < N; i++) {
254     CHECK(nna[i] == i * 2);
255   }
256 }
257 
f_ref(NotNull<MyRefType * > aR)258 void f_ref(NotNull<MyRefType*> aR) { NotNull<RefPtr<MyRefType>> r = aR; }
259 
TestNotNullWithRefPtr()260 void TestNotNullWithRefPtr() {
261   // This MyRefType object will have a maximum refcount of 5.
262   NotNull<RefPtr<MyRefType>> r1 = WrapNotNull(new MyRefType(5));
263 
264   // At this point the refcount is 1.
265 
266   NotNull<RefPtr<MyRefType>> r2 = r1;
267 
268   // At this point the refcount is 2.
269 
270   NotNull<MyRefType*> r3 = r2;
271   (void)r3;
272 
273   // At this point the refcount is still 2.
274 
275   RefPtr<MyRefType> r4 = r2;
276   mozilla::Unused << r4;
277 
278   // At this point the refcount is 3.
279 
280   RefPtr<MyRefType> r5 = r3.get();
281   mozilla::Unused << r5;
282 
283   // At this point the refcount is 4.
284 
285   // No change to the refcount occurs because of the argument passing. Within
286   // f_ref() the refcount temporarily hits 5, due to the local RefPtr.
287   f_ref(r2);
288 
289   // At this point the refcount is 4.
290 
291   NotNull<RefPtr<MyRefType>> r6 = std::move(r2);
292   mozilla::Unused << r6;
293 
294   CHECK(r2.get());
295   CHECK(r6.get());
296 
297   // At this point the refcount is 5 again, since NotNull is not movable.
298 
299   // At function's end all RefPtrs are destroyed and the refcount drops to 0
300   // and the MyRefType is destroyed.
301 }
302 
TestMakeNotNull()303 void TestMakeNotNull() {
304   // Raw pointer.
305   auto nni = MakeNotNull<int*>(11);
306   static_assert(std::is_same_v<NotNull<int*>, decltype(nni)>,
307                 "MakeNotNull<int*> should return NotNull<int*>");
308   CHECK(*nni == 11);
309   delete nni;
310 
311   // Raw pointer to const.
312   auto nnci = MakeNotNull<const int*>(12);
313   static_assert(std::is_same_v<NotNull<const int*>, decltype(nnci)>,
314                 "MakeNotNull<const int*> should return NotNull<const int*>");
315   CHECK(*nnci == 12);
316   delete nnci;
317 
318   // Create a derived object and store its base pointer.
319   struct Base {
320     virtual ~Base() = default;
321     virtual bool IsDerived() const { return false; }
322   };
323   struct Derived : Base {
324     bool IsDerived() const override { return true; }
325   };
326   auto nnd = MakeNotNull<Derived*>();
327   static_assert(std::is_same_v<NotNull<Derived*>, decltype(nnd)>,
328                 "MakeNotNull<Derived*> should return NotNull<Derived*>");
329   CHECK(nnd->IsDerived());
330   delete nnd;
331   NotNull<Base*> nnb = MakeNotNull<Derived*>();
332   static_assert(std::is_same_v<NotNull<Base*>, decltype(nnb)>,
333                 "MakeNotNull<Derived*> should be assignable to NotNull<Base*>");
334   // Check that we have really built a Derived object.
335   CHECK(nnb->IsDerived());
336   delete nnb;
337 
338   // Allow smart pointers.
339   auto nnmi = MakeNotNull<MyPtr<int>>(23);
340   static_assert(std::is_same_v<NotNull<MyPtr<int>>, decltype(nnmi)>,
341                 "MakeNotNull<MyPtr<int>> should return NotNull<MyPtr<int>>");
342   CHECK(*nnmi == 23);
343   delete nnmi.get().get();
344 
345   auto nnui = MakeNotNull<UniquePtr<int>>(24);
346   static_assert(
347       std::is_same_v<NotNull<UniquePtr<int>>, decltype(nnui)>,
348       "MakeNotNull<UniquePtr<int>> should return NotNull<UniquePtr<int>>");
349   CHECK(*nnui == 24);
350 
351   // Expect only 1 RefCnt (from construction).
352   auto nnr = MakeNotNull<RefPtr<MyRefType>>(1);
353   static_assert(std::is_same_v<NotNull<RefPtr<MyRefType>>, decltype(nnr)>,
354                 "MakeNotNull<RefPtr<MyRefType>> should return "
355                 "NotNull<RefPtr<MyRefType>>");
356   mozilla::Unused << nnr;
357 }
358 
CreateNotNullUniquePtr()359 mozilla::MovingNotNull<UniquePtr<int>> CreateNotNullUniquePtr() {
360   return mozilla::WrapMovingNotNull(mozilla::MakeUnique<int>(42));
361 }
362 
TestMovingNotNull()363 void TestMovingNotNull() {
364   UniquePtr<int> x1 = CreateNotNullUniquePtr();
365   CHECK(x1);
366   CHECK(42 == *x1);
367 
368   NotNull<UniquePtr<int>> x2 = CreateNotNullUniquePtr();
369   CHECK(42 == *x2);
370 
371   // Must not compile:
372   // auto y = CreateNotNullUniquePtr();
373 }
374 
main()375 int main() {
376   TestNotNullWithMyPtr();
377   TestNotNullWithRefPtr();
378   TestMakeNotNull();
379   TestMovingNotNull();
380 
381   return 0;
382 }
383