1 /*
2  * PROJECT:         ReactOS kernel-mode tests
3  * LICENSE:         GPLv2+ - See COPYING in the top level directory
4  * PURPOSE:         File system mini-filter support routines
5  * PROGRAMMER:      Ged Murphy <gedmurphy@reactos.org>
6  */
7 
8 #include <kmt_test.h>
9 
10 #define KMT_FLT_USER_MODE
11 #include "kmtest.h"
12 #include <kmt_public.h>
13 
14 #include <ndk/setypes.h>
15 #include <assert.h>
16 #include <debug.h>
17 
18 /*
19  * We need to call the internal function in the service.c file
20  */
21 DWORD
22 KmtpCreateService(
23     IN PCWSTR ServiceName,
24     IN PCWSTR ServicePath,
25     IN PCWSTR DisplayName OPTIONAL,
26     IN DWORD ServiceType,
27     OUT SC_HANDLE *ServiceHandle);
28 
29 // move to a shared location
30 typedef struct _KMTFLT_MESSAGE_HEADER
31 {
32     ULONG Message;
33     PVOID Buffer;
34     ULONG BufferSize;
35 
36 } KMTFLT_MESSAGE_HEADER, *PKMTFLT_MESSAGE_HEADER;
37 
38 extern HANDLE KmtestHandle;
39 static WCHAR TestServiceName[MAX_PATH];
40 
41 
42 
43 /**
44  * @name KmtFltCreateService
45  *
46  * Create the specified driver service and return a handle to it
47  *
48  * @param ServiceName
49  *        Name of the service to create
50  * @param ServicePath
51  *        File name of the driver, relative to the current directory
52  * @param DisplayName
53  *        Service display name
54  * @param ServiceHandle
55  *        Pointer to a variable to receive the handle to the service
56  *
57  * @return Win32 error code
58  */
59 DWORD
60 KmtFltCreateService(
61     _In_z_ PCWSTR ServiceName,
62     _In_z_ PCWSTR DisplayName,
63     _Out_ SC_HANDLE *ServiceHandle)
64 {
65     WCHAR ServicePath[MAX_PATH];
66 
67     StringCbCopyW(ServicePath, sizeof(ServicePath), ServiceName);
68     StringCbCatW(ServicePath, sizeof(ServicePath), L"_drv.sys");
69 
70     StringCbCopyW(TestServiceName, sizeof(TestServiceName), L"Kmtest-");
71     StringCbCatW(TestServiceName, sizeof(TestServiceName), ServiceName);
72 
73     return KmtpCreateService(TestServiceName,
74                              ServicePath,
75                              DisplayName,
76                              SERVICE_FILE_SYSTEM_DRIVER,
77                              ServiceHandle);
78 }
79 
80 /**
81  * @name KmtFltDeleteService
82  *
83  * Delete the specified filter driver
84  *
85  * @param ServiceName
86  *        If *ServiceHandle is NULL, name of the service to delete
87  * @param ServiceHandle
88  *        Pointer to a variable containing the service handle.
89  *        Will be set to NULL on success
90  *
91  * @return Win32 error code
92  */
93 DWORD
94 KmtFltDeleteService(
95     _In_opt_z_ PCWSTR ServiceName,
96     _Inout_ SC_HANDLE *ServiceHandle)
97 {
98     return KmtDeleteService(ServiceName, ServiceHandle);
99 }
100 
101 /**
102  * @name KmtFltLoadDriver
103  *
104  * Delete the specified filter driver
105  *
106  * @return Win32 error code
107  */
108 DWORD
109 KmtFltLoadDriver(
110     _In_ BOOLEAN EnableDriverLoadPrivilege,
111     _In_ BOOLEAN RestartIfRunning,
112     _In_ BOOLEAN ConnectComms,
113     _Out_ HANDLE *hPort
114 )
115 {
116     DWORD Error;
117 
118     if (EnableDriverLoadPrivilege)
119     {
120         BOOLEAN WasEnabled;
121         Error = RtlNtStatusToDosError(RtlAdjustPrivilege(
122                     SE_LOAD_DRIVER_PRIVILEGE,
123                     TRUE,
124                     FALSE, // Enable in current process.
125                     &WasEnabled));
126         if (Error)
127             return Error;
128     }
129 
130     Error = KmtFltLoad(TestServiceName);
131     if ((Error == ERROR_SERVICE_ALREADY_RUNNING) && RestartIfRunning)
132     {
133         Error = KmtFltUnload(TestServiceName);
134         if (Error)
135         {
136             // TODO
137             __debugbreak();
138         }
139 
140         Error = KmtFltLoad(TestServiceName);
141     }
142 
143     if (Error)
144         return Error;
145 
146     if (ConnectComms)
147         Error = KmtFltConnectComms(hPort);
148 
149     return Error;
150 }
151 
152 /**
153  * @name KmtFltUnloadDriver
154  *
155  * Unload the specified filter driver
156  *
157  * @param hPort
158  *        Handle to the filter's comms port
159  * @param DisonnectComms
160  *        TRUE to disconnect the comms connection before unloading
161  *
162  * @return Win32 error code
163  */
164 DWORD
165 KmtFltUnloadDriver(
166     _In_ HANDLE *hPort,
167     _In_ BOOLEAN DisonnectComms)
168 {
169     DWORD Error = ERROR_SUCCESS;
170 
171     if (DisonnectComms)
172     {
173         Error = KmtFltDisconnect(hPort);
174 
175         if (Error)
176         {
177             return Error;
178         }
179     }
180 
181     Error = KmtFltUnload(TestServiceName);
182 
183     if (Error)
184     {
185         // TODO
186         __debugbreak();
187     }
188 
189     return Error;
190 }
191 
192 /**
193  * @name KmtFltConnectComms
194  *
195  * Create a comms connection to the specified filter
196  *
197  * @param hPort
198  *        Handle to the filter's comms port
199  *
200  * @return Win32 error code
201  */
202 DWORD
203 KmtFltConnectComms(
204     _Out_ HANDLE *hPort)
205 {
206     return KmtFltConnect(TestServiceName, hPort);
207 }
208 
209 /**
210  * @name KmtFltDisconnectComms
211  *
212  * Disconenct from the comms port
213  *
214  * @param hPort
215  *        Handle to the filter's comms port
216  *
217  * @return Win32 error code
218  */
219 DWORD
220 KmtFltDisconnectComms(
221     _In_ HANDLE hPort)
222 {
223     return KmtFltDisconnect(hPort);
224 }
225 
226 
227 /**
228 * @name KmtFltCloseService
229 *
230 * Close the specified driver service handle
231 *
232 * @param ServiceHandle
233 *        Pointer to a variable containing the service handle.
234 *        Will be set to NULL on success
235 *
236 * @return Win32 error code
237 */
238 DWORD KmtFltCloseService(
239     _Inout_ SC_HANDLE *ServiceHandle)
240 {
241     return KmtCloseService(ServiceHandle);
242 }
243 
244 /**
245 * @name KmtFltRunKernelTest
246 *
247 * Run the specified filter test part
248 *
249 * @param hPort
250 *        Handle to the filter's comms port
251 * @param TestName
252 *        Name of the test to run
253 *
254 * @return Win32 error code
255 */
256 DWORD
257 KmtFltRunKernelTest(
258     _In_ HANDLE hPort,
259     _In_z_ PCSTR TestName)
260 {
261     return KmtFltSendStringToDriver(hPort, KMTFLT_RUN_TEST, TestName);
262 }
263 
264 /**
265 * @name KmtFltSendToDriver
266 *
267 * Send an I/O control message with no arguments to the driver opened with KmtOpenDriver
268 *
269 * @param hPort
270 *        Handle to the filter's comms port
271 * @param Message
272 *        The message to send to the filter
273 *
274 * @return Win32 error code
275 */
276 DWORD
277 KmtFltSendToDriver(
278     _In_ HANDLE hPort,
279     _In_ DWORD Message)
280 {
281     assert(hPort);
282     return KmtFltSendBufferToDriver(hPort, Message, NULL, 0, NULL, 0, NULL);
283 }
284 
285 /**
286  * @name KmtFltSendStringToDriver
287  *
288  * Send an I/O control message with a string argument to the driver opened with KmtOpenDriver
289  *
290  *
291  * @param hPort
292  *        Handle to the filter's comms port
293  * @param Message
294  *        The message associated with the string
295  * @param String
296  *        An ANSI string to send to the filter
297  *
298  * @return Win32 error code
299  */
300 DWORD
301 KmtFltSendStringToDriver(
302     _In_ HANDLE hPort,
303     _In_ DWORD Message,
304     _In_ PCSTR String)
305 {
306     assert(hPort);
307     assert(String);
308     return KmtFltSendBufferToDriver(hPort, Message, (PVOID)String, (DWORD)strlen(String), NULL, 0, NULL);
309 }
310 
311 /**
312  * @name KmtFltSendWStringToDriver
313  *
314  * Send an I/O control message with a wide string argument to the driver opened with KmtOpenDriver
315  *
316  * @param hPort
317  *        Handle to the filter's comms port
318  * @param Message
319  *        The message associated with the string
320  * @param String
321  *        An wide string to send to the filter
322  *
323  * @return Win32 error code
324  */
325 DWORD
326 KmtFltSendWStringToDriver(
327     _In_ HANDLE hPort,
328     _In_ DWORD Message,
329     _In_ PCWSTR String)
330 {
331     return KmtFltSendBufferToDriver(hPort, Message, (PVOID)String, (DWORD)wcslen(String) * sizeof(WCHAR), NULL, 0, NULL);
332 }
333 
334 /**
335  * @name KmtFltSendUlongToDriver
336  *
337  * Send an I/O control message with an integer argument to the driver opened with KmtOpenDriver
338  *
339  * @param hPort
340  *        Handle to the filter's comms port
341  * @param Message
342  *        The message associated with the value
343  * @param Value
344  *        An 32bit valueng to send to the filter
345  *
346  * @return Win32 error code
347  */
348 DWORD
349 KmtFltSendUlongToDriver(
350     _In_ HANDLE hPort,
351     _In_ DWORD Message,
352     _In_ DWORD Value)
353 {
354     return KmtFltSendBufferToDriver(hPort, Message, &Value, sizeof(Value), NULL, 0, NULL);
355 }
356 
357 /**
358  * @name KmtSendBufferToDriver
359  *
360  * Send an I/O control message with the specified arguments to the driver opened with KmtOpenDriver
361  *
362  * @param hPort
363  *        Handle to the filter's comms port
364  * @param Message
365  *        The message associated with the value
366  * @param InBuffer
367  *        Pointer to a buffer to send to the filter
368  * @param BufferSize
369  *        Size of the buffer pointed to by InBuffer
370  * @param OutBuffer
371  *        Pointer to a buffer to receive a response from the filter
372  * @param OutBufferSize
373  *        Size of the buffer pointed to by OutBuffer
374  * @param BytesReturned
375  *        Number of bytes written in the reply buffer
376  *
377  * @return Win32 error code
378  */
379 DWORD
380 KmtFltSendBufferToDriver(
381     _In_ HANDLE hPort,
382     _In_ DWORD Message,
383     _In_reads_bytes_(BufferSize) LPVOID InBuffer,
384     _In_ DWORD BufferSize,
385     _Out_writes_bytes_to_opt_(OutBufferSize, *BytesReturned) LPVOID OutBuffer,
386     _In_ DWORD OutBufferSize,
387     _Out_opt_ LPDWORD BytesReturned)
388 {
389     PKMTFLT_MESSAGE_HEADER Ptr;
390     KMTFLT_MESSAGE_HEADER Header;
391     BOOLEAN FreeMemory = FALSE;
392     DWORD InBufferSize;
393     DWORD Error;
394 
395     assert(hPort);
396 
397     if (BufferSize)
398     {
399         assert(InBuffer);
400 
401         InBufferSize = sizeof(KMTFLT_MESSAGE_HEADER) + BufferSize;
402         Ptr = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, InBufferSize);
403         if (!Ptr)
404         {
405             return ERROR_NOT_ENOUGH_MEMORY;
406         }
407         FreeMemory = TRUE;
408     }
409     else
410     {
411         InBufferSize = sizeof(KMTFLT_MESSAGE_HEADER);
412         Ptr = &Header;
413     }
414 
415     Ptr->Message = Message;
416     if (BufferSize)
417     {
418         Ptr->Buffer = (Ptr + 1);
419         StringCbCopy(Ptr->Buffer, BufferSize, InBuffer);
420         Ptr->BufferSize = BufferSize;
421     }
422 
423     Error = KmtFltSendMessage(hPort, Ptr, InBufferSize, OutBuffer, OutBufferSize, BytesReturned);
424 
425     if (FreeMemory)
426     {
427         HeapFree(GetProcessHeap(), 0, Ptr);
428     }
429 
430     return Error;
431 }
432 
433 /**
434 * @name KmtFltAddAltitude
435 *
436 * Sets up the mini-filter altitude data in the registry
437 *
438 * @param hPort
439 *        The altitude string to set
440 *
441 * @return Win32 error code
442 */
443 DWORD
444 KmtFltAddAltitude(
445     _In_z_ LPWSTR Altitude)
446 {
447     WCHAR DefaultInstance[128];
448     WCHAR KeyPath[256];
449     HKEY hKey = NULL;
450     HKEY hSubKey = NULL;
451     DWORD Zero = 0;
452     LONG Error;
453 
454     StringCbCopyW(KeyPath, sizeof(KeyPath), L"SYSTEM\\CurrentControlSet\\Services\\");
455     StringCbCatW(KeyPath, sizeof(KeyPath), TestServiceName);
456     StringCbCatW(KeyPath, sizeof(KeyPath), L"\\Instances\\");
457 
458     Error = RegCreateKeyEx(HKEY_LOCAL_MACHINE,
459                            KeyPath,
460                            0,
461                            NULL,
462                            REG_OPTION_NON_VOLATILE,
463                            KEY_CREATE_SUB_KEY | KEY_SET_VALUE,
464                            NULL,
465                            &hKey,
466                            NULL);
467     if (Error != ERROR_SUCCESS)
468     {
469         return Error;
470     }
471 
472     StringCbCopyW(DefaultInstance, sizeof(DefaultInstance), TestServiceName);
473     StringCbCatW(DefaultInstance, sizeof(DefaultInstance), L" Instance");
474 
475     Error = RegSetValueExW(hKey,
476                            L"DefaultInstance",
477                            0,
478                            REG_SZ,
479                            (LPBYTE)DefaultInstance,
480                            (lstrlenW(DefaultInstance) + 1) * sizeof(WCHAR));
481     if (Error != ERROR_SUCCESS)
482     {
483         goto Quit;
484     }
485 
486     Error = RegCreateKeyW(hKey, DefaultInstance, &hSubKey);
487     if (Error != ERROR_SUCCESS)
488     {
489         goto Quit;
490     }
491 
492     Error = RegSetValueExW(hSubKey,
493                            L"Altitude",
494                            0,
495                            REG_SZ,
496                            (LPBYTE)Altitude,
497                            (lstrlenW(Altitude) + 1) * sizeof(WCHAR));
498     if (Error != ERROR_SUCCESS)
499     {
500         goto Quit;
501     }
502 
503     Error = RegSetValueExW(hSubKey,
504                            L"Flags",
505                            0,
506                            REG_DWORD,
507                            (LPBYTE)&Zero,
508                            sizeof(DWORD));
509 
510 Quit:
511     if (hSubKey)
512     {
513         RegCloseKey(hSubKey);
514     }
515     if (hKey)
516     {
517         RegCloseKey(hKey);
518     }
519 
520     return Error;
521 
522 }
523