1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef COUNT_NEW_H
10 #define COUNT_NEW_H
11 
12 # include <cstdlib>
13 # include <cassert>
14 # include <new>
15 
16 #include "test_macros.h"
17 
18 #if defined(TEST_HAS_SANITIZERS)
19 #define DISABLE_NEW_COUNT
20 #endif
21 
22 namespace detail
23 {
24    TEST_NORETURN
throw_bad_alloc_helper()25    inline void throw_bad_alloc_helper() {
26 #ifndef TEST_HAS_NO_EXCEPTIONS
27        throw std::bad_alloc();
28 #else
29        std::abort();
30 #endif
31    }
32 }
33 
34 class MemCounter
35 {
36 public:
37     // Make MemCounter super hard to accidentally construct or copy.
38     class MemCounterCtorArg_ {};
MemCounter(MemCounterCtorArg_)39     explicit MemCounter(MemCounterCtorArg_) { reset(); }
40 
41 private:
42     MemCounter(MemCounter const &);
43     MemCounter & operator=(MemCounter const &);
44 
45 public:
46     // All checks return true when disable_checking is enabled.
47     static const bool disable_checking;
48 
49     // Disallow any allocations from occurring. Useful for testing that
50     // code doesn't perform any allocations.
51     bool disable_allocations;
52 
53     // number of allocations to throw after. Default (unsigned)-1. If
54     // throw_after has the default value it will never be decremented.
55     static const unsigned never_throw_value = static_cast<unsigned>(-1);
56     unsigned throw_after;
57 
58     int outstanding_new;
59     int new_called;
60     int delete_called;
61     int aligned_new_called;
62     int aligned_delete_called;
63     std::size_t last_new_size;
64     std::size_t last_new_align;
65     std::size_t last_delete_align;
66 
67     int outstanding_array_new;
68     int new_array_called;
69     int delete_array_called;
70     int aligned_new_array_called;
71     int aligned_delete_array_called;
72     std::size_t last_new_array_size;
73     std::size_t last_new_array_align;
74     std::size_t last_delete_array_align;
75 
76 public:
newCalled(std::size_t s)77     void newCalled(std::size_t s)
78     {
79         assert(disable_allocations == false);
80         assert(s);
81         if (throw_after == 0) {
82             throw_after = never_throw_value;
83             detail::throw_bad_alloc_helper();
84         } else if (throw_after != never_throw_value) {
85             --throw_after;
86         }
87         ++new_called;
88         ++outstanding_new;
89         last_new_size = s;
90     }
91 
alignedNewCalled(std::size_t s,std::size_t a)92     void alignedNewCalled(std::size_t s, std::size_t a) {
93       newCalled(s);
94       ++aligned_new_called;
95       last_new_align = a;
96     }
97 
deleteCalled(void * p)98     void deleteCalled(void * p)
99     {
100         assert(p);
101         --outstanding_new;
102         ++delete_called;
103     }
104 
alignedDeleteCalled(void * p,std::size_t a)105     void alignedDeleteCalled(void *p, std::size_t a) {
106       deleteCalled(p);
107       ++aligned_delete_called;
108       last_delete_align = a;
109     }
110 
newArrayCalled(std::size_t s)111     void newArrayCalled(std::size_t s)
112     {
113         assert(disable_allocations == false);
114         assert(s);
115         if (throw_after == 0) {
116             throw_after = never_throw_value;
117             detail::throw_bad_alloc_helper();
118         } else {
119             // don't decrement throw_after here. newCalled will end up doing that.
120         }
121         ++outstanding_array_new;
122         ++new_array_called;
123         last_new_array_size = s;
124     }
125 
alignedNewArrayCalled(std::size_t s,std::size_t a)126     void alignedNewArrayCalled(std::size_t s, std::size_t a) {
127       newArrayCalled(s);
128       ++aligned_new_array_called;
129       last_new_array_align = a;
130     }
131 
deleteArrayCalled(void * p)132     void deleteArrayCalled(void * p)
133     {
134         assert(p);
135         --outstanding_array_new;
136         ++delete_array_called;
137     }
138 
alignedDeleteArrayCalled(void * p,std::size_t a)139     void alignedDeleteArrayCalled(void * p, std::size_t a) {
140       deleteArrayCalled(p);
141       ++aligned_delete_array_called;
142       last_delete_array_align = a;
143     }
144 
disableAllocations()145     void disableAllocations()
146     {
147         disable_allocations = true;
148     }
149 
enableAllocations()150     void enableAllocations()
151     {
152         disable_allocations = false;
153     }
154 
reset()155     void reset()
156     {
157         disable_allocations = false;
158         throw_after = never_throw_value;
159 
160         outstanding_new = 0;
161         new_called = 0;
162         delete_called = 0;
163         aligned_new_called = 0;
164         aligned_delete_called = 0;
165         last_new_size = 0;
166         last_new_align = 0;
167 
168         outstanding_array_new = 0;
169         new_array_called = 0;
170         delete_array_called = 0;
171         aligned_new_array_called = 0;
172         aligned_delete_array_called = 0;
173         last_new_array_size = 0;
174         last_new_array_align = 0;
175     }
176 
177 public:
checkOutstandingNewEq(int n)178     bool checkOutstandingNewEq(int n) const
179     {
180         return disable_checking || n == outstanding_new;
181     }
182 
checkOutstandingNewNotEq(int n)183     bool checkOutstandingNewNotEq(int n) const
184     {
185         return disable_checking || n != outstanding_new;
186     }
187 
checkNewCalledEq(int n)188     bool checkNewCalledEq(int n) const
189     {
190         return disable_checking || n == new_called;
191     }
192 
checkNewCalledNotEq(int n)193     bool checkNewCalledNotEq(int n) const
194     {
195         return disable_checking || n != new_called;
196     }
197 
checkNewCalledGreaterThan(int n)198     bool checkNewCalledGreaterThan(int n) const
199     {
200         return disable_checking || new_called > n;
201     }
202 
checkDeleteCalledEq(int n)203     bool checkDeleteCalledEq(int n) const
204     {
205         return disable_checking || n == delete_called;
206     }
207 
checkDeleteCalledNotEq(int n)208     bool checkDeleteCalledNotEq(int n) const
209     {
210         return disable_checking || n != delete_called;
211     }
212 
checkAlignedNewCalledEq(int n)213     bool checkAlignedNewCalledEq(int n) const
214     {
215         return disable_checking || n == aligned_new_called;
216     }
217 
checkAlignedNewCalledNotEq(int n)218     bool checkAlignedNewCalledNotEq(int n) const
219     {
220         return disable_checking || n != aligned_new_called;
221     }
222 
checkAlignedNewCalledGreaterThan(int n)223     bool checkAlignedNewCalledGreaterThan(int n) const
224     {
225         return disable_checking || aligned_new_called > n;
226     }
227 
checkAlignedDeleteCalledEq(int n)228     bool checkAlignedDeleteCalledEq(int n) const
229     {
230         return disable_checking || n == aligned_delete_called;
231     }
232 
checkAlignedDeleteCalledNotEq(int n)233     bool checkAlignedDeleteCalledNotEq(int n) const
234     {
235         return disable_checking || n != aligned_delete_called;
236     }
237 
checkLastNewSizeEq(std::size_t n)238     bool checkLastNewSizeEq(std::size_t n) const
239     {
240         return disable_checking || n == last_new_size;
241     }
242 
checkLastNewSizeNotEq(std::size_t n)243     bool checkLastNewSizeNotEq(std::size_t n) const
244     {
245         return disable_checking || n != last_new_size;
246     }
247 
checkLastNewAlignEq(std::size_t n)248     bool checkLastNewAlignEq(std::size_t n) const
249     {
250         return disable_checking || n == last_new_align;
251     }
252 
checkLastNewAlignNotEq(std::size_t n)253     bool checkLastNewAlignNotEq(std::size_t n) const
254     {
255         return disable_checking || n != last_new_align;
256     }
257 
checkLastDeleteAlignEq(std::size_t n)258     bool checkLastDeleteAlignEq(std::size_t n) const
259     {
260         return disable_checking || n == last_delete_align;
261     }
262 
checkLastDeleteAlignNotEq(std::size_t n)263     bool checkLastDeleteAlignNotEq(std::size_t n) const
264     {
265         return disable_checking || n != last_delete_align;
266     }
267 
checkOutstandingArrayNewEq(int n)268     bool checkOutstandingArrayNewEq(int n) const
269     {
270         return disable_checking || n == outstanding_array_new;
271     }
272 
checkOutstandingArrayNewNotEq(int n)273     bool checkOutstandingArrayNewNotEq(int n) const
274     {
275         return disable_checking || n != outstanding_array_new;
276     }
277 
checkNewArrayCalledEq(int n)278     bool checkNewArrayCalledEq(int n) const
279     {
280         return disable_checking || n == new_array_called;
281     }
282 
checkNewArrayCalledNotEq(int n)283     bool checkNewArrayCalledNotEq(int n) const
284     {
285         return disable_checking || n != new_array_called;
286     }
287 
checkDeleteArrayCalledEq(int n)288     bool checkDeleteArrayCalledEq(int n) const
289     {
290         return disable_checking || n == delete_array_called;
291     }
292 
checkDeleteArrayCalledNotEq(int n)293     bool checkDeleteArrayCalledNotEq(int n) const
294     {
295         return disable_checking || n != delete_array_called;
296     }
297 
checkAlignedNewArrayCalledEq(int n)298     bool checkAlignedNewArrayCalledEq(int n) const
299     {
300         return disable_checking || n == aligned_new_array_called;
301     }
302 
checkAlignedNewArrayCalledNotEq(int n)303     bool checkAlignedNewArrayCalledNotEq(int n) const
304     {
305         return disable_checking || n != aligned_new_array_called;
306     }
307 
checkAlignedNewArrayCalledGreaterThan(int n)308     bool checkAlignedNewArrayCalledGreaterThan(int n) const
309     {
310         return disable_checking || aligned_new_array_called > n;
311     }
312 
checkAlignedDeleteArrayCalledEq(int n)313     bool checkAlignedDeleteArrayCalledEq(int n) const
314     {
315         return disable_checking || n == aligned_delete_array_called;
316     }
317 
checkAlignedDeleteArrayCalledNotEq(int n)318     bool checkAlignedDeleteArrayCalledNotEq(int n) const
319     {
320         return disable_checking || n != aligned_delete_array_called;
321     }
322 
checkLastNewArraySizeEq(std::size_t n)323     bool checkLastNewArraySizeEq(std::size_t n) const
324     {
325         return disable_checking || n == last_new_array_size;
326     }
327 
checkLastNewArraySizeNotEq(std::size_t n)328     bool checkLastNewArraySizeNotEq(std::size_t n) const
329     {
330         return disable_checking || n != last_new_array_size;
331     }
332 
checkLastNewArrayAlignEq(std::size_t n)333     bool checkLastNewArrayAlignEq(std::size_t n) const
334     {
335         return disable_checking || n == last_new_array_align;
336     }
337 
checkLastNewArrayAlignNotEq(std::size_t n)338     bool checkLastNewArrayAlignNotEq(std::size_t n) const
339     {
340         return disable_checking || n != last_new_array_align;
341     }
342 };
343 
344 #ifdef DISABLE_NEW_COUNT
345   const bool MemCounter::disable_checking = true;
346 #else
347   const bool MemCounter::disable_checking = false;
348 #endif
349 
350 #ifdef _MSC_VER
351 #pragma warning(push)
352 #pragma warning(disable: 4640) // '%s' construction of local static object is not thread safe (/Zc:threadSafeInit-)
353 #endif // _MSC_VER
getGlobalMemCounter()354 inline MemCounter* getGlobalMemCounter() {
355   static MemCounter counter((MemCounter::MemCounterCtorArg_()));
356   return &counter;
357 }
358 #ifdef _MSC_VER
359 #pragma warning(pop)
360 #endif
361 
362 MemCounter &globalMemCounter = *getGlobalMemCounter();
363 
364 #ifndef DISABLE_NEW_COUNT
new(std::size_t s)365 void* operator new(std::size_t s) TEST_THROW_SPEC(std::bad_alloc)
366 {
367     getGlobalMemCounter()->newCalled(s);
368     void* ret = std::malloc(s);
369     if (ret == nullptr)
370         detail::throw_bad_alloc_helper();
371     return ret;
372 }
373 
delete(void * p)374 void  operator delete(void* p) TEST_NOEXCEPT
375 {
376     getGlobalMemCounter()->deleteCalled(p);
377     std::free(p);
378 }
379 
TEST_THROW_SPEC(std::bad_alloc)380 void* operator new[](std::size_t s) TEST_THROW_SPEC(std::bad_alloc)
381 {
382     getGlobalMemCounter()->newArrayCalled(s);
383     return operator new(s);
384 }
385 
386 void operator delete[](void* p) TEST_NOEXCEPT
387 {
388     getGlobalMemCounter()->deleteArrayCalled(p);
389     operator delete(p);
390 }
391 
392 #ifndef TEST_HAS_NO_ALIGNED_ALLOCATION
393 #if defined(_LIBCPP_MSVCRT_LIKE) || \
394   (!defined(_LIBCPP_VERSION) && defined(_WIN32))
395 #define USE_ALIGNED_ALLOC
396 #endif
397 
new(std::size_t s,std::align_val_t av)398 void* operator new(std::size_t s, std::align_val_t av) TEST_THROW_SPEC(std::bad_alloc) {
399   const std::size_t a = static_cast<std::size_t>(av);
400   getGlobalMemCounter()->alignedNewCalled(s, a);
401   void *ret;
402 #ifdef USE_ALIGNED_ALLOC
403   ret = _aligned_malloc(s, a);
404 #else
405   posix_memalign(&ret, a, s);
406 #endif
407   if (ret == nullptr)
408     detail::throw_bad_alloc_helper();
409   return ret;
410 }
411 
delete(void * p,std::align_val_t av)412 void operator delete(void *p, std::align_val_t av) TEST_NOEXCEPT {
413   const std::size_t a = static_cast<std::size_t>(av);
414   getGlobalMemCounter()->alignedDeleteCalled(p, a);
415   if (p) {
416 #ifdef USE_ALIGNED_ALLOC
417     ::_aligned_free(p);
418 #else
419     ::free(p);
420 #endif
421   }
422 }
423 
TEST_THROW_SPEC(std::bad_alloc)424 void* operator new[](std::size_t s, std::align_val_t av) TEST_THROW_SPEC(std::bad_alloc) {
425   const std::size_t a = static_cast<std::size_t>(av);
426   getGlobalMemCounter()->alignedNewArrayCalled(s, a);
427   return operator new(s, av);
428 }
429 
430 void operator delete[](void *p, std::align_val_t av) TEST_NOEXCEPT {
431   const std::size_t a = static_cast<std::size_t>(av);
432   getGlobalMemCounter()->alignedDeleteArrayCalled(p, a);
433   return operator delete(p, av);
434 }
435 
436 #endif // TEST_HAS_NO_ALIGNED_ALLOCATION
437 
438 #endif // DISABLE_NEW_COUNT
439 
440 struct DisableAllocationGuard {
m_disabledDisableAllocationGuard441     explicit DisableAllocationGuard(bool disable = true) : m_disabled(disable)
442     {
443         // Don't re-disable if already disabled.
444         if (globalMemCounter.disable_allocations == true) m_disabled = false;
445         if (m_disabled) globalMemCounter.disableAllocations();
446     }
447 
releaseDisableAllocationGuard448     void release() {
449         if (m_disabled) globalMemCounter.enableAllocations();
450         m_disabled = false;
451     }
452 
~DisableAllocationGuardDisableAllocationGuard453     ~DisableAllocationGuard() {
454         release();
455     }
456 
457 private:
458     bool m_disabled;
459 
460     DisableAllocationGuard(DisableAllocationGuard const&);
461     DisableAllocationGuard& operator=(DisableAllocationGuard const&);
462 };
463 
464 struct RequireAllocationGuard {
465     explicit RequireAllocationGuard(std::size_t RequireAtLeast = 1)
m_req_allocRequireAllocationGuard466             : m_req_alloc(RequireAtLeast),
467               m_new_count_on_init(globalMemCounter.new_called),
468               m_outstanding_new_on_init(globalMemCounter.outstanding_new),
469               m_exactly(false)
470     {
471     }
472 
requireAtLeastRequireAllocationGuard473     void requireAtLeast(std::size_t N) { m_req_alloc = N; m_exactly = false; }
requireExactlyRequireAllocationGuard474     void requireExactly(std::size_t N) { m_req_alloc = N; m_exactly = true; }
475 
~RequireAllocationGuardRequireAllocationGuard476     ~RequireAllocationGuard() {
477         assert(globalMemCounter.checkOutstandingNewEq(static_cast<int>(m_outstanding_new_on_init)));
478         std::size_t Expect = m_new_count_on_init + m_req_alloc;
479         assert(globalMemCounter.checkNewCalledEq(static_cast<int>(Expect)) ||
480                (!m_exactly && globalMemCounter.checkNewCalledGreaterThan(static_cast<int>(Expect))));
481     }
482 
483 private:
484     std::size_t m_req_alloc;
485     const std::size_t m_new_count_on_init;
486     const std::size_t m_outstanding_new_on_init;
487     bool m_exactly;
488     RequireAllocationGuard(RequireAllocationGuard const&);
489     RequireAllocationGuard& operator=(RequireAllocationGuard const&);
490 };
491 
492 #endif /* COUNT_NEW_H */
493