1 #pragma once
2 
3 #include <windows.h>
4 #include <PathCch.h>
5 
6 #include "catch.hpp"
7 
8 #include <wil/filesystem.h>
9 #include <wil/result.h>
10 
11 #define REPORTS_ERROR(expr) witest::ReportsError(wistd::is_same<HRESULT, decltype(expr)>{}, [&]() { return expr; })
12 #define REQUIRE_ERROR(expr) REQUIRE(REPORTS_ERROR(expr))
13 #define REQUIRE_NOERROR(expr) REQUIRE_FALSE(REPORTS_ERROR(expr))
14 
15 #define CRASHES(expr) witest::DoesCodeCrash([&]() { return expr; })
16 #define REQUIRE_CRASH(expr) REQUIRE(CRASHES(expr))
17 #define REQUIRE_NOCRASH(expr) REQUIRE_FALSE(CRASHES(expr))
18 
19 // NOTE: SUCCEEDED/FAILED macros not used here since Catch2 can give us better diagnostics if it knows the HRESULT value
20 #define REQUIRE_SUCCEEDED(expr) REQUIRE((HRESULT)(expr) >= 0)
21 #define REQUIRE_FAILED(expr) REQUIRE((HRESULT)(expr) < 0)
22 
23 // MACRO double evaluation check.
24 // The following macro illustrates a common problem with writing macros:
25 //      #define MY_MAX(a, b) (((a) > (b)) ? (a) : (b))
26 // The issue is that whatever code is being used for both a and b is being executed twice.
27 // This isn't harmful when thinking of constant numerics, but consider this example:
28 //      MY_MAX(4, InterlockedIncrement(&cCount))
29 // This evaluates the (B) parameter twice and results in incrementing the counter twice.
30 // We use MDEC in unit tests to verify that this kind of pattern is not present.  A test
31 // of this kind:
32 //      MY_MAX(MDEC(4), MDEC(InterlockedIncrement(&cCount))
33 // will verify that the parameters are not evaluated more than once.
34 #define MDEC(PARAM) (witest::details::MacroDoubleEvaluationCheck(__LINE__, #PARAM), PARAM)
35 
36 // There's some functionality that we need for testing that's not available for the app partition. Since those tests are
37 // primarily compilation tests, declare what's needed here
38 extern "C" {
39 
40 #if !WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM | WINAPI_PARTITION_GAMES)
41 WINBASEAPI _Ret_maybenull_
42 PVOID WINAPI AddVectoredExceptionHandler(_In_ ULONG First, _In_ PVECTORED_EXCEPTION_HANDLER Handler);
43 
44 WINBASEAPI
45 ULONG WINAPI RemoveVectoredExceptionHandler(_In_ PVOID Handle);
46 #endif
47 
48 }
49 
50 #pragma warning(push)
51 #pragma warning(disable: 4702) // Unreachable code
52 
53 namespace witest
54 {
55     namespace details
56     {
MacroDoubleEvaluationCheck(size_t uLine,_In_ const char * pszCode)57         inline void MacroDoubleEvaluationCheck(size_t uLine, _In_ const char* pszCode)
58         {
59             struct SEval
60             {
61                 size_t uLine;
62                 const char* pszCode;
63             };
64 
65             static SEval rgEval[15] = {};
66             static size_t nOffset = 0;
67 
68             for (auto& eval : rgEval)
69             {
70                 if ((eval.uLine == uLine) && (eval.pszCode != nullptr) && (0 == strcmp(pszCode, eval.pszCode)))
71                 {
72                     // This verification indicates that macro-double-evaluation check is firing for a particular usage of MDEC().
73                     FAIL("Expression '" << pszCode << "' double evaluated in macro on line " << uLine);
74                 }
75             }
76 
77             rgEval[nOffset].uLine = uLine;
78             rgEval[nOffset].pszCode = pszCode;
79             nOffset = (nOffset + 1) % ARRAYSIZE(rgEval);
80         }
81 
82         template <typename T>
83         class AssignTemporaryValueCleanup
84         {
85         public:
86             AssignTemporaryValueCleanup(_In_ AssignTemporaryValueCleanup const &) = delete;
87             AssignTemporaryValueCleanup & operator=(_In_ AssignTemporaryValueCleanup const &) = delete;
88 
AssignTemporaryValueCleanup(_Inout_ T * pVal,T val)89             explicit AssignTemporaryValueCleanup(_Inout_ T *pVal, T val) WI_NOEXCEPT :
90                 m_pVal(pVal),
91                 m_valOld(*pVal)
92             {
93                 *pVal = val;
94             }
95 
AssignTemporaryValueCleanup(_Inout_ AssignTemporaryValueCleanup && other)96             AssignTemporaryValueCleanup(_Inout_ AssignTemporaryValueCleanup && other) WI_NOEXCEPT :
97                 m_pVal(other.m_pVal),
98                 m_valOld(other.m_valOld)
99             {
100                 other.m_pVal = nullptr;
101             }
102 
~AssignTemporaryValueCleanup()103             ~AssignTemporaryValueCleanup() WI_NOEXCEPT
104             {
105                 operator()();
106             }
107 
operator()108             void operator()() WI_NOEXCEPT
109             {
110                 if (m_pVal != nullptr)
111                 {
112                     *m_pVal = m_valOld;
113                     m_pVal = nullptr;
114                 }
115             }
116 
Dismiss()117             void Dismiss() WI_NOEXCEPT
118             {
119                 m_pVal = nullptr;
120             }
121 
122         private:
123             T *m_pVal;
124             T m_valOld;
125         };
126     }
127 
128     // Use the following routine to allow for a variable to be swapped with another and automatically revert the
129     // assignment at the end of the scope.
130     // Example:
131     //      int nFoo = 10
132     //      {
133     //          auto revert = witest::AssignTemporaryValue(&nFoo, 12);
134     //          // nFoo will now be 12 within this scope...
135     //      }
136     //      // and nFoo is back to 10 within the outer scope
137     template <typename T>
AssignTemporaryValue(_Inout_ T * pVal,T val)138     inline witest::details::AssignTemporaryValueCleanup<T> AssignTemporaryValue(_Inout_ T *pVal, T val) WI_NOEXCEPT
139     {
140         return witest::details::AssignTemporaryValueCleanup<T>(pVal, val);
141     }
142 
143     //! Global class which tracks objects that derive from @ref AllocatedObject.
144     //! Use `witest::g_objectCount.Leaked()` to determine if an object deriving from `AllocatedObject` has been leaked.
145     class GlobalCount
146     {
147     public:
148         int m_count = 0;
149 
150         //! Returns `true` if there are any objects that derive from @ref AllocatedObject still in memory.
Leaked()151         bool Leaked() const
152         {
153             return (m_count != 0);
154         }
155 
~GlobalCount()156         ~GlobalCount()
157         {
158             if (Leaked())
159             {
160                 // NOTE: This runs when no test is active, but will still cause an assert failure to notify
161                 FAIL("GlobalCount is non-zero; there is a leak somewhere");
162             }
163         }
164     };
165     __declspec(selectany) GlobalCount g_objectCount;
166 
167     //! Derive an allocated test object from witest::AllocatedObject to ensure that those objects aren't leaked in the test.
168     //! Note that you can call g_objectCount.Leaked() at any point to determine if a leak has already occurred (assuming that
169     //! all objects should have been destroyed at that point.
170     class AllocatedObject
171     {
172     public:
AllocatedObject()173         AllocatedObject()   { g_objectCount.m_count++; }
~AllocatedObject()174         ~AllocatedObject()  { g_objectCount.m_count--; }
175     };
176 
177     template <typename Lambda>
DoesCodeThrow(Lambda && callOp)178     bool DoesCodeThrow(Lambda&& callOp)
179     {
180 #ifdef WIL_ENABLE_EXCEPTIONS
181         try
182 #endif
183         {
184             callOp();
185         }
186 #ifdef WIL_ENABLE_EXCEPTIONS
187         catch (...)
188         {
189             return true;
190         }
191 #endif
192 
193         return false;
194     }
195 
196     [[noreturn]]
TranslateFailFastException(PEXCEPTION_RECORD rec,PCONTEXT,DWORD)197     inline void __stdcall TranslateFailFastException(PEXCEPTION_RECORD rec, PCONTEXT, DWORD)
198     {
199         // RaiseFailFastException cannot be continued or handled. By instead calling RaiseException, it allows us to
200         // handle exceptions
201         ::RaiseException(rec->ExceptionCode, rec->ExceptionFlags, rec->NumberParameters, rec->ExceptionInformation);
202 #ifdef __clang__
203         __builtin_unreachable();
204 #endif
205     }
206 
207     constexpr DWORD msvc_exception_code = 0xE06D7363;
208 
209     // This is a MAJOR hack. Catch2 registers a vectored exception handler - which gets run before our handler below -
210     // that interprets a set of exception codes as fatal. We don't want this behavior since we may be expecting such
211     // crashes, so instead translate all exception codes to something not fatal
TranslateExceptionCodeHandler(PEXCEPTION_POINTERS info)212     inline LONG WINAPI TranslateExceptionCodeHandler(PEXCEPTION_POINTERS info)
213     {
214         if (info->ExceptionRecord->ExceptionCode != witest::msvc_exception_code)
215         {
216             info->ExceptionRecord->ExceptionCode = STATUS_STACK_BUFFER_OVERRUN;
217         }
218 
219         return EXCEPTION_CONTINUE_SEARCH;
220     }
221 
222     namespace details
223     {
DoesCodeCrash(wistd::function<void ()> & callOp)224         inline bool DoesCodeCrash(wistd::function<void()>& callOp)
225         {
226             bool result = false;
227             __try
228             {
229                 callOp();
230             }
231             // Let C++ exceptions pass through
232             __except ((::GetExceptionCode() != msvc_exception_code) ? EXCEPTION_EXECUTE_HANDLER : EXCEPTION_CONTINUE_SEARCH)
233             {
234                 result = true;
235             }
236             return result;
237         }
238     }
239 
DoesCodeCrash(wistd::function<void ()> callOp)240     inline bool DoesCodeCrash(wistd::function<void()> callOp)
241     {
242         // See above; we don't want to actually fail fast, so make sure we raise a different exception instead
243         auto restoreHandler = AssignTemporaryValue(&wil::details::g_pfnRaiseFailFastException, TranslateFailFastException);
244 
245         auto handler = AddVectoredExceptionHandler(1, TranslateExceptionCodeHandler);
246         auto removeVectoredHandler = wil::scope_exit([&] { RemoveVectoredExceptionHandler(handler); });
247 
248         return details::DoesCodeCrash(callOp);
249     }
250 
251     template <typename Lambda>
ReportsError(wistd::false_type,Lambda && callOp)252     bool ReportsError(wistd::false_type, Lambda&& callOp)
253     {
254         bool doesThrow = false;
255         bool doesCrash = DoesCodeCrash([&]()
256         {
257             doesThrow = DoesCodeThrow(callOp);
258         });
259 
260         return doesThrow || doesCrash;
261     }
262 
263     template <typename Lambda>
ReportsError(wistd::true_type,Lambda && callOp)264     bool ReportsError(wistd::true_type, Lambda&& callOp)
265     {
266         return FAILED(callOp());
267     }
268 
269 #ifdef WIL_ENABLE_EXCEPTIONS
270     class TestFailureCache final :
271         public wil::details::IFailureCallback
272     {
273     public:
TestFailureCache()274         TestFailureCache() :
275             m_callbackHolder(this)
276         {
277         }
278 
clear()279         void clear()
280         {
281             m_failures.clear();
282         }
283 
size()284         size_t size() const
285         {
286             return m_failures.size();
287         }
288 
empty()289         bool empty() const
290         {
291             return m_failures.empty();
292         }
293 
294         const wil::FailureInfo& operator[](size_t pos) const
295         {
296             return m_failures.at(pos).GetFailureInfo();
297         }
298 
299         // IFailureCallback
NotifyFailure(wil::FailureInfo const & failure)300         bool NotifyFailure(wil::FailureInfo const & failure) WI_NOEXCEPT override
301         {
302             m_failures.emplace_back(failure);
303             return false;
304         }
305 
306     private:
307         std::vector<wil::StoredFailureInfo> m_failures;
308         wil::details::ThreadFailureCallbackHolder m_callbackHolder;
309     };
310 #endif
311 
GetTempFileName(wchar_t (& result)[MAX_PATH])312     inline HRESULT GetTempFileName(wchar_t (&result)[MAX_PATH])
313     {
314         wchar_t dir[MAX_PATH];
315         RETURN_LAST_ERROR_IF(::GetTempPathW(MAX_PATH, dir) == 0);
316         RETURN_LAST_ERROR_IF(::GetTempFileNameW(dir, L"wil", 0, result) == 0);
317         return S_OK;
318     }
319 
320     inline HRESULT CreateUniqueFolderPath(wchar_t (&buffer)[MAX_PATH], PCWSTR root = nullptr)
321     {
322         if (root)
323         {
324             RETURN_LAST_ERROR_IF(::GetTempFileNameW(root, L"wil", 0, buffer) == 0);
325         }
326         else
327         {
328             wchar_t tempPath[MAX_PATH];
329             RETURN_LAST_ERROR_IF(::GetTempPathW(ARRAYSIZE(tempPath), tempPath) == 0);
330             RETURN_LAST_ERROR_IF(::GetLongPathNameW(tempPath, tempPath, ARRAYSIZE(tempPath)) == 0);
331             RETURN_LAST_ERROR_IF(::GetTempFileNameW(tempPath, L"wil", 0, buffer) == 0);
332         }
333         RETURN_IF_WIN32_BOOL_FALSE(DeleteFileW(buffer));
334         PathCchRemoveExtension(buffer, ARRAYSIZE(buffer));
335         return S_OK;
336     }
337 
338 #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
339 
340     struct TestFolder
341     {
TestFolderTestFolder342         TestFolder()
343         {
344             if (SUCCEEDED(CreateUniqueFolderPath(m_path)) && SUCCEEDED(wil::CreateDirectoryDeepNoThrow(m_path)))
345             {
346                 m_valid = true;
347             }
348         }
349 
TestFolderTestFolder350         TestFolder(PCWSTR path)
351         {
352             if (SUCCEEDED(StringCchCopyW(m_path, ARRAYSIZE(m_path), path)) && SUCCEEDED(wil::CreateDirectoryDeepNoThrow(m_path)))
353             {
354                 m_valid = true;
355             }
356         }
357 
358         TestFolder(const TestFolder&) = delete;
359         TestFolder& operator=(const TestFolder&) = delete;
360 
TestFolderTestFolder361         TestFolder(TestFolder&& other)
362         {
363             if (other.m_valid)
364             {
365                 m_valid = true;
366                 other.m_valid = false;
367                 wcscpy_s(m_path, other.m_path);
368             }
369         }
370 
~TestFolderTestFolder371         ~TestFolder()
372         {
373             if (m_valid)
374             {
375                 wil::RemoveDirectoryRecursiveNoThrow(m_path);
376             }
377         }
378 
379         operator bool() const
380         {
381             return m_valid;
382         }
383 
PCWSTRTestFolder384         operator PCWSTR() const
385         {
386             return m_path;
387         }
388 
PathTestFolder389         PCWSTR Path() const
390         {
391             return m_path;
392         }
393 
394     private:
395 
396         bool m_valid = false;
397         wchar_t m_path[MAX_PATH] = L"";
398     };
399 
400     struct TestFile
401     {
TestFileTestFile402         TestFile(PCWSTR path)
403         {
404             if (SUCCEEDED(StringCchCopyW(m_path, ARRAYSIZE(m_path), path)))
405             {
406                 Create();
407             }
408         }
409 
TestFileTestFile410         TestFile(PCWSTR dirPath, PCWSTR fileName)
411         {
412             if (SUCCEEDED(StringCchCopyW(m_path, ARRAYSIZE(m_path), dirPath)) && SUCCEEDED(PathCchAppend(m_path, ARRAYSIZE(m_path), fileName)))
413             {
414                 Create();
415             }
416         }
417 
418         TestFile(const TestFile&) = delete;
419         TestFile& operator=(const TestFile&) = delete;
420 
TestFileTestFile421         TestFile(TestFile&& other)
422         {
423             if (other.m_valid)
424             {
425                 m_valid = true;
426                 m_deleteDir = other.m_deleteDir;
427                 other.m_valid = other.m_deleteDir = false;
428                 wcscpy_s(m_path, other.m_path);
429             }
430         }
431 
~TestFileTestFile432         ~TestFile()
433         {
434             // Best effort on all of these
435             if (m_valid)
436             {
437                 ::DeleteFileW(m_path);
438             }
439             if (m_deleteDir)
440             {
441                 size_t parentLength;
442                 if (wil::try_get_parent_path_range(m_path, &parentLength))
443                 {
444                     m_path[parentLength] = L'\0';
445                     ::RemoveDirectoryW(m_path);
446                     m_path[parentLength] = L'\\';
447                 }
448             }
449         }
450 
451         operator bool() const
452         {
453             return m_valid;
454         }
455 
PCWSTRTestFile456         operator PCWSTR() const
457         {
458             return m_path;
459         }
460 
PathTestFile461         PCWSTR Path() const
462         {
463             return m_path;
464         }
465 
466     private:
467 
CreateTestFile468         HRESULT Create()
469         {
470             WI_ASSERT(!m_valid && !m_deleteDir);
471             wil::unique_hfile fileHandle(::CreateFileW(m_path,
472                 FILE_WRITE_ATTRIBUTES,
473                 FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr,
474                 CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, nullptr));
475             if (!fileHandle)
476             {
477                 auto err = ::GetLastError();
478                 size_t parentLength;
479                 if ((err == ERROR_PATH_NOT_FOUND) && wil::try_get_parent_path_range(m_path, &parentLength))
480                 {
481                     m_path[parentLength] = L'\0';
482                     RETURN_IF_FAILED(wil::CreateDirectoryDeepNoThrow(m_path));
483                     m_deleteDir = true;
484 
485                     m_path[parentLength] = L'\\';
486                     fileHandle.reset(::CreateFileW(m_path,
487                         FILE_WRITE_ATTRIBUTES,
488                         FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr,
489                         CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, nullptr));
490                     RETURN_LAST_ERROR_IF(!fileHandle);
491                 }
492                 else
493                 {
494                     RETURN_WIN32(err);
495                 }
496             }
497 
498             m_valid = true;
499             return S_OK;
500         }
501 
502         bool m_valid = false;
503         bool m_deleteDir = false;
504         wchar_t m_path[MAX_PATH] = L"";
505     };
506 
507 #endif
508 }
509 
510 #pragma warning(pop)
511