1 /* Unit test suite for Ntdll Port API functions
2  *
3  * Copyright 2006 James Hawkins
4  *
5  * This library is free software; you can redistribute it and/or
6  * modify it under the terms of the GNU Lesser General Public
7  * License as published by the Free Software Foundation; either
8  * version 2.1 of the License, or (at your option) any later version.
9  *
10  * This library is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * Lesser General Public License for more details.
14  *
15  * You should have received a copy of the GNU Lesser General Public
16  * License along with this library; if not, write to the Free Software
17  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
18  */
19 
20 #include "ntdll_test.h"
21 
22 #ifndef __WINE_WINTERNL_H
23 
24 typedef struct _CLIENT_ID
25 {
26    HANDLE UniqueProcess;
27    HANDLE UniqueThread;
28 } CLIENT_ID, *PCLIENT_ID;
29 
30 typedef struct _LPC_SECTION_WRITE
31 {
32   ULONG Length;
33   HANDLE SectionHandle;
34   ULONG SectionOffset;
35   ULONG ViewSize;
36   PVOID ViewBase;
37   PVOID TargetViewBase;
38 } LPC_SECTION_WRITE, *PLPC_SECTION_WRITE;
39 
40 typedef struct _LPC_SECTION_READ
41 {
42   ULONG Length;
43   ULONG ViewSize;
44   PVOID ViewBase;
45 } LPC_SECTION_READ, *PLPC_SECTION_READ;
46 
47 typedef struct _LPC_MESSAGE
48 {
49   USHORT DataSize;
50   USHORT MessageSize;
51   USHORT MessageType;
52   USHORT VirtualRangesOffset;
53   CLIENT_ID ClientId;
54   ULONG_PTR MessageId;
55   ULONG_PTR SectionSize;
56   UCHAR Data[ANYSIZE_ARRAY];
57 } LPC_MESSAGE, *PLPC_MESSAGE;
58 
59 #endif
60 
61 /* on Wow64 we have to use the 64-bit layout */
62 typedef struct
63 {
64   USHORT DataSize;
65   USHORT MessageSize;
66   USHORT MessageType;
67   USHORT VirtualRangesOffset;
68   ULONGLONG ClientId[2];
69   ULONGLONG MessageId;
70   ULONGLONG SectionSize;
71   UCHAR Data[ANYSIZE_ARRAY];
72 } LPC_MESSAGE64;
73 
74 union lpc_message
75 {
76     LPC_MESSAGE   msg;
77     LPC_MESSAGE64 msg64;
78 };
79 
80 /* Types of LPC messages */
81 #define UNUSED_MSG_TYPE                 0
82 #define LPC_REQUEST                     1
83 #define LPC_REPLY                       2
84 #define LPC_DATAGRAM                    3
85 #define LPC_LOST_REPLY                  4
86 #define LPC_PORT_CLOSED                 5
87 #define LPC_CLIENT_DIED                 6
88 #define LPC_EXCEPTION                   7
89 #define LPC_DEBUG_EVENT                 8
90 #define LPC_ERROR_EVENT                 9
91 #define LPC_CONNECTION_REQUEST         10
92 
93 static const WCHAR PORTNAME[] = {'\\','M','y','P','o','r','t',0};
94 
95 #define REQUEST1    "Request1"
96 #define REQUEST2    "Request2"
97 #define REPLY       "Reply"
98 
99 #define MAX_MESSAGE_LEN    30
100 
101 static UNICODE_STRING port;
102 
103 /* Function pointers for ntdll calls */
104 static HMODULE hntdll = 0;
105 static NTSTATUS (WINAPI *pNtCompleteConnectPort)(HANDLE);
106 static NTSTATUS (WINAPI *pNtAcceptConnectPort)(PHANDLE,ULONG,PLPC_MESSAGE,ULONG,
107                                                PLPC_SECTION_WRITE,PLPC_SECTION_READ);
108 static NTSTATUS (WINAPI *pNtReplyPort)(HANDLE,PLPC_MESSAGE);
109 static NTSTATUS (WINAPI *pNtReplyWaitReceivePort)(PHANDLE,PULONG,PLPC_MESSAGE,
110                                                   PLPC_MESSAGE);
111 static NTSTATUS (WINAPI *pNtCreatePort)(PHANDLE,POBJECT_ATTRIBUTES,ULONG,ULONG,ULONG);
112 static NTSTATUS (WINAPI *pNtRequestWaitReplyPort)(HANDLE,PLPC_MESSAGE,PLPC_MESSAGE);
113 static NTSTATUS (WINAPI *pNtRequestPort)(HANDLE,PLPC_MESSAGE);
114 static NTSTATUS (WINAPI *pNtRegisterThreadTerminatePort)(HANDLE);
115 static NTSTATUS (WINAPI *pNtConnectPort)(PHANDLE,PUNICODE_STRING,
116                                          PSECURITY_QUALITY_OF_SERVICE,
117                                          PLPC_SECTION_WRITE,PLPC_SECTION_READ,
118                                          PVOID,PVOID,PULONG);
119 static NTSTATUS (WINAPI *pRtlInitUnicodeString)(PUNICODE_STRING,LPCWSTR);
120 static BOOL     (WINAPI *pIsWow64Process)(HANDLE, PBOOL);
121 
122 static BOOL is_wow64;
123 
124 static BOOL init_function_ptrs(void)
125 {
126     hntdll = LoadLibraryA("ntdll.dll");
127 
128     if (!hntdll)
129         return FALSE;
130 
131     pNtCompleteConnectPort = (void *)GetProcAddress(hntdll, "NtCompleteConnectPort");
132     pNtAcceptConnectPort = (void *)GetProcAddress(hntdll, "NtAcceptConnectPort");
133     pNtReplyPort = (void *)GetProcAddress(hntdll, "NtReplyPort");
134     pNtReplyWaitReceivePort = (void *)GetProcAddress(hntdll, "NtReplyWaitReceivePort");
135     pNtCreatePort = (void *)GetProcAddress(hntdll, "NtCreatePort");
136     pNtRequestWaitReplyPort = (void *)GetProcAddress(hntdll, "NtRequestWaitReplyPort");
137     pNtRequestPort = (void *)GetProcAddress(hntdll, "NtRequestPort");
138     pNtRegisterThreadTerminatePort = (void *)GetProcAddress(hntdll, "NtRegisterThreadTerminatePort");
139     pNtConnectPort = (void *)GetProcAddress(hntdll, "NtConnectPort");
140     pRtlInitUnicodeString = (void *)GetProcAddress(hntdll, "RtlInitUnicodeString");
141 
142     if (!pNtCompleteConnectPort || !pNtAcceptConnectPort ||
143         !pNtReplyWaitReceivePort || !pNtCreatePort || !pNtRequestWaitReplyPort ||
144         !pNtRequestPort || !pNtRegisterThreadTerminatePort ||
145         !pNtConnectPort || !pRtlInitUnicodeString)
146     {
147         win_skip("Needed port functions are not available\n");
148         FreeLibrary(hntdll);
149         return FALSE;
150     }
151 
152     pIsWow64Process = (void *)GetProcAddress(GetModuleHandleA("kernel32.dll"), "IsWow64Process");
153     if (!pIsWow64Process || !pIsWow64Process( GetCurrentProcess(), &is_wow64 )) is_wow64 = FALSE;
154     return TRUE;
155 }
156 
157 static void ProcessConnectionRequest(union lpc_message *LpcMessage, PHANDLE pAcceptPortHandle)
158 {
159     NTSTATUS status;
160 
161     if (is_wow64)
162     {
163         ok(LpcMessage->msg64.MessageType == LPC_CONNECTION_REQUEST,
164            "Expected LPC_CONNECTION_REQUEST, got %d\n", LpcMessage->msg64.MessageType);
165         ok(!*LpcMessage->msg64.Data, "Expected empty string!\n");
166     }
167     else
168     {
169         ok(LpcMessage->msg.MessageType == LPC_CONNECTION_REQUEST,
170            "Expected LPC_CONNECTION_REQUEST, got %d\n", LpcMessage->msg.MessageType);
171         ok(!*LpcMessage->msg.Data, "Expected empty string!\n");
172     }
173 
174     status = pNtAcceptConnectPort(pAcceptPortHandle, 0, &LpcMessage->msg, 1, NULL, NULL);
175     ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
176 
177     status = pNtCompleteConnectPort(*pAcceptPortHandle);
178     ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
179 }
180 
181 static void ProcessLpcRequest(HANDLE PortHandle, union lpc_message *LpcMessage)
182 {
183     NTSTATUS status;
184 
185     if (is_wow64)
186     {
187         ok(LpcMessage->msg64.MessageType == LPC_REQUEST,
188            "Expected LPC_REQUEST, got %d\n", LpcMessage->msg64.MessageType);
189         ok(!strcmp((LPSTR)LpcMessage->msg64.Data, REQUEST2),
190            "Expected %s, got %s\n", REQUEST2, LpcMessage->msg64.Data);
191         strcpy((LPSTR)LpcMessage->msg64.Data, REPLY);
192 
193         status = pNtReplyPort(PortHandle, &LpcMessage->msg);
194         ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
195         ok(LpcMessage->msg64.MessageType == LPC_REQUEST,
196            "Expected LPC_REQUEST, got %d\n", LpcMessage->msg64.MessageType);
197         ok(!strcmp((LPSTR)LpcMessage->msg64.Data, REPLY),
198            "Expected %s, got %s\n", REPLY, LpcMessage->msg64.Data);
199     }
200     else
201     {
202         ok(LpcMessage->msg.MessageType == LPC_REQUEST,
203            "Expected LPC_REQUEST, got %d\n", LpcMessage->msg.MessageType);
204         ok(!strcmp((LPSTR)LpcMessage->msg.Data, REQUEST2),
205            "Expected %s, got %s\n", REQUEST2, LpcMessage->msg.Data);
206         strcpy((LPSTR)LpcMessage->msg.Data, REPLY);
207 
208         status = pNtReplyPort(PortHandle, &LpcMessage->msg);
209         ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
210         ok(LpcMessage->msg.MessageType == LPC_REQUEST,
211            "Expected LPC_REQUEST, got %d\n", LpcMessage->msg.MessageType);
212         ok(!strcmp((LPSTR)LpcMessage->msg.Data, REPLY),
213            "Expected %s, got %s\n", REPLY, LpcMessage->msg.Data);
214     }
215 }
216 
217 static DWORD WINAPI test_ports_client(LPVOID arg)
218 {
219     SECURITY_QUALITY_OF_SERVICE sqos;
220     union lpc_message *LpcMessage, *out;
221     HANDLE PortHandle;
222     ULONG len, size;
223     NTSTATUS status;
224 
225     sqos.Length = sizeof(SECURITY_QUALITY_OF_SERVICE);
226     sqos.ImpersonationLevel = SecurityImpersonation;
227     sqos.ContextTrackingMode = SECURITY_STATIC_TRACKING;
228     sqos.EffectiveOnly = TRUE;
229 
230     status = pNtConnectPort(&PortHandle, &port, &sqos, 0, 0, &len, NULL, NULL);
231     todo_wine ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
232     if (status != STATUS_SUCCESS) return 1;
233 
234     status = pNtRegisterThreadTerminatePort(PortHandle);
235     ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
236 
237     if (is_wow64)
238     {
239         size = FIELD_OFFSET(LPC_MESSAGE64, Data[MAX_MESSAGE_LEN]);
240         LpcMessage = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, size);
241         out = HeapAlloc(GetProcessHeap(), 0, size);
242 
243         LpcMessage->msg64.DataSize = strlen(REQUEST1) + 1;
244         LpcMessage->msg64.MessageSize = FIELD_OFFSET(LPC_MESSAGE64, Data[LpcMessage->msg64.DataSize]);
245         strcpy((LPSTR)LpcMessage->msg64.Data, REQUEST1);
246 
247         status = pNtRequestPort(PortHandle, &LpcMessage->msg);
248         ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
249         ok(LpcMessage->msg64.MessageType == 0, "Expected 0, got %d\n", LpcMessage->msg64.MessageType);
250         ok(!strcmp((LPSTR)LpcMessage->msg64.Data, REQUEST1),
251            "Expected %s, got %s\n", REQUEST1, LpcMessage->msg64.Data);
252 
253         /* Fill in the message */
254         memset(LpcMessage, 0, size);
255         LpcMessage->msg64.DataSize = strlen(REQUEST2) + 1;
256         LpcMessage->msg64.MessageSize = FIELD_OFFSET(LPC_MESSAGE64, Data[LpcMessage->msg64.DataSize]);
257         strcpy((LPSTR)LpcMessage->msg64.Data, REQUEST2);
258 
259         /* Send the message and wait for the reply */
260         status = pNtRequestWaitReplyPort(PortHandle, &LpcMessage->msg, &out->msg);
261         ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
262         ok(!strcmp((LPSTR)out->msg64.Data, REPLY), "Expected %s, got %s\n", REPLY, out->msg64.Data);
263         ok(out->msg64.MessageType == LPC_REPLY, "Expected LPC_REPLY, got %d\n", out->msg64.MessageType);
264     }
265     else
266     {
267         size = FIELD_OFFSET(LPC_MESSAGE, Data[MAX_MESSAGE_LEN]);
268         LpcMessage = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, size);
269         out = HeapAlloc(GetProcessHeap(), 0, size);
270 
271         LpcMessage->msg.DataSize = strlen(REQUEST1) + 1;
272         LpcMessage->msg.MessageSize = FIELD_OFFSET(LPC_MESSAGE, Data[LpcMessage->msg.DataSize]);
273         strcpy((LPSTR)LpcMessage->msg.Data, REQUEST1);
274 
275         status = pNtRequestPort(PortHandle, &LpcMessage->msg);
276         ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
277         ok(LpcMessage->msg.MessageType == 0, "Expected 0, got %d\n", LpcMessage->msg.MessageType);
278         ok(!strcmp((LPSTR)LpcMessage->msg.Data, REQUEST1),
279            "Expected %s, got %s\n", REQUEST1, LpcMessage->msg.Data);
280 
281         /* Fill in the message */
282         memset(LpcMessage, 0, size);
283         LpcMessage->msg.DataSize = strlen(REQUEST2) + 1;
284         LpcMessage->msg.MessageSize = FIELD_OFFSET(LPC_MESSAGE, Data[LpcMessage->msg.DataSize]);
285         strcpy((LPSTR)LpcMessage->msg.Data, REQUEST2);
286 
287         /* Send the message and wait for the reply */
288         status = pNtRequestWaitReplyPort(PortHandle, &LpcMessage->msg, &out->msg);
289         ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
290         ok(!strcmp((LPSTR)out->msg.Data, REPLY), "Expected %s, got %s\n", REPLY, out->msg.Data);
291         ok(out->msg.MessageType == LPC_REPLY, "Expected LPC_REPLY, got %d\n", out->msg.MessageType);
292     }
293 
294     HeapFree(GetProcessHeap(), 0, out);
295     HeapFree(GetProcessHeap(), 0, LpcMessage);
296 
297     return 0;
298 }
299 
300 static void test_ports_server( HANDLE PortHandle )
301 {
302     HANDLE AcceptPortHandle;
303     union lpc_message *LpcMessage;
304     ULONG size;
305     NTSTATUS status;
306     BOOL done = FALSE;
307 
308     size = FIELD_OFFSET(LPC_MESSAGE, Data) + MAX_MESSAGE_LEN;
309     LpcMessage = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, size);
310 
311     while (TRUE)
312     {
313         status = pNtReplyWaitReceivePort(PortHandle, NULL, NULL, &LpcMessage->msg);
314         todo_wine
315         {
316             ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %d(%x)\n", status, status);
317         }
318         /* STATUS_INVALID_HANDLE: win2k without admin rights will perform an
319          *                        endless loop here
320          */
321         if ((status == STATUS_NOT_IMPLEMENTED) ||
322             (status == STATUS_INVALID_HANDLE)) return;
323 
324         switch (is_wow64 ? LpcMessage->msg64.MessageType : LpcMessage->msg.MessageType)
325         {
326             case LPC_CONNECTION_REQUEST:
327                 ProcessConnectionRequest(LpcMessage, &AcceptPortHandle);
328                 break;
329 
330             case LPC_REQUEST:
331                 ProcessLpcRequest(PortHandle, LpcMessage);
332                 done = TRUE;
333                 break;
334 
335             case LPC_DATAGRAM:
336                 if (is_wow64)
337                     ok(!strcmp((LPSTR)LpcMessage->msg64.Data, REQUEST1),
338                        "Expected %s, got %s\n", REQUEST1, LpcMessage->msg64.Data);
339                 else
340                     ok(!strcmp((LPSTR)LpcMessage->msg.Data, REQUEST1),
341                        "Expected %s, got %s\n", REQUEST1, LpcMessage->msg.Data);
342                 break;
343 
344             case LPC_CLIENT_DIED:
345                 ok(done, "Expected LPC request to be completed!\n");
346                 HeapFree(GetProcessHeap(), 0, LpcMessage);
347                 return;
348 
349             default:
350                 ok(FALSE, "Unexpected message: %d\n",
351                    is_wow64 ? LpcMessage->msg64.MessageType : LpcMessage->msg.MessageType);
352                 break;
353         }
354     }
355 
356     HeapFree(GetProcessHeap(), 0, LpcMessage);
357 }
358 
359 START_TEST(port)
360 {
361     OBJECT_ATTRIBUTES obj;
362     HANDLE port_handle;
363     NTSTATUS status;
364 
365     if (!init_function_ptrs())
366         return;
367 
368     pRtlInitUnicodeString(&port, PORTNAME);
369 
370     memset(&obj, 0, sizeof(OBJECT_ATTRIBUTES));
371     obj.Length = sizeof(OBJECT_ATTRIBUTES);
372     obj.ObjectName = &port;
373 
374     status = pNtCreatePort(&port_handle, &obj, 100, 100, 0);
375     if (status == STATUS_ACCESS_DENIED) skip("Not enough rights\n");
376     else todo_wine ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %d\n", status);
377 
378     if (status == STATUS_SUCCESS)
379     {
380         DWORD id;
381         HANDLE thread = CreateThread(NULL, 0, test_ports_client, NULL, 0, &id);
382         ok(thread != NULL, "Expected non-NULL thread handle!\n");
383 
384         test_ports_server( port_handle );
385         ok( WaitForSingleObject( thread, 10000 ) == 0, "thread didn't exit\n" );
386         CloseHandle(thread);
387     }
388     FreeLibrary(hntdll);
389 }
390