1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4  * License, v. 2.0. If a copy of the MPL was not distributed with this
5  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6 
7 #include <shlobj.h>
8 #include <stdio.h>
9 #include <commdlg.h>
10 #define SECURITY_WIN32
11 #include <security.h>
12 #include <wininet.h>
13 #include <schnlsp.h>
14 #include <winternl.h>
15 #include <processthreadsapi.h>
16 
17 #include "AssemblyPayloads.h"
18 #include "mozilla/DynamicallyLinkedFunctionPtr.h"
19 #include "mozilla/UniquePtr.h"
20 #include "mozilla/WindowsVersion.h"
21 #include "nsWindowsDllInterceptor.h"
22 #include "nsWindowsHelpers.h"
23 
24 NTSTATUS NTAPI NtFlushBuffersFile(HANDLE, PIO_STATUS_BLOCK);
25 NTSTATUS NTAPI NtReadFile(HANDLE, HANDLE, PIO_APC_ROUTINE, PVOID,
26                           PIO_STATUS_BLOCK, PVOID, ULONG, PLARGE_INTEGER,
27                           PULONG);
28 NTSTATUS NTAPI NtReadFileScatter(HANDLE, HANDLE, PIO_APC_ROUTINE, PVOID,
29                                  PIO_STATUS_BLOCK, PFILE_SEGMENT_ELEMENT, ULONG,
30                                  PLARGE_INTEGER, PULONG);
31 NTSTATUS NTAPI NtWriteFile(HANDLE, HANDLE, PIO_APC_ROUTINE, PVOID,
32                            PIO_STATUS_BLOCK, PVOID, ULONG, PLARGE_INTEGER,
33                            PULONG);
34 NTSTATUS NTAPI NtWriteFileGather(HANDLE, HANDLE, PIO_APC_ROUTINE, PVOID,
35                                  PIO_STATUS_BLOCK, PFILE_SEGMENT_ELEMENT, ULONG,
36                                  PLARGE_INTEGER, PULONG);
37 NTSTATUS NTAPI NtQueryFullAttributesFile(POBJECT_ATTRIBUTES, PVOID);
38 NTSTATUS NTAPI LdrLoadDll(PWCHAR filePath, PULONG flags,
39                           PUNICODE_STRING moduleFileName, PHANDLE handle);
40 NTSTATUS NTAPI LdrUnloadDll(HMODULE);
41 
42 NTSTATUS NTAPI NtMapViewOfSection(
43     HANDLE aSection, HANDLE aProcess, PVOID* aBaseAddress, ULONG_PTR aZeroBits,
44     SIZE_T aCommitSize, PLARGE_INTEGER aSectionOffset, PSIZE_T aViewSize,
45     SECTION_INHERIT aInheritDisposition, ULONG aAllocationType,
46     ULONG aProtectionFlags);
47 
48 // These pointers are disguised as PVOID to avoid pulling in obscure headers
49 PVOID NTAPI LdrResolveDelayLoadedAPI(PVOID, PVOID, PVOID, PVOID, PVOID, ULONG);
50 void CALLBACK ProcessCaretEvents(HWINEVENTHOOK, DWORD, HWND, LONG, LONG, DWORD,
51                                  DWORD);
52 void __fastcall BaseThreadInitThunk(BOOL aIsInitialThread, void* aStartAddress,
53                                     void* aThreadParam);
54 
55 BOOL WINAPI ApiSetQueryApiSetPresence(PCUNICODE_STRING, PBOOLEAN);
56 
57 #if (_WIN32_WINNT < 0x0602)
58 BOOL WINAPI
59 SetProcessMitigationPolicy(PROCESS_MITIGATION_POLICY aMitigationPolicy,
60                            PVOID aBuffer, SIZE_T aBufferLen);
61 #endif  // (_WIN32_WINNT < 0x0602)
62 
63 using namespace mozilla;
64 
65 struct payload {
66   UINT64 a;
67   UINT64 b;
68   UINT64 c;
69 
operator ==payload70   bool operator==(const payload& other) const {
71     return (a == other.a && b == other.b && c == other.c);
72   }
73 };
74 
75 extern "C" __declspec(dllexport) __declspec(noinline) payload
rotatePayload(payload p)76     rotatePayload(payload p) {
77   UINT64 tmp = p.a;
78   p.a = p.b;
79   p.b = p.c;
80   p.c = tmp;
81   return p;
82 }
83 
84 // payloadNotHooked is a target function for a test to expect a negative result.
85 // We cannot use rotatePayload for that purpose because our detour cannot hook
86 // a function detoured already.  Please keep this function always unhooked.
87 extern "C" __declspec(dllexport) __declspec(noinline) payload
payloadNotHooked(payload p)88     payloadNotHooked(payload p) {
89   // Do something different from rotatePayload to avoid ICF.
90   p.a ^= p.b;
91   p.b ^= p.c;
92   p.c ^= p.a;
93   return p;
94 }
95 
96 static bool patched_func_called = false;
97 
98 static WindowsDllInterceptor::FuncHookType<decltype(&rotatePayload)>
99     orig_rotatePayload;
100 
101 static WindowsDllInterceptor::FuncHookType<decltype(&payloadNotHooked)>
102     orig_payloadNotHooked;
103 
patched_rotatePayload(payload p)104 static payload patched_rotatePayload(payload p) {
105   patched_func_called = true;
106   return orig_rotatePayload(p);
107 }
108 
109 // Invoke aFunc by taking aArg's contents and using them as aFunc's arguments
110 template <typename OrigFuncT, typename... Args,
111           typename ArgTuple = Tuple<Args...>, size_t... Indices>
Apply(OrigFuncT & aFunc,ArgTuple && aArgs,std::index_sequence<Indices...>)112 decltype(auto) Apply(OrigFuncT& aFunc, ArgTuple&& aArgs,
113                      std::index_sequence<Indices...>) {
114   return aFunc(Get<Indices>(std::forward<ArgTuple>(aArgs))...);
115 }
116 
117 #define DEFINE_TEST_FUNCTION(calling_convention)                               \
118   template <typename R, typename... Args, typename... TestArgs>                \
119   bool TestFunction(R(calling_convention* aFunc)(Args...), bool (*aPred)(R),   \
120                     TestArgs&&... aArgs) {                                     \
121     using ArgTuple = Tuple<Args...>;                                           \
122     using Indices = std::index_sequence_for<Args...>;                          \
123     ArgTuple fakeArgs{std::forward<TestArgs>(aArgs)...};                       \
124     patched_func_called = false;                                               \
125     return aPred(Apply(aFunc, std::forward<ArgTuple>(fakeArgs), Indices())) && \
126            patched_func_called;                                                \
127   }                                                                            \
128                                                                                \
129   /* Specialization for functions returning void */                            \
130   template <typename PredT, typename... Args, typename... TestArgs>            \
131   bool TestFunction(void(calling_convention * aFunc)(Args...), PredT,          \
132                     TestArgs&&... aArgs) {                                     \
133     using ArgTuple = Tuple<Args...>;                                           \
134     using Indices = std::index_sequence_for<Args...>;                          \
135     ArgTuple fakeArgs{std::forward<TestArgs>(aArgs)...};                       \
136     patched_func_called = false;                                               \
137     Apply(aFunc, std::forward<ArgTuple>(fakeArgs), Indices());                 \
138     return patched_func_called;                                                \
139   }
140 
141 // C++11 allows empty arguments to macros. clang works just fine. MSVC does the
142 // right thing, but it also throws up warning C4003.
143 #if defined(_MSC_VER) && !defined(__clang__)
144 DEFINE_TEST_FUNCTION(__cdecl)
145 #else
146 DEFINE_TEST_FUNCTION()
147 #endif
148 
149 #ifdef _M_IX86
DEFINE_TEST_FUNCTION(__stdcall)150 DEFINE_TEST_FUNCTION(__stdcall)
151 DEFINE_TEST_FUNCTION(__fastcall)
152 #endif  // _M_IX86
153 
154 // Test the hooked function against the supplied predicate
155 template <typename OrigFuncT, typename PredicateT, typename... Args>
156 bool CheckHook(OrigFuncT& aOrigFunc, const char* aDllName,
157                const char* aFuncName, PredicateT&& aPred, Args&&... aArgs) {
158   if (TestFunction(aOrigFunc, std::forward<PredicateT>(aPred),
159                    std::forward<Args>(aArgs)...)) {
160     printf(
161         "TEST-PASS | WindowsDllInterceptor | "
162         "Executed hooked function %s from %s\n",
163         aFuncName, aDllName);
164     fflush(stdout);
165     return true;
166   }
167   printf(
168       "TEST-FAILED | WindowsDllInterceptor | "
169       "Failed to execute hooked function %s from %s\n",
170       aFuncName, aDllName);
171   return false;
172 }
173 
174 struct InterceptorFunction {
175   static const size_t EXEC_MEMBLOCK_SIZE = 64 * 1024;  // 64K
176 
CreateInterceptorFunction177   static InterceptorFunction& Create() {
178     // Make sure the executable memory is allocated
179     if (!sBlock) {
180       Init();
181     }
182     MOZ_ASSERT(sBlock);
183 
184     // Make sure we aren't making more functions than we allocated room for
185     MOZ_RELEASE_ASSERT((sNumInstances + 1) * sizeof(InterceptorFunction) <=
186                        EXEC_MEMBLOCK_SIZE);
187 
188     // Grab the next InterceptorFunction from executable memory
189     InterceptorFunction& ret = *reinterpret_cast<InterceptorFunction*>(
190         sBlock + (sNumInstances++ * sizeof(InterceptorFunction)));
191 
192     // Set the InterceptorFunction to the code template.
193     auto funcCode = &ret[0];
194     memcpy(funcCode, sInterceptorTemplate, TemplateLength);
195 
196     // Fill in the patched_func_called pointer in the template.
197     auto pfPtr = reinterpret_cast<bool**>(&ret[PatchedFuncCalledIndex]);
198     *pfPtr = &patched_func_called;
199     return ret;
200   }
201 
operator []InterceptorFunction202   uint8_t& operator[](size_t i) { return mFuncCode[i]; }
203 
GetFunctionInterceptorFunction204   uint8_t* GetFunction() { return mFuncCode; }
205 
SetStubInterceptorFunction206   void SetStub(uintptr_t aStub) {
207     auto pfPtr = reinterpret_cast<uintptr_t*>(&mFuncCode[StubFuncIndex]);
208     *pfPtr = aStub;
209   }
210 
211  private:
212   // We intercept functions with short machine-code functions that set a boolean
213   // and run the stub that launches the original function.  Each entry in the
214   // array is the code for one of those interceptor functions.  We cannot
215   // free this memory until the test shuts down.
216   // The templates have spots for the address of patched_func_called
217   // and for the address of the stub function.  Their indices in the byte
218   // array are given as constants below and they appear as blocks of
219   // 0xff bytes in the templates.
220 #if defined(_M_X64)
221   //  0: 48 b8 ff ff ff ff ff ff ff ff    movabs rax, &patched_func_called
222   //  a: c6 00 01                         mov    BYTE PTR [rax],0x1
223   //  d: 48 b8 ff ff ff ff ff ff ff ff    movabs rax, &stub_func_ptr
224   // 17: ff e0                            jmp    rax
225   static constexpr uint8_t sInterceptorTemplate[] = {
226       0x48, 0xB8, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
227       0xFF, 0xC6, 0x00, 0x01, 0x48, 0xB8, 0xFF, 0xFF, 0xFF,
228       0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xE0};
229   static const size_t PatchedFuncCalledIndex = 0x2;
230   static const size_t StubFuncIndex = 0xf;
231 #elif defined(_M_IX86)
232   // 0: c6 05 ff ff ff ff 01     mov    BYTE PTR &patched_func_called, 0x1
233   // 7: 68 ff ff ff ff           push   &stub_func_ptr
234   // c: c3                       ret
235   static constexpr uint8_t sInterceptorTemplate[] = {
236       0xC6, 0x05, 0xFF, 0xFF, 0xFF, 0xFF, 0x01,
237       0x68, 0xFF, 0xFF, 0xFF, 0xFF, 0xC3};
238   static const size_t PatchedFuncCalledIndex = 0x2;
239   static const size_t StubFuncIndex = 0x8;
240 #elif defined(_M_ARM64)
241   //  0: 31 00 80 52    movz w17, #0x1
242   //  4: 90 00 00 58    ldr  x16, #16
243   //  8: 11 02 00 39    strb w17, [x16]
244   //  c: 90 00 00 58    ldr  x16, #16
245   // 10: 00 02 1F D6    br   x16
246   // 14: &patched_func_called
247   // 1c: &stub_func_ptr
248   static constexpr uint8_t sInterceptorTemplate[] = {
249       0x31, 0x00, 0x80, 0x52, 0x90, 0x00, 0x00, 0x58, 0x11, 0x02, 0x00, 0x39,
250       0x90, 0x00, 0x00, 0x58, 0x00, 0x02, 0x1F, 0xD6, 0xFF, 0xFF, 0xFF, 0xFF,
251       0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF};
252   static const size_t PatchedFuncCalledIndex = 0x14;
253   static const size_t StubFuncIndex = 0x1c;
254 #else
255 #  error "Missing template for architecture"
256 #endif
257 
258   static const size_t TemplateLength = sizeof(sInterceptorTemplate);
259   uint8_t mFuncCode[TemplateLength];
260 
261   InterceptorFunction() = delete;
262   InterceptorFunction(const InterceptorFunction&) = delete;
263   InterceptorFunction& operator=(const InterceptorFunction&) = delete;
264 
InitInterceptorFunction265   static void Init() {
266     MOZ_ASSERT(!sBlock);
267     sBlock = reinterpret_cast<uint8_t*>(
268         ::VirtualAlloc(nullptr, EXEC_MEMBLOCK_SIZE, MEM_RESERVE | MEM_COMMIT,
269                        PAGE_EXECUTE_READWRITE));
270   }
271 
272   static uint8_t* sBlock;
273   static size_t sNumInstances;
274 };
275 
276 uint8_t* InterceptorFunction::sBlock = nullptr;
277 size_t InterceptorFunction::sNumInstances = 0;
278 
279 constexpr uint8_t InterceptorFunction::sInterceptorTemplate[];
280 
281 // Hook the function and optionally attempt calling it
282 template <typename OrigFuncT, size_t N, typename PredicateT, typename... Args>
TestHook(const char (& dll)[N],const char * func,PredicateT && aPred,Args &&...aArgs)283 bool TestHook(const char (&dll)[N], const char* func, PredicateT&& aPred,
284               Args&&... aArgs) {
285   auto orig_func(
286       mozilla::MakeUnique<WindowsDllInterceptor::FuncHookType<OrigFuncT>>());
287 
288   bool successful = false;
289   WindowsDllInterceptor TestIntercept;
290   TestIntercept.Init(dll);
291 
292   InterceptorFunction& interceptorFunc = InterceptorFunction::Create();
293   successful = orig_func->Set(
294       TestIntercept, func,
295       reinterpret_cast<OrigFuncT>(interceptorFunc.GetFunction()));
296 
297   if (successful) {
298     interceptorFunc.SetStub(reinterpret_cast<uintptr_t>(orig_func->GetStub()));
299     printf("TEST-PASS | WindowsDllInterceptor | Could hook %s from %s\n", func,
300            dll);
301     fflush(stdout);
302     if (!aPred) {
303       printf(
304           "TEST-SKIPPED | WindowsDllInterceptor | "
305           "Will not attempt to execute patched %s.\n",
306           func);
307       fflush(stdout);
308       return true;
309     }
310 
311     // Test the DLL function we just hooked.
312     HMODULE module = ::LoadLibrary(dll);
313     FARPROC funcAddr = ::GetProcAddress(module, func);
314     if (!funcAddr) {
315       return false;
316     }
317 
318     return CheckHook(reinterpret_cast<OrigFuncT&>(funcAddr), dll, func,
319                      std::forward<PredicateT>(aPred),
320                      std::forward<Args>(aArgs)...);
321   } else {
322     printf(
323         "TEST-UNEXPECTED-FAIL | WindowsDllInterceptor | Failed to hook %s from "
324         "%s\n",
325         func, dll);
326     fflush(stdout);
327 
328     // Print out the function's bytes so that we can easily analyze the error.
329     nsModuleHandle mod(::LoadLibrary(dll));
330     FARPROC funcAddr = ::GetProcAddress(mod, func);
331     if (funcAddr) {
332       const uint32_t kNumBytesToDump =
333           WindowsDllInterceptor::GetWorstCaseRequiredBytesToPatch();
334 
335       printf("\tFirst %u bytes of function:\n\t", kNumBytesToDump);
336 
337       auto code = reinterpret_cast<const uint8_t*>(funcAddr);
338       for (uint32_t i = 0; i < kNumBytesToDump; ++i) {
339         char suffix = (i < (kNumBytesToDump - 1)) ? ' ' : '\n';
340         printf("%02hhX%c", code[i], suffix);
341       }
342 
343       fflush(stdout);
344     }
345     return false;
346   }
347 }
348 
349 // Detour the function and optionally attempt calling it
350 template <typename OrigFuncT, size_t N, typename PredicateT>
TestDetour(const char (& dll)[N],const char * func,PredicateT && aPred)351 bool TestDetour(const char (&dll)[N], const char* func, PredicateT&& aPred) {
352   auto orig_func(
353       mozilla::MakeUnique<WindowsDllInterceptor::FuncHookType<OrigFuncT>>());
354 
355   bool successful = false;
356   WindowsDllInterceptor TestIntercept;
357   TestIntercept.Init(dll);
358 
359   InterceptorFunction& interceptorFunc = InterceptorFunction::Create();
360   successful = orig_func->Set(
361       TestIntercept, func,
362       reinterpret_cast<OrigFuncT>(interceptorFunc.GetFunction()));
363 
364   if (successful) {
365     interceptorFunc.SetStub(reinterpret_cast<uintptr_t>(orig_func->GetStub()));
366     printf("TEST-PASS | WindowsDllInterceptor | Could detour %s from %s\n",
367            func, dll);
368     fflush(stdout);
369     if (!aPred) {
370       printf(
371           "TEST-SKIPPED | WindowsDllInterceptor | "
372           "Will not attempt to execute patched %s.\n",
373           func);
374       fflush(stdout);
375       return true;
376     }
377 
378     // Test the DLL function we just hooked.
379     HMODULE module = ::LoadLibrary(dll);
380     FARPROC funcAddr = ::GetProcAddress(module, func);
381     if (!funcAddr) {
382       return false;
383     }
384 
385     return CheckHook(reinterpret_cast<OrigFuncT&>(funcAddr), dll, func,
386                      std::forward<PredicateT>(aPred));
387   } else {
388     printf(
389         "TEST-UNEXPECTED-FAIL | WindowsDllInterceptor | Failed to detour %s "
390         "from %s\n",
391         func, dll);
392     fflush(stdout);
393     return false;
394   }
395 }
396 
397 // If a function pointer's type returns void*, this template converts that type
398 // to return uintptr_t instead, for the purposes of predicates.
399 template <typename FuncT>
400 struct SubstituteForVoidPtr {
401   using Type = FuncT;
402 };
403 
404 template <typename... Args>
405 struct SubstituteForVoidPtr<void* (*)(Args...)> {
406   using Type = uintptr_t (*)(Args...);
407 };
408 
409 #ifdef _M_IX86
410 template <typename... Args>
411 struct SubstituteForVoidPtr<void*(__stdcall*)(Args...)> {
412   using Type = uintptr_t(__stdcall*)(Args...);
413 };
414 
415 template <typename... Args>
416 struct SubstituteForVoidPtr<void*(__fastcall*)(Args...)> {
417   using Type = uintptr_t(__fastcall*)(Args...);
418 };
419 #endif  // _M_IX86
420 
421 // Determines the function's return type
422 template <typename FuncT>
423 struct ReturnType;
424 
425 template <typename R, typename... Args>
426 struct ReturnType<R (*)(Args...)> {
427   using Type = R;
428 };
429 
430 #ifdef _M_IX86
431 template <typename R, typename... Args>
432 struct ReturnType<R(__stdcall*)(Args...)> {
433   using Type = R;
434 };
435 
436 template <typename R, typename... Args>
437 struct ReturnType<R(__fastcall*)(Args...)> {
438   using Type = R;
439 };
440 #endif  // _M_IX86
441 
442 // Predicates that may be supplied during tests
443 template <typename FuncT>
444 struct Predicates {
445   using ArgType = typename ReturnType<FuncT>::Type;
446 
447   template <ArgType CompVal>
EqualsPredicates448   static bool Equals(ArgType aValue) {
449     return CompVal == aValue;
450   }
451 
452   template <ArgType CompVal>
NotEqualsPredicates453   static bool NotEquals(ArgType aValue) {
454     return CompVal != aValue;
455   }
456 
457   template <ArgType CompVal>
IgnorePredicates458   static bool Ignore(ArgType aValue) {
459     return true;
460   }
461 };
462 
463 // Functions that return void should be ignored, so we specialize the
464 // Ignore predicate for that case. Use nullptr as the value to compare against.
465 template <typename... Args>
466 struct Predicates<void (*)(Args...)> {
467   template <nullptr_t DummyVal>
IgnorePredicates468   static bool Ignore() {
469     return true;
470   }
471 };
472 
473 #ifdef _M_IX86
474 template <typename... Args>
475 struct Predicates<void(__stdcall*)(Args...)> {
476   template <nullptr_t DummyVal>
IgnorePredicates477   static bool Ignore() {
478     return true;
479   }
480 };
481 
482 template <typename... Args>
483 struct Predicates<void(__fastcall*)(Args...)> {
484   template <nullptr_t DummyVal>
IgnorePredicates485   static bool Ignore() {
486     return true;
487   }
488 };
489 #endif  // _M_IX86
490 
491 // The standard test. Hook |func|, and then try executing it with all zero
492 // arguments, using |pred| and |comp| to determine whether the call successfully
493 // executed. In general, you want set pred and comp such that they return true
494 // when the function is returning whatever value is expected with all-zero
495 // arguments.
496 //
497 // Note: When |func| returns void, you must supply |Ignore| and |nullptr| as the
498 // |pred| and |comp| arguments, respectively.
499 #define TEST_HOOK(dll, func, pred, comp) \
500   TestHook<decltype(&func)>(dll, #func,  \
501                             &Predicates<decltype(&func)>::pred<comp>)
502 
503 // We need to special-case functions that return INVALID_HANDLE_VALUE
504 // (ie, CreateFile). Our template machinery for comparing values doesn't work
505 // with integer constants passed as pointers (well, it works on MSVC, but not
506 // clang, because that is not standard-compliant).
507 #define TEST_HOOK_FOR_INVALID_HANDLE_VALUE(dll, func)                   \
508   TestHook<SubstituteForVoidPtr<decltype(&func)>::Type>(                \
509       dll, #func,                                                       \
510       &Predicates<SubstituteForVoidPtr<decltype(&func)>::Type>::Equals< \
511           uintptr_t(-1)>)
512 
513 // This variant allows you to explicitly supply arguments to the hooked function
514 // during testing. You want to provide arguments that produce the conditions
515 // that induce the function to return a value that is accepted by your
516 // predicate.
517 #define TEST_HOOK_PARAMS(dll, func, pred, comp, ...) \
518   TestHook<decltype(&func)>(                         \
519       dll, #func, &Predicates<decltype(&func)>::pred<comp>, __VA_ARGS__)
520 
521 // This is for cases when we want to hook |func|, but it is unsafe to attempt
522 // to execute the function in the context of a test.
523 #define TEST_HOOK_SKIP_EXEC(dll, func)                                        \
524   TestHook<decltype(&func)>(                                                  \
525       dll, #func,                                                             \
526       reinterpret_cast<bool (*)(typename ReturnType<decltype(&func)>::Type)>( \
527           NULL))
528 
529 // The following three variants are identical to the previous macros,
530 // however the forcibly use a Detour on 32-bit Windows. On 64-bit Windows,
531 // these macros are identical to their TEST_HOOK variants.
532 #define TEST_DETOUR(dll, func, pred, comp) \
533   TestDetour<decltype(&func)>(dll, #func,  \
534                               &Predicates<decltype(&func)>::pred<comp>)
535 
536 #define TEST_DETOUR_PARAMS(dll, func, pred, comp, ...) \
537   TestDetour<decltype(&func)>(                         \
538       dll, #func, &Predicates<decltype(&func)>::pred<comp>, __VA_ARGS__)
539 
540 #define TEST_DETOUR_SKIP_EXEC(dll, func)                                      \
541   TestDetour<decltype(&func)>(                                                \
542       dll, #func,                                                             \
543       reinterpret_cast<bool (*)(typename ReturnType<decltype(&func)>::Type)>( \
544           NULL))
545 
546 template <typename OrigFuncT, size_t N, typename PredicateT, typename... Args>
MaybeTestHook(const bool cond,const char (& dll)[N],const char * func,PredicateT && aPred,Args &&...aArgs)547 bool MaybeTestHook(const bool cond, const char (&dll)[N], const char* func,
548                    PredicateT&& aPred, Args&&... aArgs) {
549   if (!cond) {
550     printf(
551         "TEST-SKIPPED | WindowsDllInterceptor | Skipped hook test for %s from "
552         "%s\n",
553         func, dll);
554     fflush(stdout);
555     return true;
556   }
557 
558   return TestHook<OrigFuncT>(dll, func, std::forward<PredicateT>(aPred),
559                              std::forward<Args>(aArgs)...);
560 }
561 
562 // Like TEST_HOOK, but the test is only executed when cond is true.
563 #define MAYBE_TEST_HOOK(cond, dll, func, pred, comp) \
564   MaybeTestHook<decltype(&func)>(cond, dll, #func,   \
565                                  &Predicates<decltype(&func)>::pred<comp>)
566 
567 #define MAYBE_TEST_HOOK_PARAMS(cond, dll, func, pred, comp, ...) \
568   MaybeTestHook<decltype(&func)>(                                \
569       cond, dll, #func, &Predicates<decltype(&func)>::pred<comp>, __VA_ARGS__)
570 
571 #define MAYBE_TEST_HOOK_SKIP_EXEC(cond, dll, func)                            \
572   MaybeTestHook<decltype(&func)>(                                             \
573       cond, dll, #func,                                                       \
574       reinterpret_cast<bool (*)(typename ReturnType<decltype(&func)>::Type)>( \
575           NULL))
576 
ShouldTestTipTsf()577 bool ShouldTestTipTsf() {
578   if (!IsWin8OrLater()) {
579     return false;
580   }
581 
582   mozilla::DynamicallyLinkedFunctionPtr<decltype(&SHGetKnownFolderPath)>
583       pSHGetKnownFolderPath(L"shell32.dll", "SHGetKnownFolderPath");
584   if (!pSHGetKnownFolderPath) {
585     return false;
586   }
587 
588   PWSTR commonFilesPath = nullptr;
589   if (FAILED(pSHGetKnownFolderPath(FOLDERID_ProgramFilesCommon, 0, nullptr,
590                                    &commonFilesPath))) {
591     return false;
592   }
593 
594   wchar_t fullPath[MAX_PATH + 1] = {};
595   wcscpy(fullPath, commonFilesPath);
596   wcscat(fullPath, L"\\Microsoft Shared\\Ink\\tiptsf.dll");
597   CoTaskMemFree(commonFilesPath);
598 
599   if (!LoadLibraryW(fullPath)) {
600     return false;
601   }
602 
603   // Leak the module so that it's loaded for the interceptor test
604   return true;
605 }
606 
607 static const wchar_t gEmptyUnicodeStringLiteral[] = L"";
608 static UNICODE_STRING gEmptyUnicodeString;
609 static BOOLEAN gIsPresent;
610 
HasApiSetQueryApiSetPresence()611 bool HasApiSetQueryApiSetPresence() {
612   mozilla::DynamicallyLinkedFunctionPtr<decltype(&ApiSetQueryApiSetPresence)>
613       func(L"Api-ms-win-core-apiquery-l1-1-0.dll", "ApiSetQueryApiSetPresence");
614   if (!func) {
615     return false;
616   }
617 
618   // Prepare gEmptyUnicodeString for the test
619   ::RtlInitUnicodeString(&gEmptyUnicodeString, gEmptyUnicodeStringLiteral);
620 
621   return true;
622 }
623 
624 // Set this to true to test function unhooking.
625 const bool ShouldTestUnhookFunction = false;
626 
627 #if defined(_M_X64) || defined(_M_ARM64)
628 
629 // Use VMSharingPolicyUnique for the ShortInterceptor, as it needs to
630 // reserve its trampoline memory in a special location.
631 using ShortInterceptor = mozilla::interceptor::WindowsDllInterceptor<
632     mozilla::interceptor::VMSharingPolicyUnique<
633         mozilla::interceptor::MMPolicyInProcess>>;
634 
635 static ShortInterceptor::FuncHookType<decltype(&::NtMapViewOfSection)>
636     orig_NtMapViewOfSection;
637 
638 #endif  // defined(_M_X64) || defined(_M_ARM64)
639 
TestShortDetour()640 bool TestShortDetour() {
641 #if defined(_M_X64) || defined(_M_ARM64)
642   auto pNtMapViewOfSection = reinterpret_cast<decltype(&::NtMapViewOfSection)>(
643       ::GetProcAddress(::GetModuleHandleW(L"ntdll.dll"), "NtMapViewOfSection"));
644   if (!pNtMapViewOfSection) {
645     printf(
646         "TEST-FAILED | WindowsDllInterceptor | "
647         "Failed to resolve ntdll!NtMapViewOfSection\n");
648     fflush(stdout);
649     return false;
650   }
651 
652   {  // Scope for shortInterceptor
653     ShortInterceptor shortInterceptor;
654     shortInterceptor.TestOnlyDetourInit(
655         L"ntdll.dll",
656         mozilla::interceptor::DetourFlags::eTestOnlyForceShortPatch);
657 
658     InterceptorFunction& interceptorFunc = InterceptorFunction::Create();
659     if (!orig_NtMapViewOfSection.SetDetour(
660             shortInterceptor, "NtMapViewOfSection",
661             reinterpret_cast<decltype(&::NtMapViewOfSection)>(
662                 interceptorFunc.GetFunction()))) {
663       printf(
664           "TEST-FAILED | WindowsDllInterceptor | "
665           "Failed to hook ntdll!NtMapViewOfSection via 10-byte patch\n");
666       fflush(stdout);
667       return false;
668     }
669 
670     interceptorFunc.SetStub(
671         reinterpret_cast<uintptr_t>(orig_NtMapViewOfSection.GetStub()));
672 
673     auto pred =
674         &Predicates<decltype(&::NtMapViewOfSection)>::Ignore<((NTSTATUS)0)>;
675 
676     if (!CheckHook(pNtMapViewOfSection, "ntdll.dll", "NtMapViewOfSection",
677                    pred)) {
678       // CheckHook has already printed the error message for us
679       return false;
680     }
681   }
682 
683   // Now ensure that our hook cleanup worked
684   if (ShouldTestUnhookFunction) {
685     NTSTATUS status =
686         pNtMapViewOfSection(nullptr, nullptr, nullptr, 0, 0, nullptr, nullptr,
687                             ((SECTION_INHERIT)0), 0, 0);
688     if (NT_SUCCESS(status)) {
689       printf(
690           "TEST-FAILED | WindowsDllInterceptor | "
691           "Unexpected successful call to ntdll!NtMapViewOfSection after "
692           "removing short-patched hook\n");
693       fflush(stdout);
694       return false;
695     }
696 
697     printf(
698         "TEST-PASS | WindowsDllInterceptor | "
699         "Successfully unhooked ntdll!NtMapViewOfSection via short patch\n");
700     fflush(stdout);
701   }
702 
703   return true;
704 #else
705   return true;
706 #endif
707 }
708 
709 template <typename InterceptorType>
TestAssemblyFunctions()710 bool TestAssemblyFunctions() {
711   constexpr uintptr_t NoStubAddressCheck = 0;
712   struct TestCase {
713     const char* functionName;
714     uintptr_t expectedStub;
715     explicit TestCase(const char* aFunctionName, uintptr_t aExpectedStub)
716         : functionName(aFunctionName), expectedStub(aExpectedStub) {}
717   } testCases[] = {
718 #if defined(__clang__)
719 // We disable these testcases because the code coverage instrumentation injects
720 // code in a way that WindowsDllInterceptor doesn't understand.
721 #  ifndef MOZ_CODE_COVERAGE
722 #    if defined(_M_X64)
723     // Since we have PatchIfTargetIsRecognizedTrampoline for x64, we expect the
724     // original jump destination is returned as a stub.
725     TestCase("MovPushRet", JumpDestination),
726     TestCase("MovRaxJump", JumpDestination),
727 #    elif defined(_M_IX86)
728     // Skip the stub address check as we always generate a trampoline for x86.
729     TestCase("PushRet", NoStubAddressCheck),
730     TestCase("MovEaxJump", NoStubAddressCheck),
731     TestCase("Opcode83", NoStubAddressCheck),
732     TestCase("LockPrefix", NoStubAddressCheck),
733     TestCase("LooksLikeLockPrefix", NoStubAddressCheck),
734 #    endif
735 #  endif  // MOZ_CODE_COVERAGE
736 #endif    // defined(__clang__)
737   };
738 
739   static const auto patchedFunction = []() { patched_func_called = true; };
740 
741   InterceptorType interceptor;
742   interceptor.Init("TestDllInterceptor.exe");
743 
744   for (const auto& testCase : testCases) {
745     typename InterceptorType::template FuncHookType<void (*)()> hook;
746     bool result = hook.Set(interceptor, testCase.functionName, patchedFunction);
747     if (!result) {
748       printf(
749           "TEST-FAILED | WindowsDllInterceptor | "
750           "Failed to detour %s.\n",
751           testCase.functionName);
752       return false;
753     }
754 
755     const auto actualStub = reinterpret_cast<uintptr_t>(hook.GetStub());
756     if (testCase.expectedStub != NoStubAddressCheck &&
757         actualStub != testCase.expectedStub) {
758       printf(
759           "TEST-FAILED | WindowsDllInterceptor | "
760           "Wrong stub was backed up for %s: %zx\n",
761           testCase.functionName, actualStub);
762       return false;
763     }
764 
765     patched_func_called = false;
766 
767     auto originalFunction = reinterpret_cast<void (*)()>(
768         GetProcAddress(GetModuleHandle(nullptr), testCase.functionName));
769     originalFunction();
770 
771     if (!patched_func_called) {
772       printf(
773           "TEST-FAILED | WindowsDllInterceptor | "
774           "Hook from %s was not called\n",
775           testCase.functionName);
776       return false;
777     }
778 
779     printf("TEST-PASS | WindowsDllInterceptor | %s\n", testCase.functionName);
780   }
781 
782   return true;
783 }
784 
TestDynamicCodePolicy()785 bool TestDynamicCodePolicy() {
786   if (!IsWin8Point1OrLater()) {
787     // Skip if a platform does not support this policy.
788     return true;
789   }
790 
791   PROCESS_MITIGATION_DYNAMIC_CODE_POLICY policy = {};
792   policy.ProhibitDynamicCode = true;
793 
794   mozilla::DynamicallyLinkedFunctionPtr<decltype(&SetProcessMitigationPolicy)>
795       pSetProcessMitigationPolicy(L"kernel32.dll",
796                                   "SetProcessMitigationPolicy");
797   if (!pSetProcessMitigationPolicy) {
798     printf(
799         "TEST-UNEXPECTED-FAIL | WindowsDllInterceptor | "
800         "SetProcessMitigationPolicy does not exist.\n");
801     fflush(stdout);
802     return false;
803   }
804 
805   if (!pSetProcessMitigationPolicy(ProcessDynamicCodePolicy, &policy,
806                                    sizeof(policy))) {
807     printf(
808         "TEST-UNEXPECTED-FAIL | WindowsDllInterceptor | "
809         "Fail to enable ProcessDynamicCodePolicy.\n");
810     fflush(stdout);
811     return false;
812   }
813 
814   WindowsDllInterceptor ExeIntercept;
815   ExeIntercept.Init("TestDllInterceptor.exe");
816 
817   // Make sure we fail to hook a function if ProcessDynamicCodePolicy is on
818   // because we cannot create an executable trampoline region.
819   if (orig_payloadNotHooked.Set(ExeIntercept, "payloadNotHooked",
820                                 &patched_rotatePayload)) {
821     printf(
822         "TEST-UNEXPECTED-FAIL | WindowsDllInterceptor | "
823         "ProcessDynamicCodePolicy is not working.\n");
824     fflush(stdout);
825     return false;
826   }
827 
828   printf(
829       "TEST-PASS | WindowsDllInterceptor | "
830       "Successfully passed TestDynamicCodePolicy.\n");
831   fflush(stdout);
832   return true;
833 }
834 
wmain(int argc,wchar_t * argv[])835 extern "C" int wmain(int argc, wchar_t* argv[]) {
836   LARGE_INTEGER start;
837   QueryPerformanceCounter(&start);
838 
839   // We disable this part of the test because the code coverage instrumentation
840   // injects code in rotatePayload in a way that WindowsDllInterceptor doesn't
841   // understand.
842 #ifndef MOZ_CODE_COVERAGE
843   payload initial = {0x12345678, 0xfc4e9d31, 0x87654321};
844   payload p0, p1;
845   ZeroMemory(&p0, sizeof(p0));
846   ZeroMemory(&p1, sizeof(p1));
847 
848   p0 = rotatePayload(initial);
849 
850   {
851     WindowsDllInterceptor ExeIntercept;
852     ExeIntercept.Init("TestDllInterceptor.exe");
853     if (orig_rotatePayload.Set(ExeIntercept, "rotatePayload",
854                                &patched_rotatePayload)) {
855       printf("TEST-PASS | WindowsDllInterceptor | Hook added\n");
856       fflush(stdout);
857     } else {
858       printf(
859           "TEST-UNEXPECTED-FAIL | WindowsDllInterceptor | Failed to add "
860           "hook\n");
861       fflush(stdout);
862       return 1;
863     }
864 
865     p1 = rotatePayload(initial);
866 
867     if (patched_func_called) {
868       printf("TEST-PASS | WindowsDllInterceptor | Hook called\n");
869       fflush(stdout);
870     } else {
871       printf(
872           "TEST-UNEXPECTED-FAIL | WindowsDllInterceptor | Hook was not "
873           "called\n");
874       fflush(stdout);
875       return 1;
876     }
877 
878     if (p0 == p1) {
879       printf("TEST-PASS | WindowsDllInterceptor | Hook works properly\n");
880       fflush(stdout);
881     } else {
882       printf(
883           "TEST-UNEXPECTED-FAIL | WindowsDllInterceptor | Hook didn't return "
884           "the right information\n");
885       fflush(stdout);
886       return 1;
887     }
888   }
889 
890   patched_func_called = false;
891   ZeroMemory(&p1, sizeof(p1));
892 
893   p1 = rotatePayload(initial);
894 
895   if (ShouldTestUnhookFunction != patched_func_called) {
896     printf(
897         "TEST-PASS | WindowsDllInterceptor | Hook was %scalled after "
898         "unregistration\n",
899         ShouldTestUnhookFunction ? "not " : "");
900     fflush(stdout);
901   } else {
902     printf(
903         "TEST-UNEXPECTED-FAIL | WindowsDllInterceptor | Hook was %scalled "
904         "after unregistration\n",
905         ShouldTestUnhookFunction ? "" : "not ");
906     fflush(stdout);
907     return 1;
908   }
909 
910   if (p0 == p1) {
911     printf(
912         "TEST-PASS | WindowsDllInterceptor | Original function worked "
913         "properly\n");
914     fflush(stdout);
915   } else {
916     printf(
917         "TEST-UNEXPECTED-FAIL | WindowsDllInterceptor | Original function "
918         "didn't return the right information\n");
919     fflush(stdout);
920     return 1;
921   }
922 #endif
923 
924   CredHandle credHandle;
925   memset(&credHandle, 0, sizeof(CredHandle));
926   OBJECT_ATTRIBUTES attributes = {};
927 
928   // NB: These tests should be ordered such that lower-level APIs are tested
929   // before higher-level APIs.
930   if (TestShortDetour() &&
931   // Run <ShortInterceptor> first because <WindowsDllInterceptor>
932   // does not clean up hooks.
933 #if defined(_M_X64)
934       TestAssemblyFunctions<ShortInterceptor>() &&
935 #endif
936       TestAssemblyFunctions<WindowsDllInterceptor>() &&
937 #ifdef _M_IX86
938       // We keep this test to hook complex code on x86. (Bug 850957)
939       TEST_HOOK("ntdll.dll", NtFlushBuffersFile, NotEquals, 0) &&
940 #endif
941       TEST_HOOK("ntdll.dll", NtCreateFile, NotEquals, 0) &&
942       TEST_HOOK("ntdll.dll", NtReadFile, NotEquals, 0) &&
943       TEST_HOOK("ntdll.dll", NtReadFileScatter, NotEquals, 0) &&
944       TEST_HOOK("ntdll.dll", NtWriteFile, NotEquals, 0) &&
945       TEST_HOOK("ntdll.dll", NtWriteFileGather, NotEquals, 0) &&
946       TEST_HOOK_PARAMS("ntdll.dll", NtQueryFullAttributesFile, NotEquals, 0,
947                        &attributes, nullptr) &&
948       TEST_DETOUR_SKIP_EXEC("ntdll.dll", LdrLoadDll) &&
949       TEST_HOOK("ntdll.dll", LdrUnloadDll, NotEquals, 0) &&
950       MAYBE_TEST_HOOK_SKIP_EXEC(IsWin8OrLater(), "ntdll.dll",
951                                 LdrResolveDelayLoadedAPI) &&
952       MAYBE_TEST_HOOK_PARAMS(HasApiSetQueryApiSetPresence(),
953                              "Api-ms-win-core-apiquery-l1-1-0.dll",
954                              ApiSetQueryApiSetPresence, Equals, FALSE,
955                              &gEmptyUnicodeString, &gIsPresent) &&
956       TEST_HOOK("kernelbase.dll", QueryDosDeviceW, Equals, 0) &&
957 #if !defined(_M_ARM64)
958 #  ifndef MOZ_ASAN
959       // Bug 733892: toolkit/crashreporter/nsExceptionHandler.cpp
960       // This fails on ASan because the ASan runtime already hooked this
961       // function
962       TEST_HOOK("kernel32.dll", SetUnhandledExceptionFilter, Ignore, nullptr) &&
963 #  endif
964 #endif  // !defined(_M_ARM64)
965 #ifdef _M_IX86
966       TEST_HOOK_FOR_INVALID_HANDLE_VALUE("kernel32.dll", CreateFileW) &&
967 #endif
968 #if !defined(_M_ARM64)
969       TEST_HOOK_FOR_INVALID_HANDLE_VALUE("kernel32.dll", CreateFileA) &&
970 #endif  // !defined(_M_ARM64)
971 #if !defined(_M_ARM64)
972       TEST_HOOK("kernel32.dll", TlsAlloc, NotEquals, TLS_OUT_OF_INDEXES) &&
973       TEST_HOOK_PARAMS("kernel32.dll", TlsFree, Equals, FALSE,
974                        TLS_OUT_OF_INDEXES) &&
975       TEST_HOOK("kernel32.dll", CloseHandle, Equals, FALSE) &&
976       TEST_HOOK("kernel32.dll", DuplicateHandle, Equals, FALSE) &&
977 #endif  // !defined(_M_ARM64)
978       TEST_DETOUR_SKIP_EXEC("kernel32.dll", BaseThreadInitThunk) &&
979 #if defined(_M_X64) || defined(_M_ARM64)
980       MAYBE_TEST_HOOK(!IsWin8OrLater(), "kernel32.dll",
981                       RtlInstallFunctionTableCallback, Equals, FALSE) &&
982       TEST_HOOK("user32.dll", GetKeyState, Ignore, 0) &&  // see Bug 1316415
983 #endif
984       TEST_HOOK("user32.dll", GetWindowInfo, Equals, FALSE) &&
985       TEST_HOOK("user32.dll", TrackPopupMenu, Equals, FALSE) &&
986       TEST_DETOUR("user32.dll", CreateWindowExW, Equals, nullptr) &&
987       TEST_HOOK("user32.dll", InSendMessageEx, Equals, ISMEX_NOSEND) &&
988       TEST_HOOK("user32.dll", SendMessageTimeoutW, Equals, 0) &&
989       TEST_HOOK("user32.dll", SetCursorPos, NotEquals, FALSE) &&
990 #if !defined(_M_ARM64)
991       TEST_HOOK("imm32.dll", ImmGetContext, Equals, nullptr) &&
992 #endif  // !defined(_M_ARM64)
993       TEST_HOOK("imm32.dll", ImmGetCompositionStringW, Ignore, 0) &&
994       TEST_HOOK_SKIP_EXEC("imm32.dll", ImmSetCandidateWindow) &&
995       TEST_HOOK("imm32.dll", ImmNotifyIME, Equals, 0) &&
996       TEST_HOOK("comdlg32.dll", GetSaveFileNameW, Ignore, FALSE) &&
997       TEST_HOOK("comdlg32.dll", GetOpenFileNameW, Ignore, FALSE) &&
998 #if defined(_M_X64)
999       TEST_HOOK("comdlg32.dll", PrintDlgW, Ignore, 0) &&
1000 #endif
1001       MAYBE_TEST_HOOK(ShouldTestTipTsf(), "tiptsf.dll", ProcessCaretEvents,
1002                       Ignore, nullptr) &&
1003       TEST_HOOK("wininet.dll", InternetOpenA, NotEquals, nullptr) &&
1004       TEST_HOOK("wininet.dll", InternetCloseHandle, Equals, FALSE) &&
1005       TEST_HOOK("wininet.dll", InternetConnectA, Equals, nullptr) &&
1006       TEST_HOOK("wininet.dll", InternetQueryDataAvailable, Equals, FALSE) &&
1007       TEST_HOOK("wininet.dll", InternetReadFile, Equals, FALSE) &&
1008       TEST_HOOK("wininet.dll", InternetWriteFile, Equals, FALSE) &&
1009       TEST_HOOK("wininet.dll", InternetSetOptionA, Equals, FALSE) &&
1010       TEST_HOOK("wininet.dll", HttpAddRequestHeadersA, Equals, FALSE) &&
1011       TEST_HOOK("wininet.dll", HttpOpenRequestA, Equals, nullptr) &&
1012       TEST_HOOK("wininet.dll", HttpQueryInfoA, Equals, FALSE) &&
1013       TEST_HOOK("wininet.dll", HttpSendRequestA, Equals, FALSE) &&
1014       TEST_HOOK("wininet.dll", HttpSendRequestExA, Equals, FALSE) &&
1015       TEST_HOOK("wininet.dll", HttpEndRequestA, Equals, FALSE) &&
1016       TEST_HOOK("wininet.dll", InternetQueryOptionA, Equals, FALSE) &&
1017       TEST_HOOK("sspicli.dll", AcquireCredentialsHandleA, NotEquals,
1018                 SEC_E_OK) &&
1019       TEST_HOOK_PARAMS("sspicli.dll", QueryCredentialsAttributesA, Equals,
1020                        SEC_E_INVALID_HANDLE, &credHandle, 0, nullptr) &&
1021       TEST_HOOK_PARAMS("sspicli.dll", FreeCredentialsHandle, Equals,
1022                        SEC_E_INVALID_HANDLE, &credHandle) &&
1023       // Run TestDynamicCodePolicy() at the end because the policy is
1024       // irreversible.
1025       TestDynamicCodePolicy()) {
1026     printf("TEST-PASS | WindowsDllInterceptor | all checks passed\n");
1027 
1028     LARGE_INTEGER end, freq;
1029     QueryPerformanceCounter(&end);
1030 
1031     QueryPerformanceFrequency(&freq);
1032 
1033     LARGE_INTEGER result;
1034     result.QuadPart = end.QuadPart - start.QuadPart;
1035     result.QuadPart *= 1000000;
1036     result.QuadPart /= freq.QuadPart;
1037 
1038     printf("Elapsed time: %lld microseconds\n", result.QuadPart);
1039 
1040     return 0;
1041   }
1042 
1043   return 1;
1044 }
1045