1 /*
2  * PROJECT:     ReactOS kernel-mode tests
3  * LICENSE:     LGPL-2.1+ (https://spdx.org/licenses/LGPL-2.1+)
4  * PURPOSE:     Kernel-Mode Test Suite Example test driver
5  * COPYRIGHT:   Copyright 2011-2018 Thomas Faber <thomas.faber@reactos.org>
6  */
7 
8 #include <kmt_test.h>
9 
10 //#define NDEBUG
11 #include <debug.h>
12 
13 #include "Example.h"
14 
15 /* prototypes */
16 static KMT_MESSAGE_HANDLER TestMessageHandler;
17 static KMT_IRP_HANDLER TestIrpHandler;
18 
19 /* globals */
20 static PDRIVER_OBJECT TestDriverObject;
21 
22 /**
23  * @name TestEntry
24  *
25  * Test entry point.
26  * This is called by DriverEntry as early as possible, but with ResultBuffer
27  * initialized, so that test macros work correctly
28  *
29  * @param DriverObject
30  *        Driver Object.
31  *        This is guaranteed not to have been touched by DriverEntry before
32  *        the call to TestEntry
33  * @param RegistryPath
34  *        Driver Registry Path
35  *        This is guaranteed not to have been touched by DriverEntry before
36  *        the call to TestEntry
37  * @param DeviceName
38  *        Pointer to receive a test-specific name for the device to create
39  * @param Flags
40  *        Pointer to a flags variable instructing DriverEntry how to proceed.
41  *        See the KMT_TESTENTRY_FLAGS enumeration for possible values
42  *        Initialized to zero on entry
43  *
44  * @return Status.
45  *         DriverEntry will fail if this is a failure status
46  */
47 NTSTATUS
48 TestEntry(
49     IN PDRIVER_OBJECT DriverObject,
50     IN PCUNICODE_STRING RegistryPath,
51     OUT PCWSTR *DeviceName,
52     IN OUT INT *Flags)
53 {
54     NTSTATUS Status = STATUS_SUCCESS;
55 
56     PAGED_CODE();
57 
58     UNREFERENCED_PARAMETER(RegistryPath);
59     UNREFERENCED_PARAMETER(Flags);
60 
61     DPRINT("Entry!\n");
62 
63     ok_irql(PASSIVE_LEVEL);
64     TestDriverObject = DriverObject;
65 
66     *DeviceName = L"Example";
67 
68     trace("Hi, this is the example driver\n");
69 
70     KmtRegisterIrpHandler(IRP_MJ_CREATE, NULL, TestIrpHandler);
71     KmtRegisterIrpHandler(IRP_MJ_CLOSE, NULL, TestIrpHandler);
72     KmtRegisterMessageHandler(0, NULL, TestMessageHandler);
73 
74     return Status;
75 }
76 
77 /**
78  * @name TestUnload
79  *
80  * Test unload routine.
81  * This is called by the driver's Unload routine as early as possible, with
82  * ResultBuffer and the test device object still valid, so that test macros
83  * work correctly
84  *
85  * @param DriverObject
86  *        Driver Object.
87  *        This is guaranteed not to have been touched by Unload before the call
88  *        to TestEntry
89  *
90  * @return Status
91  */
92 VOID
93 TestUnload(
94     IN PDRIVER_OBJECT DriverObject)
95 {
96     PAGED_CODE();
97 
98     DPRINT("Unload!\n");
99 
100     ok_irql(PASSIVE_LEVEL);
101     ok_eq_pointer(DriverObject, TestDriverObject);
102 
103     trace("Unloading example driver\n");
104 }
105 
106 /**
107  * @name TestMessageHandler
108  *
109  * Test message handler routine
110  *
111  * @param DeviceObject
112  *        Device Object.
113  *        This is guaranteed not to have been touched by the dispatch function
114  *        before the call to the IRP handler
115  * @param Irp
116  *        Device Object.
117  *        This is guaranteed not to have been touched by the dispatch function
118  *        before the call to the IRP handler, except for passing it to
119  *        IoGetCurrentStackLocation
120  * @param IoStackLocation
121  *        Device Object.
122  *        This is guaranteed not to have been touched by the dispatch function
123  *        before the call to the IRP handler
124  *
125  * @return Status
126  */
127 static
128 NTSTATUS
129 TestMessageHandler(
130     IN PDEVICE_OBJECT DeviceObject,
131     IN ULONG ControlCode,
132     IN PVOID Buffer OPTIONAL,
133     IN SIZE_T InLength,
134     IN OUT PSIZE_T OutLength)
135 {
136     NTSTATUS Status = STATUS_SUCCESS;
137 
138     switch (ControlCode)
139     {
140         case IOCTL_NOTIFY:
141         {
142             static int TimesReceived = 0;
143 
144             ++TimesReceived;
145             ok(TimesReceived == 1, "Received control code 1 %d times\n", TimesReceived);
146             ok_eq_pointer(Buffer, NULL);
147             ok_eq_ulong((ULONG)InLength, 0LU);
148             ok_eq_ulong((ULONG)*OutLength, 0LU);
149             break;
150         }
151         case IOCTL_SEND_STRING:
152         {
153             static int TimesReceived = 0;
154             ANSI_STRING ExpectedString = RTL_CONSTANT_STRING("yay");
155             ANSI_STRING ReceivedString;
156 
157             ++TimesReceived;
158             ok(TimesReceived == 1, "Received control code 2 %d times\n", TimesReceived);
159             ok(Buffer != NULL, "Buffer is NULL\n");
160             ok_eq_ulong((ULONG)InLength, (ULONG)ExpectedString.Length);
161             ok_eq_ulong((ULONG)*OutLength, 0LU);
162             ReceivedString.MaximumLength = ReceivedString.Length = (USHORT)InLength;
163             ReceivedString.Buffer = Buffer;
164             ok(RtlCompareString(&ExpectedString, &ReceivedString, FALSE) == 0, "Received string: %Z\n", &ReceivedString);
165             break;
166         }
167         case IOCTL_SEND_MYSTRUCT:
168         {
169             static int TimesReceived = 0;
170             MY_STRUCT ExpectedStruct = { 123, ":D" };
171             MY_STRUCT ResultStruct = { 456, "!!!" };
172 
173             ++TimesReceived;
174             ok(TimesReceived == 1, "Received control code 3 %d times\n", TimesReceived);
175             ok(Buffer != NULL, "Buffer is NULL\n");
176             ok_eq_ulong((ULONG)InLength, (ULONG)sizeof ExpectedStruct);
177             ok_eq_ulong((ULONG)*OutLength, 2LU * sizeof ExpectedStruct);
178             if (!skip(Buffer && InLength >= sizeof ExpectedStruct, "Cannot read from buffer!\n"))
179                 ok(RtlCompareMemory(&ExpectedStruct, Buffer, sizeof ExpectedStruct) == sizeof ExpectedStruct, "Buffer does not contain expected values\n");
180 
181             if (!skip(Buffer && *OutLength >= 2 * sizeof ExpectedStruct, "Cannot write to buffer!\n"))
182             {
183                 RtlCopyMemory((PCHAR)Buffer + sizeof ExpectedStruct, &ResultStruct, sizeof ResultStruct);
184                 *OutLength = 2 * sizeof ExpectedStruct;
185             }
186             break;
187         }
188         default:
189             ok(0, "Got an unknown message! DeviceObject=%p, ControlCode=%lu, Buffer=%p, In=%lu, Out=%lu bytes\n",
190                     DeviceObject, ControlCode, Buffer, InLength, *OutLength);
191             break;
192     }
193 
194     return Status;
195 }
196 
197 /**
198  * @name TestIrpHandler
199  *
200  * Test IRP handler routine
201  *
202  * @param DeviceObject
203  *        Device Object.
204  *        This is guaranteed not to have been touched by the dispatch function
205  *        before the call to the IRP handler
206  * @param Irp
207  *        Device Object.
208  *        This is guaranteed not to have been touched by the dispatch function
209  *        before the call to the IRP handler, except for passing it to
210  *        IoGetCurrentStackLocation
211  * @param IoStackLocation
212  *        Device Object.
213  *        This is guaranteed not to have been touched by the dispatch function
214  *        before the call to the IRP handler
215  *
216  * @return Status
217  */
218 static
219 NTSTATUS
220 TestIrpHandler(
221     IN PDEVICE_OBJECT DeviceObject,
222     IN PIRP Irp,
223     IN PIO_STACK_LOCATION IoStackLocation)
224 {
225     NTSTATUS Status = STATUS_SUCCESS;
226 
227     DPRINT("IRP!\n");
228 
229     ok_irql(PASSIVE_LEVEL);
230     ok_eq_pointer(DeviceObject->DriverObject, TestDriverObject);
231 
232     if (IoStackLocation->MajorFunction == IRP_MJ_CREATE)
233         trace("Got IRP_MJ_CREATE!\n");
234     else if (IoStackLocation->MajorFunction == IRP_MJ_CLOSE)
235         trace("Got IRP_MJ_CLOSE!\n");
236     else
237         trace("Got an IRP!\n");
238 
239     Irp->IoStatus.Status = Status;
240     Irp->IoStatus.Information = 0;
241 
242     IoCompleteRequest(Irp, IO_NO_INCREMENT);
243 
244     return Status;
245 }
246