1 /*
2  * PROJECT:         ReactOS kernel-mode tests
3  * LICENSE:         GPLv2+ - See COPYING in the top level directory
4  * PURPOSE:         Kernel-Mode Test Suite Executive Callback test
5  * PROGRAMMER:      Thomas Faber <thomas.faber@reactos.org>
6  */
7 
8 #include <kmt_test.h>
9 
10 static
11 PEX_CALLBACK_ROUTINE_BLOCK
12 (NTAPI
13 *ExAllocateCallBack)(
14     IN PEX_CALLBACK_FUNCTION Function,
15     IN PVOID Context
16 )
17 //= (PVOID)0x809af1f4 // 2003 sp1 x86
18 //= (PVOID)0x80a7f04a // 2003 sp1 x86 checked
19 ;
20 
21 static
22 VOID
23 (NTAPI
24 *ExFreeCallBack)(
25     IN PEX_CALLBACK_ROUTINE_BLOCK CallbackBlock
26 )
27 //= (PVOID)0x80918bb5 // 2003 sp1 x86
28 //= (PVOID)0x80a355f0 // 2003 sp1 x86 checked
29 ;
30 
31 static INT CallbackArgument1;
32 static INT CallbackArgument2;
33 
34 static
35 NTSTATUS
36 NTAPI
37 ExCallbackFunction(
38     IN PVOID CallbackContext,
39     IN PVOID Argument1 OPTIONAL,
40     IN PVOID Argument2 OPTIONAL)
41 {
42     ok(0, "Callback function unexpectedly called\n");
43     return STATUS_SUCCESS;
44 }
45 
46 static
47 VOID
48 TestPrivateFunctions(VOID)
49 {
50     UNICODE_STRING ExAllocateCallBackName = RTL_CONSTANT_STRING(L"ExAllocateCallBack");
51     UNICODE_STRING ExFreeCallBackName = RTL_CONSTANT_STRING(L"ExFreeCallBack");
52     PEX_CALLBACK_ROUTINE_BLOCK CallbackBlock;
53     INT CallbackContext;
54 
55     if (!ExAllocateCallBack)
56         ExAllocateCallBack = MmGetSystemRoutineAddress(&ExAllocateCallBackName);
57     if (!ExFreeCallBack)
58         ExFreeCallBack = MmGetSystemRoutineAddress(&ExFreeCallBackName);
59 
60     if (skip(ExAllocateCallBack && ExFreeCallBack,
61              "ExAllocateCallBack and/or ExFreeCallBack unavailable\n"))
62         return;
63 
64     CallbackBlock = ExAllocateCallBack(ExCallbackFunction, &CallbackContext);
65     ok(CallbackBlock != NULL, "CallbackBlock = NULL\n");
66 
67     if (skip(CallbackBlock != NULL, "Allocating callback failed\n"))
68         return;
69 
70     ok_eq_pointer(CallbackBlock->Function, ExCallbackFunction);
71     ok_eq_pointer(CallbackBlock->Context, &CallbackContext);
72     ok_eq_hex(KmtGetPoolTag(CallbackBlock), 'brbC');
73 
74     ExFreeCallBack(CallbackBlock);
75 }
76 
77 static
78 VOID
79 NTAPI
80 CallbackFunction(
81     IN PVOID CallbackContext,
82     IN PVOID Argument1,
83     IN PVOID Argument2)
84 {
85     INT *InvocationCount = CallbackContext;
86 
87     ok_irql(PASSIVE_LEVEL);
88 
89     (*InvocationCount)++;
90     ok_eq_pointer(Argument1, &CallbackArgument1);
91     ok_eq_pointer(Argument2, &CallbackArgument2);
92 }
93 
94 START_TEST(ExCallback)
95 {
96     NTSTATUS Status;
97     PCALLBACK_OBJECT CallbackObject;
98     OBJECT_ATTRIBUTES ObjectAttributes;
99     UNICODE_STRING CallbackName = RTL_CONSTANT_STRING(L"\\Callback\\KmtestExCallbackTestCallback");
100     PVOID CallbackRegistration;
101     INT InvocationCount = 0;
102 
103     TestPrivateFunctions();
104 
105     /* TODO: Parameter tests */
106     /* TODO: Test the three predefined callbacks */
107     /* TODO: Test opening an existing callback */
108     /* TODO: Test AllowMultipleCallbacks */
109     /* TODO: Test calling multiple callbacks */
110     /* TODO: Test registering the same function twice */
111     /* TODO: Test callback object fields */
112     /* TODO: Test callback registration fields */
113     InitializeObjectAttributes(&ObjectAttributes,
114                                &CallbackName,
115                                OBJ_CASE_INSENSITIVE,
116                                NULL,
117                                NULL);
118 
119     CallbackObject = KmtInvalidPointer;
120     Status = ExCreateCallback(&CallbackObject,
121                               &ObjectAttributes,
122                               TRUE,
123                               TRUE);
124     ok_eq_hex(Status, STATUS_SUCCESS);
125     ok(CallbackObject != NULL && CallbackObject != KmtInvalidPointer,
126         "CallbackObject = %p", CallbackObject);
127 
128     if (skip(NT_SUCCESS(Status), "Creating callback failed\n"))
129         return;
130 
131     CallbackRegistration = ExRegisterCallback(CallbackObject,
132                                               CallbackFunction,
133                                               &InvocationCount);
134     ok(CallbackRegistration != NULL, "CallbackRegistration = NULL\n");
135 
136     if (!skip(CallbackRegistration != NULL, "Registering callback failed\n"))
137     {
138         ok_eq_hex(KmtGetPoolTag(CallbackRegistration), 'eRBC');
139         ok_eq_int(InvocationCount, 0);
140         ExNotifyCallback(CallbackObject,
141                          &CallbackArgument1,
142                          &CallbackArgument2);
143         ok_eq_int(InvocationCount, 1);
144         ExNotifyCallback(CallbackObject,
145                          &CallbackArgument1,
146                          &CallbackArgument2);
147         ok_eq_int(InvocationCount, 2);
148 
149         ExUnregisterCallback(CallbackRegistration);
150     }
151 
152     ObDereferenceObject(CallbackObject);
153 }
154