1 #pragma once
2
3 #include "catch.hpp"
4 #include <objbase.h>
5 #include <wil/wistd_functional.h>
6 #include <wrl/implements.h>
7
8 #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM)
9
10 // IMallocSpy requires you to implement all methods, but we often only want one or two...
11 struct MallocSpy : Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IMallocSpy>
12 {
13 wistd::function<SIZE_T(SIZE_T)> PreAllocCallback;
PreAllocMallocSpy14 virtual SIZE_T STDMETHODCALLTYPE PreAlloc(SIZE_T requestSize) override
15 {
16 if (PreAllocCallback)
17 {
18 return PreAllocCallback(requestSize);
19 }
20
21 return requestSize;
22 }
23
24 wistd::function<void*(void*)> PostAllocCallback;
PostAllocMallocSpy25 virtual void* STDMETHODCALLTYPE PostAlloc(void* ptr) override
26 {
27 if (PostAllocCallback)
28 {
29 return PostAllocCallback(ptr);
30 }
31
32 return ptr;
33 }
34
35 wistd::function<void*(void*)> PreFreeCallback;
PreFreeMallocSpy36 virtual void* STDMETHODCALLTYPE PreFree(void* ptr, BOOL wasSpyed) override
37 {
38 if (wasSpyed && PreFreeCallback)
39 {
40 return PreFreeCallback(ptr);
41 }
42
43 return ptr;
44 }
45
PostFreeMallocSpy46 virtual void STDMETHODCALLTYPE PostFree(BOOL /*wasSpyed*/) override
47 {
48 }
49
50 wistd::function<SIZE_T(void*, SIZE_T, void**)> PreReallocCallback;
PreReallocMallocSpy51 virtual SIZE_T STDMETHODCALLTYPE PreRealloc(void* ptr, SIZE_T requestSize, void** newPtr, BOOL wasSpyed) override
52 {
53 *newPtr = ptr;
54 if (wasSpyed && PreReallocCallback)
55 {
56 return PreReallocCallback(ptr, requestSize, newPtr);
57 }
58
59 return requestSize;
60 }
61
62 wistd::function<void*(void*)> PostReallocCallback;
PostReallocMallocSpy63 virtual void* STDMETHODCALLTYPE PostRealloc(void* ptr, BOOL wasSpyed) override
64 {
65 if (wasSpyed && PostReallocCallback)
66 {
67 return PostReallocCallback(ptr);
68 }
69
70 return ptr;
71 }
72
73 wistd::function<void*(void*)> PreGetSizeCallback;
PreGetSizeMallocSpy74 virtual void* STDMETHODCALLTYPE PreGetSize(void* ptr, BOOL wasSpyed) override
75 {
76 if (wasSpyed && PreGetSizeCallback)
77 {
78 return PreGetSizeCallback(ptr);
79 }
80
81 return ptr;
82 }
83
84 wistd::function<SIZE_T(SIZE_T)> PostGetSizeCallback;
PostGetSizeMallocSpy85 virtual SIZE_T STDMETHODCALLTYPE PostGetSize(SIZE_T size, BOOL wasSpyed) override
86 {
87 if (wasSpyed && PostGetSizeCallback)
88 {
89 return PostGetSizeCallback(size);
90 }
91
92 return size;
93 }
94
95 wistd::function<void*(void*)> PreDidAllocCallback;
PreDidAllocMallocSpy96 virtual void* STDMETHODCALLTYPE PreDidAlloc(void* ptr, BOOL wasSpyed) override
97 {
98 if (wasSpyed && PreDidAllocCallback)
99 {
100 return PreDidAllocCallback(ptr);
101 }
102
103 return ptr;
104 }
105
PostDidAllocMallocSpy106 virtual int STDMETHODCALLTYPE PostDidAlloc(void* /*ptr*/, BOOL /*wasSpyed*/, int result) override
107 {
108 return result;
109 }
110
PreHeapMinimizeMallocSpy111 virtual void STDMETHODCALLTYPE PreHeapMinimize() override
112 {
113 }
114
PostHeapMinimizeMallocSpy115 virtual void STDMETHODCALLTYPE PostHeapMinimize() override
116 {
117 }
118 };
119
MakeSecureDeleterMallocSpy()120 Microsoft::WRL::ComPtr<MallocSpy> MakeSecureDeleterMallocSpy()
121 {
122 using namespace Microsoft::WRL;
123 auto result = Make<MallocSpy>();
124 REQUIRE(result);
125
126 result->PreFreeCallback = [](void* ptr)
127 {
128 ComPtr<IMalloc> malloc;
129 if (SUCCEEDED(::CoGetMalloc(1, &malloc)))
130 {
131 auto size = malloc->GetSize(ptr);
132 auto buffer = static_cast<byte*>(ptr);
133 for (size_t i = 0; i < size; ++i)
134 {
135 REQUIRE(buffer[i] == 0);
136 }
137 }
138
139 return ptr;
140 };
141
142 return result;
143 }
144
145 #endif
146