1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this
5  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #ifndef NS_WINDOWS_DLL_INTERCEPTOR_H_
8 #define NS_WINDOWS_DLL_INTERCEPTOR_H_
9 
10 #include "mozilla/Assertions.h"
11 #include "mozilla/ArrayUtils.h"
12 #include "mozilla/UniquePtr.h"
13 #include "nsWindowsHelpers.h"
14 
15 #include <wchar.h>
16 #include <windows.h>
17 #include <winternl.h>
18 
19 /*
20  * Simple function interception.
21  *
22  * We have two separate mechanisms for intercepting a function: We can use the
23  * built-in nop space, if it exists, or we can create a detour.
24  *
25  * Using the built-in nop space works as follows: On x86-32, DLL functions
26  * begin with a two-byte nop (mov edi, edi) and are preceeded by five bytes of
27  * NOP instructions.
28  *
29  * When we detect a function with this prelude, we do the following:
30  *
31  * 1. Write a long jump to our interceptor function into the five bytes of NOPs
32  *    before the function.
33  *
34  * 2. Write a short jump -5 into the two-byte nop at the beginning of the
35  * function.
36  *
37  * This mechanism is nice because it's thread-safe.  It's even safe to do if
38  * another thread is currently running the function we're modifying!
39  *
40  * When the WindowsDllNopSpacePatcher is destroyed, we overwrite the short jump
41  * but not the long jump, so re-intercepting the same function won't work,
42  * because its prelude won't match.
43  *
44  *
45  * Unfortunately nop space patching doesn't work on functions which don't have
46  * this magic prelude (and in particular, x86-64 never has the prelude).  So
47  * when we can't use the built-in nop space, we fall back to using a detour,
48  * which works as follows:
49  *
50  * 1. Save first N bytes of OrigFunction to trampoline, where N is a
51  *    number of bytes >= 5 that are instruction aligned.
52  *
53  * 2. Replace first 5 bytes of OrigFunction with a jump to the Hook
54  *    function.
55  *
56  * 3. After N bytes of the trampoline, add a jump to OrigFunction+N to
57  *    continue original program flow.
58  *
59  * 4. Hook function needs to call the trampoline during its execution,
60  *    to invoke the original function (so address of trampoline is
61  *    returned).
62  *
63  * When the WindowsDllDetourPatcher object is destructed, OrigFunction is
64  * patched again to jump directly to the trampoline instead of going through
65  * the hook function. As such, re-intercepting the same function won't work, as
66  * jump instructions are not supported.
67  *
68  * Note that this is not thread-safe.  Sad day.
69  *
70  */
71 
72 #include <stdint.h>
73 
74 #define COPY_CODES(NBYTES)                                       \
75   do {                                                           \
76     memcpy(&tramp[nTrampBytes], &origBytes[nOrigBytes], NBYTES); \
77     nOrigBytes += NBYTES;                                        \
78     nTrampBytes += NBYTES;                                       \
79   } while (0)
80 
81 namespace mozilla {
82 namespace internal {
83 
84 class AutoVirtualProtect {
85  public:
AutoVirtualProtect(void * aFunc,size_t aSize,DWORD aProtect)86   AutoVirtualProtect(void* aFunc, size_t aSize, DWORD aProtect)
87       : mFunc(aFunc),
88         mSize(aSize),
89         mNewProtect(aProtect),
90         mOldProtect(0),
91         mSuccess(false) {}
92 
~AutoVirtualProtect()93   ~AutoVirtualProtect() {
94     if (mSuccess) {
95       VirtualProtectEx(GetCurrentProcess(), mFunc, mSize, mOldProtect,
96                        &mOldProtect);
97     }
98   }
99 
Protect()100   bool Protect() {
101     mSuccess = !!VirtualProtectEx(GetCurrentProcess(), mFunc, mSize,
102                                   mNewProtect, &mOldProtect);
103     if (!mSuccess) {
104       // printf("VirtualProtectEx failed! %d\n", GetLastError());
105     }
106     return mSuccess;
107   }
108 
109  private:
110   void* const mFunc;
111   size_t const mSize;
112   DWORD const mNewProtect;
113   DWORD mOldProtect;
114   bool mSuccess;
115 };
116 
117 class WindowsDllNopSpacePatcher {
118   typedef uint8_t* byteptr_t;
119   HMODULE mModule;
120 
121   // Dumb array for remembering the addresses of functions we've patched.
122   // (This should be nsTArray, but non-XPCOM code uses this class.)
123   static const size_t maxPatchedFns = 16;
124   byteptr_t mPatchedFns[maxPatchedFns];
125   size_t mPatchedFnsLen;
126 
127  public:
WindowsDllNopSpacePatcher()128   WindowsDllNopSpacePatcher() : mModule(0), mPatchedFnsLen(0) {}
129 
130 #if defined(_M_IX86)
~WindowsDllNopSpacePatcher()131   ~WindowsDllNopSpacePatcher() {
132     // Restore the mov edi, edi to the beginning of each function we patched.
133 
134     for (size_t i = 0; i < mPatchedFnsLen; i++) {
135       byteptr_t fn = mPatchedFns[i];
136 
137       // Ensure we can write to the code.
138       AutoVirtualProtect protect(fn, 2, PAGE_EXECUTE_READWRITE);
139       if (!protect.Protect()) {
140         continue;
141       }
142 
143       // mov edi, edi
144       *((uint16_t*)fn) = 0xff8b;
145 
146       // I don't think this is actually necessary, but it can't hurt.
147       FlushInstructionCache(GetCurrentProcess(),
148                             /* ignored */ nullptr,
149                             /* ignored */ 0);
150     }
151   }
152 
Init(const char * aModuleName)153   void Init(const char* aModuleName) {
154     if (!IsCompatible()) {
155 #if defined(MOZILLA_INTERNAL_API)
156       NS_WARNING("NOP space patching is unavailable for compatibility reasons");
157 #endif
158       return;
159     }
160 
161     mModule = LoadLibraryExA(aModuleName, nullptr, 0);
162     if (!mModule) {
163       // printf("LoadLibraryEx for '%s' failed\n", aModuleName);
164       return;
165     }
166   }
167 
168   /**
169    * NVIDIA Optimus drivers utilize Microsoft Detours 2.x to patch functions
170    * in our address space. There is a bug in Detours 2.x that causes it to
171    * patch at the wrong address when attempting to detour code that is already
172    * NOP space patched. This function is an effort to detect the presence of
173    * this NVIDIA code in our address space and disable NOP space patching if it
174    * is. We also check AppInit_DLLs since this is the mechanism that the Optimus
175    * drivers use to inject into our process.
176    */
IsCompatible()177   static bool IsCompatible() {
178     // These DLLs are known to have bad interactions with this style of patching
179     const wchar_t* kIncompatibleDLLs[] = {L"detoured.dll", L"_etoured.dll",
180                                           L"nvd3d9wrap.dll", L"nvdxgiwrap.dll"};
181     // See if the infringing DLLs are already loaded
182     for (unsigned int i = 0; i < mozilla::ArrayLength(kIncompatibleDLLs); ++i) {
183       if (GetModuleHandleW(kIncompatibleDLLs[i])) {
184         return false;
185       }
186     }
187     if (GetModuleHandleW(L"user32.dll")) {
188       // user32 is loaded but the infringing DLLs are not, assume we're safe to
189       // proceed.
190       return true;
191     }
192     // If user32 has not loaded yet, check AppInit_DLLs to ensure that Optimus
193     // won't be loaded once user32 is initialized.
194     HKEY hkey = NULL;
195     if (!RegOpenKeyExW(
196             HKEY_LOCAL_MACHINE,
197             L"SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\Windows", 0,
198             KEY_QUERY_VALUE, &hkey)) {
199       nsAutoRegKey key(hkey);
200       DWORD numBytes = 0;
201       const wchar_t kAppInitDLLs[] = L"AppInit_DLLs";
202       // Query for required buffer size
203       LONG status = RegQueryValueExW(hkey, kAppInitDLLs, nullptr, nullptr,
204                                      nullptr, &numBytes);
205       mozilla::UniquePtr<wchar_t[]> data;
206       if (!status) {
207         // Allocate the buffer and query for the actual data
208         data = mozilla::MakeUnique<wchar_t[]>(numBytes / sizeof(wchar_t));
209         status = RegQueryValueExW(hkey, kAppInitDLLs, nullptr, nullptr,
210                                   (LPBYTE)data.get(), &numBytes);
211       }
212       if (!status) {
213         // For each token, split up the filename components and then check the
214         // name of the file.
215         const wchar_t kDelimiters[] = L", ";
216         wchar_t* tokenContext = nullptr;
217         wchar_t* token = wcstok_s(data.get(), kDelimiters, &tokenContext);
218         while (token) {
219           wchar_t fname[_MAX_FNAME] = {0};
220           if (!_wsplitpath_s(token, nullptr, 0, nullptr, 0, fname,
221                              mozilla::ArrayLength(fname), nullptr, 0)) {
222             // nvinit.dll is responsible for bootstrapping the DLL injection, so
223             // that is the library that we check for here
224             const wchar_t kNvInitName[] = L"nvinit";
225             if (!_wcsnicmp(fname, kNvInitName,
226                            mozilla::ArrayLength(kNvInitName))) {
227               return false;
228             }
229           }
230           token = wcstok_s(nullptr, kDelimiters, &tokenContext);
231         }
232       }
233     }
234     return true;
235   }
236 
AddHook(const char * aName,intptr_t aHookDest,void ** aOrigFunc)237   bool AddHook(const char* aName, intptr_t aHookDest, void** aOrigFunc) {
238     if (!mModule) {
239       return false;
240     }
241 
242     if (!IsCompatible()) {
243 #if defined(MOZILLA_INTERNAL_API)
244       NS_WARNING("NOP space patching is unavailable for compatibility reasons");
245 #endif
246       return false;
247     }
248 
249     MOZ_RELEASE_ASSERT(mPatchedFnsLen < maxPatchedFns, "No room for the hook");
250 
251     byteptr_t fn = reinterpret_cast<byteptr_t>(GetProcAddress(mModule, aName));
252     if (!fn) {
253       // printf ("GetProcAddress failed\n");
254       return false;
255     }
256 
257     fn = ResolveRedirectedAddress(fn);
258 
259     // Ensure we can read and write starting at fn - 5 (for the long jmp we're
260     // going to write) and ending at fn + 2 (for the short jmp up to the long
261     // jmp). These bytes may span two pages with different protection.
262     AutoVirtualProtect protectBefore(fn - 5, 5, PAGE_EXECUTE_READWRITE);
263     AutoVirtualProtect protectAfter(fn, 2, PAGE_EXECUTE_READWRITE);
264     if (!protectBefore.Protect() || !protectAfter.Protect()) {
265       return false;
266     }
267 
268     bool rv = WriteHook(fn, aHookDest, aOrigFunc);
269 
270     if (rv) {
271       mPatchedFns[mPatchedFnsLen] = fn;
272       mPatchedFnsLen++;
273     }
274 
275     return rv;
276   }
277 
WriteHook(byteptr_t aFn,intptr_t aHookDest,void ** aOrigFunc)278   bool WriteHook(byteptr_t aFn, intptr_t aHookDest, void** aOrigFunc) {
279     // Check that the 5 bytes before aFn are NOP's or INT 3's,
280     // and that the 2 bytes after aFn are mov(edi, edi).
281     //
282     // It's safe to read aFn[-5] because we set it to PAGE_EXECUTE_READWRITE
283     // before calling WriteHook.
284 
285     for (int i = -5; i <= -1; i++) {
286       if (aFn[i] != 0x90 && aFn[i] != 0xcc) {  // nop or int 3
287         return false;
288       }
289     }
290 
291     // mov edi, edi.  Yes, there are two ways to encode the same thing:
292     //
293     //   0x89ff == mov r/m, r
294     //   0x8bff == mov r, r/m
295     //
296     // where "r" is register and "r/m" is register or memory.  Windows seems to
297     // use 8bff; I include 89ff out of paranoia.
298     if ((aFn[0] != 0x8b && aFn[0] != 0x89) || aFn[1] != 0xff) {
299       return false;
300     }
301 
302     // Write a long jump into the space above the function.
303     aFn[-5] = 0xe9;  // jmp
304     *((intptr_t*)(aFn - 4)) =
305         aHookDest - (uintptr_t)(aFn);  // target displacement
306 
307     // Set aOrigFunc here, because after this point, aHookDest might be called,
308     // and aHookDest might use the aOrigFunc pointer.
309     *aOrigFunc = aFn + 2;
310 
311     // Short jump up into our long jump.
312     *((uint16_t*)(aFn)) = 0xf9eb;  // jmp $-5
313 
314     // I think this routine is safe without this, but it can't hurt.
315     FlushInstructionCache(GetCurrentProcess(),
316                           /* ignored */ nullptr,
317                           /* ignored */ 0);
318 
319     return true;
320   }
321 
322  private:
ResolveRedirectedAddress(const byteptr_t aOriginalFunction)323   static byteptr_t ResolveRedirectedAddress(const byteptr_t aOriginalFunction) {
324     // If function entry is jmp rel8 stub to the internal implementation, we
325     // resolve redirected address from the jump target.
326     if (aOriginalFunction[0] == 0xeb) {
327       int8_t offset = (int8_t)(aOriginalFunction[1]);
328       if (offset <= 0) {
329         // Bail out for negative offset: probably already patched by some
330         // third-party code.
331         return aOriginalFunction;
332       }
333 
334       for (int8_t i = 0; i < offset; i++) {
335         if (aOriginalFunction[2 + i] != 0x90) {
336           // Bail out on insufficient nop space.
337           return aOriginalFunction;
338         }
339       }
340 
341       return aOriginalFunction + 2 + offset;
342     }
343 
344     // If function entry is jmp [disp32] such as used by kernel32,
345     // we resolve redirected address from import table.
346     if (aOriginalFunction[0] == 0xff && aOriginalFunction[1] == 0x25) {
347       return (byteptr_t)(**((uint32_t**)(aOriginalFunction + 2)));
348     }
349 
350     return aOriginalFunction;
351   }
352 #else
Init(const char * aModuleName)353   void Init(const char* aModuleName) {
354     // Not implemented except on x86-32.
355   }
356 
AddHook(const char * aName,intptr_t aHookDest,void ** aOrigFunc)357   bool AddHook(const char* aName, intptr_t aHookDest, void** aOrigFunc) {
358     // Not implemented except on x86-32.
359     return false;
360   }
361 #endif
362 };
363 
364 class WindowsDllDetourPatcher {
365   typedef unsigned char* byteptr_t;
366 
367  public:
WindowsDllDetourPatcher()368   WindowsDllDetourPatcher()
369       : mModule(0), mHookPage(0), mMaxHooks(0), mCurHooks(0) {}
370 
~WindowsDllDetourPatcher()371   ~WindowsDllDetourPatcher() {
372     int i;
373     byteptr_t p;
374     for (i = 0, p = mHookPage; i < mCurHooks; i++, p += kHookSize) {
375 #if defined(_M_IX86)
376       size_t nBytes = 1 + sizeof(intptr_t);
377 #elif defined(_M_X64)
378       size_t nBytes = 2 + sizeof(intptr_t);
379 #else
380 #error "Unknown processor type"
381 #endif
382       byteptr_t origBytes = (byteptr_t)DecodePointer(*((byteptr_t*)p));
383 
384       // ensure we can modify the original code
385       AutoVirtualProtect protect(origBytes, nBytes, PAGE_EXECUTE_READWRITE);
386       if (!protect.Protect()) {
387         continue;
388       }
389 
390       // Remove the hook by making the original function jump directly
391       // in the trampoline.
392       intptr_t dest = (intptr_t)(p + sizeof(void*));
393 #if defined(_M_IX86)
394       // Ensure the JMP from CreateTrampoline is where we expect it to be.
395       if (origBytes[0] != 0xE9) continue;
396       *((intptr_t*)(origBytes + 1)) =
397           dest - (intptr_t)(origBytes + 5);  // target displacement
398 #elif defined(_M_X64)
399       // Ensure the MOV R11 from CreateTrampoline is where we expect it to be.
400       if (origBytes[0] != 0x49 || origBytes[1] != 0xBB) continue;
401       *((intptr_t*)(origBytes + 2)) = dest;
402 #else
403 #error "Unknown processor type"
404 #endif
405     }
406   }
407 
408   void Init(const char* aModuleName, int aNumHooks = 0) {
409     if (mModule) {
410       return;
411     }
412 
413     mModule = LoadLibraryExA(aModuleName, nullptr, 0);
414     if (!mModule) {
415       // printf("LoadLibraryEx for '%s' failed\n", aModuleName);
416       return;
417     }
418 
419     int hooksPerPage = 4096 / kHookSize;
420     if (aNumHooks == 0) {
421       aNumHooks = hooksPerPage;
422     }
423 
424     mMaxHooks = aNumHooks + (hooksPerPage % aNumHooks);
425 
426     mHookPage = (byteptr_t)VirtualAllocEx(
427         GetCurrentProcess(), nullptr, mMaxHooks * kHookSize,
428         MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READ);
429     if (!mHookPage) {
430       mModule = 0;
431       return;
432     }
433   }
434 
Initialized()435   bool Initialized() { return !!mModule; }
436 
AddHook(const char * aName,intptr_t aHookDest,void ** aOrigFunc)437   bool AddHook(const char* aName, intptr_t aHookDest, void** aOrigFunc) {
438     if (!mModule) {
439       return false;
440     }
441 
442     void* pAddr = (void*)GetProcAddress(mModule, aName);
443     if (!pAddr) {
444       // printf ("GetProcAddress failed\n");
445       return false;
446     }
447 
448     pAddr = ResolveRedirectedAddress((byteptr_t)pAddr);
449 
450     CreateTrampoline(pAddr, aHookDest, aOrigFunc);
451     if (!*aOrigFunc) {
452       // printf ("CreateTrampoline failed\n");
453       return false;
454     }
455 
456     return true;
457   }
458 
459  protected:
460   const static int kPageSize = 4096;
461   const static int kHookSize = 128;
462 
463   HMODULE mModule;
464   byteptr_t mHookPage;
465   int mMaxHooks;
466   int mCurHooks;
467 
468   // rex bits
469   static const BYTE kMaskHighNibble = 0xF0;
470   static const BYTE kRexOpcode = 0x40;
471   static const BYTE kMaskRexW = 0x08;
472   static const BYTE kMaskRexR = 0x04;
473   static const BYTE kMaskRexX = 0x02;
474   static const BYTE kMaskRexB = 0x01;
475 
476   // mod r/m bits
477   static const BYTE kRegFieldShift = 3;
478   static const BYTE kMaskMod = 0xC0;
479   static const BYTE kMaskReg = 0x38;
480   static const BYTE kMaskRm = 0x07;
481   static const BYTE kRmNeedSib = 0x04;
482   static const BYTE kModReg = 0xC0;
483   static const BYTE kModDisp32 = 0x80;
484   static const BYTE kModDisp8 = 0x40;
485   static const BYTE kModNoRegDisp = 0x00;
486   static const BYTE kRmNoRegDispDisp32 = 0x05;
487 
488   // sib bits
489   static const BYTE kMaskSibScale = 0xC0;
490   static const BYTE kMaskSibIndex = 0x38;
491   static const BYTE kMaskSibBase = 0x07;
492   static const BYTE kSibBaseEbp = 0x05;
493 
494   // Register bit IDs.
495   static const BYTE kRegAx = 0x0;
496   static const BYTE kRegCx = 0x1;
497   static const BYTE kRegDx = 0x2;
498   static const BYTE kRegBx = 0x3;
499   static const BYTE kRegSp = 0x4;
500   static const BYTE kRegBp = 0x5;
501   static const BYTE kRegSi = 0x6;
502   static const BYTE kRegDi = 0x7;
503 
504   // Special ModR/M codes.  These indicate operands that cannot be simply
505   // memcpy-ed.
506   // Operand is a 64-bit RIP-relative address.
507   static const int kModOperand64 = -2;
508   // Operand is not yet handled by our trampoline.
509   static const int kModUnknown = -1;
510 
511   /**
512    * Returns the number of bytes taken by the ModR/M byte, SIB (if present)
513    * and the instruction's operand.  In special cases, the special MODRM codes
514    * above are returned.
515    * aModRm points to the ModR/M byte of the instruction.
516    * On return, aSubOpcode (if present) is filled with the subopcode/register
517    * code found in the ModR/M byte.
518    */
519   int CountModRmSib(const BYTE* aModRm, BYTE* aSubOpcode = nullptr) {
520     if (!aModRm) {
521       MOZ_ASSERT(aModRm, "Missing ModRM byte");
522       return kModUnknown;
523     }
524     int numBytes = 1;  // Start with 1 for mod r/m byte itself
525     switch (*aModRm & kMaskMod) {
526       case kModReg:
527         return numBytes;
528       case kModDisp8:
529         numBytes += 1;
530         break;
531       case kModDisp32:
532         numBytes += 4;
533         break;
534       case kModNoRegDisp:
535         if ((*aModRm & kMaskRm) == kRmNoRegDispDisp32) {
536 #if defined(_M_X64)
537           if (aSubOpcode) {
538             *aSubOpcode = (*aModRm & kMaskReg) >> kRegFieldShift;
539           }
540           return kModOperand64;
541 #else
542           // On IA-32, all ModR/M instruction modes address memory relative to 0
543           numBytes += 4;
544 #endif
545         } else if (((*aModRm & kMaskRm) == kRmNeedSib &&
546                     (*(aModRm + 1) & kMaskSibBase) == kSibBaseEbp)) {
547           numBytes += 4;
548         }
549         break;
550       default:
551         // This should not be reachable
552         MOZ_ASSERT_UNREACHABLE("Impossible value for modr/m byte mod bits");
553         return kModUnknown;
554     }
555     if ((*aModRm & kMaskRm) == kRmNeedSib) {
556       // SIB byte
557       numBytes += 1;
558     }
559     if (aSubOpcode) {
560       *aSubOpcode = (*aModRm & kMaskReg) >> kRegFieldShift;
561     }
562     return numBytes;
563   }
564 
565 #if defined(_M_X64)
566   // To patch for JMP and JE
567 
568   enum JumpType { Je, Jne, Jmp, Call };
569 
570   struct JumpPatch {
JumpPatchJumpPatch571     JumpPatch() : mHookOffset(0), mJumpAddress(0), mType(JumpType::Jmp) {}
572 
573     JumpPatch(size_t aOffset, intptr_t aAddress, JumpType aType = JumpType::Jmp)
mHookOffsetJumpPatch574         : mHookOffset(aOffset), mJumpAddress(aAddress), mType(aType) {}
575 
GenerateJumpJumpPatch576     size_t GenerateJump(uint8_t* aCode) {
577       size_t offset = mHookOffset;
578       if (mType == JumpType::Je) {
579         // JNE RIP+14
580         aCode[offset] = 0x75;
581         aCode[offset + 1] = 14;
582         offset += 2;
583       } else if (mType == JumpType::Jne) {
584         // JE RIP+14
585         aCode[offset] = 0x74;
586         aCode[offset + 1] = 14;
587         offset += 2;
588       }
589 
590       // Near call/jmp, absolute indirect, address given in r/m32
591       if (mType == JumpType::Call) {
592         // CALL [RIP+0]
593         aCode[offset] = 0xff;
594         aCode[offset + 1] = 0x15;
595         // The offset to jump destination -- ie it is placed 2 bytes after the
596         // offset.
597         *reinterpret_cast<int32_t*>(aCode + offset + 2) = 2;
598         aCode[offset + 2 + 4] = 0xeb;  // JMP +8 (jump over mJumpAddress)
599         aCode[offset + 2 + 4 + 1] = 8;
600         *reinterpret_cast<int64_t*>(aCode + offset + 2 + 4 + 2) = mJumpAddress;
601         return offset + 2 + 4 + 2 + 8;
602       } else {
603         // JMP [RIP+0]
604         aCode[offset] = 0xff;
605         aCode[offset + 1] = 0x25;
606         // The offset to jump destination is 0
607         *reinterpret_cast<int32_t*>(aCode + offset + 2) = 0;
608         *reinterpret_cast<int64_t*>(aCode + offset + 2 + 4) = mJumpAddress;
609         return offset + 2 + 4 + 8;
610       }
611     }
612 
613     size_t mHookOffset;
614     intptr_t mJumpAddress;
615     JumpType mType;
616   };
617 
618 #endif
619 
620   enum ePrefixGroupBits {
621     eNoPrefixes = 0,
622     ePrefixGroup1 = (1 << 0),
623     ePrefixGroup2 = (1 << 1),
624     ePrefixGroup3 = (1 << 2),
625     ePrefixGroup4 = (1 << 3)
626   };
627 
CountPrefixBytes(byteptr_t aBytes,const int aBytesIndex,unsigned char * aOutGroupBits)628   int CountPrefixBytes(byteptr_t aBytes, const int aBytesIndex,
629                        unsigned char* aOutGroupBits) {
630     unsigned char& groupBits = *aOutGroupBits;
631     groupBits = eNoPrefixes;
632     int index = aBytesIndex;
633     while (true) {
634       switch (aBytes[index]) {
635         // Group 1
636         case 0xF0:  // LOCK
637         case 0xF2:  // REPNZ
638         case 0xF3:  // REP / REPZ
639           if (groupBits & ePrefixGroup1) {
640             return -1;
641           }
642           groupBits |= ePrefixGroup1;
643           ++index;
644           break;
645 
646         // Group 2
647         case 0x2E:  // CS override / branch not taken
648         case 0x36:  // SS override
649         case 0x3E:  // DS override / branch taken
650         case 0x64:  // FS override
651         case 0x65:  // GS override
652           if (groupBits & ePrefixGroup2) {
653             return -1;
654           }
655           groupBits |= ePrefixGroup2;
656           ++index;
657           break;
658 
659         // Group 3
660         case 0x66:  // operand size override
661           if (groupBits & ePrefixGroup3) {
662             return -1;
663           }
664           groupBits |= ePrefixGroup3;
665           ++index;
666           break;
667 
668         // Group 4
669         case 0x67:  // Address size override
670           if (groupBits & ePrefixGroup4) {
671             return -1;
672           }
673           groupBits |= ePrefixGroup4;
674           ++index;
675           break;
676 
677         default:
678           return index - aBytesIndex;
679       }
680     }
681   }
682 
683   // Return a ModR/M byte made from the 2 Mod bits, the register used for the
684   // reg bits and the register used for the R/M bits.
BuildModRmByte(BYTE aModBits,BYTE aReg,BYTE aRm)685   BYTE BuildModRmByte(BYTE aModBits, BYTE aReg, BYTE aRm) {
686     MOZ_ASSERT((aRm & kMaskRm) == aRm);
687     MOZ_ASSERT((aModBits & kMaskMod) == aModBits);
688     MOZ_ASSERT(((aReg << kRegFieldShift) & kMaskReg) ==
689                (aReg << kRegFieldShift));
690     return aModBits | (aReg << kRegFieldShift) | aRm;
691   }
692 
CreateTrampoline(void * aOrigFunction,intptr_t aDest,void ** aOutTramp)693   void CreateTrampoline(void* aOrigFunction, intptr_t aDest, void** aOutTramp) {
694     *aOutTramp = nullptr;
695 
696     AutoVirtualProtect protectHookPage(mHookPage, mMaxHooks * kHookSize,
697                                        PAGE_EXECUTE_READWRITE);
698     if (!protectHookPage.Protect()) {
699       return;
700     }
701 
702     byteptr_t tramp = FindTrampolineSpace();
703     if (!tramp) {
704       return;
705     }
706 
707     // We keep the address of the original function in the first bytes of
708     // the trampoline buffer
709     *((void**)tramp) = EncodePointer(aOrigFunction);
710     tramp += sizeof(void*);
711 
712     byteptr_t origBytes = (byteptr_t)aOrigFunction;
713 
714     // # of bytes of the original function that we can overwrite.
715     int nOrigBytes = 0;
716 
717 #if defined(_M_IX86)
718     int pJmp32 = -1;
719     while (nOrigBytes < 5) {
720       // Understand some simple instructions that might be found in a
721       // prologue; we might need to extend this as necessary.
722       //
723       // Note!  If we ever need to understand jump instructions, we'll
724       // need to rewrite the displacement argument.
725       unsigned char prefixGroups;
726       int numPrefixBytes =
727           CountPrefixBytes(origBytes, nOrigBytes, &prefixGroups);
728       if (numPrefixBytes < 0 ||
729           (prefixGroups & (ePrefixGroup3 | ePrefixGroup4))) {
730         // Either the prefix sequence was bad, or there are prefixes that
731         // we don't currently support (groups 3 and 4)
732         MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
733         return;
734       }
735       nOrigBytes += numPrefixBytes;
736       if (origBytes[nOrigBytes] >= 0x88 && origBytes[nOrigBytes] <= 0x8B) {
737         // various MOVs
738         ++nOrigBytes;
739         int len = CountModRmSib(origBytes + nOrigBytes);
740         if (len < 0) {
741           MOZ_ASSERT_UNREACHABLE("Unrecognized MOV opcode sequence");
742           return;
743         }
744         nOrigBytes += len;
745       } else if (origBytes[nOrigBytes] == 0xA1) {
746         // MOV eax, [seg:offset]
747         nOrigBytes += 5;
748       } else if (origBytes[nOrigBytes] == 0xB8) {
749         // MOV 0xB8: http://ref.x86asm.net/coder32.html#xB8
750         nOrigBytes += 5;
751       } else if (origBytes[nOrigBytes] == 0x33 &&
752                  (origBytes[nOrigBytes + 1] & kMaskMod) == kModReg) {
753         // XOR r32, r32
754         nOrigBytes += 2;
755       } else if ((origBytes[nOrigBytes] & 0xf8) == 0x40) {
756         // INC r32
757         nOrigBytes += 1;
758       } else if (origBytes[nOrigBytes] == 0x83) {
759         // ADD|ODR|ADC|SBB|AND|SUB|XOR|CMP r/m, imm8
760         unsigned char b = origBytes[nOrigBytes + 1];
761         if ((b & 0xc0) == 0xc0) {
762           // ADD|ODR|ADC|SBB|AND|SUB|XOR|CMP r, imm8
763           nOrigBytes += 3;
764         } else {
765           // bail
766           MOZ_ASSERT_UNREACHABLE("Unrecognized bit opcode sequence");
767           return;
768         }
769       } else if (origBytes[nOrigBytes] == 0x68) {
770         // PUSH with 4-byte operand
771         nOrigBytes += 5;
772       } else if ((origBytes[nOrigBytes] & 0xf0) == 0x50) {
773         // 1-byte PUSH/POP
774         nOrigBytes++;
775       } else if (origBytes[nOrigBytes] == 0x6A) {
776         // PUSH imm8
777         nOrigBytes += 2;
778       } else if (origBytes[nOrigBytes] == 0xe9) {
779         pJmp32 = nOrigBytes;
780         // jmp 32bit offset
781         nOrigBytes += 5;
782       } else if (origBytes[nOrigBytes] == 0xff &&
783                  origBytes[nOrigBytes + 1] == 0x25) {
784         // jmp [disp32]
785         nOrigBytes += 6;
786       } else if (origBytes[nOrigBytes] == 0xc2) {
787       // ret imm16.  We can't handle this but it happens.  We don't ASSERT but
788       // we do fail to hook.
789 #if defined(MOZILLA_INTERNAL_API)
790         NS_WARNING("Cannot hook method -- RET opcode found");
791 #endif
792         return;
793       } else {
794         // printf ("Unknown x86 instruction byte 0x%02x, aborting trampoline\n",
795         // origBytes[nOrigBytes]);
796         MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
797         return;
798       }
799     }
800 
801     // The trampoline is a copy of the instructions that we just traced,
802     // followed by a jump that we add below.
803     memcpy(tramp, aOrigFunction, nOrigBytes);
804 #elif defined(_M_X64)
805     // The number of bytes used by the trampoline.
806     int nTrampBytes = 0;
807     bool foundJmp = false;
808 
809     while (nOrigBytes < 13) {
810       // If we found JMP 32bit offset, we require that the next bytes must
811       // be NOP or INT3.  There is no reason to copy them.
812       // TODO: This used to trigger for Je as well.  Now that I allow
813       // instructions after CALL and JE, I don't think I need that.
814       // The only real value of this condition is that if code follows a JMP
815       // then its _probably_ the target of a JMP somewhere else and we
816       // will be overwriting it, which would be tragic.  This seems
817       // highly unlikely.
818       if (foundJmp) {
819         if (origBytes[nOrigBytes] == 0x90 || origBytes[nOrigBytes] == 0xcc) {
820           nOrigBytes++;
821           continue;
822         }
823         MOZ_ASSERT_UNREACHABLE("Opcode sequence includes commands after JMP");
824         return;
825       }
826       if (origBytes[nOrigBytes] == 0x0f) {
827         COPY_CODES(1);
828         if (origBytes[nOrigBytes] == 0x1f) {
829           // nop (multibyte)
830           COPY_CODES(1);
831           if ((origBytes[nOrigBytes] & 0xc0) == 0x40 &&
832               (origBytes[nOrigBytes] & 0x7) == 0x04) {
833             COPY_CODES(3);
834           } else {
835             MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
836             return;
837           }
838         } else if (origBytes[nOrigBytes] == 0x05) {
839           // syscall
840           COPY_CODES(1);
841         } else if (origBytes[nOrigBytes] == 0x10 ||
842                    origBytes[nOrigBytes] == 0x11) {
843           // SSE: movups xmm, xmm/m128
844           //      movups xmm/m128, xmm
845           COPY_CODES(1);
846           int nModRmSibBytes = CountModRmSib(&origBytes[nOrigBytes]);
847           if (nModRmSibBytes < 0) {
848             MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
849             return;
850           } else {
851             COPY_CODES(nModRmSibBytes);
852           }
853         } else if (origBytes[nOrigBytes] == 0x84) {
854           // je rel32
855           JumpPatch jump(nTrampBytes - 1,  // overwrite the 0x0f we copied above
856                          (intptr_t)(origBytes + nOrigBytes + 5 +
857                                     *(reinterpret_cast<int32_t*>(
858                                         origBytes + nOrigBytes + 1))),
859                          JumpType::Je);
860           nTrampBytes = jump.GenerateJump(tramp);
861           nOrigBytes += 5;
862         } else {
863           MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
864           return;
865         }
866       } else if (origBytes[nOrigBytes] == 0x40 ||
867                  origBytes[nOrigBytes] == 0x41) {
868         // Plain REX or REX.B
869         COPY_CODES(1);
870         if ((origBytes[nOrigBytes] & 0xf0) == 0x50) {
871           // push/pop with Rx register
872           COPY_CODES(1);
873         } else if (origBytes[nOrigBytes] >= 0xb8 &&
874                    origBytes[nOrigBytes] <= 0xbf) {
875           // mov r32, imm32
876           COPY_CODES(5);
877         } else {
878           MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
879           return;
880         }
881       } else if (origBytes[nOrigBytes] == 0x44) {
882         // REX.R
883         COPY_CODES(1);
884 
885         // TODO: Combine with the "0x89" case below in the REX.W section
886         if (origBytes[nOrigBytes] == 0x89) {
887           // mov r/m32, r32
888           COPY_CODES(1);
889           int len = CountModRmSib(origBytes + nOrigBytes);
890           if (len < 0) {
891             MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
892             return;
893           }
894           COPY_CODES(len);
895         } else {
896           MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
897           return;
898         }
899       } else if (origBytes[nOrigBytes] == 0x45) {
900         // REX.R & REX.B
901         COPY_CODES(1);
902 
903         if (origBytes[nOrigBytes] == 0x33) {
904           // xor r32, r32
905           COPY_CODES(2);
906         } else {
907           MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
908           return;
909         }
910       } else if ((origBytes[nOrigBytes] & 0xfa) == 0x48) {
911         // REX.W | REX.WR | REX.WRB | REX.WB
912         COPY_CODES(1);
913 
914         if (origBytes[nOrigBytes] == 0x81 &&
915             (origBytes[nOrigBytes + 1] & 0xf8) == 0xe8) {
916           // sub r, dword
917           COPY_CODES(6);
918         } else if (origBytes[nOrigBytes] == 0x83 &&
919                    (origBytes[nOrigBytes + 1] & 0xf8) == 0xe8) {
920           // sub r, byte
921           COPY_CODES(3);
922         } else if (origBytes[nOrigBytes] == 0x83 &&
923                    (origBytes[nOrigBytes + 1] & (kMaskMod | kMaskReg)) ==
924                        kModReg) {
925           // add r, byte
926           COPY_CODES(3);
927         } else if (origBytes[nOrigBytes] == 0x83 &&
928                    (origBytes[nOrigBytes + 1] & 0xf8) == 0x60) {
929           // and [r+d], imm8
930           COPY_CODES(5);
931         } else if (origBytes[nOrigBytes] == 0x2b &&
932                    (origBytes[nOrigBytes + 1] & kMaskMod) == kModReg) {
933           // sub r64, r64
934           COPY_CODES(2);
935         } else if (origBytes[nOrigBytes] == 0x85) {
936           // 85 /r => TEST r/m32, r32
937           if ((origBytes[nOrigBytes + 1] & 0xc0) == 0xc0) {
938             COPY_CODES(2);
939           } else {
940             MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
941             return;
942           }
943         } else if ((origBytes[nOrigBytes] & 0xfd) == 0x89) {
944           // MOV r/m64, r64 | MOV r64, r/m64
945           BYTE reg;
946           int len = CountModRmSib(origBytes + nOrigBytes + 1, &reg);
947           if (len < 0) {
948             MOZ_ASSERT(len == kModOperand64);
949             if (len != kModOperand64) {
950               return;
951             }
952             nOrigBytes += 2;  // skip the MOV and MOD R/M bytes
953 
954             // The instruction MOVs 64-bit data from a RIP-relative memory
955             // address (determined with a 32-bit offset from RIP) into a
956             // 64-bit register.
957             int64_t* absAddr = reinterpret_cast<int64_t*>(
958                 origBytes + nOrigBytes + 4 +
959                 *reinterpret_cast<int32_t*>(origBytes + nOrigBytes));
960             nOrigBytes += 4;
961 
962             if (reg == kRegAx) {
963               // Destination is RAX.  Encode instruction as MOVABS with a
964               // 64-bit absolute address as its immediate operand.
965               tramp[nTrampBytes] = 0xa1;
966               ++nTrampBytes;
967               int64_t** trampOperandPtr =
968                   reinterpret_cast<int64_t**>(tramp + nTrampBytes);
969               *trampOperandPtr = absAddr;
970               nTrampBytes += 8;
971             } else {
972               // The MOV must be done in two steps.  First, we MOVABS the
973               // absolute 64-bit address into our target register.
974               // Then, we MOV from that address into the register
975               // using register-indirect addressing.
976               tramp[nTrampBytes] = 0xb8 + reg;
977               ++nTrampBytes;
978               int64_t** trampOperandPtr =
979                   reinterpret_cast<int64_t**>(tramp + nTrampBytes);
980               *trampOperandPtr = absAddr;
981               nTrampBytes += 8;
982               tramp[nTrampBytes] = 0x48;
983               tramp[nTrampBytes + 1] = 0x8b;
984               tramp[nTrampBytes + 2] = BuildModRmByte(kModNoRegDisp, reg, reg);
985               nTrampBytes += 3;
986             }
987           } else {
988             COPY_CODES(len + 1);
989           }
990         } else if (origBytes[nOrigBytes] == 0xc7) {
991           // MOV r/m64, imm32
992           if (origBytes[nOrigBytes + 1] == 0x44) {
993             // MOV [r64+disp8], imm32
994             // ModR/W + SIB + disp8 + imm32
995             COPY_CODES(8);
996           } else {
997             MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
998             return;
999           }
1000         } else if (origBytes[nOrigBytes] == 0xff) {
1001           // JMP /4
1002           if ((origBytes[nOrigBytes + 1] & 0xc0) == 0x0 &&
1003               (origBytes[nOrigBytes + 1] & 0x07) == 0x5) {
1004             // [rip+disp32]
1005             // convert JMP 32bit offset to JMP 64bit direct
1006             JumpPatch jump(
1007                 nTrampBytes - 1,  // overwrite the REX.W/REX.WR we copied above
1008                 *reinterpret_cast<intptr_t*>(
1009                     origBytes + nOrigBytes + 6 +
1010                     *reinterpret_cast<int32_t*>(origBytes + nOrigBytes + 2)),
1011                 JumpType::Jmp);
1012             nTrampBytes = jump.GenerateJump(tramp);
1013             nOrigBytes += 6;
1014             foundJmp = true;
1015           } else {
1016             // not support yet!
1017             MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
1018             return;
1019           }
1020         } else if (origBytes[nOrigBytes] == 0x8d) {
1021           // LEA reg, addr
1022           if ((origBytes[nOrigBytes + 1] & kMaskMod) == 0x0 &&
1023               (origBytes[nOrigBytes + 1] & kMaskRm) == 0x5) {
1024             // [rip+disp32]
1025             // convert 32bit offset to 64bit direct and convert instruction
1026             // to a simple 64-bit mov
1027             BYTE reg = (origBytes[nOrigBytes + 1] & kMaskReg) >> kRegFieldShift;
1028             intptr_t absAddr = reinterpret_cast<intptr_t>(
1029                 origBytes + nOrigBytes + 6 +
1030                 *reinterpret_cast<int32_t*>(origBytes + nOrigBytes + 2));
1031             nOrigBytes += 6;
1032             tramp[nTrampBytes] = 0xb8 + reg;  // mov
1033             ++nTrampBytes;
1034             intptr_t* trampOperandPtr =
1035                 reinterpret_cast<intptr_t*>(tramp + nTrampBytes);
1036             *trampOperandPtr = absAddr;
1037             nTrampBytes += 8;
1038           } else {
1039             // Above we dealt with RIP-relative instructions.  Any other
1040             // operand form can simply be copied.
1041             int len = CountModRmSib(origBytes + nOrigBytes + 1);
1042             // We handled the kModOperand64 -- ie RIP-relative -- case above
1043             MOZ_ASSERT(len > 0);
1044             COPY_CODES(len + 1);
1045           }
1046         } else if (origBytes[nOrigBytes] == 0x63 &&
1047                    (origBytes[nOrigBytes + 1] & kMaskMod) == kModReg) {
1048           // movsxd r64, r32 (move + sign extend)
1049           COPY_CODES(2);
1050         } else {
1051           // not support yet!
1052           MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
1053           return;
1054         }
1055       } else if (origBytes[nOrigBytes] == 0x66) {
1056         // operand override prefix
1057         COPY_CODES(1);
1058         // This is the same as the x86 version
1059         if (origBytes[nOrigBytes] >= 0x88 && origBytes[nOrigBytes] <= 0x8B) {
1060           // various MOVs
1061           unsigned char b = origBytes[nOrigBytes + 1];
1062           if (((b & 0xc0) == 0xc0) ||
1063               (((b & 0xc0) == 0x00) && ((b & 0x07) != 0x04) &&
1064                ((b & 0x07) != 0x05))) {
1065             // REG=r, R/M=r or REG=r, R/M=[r]
1066             COPY_CODES(2);
1067           } else if ((b & 0xc0) == 0x40) {
1068             if ((b & 0x07) == 0x04) {
1069               // REG=r, R/M=[SIB + disp8]
1070               COPY_CODES(4);
1071             } else {
1072               // REG=r, R/M=[r + disp8]
1073               COPY_CODES(3);
1074             }
1075           } else {
1076             // complex MOV, bail
1077             MOZ_ASSERT_UNREACHABLE("Unrecognized MOV opcode sequence");
1078             return;
1079           }
1080         } else if (origBytes[nOrigBytes] == 0x44 &&
1081                    origBytes[nOrigBytes + 1] == 0x89) {
1082           // mov word ptr [reg+disp8], reg
1083           COPY_CODES(2);
1084           int len = CountModRmSib(origBytes + nOrigBytes);
1085           if (len < 0) {
1086             // no way to support this yet.
1087             MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
1088             return;
1089           }
1090           COPY_CODES(len);
1091         }
1092       } else if ((origBytes[nOrigBytes] & 0xf0) == 0x50) {
1093         // 1-byte push/pop
1094         COPY_CODES(1);
1095       } else if (origBytes[nOrigBytes] == 0x65) {
1096         // GS prefix
1097         //
1098         // The entry of GetKeyState on Windows 10 has the following code.
1099         // 65 48 8b 04 25 30 00 00 00    mov   rax,qword ptr gs:[30h]
1100         // (GS prefix + REX + MOV (0x8b) ...)
1101         if (origBytes[nOrigBytes + 1] == 0x48 &&
1102             (origBytes[nOrigBytes + 2] >= 0x88 &&
1103              origBytes[nOrigBytes + 2] <= 0x8b)) {
1104           COPY_CODES(3);
1105           int len = CountModRmSib(origBytes + nOrigBytes);
1106           if (len < 0) {
1107             // no way to support this yet.
1108             MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
1109             return;
1110           }
1111           COPY_CODES(len);
1112         } else {
1113           MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
1114           return;
1115         }
1116       } else if (origBytes[nOrigBytes] == 0x80 &&
1117                  origBytes[nOrigBytes + 1] == 0x3d) {
1118         // cmp byte ptr [rip-relative address], imm8
1119         // We'll compute the absolute address and do the cmp in r11
1120 
1121         // push r11 (to save the old value)
1122         tramp[nTrampBytes] = 0x49;
1123         ++nTrampBytes;
1124         tramp[nTrampBytes] = 0x53;
1125         ++nTrampBytes;
1126 
1127         byteptr_t absAddr = reinterpret_cast<byteptr_t>(
1128             origBytes + nOrigBytes + 7 +
1129             *reinterpret_cast<int32_t*>(origBytes + nOrigBytes + 2));
1130         nOrigBytes += 6;
1131 
1132         // mov r11, absolute address
1133         tramp[nTrampBytes] = 0x49;
1134         ++nTrampBytes;
1135         tramp[nTrampBytes] = 0xbb;
1136         ++nTrampBytes;
1137 
1138         *reinterpret_cast<byteptr_t*>(tramp + nTrampBytes) = absAddr;
1139         nTrampBytes += 8;
1140 
1141         // cmp byte ptr [r11],...
1142         tramp[nTrampBytes] = 0x41;
1143         ++nTrampBytes;
1144         tramp[nTrampBytes] = 0x80;
1145         ++nTrampBytes;
1146         tramp[nTrampBytes] = 0x3b;
1147         ++nTrampBytes;
1148 
1149         // ...imm8
1150         COPY_CODES(1);
1151 
1152         // pop r11 (doesn't affect the flags from the cmp)
1153         tramp[nTrampBytes] = 0x49;
1154         ++nTrampBytes;
1155         tramp[nTrampBytes] = 0x5b;
1156         ++nTrampBytes;
1157       } else if (origBytes[nOrigBytes] == 0x90) {
1158         // nop
1159         COPY_CODES(1);
1160       } else if ((origBytes[nOrigBytes] & 0xf8) == 0xb8) {
1161         // MOV r32, imm32
1162         COPY_CODES(5);
1163       } else if (origBytes[nOrigBytes] == 0x33) {
1164         // xor r32, r/m32
1165         COPY_CODES(2);
1166       } else if (origBytes[nOrigBytes] == 0xf6) {
1167         // test r/m8, imm8 (used by ntdll on Windows 10 x64)
1168         // (no flags are affected by near jmp since there is no task switch,
1169         // so it is ok for a jmp to be written immediately after a test)
1170         BYTE subOpcode = 0;
1171         int nModRmSibBytes =
1172             CountModRmSib(&origBytes[nOrigBytes + 1], &subOpcode);
1173         if (nModRmSibBytes < 0 || subOpcode != 0) {
1174           // Unsupported
1175           MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
1176           return;
1177         }
1178         COPY_CODES(2 + nModRmSibBytes);
1179       } else if (origBytes[nOrigBytes] == 0x85) {
1180         // test r/m32, r32
1181         int nModRmSibBytes = CountModRmSib(&origBytes[nOrigBytes + 1]);
1182         if (nModRmSibBytes < 0) {
1183           MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
1184           return;
1185         }
1186         COPY_CODES(1 + nModRmSibBytes);
1187       } else if (origBytes[nOrigBytes] == 0xd1 &&
1188                  (origBytes[nOrigBytes + 1] & kMaskMod) == kModReg) {
1189         // bit shifts/rotates : (SA|SH|RO|RC)(R|L) r32
1190         // (e.g. 0xd1 0xe0 is SAL, 0xd1 0xc8 is ROR)
1191         COPY_CODES(2);
1192       } else if (origBytes[nOrigBytes] == 0xc3) {
1193         // ret
1194         COPY_CODES(1);
1195       } else if (origBytes[nOrigBytes] == 0xcc) {
1196         // int 3
1197         COPY_CODES(1);
1198       } else if (origBytes[nOrigBytes] == 0xe8 ||
1199                  origBytes[nOrigBytes] == 0xe9) {
1200         // CALL (0xe8) or JMP (0xe9) 32bit offset
1201         foundJmp = origBytes[nOrigBytes] == 0xe9;
1202         JumpPatch jump(
1203             nTrampBytes,
1204             (intptr_t)(
1205                 origBytes + nOrigBytes + 5 +
1206                 *(reinterpret_cast<int32_t*>(origBytes + nOrigBytes + 1))),
1207             origBytes[nOrigBytes] == 0xe8 ? JumpType::Call : JumpType::Jmp);
1208         nTrampBytes = jump.GenerateJump(tramp);
1209         nOrigBytes += 5;
1210       } else if (origBytes[nOrigBytes] == 0x74 ||  // je rel8 (0x74)
1211                  origBytes[nOrigBytes] == 0x75) {  // jne rel8 (0x75)
1212         char offset = origBytes[nOrigBytes + 1];
1213         auto jumpType = JumpType::Je;
1214         if (origBytes[nOrigBytes] == 0x75) jumpType = JumpType::Jne;
1215         JumpPatch jump(nTrampBytes,
1216                        (intptr_t)(origBytes + nOrigBytes + 2 + offset),
1217                        jumpType);
1218         nTrampBytes = jump.GenerateJump(tramp);
1219         nOrigBytes += 2;
1220       } else if (origBytes[nOrigBytes] == 0xff) {
1221         if ((origBytes[nOrigBytes + 1] & (kMaskMod | kMaskReg)) == 0xf0) {
1222           // push r64
1223           COPY_CODES(2);
1224         } else if (origBytes[nOrigBytes + 1] == 0x25) {
1225           // jmp absolute indirect m32
1226           foundJmp = true;
1227           int32_t offset =
1228               *(reinterpret_cast<int32_t*>(origBytes + nOrigBytes + 2));
1229           int64_t* ptrToJmpDest =
1230               reinterpret_cast<int64_t*>(origBytes + nOrigBytes + 6 + offset);
1231           intptr_t jmpDest = static_cast<intptr_t>(*ptrToJmpDest);
1232           JumpPatch jump(nTrampBytes, jmpDest, JumpType::Jmp);
1233           nTrampBytes = jump.GenerateJump(tramp);
1234           nOrigBytes += 6;
1235         } else if ((origBytes[nOrigBytes + 1] & (kMaskMod | kMaskReg)) ==
1236                    BuildModRmByte(kModReg, 2, 0)) {
1237           // CALL reg (ff nn)
1238           COPY_CODES(2);
1239         } else {
1240           MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
1241           return;
1242         }
1243       } else if (origBytes[nOrigBytes] == 0x83 &&
1244                  (origBytes[nOrigBytes + 1] & 0xf8) == 0x60) {
1245         // and [r+d], imm8
1246         COPY_CODES(5);
1247       } else if (origBytes[nOrigBytes] == 0xc6) {
1248         // mov [r+d], imm8
1249         int len = CountModRmSib(&origBytes[nOrigBytes + 1]);
1250         if (len < 0) {
1251           // RIP-relative not yet supported
1252           MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
1253           return;
1254         }
1255         COPY_CODES(len + 1);
1256       } else {
1257         MOZ_ASSERT_UNREACHABLE("Unrecognized opcode sequence");
1258         return;
1259       }
1260     }
1261 #else
1262 #error "Unknown processor type"
1263 #endif
1264 
1265     if (nOrigBytes > 100) {
1266       // printf ("Too big!");
1267       return;
1268     }
1269 
1270     // target address of the final jmp instruction in the trampoline
1271     byteptr_t trampDest = origBytes + nOrigBytes;
1272 
1273 #if defined(_M_IX86)
1274     if (pJmp32 >= 0) {
1275       // Jump directly to the original target of the jump instead of jumping to
1276       // the original function. Adjust jump target displacement to jump location
1277       // in the trampoline.
1278       *((intptr_t*)(tramp + pJmp32 + 1)) += origBytes - tramp;
1279     } else {
1280       tramp[nOrigBytes] = 0xE9;  // jmp
1281       *((intptr_t*)(tramp + nOrigBytes + 1)) =
1282           (intptr_t)trampDest -
1283           (intptr_t)(tramp + nOrigBytes + 5);  // target displacement
1284     }
1285 #elif defined(_M_X64)
1286     // If the we found a Jmp, we don't need to add another instruction. However,
1287     // if we found a _conditional_ jump or a CALL (or no control operations
1288     // at all) then we still need to run the rest of aOriginalFunction.
1289     if (!foundJmp) {
1290       JumpPatch patch(nTrampBytes, reinterpret_cast<intptr_t>(trampDest));
1291       patch.GenerateJump(tramp);
1292     }
1293 #endif
1294 
1295     // The trampoline is now valid.
1296     *aOutTramp = tramp;
1297 
1298     // ensure we can modify the original code
1299     AutoVirtualProtect protect(aOrigFunction, nOrigBytes,
1300                                PAGE_EXECUTE_READWRITE);
1301     if (!protect.Protect()) {
1302       return;
1303     }
1304 
1305 #if defined(_M_IX86)
1306     // now modify the original bytes
1307     origBytes[0] = 0xE9;  // jmp
1308     *((intptr_t*)(origBytes + 1)) =
1309         aDest - (intptr_t)(origBytes + 5);  // target displacement
1310 #elif defined(_M_X64)
1311     // mov r11, address
1312     origBytes[0] = 0x49;
1313     origBytes[1] = 0xbb;
1314 
1315     *((intptr_t*)(origBytes + 2)) = aDest;
1316 
1317     // jmp r11
1318     origBytes[10] = 0x41;
1319     origBytes[11] = 0xff;
1320     origBytes[12] = 0xe3;
1321 #endif
1322   }
1323 
FindTrampolineSpace()1324   byteptr_t FindTrampolineSpace() {
1325     if (mCurHooks >= mMaxHooks) {
1326       return 0;
1327     }
1328 
1329     byteptr_t p = mHookPage + mCurHooks * kHookSize;
1330 
1331     mCurHooks++;
1332 
1333     return p;
1334   }
1335 
ResolveRedirectedAddress(const byteptr_t aOriginalFunction)1336   static void* ResolveRedirectedAddress(const byteptr_t aOriginalFunction) {
1337     // If function entry is jmp rel8 stub to the internal implementation, we
1338     // resolve redirected address from the jump target.
1339     if (aOriginalFunction[0] == 0xeb) {
1340       int8_t offset = (int8_t)(aOriginalFunction[1]);
1341       if (offset <= 0) {
1342         // Bail out for negative offset: probably already patched by some
1343         // third-party code.
1344         return aOriginalFunction;
1345       }
1346 
1347       for (int8_t i = 0; i < offset; i++) {
1348         if (aOriginalFunction[2 + i] != 0x90) {
1349           // Bail out on insufficient nop space.
1350           return aOriginalFunction;
1351         }
1352       }
1353 
1354       return aOriginalFunction + 2 + offset;
1355     }
1356 
1357 #if defined(_M_IX86)
1358     // If function entry is jmp [disp32] such as used by kernel32,
1359     // we resolve redirected address from import table.
1360     if (aOriginalFunction[0] == 0xff && aOriginalFunction[1] == 0x25) {
1361       return (void*)(**((uint32_t**)(aOriginalFunction + 2)));
1362     }
1363 #elif defined(_M_X64)
1364     if (aOriginalFunction[0] == 0xe9) {
1365       // require for TestDllInterceptor with --disable-optimize
1366       int32_t offset = *((int32_t*)(aOriginalFunction + 1));
1367       return aOriginalFunction + 5 + offset;
1368     }
1369 #endif
1370 
1371     return aOriginalFunction;
1372   }
1373 };
1374 
1375 }  // namespace internal
1376 
1377 class WindowsDllInterceptor {
1378   internal::WindowsDllNopSpacePatcher mNopSpacePatcher;
1379   internal::WindowsDllDetourPatcher mDetourPatcher;
1380 
1381   const char* mModuleName;
1382   int mNHooks;
1383 
1384  public:
1385   explicit WindowsDllInterceptor(const char* aModuleName = nullptr,
1386                                  int aNumHooks = 0)
mModuleName(nullptr)1387       : mModuleName(nullptr), mNHooks(0) {
1388     if (aModuleName) {
1389       Init(aModuleName, aNumHooks);
1390     }
1391   }
1392 
1393   void Init(const char* aModuleName, int aNumHooks = 0) {
1394     if (mModuleName) {
1395       return;
1396     }
1397 
1398     mModuleName = aModuleName;
1399     mNHooks = aNumHooks;
1400     mNopSpacePatcher.Init(aModuleName);
1401 
1402     // Lazily initialize mDetourPatcher, since it allocates memory and we might
1403     // not need it.
1404   }
1405 
1406   /**
1407    * Hook/detour the method aName from the DLL we set in Init so that it calls
1408    * aHookDest instead.  Returns the original method pointer in aOrigFunc
1409    * and returns true if successful.
1410    *
1411    * IMPORTANT: If you use this method, please add your case to the
1412    * TestDllInterceptor in order to detect future failures.  Even if this
1413    * succeeds now, updates to the hooked DLL could cause it to fail in
1414    * the future.
1415    */
AddHook(const char * aName,intptr_t aHookDest,void ** aOrigFunc)1416   bool AddHook(const char* aName, intptr_t aHookDest, void** aOrigFunc) {
1417     // Use a nop space patch if possible, otherwise fall back to a detour.
1418     // This should be the preferred method for adding hooks.
1419 
1420     if (!mModuleName) {
1421       return false;
1422     }
1423 
1424     if (mNopSpacePatcher.AddHook(aName, aHookDest, aOrigFunc)) {
1425       return true;
1426     }
1427 
1428     return AddDetour(aName, aHookDest, aOrigFunc);
1429   }
1430 
1431   /**
1432    * Detour the method aName from the DLL we set in Init so that it calls
1433    * aHookDest instead.  Returns the original method pointer in aOrigFunc
1434    * and returns true if successful.
1435    *
1436    * IMPORTANT: If you use this method, please add your case to the
1437    * TestDllInterceptor in order to detect future failures.  Even if this
1438    * succeeds now, updates to the detoured DLL could cause it to fail in
1439    * the future.
1440    */
AddDetour(const char * aName,intptr_t aHookDest,void ** aOrigFunc)1441   bool AddDetour(const char* aName, intptr_t aHookDest, void** aOrigFunc) {
1442     // Generally, code should not call this method directly. Use AddHook unless
1443     // there is a specific need to avoid nop space patches.
1444 
1445     if (!mModuleName) {
1446       return false;
1447     }
1448 
1449     if (!mDetourPatcher.Initialized()) {
1450       mDetourPatcher.Init(mModuleName, mNHooks);
1451     }
1452 
1453     return mDetourPatcher.AddHook(aName, aHookDest, aOrigFunc);
1454   }
1455 };
1456 
1457 }  // namespace mozilla
1458 
1459 #endif /* NS_WINDOWS_DLL_INTERCEPTOR_H_ */
1460