1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this
5  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #ifndef NS_WINDOWS_DLL_INTERCEPTOR_H_
8 #define NS_WINDOWS_DLL_INTERCEPTOR_H_
9 
10 #include <wchar.h>
11 #include <windows.h>
12 #include <winternl.h>
13 
14 #include <utility>
15 
16 #include "mozilla/ArrayUtils.h"
17 #include "mozilla/Assertions.h"
18 #include "mozilla/Atomics.h"
19 #include "mozilla/Attributes.h"
20 #include "mozilla/CheckedInt.h"
21 #include "mozilla/DebugOnly.h"
22 #include "mozilla/NativeNt.h"
23 #include "mozilla/Tuple.h"
24 #include "mozilla/Types.h"
25 #include "mozilla/UniquePtr.h"
26 #include "mozilla/Vector.h"
27 #include "mozilla/interceptor/MMPolicies.h"
28 #include "mozilla/interceptor/PatcherDetour.h"
29 #include "mozilla/interceptor/PatcherNopSpace.h"
30 #include "mozilla/interceptor/VMSharingPolicies.h"
31 #include "nsWindowsHelpers.h"
32 
33 /*
34  * Simple function interception.
35  *
36  * We have two separate mechanisms for intercepting a function: We can use the
37  * built-in nop space, if it exists, or we can create a detour.
38  *
39  * Using the built-in nop space works as follows: On x86-32, DLL functions
40  * begin with a two-byte nop (mov edi, edi) and are preceeded by five bytes of
41  * NOP instructions.
42  *
43  * When we detect a function with this prelude, we do the following:
44  *
45  * 1. Write a long jump to our interceptor function into the five bytes of NOPs
46  *    before the function.
47  *
48  * 2. Write a short jump -5 into the two-byte nop at the beginning of the
49  *    function.
50  *
51  * This mechanism is nice because it's thread-safe.  It's even safe to do if
52  * another thread is currently running the function we're modifying!
53  *
54  * When the WindowsDllNopSpacePatcher is destroyed, we overwrite the short jump
55  * but not the long jump, so re-intercepting the same function won't work,
56  * because its prelude won't match.
57  *
58  *
59  * Unfortunately nop space patching doesn't work on functions which don't have
60  * this magic prelude (and in particular, x86-64 never has the prelude).  So
61  * when we can't use the built-in nop space, we fall back to using a detour,
62  * which works as follows:
63  *
64  * 1. Save first N bytes of OrigFunction to trampoline, where N is a
65  *    number of bytes >= 5 that are instruction aligned.
66  *
67  * 2. Replace first 5 bytes of OrigFunction with a jump to the Hook
68  *    function.
69  *
70  * 3. After N bytes of the trampoline, add a jump to OrigFunction+N to
71  *    continue original program flow.
72  *
73  * 4. Hook function needs to call the trampoline during its execution,
74  *    to invoke the original function (so address of trampoline is
75  *    returned).
76  *
77  * When the WindowsDllDetourPatcher object is destructed, OrigFunction is
78  * patched again to jump directly to the trampoline instead of going through
79  * the hook function. As such, re-intercepting the same function won't work, as
80  * jump instructions are not supported.
81  *
82  * Note that this is not thread-safe.  Sad day.
83  *
84  */
85 
86 #if defined(_M_IX86) && defined(__clang__) && __has_declspec_attribute(guard)
87 // On x86, nop-space patches return to the second instruction of their target.
88 // This is a deliberate violation of Control Flow Guard, so disable the check.
89 #  define INTERCEPTOR_DISABLE_CFGUARD __declspec(guard(nocf))
90 #else
91 #  define INTERCEPTOR_DISABLE_CFGUARD /* nothing */
92 #endif
93 
94 namespace mozilla {
95 namespace interceptor {
96 
97 template <typename T>
98 struct OriginalFunctionPtrTraits;
99 
100 template <typename R, typename... Args>
101 struct OriginalFunctionPtrTraits<R (*)(Args...)> {
102   using ReturnType = R;
103 };
104 
105 #if defined(_M_IX86)
106 template <typename R, typename... Args>
107 struct OriginalFunctionPtrTraits<R(__stdcall*)(Args...)> {
108   using ReturnType = R;
109 };
110 
111 template <typename R, typename... Args>
112 struct OriginalFunctionPtrTraits<R(__fastcall*)(Args...)> {
113   using ReturnType = R;
114 };
115 #endif  // defined(_M_IX86)
116 
117 template <typename InterceptorT, typename FuncPtrT>
118 class FuncHook final {
119  public:
120   using ThisType = FuncHook<InterceptorT, FuncPtrT>;
121   using ReturnType = typename OriginalFunctionPtrTraits<FuncPtrT>::ReturnType;
122 
123   constexpr FuncHook() : mOrigFunc(nullptr), mInitOnce(INIT_ONCE_STATIC_INIT) {}
124 
125   ~FuncHook() = default;
126 
127   bool Set(InterceptorT& aInterceptor, const char* aName, FuncPtrT aHookDest) {
128     LPVOID addHookOk = nullptr;
129     InitOnceContext ctx(this, &aInterceptor, aName, aHookDest, false);
130 
131     return ::InitOnceExecuteOnce(&mInitOnce, &InitOnceCallback, &ctx,
132                                  &addHookOk) &&
133            addHookOk;
134   }
135 
136   bool SetDetour(InterceptorT& aInterceptor, const char* aName,
137                  FuncPtrT aHookDest) {
138     LPVOID addHookOk = nullptr;
139     InitOnceContext ctx(this, &aInterceptor, aName, aHookDest, true);
140 
141     return ::InitOnceExecuteOnce(&mInitOnce, &InitOnceCallback, &ctx,
142                                  &addHookOk) &&
143            addHookOk;
144   }
145 
146   explicit operator bool() const { return !!mOrigFunc; }
147 
148   template <typename... ArgsType>
149   INTERCEPTOR_DISABLE_CFGUARD ReturnType operator()(ArgsType&&... aArgs) const {
150     return mOrigFunc(std::forward<ArgsType>(aArgs)...);
151   }
152 
153   FuncPtrT GetStub() const { return mOrigFunc; }
154 
155   // One-time init stuff cannot be moved or copied
156   FuncHook(const FuncHook&) = delete;
157   FuncHook(FuncHook&&) = delete;
158   FuncHook& operator=(const FuncHook&) = delete;
159   FuncHook& operator=(FuncHook&& aOther) = delete;
160 
161  private:
162   struct MOZ_RAII InitOnceContext final {
163     InitOnceContext(ThisType* aHook, InterceptorT* aInterceptor,
164                     const char* aName, FuncPtrT aHookDest, bool aForceDetour)
165         : mHook(aHook),
166           mInterceptor(aInterceptor),
167           mName(aName),
168           mHookDest(reinterpret_cast<void*>(aHookDest)),
169           mForceDetour(aForceDetour) {}
170 
171     ThisType* mHook;
172     InterceptorT* mInterceptor;
173     const char* mName;
174     void* mHookDest;
175     bool mForceDetour;
176   };
177 
178  private:
179   bool Apply(InterceptorT* aInterceptor, const char* aName, void* aHookDest) {
180     return aInterceptor->AddHook(aName, reinterpret_cast<intptr_t>(aHookDest),
181                                  reinterpret_cast<void**>(&mOrigFunc));
182   }
183 
184   bool ApplyDetour(InterceptorT* aInterceptor, const char* aName,
185                    void* aHookDest) {
186     return aInterceptor->AddDetour(aName, reinterpret_cast<intptr_t>(aHookDest),
187                                    reinterpret_cast<void**>(&mOrigFunc));
188   }
189 
190   static BOOL CALLBACK InitOnceCallback(PINIT_ONCE aInitOnce, PVOID aParam,
191                                         PVOID* aOutContext) {
192     MOZ_ASSERT(aOutContext);
193 
194     bool result;
195     auto ctx = reinterpret_cast<InitOnceContext*>(aParam);
196     if (ctx->mForceDetour) {
197       result = ctx->mHook->ApplyDetour(ctx->mInterceptor, ctx->mName,
198                                        ctx->mHookDest);
199     } else {
200       result = ctx->mHook->Apply(ctx->mInterceptor, ctx->mName, ctx->mHookDest);
201     }
202 
203     *aOutContext =
204         result ? reinterpret_cast<PVOID>(1U << INIT_ONCE_CTX_RESERVED_BITS)
205                : nullptr;
206     return TRUE;
207   }
208 
209  private:
210   FuncPtrT mOrigFunc;
211   INIT_ONCE mInitOnce;
212 };
213 
214 template <typename InterceptorT, typename FuncPtrT>
215 class MOZ_ONLY_USED_TO_AVOID_STATIC_CONSTRUCTORS FuncHookCrossProcess final {
216  public:
217   using ThisType = FuncHookCrossProcess<InterceptorT, FuncPtrT>;
218   using ReturnType = typename OriginalFunctionPtrTraits<FuncPtrT>::ReturnType;
219 
220 #if defined(DEBUG)
221   FuncHookCrossProcess() {}
222 #endif  // defined(DEBUG)
223 
224   bool Set(HANDLE aProcess, InterceptorT& aInterceptor, const char* aName,
225            FuncPtrT aHookDest) {
226     FuncPtrT origFunc;
227     if (!aInterceptor.AddHook(aName, reinterpret_cast<intptr_t>(aHookDest),
228                               reinterpret_cast<void**>(&origFunc))) {
229       return false;
230     }
231 
232     return CopyStubToChildProcess(origFunc, aProcess);
233   }
234 
235   bool SetDetour(HANDLE aProcess, InterceptorT& aInterceptor, const char* aName,
236                  FuncPtrT aHookDest) {
237     FuncPtrT origFunc;
238     if (!aInterceptor.AddDetour(aName, reinterpret_cast<intptr_t>(aHookDest),
239                                 reinterpret_cast<void**>(&origFunc))) {
240       return false;
241     }
242 
243     return CopyStubToChildProcess(origFunc, aProcess);
244   }
245 
246   explicit operator bool() const { return !!mOrigFunc; }
247 
248   /**
249    * NB: This operator is only meaningful when invoked in the target process!
250    */
251   template <typename... ArgsType>
252   ReturnType operator()(ArgsType&&... aArgs) const {
253     return mOrigFunc(std::forward<ArgsType>(aArgs)...);
254   }
255 
256 #if defined(DEBUG)
257   FuncHookCrossProcess(const FuncHookCrossProcess&) = delete;
258   FuncHookCrossProcess(FuncHookCrossProcess&&) = delete;
259   FuncHookCrossProcess& operator=(const FuncHookCrossProcess&) = delete;
260   FuncHookCrossProcess& operator=(FuncHookCrossProcess&& aOther) = delete;
261 #endif  // defined(DEBUG)
262 
263  private:
264   bool CopyStubToChildProcess(FuncPtrT aStub, HANDLE aProcess) {
265     SIZE_T bytesWritten;
266     return ::WriteProcessMemory(aProcess, &mOrigFunc, &aStub, sizeof(FuncPtrT),
267                                 &bytesWritten) &&
268            bytesWritten == sizeof(FuncPtrT);
269   }
270 
271  private:
272   FuncPtrT mOrigFunc;
273 };
274 
275 template <typename MMPolicyT, typename InterceptorT>
276 struct TypeResolver;
277 
278 template <typename InterceptorT>
279 struct TypeResolver<mozilla::interceptor::MMPolicyInProcess, InterceptorT> {
280   template <typename FuncPtrT>
281   using FuncHookType = FuncHook<InterceptorT, FuncPtrT>;
282 };
283 
284 template <typename InterceptorT>
285 struct TypeResolver<mozilla::interceptor::MMPolicyOutOfProcess, InterceptorT> {
286   template <typename FuncPtrT>
287   using FuncHookType = FuncHookCrossProcess<InterceptorT, FuncPtrT>;
288 };
289 
290 template <typename VMPolicy = mozilla::interceptor::VMSharingPolicyShared<
291               mozilla::interceptor::MMPolicyInProcess, true>>
292 class WindowsDllInterceptor final
293     : public TypeResolver<typename VMPolicy::MMPolicyT,
294                           WindowsDllInterceptor<VMPolicy>> {
295   typedef WindowsDllInterceptor<VMPolicy> ThisType;
296 
297   interceptor::WindowsDllDetourPatcher<VMPolicy> mDetourPatcher;
298 #if defined(_M_IX86)
299   interceptor::WindowsDllNopSpacePatcher<typename VMPolicy::MMPolicyT>
300       mNopSpacePatcher;
301 #endif  // defined(_M_IX86)
302 
303   HMODULE mModule;
304 
305  public:
306   template <typename... Args>
307   explicit WindowsDllInterceptor(Args&&... aArgs)
308       : mDetourPatcher(std::forward<Args>(aArgs)...)
309 #if defined(_M_IX86)
310         ,
311         mNopSpacePatcher(std::forward<Args>(aArgs)...)
312 #endif  // defined(_M_IX86)
313         ,
314         mModule(nullptr) {
315   }
316 
317   WindowsDllInterceptor(const WindowsDllInterceptor&) = delete;
318   WindowsDllInterceptor(WindowsDllInterceptor&&) = delete;
319   WindowsDllInterceptor& operator=(const WindowsDllInterceptor&) = delete;
320   WindowsDllInterceptor& operator=(WindowsDllInterceptor&&) = delete;
321 
322   ~WindowsDllInterceptor() { Clear(); }
323 
324   template <size_t N>
325   void Init(const char (&aModuleName)[N]) {
326     wchar_t moduleName[N];
327 
328     for (size_t i = 0; i < N; ++i) {
329       MOZ_ASSERT(!(aModuleName[i] & 0x80),
330                  "Use wide-character overload for non-ASCII module names");
331       moduleName[i] = aModuleName[i];
332     }
333 
334     Init(moduleName);
335   }
336 
337   void Init(const wchar_t* aModuleName) {
338     if (mModule) {
339       return;
340     }
341 
342     mModule = ::LoadLibraryW(aModuleName);
343   }
344 
345   /** Force a specific configuration for testing purposes. NOT to be used in
346       production code! **/
347   void TestOnlyDetourInit(const wchar_t* aModuleName, DetourFlags aFlags) {
348     Init(aModuleName);
349     mDetourPatcher.Init(aFlags);
350   }
351 
352   void Clear() {
353     if (!mModule) {
354       return;
355     }
356 
357 #if defined(_M_IX86)
358     mNopSpacePatcher.Clear();
359 #endif  // defined(_M_IX86)
360     mDetourPatcher.Clear();
361 
362     // NB: We intentionally leak mModule
363   }
364 
365   constexpr static uint32_t GetWorstCaseRequiredBytesToPatch() {
366     return WindowsDllDetourPatcherPrimitive<
367         typename VMPolicy::MMPolicyT>::GetWorstCaseRequiredBytesToPatch();
368   }
369 
370  private:
371   /**
372    * Hook/detour the method aName from the DLL we set in Init so that it calls
373    * aHookDest instead.  Returns the original method pointer in aOrigFunc
374    * and returns true if successful.
375    *
376    * IMPORTANT: If you use this method, please add your case to the
377    * TestDllInterceptor in order to detect future failures.  Even if this
378    * succeeds now, updates to the hooked DLL could cause it to fail in
379    * the future.
380    */
381   bool AddHook(const char* aName, intptr_t aHookDest, void** aOrigFunc) {
382     // Use a nop space patch if possible, otherwise fall back to a detour.
383     // This should be the preferred method for adding hooks.
384     if (!mModule) {
385       return false;
386     }
387 
388     FARPROC proc = mDetourPatcher.GetProcAddress(mModule, aName);
389     if (!proc) {
390       return false;
391     }
392 
393 #if defined(_M_IX86)
394     if (mNopSpacePatcher.AddHook(proc, aHookDest, aOrigFunc)) {
395       return true;
396     }
397 #endif  // defined(_M_IX86)
398 
399     return AddDetour(proc, aHookDest, aOrigFunc);
400   }
401 
402   /**
403    * Detour the method aName from the DLL we set in Init so that it calls
404    * aHookDest instead.  Returns the original method pointer in aOrigFunc
405    * and returns true if successful.
406    *
407    * IMPORTANT: If you use this method, please add your case to the
408    * TestDllInterceptor in order to detect future failures.  Even if this
409    * succeeds now, updates to the detoured DLL could cause it to fail in
410    * the future.
411    */
412   bool AddDetour(const char* aName, intptr_t aHookDest, void** aOrigFunc) {
413     // Generally, code should not call this method directly. Use AddHook unless
414     // there is a specific need to avoid nop space patches.
415     if (!mModule) {
416       return false;
417     }
418 
419     FARPROC proc = mDetourPatcher.GetProcAddress(mModule, aName);
420     if (!proc) {
421       return false;
422     }
423 
424     return AddDetour(proc, aHookDest, aOrigFunc);
425   }
426 
427   bool AddDetour(FARPROC aProc, intptr_t aHookDest, void** aOrigFunc) {
428     MOZ_ASSERT(mModule && aProc);
429 
430     if (!mDetourPatcher.Initialized()) {
431       DetourFlags flags = DetourFlags::eDefault;
432 #if defined(_M_X64)
433       // NTDLL hooks should attempt to use a 10-byte patch because some
434       // injected DLLs do the same and interfere with our stuff.
435       bool needs10BytePatch = (mModule == ::GetModuleHandleW(L"ntdll.dll"));
436 
437       bool isWin8Or81 = IsWin8OrLater() && (!IsWin10OrLater());
438       bool isWin8 = IsWin8OrLater() && (!IsWin8Point1OrLater());
439 
440       bool isKernel32Dll = (mModule == ::GetModuleHandleW(L"kernel32.dll"));
441 
442       // CloseHandle on Windows 8/8.1 only accomodates 10-byte patches.
443       needs10BytePatch |= isWin8Or81 && isKernel32Dll &&
444                           (reinterpret_cast<void*>(aProc) ==
445                            reinterpret_cast<void*>(&CloseHandle));
446 
447       // CreateFileA and DuplicateHandle on Windows 8 require 10-byte patches.
448       needs10BytePatch |= isWin8 && isKernel32Dll &&
449                           ((reinterpret_cast<void*>(aProc) ==
450                             reinterpret_cast<void*>(&::CreateFileA)) ||
451                            (reinterpret_cast<void*>(aProc) ==
452                             reinterpret_cast<void*>(&::DuplicateHandle)));
453 
454       if (needs10BytePatch) {
455         flags |= DetourFlags::eEnable10BytePatch;
456       }
457 #endif  // defined(_M_X64)
458 
459       mDetourPatcher.Init(flags);
460     }
461 
462     return mDetourPatcher.AddHook(aProc, aHookDest, aOrigFunc);
463   }
464 
465  private:
466   template <typename InterceptorT, typename FuncPtrT>
467   friend class FuncHook;
468 
469   template <typename InterceptorT, typename FuncPtrT>
470   friend class FuncHookCrossProcess;
471 };
472 
473 /**
474  * IAT patching is intended for use when we only want to intercept a function
475  * call originating from a specific module.
476  */
477 class WindowsIATPatcher final {
478  public:
479   template <typename FuncPtrT>
480   using FuncHookType = FuncHook<WindowsIATPatcher, FuncPtrT>;
481 
482  private:
483   static bool CheckASCII(const char* aInStr) {
484     while (*aInStr) {
485       if (*aInStr & 0x80) {
486         return false;
487       }
488       ++aInStr;
489     }
490     return true;
491   }
492 
493   static bool AddHook(HMODULE aFromModule, const char* aToModuleName,
494                       const char* aTargetFnName, void* aHookDest,
495                       Atomic<void*>* aOutOrigFunc) {
496     if (!aFromModule || !aToModuleName || !aTargetFnName || !aOutOrigFunc) {
497       return false;
498     }
499 
500     // PE Spec requires ASCII names for imported module names
501     const bool isModuleNameAscii = CheckASCII(aToModuleName);
502     MOZ_ASSERT(isModuleNameAscii);
503     if (!isModuleNameAscii) {
504       return false;
505     }
506 
507     // PE Spec requires ASCII names for imported function names
508     const bool isTargetFnNameAscii = CheckASCII(aTargetFnName);
509     MOZ_ASSERT(isTargetFnNameAscii);
510     if (!isTargetFnNameAscii) {
511       return false;
512     }
513 
514     nt::PEHeaders headers(aFromModule);
515     if (!headers) {
516       return false;
517     }
518 
519     PIMAGE_IMPORT_DESCRIPTOR impDesc =
520         headers.GetImportDescriptor(aToModuleName);
521     if (!nt::PEHeaders::IsValid(impDesc)) {
522       // Either aFromModule does not import aToModuleName at load-time, or
523       // aToModuleName is a (currently unsupported) delay-load import.
524       return false;
525     }
526 
527     // Resolve the import name table (INT).
528     auto firstINTThunk = headers.template RVAToPtr<PIMAGE_THUNK_DATA>(
529         impDesc->OriginalFirstThunk);
530     if (!nt::PEHeaders::IsValid(firstINTThunk)) {
531       return false;
532     }
533 
534     Maybe<ptrdiff_t> thunkIndex;
535 
536     // Scan the INT for the location of the thunk for the function named
537     // 'aTargetFnName'.
538     for (PIMAGE_THUNK_DATA curINTThunk = firstINTThunk;
539          nt::PEHeaders::IsValid(curINTThunk); ++curINTThunk) {
540       if (IMAGE_SNAP_BY_ORDINAL(curINTThunk->u1.Ordinal)) {
541         // Currently not supporting import by ordinal; this isn't hard to add,
542         // but we won't bother unless necessary.
543         continue;
544       }
545 
546       PIMAGE_IMPORT_BY_NAME curThunkFnName =
547           headers.template RVAToPtr<PIMAGE_IMPORT_BY_NAME>(
548               curINTThunk->u1.AddressOfData);
549       MOZ_ASSERT(curThunkFnName);
550       if (!curThunkFnName) {
551         // Looks like we have a bad name descriptor. Try to continue.
552         continue;
553       }
554 
555       // Function name checks MUST be case-sensitive!
556       if (!strcmp(aTargetFnName, curThunkFnName->Name)) {
557         // We found the thunk. Save the index of this thunk, as the IAT thunk
558         // is located at the same index in that table as in the INT.
559         thunkIndex = Some(curINTThunk - firstINTThunk);
560         break;
561       }
562     }
563 
564     if (thunkIndex.isNothing()) {
565       // We never found a thunk for that function. Perhaps it's not imported?
566       return false;
567     }
568 
569     if (thunkIndex.value() < 0) {
570       // That's just wrong.
571       return false;
572     }
573 
574     auto firstIATThunk =
575         headers.template RVAToPtr<PIMAGE_THUNK_DATA>(impDesc->FirstThunk);
576     if (!nt::PEHeaders::IsValid(firstIATThunk)) {
577       return false;
578     }
579 
580     // Resolve the IAT thunk for the function we want
581     PIMAGE_THUNK_DATA targetThunk = &firstIATThunk[thunkIndex.value()];
582     if (!nt::PEHeaders::IsValid(targetThunk)) {
583       return false;
584     }
585 
586     auto fnPtr = reinterpret_cast<Atomic<void*>*>(&targetThunk->u1.Function);
587 
588     // Now we can just change out its pointer with our hook function.
589     AutoVirtualProtect prot(fnPtr, sizeof(void*), PAGE_EXECUTE_READWRITE);
590     if (!prot) {
591       return false;
592     }
593 
594     // We do the exchange this way to ensure that *aOutOrigFunc is always valid
595     // once the atomic exchange has taken place.
596     void* tmp;
597 
598     do {
599       tmp = *fnPtr;
600       *aOutOrigFunc = tmp;
601     } while (!fnPtr->compareExchange(tmp, aHookDest));
602 
603     return true;
604   }
605 
606   template <typename InterceptorT, typename FuncPtrT>
607   friend class FuncHook;
608 };
609 
610 template <typename FuncPtrT>
611 class MOZ_ONLY_USED_TO_AVOID_STATIC_CONSTRUCTORS
612     FuncHook<WindowsIATPatcher, FuncPtrT>
613         final {
614  public:
615   using ThisType = FuncHook<WindowsIATPatcher, FuncPtrT>;
616   using ReturnType = typename OriginalFunctionPtrTraits<FuncPtrT>::ReturnType;
617 
618   constexpr FuncHook()
619       : mInitOnce(INIT_ONCE_STATIC_INIT),
620         mFromModule(nullptr),
621         mOrigFunc(nullptr) {}
622 
623 #if defined(DEBUG)
624   ~FuncHook() = default;
625 #endif  // defined(DEBUG)
626 
627   bool Set(const wchar_t* aFromModuleName, const char* aToModuleName,
628            const char* aFnName, FuncPtrT aHookDest) {
629     nsModuleHandle fromModule(::LoadLibraryW(aFromModuleName));
630     if (!fromModule) {
631       return false;
632     }
633 
634     return Set(fromModule, aToModuleName, aFnName, aHookDest);
635   }
636 
637   // We offer this overload in case the client wants finer-grained control over
638   // loading aFromModule.
639   bool Set(nsModuleHandle& aFromModule, const char* aToModuleName,
640            const char* aFnName, FuncPtrT aHookDest) {
641     LPVOID addHookOk = nullptr;
642     InitOnceContext ctx(this, aFromModule, aToModuleName, aFnName, aHookDest);
643 
644     bool result = ::InitOnceExecuteOnce(&mInitOnce, &InitOnceCallback, &ctx,
645                                         &addHookOk) &&
646                   addHookOk;
647     if (!result) {
648       return result;
649     }
650 
651     // If we successfully set the hook then we must retain a strong reference
652     // to the module that we modified.
653     mFromModule = aFromModule.disown();
654     return result;
655   }
656 
657   explicit operator bool() const { return !!mOrigFunc; }
658 
659   template <typename... ArgsType>
660   ReturnType operator()(ArgsType&&... aArgs) const {
661     return mOrigFunc(std::forward<ArgsType>(aArgs)...);
662   }
663 
664   FuncPtrT GetStub() const { return mOrigFunc; }
665 
666 #if defined(DEBUG)
667   // One-time init stuff cannot be moved or copied
668   FuncHook(const FuncHook&) = delete;
669   FuncHook(FuncHook&&) = delete;
670   FuncHook& operator=(const FuncHook&) = delete;
671   FuncHook& operator=(FuncHook&& aOther) = delete;
672 #endif  // defined(DEBUG)
673 
674  private:
675   struct MOZ_RAII InitOnceContext final {
676     InitOnceContext(ThisType* aHook, const nsModuleHandle& aFromModule,
677                     const char* aToModuleName, const char* aFnName,
678                     FuncPtrT aHookDest)
679         : mHook(aHook),
680           mFromModule(aFromModule),
681           mToModuleName(aToModuleName),
682           mFnName(aFnName),
683           mHookDest(reinterpret_cast<void*>(aHookDest)) {}
684 
685     ThisType* mHook;
686     const nsModuleHandle& mFromModule;
687     const char* mToModuleName;
688     const char* mFnName;
689     void* mHookDest;
690   };
691 
692  private:
693   bool Apply(const nsModuleHandle& aFromModule, const char* aToModuleName,
694              const char* aFnName, void* aHookDest) {
695     return WindowsIATPatcher::AddHook(
696         aFromModule, aToModuleName, aFnName, aHookDest,
697         reinterpret_cast<Atomic<void*>*>(&mOrigFunc));
698   }
699 
700   static BOOL CALLBACK InitOnceCallback(PINIT_ONCE aInitOnce, PVOID aParam,
701                                         PVOID* aOutContext) {
702     MOZ_ASSERT(aOutContext);
703 
704     auto ctx = reinterpret_cast<InitOnceContext*>(aParam);
705     bool result = ctx->mHook->Apply(ctx->mFromModule, ctx->mToModuleName,
706                                     ctx->mFnName, ctx->mHookDest);
707 
708     *aOutContext =
709         result ? reinterpret_cast<PVOID>(1U << INIT_ONCE_CTX_RESERVED_BITS)
710                : nullptr;
711     return TRUE;
712   }
713 
714  private:
715   INIT_ONCE mInitOnce;
716   HMODULE mFromModule;  // never freed
717   FuncPtrT mOrigFunc;
718 };
719 
720 /**
721  * This class applies an irreversible patch to jump to a target function
722  * without backing up the original function.
723  */
724 class WindowsDllEntryPointInterceptor final {
725   using DllMainFn = BOOL(WINAPI*)(HINSTANCE, DWORD, LPVOID);
726   using MMPolicyT = MMPolicyInProcessEarlyStage;
727 
728   MMPolicyT mMMPolicy;
729 
730  public:
731   explicit WindowsDllEntryPointInterceptor(
732       const MMPolicyT::Kernel32Exports& aK32Exports)
733       : mMMPolicy(aK32Exports) {}
734 
735   bool Set(const nt::PEHeaders& aHeaders, DllMainFn aDestination) {
736     if (!aHeaders) {
737       return false;
738     }
739 
740     WindowsDllDetourPatcherPrimitive<MMPolicyT> patcher;
741     return patcher.AddIrreversibleHook(
742         mMMPolicy, aHeaders.GetEntryPoint(),
743         reinterpret_cast<uintptr_t>(aDestination));
744   }
745 };
746 
747 }  // namespace interceptor
748 
749 using WindowsDllInterceptor = interceptor::WindowsDllInterceptor<>;
750 
751 using CrossProcessDllInterceptor = interceptor::WindowsDllInterceptor<
752     mozilla::interceptor::VMSharingPolicyUnique<
753         mozilla::interceptor::MMPolicyOutOfProcess>>;
754 
755 using WindowsIATPatcher = interceptor::WindowsIATPatcher;
756 
757 }  // namespace mozilla
758 
759 #endif /* NS_WINDOWS_DLL_INTERCEPTOR_H_ */
760