1 /*
2  * PROJECT:         ReactOS API tests
3  * LICENSE:         LGPLv2.1+ - See COPYING.LIB in the top level directory
4  * PURPOSE:         Test for NtAcceptConnectPort
5  * PROGRAMMERS:     Thomas Faber <thomas.faber@reactos.org>
6  */
7 
8 #include "precomp.h"
9 
10 #include <process.h>
11 
12 #define TEST_CONNECTION_INFO_SIGNATURE1 0xaabb0123
13 #define TEST_CONNECTION_INFO_SIGNATURE2 0xaabb0124
14 typedef struct _TEST_CONNECTION_INFO
15 {
16     ULONG Signature;
17 } TEST_CONNECTION_INFO, *PTEST_CONNECTION_INFO;
18 
19 #define TEST_MESSAGE_MESSAGE 0x4455cdef
20 typedef struct _TEST_MESSAGE
21 {
22     PORT_MESSAGE Header;
23     ULONG Message;
24 } TEST_MESSAGE, *PTEST_MESSAGE;
25 
26 static UNICODE_STRING PortName = RTL_CONSTANT_STRING(L"\\NtdllApitestNtAcceptConnectPortTestPort");
27 static UINT ServerThreadId;
28 static UINT ClientThreadId;
29 static UCHAR Context;
30 
31 UINT
32 CALLBACK
33 ServerThread(
34     _Inout_ PVOID Parameter)
35 {
36     NTSTATUS Status;
37     TEST_MESSAGE Message;
38     HANDLE PortHandle;
39     HANDLE ServerPortHandle = Parameter;
40 
41     /* Listen, but refuse the connection */
42     RtlZeroMemory(&Message, sizeof(Message));
43     Status = NtListenPort(ServerPortHandle,
44                           &Message.Header);
45     ok_hex(Status, STATUS_SUCCESS);
46 
47     ok(Message.Header.u1.s1.TotalLength == RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message),
48        "TotalLength = %u, expected %lu\n",
49        Message.Header.u1.s1.TotalLength, RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message));
50     ok(Message.Header.u1.s1.DataLength == sizeof(TEST_CONNECTION_INFO),
51        "DataLength = %u\n", Message.Header.u1.s1.DataLength);
52     ok(Message.Header.u2.s2.Type == LPC_CONNECTION_REQUEST,
53        "Type = %x\n", Message.Header.u2.s2.Type);
54     ok(Message.Header.ClientId.UniqueProcess == (HANDLE)GetCurrentProcessId(),
55        "UniqueProcess = %p, expected %lx\n",
56        Message.Header.ClientId.UniqueProcess, GetCurrentProcessId());
57     ok(Message.Header.ClientId.UniqueThread == (HANDLE)ClientThreadId,
58        "UniqueThread = %p, expected %x\n",
59        Message.Header.ClientId.UniqueThread, ClientThreadId);
60     ok(Message.Message == TEST_CONNECTION_INFO_SIGNATURE1, "Message = %lx\n", Message.Message);
61 
62     PortHandle = (PVOID)(ULONG_PTR)0x55555555;
63     Status = NtAcceptConnectPort(&PortHandle,
64                                  &Context,
65                                  &Message.Header,
66                                  FALSE,
67                                  NULL,
68                                  NULL);
69     ok_hex(Status, STATUS_SUCCESS);
70     ok(PortHandle == (PVOID)(ULONG_PTR)0x55555555, "PortHandle = %p\n", PortHandle);
71 
72     /* Listen a second time, then accept */
73     RtlZeroMemory(&Message, sizeof(Message));
74     Status = NtListenPort(ServerPortHandle,
75                           &Message.Header);
76     ok_hex(Status, STATUS_SUCCESS);
77 
78     ok(Message.Header.u1.s1.TotalLength == RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message),
79        "TotalLength = %u, expected %lu\n",
80        Message.Header.u1.s1.TotalLength, RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message));
81     ok(Message.Header.u1.s1.DataLength == sizeof(TEST_CONNECTION_INFO),
82        "DataLength = %u\n", Message.Header.u1.s1.DataLength);
83     ok(Message.Header.u2.s2.Type == LPC_CONNECTION_REQUEST,
84        "Type = %x\n", Message.Header.u2.s2.Type);
85     ok(Message.Header.ClientId.UniqueProcess == (HANDLE)GetCurrentProcessId(),
86        "UniqueProcess = %p, expected %lx\n",
87        Message.Header.ClientId.UniqueProcess, GetCurrentProcessId());
88     ok(Message.Header.ClientId.UniqueThread == (HANDLE)ClientThreadId,
89        "UniqueThread = %p, expected %x\n",
90        Message.Header.ClientId.UniqueThread, ClientThreadId);
91     ok(Message.Message == TEST_CONNECTION_INFO_SIGNATURE2, "Message = %lx\n", Message.Message);
92 
93     Status = NtAcceptConnectPort(&PortHandle,
94                                  &Context,
95                                  &Message.Header,
96                                  TRUE,
97                                  NULL,
98                                  NULL);
99     ok_hex(Status, STATUS_SUCCESS);
100 
101     Status = NtCompleteConnectPort(PortHandle);
102     ok_hex(Status, STATUS_SUCCESS);
103 
104     RtlZeroMemory(&Message, sizeof(Message));
105     Status = NtReplyWaitReceivePort(PortHandle,
106                                     NULL,
107                                     NULL,
108                                     &Message.Header);
109     ok_hex(Status, STATUS_SUCCESS);
110 
111     ok(Message.Header.u1.s1.TotalLength == sizeof(Message),
112        "TotalLength = %u, expected %Iu\n",
113        Message.Header.u1.s1.TotalLength, sizeof(Message));
114     ok(Message.Header.u1.s1.DataLength == sizeof(Message.Message),
115        "DataLength = %u\n", Message.Header.u1.s1.DataLength);
116     ok(Message.Header.u2.s2.Type == LPC_DATAGRAM,
117        "Type = %x\n", Message.Header.u2.s2.Type);
118     ok(Message.Header.ClientId.UniqueProcess == (HANDLE)GetCurrentProcessId(),
119        "UniqueProcess = %p, expected %lx\n",
120        Message.Header.ClientId.UniqueProcess, GetCurrentProcessId());
121     ok(Message.Header.ClientId.UniqueThread == (HANDLE)ClientThreadId,
122        "UniqueThread = %p, expected %x\n",
123        Message.Header.ClientId.UniqueThread, ClientThreadId);
124     ok(Message.Message == TEST_MESSAGE_MESSAGE, "Message = %lx\n", Message.Message);
125 
126     Status = NtClose(PortHandle);
127     ok_hex(Status, STATUS_SUCCESS);
128 
129     return 0;
130 }
131 
132 UINT
133 CALLBACK
134 ClientThread(
135     _Inout_ PVOID Parameter)
136 {
137     NTSTATUS Status;
138     HANDLE PortHandle;
139     TEST_CONNECTION_INFO ConnectInfo;
140     ULONG ConnectInfoLength;
141     SECURITY_QUALITY_OF_SERVICE SecurityQos;
142     TEST_MESSAGE Message;
143 
144     SecurityQos.Length = sizeof(SecurityQos);
145     SecurityQos.ImpersonationLevel = SecurityIdentification;
146     SecurityQos.EffectiveOnly = TRUE;
147     SecurityQos.ContextTrackingMode = SECURITY_STATIC_TRACKING;
148 
149     /* Attempt to connect -- will be rejected */
150     ConnectInfo.Signature = TEST_CONNECTION_INFO_SIGNATURE1;
151     ConnectInfoLength = sizeof(ConnectInfo);
152     PortHandle = (PVOID)(ULONG_PTR)0x55555555;
153     Status = NtConnectPort(&PortHandle,
154                            &PortName,
155                            &SecurityQos,
156                            NULL,
157                            NULL,
158                            NULL,
159                            &ConnectInfo,
160                            &ConnectInfoLength);
161     ok_hex(Status, STATUS_PORT_CONNECTION_REFUSED);
162     ok(PortHandle == (PVOID)(ULONG_PTR)0x55555555, "PortHandle = %p\n", PortHandle);
163 
164     /* Try again, this time it will be accepted */
165     ConnectInfo.Signature = TEST_CONNECTION_INFO_SIGNATURE2;
166     ConnectInfoLength = sizeof(ConnectInfo);
167     Status = NtConnectPort(&PortHandle,
168                            &PortName,
169                            &SecurityQos,
170                            NULL,
171                            NULL,
172                            NULL,
173                            &ConnectInfo,
174                            &ConnectInfoLength);
175     ok_hex(Status, STATUS_SUCCESS);
176     if (!NT_SUCCESS(Status))
177     {
178         skip("Failed to connect\n");
179         return 0;
180     }
181 
182     RtlZeroMemory(&Message, sizeof(Message));
183     Message.Header.u1.s1.TotalLength = sizeof(Message);
184     Message.Header.u1.s1.DataLength = sizeof(Message.Message);
185     Message.Message = TEST_MESSAGE_MESSAGE;
186     Status = NtRequestPort(PortHandle,
187                            &Message.Header);
188     ok_hex(Status, STATUS_SUCCESS);
189 
190     Status = NtClose(PortHandle);
191     ok_hex(Status, STATUS_SUCCESS);
192 
193     return 0;
194 }
195 
196 START_TEST(NtAcceptConnectPort)
197 {
198     NTSTATUS Status;
199     OBJECT_ATTRIBUTES ObjectAttributes;
200     HANDLE PortHandle;
201     HANDLE ThreadHandles[2];
202 
203     InitializeObjectAttributes(&ObjectAttributes,
204                                &PortName,
205                                OBJ_CASE_INSENSITIVE,
206                                NULL,
207                                NULL);
208     Status = NtCreatePort(&PortHandle,
209                           &ObjectAttributes,
210                           sizeof(TEST_CONNECTION_INFO),
211                           sizeof(TEST_MESSAGE),
212                           2 * sizeof(TEST_MESSAGE));
213     ok_hex(Status, STATUS_SUCCESS);
214     if (!NT_SUCCESS(Status))
215     {
216         skip("Failed to create port\n");
217         return;
218     }
219 
220     ThreadHandles[0] = (HANDLE)_beginthreadex(NULL,
221                                               0,
222                                               ServerThread,
223                                               PortHandle,
224                                               0,
225                                               &ServerThreadId);
226     ok(ThreadHandles[0] != NULL, "_beginthreadex failed\n");
227 
228     ThreadHandles[1] = (HANDLE)_beginthreadex(NULL,
229                                               0,
230                                               ClientThread,
231                                               PortHandle,
232                                               0,
233                                               &ClientThreadId);
234     ok(ThreadHandles[1] != NULL, "_beginthreadex failed\n");
235 
236     Status = NtWaitForMultipleObjects(RTL_NUMBER_OF(ThreadHandles),
237                                       ThreadHandles,
238                                       WaitAll,
239                                       FALSE,
240                                       NULL);
241     ok_hex(Status, STATUS_SUCCESS);
242 
243     Status = NtClose(ThreadHandles[0]);
244     ok_hex(Status, STATUS_SUCCESS);
245     Status = NtClose(ThreadHandles[1]);
246     ok_hex(Status, STATUS_SUCCESS);
247 
248     Status = NtClose(PortHandle);
249     ok_hex(Status, STATUS_SUCCESS);
250 }
251