1 /*
2  * PROJECT:         ReactOS kernel-mode tests
3  * LICENSE:         GPLv2+ - See COPYING in the top level directory
4  * PURPOSE:         Test driver for MmMapLockedPagesSpecifyCache function
5  * PROGRAMMER:      Pierre Schweitzer <pierre@reactos.org>
6  */
7 
8 #include <kmt_test.h>
9 
10 #define NDEBUG
11 #include <debug.h>
12 
13 #include "MmMapLockedPagesSpecifyCache.h"
14 
15 static KMT_IRP_HANDLER TestIrpHandler;
16 static KMT_MESSAGE_HANDLER TestMessageHandler;
17 
18 static PVOID CurrentBuffer;
19 static PMDL CurrentMdl;
20 static PVOID CurrentUser;
21 static SIZE_T NonCachedLength;
22 
23 NTSTATUS
TestEntry(IN PDRIVER_OBJECT DriverObject,IN PCUNICODE_STRING RegistryPath,OUT PCWSTR * DeviceName,IN OUT INT * Flags)24 TestEntry(
25     IN PDRIVER_OBJECT DriverObject,
26     IN PCUNICODE_STRING RegistryPath,
27     OUT PCWSTR *DeviceName,
28     IN OUT INT *Flags)
29 {
30     NTSTATUS Status = STATUS_SUCCESS;
31 
32     PAGED_CODE();
33 
34     UNREFERENCED_PARAMETER(RegistryPath);
35     UNREFERENCED_PARAMETER(Flags);
36 
37     *DeviceName = L"MmMapLockedPagesSpecifyCache";
38 
39     KmtRegisterIrpHandler(IRP_MJ_CLEANUP, NULL, TestIrpHandler);
40     KmtRegisterMessageHandler(0, NULL, TestMessageHandler);
41 
42     return Status;
43 }
44 
45 VOID
TestUnload(IN PDRIVER_OBJECT DriverObject)46 TestUnload(
47     IN PDRIVER_OBJECT DriverObject)
48 {
49     PAGED_CODE();
50 }
51 
52 VOID
TestCleanEverything(VOID)53 TestCleanEverything(VOID)
54 {
55     NTSTATUS SehStatus;
56 
57     if (CurrentMdl == NULL)
58     {
59         return;
60     }
61 
62     if (CurrentUser != NULL)
63     {
64         SehStatus = STATUS_SUCCESS;
65         _SEH2_TRY
66         {
67             MmUnmapLockedPages(CurrentUser, CurrentMdl);
68         }
69         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
70         {
71             SehStatus = _SEH2_GetExceptionCode();
72         }
73         _SEH2_END;
74         ok_eq_hex(SehStatus, STATUS_SUCCESS);
75         CurrentUser = NULL;
76     }
77 
78     SehStatus = STATUS_SUCCESS;
79     _SEH2_TRY
80     {
81         MmUnlockPages(CurrentMdl);
82     }
83     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
84     {
85         SehStatus = _SEH2_GetExceptionCode();
86     }
87     _SEH2_END;
88     ok_eq_hex(SehStatus, STATUS_SUCCESS);
89     IoFreeMdl(CurrentMdl);
90     if (NonCachedLength)
91     {
92         MmFreeNonCachedMemory(CurrentBuffer, NonCachedLength);
93     }
94     else
95     {
96         ExFreePoolWithTag(CurrentBuffer, 'MLPC');
97     }
98     CurrentMdl = NULL;
99 }
100 
101 static
102 NTSTATUS
TestMessageHandler(IN PDEVICE_OBJECT DeviceObject,IN ULONG ControlCode,IN PVOID Buffer OPTIONAL,IN SIZE_T InLength,IN OUT PSIZE_T OutLength)103 TestMessageHandler(
104     IN PDEVICE_OBJECT DeviceObject,
105     IN ULONG ControlCode,
106     IN PVOID Buffer OPTIONAL,
107     IN SIZE_T InLength,
108     IN OUT PSIZE_T OutLength)
109 {
110     NTSTATUS Status = STATUS_SUCCESS;
111     NTSTATUS SehStatus;
112 
113     switch (ControlCode)
114     {
115         case IOCTL_QUERY_BUFFER:
116         {
117             ok(Buffer != NULL, "Buffer is NULL\n");
118             ok_eq_size(InLength, sizeof(QUERY_BUFFER));
119             ok_eq_size(*OutLength, sizeof(QUERY_BUFFER));
120             ok_eq_pointer(CurrentMdl, NULL);
121 
122             TestCleanEverything();
123 
124             ok(ExGetPreviousMode() == UserMode, "Not coming from umode!\n");
125             if (!skip(Buffer && InLength >= sizeof(QUERY_BUFFER) && *OutLength >= sizeof(QUERY_BUFFER), "Cannot read/write from/to buffer!\n"))
126             {
127                 PQUERY_BUFFER QueryBuffer;
128                 USHORT Length;
129                 MEMORY_CACHING_TYPE CacheType;
130 
131                 QueryBuffer = Buffer;
132                 CacheType = (QueryBuffer->Cached ? MmCached : MmNonCached);
133                 Length = QueryBuffer->Length;
134                 CurrentUser = NULL;
135                 ok(Length > 0, "Null size!\n");
136 
137                 if (!skip(Length > 0, "Null size!\n"))
138                 {
139                     if (QueryBuffer->Cached)
140                     {
141                         CurrentBuffer = ExAllocatePoolWithTag(NonPagedPool, Length, 'MLPC');
142                         ok(CurrentBuffer != NULL, "ExAllocatePool failed!\n");
143                         NonCachedLength = 0;
144                     }
145                     else
146                     {
147                         CurrentBuffer = MmAllocateNonCachedMemory(Length);
148                         ok(CurrentBuffer != NULL, "MmAllocateNonCachedMemory failed!\n");
149                         if (CurrentBuffer)
150                         {
151                             RtlZeroMemory(CurrentBuffer, Length);
152                             NonCachedLength = Length;
153                         }
154                     }
155                     if (!skip(CurrentBuffer != NULL, "ExAllocatePool failed!\n"))
156                     {
157                         CurrentMdl = IoAllocateMdl(CurrentBuffer, Length, FALSE, FALSE, NULL);
158                         ok(CurrentMdl != NULL, "IoAllocateMdl failed!\n");
159                         if (!skip(CurrentMdl != NULL, "IoAllocateMdl failed!\n"))
160                         {
161                             KIRQL Irql;
162 
163                             SehStatus = STATUS_SUCCESS;
164                             _SEH2_TRY
165                             {
166                                 MmProbeAndLockPages(CurrentMdl, KernelMode, IoWriteAccess);
167                             }
168                             _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
169                             {
170                                 SehStatus = _SEH2_GetExceptionCode();
171                             }
172                             _SEH2_END;
173                             ok_eq_hex(SehStatus, STATUS_SUCCESS);
174 
175                             Irql = KeGetCurrentIrql();
176                             ok(Irql <= APC_LEVEL, "IRQL > APC_LEVEL: %d\n", Irql);
177 
178                             SehStatus = STATUS_SUCCESS;
179                             _SEH2_TRY
180                             {
181                                 CurrentUser = MmMapLockedPagesSpecifyCache(CurrentMdl, UserMode, CacheType, QueryBuffer->Buffer, FALSE, NormalPagePriority);
182                             }
183                             _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
184                             {
185                                 SehStatus = _SEH2_GetExceptionCode();
186                             }
187                             _SEH2_END;
188                             if (QueryBuffer->Status != -1)
189                             {
190                                 ok_eq_hex(SehStatus, QueryBuffer->Status);
191                                 if (NT_SUCCESS(QueryBuffer->Status))
192                                 {
193                                     ok(CurrentUser != NULL, "MmMapLockedPagesSpecifyCache failed!\n");
194                                 }
195                                 else
196                                 {
197                                     ok(CurrentUser == NULL, "MmMapLockedPagesSpecifyCache succeeded!\n");
198                                 }
199                             }
200                             QueryBuffer->Status = SehStatus;
201                         }
202                         else
203                         {
204                             ExFreePoolWithTag(CurrentBuffer, 'MLPC');
205                         }
206                     }
207                 }
208 
209                 QueryBuffer->Buffer = CurrentUser;
210                 *OutLength = sizeof(QUERY_BUFFER);
211             }
212 
213             break;
214         }
215         case IOCTL_READ_BUFFER:
216         {
217             ok(Buffer != NULL, "Buffer is NULL\n");
218             ok_eq_size(InLength, sizeof(READ_BUFFER));
219             ok_eq_size(*OutLength, 0);
220             ok(CurrentMdl != NULL, "MDL is not in use!\n");
221 
222             if (!skip(Buffer && InLength >= sizeof(READ_BUFFER), "Cannot read from buffer!\n"))
223             {
224                 PREAD_BUFFER ReadBuffer;
225 
226                 ReadBuffer = Buffer;
227                 if (!skip(ReadBuffer && ReadBuffer->Buffer == CurrentUser, "Cannot find matching MDL\n"))
228                 {
229                     if (ReadBuffer->Buffer != NULL)
230                     {
231                         USHORT i;
232                         PULONG KBuffer = MmGetSystemAddressForMdlSafe(CurrentMdl, NormalPagePriority);
233                         ok(KBuffer != NULL, "Failed to get kmode ptr\n");
234                         ok(ReadBuffer->Length % sizeof(ULONG) == 0, "Invalid size: %d\n", ReadBuffer->Length);
235 
236                         if (!skip(Buffer != NULL, "Failed to get kmode ptr\n"))
237                         {
238                             for (i = 0; i < ReadBuffer->Length / sizeof(ULONG); ++i)
239                             {
240                                 ok_eq_ulong(KBuffer[i], ReadBuffer->Pattern);
241                             }
242                         }
243                     }
244                 }
245 
246                 TestCleanEverything();
247             }
248 
249             break;
250         }
251         case IOCTL_CLEAN:
252         {
253             TestCleanEverything();
254             break;
255         }
256         default:
257             ok(0, "Got an unknown message! DeviceObject=%p, ControlCode=%lu, Buffer=%p, In=%lu, Out=%lu bytes\n",
258                     DeviceObject, ControlCode, Buffer, InLength, *OutLength);
259             break;
260     }
261 
262     return Status;
263 }
264 
265 static
266 NTSTATUS
TestIrpHandler(_In_ PDEVICE_OBJECT DeviceObject,_In_ PIRP Irp,_In_ PIO_STACK_LOCATION IoStack)267 TestIrpHandler(
268     _In_ PDEVICE_OBJECT DeviceObject,
269     _In_ PIRP Irp,
270     _In_ PIO_STACK_LOCATION IoStack)
271 {
272     NTSTATUS Status;
273 
274     PAGED_CODE();
275 
276     DPRINT("IRP %x/%x\n", IoStack->MajorFunction, IoStack->MinorFunction);
277     ASSERT(IoStack->MajorFunction == IRP_MJ_CLEANUP);
278 
279     Status = STATUS_NOT_SUPPORTED;
280     Irp->IoStatus.Information = 0;
281 
282     if (IoStack->MajorFunction == IRP_MJ_CLEANUP)
283     {
284         TestCleanEverything();
285         Status = STATUS_SUCCESS;
286     }
287 
288     if (Status == STATUS_PENDING)
289     {
290         IoMarkIrpPending(Irp);
291         IoCompleteRequest(Irp, IO_NO_INCREMENT);
292         Status = STATUS_PENDING;
293     }
294     else
295     {
296         Irp->IoStatus.Status = Status;
297         IoCompleteRequest(Irp, IO_NO_INCREMENT);
298     }
299 
300     return Status;
301 }
302