1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4 
5 #ifndef __SOFTWARE_WRITE_WATCH_H__
6 #define __SOFTWARE_WRITE_WATCH_H__
7 
8 #include "gcinterface.h"
9 #include "gc.h"
10 
11 #ifdef FEATURE_USE_SOFTWARE_WRITE_WATCH_FOR_GC_HEAP
12 #ifndef DACCESS_COMPILE
13 
14 extern "C"
15 {
16     // Table containing the dirty state. This table is translated to exclude the lowest address it represents, see
17     // TranslateTableToExcludeHeapStartAddress.
18     extern uint8_t *g_gc_sw_ww_table;
19 
20     // Write watch may be disabled when it is not needed (between GCs for instance). This indicates whether it is enabled.
21     extern bool g_gc_sw_ww_enabled_for_gc_heap;
22 }
23 
24 class SoftwareWriteWatch
25 {
26 private:
27     // The granularity of dirty state in the table is one page. Dirtiness is tracked per byte of the table so that
28     // synchronization is not required when changing the dirty state. Shifting-right an address by the following value yields
29     // the byte index of the address into the write watch table. For instance,
30     // GetTable()[address >> AddressToTableByteIndexShift] is the byte that represents the region of memory for 'address'.
31     static const uint8_t AddressToTableByteIndexShift = SOFTWARE_WRITE_WATCH_AddressToTableByteIndexShift;
32 
33 private:
34     static void VerifyCreated();
35     static void VerifyMemoryRegion(void *baseAddress, size_t regionByteSize);
36     static void VerifyMemoryRegion(void *baseAddress, size_t regionByteSize, void *heapStartAddress, void *heapEndAddress);
37 
38 public:
39     static uint8_t *GetTable();
40 private:
41     static uint8_t *GetUntranslatedTable();
42     static uint8_t *GetUntranslatedTable(uint8_t *table, void *heapStartAddress);
43     static uint8_t *GetUntranslatedTableEnd();
44     static uint8_t *GetUntranslatedTableEnd(uint8_t *table, void *heapEndAddress);
45 public:
46     static void InitializeUntranslatedTable(uint8_t *untranslatedTable, void *heapStartAddress);
47 private:
48     static void SetUntranslatedTable(uint8_t *untranslatedTable, void *heapStartAddress);
49 public:
50     static void SetResizedUntranslatedTable(uint8_t *untranslatedTable, void *heapStartAddress, void *heapEndAddress);
51     static bool IsEnabledForGCHeap();
52     static void EnableForGCHeap();
53     static void DisableForGCHeap();
54 private:
55     static void *GetHeapStartAddress();
56     static void *GetHeapEndAddress();
57 
58 public:
59     static void StaticClose();
60 
61 private:
62     static size_t GetTableByteIndex(void *address);
63     static void *GetPageAddress(size_t tableByteIndex);
64 public:
65     static size_t GetTableByteSize(void *heapStartAddress, void *heapEndAddress);
66     static size_t GetTableStartByteOffset(size_t byteSizeBeforeTable);
67 private:
68     static uint8_t *TranslateTableToExcludeHeapStartAddress(uint8_t *table, void *heapStartAddress);
69     static void TranslateToTableRegion(void *baseAddress, size_t regionByteSize, uint8_t **tableBaseAddressRef, size_t *tableRegionByteSizeRef);
70 
71 public:
72     static void ClearDirty(void *baseAddress, size_t regionByteSize);
73     static void SetDirty(void *address, size_t writeByteSize);
74     static void SetDirtyRegion(void *baseAddress, size_t regionByteSize);
75 private:
76     static bool GetDirtyFromBlock(uint8_t *block, uint8_t *firstPageAddressInBlock, size_t startByteIndex, size_t endByteIndex, void **dirtyPages, size_t *dirtyPageIndexRef, size_t dirtyPageCount, bool clearDirty);
77 public:
78     static void GetDirty(void *baseAddress, size_t regionByteSize, void **dirtyPages, size_t *dirtyPageCountRef, bool clearDirty, bool isRuntimeSuspended);
79 };
80 
VerifyCreated()81 inline void SoftwareWriteWatch::VerifyCreated()
82 {
83     assert(GetTable() != nullptr);
84     assert(GetHeapStartAddress() != nullptr);
85     assert(GetHeapEndAddress() != nullptr);
86     assert(GetHeapStartAddress() < GetHeapEndAddress());
87 }
88 
VerifyMemoryRegion(void * baseAddress,size_t regionByteSize)89 inline void SoftwareWriteWatch::VerifyMemoryRegion(void *baseAddress, size_t regionByteSize)
90 {
91     VerifyMemoryRegion(baseAddress, regionByteSize, GetHeapStartAddress(), GetHeapEndAddress());
92 }
93 
VerifyMemoryRegion(void * baseAddress,size_t regionByteSize,void * heapStartAddress,void * heapEndAddress)94 inline void SoftwareWriteWatch::VerifyMemoryRegion(
95     void *baseAddress,
96     size_t regionByteSize,
97     void *heapStartAddress,
98     void *heapEndAddress)
99 {
100     VerifyCreated();
101     assert(baseAddress != nullptr);
102     assert(heapStartAddress != nullptr);
103     assert(heapStartAddress >= GetHeapStartAddress());
104     assert(heapEndAddress != nullptr);
105     assert(heapEndAddress <= GetHeapEndAddress());
106     assert(baseAddress >= heapStartAddress);
107     assert(baseAddress < heapEndAddress);
108     assert(regionByteSize != 0);
109     assert(regionByteSize <= reinterpret_cast<size_t>(heapEndAddress) - reinterpret_cast<size_t>(baseAddress));
110 }
111 
GetTable()112 inline uint8_t *SoftwareWriteWatch::GetTable()
113 {
114     return g_gc_sw_ww_table;
115 }
116 
GetUntranslatedTable()117 inline uint8_t *SoftwareWriteWatch::GetUntranslatedTable()
118 {
119     VerifyCreated();
120     return GetUntranslatedTable(GetTable(), GetHeapStartAddress());
121 }
122 
GetUntranslatedTable(uint8_t * table,void * heapStartAddress)123 inline uint8_t *SoftwareWriteWatch::GetUntranslatedTable(uint8_t *table, void *heapStartAddress)
124 {
125     assert(table != nullptr);
126     assert(heapStartAddress != nullptr);
127     assert(heapStartAddress >= GetHeapStartAddress());
128 
129     uint8_t *untranslatedTable = table + GetTableByteIndex(heapStartAddress);
130     assert(ALIGN_DOWN(untranslatedTable, sizeof(size_t)) == untranslatedTable);
131     return untranslatedTable;
132 }
133 
GetUntranslatedTableEnd()134 inline uint8_t *SoftwareWriteWatch::GetUntranslatedTableEnd()
135 {
136     VerifyCreated();
137     return GetUntranslatedTableEnd(GetTable(), GetHeapEndAddress());
138 }
139 
GetUntranslatedTableEnd(uint8_t * table,void * heapEndAddress)140 inline uint8_t *SoftwareWriteWatch::GetUntranslatedTableEnd(uint8_t *table, void *heapEndAddress)
141 {
142     assert(table != nullptr);
143     assert(heapEndAddress != nullptr);
144     assert(heapEndAddress <= GetHeapEndAddress());
145 
146     return ALIGN_UP(&table[GetTableByteIndex(reinterpret_cast<uint8_t *>(heapEndAddress) - 1) + 1], sizeof(size_t));
147 }
148 
InitializeUntranslatedTable(uint8_t * untranslatedTable,void * heapStartAddress)149 inline void SoftwareWriteWatch::InitializeUntranslatedTable(uint8_t *untranslatedTable, void *heapStartAddress)
150 {
151     assert(GetTable() == nullptr);
152     SetUntranslatedTable(untranslatedTable, heapStartAddress);
153 }
154 
SetUntranslatedTable(uint8_t * untranslatedTable,void * heapStartAddress)155 inline void SoftwareWriteWatch::SetUntranslatedTable(uint8_t *untranslatedTable, void *heapStartAddress)
156 {
157     assert(untranslatedTable != nullptr);
158     assert(ALIGN_DOWN(untranslatedTable, sizeof(size_t)) == untranslatedTable);
159     assert(heapStartAddress != nullptr);
160 
161     g_gc_sw_ww_table = TranslateTableToExcludeHeapStartAddress(untranslatedTable, heapStartAddress);
162 }
163 
SetResizedUntranslatedTable(uint8_t * untranslatedTable,void * heapStartAddress,void * heapEndAddress)164 inline void SoftwareWriteWatch::SetResizedUntranslatedTable(
165     uint8_t *untranslatedTable,
166     void *heapStartAddress,
167     void *heapEndAddress)
168 {
169     // The runtime needs to be suspended during this call, and background GC threads need to synchronize calls to ClearDirty()
170     // and GetDirty() such that they are not called concurrently with this function
171 
172     VerifyCreated();
173     assert(untranslatedTable != nullptr);
174     assert(ALIGN_DOWN(untranslatedTable, sizeof(size_t)) == untranslatedTable);
175     assert(heapStartAddress != nullptr);
176     assert(heapEndAddress != nullptr);
177     assert(heapStartAddress <= GetHeapStartAddress());
178     assert(heapEndAddress >= GetHeapEndAddress());
179     assert(heapStartAddress < GetHeapStartAddress() || heapEndAddress > GetHeapEndAddress());
180 
181     uint8_t *oldUntranslatedTable = GetUntranslatedTable();
182     void *oldTableHeapStartAddress = GetHeapStartAddress();
183     size_t oldTableByteSize = GetTableByteSize(oldTableHeapStartAddress, GetHeapEndAddress());
184     SetUntranslatedTable(untranslatedTable, heapStartAddress);
185 
186     uint8_t *tableRegionStart = &GetTable()[GetTableByteIndex(oldTableHeapStartAddress)];
187     memcpy(tableRegionStart, oldUntranslatedTable, oldTableByteSize);
188 }
189 
IsEnabledForGCHeap()190 inline bool SoftwareWriteWatch::IsEnabledForGCHeap()
191 {
192     return g_gc_sw_ww_enabled_for_gc_heap;
193 }
194 
EnableForGCHeap()195 inline void SoftwareWriteWatch::EnableForGCHeap()
196 {
197     // The runtime needs to be suspended during this call. This is how it currently guarantees that GC heap writes from other
198     // threads between calls to EnableForGCHeap() and DisableForGCHeap() will be tracked.
199 
200     VerifyCreated();
201     assert(!IsEnabledForGCHeap());
202     g_gc_sw_ww_enabled_for_gc_heap = true;
203 
204     WriteBarrierParameters args = {};
205     args.operation = WriteBarrierOp::SwitchToWriteWatch;
206     args.write_watch_table = g_gc_sw_ww_table;
207     args.is_runtime_suspended = true;
208     GCToEEInterface::StompWriteBarrier(&args);
209 }
210 
DisableForGCHeap()211 inline void SoftwareWriteWatch::DisableForGCHeap()
212 {
213     // The runtime needs to be suspended during this call. This is how it currently guarantees that GC heap writes from other
214     // threads between calls to EnableForGCHeap() and DisableForGCHeap() will be tracked.
215 
216     VerifyCreated();
217     assert(IsEnabledForGCHeap());
218     g_gc_sw_ww_enabled_for_gc_heap = false;
219 
220     WriteBarrierParameters args = {};
221     args.operation = WriteBarrierOp::SwitchToNonWriteWatch;
222     args.is_runtime_suspended = true;
223     GCToEEInterface::StompWriteBarrier(&args);
224 }
225 
GetHeapStartAddress()226 inline void *SoftwareWriteWatch::GetHeapStartAddress()
227 {
228     return g_gc_lowest_address;
229 }
230 
GetHeapEndAddress()231 inline void *SoftwareWriteWatch::GetHeapEndAddress()
232 {
233     return g_gc_highest_address;
234 }
235 
GetTableByteIndex(void * address)236 inline size_t SoftwareWriteWatch::GetTableByteIndex(void *address)
237 {
238     assert(address != nullptr);
239 
240     size_t tableByteIndex = reinterpret_cast<size_t>(address) >> AddressToTableByteIndexShift;
241     assert(tableByteIndex != 0);
242     return tableByteIndex;
243 }
244 
GetPageAddress(size_t tableByteIndex)245 inline void *SoftwareWriteWatch::GetPageAddress(size_t tableByteIndex)
246 {
247     assert(tableByteIndex != 0);
248 
249     void *pageAddress = reinterpret_cast<void *>(tableByteIndex << AddressToTableByteIndexShift);
250     assert(pageAddress >= GetHeapStartAddress());
251     assert(pageAddress < GetHeapEndAddress());
252     assert(ALIGN_DOWN(pageAddress, OS_PAGE_SIZE) == pageAddress);
253     return pageAddress;
254 }
255 
GetTableByteSize(void * heapStartAddress,void * heapEndAddress)256 inline size_t SoftwareWriteWatch::GetTableByteSize(void *heapStartAddress, void *heapEndAddress)
257 {
258     assert(heapStartAddress != nullptr);
259     assert(heapEndAddress != nullptr);
260     assert(heapStartAddress < heapEndAddress);
261 
262     size_t tableByteSize =
263         GetTableByteIndex(reinterpret_cast<uint8_t *>(heapEndAddress) - 1) - GetTableByteIndex(heapStartAddress) + 1;
264     tableByteSize = ALIGN_UP(tableByteSize, sizeof(size_t));
265     return tableByteSize;
266 }
267 
GetTableStartByteOffset(size_t byteSizeBeforeTable)268 inline size_t SoftwareWriteWatch::GetTableStartByteOffset(size_t byteSizeBeforeTable)
269 {
270     return ALIGN_UP(byteSizeBeforeTable, sizeof(size_t)); // start of the table needs to be aligned to size_t
271 }
272 
TranslateTableToExcludeHeapStartAddress(uint8_t * table,void * heapStartAddress)273 inline uint8_t *SoftwareWriteWatch::TranslateTableToExcludeHeapStartAddress(uint8_t *table, void *heapStartAddress)
274 {
275     assert(table != nullptr);
276     assert(heapStartAddress != nullptr);
277 
278     // Exclude the table byte index corresponding to the heap start address from the table pointer, so that each lookup in the
279     // table by address does not have to calculate (address - heapStartAddress)
280     return table - GetTableByteIndex(heapStartAddress);
281 }
282 
TranslateToTableRegion(void * baseAddress,size_t regionByteSize,uint8_t ** tableBaseAddressRef,size_t * tableRegionByteSizeRef)283 inline void SoftwareWriteWatch::TranslateToTableRegion(
284     void *baseAddress,
285     size_t regionByteSize,
286     uint8_t **tableBaseAddressRef,
287     size_t *tableRegionByteSizeRef)
288 {
289     VerifyCreated();
290     VerifyMemoryRegion(baseAddress, regionByteSize);
291     assert(tableBaseAddressRef != nullptr);
292     assert(tableRegionByteSizeRef != nullptr);
293 
294     size_t baseAddressTableByteIndex = GetTableByteIndex(baseAddress);
295     *tableBaseAddressRef = &GetTable()[baseAddressTableByteIndex];
296     *tableRegionByteSizeRef =
297         GetTableByteIndex(reinterpret_cast<uint8_t *>(baseAddress) + (regionByteSize - 1)) - baseAddressTableByteIndex + 1;
298 }
299 
ClearDirty(void * baseAddress,size_t regionByteSize)300 inline void SoftwareWriteWatch::ClearDirty(void *baseAddress, size_t regionByteSize)
301 {
302     VerifyCreated();
303     VerifyMemoryRegion(baseAddress, regionByteSize);
304 
305     uint8_t *tableBaseAddress;
306     size_t tableRegionByteSize;
307     TranslateToTableRegion(baseAddress, regionByteSize, &tableBaseAddress, &tableRegionByteSize);
308     memset(tableBaseAddress, 0, tableRegionByteSize);
309 }
310 
SetDirty(void * address,size_t writeByteSize)311 inline void SoftwareWriteWatch::SetDirty(void *address, size_t writeByteSize)
312 {
313     VerifyCreated();
314     VerifyMemoryRegion(address, writeByteSize);
315     assert(address != nullptr);
316     assert(writeByteSize <= sizeof(void *));
317 
318     size_t tableByteIndex = GetTableByteIndex(address);
319     assert(GetTableByteIndex(reinterpret_cast<uint8_t *>(address) + (writeByteSize - 1)) == tableByteIndex);
320 
321     uint8_t *tableByteAddress = &GetTable()[tableByteIndex];
322     if (*tableByteAddress == 0)
323     {
324         *tableByteAddress = 0xff;
325     }
326 }
327 
SetDirtyRegion(void * baseAddress,size_t regionByteSize)328 inline void SoftwareWriteWatch::SetDirtyRegion(void *baseAddress, size_t regionByteSize)
329 {
330     VerifyCreated();
331     VerifyMemoryRegion(baseAddress, regionByteSize);
332 
333     uint8_t *tableBaseAddress;
334     size_t tableRegionByteSize;
335     TranslateToTableRegion(baseAddress, regionByteSize, &tableBaseAddress, &tableRegionByteSize);
336     memset(tableBaseAddress, ~0, tableRegionByteSize);
337 }
338 
339 #endif // !DACCESS_COMPILE
340 #endif // FEATURE_USE_SOFTWARE_WRITE_WATCH_FOR_GC_HEAP
341 #endif // !__SOFTWARE_WRITE_WATCH_H__
342