1 /*
2  * PROJECT:         ReactOS kernel-mode tests
3  * LICENSE:         GPLv2+ - See COPYING in the top level directory
4  * PURPOSE:         Kernel-Mode Test Suite Fast Mutex test
5  * PROGRAMMER:      Thomas Faber <thomas.faber@reactos.org>
6  */
7 
8 #include <kmt_test.h>
9 
10 //#define NDEBUG
11 #include <debug.h>
12 
13 static
14 VOID
15 (FASTCALL
16 *pExEnterCriticalRegionAndAcquireFastMutexUnsafe)(
17     _Inout_ PFAST_MUTEX FastMutex
18 );
19 
20 static
21 VOID
22 (FASTCALL
23 *pExReleaseFastMutexUnsafeAndLeaveCriticalRegion)(
24     _Inout_ PFAST_MUTEX FastMutex
25 );
26 
27 static VOID    (FASTCALL *pExiAcquireFastMutex)(IN OUT PFAST_MUTEX FastMutex);
28 static VOID    (FASTCALL *pExiReleaseFastMutex)(IN OUT PFAST_MUTEX FastMutex);
29 static BOOLEAN (FASTCALL *pExiTryToAcquireFastMutex)(IN OUT PFAST_MUTEX FastMutex);
30 
31 #define CheckMutex(Mutex, ExpectedCount, ExpectedOwner,                 \
32                    ExpectedContention, ExpectedOldIrql,                 \
33                    ExpectedIrql) do                                     \
34 {                                                                       \
35     ok_eq_long((Mutex)->Count, ExpectedCount);                          \
36     ok_eq_pointer((Mutex)->Owner, ExpectedOwner);                       \
37     ok_eq_ulong((Mutex)->Contention, ExpectedContention);               \
38     ok_eq_ulong((Mutex)->OldIrql, (ULONG)ExpectedOldIrql);              \
39     ok_bool_false(KeAreApcsDisabled(), "KeAreApcsDisabled returned");   \
40     ok_irql(ExpectedIrql);                                              \
41 } while (0)
42 
43 static
44 VOID
45 TestFastMutex(
46     PFAST_MUTEX Mutex,
47     KIRQL OriginalIrql)
48 {
49     PKTHREAD Thread = KeGetCurrentThread();
50 
51     ok_irql(OriginalIrql);
52 
53     /* acquire/release normally */
54     ExAcquireFastMutex(Mutex);
55     CheckMutex(Mutex, 0L, Thread, 0LU, OriginalIrql, APC_LEVEL);
56     ok_bool_false(ExTryToAcquireFastMutex(Mutex), "ExTryToAcquireFastMutex returned");
57     CheckMutex(Mutex, 0L, Thread, 0LU, OriginalIrql, APC_LEVEL);
58     ExReleaseFastMutex(Mutex);
59     CheckMutex(Mutex, 1L, NULL, 0LU, OriginalIrql, OriginalIrql);
60 
61     /* ntoskrnl's fastcall version */
62     if (!skip(pExiAcquireFastMutex &&
63               pExiReleaseFastMutex &&
64               pExiTryToAcquireFastMutex, "No fastcall fast mutex functions\n"))
65     {
66         pExiAcquireFastMutex(Mutex);
67         CheckMutex(Mutex, 0L, Thread, 0LU, OriginalIrql, APC_LEVEL);
68         ok_bool_false(pExiTryToAcquireFastMutex(Mutex), "ExiTryToAcquireFastMutex returned");
69         CheckMutex(Mutex, 0L, Thread, 0LU, OriginalIrql, APC_LEVEL);
70         pExiReleaseFastMutex(Mutex);
71         CheckMutex(Mutex, 1L, NULL, 0LU, OriginalIrql, OriginalIrql);
72     }
73 
74     /* try to acquire */
75     ok_bool_true(ExTryToAcquireFastMutex(Mutex), "ExTryToAcquireFastMutex returned");
76     CheckMutex(Mutex, 0L, Thread, 0LU, OriginalIrql, APC_LEVEL);
77     ExReleaseFastMutex(Mutex);
78     CheckMutex(Mutex, 1L, NULL, 0LU, OriginalIrql, OriginalIrql);
79 
80     /* shortcut functions with critical region */
81     if (!skip(pExEnterCriticalRegionAndAcquireFastMutexUnsafe &&
82               pExReleaseFastMutexUnsafeAndLeaveCriticalRegion,
83               "Shortcut functions not available"))
84     {
85         pExEnterCriticalRegionAndAcquireFastMutexUnsafe(Mutex);
86         ok_bool_true(KeAreApcsDisabled(), "KeAreApcsDisabled returned");
87         pExReleaseFastMutexUnsafeAndLeaveCriticalRegion(Mutex);
88     }
89 
90     /* acquire/release unsafe */
91     if (!KmtIsCheckedBuild || OriginalIrql == APC_LEVEL)
92     {
93         ExAcquireFastMutexUnsafe(Mutex);
94         CheckMutex(Mutex, 0L, Thread, 0LU, OriginalIrql, OriginalIrql);
95         ExReleaseFastMutexUnsafe(Mutex);
96         CheckMutex(Mutex, 1L, NULL, 0LU, OriginalIrql, OriginalIrql);
97 
98         /* mismatched acquire/release */
99         ExAcquireFastMutex(Mutex);
100         CheckMutex(Mutex, 0L, Thread, 0LU, OriginalIrql, APC_LEVEL);
101         ExReleaseFastMutexUnsafe(Mutex);
102         CheckMutex(Mutex, 1L, NULL, 0LU, OriginalIrql, APC_LEVEL);
103         KmtSetIrql(OriginalIrql);
104         CheckMutex(Mutex, 1L, NULL, 0LU, OriginalIrql, OriginalIrql);
105 
106         Mutex->OldIrql = 0x55555555LU;
107         ExAcquireFastMutexUnsafe(Mutex);
108         CheckMutex(Mutex, 0L, Thread, 0LU, 0x55555555LU, OriginalIrql);
109         Mutex->OldIrql = PASSIVE_LEVEL;
110         ExReleaseFastMutex(Mutex);
111         CheckMutex(Mutex, 1L, NULL, 0LU, PASSIVE_LEVEL, PASSIVE_LEVEL);
112         KmtSetIrql(OriginalIrql);
113         CheckMutex(Mutex, 1L, NULL, 0LU, PASSIVE_LEVEL, OriginalIrql);
114     }
115 
116     if (!KmtIsCheckedBuild)
117     {
118         /* release without acquire */
119         ExReleaseFastMutexUnsafe(Mutex);
120         CheckMutex(Mutex, 2L, NULL, 0LU, PASSIVE_LEVEL, OriginalIrql);
121         --Mutex->Count;
122         Mutex->OldIrql = OriginalIrql;
123         ExReleaseFastMutex(Mutex);
124         CheckMutex(Mutex, 2L, NULL, 0LU, OriginalIrql, OriginalIrql);
125         ExReleaseFastMutex(Mutex);
126         CheckMutex(Mutex, 3L, NULL, 0LU, OriginalIrql, OriginalIrql);
127         Mutex->Count -= 2;
128     }
129 
130     /* make sure we survive this in case of error */
131     ok_eq_long(Mutex->Count, 1L);
132     Mutex->Count = 1;
133     ok_irql(OriginalIrql);
134     KmtSetIrql(OriginalIrql);
135 }
136 
137 typedef VOID (FASTCALL *PMUTEX_FUNCTION)(PFAST_MUTEX);
138 typedef BOOLEAN (FASTCALL *PMUTEX_TRY_FUNCTION)(PFAST_MUTEX);
139 
140 typedef struct
141 {
142     HANDLE Handle;
143     PKTHREAD Thread;
144     KIRQL Irql;
145     PFAST_MUTEX Mutex;
146     PMUTEX_FUNCTION Acquire;
147     PMUTEX_TRY_FUNCTION TryAcquire;
148     PMUTEX_FUNCTION Release;
149     BOOLEAN Try;
150     BOOLEAN RetExpected;
151     KEVENT InEvent;
152     KEVENT OutEvent;
153 } THREAD_DATA, *PTHREAD_DATA;
154 
155 static
156 VOID
157 NTAPI
158 AcquireMutexThread(
159     PVOID Parameter)
160 {
161     PTHREAD_DATA ThreadData = Parameter;
162     KIRQL Irql;
163     BOOLEAN Ret = FALSE;
164     NTSTATUS Status;
165 
166     KeRaiseIrql(ThreadData->Irql, &Irql);
167 
168     if (ThreadData->Try)
169     {
170         Ret = ThreadData->TryAcquire(ThreadData->Mutex);
171         ok_eq_bool(Ret, ThreadData->RetExpected);
172     }
173     else
174         ThreadData->Acquire(ThreadData->Mutex);
175 
176     ok_bool_false(KeSetEvent(&ThreadData->OutEvent, 0, TRUE), "KeSetEvent returned");
177     Status = KeWaitForSingleObject(&ThreadData->InEvent, Executive, KernelMode, FALSE, NULL);
178     ok_eq_hex(Status, STATUS_SUCCESS);
179 
180     if (!ThreadData->Try || Ret)
181         ThreadData->Release(ThreadData->Mutex);
182 
183     KeLowerIrql(Irql);
184 }
185 
186 static
187 VOID
188 InitThreadData(
189     PTHREAD_DATA ThreadData,
190     PFAST_MUTEX Mutex,
191     PMUTEX_FUNCTION Acquire,
192     PMUTEX_TRY_FUNCTION TryAcquire,
193     PMUTEX_FUNCTION Release)
194 {
195     ThreadData->Mutex = Mutex;
196     KeInitializeEvent(&ThreadData->InEvent, NotificationEvent, FALSE);
197     KeInitializeEvent(&ThreadData->OutEvent, NotificationEvent, FALSE);
198     ThreadData->Acquire = Acquire;
199     ThreadData->TryAcquire = TryAcquire;
200     ThreadData->Release = Release;
201 }
202 
203 static
204 NTSTATUS
205 StartThread(
206     PTHREAD_DATA ThreadData,
207     PLARGE_INTEGER Timeout,
208     KIRQL Irql,
209     BOOLEAN Try,
210     BOOLEAN RetExpected)
211 {
212     NTSTATUS Status = STATUS_SUCCESS;
213     OBJECT_ATTRIBUTES Attributes;
214 
215     ThreadData->Try = Try;
216     ThreadData->Irql = Irql;
217     ThreadData->RetExpected = RetExpected;
218     InitializeObjectAttributes(&Attributes, NULL, OBJ_KERNEL_HANDLE, NULL, NULL);
219     Status = PsCreateSystemThread(&ThreadData->Handle, GENERIC_ALL, &Attributes, NULL, NULL, AcquireMutexThread, ThreadData);
220     ok_eq_hex(Status, STATUS_SUCCESS);
221     Status = ObReferenceObjectByHandle(ThreadData->Handle, SYNCHRONIZE, *PsThreadType, KernelMode, (PVOID *)&ThreadData->Thread, NULL);
222     ok_eq_hex(Status, STATUS_SUCCESS);
223 
224     return KeWaitForSingleObject(&ThreadData->OutEvent, Executive, KernelMode, FALSE, Timeout);
225 }
226 
227 static
228 VOID
229 FinishThread(
230     PTHREAD_DATA ThreadData)
231 {
232     NTSTATUS Status = STATUS_SUCCESS;
233 
234     KeSetEvent(&ThreadData->InEvent, 0, TRUE);
235     Status = KeWaitForSingleObject(ThreadData->Thread, Executive, KernelMode, FALSE, NULL);
236     ok_eq_hex(Status, STATUS_SUCCESS);
237 
238     ObDereferenceObject(ThreadData->Thread);
239     Status = ZwClose(ThreadData->Handle);
240     ok_eq_hex(Status, STATUS_SUCCESS);
241     KeClearEvent(&ThreadData->InEvent);
242     KeClearEvent(&ThreadData->OutEvent);
243 }
244 
245 static
246 VOID
247 TestFastMutexConcurrent(
248     PFAST_MUTEX Mutex)
249 {
250     NTSTATUS Status;
251     THREAD_DATA ThreadData;
252     THREAD_DATA ThreadData2;
253     THREAD_DATA ThreadDataUnsafe;
254     THREAD_DATA ThreadDataTry;
255     LARGE_INTEGER Timeout;
256     Timeout.QuadPart = -10 * 1000 * 10; /* 10 ms */
257 
258     InitThreadData(&ThreadData, Mutex, ExAcquireFastMutex, NULL, ExReleaseFastMutex);
259     InitThreadData(&ThreadData2, Mutex, ExAcquireFastMutex, NULL, ExReleaseFastMutex);
260     InitThreadData(&ThreadDataUnsafe, Mutex, ExAcquireFastMutexUnsafe, NULL, ExReleaseFastMutexUnsafe);
261     InitThreadData(&ThreadDataTry, Mutex, NULL, ExTryToAcquireFastMutex, ExReleaseFastMutex);
262 
263     /* have a thread acquire the mutex */
264     Status = StartThread(&ThreadData, NULL, PASSIVE_LEVEL, FALSE, FALSE);
265     ok_eq_hex(Status, STATUS_SUCCESS);
266     CheckMutex(Mutex, 0L, ThreadData.Thread, 0LU, PASSIVE_LEVEL, PASSIVE_LEVEL);
267     /* have a second thread try to acquire it -- should fail */
268     Status = StartThread(&ThreadDataTry, NULL, PASSIVE_LEVEL, TRUE, FALSE);
269     ok_eq_hex(Status, STATUS_SUCCESS);
270     CheckMutex(Mutex, 0L, ThreadData.Thread, 0LU, PASSIVE_LEVEL, PASSIVE_LEVEL);
271     FinishThread(&ThreadDataTry);
272 
273     /* have another thread acquire it -- should block */
274     Status = StartThread(&ThreadData2, &Timeout, APC_LEVEL, FALSE, FALSE);
275     ok_eq_hex(Status, STATUS_TIMEOUT);
276     CheckMutex(Mutex, -1L, ThreadData.Thread, 1LU, PASSIVE_LEVEL, PASSIVE_LEVEL);
277 
278     /* finish the first thread -- now the second should become available */
279     FinishThread(&ThreadData);
280     Status = KeWaitForSingleObject(&ThreadData2.OutEvent, Executive, KernelMode, FALSE, NULL);
281     ok_eq_hex(Status, STATUS_SUCCESS);
282     CheckMutex(Mutex, 0L, ThreadData2.Thread, 1LU, APC_LEVEL, PASSIVE_LEVEL);
283 
284     /* block two more threads */
285     Status = StartThread(&ThreadDataUnsafe, &Timeout, APC_LEVEL, FALSE, FALSE);
286     ok_eq_hex(Status, STATUS_TIMEOUT);
287     CheckMutex(Mutex, -1L, ThreadData2.Thread, 2LU, APC_LEVEL, PASSIVE_LEVEL);
288 
289     Status = StartThread(&ThreadData, &Timeout, PASSIVE_LEVEL, FALSE, FALSE);
290     ok_eq_hex(Status, STATUS_TIMEOUT);
291     CheckMutex(Mutex, -2L, ThreadData2.Thread, 3LU, APC_LEVEL, PASSIVE_LEVEL);
292 
293     /* finish 1 */
294     FinishThread(&ThreadData2);
295     Status = KeWaitForSingleObject(&ThreadDataUnsafe.OutEvent, Executive, KernelMode, FALSE, NULL);
296     ok_eq_hex(Status, STATUS_SUCCESS);
297     CheckMutex(Mutex, -1L, ThreadDataUnsafe.Thread, 3LU, APC_LEVEL, PASSIVE_LEVEL);
298 
299     /* finish 2 */
300     FinishThread(&ThreadDataUnsafe);
301     Status = KeWaitForSingleObject(&ThreadData.OutEvent, Executive, KernelMode, FALSE, NULL);
302     ok_eq_hex(Status, STATUS_SUCCESS);
303     CheckMutex(Mutex, 0L, ThreadData.Thread, 3LU, PASSIVE_LEVEL, PASSIVE_LEVEL);
304 
305     /* finish 3 */
306     FinishThread(&ThreadData);
307 
308     CheckMutex(Mutex, 1L, NULL, 3LU, PASSIVE_LEVEL, PASSIVE_LEVEL);
309 }
310 
311 START_TEST(ExFastMutex)
312 {
313     FAST_MUTEX Mutex;
314     KIRQL Irql;
315 
316     pExEnterCriticalRegionAndAcquireFastMutexUnsafe = KmtGetSystemRoutineAddress(L"ExEnterCriticalRegionAndAcquireFastMutexUnsafe");
317     pExReleaseFastMutexUnsafeAndLeaveCriticalRegion = KmtGetSystemRoutineAddress(L"ExReleaseFastMutexUnsafeAndLeaveCriticalRegion");
318 
319     pExiAcquireFastMutex = KmtGetSystemRoutineAddress(L"ExiAcquireFastMutex");
320     pExiReleaseFastMutex = KmtGetSystemRoutineAddress(L"ExiReleaseFastMutex");
321     pExiTryToAcquireFastMutex = KmtGetSystemRoutineAddress(L"ExiTryToAcquireFastMutex");
322 
323     memset(&Mutex, 0x55, sizeof Mutex);
324     ExInitializeFastMutex(&Mutex);
325     CheckMutex(&Mutex, 1L, NULL, 0LU, 0x55555555LU, PASSIVE_LEVEL);
326 
327     TestFastMutex(&Mutex, PASSIVE_LEVEL);
328     KeRaiseIrql(APC_LEVEL, &Irql);
329     TestFastMutex(&Mutex, APC_LEVEL);
330     if (!KmtIsCheckedBuild)
331     {
332         KeRaiseIrql(DISPATCH_LEVEL, &Irql);
333         TestFastMutex(&Mutex, DISPATCH_LEVEL);
334         KeRaiseIrql(HIGH_LEVEL, &Irql);
335         TestFastMutex(&Mutex, HIGH_LEVEL);
336     }
337     KeLowerIrql(PASSIVE_LEVEL);
338 
339     TestFastMutexConcurrent(&Mutex);
340 }
341