1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this
5  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #ifndef NS_WINDOWS_DLL_INTERCEPTOR_H_
8 #define NS_WINDOWS_DLL_INTERCEPTOR_H_
9 
10 #include <wchar.h>
11 #include <windows.h>
12 #include <winternl.h>
13 
14 #include <utility>
15 
16 #include "mozilla/ArrayUtils.h"
17 #include "mozilla/Assertions.h"
18 #include "mozilla/Atomics.h"
19 #include "mozilla/Attributes.h"
20 #include "mozilla/CheckedInt.h"
21 #include "mozilla/DebugOnly.h"
22 #include "mozilla/NativeNt.h"
23 #include "mozilla/Tuple.h"
24 #include "mozilla/Types.h"
25 #include "mozilla/UniquePtr.h"
26 #include "mozilla/Vector.h"
27 #include "mozilla/interceptor/MMPolicies.h"
28 #include "mozilla/interceptor/PatcherDetour.h"
29 #include "mozilla/interceptor/PatcherNopSpace.h"
30 #include "mozilla/interceptor/VMSharingPolicies.h"
31 #include "nsWindowsHelpers.h"
32 
33 /*
34  * Simple function interception.
35  *
36  * We have two separate mechanisms for intercepting a function: We can use the
37  * built-in nop space, if it exists, or we can create a detour.
38  *
39  * Using the built-in nop space works as follows: On x86-32, DLL functions
40  * begin with a two-byte nop (mov edi, edi) and are preceeded by five bytes of
41  * NOP instructions.
42  *
43  * When we detect a function with this prelude, we do the following:
44  *
45  * 1. Write a long jump to our interceptor function into the five bytes of NOPs
46  *    before the function.
47  *
48  * 2. Write a short jump -5 into the two-byte nop at the beginning of the
49  *    function.
50  *
51  * This mechanism is nice because it's thread-safe.  It's even safe to do if
52  * another thread is currently running the function we're modifying!
53  *
54  * When the WindowsDllNopSpacePatcher is destroyed, we overwrite the short jump
55  * but not the long jump, so re-intercepting the same function won't work,
56  * because its prelude won't match.
57  *
58  *
59  * Unfortunately nop space patching doesn't work on functions which don't have
60  * this magic prelude (and in particular, x86-64 never has the prelude).  So
61  * when we can't use the built-in nop space, we fall back to using a detour,
62  * which works as follows:
63  *
64  * 1. Save first N bytes of OrigFunction to trampoline, where N is a
65  *    number of bytes >= 5 that are instruction aligned.
66  *
67  * 2. Replace first 5 bytes of OrigFunction with a jump to the Hook
68  *    function.
69  *
70  * 3. After N bytes of the trampoline, add a jump to OrigFunction+N to
71  *    continue original program flow.
72  *
73  * 4. Hook function needs to call the trampoline during its execution,
74  *    to invoke the original function (so address of trampoline is
75  *    returned).
76  *
77  * When the WindowsDllDetourPatcher object is destructed, OrigFunction is
78  * patched again to jump directly to the trampoline instead of going through
79  * the hook function. As such, re-intercepting the same function won't work, as
80  * jump instructions are not supported.
81  *
82  * Note that this is not thread-safe.  Sad day.
83  *
84  */
85 
86 #if defined(_M_IX86) && defined(__clang__) && __has_declspec_attribute(guard)
87 // On x86, nop-space patches return to the second instruction of their target.
88 // This is a deliberate violation of Control Flow Guard, so disable the check.
89 #  define INTERCEPTOR_DISABLE_CFGUARD __declspec(guard(nocf))
90 #else
91 #  define INTERCEPTOR_DISABLE_CFGUARD /* nothing */
92 #endif
93 
94 namespace mozilla {
95 namespace interceptor {
96 
97 template <typename T>
98 struct OriginalFunctionPtrTraits;
99 
100 template <typename R, typename... Args>
101 struct OriginalFunctionPtrTraits<R (*)(Args...)> {
102   using ReturnType = R;
103 };
104 
105 #if defined(_M_IX86)
106 template <typename R, typename... Args>
107 struct OriginalFunctionPtrTraits<R(__stdcall*)(Args...)> {
108   using ReturnType = R;
109 };
110 
111 template <typename R, typename... Args>
112 struct OriginalFunctionPtrTraits<R(__fastcall*)(Args...)> {
113   using ReturnType = R;
114 };
115 #endif  // defined(_M_IX86)
116 
117 template <typename InterceptorT, typename FuncPtrT>
118 class FuncHook final {
119  public:
120   using ThisType = FuncHook<InterceptorT, FuncPtrT>;
121   using ReturnType = typename OriginalFunctionPtrTraits<FuncPtrT>::ReturnType;
122 
123   constexpr FuncHook() : mOrigFunc(nullptr), mInitOnce(INIT_ONCE_STATIC_INIT) {}
124 
125   ~FuncHook() = default;
126 
127   bool Set(InterceptorT& aInterceptor, const char* aName, FuncPtrT aHookDest) {
128     LPVOID addHookOk = nullptr;
129     InitOnceContext ctx(this, &aInterceptor, aName, aHookDest, false);
130 
131     return ::InitOnceExecuteOnce(&mInitOnce, &InitOnceCallback, &ctx,
132                                  &addHookOk) &&
133            addHookOk;
134   }
135 
136   bool SetDetour(InterceptorT& aInterceptor, const char* aName,
137                  FuncPtrT aHookDest) {
138     LPVOID addHookOk = nullptr;
139     InitOnceContext ctx(this, &aInterceptor, aName, aHookDest, true);
140 
141     return ::InitOnceExecuteOnce(&mInitOnce, &InitOnceCallback, &ctx,
142                                  &addHookOk) &&
143            addHookOk;
144   }
145 
146   explicit operator bool() const { return !!mOrigFunc; }
147 
148   template <typename... ArgsType>
149   INTERCEPTOR_DISABLE_CFGUARD ReturnType operator()(ArgsType&&... aArgs) const {
150     return mOrigFunc(std::forward<ArgsType>(aArgs)...);
151   }
152 
153   FuncPtrT GetStub() const { return mOrigFunc; }
154 
155   // One-time init stuff cannot be moved or copied
156   FuncHook(const FuncHook&) = delete;
157   FuncHook(FuncHook&&) = delete;
158   FuncHook& operator=(const FuncHook&) = delete;
159   FuncHook& operator=(FuncHook&& aOther) = delete;
160 
161  private:
162   struct MOZ_RAII InitOnceContext final {
163     InitOnceContext(ThisType* aHook, InterceptorT* aInterceptor,
164                     const char* aName, FuncPtrT aHookDest, bool aForceDetour)
165         : mHook(aHook),
166           mInterceptor(aInterceptor),
167           mName(aName),
168           mHookDest(reinterpret_cast<void*>(aHookDest)),
169           mForceDetour(aForceDetour) {}
170 
171     ThisType* mHook;
172     InterceptorT* mInterceptor;
173     const char* mName;
174     void* mHookDest;
175     bool mForceDetour;
176   };
177 
178  private:
179   bool Apply(InterceptorT* aInterceptor, const char* aName, void* aHookDest) {
180     return aInterceptor->AddHook(aName, reinterpret_cast<intptr_t>(aHookDest),
181                                  reinterpret_cast<void**>(&mOrigFunc));
182   }
183 
184   bool ApplyDetour(InterceptorT* aInterceptor, const char* aName,
185                    void* aHookDest) {
186     return aInterceptor->AddDetour(aName, reinterpret_cast<intptr_t>(aHookDest),
187                                    reinterpret_cast<void**>(&mOrigFunc));
188   }
189 
190   static BOOL CALLBACK InitOnceCallback(PINIT_ONCE aInitOnce, PVOID aParam,
191                                         PVOID* aOutContext) {
192     MOZ_ASSERT(aOutContext);
193 
194     bool result;
195     auto ctx = reinterpret_cast<InitOnceContext*>(aParam);
196     if (ctx->mForceDetour) {
197       result = ctx->mHook->ApplyDetour(ctx->mInterceptor, ctx->mName,
198                                        ctx->mHookDest);
199     } else {
200       result = ctx->mHook->Apply(ctx->mInterceptor, ctx->mName, ctx->mHookDest);
201     }
202 
203     *aOutContext =
204         result ? reinterpret_cast<PVOID>(1U << INIT_ONCE_CTX_RESERVED_BITS)
205                : nullptr;
206     return TRUE;
207   }
208 
209  private:
210   FuncPtrT mOrigFunc;
211   INIT_ONCE mInitOnce;
212 };
213 
214 template <typename InterceptorT, typename FuncPtrT>
215 class MOZ_ONLY_USED_TO_AVOID_STATIC_CONSTRUCTORS FuncHookCrossProcess final {
216  public:
217   using ThisType = FuncHookCrossProcess<InterceptorT, FuncPtrT>;
218   using ReturnType = typename OriginalFunctionPtrTraits<FuncPtrT>::ReturnType;
219 
220 #if defined(DEBUG)
221   FuncHookCrossProcess() {}
222 #endif  // defined(DEBUG)
223 
224   bool Set(nt::CrossExecTransferManager& aTransferMgr,
225            InterceptorT& aInterceptor, const char* aName, FuncPtrT aHookDest) {
226     FuncPtrT origFunc;
227     if (!aInterceptor.AddHook(aName, reinterpret_cast<intptr_t>(aHookDest),
228                               reinterpret_cast<void**>(&origFunc))) {
229       return false;
230     }
231 
232     return CopyStubToChildProcess(aTransferMgr, aInterceptor, origFunc);
233   }
234 
235   bool SetDetour(nt::CrossExecTransferManager& aTransferMgr,
236                  InterceptorT& aInterceptor, const char* aName,
237                  FuncPtrT aHookDest) {
238     FuncPtrT origFunc;
239     if (!aInterceptor.AddDetour(aName, reinterpret_cast<intptr_t>(aHookDest),
240                                 reinterpret_cast<void**>(&origFunc))) {
241       return false;
242     }
243 
244     return CopyStubToChildProcess(aTransferMgr, aInterceptor, origFunc);
245   }
246 
247   explicit operator bool() const { return !!mOrigFunc; }
248 
249   /**
250    * NB: This operator is only meaningful when invoked in the target process!
251    */
252   template <typename... ArgsType>
253   ReturnType operator()(ArgsType&&... aArgs) const {
254     return mOrigFunc(std::forward<ArgsType>(aArgs)...);
255   }
256 
257 #if defined(DEBUG)
258   FuncHookCrossProcess(const FuncHookCrossProcess&) = delete;
259   FuncHookCrossProcess(FuncHookCrossProcess&&) = delete;
260   FuncHookCrossProcess& operator=(const FuncHookCrossProcess&) = delete;
261   FuncHookCrossProcess& operator=(FuncHookCrossProcess&& aOther) = delete;
262 #endif  // defined(DEBUG)
263 
264  private:
265   bool CopyStubToChildProcess(nt::CrossExecTransferManager& aTransferMgr,
266                               InterceptorT& aInterceptor, FuncPtrT aStub) {
267     LauncherVoidResult writeResult =
268         aTransferMgr.Transfer(&mOrigFunc, &aStub, sizeof(FuncPtrT));
269     if (writeResult.isErr()) {
270 #ifdef MOZ_USE_LAUNCHER_ERROR
271       const mozilla::WindowsError& err = writeResult.inspectErr().mError;
272 #else
273       const mozilla::WindowsError& err = writeResult.inspectErr();
274 #endif
275       aInterceptor.SetLastDetourError(FUNCHOOKCROSSPROCESS_COPYSTUB_ERROR,
276                                       err.AsHResult());
277       return false;
278     }
279     return true;
280   }
281 
282  private:
283   FuncPtrT mOrigFunc;
284 };
285 
286 template <typename MMPolicyT, typename InterceptorT>
287 struct TypeResolver;
288 
289 template <typename InterceptorT>
290 struct TypeResolver<mozilla::interceptor::MMPolicyInProcess, InterceptorT> {
291   template <typename FuncPtrT>
292   using FuncHookType = FuncHook<InterceptorT, FuncPtrT>;
293 };
294 
295 template <typename InterceptorT>
296 struct TypeResolver<mozilla::interceptor::MMPolicyOutOfProcess, InterceptorT> {
297   template <typename FuncPtrT>
298   using FuncHookType = FuncHookCrossProcess<InterceptorT, FuncPtrT>;
299 };
300 
301 template <typename VMPolicy = mozilla::interceptor::VMSharingPolicyShared>
302 class WindowsDllInterceptor final
303     : public TypeResolver<typename VMPolicy::MMPolicyT,
304                           WindowsDllInterceptor<VMPolicy>> {
305   typedef WindowsDllInterceptor<VMPolicy> ThisType;
306 
307   interceptor::WindowsDllDetourPatcher<VMPolicy> mDetourPatcher;
308 #if defined(_M_IX86)
309   interceptor::WindowsDllNopSpacePatcher<typename VMPolicy::MMPolicyT>
310       mNopSpacePatcher;
311 #endif  // defined(_M_IX86)
312 
313   HMODULE mModule;
314 
315  public:
316   template <typename... Args>
317   explicit WindowsDllInterceptor(Args&&... aArgs)
318       : mDetourPatcher(std::forward<Args>(aArgs)...)
319 #if defined(_M_IX86)
320         ,
321         mNopSpacePatcher(std::forward<Args>(aArgs)...)
322 #endif  // defined(_M_IX86)
323         ,
324         mModule(nullptr) {
325   }
326 
327   WindowsDllInterceptor(const WindowsDllInterceptor&) = delete;
328   WindowsDllInterceptor(WindowsDllInterceptor&&) = delete;
329   WindowsDllInterceptor& operator=(const WindowsDllInterceptor&) = delete;
330   WindowsDllInterceptor& operator=(WindowsDllInterceptor&&) = delete;
331 
332   ~WindowsDllInterceptor() { Clear(); }
333 
334   template <size_t N>
335   void Init(const char (&aModuleName)[N]) {
336     wchar_t moduleName[N];
337 
338     for (size_t i = 0; i < N; ++i) {
339       MOZ_ASSERT(!(aModuleName[i] & 0x80),
340                  "Use wide-character overload for non-ASCII module names");
341       moduleName[i] = aModuleName[i];
342     }
343 
344     Init(moduleName);
345   }
346 
347   void Init(const wchar_t* aModuleName) {
348     if (mModule) {
349       return;
350     }
351 
352     mModule = ::LoadLibraryW(aModuleName);
353   }
354 
355   /** Force a specific configuration for testing purposes. NOT to be used in
356       production code! **/
357   void TestOnlyDetourInit(const wchar_t* aModuleName, DetourFlags aFlags) {
358     Init(aModuleName);
359     mDetourPatcher.Init(aFlags);
360   }
361 
362   void Clear() {
363     if (!mModule) {
364       return;
365     }
366 
367 #if defined(_M_IX86)
368     mNopSpacePatcher.Clear();
369 #endif  // defined(_M_IX86)
370     mDetourPatcher.Clear();
371 
372     // NB: We intentionally leak mModule
373   }
374 
375 #if defined(NIGHTLY_BUILD)
376   const Maybe<DetourError>& GetLastDetourError() const {
377     return mDetourPatcher.GetLastDetourError();
378   }
379 #endif  // defined(NIGHTLY_BUILD)
380   template <typename... Args>
381   void SetLastDetourError(Args&&... aArgs) {
382     return mDetourPatcher.SetLastDetourError(std::forward<Args>(aArgs)...);
383   }
384 
385   constexpr static uint32_t GetWorstCaseRequiredBytesToPatch() {
386     return WindowsDllDetourPatcherPrimitive<
387         typename VMPolicy::MMPolicyT>::GetWorstCaseRequiredBytesToPatch();
388   }
389 
390  private:
391   /**
392    * Hook/detour the method aName from the DLL we set in Init so that it calls
393    * aHookDest instead.  Returns the original method pointer in aOrigFunc
394    * and returns true if successful.
395    *
396    * IMPORTANT: If you use this method, please add your case to the
397    * TestDllInterceptor in order to detect future failures.  Even if this
398    * succeeds now, updates to the hooked DLL could cause it to fail in
399    * the future.
400    */
401   bool AddHook(const char* aName, intptr_t aHookDest, void** aOrigFunc) {
402     // Use a nop space patch if possible, otherwise fall back to a detour.
403     // This should be the preferred method for adding hooks.
404     if (!mModule) {
405       mDetourPatcher.SetLastDetourError(DetourResultCode::INTERCEPTOR_MOD_NULL);
406       return false;
407     }
408 
409     if (!mDetourPatcher.IsPageAccessible(
410             nt::PEHeaders::HModuleToBaseAddr<uintptr_t>(mModule))) {
411       mDetourPatcher.SetLastDetourError(
412           DetourResultCode::INTERCEPTOR_MOD_INACCESSIBLE);
413       return false;
414     }
415 
416     FARPROC proc = mDetourPatcher.GetProcAddress(mModule, aName);
417     if (!proc) {
418       mDetourPatcher.SetLastDetourError(
419           DetourResultCode::INTERCEPTOR_PROC_NULL);
420       return false;
421     }
422 
423     if (!mDetourPatcher.IsPageAccessible(reinterpret_cast<uintptr_t>(proc))) {
424       mDetourPatcher.SetLastDetourError(
425           DetourResultCode::INTERCEPTOR_PROC_INACCESSIBLE);
426       return false;
427     }
428 
429 #if defined(_M_IX86)
430     if (mNopSpacePatcher.AddHook(proc, aHookDest, aOrigFunc)) {
431       return true;
432     }
433 #endif  // defined(_M_IX86)
434 
435     return AddDetour(proc, aHookDest, aOrigFunc);
436   }
437 
438   /**
439    * Detour the method aName from the DLL we set in Init so that it calls
440    * aHookDest instead.  Returns the original method pointer in aOrigFunc
441    * and returns true if successful.
442    *
443    * IMPORTANT: If you use this method, please add your case to the
444    * TestDllInterceptor in order to detect future failures.  Even if this
445    * succeeds now, updates to the detoured DLL could cause it to fail in
446    * the future.
447    */
448   bool AddDetour(const char* aName, intptr_t aHookDest, void** aOrigFunc) {
449     // Generally, code should not call this method directly. Use AddHook unless
450     // there is a specific need to avoid nop space patches.
451     if (!mModule) {
452       mDetourPatcher.SetLastDetourError(DetourResultCode::INTERCEPTOR_MOD_NULL);
453       return false;
454     }
455 
456     if (!mDetourPatcher.IsPageAccessible(
457             nt::PEHeaders::HModuleToBaseAddr<uintptr_t>(mModule))) {
458       mDetourPatcher.SetLastDetourError(
459           DetourResultCode::INTERCEPTOR_MOD_INACCESSIBLE);
460       return false;
461     }
462 
463     FARPROC proc = mDetourPatcher.GetProcAddress(mModule, aName);
464     if (!proc) {
465       mDetourPatcher.SetLastDetourError(
466           DetourResultCode::INTERCEPTOR_PROC_NULL);
467       return false;
468     }
469 
470     if (!mDetourPatcher.IsPageAccessible(reinterpret_cast<uintptr_t>(proc))) {
471       mDetourPatcher.SetLastDetourError(
472           DetourResultCode::INTERCEPTOR_PROC_INACCESSIBLE);
473       return false;
474     }
475 
476     return AddDetour(proc, aHookDest, aOrigFunc);
477   }
478 
479   bool AddDetour(FARPROC aProc, intptr_t aHookDest, void** aOrigFunc) {
480     MOZ_ASSERT(mModule && aProc);
481 
482     if (!mDetourPatcher.Initialized()) {
483       DetourFlags flags = DetourFlags::eDefault;
484 #if defined(_M_X64)
485       // NTDLL hooks should attempt to use a 10-byte patch because some
486       // injected DLLs do the same and interfere with our stuff.
487       bool needs10BytePatch = (mModule == ::GetModuleHandleW(L"ntdll.dll"));
488 
489       bool isWin8Or81 = IsWin8OrLater() && (!IsWin10OrLater());
490       bool isWin8 = IsWin8OrLater() && (!IsWin8Point1OrLater());
491 
492       bool isKernel32Dll = (mModule == ::GetModuleHandleW(L"kernel32.dll"));
493 
494       bool isDuplicateHandle = (reinterpret_cast<void*>(aProc) ==
495                                 reinterpret_cast<void*>(&::DuplicateHandle));
496 
497       // CloseHandle on Windows 8/8.1 only accomodates 10-byte patches.
498       needs10BytePatch |= isWin8Or81 && isKernel32Dll &&
499                           (reinterpret_cast<void*>(aProc) ==
500                            reinterpret_cast<void*>(&CloseHandle));
501 
502       // CreateFileA and DuplicateHandle on Windows 8 require 10-byte patches.
503       needs10BytePatch |= isWin8 && isKernel32Dll &&
504                           ((reinterpret_cast<void*>(aProc) ==
505                             reinterpret_cast<void*>(&::CreateFileA)) ||
506                            isDuplicateHandle);
507 
508       if (needs10BytePatch) {
509         flags |= DetourFlags::eEnable10BytePatch;
510       }
511 
512       if (isWin8 && isDuplicateHandle) {
513         // Because we can't detour Win8's KERNELBASE!DuplicateHandle,
514         // we detour kernel32!DuplicateHandle (See bug 1659398).
515         flags |= DetourFlags::eDontResolveRedirection;
516       }
517 #endif  // defined(_M_X64)
518 
519       mDetourPatcher.Init(flags);
520     }
521 
522     return mDetourPatcher.AddHook(aProc, aHookDest, aOrigFunc);
523   }
524 
525  private:
526   template <typename InterceptorT, typename FuncPtrT>
527   friend class FuncHook;
528 
529   template <typename InterceptorT, typename FuncPtrT>
530   friend class FuncHookCrossProcess;
531 };
532 
533 /**
534  * IAT patching is intended for use when we only want to intercept a function
535  * call originating from a specific module.
536  */
537 class WindowsIATPatcher final {
538  public:
539   template <typename FuncPtrT>
540   using FuncHookType = FuncHook<WindowsIATPatcher, FuncPtrT>;
541 
542  private:
543   static bool CheckASCII(const char* aInStr) {
544     while (*aInStr) {
545       if (*aInStr & 0x80) {
546         return false;
547       }
548       ++aInStr;
549     }
550     return true;
551   }
552 
553   static bool AddHook(HMODULE aFromModule, const char* aToModuleName,
554                       const char* aTargetFnName, void* aHookDest,
555                       Atomic<void*>* aOutOrigFunc) {
556     if (!aFromModule || !aToModuleName || !aTargetFnName || !aOutOrigFunc) {
557       return false;
558     }
559 
560     // PE Spec requires ASCII names for imported module names
561     const bool isModuleNameAscii = CheckASCII(aToModuleName);
562     MOZ_ASSERT(isModuleNameAscii);
563     if (!isModuleNameAscii) {
564       return false;
565     }
566 
567     // PE Spec requires ASCII names for imported function names
568     const bool isTargetFnNameAscii = CheckASCII(aTargetFnName);
569     MOZ_ASSERT(isTargetFnNameAscii);
570     if (!isTargetFnNameAscii) {
571       return false;
572     }
573 
574     nt::PEHeaders headers(aFromModule);
575     if (!headers) {
576       return false;
577     }
578 
579     PIMAGE_IMPORT_DESCRIPTOR impDesc =
580         headers.GetImportDescriptor(aToModuleName);
581     if (!nt::PEHeaders::IsValid(impDesc)) {
582       // Either aFromModule does not import aToModuleName at load-time, or
583       // aToModuleName is a (currently unsupported) delay-load import.
584       return false;
585     }
586 
587     // Resolve the import name table (INT).
588     auto firstINTThunk = headers.template RVAToPtr<PIMAGE_THUNK_DATA>(
589         impDesc->OriginalFirstThunk);
590     if (!nt::PEHeaders::IsValid(firstINTThunk)) {
591       return false;
592     }
593 
594     Maybe<ptrdiff_t> thunkIndex;
595 
596     // Scan the INT for the location of the thunk for the function named
597     // 'aTargetFnName'.
598     for (PIMAGE_THUNK_DATA curINTThunk = firstINTThunk;
599          nt::PEHeaders::IsValid(curINTThunk); ++curINTThunk) {
600       if (IMAGE_SNAP_BY_ORDINAL(curINTThunk->u1.Ordinal)) {
601         // Currently not supporting import by ordinal; this isn't hard to add,
602         // but we won't bother unless necessary.
603         continue;
604       }
605 
606       PIMAGE_IMPORT_BY_NAME curThunkFnName =
607           headers.template RVAToPtr<PIMAGE_IMPORT_BY_NAME>(
608               curINTThunk->u1.AddressOfData);
609       MOZ_ASSERT(curThunkFnName);
610       if (!curThunkFnName) {
611         // Looks like we have a bad name descriptor. Try to continue.
612         continue;
613       }
614 
615       // Function name checks MUST be case-sensitive!
616       if (!strcmp(aTargetFnName, curThunkFnName->Name)) {
617         // We found the thunk. Save the index of this thunk, as the IAT thunk
618         // is located at the same index in that table as in the INT.
619         thunkIndex = Some(curINTThunk - firstINTThunk);
620         break;
621       }
622     }
623 
624     if (thunkIndex.isNothing()) {
625       // We never found a thunk for that function. Perhaps it's not imported?
626       return false;
627     }
628 
629     if (thunkIndex.value() < 0) {
630       // That's just wrong.
631       return false;
632     }
633 
634     auto firstIATThunk =
635         headers.template RVAToPtr<PIMAGE_THUNK_DATA>(impDesc->FirstThunk);
636     if (!nt::PEHeaders::IsValid(firstIATThunk)) {
637       return false;
638     }
639 
640     // Resolve the IAT thunk for the function we want
641     PIMAGE_THUNK_DATA targetThunk = &firstIATThunk[thunkIndex.value()];
642     if (!nt::PEHeaders::IsValid(targetThunk)) {
643       return false;
644     }
645 
646     auto fnPtr = reinterpret_cast<Atomic<void*>*>(&targetThunk->u1.Function);
647 
648     // Now we can just change out its pointer with our hook function.
649     AutoVirtualProtect prot(fnPtr, sizeof(void*), PAGE_EXECUTE_READWRITE);
650     if (!prot) {
651       return false;
652     }
653 
654     // We do the exchange this way to ensure that *aOutOrigFunc is always valid
655     // once the atomic exchange has taken place.
656     void* tmp;
657 
658     do {
659       tmp = *fnPtr;
660       *aOutOrigFunc = tmp;
661     } while (!fnPtr->compareExchange(tmp, aHookDest));
662 
663     return true;
664   }
665 
666   template <typename InterceptorT, typename FuncPtrT>
667   friend class FuncHook;
668 };
669 
670 template <typename FuncPtrT>
671 class MOZ_ONLY_USED_TO_AVOID_STATIC_CONSTRUCTORS
672     FuncHook<WindowsIATPatcher, FuncPtrT>
673         final {
674  public:
675   using ThisType = FuncHook<WindowsIATPatcher, FuncPtrT>;
676   using ReturnType = typename OriginalFunctionPtrTraits<FuncPtrT>::ReturnType;
677 
678   constexpr FuncHook()
679       : mInitOnce(INIT_ONCE_STATIC_INIT),
680         mFromModule(nullptr),
681         mOrigFunc(nullptr) {}
682 
683 #if defined(DEBUG)
684   ~FuncHook() = default;
685 #endif  // defined(DEBUG)
686 
687   bool Set(const wchar_t* aFromModuleName, const char* aToModuleName,
688            const char* aFnName, FuncPtrT aHookDest) {
689     nsModuleHandle fromModule(::LoadLibraryW(aFromModuleName));
690     if (!fromModule) {
691       return false;
692     }
693 
694     return Set(fromModule, aToModuleName, aFnName, aHookDest);
695   }
696 
697   // We offer this overload in case the client wants finer-grained control over
698   // loading aFromModule.
699   bool Set(nsModuleHandle& aFromModule, const char* aToModuleName,
700            const char* aFnName, FuncPtrT aHookDest) {
701     LPVOID addHookOk = nullptr;
702     InitOnceContext ctx(this, aFromModule, aToModuleName, aFnName, aHookDest);
703 
704     bool result = ::InitOnceExecuteOnce(&mInitOnce, &InitOnceCallback, &ctx,
705                                         &addHookOk) &&
706                   addHookOk;
707     if (!result) {
708       return result;
709     }
710 
711     // If we successfully set the hook then we must retain a strong reference
712     // to the module that we modified.
713     mFromModule = aFromModule.disown();
714     return result;
715   }
716 
717   explicit operator bool() const { return !!mOrigFunc; }
718 
719   template <typename... ArgsType>
720   ReturnType operator()(ArgsType&&... aArgs) const {
721     return mOrigFunc(std::forward<ArgsType>(aArgs)...);
722   }
723 
724   FuncPtrT GetStub() const { return mOrigFunc; }
725 
726 #if defined(DEBUG)
727   // One-time init stuff cannot be moved or copied
728   FuncHook(const FuncHook&) = delete;
729   FuncHook(FuncHook&&) = delete;
730   FuncHook& operator=(const FuncHook&) = delete;
731   FuncHook& operator=(FuncHook&& aOther) = delete;
732 #endif  // defined(DEBUG)
733 
734  private:
735   struct MOZ_RAII InitOnceContext final {
736     InitOnceContext(ThisType* aHook, const nsModuleHandle& aFromModule,
737                     const char* aToModuleName, const char* aFnName,
738                     FuncPtrT aHookDest)
739         : mHook(aHook),
740           mFromModule(aFromModule),
741           mToModuleName(aToModuleName),
742           mFnName(aFnName),
743           mHookDest(reinterpret_cast<void*>(aHookDest)) {}
744 
745     ThisType* mHook;
746     const nsModuleHandle& mFromModule;
747     const char* mToModuleName;
748     const char* mFnName;
749     void* mHookDest;
750   };
751 
752  private:
753   bool Apply(const nsModuleHandle& aFromModule, const char* aToModuleName,
754              const char* aFnName, void* aHookDest) {
755     return WindowsIATPatcher::AddHook(
756         aFromModule, aToModuleName, aFnName, aHookDest,
757         reinterpret_cast<Atomic<void*>*>(&mOrigFunc));
758   }
759 
760   static BOOL CALLBACK InitOnceCallback(PINIT_ONCE aInitOnce, PVOID aParam,
761                                         PVOID* aOutContext) {
762     MOZ_ASSERT(aOutContext);
763 
764     auto ctx = reinterpret_cast<InitOnceContext*>(aParam);
765     bool result = ctx->mHook->Apply(ctx->mFromModule, ctx->mToModuleName,
766                                     ctx->mFnName, ctx->mHookDest);
767 
768     *aOutContext =
769         result ? reinterpret_cast<PVOID>(1U << INIT_ONCE_CTX_RESERVED_BITS)
770                : nullptr;
771     return TRUE;
772   }
773 
774  private:
775   INIT_ONCE mInitOnce;
776   HMODULE mFromModule;  // never freed
777   FuncPtrT mOrigFunc;
778 };
779 
780 /**
781  * This class applies an irreversible patch to jump to a target function
782  * without backing up the original function.
783  */
784 class WindowsDllEntryPointInterceptor final {
785   using DllMainFn = BOOL(WINAPI*)(HINSTANCE, DWORD, LPVOID);
786   using MMPolicyT = MMPolicyInProcessEarlyStage;
787 
788   MMPolicyT mMMPolicy;
789 
790  public:
791   explicit WindowsDllEntryPointInterceptor(
792       const MMPolicyT::Kernel32Exports& aK32Exports)
793       : mMMPolicy(aK32Exports) {}
794 
795   bool Set(const nt::PEHeaders& aHeaders, DllMainFn aDestination) {
796     if (!aHeaders) {
797       return false;
798     }
799 
800     WindowsDllDetourPatcherPrimitive<MMPolicyT> patcher;
801     return patcher.AddIrreversibleHook(
802         mMMPolicy, aHeaders.GetEntryPoint(),
803         reinterpret_cast<uintptr_t>(aDestination));
804   }
805 };
806 
807 }  // namespace interceptor
808 
809 using WindowsDllInterceptor = interceptor::WindowsDllInterceptor<>;
810 
811 using CrossProcessDllInterceptor = interceptor::WindowsDllInterceptor<
812     mozilla::interceptor::VMSharingPolicyUnique<
813         mozilla::interceptor::MMPolicyOutOfProcess>>;
814 
815 using WindowsIATPatcher = interceptor::WindowsIATPatcher;
816 
817 }  // namespace mozilla
818 
819 #endif /* NS_WINDOWS_DLL_INTERCEPTOR_H_ */
820