1 /* 2 * PROJECT: ReactOS API tests 3 * LICENSE: LGPLv2.1+ - See COPYING.LIB in the top level directory 4 * PURPOSE: Test for NtAcceptConnectPort 5 * PROGRAMMERS: Thomas Faber <thomas.faber@reactos.org> 6 */ 7 8 #include "precomp.h" 9 10 #include <process.h> 11 12 #define TEST_CONNECTION_INFO_SIGNATURE1 0xaabb0123 13 #define TEST_CONNECTION_INFO_SIGNATURE2 0xaabb0124 14 typedef struct _TEST_CONNECTION_INFO 15 { 16 ULONG Signature; 17 } TEST_CONNECTION_INFO, *PTEST_CONNECTION_INFO; 18 19 #define TEST_MESSAGE_MESSAGE 0x4455cdef 20 typedef struct _TEST_MESSAGE 21 { 22 PORT_MESSAGE Header; 23 ULONG Message; 24 } TEST_MESSAGE, *PTEST_MESSAGE; 25 26 static UNICODE_STRING PortName = RTL_CONSTANT_STRING(L"\\NtdllApitestNtAcceptConnectPortTestPort"); 27 static UINT ServerThreadId; 28 static UINT ClientThreadId; 29 static UCHAR Context; 30 31 UINT 32 CALLBACK 33 ServerThread( 34 _Inout_ PVOID Parameter) 35 { 36 NTSTATUS Status; 37 TEST_MESSAGE Message; 38 HANDLE PortHandle; 39 HANDLE ServerPortHandle = Parameter; 40 41 /* Listen, but refuse the connection */ 42 RtlZeroMemory(&Message, sizeof(Message)); 43 Status = NtListenPort(ServerPortHandle, 44 &Message.Header); 45 ok_hex(Status, STATUS_SUCCESS); 46 47 ok(Message.Header.u1.s1.TotalLength == RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message), 48 "TotalLength = %u, expected %lu\n", 49 Message.Header.u1.s1.TotalLength, RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message)); 50 ok(Message.Header.u1.s1.DataLength == sizeof(TEST_CONNECTION_INFO), 51 "DataLength = %u\n", Message.Header.u1.s1.DataLength); 52 ok(Message.Header.u2.s2.Type == LPC_CONNECTION_REQUEST, 53 "Type = %x\n", Message.Header.u2.s2.Type); 54 ok(Message.Header.ClientId.UniqueProcess == (HANDLE)GetCurrentProcessId(), 55 "UniqueProcess = %p, expected %lx\n", 56 Message.Header.ClientId.UniqueProcess, GetCurrentProcessId()); 57 ok(Message.Header.ClientId.UniqueThread == (HANDLE)ClientThreadId, 58 "UniqueThread = %p, expected %x\n", 59 Message.Header.ClientId.UniqueThread, ClientThreadId); 60 ok(Message.Message == TEST_CONNECTION_INFO_SIGNATURE1, "Message = %lx\n", Message.Message); 61 62 PortHandle = (PVOID)(ULONG_PTR)0x55555555; 63 Status = NtAcceptConnectPort(&PortHandle, 64 &Context, 65 &Message.Header, 66 FALSE, 67 NULL, 68 NULL); 69 ok_hex(Status, STATUS_SUCCESS); 70 ok(PortHandle == (PVOID)(ULONG_PTR)0x55555555, "PortHandle = %p\n", PortHandle); 71 72 /* Listen a second time, then accept */ 73 RtlZeroMemory(&Message, sizeof(Message)); 74 Status = NtListenPort(ServerPortHandle, 75 &Message.Header); 76 ok_hex(Status, STATUS_SUCCESS); 77 78 ok(Message.Header.u1.s1.TotalLength == RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message), 79 "TotalLength = %u, expected %lu\n", 80 Message.Header.u1.s1.TotalLength, RTL_SIZEOF_THROUGH_FIELD(TEST_MESSAGE, Message)); 81 ok(Message.Header.u1.s1.DataLength == sizeof(TEST_CONNECTION_INFO), 82 "DataLength = %u\n", Message.Header.u1.s1.DataLength); 83 ok(Message.Header.u2.s2.Type == LPC_CONNECTION_REQUEST, 84 "Type = %x\n", Message.Header.u2.s2.Type); 85 ok(Message.Header.ClientId.UniqueProcess == (HANDLE)GetCurrentProcessId(), 86 "UniqueProcess = %p, expected %lx\n", 87 Message.Header.ClientId.UniqueProcess, GetCurrentProcessId()); 88 ok(Message.Header.ClientId.UniqueThread == (HANDLE)ClientThreadId, 89 "UniqueThread = %p, expected %x\n", 90 Message.Header.ClientId.UniqueThread, ClientThreadId); 91 ok(Message.Message == TEST_CONNECTION_INFO_SIGNATURE2, "Message = %lx\n", Message.Message); 92 93 Status = NtAcceptConnectPort(&PortHandle, 94 &Context, 95 &Message.Header, 96 TRUE, 97 NULL, 98 NULL); 99 ok_hex(Status, STATUS_SUCCESS); 100 101 Status = NtCompleteConnectPort(PortHandle); 102 ok_hex(Status, STATUS_SUCCESS); 103 104 RtlZeroMemory(&Message, sizeof(Message)); 105 Status = NtReplyWaitReceivePort(PortHandle, 106 NULL, 107 NULL, 108 &Message.Header); 109 ok_hex(Status, STATUS_SUCCESS); 110 111 ok(Message.Header.u1.s1.TotalLength == sizeof(Message), 112 "TotalLength = %u, expected %Iu\n", 113 Message.Header.u1.s1.TotalLength, sizeof(Message)); 114 ok(Message.Header.u1.s1.DataLength == sizeof(Message.Message), 115 "DataLength = %u\n", Message.Header.u1.s1.DataLength); 116 ok(Message.Header.u2.s2.Type == LPC_DATAGRAM, 117 "Type = %x\n", Message.Header.u2.s2.Type); 118 ok(Message.Header.ClientId.UniqueProcess == (HANDLE)GetCurrentProcessId(), 119 "UniqueProcess = %p, expected %lx\n", 120 Message.Header.ClientId.UniqueProcess, GetCurrentProcessId()); 121 ok(Message.Header.ClientId.UniqueThread == (HANDLE)ClientThreadId, 122 "UniqueThread = %p, expected %x\n", 123 Message.Header.ClientId.UniqueThread, ClientThreadId); 124 ok(Message.Message == TEST_MESSAGE_MESSAGE, "Message = %lx\n", Message.Message); 125 126 Status = NtClose(PortHandle); 127 ok_hex(Status, STATUS_SUCCESS); 128 129 return 0; 130 } 131 132 UINT 133 CALLBACK 134 ClientThread( 135 _Inout_ PVOID Parameter) 136 { 137 NTSTATUS Status; 138 HANDLE PortHandle; 139 TEST_CONNECTION_INFO ConnectInfo; 140 ULONG ConnectInfoLength; 141 SECURITY_QUALITY_OF_SERVICE SecurityQos; 142 TEST_MESSAGE Message; 143 144 SecurityQos.Length = sizeof(SecurityQos); 145 SecurityQos.ImpersonationLevel = SecurityIdentification; 146 SecurityQos.EffectiveOnly = TRUE; 147 SecurityQos.ContextTrackingMode = SECURITY_STATIC_TRACKING; 148 149 /* Attempt to connect -- will be rejected */ 150 ConnectInfo.Signature = TEST_CONNECTION_INFO_SIGNATURE1; 151 ConnectInfoLength = sizeof(ConnectInfo); 152 PortHandle = (PVOID)(ULONG_PTR)0x55555555; 153 Status = NtConnectPort(&PortHandle, 154 &PortName, 155 &SecurityQos, 156 NULL, 157 NULL, 158 NULL, 159 &ConnectInfo, 160 &ConnectInfoLength); 161 ok_hex(Status, STATUS_PORT_CONNECTION_REFUSED); 162 ok(PortHandle == (PVOID)(ULONG_PTR)0x55555555, "PortHandle = %p\n", PortHandle); 163 164 /* Try again, this time it will be accepted */ 165 ConnectInfo.Signature = TEST_CONNECTION_INFO_SIGNATURE2; 166 ConnectInfoLength = sizeof(ConnectInfo); 167 Status = NtConnectPort(&PortHandle, 168 &PortName, 169 &SecurityQos, 170 NULL, 171 NULL, 172 NULL, 173 &ConnectInfo, 174 &ConnectInfoLength); 175 ok_hex(Status, STATUS_SUCCESS); 176 if (!NT_SUCCESS(Status)) 177 { 178 skip("Failed to connect\n"); 179 return 0; 180 } 181 182 RtlZeroMemory(&Message, sizeof(Message)); 183 Message.Header.u1.s1.TotalLength = sizeof(Message); 184 Message.Header.u1.s1.DataLength = sizeof(Message.Message); 185 Message.Message = TEST_MESSAGE_MESSAGE; 186 Status = NtRequestPort(PortHandle, 187 &Message.Header); 188 ok_hex(Status, STATUS_SUCCESS); 189 190 Status = NtClose(PortHandle); 191 ok_hex(Status, STATUS_SUCCESS); 192 193 return 0; 194 } 195 196 START_TEST(NtAcceptConnectPort) 197 { 198 NTSTATUS Status; 199 OBJECT_ATTRIBUTES ObjectAttributes; 200 HANDLE PortHandle; 201 HANDLE ThreadHandles[2]; 202 203 InitializeObjectAttributes(&ObjectAttributes, 204 &PortName, 205 OBJ_CASE_INSENSITIVE, 206 NULL, 207 NULL); 208 Status = NtCreatePort(&PortHandle, 209 &ObjectAttributes, 210 sizeof(TEST_CONNECTION_INFO), 211 sizeof(TEST_MESSAGE), 212 2 * sizeof(TEST_MESSAGE)); 213 ok_hex(Status, STATUS_SUCCESS); 214 if (!NT_SUCCESS(Status)) 215 { 216 skip("Failed to create port\n"); 217 return; 218 } 219 220 ThreadHandles[0] = (HANDLE)_beginthreadex(NULL, 221 0, 222 ServerThread, 223 PortHandle, 224 0, 225 &ServerThreadId); 226 ok(ThreadHandles[0] != NULL, "_beginthreadex failed\n"); 227 228 ThreadHandles[1] = (HANDLE)_beginthreadex(NULL, 229 0, 230 ClientThread, 231 PortHandle, 232 0, 233 &ClientThreadId); 234 ok(ThreadHandles[1] != NULL, "_beginthreadex failed\n"); 235 236 Status = NtWaitForMultipleObjects(RTL_NUMBER_OF(ThreadHandles), 237 ThreadHandles, 238 WaitAll, 239 FALSE, 240 NULL); 241 ok_hex(Status, STATUS_SUCCESS); 242 243 Status = NtClose(ThreadHandles[0]); 244 ok_hex(Status, STATUS_SUCCESS); 245 Status = NtClose(ThreadHandles[1]); 246 ok_hex(Status, STATUS_SUCCESS); 247 248 Status = NtClose(PortHandle); 249 ok_hex(Status, STATUS_SUCCESS); 250 } 251