1 /*
2  * PROJECT:     ReactOS kernel-mode tests
3  * LICENSE:     LGPL-2.1+ (https://spdx.org/licenses/LGPL-2.1+)
4  * PURPOSE:     Kernel-Mode Test Suite loader application
5  * COPYRIGHT:   Copyright 2011-2018 Thomas Faber <thomas.faber@reactos.org>
6  */
7 
8 #define KMT_DEFINE_TEST_FUNCTIONS
9 #include <kmt_test.h>
10 
11 #include "kmtest.h"
12 #include <kmt_public.h>
13 
14 #include <assert.h>
15 #include <stdio.h>
16 #include <stdlib.h>
17 
18 #define SERVICE_NAME        L"Kmtest"
19 #define SERVICE_PATH        L"kmtest_drv.sys"
20 #define SERVICE_DESCRIPTION L"ReactOS Kernel-Mode Test Suite Driver"
21 
22 #define RESULTBUFFER_SIZE   (1024 * 1024)
23 
24 typedef enum
25 {
26     KMT_DO_NOTHING,
27     KMT_LIST_TESTS,
28     KMT_LIST_ALL_TESTS,
29     KMT_RUN_TEST,
30 } KMT_OPERATION;
31 
32 HANDLE KmtestHandle;
33 SC_HANDLE KmtestServiceHandle;
34 PCSTR ErrorFileAndLine = "No error";
35 
36 static void OutputError(IN DWORD Error);
37 static DWORD ListTests(IN BOOLEAN IncludeHidden);
38 static PKMT_TESTFUNC FindTest(IN PCSTR TestName);
39 static DWORD OutputResult(IN PCSTR TestName);
40 static DWORD RunTest(IN PCSTR TestName);
41 int __cdecl main(int ArgCount, char **Arguments);
42 
43 /**
44  * @name OutputError
45  *
46  * Output an error message to the console.
47  *
48  * @param Error
49  *        Win32 error code
50  */
51 static
52 void
53 OutputError(
54     IN DWORD Error)
55 {
56     PSTR Message;
57     if (!FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS | FORMAT_MESSAGE_ALLOCATE_BUFFER,
58                    NULL, Error, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&Message, 0, NULL))
59     {
60         fprintf(stderr, "%s: Could not retrieve error message (error 0x%08lx). Original error: 0x%08lx\n",
61             ErrorFileAndLine, GetLastError(), Error);
62         return;
63     }
64 
65     fprintf(stderr, "%s: error 0x%08lx: %s\n", ErrorFileAndLine, Error, Message);
66 
67     LocalFree(Message);
68 }
69 
70 /**
71  * @name CompareTestNames
72  *
73  * strcmp that skips a leading '-' on either string if present
74  *
75  * @param Str1
76  * @param Str2
77  * @return see strcmp
78  */
79 static
80 INT
81 CompareTestNames(
82     IN PCSTR Str1,
83     IN PCSTR Str2)
84 {
85     if (*Str1 == '-')
86         ++Str1;
87     if (*Str2 == '-')
88         ++Str2;
89     while (*Str1 && *Str1 == *Str2)
90     {
91         ++Str1;
92         ++Str2;
93     }
94     return *Str1 - *Str2;
95 }
96 
97 /**
98  * @name ListTests
99  *
100  * Output the list of tests to the console.
101  * The list will comprise tests as listed by the driver
102  * in addition to user-mode tests in TestList.
103  *
104  * @param IncludeHidden
105  *        TRUE to include "hidden" tests prefixed with a '-'
106  *
107  * @return Win32 error code
108  */
109 static
110 DWORD
111 ListTests(
112     IN BOOLEAN IncludeHidden)
113 {
114     DWORD Error = ERROR_SUCCESS;
115     CHAR Buffer[1024];
116     DWORD BytesRead;
117     PCSTR TestName = Buffer;
118     PCKMT_TEST TestEntry = TestList;
119     PCSTR NextTestName;
120 
121     puts("Valid test names:");
122 
123     // get test list from driver
124     if (!DeviceIoControl(KmtestHandle, IOCTL_KMTEST_GET_TESTS, NULL, 0, Buffer, sizeof Buffer, &BytesRead, NULL))
125         error_goto(Error, cleanup);
126 
127     // output test list plus user-mode tests
128     while (TestEntry->TestName || *TestName)
129     {
130         if (!TestEntry->TestName)
131         {
132             NextTestName = TestName;
133             TestName += strlen(TestName) + 1;
134         }
135         else if (!*TestName)
136         {
137             NextTestName = TestEntry->TestName;
138             ++TestEntry;
139         }
140         else
141         {
142             INT Result = CompareTestNames(TestEntry->TestName, TestName);
143 
144             if (Result == 0)
145             {
146                 NextTestName = TestEntry->TestName;
147                 TestName += strlen(TestName) + 1;
148                 ++TestEntry;
149             }
150             else if (Result < 0)
151             {
152                 NextTestName = TestEntry->TestName;
153                 ++TestEntry;
154             }
155             else
156             {
157                 NextTestName = TestName;
158                 TestName += strlen(TestName) + 1;
159             }
160         }
161 
162         if (IncludeHidden && NextTestName[0] == '-')
163             ++NextTestName;
164 
165         if (NextTestName[0] != '-')
166             printf("    %s\n", NextTestName);
167     }
168 
169 cleanup:
170     return Error;
171 }
172 
173 /**
174  * @name FindTest
175  *
176  * Find a test in TestList by name.
177  *
178  * @param TestName
179  *        Name of the test to look for. Case sensitive
180  *
181  * @return pointer to test function, or NULL if not found
182  */
183 static
184 PKMT_TESTFUNC
185 FindTest(
186     IN PCSTR TestName)
187 {
188     PCKMT_TEST TestEntry = TestList;
189 
190     for (TestEntry = TestList; TestEntry->TestName; ++TestEntry)
191     {
192         PCSTR TestEntryName = TestEntry->TestName;
193 
194         // skip leading '-' if present
195         if (*TestEntryName == '-')
196             ++TestEntryName;
197 
198         if (!lstrcmpA(TestEntryName, TestName))
199             break;
200     }
201 
202     return TestEntry->TestFunction;
203 }
204 
205 /**
206  * @name OutputResult
207  *
208  * Output the test results in ResultBuffer to the console.
209  *
210  * @param TestName
211  *        Name of the test whose result is to be printed
212  *
213  * @return Win32 error code
214  */
215 static
216 DWORD
217 OutputResult(
218     IN PCSTR TestName)
219 {
220     DWORD Error = ERROR_SUCCESS;
221     DWORD BytesWritten;
222     DWORD LogBufferLength;
223     DWORD Offset = 0;
224     /* A console window can't handle a single
225      * huge block of data, so split it up */
226     const DWORD BlockSize = 8 * 1024;
227 
228     KmtFinishTest(TestName);
229 
230     LogBufferLength = ResultBuffer->LogBufferLength;
231     for (Offset = 0; Offset < LogBufferLength; Offset += BlockSize)
232     {
233         DWORD Length = min(LogBufferLength - Offset, BlockSize);
234         if (!WriteFile(GetStdHandle(STD_OUTPUT_HANDLE), ResultBuffer->LogBuffer + Offset, Length, &BytesWritten, NULL))
235             error(Error);
236     }
237 
238     return Error;
239 }
240 
241 /**
242  * @name RunTest
243  *
244  * Run the named test and output its results.
245  *
246  * @param TestName
247  *        Name of the test to run. Case sensitive
248  *
249  * @return Win32 error code
250  */
251 static
252 DWORD
253 RunTest(
254     IN PCSTR TestName)
255 {
256     DWORD Error = ERROR_SUCCESS;
257     PKMT_TESTFUNC TestFunction;
258     DWORD BytesRead;
259 
260     assert(TestName != NULL);
261 
262     if (!ResultBuffer)
263     {
264         ResultBuffer = KmtAllocateResultBuffer(RESULTBUFFER_SIZE);
265         if (!ResultBuffer)
266             error_goto(Error, cleanup);
267         if (!DeviceIoControl(KmtestHandle, IOCTL_KMTEST_SET_RESULTBUFFER, ResultBuffer, RESULTBUFFER_SIZE, NULL, 0, &BytesRead, NULL))
268             error_goto(Error, cleanup);
269     }
270 
271     // check test list
272     TestFunction = FindTest(TestName);
273 
274     if (TestFunction)
275     {
276         TestFunction();
277         goto cleanup;
278     }
279 
280     // not found in user-mode test list, call driver
281     Error = KmtRunKernelTest(TestName);
282 
283 cleanup:
284     if (!Error)
285         Error = OutputResult(TestName);
286 
287     return Error;
288 }
289 
290 /**
291  * @name main
292  *
293  * Program entry point
294  *
295  * @param ArgCount
296  * @param Arguments
297  *
298  * @return EXIT_SUCCESS on success, EXIT_FAILURE on failure
299  */
300 int
301 main(
302     int ArgCount,
303     char **Arguments)
304 {
305     INT Status = EXIT_SUCCESS;
306     DWORD Error = ERROR_SUCCESS;
307     PCSTR AppName = "kmtest.exe";
308     PCSTR TestName = NULL;
309     KMT_OPERATION Operation = KMT_DO_NOTHING;
310     BOOLEAN ShowHidden = FALSE;
311 
312     Error = KmtServiceInit();
313     if (Error)
314         goto cleanup;
315 
316     if (ArgCount >= 1)
317         AppName = Arguments[0];
318 
319     if (ArgCount <= 1)
320     {
321         printf("Usage: %s <test_name>                 - run the specified test (creates/starts the driver(s) as appropriate)\n", AppName);
322         printf("       %s --list                      - list available tests\n", AppName);
323         printf("       %s --list-all                  - list available tests, including hidden\n", AppName);
324         printf("       %s <create|delete|start|stop>  - manage the kmtest driver\n\n", AppName);
325         Operation = KMT_LIST_TESTS;
326     }
327     else
328     {
329         TestName = Arguments[1];
330         if (!lstrcmpA(TestName, "create"))
331             Error = KmtCreateService(SERVICE_NAME, SERVICE_PATH, SERVICE_DESCRIPTION, &KmtestServiceHandle);
332         else if (!lstrcmpA(TestName, "delete"))
333             Error = KmtDeleteService(SERVICE_NAME, &KmtestServiceHandle);
334         else if (!lstrcmpA(TestName, "start"))
335             Error = KmtStartService(SERVICE_NAME, &KmtestServiceHandle);
336         else if (!lstrcmpA(TestName, "stop"))
337             Error = KmtStopService(SERVICE_NAME, &KmtestServiceHandle);
338 
339         else if (!lstrcmpA(TestName, "--list"))
340             Operation = KMT_LIST_TESTS;
341         else if (!lstrcmpA(TestName, "--list-all"))
342             Operation = KMT_LIST_ALL_TESTS;
343         else
344             Operation = KMT_RUN_TEST;
345     }
346 
347     if (Operation)
348     {
349         Error = KmtCreateAndStartService(SERVICE_NAME, SERVICE_PATH, SERVICE_DESCRIPTION, &KmtestServiceHandle, FALSE);
350         if (Error)
351             goto cleanup;
352 
353         KmtestHandle = CreateFile(KMTEST_DEVICE_PATH, GENERIC_READ | GENERIC_WRITE, 0, NULL, OPEN_EXISTING, 0, NULL);
354         if (KmtestHandle == INVALID_HANDLE_VALUE)
355             error_goto(Error, cleanup);
356 
357         switch (Operation)
358         {
359             case KMT_LIST_ALL_TESTS:
360                 ShowHidden = TRUE;
361                 /* fall through */
362             case KMT_LIST_TESTS:
363                 Error = ListTests(ShowHidden);
364                 break;
365             case KMT_RUN_TEST:
366                 Error = RunTest(TestName);
367                 break;
368             default:
369                 assert(FALSE);
370         }
371     }
372 
373 cleanup:
374     if (KmtestHandle)
375         CloseHandle(KmtestHandle);
376 
377     if (ResultBuffer)
378         KmtFreeResultBuffer(ResultBuffer);
379 
380     KmtCloseService(&KmtestServiceHandle);
381 
382     if (Error)
383         KmtServiceCleanup(TRUE);
384     else
385         Error = KmtServiceCleanup(FALSE);
386 
387     if (Error)
388     {
389         OutputError(Error);
390 
391         Status = EXIT_FAILURE;
392     }
393 
394     return Status;
395 }
396