1 /*
2  * PROJECT:         ReactOS kernel-mode tests
3  * LICENSE:         GPLv2+ - See COPYING in the top level directory
4  * PURPOSE:         Kernel-Mode Test Suite MmMapLockedPagesSpecifyCache test user-mode part
5  * PROGRAMMER:      Pierre Schweitzer <pierre@reactos.org>
6  */
7 
8 #include <kmt_test.h>
9 #include <ndk/exfuncs.h>
10 
11 #include "MmMapLockedPagesSpecifyCache.h"
12 
13 #define ALIGN_DOWN_BY(size, align) \
14         ((ULONG_PTR)(size) & ~((ULONG_PTR)(align) - 1))
15 
16 #define SET_BUFFER_LENGTH(Var, Length)         \
17 {                                              \
18     C_ASSERT(((Length) % sizeof(ULONG)) == 0); \
19     Var = (Length);                            \
20 }
21 
22 #define FILL_QUERY_BUFFER(QueryBuffer, BufferLength, UseCache) \
23 {                                                              \
24     QueryBuffer.Length = BufferLength;                         \
25     QueryBuffer.Buffer = NULL;                                 \
26     QueryBuffer.Cached = UseCache;                             \
27     QueryBuffer.Status = STATUS_SUCCESS;                       \
28 }
29 
30 #define FILL_READ_BUFFER(QueryBuffer, ReadBuffer)               \
31 {                                                               \
32     PULONG Buffer;                                              \
33     ReadBuffer.Buffer = QueryBuffer.Buffer;                     \
34     if (!skip(QueryBuffer.Buffer != NULL, "Buffer is NULL\n"))  \
35     {                                                           \
36         ReadBuffer.Pattern = WRITE_PATTERN;                     \
37         ReadBuffer.Length = QueryBuffer.Length;                 \
38         Buffer = QueryBuffer.Buffer;                            \
39         for (i = 0; i < ReadBuffer.Length / sizeof(ULONG); ++i) \
40         {                                                       \
41             Buffer[i] = ReadBuffer.Pattern;                     \
42         }                                                       \
43     }                                                           \
44 }
45 
46 #define CHECK_ALLOC(MappedBuffer, BufferLength)                 \
47 {                                                               \
48     NTSTATUS Status;                                            \
49     PVOID BaseAddress;                                          \
50     SIZE_T Size;                                                \
51     BaseAddress = MappedBuffer;                                 \
52     Size = BufferLength;                                        \
53     Status = NtAllocateVirtualMemory(NtCurrentProcess(),        \
54                                      &BaseAddress,              \
55                                      0,                         \
56                                      &Size,                     \
57                                      MEM_RESERVE,               \
58                                      PAGE_READWRITE);           \
59     ok_eq_hex(Status, STATUS_CONFLICTING_ADDRESSES);            \
60     BaseAddress = MappedBuffer;                                 \
61     Size = 0;                                                   \
62     Status = NtFreeVirtualMemory(NtCurrentProcess(),            \
63                                  &BaseAddress,                  \
64                                  &Size,                         \
65                                  MEM_DECOMMIT);                 \
66     ok_eq_hex(Status, STATUS_UNABLE_TO_DELETE_SECTION);         \
67     BaseAddress = MappedBuffer;                                 \
68     Size = 0;                                                   \
69     Status = NtFreeVirtualMemory(NtCurrentProcess(),            \
70                                  &BaseAddress,                  \
71                                  &Size,                         \
72                                  MEM_RELEASE);                  \
73     ok_eq_hex(Status, STATUS_UNABLE_TO_DELETE_SECTION);         \
74     Status = NtUnmapViewOfSection(NtCurrentProcess(),           \
75                                   MappedBuffer);                \
76     ok_eq_hex(Status, STATUS_NOT_MAPPED_VIEW);                  \
77 }
78 
79 START_TEST(MmMapLockedPagesSpecifyCache)
80 {
81     QUERY_BUFFER QueryBuffer;
82     READ_BUFFER ReadBuffer;
83     DWORD Length;
84     USHORT i;
85     USHORT BufferLength;
86     SYSTEM_BASIC_INFORMATION BasicInfo;
87     NTSTATUS Status;
88     ULONG_PTR HighestAddress;
89     DWORD Error;
90 
91     Error = KmtLoadAndOpenDriver(L"MmMapLockedPagesSpecifyCache", FALSE);
92     ok_eq_int(Error, ERROR_SUCCESS);
93     if (Error)
94         return;
95 
96     // Less than a page
97     SET_BUFFER_LENGTH(BufferLength, 2048);
98     Length = sizeof(QUERY_BUFFER);
99     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
100     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
101     ok_eq_int(QueryBuffer.Length, BufferLength);
102     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
103     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
104 
105     Length = 0;
106     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
107     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
108 
109     Length = sizeof(QUERY_BUFFER);
110     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, TRUE);
111     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
112     ok_eq_int(QueryBuffer.Length, BufferLength);
113     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
114     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
115 
116     Length = 0;
117     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
118     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
119 
120     // 1 page
121     SET_BUFFER_LENGTH(BufferLength, 4096);
122     Length = sizeof(QUERY_BUFFER);
123     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
124     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
125     ok_eq_int(QueryBuffer.Length, BufferLength);
126     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
127     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
128 
129     Length = 0;
130     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
131     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
132 
133     Length = sizeof(QUERY_BUFFER);
134     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, TRUE);
135     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
136     ok_eq_int(QueryBuffer.Length, BufferLength);
137     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
138     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
139 
140     Length = 0;
141     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
142     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
143 
144     // more than 1 page
145     SET_BUFFER_LENGTH(BufferLength, 4096 + 2048);
146     Length = sizeof(QUERY_BUFFER);
147     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
148     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
149     ok_eq_int(QueryBuffer.Length, BufferLength);
150     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
151     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
152 
153     Length = 0;
154     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
155     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
156 
157     Length = sizeof(QUERY_BUFFER);
158     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, TRUE);
159     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
160     ok_eq_int(QueryBuffer.Length, BufferLength);
161     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
162     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
163 
164     Length = 0;
165     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
166     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
167 
168     // 2 pages
169     SET_BUFFER_LENGTH(BufferLength, 2 * 4096);
170     Length = sizeof(QUERY_BUFFER);
171     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
172     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
173     ok_eq_int(QueryBuffer.Length, BufferLength);
174     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
175     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
176 
177     Length = 0;
178     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
179     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
180 
181     Length = sizeof(QUERY_BUFFER);
182     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, TRUE);
183     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
184     ok_eq_int(QueryBuffer.Length, BufferLength);
185     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
186     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
187 
188     Length = 0;
189     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
190     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
191 
192     // more than 2 pages
193     SET_BUFFER_LENGTH(BufferLength, 2 * 4096 + 2048);
194     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
195     Length = sizeof(QUERY_BUFFER);
196     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
197     ok_eq_int(QueryBuffer.Length, BufferLength);
198     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
199     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
200 
201     Length = 0;
202     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
203     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
204 
205     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, TRUE);
206     Length = sizeof(QUERY_BUFFER);
207     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
208     ok_eq_int(QueryBuffer.Length, BufferLength);
209     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
210     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
211 
212     Length = 0;
213     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
214     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
215 
216     // ask for a specific address (we know that ReadBuffer.Buffer is free)
217     SET_BUFFER_LENGTH(BufferLength, 4096);
218     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
219     QueryBuffer.Buffer = ReadBuffer.Buffer;
220     Length = sizeof(QUERY_BUFFER);
221     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
222     ok_eq_int(QueryBuffer.Length, BufferLength);
223     ok(QueryBuffer.Buffer == ReadBuffer.Buffer, "Buffer is NULL\n");
224     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
225 
226     Length = 0;
227     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
228     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
229 
230     // ask for an unaligned address
231     SET_BUFFER_LENGTH(BufferLength, 4096);
232     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
233     QueryBuffer.Buffer = (PVOID)((ULONG_PTR)ReadBuffer.Buffer + 2048);
234     QueryBuffer.Status = STATUS_INVALID_ADDRESS;
235     Length = sizeof(QUERY_BUFFER);
236     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
237     ok_eq_int(QueryBuffer.Length, BufferLength);
238     ok(QueryBuffer.Buffer == NULL, "Buffer is %p\n", QueryBuffer.Buffer);
239 
240     Length = 0;
241     ok(KmtSendBufferToDriver(IOCTL_CLEAN, NULL, 0, &Length) == ERROR_SUCCESS, "\n");
242 
243     // get system info for MmHighestUserAddress
244     Status = NtQuerySystemInformation(SystemBasicInformation,
245                                       &BasicInfo,
246                                       sizeof(BasicInfo),
247                                       NULL);
248     ok_eq_hex(Status, STATUS_SUCCESS);
249     trace("MaximumUserModeAddress: %lx\n", BasicInfo.MaximumUserModeAddress);
250     HighestAddress = ALIGN_DOWN_BY(BasicInfo.MaximumUserModeAddress, PAGE_SIZE);
251 
252     // near MmHighestUserAddress
253     SET_BUFFER_LENGTH(BufferLength, 4096);
254     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
255     QueryBuffer.Buffer = (PVOID)(HighestAddress - 15 * PAGE_SIZE); // 7ffe0000
256     QueryBuffer.Status = STATUS_INVALID_ADDRESS;
257     trace("QueryBuffer.Buffer %p\n", QueryBuffer.Buffer);
258     Length = sizeof(QUERY_BUFFER);
259     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
260     ok_eq_int(QueryBuffer.Length, BufferLength);
261     ok(QueryBuffer.Buffer == NULL, "Buffer is %p\n", QueryBuffer.Buffer);
262 
263     Length = 0;
264     ok(KmtSendBufferToDriver(IOCTL_CLEAN, NULL, 0, &Length) == ERROR_SUCCESS, "\n");
265 
266     // far enough away from MmHighestUserAddress
267     SET_BUFFER_LENGTH(BufferLength, 4096);
268     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
269     QueryBuffer.Buffer = (PVOID)(HighestAddress - 16 * PAGE_SIZE); // 7ffdf000
270     QueryBuffer.Status = -1;
271     trace("QueryBuffer.Buffer %p\n", QueryBuffer.Buffer);
272     Length = sizeof(QUERY_BUFFER);
273     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
274     ok_eq_int(QueryBuffer.Length, BufferLength);
275     ok(QueryBuffer.Status == STATUS_SUCCESS ||
276        QueryBuffer.Status == STATUS_CONFLICTING_ADDRESSES, "Status = %lx\n", QueryBuffer.Status);
277 
278     Length = 0;
279     ok(KmtSendBufferToDriver(IOCTL_CLEAN, NULL, 0, &Length) == ERROR_SUCCESS, "\n");
280 
281     KmtCloseDriver();
282     KmtUnloadDriver();
283 }
284