1 /* Unit test suite for Ntdll Port API functions
2  *
3  * Copyright 2006 James Hawkins
4  *
5  * This library is free software; you can redistribute it and/or
6  * modify it under the terms of the GNU Lesser General Public
7  * License as published by the Free Software Foundation; either
8  * version 2.1 of the License, or (at your option) any later version.
9  *
10  * This library is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * Lesser General Public License for more details.
14  *
15  * You should have received a copy of the GNU Lesser General Public
16  * License along with this library; if not, write to the Free Software
17  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
18  */
19 
20 #include <stdio.h>
21 #include <stdarg.h>
22 
23 #include "ntstatus.h"
24 #define WIN32_NO_STATUS
25 #include "windef.h"
26 #include "winbase.h"
27 #include "winuser.h"
28 #include "winreg.h"
29 #include "winnls.h"
30 #include "wine/test.h"
31 #include "winternl.h"
32 
33 #ifndef __WINE_WINTERNL_H
34 
35 typedef struct _CLIENT_ID
36 {
37    HANDLE UniqueProcess;
38    HANDLE UniqueThread;
39 } CLIENT_ID, *PCLIENT_ID;
40 
41 typedef struct _LPC_SECTION_WRITE
42 {
43   ULONG Length;
44   HANDLE SectionHandle;
45   ULONG SectionOffset;
46   ULONG ViewSize;
47   PVOID ViewBase;
48   PVOID TargetViewBase;
49 } LPC_SECTION_WRITE, *PLPC_SECTION_WRITE;
50 
51 typedef struct _LPC_SECTION_READ
52 {
53   ULONG Length;
54   ULONG ViewSize;
55   PVOID ViewBase;
56 } LPC_SECTION_READ, *PLPC_SECTION_READ;
57 
58 typedef struct _LPC_MESSAGE
59 {
60   USHORT DataSize;
61   USHORT MessageSize;
62   USHORT MessageType;
63   USHORT VirtualRangesOffset;
64   CLIENT_ID ClientId;
65   ULONG_PTR MessageId;
66   ULONG_PTR SectionSize;
67   UCHAR Data[ANYSIZE_ARRAY];
68 } LPC_MESSAGE, *PLPC_MESSAGE;
69 
70 #endif
71 
72 /* on Wow64 we have to use the 64-bit layout */
73 typedef struct
74 {
75   USHORT DataSize;
76   USHORT MessageSize;
77   USHORT MessageType;
78   USHORT VirtualRangesOffset;
79   ULONGLONG ClientId[2];
80   ULONGLONG MessageId;
81   ULONGLONG SectionSize;
82   UCHAR Data[ANYSIZE_ARRAY];
83 } LPC_MESSAGE64;
84 
85 union lpc_message
86 {
87     LPC_MESSAGE   msg;
88     LPC_MESSAGE64 msg64;
89 };
90 
91 /* Types of LPC messages */
92 #define UNUSED_MSG_TYPE                 0
93 #define LPC_REQUEST                     1
94 #define LPC_REPLY                       2
95 #define LPC_DATAGRAM                    3
96 #define LPC_LOST_REPLY                  4
97 #define LPC_PORT_CLOSED                 5
98 #define LPC_CLIENT_DIED                 6
99 #define LPC_EXCEPTION                   7
100 #define LPC_DEBUG_EVENT                 8
101 #define LPC_ERROR_EVENT                 9
102 #define LPC_CONNECTION_REQUEST         10
103 
104 static const WCHAR PORTNAME[] = {'\\','M','y','P','o','r','t',0};
105 
106 #define REQUEST1    "Request1"
107 #define REQUEST2    "Request2"
108 #define REPLY       "Reply"
109 
110 #define MAX_MESSAGE_LEN    30
111 
112 static UNICODE_STRING port;
113 
114 /* Function pointers for ntdll calls */
115 static HMODULE hntdll = 0;
116 static NTSTATUS (WINAPI *pNtCompleteConnectPort)(HANDLE);
117 static NTSTATUS (WINAPI *pNtAcceptConnectPort)(PHANDLE,ULONG,PLPC_MESSAGE,ULONG,
118                                                PLPC_SECTION_WRITE,PLPC_SECTION_READ);
119 static NTSTATUS (WINAPI *pNtReplyPort)(HANDLE,PLPC_MESSAGE);
120 static NTSTATUS (WINAPI *pNtReplyWaitReceivePort)(PHANDLE,PULONG,PLPC_MESSAGE,
121                                                   PLPC_MESSAGE);
122 static NTSTATUS (WINAPI *pNtCreatePort)(PHANDLE,POBJECT_ATTRIBUTES,ULONG,ULONG,ULONG);
123 static NTSTATUS (WINAPI *pNtRequestWaitReplyPort)(HANDLE,PLPC_MESSAGE,PLPC_MESSAGE);
124 static NTSTATUS (WINAPI *pNtRequestPort)(HANDLE,PLPC_MESSAGE);
125 static NTSTATUS (WINAPI *pNtRegisterThreadTerminatePort)(HANDLE);
126 static NTSTATUS (WINAPI *pNtConnectPort)(PHANDLE,PUNICODE_STRING,
127                                          PSECURITY_QUALITY_OF_SERVICE,
128                                          PLPC_SECTION_WRITE,PLPC_SECTION_READ,
129                                          PVOID,PVOID,PULONG);
130 static NTSTATUS (WINAPI *pRtlInitUnicodeString)(PUNICODE_STRING,LPCWSTR);
131 static BOOL     (WINAPI *pIsWow64Process)(HANDLE, PBOOL);
132 
133 static BOOL is_wow64;
134 
135 static BOOL init_function_ptrs(void)
136 {
137     hntdll = LoadLibraryA("ntdll.dll");
138 
139     if (!hntdll)
140         return FALSE;
141 
142     pNtCompleteConnectPort = (void *)GetProcAddress(hntdll, "NtCompleteConnectPort");
143     pNtAcceptConnectPort = (void *)GetProcAddress(hntdll, "NtAcceptConnectPort");
144     pNtReplyPort = (void *)GetProcAddress(hntdll, "NtReplyPort");
145     pNtReplyWaitReceivePort = (void *)GetProcAddress(hntdll, "NtReplyWaitReceivePort");
146     pNtCreatePort = (void *)GetProcAddress(hntdll, "NtCreatePort");
147     pNtRequestWaitReplyPort = (void *)GetProcAddress(hntdll, "NtRequestWaitReplyPort");
148     pNtRequestPort = (void *)GetProcAddress(hntdll, "NtRequestPort");
149     pNtRegisterThreadTerminatePort = (void *)GetProcAddress(hntdll, "NtRegisterThreadTerminatePort");
150     pNtConnectPort = (void *)GetProcAddress(hntdll, "NtConnectPort");
151     pRtlInitUnicodeString = (void *)GetProcAddress(hntdll, "RtlInitUnicodeString");
152 
153     if (!pNtCompleteConnectPort || !pNtAcceptConnectPort ||
154         !pNtReplyWaitReceivePort || !pNtCreatePort || !pNtRequestWaitReplyPort ||
155         !pNtRequestPort || !pNtRegisterThreadTerminatePort ||
156         !pNtConnectPort || !pRtlInitUnicodeString)
157     {
158         win_skip("Needed port functions are not available\n");
159         FreeLibrary(hntdll);
160         return FALSE;
161     }
162 
163     pIsWow64Process = (void *)GetProcAddress(GetModuleHandleA("kernel32.dll"), "IsWow64Process");
164     if (!pIsWow64Process || !pIsWow64Process( GetCurrentProcess(), &is_wow64 )) is_wow64 = FALSE;
165     return TRUE;
166 }
167 
168 static void ProcessConnectionRequest(union lpc_message *LpcMessage, PHANDLE pAcceptPortHandle)
169 {
170     NTSTATUS status;
171 
172     if (is_wow64)
173     {
174         ok(LpcMessage->msg64.MessageType == LPC_CONNECTION_REQUEST,
175            "Expected LPC_CONNECTION_REQUEST, got %d\n", LpcMessage->msg64.MessageType);
176         ok(!*LpcMessage->msg64.Data, "Expected empty string!\n");
177     }
178     else
179     {
180         ok(LpcMessage->msg.MessageType == LPC_CONNECTION_REQUEST,
181            "Expected LPC_CONNECTION_REQUEST, got %d\n", LpcMessage->msg.MessageType);
182         ok(!*LpcMessage->msg.Data, "Expected empty string!\n");
183     }
184 
185     status = pNtAcceptConnectPort(pAcceptPortHandle, 0, &LpcMessage->msg, 1, NULL, NULL);
186     ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
187 
188     status = pNtCompleteConnectPort(*pAcceptPortHandle);
189     ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
190 }
191 
192 static void ProcessLpcRequest(HANDLE PortHandle, union lpc_message *LpcMessage)
193 {
194     NTSTATUS status;
195 
196     if (is_wow64)
197     {
198         ok(LpcMessage->msg64.MessageType == LPC_REQUEST,
199            "Expected LPC_REQUEST, got %d\n", LpcMessage->msg64.MessageType);
200         ok(!strcmp((LPSTR)LpcMessage->msg64.Data, REQUEST2),
201            "Expected %s, got %s\n", REQUEST2, LpcMessage->msg64.Data);
202         strcpy((LPSTR)LpcMessage->msg64.Data, REPLY);
203 
204         status = pNtReplyPort(PortHandle, &LpcMessage->msg);
205         ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
206         ok(LpcMessage->msg64.MessageType == LPC_REQUEST,
207            "Expected LPC_REQUEST, got %d\n", LpcMessage->msg64.MessageType);
208         ok(!strcmp((LPSTR)LpcMessage->msg64.Data, REPLY),
209            "Expected %s, got %s\n", REPLY, LpcMessage->msg64.Data);
210     }
211     else
212     {
213         ok(LpcMessage->msg.MessageType == LPC_REQUEST,
214            "Expected LPC_REQUEST, got %d\n", LpcMessage->msg.MessageType);
215         ok(!strcmp((LPSTR)LpcMessage->msg.Data, REQUEST2),
216            "Expected %s, got %s\n", REQUEST2, LpcMessage->msg.Data);
217         strcpy((LPSTR)LpcMessage->msg.Data, REPLY);
218 
219         status = pNtReplyPort(PortHandle, &LpcMessage->msg);
220         ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
221         ok(LpcMessage->msg.MessageType == LPC_REQUEST,
222            "Expected LPC_REQUEST, got %d\n", LpcMessage->msg.MessageType);
223         ok(!strcmp((LPSTR)LpcMessage->msg.Data, REPLY),
224            "Expected %s, got %s\n", REPLY, LpcMessage->msg.Data);
225     }
226 }
227 
228 static DWORD WINAPI test_ports_client(LPVOID arg)
229 {
230     SECURITY_QUALITY_OF_SERVICE sqos;
231     union lpc_message *LpcMessage, *out;
232     HANDLE PortHandle;
233     ULONG len, size;
234     NTSTATUS status;
235 
236     sqos.Length = sizeof(SECURITY_QUALITY_OF_SERVICE);
237     sqos.ImpersonationLevel = SecurityImpersonation;
238     sqos.ContextTrackingMode = SECURITY_STATIC_TRACKING;
239     sqos.EffectiveOnly = TRUE;
240 
241     status = pNtConnectPort(&PortHandle, &port, &sqos, 0, 0, &len, NULL, NULL);
242     todo_wine ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
243     if (status != STATUS_SUCCESS) return 1;
244 
245     status = pNtRegisterThreadTerminatePort(PortHandle);
246     ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
247 
248     if (is_wow64)
249     {
250         size = FIELD_OFFSET(LPC_MESSAGE64, Data[MAX_MESSAGE_LEN]);
251         LpcMessage = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, size);
252         out = HeapAlloc(GetProcessHeap(), 0, size);
253 
254         LpcMessage->msg64.DataSize = strlen(REQUEST1) + 1;
255         LpcMessage->msg64.MessageSize = FIELD_OFFSET(LPC_MESSAGE64, Data[LpcMessage->msg64.DataSize]);
256         strcpy((LPSTR)LpcMessage->msg64.Data, REQUEST1);
257 
258         status = pNtRequestPort(PortHandle, &LpcMessage->msg);
259         ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
260         ok(LpcMessage->msg64.MessageType == 0, "Expected 0, got %d\n", LpcMessage->msg64.MessageType);
261         ok(!strcmp((LPSTR)LpcMessage->msg64.Data, REQUEST1),
262            "Expected %s, got %s\n", REQUEST1, LpcMessage->msg64.Data);
263 
264         /* Fill in the message */
265         memset(LpcMessage, 0, size);
266         LpcMessage->msg64.DataSize = strlen(REQUEST2) + 1;
267         LpcMessage->msg64.MessageSize = FIELD_OFFSET(LPC_MESSAGE64, Data[LpcMessage->msg64.DataSize]);
268         strcpy((LPSTR)LpcMessage->msg64.Data, REQUEST2);
269 
270         /* Send the message and wait for the reply */
271         status = pNtRequestWaitReplyPort(PortHandle, &LpcMessage->msg, &out->msg);
272         ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
273         ok(!strcmp((LPSTR)out->msg64.Data, REPLY), "Expected %s, got %s\n", REPLY, out->msg64.Data);
274         ok(out->msg64.MessageType == LPC_REPLY, "Expected LPC_REPLY, got %d\n", out->msg64.MessageType);
275     }
276     else
277     {
278         size = FIELD_OFFSET(LPC_MESSAGE, Data[MAX_MESSAGE_LEN]);
279         LpcMessage = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, size);
280         out = HeapAlloc(GetProcessHeap(), 0, size);
281 
282         LpcMessage->msg.DataSize = strlen(REQUEST1) + 1;
283         LpcMessage->msg.MessageSize = FIELD_OFFSET(LPC_MESSAGE, Data[LpcMessage->msg.DataSize]);
284         strcpy((LPSTR)LpcMessage->msg.Data, REQUEST1);
285 
286         status = pNtRequestPort(PortHandle, &LpcMessage->msg);
287         ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
288         ok(LpcMessage->msg.MessageType == 0, "Expected 0, got %d\n", LpcMessage->msg.MessageType);
289         ok(!strcmp((LPSTR)LpcMessage->msg.Data, REQUEST1),
290            "Expected %s, got %s\n", REQUEST1, LpcMessage->msg.Data);
291 
292         /* Fill in the message */
293         memset(LpcMessage, 0, size);
294         LpcMessage->msg.DataSize = strlen(REQUEST2) + 1;
295         LpcMessage->msg.MessageSize = FIELD_OFFSET(LPC_MESSAGE, Data[LpcMessage->msg.DataSize]);
296         strcpy((LPSTR)LpcMessage->msg.Data, REQUEST2);
297 
298         /* Send the message and wait for the reply */
299         status = pNtRequestWaitReplyPort(PortHandle, &LpcMessage->msg, &out->msg);
300         ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %x\n", status);
301         ok(!strcmp((LPSTR)out->msg.Data, REPLY), "Expected %s, got %s\n", REPLY, out->msg.Data);
302         ok(out->msg.MessageType == LPC_REPLY, "Expected LPC_REPLY, got %d\n", out->msg.MessageType);
303     }
304 
305     HeapFree(GetProcessHeap(), 0, out);
306     HeapFree(GetProcessHeap(), 0, LpcMessage);
307 
308     return 0;
309 }
310 
311 static void test_ports_server( HANDLE PortHandle )
312 {
313     HANDLE AcceptPortHandle;
314     union lpc_message *LpcMessage;
315     ULONG size;
316     NTSTATUS status;
317     BOOL done = FALSE;
318 
319     size = FIELD_OFFSET(LPC_MESSAGE, Data) + MAX_MESSAGE_LEN;
320     LpcMessage = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, size);
321 
322     while (TRUE)
323     {
324         status = pNtReplyWaitReceivePort(PortHandle, NULL, NULL, &LpcMessage->msg);
325         todo_wine
326         {
327             ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %d(%x)\n", status, status);
328         }
329         /* STATUS_INVALID_HANDLE: win2k without admin rights will perform an
330          *                        endless loop here
331          */
332         if ((status == STATUS_NOT_IMPLEMENTED) ||
333             (status == STATUS_INVALID_HANDLE)) return;
334 
335         switch (is_wow64 ? LpcMessage->msg64.MessageType : LpcMessage->msg.MessageType)
336         {
337             case LPC_CONNECTION_REQUEST:
338                 ProcessConnectionRequest(LpcMessage, &AcceptPortHandle);
339                 break;
340 
341             case LPC_REQUEST:
342                 ProcessLpcRequest(PortHandle, LpcMessage);
343                 done = TRUE;
344                 break;
345 
346             case LPC_DATAGRAM:
347                 if (is_wow64)
348                     ok(!strcmp((LPSTR)LpcMessage->msg64.Data, REQUEST1),
349                        "Expected %s, got %s\n", REQUEST1, LpcMessage->msg64.Data);
350                 else
351                     ok(!strcmp((LPSTR)LpcMessage->msg.Data, REQUEST1),
352                        "Expected %s, got %s\n", REQUEST1, LpcMessage->msg.Data);
353                 break;
354 
355             case LPC_CLIENT_DIED:
356                 ok(done, "Expected LPC request to be completed!\n");
357                 HeapFree(GetProcessHeap(), 0, LpcMessage);
358                 return;
359 
360             default:
361                 ok(FALSE, "Unexpected message: %d\n",
362                    is_wow64 ? LpcMessage->msg64.MessageType : LpcMessage->msg.MessageType);
363                 break;
364         }
365     }
366 
367     HeapFree(GetProcessHeap(), 0, LpcMessage);
368 }
369 
370 START_TEST(port)
371 {
372     OBJECT_ATTRIBUTES obj;
373     HANDLE port_handle;
374     NTSTATUS status;
375 
376     if (!init_function_ptrs())
377         return;
378 
379     pRtlInitUnicodeString(&port, PORTNAME);
380 
381     memset(&obj, 0, sizeof(OBJECT_ATTRIBUTES));
382     obj.Length = sizeof(OBJECT_ATTRIBUTES);
383     obj.ObjectName = &port;
384 
385     status = pNtCreatePort(&port_handle, &obj, 100, 100, 0);
386     if (status == STATUS_ACCESS_DENIED) skip("Not enough rights\n");
387     else todo_wine ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %d\n", status);
388 
389     if (status == STATUS_SUCCESS)
390     {
391         DWORD id;
392         HANDLE thread = CreateThread(NULL, 0, test_ports_client, NULL, 0, &id);
393         ok(thread != NULL, "Expected non-NULL thread handle!\n");
394 
395         test_ports_server( port_handle );
396         ok( WaitForSingleObject( thread, 10000 ) == 0, "thread didn't exit\n" );
397         CloseHandle(thread);
398     }
399     FreeLibrary(hntdll);
400 }
401