1 /*
2 This file is part of Warzone 2100.
3 Copyright (C) 1999-2004 Eidos Interactive
4 Copyright (C) 2005-2020 Warzone 2100 Project
5
6 Warzone 2100 is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 Warzone 2100 is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with Warzone 2100; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20 /**
21 * @file netsocket.cpp
22 *
23 * Basic raw socket handling code.
24 */
25
26 #include "lib/framework/frame.h"
27 #include "lib/framework/wzapp.h"
28 #include "netsocket.h"
29
30 #include <vector>
31 #include <algorithm>
32 #include <map>
33
34 #if !defined(ZLIB_CONST)
35 # define ZLIB_CONST
36 #endif
37 #include <zlib.h>
38
39 enum
40 {
41 SOCK_CONNECTION,
42 SOCK_IPV4_LISTEN = SOCK_CONNECTION,
43 SOCK_IPV6_LISTEN,
44 SOCK_COUNT,
45 };
46
47 struct Socket
48 {
49 /* Multiple socket handles only for listening sockets. This allows us
50 * to listen on multiple protocols and address families (e.g. IPv4 and
51 * IPv6).
52 *
53 * All non-listening sockets will only use the first socket handle.
54 */
SocketSocket55 Socket() : ready(false), writeError(false), deleteLater(false), isCompressed(false), readDisconnected(false), zDeflateInSize(0)
56 {
57 memset(&zDeflate, 0, sizeof(zDeflate));
58 memset(&zInflate, 0, sizeof(zInflate));
59 }
60 ~Socket();
61
62 SOCKET fd[SOCK_COUNT];
63 bool ready;
64 bool writeError;
65 bool deleteLater;
66 char textAddress[40];
67
68 bool isCompressed;
69 bool readDisconnected; ///< True iff a call to recv() returned 0.
70 z_stream zDeflate;
71 z_stream zInflate;
72 unsigned zDeflateInSize;
73 bool zInflateNeedInput;
74 std::vector<uint8_t> zDeflateOutBuf;
75 std::vector<uint8_t> zInflateInBuf;
76 };
77
78 struct SocketSet
79 {
80 std::vector<Socket *> fds;
81 };
82
83
84 static WZ_MUTEX *socketThreadMutex;
85 static WZ_SEMAPHORE *socketThreadSemaphore;
86 static WZ_THREAD *socketThread = nullptr;
87 static bool socketThreadQuit;
88 typedef std::map<Socket *, std::vector<uint8_t>> SocketThreadWriteMap;
89 static SocketThreadWriteMap socketThreadWrites;
90
91
92 static void socketCloseNow(Socket *sock);
93
94
socketReadReady(Socket const * sock)95 bool socketReadReady(Socket const *sock)
96 {
97 return sock->ready;
98 }
99
getSockErr(void)100 int getSockErr(void)
101 {
102 #if defined(WZ_OS_UNIX)
103 return errno;
104 #elif defined(WZ_OS_WIN)
105 return WSAGetLastError();
106 #endif
107 }
108
setSockErr(int error)109 void setSockErr(int error)
110 {
111 #if defined(WZ_OS_UNIX)
112 errno = error;
113 #elif defined(WZ_OS_WIN)
114 WSASetLastError(error);
115 #endif
116 }
117
118 #if defined(WZ_OS_WIN)
119 typedef int (WINAPI *GETADDRINFO_DLL_FUNC)(const char *node, const char *service,
120 const struct addrinfo *hints,
121 struct addrinfo **res);
122 typedef int (WINAPI *FREEADDRINFO_DLL_FUNC)(struct addrinfo *res);
123
124 static HMODULE winsock2_dll = nullptr;
125
126 static GETADDRINFO_DLL_FUNC getaddrinfo_dll_func = nullptr;
127 static FREEADDRINFO_DLL_FUNC freeaddrinfo_dll_func = nullptr;
128
129 # define getaddrinfo getaddrinfo_dll_dispatcher
130 # define freeaddrinfo freeaddrinfo_dll_dispatcher
131
132 # include <ntverp.h> // Windows SDK - include for access to VER_PRODUCTBUILD
133 # if VER_PRODUCTBUILD >= 9600
134 // 9600 is the Windows SDK 8.1
135 # include <VersionHelpers.h> // For IsWindowsVistaOrGreater()
136 # else
137 // Earlier SDKs may not have VersionHelpers.h - use simple fallback
IsWindowsVistaOrGreater()138 inline bool IsWindowsVistaOrGreater()
139 {
140 DWORD dwMajorVersion = (DWORD)(LOBYTE(LOWORD(GetVersion())));
141 return dwMajorVersion >= 6;
142 }
143 # endif
144
getaddrinfo(const char * node,const char * service,const struct addrinfo * hints,struct addrinfo ** res)145 static int getaddrinfo(const char *node, const char *service,
146 const struct addrinfo *hints,
147 struct addrinfo **res)
148 {
149 struct addrinfo hint;
150 if (hints)
151 {
152 memcpy(&hint, hints, sizeof(hint));
153 }
154
155 // // Windows 95, 98 and ME
156 // debug(LOG_ERROR, "Name resolution isn't supported on this version of Windows");
157 // return EAI_FAIL;
158
159 if (!IsWindowsVistaOrGreater())
160 {
161 // Windows 2000, XP and Server 2003
162 if (hints)
163 {
164 // These flags are only supported from Windows Vista+
165 hint.ai_flags &= ~(AI_V4MAPPED | AI_ADDRCONFIG);
166 }
167 }
168
169 if (!winsock2_dll)
170 {
171 debug(LOG_ERROR, "Failed to load winsock2 DLL. Required for name resolution.");
172 return EAI_FAIL;
173 }
174
175 if (!getaddrinfo_dll_func)
176 {
177 debug(LOG_ERROR, "Failed to retrieve \"getaddrinfo\" function from winsock2 DLL. Required for name resolution.");
178 return EAI_FAIL;
179 }
180
181 return getaddrinfo_dll_func(node, service, hints ? &hint : NULL, res);
182 }
183
freeaddrinfo(struct addrinfo * res)184 static void freeaddrinfo(struct addrinfo *res)
185 {
186
187 // // Windows 95, 98 and ME
188 // debug(LOG_ERROR, "Name resolution isn't supported on this version of Windows");
189 // return;
190
191 if (!winsock2_dll)
192 {
193 debug(LOG_ERROR, "Failed to load winsock2 DLL. Required for name resolution.");
194 return;
195 }
196
197 if (!freeaddrinfo_dll_func)
198 {
199 debug(LOG_ERROR, "Failed to retrieve \"freeaddrinfo\" function from winsock2 DLL. Required for name resolution.");
200 return;
201 }
202
203 freeaddrinfo_dll_func(res);
204 }
205 #endif
206
addressToText(const struct sockaddr * addr,char * buf,size_t size)207 static int addressToText(const struct sockaddr *addr, char *buf, size_t size)
208 {
209 auto handleIpv4 = [&](uint32_t addr) {
210 uint32_t val = ntohl(addr);
211 return snprintf(buf, size, "%u.%u.%u.%u", (val>>24)&0xFF, (val>>16)&0xFF, (val>>8)&0xFF, val&0xFF);
212 };
213
214 switch (addr->sa_family)
215 {
216 case AF_INET:
217 {
218 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
219 # pragma GCC diagnostic push
220 # pragma GCC diagnostic ignored "-Wcast-align"
221 #endif
222 return handleIpv4((reinterpret_cast<const sockaddr_in *>(addr))->sin_addr.s_addr);
223 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
224 # pragma GCC diagnostic pop
225 #endif
226 }
227 case AF_INET6:
228 {
229
230 // Check to see if this is really a IPv6 address
231 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
232 # pragma GCC diagnostic push
233 # pragma GCC diagnostic ignored "-Wcast-align"
234 #endif
235 const struct sockaddr_in6 *mappedIP = reinterpret_cast<const sockaddr_in6 *>(addr);
236 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
237 # pragma GCC diagnostic pop
238 #endif
239 if (IN6_IS_ADDR_V4MAPPED(&mappedIP->sin6_addr))
240 {
241 // looks like it is ::ffff:(...) so lets set up a IPv4 socket address structure
242 // slightly overkill for our needs, but it shows exactly what needs to be done.
243 // At this time, we only care about the address, nothing else.
244 struct sockaddr_in addr4;
245 memcpy(&addr4.sin_addr.s_addr, mappedIP->sin6_addr.s6_addr + 12, sizeof(addr4.sin_addr.s_addr));
246 return handleIpv4(addr4.sin_addr.s_addr);
247 }
248 else
249 {
250 static_assert(sizeof(in6_addr::s6_addr) == 16, "Standard expects in6_addr structure that contains member s6_addr[16], a 16-element array of uint8_t");
251 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
252 # pragma GCC diagnostic push
253 # pragma GCC diagnostic ignored "-Wcast-align"
254 #endif
255 const uint8_t *address_u8 = &((reinterpret_cast<const sockaddr_in6 *>(addr))->sin6_addr.s6_addr[0]);
256 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
257 # pragma GCC diagnostic pop
258 #endif
259 uint16_t address[8] = {0};
260 memcpy(&address, address_u8, sizeof(uint8_t) * 16);
261 return snprintf(buf, size,
262 "%hx:%hx:%hx:%hx:%hx:%hx:%hx:%hx",
263 ntohs(address[0]),
264 ntohs(address[1]),
265 ntohs(address[2]),
266 ntohs(address[3]),
267 ntohs(address[4]),
268 ntohs(address[5]),
269 ntohs(address[6]),
270 ntohs(address[7]));
271 }
272 }
273 default:
274 ASSERT(!"Unknown address family", "Got non IPv4 or IPv6 address!");
275 return -1;
276 }
277 }
278
strSockError(int error)279 const char *strSockError(int error)
280 {
281 #if defined(WZ_OS_WIN)
282 switch (error)
283 {
284 case 0: return "No error";
285 case WSAEINTR: return "Interrupted system call";
286 case WSAEBADF: return "Bad file number";
287 case WSAEACCES: return "Permission denied";
288 case WSAEFAULT: return "Bad address";
289 case WSAEINVAL: return "Invalid argument";
290 case WSAEMFILE: return "Too many open sockets";
291 case WSAEWOULDBLOCK: return "Operation would block";
292 case WSAEINPROGRESS: return "Operation now in progress";
293 case WSAEALREADY: return "Operation already in progress";
294 case WSAENOTSOCK: return "Socket operation on non-socket";
295 case WSAEDESTADDRREQ: return "Destination address required";
296 case WSAEMSGSIZE: return "Message too long";
297 case WSAEPROTOTYPE: return "Protocol wrong type for socket";
298 case WSAENOPROTOOPT: return "Bad protocol option";
299 case WSAEPROTONOSUPPORT: return "Protocol not supported";
300 case WSAESOCKTNOSUPPORT: return "Socket type not supported";
301 case WSAEOPNOTSUPP: return "Operation not supported on socket";
302 case WSAEPFNOSUPPORT: return "Protocol family not supported";
303 case WSAEAFNOSUPPORT: return "Address family not supported";
304 case WSAEADDRINUSE: return "Address already in use";
305 case WSAEADDRNOTAVAIL: return "Can't assign requested address";
306 case WSAENETDOWN: return "Network is down";
307 case WSAENETUNREACH: return "Network is unreachable";
308 case WSAENETRESET: return "Net connection reset";
309 case WSAECONNABORTED: return "Software caused connection abort";
310 case WSAECONNRESET: return "Connection reset by peer";
311 case WSAENOBUFS: return "No buffer space available";
312 case WSAEISCONN: return "Socket is already connected";
313 case WSAENOTCONN: return "Socket is not connected";
314 case WSAESHUTDOWN: return "Can't send after socket shutdown";
315 case WSAETOOMANYREFS: return "Too many references, can't splice";
316 case WSAETIMEDOUT: return "Connection timed out";
317 case WSAECONNREFUSED: return "Connection refused";
318 case WSAELOOP: return "Too many levels of symbolic links";
319 case WSAENAMETOOLONG: return "File name too long";
320 case WSAEHOSTDOWN: return "Host is down";
321 case WSAEHOSTUNREACH: return "No route to host";
322 case WSAENOTEMPTY: return "Directory not empty";
323 case WSAEPROCLIM: return "Too many processes";
324 case WSAEUSERS: return "Too many users";
325 case WSAEDQUOT: return "Disc quota exceeded";
326 case WSAESTALE: return "Stale NFS file handle";
327 case WSAEREMOTE: return "Too many levels of remote in path";
328 case WSASYSNOTREADY: return "Network system is unavailable";
329 case WSAVERNOTSUPPORTED: return "Winsock version out of range";
330 case WSANOTINITIALISED: return "WSAStartup not yet called";
331 case WSAEDISCON: return "Graceful shutdown in progress";
332 case WSAHOST_NOT_FOUND: return "Host not found";
333 case WSANO_DATA: return "No host data of that type was found";
334 default: return "Unknown error";
335 }
336 #elif defined(WZ_OS_UNIX)
337 return strerror(error);
338 #endif
339 }
340
341 /**
342 * Test whether the given socket still has an open connection.
343 *
344 * @return true when the connection is open, false when it's closed or in an
345 * error state, check getSockErr() to find out which.
346 */
connectionIsOpen(Socket * sock)347 static bool connectionIsOpen(Socket *sock)
348 {
349 const SocketSet set = {std::vector<Socket *>(1, sock)};
350
351 ASSERT_OR_RETURN((setSockErr(EBADF), false),
352 sock && sock->fd[SOCK_CONNECTION] != INVALID_SOCKET, "Invalid socket");
353
354 // Check whether the socket is still connected
355 int ret = checkSockets(&set, 0);
356 if (ret == SOCKET_ERROR)
357 {
358 return false;
359 }
360 else if (ret == (int)set.fds.size() && sock->ready)
361 {
362 /* The next recv(2) call won't block, but we're writing. So
363 * check the read queue to see if the connection is closed.
364 * If there's no data in the queue that means the connection
365 * is closed.
366 */
367 #if defined(WZ_OS_WIN)
368 unsigned long readQueue;
369 ret = ioctlsocket(sock->fd[SOCK_CONNECTION], FIONREAD, &readQueue);
370 #else
371 int readQueue;
372 ret = ioctl(sock->fd[SOCK_CONNECTION], FIONREAD, &readQueue);
373 #endif
374 if (ret == SOCKET_ERROR)
375 {
376 debug(LOG_NET, "socket error");
377 return false;
378 }
379 else if (readQueue == 0)
380 {
381 // Disconnected
382 setSockErr(ECONNRESET);
383 debug(LOG_NET, "Read queue empty - failing (ECONNRESET)");
384 return false;
385 }
386 }
387
388 return true;
389 }
390
socketThreadFunction(void *)391 static int socketThreadFunction(void *)
392 {
393 wzMutexLock(socketThreadMutex);
394 while (!socketThreadQuit)
395 {
396 #if defined(WZ_OS_UNIX)
397 SOCKET maxfd = INT_MIN;
398 #elif defined(WZ_OS_WIN)
399 SOCKET maxfd = 0;
400 #endif
401 fd_set fds;
402 FD_ZERO(&fds);
403 for (SocketThreadWriteMap::iterator i = socketThreadWrites.begin(); i != socketThreadWrites.end(); ++i)
404 {
405 if (!i->second.empty())
406 {
407 SOCKET fd = i->first->fd[SOCK_CONNECTION];
408 maxfd = std::max(maxfd, fd);
409 ASSERT(!FD_ISSET(fd, &fds), "Duplicate file descriptor!"); // Shouldn't be possible, but blocking in send, after select says it won't block, shouldn't be possible either.
410 FD_SET(fd, &fds);
411 }
412 }
413 struct timeval tv = {0, 50 * 1000};
414
415 // Check if we can write to any sockets.
416 wzMutexUnlock(socketThreadMutex);
417 int ret = select(maxfd + 1, nullptr, &fds, nullptr, &tv);
418 wzMutexLock(socketThreadMutex);
419
420 // We can write to some sockets. (Ignore errors from select, we may have deleted the socket after unlocking the mutex, and before calling select.)
421 if (ret > 0)
422 {
423 for (SocketThreadWriteMap::iterator i = socketThreadWrites.begin(); i != socketThreadWrites.end();)
424 {
425 SocketThreadWriteMap::iterator w = i;
426 ++i;
427
428 Socket *sock = w->first;
429 std::vector<uint8_t> &writeQueue = w->second;
430 ASSERT(!writeQueue.empty(), "writeQueue[sock] must not be empty.");
431
432 if (!FD_ISSET(sock->fd[SOCK_CONNECTION], &fds))
433 {
434 continue; // This socket is not ready for writing, or we don't have anything to write.
435 }
436
437 // Write data.
438 // FIXME SOMEHOW AAARGH This send() call can't block, but unless the socket is not set to blocking (setting the socket to nonblocking had better work, or else), does anyway (at least sometimes, when someone quits). Not reproducible except in public releases.
439 ssize_t retSent = send(sock->fd[SOCK_CONNECTION], reinterpret_cast<char *>(&writeQueue[0]), writeQueue.size(), MSG_NOSIGNAL);
440 if (retSent != SOCKET_ERROR)
441 {
442 // Erase as much data as written.
443 writeQueue.erase(writeQueue.begin(), writeQueue.begin() + retSent);
444 if (writeQueue.empty())
445 {
446 socketThreadWrites.erase(w); // Nothing left to write, delete from pending list.
447 if (sock->deleteLater)
448 {
449 socketCloseNow(sock);
450 }
451 }
452 }
453 else
454 {
455 switch (getSockErr())
456 {
457 case EAGAIN:
458 #if defined(EWOULDBLOCK) && EAGAIN != EWOULDBLOCK
459 case EWOULDBLOCK:
460 #endif
461 if (!connectionIsOpen(sock))
462 {
463 debug(LOG_NET, "Socket error");
464 sock->writeError = true;
465 socketThreadWrites.erase(w); // Socket broken, don't try writing to it again.
466 if (sock->deleteLater)
467 {
468 socketCloseNow(sock);
469 }
470 break;
471 }
472 case EINTR:
473 break;
474 #if defined(EPIPE)
475 case EPIPE:
476 #endif
477 default:
478 sock->writeError = true;
479 socketThreadWrites.erase(w); // Socket broken, don't try writing to it again.
480 if (sock->deleteLater)
481 {
482 socketCloseNow(sock);
483 }
484 break;
485 }
486 }
487 }
488 }
489
490 if (socketThreadWrites.empty())
491 {
492 // Nothing to do, expect to wait.
493 wzMutexUnlock(socketThreadMutex);
494 wzSemaphoreWait(socketThreadSemaphore);
495 wzMutexLock(socketThreadMutex);
496 }
497 }
498 wzMutexUnlock(socketThreadMutex);
499
500 return 42; // Return value arbitrary and unused.
501 }
502
503 /**
504 * Similar to read(2) with the exception that this function won't be
505 * interrupted by signals (EINTR).
506 */
readNoInt(Socket * sock,void * buf,size_t max_size,size_t * rawByteCount)507 ssize_t readNoInt(Socket *sock, void *buf, size_t max_size, size_t *rawByteCount)
508 {
509 size_t ignored;
510 size_t &rawBytes = rawByteCount != nullptr ? *rawByteCount : ignored;
511 rawBytes = 0;
512
513 if (sock->fd[SOCK_CONNECTION] == INVALID_SOCKET)
514 {
515 debug(LOG_ERROR, "Invalid socket");
516 setSockErr(EBADF);
517 return SOCKET_ERROR;
518 }
519
520 if (sock->isCompressed)
521 {
522 if (sock->zInflateNeedInput)
523 {
524 // No input data, read some.
525
526 sock->zInflateInBuf.resize(max_size + 1000);
527
528 ssize_t received;
529 do
530 {
531 // v----- This weird cast is because recv() takes a char * on windows instead of a void *...
532 received = recv(sock->fd[SOCK_CONNECTION], (char *)&sock->zInflateInBuf[0], sock->zInflateInBuf.size(), 0);
533 }
534 while (received == SOCKET_ERROR && getSockErr() == EINTR);
535 if (received < 0)
536 {
537 return received;
538 }
539
540 sock->zInflate.next_in = &sock->zInflateInBuf[0];
541 sock->zInflate.avail_in = received;
542 rawBytes = received;
543
544 if (received == 0)
545 {
546 sock->readDisconnected = true;
547 }
548 else
549 {
550 sock->zInflateNeedInput = false;
551 }
552 }
553
554 sock->zInflate.next_out = (Bytef *)buf;
555 sock->zInflate.avail_out = max_size;
556 int ret = inflate(&sock->zInflate, Z_NO_FLUSH);
557 ASSERT(ret != Z_STREAM_ERROR, "zlib inflate not working!");
558 char const *err = nullptr;
559 switch (ret)
560 {
561 case Z_NEED_DICT: err = "Z_NEED_DICT"; break;
562 case Z_DATA_ERROR: err = "Z_DATA_ERROR"; break;
563 case Z_MEM_ERROR: err = "Z_MEM_ERROR"; break;
564 }
565 if (err != nullptr)
566 {
567 debug(LOG_ERROR, "Couldn't decompress data from socket. zlib error %s", err);
568 return -1; // Bad data!
569 }
570
571 if (sock->zInflate.avail_out != 0)
572 {
573 sock->zInflateNeedInput = true;
574 ASSERT(sock->zInflate.avail_in == 0, "zlib not consuming all input!");
575 }
576
577 return max_size - sock->zInflate.avail_out; // Got some data, return how much.
578 }
579
580 ssize_t received;
581 do
582 {
583 received = recv(sock->fd[SOCK_CONNECTION], (char *)buf, max_size, 0);
584 if (received == 0)
585 {
586 sock->readDisconnected = true;
587 }
588 }
589 while (received == SOCKET_ERROR && getSockErr() == EINTR);
590
591 sock->ready = false;
592
593 rawBytes = received;
594 return received;
595 }
596
socketReadDisconnected(Socket * sock)597 bool socketReadDisconnected(Socket *sock)
598 {
599 return sock->readDisconnected;
600 }
601
602 /**
603 * Similar to write(2) with the exception that this function will block until
604 * <em>all</em> data has been written or an error occurs.
605 *
606 * @return @c size when successful or @c SOCKET_ERROR if an error occurred.
607 */
writeAll(Socket * sock,const void * buf,size_t size,size_t * rawByteCount)608 ssize_t writeAll(Socket *sock, const void *buf, size_t size, size_t *rawByteCount)
609 {
610 size_t ignored;
611 size_t &rawBytes = rawByteCount != nullptr ? *rawByteCount : ignored;
612 rawBytes = 0;
613
614 if (sock->fd[SOCK_CONNECTION] == INVALID_SOCKET)
615 {
616 debug(LOG_ERROR, "Invalid socket (EBADF)");
617 setSockErr(EBADF);
618 return SOCKET_ERROR;
619 }
620
621 if (sock->writeError)
622 {
623 return SOCKET_ERROR;
624 }
625
626 if (size > 0)
627 {
628 if (!sock->isCompressed)
629 {
630 wzMutexLock(socketThreadMutex);
631 if (socketThreadWrites.empty())
632 {
633 wzSemaphorePost(socketThreadSemaphore);
634 }
635 std::vector<uint8_t> &writeQueue = socketThreadWrites[sock];
636 writeQueue.insert(writeQueue.end(), static_cast<char const *>(buf), static_cast<char const *>(buf) + size);
637 wzMutexUnlock(socketThreadMutex);
638 rawBytes = size;
639 }
640 else
641 {
642 #if ZLIB_VERNUM < 0x1252
643 // zlib < 1.2.5.2 does not support `#define ZLIB_CONST`
644 // Unfortunately, some OSes (ex. OpenBSD) ship with zlib < 1.2.5.2
645 // Workaround: cast away the const of the input, and disable the resulting -Wcast-qual warning
646 #if defined(__clang__)
647 # pragma clang diagnostic push
648 # pragma clang diagnostic ignored "-Wcast-qual"
649 #elif defined(__GNUC__)
650 # pragma GCC diagnostic push
651 # pragma GCC diagnostic ignored "-Wcast-qual"
652 #endif
653
654 // cast away the const for earlier zlib versions
655 sock->zDeflate.next_in = (Bytef *)buf; // -Wcast-qual
656
657 #if defined(__clang__)
658 # pragma clang diagnostic pop
659 #elif defined(__GNUC__)
660 # pragma GCC diagnostic pop
661 #endif
662 #else
663 // zlib >= 1.2.5.2 supports ZLIB_CONST
664 sock->zDeflate.next_in = (const Bytef *)buf;
665 #endif
666
667 sock->zDeflate.avail_in = size;
668 sock->zDeflateInSize += sock->zDeflate.avail_in;
669 do
670 {
671 size_t alreadyHave = sock->zDeflateOutBuf.size();
672 sock->zDeflateOutBuf.resize(alreadyHave + size + 20); // A bit more than size should be enough to always do everything in one go.
673 sock->zDeflate.next_out = (Bytef *)&sock->zDeflateOutBuf[alreadyHave];
674 sock->zDeflate.avail_out = sock->zDeflateOutBuf.size() - alreadyHave;
675
676 int ret = deflate(&sock->zDeflate, Z_NO_FLUSH);
677 ASSERT(ret != Z_STREAM_ERROR, "zlib compression failed!");
678
679 // Remove unused part of buffer.
680 sock->zDeflateOutBuf.resize(sock->zDeflateOutBuf.size() - sock->zDeflate.avail_out);
681 }
682 while (sock->zDeflate.avail_out == 0);
683
684 ASSERT(sock->zDeflate.avail_in == 0, "zlib didn't compress everything!");
685 }
686 }
687
688 return size;
689 }
690
socketFlush(Socket * sock,size_t * rawByteCount)691 void socketFlush(Socket *sock, size_t *rawByteCount)
692 {
693 size_t ignored;
694 size_t &rawBytes = rawByteCount != nullptr ? *rawByteCount : ignored;
695 rawBytes = 0;
696
697 if (!sock->isCompressed)
698 {
699 return; // Not compressed, so don't mess with zlib.
700 }
701
702 // Flush data out of zlib compression state.
703 do
704 {
705 sock->zDeflate.next_in = (Bytef *)nullptr;
706 sock->zDeflate.avail_in = 0;
707 size_t alreadyHave = sock->zDeflateOutBuf.size();
708 sock->zDeflateOutBuf.resize(alreadyHave + 1000); // 100 bytes would probably be enough to flush the rest in one go.
709 sock->zDeflate.next_out = (Bytef *)&sock->zDeflateOutBuf[alreadyHave];
710 sock->zDeflate.avail_out = sock->zDeflateOutBuf.size() - alreadyHave;
711
712 int ret = deflate(&sock->zDeflate, Z_PARTIAL_FLUSH);
713 ASSERT(ret != Z_STREAM_ERROR, "zlib compression failed!");
714
715 // Remove unused part of buffer.
716 sock->zDeflateOutBuf.resize(sock->zDeflateOutBuf.size() - sock->zDeflate.avail_out);
717 }
718 while (sock->zDeflate.avail_out == 0);
719
720 if (sock->zDeflateOutBuf.empty())
721 {
722 return; // No data to flush out.
723 }
724
725 wzMutexLock(socketThreadMutex);
726 if (socketThreadWrites.empty())
727 {
728 wzSemaphorePost(socketThreadSemaphore);
729 }
730 std::vector<uint8_t> &writeQueue = socketThreadWrites[sock];
731 writeQueue.insert(writeQueue.end(), sock->zDeflateOutBuf.begin(), sock->zDeflateOutBuf.end());
732 wzMutexUnlock(socketThreadMutex);
733
734 // Primitive network logging, uncomment to use.
735 //printf("Size %3u ->%3zu, buf =", sock->zDeflateInSize, sock->zDeflateOutBuf.size());
736 //for (unsigned n = 0; n < std::min<unsigned>(sock->zDeflateOutBuf.size(), 40); ++n) printf(" %02X", sock->zDeflateOutBuf[n]);
737 //printf("\n");
738
739 // Data sent, don't send again.
740 rawBytes = sock->zDeflateOutBuf.size();
741 sock->zDeflateInSize = 0;
742 sock->zDeflateOutBuf.clear();
743 }
744
socketBeginCompression(Socket * sock)745 void socketBeginCompression(Socket *sock)
746 {
747 if (sock->isCompressed)
748 {
749 return; // Nothing to do.
750 }
751
752 wzMutexLock(socketThreadMutex);
753
754 // Init deflate.
755 sock->zDeflate.zalloc = Z_NULL;
756 sock->zDeflate.zfree = Z_NULL;
757 sock->zDeflate.opaque = Z_NULL;
758 int ret = deflateInit(&sock->zDeflate, 6);
759 ASSERT(ret == Z_OK, "deflateInit failed! Sockets won't work.");
760
761 sock->zInflate.zalloc = Z_NULL;
762 sock->zInflate.zfree = Z_NULL;
763 sock->zInflate.opaque = Z_NULL;
764 sock->zInflate.avail_in = 0;
765 sock->zInflate.next_in = Z_NULL;
766 ret = inflateInit(&sock->zInflate);
767 ASSERT(ret == Z_OK, "deflateInit failed! Sockets won't work.");
768
769 sock->zInflateNeedInput = true;
770
771 sock->isCompressed = true;
772 wzMutexUnlock(socketThreadMutex);
773 }
774
~Socket()775 Socket::~Socket()
776 {
777 if (isCompressed)
778 {
779 deflateEnd(&zDeflate);
780 deflateEnd(&zInflate);
781 }
782 }
783
allocSocketSet()784 SocketSet *allocSocketSet()
785 {
786 return new SocketSet;
787 }
788
deleteSocketSet(SocketSet * set)789 void deleteSocketSet(SocketSet *set)
790 {
791 delete set;
792 }
793
794 /**
795 * Add the given socket to the given socket set.
796 *
797 * @return true if @c socket is successfully added to @set.
798 */
SocketSet_AddSocket(SocketSet * set,Socket * socket)799 void SocketSet_AddSocket(SocketSet *set, Socket *socket)
800 {
801 /* Check whether this socket is already present in this set (i.e. it
802 * shouldn't be added again).
803 */
804 size_t i = std::find(set->fds.begin(), set->fds.end(), socket) - set->fds.begin();
805 if (i != set->fds.size())
806 {
807 debug(LOG_NET, "Already found, socket: (set->fds[%lu]) %p", (unsigned long)i, static_cast<void *>(socket));
808 return;
809 }
810
811 set->fds.push_back(socket);
812 debug(LOG_NET, "Socket added: set->fds[%lu] = %p", (unsigned long)i, static_cast<void *>(socket));
813 }
814
815 /**
816 * Remove the given socket from the given socket set.
817 */
SocketSet_DelSocket(SocketSet * set,Socket * socket)818 void SocketSet_DelSocket(SocketSet *set, Socket *socket)
819 {
820 size_t i = std::find(set->fds.begin(), set->fds.end(), socket) - set->fds.begin();
821 if (i != set->fds.size())
822 {
823 debug(LOG_NET, "Socket %p erased (set->fds[%lu])", static_cast<void *>(socket), (unsigned long)i);
824 set->fds.erase(set->fds.begin() + i);
825 }
826 }
827
828 #if !defined(SOCK_CLOEXEC)
setSocketInheritable(SOCKET fd,bool inheritable)829 static bool setSocketInheritable(SOCKET fd, bool inheritable)
830 {
831 #if defined(WZ_OS_UNIX)
832 int sockopts = fcntl(fd, F_SETFD);
833 if (sockopts == SOCKET_ERROR)
834 {
835 debug(LOG_NET, "Failed to retrieve current socket options: %s", strSockError(getSockErr()));
836 return false;
837 }
838
839 // Set or clear FD_CLOEXEC flag
840 if (inheritable)
841 {
842 sockopts &= ~FD_CLOEXEC;
843 }
844 else
845 {
846 sockopts |= FD_CLOEXEC;
847 }
848
849 if (fcntl(fd, F_SETFD, sockopts) == SOCKET_ERROR)
850 {
851 debug(LOG_NET, "Failed to set socket %sinheritable: %s", (inheritable ? "" : "non-"), strSockError(getSockErr()));
852 return false;
853 }
854 #elif defined(WZ_OS_WIN)
855 DWORD dwFlags = (inheritable) ? HANDLE_FLAG_INHERIT : 0;
856 if (::SetHandleInformation((HANDLE)fd, HANDLE_FLAG_INHERIT, dwFlags) == 0)
857 {
858 DWORD dwErr = GetLastError();
859 debug(LOG_NET, "Failed to set socket %sinheritable: %s", (inheritable ? "" : "non-"), std::to_string(dwErr).c_str());
860 return false;
861 }
862 #endif
863
864 debug(LOG_NET, "Socket is set to %sinheritable.", (inheritable ? "" : "non-"));
865 return true;
866 }
867 #endif // !defined(SOCK_CLOEXEC)
868
setSocketBlocking(const SOCKET fd,bool blocking)869 static bool setSocketBlocking(const SOCKET fd, bool blocking)
870 {
871 #if defined(WZ_OS_UNIX)
872 int sockopts = fcntl(fd, F_GETFL);
873 if (sockopts == SOCKET_ERROR)
874 {
875 debug(LOG_NET, "Failed to retrieve current socket options: %s", strSockError(getSockErr()));
876 return false;
877 }
878
879 // Set or clear O_NONBLOCK flag
880 if (blocking)
881 {
882 sockopts &= ~O_NONBLOCK;
883 }
884 else
885 {
886 sockopts |= O_NONBLOCK;
887 }
888
889 if (fcntl(fd, F_SETFL, sockopts) == SOCKET_ERROR)
890 #elif defined(WZ_OS_WIN)
891 unsigned long nonblocking = !blocking;
892 if (ioctlsocket(fd, FIONBIO, &nonblocking) == SOCKET_ERROR)
893 #endif
894 {
895 debug(LOG_NET, "Failed to set socket %sblocking: %s", (blocking ? "" : "non-"), strSockError(getSockErr()));
896 return false;
897 }
898
899 debug(LOG_NET, "Socket is set to %sblocking.", (blocking ? "" : "non-"));
900 return true;
901 }
902
socketBlockSIGPIPE(const SOCKET fd,bool block_sigpipe)903 static void socketBlockSIGPIPE(const SOCKET fd, bool block_sigpipe)
904 {
905 #if defined(SO_NOSIGPIPE)
906 const int no_sigpipe = block_sigpipe ? 1 : 0;
907
908 if (setsockopt(fd, SOL_SOCKET, SO_NOSIGPIPE, &no_sigpipe, sizeof(no_sigpipe)) == SOCKET_ERROR)
909 {
910 debug(LOG_INFO, "Failed to set SO_NOSIGPIPE on socket, SIGPIPE might be raised when connections gets broken. Error: %s", strSockError(getSockErr()));
911 }
912 // this is only for unix, windows don't have SIGPIPE
913 debug(LOG_NET, "Socket fd %x sets SIGPIPE to %sblocked.", fd, (block_sigpipe ? "" : "non-"));
914 #else
915 // Prevent warnings
916 (void)fd;
917 (void)block_sigpipe;
918 #endif
919 }
920
checkSockets(const SocketSet * set,unsigned int timeout)921 int checkSockets(const SocketSet *set, unsigned int timeout)
922 {
923 if (set->fds.empty())
924 {
925 return 0;
926 }
927
928 #if defined(WZ_OS_UNIX)
929 SOCKET maxfd = INT_MIN;
930 #elif defined(WZ_OS_WIN)
931 SOCKET maxfd = 0;
932 #endif
933
934 bool compressedReady = false;
935 for (size_t i = 0; i < set->fds.size(); ++i)
936 {
937 ASSERT(set->fds[i]->fd[SOCK_CONNECTION] != INVALID_SOCKET, "Invalid file descriptor!");
938
939 if (set->fds[i]->isCompressed && !set->fds[i]->zInflateNeedInput)
940 {
941 compressedReady = true;
942 break;
943 }
944
945 maxfd = std::max(maxfd, set->fds[i]->fd[SOCK_CONNECTION]);
946 }
947
948 if (compressedReady)
949 {
950 // A socket already has some data ready. Don't really poll the sockets.
951
952 int ret = 0;
953 for (size_t i = 0; i < set->fds.size(); ++i)
954 {
955 set->fds[i]->ready = set->fds[i]->isCompressed && !set->fds[i]->zInflateNeedInput;
956 ++ret;
957 }
958 return ret;
959 }
960
961 int ret;
962 fd_set fds;
963 do
964 {
965 struct timeval tv = {(int)(timeout / 1000), (int)(timeout % 1000) * 1000}; // Cast to int to avoid narrowing needed for C++11.
966
967 FD_ZERO(&fds);
968 for (size_t i = 0; i < set->fds.size(); ++i)
969 {
970 const SOCKET fd = set->fds[i]->fd[SOCK_CONNECTION];
971
972 FD_SET(fd, &fds);
973 }
974
975 ret = select(maxfd + 1, &fds, nullptr, nullptr, &tv);
976 }
977 while (ret == SOCKET_ERROR && getSockErr() == EINTR);
978
979 if (ret == SOCKET_ERROR)
980 {
981 debug(LOG_ERROR, "select failed: %s", strSockError(getSockErr()));
982 return SOCKET_ERROR;
983 }
984
985 for (size_t i = 0; i < set->fds.size(); ++i)
986 {
987 set->fds[i]->ready = FD_ISSET(set->fds[i]->fd[SOCK_CONNECTION], &fds);
988 }
989
990 return ret;
991 }
992
993 /**
994 * Similar to read(2) with the exception that this function won't be
995 * interrupted by signals (EINTR) and will only return when <em>exactly</em>
996 * @c size bytes have been received. I.e. this function blocks until all data
997 * has been received or a timeout occurred.
998 *
999 * @param timeout When non-zero this function times out after @c timeout
1000 * milliseconds. When zero this function blocks until success or
1001 * an error occurs.
1002 *
1003 * @c return @c size when successful, less than @c size but at least zero (0)
1004 * when the other end disconnected or a timeout occurred. Or @c SOCKET_ERROR if
1005 * an error occurred.
1006 */
readAll(Socket * sock,void * buf,size_t size,unsigned int timeout)1007 ssize_t readAll(Socket *sock, void *buf, size_t size, unsigned int timeout)
1008 {
1009 ASSERT(!sock->isCompressed, "readAll on compressed sockets not implemented.");
1010
1011 const SocketSet set = {std::vector<Socket *>(1, sock)};
1012
1013 size_t received = 0;
1014
1015 if (sock->fd[SOCK_CONNECTION] == INVALID_SOCKET)
1016 {
1017 debug(LOG_ERROR, "Invalid socket (%p), sock->fd[SOCK_CONNECTION]=%x (error: EBADF)", static_cast<void *>(sock), sock->fd[SOCK_CONNECTION]);
1018 setSockErr(EBADF);
1019 return SOCKET_ERROR;
1020 }
1021
1022 while (received < size)
1023 {
1024 ssize_t ret;
1025
1026 // If a timeout is set, wait for that amount of time for data to arrive (or abort)
1027 if (timeout)
1028 {
1029 ret = checkSockets(&set, timeout);
1030 if (ret < (ssize_t)set.fds.size()
1031 || !sock->ready)
1032 {
1033 if (ret == 0)
1034 {
1035 debug(LOG_NET, "socket (%p) has timed out.", static_cast<void *>(sock));
1036 setSockErr(ETIMEDOUT);
1037 }
1038 debug(LOG_NET, "socket (%p) error.", static_cast<void *>(sock));
1039 return SOCKET_ERROR;
1040 }
1041 }
1042
1043 ret = recv(sock->fd[SOCK_CONNECTION], &((char *)buf)[received], size - received, 0);
1044 sock->ready = false;
1045 if (ret == 0)
1046 {
1047 debug(LOG_NET, "Socket %x disconnected.", sock->fd[SOCK_CONNECTION]);
1048 sock->readDisconnected = true;
1049 setSockErr(ECONNRESET);
1050 return received;
1051 }
1052
1053 if (ret == SOCKET_ERROR)
1054 {
1055 switch (getSockErr())
1056 {
1057 case EAGAIN:
1058 #if defined(EWOULDBLOCK) && EAGAIN != EWOULDBLOCK
1059 case EWOULDBLOCK:
1060 #endif
1061 case EINTR:
1062 continue;
1063
1064 default:
1065 return SOCKET_ERROR;
1066 }
1067 }
1068
1069 received += ret;
1070 }
1071
1072 return received;
1073 }
1074
socketCloseNow(Socket * sock)1075 static void socketCloseNow(Socket *sock)
1076 {
1077 for (unsigned i = 0; i < ARRAY_SIZE(sock->fd); ++i)
1078 {
1079 if (sock->fd[i] != INVALID_SOCKET)
1080 {
1081 #if defined(WZ_OS_WIN)
1082 int err = closesocket(sock->fd[i]);
1083 #else
1084 int err = close(sock->fd[i]);
1085 #endif
1086 if (err)
1087 {
1088 debug(LOG_ERROR, "Failed to close socket %p: %s", static_cast<void *>(sock), strSockError(getSockErr()));
1089 }
1090
1091 /* Make sure that dangling pointers to this
1092 * structure don't think they've got their
1093 * hands on a valid socket.
1094 */
1095 sock->fd[i] = INVALID_SOCKET;
1096 }
1097 }
1098 delete sock;
1099 }
1100
socketClose(Socket * sock)1101 void socketClose(Socket *sock)
1102 {
1103 wzMutexLock(socketThreadMutex);
1104 //Instead of socketThreadWrites.erase(sock);, try sending the data before actually deleting.
1105 if (socketThreadWrites.find(sock) != socketThreadWrites.end())
1106 {
1107 // Wait until the data is written, then delete the socket.
1108 sock->deleteLater = true;
1109 }
1110 else
1111 {
1112 // Delete the socket.
1113 socketCloseNow(sock);
1114 }
1115 wzMutexUnlock(socketThreadMutex);
1116 }
1117
socketAccept(Socket * sock)1118 Socket *socketAccept(Socket *sock)
1119 {
1120 unsigned int i;
1121
1122 /* Search for a socket that has a pending connection on it and accept
1123 * the first one.
1124 */
1125 for (i = 0; i < ARRAY_SIZE(sock->fd); ++i)
1126 {
1127 if (sock->fd[i] != INVALID_SOCKET)
1128 {
1129 struct sockaddr_storage addr;
1130 socklen_t addr_len = sizeof(addr);
1131 Socket *conn;
1132 unsigned int j;
1133
1134 #if defined(SOCK_CLOEXEC)
1135 const SOCKET newConn = accept4(sock->fd[i], (struct sockaddr *)&addr, &addr_len, SOCK_CLOEXEC);
1136 #else
1137 const SOCKET newConn = accept(sock->fd[i], (struct sockaddr *)&addr, &addr_len);
1138 #endif
1139 if (newConn == INVALID_SOCKET)
1140 {
1141 // Ignore the case where no connection is pending
1142 if (getSockErr() != EAGAIN
1143 && getSockErr() != EWOULDBLOCK)
1144 {
1145 debug(LOG_ERROR, "accept failed for socket %p: %s", static_cast<void *>(sock), strSockError(getSockErr()));
1146 }
1147
1148 continue;
1149 }
1150
1151 conn = new Socket;
1152 if (conn == nullptr)
1153 {
1154 debug(LOG_ERROR, "Out of memory!");
1155 abort();
1156 return nullptr;
1157 }
1158
1159 #if !defined(SOCK_CLOEXEC)
1160 if (!setSocketInheritable(newConn, false))
1161 {
1162 debug(LOG_NET, "Couldn't set socket (%p) inheritable status (false). Ignoring...", static_cast<void *>(conn));
1163 // ignore and continue
1164 }
1165 #endif
1166
1167 debug(LOG_NET, "setting socket (%p) blocking status (false).", static_cast<void *>(conn));
1168 if (!setSocketBlocking(newConn, false))
1169 {
1170 debug(LOG_NET, "Couldn't set socket (%p) blocking status (false). Closing.", static_cast<void *>(conn));
1171 socketClose(conn);
1172 return nullptr;
1173 }
1174
1175 socketBlockSIGPIPE(newConn, true);
1176
1177 // Mark all unused socket handles as invalid
1178 for (j = 0; j < ARRAY_SIZE(conn->fd); ++j)
1179 {
1180 conn->fd[j] = INVALID_SOCKET;
1181 }
1182
1183 conn->fd[SOCK_CONNECTION] = newConn;
1184
1185 sock->ready = false;
1186
1187 addressToText((const struct sockaddr *)&addr, conn->textAddress, sizeof(conn->textAddress));
1188 debug(LOG_NET, "Incoming connection from [%s]:/*%%d*/ (FIXME: gives strict-aliasing error)", conn->textAddress/*, (unsigned int)ntohs(((const struct sockaddr_in*)&addr)->sin_port)*/);
1189 debug(LOG_NET, "Using socket %p", static_cast<void *>(conn));
1190 return conn;
1191 }
1192 }
1193
1194 return nullptr;
1195 }
1196
socketOpen(const SocketAddress * addr,unsigned timeout)1197 Socket *socketOpen(const SocketAddress *addr, unsigned timeout)
1198 {
1199 unsigned int i;
1200 int ret;
1201
1202 Socket *const conn = new Socket;
1203 if (conn == nullptr)
1204 {
1205 debug(LOG_ERROR, "Out of memory!");
1206 abort();
1207 return nullptr;
1208 }
1209
1210 ASSERT(addr != nullptr, "NULL Socket provided");
1211
1212 addressToText(addr->ai_addr, conn->textAddress, sizeof(conn->textAddress));
1213 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
1214 # pragma GCC diagnostic push
1215 # pragma GCC diagnostic ignored "-Wcast-align"
1216 #endif
1217 debug(LOG_NET, "Connecting to [%s]:%d", conn->textAddress, (int)ntohs((reinterpret_cast<const sockaddr_in *>(addr->ai_addr))->sin_port));
1218 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
1219 # pragma GCC diagnostic pop
1220 #endif
1221
1222 // Mark all unused socket handles as invalid
1223 for (i = 0; i < ARRAY_SIZE(conn->fd); ++i)
1224 {
1225 conn->fd[i] = INVALID_SOCKET;
1226 }
1227
1228 int socket_type = addr->ai_socktype;
1229 #if defined(SOCK_CLOEXEC)
1230 socket_type |= SOCK_CLOEXEC;
1231 #endif
1232 conn->fd[SOCK_CONNECTION] = socket(addr->ai_family, socket_type, addr->ai_protocol);
1233
1234 if (conn->fd[SOCK_CONNECTION] == INVALID_SOCKET)
1235 {
1236 debug(LOG_ERROR, "Failed to create a socket (%p): %s", static_cast<void *>(conn), strSockError(getSockErr()));
1237 socketClose(conn);
1238 return nullptr;
1239 }
1240
1241 #if !defined(SOCK_CLOEXEC)
1242 if (!setSocketInheritable(conn->fd[SOCK_CONNECTION], false))
1243 {
1244 debug(LOG_NET, "Couldn't set socket (%p) inheritable status (false). Ignoring...", static_cast<void *>(conn));
1245 // ignore and continue
1246 }
1247 #endif
1248
1249 debug(LOG_NET, "setting socket (%p) blocking status (false).", static_cast<void *>(conn));
1250 if (!setSocketBlocking(conn->fd[SOCK_CONNECTION], false))
1251 {
1252 debug(LOG_NET, "Couldn't set socket (%p) blocking status (false). Closing.", static_cast<void *>(conn));
1253 socketClose(conn);
1254 return nullptr;
1255 }
1256
1257 socketBlockSIGPIPE(conn->fd[SOCK_CONNECTION], true);
1258
1259 ret = connect(conn->fd[SOCK_CONNECTION], addr->ai_addr, addr->ai_addrlen);
1260 if (ret == SOCKET_ERROR)
1261 {
1262 fd_set conReady;
1263 #if defined(WZ_OS_WIN)
1264 fd_set conFailed;
1265 #endif
1266
1267 if ((getSockErr() != EINPROGRESS
1268 && getSockErr() != EAGAIN
1269 && getSockErr() != EWOULDBLOCK)
1270 #if defined(WZ_OS_UNIX)
1271 || conn->fd[SOCK_CONNECTION] >= FD_SETSIZE
1272 #endif
1273 || timeout == 0)
1274 {
1275 debug(LOG_NET, "Failed to start connecting: %s, using socket %p", strSockError(getSockErr()), static_cast<void *>(conn));
1276 socketClose(conn);
1277 return nullptr;
1278 }
1279
1280 do
1281 {
1282 struct timeval tv = {(int)(timeout / 1000), (int)(timeout % 1000) * 1000}; // Cast to int to avoid narrowing needed for C++11.
1283
1284 FD_ZERO(&conReady);
1285 FD_SET(conn->fd[SOCK_CONNECTION], &conReady);
1286 #if defined(WZ_OS_WIN)
1287 FD_ZERO(&conFailed);
1288 FD_SET(conn->fd[SOCK_CONNECTION], &conFailed);
1289 #endif
1290
1291 #if defined(WZ_OS_WIN)
1292 ret = select(conn->fd[SOCK_CONNECTION] + 1, NULL, &conReady, &conFailed, &tv);
1293 #else
1294 ret = select(conn->fd[SOCK_CONNECTION] + 1, nullptr, &conReady, nullptr, &tv);
1295 #endif
1296 }
1297 while (ret == SOCKET_ERROR && getSockErr() == EINTR);
1298
1299 if (ret == SOCKET_ERROR)
1300 {
1301 debug(LOG_NET, "Failed to wait for connection: %s, socket %p. Closing.", strSockError(getSockErr()), static_cast<void *>(conn));
1302 socketClose(conn);
1303 return nullptr;
1304 }
1305
1306 if (ret == 0)
1307 {
1308 setSockErr(ETIMEDOUT);
1309 debug(LOG_NET, "Timed out while waiting for connection to be established: %s, using socket %p. Closing.", strSockError(getSockErr()), static_cast<void *>(conn));
1310 socketClose(conn);
1311 return nullptr;
1312 }
1313
1314 #if defined(WZ_OS_WIN)
1315 ASSERT(FD_ISSET(conn->fd[SOCK_CONNECTION], &conReady) || FD_ISSET(conn->fd[SOCK_CONNECTION], &conFailed), "\"sock\" is the only file descriptor in set, it should be the one that is set.");
1316 #else
1317 ASSERT(FD_ISSET(conn->fd[SOCK_CONNECTION], &conReady), "\"sock\" is the only file descriptor in set, it should be the one that is set.");
1318 #endif
1319
1320 #if defined(WZ_OS_WIN)
1321 if (FD_ISSET(conn->fd[SOCK_CONNECTION], &conFailed))
1322 #elif defined(WZ_OS_UNIX)
1323 if (connect(conn->fd[SOCK_CONNECTION], addr->ai_addr, addr->ai_addrlen) == SOCKET_ERROR
1324 && getSockErr() != EISCONN)
1325 #endif
1326 {
1327 debug(LOG_NET, "Failed to connect: %s, with socket %p. Closing.", strSockError(getSockErr()), static_cast<void *>(conn));
1328 socketClose(conn);
1329 return nullptr;
1330 }
1331 }
1332
1333 return conn;
1334 }
1335
socketListen(unsigned int port)1336 Socket *socketListen(unsigned int port)
1337 {
1338 /* Enable the V4 to V6 mapping, but only when available, because it
1339 * isn't available on all platforms.
1340 */
1341 #if defined(IPV6_V6ONLY)
1342 static const int ipv6_v6only = 0;
1343 #endif
1344 static const int so_reuseaddr = 1;
1345
1346 struct sockaddr_in addr4;
1347 struct sockaddr_in6 addr6;
1348 unsigned int i;
1349
1350 Socket *const conn = new Socket;
1351 if (conn == nullptr)
1352 {
1353 debug(LOG_ERROR, "Out of memory!");
1354 abort();
1355 return nullptr;
1356 }
1357
1358 // Mark all unused socket handles as invalid
1359 for (i = 0; i < ARRAY_SIZE(conn->fd); ++i)
1360 {
1361 conn->fd[i] = INVALID_SOCKET;
1362 }
1363
1364 strncpy(conn->textAddress, "LISTENING SOCKET", sizeof(conn->textAddress));
1365
1366 // Listen on all local IPv4 and IPv6 addresses for the given port
1367 addr4.sin_family = AF_INET;
1368 addr4.sin_port = htons(port);
1369 addr4.sin_addr.s_addr = INADDR_ANY;
1370
1371 addr6.sin6_family = AF_INET6;
1372 addr6.sin6_port = htons(port);
1373 addr6.sin6_addr = in6addr_any;
1374 addr6.sin6_flowinfo = 0;
1375 addr6.sin6_scope_id = 0;
1376
1377 int socket_type = SOCK_STREAM;
1378 #if defined(SOCK_CLOEXEC)
1379 socket_type |= SOCK_CLOEXEC;
1380 #endif
1381 conn->fd[SOCK_IPV4_LISTEN] = socket(addr4.sin_family, socket_type, 0);
1382 conn->fd[SOCK_IPV6_LISTEN] = socket(addr6.sin6_family, socket_type, 0);
1383
1384 if (conn->fd[SOCK_IPV4_LISTEN] == INVALID_SOCKET
1385 && conn->fd[SOCK_IPV6_LISTEN] == INVALID_SOCKET)
1386 {
1387 debug(LOG_ERROR, "Failed to create an IPv4 and IPv6 (only supported address families) socket (%p): %s. Closing.", static_cast<void *>(conn), strSockError(getSockErr()));
1388 socketClose(conn);
1389 return nullptr;
1390 }
1391
1392 if (conn->fd[SOCK_IPV4_LISTEN] != INVALID_SOCKET)
1393 {
1394 debug(LOG_NET, "Successfully created an IPv4 socket (%p)", static_cast<void *>(conn));
1395 }
1396
1397 if (conn->fd[SOCK_IPV6_LISTEN] != INVALID_SOCKET)
1398 {
1399 debug(LOG_NET, "Successfully created an IPv6 socket (%p)", static_cast<void *>(conn));
1400 }
1401
1402 #if defined(IPV6_V6ONLY)
1403 if (conn->fd[SOCK_IPV6_LISTEN] != INVALID_SOCKET)
1404 {
1405 if (setsockopt(conn->fd[SOCK_IPV6_LISTEN], IPPROTO_IPV6, IPV6_V6ONLY, (const char *)&ipv6_v6only, sizeof(ipv6_v6only)) == SOCKET_ERROR)
1406 {
1407 debug(LOG_INFO, "Failed to set IPv6 socket to perform IPv4 to IPv6 mapping. Falling back to using two sockets. Error: %s", strSockError(getSockErr()));
1408 }
1409 else
1410 {
1411 debug(LOG_NET, "Successfully enabled IPv4 to IPv6 mapping. Cleaning up IPv4 socket.");
1412 #if defined(WZ_OS_WIN)
1413 closesocket(conn->fd[SOCK_IPV4_LISTEN]);
1414 #else
1415 close(conn->fd[SOCK_IPV4_LISTEN]);
1416 #endif
1417 conn->fd[SOCK_IPV4_LISTEN] = INVALID_SOCKET;
1418 }
1419 }
1420 #endif
1421
1422 if (conn->fd[SOCK_IPV4_LISTEN] != INVALID_SOCKET)
1423 {
1424 #if !defined(SOCK_CLOEXEC)
1425 if (!setSocketInheritable(conn->fd[SOCK_IPV4_LISTEN], false))
1426 {
1427 debug(LOG_NET, "Couldn't set socket (%p) inheritable status (false). Ignoring...", static_cast<void *>(conn));
1428 // ignore and continue
1429 }
1430 #endif
1431
1432 if (setsockopt(conn->fd[SOCK_IPV4_LISTEN], SOL_SOCKET, SO_REUSEADDR, (const char *)&so_reuseaddr, sizeof(so_reuseaddr)) == SOCKET_ERROR)
1433 {
1434 debug(LOG_WARNING, "Failed to set SO_REUSEADDR on IPv4 socket. Error: %s", strSockError(getSockErr()));
1435 }
1436
1437 debug(LOG_NET, "setting socket (%p) blocking status (false, IPv4).", static_cast<void *>(conn));
1438 if (bind(conn->fd[SOCK_IPV4_LISTEN], (const struct sockaddr *)&addr4, sizeof(addr4)) == SOCKET_ERROR
1439 || listen(conn->fd[SOCK_IPV4_LISTEN], 5) == SOCKET_ERROR
1440 || !setSocketBlocking(conn->fd[SOCK_IPV4_LISTEN], false))
1441 {
1442 debug(LOG_ERROR, "Failed to set up IPv4 socket for listening on port %u: %s", port, strSockError(getSockErr()));
1443 #if defined(WZ_OS_WIN)
1444 closesocket(conn->fd[SOCK_IPV4_LISTEN]);
1445 #else
1446 close(conn->fd[SOCK_IPV4_LISTEN]);
1447 #endif
1448 conn->fd[SOCK_IPV4_LISTEN] = INVALID_SOCKET;
1449 }
1450 }
1451
1452 if (conn->fd[SOCK_IPV6_LISTEN] != INVALID_SOCKET)
1453 {
1454 #if !defined(SOCK_CLOEXEC)
1455 if (!setSocketInheritable(conn->fd[SOCK_IPV6_LISTEN], false))
1456 {
1457 debug(LOG_NET, "Couldn't set socket (%p) inheritable status (false). Ignoring...", static_cast<void *>(conn));
1458 // ignore and continue
1459 }
1460 #endif
1461
1462 if (setsockopt(conn->fd[SOCK_IPV6_LISTEN], SOL_SOCKET, SO_REUSEADDR, (const char *)&so_reuseaddr, sizeof(so_reuseaddr)) == SOCKET_ERROR)
1463 {
1464 debug(LOG_INFO, "Failed to set SO_REUSEADDR on IPv6 socket. Error: %s", strSockError(getSockErr()));
1465 }
1466
1467 debug(LOG_NET, "setting socket (%p) blocking status (false, IPv6).", static_cast<void *>(conn));
1468 if (bind(conn->fd[SOCK_IPV6_LISTEN], (const struct sockaddr *)&addr6, sizeof(addr6)) == SOCKET_ERROR
1469 || listen(conn->fd[SOCK_IPV6_LISTEN], 5) == SOCKET_ERROR
1470 || !setSocketBlocking(conn->fd[SOCK_IPV6_LISTEN], false))
1471 {
1472 debug(LOG_ERROR, "Failed to set up IPv6 socket for listening on port %u: %s", port, strSockError(getSockErr()));
1473 #if defined(WZ_OS_WIN)
1474 closesocket(conn->fd[SOCK_IPV6_LISTEN]);
1475 #else
1476 close(conn->fd[SOCK_IPV6_LISTEN]);
1477 #endif
1478 conn->fd[SOCK_IPV6_LISTEN] = INVALID_SOCKET;
1479 }
1480 }
1481
1482 // Check whether we still have at least a single (operating) socket.
1483 if (conn->fd[SOCK_IPV4_LISTEN] == INVALID_SOCKET
1484 && conn->fd[SOCK_IPV6_LISTEN] == INVALID_SOCKET)
1485 {
1486 debug(LOG_NET, "No IPv4 or IPv6 sockets created.");
1487 socketClose(conn);
1488 return nullptr;
1489 }
1490
1491 return conn;
1492 }
1493
socketOpenAny(const SocketAddress * addr,unsigned timeout)1494 Socket *socketOpenAny(const SocketAddress *addr, unsigned timeout)
1495 {
1496 Socket *ret = nullptr;
1497 while (addr != nullptr && ret == nullptr)
1498 {
1499 ret = socketOpen(addr, timeout);
1500
1501 addr = addr->ai_next;
1502 }
1503
1504 return ret;
1505 }
1506
socketArrayOpen(Socket ** sockets,size_t maxSockets,const SocketAddress * addr,unsigned timeout)1507 size_t socketArrayOpen(Socket **sockets, size_t maxSockets, const SocketAddress *addr, unsigned timeout)
1508 {
1509 size_t i = 0;
1510 while (i < maxSockets && addr != nullptr)
1511 {
1512 if (addr->ai_family == AF_INET || addr->ai_family == AF_INET6)
1513 {
1514 sockets[i] = socketOpen(addr, timeout);
1515 i += sockets[i] != nullptr;
1516 }
1517
1518 addr = addr->ai_next;
1519 }
1520 std::fill(sockets + i, sockets + maxSockets, (Socket *)nullptr);
1521 return i;
1522 }
1523
socketArrayClose(Socket ** sockets,size_t maxSockets)1524 void socketArrayClose(Socket **sockets, size_t maxSockets)
1525 {
1526 std::for_each(sockets, sockets + maxSockets, socketClose); // Close any open sockets.
1527 std::fill(sockets, sockets + maxSockets, (Socket *)nullptr); // Set the pointers to NULL.
1528 }
1529
socketHasIPv4(Socket * sock)1530 WZ_DECL_NONNULL(1) bool socketHasIPv4(Socket *sock)
1531 {
1532 if (sock->fd[SOCK_IPV4_LISTEN] != INVALID_SOCKET)
1533 {
1534 return true;
1535 }
1536 else
1537 {
1538 #if defined(IPV6_V6ONLY)
1539 if (sock->fd[SOCK_IPV6_LISTEN] != INVALID_SOCKET)
1540 {
1541 int ipv6_v6only = 1;
1542 socklen_t len = sizeof(ipv6_v6only);
1543 if (getsockopt(sock->fd[SOCK_IPV6_LISTEN], IPPROTO_IPV6, IPV6_V6ONLY, (char *)&ipv6_v6only, &len) == 0)
1544 {
1545 return ipv6_v6only == 0;
1546 }
1547 }
1548 #endif
1549 return false;
1550 }
1551 }
1552
socketHasIPv6(Socket * sock)1553 WZ_DECL_NONNULL(1) bool socketHasIPv6(Socket *sock)
1554 {
1555 return sock->fd[SOCK_IPV6_LISTEN] != INVALID_SOCKET;
1556 }
1557
getSocketTextAddress(Socket const * sock)1558 char const *getSocketTextAddress(Socket const *sock)
1559 {
1560 return sock->textAddress;
1561 }
1562
ipv4_AddressString_To_NetBinary(const std::string & ipv4Address)1563 std::vector<unsigned char> ipv4_AddressString_To_NetBinary(const std::string& ipv4Address)
1564 {
1565 std::vector<unsigned char> binaryForm(sizeof(struct in_addr), 0);
1566 if (inet_pton(AF_INET, ipv4Address.c_str(), binaryForm.data()) <= 0)
1567 {
1568 // inet_pton failed
1569 binaryForm.clear();
1570 }
1571 return binaryForm;
1572 }
1573
1574 #ifndef INET_ADDRSTRLEN
1575 # define INET_ADDRSTRLEN 16
1576 #endif
1577
ipv4_NetBinary_To_AddressString(const std::vector<unsigned char> & ip4NetBinaryForm)1578 std::string ipv4_NetBinary_To_AddressString(const std::vector<unsigned char>& ip4NetBinaryForm)
1579 {
1580 if (ip4NetBinaryForm.size() != sizeof(struct in_addr))
1581 {
1582 return "";
1583 }
1584
1585 char buf[INET_ADDRSTRLEN] = {0};
1586 if (inet_ntop(AF_INET, ip4NetBinaryForm.data(), buf, sizeof(buf)) == nullptr)
1587 {
1588 return "";
1589 }
1590 std::string ipv4Address = buf;
1591 return ipv4Address;
1592 }
1593
ipv6_AddressString_To_NetBinary(const std::string & ipv6Address)1594 std::vector<unsigned char> ipv6_AddressString_To_NetBinary(const std::string& ipv6Address)
1595 {
1596 std::vector<unsigned char> binaryForm(sizeof(struct in6_addr), 0);
1597 if (inet_pton(AF_INET6, ipv6Address.c_str(), binaryForm.data()) <= 0)
1598 {
1599 // inet_pton failed
1600 binaryForm.clear();
1601 }
1602 return binaryForm;
1603 }
1604
1605 #ifndef INET6_ADDRSTRLEN
1606 # define INET6_ADDRSTRLEN 46
1607 #endif
1608
ipv6_NetBinary_To_AddressString(const std::vector<unsigned char> & ip6NetBinaryForm)1609 std::string ipv6_NetBinary_To_AddressString(const std::vector<unsigned char>& ip6NetBinaryForm)
1610 {
1611 if (ip6NetBinaryForm.size() != sizeof(struct in6_addr))
1612 {
1613 return "";
1614 }
1615
1616 char buf[INET6_ADDRSTRLEN] = {0};
1617 if (inet_ntop(AF_INET6, ip6NetBinaryForm.data(), buf, sizeof(buf)) == nullptr)
1618 {
1619 return "";
1620 }
1621 std::string ipv6Address = buf;
1622 return ipv6Address;
1623 }
1624
resolveHost(const char * host,unsigned int port)1625 SocketAddress *resolveHost(const char *host, unsigned int port)
1626 {
1627 struct addrinfo *results;
1628 std::string service;
1629 struct addrinfo hint;
1630 int error, flags = 0;
1631
1632 hint.ai_family = AF_UNSPEC;
1633 hint.ai_socktype = SOCK_STREAM;
1634 hint.ai_protocol = 0;
1635 #ifdef AI_V4MAPPED
1636 flags |= AI_V4MAPPED;
1637 #endif
1638 #ifdef AI_ADDRCONFIG
1639 flags |= AI_ADDRCONFIG;
1640 #endif
1641 hint.ai_flags = flags;
1642 hint.ai_addrlen = 0;
1643 hint.ai_addr = nullptr;
1644 hint.ai_canonname = nullptr;
1645 hint.ai_next = nullptr;
1646
1647 service = astringf("%u", port);
1648
1649 error = getaddrinfo(host, service.c_str(), &hint, &results);
1650 if (error != 0)
1651 {
1652 debug(LOG_NET, "getaddrinfo failed for %s:%s: %s", host, service.c_str(), gai_strerror(error));
1653 return nullptr;
1654 }
1655
1656 return results;
1657 }
1658
deleteSocketAddress(SocketAddress * addr)1659 void deleteSocketAddress(SocketAddress *addr)
1660 {
1661 freeaddrinfo(addr);
1662 }
1663
1664 // ////////////////////////////////////////////////////////////////////////
1665 // setup stuff
SOCKETinit()1666 void SOCKETinit()
1667 {
1668 #if defined(WZ_OS_WIN)
1669 static bool firstCall = true;
1670 if (firstCall)
1671 {
1672 firstCall = false;
1673
1674 static WSADATA stuff;
1675 WORD ver_required = (2 << 8) + 2;
1676 if (WSAStartup(ver_required, &stuff) != 0)
1677 {
1678 debug(LOG_ERROR, "Failed to initialize Winsock: %s", strSockError(getSockErr()));
1679 return;
1680 }
1681
1682 winsock2_dll = LoadLibraryA("ws2_32.dll");
1683 if (winsock2_dll)
1684 {
1685 getaddrinfo_dll_func = reinterpret_cast<GETADDRINFO_DLL_FUNC>(reinterpret_cast<void*>(GetProcAddress(winsock2_dll, "getaddrinfo")));
1686 freeaddrinfo_dll_func = reinterpret_cast<FREEADDRINFO_DLL_FUNC>(reinterpret_cast<void*>(GetProcAddress(winsock2_dll, "freeaddrinfo")));
1687 }
1688 }
1689 #endif
1690
1691 if (socketThread == nullptr)
1692 {
1693 socketThreadQuit = false;
1694 socketThreadMutex = wzMutexCreate();
1695 socketThreadSemaphore = wzSemaphoreCreate(0);
1696 socketThread = wzThreadCreate(socketThreadFunction, nullptr);
1697 wzThreadStart(socketThread);
1698 }
1699 }
1700
SOCKETshutdown()1701 void SOCKETshutdown()
1702 {
1703 if (socketThread != nullptr)
1704 {
1705 wzMutexLock(socketThreadMutex);
1706 socketThreadQuit = true;
1707 socketThreadWrites.clear();
1708 wzMutexUnlock(socketThreadMutex);
1709 wzSemaphorePost(socketThreadSemaphore); // Wake up the thread, so it can quit.
1710 wzThreadJoin(socketThread);
1711 wzMutexDestroy(socketThreadMutex);
1712 wzSemaphoreDestroy(socketThreadSemaphore);
1713 socketThread = nullptr;
1714 }
1715
1716 #if defined(WZ_OS_WIN)
1717 WSACleanup();
1718
1719 if (winsock2_dll)
1720 {
1721 FreeLibrary(winsock2_dll);
1722 winsock2_dll = NULL;
1723 getaddrinfo_dll_func = NULL;
1724 freeaddrinfo_dll_func = NULL;
1725 }
1726 #endif
1727 }
1728