xref: /reactos/dll/win32/msafd/misc/helpers.c (revision 8a978a17)
1 /*
2  * COPYRIGHT:   See COPYING in the top level directory
3  * PROJECT:     ReactOS Ancillary Function Driver DLL
4  * FILE:        dll/win32/msafd/misc/helpers.c
5  * PURPOSE:     Helper DLL management
6  * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
7  *				Alex Ionescu (alex@relsoft.net)
8  * REVISIONS:
9  *   CSH 01/09-2000 Created
10  *	 Alex 16/07/2004 - Complete Rewrite
11  */
12 
13 #include <msafd.h>
14 
15 #include <winreg.h>
16 
17 CRITICAL_SECTION HelperDLLDatabaseLock;
18 LIST_ENTRY HelperDLLDatabaseListHead;
19 
20 
21 INT
22 SockGetTdiName(
23     PINT AddressFamily,
24     PINT SocketType,
25     PINT Protocol,
26     GROUP Group,
27     DWORD Flags,
28     PUNICODE_STRING TransportName,
29     PVOID *HelperDllContext,
30     PHELPER_DATA *HelperDllData,
31     PDWORD Events)
32 {
33     PHELPER_DATA        HelperData;
34     PWSTR               Transports;
35     PWSTR               Transport;
36     PWINSOCK_MAPPING	Mapping;
37     PLIST_ENTRY	        Helpers;
38     INT                 Status;
39 
40     TRACE("AddressFamily %p, SocketType %p, Protocol %p, Group %u, Flags %lx, TransportName %p, HelperDllContext %p, HelperDllData %p, Events %p\n",
41         AddressFamily, SocketType, Protocol, Group, Flags, TransportName, HelperDllContext, HelperDllData, Events);
42 
43     /* Check in our Current Loaded Helpers */
44     for (Helpers = SockHelpersListHead.Flink;
45          Helpers != &SockHelpersListHead;
46          Helpers = Helpers->Flink ) {
47 
48         HelperData = CONTAINING_RECORD(Helpers, HELPER_DATA, Helpers);
49 
50         /* See if this Mapping works for us */
51         if (SockIsTripleInMapping (HelperData->Mapping,
52                                    *AddressFamily,
53                                    *SocketType,
54                                    *Protocol)) {
55 
56             /* Call the Helper Dll function get the Transport Name */
57             if (HelperData->WSHOpenSocket2 == NULL ) {
58 
59                 /* DLL Doesn't support WSHOpenSocket2, call the old one */
60                 HelperData->WSHOpenSocket(AddressFamily,
61                                           SocketType,
62                                           Protocol,
63                                           TransportName,
64                                           HelperDllContext,
65                                           Events
66                                           );
67             } else {
68                 HelperData->WSHOpenSocket2(AddressFamily,
69                                            SocketType,
70                                            Protocol,
71                                            Group,
72                                            Flags,
73                                            TransportName,
74                                            HelperDllContext,
75                                            Events
76                                            );
77             }
78 
79             /* Return the Helper Pointers */
80             *HelperDllData = HelperData;
81             return NO_ERROR;
82         }
83     }
84 
85     /* Get the Transports available */
86     Status = SockLoadTransportList(&Transports);
87 
88     /* Check for error */
89     if (Status) {
90         WARN("Can't get transport list\n");
91         return Status;
92     }
93 
94     /* Loop through each transport until we find one that can satisfy us */
95     for (Transport = Transports;
96          *Transports != 0;
97          Transport += wcslen(Transport) + 1) {
98         TRACE("Transport: %S\n", Transports);
99 
100         /* See what mapping this Transport supports */
101         Status = SockLoadTransportMapping(Transport, &Mapping);
102 
103         /* Check for error */
104         if (Status) {
105             ERR("Can't get mapping for %S\n", Transports);
106             HeapFree(GlobalHeap, 0, Transports);
107             return Status;
108         }
109 
110         /* See if this Mapping works for us */
111         if (SockIsTripleInMapping(Mapping, *AddressFamily, *SocketType, *Protocol)) {
112 
113             /* It does, so load the DLL associated with it */
114             Status = SockLoadHelperDll(Transport, Mapping, &HelperData);
115 
116             /* Check for error */
117             if (Status) {
118                 ERR("Can't load helper DLL for Transport %S.\n", Transport);
119                 HeapFree(GlobalHeap, 0, Transports);
120                 HeapFree(GlobalHeap, 0, Mapping);
121                 return Status;
122             }
123 
124             /* Call the Helper Dll function get the Transport Name */
125             if (HelperData->WSHOpenSocket2 == NULL) {
126                 /* DLL Doesn't support WSHOpenSocket2, call the old one */
127                 HelperData->WSHOpenSocket(AddressFamily,
128                                           SocketType,
129                                           Protocol,
130                                           TransportName,
131                                           HelperDllContext,
132                                           Events
133                                           );
134             } else {
135                 HelperData->WSHOpenSocket2(AddressFamily,
136                                            SocketType,
137                                            Protocol,
138                                            Group,
139                                            Flags,
140                                            TransportName,
141                                            HelperDllContext,
142                                            Events
143                                            );
144             }
145 
146             /* Return the Helper Pointers */
147             *HelperDllData = HelperData;
148             /* We actually cache these ... the can't be freed yet */
149             /*HeapFree(GlobalHeap, 0, Transports);*/
150             /*HeapFree(GlobalHeap, 0, Mapping);*/
151             return NO_ERROR;
152         }
153 
154         HeapFree(GlobalHeap, 0, Mapping);
155     }
156     HeapFree(GlobalHeap, 0, Transports);
157     return WSAEINVAL;
158 }
159 
160 INT
161 SockLoadTransportMapping(
162     PWSTR TransportName,
163     PWINSOCK_MAPPING *Mapping)
164 {
165     PWSTR               TransportKey;
166     HKEY                KeyHandle;
167     ULONG               MappingSize;
168     LONG                Status;
169 
170     TRACE("TransportName %ws\n", TransportName);
171 
172     /* Allocate a Buffer */
173     TransportKey = HeapAlloc(GlobalHeap, 0, (54 + wcslen(TransportName)) * sizeof(WCHAR));
174 
175     /* Check for error */
176     if (TransportKey == NULL) {
177         ERR("Buffer allocation failed\n");
178         return WSAEINVAL;
179     }
180 
181     /* Generate the right key name */
182     wcscpy(TransportKey, L"System\\CurrentControlSet\\Services\\");
183     wcscat(TransportKey, TransportName);
184     wcscat(TransportKey, L"\\Parameters\\Winsock");
185 
186     /* Open the Key */
187     Status = RegOpenKeyExW(HKEY_LOCAL_MACHINE, TransportKey, 0, KEY_READ, &KeyHandle);
188 
189     /* We don't need the Transport Key anymore */
190     HeapFree(GlobalHeap, 0, TransportKey);
191 
192     /* Check for error */
193     if (Status) {
194         ERR("Error reading transport mapping registry\n");
195         return WSAEINVAL;
196     }
197 
198     /* Find out how much space we need for the Mapping */
199     Status = RegQueryValueExW(KeyHandle, L"Mapping", NULL, NULL, NULL, &MappingSize);
200 
201     /* Check for error */
202     if (Status) {
203         ERR("Error reading transport mapping registry\n");
204         return WSAEINVAL;
205     }
206 
207     /* Allocate Memory for the Mapping */
208     *Mapping = HeapAlloc(GlobalHeap, 0, MappingSize);
209 
210     /* Check for error */
211     if (*Mapping == NULL) {
212         ERR("Buffer allocation failed\n");
213         return WSAEINVAL;
214     }
215 
216     /* Read the Mapping */
217     Status = RegQueryValueExW(KeyHandle, L"Mapping", NULL, NULL, (LPBYTE)*Mapping, &MappingSize);
218 
219     /* Check for error */
220     if (Status) {
221         ERR("Error reading transport mapping registry\n");
222         HeapFree(GlobalHeap, 0, *Mapping);
223         return WSAEINVAL;
224     }
225 
226     /* Close key and return */
227     RegCloseKey(KeyHandle);
228     return 0;
229 }
230 
231 INT
232 SockLoadTransportList(
233     PWSTR *TransportList)
234 {
235     ULONG	TransportListSize;
236     HKEY	KeyHandle;
237     LONG	Status;
238 
239     TRACE("Called\n");
240 
241     /* Open the Transports Key */
242     Status = RegOpenKeyExW (HKEY_LOCAL_MACHINE,
243                             L"SYSTEM\\CurrentControlSet\\Services\\Winsock\\Parameters",
244                             0,
245                             KEY_READ,
246                             &KeyHandle);
247 
248     /* Check for error */
249     if (Status) {
250         ERR("Error reading transport list registry\n");
251         return WSAEINVAL;
252     }
253 
254     /* Get the Transport List Size */
255     Status = RegQueryValueExW(KeyHandle,
256                               L"Transports",
257                               NULL,
258                               NULL,
259                               NULL,
260                               &TransportListSize);
261 
262     /* Check for error */
263     if (Status) {
264         ERR("Error reading transport list registry\n");
265         return WSAEINVAL;
266     }
267 
268     /* Allocate Memory for the Transport List */
269     *TransportList = HeapAlloc(GlobalHeap, 0, TransportListSize);
270 
271     /* Check for error */
272     if (*TransportList == NULL) {
273         ERR("Buffer allocation failed\n");
274         return WSAEINVAL;
275     }
276 
277     /* Get the Transports */
278     Status = RegQueryValueExW (KeyHandle,
279                                L"Transports",
280                                NULL,
281                                NULL,
282                                (LPBYTE)*TransportList,
283                                &TransportListSize);
284 
285     /* Check for error */
286     if (Status) {
287         ERR("Error reading transport list registry\n");
288         HeapFree(GlobalHeap, 0, *TransportList);
289         return WSAEINVAL;
290     }
291 
292     /* Close key and return */
293     RegCloseKey(KeyHandle);
294     return 0;
295 }
296 
297 INT
298 SockLoadHelperDll(
299     PWSTR TransportName,
300     PWINSOCK_MAPPING Mapping,
301     PHELPER_DATA *HelperDllData)
302 {
303     PHELPER_DATA        HelperData;
304     PWSTR               HelperDllName;
305     PWSTR               FullHelperDllName;
306     PWSTR               HelperKey;
307     HKEY                KeyHandle;
308     ULONG               DataSize;
309     LONG                Status;
310 
311     /* Allocate space for the Helper Structure and TransportName */
312     HelperData = HeapAlloc(GlobalHeap, 0, sizeof(*HelperData) + (wcslen(TransportName) + 1) * sizeof(WCHAR));
313 
314     /* Check for error */
315     if (HelperData == NULL) {
316         ERR("Buffer allocation failed\n");
317         return WSAEINVAL;
318     }
319 
320     /* Allocate Space for the Helper DLL Key */
321     HelperKey = HeapAlloc(GlobalHeap, 0, (54 + wcslen(TransportName)) * sizeof(WCHAR));
322 
323     /* Check for error */
324     if (HelperKey == NULL) {
325         ERR("Buffer allocation failed\n");
326         HeapFree(GlobalHeap, 0, HelperData);
327         return WSAEINVAL;
328     }
329 
330     /* Generate the right key name */
331     wcscpy(HelperKey, L"System\\CurrentControlSet\\Services\\");
332     wcscat(HelperKey, TransportName);
333     wcscat(HelperKey, L"\\Parameters\\Winsock");
334 
335     /* Open the Key */
336     Status = RegOpenKeyExW(HKEY_LOCAL_MACHINE, HelperKey, 0, KEY_READ, &KeyHandle);
337 
338     HeapFree(GlobalHeap, 0, HelperKey);
339 
340     /* Check for error */
341     if (Status) {
342         ERR("Error reading helper DLL parameters\n");
343         HeapFree(GlobalHeap, 0, HelperData);
344         return WSAEINVAL;
345     }
346 
347     /* Read Size of SockAddr Structures */
348     DataSize = sizeof(HelperData->MinWSAddressLength);
349     HelperData->MinWSAddressLength = 16;
350     RegQueryValueExW (KeyHandle,
351                       L"MinSockaddrLength",
352                       NULL,
353                       NULL,
354                       (LPBYTE)&HelperData->MinWSAddressLength,
355                       &DataSize);
356     DataSize = sizeof(HelperData->MinWSAddressLength);
357     HelperData->MaxWSAddressLength = 16;
358     RegQueryValueExW (KeyHandle,
359                       L"MaxSockaddrLength",
360                       NULL,
361                       NULL,
362                       (LPBYTE)&HelperData->MaxWSAddressLength,
363                       &DataSize);
364 
365     /* Size of TDI Structures */
366     HelperData->MinTDIAddressLength = HelperData->MinWSAddressLength + 6;
367     HelperData->MaxTDIAddressLength = HelperData->MaxWSAddressLength + 6;
368 
369     /* Read Delayed Acceptance Setting */
370     DataSize = sizeof(DWORD);
371     HelperData->UseDelayedAcceptance = FALSE;
372     RegQueryValueExW (KeyHandle,
373                       L"UseDelayedAcceptance",
374                       NULL,
375                       NULL,
376                       (LPBYTE)&HelperData->UseDelayedAcceptance,
377                       &DataSize);
378 
379     /* Allocate Space for the Helper DLL Names */
380     HelperDllName = HeapAlloc(GlobalHeap, 0, 512);
381 
382     /* Check for error */
383     if (HelperDllName == NULL) {
384         ERR("Buffer allocation failed\n");
385         HeapFree(GlobalHeap, 0, HelperData);
386         return WSAEINVAL;
387     }
388 
389     FullHelperDllName = HeapAlloc(GlobalHeap, 0, 512);
390 
391     /* Check for error */
392     if (FullHelperDllName == NULL) {
393         ERR("Buffer allocation failed\n");
394         HeapFree(GlobalHeap, 0, HelperDllName);
395         HeapFree(GlobalHeap, 0, HelperData);
396         return WSAEINVAL;
397     }
398 
399     /* Get the name of the Helper DLL*/
400     DataSize = 512;
401     Status = RegQueryValueExW (KeyHandle,
402                                L"HelperDllName",
403                                NULL,
404                                NULL,
405                                (LPBYTE)HelperDllName,
406                                &DataSize);
407 
408     /* Check for error */
409     if (Status) {
410         ERR("Error reading helper DLL parameters\n");
411         HeapFree(GlobalHeap, 0, FullHelperDllName);
412         HeapFree(GlobalHeap, 0, HelperDllName);
413         HeapFree(GlobalHeap, 0, HelperData);
414         return WSAEINVAL;
415     }
416 
417     /* Get the Full name, expanding Environment Strings */
418     ExpandEnvironmentStringsW (HelperDllName,
419                                FullHelperDllName,
420                                256);
421 
422     /* Load the DLL */
423     HelperData->hInstance = LoadLibraryW(FullHelperDllName);
424 
425     HeapFree(GlobalHeap, 0, HelperDllName);
426     HeapFree(GlobalHeap, 0, FullHelperDllName);
427 
428     if (HelperData->hInstance == NULL) {
429         ERR("Error loading helper DLL\n");
430         HeapFree(GlobalHeap, 0, HelperData);
431         return WSAEINVAL;
432     }
433 
434     /* Close Key */
435     RegCloseKey(KeyHandle);
436 
437     /* Get the Pointers to the Helper Routines */
438     HelperData->WSHOpenSocket =	(PWSH_OPEN_SOCKET)
439 									GetProcAddress(HelperData->hInstance,
440 									"WSHOpenSocket");
441     HelperData->WSHOpenSocket2 = (PWSH_OPEN_SOCKET2)
442 									GetProcAddress(HelperData->hInstance,
443 									"WSHOpenSocket2");
444     HelperData->WSHJoinLeaf = (PWSH_JOIN_LEAF)
445 								GetProcAddress(HelperData->hInstance,
446 								"WSHJoinLeaf");
447     HelperData->WSHNotify = (PWSH_NOTIFY)
448 								GetProcAddress(HelperData->hInstance, "WSHNotify");
449     HelperData->WSHGetSocketInformation = (PWSH_GET_SOCKET_INFORMATION)
450 											GetProcAddress(HelperData->hInstance,
451 											"WSHGetSocketInformation");
452     HelperData->WSHSetSocketInformation = (PWSH_SET_SOCKET_INFORMATION)
453 											GetProcAddress(HelperData->hInstance,
454 											"WSHSetSocketInformation");
455     HelperData->WSHGetSockaddrType = (PWSH_GET_SOCKADDR_TYPE)
456 										GetProcAddress(HelperData->hInstance,
457 										"WSHGetSockaddrType");
458     HelperData->WSHGetWildcardSockaddr = (PWSH_GET_WILDCARD_SOCKADDR)
459 											GetProcAddress(HelperData->hInstance,
460 											"WSHGetWildcardSockaddr");
461     HelperData->WSHGetBroadcastSockaddr = (PWSH_GET_BROADCAST_SOCKADDR)
462 											GetProcAddress(HelperData->hInstance,
463 											"WSHGetBroadcastSockaddr");
464     HelperData->WSHAddressToString = (PWSH_ADDRESS_TO_STRING)
465 										GetProcAddress(HelperData->hInstance,
466 										"WSHAddressToString");
467     HelperData->WSHStringToAddress = (PWSH_STRING_TO_ADDRESS)
468 										GetProcAddress(HelperData->hInstance,
469 										"WSHStringToAddress");
470     HelperData->WSHIoctl = (PWSH_IOCTL)
471 							GetProcAddress(HelperData->hInstance,
472 							"WSHIoctl");
473 
474     /* Save the Mapping Structure and transport name */
475     HelperData->Mapping = Mapping;
476     wcscpy(HelperData->TransportName, TransportName);
477 
478     /* Increment Reference Count */
479     HelperData->RefCount = 1;
480 
481     /* Add it to our list */
482     InsertHeadList(&SockHelpersListHead, &HelperData->Helpers);
483 
484     /* Return Pointers */
485     *HelperDllData = HelperData;
486     return 0;
487 }
488 
489 BOOL
490 SockIsTripleInMapping(
491     PWINSOCK_MAPPING Mapping,
492     INT AddressFamily,
493     INT SocketType,
494     INT Protocol)
495 {
496     /* The Windows version returns more detailed information on which of the 3 parameters failed...we should do this later */
497     ULONG    Row;
498 
499     TRACE("Called, Mapping rows = %d\n", Mapping->Rows);
500 
501     /* Loop through Mapping to Find a matching one */
502     for (Row = 0; Row < Mapping->Rows; Row++) {
503         TRACE("Examining: row %d: AF %d type %d proto %d\n",
504 				Row,
505 				(INT)Mapping->Mapping[Row].AddressFamily,
506 				(INT)Mapping->Mapping[Row].SocketType,
507 				(INT)Mapping->Mapping[Row].Protocol);
508 
509         /* Check of all three values Match */
510         if (((INT)Mapping->Mapping[Row].AddressFamily == AddressFamily) &&
511             ((INT)Mapping->Mapping[Row].SocketType == SocketType) &&
512             ((INT)Mapping->Mapping[Row].Protocol == Protocol)) {
513             TRACE("Found\n");
514             return TRUE;
515         }
516     }
517     WARN("Not found\n");
518     return FALSE;
519 }
520 
521 /* EOF */
522