1 /**************************************************************************
2  *
3  * Copyright 2016-2018 VMware, Inc.
4  * Copyright 2011-2012 Jose Fonseca
5  * All Rights Reserved.
6  *
7  * Permission is hereby granted, free of charge, to any person obtaining a copy
8  * of this software and associated documentation files (the "Software"), to deal
9  * in the Software without restriction, including without limitation the rights
10  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11  * copies of the Software, and to permit persons to whom the Software is
12  * furnished to do so, subject to the following conditions:
13  *
14  * The above copyright notice and this permission notice shall be included in
15  * all copies or substantial portions of the Software.
16  *
17  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23  * THE SOFTWARE.
24  *
25  **************************************************************************/
26 
27 
28 /*
29  * Code for the DLL that will be injected in the target process.
30  *
31  * The injected DLL will manipulate the import tables to hook the
32  * modules/functions of interest.
33  *
34  * See also:
35  * - http://www.codeproject.com/KB/system/api_spying_hack.aspx
36  * - http://www.codeproject.com/KB/threads/APIHooking.aspx
37  * - http://msdn.microsoft.com/en-us/magazine/cc301808.aspx
38  */
39 
40 
41 #include <assert.h>
42 #include <stdio.h>
43 #include <stdarg.h>
44 #include <string.h>
45 
46 #include <algorithm>
47 #include <set>
48 #include <map>
49 #include <functional>
50 #include <iterator>
51 
52 #include <windows.h>
53 #include <tlhelp32.h>
54 #include <delayimp.h>
55 
56 #include "inject.h"
57 
58 
59 static int VERBOSITY = 0;
60 #define NOOP 0
61 
62 
63 static CRITICAL_SECTION g_Mutex;
64 
65 
66 
67 static HMODULE g_hThisModule = NULL;
68 static HMODULE g_hHookModule = NULL;
69 
70 
71 static void
debugPrintf(const char * format,...)72 debugPrintf(const char *format, ...)
73 {
74     char buf[512];
75 
76     va_list ap;
77     va_start(ap, format);
78     _vsnprintf(buf, sizeof buf, format, ap);
79     va_end(ap);
80 
81     OutputDebugStringA(buf);
82 }
83 
84 
85 EXTERN_C void
_assert(const char * _Message,const char * _File,unsigned _Line)86 _assert(const char *_Message, const char *_File, unsigned _Line)
87 {
88     debugPrintf("Assertion failed: %s, file %s, line %u\n", _Message, _File, _Line);
89     TerminateProcess(GetCurrentProcess(), 1);
90 }
91 
92 
93 EXTERN_C void
_wassert(const wchar_t * _Message,const wchar_t * _File,unsigned _Line)94 _wassert(const wchar_t * _Message, const wchar_t *_File, unsigned _Line)
95 {
96     debugPrintf("Assertion failed: %S, file %S, line %u\n", _Message, _File, _Line);
97     TerminateProcess(GetCurrentProcess(), 1);
98 }
99 
100 
101 static HMODULE WINAPI
102 MyLoadLibraryA(LPCSTR lpLibFileName);
103 
104 static HMODULE WINAPI
105 MyLoadLibraryW(LPCWSTR lpLibFileName);
106 
107 static HMODULE WINAPI
108 MyLoadLibraryExA(LPCSTR lpFileName, HANDLE hFile, DWORD dwFlags);
109 
110 static HMODULE WINAPI
111 MyLoadLibraryExW(LPCWSTR lpFileName, HANDLE hFile, DWORD dwFlags);
112 
113 static FARPROC WINAPI
114 MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName);
115 
116 
117 static void
MyCreateProcessCommon(BOOL bRet,DWORD dwCreationFlags,LPPROCESS_INFORMATION lpProcessInformation)118 MyCreateProcessCommon(BOOL bRet,
119                       DWORD dwCreationFlags,
120                       LPPROCESS_INFORMATION lpProcessInformation)
121 {
122     if (!bRet) {
123         debugPrintf("inject: warning: failed to create child process\n");
124         return;
125     }
126 
127     DWORD dwLastError = GetLastError();
128 
129     if (isDifferentArch(lpProcessInformation->hProcess)) {
130         debugPrintf("inject: error: child process %lu has different architecture\n",
131                     GetProcessId(lpProcessInformation->hProcess));
132     } else {
133         char szDllPath[MAX_PATH];
134         GetModuleFileNameA(g_hThisModule, szDllPath, sizeof szDllPath);
135 
136         if (!injectDll(lpProcessInformation->hProcess, szDllPath)) {
137             debugPrintf("inject: warning: failed to inject into child process %lu\n",
138                         GetProcessId(lpProcessInformation->hProcess));
139         }
140     }
141 
142     if (!(dwCreationFlags & CREATE_SUSPENDED)) {
143         ResumeThread(lpProcessInformation->hThread);
144     }
145 
146     SetLastError(dwLastError);
147 }
148 
149 static BOOL WINAPI
MyCreateProcessA(LPCSTR lpApplicationName,LPSTR lpCommandLine,LPSECURITY_ATTRIBUTES lpProcessAttributes,LPSECURITY_ATTRIBUTES lpThreadAttributes,BOOL bInheritHandles,DWORD dwCreationFlags,LPVOID lpEnvironment,LPCSTR lpCurrentDirectory,LPSTARTUPINFOA lpStartupInfo,LPPROCESS_INFORMATION lpProcessInformation)150 MyCreateProcessA(LPCSTR lpApplicationName,
151                  LPSTR lpCommandLine,
152                  LPSECURITY_ATTRIBUTES lpProcessAttributes,
153                  LPSECURITY_ATTRIBUTES lpThreadAttributes,
154                  BOOL bInheritHandles,
155                  DWORD dwCreationFlags,
156                  LPVOID lpEnvironment,
157                  LPCSTR lpCurrentDirectory,
158                  LPSTARTUPINFOA lpStartupInfo,
159                  LPPROCESS_INFORMATION lpProcessInformation)
160 {
161     if (VERBOSITY >= 2) {
162         debugPrintf("inject: intercepting %s(\"%s\", \"%s\", ...)\n",
163                     __FUNCTION__,
164                     lpApplicationName,
165                     lpCommandLine);
166     }
167 
168     BOOL bRet;
169     bRet = CreateProcessA(lpApplicationName,
170                           lpCommandLine,
171                           lpProcessAttributes,
172                           lpThreadAttributes,
173                           bInheritHandles,
174                           dwCreationFlags | CREATE_SUSPENDED,
175                           lpEnvironment,
176                           lpCurrentDirectory,
177                           lpStartupInfo,
178                           lpProcessInformation);
179 
180     MyCreateProcessCommon(bRet, dwCreationFlags, lpProcessInformation);
181 
182     return bRet;
183 }
184 
185 static BOOL WINAPI
MyCreateProcessW(LPCWSTR lpApplicationName,LPWSTR lpCommandLine,LPSECURITY_ATTRIBUTES lpProcessAttributes,LPSECURITY_ATTRIBUTES lpThreadAttributes,BOOL bInheritHandles,DWORD dwCreationFlags,LPVOID lpEnvironment,LPCWSTR lpCurrentDirectory,LPSTARTUPINFOW lpStartupInfo,LPPROCESS_INFORMATION lpProcessInformation)186 MyCreateProcessW(LPCWSTR lpApplicationName,
187                  LPWSTR lpCommandLine,
188                  LPSECURITY_ATTRIBUTES lpProcessAttributes,
189                  LPSECURITY_ATTRIBUTES lpThreadAttributes,
190                  BOOL bInheritHandles,
191                  DWORD dwCreationFlags,
192                  LPVOID lpEnvironment,
193                  LPCWSTR lpCurrentDirectory,
194                  LPSTARTUPINFOW lpStartupInfo,
195                  LPPROCESS_INFORMATION lpProcessInformation)
196 {
197     if (VERBOSITY >= 2) {
198         debugPrintf("inject: intercepting %s(\"%S\", \"%S\", ...)\n",
199                     __FUNCTION__,
200                     lpApplicationName,
201                     lpCommandLine);
202     }
203 
204     BOOL bRet;
205     bRet = CreateProcessW(lpApplicationName,
206                           lpCommandLine,
207                           lpProcessAttributes,
208                           lpThreadAttributes,
209                           bInheritHandles,
210                           dwCreationFlags | CREATE_SUSPENDED,
211                           lpEnvironment,
212                           lpCurrentDirectory,
213                           lpStartupInfo,
214                           lpProcessInformation);
215 
216     MyCreateProcessCommon(bRet, dwCreationFlags, lpProcessInformation);
217 
218     return bRet;
219 }
220 
221 typedef BOOL
222 (WINAPI *PFNCREATEPROCESSASUSERW) (HANDLE, LPCWSTR, LPWSTR,
223         LPSECURITY_ATTRIBUTES, LPSECURITY_ATTRIBUTES, BOOL, DWORD, LPVOID,
224         LPCWSTR, LPSTARTUPINFOW, LPPROCESS_INFORMATION);
225 
226 static PFNCREATEPROCESSASUSERW pfnCreateProcessAsUserW;
227 
228 static BOOL WINAPI
MyCreateProcessAsUserW(HANDLE hToken,LPCWSTR lpApplicationName,LPWSTR lpCommandLine,LPSECURITY_ATTRIBUTES lpProcessAttributes,LPSECURITY_ATTRIBUTES lpThreadAttributes,BOOL bInheritHandles,DWORD dwCreationFlags,LPVOID lpEnvironment,LPCWSTR lpCurrentDirectory,LPSTARTUPINFOW lpStartupInfo,LPPROCESS_INFORMATION lpProcessInformation)229 MyCreateProcessAsUserW(HANDLE hToken,
230                        LPCWSTR lpApplicationName,
231                        LPWSTR lpCommandLine,
232                        LPSECURITY_ATTRIBUTES lpProcessAttributes,
233                        LPSECURITY_ATTRIBUTES lpThreadAttributes,
234                        BOOL bInheritHandles,
235                        DWORD dwCreationFlags,
236                        LPVOID lpEnvironment,
237                        LPCWSTR lpCurrentDirectory,
238                        LPSTARTUPINFOW lpStartupInfo,
239                        LPPROCESS_INFORMATION lpProcessInformation)
240 {
241     if (VERBOSITY >= 2) {
242         debugPrintf("inject: intercepting %s(\"%S\", \"%S\", ...)\n",
243                     __FUNCTION__,
244                     lpApplicationName,
245                     lpCommandLine);
246     }
247 
248     // Certain WINE versions (at least 1.6.2) don't export
249     // kernel32.dll!CreateProcessAsUserW
250     assert(pfnCreateProcessAsUserW);
251 
252     BOOL bRet;
253     bRet = pfnCreateProcessAsUserW(hToken,
254                                    lpApplicationName,
255                                    lpCommandLine,
256                                    lpProcessAttributes,
257                                    lpThreadAttributes,
258                                    bInheritHandles,
259                                    dwCreationFlags,
260                                    lpEnvironment,
261                                    lpCurrentDirectory,
262                                    lpStartupInfo,
263                                    lpProcessInformation);
264 
265     MyCreateProcessCommon(bRet, dwCreationFlags, lpProcessInformation);
266 
267     return bRet;
268 }
269 
270 
271 template< class T, class I >
272 inline T *
rvaToVa(HMODULE hModule,I rva)273 rvaToVa(HMODULE hModule, I rva)
274 {
275     assert(rva != 0);
276     return reinterpret_cast<T *>(reinterpret_cast<PBYTE>(hModule) + rva);
277 }
278 
279 
280 static const char *
getDescriptorName(HMODULE hModule,const PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor)281 getDescriptorName(HMODULE hModule,
282                         const PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor)
283 {
284     return rvaToVa<const char>(hModule, pImportDescriptor->Name);
285 }
286 
287 static const char *
getDescriptorName(HMODULE hModule,const PImgDelayDescr pDelayDescriptor)288 getDescriptorName(HMODULE hModule,
289                   const PImgDelayDescr pDelayDescriptor)
290 {
291     if (pDelayDescriptor->grAttrs & dlattrRva) {
292         return rvaToVa<const char>(hModule, pDelayDescriptor->rvaDLLName);
293     } else {
294 #ifdef _WIN64
295         assert(pDelayDescriptor->grAttrs & dlattrRva);
296         return "???";
297 #else
298         // old-style, with ImgDelayDescr::szName being a LPCSTR
299         return reinterpret_cast<LPCSTR>(pDelayDescriptor->rvaDLLName);
300 #endif
301     }
302 }
303 
304 
305 static PIMAGE_OPTIONAL_HEADER
getOptionalHeader(HMODULE hModule,const char * szModule)306 getOptionalHeader(HMODULE hModule,
307                   const char *szModule)
308 {
309     PIMAGE_DOS_HEADER pDosHeader = reinterpret_cast<PIMAGE_DOS_HEADER>(hModule);
310     if (pDosHeader->e_magic != IMAGE_DOS_SIGNATURE) {
311         debugPrintf("inject: warning: %s: unexpected DOS header magic (0x%04x)\n",
312                     szModule, pDosHeader->e_magic);
313         return NULL;
314     }
315     PIMAGE_NT_HEADERS pNtHeaders = rvaToVa<IMAGE_NT_HEADERS>(hModule, pDosHeader->e_lfanew);
316     if (pNtHeaders->Signature != IMAGE_NT_SIGNATURE) {
317         debugPrintf("inject: warning: %s: unexpected NT header signature (0x%08lx)\n",
318                     szModule, pNtHeaders->Signature);
319         return NULL;
320     }
321 
322     /*
323      * Handle gracefully DLL might have been loaded for resources,
324      * LOAD_LIBRARY_AS_DATAFILE.
325      */
326     const WORD Machine =
327 #ifdef _WIN64
328         IMAGE_FILE_MACHINE_AMD64
329 #else
330         IMAGE_FILE_MACHINE_I386
331 #endif
332     ;
333     if (pNtHeaders->FileHeader.Machine != Machine) {
334         debugPrintf("inject: warning: %s: ignoring different machine (0x%02x)\n",
335                     szModule, pNtHeaders->FileHeader.Machine);
336         return nullptr;
337     }
338     if (pNtHeaders->FileHeader.SizeOfOptionalHeader < sizeof pNtHeaders->OptionalHeader) {
339         debugPrintf("inject: warning: %s: SizeOfOptionalHeader too small (%u)\n",
340                     szModule, pNtHeaders->FileHeader.SizeOfOptionalHeader);
341         return nullptr;
342     }
343 
344     PIMAGE_OPTIONAL_HEADER pOptionalHeader = &pNtHeaders->OptionalHeader;
345     return pOptionalHeader;
346 }
347 
348 template< typename T >
349 static bool
getImageDirectoryEntry(HMODULE hModule,const char * szModule,UINT Entry,T ** ppEntry)350 getImageDirectoryEntry(HMODULE hModule,
351                        const char *szModule,
352                        UINT Entry,
353                        T ** ppEntry)
354 {
355     MEMORY_BASIC_INFORMATION MemoryInfo;
356     if (VirtualQuery(hModule, &MemoryInfo, sizeof MemoryInfo) != sizeof MemoryInfo) {
357         debugPrintf("inject: warning: %s: VirtualQuery failed\n", szModule);
358         return false;
359     }
360     if (MemoryInfo.Protect & (PAGE_NOACCESS | PAGE_EXECUTE)) {
361         debugPrintf("inject: warning: %s: no read access (Protect = 0x%08lx)\n", szModule, MemoryInfo.Protect);
362         return false;
363     }
364 
365     PIMAGE_OPTIONAL_HEADER pOptionalHeader = getOptionalHeader(hModule, szModule);
366     if (!pOptionalHeader) {
367         return false;
368     }
369 
370     assert(pOptionalHeader->NumberOfRvaAndSizes > Entry);
371     if (pOptionalHeader->DataDirectory[Entry].Size == 0) {
372         return false;
373     }
374 
375     assert(pOptionalHeader->DataDirectory[Entry].Size >= sizeof(T));
376     UINT_PTR ImportAddress = pOptionalHeader->DataDirectory[Entry].VirtualAddress;
377     if (!ImportAddress) {
378         return false;
379     }
380 
381     *ppEntry = rvaToVa<T>(hModule, ImportAddress);
382     return true;
383 }
384 
385 
386 static PIMAGE_IMPORT_DESCRIPTOR
getFirstImportDescriptor(HMODULE hModule,const char * szModule)387 getFirstImportDescriptor(HMODULE hModule, const char *szModule)
388 {
389     PIMAGE_IMPORT_DESCRIPTOR pEntry = nullptr;
390     return getImageDirectoryEntry(hModule, szModule, IMAGE_DIRECTORY_ENTRY_IMPORT, &pEntry) ? pEntry : nullptr;
391 }
392 
393 
394 static PImgDelayDescr
getDelayImportDescriptor(HMODULE hModule,const char * szModule)395 getDelayImportDescriptor(HMODULE hModule, const char *szModule)
396 {
397     PImgDelayDescr pEntry = nullptr;
398     return getImageDirectoryEntry(hModule, szModule, IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT, &pEntry) ? pEntry : nullptr;
399 }
400 
401 
402 static PIMAGE_EXPORT_DIRECTORY
getExportDescriptor(HMODULE hModule)403 getExportDescriptor(HMODULE hModule)
404 {
405     PIMAGE_EXPORT_DIRECTORY pEntry = nullptr;
406     return getImageDirectoryEntry(hModule, "(wrapper)", IMAGE_DIRECTORY_ENTRY_EXPORT, &pEntry) ? pEntry : nullptr;
407 }
408 
409 
410 static BOOL
replaceAddress(LPVOID * lpOldAddress,LPVOID lpNewAddress)411 replaceAddress(LPVOID *lpOldAddress, LPVOID lpNewAddress)
412 {
413     DWORD flOldProtect;
414 
415     if (*lpOldAddress == lpNewAddress) {
416         return TRUE;
417     }
418 
419     EnterCriticalSection(&g_Mutex);
420 
421     if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, PAGE_READWRITE, &flOldProtect))) {
422         LeaveCriticalSection(&g_Mutex);
423         return FALSE;
424     }
425 
426     *lpOldAddress = lpNewAddress;
427 
428     if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, flOldProtect, &flOldProtect))) {
429         LeaveCriticalSection(&g_Mutex);
430         return FALSE;
431     }
432 
433     LeaveCriticalSection(&g_Mutex);
434     return TRUE;
435 }
436 
437 
438 /* Return pointer to patcheable function address.
439  *
440  * See also:
441  *
442  * - An In-Depth Look into the Win32 Portable Executable File Format, Part 2, Matt Pietrek,
443  *   http://msdn.microsoft.com/en-gb/magazine/cc301808.aspx
444  *
445  * - http://www.microsoft.com/msj/1298/hood/hood1298.aspx
446  *
447  */
448 static LPVOID *
getPatchAddress(HMODULE hModule,const char * szDescriptorName,DWORD OriginalFirstThunk,DWORD FirstThunk,const char * pszFunctionName,LPVOID lpOldAddress)449 getPatchAddress(HMODULE hModule,
450                 const char *szDescriptorName,
451                 DWORD OriginalFirstThunk,
452                 DWORD FirstThunk,
453                 const char* pszFunctionName,
454                 LPVOID lpOldAddress)
455 {
456     if (VERBOSITY >= 4) {
457         debugPrintf("inject: %s(%s, %s)\n", __FUNCTION__,
458                     szDescriptorName,
459                     pszFunctionName);
460     }
461 
462     PIMAGE_THUNK_DATA pThunkIAT = rvaToVa<IMAGE_THUNK_DATA>(hModule, FirstThunk);
463 
464     UINT_PTR pOldFunction = (UINT_PTR)lpOldAddress;
465 
466     PIMAGE_THUNK_DATA pThunk;
467     if (OriginalFirstThunk) {
468         pThunk = rvaToVa<IMAGE_THUNK_DATA>(hModule, OriginalFirstThunk);
469     } else {
470         pThunk = pThunkIAT;
471     }
472 
473     while (pThunk->u1.Function) {
474         if (OriginalFirstThunk == 0 ||
475             pThunk->u1.Ordinal & IMAGE_ORDINAL_FLAG) {
476             // No name -- search by the real function address
477             if (!pOldFunction) {
478                 return NULL;
479             }
480             if (pThunkIAT->u1.Function == pOldFunction) {
481                 return (LPVOID *)(&pThunkIAT->u1.Function);
482             }
483         } else {
484             // Search by name
485             PIMAGE_IMPORT_BY_NAME pImport = rvaToVa<IMAGE_IMPORT_BY_NAME>(hModule, pThunk->u1.AddressOfData);
486             const char* szName = (const char* )pImport->Name;
487             if (strcmp(pszFunctionName, szName) == 0) {
488                 return (LPVOID *)(&pThunkIAT->u1.Function);
489             }
490         }
491         ++pThunk;
492         ++pThunkIAT;
493     }
494 
495     return NULL;
496 }
497 
498 
499 static LPVOID *
getPatchAddress(HMODULE hModule,PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,const char * pszFunctionName,LPVOID lpOldAddress)500 getPatchAddress(HMODULE hModule,
501                 PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
502                 const char* pszFunctionName,
503                 LPVOID lpOldAddress)
504 {
505     assert(pImportDescriptor->TimeDateStamp != 0 || pImportDescriptor->Name != 0);
506 
507     return getPatchAddress(hModule,
508                            getDescriptorName(hModule, pImportDescriptor),
509                            pImportDescriptor->OriginalFirstThunk,
510                            pImportDescriptor->FirstThunk,
511                            pszFunctionName,
512                            lpOldAddress);
513 }
514 
515 
516 // See
517 // http://www.microsoft.com/msj/1298/hood/hood1298.aspx
518 // http://msdn.microsoft.com/en-us/library/16b2dyk5.aspx
519 static LPVOID *
getPatchAddress(HMODULE hModule,PImgDelayDescr pDelayDescriptor,const char * pszFunctionName,LPVOID lpOldAddress)520 getPatchAddress(HMODULE hModule,
521                 PImgDelayDescr pDelayDescriptor,
522                 const char* pszFunctionName,
523                 LPVOID lpOldAddress)
524 {
525     assert(pDelayDescriptor->rvaDLLName != 0);
526 
527     return getPatchAddress(hModule,
528                            getDescriptorName(hModule, pDelayDescriptor),
529                            pDelayDescriptor->rvaINT,
530                            pDelayDescriptor->rvaIAT,
531                            pszFunctionName,
532                            lpOldAddress);
533 }
534 
535 
536 template< class T >
537 static BOOL
patchFunction(HMODULE hModule,const char * szModule,const char * pszDllName,T pImportDescriptor,const char * pszFunctionName,LPVOID lpOldAddress,LPVOID lpNewAddress)538 patchFunction(HMODULE hModule,
539               const char *szModule,
540               const char *pszDllName,
541               T pImportDescriptor,
542               const char *pszFunctionName,
543               LPVOID lpOldAddress,
544               LPVOID lpNewAddress)
545 {
546     LPVOID* lpPatchAddress = getPatchAddress(hModule, pImportDescriptor, pszFunctionName, lpOldAddress);
547     if (lpPatchAddress == NULL) {
548         return FALSE;
549     }
550 
551     if (*lpPatchAddress == lpNewAddress) {
552         return TRUE;
553     }
554 
555     DWORD Offset = (DWORD)(UINT_PTR)lpPatchAddress - (UINT_PTR)hModule;
556     if (VERBOSITY > 0) {
557         debugPrintf("inject: patching %s!0x%lx -> %s!%s\n", szModule, Offset, pszDllName, pszFunctionName);
558     }
559 
560     BOOL bRet;
561     bRet = replaceAddress(lpPatchAddress, lpNewAddress);
562     if (!bRet) {
563         debugPrintf("inject: failed to patch %s!0x%lx -> %s!%s\n", szModule, Offset, pszDllName, pszFunctionName);
564     }
565 
566     return bRet;
567 }
568 
569 
570 
571 struct StrCompare : public std::binary_function<const char *, const char *, bool> {
operator ()StrCompare572     bool operator() (const char * s1, const char * s2) const {
573         return strcmp(s1, s2) < 0;
574     }
575 };
576 
577 typedef std::map<const char *, LPVOID, StrCompare> FunctionMap;
578 
579 struct StrICompare : public std::binary_function<const char *, const char *, bool> {
operator ()StrICompare580     bool operator() (const char * s1, const char * s2) const {
581         return stricmp(s1, s2) < 0;
582     }
583 };
584 
585 struct Module {
586     bool bInternal;
587     FunctionMap functionMap;
588 };
589 
590 typedef std::map<const char *, Module, StrICompare> ModulesMap;
591 
592 /* This is only modified at DLL_PROCESS_ATTACH time. */
593 static ModulesMap modulesMap;
594 
595 
596 static inline bool
isMatchModuleName(const char * szModuleName)597 isMatchModuleName(const char *szModuleName)
598 {
599     ModulesMap::const_iterator modIt = modulesMap.find(szModuleName);
600     return modIt != modulesMap.end();
601 }
602 
603 
604 /* Set of previously hooked modules */
605 static std::set<HMODULE>
606 g_hHookedModules;
607 
608 
609 enum Action {
610     ACTION_HOOK,
611     ACTION_UNHOOK,
612 
613 };
614 
615 
616 template< class T >
617 void
patchDescriptor(HMODULE hModule,const char * szModule,T pImportDescriptor,Action action)618 patchDescriptor(HMODULE hModule,
619                 const char *szModule,
620                 T pImportDescriptor,
621                 Action action)
622 {
623     const char* szDescriptorName = getDescriptorName(hModule, pImportDescriptor);
624 
625     ModulesMap::const_iterator modIt = modulesMap.find(szDescriptorName);
626     if (modIt != modulesMap.end()) {
627         const char *szMatchModule = modIt->first; // same as szDescriptorName
628         const Module & module = modIt->second;
629 
630         const FunctionMap & functionMap = module.functionMap;
631         FunctionMap::const_iterator fnIt;
632         for (fnIt = functionMap.begin(); fnIt != functionMap.end(); ++fnIt) {
633             const char *szFunctionName = fnIt->first;
634             LPVOID lpHookAddress = fnIt->second;
635 
636             // Knowning the real address is useful when patching imports by ordinal
637             LPVOID lpRealAddress = NULL;
638             HMODULE hRealModule = GetModuleHandleA(szDescriptorName);
639             if (hRealModule) {
640                 // FIXME: this assertion can fail when the wrapper name is the same as the original DLL
641                 //assert(hRealModule != g_hHookModule);
642                 if (hRealModule != g_hHookModule) {
643                     lpRealAddress = (LPVOID)GetProcAddress(hRealModule, szFunctionName);
644                 }
645             }
646 
647             LPVOID lpOldAddress = lpRealAddress;
648             LPVOID lpNewAddress = lpHookAddress;
649 
650             if (action == ACTION_UNHOOK) {
651                 std::swap(lpOldAddress, lpNewAddress);
652             }
653 
654             BOOL bPatched;
655             bPatched = patchFunction(hModule, szModule, szMatchModule, pImportDescriptor, szFunctionName, lpOldAddress, lpNewAddress);
656             if (action == ACTION_HOOK && bPatched && !module.bInternal && pSharedMem) {
657                 pSharedMem->bReplaced = TRUE;
658             }
659         }
660     }
661 }
662 
663 
664 static void
patchModule(HMODULE hModule,const char * szModule,Action action)665 patchModule(HMODULE hModule,
666             const char *szModule,
667             Action action)
668 {
669     /* Never patch this module */
670     if (hModule == g_hThisModule) {
671         return;
672     }
673 
674     /* Never patch our hook module */
675     if (hModule == g_hHookModule) {
676         return;
677     }
678 
679     /* Hook modules only once */
680     if (action == ACTION_HOOK) {
681         std::pair< std::set<HMODULE>::iterator, bool > ret;
682         EnterCriticalSection(&g_Mutex);
683         ret = g_hHookedModules.insert(hModule);
684         LeaveCriticalSection(&g_Mutex);
685         if (!ret.second) {
686             return;
687         }
688     }
689 
690     const char *szBaseName = getBaseName(szModule);
691 
692     /* Don't hook our replacement modules to avoid tracing internal APIs */
693     /* XXX: is this really a good idea? */
694     if (isMatchModuleName(szBaseName)) {
695         return;
696     }
697 
698     /* Leave these modules alone.
699      *
700      * Hooking other injection DLLs easily leads to infinite recursion (and
701      * stack overflow), especially when those libraries use techniques like
702      * modifying the hooked functions prolog (instead of patching IAT like we
703      * do).
704      *
705      * See also:
706      * - http://www.nynaeve.net/?p=62
707      */
708     if (stricmp(szBaseName, "kernel32.dll") == 0 ||
709         stricmp(szBaseName, "AcLayers.dll") == 0 ||
710         stricmp(szBaseName, "ConEmuHk.dll") == 0 ||
711         stricmp(szBaseName, "gameoverlayrenderer.dll") == 0 ||
712         stricmp(szBaseName, "gameoverlayrenderer64.dll") == 0) {
713         return;
714     }
715 
716     if (VERBOSITY > 0) {
717         debugPrintf("inject: found module %s\n", szModule);
718     }
719 
720     PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = getFirstImportDescriptor(hModule, szModule);
721     if (pImportDescriptor) {
722         while (pImportDescriptor->FirstThunk) {
723 
724             patchDescriptor(hModule, szModule, pImportDescriptor, action);
725 
726             ++pImportDescriptor;
727         }
728     }
729 
730     PImgDelayDescr pDelayDescriptor = getDelayImportDescriptor(hModule, szModule);
731     if (pDelayDescriptor) {
732         while (pDelayDescriptor->rvaDLLName) {
733             if (pDelayDescriptor->grAttrs > 1) {
734                 debugPrintf("inject: warning: ignoring delay import section (grAttrs = 0x%08lx)\n",
735                             pDelayDescriptor->grAttrs);
736                 break;
737             }
738 
739             if (VERBOSITY > 1) {
740                 const char* szName = getDescriptorName(hModule, pDelayDescriptor);
741                 debugPrintf("inject: found %sdelay-load import entry for module %s\n",
742                             pDelayDescriptor->grAttrs & dlattrRva ? "" : "old-style ",
743                             szName);
744             }
745 
746 #ifdef _WIN64
747             assert(pDelayDescriptor->grAttrs & dlattrRva);
748 #endif
749 
750             patchDescriptor(hModule, szModule, pDelayDescriptor, action);
751 
752             ++pDelayDescriptor;
753         }
754     }
755 }
756 
757 
758 static void
patchAllModules(Action action)759 patchAllModules(Action action)
760 {
761     HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId());
762     if (hModuleSnap == INVALID_HANDLE_VALUE) {
763         return;
764     }
765 
766     MODULEENTRY32 me32;
767     me32.dwSize = sizeof me32;
768     if (Module32First(hModuleSnap, &me32)) {
769         do  {
770             patchModule(me32.hModule, me32.szExePath, action);
771         } while (Module32Next(hModuleSnap, &me32));
772     }
773 
774     CloseHandle(hModuleSnap);
775 }
776 
777 
778 static HMODULE WINAPI
MyLoadLibraryA(LPCSTR lpLibFileName)779 MyLoadLibraryA(LPCSTR lpLibFileName)
780 {
781     HMODULE hModule = LoadLibraryA(lpLibFileName);
782     DWORD dwLastError = GetLastError();
783 
784     if (VERBOSITY >= 2) {
785         debugPrintf("inject: intercepting %s(\"%s\") = 0x%p\n",
786                     __FUNCTION__ + 2, lpLibFileName, hModule);
787     }
788 
789     if (VERBOSITY > 0) {
790         const char *szBaseName = getBaseName(lpLibFileName);
791         if (isMatchModuleName(szBaseName)) {
792             if (VERBOSITY < 2) {
793                 debugPrintf("inject: intercepting %s(\"%s\")\n", __FUNCTION__, lpLibFileName);
794             }
795 #ifdef __GNUC__
796             void *caller = __builtin_return_address (0);
797 
798             HMODULE hModule = 0;
799             BOOL bRet = GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
800                                           GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
801                                           (LPCTSTR)caller,
802                                           &hModule);
803             assert(bRet);
804             char szCaller[MAX_PATH];
805             DWORD dwRet = GetModuleFileNameA(hModule, szCaller, sizeof szCaller);
806             assert(dwRet);
807             debugPrintf("inject: called from %s\n", szCaller);
808 #endif
809         }
810     }
811 
812     // Hook all new modules (and not just this one, to pick up any dependencies)
813     patchAllModules(ACTION_HOOK);
814 
815     SetLastError(dwLastError);
816     return hModule;
817 }
818 
819 static HMODULE WINAPI
MyLoadLibraryW(LPCWSTR lpLibFileName)820 MyLoadLibraryW(LPCWSTR lpLibFileName)
821 {
822     HMODULE hModule = LoadLibraryW(lpLibFileName);
823     DWORD dwLastError = GetLastError();
824 
825     if (VERBOSITY >= 2) {
826         debugPrintf("inject: intercepting %s(L\"%S\") = 0x%p\n",
827                     __FUNCTION__ + 2, lpLibFileName, hModule);
828     }
829 
830     // Hook all new modules (and not just this one, to pick up any dependencies)
831     patchAllModules(ACTION_HOOK);
832 
833     SetLastError(dwLastError);
834     return hModule;
835 }
836 
837 #ifndef LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR
838 #define LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR    0x00000100
839 #endif
840 #ifndef LOAD_LIBRARY_SEARCH_APPLICATION_DIR
841 #define LOAD_LIBRARY_SEARCH_APPLICATION_DIR 0x00000200
842 #endif
843 #ifndef LOAD_LIBRARY_SEARCH_USER_DIRS
844 #define LOAD_LIBRARY_SEARCH_USER_DIRS       0x00000400
845 #endif
846 #ifndef LOAD_LIBRARY_SEARCH_SYSTEM32
847 #define LOAD_LIBRARY_SEARCH_SYSTEM32        0x00000800
848 #endif
849 #ifndef LOAD_LIBRARY_SEARCH_DEFAULT_DIRS
850 #define LOAD_LIBRARY_SEARCH_DEFAULT_DIRS    0x00001000
851 #endif
852 
853 static inline DWORD
adjustFlags(DWORD dwFlags)854 adjustFlags(DWORD dwFlags)
855 {
856     /*
857      * XXX: LoadLibraryEx seems to interpret "application directory" in respect
858      * to the module that's calling it.  So when the application restricts the
859      * search path to application directory via
860      * LOAD_LIBRARY_SEARCH_APPLICATION_DIR or LOAD_LIBRARY_SEARCH_DEFAULT_DIRS
861      * flags, kernel32.dll ends up searching on the directory of the inject.dll
862      * module.
863      *
864      * XXX: What about SetDefaultDllDirectories?
865      *
866      */
867     if (dwFlags & (LOAD_LIBRARY_SEARCH_APPLICATION_DIR |
868                    LOAD_LIBRARY_SEARCH_DEFAULT_DIRS)) {
869         dwFlags &= ~(LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR |
870                      LOAD_LIBRARY_SEARCH_APPLICATION_DIR |
871                      LOAD_LIBRARY_SEARCH_USER_DIRS |
872                      LOAD_LIBRARY_SEARCH_SYSTEM32 |
873                      LOAD_LIBRARY_SEARCH_DEFAULT_DIRS);
874     }
875 
876     return dwFlags;
877 }
878 
879 static HMODULE WINAPI
MyLoadLibraryExA(LPCSTR lpLibFileName,HANDLE hFile,DWORD dwFlags)880 MyLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
881 {
882     HMODULE hModule = LoadLibraryExA(lpLibFileName, hFile, adjustFlags(dwFlags));
883     DWORD dwLastError = GetLastError();
884 
885     if (VERBOSITY >= 2) {
886         debugPrintf("inject: intercepting %s(\"%s\", 0x%p, 0x%lx) = 0x%p\n",
887                     __FUNCTION__ + 2, lpLibFileName, hFile, dwFlags, hModule);
888     }
889 
890     // Hook all new modules (and not just this one, to pick up any dependencies)
891     if ((dwFlags & (DONT_RESOLVE_DLL_REFERENCES |
892                     LOAD_LIBRARY_AS_DATAFILE |
893                     LOAD_LIBRARY_AS_DATAFILE_EXCLUSIVE)) == 0) {
894         patchAllModules(ACTION_HOOK);
895     }
896 
897     SetLastError(dwLastError);
898     return hModule;
899 }
900 
901 static HMODULE WINAPI
MyLoadLibraryExW(LPCWSTR lpLibFileName,HANDLE hFile,DWORD dwFlags)902 MyLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
903 {
904     HMODULE hModule = LoadLibraryExW(lpLibFileName, hFile, adjustFlags(dwFlags));
905     DWORD dwLastError = GetLastError();
906 
907     if (VERBOSITY >= 2) {
908         debugPrintf("inject: intercepting %s(L\"%S\", 0x%p, 0x%lx) = 0x%p\n",
909                     __FUNCTION__ + 2, lpLibFileName, hFile, dwFlags, hModule);
910     }
911 
912     // Hook all new modules (and not just this one, to pick up any dependencies)
913     if ((dwFlags & (DONT_RESOLVE_DLL_REFERENCES |
914                     LOAD_LIBRARY_AS_DATAFILE |
915                     LOAD_LIBRARY_AS_DATAFILE_EXCLUSIVE)) == 0) {
916         patchAllModules(ACTION_HOOK);
917     }
918 
919     SetLastError(dwLastError);
920     return hModule;
921 }
922 
923 
924 static void
logGetProcAddress(HMODULE hModule,LPCSTR lpProcName)925 logGetProcAddress(HMODULE hModule, LPCSTR lpProcName) {
926     if (HIWORD(lpProcName) == 0) {
927         debugPrintf("inject: intercepting %s(%u)\n", "GetProcAddress", LOWORD(lpProcName));
928     } else {
929         debugPrintf("inject: intercepting %s(\"%s\")\n", "GetProcAddress", lpProcName);
930     }
931 }
932 
933 static FARPROC WINAPI
MyGetProcAddress(HMODULE hModule,LPCSTR lpProcName)934 MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName) {
935 
936     if (VERBOSITY >= 3) {
937         /* XXX this can cause segmentation faults */
938         logGetProcAddress(hModule, lpProcName);
939     }
940 
941     if (!NOOP) {
942         char szModule[MAX_PATH];
943         DWORD dwRet = GetModuleFileNameA(hModule, szModule, sizeof szModule);
944         assert(dwRet);
945         const char *szBaseName = getBaseName(szModule);
946 
947         ModulesMap::const_iterator modIt;
948         modIt = modulesMap.find(szBaseName);
949         if (modIt != modulesMap.end()) {
950             if (VERBOSITY > 1 && VERBOSITY < 3) {
951                 logGetProcAddress(hModule, lpProcName);
952             }
953 
954             const Module & module = modIt->second;
955             const FunctionMap & functionMap = module.functionMap;
956             FunctionMap::const_iterator fnIt;
957 
958             if (HIWORD(lpProcName) == 0) {
959                 FARPROC proc = GetProcAddress(hModule, lpProcName);
960                 if (!proc) {
961                     return proc;
962                 }
963 
964                 for (fnIt = functionMap.begin(); fnIt != functionMap.end(); ++fnIt) {
965                     FARPROC pRealProc = GetProcAddress(hModule, fnIt->first);
966                     if (proc == pRealProc) {
967                         if (VERBOSITY > 0) {
968                             debugPrintf("inject: replacing %s!%s\n", szBaseName, lpProcName);
969                         }
970                         return (FARPROC)fnIt->second;
971                     }
972 
973                 }
974 
975                 debugPrintf("inject: ignoring %s!@%u\n", szBaseName, LOWORD(lpProcName));
976 
977                 return proc;
978             }
979 
980             fnIt = functionMap.find(lpProcName);
981 
982             if (fnIt != functionMap.end()) {
983                 LPVOID pProcAddress = fnIt->second;
984                 if (VERBOSITY > 0) {
985                     debugPrintf("inject: replacing %s!%s\n", szBaseName, lpProcName);
986                 }
987                 if (!module.bInternal && pSharedMem) {
988                     pSharedMem->bReplaced = TRUE;
989                 }
990                 return (FARPROC)pProcAddress;
991             } else {
992                 if (VERBOSITY > 0 && !module.bInternal) {
993                     debugPrintf("inject: ignoring %s!%s\n", szBaseName, lpProcName);
994                 }
995             }
996         }
997     }
998 
999     return GetProcAddress(hModule, lpProcName);
1000 }
1001 
1002 
1003 static BOOL WINAPI
MyFreeLibrary(HMODULE hModule)1004 MyFreeLibrary(HMODULE hModule)
1005 {
1006     if (VERBOSITY >= 2) {
1007         debugPrintf("inject: intercepting %s(0x%p)\n", __FUNCTION__, hModule);
1008     }
1009 
1010     BOOL bRet = FreeLibrary(hModule);
1011     DWORD dwLastError = GetLastError();
1012 
1013     std::set<HMODULE> hCurrentModules;
1014     HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId());
1015     if (hModuleSnap != INVALID_HANDLE_VALUE) {
1016         MODULEENTRY32 me32;
1017         me32.dwSize = sizeof me32;
1018         if (Module32First(hModuleSnap, &me32)) {
1019             do  {
1020                 hCurrentModules.insert(me32.hModule);
1021             } while (Module32Next(hModuleSnap, &me32));
1022         }
1023         CloseHandle(hModuleSnap);
1024     }
1025 
1026     // Clear the modules that have been freed
1027     EnterCriticalSection(&g_Mutex);
1028     std::set<HMODULE> hIntersectedModules;
1029     std::set_intersection(g_hHookedModules.begin(), g_hHookedModules.end(),
1030                           hCurrentModules.begin(), hCurrentModules.end(),
1031                           std::inserter(hIntersectedModules, hIntersectedModules.begin()));
1032     g_hHookedModules = std::move(hIntersectedModules);
1033     LeaveCriticalSection(&g_Mutex);
1034 
1035     SetLastError(dwLastError);
1036     return bRet;
1037 }
1038 
1039 
1040 static void
registerLibraryLoaderHooks(const char * szMatchModule)1041 registerLibraryLoaderHooks(const char *szMatchModule)
1042 {
1043     Module & module = modulesMap[szMatchModule];
1044     module.bInternal = true;
1045     FunctionMap & functionMap = module.functionMap;
1046     functionMap["LoadLibraryA"]   = (LPVOID)MyLoadLibraryA;
1047     functionMap["LoadLibraryW"]   = (LPVOID)MyLoadLibraryW;
1048     functionMap["LoadLibraryExA"] = (LPVOID)MyLoadLibraryExA;
1049     functionMap["LoadLibraryExW"] = (LPVOID)MyLoadLibraryExW;
1050     functionMap["GetProcAddress"] = (LPVOID)MyGetProcAddress;
1051     functionMap["FreeLibrary"]    = (LPVOID)MyFreeLibrary;
1052 }
1053 
1054 static void
registerProcessThreadsHooks(const char * szMatchModule)1055 registerProcessThreadsHooks(const char *szMatchModule)
1056 {
1057     Module & module = modulesMap[szMatchModule];
1058     module.bInternal = true;
1059     FunctionMap & functionMap = module.functionMap;
1060     functionMap["CreateProcessA"]       = (LPVOID)MyCreateProcessA;
1061     functionMap["CreateProcessW"]       = (LPVOID)MyCreateProcessW;
1062     // NOTE: CreateProcessAsUserA is implemented by advapi32.dll
1063     functionMap["CreateProcessAsUserW"] = (LPVOID)MyCreateProcessAsUserW;
1064     // TODO: CreateProcessWithTokenW
1065 }
1066 
1067 static void
registerModuleHooks(const char * szMatchModule,HMODULE hReplaceModule)1068 registerModuleHooks(const char *szMatchModule, HMODULE hReplaceModule)
1069 {
1070     Module & module = modulesMap[szMatchModule];
1071     module.bInternal = false;
1072     FunctionMap & functionMap = module.functionMap;
1073 
1074     PIMAGE_EXPORT_DIRECTORY pExportDescriptor = getExportDescriptor(hReplaceModule);
1075     assert(pExportDescriptor);
1076 
1077     DWORD *pAddressOfNames = (DWORD *)((BYTE *)hReplaceModule + pExportDescriptor->AddressOfNames);
1078     for (DWORD i = 0; i < pExportDescriptor->NumberOfNames; ++i) {
1079         const char *szFunctionName = (const char *)((BYTE *)hReplaceModule + pAddressOfNames[i]);
1080         LPVOID lpNewAddress = (LPVOID)GetProcAddress(hReplaceModule, szFunctionName);
1081         assert(lpNewAddress);
1082 
1083         functionMap[szFunctionName] = lpNewAddress;
1084     }
1085 }
1086 
1087 static void
dumpRegisteredHooks(void)1088 dumpRegisteredHooks(void)
1089 {
1090     if (VERBOSITY > 1) {
1091         ModulesMap::const_iterator modIt;
1092         for (modIt = modulesMap.begin(); modIt != modulesMap.end(); ++modIt) {
1093             const char *szMatchModule = modIt->first;
1094             const Module & module = modIt->second;
1095             const FunctionMap & functionMap = module.functionMap;
1096             FunctionMap::const_iterator fnIt;
1097             for (fnIt = functionMap.begin(); fnIt != functionMap.end(); ++fnIt) {
1098                 const char *szFunctionName = fnIt->first;
1099                 debugPrintf("inject: registered hook for %s!%s%s\n",
1100                             szMatchModule, szFunctionName,
1101                             module.bInternal ? " (internal)" : "");
1102             }
1103         }
1104     }
1105 }
1106 
1107 
1108 EXTERN_C BOOL WINAPI
DllMain(HINSTANCE hinstDLL,DWORD fdwReason,LPVOID lpvReserved)1109 DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved)
1110 {
1111     const char *szNewDllName = NULL;
1112     const char *szNewDllBaseName;
1113 
1114     switch (fdwReason) {
1115     case DLL_PROCESS_ATTACH:
1116         InitializeCriticalSection(&g_Mutex);
1117 
1118         g_hThisModule = hinstDLL;
1119 
1120         /*
1121          * Calling LoadLibrary inside DllMain is strongly discouraged.  But it
1122          * works quite well, provided that the loaded DLL does not require or do
1123          * anything special in its DllMain, which seems to be the general case.
1124          *
1125          * See also:
1126          * - http://stackoverflow.com/questions/4370812/calling-loadlibrary-from-dllmain
1127          * - http://msdn.microsoft.com/en-us/library/ms682583
1128          */
1129 
1130         if (!USE_SHARED_MEM) {
1131             szNewDllName = getenv("INJECT_DLL");
1132             if (!szNewDllName) {
1133                 debugPrintf("inject: warning: INJECT_DLL not set\n");
1134                 return FALSE;
1135             }
1136         } else {
1137             SharedMem *pSharedMem = OpenSharedMemory(NULL);
1138             if (!pSharedMem) {
1139                 debugPrintf("inject: error: failed to open shared memory\n");
1140                 return FALSE;
1141             }
1142 
1143             VERBOSITY = pSharedMem->cVerbosity;
1144 
1145             static char szSharedMemCopy[MAX_PATH];
1146             strncpy(szSharedMemCopy, pSharedMem->szDllName, _countof(szSharedMemCopy) - 1);
1147             szSharedMemCopy[_countof(szSharedMemCopy) - 1] = '\0';
1148 
1149             szNewDllName = szSharedMemCopy;
1150         }
1151 
1152         if (VERBOSITY > 0) {
1153             debugPrintf("inject: DLL_PROCESS_ATTACH\n");
1154         }
1155 
1156         if (VERBOSITY > 0) {
1157             char szProcess[MAX_PATH];
1158             GetModuleFileNameA(NULL, szProcess, sizeof szProcess);
1159             debugPrintf("inject: attached to process %s\n", szProcess);
1160         }
1161 
1162         if (VERBOSITY > 0) {
1163             debugPrintf("inject: loading %s\n", szNewDllName);
1164         }
1165 
1166         g_hHookModule = LoadLibraryA(szNewDllName);
1167         if (!g_hHookModule) {
1168             debugPrintf("inject: warning: failed to load %s\n", szNewDllName);
1169             return FALSE;
1170         }
1171 
1172         // Ensure we use kernel32.dll's CreateProcessAsUserW, and not advapi32.dll's.
1173         {
1174             HMODULE hKernel32 = GetModuleHandleA("kernel32.dll");
1175             assert(hKernel32);
1176             pfnCreateProcessAsUserW = (PFNCREATEPROCESSASUSERW)GetProcAddress(hKernel32, "CreateProcessAsUserW");
1177         }
1178 
1179         /*
1180          * Hook kernel32.dll functions, and its respective Windows API Set.
1181          *
1182          * https://msdn.microsoft.com/en-us/library/dn505783.aspx (Windows 8.1)
1183          * https://msdn.microsoft.com/en-us/library/hh802935.aspx (Windows 8)
1184          * https://docs.microsoft.com/en-us/uwp/win32-and-com/win32-apis
1185          */
1186 
1187         registerLibraryLoaderHooks("kernel32.dll");
1188         registerLibraryLoaderHooks("api-ms-win-core-libraryloader-l1-1-0.dll");
1189         registerLibraryLoaderHooks("api-ms-win-core-libraryloader-l1-1-1.dll");
1190         registerLibraryLoaderHooks("api-ms-win-core-libraryloader-l1-2-0.dll");
1191         registerLibraryLoaderHooks("api-ms-win-core-libraryloader-l1-2-1.dll");
1192         registerLibraryLoaderHooks("api-ms-win-core-libraryloader-l1-2-2.dll");
1193         registerLibraryLoaderHooks("api-ms-win-core-libraryloader-l2-1-0.dll");
1194         registerLibraryLoaderHooks("api-ms-win-core-kernel32-legacy-l1-1-0.dll");
1195         registerLibraryLoaderHooks("api-ms-win-core-kernel32-legacy-l1-1-1.dll");
1196         registerLibraryLoaderHooks("api-ms-win-core-kernel32-legacy-l1-1-2.dll");
1197 
1198         registerProcessThreadsHooks("kernel32.dll");
1199         registerProcessThreadsHooks("api-ms-win-core-processthreads-l1-1-0.dll");
1200         registerProcessThreadsHooks("api-ms-win-core-processthreads-l1-1-1.dll");
1201         registerProcessThreadsHooks("api-ms-win-core-processthreads-l1-1-2.dll");
1202         registerProcessThreadsHooks("api-ms-win-core-processthreads-l1-1-2.dll");
1203         registerProcessThreadsHooks("api-ms-win-core-processthreads-l1-1-3.dll");
1204 
1205         szNewDllBaseName = getBaseName(szNewDllName);
1206         if (stricmp(szNewDllBaseName, "dxgitrace.dll") == 0) {
1207             registerModuleHooks("dxgi.dll",    g_hHookModule);
1208             registerModuleHooks("d3d10.dll",   g_hHookModule);
1209             registerModuleHooks("d3d10_1.dll", g_hHookModule);
1210             registerModuleHooks("d3d11.dll",   g_hHookModule);
1211             registerModuleHooks("d3d9.dll",    g_hHookModule); // for D3DPERF_*
1212             registerModuleHooks("dcomp.dll",   g_hHookModule);
1213         } else if (stricmp(szNewDllBaseName, "d3d9.dll") == 0) {
1214             registerModuleHooks("d3d9.dll",    g_hHookModule);
1215             registerModuleHooks("dxva2.dll",   g_hHookModule);
1216         } else if (stricmp(szNewDllBaseName, "d2d1trace.dll") == 0) {
1217             registerModuleHooks("d2d1.dll",    g_hHookModule);
1218             registerModuleHooks("dwrite.dll",  g_hHookModule);
1219         } else {
1220             registerModuleHooks(szNewDllBaseName, g_hHookModule);
1221         }
1222 
1223         dumpRegisteredHooks();
1224 
1225         patchAllModules(ACTION_HOOK);
1226         break;
1227 
1228     case DLL_THREAD_ATTACH:
1229         break;
1230 
1231     case DLL_THREAD_DETACH:
1232         break;
1233 
1234     case DLL_PROCESS_DETACH:
1235         if (VERBOSITY > 0) {
1236             debugPrintf("inject: DLL_PROCESS_DETACH\n");
1237         }
1238 
1239         assert(!lpvReserved);
1240 
1241         patchAllModules(ACTION_UNHOOK);
1242 
1243         if (g_hHookModule) {
1244             FreeLibrary(g_hHookModule);
1245         }
1246         break;
1247     }
1248     return TRUE;
1249 }
1250 
1251 
1252 /*
1253  * Prevent the C/C++ runtime from destroying things when the program
1254  * terminates.
1255  *
1256  * There is no effective way to control the order DLLs receive
1257  * DLL_PROCESS_DETACH -- patched DLLs might get detacched after we are --, and
1258  * unpatching our hooks doesn't always work.  So instead just do nothing (and
1259  * prevent C/C++ runtime from doing anything too), so our hooks can still work
1260  * after we are dettached.
1261  */
1262 
1263 #ifdef _MSC_VER
1264 #  define DLLMAIN_CRT_STARTUP _DllMainCRTStartup
1265 #else
1266 #  define DLLMAIN_CRT_STARTUP DllMainCRTStartup
1267 #  pragma GCC optimize ("no-stack-protector")
1268 #endif
1269 
1270 EXTERN_C BOOL WINAPI
1271 DLLMAIN_CRT_STARTUP(HANDLE hDllHandle, DWORD dwReason, LPVOID lpvReserved);
1272 
1273 EXTERN_C BOOL WINAPI
DllMainStartup(HANDLE hDllHandle,DWORD dwReason,LPVOID lpvReserved)1274 DllMainStartup(HANDLE hDllHandle, DWORD dwReason, LPVOID lpvReserved)
1275 {
1276     if (dwReason == DLL_PROCESS_DETACH && lpvReserved) {
1277         if (VERBOSITY > 0) {
1278             debugPrintf("inject: DLL_PROCESS_DETACH\n");
1279         }
1280         return TRUE;
1281     }
1282 
1283     return DLLMAIN_CRT_STARTUP(hDllHandle, dwReason, lpvReserved);
1284 }
1285