1 /*
2  * PROJECT:         ReactOS kernel-mode tests
3  * LICENSE:         GPLv2+ - See COPYING in the top level directory
4  * PURPOSE:         Kernel-Mode Test Suite MmMapLockedPagesSpecifyCache test user-mode part
5  * PROGRAMMER:      Pierre Schweitzer <pierre@reactos.org>
6  */
7 
8 #include <kmt_test.h>
9 #include <ndk/exfuncs.h>
10 
11 #include "MmMapLockedPagesSpecifyCache.h"
12 
13 #define ALIGN_DOWN_BY(size, align) \
14         ((ULONG_PTR)(size) & ~((ULONG_PTR)(align) - 1))
15 
16 #define SET_BUFFER_LENGTH(Var, Length)         \
17 {                                              \
18     C_ASSERT(((Length) % sizeof(ULONG)) == 0); \
19     Var = (Length);                            \
20 }
21 
22 #define FILL_QUERY_BUFFER(QueryBuffer, BufferLength, UseCache) \
23 {                                                              \
24     QueryBuffer.Length = BufferLength;                         \
25     QueryBuffer.Buffer = NULL;                                 \
26     QueryBuffer.Cached = UseCache;                             \
27     QueryBuffer.Status = STATUS_SUCCESS;                       \
28 }
29 
30 #define FILL_READ_BUFFER(QueryBuffer, ReadBuffer)               \
31 {                                                               \
32     PULONG Buffer;                                              \
33     ReadBuffer.Buffer = QueryBuffer.Buffer;                     \
34     if (!skip(QueryBuffer.Buffer != NULL, "Buffer is NULL\n"))  \
35     {                                                           \
36         ReadBuffer.Pattern = WRITE_PATTERN;                     \
37         ReadBuffer.Length = QueryBuffer.Length;                 \
38         Buffer = QueryBuffer.Buffer;                            \
39         for (i = 0; i < ReadBuffer.Length / sizeof(ULONG); ++i) \
40         {                                                       \
41             Buffer[i] = ReadBuffer.Pattern;                     \
42         }                                                       \
43     }                                                           \
44 }
45 
46 #define CHECK_ALLOC(MappedBuffer, BufferLength)                 \
47 {                                                               \
48     NTSTATUS Status;                                            \
49     PVOID BaseAddress;                                          \
50     SIZE_T Size;                                                \
51     BaseAddress = MappedBuffer;                                 \
52     Size = BufferLength;                                        \
53     Status = NtAllocateVirtualMemory(NtCurrentProcess(),        \
54                                      &BaseAddress,              \
55                                      0,                         \
56                                      &Size,                     \
57                                      MEM_RESERVE,               \
58                                      PAGE_READWRITE);           \
59     ok_eq_hex(Status, STATUS_CONFLICTING_ADDRESSES);            \
60     BaseAddress = MappedBuffer;                                 \
61     Size = 0;                                                   \
62     Status = NtFreeVirtualMemory(NtCurrentProcess(),            \
63                                  &BaseAddress,                  \
64                                  &Size,                         \
65                                  MEM_DECOMMIT);                 \
66     ok_eq_hex(Status, STATUS_UNABLE_TO_DELETE_SECTION);         \
67     BaseAddress = MappedBuffer;                                 \
68     Size = 0;                                                   \
69     Status = NtFreeVirtualMemory(NtCurrentProcess(),            \
70                                  &BaseAddress,                  \
71                                  &Size,                         \
72                                  MEM_RELEASE);                  \
73     ok_eq_hex(Status, STATUS_UNABLE_TO_DELETE_SECTION);         \
74     Status = NtUnmapViewOfSection(NtCurrentProcess(),           \
75                                   MappedBuffer);                \
76     ok_eq_hex(Status, STATUS_NOT_MAPPED_VIEW);                  \
77 }
78 
79 START_TEST(MmMapLockedPagesSpecifyCache)
80 {
81     QUERY_BUFFER QueryBuffer;
82     READ_BUFFER ReadBuffer;
83     DWORD Length;
84     USHORT i;
85     USHORT BufferLength;
86     SYSTEM_BASIC_INFORMATION BasicInfo;
87     NTSTATUS Status;
88     ULONG_PTR HighestAddress;
89 
90     KmtLoadDriver(L"MmMapLockedPagesSpecifyCache", FALSE);
91     KmtOpenDriver();
92 
93     // Less than a page
94     SET_BUFFER_LENGTH(BufferLength, 2048);
95     Length = sizeof(QUERY_BUFFER);
96     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
97     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
98     ok_eq_int(QueryBuffer.Length, BufferLength);
99     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
100     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
101 
102     Length = 0;
103     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
104     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
105 
106     Length = sizeof(QUERY_BUFFER);
107     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, TRUE);
108     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
109     ok_eq_int(QueryBuffer.Length, BufferLength);
110     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
111     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
112 
113     Length = 0;
114     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
115     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
116 
117     // 1 page
118     SET_BUFFER_LENGTH(BufferLength, 4096);
119     Length = sizeof(QUERY_BUFFER);
120     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
121     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
122     ok_eq_int(QueryBuffer.Length, BufferLength);
123     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
124     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
125 
126     Length = 0;
127     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
128     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
129 
130     Length = sizeof(QUERY_BUFFER);
131     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, TRUE);
132     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
133     ok_eq_int(QueryBuffer.Length, BufferLength);
134     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
135     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
136 
137     Length = 0;
138     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
139     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
140 
141     // more than 1 page
142     SET_BUFFER_LENGTH(BufferLength, 4096 + 2048);
143     Length = sizeof(QUERY_BUFFER);
144     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
145     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
146     ok_eq_int(QueryBuffer.Length, BufferLength);
147     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
148     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
149 
150     Length = 0;
151     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
152     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
153 
154     Length = sizeof(QUERY_BUFFER);
155     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, TRUE);
156     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
157     ok_eq_int(QueryBuffer.Length, BufferLength);
158     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
159     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
160 
161     Length = 0;
162     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
163     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
164 
165     // 2 pages
166     SET_BUFFER_LENGTH(BufferLength, 2 * 4096);
167     Length = sizeof(QUERY_BUFFER);
168     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
169     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
170     ok_eq_int(QueryBuffer.Length, BufferLength);
171     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
172     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
173 
174     Length = 0;
175     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
176     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
177 
178     Length = sizeof(QUERY_BUFFER);
179     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, TRUE);
180     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
181     ok_eq_int(QueryBuffer.Length, BufferLength);
182     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
183     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
184 
185     Length = 0;
186     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
187     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
188 
189     // more than 2 pages
190     SET_BUFFER_LENGTH(BufferLength, 2 * 4096 + 2048);
191     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
192     Length = sizeof(QUERY_BUFFER);
193     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
194     ok_eq_int(QueryBuffer.Length, BufferLength);
195     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
196     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
197 
198     Length = 0;
199     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
200     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
201 
202     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, TRUE);
203     Length = sizeof(QUERY_BUFFER);
204     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
205     ok_eq_int(QueryBuffer.Length, BufferLength);
206     ok(QueryBuffer.Buffer != NULL, "Buffer is NULL\n");
207     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
208 
209     Length = 0;
210     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
211     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
212 
213     // ask for a specific address (we know that ReadBuffer.Buffer is free)
214     SET_BUFFER_LENGTH(BufferLength, 4096);
215     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
216     QueryBuffer.Buffer = ReadBuffer.Buffer;
217     Length = sizeof(QUERY_BUFFER);
218     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
219     ok_eq_int(QueryBuffer.Length, BufferLength);
220     ok(QueryBuffer.Buffer == ReadBuffer.Buffer, "Buffer is NULL\n");
221     CHECK_ALLOC(QueryBuffer.Buffer, BufferLength);
222 
223     Length = 0;
224     FILL_READ_BUFFER(QueryBuffer, ReadBuffer);
225     ok(KmtSendBufferToDriver(IOCTL_READ_BUFFER, &ReadBuffer, sizeof(READ_BUFFER), &Length) == ERROR_SUCCESS, "\n");
226 
227     // ask for an unaligned address
228     SET_BUFFER_LENGTH(BufferLength, 4096);
229     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
230     QueryBuffer.Buffer = (PVOID)((ULONG_PTR)ReadBuffer.Buffer + 2048);
231     QueryBuffer.Status = STATUS_INVALID_ADDRESS;
232     Length = sizeof(QUERY_BUFFER);
233     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
234     ok_eq_int(QueryBuffer.Length, BufferLength);
235     ok(QueryBuffer.Buffer == NULL, "Buffer is %p\n", QueryBuffer.Buffer);
236 
237     Length = 0;
238     ok(KmtSendBufferToDriver(IOCTL_CLEAN, NULL, 0, &Length) == ERROR_SUCCESS, "\n");
239 
240     // get system info for MmHighestUserAddress
241     Status = NtQuerySystemInformation(SystemBasicInformation,
242                                       &BasicInfo,
243                                       sizeof(BasicInfo),
244                                       NULL);
245     ok_eq_hex(Status, STATUS_SUCCESS);
246     trace("MaximumUserModeAddress: %lx\n", BasicInfo.MaximumUserModeAddress);
247     HighestAddress = ALIGN_DOWN_BY(BasicInfo.MaximumUserModeAddress, PAGE_SIZE);
248 
249     // near MmHighestUserAddress
250     SET_BUFFER_LENGTH(BufferLength, 4096);
251     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
252     QueryBuffer.Buffer = (PVOID)(HighestAddress - 15 * PAGE_SIZE); // 7ffe0000
253     QueryBuffer.Status = STATUS_INVALID_ADDRESS;
254     trace("QueryBuffer.Buffer %p\n", QueryBuffer.Buffer);
255     Length = sizeof(QUERY_BUFFER);
256     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
257     ok_eq_int(QueryBuffer.Length, BufferLength);
258     ok(QueryBuffer.Buffer == NULL, "Buffer is %p\n", QueryBuffer.Buffer);
259 
260     Length = 0;
261     ok(KmtSendBufferToDriver(IOCTL_CLEAN, NULL, 0, &Length) == ERROR_SUCCESS, "\n");
262 
263     // far enough away from MmHighestUserAddress
264     SET_BUFFER_LENGTH(BufferLength, 4096);
265     FILL_QUERY_BUFFER(QueryBuffer, BufferLength, FALSE);
266     QueryBuffer.Buffer = (PVOID)(HighestAddress - 16 * PAGE_SIZE); // 7ffdf000
267     QueryBuffer.Status = -1;
268     trace("QueryBuffer.Buffer %p\n", QueryBuffer.Buffer);
269     Length = sizeof(QUERY_BUFFER);
270     ok(KmtSendBufferToDriver(IOCTL_QUERY_BUFFER, &QueryBuffer, sizeof(QUERY_BUFFER), &Length) == ERROR_SUCCESS, "\n");
271     ok_eq_int(QueryBuffer.Length, BufferLength);
272     ok(QueryBuffer.Status == STATUS_SUCCESS ||
273        QueryBuffer.Status == STATUS_CONFLICTING_ADDRESSES, "Status = %lx\n", QueryBuffer.Status);
274 
275     Length = 0;
276     ok(KmtSendBufferToDriver(IOCTL_CLEAN, NULL, 0, &Length) == ERROR_SUCCESS, "\n");
277 
278     KmtCloseDriver();
279     KmtUnloadDriver();
280 }
281