1 /**************************************************************************
2  *
3  * Copyright 2016 VMware, Inc.
4  * Copyright 2011-2012 Jose Fonseca
5  * All Rights Reserved.
6  *
7  * Permission is hereby granted, free of charge, to any person obtaining a copy
8  * of this software and associated documentation files (the "Software"), to deal
9  * in the Software without restriction, including without limitation the rights
10  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11  * copies of the Software, and to permit persons to whom the Software is
12  * furnished to do so, subject to the following conditions:
13  *
14  * The above copyright notice and this permission notice shall be included in
15  * all copies or substantial portions of the Software.
16  *
17  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23  * THE SOFTWARE.
24  *
25  **************************************************************************/
26 
27 
28 /*
29  * Code for the DLL that will be injected in the target process.
30  *
31  * The injected DLL will manipulate the import tables to hook the
32  * modules/functions of interest.
33  *
34  * See also:
35  * - http://www.codeproject.com/KB/system/api_spying_hack.aspx
36  * - http://www.codeproject.com/KB/threads/APIHooking.aspx
37  * - http://msdn.microsoft.com/en-us/magazine/cc301808.aspx
38  */
39 
40 
41 #include <assert.h>
42 #include <stdio.h>
43 #include <stdarg.h>
44 #include <string.h>
45 
46 #include <algorithm>
47 #include <set>
48 #include <map>
49 #include <functional>
50 #include <iterator>
51 
52 #include <windows.h>
53 #include <tlhelp32.h>
54 #include <delayimp.h>
55 
56 #include "inject.h"
57 #include "mhook.h"
58 
59 #include "os_symbols.hpp"
60 
61 
62 static int VERBOSITY = 0;
63 #define NOOP 0
64 
65 
66 static CRITICAL_SECTION g_Mutex;
67 
68 
69 
70 static HMODULE g_hThisModule = NULL;
71 static HMODULE g_hHookModule = NULL;
72 
73 
74 static std::map<PVOID, PVOID>
75 g_pRealFunctions;
76 
77 
78 typedef HMODULE
79 (WINAPI *PFNLOADLIBRARYA)(LPCSTR);
80 
81 typedef HMODULE
82 (WINAPI *PFNLOADLIBRARYW)(LPCWSTR);
83 
84 typedef HMODULE
85 (WINAPI *PFNLOADLIBRARYEXA)(LPCSTR, HANDLE, DWORD);
86 
87 typedef HMODULE
88 (WINAPI *PFNLOADLIBRARYEXW)(LPCWSTR, HANDLE, DWORD);
89 
90 typedef BOOL
91 (WINAPI *PFNFREELIBRARY)(HMODULE);
92 
93 static PFNLOADLIBRARYA RealLoadLibraryA = LoadLibraryA;
94 static PFNLOADLIBRARYW RealLoadLibraryW = LoadLibraryW;
95 static PFNLOADLIBRARYEXA RealLoadLibraryExA = LoadLibraryExA;
96 static PFNLOADLIBRARYEXW RealLoadLibraryExW = LoadLibraryExW;
97 static PFNFREELIBRARY RealFreeLibrary = FreeLibrary;
98 
99 typedef FARPROC (WINAPI * PFNGETPROCADDRESS)(HMODULE hModule, LPCSTR lpProcName);
100 
101 static PFNGETPROCADDRESS RealGetProcAddress = GetProcAddress;
102 
103 typedef BOOL
104 (WINAPI *PFNCREATEPROCESSA) (LPCSTR, LPSTR,
105         LPSECURITY_ATTRIBUTES, LPSECURITY_ATTRIBUTES, BOOL, DWORD, LPVOID,
106         LPCSTR, LPSTARTUPINFOA, LPPROCESS_INFORMATION);
107 
108 static PFNCREATEPROCESSA RealCreateProcessA = CreateProcessA;
109 
110 typedef BOOL
111 (WINAPI *PFNCREATEPROCESSW) (LPCWSTR, LPWSTR,
112         LPSECURITY_ATTRIBUTES, LPSECURITY_ATTRIBUTES, BOOL, DWORD, LPVOID,
113         LPCWSTR, LPSTARTUPINFOW, LPPROCESS_INFORMATION);
114 
115 static PFNCREATEPROCESSW RealCreateProcessW = CreateProcessW;
116 
117 typedef BOOL
118 (WINAPI *PFNCREATEPROCESSASUSERW) (HANDLE, LPCWSTR, LPWSTR,
119         LPSECURITY_ATTRIBUTES, LPSECURITY_ATTRIBUTES, BOOL, DWORD, LPVOID,
120         LPCWSTR, LPSTARTUPINFOW, LPPROCESS_INFORMATION);
121 
122 static PFNCREATEPROCESSASUSERW RealCreateProcessAsUserW = CreateProcessAsUserW;
123 
124 
125 static void
debugPrintf(const char * format,...)126 debugPrintf(const char *format, ...)
127 {
128     char buf[512];
129 
130     va_list ap;
131     va_start(ap, format);
132     _vsnprintf(buf, sizeof buf, format, ap);
133     va_end(ap);
134 
135     OutputDebugStringA(buf);
136 }
137 
138 
139 EXTERN_C void
_assert(const char * _Message,const char * _File,unsigned _Line)140 _assert(const char *_Message, const char *_File, unsigned _Line)
141 {
142     debugPrintf("Assertion failed: %s, file %s, line %u\n", _Message, _File, _Line);
143     TerminateProcess(GetCurrentProcess(), 1);
144 }
145 
146 
147 EXTERN_C void
_wassert(const wchar_t * _Message,const wchar_t * _File,unsigned _Line)148 _wassert(const wchar_t * _Message, const wchar_t *_File, unsigned _Line)
149 {
150     debugPrintf("Assertion failed: %S, file %S, line %u\n", _Message, _File, _Line);
151     TerminateProcess(GetCurrentProcess(), 1);
152 }
153 
154 
155 static void
MyCreateProcessCommon(BOOL bRet,DWORD dwCreationFlags,LPPROCESS_INFORMATION lpProcessInformation)156 MyCreateProcessCommon(BOOL bRet,
157                       DWORD dwCreationFlags,
158                       LPPROCESS_INFORMATION lpProcessInformation)
159 {
160     if (!bRet) {
161         debugPrintf("inject: warning: failed to create child process\n");
162         return;
163     }
164 
165     DWORD dwLastError = GetLastError();
166 
167     if (isDifferentArch(lpProcessInformation->hProcess)) {
168         debugPrintf("inject: error: child process %lu has different architecture\n",
169                     GetProcessId(lpProcessInformation->hProcess));
170     } else {
171         char szDllPath[MAX_PATH];
172         GetModuleFileNameA(g_hThisModule, szDllPath, sizeof szDllPath);
173 
174         if (!injectDll(lpProcessInformation->hProcess, szDllPath)) {
175             debugPrintf("inject: warning: failed to inject into child process %lu\n",
176                         GetProcessId(lpProcessInformation->hProcess));
177         }
178     }
179 
180     if (!(dwCreationFlags & CREATE_SUSPENDED)) {
181         ResumeThread(lpProcessInformation->hThread);
182     }
183 
184     SetLastError(dwLastError);
185 }
186 
187 static BOOL WINAPI
MyCreateProcessA(LPCSTR lpApplicationName,LPSTR lpCommandLine,LPSECURITY_ATTRIBUTES lpProcessAttributes,LPSECURITY_ATTRIBUTES lpThreadAttributes,BOOL bInheritHandles,DWORD dwCreationFlags,LPVOID lpEnvironment,LPCSTR lpCurrentDirectory,LPSTARTUPINFOA lpStartupInfo,LPPROCESS_INFORMATION lpProcessInformation)188 MyCreateProcessA(LPCSTR lpApplicationName,
189                  LPSTR lpCommandLine,
190                  LPSECURITY_ATTRIBUTES lpProcessAttributes,
191                  LPSECURITY_ATTRIBUTES lpThreadAttributes,
192                  BOOL bInheritHandles,
193                  DWORD dwCreationFlags,
194                  LPVOID lpEnvironment,
195                  LPCSTR lpCurrentDirectory,
196                  LPSTARTUPINFOA lpStartupInfo,
197                  LPPROCESS_INFORMATION lpProcessInformation)
198 {
199     if (VERBOSITY >= 2) {
200         debugPrintf("inject: intercepting %s(\"%s\", \"%s\", ...)\n",
201                     __FUNCTION__,
202                     lpApplicationName,
203                     lpCommandLine);
204     }
205 
206     BOOL bRet;
207     bRet = RealCreateProcessA(lpApplicationName,
208                               lpCommandLine,
209                               lpProcessAttributes,
210                               lpThreadAttributes,
211                               bInheritHandles,
212                               dwCreationFlags | CREATE_SUSPENDED,
213                               lpEnvironment,
214                               lpCurrentDirectory,
215                               lpStartupInfo,
216                               lpProcessInformation);
217 
218     MyCreateProcessCommon(bRet, dwCreationFlags, lpProcessInformation);
219 
220     return bRet;
221 }
222 
223 static BOOL WINAPI
MyCreateProcessW(LPCWSTR lpApplicationName,LPWSTR lpCommandLine,LPSECURITY_ATTRIBUTES lpProcessAttributes,LPSECURITY_ATTRIBUTES lpThreadAttributes,BOOL bInheritHandles,DWORD dwCreationFlags,LPVOID lpEnvironment,LPCWSTR lpCurrentDirectory,LPSTARTUPINFOW lpStartupInfo,LPPROCESS_INFORMATION lpProcessInformation)224 MyCreateProcessW(LPCWSTR lpApplicationName,
225                  LPWSTR lpCommandLine,
226                  LPSECURITY_ATTRIBUTES lpProcessAttributes,
227                  LPSECURITY_ATTRIBUTES lpThreadAttributes,
228                  BOOL bInheritHandles,
229                  DWORD dwCreationFlags,
230                  LPVOID lpEnvironment,
231                  LPCWSTR lpCurrentDirectory,
232                  LPSTARTUPINFOW lpStartupInfo,
233                  LPPROCESS_INFORMATION lpProcessInformation)
234 {
235     if (VERBOSITY >= 2) {
236         debugPrintf("inject: intercepting %s(\"%S\", \"%S\", ...)\n",
237                     __FUNCTION__,
238                     lpApplicationName,
239                     lpCommandLine);
240     }
241 
242     BOOL bRet;
243     bRet = RealCreateProcessW(lpApplicationName,
244                               lpCommandLine,
245                               lpProcessAttributes,
246                               lpThreadAttributes,
247                               bInheritHandles,
248                               dwCreationFlags | CREATE_SUSPENDED,
249                               lpEnvironment,
250                               lpCurrentDirectory,
251                               lpStartupInfo,
252                               lpProcessInformation);
253 
254     MyCreateProcessCommon(bRet, dwCreationFlags, lpProcessInformation);
255 
256     return bRet;
257 }
258 
259 static BOOL WINAPI
MyCreateProcessAsUserW(HANDLE hToken,LPCWSTR lpApplicationName,LPWSTR lpCommandLine,LPSECURITY_ATTRIBUTES lpProcessAttributes,LPSECURITY_ATTRIBUTES lpThreadAttributes,BOOL bInheritHandles,DWORD dwCreationFlags,LPVOID lpEnvironment,LPCWSTR lpCurrentDirectory,LPSTARTUPINFOW lpStartupInfo,LPPROCESS_INFORMATION lpProcessInformation)260 MyCreateProcessAsUserW(HANDLE hToken,
261                        LPCWSTR lpApplicationName,
262                        LPWSTR lpCommandLine,
263                        LPSECURITY_ATTRIBUTES lpProcessAttributes,
264                        LPSECURITY_ATTRIBUTES lpThreadAttributes,
265                        BOOL bInheritHandles,
266                        DWORD dwCreationFlags,
267                        LPVOID lpEnvironment,
268                        LPCWSTR lpCurrentDirectory,
269                        LPSTARTUPINFOW lpStartupInfo,
270                        LPPROCESS_INFORMATION lpProcessInformation)
271 {
272     if (VERBOSITY >= 2) {
273         debugPrintf("inject: intercepting %s(\"%S\", \"%S\", ...)\n",
274                     __FUNCTION__,
275                     lpApplicationName,
276                     lpCommandLine);
277     }
278 
279     BOOL bRet;
280     bRet = RealCreateProcessAsUserW(hToken,
281                                     lpApplicationName,
282                                     lpCommandLine,
283                                     lpProcessAttributes,
284                                     lpThreadAttributes,
285                                     bInheritHandles,
286                                     dwCreationFlags,
287                                     lpEnvironment,
288                                     lpCurrentDirectory,
289                                     lpStartupInfo,
290                                     lpProcessInformation);
291 
292     MyCreateProcessCommon(bRet, dwCreationFlags, lpProcessInformation);
293 
294     return bRet;
295 }
296 
297 
298 template< class T, class I >
299 inline T *
rvaToVa(HMODULE hModule,I rva)300 rvaToVa(HMODULE hModule, I rva)
301 {
302     assert(rva != 0);
303     return reinterpret_cast<T *>(reinterpret_cast<PBYTE>(hModule) + rva);
304 }
305 
306 
307 static PIMAGE_OPTIONAL_HEADER
getOptionalHeader(HMODULE hModule,const char * szModule)308 getOptionalHeader(HMODULE hModule,
309                   const char *szModule)
310 {
311     PIMAGE_DOS_HEADER pDosHeader = reinterpret_cast<PIMAGE_DOS_HEADER>(hModule);
312     if (pDosHeader->e_magic != IMAGE_DOS_SIGNATURE) {
313         debugPrintf("inject: warning: %s: unexpected DOS header magic (0x%04x)\n",
314                     szModule, pDosHeader->e_magic);
315         return NULL;
316     }
317     PIMAGE_NT_HEADERS pNtHeaders = rvaToVa<IMAGE_NT_HEADERS>(hModule, pDosHeader->e_lfanew);
318     if (pNtHeaders->Signature != IMAGE_NT_SIGNATURE) {
319         debugPrintf("inject: warning: %s: unexpected NT header signature (0x%08lx)\n",
320                     szModule, pNtHeaders->Signature);
321         return NULL;
322     }
323     PIMAGE_OPTIONAL_HEADER pOptionalHeader = &pNtHeaders->OptionalHeader;
324     return pOptionalHeader;
325 }
326 
327 static PVOID
getImageDirectoryEntry(HMODULE hModule,const char * szModule,UINT Entry)328 getImageDirectoryEntry(HMODULE hModule,
329                        const char *szModule,
330                        UINT Entry)
331 {
332     MEMORY_BASIC_INFORMATION MemoryInfo;
333     if (VirtualQuery(hModule, &MemoryInfo, sizeof MemoryInfo) != sizeof MemoryInfo) {
334         debugPrintf("inject: warning: %s: VirtualQuery failed\n", szModule);
335         return NULL;
336     }
337     if (MemoryInfo.Protect & (PAGE_NOACCESS | PAGE_EXECUTE)) {
338         debugPrintf("inject: warning: %s: no read access (Protect = 0x%08lx)\n", szModule, MemoryInfo.Protect);
339         return NULL;
340     }
341 
342     PIMAGE_OPTIONAL_HEADER pOptionalHeader = getOptionalHeader(hModule, szModule);
343     if (!pOptionalHeader ||
344         pOptionalHeader->DataDirectory[Entry].Size == 0) {
345         return NULL;
346     }
347 
348     UINT_PTR ImportAddress = pOptionalHeader->DataDirectory[Entry].VirtualAddress;
349     if (!ImportAddress) {
350         return NULL;
351     }
352 
353     return rvaToVa<VOID>(hModule, ImportAddress);
354 }
355 
356 
357 static PIMAGE_EXPORT_DIRECTORY
getExportDescriptor(HMODULE hModule)358 getExportDescriptor(HMODULE hModule)
359 {
360     PVOID pEntry = getImageDirectoryEntry(hModule, "(wrapper)", IMAGE_DIRECTORY_ENTRY_EXPORT);
361     return reinterpret_cast<PIMAGE_EXPORT_DIRECTORY>(pEntry);
362 }
363 
364 
365 /* Set of previously hooked modules */
366 static std::set<HMODULE>
367 g_hHookedModules;
368 
369 
370 enum Action {
371     ACTION_HOOK,
372     ACTION_UNHOOK,
373 
374 };
375 
376 
377 static void
patchModule(HMODULE hModule,const char * szModule,Action action)378 patchModule(HMODULE hModule,
379             const char *szModule,
380             Action action)
381 {
382     /* Never patch this module */
383     if (hModule == g_hThisModule) {
384         return;
385     }
386 
387     /* Never patch our hook module */
388     if (hModule == g_hHookModule) {
389         return;
390     }
391 
392     /* Hook modules only once */
393     if (action == ACTION_HOOK) {
394         std::pair< std::set<HMODULE>::iterator, bool > ret;
395         EnterCriticalSection(&g_Mutex);
396         ret = g_hHookedModules.insert(hModule);
397         LeaveCriticalSection(&g_Mutex);
398         if (!ret.second) {
399             return;
400         }
401     }
402 
403     if (VERBOSITY > 0) {
404         debugPrintf("inject: found module %s\n", szModule);
405     }
406 
407     const char *szBaseName = getBaseName(szModule);
408 
409     if (stricmp(szBaseName, "opengl32.dll") != 0 &&
410         stricmp(szBaseName, "dxgi.dll") != 0 &&
411         stricmp(szBaseName, "d3d10.dll") != 0 &&
412         stricmp(szBaseName, "d3d10_1.dll") != 0 &&
413         stricmp(szBaseName, "d3d11.dll") != 0 &&
414         stricmp(szBaseName, "d3d9.dll") != 0 &&
415         stricmp(szBaseName, "d3d8.dll") != 0 &&
416         stricmp(szBaseName, "ddraw.dll") != 0 &&
417         stricmp(szBaseName, "dcomp.dll") != 0 &&
418         stricmp(szBaseName, "dxva2.dll") != 0 &&
419         stricmp(szBaseName, "d2d1.dll") != 0 &&
420         stricmp(szBaseName, "dwrite.dll") != 0) {
421         return;
422     }
423 
424     PIMAGE_EXPORT_DIRECTORY pExportDescriptor = getExportDescriptor(g_hHookModule);
425     assert(pExportDescriptor);
426 
427     DWORD *pAddressOfNames = (DWORD *)((BYTE *)g_hHookModule + pExportDescriptor->AddressOfNames);
428     for (DWORD i = 0; i < pExportDescriptor->NumberOfNames; ++i) {
429         const char *szFunctionName = (const char *)((BYTE *)g_hHookModule + pAddressOfNames[i]);
430 
431         LPVOID lpOrigAddress = (LPVOID)RealGetProcAddress(hModule, szFunctionName);
432         if (lpOrigAddress) {
433             LPVOID lpHookAddress = (LPVOID)RealGetProcAddress(g_hHookModule, szFunctionName);
434             assert(lpHookAddress);
435 
436             // With mhook we intercept the inner gl* calls, so no need to trace the
437             // outer wglUseFont* calls.
438             if (strncmp(szFunctionName, "wglUseFont", strlen("wglUseFont")) == 0) {
439                 debugPrintf("inject: not hooking %s!%s\n", szBaseName, szFunctionName);
440                 continue;
441             }
442 
443             if (VERBOSITY > 0) {
444                 debugPrintf("inject: hooking %s!%s\n", szModule, szFunctionName);
445             }
446 
447             LPVOID lpRealAddress = lpOrigAddress;
448             if (!Mhook_SetHook(&lpRealAddress, lpHookAddress)) {
449                 debugPrintf("inject: error: failed to hook %s!%s\n", szModule, szFunctionName);
450             }
451 
452             EnterCriticalSection(&g_Mutex);
453             g_pRealFunctions[lpOrigAddress] = lpRealAddress;
454             LeaveCriticalSection(&g_Mutex);
455 
456             pSharedMem->bReplaced = TRUE;
457         }
458     }
459 }
460 
461 
462 static void
patchAllModules(Action action)463 patchAllModules(Action action)
464 {
465     HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId());
466     if (hModuleSnap == INVALID_HANDLE_VALUE) {
467         return;
468     }
469 
470     MODULEENTRY32 me32;
471     me32.dwSize = sizeof me32;
472     if (Module32First(hModuleSnap, &me32)) {
473         do  {
474             patchModule(me32.hModule, me32.szExePath, action);
475         } while (Module32Next(hModuleSnap, &me32));
476     }
477 
478     CloseHandle(hModuleSnap);
479 }
480 
481 
482 static HMODULE WINAPI
MyLoadLibraryA(LPCSTR lpLibFileName)483 MyLoadLibraryA(LPCSTR lpLibFileName)
484 {
485     HMODULE hModule = RealLoadLibraryA(lpLibFileName);
486     DWORD dwLastError = GetLastError();
487 
488     if (VERBOSITY >= 2) {
489         debugPrintf("inject: intercepting %s(\"%s\") = 0x%p\n",
490                     __FUNCTION__ + 2, lpLibFileName, hModule);
491     }
492 
493     // Hook all new modules (and not just this one, to pick up any dependencies)
494     patchAllModules(ACTION_HOOK);
495 
496     SetLastError(dwLastError);
497     return hModule;
498 }
499 
500 static HMODULE WINAPI
MyLoadLibraryW(LPCWSTR lpLibFileName)501 MyLoadLibraryW(LPCWSTR lpLibFileName)
502 {
503     HMODULE hModule = RealLoadLibraryW(lpLibFileName);
504     DWORD dwLastError = GetLastError();
505 
506     if (VERBOSITY >= 2) {
507         debugPrintf("inject: intercepting %s(L\"%S\") = 0x%p\n",
508                     __FUNCTION__ + 2, lpLibFileName, hModule);
509     }
510 
511     // Hook all new modules (and not just this one, to pick up any dependencies)
512     patchAllModules(ACTION_HOOK);
513 
514     SetLastError(dwLastError);
515     return hModule;
516 }
517 
518 #ifndef LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR
519 #define LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR    0x00000100
520 #endif
521 #ifndef LOAD_LIBRARY_SEARCH_APPLICATION_DIR
522 #define LOAD_LIBRARY_SEARCH_APPLICATION_DIR 0x00000200
523 #endif
524 #ifndef LOAD_LIBRARY_SEARCH_USER_DIRS
525 #define LOAD_LIBRARY_SEARCH_USER_DIRS       0x00000400
526 #endif
527 #ifndef LOAD_LIBRARY_SEARCH_SYSTEM32
528 #define LOAD_LIBRARY_SEARCH_SYSTEM32        0x00000800
529 #endif
530 #ifndef LOAD_LIBRARY_SEARCH_DEFAULT_DIRS
531 #define LOAD_LIBRARY_SEARCH_DEFAULT_DIRS    0x00001000
532 #endif
533 
534 static inline DWORD
adjustFlags(DWORD dwFlags)535 adjustFlags(DWORD dwFlags)
536 {
537     /*
538      * XXX: LoadLibraryEx seems to interpret "application directory" in respect
539      * to the module that's calling it.  So when the application restricts the
540      * search path to application directory via
541      * LOAD_LIBRARY_SEARCH_APPLICATION_DIR or LOAD_LIBRARY_SEARCH_DEFAULT_DIRS
542      * flags, kernel32.dll ends up searching on the directory of the inject.dll
543      * module.
544      *
545      * XXX: What about SetDefaultDllDirectories?
546      *
547      */
548     if (dwFlags & (LOAD_LIBRARY_SEARCH_APPLICATION_DIR |
549                    LOAD_LIBRARY_SEARCH_DEFAULT_DIRS)) {
550         dwFlags &= ~(LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR |
551                      LOAD_LIBRARY_SEARCH_APPLICATION_DIR |
552                      LOAD_LIBRARY_SEARCH_USER_DIRS |
553                      LOAD_LIBRARY_SEARCH_SYSTEM32 |
554                      LOAD_LIBRARY_SEARCH_DEFAULT_DIRS);
555     }
556 
557     return dwFlags;
558 }
559 
560 static HMODULE WINAPI
MyLoadLibraryExA(LPCSTR lpLibFileName,HANDLE hFile,DWORD dwFlags)561 MyLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
562 {
563     HMODULE hModule = RealLoadLibraryExA(lpLibFileName, hFile, adjustFlags(dwFlags));
564     DWORD dwLastError = GetLastError();
565 
566     if (VERBOSITY >= 2) {
567         debugPrintf("inject: intercepting %s(\"%s\", 0x%p, 0x%lx) = 0x%p\n",
568                     __FUNCTION__ + 2, lpLibFileName, hFile, dwFlags, hModule);
569     }
570 
571     // Hook all new modules (and not just this one, to pick up any dependencies)
572     patchAllModules(ACTION_HOOK);
573 
574     SetLastError(dwLastError);
575     return hModule;
576 }
577 
578 static HMODULE WINAPI
MyLoadLibraryExW(LPCWSTR lpLibFileName,HANDLE hFile,DWORD dwFlags)579 MyLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
580 {
581     HMODULE hModule = RealLoadLibraryExW(lpLibFileName, hFile, adjustFlags(dwFlags));
582     DWORD dwLastError = GetLastError();
583 
584     if (VERBOSITY >= 2) {
585         debugPrintf("inject: intercepting %s(L\"%S\", 0x%p, 0x%lx) = 0x%p\n",
586                     __FUNCTION__ + 2, lpLibFileName, hFile, dwFlags, hModule);
587     }
588 
589     // Hook all new modules (and not just this one, to pick up any dependencies)
590     patchAllModules(ACTION_HOOK);
591 
592     SetLastError(dwLastError);
593     return hModule;
594 }
595 
596 
597 static void
logGetProcAddress(PCSTR szAction,HMODULE hModule,LPCSTR lpProcName,HMODULE hCallerModule)598 logGetProcAddress(PCSTR szAction, HMODULE hModule, LPCSTR lpProcName, HMODULE hCallerModule)
599 {
600     char szCaller[MAX_PATH];
601     DWORD dwRet = GetModuleFileNameA(hCallerModule, szCaller, sizeof szCaller);
602     assert(dwRet);
603 
604     if (HIWORD(lpProcName) == 0) {
605         debugPrintf("inject: %s %s(%u) from %s\n", szAction, "GetProcAddress", LOWORD(lpProcName), szCaller);
606     } else {
607         debugPrintf("inject: %s %s(\"%s\") from %s\n", szAction, "GetProcAddress", lpProcName, szCaller);
608     }
609 }
610 
611 static HMODULE
GetModuleFromAddress(PVOID pAddress)612 GetModuleFromAddress(PVOID pAddress)
613 {
614     HMODULE hModule = nullptr;
615     BOOL bRet = GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
616                                   GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
617                                   (LPCTSTR)pAddress,
618                                   &hModule);
619     return bRet ? hModule : nullptr;
620 }
621 
622 
623 static FARPROC WINAPI
MyGetProcAddress(HMODULE hModule,LPCSTR lpProcName)624 MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName)
625 {
626     FARPROC lpProcAddress = RealGetProcAddress(hModule, lpProcName);
627     LPCSTR szAction = "ignoring";
628 
629     void *pCallerAddr = ReturnAddress();
630     HMODULE hCallerModule = GetModuleFromAddress(pCallerAddr);
631 
632     if (lpProcAddress) {
633         if (hCallerModule == g_hThisModule ||
634             hCallerModule == g_hHookModule) {
635 
636             assert(HIWORD(lpProcName) != 0);
637 
638             if (hCallerModule == g_hHookModule) {
639                 EnterCriticalSection(&g_Mutex);
640                 auto search = g_pRealFunctions.find((PVOID)lpProcAddress);
641                 if (search != g_pRealFunctions.end()) {
642                     szAction = "overriding";
643                     lpProcAddress = (FARPROC)search->second;
644                 }
645                 LeaveCriticalSection(&g_Mutex);
646             }
647 
648             // Check for recursion
649             HMODULE hResultModule = GetModuleFromAddress((PVOID)lpProcAddress);
650             if (hResultModule == g_hThisModule ||
651                 hResultModule == g_hHookModule) {
652                 debugPrintf("inject: error: recursion in GetProcAddress(\"%s\")\n", lpProcName);
653                 TerminateProcess(GetCurrentProcess(), 1);
654             }
655         }
656     }
657 
658     if (VERBOSITY >= 3) {
659         /* XXX this can cause segmentation faults */
660         logGetProcAddress(szAction, hModule, lpProcName, hCallerModule);
661     }
662 
663     return lpProcAddress;
664 }
665 
666 
667 static BOOL WINAPI
MyFreeLibrary(HMODULE hModule)668 MyFreeLibrary(HMODULE hModule)
669 {
670     if (VERBOSITY >= 2) {
671         debugPrintf("inject: intercepting %s(0x%p)\n", __FUNCTION__, hModule);
672     }
673 
674     BOOL bRet = RealFreeLibrary(hModule);
675     DWORD dwLastError = GetLastError();
676 
677     std::set<HMODULE> hCurrentModules;
678     HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId());
679     if (hModuleSnap != INVALID_HANDLE_VALUE) {
680         MODULEENTRY32 me32;
681         me32.dwSize = sizeof me32;
682         if (Module32First(hModuleSnap, &me32)) {
683             do  {
684                 hCurrentModules.insert(me32.hModule);
685             } while (Module32Next(hModuleSnap, &me32));
686         }
687         CloseHandle(hModuleSnap);
688     }
689 
690     // Clear the modules that have been freed
691     EnterCriticalSection(&g_Mutex);
692     std::set<HMODULE> hIntersectedModules;
693     std::set_intersection(g_hHookedModules.begin(), g_hHookedModules.end(),
694                           hCurrentModules.begin(), hCurrentModules.end(),
695                           std::inserter(hIntersectedModules, hIntersectedModules.begin()));
696     g_hHookedModules = std::move(hIntersectedModules);
697     LeaveCriticalSection(&g_Mutex);
698 
699     SetLastError(dwLastError);
700     return bRet;
701 }
702 
703 
704 static void
setHooks(void)705 setHooks(void)
706 {
707     HMODULE hKernel32 = GetModuleHandleA("kernel32");
708     assert(hKernel32);
709 
710 #   define SET_HOOK(_name) \
711         Real##_name = reinterpret_cast<decltype(Real##_name)>(RealGetProcAddress(hKernel32, #_name)); \
712         assert(Real##_name); \
713         assert(Real##_name != My##_name); \
714         if (!Mhook_SetHook((PVOID*)&Real##_name, (PVOID)My##_name)) { \
715             debugPrintf("inject: error: failed to hook " #_name "\n"); \
716         }
717 
718     SET_HOOK(LoadLibraryA)
719     SET_HOOK(LoadLibraryW)
720     SET_HOOK(LoadLibraryExA)
721     SET_HOOK(LoadLibraryExW)
722     SET_HOOK(FreeLibrary)
723     SET_HOOK(GetProcAddress)
724     SET_HOOK(CreateProcessA)
725     SET_HOOK(CreateProcessW)
726     SET_HOOK(CreateProcessAsUserW)
727 
728 #   undef SET_HOOK
729 }
730 
731 
732 EXTERN_C BOOL WINAPI
DllMain(HINSTANCE hinstDLL,DWORD fdwReason,LPVOID lpvReserved)733 DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved)
734 {
735     const char *szNewDllName = NULL;
736 
737     switch (fdwReason) {
738     case DLL_PROCESS_ATTACH:
739         InitializeCriticalSection(&g_Mutex);
740 
741         g_hThisModule = hinstDLL;
742 
743         /*
744          * Calling LoadLibrary inside DllMain is strongly discouraged.  But it
745          * works quite well, provided that the loaded DLL does not require or do
746          * anything special in its DllMain, which seems to be the general case.
747          *
748          * See also:
749          * - http://stackoverflow.com/questions/4370812/calling-loadlibrary-from-dllmain
750          * - http://msdn.microsoft.com/en-us/library/ms682583
751          */
752 
753         if (!USE_SHARED_MEM) {
754             szNewDllName = getenv("INJECT_DLL");
755             if (!szNewDllName) {
756                 debugPrintf("inject: warning: INJECT_DLL not set\n");
757                 return FALSE;
758             }
759         } else {
760             SharedMem *pSharedMem = OpenSharedMemory(NULL);
761             if (!pSharedMem) {
762                 debugPrintf("inject: error: failed to open shared memory\n");
763                 return FALSE;
764             }
765 
766             VERBOSITY = pSharedMem->cVerbosity;
767 
768             static char szSharedMemCopy[MAX_PATH];
769             strncpy(szSharedMemCopy, pSharedMem->szDllName, _countof(szSharedMemCopy) - 1);
770             szSharedMemCopy[_countof(szSharedMemCopy) - 1] = '\0';
771 
772             szNewDllName = szSharedMemCopy;
773         }
774 
775         if (VERBOSITY > 0) {
776             debugPrintf("inject: DLL_PROCESS_ATTACH\n");
777         }
778 
779         if (VERBOSITY > 0) {
780             char szProcess[MAX_PATH];
781             GetModuleFileNameA(NULL, szProcess, sizeof szProcess);
782             debugPrintf("inject: attached to process %s\n", szProcess);
783         }
784 
785         if (VERBOSITY > 0) {
786             debugPrintf("inject: loading %s\n", szNewDllName);
787         }
788 
789         g_hHookModule = LoadLibraryA(szNewDllName);
790         if (!g_hHookModule) {
791             debugPrintf("inject: warning: failed to load %s\n", szNewDllName);
792             return FALSE;
793         }
794 
795         patchAllModules(ACTION_HOOK);
796 
797         setHooks();
798 
799         break;
800 
801     case DLL_THREAD_ATTACH:
802         break;
803 
804     case DLL_THREAD_DETACH:
805         break;
806 
807     case DLL_PROCESS_DETACH:
808         if (VERBOSITY > 0) {
809             debugPrintf("inject: DLL_PROCESS_DETACH\n");
810         }
811 
812         assert(!lpvReserved);
813 
814         patchAllModules(ACTION_UNHOOK);
815 
816         if (g_hHookModule) {
817             FreeLibrary(g_hHookModule);
818         }
819 
820         Mhook_Unhook((PVOID*)&RealGetProcAddress);
821 
822         break;
823     }
824     return TRUE;
825 }
826 
827 
828 /*
829  * Prevent the C/C++ runtime from destroying things when the program
830  * terminates.
831  *
832  * There is no effective way to control the order DLLs receive
833  * DLL_PROCESS_DETACH -- patched DLLs might get detacched after we are --, and
834  * unpatching our hooks doesn't always work.  So instead just do nothing (and
835  * prevent C/C++ runtime from doing anything too), so our hooks can still work
836  * after we are dettached.
837  */
838 
839 #ifdef _MSC_VER
840 #  define DLLMAIN_CRT_STARTUP _DllMainCRTStartup
841 #else
842 #  define DLLMAIN_CRT_STARTUP DllMainCRTStartup
843 #  pragma GCC optimize ("no-stack-protector")
844 #endif
845 
846 EXTERN_C BOOL WINAPI
847 DLLMAIN_CRT_STARTUP(HANDLE hDllHandle, DWORD dwReason, LPVOID lpvReserved);
848 
849 EXTERN_C BOOL WINAPI
DllMainStartup(HANDLE hDllHandle,DWORD dwReason,LPVOID lpvReserved)850 DllMainStartup(HANDLE hDllHandle, DWORD dwReason, LPVOID lpvReserved)
851 {
852     if (dwReason == DLL_PROCESS_DETACH && lpvReserved) {
853         if (VERBOSITY > 0) {
854             debugPrintf("inject: DLL_PROCESS_DETACH\n");
855         }
856         return TRUE;
857     }
858 
859     return DLLMAIN_CRT_STARTUP(hDllHandle, dwReason, lpvReserved);
860 }
861