1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef COUNT_NEW_H
10 #define COUNT_NEW_H
11 
12 # include <cstdlib>
13 # include <cassert>
14 # include <new>
15 
16 #include "test_macros.h"
17 
18 #if defined(TEST_HAS_SANITIZERS)
19 #define DISABLE_NEW_COUNT
20 #endif
21 
22 namespace detail
23 {
24    TEST_NORETURN
throw_bad_alloc_helper()25    inline void throw_bad_alloc_helper() {
26 #ifndef TEST_HAS_NO_EXCEPTIONS
27        throw std::bad_alloc();
28 #else
29        std::abort();
30 #endif
31    }
32 }
33 
34 class MemCounter
35 {
36 public:
37     // Make MemCounter super hard to accidentally construct or copy.
38     class MemCounterCtorArg_ {};
MemCounter(MemCounterCtorArg_)39     explicit MemCounter(MemCounterCtorArg_) { reset(); }
40 
41 private:
42     MemCounter(MemCounter const &);
43     MemCounter & operator=(MemCounter const &);
44 
45 public:
46     // All checks return true when disable_checking is enabled.
47     static const bool disable_checking;
48 
49     // Disallow any allocations from occurring. Useful for testing that
50     // code doesn't perform any allocations.
51     bool disable_allocations;
52 
53     // number of allocations to throw after. Default (unsigned)-1. If
54     // throw_after has the default value it will never be decremented.
55     static const unsigned never_throw_value = static_cast<unsigned>(-1);
56     unsigned throw_after;
57 
58     int outstanding_new;
59     int new_called;
60     int delete_called;
61     int aligned_new_called;
62     int aligned_delete_called;
63     std::size_t last_new_size;
64     std::size_t last_new_align;
65     std::size_t last_delete_align;
66 
67     int outstanding_array_new;
68     int new_array_called;
69     int delete_array_called;
70     int aligned_new_array_called;
71     int aligned_delete_array_called;
72     std::size_t last_new_array_size;
73     std::size_t last_new_array_align;
74     std::size_t last_delete_array_align;
75 
76 public:
newCalled(std::size_t s)77     void newCalled(std::size_t s)
78     {
79         assert(disable_allocations == false);
80         assert(s);
81         if (throw_after == 0) {
82             throw_after = never_throw_value;
83             detail::throw_bad_alloc_helper();
84         } else if (throw_after != never_throw_value) {
85             --throw_after;
86         }
87         ++new_called;
88         ++outstanding_new;
89         last_new_size = s;
90     }
91 
alignedNewCalled(std::size_t s,std::size_t a)92     void alignedNewCalled(std::size_t s, std::size_t a) {
93       newCalled(s);
94       ++aligned_new_called;
95       last_new_align = a;
96     }
97 
deleteCalled(void * p)98     void deleteCalled(void * p)
99     {
100         assert(p);
101         --outstanding_new;
102         ++delete_called;
103     }
104 
alignedDeleteCalled(void * p,std::size_t a)105     void alignedDeleteCalled(void *p, std::size_t a) {
106       deleteCalled(p);
107       ++aligned_delete_called;
108       last_delete_align = a;
109     }
110 
newArrayCalled(std::size_t s)111     void newArrayCalled(std::size_t s)
112     {
113         assert(disable_allocations == false);
114         assert(s);
115         if (throw_after == 0) {
116             throw_after = never_throw_value;
117             detail::throw_bad_alloc_helper();
118         } else {
119             // don't decrement throw_after here. newCalled will end up doing that.
120         }
121         ++outstanding_array_new;
122         ++new_array_called;
123         last_new_array_size = s;
124     }
125 
alignedNewArrayCalled(std::size_t s,std::size_t a)126     void alignedNewArrayCalled(std::size_t s, std::size_t a) {
127       newArrayCalled(s);
128       ++aligned_new_array_called;
129       last_new_array_align = a;
130     }
131 
deleteArrayCalled(void * p)132     void deleteArrayCalled(void * p)
133     {
134         assert(p);
135         --outstanding_array_new;
136         ++delete_array_called;
137     }
138 
alignedDeleteArrayCalled(void * p,std::size_t a)139     void alignedDeleteArrayCalled(void * p, std::size_t a) {
140       deleteArrayCalled(p);
141       ++aligned_delete_array_called;
142       last_delete_array_align = a;
143     }
144 
disableAllocations()145     void disableAllocations()
146     {
147         disable_allocations = true;
148     }
149 
enableAllocations()150     void enableAllocations()
151     {
152         disable_allocations = false;
153     }
154 
reset()155     void reset()
156     {
157         disable_allocations = false;
158         throw_after = never_throw_value;
159 
160         outstanding_new = 0;
161         new_called = 0;
162         delete_called = 0;
163         aligned_new_called = 0;
164         aligned_delete_called = 0;
165         last_new_size = 0;
166         last_new_align = 0;
167 
168         outstanding_array_new = 0;
169         new_array_called = 0;
170         delete_array_called = 0;
171         aligned_new_array_called = 0;
172         aligned_delete_array_called = 0;
173         last_new_array_size = 0;
174         last_new_array_align = 0;
175     }
176 
177 public:
checkOutstandingNewEq(int n)178     bool checkOutstandingNewEq(int n) const
179     {
180         return disable_checking || n == outstanding_new;
181     }
182 
checkOutstandingNewNotEq(int n)183     bool checkOutstandingNewNotEq(int n) const
184     {
185         return disable_checking || n != outstanding_new;
186     }
187 
checkNewCalledEq(int n)188     bool checkNewCalledEq(int n) const
189     {
190         return disable_checking || n == new_called;
191     }
192 
checkNewCalledNotEq(int n)193     bool checkNewCalledNotEq(int n) const
194     {
195         return disable_checking || n != new_called;
196     }
197 
checkNewCalledGreaterThan(int n)198     bool checkNewCalledGreaterThan(int n) const
199     {
200         return disable_checking || new_called > n;
201     }
202 
checkDeleteCalledEq(int n)203     bool checkDeleteCalledEq(int n) const
204     {
205         return disable_checking || n == delete_called;
206     }
207 
checkDeleteCalledNotEq(int n)208     bool checkDeleteCalledNotEq(int n) const
209     {
210         return disable_checking || n != delete_called;
211     }
212 
checkAlignedNewCalledEq(int n)213     bool checkAlignedNewCalledEq(int n) const
214     {
215         return disable_checking || n == aligned_new_called;
216     }
217 
checkAlignedNewCalledNotEq(int n)218     bool checkAlignedNewCalledNotEq(int n) const
219     {
220         return disable_checking || n != aligned_new_called;
221     }
222 
checkAlignedNewCalledGreaterThan(int n)223     bool checkAlignedNewCalledGreaterThan(int n) const
224     {
225         return disable_checking || aligned_new_called > n;
226     }
227 
checkAlignedDeleteCalledEq(int n)228     bool checkAlignedDeleteCalledEq(int n) const
229     {
230         return disable_checking || n == aligned_delete_called;
231     }
232 
checkAlignedDeleteCalledNotEq(int n)233     bool checkAlignedDeleteCalledNotEq(int n) const
234     {
235         return disable_checking || n != aligned_delete_called;
236     }
237 
checkLastNewSizeEq(std::size_t n)238     bool checkLastNewSizeEq(std::size_t n) const
239     {
240         return disable_checking || n == last_new_size;
241     }
242 
checkLastNewSizeNotEq(std::size_t n)243     bool checkLastNewSizeNotEq(std::size_t n) const
244     {
245         return disable_checking || n != last_new_size;
246     }
247 
checkLastNewAlignEq(std::size_t n)248     bool checkLastNewAlignEq(std::size_t n) const
249     {
250         return disable_checking || n == last_new_align;
251     }
252 
checkLastNewAlignNotEq(std::size_t n)253     bool checkLastNewAlignNotEq(std::size_t n) const
254     {
255         return disable_checking || n != last_new_align;
256     }
257 
checkLastDeleteAlignEq(std::size_t n)258     bool checkLastDeleteAlignEq(std::size_t n) const
259     {
260         return disable_checking || n == last_delete_align;
261     }
262 
checkLastDeleteAlignNotEq(std::size_t n)263     bool checkLastDeleteAlignNotEq(std::size_t n) const
264     {
265         return disable_checking || n != last_delete_align;
266     }
267 
checkOutstandingArrayNewEq(int n)268     bool checkOutstandingArrayNewEq(int n) const
269     {
270         return disable_checking || n == outstanding_array_new;
271     }
272 
checkOutstandingArrayNewNotEq(int n)273     bool checkOutstandingArrayNewNotEq(int n) const
274     {
275         return disable_checking || n != outstanding_array_new;
276     }
277 
checkNewArrayCalledEq(int n)278     bool checkNewArrayCalledEq(int n) const
279     {
280         return disable_checking || n == new_array_called;
281     }
282 
checkNewArrayCalledNotEq(int n)283     bool checkNewArrayCalledNotEq(int n) const
284     {
285         return disable_checking || n != new_array_called;
286     }
287 
checkDeleteArrayCalledEq(int n)288     bool checkDeleteArrayCalledEq(int n) const
289     {
290         return disable_checking || n == delete_array_called;
291     }
292 
checkDeleteArrayCalledNotEq(int n)293     bool checkDeleteArrayCalledNotEq(int n) const
294     {
295         return disable_checking || n != delete_array_called;
296     }
297 
checkAlignedNewArrayCalledEq(int n)298     bool checkAlignedNewArrayCalledEq(int n) const
299     {
300         return disable_checking || n == aligned_new_array_called;
301     }
302 
checkAlignedNewArrayCalledNotEq(int n)303     bool checkAlignedNewArrayCalledNotEq(int n) const
304     {
305         return disable_checking || n != aligned_new_array_called;
306     }
307 
checkAlignedNewArrayCalledGreaterThan(int n)308     bool checkAlignedNewArrayCalledGreaterThan(int n) const
309     {
310         return disable_checking || aligned_new_array_called > n;
311     }
312 
checkAlignedDeleteArrayCalledEq(int n)313     bool checkAlignedDeleteArrayCalledEq(int n) const
314     {
315         return disable_checking || n == aligned_delete_array_called;
316     }
317 
checkAlignedDeleteArrayCalledNotEq(int n)318     bool checkAlignedDeleteArrayCalledNotEq(int n) const
319     {
320         return disable_checking || n != aligned_delete_array_called;
321     }
322 
checkLastNewArraySizeEq(std::size_t n)323     bool checkLastNewArraySizeEq(std::size_t n) const
324     {
325         return disable_checking || n == last_new_array_size;
326     }
327 
checkLastNewArraySizeNotEq(std::size_t n)328     bool checkLastNewArraySizeNotEq(std::size_t n) const
329     {
330         return disable_checking || n != last_new_array_size;
331     }
332 
checkLastNewArrayAlignEq(std::size_t n)333     bool checkLastNewArrayAlignEq(std::size_t n) const
334     {
335         return disable_checking || n == last_new_array_align;
336     }
337 
checkLastNewArrayAlignNotEq(std::size_t n)338     bool checkLastNewArrayAlignNotEq(std::size_t n) const
339     {
340         return disable_checking || n != last_new_array_align;
341     }
342 };
343 
344 #ifdef DISABLE_NEW_COUNT
345   const bool MemCounter::disable_checking = true;
346 #else
347   const bool MemCounter::disable_checking = false;
348 #endif
349 
getGlobalMemCounter()350 inline MemCounter* getGlobalMemCounter() {
351   static MemCounter counter((MemCounter::MemCounterCtorArg_()));
352   return &counter;
353 }
354 
355 MemCounter &globalMemCounter = *getGlobalMemCounter();
356 
357 #ifndef DISABLE_NEW_COUNT
new(std::size_t s)358 void* operator new(std::size_t s) TEST_THROW_SPEC(std::bad_alloc)
359 {
360     getGlobalMemCounter()->newCalled(s);
361     void* ret = std::malloc(s);
362     if (ret == nullptr)
363         detail::throw_bad_alloc_helper();
364     return ret;
365 }
366 
delete(void * p)367 void  operator delete(void* p) TEST_NOEXCEPT
368 {
369     getGlobalMemCounter()->deleteCalled(p);
370     std::free(p);
371 }
372 
TEST_THROW_SPEC(std::bad_alloc)373 void* operator new[](std::size_t s) TEST_THROW_SPEC(std::bad_alloc)
374 {
375     getGlobalMemCounter()->newArrayCalled(s);
376     return operator new(s);
377 }
378 
379 void operator delete[](void* p) TEST_NOEXCEPT
380 {
381     getGlobalMemCounter()->deleteArrayCalled(p);
382     operator delete(p);
383 }
384 
385 #ifndef TEST_HAS_NO_ALIGNED_ALLOCATION
386 #if defined(_LIBCPP_MSVCRT_LIKE) || \
387   (!defined(_LIBCPP_VERSION) && defined(_WIN32))
388 #define USE_ALIGNED_ALLOC
389 #endif
390 
new(std::size_t s,std::align_val_t av)391 void* operator new(std::size_t s, std::align_val_t av) TEST_THROW_SPEC(std::bad_alloc) {
392   const std::size_t a = static_cast<std::size_t>(av);
393   getGlobalMemCounter()->alignedNewCalled(s, a);
394   void *ret;
395 #ifdef USE_ALIGNED_ALLOC
396   ret = _aligned_malloc(s, a);
397 #else
398   posix_memalign(&ret, a, s);
399 #endif
400   if (ret == nullptr)
401     detail::throw_bad_alloc_helper();
402   return ret;
403 }
404 
delete(void * p,std::align_val_t av)405 void operator delete(void *p, std::align_val_t av) TEST_NOEXCEPT {
406   const std::size_t a = static_cast<std::size_t>(av);
407   getGlobalMemCounter()->alignedDeleteCalled(p, a);
408   if (p) {
409 #ifdef USE_ALIGNED_ALLOC
410     ::_aligned_free(p);
411 #else
412     ::free(p);
413 #endif
414   }
415 }
416 
TEST_THROW_SPEC(std::bad_alloc)417 void* operator new[](std::size_t s, std::align_val_t av) TEST_THROW_SPEC(std::bad_alloc) {
418   const std::size_t a = static_cast<std::size_t>(av);
419   getGlobalMemCounter()->alignedNewArrayCalled(s, a);
420   return operator new(s, av);
421 }
422 
423 void operator delete[](void *p, std::align_val_t av) TEST_NOEXCEPT {
424   const std::size_t a = static_cast<std::size_t>(av);
425   getGlobalMemCounter()->alignedDeleteArrayCalled(p, a);
426   return operator delete(p, av);
427 }
428 
429 #endif // TEST_HAS_NO_ALIGNED_ALLOCATION
430 
431 #endif // DISABLE_NEW_COUNT
432 
433 struct DisableAllocationGuard {
m_disabledDisableAllocationGuard434     explicit DisableAllocationGuard(bool disable = true) : m_disabled(disable)
435     {
436         // Don't re-disable if already disabled.
437         if (globalMemCounter.disable_allocations == true) m_disabled = false;
438         if (m_disabled) globalMemCounter.disableAllocations();
439     }
440 
releaseDisableAllocationGuard441     void release() {
442         if (m_disabled) globalMemCounter.enableAllocations();
443         m_disabled = false;
444     }
445 
~DisableAllocationGuardDisableAllocationGuard446     ~DisableAllocationGuard() {
447         release();
448     }
449 
450 private:
451     bool m_disabled;
452 
453     DisableAllocationGuard(DisableAllocationGuard const&);
454     DisableAllocationGuard& operator=(DisableAllocationGuard const&);
455 };
456 
457 struct RequireAllocationGuard {
458     explicit RequireAllocationGuard(std::size_t RequireAtLeast = 1)
m_req_allocRequireAllocationGuard459             : m_req_alloc(RequireAtLeast),
460               m_new_count_on_init(globalMemCounter.new_called),
461               m_outstanding_new_on_init(globalMemCounter.outstanding_new),
462               m_exactly(false)
463     {
464     }
465 
requireAtLeastRequireAllocationGuard466     void requireAtLeast(std::size_t N) { m_req_alloc = N; m_exactly = false; }
requireExactlyRequireAllocationGuard467     void requireExactly(std::size_t N) { m_req_alloc = N; m_exactly = true; }
468 
~RequireAllocationGuardRequireAllocationGuard469     ~RequireAllocationGuard() {
470         assert(globalMemCounter.checkOutstandingNewEq(static_cast<int>(m_outstanding_new_on_init)));
471         std::size_t Expect = m_new_count_on_init + m_req_alloc;
472         assert(globalMemCounter.checkNewCalledEq(static_cast<int>(Expect)) ||
473                (!m_exactly && globalMemCounter.checkNewCalledGreaterThan(static_cast<int>(Expect))));
474     }
475 
476 private:
477     std::size_t m_req_alloc;
478     const std::size_t m_new_count_on_init;
479     const std::size_t m_outstanding_new_on_init;
480     bool m_exactly;
481     RequireAllocationGuard(RequireAllocationGuard const&);
482     RequireAllocationGuard& operator=(RequireAllocationGuard const&);
483 };
484 
485 #endif /* COUNT_NEW_H */
486