1 /*
2  * PROJECT:         ReactOS kernel-mode tests
3  * LICENSE:         GPLv2+ - See COPYING in the top level directory
4  * PURPOSE:         Kernel-Mode Test Suite Example Test Driver
5  * PROGRAMMER:      Thomas Faber <thomas.faber@reactos.org>
6  */
7 
8 #include <ntddk.h>
9 #include <ntifs.h>
10 #include <ndk/ketypes.h>
11 
12 #define KMT_DEFINE_TEST_FUNCTIONS
13 #include <kmt_test.h>
14 
15 #define NDEBUG
16 #include <debug.h>
17 
18 #include <kmt_public.h>
19 
20 /* types */
21 typedef struct
22 {
23     UCHAR MajorFunction;
24     PDEVICE_OBJECT DeviceObject;
25     PKMT_IRP_HANDLER IrpHandler;
26 } KMT_IRP_HANDLER_ENTRY, *PKMT_IRP_HANDLER_ENTRY;
27 
28 typedef struct
29 {
30     ULONG ControlCode;
31     PDEVICE_OBJECT DeviceObject;
32     PKMT_MESSAGE_HANDLER MessageHandler;
33 } KMT_MESSAGE_HANDLER_ENTRY, *PKMT_MESSAGE_HANDLER_ENTRY;
34 
35 /* Prototypes */
36 DRIVER_INITIALIZE DriverEntry;
37 static DRIVER_UNLOAD DriverUnload;
38 static DRIVER_DISPATCH DriverDispatch;
39 static KMT_IRP_HANDLER DeviceControlHandler;
40 
41 /* Globals */
42 static PDEVICE_OBJECT TestDeviceObject;
43 static PDEVICE_OBJECT KmtestDeviceObject;
44 
45 #define KMT_MAX_IRP_HANDLERS 256
46 static KMT_IRP_HANDLER_ENTRY IrpHandlers[KMT_MAX_IRP_HANDLERS] = { { 0 } };
47 #define KMT_MAX_MESSAGE_HANDLERS 256
48 static KMT_MESSAGE_HANDLER_ENTRY MessageHandlers[KMT_MAX_MESSAGE_HANDLERS] = { { 0 } };
49 
50 /**
51  * @name DriverEntry
52  *
53  * Driver entry point.
54  *
55  * @param DriverObject
56  *        Driver Object
57  * @param RegistryPath
58  *        Driver Registry Path
59  *
60  * @return Status
61  */
62 NTSTATUS
63 NTAPI
64 DriverEntry(
65     IN PDRIVER_OBJECT DriverObject,
66     IN PUNICODE_STRING RegistryPath)
67 {
68     NTSTATUS Status = STATUS_SUCCESS;
69     WCHAR DeviceNameBuffer[128] = L"\\Device\\Kmtest-";
70     UNICODE_STRING KmtestDeviceName;
71     PFILE_OBJECT KmtestFileObject;
72     PKMT_DEVICE_EXTENSION KmtestDeviceExtension;
73     UNICODE_STRING DeviceName;
74     PCWSTR DeviceNameSuffix;
75     INT Flags = 0;
76     int i;
77     PKPRCB Prcb;
78 
79     PAGED_CODE();
80 
81     DPRINT("DriverEntry\n");
82 
83     Prcb = KeGetCurrentPrcb();
84     KmtIsCheckedBuild = (Prcb->BuildType & PRCB_BUILD_DEBUG) != 0;
85     KmtIsMultiProcessorBuild = (Prcb->BuildType & PRCB_BUILD_UNIPROCESSOR) == 0;
86 
87     /* get the Kmtest device, so that we get a ResultBuffer pointer */
88     RtlInitUnicodeString(&KmtestDeviceName, KMTEST_DEVICE_DRIVER_PATH);
89     Status = IoGetDeviceObjectPointer(&KmtestDeviceName, FILE_ALL_ACCESS, &KmtestFileObject, &KmtestDeviceObject);
90 
91     if (!NT_SUCCESS(Status))
92     {
93         DPRINT1("Failed to get Kmtest device object pointer\n");
94         goto cleanup;
95     }
96 
97     Status = ObReferenceObjectByPointer(KmtestDeviceObject, FILE_ALL_ACCESS, NULL, KernelMode);
98 
99     if (!NT_SUCCESS(Status))
100     {
101         DPRINT1("Failed to reference Kmtest device object\n");
102         goto cleanup;
103     }
104 
105     ObDereferenceObject(KmtestFileObject);
106     KmtestFileObject = NULL;
107     KmtestDeviceExtension = KmtestDeviceObject->DeviceExtension;
108     ResultBuffer = KmtestDeviceExtension->ResultBuffer;
109     DPRINT("KmtestDeviceObject: %p\n", (PVOID)KmtestDeviceObject);
110     DPRINT("KmtestDeviceExtension: %p\n", (PVOID)KmtestDeviceExtension);
111     DPRINT("Setting ResultBuffer: %p\n", (PVOID)ResultBuffer);
112 
113     /* call TestEntry */
114     RtlInitUnicodeString(&DeviceName, DeviceNameBuffer);
115     DeviceName.MaximumLength = sizeof DeviceNameBuffer;
116     TestEntry(DriverObject, RegistryPath, &DeviceNameSuffix, &Flags);
117 
118     /* create test device */
119     if (!(Flags & TESTENTRY_NO_CREATE_DEVICE))
120     {
121         RtlAppendUnicodeToString(&DeviceName, DeviceNameSuffix);
122         Status = IoCreateDevice(DriverObject, 0, &DeviceName,
123                                 FILE_DEVICE_UNKNOWN,
124                                 FILE_DEVICE_SECURE_OPEN |
125                                     (Flags & TESTENTRY_NO_READONLY_DEVICE ? 0 : FILE_READ_ONLY_DEVICE),
126                                 Flags & TESTENTRY_NO_EXCLUSIVE_DEVICE ? FALSE : TRUE,
127                                 &TestDeviceObject);
128 
129         if (!NT_SUCCESS(Status))
130         {
131             DPRINT1("Could not create device object %wZ\n", &DeviceName);
132             goto cleanup;
133         }
134 
135         if (Flags & TESTENTRY_BUFFERED_IO_DEVICE)
136             TestDeviceObject->Flags |= DO_BUFFERED_IO;
137 
138         DPRINT("DriverEntry. Created DeviceObject %p\n",
139                  TestDeviceObject);
140     }
141 
142     /* initialize dispatch functions */
143     if (!(Flags & TESTENTRY_NO_REGISTER_UNLOAD))
144         DriverObject->DriverUnload = DriverUnload;
145     if (!(Flags & TESTENTRY_NO_REGISTER_DISPATCH))
146         for (i = 0; i <= IRP_MJ_MAXIMUM_FUNCTION; ++i)
147             DriverObject->MajorFunction[i] = DriverDispatch;
148 
149 cleanup:
150     if (TestDeviceObject && !NT_SUCCESS(Status))
151     {
152         IoDeleteDevice(TestDeviceObject);
153         TestDeviceObject = NULL;
154     }
155 
156     if (KmtestDeviceObject && !NT_SUCCESS(Status))
157     {
158         ObDereferenceObject(KmtestDeviceObject);
159         KmtestDeviceObject = NULL;
160         if (KmtestFileObject)
161             ObDereferenceObject(KmtestFileObject);
162     }
163 
164     return Status;
165 }
166 
167 /**
168  * @name DriverUnload
169  *
170  * Driver cleanup funtion.
171  *
172  * @param DriverObject
173  *        Driver Object
174  */
175 static
176 VOID
177 NTAPI
178 DriverUnload(
179     IN PDRIVER_OBJECT DriverObject)
180 {
181     PAGED_CODE();
182 
183     UNREFERENCED_PARAMETER(DriverObject);
184 
185     DPRINT("DriverUnload\n");
186 
187     TestUnload(DriverObject);
188 
189     if (TestDeviceObject)
190         IoDeleteDevice(TestDeviceObject);
191 
192     if (KmtestDeviceObject)
193         ObDereferenceObject(KmtestDeviceObject);
194 }
195 
196 /**
197  * @name KmtRegisterIrpHandler
198  *
199  * Register a handler with the IRP Dispatcher.
200  * If multiple registered handlers match an IRP, it is unspecified which of
201  * them is called on IRP reception
202  *
203  * @param MajorFunction
204  *        IRP major function code to be handled
205  * @param DeviceObject
206  *        Device Object to handle IRPs for.
207  *        Can be NULL to indicate any device object
208  * @param IrpHandler
209  *        Handler function to register.
210  *
211  * @return Status
212  */
213 NTSTATUS
214 KmtRegisterIrpHandler(
215     IN UCHAR MajorFunction,
216     IN PDEVICE_OBJECT DeviceObject OPTIONAL,
217     IN PKMT_IRP_HANDLER IrpHandler)
218 {
219     NTSTATUS Status = STATUS_SUCCESS;
220     int i;
221 
222     if (MajorFunction > IRP_MJ_MAXIMUM_FUNCTION)
223     {
224         Status = STATUS_INVALID_PARAMETER_1;
225         goto cleanup;
226     }
227 
228     if (IrpHandler == NULL)
229     {
230         Status = STATUS_INVALID_PARAMETER_3;
231         goto cleanup;
232     }
233 
234     for (i = 0; i < sizeof IrpHandlers / sizeof IrpHandlers[0]; ++i)
235         if (IrpHandlers[i].IrpHandler == NULL)
236         {
237             IrpHandlers[i].MajorFunction = MajorFunction;
238             IrpHandlers[i].DeviceObject = DeviceObject;
239             IrpHandlers[i].IrpHandler = IrpHandler;
240             goto cleanup;
241         }
242 
243     Status = STATUS_ALLOTTED_SPACE_EXCEEDED;
244 
245 cleanup:
246     return Status;
247 }
248 
249 /**
250  * @name KmtUnregisterIrpHandler
251  *
252  * Unregister a handler with the IRP Dispatcher.
253  * Parameters must be specified exactly as in the call to
254  * KmtRegisterIrpHandler. Only the first matching entry will be removed
255  * if multiple exist
256  *
257  * @param MajorFunction
258  *        IRP major function code of the handler to be removed
259  * @param DeviceObject
260  *        Device Object to of the handler to be removed
261  * @param IrpHandler
262  *        Handler function of the handler to be removed
263  *
264  * @return Status
265  */
266 NTSTATUS
267 KmtUnregisterIrpHandler(
268     IN UCHAR MajorFunction,
269     IN PDEVICE_OBJECT DeviceObject OPTIONAL,
270     IN PKMT_IRP_HANDLER IrpHandler)
271 {
272     NTSTATUS Status = STATUS_SUCCESS;
273     int i;
274 
275     for (i = 0; i < sizeof IrpHandlers / sizeof IrpHandlers[0]; ++i)
276         if (IrpHandlers[i].MajorFunction == MajorFunction &&
277                 IrpHandlers[i].DeviceObject == DeviceObject &&
278                 IrpHandlers[i].IrpHandler == IrpHandler)
279         {
280             IrpHandlers[i].IrpHandler = NULL;
281             goto cleanup;
282         }
283 
284     Status = STATUS_NOT_FOUND;
285 
286 cleanup:
287     return Status;
288 }
289 
290 /**
291  * @name DriverDispatch
292  *
293  * Driver Dispatch function
294  *
295  * @param DeviceObject
296  *        Device Object
297  * @param Irp
298  *        I/O request packet
299  *
300  * @return Status
301  */
302 static
303 NTSTATUS
304 NTAPI
305 DriverDispatch(
306     IN PDEVICE_OBJECT DeviceObject,
307     IN PIRP Irp)
308 {
309     NTSTATUS Status = STATUS_INVALID_DEVICE_REQUEST;
310     PIO_STACK_LOCATION IoStackLocation;
311     int i;
312 
313     IoStackLocation = IoGetCurrentIrpStackLocation(Irp);
314 
315     DPRINT("DriverDispatch: Function=%s, Device=%p\n",
316             KmtMajorFunctionNames[IoStackLocation->MajorFunction],
317             DeviceObject);
318 
319     for (i = 0; i < sizeof IrpHandlers / sizeof IrpHandlers[0]; ++i)
320     {
321         if (IrpHandlers[i].MajorFunction == IoStackLocation->MajorFunction &&
322                 (IrpHandlers[i].DeviceObject == NULL || IrpHandlers[i].DeviceObject == DeviceObject) &&
323                 IrpHandlers[i].IrpHandler != NULL)
324             return IrpHandlers[i].IrpHandler(DeviceObject, Irp, IoStackLocation);
325     }
326 
327     /* default handler for DeviceControl */
328     if (IoStackLocation->MajorFunction == IRP_MJ_DEVICE_CONTROL ||
329             IoStackLocation->MajorFunction == IRP_MJ_INTERNAL_DEVICE_CONTROL)
330         return DeviceControlHandler(DeviceObject, Irp, IoStackLocation);
331 
332     /* Return success for create, close, and cleanup */
333     if (IoStackLocation->MajorFunction == IRP_MJ_CREATE ||
334             IoStackLocation->MajorFunction == IRP_MJ_CLOSE ||
335             IoStackLocation->MajorFunction == IRP_MJ_CLEANUP)
336         Status = STATUS_SUCCESS;
337 
338     /* default handler */
339     Irp->IoStatus.Status = Status;
340     Irp->IoStatus.Information = 0;
341 
342     IoCompleteRequest(Irp, IO_NO_INCREMENT);
343 
344     return Status;
345 }
346 
347 /**
348  * @name KmtRegisterMessageHandler
349  *
350  * Register a handler with the DeviceControl Dispatcher.
351  * If multiple registered handlers match a message, it is unspecified which of
352  * them is called on message reception.
353  * NOTE: message handlers registered with this function will not be called
354  *       if a custom IRP handler matching the corresponding IRP is installed!
355  *
356  * @param ControlCode
357  *        Control code to be handled, as passed by the application.
358  *        Can be 0 to indicate any control code
359  * @param DeviceObject
360  *        Device Object to handle IRPs for.
361  *        Can be NULL to indicate any device object
362  * @param MessageHandler
363  *        Handler function to register.
364  *
365  * @return Status
366  */
367 NTSTATUS
368 KmtRegisterMessageHandler(
369     IN ULONG ControlCode OPTIONAL,
370     IN PDEVICE_OBJECT DeviceObject OPTIONAL,
371     IN PKMT_MESSAGE_HANDLER MessageHandler)
372 {
373     NTSTATUS Status = STATUS_SUCCESS;
374     int i;
375 
376     if (ControlCode >= 0x400)
377     {
378         Status = STATUS_INVALID_PARAMETER_1;
379         goto cleanup;
380     }
381 
382     if (MessageHandler == NULL)
383     {
384         Status = STATUS_INVALID_PARAMETER_2;
385         goto cleanup;
386     }
387 
388     for (i = 0; i < sizeof MessageHandlers / sizeof MessageHandlers[0]; ++i)
389         if (MessageHandlers[i].MessageHandler == NULL)
390         {
391             MessageHandlers[i].ControlCode = ControlCode;
392             MessageHandlers[i].DeviceObject = DeviceObject;
393             MessageHandlers[i].MessageHandler = MessageHandler;
394             goto cleanup;
395         }
396 
397     Status = STATUS_ALLOTTED_SPACE_EXCEEDED;
398 
399 cleanup:
400     return Status;
401 }
402 
403 /**
404  * @name KmtUnregisterMessageHandler
405  *
406  * Unregister a handler with the DeviceControl Dispatcher.
407  * Parameters must be specified exactly as in the call to
408  * KmtRegisterMessageHandler. Only the first matching entry will be removed
409  * if multiple exist
410  *
411  * @param ControlCode
412  *        Control code of the handler to be removed
413  * @param DeviceObject
414  *        Device Object to of the handler to be removed
415  * @param MessageHandler
416  *        Handler function of the handler to be removed
417  *
418  * @return Status
419  */
420 NTSTATUS
421 KmtUnregisterMessageHandler(
422     IN ULONG ControlCode OPTIONAL,
423     IN PDEVICE_OBJECT DeviceObject OPTIONAL,
424     IN PKMT_MESSAGE_HANDLER MessageHandler)
425 {
426     NTSTATUS Status = STATUS_SUCCESS;
427     int i;
428 
429     for (i = 0; i < sizeof MessageHandlers / sizeof MessageHandlers[0]; ++i)
430         if (MessageHandlers[i].ControlCode == ControlCode &&
431                 MessageHandlers[i].DeviceObject == DeviceObject &&
432                 MessageHandlers[i].MessageHandler == MessageHandler)
433         {
434             MessageHandlers[i].MessageHandler = NULL;
435             goto cleanup;
436         }
437 
438     Status = STATUS_NOT_FOUND;
439 
440 cleanup:
441     return Status;
442 }
443 
444 /**
445  * @name DeviceControlHandler
446  *
447  * Default IRP_MJ_DEVICE_CONTROL/IRP_MJ_INTERNAL_DEVICE_CONTROL handler
448  *
449  * @param DeviceObject
450  *        Device Object.
451  *        This is guaranteed not to have been touched by the dispatch function
452  *        before the call to the IRP handler
453  * @param Irp
454  *        Device Object.
455  *        This is guaranteed not to have been touched by the dispatch function
456  *        before the call to the IRP handler, except for passing it to
457  *        IoGetCurrentStackLocation
458  * @param IoStackLocation
459  *        Device Object.
460  *        This is guaranteed not to have been touched by the dispatch function
461  *        before the call to the IRP handler
462  *
463  * @return Status
464  */
465 static
466 NTSTATUS
467 DeviceControlHandler(
468     IN PDEVICE_OBJECT DeviceObject,
469     IN PIRP Irp,
470     IN PIO_STACK_LOCATION IoStackLocation)
471 {
472     NTSTATUS Status = STATUS_SUCCESS;
473     ULONG ControlCode = (IoStackLocation->Parameters.DeviceIoControl.IoControlCode & 0x00000FFC) >> 2;
474     SIZE_T OutLength = IoStackLocation->Parameters.DeviceIoControl.OutputBufferLength;
475     int i;
476 
477     for (i = 0; i < sizeof MessageHandlers / sizeof MessageHandlers[0]; ++i)
478     {
479         if ((MessageHandlers[i].ControlCode == 0 ||
480                 MessageHandlers[i].ControlCode == ControlCode) &&
481                 (MessageHandlers[i].DeviceObject == NULL || MessageHandlers[i].DeviceObject == DeviceObject) &&
482                 MessageHandlers[i].MessageHandler != NULL)
483         {
484             Status = MessageHandlers[i].MessageHandler(DeviceObject, ControlCode, Irp->AssociatedIrp.SystemBuffer,
485                                                         IoStackLocation->Parameters.DeviceIoControl.InputBufferLength,
486                                                         &OutLength);
487             break;
488         }
489     }
490 
491     Irp->IoStatus.Status = Status;
492     Irp->IoStatus.Information = OutLength;
493 
494     IoCompleteRequest(Irp, IO_NO_INCREMENT);
495 
496     return Status;
497 }
498