1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this
5  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #ifndef NS_WINDOWS_DLL_INTERCEPTOR_H_
8 #define NS_WINDOWS_DLL_INTERCEPTOR_H_
9 
10 #include "mozilla/Assertions.h"
11 #include "mozilla/ArrayUtils.h"
12 #include "mozilla/UniquePtr.h"
13 #include "nsWindowsHelpers.h"
14 
15 #include <wchar.h>
16 #include <windows.h>
17 #include <winternl.h>
18 
19 /*
20  * Simple function interception.
21  *
22  * We have two separate mechanisms for intercepting a function: We can use the
23  * built-in nop space, if it exists, or we can create a detour.
24  *
25  * Using the built-in nop space works as follows: On x86-32, DLL functions
26  * begin with a two-byte nop (mov edi, edi) and are preceeded by five bytes of
27  * NOP instructions.
28  *
29  * When we detect a function with this prelude, we do the following:
30  *
31  * 1. Write a long jump to our interceptor function into the five bytes of NOPs
32  *    before the function.
33  *
34  * 2. Write a short jump -5 into the two-byte nop at the beginning of the function.
35  *
36  * This mechanism is nice because it's thread-safe.  It's even safe to do if
37  * another thread is currently running the function we're modifying!
38  *
39  * When the WindowsDllNopSpacePatcher is destroyed, we overwrite the short jump
40  * but not the long jump, so re-intercepting the same function won't work,
41  * because its prelude won't match.
42  *
43  *
44  * Unfortunately nop space patching doesn't work on functions which don't have
45  * this magic prelude (and in particular, x86-64 never has the prelude).  So
46  * when we can't use the built-in nop space, we fall back to using a detour,
47  * which works as follows:
48  *
49  * 1. Save first N bytes of OrigFunction to trampoline, where N is a
50  *    number of bytes >= 5 that are instruction aligned.
51  *
52  * 2. Replace first 5 bytes of OrigFunction with a jump to the Hook
53  *    function.
54  *
55  * 3. After N bytes of the trampoline, add a jump to OrigFunction+N to
56  *    continue original program flow.
57  *
58  * 4. Hook function needs to call the trampoline during its execution,
59  *    to invoke the original function (so address of trampoline is
60  *    returned).
61  *
62  * When the WindowsDllDetourPatcher object is destructed, OrigFunction is
63  * patched again to jump directly to the trampoline instead of going through
64  * the hook function. As such, re-intercepting the same function won't work, as
65  * jump instructions are not supported.
66  *
67  * Note that this is not thread-safe.  Sad day.
68  *
69  */
70 
71 #include <stdint.h>
72 
73 namespace mozilla {
74 namespace internal {
75 
76 class AutoVirtualProtect
77 {
78 public:
AutoVirtualProtect(void * aFunc,size_t aSize,DWORD aProtect)79   AutoVirtualProtect(void* aFunc, size_t aSize, DWORD aProtect)
80     : mFunc(aFunc), mSize(aSize), mNewProtect(aProtect), mOldProtect(0),
81       mSuccess(false)
82   {}
83 
~AutoVirtualProtect()84   ~AutoVirtualProtect()
85   {
86     if (mSuccess) {
87       VirtualProtectEx(GetCurrentProcess(), mFunc, mSize, mOldProtect,
88                        &mOldProtect);
89     }
90   }
91 
Protect()92   bool Protect()
93   {
94     mSuccess = !!VirtualProtectEx(GetCurrentProcess(), mFunc, mSize,
95                                   mNewProtect, &mOldProtect);
96     if (!mSuccess) {
97       // printf("VirtualProtectEx failed! %d\n", GetLastError());
98     }
99     return mSuccess;
100   }
101 
102 private:
103   void* const mFunc;
104   size_t const mSize;
105   DWORD const mNewProtect;
106   DWORD mOldProtect;
107   bool mSuccess;
108 };
109 
110 class WindowsDllNopSpacePatcher
111 {
112   typedef uint8_t* byteptr_t;
113   HMODULE mModule;
114 
115   // Dumb array for remembering the addresses of functions we've patched.
116   // (This should be nsTArray, but non-XPCOM code uses this class.)
117   static const size_t maxPatchedFns = 128;
118   byteptr_t mPatchedFns[maxPatchedFns];
119   int mPatchedFnsLen;
120 
121 public:
WindowsDllNopSpacePatcher()122   WindowsDllNopSpacePatcher()
123     : mModule(0)
124     , mPatchedFnsLen(0)
125   {}
126 
127 #if defined(_M_IX86)
~WindowsDllNopSpacePatcher()128   ~WindowsDllNopSpacePatcher()
129   {
130     // Restore the mov edi, edi to the beginning of each function we patched.
131 
132     for (int i = 0; i < mPatchedFnsLen; i++) {
133       byteptr_t fn = mPatchedFns[i];
134 
135       // Ensure we can write to the code.
136       AutoVirtualProtect protect(fn, 2, PAGE_EXECUTE_READWRITE);
137       if (!protect.Protect()) {
138         continue;
139       }
140 
141       // mov edi, edi
142       *((uint16_t*)fn) = 0xff8b;
143 
144       // I don't think this is actually necessary, but it can't hurt.
145       FlushInstructionCache(GetCurrentProcess(),
146                             /* ignored */ nullptr,
147                             /* ignored */ 0);
148     }
149   }
150 
Init(const char * aModuleName)151   void Init(const char* aModuleName)
152   {
153     if (!IsCompatible()) {
154 #if defined(MOZILLA_INTERNAL_API)
155       NS_WARNING("NOP space patching is unavailable for compatibility reasons");
156 #endif
157       return;
158     }
159 
160     mModule = LoadLibraryExA(aModuleName, nullptr, 0);
161     if (!mModule) {
162       //printf("LoadLibraryEx for '%s' failed\n", aModuleName);
163       return;
164     }
165   }
166 
167   /**
168    * NVIDIA Optimus drivers utilize Microsoft Detours 2.x to patch functions
169    * in our address space. There is a bug in Detours 2.x that causes it to
170    * patch at the wrong address when attempting to detour code that is already
171    * NOP space patched. This function is an effort to detect the presence of
172    * this NVIDIA code in our address space and disable NOP space patching if it
173    * is. We also check AppInit_DLLs since this is the mechanism that the Optimus
174    * drivers use to inject into our process.
175    */
IsCompatible()176   static bool IsCompatible()
177   {
178     // These DLLs are known to have bad interactions with this style of patching
179     const wchar_t* kIncompatibleDLLs[] = {
180       L"detoured.dll",
181       L"_etoured.dll",
182       L"nvd3d9wrap.dll",
183       L"nvdxgiwrap.dll"
184     };
185     // See if the infringing DLLs are already loaded
186     for (unsigned int i = 0; i < mozilla::ArrayLength(kIncompatibleDLLs); ++i) {
187       if (GetModuleHandleW(kIncompatibleDLLs[i])) {
188         return false;
189       }
190     }
191     if (GetModuleHandleW(L"user32.dll")) {
192       // user32 is loaded but the infringing DLLs are not, assume we're safe to
193       // proceed.
194       return true;
195     }
196     // If user32 has not loaded yet, check AppInit_DLLs to ensure that Optimus
197     // won't be loaded once user32 is initialized.
198     HKEY hkey = NULL;
199     if (!RegOpenKeyExW(HKEY_LOCAL_MACHINE,
200           L"SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\Windows",
201           0, KEY_QUERY_VALUE, &hkey)) {
202       nsAutoRegKey key(hkey);
203       DWORD numBytes = 0;
204       const wchar_t kAppInitDLLs[] = L"AppInit_DLLs";
205       // Query for required buffer size
206       LONG status = RegQueryValueExW(hkey, kAppInitDLLs, nullptr,
207                                      nullptr, nullptr, &numBytes);
208       mozilla::UniquePtr<wchar_t[]> data;
209       if (!status) {
210         // Allocate the buffer and query for the actual data
211         data = mozilla::MakeUnique<wchar_t[]>(numBytes / sizeof(wchar_t));
212         status = RegQueryValueExW(hkey, kAppInitDLLs, nullptr,
213                                   nullptr, (LPBYTE)data.get(), &numBytes);
214       }
215       if (!status) {
216         // For each token, split up the filename components and then check the
217         // name of the file.
218         const wchar_t kDelimiters[] = L", ";
219         wchar_t* tokenContext = nullptr;
220         wchar_t* token = wcstok_s(data.get(), kDelimiters, &tokenContext);
221         while (token) {
222           wchar_t fname[_MAX_FNAME] = {0};
223           if (!_wsplitpath_s(token, nullptr, 0, nullptr, 0,
224                              fname, mozilla::ArrayLength(fname),
225                              nullptr, 0)) {
226             // nvinit.dll is responsible for bootstrapping the DLL injection, so
227             // that is the library that we check for here
228             const wchar_t kNvInitName[] = L"nvinit";
229             if (!_wcsnicmp(fname, kNvInitName,
230                            mozilla::ArrayLength(kNvInitName))) {
231               return false;
232             }
233           }
234           token = wcstok_s(nullptr, kDelimiters, &tokenContext);
235         }
236       }
237     }
238     return true;
239   }
240 
AddHook(const char * aName,intptr_t aHookDest,void ** aOrigFunc)241   bool AddHook(const char* aName, intptr_t aHookDest, void** aOrigFunc)
242   {
243     if (!mModule) {
244       return false;
245     }
246 
247     if (!IsCompatible()) {
248 #if defined(MOZILLA_INTERNAL_API)
249       NS_WARNING("NOP space patching is unavailable for compatibility reasons");
250 #endif
251       return false;
252     }
253 
254     if (mPatchedFnsLen == maxPatchedFns) {
255       // printf ("No space for hook in mPatchedFns.\n");
256       return false;
257     }
258 
259     byteptr_t fn = reinterpret_cast<byteptr_t>(GetProcAddress(mModule, aName));
260     if (!fn) {
261       //printf ("GetProcAddress failed\n");
262       return false;
263     }
264 
265     fn = ResolveRedirectedAddress(fn);
266 
267     // Ensure we can read and write starting at fn - 5 (for the long jmp we're
268     // going to write) and ending at fn + 2 (for the short jmp up to the long
269     // jmp). These bytes may span two pages with different protection.
270     AutoVirtualProtect protectBefore(fn - 5, 5, PAGE_EXECUTE_READWRITE);
271     AutoVirtualProtect protectAfter(fn, 2, PAGE_EXECUTE_READWRITE);
272     if (!protectBefore.Protect() || !protectAfter.Protect()) {
273       return false;
274     }
275 
276     bool rv = WriteHook(fn, aHookDest, aOrigFunc);
277 
278     if (rv) {
279       mPatchedFns[mPatchedFnsLen] = fn;
280       mPatchedFnsLen++;
281     }
282 
283     return rv;
284   }
285 
WriteHook(byteptr_t aFn,intptr_t aHookDest,void ** aOrigFunc)286   bool WriteHook(byteptr_t aFn, intptr_t aHookDest, void** aOrigFunc)
287   {
288     // Check that the 5 bytes before aFn are NOP's or INT 3's,
289     // and that the 2 bytes after aFn are mov(edi, edi).
290     //
291     // It's safe to read aFn[-5] because we set it to PAGE_EXECUTE_READWRITE
292     // before calling WriteHook.
293 
294     for (int i = -5; i <= -1; i++) {
295       if (aFn[i] != 0x90 && aFn[i] != 0xcc) { // nop or int 3
296         return false;
297       }
298     }
299 
300     // mov edi, edi.  Yes, there are two ways to encode the same thing:
301     //
302     //   0x89ff == mov r/m, r
303     //   0x8bff == mov r, r/m
304     //
305     // where "r" is register and "r/m" is register or memory.  Windows seems to
306     // use 8bff; I include 89ff out of paranoia.
307     if ((aFn[0] != 0x8b && aFn[0] != 0x89) || aFn[1] != 0xff) {
308       return false;
309     }
310 
311     // Write a long jump into the space above the function.
312     aFn[-5] = 0xe9; // jmp
313     *((intptr_t*)(aFn - 4)) = aHookDest - (uintptr_t)(aFn); // target displacement
314 
315     // Set aOrigFunc here, because after this point, aHookDest might be called,
316     // and aHookDest might use the aOrigFunc pointer.
317     *aOrigFunc = aFn + 2;
318 
319     // Short jump up into our long jump.
320     *((uint16_t*)(aFn)) = 0xf9eb; // jmp $-5
321 
322     // I think this routine is safe without this, but it can't hurt.
323     FlushInstructionCache(GetCurrentProcess(),
324                           /* ignored */ nullptr,
325                           /* ignored */ 0);
326 
327     return true;
328   }
329 
330 private:
ResolveRedirectedAddress(const byteptr_t aOriginalFunction)331   static byteptr_t ResolveRedirectedAddress(const byteptr_t aOriginalFunction)
332   {
333     // If function entry is jmp [disp32] such as used by kernel32,
334     // we resolve redirected address from import table.
335     if (aOriginalFunction[0] == 0xff && aOriginalFunction[1] == 0x25) {
336       return (byteptr_t)(**((uint32_t**) (aOriginalFunction + 2)));
337     }
338 
339     return aOriginalFunction;
340   }
341 #else
Init(const char * aModuleName)342   void Init(const char* aModuleName)
343   {
344     // Not implemented except on x86-32.
345   }
346 
AddHook(const char * aName,intptr_t aHookDest,void ** aOrigFunc)347   bool AddHook(const char* aName, intptr_t aHookDest, void** aOrigFunc)
348   {
349     // Not implemented except on x86-32.
350     return false;
351   }
352 #endif
353 };
354 
355 class WindowsDllDetourPatcher
356 {
357   typedef unsigned char* byteptr_t;
358 public:
WindowsDllDetourPatcher()359   WindowsDllDetourPatcher()
360     : mModule(0), mHookPage(0), mMaxHooks(0), mCurHooks(0)
361   {
362   }
363 
~WindowsDllDetourPatcher()364   ~WindowsDllDetourPatcher()
365   {
366     int i;
367     byteptr_t p;
368     for (i = 0, p = mHookPage; i < mCurHooks; i++, p += kHookSize) {
369 #if defined(_M_IX86)
370       size_t nBytes = 1 + sizeof(intptr_t);
371 #elif defined(_M_X64)
372       size_t nBytes = 2 + sizeof(intptr_t);
373 #else
374 #error "Unknown processor type"
375 #endif
376       byteptr_t origBytes = (byteptr_t)DecodePointer(*((byteptr_t*)p));
377 
378       // ensure we can modify the original code
379       AutoVirtualProtect protect(origBytes, nBytes, PAGE_EXECUTE_READWRITE);
380       if (!protect.Protect()) {
381         continue;
382       }
383 
384       // Remove the hook by making the original function jump directly
385       // in the trampoline.
386       intptr_t dest = (intptr_t)(p + sizeof(void*));
387 #if defined(_M_IX86)
388       // Ensure the JMP from CreateTrampoline is where we expect it to be.
389       if (origBytes[0] != 0xE9)
390         continue;
391       *((intptr_t*)(origBytes + 1)) =
392         dest - (intptr_t)(origBytes + 5); // target displacement
393 #elif defined(_M_X64)
394       // Ensure the MOV R11 from CreateTrampoline is where we expect it to be.
395       if (origBytes[0] != 0x49 || origBytes[1] != 0xBB)
396         continue;
397       *((intptr_t*)(origBytes + 2)) = dest;
398 #else
399 #error "Unknown processor type"
400 #endif
401     }
402   }
403 
404   void Init(const char* aModuleName, int aNumHooks = 0)
405   {
406     if (mModule) {
407       return;
408     }
409 
410     mModule = LoadLibraryExA(aModuleName, nullptr, 0);
411     if (!mModule) {
412       //printf("LoadLibraryEx for '%s' failed\n", aModuleName);
413       return;
414     }
415 
416     int hooksPerPage = 4096 / kHookSize;
417     if (aNumHooks == 0) {
418       aNumHooks = hooksPerPage;
419     }
420 
421     mMaxHooks = aNumHooks + (hooksPerPage % aNumHooks);
422 
423     mHookPage = (byteptr_t)VirtualAllocEx(GetCurrentProcess(), nullptr,
424                                           mMaxHooks * kHookSize,
425                                           MEM_COMMIT | MEM_RESERVE,
426                                           PAGE_EXECUTE_READ);
427     if (!mHookPage) {
428       mModule = 0;
429       return;
430     }
431   }
432 
Initialized()433   bool Initialized() { return !!mModule; }
434 
AddHook(const char * aName,intptr_t aHookDest,void ** aOrigFunc)435   bool AddHook(const char* aName, intptr_t aHookDest, void** aOrigFunc)
436   {
437     if (!mModule) {
438       return false;
439     }
440 
441     void* pAddr = (void*)GetProcAddress(mModule, aName);
442     if (!pAddr) {
443       //printf ("GetProcAddress failed\n");
444       return false;
445     }
446 
447     pAddr = ResolveRedirectedAddress((byteptr_t)pAddr);
448 
449     CreateTrampoline(pAddr, aHookDest, aOrigFunc);
450     if (!*aOrigFunc) {
451       //printf ("CreateTrampoline failed\n");
452       return false;
453     }
454 
455     return true;
456   }
457 
458 protected:
459   const static int kPageSize = 4096;
460   const static int kHookSize = 128;
461 
462   HMODULE mModule;
463   byteptr_t mHookPage;
464   int mMaxHooks;
465   int mCurHooks;
466 
467   // rex bits
468   static const BYTE kMaskHighNibble = 0xF0;
469   static const BYTE kRexOpcode = 0x40;
470   static const BYTE kMaskRexW = 0x08;
471   static const BYTE kMaskRexR = 0x04;
472   static const BYTE kMaskRexX = 0x02;
473   static const BYTE kMaskRexB = 0x01;
474 
475   // mod r/m bits
476   static const BYTE kRegFieldShift = 3;
477   static const BYTE kMaskMod = 0xC0;
478   static const BYTE kMaskReg = 0x38;
479   static const BYTE kMaskRm = 0x07;
480   static const BYTE kRmNeedSib = 0x04;
481   static const BYTE kModReg = 0xC0;
482   static const BYTE kModDisp32 = 0x80;
483   static const BYTE kModDisp8 = 0x40;
484   static const BYTE kModNoRegDisp = 0x00;
485   static const BYTE kRmNoRegDispDisp32 = 0x05;
486 
487   // sib bits
488   static const BYTE kMaskSibScale = 0xC0;
489   static const BYTE kMaskSibIndex = 0x38;
490   static const BYTE kMaskSibBase = 0x07;
491   static const BYTE kSibBaseEbp = 0x05;
492 
493   int CountModRmSib(const BYTE *aModRm, BYTE* aSubOpcode = nullptr)
494   {
495     if (!aModRm) {
496       return -1;
497     }
498     int numBytes = 1; // Start with 1 for mod r/m byte itself
499     switch (*aModRm & kMaskMod) {
500       case kModReg:
501         return numBytes;
502       case kModDisp8:
503         numBytes += 1;
504         break;
505       case kModDisp32:
506         numBytes += 4;
507         break;
508       case kModNoRegDisp:
509         if ((*aModRm & kMaskRm) == kRmNoRegDispDisp32) {
510 #if defined(_M_X64)
511           // RIP-relative on AMD64, currently unsupported
512           return -1;
513 #else
514           // On IA-32, all ModR/M instruction modes address memory relative to 0
515           numBytes += 4;
516 #endif
517         } else if (((*aModRm & kMaskRm) == kRmNeedSib &&
518              (*(aModRm + 1) & kMaskSibBase) == kSibBaseEbp)) {
519           numBytes += 4;
520         }
521         break;
522       default:
523         // This should not be reachable
524         MOZ_ASSERT_UNREACHABLE("Impossible value for modr/m byte mod bits");
525         return -1;
526     }
527     if ((*aModRm & kMaskRm) == kRmNeedSib) {
528       // SIB byte
529       numBytes += 1;
530     }
531     if (aSubOpcode) {
532       *aSubOpcode = (*aModRm & kMaskReg) >> kRegFieldShift;
533     }
534     return numBytes;
535   }
536 
537 #if defined(_M_X64)
538   // To patch for JMP and JE
539 
540   enum JumpType {
541    Je,
542    Jmp
543   };
544 
545   struct JumpPatch {
JumpPatchJumpPatch546     JumpPatch()
547       : mHookOffset(0), mJumpAddress(0), mType(JumpType::Jmp)
548     {
549     }
550 
551     JumpPatch(size_t aOffset, intptr_t aAddress, JumpType aType = JumpType::Jmp)
mHookOffsetJumpPatch552       : mHookOffset(aOffset), mJumpAddress(aAddress), mType(aType)
553     {
554     }
555 
556     void AddJumpPatch(size_t aHookOffset, intptr_t aAbsJumpAddress,
557                      JumpType aType = JumpType::Jmp)
558     {
559       mHookOffset = aHookOffset;
560       mJumpAddress = aAbsJumpAddress;
561       mType = aType;
562     }
563 
GenerateJumpJumpPatch564     size_t GenerateJump(uint8_t* aCode)
565     {
566       size_t offset = mHookOffset;
567       if (mType == JumpType::Je) {
568         // JNE RIP+14
569         aCode[offset]     = 0x75;
570         aCode[offset + 1] = 14;
571         offset += 2;
572       }
573 
574       // JMP [RIP+0]
575       aCode[offset] = 0xff;
576       aCode[offset + 1] = 0x25;
577       *reinterpret_cast<int32_t*>(aCode + offset + 2) = 0;
578 
579       // Jump table
580       *reinterpret_cast<int64_t*>(aCode + offset + 2 + 4) = mJumpAddress;
581 
582       return offset + 2 + 4 + 8;
583     }
584 
HasJumpPatchJumpPatch585     bool HasJumpPatch() const
586     {
587       return !!mJumpAddress;
588     }
589 
590     size_t mHookOffset;
591     intptr_t mJumpAddress;
592     JumpType mType;
593   };
594 
595 #endif
596 
597   enum ePrefixGroupBits
598   {
599     eNoPrefixes = 0,
600     ePrefixGroup1 = (1 << 0),
601     ePrefixGroup2 = (1 << 1),
602     ePrefixGroup3 = (1 << 2),
603     ePrefixGroup4 = (1 << 3)
604   };
605 
CountPrefixBytes(byteptr_t aBytes,const int aBytesIndex,unsigned char * aOutGroupBits)606   int CountPrefixBytes(byteptr_t aBytes, const int aBytesIndex,
607                        unsigned char* aOutGroupBits)
608   {
609     unsigned char& groupBits = *aOutGroupBits;
610     groupBits = eNoPrefixes;
611     int index = aBytesIndex;
612     while (true) {
613       switch (aBytes[index]) {
614         // Group 1
615         case 0xF0: // LOCK
616         case 0xF2: // REPNZ
617         case 0xF3: // REP / REPZ
618           if (groupBits & ePrefixGroup1) {
619             return -1;
620           }
621           groupBits |= ePrefixGroup1;
622           ++index;
623           break;
624 
625         // Group 2
626         case 0x2E: // CS override / branch not taken
627         case 0x36: // SS override
628         case 0x3E: // DS override / branch taken
629         case 0x64: // FS override
630         case 0x65: // GS override
631           if (groupBits & ePrefixGroup2) {
632             return -1;
633           }
634           groupBits |= ePrefixGroup2;
635           ++index;
636           break;
637 
638         // Group 3
639         case 0x66: // operand size override
640           if (groupBits & ePrefixGroup3) {
641             return -1;
642           }
643           groupBits |= ePrefixGroup3;
644           ++index;
645           break;
646 
647         // Group 4
648         case 0x67: // Address size override
649           if (groupBits & ePrefixGroup4) {
650             return -1;
651           }
652           groupBits |= ePrefixGroup4;
653           ++index;
654           break;
655 
656         default:
657           return index - aBytesIndex;
658       }
659     }
660   }
661 
CreateTrampoline(void * aOrigFunction,intptr_t aDest,void ** aOutTramp)662   void CreateTrampoline(void* aOrigFunction, intptr_t aDest, void** aOutTramp)
663   {
664     *aOutTramp = nullptr;
665 
666     AutoVirtualProtect protectHookPage(mHookPage, mMaxHooks * kHookSize,
667                                        PAGE_EXECUTE_READWRITE);
668     if (!protectHookPage.Protect()) {
669       return;
670     }
671 
672     byteptr_t tramp = FindTrampolineSpace();
673     if (!tramp) {
674       return;
675     }
676 
677     byteptr_t origBytes = (byteptr_t)aOrigFunction;
678 
679     int nBytes = 0;
680 
681 #if defined(_M_IX86)
682     int pJmp32 = -1;
683     while (nBytes < 5) {
684       // Understand some simple instructions that might be found in a
685       // prologue; we might need to extend this as necessary.
686       //
687       // Note!  If we ever need to understand jump instructions, we'll
688       // need to rewrite the displacement argument.
689       unsigned char prefixGroups;
690       int numPrefixBytes = CountPrefixBytes(origBytes, nBytes, &prefixGroups);
691       if (numPrefixBytes < 0 || (prefixGroups & (ePrefixGroup3 | ePrefixGroup4))) {
692         // Either the prefix sequence was bad, or there are prefixes that
693         // we don't currently support (groups 3 and 4)
694         return;
695       }
696       nBytes += numPrefixBytes;
697       if (origBytes[nBytes] >= 0x88 && origBytes[nBytes] <= 0x8B) {
698         // various MOVs
699         ++nBytes;
700         int len = CountModRmSib(origBytes + nBytes);
701         if (len < 0) {
702           return;
703         }
704         nBytes += len;
705       } else if (origBytes[nBytes] == 0xA1) {
706         // MOV eax, [seg:offset]
707         nBytes += 5;
708       } else if (origBytes[nBytes] == 0xB8) {
709         // MOV 0xB8: http://ref.x86asm.net/coder32.html#xB8
710         nBytes += 5;
711       } else if (origBytes[nBytes] == 0x83) {
712         // ADD|ODR|ADC|SBB|AND|SUB|XOR|CMP r/m, imm8
713         unsigned char b = origBytes[nBytes + 1];
714         if ((b & 0xc0) == 0xc0) {
715           // ADD|ODR|ADC|SBB|AND|SUB|XOR|CMP r, imm8
716           nBytes += 3;
717         } else {
718           // bail
719           return;
720         }
721       } else if (origBytes[nBytes] == 0x68) {
722         // PUSH with 4-byte operand
723         nBytes += 5;
724       } else if ((origBytes[nBytes] & 0xf0) == 0x50) {
725         // 1-byte PUSH/POP
726         nBytes++;
727       } else if (origBytes[nBytes] == 0x6A) {
728         // PUSH imm8
729         nBytes += 2;
730       } else if (origBytes[nBytes] == 0xe9) {
731         pJmp32 = nBytes;
732         // jmp 32bit offset
733         nBytes += 5;
734       } else if (origBytes[nBytes] == 0xff && origBytes[nBytes + 1] == 0x25) {
735         // jmp [disp32]
736         nBytes += 6;
737       } else {
738         //printf ("Unknown x86 instruction byte 0x%02x, aborting trampoline\n", origBytes[nBytes]);
739         return;
740       }
741     }
742 #elif defined(_M_X64)
743     JumpPatch jump;
744 
745     while (nBytes < 13) {
746 
747       // if found JMP 32bit offset, next bytes must be NOP or INT3
748       if (jump.HasJumpPatch()) {
749         if (origBytes[nBytes] == 0x90 || origBytes[nBytes] == 0xcc) {
750           nBytes++;
751           continue;
752         }
753         return;
754       }
755       if (origBytes[nBytes] == 0x0f) {
756         nBytes++;
757         if (origBytes[nBytes] == 0x1f) {
758           // nop (multibyte)
759           nBytes++;
760           if ((origBytes[nBytes] & 0xc0) == 0x40 &&
761               (origBytes[nBytes] & 0x7) == 0x04) {
762             nBytes += 3;
763           } else {
764             return;
765           }
766         } else if (origBytes[nBytes] == 0x05) {
767           // syscall
768           nBytes++;
769         } else if (origBytes[nBytes] == 0x84) {
770           // je rel32
771           jump.AddJumpPatch(nBytes - 1,
772                             (intptr_t)
773                               origBytes + nBytes + 5 +
774                             *(reinterpret_cast<int32_t*>(origBytes +
775                                                          nBytes + 1)),
776                             JumpType::Je);
777           nBytes += 5;
778         } else {
779           return;
780         }
781       } else if (origBytes[nBytes] == 0x40 ||
782                  origBytes[nBytes] == 0x41) {
783         // Plain REX or REX.B
784         nBytes++;
785 
786         if ((origBytes[nBytes] & 0xf0) == 0x50) {
787           // push/pop with Rx register
788           nBytes++;
789         } else if (origBytes[nBytes] >= 0xb8 && origBytes[nBytes] <= 0xbf) {
790           // mov r32, imm32
791           nBytes += 5;
792         } else {
793           return;
794         }
795       } else if (origBytes[nBytes] == 0x45) {
796         // REX.R & REX.B
797         nBytes++;
798 
799         if (origBytes[nBytes] == 0x33) {
800           // xor r32, r32
801           nBytes += 2;
802         } else {
803           return;
804         }
805       } else if ((origBytes[nBytes] & 0xfb) == 0x48) {
806         // REX.W | REX.WR
807         nBytes++;
808 
809         if (origBytes[nBytes] == 0x81 &&
810             (origBytes[nBytes + 1] & 0xf8) == 0xe8) {
811           // sub r, dword
812           nBytes += 6;
813         } else if (origBytes[nBytes] == 0x83 &&
814                    (origBytes[nBytes + 1] & 0xf8) == 0xe8) {
815           // sub r, byte
816           nBytes += 3;
817         } else if (origBytes[nBytes] == 0x83 &&
818                    (origBytes[nBytes + 1] & 0xf8) == 0x60) {
819           // and [r+d], imm8
820           nBytes += 5;
821         } else if (origBytes[nBytes] == 0x85) {
822           // 85 /r => TEST r/m32, r32
823           if ((origBytes[nBytes + 1] & 0xc0) == 0xc0) {
824             nBytes += 2;
825           } else {
826             return;
827           }
828         } else if ((origBytes[nBytes] & 0xfd) == 0x89) {
829           ++nBytes;
830           // MOV r/m64, r64 | MOV r64, r/m64
831           int len = CountModRmSib(origBytes + nBytes);
832           if (len < 0) {
833             return;
834           }
835           nBytes += len;
836         } else if (origBytes[nBytes] == 0xc7) {
837           // MOV r/m64, imm32
838           if (origBytes[nBytes + 1] == 0x44) {
839             // MOV [r64+disp8], imm32
840             // ModR/W + SIB + disp8 + imm32
841             nBytes += 8;
842           } else {
843             return;
844           }
845         } else if (origBytes[nBytes] == 0xff) {
846           // JMP /4
847           if ((origBytes[nBytes + 1] & 0xc0) == 0x0 &&
848               (origBytes[nBytes + 1] & 0x07) == 0x5) {
849             // [rip+disp32]
850             // convert JMP 32bit offset to JMP 64bit direct
851             jump.AddJumpPatch(nBytes - 1,
852                               *reinterpret_cast<intptr_t*>(
853                                 origBytes + nBytes + 6 +
854                               *reinterpret_cast<int32_t*>(origBytes + nBytes +
855                                                           2)));
856             nBytes += 6;
857           } else {
858             // not support yet!
859             return;
860           }
861         } else {
862           // not support yet!
863           return;
864         }
865       } else if (origBytes[nBytes] == 0x66) {
866         // operand override prefix
867         nBytes += 1;
868         // This is the same as the x86 version
869         if (origBytes[nBytes] >= 0x88 && origBytes[nBytes] <= 0x8B) {
870           // various MOVs
871           unsigned char b = origBytes[nBytes + 1];
872           if (((b & 0xc0) == 0xc0) ||
873               (((b & 0xc0) == 0x00) &&
874                ((b & 0x07) != 0x04) && ((b & 0x07) != 0x05))) {
875             // REG=r, R/M=r or REG=r, R/M=[r]
876             nBytes += 2;
877           } else if ((b & 0xc0) == 0x40) {
878             if ((b & 0x07) == 0x04) {
879               // REG=r, R/M=[SIB + disp8]
880               nBytes += 4;
881             } else {
882               // REG=r, R/M=[r + disp8]
883               nBytes += 3;
884             }
885           } else {
886             // complex MOV, bail
887             return;
888           }
889         }
890       } else if ((origBytes[nBytes] & 0xf0) == 0x50) {
891         // 1-byte push/pop
892         nBytes++;
893       } else if (origBytes[nBytes] == 0x65) {
894         // GS prefix
895         //
896         // The entry of GetKeyState on Windows 10 has the following code.
897         // 65 48 8b 04 25 30 00 00 00    mov   rax,qword ptr gs:[30h]
898         // (GS prefix + REX + MOV (0x8b) ...)
899         if (origBytes[nBytes + 1] == 0x48 &&
900             (origBytes[nBytes + 2] >= 0x88 && origBytes[nBytes + 2] <= 0x8b)) {
901           nBytes += 3;
902           int len = CountModRmSib(origBytes + nBytes);
903           if (len < 0) {
904             // no way to support this yet.
905             return;
906           }
907           nBytes += len;
908         } else {
909           return;
910         }
911       } else if (origBytes[nBytes] == 0x90) {
912         // nop
913         nBytes++;
914       } else if (origBytes[nBytes] == 0xb8) {
915         // MOV 0xB8: http://ref.x86asm.net/coder32.html#xB8
916         nBytes += 5;
917       } else if (origBytes[nBytes] == 0x33) {
918         // xor r32, r/m32
919         nBytes += 2;
920       } else if (origBytes[nBytes] == 0xf6) {
921         // test r/m8, imm8 (used by ntdll on Windows 10 x64)
922         // (no flags are affected by near jmp since there is no task switch,
923         // so it is ok for a jmp to be written immediately after a test)
924         BYTE subOpcode = 0;
925         int nModRmSibBytes = CountModRmSib(&origBytes[nBytes + 1], &subOpcode);
926         if (nModRmSibBytes < 0 || subOpcode != 0) {
927           // Unsupported
928           return;
929         }
930         nBytes += 2 + nModRmSibBytes;
931       } else if (origBytes[nBytes] == 0xc3) {
932         // ret
933         nBytes++;
934       } else if (origBytes[nBytes] == 0xcc) {
935         // int 3
936         nBytes++;
937       } else if (origBytes[nBytes] == 0xe9) {
938         // jmp 32bit offset
939         jump.AddJumpPatch(nBytes,
940                           // convert JMP 32bit offset to JMP 64bit direct
941                           (intptr_t)
942                             origBytes + nBytes + 5 +
943                           *(reinterpret_cast<int32_t*>(origBytes + nBytes + 1)));
944         nBytes += 5;
945       } else if (origBytes[nBytes] == 0xff) {
946         nBytes++;
947         if ((origBytes[nBytes] & 0xf8) == 0xf0) {
948           // push r64
949           nBytes++;
950         } else {
951           return;
952         }
953       } else {
954         return;
955       }
956     }
957 #else
958 #error "Unknown processor type"
959 #endif
960 
961     if (nBytes > 100) {
962       //printf ("Too big!");
963       return;
964     }
965 
966     // We keep the address of the original function in the first bytes of
967     // the trampoline buffer
968     *((void**)tramp) = EncodePointer(aOrigFunction);
969     tramp += sizeof(void*);
970 
971     memcpy(tramp, aOrigFunction, nBytes);
972 
973     // OrigFunction+N, the target of the trampoline
974     byteptr_t trampDest = origBytes + nBytes;
975 
976 #if defined(_M_IX86)
977     if (pJmp32 >= 0) {
978       // Jump directly to the original target of the jump instead of jumping to the
979       // original function.
980       // Adjust jump target displacement to jump location in the trampoline.
981       *((intptr_t*)(tramp + pJmp32 + 1)) += origBytes - tramp;
982     } else {
983       tramp[nBytes] = 0xE9; // jmp
984       *((intptr_t*)(tramp + nBytes + 1)) =
985         (intptr_t)trampDest - (intptr_t)(tramp + nBytes + 5); // target displacement
986     }
987 #elif defined(_M_X64)
988     // If JMP/JE opcode found, we don't insert to trampoline jump
989     if (jump.HasJumpPatch()) {
990       size_t offset = jump.GenerateJump(tramp);
991       if (jump.mType != JumpType::Jmp) {
992         JumpPatch patch(offset, reinterpret_cast<intptr_t>(trampDest));
993         patch.GenerateJump(tramp);
994       }
995     } else {
996       JumpPatch patch(nBytes, reinterpret_cast<intptr_t>(trampDest));
997       patch.GenerateJump(tramp);
998     }
999 #endif
1000 
1001     // The trampoline is now valid.
1002     *aOutTramp = tramp;
1003 
1004     // ensure we can modify the original code
1005     AutoVirtualProtect protect(aOrigFunction, nBytes, PAGE_EXECUTE_READWRITE);
1006     if (!protect.Protect()) {
1007       return;
1008     }
1009 
1010 #if defined(_M_IX86)
1011     // now modify the original bytes
1012     origBytes[0] = 0xE9; // jmp
1013     *((intptr_t*)(origBytes + 1)) =
1014       aDest - (intptr_t)(origBytes + 5); // target displacement
1015 #elif defined(_M_X64)
1016     // mov r11, address
1017     origBytes[0] = 0x49;
1018     origBytes[1] = 0xbb;
1019 
1020     *((intptr_t*)(origBytes + 2)) = aDest;
1021 
1022     // jmp r11
1023     origBytes[10] = 0x41;
1024     origBytes[11] = 0xff;
1025     origBytes[12] = 0xe3;
1026 #endif
1027   }
1028 
FindTrampolineSpace()1029   byteptr_t FindTrampolineSpace()
1030   {
1031     if (mCurHooks >= mMaxHooks) {
1032       return 0;
1033     }
1034 
1035     byteptr_t p = mHookPage + mCurHooks * kHookSize;
1036 
1037     mCurHooks++;
1038 
1039     return p;
1040   }
1041 
ResolveRedirectedAddress(const byteptr_t aOriginalFunction)1042   static void* ResolveRedirectedAddress(const byteptr_t aOriginalFunction)
1043   {
1044 #if defined(_M_IX86)
1045     // If function entry is jmp [disp32] such as used by kernel32,
1046     // we resolve redirected address from import table.
1047     if (aOriginalFunction[0] == 0xff && aOriginalFunction[1] == 0x25) {
1048       return (void*)(**((uint32_t**) (aOriginalFunction + 2)));
1049     }
1050 #elif defined(_M_X64)
1051     if (aOriginalFunction[0] == 0xe9) {
1052       // require for TestDllInterceptor with --disable-optimize
1053       int32_t offset = *((int32_t*)(aOriginalFunction + 1));
1054       return aOriginalFunction + 5 + offset;
1055     }
1056 #endif
1057 
1058     return aOriginalFunction;
1059   }
1060 };
1061 
1062 } // namespace internal
1063 
1064 class WindowsDllInterceptor
1065 {
1066   internal::WindowsDllNopSpacePatcher mNopSpacePatcher;
1067   internal::WindowsDllDetourPatcher mDetourPatcher;
1068 
1069   const char* mModuleName;
1070   int mNHooks;
1071 
1072 public:
WindowsDllInterceptor()1073   WindowsDllInterceptor()
1074     : mModuleName(nullptr)
1075     , mNHooks(0)
1076   {}
1077 
1078   void Init(const char* aModuleName, int aNumHooks = 0)
1079   {
1080     if (mModuleName) {
1081       return;
1082     }
1083 
1084     mModuleName = aModuleName;
1085     mNHooks = aNumHooks;
1086     mNopSpacePatcher.Init(aModuleName);
1087 
1088     // Lazily initialize mDetourPatcher, since it allocates memory and we might
1089     // not need it.
1090   }
1091 
AddHook(const char * aName,intptr_t aHookDest,void ** aOrigFunc)1092   bool AddHook(const char* aName, intptr_t aHookDest, void** aOrigFunc)
1093   {
1094     // Use a nop space patch if possible, otherwise fall back to a detour.
1095     // This should be the preferred method for adding hooks.
1096 
1097     if (!mModuleName) {
1098       return false;
1099     }
1100 
1101     if (mNopSpacePatcher.AddHook(aName, aHookDest, aOrigFunc)) {
1102       return true;
1103     }
1104 
1105     return AddDetour(aName, aHookDest, aOrigFunc);
1106   }
1107 
AddDetour(const char * aName,intptr_t aHookDest,void ** aOrigFunc)1108   bool AddDetour(const char* aName, intptr_t aHookDest, void** aOrigFunc)
1109   {
1110     // Generally, code should not call this method directly. Use AddHook unless
1111     // there is a specific need to avoid nop space patches.
1112 
1113     if (!mModuleName) {
1114       return false;
1115     }
1116 
1117     if (!mDetourPatcher.Initialized()) {
1118       mDetourPatcher.Init(mModuleName, mNHooks);
1119     }
1120 
1121     return mDetourPatcher.AddHook(aName, aHookDest, aOrigFunc);
1122   }
1123 };
1124 
1125 } // namespace mozilla
1126 
1127 #endif /* NS_WINDOWS_DLL_INTERCEPTOR_H_ */
1128