1 /*
2  * PROJECT:     ReactOS API Tests
3  * LICENSE:     LGPL-2.1+ (https://spdx.org/licenses/LGPL-2.1+)
4  * PURPOSE:     Utility function definitions for calling AFD
5  * COPYRIGHT:   Copyright 2015-2018 Thomas Faber (thomas.faber@reactos.org)
6  */
7 
8 #include "precomp.h"
9 
10 #define DD_UDP_DEVICE_NAME L"\\Device\\Udp"
11 
12 typedef struct _AFD_CREATE_PACKET_NT6 {
13     DWORD                               EndpointFlags;
14     DWORD                               GroupID;
15     DWORD                               AddressFamily;
16     DWORD                               SocketType;
17     DWORD                               Protocol;
18     DWORD                               SizeOfTransportName;
19     WCHAR                               TransportName[1];
20 } AFD_CREATE_PACKET_NT6, *PAFD_CREATE_PACKET_NT6;
21 
22 NTSTATUS
23 AfdCreateSocket(
24     _Out_ PHANDLE SocketHandle,
25     _In_ int AddressFamily,
26     _In_ int SocketType,
27     _In_ int Protocol)
28 {
29     NTSTATUS Status;
30     OBJECT_ATTRIBUTES ObjectAttributes;
31     IO_STATUS_BLOCK IoStatus;
32     PFILE_FULL_EA_INFORMATION EaBuffer = NULL;
33     ULONG EaLength;
34     PAFD_CREATE_PACKET AfdPacket;
35     PAFD_CREATE_PACKET_NT6 AfdPacket6;
36     ULONG SizeOfPacket;
37     ANSI_STRING EaName = RTL_CONSTANT_STRING(AfdCommand);
38     UNICODE_STRING TcpTransportName = RTL_CONSTANT_STRING(DD_TCP_DEVICE_NAME);
39     UNICODE_STRING UdpTransportName = RTL_CONSTANT_STRING(DD_UDP_DEVICE_NAME);
40     UNICODE_STRING TransportName = SocketType == SOCK_STREAM ? TcpTransportName : UdpTransportName;
41     UNICODE_STRING DeviceName = RTL_CONSTANT_STRING(L"\\Device\\Afd\\Endpoint");
42 
43     *SocketHandle = NULL;
44 
45     if (LOBYTE(LOWORD(GetVersion())) >= 6)
46     {
47         SizeOfPacket = FIELD_OFFSET(AFD_CREATE_PACKET_NT6, TransportName) + TransportName.Length + sizeof(UNICODE_NULL);
48     }
49     else
50     {
51         SizeOfPacket = FIELD_OFFSET(AFD_CREATE_PACKET, TransportName) + TransportName.Length + sizeof(UNICODE_NULL);
52     }
53     EaLength = SizeOfPacket + FIELD_OFFSET(FILE_FULL_EA_INFORMATION, EaName) + EaName.Length + sizeof(ANSI_NULL);
54 
55     /* Set up EA Buffer */
56     EaBuffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, EaLength);
57     if (!EaBuffer)
58     {
59         return STATUS_INSUFFICIENT_RESOURCES;
60     }
61 
62     EaBuffer->NextEntryOffset = 0;
63     EaBuffer->Flags = 0;
64     EaBuffer->EaNameLength = EaName.Length;
65     RtlCopyMemory(EaBuffer->EaName,
66                   EaName.Buffer,
67                   EaName.Length + sizeof(ANSI_NULL));
68     EaBuffer->EaValueLength = SizeOfPacket;
69 
70     if (LOBYTE(LOWORD(GetVersion())) >= 6)
71     {
72         AfdPacket6 = (PAFD_CREATE_PACKET_NT6)(EaBuffer->EaName + EaBuffer->EaNameLength + sizeof(ANSI_NULL));
73         AfdPacket6->GroupID = 0;
74         if (SocketType == SOCK_DGRAM)
75         {
76             AfdPacket6->EndpointFlags = AFD_ENDPOINT_CONNECTIONLESS;
77         }
78         else if (SocketType == SOCK_STREAM)
79         {
80             AfdPacket6->EndpointFlags = AFD_ENDPOINT_MESSAGE_ORIENTED;
81         }
82         AfdPacket6->AddressFamily = AddressFamily;
83         AfdPacket6->SocketType = SocketType;
84         AfdPacket6->Protocol = Protocol;
85         AfdPacket6->SizeOfTransportName = TransportName.Length;
86         RtlCopyMemory(AfdPacket6->TransportName,
87                       TransportName.Buffer,
88                       TransportName.Length + sizeof(UNICODE_NULL));
89     }
90     else
91     {
92         AfdPacket = (PAFD_CREATE_PACKET)(EaBuffer->EaName + EaBuffer->EaNameLength + sizeof(ANSI_NULL));
93         AfdPacket->GroupID = 0;
94         if (SocketType == SOCK_DGRAM)
95         {
96             AfdPacket->EndpointFlags = AFD_ENDPOINT_CONNECTIONLESS;
97         }
98         else if (SocketType == SOCK_STREAM)
99         {
100             AfdPacket->EndpointFlags = AFD_ENDPOINT_MESSAGE_ORIENTED;
101         }
102         AfdPacket->SizeOfTransportName = TransportName.Length;
103         RtlCopyMemory(AfdPacket->TransportName,
104                       TransportName.Buffer,
105                       TransportName.Length + sizeof(UNICODE_NULL));
106     }
107 
108     InitializeObjectAttributes(&ObjectAttributes,
109                                &DeviceName,
110                                OBJ_CASE_INSENSITIVE | OBJ_INHERIT,
111                                0,
112                                0);
113 
114     Status = NtCreateFile(SocketHandle,
115                           GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
116                           &ObjectAttributes,
117                           &IoStatus,
118                           NULL,
119                           0,
120                           FILE_SHARE_READ | FILE_SHARE_WRITE,
121                           FILE_OPEN_IF,
122                           0,
123                           EaBuffer,
124                           EaLength);
125 
126     RtlFreeHeap(RtlGetProcessHeap(), 0, EaBuffer);
127 
128     return Status;
129 }
130 
131 
132 NTSTATUS
133 AfdBind(
134     _In_ HANDLE SocketHandle,
135     _In_ const struct sockaddr *Address,
136     _In_ ULONG AddressLength)
137 {
138     NTSTATUS Status;
139     IO_STATUS_BLOCK IoStatus;
140     PAFD_BIND_DATA BindInfo;
141     ULONG BindInfoLength;
142     HANDLE Event;
143 
144     Status = NtCreateEvent(&Event,
145                            EVENT_ALL_ACCESS,
146                            NULL,
147                            NotificationEvent,
148                            FALSE);
149     if (!NT_SUCCESS(Status))
150     {
151         return Status;
152     }
153 
154     BindInfoLength = FIELD_OFFSET(AFD_BIND_DATA, Address.Address[0].Address) +
155                      AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
156     BindInfo = RtlAllocateHeap(RtlGetProcessHeap(),
157                                0,
158                                BindInfoLength);
159     if (!BindInfo)
160     {
161         NtClose(Event);
162         return STATUS_INSUFFICIENT_RESOURCES;
163     }
164 
165     BindInfo->ShareType = AFD_SHARE_UNIQUE;
166     BindInfo->Address.TAAddressCount = 1;
167     BindInfo->Address.Address[0].AddressType = Address->sa_family;
168     BindInfo->Address.Address[0].AddressLength = AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
169     RtlCopyMemory(&BindInfo->Address.Address[0].Address,
170                   Address->sa_data,
171                   BindInfo->Address.Address[0].AddressLength);
172 
173     Status = NtDeviceIoControlFile(SocketHandle,
174                                    Event,
175                                    NULL,
176                                    NULL,
177                                    &IoStatus,
178                                    IOCTL_AFD_BIND,
179                                    BindInfo,
180                                    BindInfoLength,
181                                    BindInfo,
182                                    BindInfoLength);
183     if (Status == STATUS_PENDING)
184     {
185         NtWaitForSingleObject(Event, FALSE, NULL);
186         Status = IoStatus.Status;
187     }
188 
189     RtlFreeHeap(RtlGetProcessHeap(), 0, BindInfo);
190     NtClose(Event);
191 
192     return Status;
193 }
194 
195 NTSTATUS
196 AfdConnect(
197     _In_ HANDLE SocketHandle,
198     _In_ const struct sockaddr *Address,
199     _In_ ULONG AddressLength)
200 {
201     NTSTATUS Status;
202     IO_STATUS_BLOCK IoStatus;
203     PAFD_CONNECT_INFO ConnectInfo;
204     ULONG ConnectInfoLength;
205     HANDLE Event;
206 
207     Status = NtCreateEvent(&Event,
208                            EVENT_ALL_ACCESS,
209                            NULL,
210                            NotificationEvent,
211                            FALSE);
212     if (!NT_SUCCESS(Status))
213     {
214         return Status;
215     }
216 
217     ASSERT(FIELD_OFFSET(AFD_CONNECT_INFO, RemoteAddress.Address[0].Address) == 20);
218     ConnectInfoLength = FIELD_OFFSET(AFD_CONNECT_INFO, RemoteAddress.Address[0].Address) +
219                         AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
220     ConnectInfo = RtlAllocateHeap(RtlGetProcessHeap(),
221                                   0,
222                                   ConnectInfoLength);
223     if (!ConnectInfo)
224     {
225         NtClose(Event);
226         return STATUS_INSUFFICIENT_RESOURCES;
227     }
228 
229 
230     ConnectInfo->UseSAN = FALSE;
231     ConnectInfo->Root = 0;
232     ConnectInfo->Unknown = 0;
233     ConnectInfo->RemoteAddress.TAAddressCount = 1;
234     ConnectInfo->RemoteAddress.Address[0].AddressType = Address->sa_family;
235     ConnectInfo->RemoteAddress.Address[0].AddressLength = AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
236     RtlCopyMemory(&ConnectInfo->RemoteAddress.Address[0].Address,
237                   Address->sa_data,
238                   ConnectInfo->RemoteAddress.Address[0].AddressLength);
239 
240     Status = NtDeviceIoControlFile(SocketHandle,
241                                    Event,
242                                    NULL,
243                                    NULL,
244                                    &IoStatus,
245                                    IOCTL_AFD_CONNECT,
246                                    ConnectInfo,
247                                    ConnectInfoLength,
248                                    NULL,
249                                    0);
250     if (Status == STATUS_PENDING)
251     {
252         NtWaitForSingleObject(Event, FALSE, NULL);
253         Status = IoStatus.Status;
254     }
255 
256     RtlFreeHeap(RtlGetProcessHeap(), 0, ConnectInfo);
257     NtClose(Event);
258 
259     return Status;
260 }
261 
262 NTSTATUS
263 AfdSend(
264     _In_ HANDLE SocketHandle,
265     _In_ const void *Buffer,
266     _In_ ULONG BufferLength)
267 {
268     NTSTATUS Status;
269     IO_STATUS_BLOCK IoStatus;
270     AFD_SEND_INFO SendInfo;
271     HANDLE Event;
272     AFD_WSABUF AfdBuffer;
273 
274     Status = NtCreateEvent(&Event,
275                            EVENT_ALL_ACCESS,
276                            NULL,
277                            NotificationEvent,
278                            FALSE);
279     if (!NT_SUCCESS(Status))
280     {
281         return Status;
282     }
283 
284     AfdBuffer.buf = (PVOID)Buffer;
285     AfdBuffer.len = BufferLength;
286     SendInfo.BufferArray = &AfdBuffer;
287     SendInfo.BufferCount = 1;
288     SendInfo.TdiFlags = 0;
289     SendInfo.AfdFlags = 0;
290 
291     Status = NtDeviceIoControlFile(SocketHandle,
292                                    Event,
293                                    NULL,
294                                    NULL,
295                                    &IoStatus,
296                                    IOCTL_AFD_SEND,
297                                    &SendInfo,
298                                    sizeof(SendInfo),
299                                    NULL,
300                                    0);
301     if (Status == STATUS_PENDING)
302     {
303         NtWaitForSingleObject(Event, FALSE, NULL);
304         Status = IoStatus.Status;
305     }
306 
307     NtClose(Event);
308 
309     return Status;
310 }
311 
312 NTSTATUS
313 AfdSendTo(
314     _In_ HANDLE SocketHandle,
315     _In_ const void *Buffer,
316     _In_ ULONG BufferLength,
317     _In_ const struct sockaddr *Address,
318     _In_ ULONG AddressLength)
319 {
320     NTSTATUS Status;
321     IO_STATUS_BLOCK IoStatus;
322     AFD_SEND_INFO_UDP SendInfo;
323     HANDLE Event;
324     AFD_WSABUF AfdBuffer;
325     PTRANSPORT_ADDRESS TransportAddress;
326     ULONG TransportAddressLength;
327 
328     Status = NtCreateEvent(&Event,
329                            EVENT_ALL_ACCESS,
330                            NULL,
331                            NotificationEvent,
332                            FALSE);
333     if (!NT_SUCCESS(Status))
334     {
335         return Status;
336     }
337 
338     TransportAddressLength = FIELD_OFFSET(TRANSPORT_ADDRESS, Address[0].Address) +
339                              AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
340     TransportAddress = RtlAllocateHeap(RtlGetProcessHeap(),
341                                        0,
342                                        TransportAddressLength);
343     if (!TransportAddress)
344     {
345         NtClose(Event);
346         return STATUS_INSUFFICIENT_RESOURCES;
347     }
348     TransportAddress->TAAddressCount = 1;
349     TransportAddress->Address[0].AddressType = Address->sa_family;
350     TransportAddress->Address[0].AddressLength = AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
351     RtlCopyMemory(&TransportAddress->Address[0].Address,
352                   Address->sa_data,
353                   TransportAddress->Address[0].AddressLength);
354 
355     AfdBuffer.buf = (PVOID)Buffer;
356     AfdBuffer.len = BufferLength;
357     RtlZeroMemory(&SendInfo, sizeof(SendInfo));
358     SendInfo.BufferArray = &AfdBuffer;
359     SendInfo.BufferCount = 1;
360     SendInfo.AfdFlags = 0;
361     SendInfo.TdiConnection.RemoteAddress = TransportAddress;
362     SendInfo.TdiConnection.RemoteAddressLength = TransportAddressLength;
363 
364     Status = NtDeviceIoControlFile(SocketHandle,
365                                    Event,
366                                    NULL,
367                                    NULL,
368                                    &IoStatus,
369                                    IOCTL_AFD_SEND_DATAGRAM,
370                                    &SendInfo,
371                                    sizeof(SendInfo),
372                                    NULL,
373                                    0);
374     if (Status == STATUS_PENDING)
375     {
376         NtWaitForSingleObject(Event, FALSE, NULL);
377         Status = IoStatus.Status;
378     }
379 
380     RtlFreeHeap(RtlGetProcessHeap(), 0, TransportAddress);
381     NtClose(Event);
382 
383     return Status;
384 }
385