1 /*
2  * PROJECT:         ReactOS kernel-mode tests
3  * LICENSE:         GPLv2+ - See COPYING in the top level directory
4  * PURPOSE:         Kernel-Mode Test Suite Event test
5  * PROGRAMMER:      Thomas Faber <thomas.faber@reactos.org>
6  */
7 
8 #include <kmt_test.h>
9 
10 #define CheckEvent(Event, ExpectedType, State, ExpectedWaitNext,                \
11                             Irql, ThreadList, ThreadCount) do                   \
12 {                                                                               \
13     INT TheIndex;                                                               \
14     PLIST_ENTRY TheEntry;                                                       \
15     PKTHREAD TheThread;                                                         \
16     ok_eq_uint((Event)->Header.Type, ExpectedType);                             \
17     ok_eq_uint((Event)->Header.Hand, sizeof *(Event) / sizeof(ULONG));          \
18     ok_eq_hex((Event)->Header.Lock & 0xFF00FF00L, 0x55005500L);                 \
19     ok_eq_long((Event)->Header.SignalState, State);                             \
20     TheEntry = (Event)->Header.WaitListHead.Flink;                              \
21     for (TheIndex = 0; TheIndex < (ThreadCount); ++TheIndex)                    \
22     {                                                                           \
23         TheThread = CONTAINING_RECORD(TheEntry, KTHREAD,                        \
24                                       WaitBlock[0].WaitListEntry);              \
25         ok_eq_pointer(TheThread, (ThreadList)[TheIndex]);                       \
26         ok_eq_pointer(TheEntry->Flink->Blink, TheEntry);                        \
27         TheEntry = TheEntry->Flink;                                             \
28     }                                                                           \
29     ok_eq_pointer(TheEntry, &(Event)->Header.WaitListHead);                     \
30     ok_eq_pointer(TheEntry->Flink->Blink, TheEntry);                            \
31     ok_eq_long(KeReadStateEvent(Event), State);                                 \
32     ok_eq_bool(Thread->WaitNext, ExpectedWaitNext);                             \
33     ok_irql(Irql);                                                              \
34 } while (0)
35 
36 static
37 VOID
38 TestEventFunctional(
39     IN PKEVENT Event,
40     IN EVENT_TYPE Type,
41     IN KIRQL OriginalIrql)
42 {
43     LONG State;
44     PKTHREAD Thread = KeGetCurrentThread();
45 
46     memset(Event, 0x55, sizeof *Event);
47     KeInitializeEvent(Event, Type, FALSE);
48     CheckEvent(Event, Type, 0L, FALSE, OriginalIrql, (PVOID *)NULL, 0);
49 
50     memset(Event, 0x55, sizeof *Event);
51     KeInitializeEvent(Event, Type, TRUE);
52     CheckEvent(Event, Type, 1L, FALSE, OriginalIrql, (PVOID *)NULL, 0);
53 
54     Event->Header.SignalState = 0x12345678L;
55     CheckEvent(Event, Type, 0x12345678L, FALSE, OriginalIrql, (PVOID *)NULL, 0);
56 
57     State = KePulseEvent(Event, 0, FALSE);
58     CheckEvent(Event, Type, 0L, FALSE, OriginalIrql, (PVOID *)NULL, 0);
59     ok_eq_long(State, 0x12345678L);
60 
61     Event->Header.SignalState = 0x12345678L;
62     KeClearEvent(Event);
63     CheckEvent(Event, Type, 0L, FALSE, OriginalIrql, (PVOID *)NULL, 0);
64 
65     State = KeSetEvent(Event, 0, FALSE);
66     CheckEvent(Event, Type, 1L, FALSE, OriginalIrql, (PVOID *)NULL, 0);
67     ok_eq_long(State, 0L);
68 
69     State = KeResetEvent(Event);
70     CheckEvent(Event, Type, 0L, FALSE, OriginalIrql, (PVOID *)NULL, 0);
71     ok_eq_long(State, 1L);
72 
73     Event->Header.SignalState = 0x23456789L;
74     State = KeSetEvent(Event, 0, FALSE);
75     CheckEvent(Event, Type, 1L, FALSE, OriginalIrql, (PVOID *)NULL, 0);
76     ok_eq_long(State, 0x23456789L);
77 
78     Event->Header.SignalState = 0x3456789AL;
79     State = KeResetEvent(Event);
80     CheckEvent(Event, Type, 0L, FALSE, OriginalIrql, (PVOID *)NULL, 0);
81     ok_eq_long(State, 0x3456789AL);
82 
83     /* Irql is raised to DISPATCH_LEVEL here, which kills checked build,
84      * a spinlock is acquired and never released, which kills MP build */
85     if ((OriginalIrql <= DISPATCH_LEVEL || !KmtIsCheckedBuild) &&
86         !KmtIsMultiProcessorBuild)
87     {
88         Event->Header.SignalState = 0x456789ABL;
89         State = KeSetEvent(Event, 0, TRUE);
90         CheckEvent(Event, Type, 1L, TRUE, DISPATCH_LEVEL, (PVOID *)NULL, 0);
91         ok_eq_long(State, 0x456789ABL);
92         ok_eq_uint(Thread->WaitIrql, OriginalIrql);
93         /* repair the "damage" */
94         Thread->WaitNext = FALSE;
95         KmtSetIrql(OriginalIrql);
96 
97         Event->Header.SignalState = 0x56789ABCL;
98         State = KePulseEvent(Event, 0, TRUE);
99         CheckEvent(Event, Type, 0L, TRUE, DISPATCH_LEVEL, (PVOID *)NULL, 0);
100         ok_eq_long(State, 0x56789ABCL);
101         ok_eq_uint(Thread->WaitIrql, OriginalIrql);
102         /* repair the "damage" */
103         Thread->WaitNext = FALSE;
104         KmtSetIrql(OriginalIrql);
105     }
106 
107     ok_irql(OriginalIrql);
108     KmtSetIrql(OriginalIrql);
109 }
110 
111 typedef struct
112 {
113     HANDLE Handle;
114     PKTHREAD Thread;
115     PKEVENT Event;
116     volatile BOOLEAN Signal;
117 } THREAD_DATA, *PTHREAD_DATA;
118 
119 static
120 VOID
121 NTAPI
122 WaitForEventThread(
123     IN OUT PVOID Context)
124 {
125     NTSTATUS Status;
126     PTHREAD_DATA ThreadData = Context;
127 
128     ok_irql(PASSIVE_LEVEL);
129     ThreadData->Signal = TRUE;
130     Status = KeWaitForSingleObject(ThreadData->Event, Executive, KernelMode, FALSE, NULL);
131     ok_eq_hex(Status, STATUS_SUCCESS);
132     ok_irql(PASSIVE_LEVEL);
133 }
134 
135 typedef LONG (NTAPI *PSET_EVENT_FUNCTION)(PRKEVENT, KPRIORITY, BOOLEAN);
136 
137 static
138 VOID
139 TestEventConcurrent(
140     IN PKEVENT Event,
141     IN EVENT_TYPE Type,
142     IN KIRQL OriginalIrql,
143     PSET_EVENT_FUNCTION SetEvent,
144     KPRIORITY PriorityIncrement,
145     LONG ExpectedState,
146     BOOLEAN SatisfiesAll)
147 {
148     NTSTATUS Status;
149     THREAD_DATA Threads[5];
150     const INT ThreadCount = sizeof Threads / sizeof Threads[0];
151     KPRIORITY Priority;
152     LARGE_INTEGER LongTimeout, ShortTimeout;
153     INT i;
154     KWAIT_BLOCK WaitBlock[RTL_NUMBER_OF(Threads)];
155     PVOID ThreadObjects[RTL_NUMBER_OF(Threads)];
156     LONG State;
157     PKTHREAD Thread = KeGetCurrentThread();
158     OBJECT_ATTRIBUTES ObjectAttributes;
159 
160     LongTimeout.QuadPart = -100 * MILLISECOND;
161     ShortTimeout.QuadPart = -1 * MILLISECOND;
162 
163     KeInitializeEvent(Event, Type, FALSE);
164 
165     for (i = 0; i < ThreadCount; ++i)
166     {
167         Threads[i].Event = Event;
168         Threads[i].Signal = FALSE;
169         InitializeObjectAttributes(&ObjectAttributes,
170                                    NULL,
171                                    OBJ_KERNEL_HANDLE,
172                                    NULL,
173                                    NULL);
174         Status = PsCreateSystemThread(&Threads[i].Handle, GENERIC_ALL, &ObjectAttributes, NULL, NULL, WaitForEventThread, &Threads[i]);
175         ok_eq_hex(Status, STATUS_SUCCESS);
176         Status = ObReferenceObjectByHandle(Threads[i].Handle, SYNCHRONIZE, *PsThreadType, KernelMode, (PVOID *)&Threads[i].Thread, NULL);
177         ok_eq_hex(Status, STATUS_SUCCESS);
178         ThreadObjects[i] = Threads[i].Thread;
179         Priority = KeQueryPriorityThread(Threads[i].Thread);
180         ok_eq_long(Priority, 8L);
181         while (!Threads[i].Signal)
182         {
183             Status = KeDelayExecutionThread(KernelMode, FALSE, &ShortTimeout);
184             if (Status != STATUS_SUCCESS)
185             {
186                 ok_eq_hex(Status, STATUS_SUCCESS);
187             }
188         }
189         CheckEvent(Event, Type, 0L, FALSE, OriginalIrql, ThreadObjects, i + 1);
190     }
191 
192     /* the threads shouldn't wake up on their own */
193     Status = KeDelayExecutionThread(KernelMode, FALSE, &ShortTimeout);
194     ok_eq_hex(Status, STATUS_SUCCESS);
195 
196     for (i = 0; i < ThreadCount; ++i)
197     {
198         CheckEvent(Event, Type, 0L, FALSE, OriginalIrql, ThreadObjects + i, ThreadCount - i);
199         State = SetEvent(Event, PriorityIncrement + i, FALSE);
200 
201         ok_eq_long(State, 0L);
202         CheckEvent(Event, Type, ExpectedState, FALSE, OriginalIrql, ThreadObjects + i + 1, SatisfiesAll ? 0 : ThreadCount - i - 1);
203         Status = KeWaitForMultipleObjects(ThreadCount, ThreadObjects, SatisfiesAll ? WaitAll : WaitAny, Executive, KernelMode, FALSE, &LongTimeout, WaitBlock);
204         ok_eq_hex(Status, STATUS_WAIT_0 + i);
205         if (SatisfiesAll)
206         {
207             for (; i < ThreadCount; ++i)
208             {
209                 Priority = KeQueryPriorityThread(Threads[i].Thread);
210                 ok_eq_long(Priority, max(min(8L + PriorityIncrement, 15L), 8L));
211             }
212             break;
213         }
214         Priority = KeQueryPriorityThread(Threads[i].Thread);
215         ok_eq_long(Priority, max(min(8L + PriorityIncrement + i, 15L), 8L));
216         /* replace the thread with the current thread - which will never signal */
217         if (!skip((Status & 0x3F) < ThreadCount, "Index out of bounds\n"))
218             ThreadObjects[Status & 0x3F] = Thread;
219         Status = KeWaitForMultipleObjects(ThreadCount, ThreadObjects, WaitAny, Executive, KernelMode, FALSE, &ShortTimeout, WaitBlock);
220         ok_eq_hex(Status, STATUS_TIMEOUT);
221     }
222 
223     for (i = 0; i < ThreadCount; ++i)
224     {
225         ObDereferenceObject(Threads[i].Thread);
226         Status = ZwClose(Threads[i].Handle);
227         ok_eq_hex(Status, STATUS_SUCCESS);
228     }
229 }
230 
231 #define NUM_SCHED_TESTS 1000
232 
233 typedef struct
234 {
235     KEVENT Event;
236     KEVENT WaitEvent;
237     ULONG Counter;
238     KPRIORITY PriorityIncrement;
239     ULONG CounterValues[NUM_SCHED_TESTS];
240 } COUNT_THREAD_DATA, *PCOUNT_THREAD_DATA;
241 
242 static
243 VOID
244 NTAPI
245 CountThread(
246     IN OUT PVOID Context)
247 {
248     PCOUNT_THREAD_DATA ThreadData = Context;
249     PKEVENT Event = &ThreadData->Event;
250     volatile ULONG *Counter = &ThreadData->Counter;
251     ULONG *CounterValue = ThreadData->CounterValues;
252     KPRIORITY Priority;
253 
254     Priority = KeQueryPriorityThread(KeGetCurrentThread());
255     ok_eq_long(Priority, 8L);
256 
257     while (CounterValue < &ThreadData->CounterValues[NUM_SCHED_TESTS])
258     {
259         KeSetEvent(&ThreadData->WaitEvent, IO_NO_INCREMENT, TRUE);
260         KeWaitForSingleObject(Event, Executive, KernelMode, FALSE, NULL);
261         *CounterValue++ = *Counter;
262     }
263 
264     Priority = KeQueryPriorityThread(KeGetCurrentThread());
265     ok_eq_long(Priority, 8L + min(ThreadData->PriorityIncrement, 7));
266 }
267 
268 static
269 VOID
270 NTAPI
271 TestEventScheduling(
272     _In_ PVOID Context)
273 {
274     PCOUNT_THREAD_DATA ThreadData;
275     PKTHREAD Thread;
276     NTSTATUS Status;
277     LONG PreviousState;
278     ULONG i;
279     volatile ULONG *Counter;
280     KPRIORITY PriorityIncrement;
281     KPRIORITY Priority;
282 
283     UNREFERENCED_PARAMETER(Context);
284 
285     ThreadData = ExAllocatePoolWithTag(PagedPool, sizeof(*ThreadData), 'CEmK');
286     if (skip(ThreadData != NULL, "Out of memory\n"))
287     {
288         return;
289     }
290     KeInitializeEvent(&ThreadData->Event, SynchronizationEvent, FALSE);
291     KeInitializeEvent(&ThreadData->WaitEvent, SynchronizationEvent, FALSE);
292     Counter = &ThreadData->Counter;
293 
294     for (PriorityIncrement = 0; PriorityIncrement <= 8; PriorityIncrement++)
295     {
296         ThreadData->PriorityIncrement = PriorityIncrement;
297         ThreadData->Counter = 0;
298         RtlFillMemory(ThreadData->CounterValues,
299                       sizeof(ThreadData->CounterValues),
300                       0xFE);
301         Thread = KmtStartThread(CountThread, ThreadData);
302         Priority = KeQueryPriorityThread(KeGetCurrentThread());
303         ok(Priority == 8, "[%lu] Priority = %lu\n", PriorityIncrement, Priority);
304         for (i = 1; i <= NUM_SCHED_TESTS; i++)
305         {
306             Status = KeWaitForSingleObject(&ThreadData->WaitEvent, Executive, KernelMode, FALSE, NULL);
307             ok_eq_hex(Status, STATUS_SUCCESS);
308             PreviousState = KeSetEvent(&ThreadData->Event, PriorityIncrement, FALSE);
309             *Counter = i;
310             ok_eq_long(PreviousState, 0L);
311         }
312         Priority = KeQueryPriorityThread(KeGetCurrentThread());
313         ok(Priority == 8, "[%lu] Priority = %lu\n", PriorityIncrement, Priority);
314         KmtFinishThread(Thread, NULL);
315 
316         if (PriorityIncrement == 0)
317         {
318             /* Both threads have the same priority, so either can win the race */
319             ok(ThreadData->CounterValues[0] == 0 || ThreadData->CounterValues[0] == 1,
320                "[%lu] Counter 0 = %lu\n",
321                PriorityIncrement, ThreadData->CounterValues[0]);
322         }
323         else
324         {
325             /* CountThread has the higher priority, it will always win */
326             ok(ThreadData->CounterValues[0] == 0,
327                "[%lu] Counter 0 = %lu\n",
328                PriorityIncrement, ThreadData->CounterValues[0]);
329         }
330         for (i = 1; i < NUM_SCHED_TESTS; i++)
331         {
332             if (PriorityIncrement == 0)
333             {
334                 ok(ThreadData->CounterValues[i] == i ||
335                    ThreadData->CounterValues[i] == i + 1,
336                    "[%lu] Counter %lu = %lu, expected %lu or %lu\n",
337                    PriorityIncrement, i,
338                    ThreadData->CounterValues[i], i, i + 1);
339             }
340             else
341             {
342                 ok(ThreadData->CounterValues[i] == ThreadData->CounterValues[i - 1] + 1,
343                    "[%lu] Counter %lu = %lu, expected %lu\n",
344                    PriorityIncrement, i,
345                    ThreadData->CounterValues[i], ThreadData->CounterValues[i - 1] + 1);
346             }
347         }
348     }
349 
350     ExFreePoolWithTag(ThreadData, 'CEmK');
351 }
352 
353 START_TEST(KeEvent)
354 {
355     PKTHREAD Thread;
356     KEVENT Event;
357     KIRQL Irql;
358     KIRQL Irqls[] = { PASSIVE_LEVEL, APC_LEVEL, DISPATCH_LEVEL };
359     ULONG i;
360     KPRIORITY PriorityIncrement;
361 
362     for (i = 0; i < RTL_NUMBER_OF(Irqls); ++i)
363     {
364         KeRaiseIrql(Irqls[i], &Irql);
365         TestEventFunctional(&Event, NotificationEvent, Irqls[i]);
366         TestEventFunctional(&Event, SynchronizationEvent, Irqls[i]);
367         KeLowerIrql(Irql);
368     }
369 
370     for (i = 0; i < RTL_NUMBER_OF(Irqls); ++i)
371     {
372         /* creating threads above DISPATCH_LEVEL... nope */
373         if (Irqls[i] >= DISPATCH_LEVEL)
374             continue;
375         KeRaiseIrql(Irqls[i], &Irql);
376         trace("IRQL: %u\n", Irqls[i]);
377         for (PriorityIncrement = -1; PriorityIncrement <= 8; ++PriorityIncrement)
378         {
379             if (PriorityIncrement < 0 && KmtIsCheckedBuild)
380                 continue;
381             trace("PriorityIncrement: %ld\n", PriorityIncrement);
382             trace("-> Checking KeSetEvent, NotificationEvent\n");
383             TestEventConcurrent(&Event, NotificationEvent, Irqls[i], KeSetEvent, PriorityIncrement, 1, TRUE);
384             trace("-> Checking KeSetEvent, SynchronizationEvent\n");
385             TestEventConcurrent(&Event, SynchronizationEvent, Irqls[i], KeSetEvent, PriorityIncrement, 0, FALSE);
386             trace("-> Checking KePulseEvent, NotificationEvent\n");
387             TestEventConcurrent(&Event, NotificationEvent, Irqls[i], KePulseEvent, PriorityIncrement, 0, TRUE);
388             trace("-> Checking KePulseEvent, SynchronizationEvent\n");
389             TestEventConcurrent(&Event, SynchronizationEvent, Irqls[i], KePulseEvent, PriorityIncrement, 0, FALSE);
390         }
391         KeLowerIrql(Irql);
392     }
393 
394     ok_irql(PASSIVE_LEVEL);
395     KmtSetIrql(PASSIVE_LEVEL);
396 
397     Thread = KmtStartThread(TestEventScheduling, NULL);
398     KmtFinishThread(Thread, NULL);
399 }
400