1 /*
2  * PROJECT:         ReactOS kernel-mode tests
3  * LICENSE:         GPLv2+ - See COPYING in the top level directory
4  * PURPOSE:         File system mini-filter support routines
5  * PROGRAMMER:      Ged Murphy <gedmurphy@reactos.org>
6  */
7 
8 #include <kmt_test.h>
9 
10 #define KMT_FLT_USER_MODE
11 #include "kmtest.h"
12 #include <kmt_public.h>
13 
14 #include <assert.h>
15 #include <debug.h>
16 
17 /*
18  * We need to call the internal function in the service.c file
19  */
20 DWORD
21 KmtpCreateService(
22     IN PCWSTR ServiceName,
23     IN PCWSTR ServicePath,
24     IN PCWSTR DisplayName OPTIONAL,
25     IN DWORD ServiceType,
26     OUT SC_HANDLE *ServiceHandle);
27 
28 DWORD EnablePrivilegeInCurrentProcess(
29     _In_z_ LPWSTR lpPrivName,
30     _In_ BOOL bEnable);
31 
32 // move to a shared location
33 typedef struct _KMTFLT_MESSAGE_HEADER
34 {
35     ULONG Message;
36     PVOID Buffer;
37     ULONG BufferSize;
38 
39 } KMTFLT_MESSAGE_HEADER, *PKMTFLT_MESSAGE_HEADER;
40 
41 extern HANDLE KmtestHandle;
42 static WCHAR TestServiceName[MAX_PATH];
43 
44 
45 
46 /**
47  * @name KmtFltCreateService
48  *
49  * Create the specified driver service and return a handle to it
50  *
51  * @param ServiceName
52  *        Name of the service to create
53  * @param ServicePath
54  *        File name of the driver, relative to the current directory
55  * @param DisplayName
56  *        Service display name
57  * @param ServiceHandle
58  *        Pointer to a variable to receive the handle to the service
59  *
60  * @return Win32 error code
61  */
62 DWORD
KmtFltCreateService(_In_z_ PCWSTR ServiceName,_In_z_ PCWSTR DisplayName,_Out_ SC_HANDLE * ServiceHandle)63 KmtFltCreateService(
64     _In_z_ PCWSTR ServiceName,
65     _In_z_ PCWSTR DisplayName,
66     _Out_ SC_HANDLE *ServiceHandle)
67 {
68     WCHAR ServicePath[MAX_PATH];
69 
70     StringCbCopyW(ServicePath, sizeof(ServicePath), ServiceName);
71     StringCbCatW(ServicePath, sizeof(ServicePath), L"_drv.sys");
72 
73     StringCbCopyW(TestServiceName, sizeof(TestServiceName), L"Kmtest-");
74     StringCbCatW(TestServiceName, sizeof(TestServiceName), ServiceName);
75 
76     return KmtpCreateService(TestServiceName,
77                              ServicePath,
78                              DisplayName,
79                              SERVICE_FILE_SYSTEM_DRIVER,
80                              ServiceHandle);
81 }
82 
83 /**
84  * @name KmtFltDeleteService
85  *
86  * Delete the specified filter driver
87  *
88  * @param ServiceName
89  *        If *ServiceHandle is NULL, name of the service to delete
90  * @param ServiceHandle
91  *        Pointer to a variable containing the service handle.
92  *        Will be set to NULL on success
93  *
94  * @return Win32 error code
95  */
96 DWORD
KmtFltDeleteService(_In_opt_z_ PCWSTR ServiceName,_Inout_ SC_HANDLE * ServiceHandle)97 KmtFltDeleteService(
98     _In_opt_z_ PCWSTR ServiceName,
99     _Inout_ SC_HANDLE *ServiceHandle)
100 {
101     return KmtDeleteService(ServiceName, ServiceHandle);
102 }
103 
104 /**
105  * @name KmtFltLoadDriver
106  *
107  * Delete the specified filter driver
108  *
109  * @return Win32 error code
110  */
111 DWORD
KmtFltLoadDriver(_In_ BOOLEAN EnableDriverLoadPrivlege,_In_ BOOLEAN RestartIfRunning,_In_ BOOLEAN ConnectComms,_Out_ HANDLE * hPort)112 KmtFltLoadDriver(
113     _In_ BOOLEAN EnableDriverLoadPrivlege,
114     _In_ BOOLEAN RestartIfRunning,
115     _In_ BOOLEAN ConnectComms,
116     _Out_ HANDLE *hPort
117 )
118 {
119     DWORD Error;
120 
121     if (EnableDriverLoadPrivlege)
122     {
123         Error = EnablePrivilegeInCurrentProcess(SE_LOAD_DRIVER_NAME , TRUE);
124         if (Error)
125         {
126             return Error;
127         }
128     }
129 
130     Error = KmtFltLoad(TestServiceName);
131     if ((Error == ERROR_SERVICE_ALREADY_RUNNING) && RestartIfRunning)
132     {
133         Error = KmtFltUnload(TestServiceName);
134         if (Error)
135         {
136             // TODO
137             __debugbreak();
138         }
139 
140         Error = KmtFltLoad(TestServiceName);
141     }
142 
143     if (Error)
144     {
145         return Error;
146     }
147 
148     if (ConnectComms)
149     {
150         Error = KmtFltConnectComms(hPort);
151     }
152 
153     return Error;
154 }
155 
156 /**
157  * @name KmtFltUnloadDriver
158  *
159  * Unload the specified filter driver
160  *
161  * @param hPort
162  *        Handle to the filter's comms port
163  * @param DisonnectComms
164  *        TRUE to disconnect the comms connection before unloading
165  *
166  * @return Win32 error code
167  */
168 DWORD
KmtFltUnloadDriver(_In_ HANDLE * hPort,_In_ BOOLEAN DisonnectComms)169 KmtFltUnloadDriver(
170     _In_ HANDLE *hPort,
171     _In_ BOOLEAN DisonnectComms)
172 {
173     DWORD Error = ERROR_SUCCESS;
174 
175     if (DisonnectComms)
176     {
177         Error = KmtFltDisconnect(hPort);
178 
179         if (Error)
180         {
181             return Error;
182         }
183     }
184 
185     Error = KmtFltUnload(TestServiceName);
186 
187     if (Error)
188     {
189         // TODO
190         __debugbreak();
191     }
192 
193     return Error;
194 }
195 
196 /**
197  * @name KmtFltConnectComms
198  *
199  * Create a comms connection to the specified filter
200  *
201  * @param hPort
202  *        Handle to the filter's comms port
203  *
204  * @return Win32 error code
205  */
206 DWORD
KmtFltConnectComms(_Out_ HANDLE * hPort)207 KmtFltConnectComms(
208     _Out_ HANDLE *hPort)
209 {
210     return KmtFltConnect(TestServiceName, hPort);
211 }
212 
213 /**
214  * @name KmtFltDisconnectComms
215  *
216  * Disconenct from the comms port
217  *
218  * @param hPort
219  *        Handle to the filter's comms port
220  *
221  * @return Win32 error code
222  */
223 DWORD
KmtFltDisconnectComms(_In_ HANDLE hPort)224 KmtFltDisconnectComms(
225     _In_ HANDLE hPort)
226 {
227     return KmtFltDisconnect(hPort);
228 }
229 
230 
231 /**
232 * @name KmtFltCloseService
233 *
234 * Close the specified driver service handle
235 *
236 * @param ServiceHandle
237 *        Pointer to a variable containing the service handle.
238 *        Will be set to NULL on success
239 *
240 * @return Win32 error code
241 */
KmtFltCloseService(_Inout_ SC_HANDLE * ServiceHandle)242 DWORD KmtFltCloseService(
243     _Inout_ SC_HANDLE *ServiceHandle)
244 {
245     return KmtCloseService(ServiceHandle);
246 }
247 
248 /**
249 * @name KmtFltRunKernelTest
250 *
251 * Run the specified filter test part
252 *
253 * @param hPort
254 *        Handle to the filter's comms port
255 * @param TestName
256 *        Name of the test to run
257 *
258 * @return Win32 error code
259 */
260 DWORD
KmtFltRunKernelTest(_In_ HANDLE hPort,_In_z_ PCSTR TestName)261 KmtFltRunKernelTest(
262     _In_ HANDLE hPort,
263     _In_z_ PCSTR TestName)
264 {
265     return KmtFltSendStringToDriver(hPort, KMTFLT_RUN_TEST, TestName);
266 }
267 
268 /**
269 * @name KmtFltSendToDriver
270 *
271 * Send an I/O control message with no arguments to the driver opened with KmtOpenDriver
272 *
273 * @param hPort
274 *        Handle to the filter's comms port
275 * @param Message
276 *        The message to send to the filter
277 *
278 * @return Win32 error code
279 */
280 DWORD
KmtFltSendToDriver(_In_ HANDLE hPort,_In_ DWORD Message)281 KmtFltSendToDriver(
282     _In_ HANDLE hPort,
283     _In_ DWORD Message)
284 {
285     assert(hPort);
286     return KmtFltSendBufferToDriver(hPort, Message, NULL, 0, NULL, 0, NULL);
287 }
288 
289 /**
290  * @name KmtFltSendStringToDriver
291  *
292  * Send an I/O control message with a string argument to the driver opened with KmtOpenDriver
293  *
294  *
295  * @param hPort
296  *        Handle to the filter's comms port
297  * @param Message
298  *        The message associated with the string
299  * @param String
300  *        An ANSI string to send to the filter
301  *
302  * @return Win32 error code
303  */
304 DWORD
KmtFltSendStringToDriver(_In_ HANDLE hPort,_In_ DWORD Message,_In_ PCSTR String)305 KmtFltSendStringToDriver(
306     _In_ HANDLE hPort,
307     _In_ DWORD Message,
308     _In_ PCSTR String)
309 {
310     assert(hPort);
311     assert(String);
312     return KmtFltSendBufferToDriver(hPort, Message, (PVOID)String, (DWORD)strlen(String), NULL, 0, NULL);
313 }
314 
315 /**
316  * @name KmtFltSendWStringToDriver
317  *
318  * Send an I/O control message with a wide string argument to the driver opened with KmtOpenDriver
319  *
320  * @param hPort
321  *        Handle to the filter's comms port
322  * @param Message
323  *        The message associated with the string
324  * @param String
325  *        An wide string to send to the filter
326  *
327  * @return Win32 error code
328  */
329 DWORD
KmtFltSendWStringToDriver(_In_ HANDLE hPort,_In_ DWORD Message,_In_ PCWSTR String)330 KmtFltSendWStringToDriver(
331     _In_ HANDLE hPort,
332     _In_ DWORD Message,
333     _In_ PCWSTR String)
334 {
335     return KmtFltSendBufferToDriver(hPort, Message, (PVOID)String, (DWORD)wcslen(String) * sizeof(WCHAR), NULL, 0, NULL);
336 }
337 
338 /**
339  * @name KmtFltSendUlongToDriver
340  *
341  * Send an I/O control message with an integer argument to the driver opened with KmtOpenDriver
342  *
343  * @param hPort
344  *        Handle to the filter's comms port
345  * @param Message
346  *        The message associated with the value
347  * @param Value
348  *        An 32bit valueng to send to the filter
349  *
350  * @return Win32 error code
351  */
352 DWORD
KmtFltSendUlongToDriver(_In_ HANDLE hPort,_In_ DWORD Message,_In_ DWORD Value)353 KmtFltSendUlongToDriver(
354     _In_ HANDLE hPort,
355     _In_ DWORD Message,
356     _In_ DWORD Value)
357 {
358     return KmtFltSendBufferToDriver(hPort, Message, &Value, sizeof(Value), NULL, 0, NULL);
359 }
360 
361 /**
362  * @name KmtSendBufferToDriver
363  *
364  * Send an I/O control message with the specified arguments to the driver opened with KmtOpenDriver
365  *
366  * @param hPort
367  *        Handle to the filter's comms port
368  * @param Message
369  *        The message associated with the value
370  * @param InBuffer
371  *        Pointer to a buffer to send to the filter
372  * @param BufferSize
373  *        Size of the buffer pointed to by InBuffer
374  * @param OutBuffer
375  *        Pointer to a buffer to receive a response from the filter
376  * @param OutBufferSize
377  *        Size of the buffer pointed to by OutBuffer
378  * @param BytesReturned
379  *        Number of bytes written in the reply buffer
380  *
381  * @return Win32 error code
382  */
383 DWORD
KmtFltSendBufferToDriver(_In_ HANDLE hPort,_In_ DWORD Message,_In_reads_bytes_ (BufferSize)LPVOID InBuffer,_In_ DWORD BufferSize,_Out_writes_bytes_to_opt_ (OutBufferSize,* BytesReturned)LPVOID OutBuffer,_In_ DWORD OutBufferSize,_Out_opt_ LPDWORD BytesReturned)384 KmtFltSendBufferToDriver(
385     _In_ HANDLE hPort,
386     _In_ DWORD Message,
387     _In_reads_bytes_(BufferSize) LPVOID InBuffer,
388     _In_ DWORD BufferSize,
389     _Out_writes_bytes_to_opt_(OutBufferSize, *BytesReturned) LPVOID OutBuffer,
390     _In_ DWORD OutBufferSize,
391     _Out_opt_ LPDWORD BytesReturned)
392 {
393     PKMTFLT_MESSAGE_HEADER Ptr;
394     KMTFLT_MESSAGE_HEADER Header;
395     BOOLEAN FreeMemory = FALSE;
396     DWORD InBufferSize;
397     DWORD Error;
398 
399     assert(hPort);
400 
401     if (BufferSize)
402     {
403         assert(InBuffer);
404 
405         InBufferSize = sizeof(KMTFLT_MESSAGE_HEADER) + BufferSize;
406         Ptr = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, InBufferSize);
407         if (!Ptr)
408         {
409             return ERROR_NOT_ENOUGH_MEMORY;
410         }
411         FreeMemory = TRUE;
412     }
413     else
414     {
415         InBufferSize = sizeof(KMTFLT_MESSAGE_HEADER);
416         Ptr = &Header;
417     }
418 
419     Ptr->Message = Message;
420     if (BufferSize)
421     {
422         Ptr->Buffer = (Ptr + 1);
423         StringCbCopy(Ptr->Buffer, BufferSize, InBuffer);
424         Ptr->BufferSize = BufferSize;
425     }
426 
427     Error = KmtFltSendMessage(hPort, Ptr, InBufferSize, OutBuffer, OutBufferSize, BytesReturned);
428 
429     if (FreeMemory)
430     {
431         HeapFree(GetProcessHeap(), 0, Ptr);
432     }
433 
434     return Error;
435 }
436 
437 /**
438 * @name KmtFltAddAltitude
439 *
440 * Sets up the mini-filter altitude data in the registry
441 *
442 * @param hPort
443 *        The altitude string to set
444 *
445 * @return Win32 error code
446 */
447 DWORD
KmtFltAddAltitude(_In_z_ LPWSTR Altitude)448 KmtFltAddAltitude(
449     _In_z_ LPWSTR Altitude)
450 {
451     WCHAR DefaultInstance[128];
452     WCHAR KeyPath[256];
453     HKEY hKey = NULL;
454     HKEY hSubKey = NULL;
455     DWORD Zero = 0;
456     LONG Error;
457 
458     StringCbCopyW(KeyPath, sizeof(KeyPath), L"SYSTEM\\CurrentControlSet\\Services\\");
459     StringCbCatW(KeyPath, sizeof(KeyPath), TestServiceName);
460     StringCbCatW(KeyPath, sizeof(KeyPath), L"\\Instances\\");
461 
462     Error = RegCreateKeyEx(HKEY_LOCAL_MACHINE,
463                            KeyPath,
464                            0,
465                            NULL,
466                            REG_OPTION_NON_VOLATILE,
467                            KEY_CREATE_SUB_KEY | KEY_SET_VALUE,
468                            NULL,
469                            &hKey,
470                            NULL);
471     if (Error != ERROR_SUCCESS)
472     {
473         return Error;
474     }
475 
476     StringCbCopyW(DefaultInstance, sizeof(DefaultInstance), TestServiceName);
477     StringCbCatW(DefaultInstance, sizeof(DefaultInstance), L" Instance");
478 
479     Error = RegSetValueExW(hKey,
480                            L"DefaultInstance",
481                            0,
482                            REG_SZ,
483                            (LPBYTE)DefaultInstance,
484                            (lstrlenW(DefaultInstance) + 1) * sizeof(WCHAR));
485     if (Error != ERROR_SUCCESS)
486     {
487         goto Quit;
488     }
489 
490     Error = RegCreateKeyW(hKey, DefaultInstance, &hSubKey);
491     if (Error != ERROR_SUCCESS)
492     {
493         goto Quit;
494     }
495 
496     Error = RegSetValueExW(hSubKey,
497                            L"Altitude",
498                            0,
499                            REG_SZ,
500                            (LPBYTE)Altitude,
501                            (lstrlenW(Altitude) + 1) * sizeof(WCHAR));
502     if (Error != ERROR_SUCCESS)
503     {
504         goto Quit;
505     }
506 
507     Error = RegSetValueExW(hSubKey,
508                            L"Flags",
509                            0,
510                            REG_DWORD,
511                            (LPBYTE)&Zero,
512                            sizeof(DWORD));
513 
514 Quit:
515     if (hSubKey)
516     {
517         RegCloseKey(hSubKey);
518     }
519     if (hKey)
520     {
521         RegCloseKey(hKey);
522     }
523 
524     return Error;
525 
526 }
527 
528 /*
529 * Private functions, not meant for use in kmtests
530 */
531 
EnablePrivilege(_In_ HANDLE hToken,_In_z_ LPWSTR lpPrivName,_In_ BOOL bEnable)532 DWORD EnablePrivilege(
533     _In_ HANDLE hToken,
534     _In_z_ LPWSTR lpPrivName,
535     _In_ BOOL bEnable)
536 {
537     TOKEN_PRIVILEGES TokenPrivileges;
538     LUID luid;
539     BOOL bSuccess;
540     DWORD dwError = ERROR_SUCCESS;
541 
542     /* Get the luid for this privilege */
543     if (!LookupPrivilegeValueW(NULL, lpPrivName, &luid))
544         return GetLastError();
545 
546     /* Setup the struct with the priv info */
547     TokenPrivileges.PrivilegeCount = 1;
548     TokenPrivileges.Privileges[0].Luid = luid;
549     TokenPrivileges.Privileges[0].Attributes = bEnable ? SE_PRIVILEGE_ENABLED : 0;
550 
551     /* Enable the privilege info in the token */
552     bSuccess = AdjustTokenPrivileges(hToken,
553                                      FALSE,
554                                      &TokenPrivileges,
555                                      sizeof(TOKEN_PRIVILEGES),
556                                      NULL,
557                                      NULL);
558     if (bSuccess == FALSE) dwError = GetLastError();
559 
560     /* return status */
561     return dwError;
562 }
563 
EnablePrivilegeInCurrentProcess(_In_z_ LPWSTR lpPrivName,_In_ BOOL bEnable)564 DWORD EnablePrivilegeInCurrentProcess(
565     _In_z_ LPWSTR lpPrivName,
566     _In_ BOOL bEnable)
567 {
568     HANDLE hToken;
569     BOOL bSuccess;
570     DWORD dwError = ERROR_SUCCESS;
571 
572     /* Get a handle to our token */
573     bSuccess = OpenProcessToken(GetCurrentProcess(),
574                                 TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY,
575                                 &hToken);
576     if (bSuccess == FALSE) return GetLastError();
577 
578     /* Enable the privilege in the agent token */
579     dwError = EnablePrivilege(hToken, lpPrivName, bEnable);
580 
581     /* We're done with this now */
582     CloseHandle(hToken);
583 
584     /* return status */
585     return dwError;
586 }
587