1 #pragma once
2 
3 #include "catch.hpp"
4 #include <objbase.h>
5 #include <wil/wistd_functional.h>
6 #include <wrl/implements.h>
7 
8 #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM)
9 
10 // IMallocSpy requires you to implement all methods, but we often only want one or two...
11 struct MallocSpy : Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IMallocSpy>
12 {
13     wistd::function<SIZE_T(SIZE_T)> PreAllocCallback;
PreAllocMallocSpy14     virtual SIZE_T STDMETHODCALLTYPE PreAlloc(SIZE_T requestSize) override
15     {
16         if (PreAllocCallback)
17         {
18             return PreAllocCallback(requestSize);
19         }
20 
21         return requestSize;
22     }
23 
24     wistd::function<void*(void*)> PostAllocCallback;
PostAllocMallocSpy25     virtual void* STDMETHODCALLTYPE PostAlloc(void* ptr) override
26     {
27         if (PostAllocCallback)
28         {
29             return PostAllocCallback(ptr);
30         }
31 
32         return ptr;
33     }
34 
35     wistd::function<void*(void*)> PreFreeCallback;
PreFreeMallocSpy36     virtual void* STDMETHODCALLTYPE PreFree(void* ptr, BOOL wasSpyed) override
37     {
38         if (wasSpyed && PreFreeCallback)
39         {
40             return PreFreeCallback(ptr);
41         }
42 
43         return ptr;
44     }
45 
PostFreeMallocSpy46     virtual void STDMETHODCALLTYPE PostFree(BOOL /*wasSpyed*/) override
47     {
48     }
49 
50     wistd::function<SIZE_T(void*, SIZE_T, void**)> PreReallocCallback;
PreReallocMallocSpy51     virtual SIZE_T STDMETHODCALLTYPE PreRealloc(void* ptr, SIZE_T requestSize, void** newPtr, BOOL wasSpyed) override
52     {
53         *newPtr = ptr;
54         if (wasSpyed && PreReallocCallback)
55         {
56             return PreReallocCallback(ptr, requestSize, newPtr);
57         }
58 
59         return requestSize;
60     }
61 
62     wistd::function<void*(void*)> PostReallocCallback;
PostReallocMallocSpy63     virtual void* STDMETHODCALLTYPE PostRealloc(void* ptr, BOOL wasSpyed) override
64     {
65         if (wasSpyed && PostReallocCallback)
66         {
67             return PostReallocCallback(ptr);
68         }
69 
70         return ptr;
71     }
72 
73     wistd::function<void*(void*)> PreGetSizeCallback;
PreGetSizeMallocSpy74     virtual void* STDMETHODCALLTYPE PreGetSize(void* ptr, BOOL wasSpyed) override
75     {
76         if (wasSpyed && PreGetSizeCallback)
77         {
78             return PreGetSizeCallback(ptr);
79         }
80 
81         return ptr;
82     }
83 
84     wistd::function<SIZE_T(SIZE_T)> PostGetSizeCallback;
PostGetSizeMallocSpy85     virtual SIZE_T STDMETHODCALLTYPE PostGetSize(SIZE_T size, BOOL wasSpyed) override
86     {
87         if (wasSpyed && PostGetSizeCallback)
88         {
89             return PostGetSizeCallback(size);
90         }
91 
92         return size;
93     }
94 
95     wistd::function<void*(void*)> PreDidAllocCallback;
PreDidAllocMallocSpy96     virtual void* STDMETHODCALLTYPE PreDidAlloc(void* ptr, BOOL wasSpyed) override
97     {
98         if (wasSpyed && PreDidAllocCallback)
99         {
100             return PreDidAllocCallback(ptr);
101         }
102 
103         return ptr;
104     }
105 
PostDidAllocMallocSpy106     virtual int STDMETHODCALLTYPE PostDidAlloc(void* /*ptr*/, BOOL /*wasSpyed*/, int result) override
107     {
108         return result;
109     }
110 
PreHeapMinimizeMallocSpy111     virtual void STDMETHODCALLTYPE PreHeapMinimize() override
112     {
113     }
114 
PostHeapMinimizeMallocSpy115     virtual void STDMETHODCALLTYPE PostHeapMinimize() override
116     {
117     }
118 };
119 
MakeSecureDeleterMallocSpy()120 Microsoft::WRL::ComPtr<MallocSpy> MakeSecureDeleterMallocSpy()
121 {
122     using namespace Microsoft::WRL;
123     auto result = Make<MallocSpy>();
124     REQUIRE(result);
125 
126     result->PreFreeCallback = [](void* ptr)
127     {
128         ComPtr<IMalloc> malloc;
129         if (SUCCEEDED(::CoGetMalloc(1, &malloc)))
130         {
131             auto size = malloc->GetSize(ptr);
132             auto buffer = static_cast<byte*>(ptr);
133             for (size_t i = 0; i < size; ++i)
134             {
135                 REQUIRE(buffer[i] == 0);
136             }
137         }
138 
139         return ptr;
140     };
141 
142     return result;
143 }
144 
145 #endif
146