1 /*
2  * PROJECT:     ReactOS API Tests
3  * LICENSE:     LGPL-2.1+ (https://spdx.org/licenses/LGPL-2.1+)
4  * PURPOSE:     Utility function definitions for calling AFD
5  * COPYRIGHT:   Copyright 2015-2018 Thomas Faber (thomas.faber@reactos.org)
6  *              Copyright 2019 Pierre Schweitzer (pierre@reactos.org)
7  */
8 
9 #include "precomp.h"
10 
11 #define DD_UDP_DEVICE_NAME L"\\Device\\Udp"
12 
13 typedef struct _AFD_CREATE_PACKET_NT6 {
14     DWORD                               EndpointFlags;
15     DWORD                               GroupID;
16     DWORD                               AddressFamily;
17     DWORD                               SocketType;
18     DWORD                               Protocol;
19     DWORD                               SizeOfTransportName;
20     WCHAR                               TransportName[1];
21 } AFD_CREATE_PACKET_NT6, *PAFD_CREATE_PACKET_NT6;
22 
23 NTSTATUS
24 AfdCreateSocket(
25     _Out_ PHANDLE SocketHandle,
26     _In_ int AddressFamily,
27     _In_ int SocketType,
28     _In_ int Protocol)
29 {
30     NTSTATUS Status;
31     OBJECT_ATTRIBUTES ObjectAttributes;
32     IO_STATUS_BLOCK IoStatus;
33     PFILE_FULL_EA_INFORMATION EaBuffer = NULL;
34     ULONG EaLength;
35     PAFD_CREATE_PACKET AfdPacket;
36     PAFD_CREATE_PACKET_NT6 AfdPacket6;
37     ULONG SizeOfPacket;
38     ANSI_STRING EaName = RTL_CONSTANT_STRING(AfdCommand);
39     UNICODE_STRING TcpTransportName = RTL_CONSTANT_STRING(DD_TCP_DEVICE_NAME);
40     UNICODE_STRING UdpTransportName = RTL_CONSTANT_STRING(DD_UDP_DEVICE_NAME);
41     UNICODE_STRING TransportName = SocketType == SOCK_STREAM ? TcpTransportName : UdpTransportName;
42     UNICODE_STRING DeviceName = RTL_CONSTANT_STRING(L"\\Device\\Afd\\Endpoint");
43 
44     *SocketHandle = NULL;
45 
46     if (LOBYTE(LOWORD(GetVersion())) >= 6)
47     {
48         SizeOfPacket = FIELD_OFFSET(AFD_CREATE_PACKET_NT6, TransportName) + TransportName.Length + sizeof(UNICODE_NULL);
49     }
50     else
51     {
52         SizeOfPacket = FIELD_OFFSET(AFD_CREATE_PACKET, TransportName) + TransportName.Length + sizeof(UNICODE_NULL);
53     }
54     EaLength = SizeOfPacket + FIELD_OFFSET(FILE_FULL_EA_INFORMATION, EaName) + EaName.Length + sizeof(ANSI_NULL);
55 
56     /* Set up EA Buffer */
57     EaBuffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, EaLength);
58     if (!EaBuffer)
59     {
60         return STATUS_INSUFFICIENT_RESOURCES;
61     }
62 
63     EaBuffer->NextEntryOffset = 0;
64     EaBuffer->Flags = 0;
65     EaBuffer->EaNameLength = EaName.Length;
66     RtlCopyMemory(EaBuffer->EaName,
67                   EaName.Buffer,
68                   EaName.Length + sizeof(ANSI_NULL));
69     EaBuffer->EaValueLength = SizeOfPacket;
70 
71     if (LOBYTE(LOWORD(GetVersion())) >= 6)
72     {
73         AfdPacket6 = (PAFD_CREATE_PACKET_NT6)(EaBuffer->EaName + EaBuffer->EaNameLength + sizeof(ANSI_NULL));
74         AfdPacket6->GroupID = 0;
75         if (SocketType == SOCK_DGRAM)
76         {
77             AfdPacket6->EndpointFlags = AFD_ENDPOINT_CONNECTIONLESS;
78         }
79         else if (SocketType == SOCK_STREAM)
80         {
81             AfdPacket6->EndpointFlags = AFD_ENDPOINT_MESSAGE_ORIENTED;
82         }
83         AfdPacket6->AddressFamily = AddressFamily;
84         AfdPacket6->SocketType = SocketType;
85         AfdPacket6->Protocol = Protocol;
86         AfdPacket6->SizeOfTransportName = TransportName.Length;
87         RtlCopyMemory(AfdPacket6->TransportName,
88                       TransportName.Buffer,
89                       TransportName.Length + sizeof(UNICODE_NULL));
90     }
91     else
92     {
93         AfdPacket = (PAFD_CREATE_PACKET)(EaBuffer->EaName + EaBuffer->EaNameLength + sizeof(ANSI_NULL));
94         AfdPacket->GroupID = 0;
95         if (SocketType == SOCK_DGRAM)
96         {
97             AfdPacket->EndpointFlags = AFD_ENDPOINT_CONNECTIONLESS;
98         }
99         else if (SocketType == SOCK_STREAM)
100         {
101             AfdPacket->EndpointFlags = AFD_ENDPOINT_MESSAGE_ORIENTED;
102         }
103         AfdPacket->SizeOfTransportName = TransportName.Length;
104         RtlCopyMemory(AfdPacket->TransportName,
105                       TransportName.Buffer,
106                       TransportName.Length + sizeof(UNICODE_NULL));
107     }
108 
109     InitializeObjectAttributes(&ObjectAttributes,
110                                &DeviceName,
111                                OBJ_CASE_INSENSITIVE | OBJ_INHERIT,
112                                0,
113                                0);
114 
115     Status = NtCreateFile(SocketHandle,
116                           GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
117                           &ObjectAttributes,
118                           &IoStatus,
119                           NULL,
120                           0,
121                           FILE_SHARE_READ | FILE_SHARE_WRITE,
122                           FILE_OPEN_IF,
123                           0,
124                           EaBuffer,
125                           EaLength);
126 
127     RtlFreeHeap(RtlGetProcessHeap(), 0, EaBuffer);
128 
129     return Status;
130 }
131 
132 
133 NTSTATUS
134 AfdBind(
135     _In_ HANDLE SocketHandle,
136     _In_ const struct sockaddr *Address,
137     _In_ ULONG AddressLength)
138 {
139     NTSTATUS Status;
140     IO_STATUS_BLOCK IoStatus;
141     PAFD_BIND_DATA BindInfo;
142     ULONG BindInfoLength;
143     HANDLE Event;
144 
145     Status = NtCreateEvent(&Event,
146                            EVENT_ALL_ACCESS,
147                            NULL,
148                            NotificationEvent,
149                            FALSE);
150     if (!NT_SUCCESS(Status))
151     {
152         return Status;
153     }
154 
155     BindInfoLength = FIELD_OFFSET(AFD_BIND_DATA, Address.Address[0].Address) +
156                      AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
157     BindInfo = RtlAllocateHeap(RtlGetProcessHeap(),
158                                0,
159                                BindInfoLength);
160     if (!BindInfo)
161     {
162         NtClose(Event);
163         return STATUS_INSUFFICIENT_RESOURCES;
164     }
165 
166     BindInfo->ShareType = AFD_SHARE_UNIQUE;
167     BindInfo->Address.TAAddressCount = 1;
168     BindInfo->Address.Address[0].AddressType = Address->sa_family;
169     BindInfo->Address.Address[0].AddressLength = AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
170     RtlCopyMemory(&BindInfo->Address.Address[0].Address,
171                   Address->sa_data,
172                   BindInfo->Address.Address[0].AddressLength);
173 
174     Status = NtDeviceIoControlFile(SocketHandle,
175                                    Event,
176                                    NULL,
177                                    NULL,
178                                    &IoStatus,
179                                    IOCTL_AFD_BIND,
180                                    BindInfo,
181                                    BindInfoLength,
182                                    BindInfo,
183                                    BindInfoLength);
184     if (Status == STATUS_PENDING)
185     {
186         NtWaitForSingleObject(Event, FALSE, NULL);
187         Status = IoStatus.Status;
188     }
189 
190     RtlFreeHeap(RtlGetProcessHeap(), 0, BindInfo);
191     NtClose(Event);
192 
193     return Status;
194 }
195 
196 NTSTATUS
197 AfdConnect(
198     _In_ HANDLE SocketHandle,
199     _In_ const struct sockaddr *Address,
200     _In_ ULONG AddressLength)
201 {
202     NTSTATUS Status;
203     IO_STATUS_BLOCK IoStatus;
204     PAFD_CONNECT_INFO ConnectInfo;
205     ULONG ConnectInfoLength;
206     HANDLE Event;
207 
208     Status = NtCreateEvent(&Event,
209                            EVENT_ALL_ACCESS,
210                            NULL,
211                            NotificationEvent,
212                            FALSE);
213     if (!NT_SUCCESS(Status))
214     {
215         return Status;
216     }
217 
218     ASSERT(FIELD_OFFSET(AFD_CONNECT_INFO, RemoteAddress.Address[0].Address) == 20);
219     ConnectInfoLength = FIELD_OFFSET(AFD_CONNECT_INFO, RemoteAddress.Address[0].Address) +
220                         AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
221     ConnectInfo = RtlAllocateHeap(RtlGetProcessHeap(),
222                                   0,
223                                   ConnectInfoLength);
224     if (!ConnectInfo)
225     {
226         NtClose(Event);
227         return STATUS_INSUFFICIENT_RESOURCES;
228     }
229 
230 
231     ConnectInfo->UseSAN = FALSE;
232     ConnectInfo->Root = 0;
233     ConnectInfo->Unknown = 0;
234     ConnectInfo->RemoteAddress.TAAddressCount = 1;
235     ConnectInfo->RemoteAddress.Address[0].AddressType = Address->sa_family;
236     ConnectInfo->RemoteAddress.Address[0].AddressLength = AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
237     RtlCopyMemory(&ConnectInfo->RemoteAddress.Address[0].Address,
238                   Address->sa_data,
239                   ConnectInfo->RemoteAddress.Address[0].AddressLength);
240 
241     Status = NtDeviceIoControlFile(SocketHandle,
242                                    Event,
243                                    NULL,
244                                    NULL,
245                                    &IoStatus,
246                                    IOCTL_AFD_CONNECT,
247                                    ConnectInfo,
248                                    ConnectInfoLength,
249                                    NULL,
250                                    0);
251     if (Status == STATUS_PENDING)
252     {
253         NtWaitForSingleObject(Event, FALSE, NULL);
254         Status = IoStatus.Status;
255     }
256 
257     RtlFreeHeap(RtlGetProcessHeap(), 0, ConnectInfo);
258     NtClose(Event);
259 
260     return Status;
261 }
262 
263 NTSTATUS
264 AfdSend(
265     _In_ HANDLE SocketHandle,
266     _In_ const void *Buffer,
267     _In_ ULONG BufferLength)
268 {
269     NTSTATUS Status;
270     IO_STATUS_BLOCK IoStatus;
271     AFD_SEND_INFO SendInfo;
272     HANDLE Event;
273     AFD_WSABUF AfdBuffer;
274 
275     Status = NtCreateEvent(&Event,
276                            EVENT_ALL_ACCESS,
277                            NULL,
278                            NotificationEvent,
279                            FALSE);
280     if (!NT_SUCCESS(Status))
281     {
282         return Status;
283     }
284 
285     AfdBuffer.buf = (PVOID)Buffer;
286     AfdBuffer.len = BufferLength;
287     SendInfo.BufferArray = &AfdBuffer;
288     SendInfo.BufferCount = 1;
289     SendInfo.TdiFlags = 0;
290     SendInfo.AfdFlags = 0;
291 
292     Status = NtDeviceIoControlFile(SocketHandle,
293                                    Event,
294                                    NULL,
295                                    NULL,
296                                    &IoStatus,
297                                    IOCTL_AFD_SEND,
298                                    &SendInfo,
299                                    sizeof(SendInfo),
300                                    NULL,
301                                    0);
302     if (Status == STATUS_PENDING)
303     {
304         NtWaitForSingleObject(Event, FALSE, NULL);
305         Status = IoStatus.Status;
306     }
307 
308     NtClose(Event);
309 
310     return Status;
311 }
312 
313 NTSTATUS
314 AfdSendTo(
315     _In_ HANDLE SocketHandle,
316     _In_ const void *Buffer,
317     _In_ ULONG BufferLength,
318     _In_ const struct sockaddr *Address,
319     _In_ ULONG AddressLength)
320 {
321     NTSTATUS Status;
322     IO_STATUS_BLOCK IoStatus;
323     AFD_SEND_INFO_UDP SendInfo;
324     HANDLE Event;
325     AFD_WSABUF AfdBuffer;
326     PTRANSPORT_ADDRESS TransportAddress;
327     ULONG TransportAddressLength;
328 
329     Status = NtCreateEvent(&Event,
330                            EVENT_ALL_ACCESS,
331                            NULL,
332                            NotificationEvent,
333                            FALSE);
334     if (!NT_SUCCESS(Status))
335     {
336         return Status;
337     }
338 
339     TransportAddressLength = FIELD_OFFSET(TRANSPORT_ADDRESS, Address[0].Address) +
340                              AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
341     TransportAddress = RtlAllocateHeap(RtlGetProcessHeap(),
342                                        0,
343                                        TransportAddressLength);
344     if (!TransportAddress)
345     {
346         NtClose(Event);
347         return STATUS_INSUFFICIENT_RESOURCES;
348     }
349     TransportAddress->TAAddressCount = 1;
350     TransportAddress->Address[0].AddressType = Address->sa_family;
351     TransportAddress->Address[0].AddressLength = AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
352     RtlCopyMemory(&TransportAddress->Address[0].Address,
353                   Address->sa_data,
354                   TransportAddress->Address[0].AddressLength);
355 
356     AfdBuffer.buf = (PVOID)Buffer;
357     AfdBuffer.len = BufferLength;
358     RtlZeroMemory(&SendInfo, sizeof(SendInfo));
359     SendInfo.BufferArray = &AfdBuffer;
360     SendInfo.BufferCount = 1;
361     SendInfo.AfdFlags = 0;
362     SendInfo.TdiConnection.RemoteAddress = TransportAddress;
363     SendInfo.TdiConnection.RemoteAddressLength = TransportAddressLength;
364 
365     Status = NtDeviceIoControlFile(SocketHandle,
366                                    Event,
367                                    NULL,
368                                    NULL,
369                                    &IoStatus,
370                                    IOCTL_AFD_SEND_DATAGRAM,
371                                    &SendInfo,
372                                    sizeof(SendInfo),
373                                    NULL,
374                                    0);
375     if (Status == STATUS_PENDING)
376     {
377         NtWaitForSingleObject(Event, FALSE, NULL);
378         Status = IoStatus.Status;
379     }
380 
381     RtlFreeHeap(RtlGetProcessHeap(), 0, TransportAddress);
382     NtClose(Event);
383 
384     return Status;
385 }
386 
387 NTSTATUS
388 AfdSetInformation(
389     _In_ HANDLE SocketHandle,
390     _In_ ULONG InformationClass,
391     _In_opt_ PBOOLEAN Boolean,
392     _In_opt_ PULONG Ulong,
393     _In_opt_ PLARGE_INTEGER LargeInteger)
394 {
395     NTSTATUS Status;
396     IO_STATUS_BLOCK IoStatus;
397     AFD_INFO InfoData;
398     HANDLE Event;
399 
400     Status = NtCreateEvent(&Event,
401                            EVENT_ALL_ACCESS,
402                            NULL,
403                            NotificationEvent,
404                            FALSE);
405     if (!NT_SUCCESS(Status))
406     {
407         return Status;
408     }
409 
410     InfoData.InformationClass = InformationClass;
411 
412     if (Ulong != NULL)
413     {
414         InfoData.Information.Ulong = *Ulong;
415     }
416     if (LargeInteger != NULL)
417     {
418         InfoData.Information.LargeInteger = *LargeInteger;
419     }
420     if (Boolean != NULL)
421     {
422         InfoData.Information.Boolean = *Boolean;
423     }
424 
425     Status = NtDeviceIoControlFile(SocketHandle,
426                                    Event,
427                                    NULL,
428                                    NULL,
429                                    &IoStatus,
430                                    IOCTL_AFD_SET_INFO,
431                                    &InfoData,
432                                    sizeof(InfoData),
433                                    NULL,
434                                    0);
435     if (Status == STATUS_PENDING)
436     {
437         NtWaitForSingleObject(Event, FALSE, NULL);
438         Status = IoStatus.Status;
439     }
440 
441     NtClose(Event);
442 
443     return Status;
444 }
445 
446 NTSTATUS
447 AfdGetInformation(
448     _In_ HANDLE SocketHandle,
449     _In_ ULONG InformationClass,
450     _In_opt_ PBOOLEAN Boolean,
451     _In_opt_ PULONG Ulong,
452     _In_opt_ PLARGE_INTEGER LargeInteger)
453 {
454     NTSTATUS Status;
455     IO_STATUS_BLOCK IoStatus;
456     AFD_INFO InfoData;
457     HANDLE Event;
458 
459     Status = NtCreateEvent(&Event,
460                            EVENT_ALL_ACCESS,
461                            NULL,
462                            NotificationEvent,
463                            FALSE);
464     if (!NT_SUCCESS(Status))
465     {
466         return Status;
467     }
468 
469     InfoData.InformationClass = InformationClass;
470 
471     Status = NtDeviceIoControlFile(SocketHandle,
472                                    Event,
473                                    NULL,
474                                    NULL,
475                                    &IoStatus,
476                                    IOCTL_AFD_GET_INFO,
477                                    &InfoData,
478                                    sizeof(InfoData),
479                                    &InfoData,
480                                    sizeof(InfoData));
481     if (Status == STATUS_PENDING)
482     {
483         NtWaitForSingleObject(Event, FALSE, NULL);
484         Status = IoStatus.Status;
485     }
486 
487     NtClose(Event);
488 
489     if (Status != STATUS_SUCCESS)
490     {
491         return Status;
492     }
493 
494     if (Ulong != NULL)
495     {
496         *Ulong = InfoData.Information.Ulong;
497     }
498     if (LargeInteger != NULL)
499     {
500         *LargeInteger = InfoData.Information.LargeInteger;
501     }
502     if (Boolean != NULL)
503     {
504         *Boolean = InfoData.Information.Boolean;
505     }
506 
507     return Status;
508 }
509