1 /*
2 	This file is part of Warzone 2100.
3 	Copyright (C) 1999-2004  Eidos Interactive
4 	Copyright (C) 2005-2020  Warzone 2100 Project
5 
6 	Warzone 2100 is free software; you can redistribute it and/or modify
7 	it under the terms of the GNU General Public License as published by
8 	the Free Software Foundation; either version 2 of the License, or
9 	(at your option) any later version.
10 
11 	Warzone 2100 is distributed in the hope that it will be useful,
12 	but WITHOUT ANY WARRANTY; without even the implied warranty of
13 	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 	GNU General Public License for more details.
15 
16 	You should have received a copy of the GNU General Public License
17 	along with Warzone 2100; if not, write to the Free Software
18 	Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20 /**
21  * @file netsocket.cpp
22  *
23  * Basic raw socket handling code.
24  */
25 
26 #include "lib/framework/frame.h"
27 #include "lib/framework/wzapp.h"
28 #include "netsocket.h"
29 
30 #include <vector>
31 #include <algorithm>
32 #include <map>
33 
34 #if !defined(ZLIB_CONST)
35 #  define ZLIB_CONST
36 #endif
37 #include <zlib.h>
38 
39 enum
40 {
41 	SOCK_CONNECTION,
42 	SOCK_IPV4_LISTEN = SOCK_CONNECTION,
43 	SOCK_IPV6_LISTEN,
44 	SOCK_COUNT,
45 };
46 
47 struct Socket
48 {
49 	/* Multiple socket handles only for listening sockets. This allows us
50 	 * to listen on multiple protocols and address families (e.g. IPv4 and
51 	 * IPv6).
52 	 *
53 	 * All non-listening sockets will only use the first socket handle.
54 	 */
SocketSocket55 	Socket() : ready(false), writeError(false), deleteLater(false), isCompressed(false), readDisconnected(false), zDeflateInSize(0)
56 	{
57 		memset(&zDeflate, 0, sizeof(zDeflate));
58 		memset(&zInflate, 0, sizeof(zInflate));
59 	}
60 	~Socket();
61 
62 	SOCKET fd[SOCK_COUNT];
63 	bool ready;
64 	bool writeError;
65 	bool deleteLater;
66 	char textAddress[40];
67 
68 	bool isCompressed;
69 	bool readDisconnected;  ///< True iff a call to recv() returned 0.
70 	z_stream zDeflate;
71 	z_stream zInflate;
72 	unsigned zDeflateInSize;
73 	bool zInflateNeedInput;
74 	std::vector<uint8_t> zDeflateOutBuf;
75 	std::vector<uint8_t> zInflateInBuf;
76 };
77 
78 struct SocketSet
79 {
80 	std::vector<Socket *> fds;
81 };
82 
83 
84 static WZ_MUTEX *socketThreadMutex;
85 static WZ_SEMAPHORE *socketThreadSemaphore;
86 static WZ_THREAD *socketThread = nullptr;
87 static bool socketThreadQuit;
88 typedef std::map<Socket *, std::vector<uint8_t>> SocketThreadWriteMap;
89 static SocketThreadWriteMap socketThreadWrites;
90 
91 
92 static void socketCloseNow(Socket *sock);
93 
94 
socketReadReady(Socket const * sock)95 bool socketReadReady(Socket const *sock)
96 {
97 	return sock->ready;
98 }
99 
getSockErr(void)100 int getSockErr(void)
101 {
102 #if   defined(WZ_OS_UNIX)
103 	return errno;
104 #elif defined(WZ_OS_WIN)
105 	return WSAGetLastError();
106 #endif
107 }
108 
setSockErr(int error)109 void setSockErr(int error)
110 {
111 #if   defined(WZ_OS_UNIX)
112 	errno = error;
113 #elif defined(WZ_OS_WIN)
114 	WSASetLastError(error);
115 #endif
116 }
117 
118 #if defined(WZ_OS_WIN)
119 typedef int (WINAPI *GETADDRINFO_DLL_FUNC)(const char *node, const char *service,
120         const struct addrinfo *hints,
121         struct addrinfo **res);
122 typedef int (WINAPI *FREEADDRINFO_DLL_FUNC)(struct addrinfo *res);
123 
124 static HMODULE winsock2_dll = nullptr;
125 
126 static GETADDRINFO_DLL_FUNC getaddrinfo_dll_func = nullptr;
127 static FREEADDRINFO_DLL_FUNC freeaddrinfo_dll_func = nullptr;
128 
129 # define getaddrinfo  getaddrinfo_dll_dispatcher
130 # define freeaddrinfo freeaddrinfo_dll_dispatcher
131 
132 # include <ntverp.h>				// Windows SDK - include for access to VER_PRODUCTBUILD
133 # if VER_PRODUCTBUILD >= 9600
134 	// 9600 is the Windows SDK 8.1
135 	# include <VersionHelpers.h>	// For IsWindowsVistaOrGreater()
136 # else
137 	// Earlier SDKs may not have VersionHelpers.h - use simple fallback
IsWindowsVistaOrGreater()138 	inline bool IsWindowsVistaOrGreater()
139 	{
140 		DWORD dwMajorVersion = (DWORD)(LOBYTE(LOWORD(GetVersion())));
141 		return dwMajorVersion >= 6;
142 	}
143 # endif
144 
getaddrinfo(const char * node,const char * service,const struct addrinfo * hints,struct addrinfo ** res)145 static int getaddrinfo(const char *node, const char *service,
146                        const struct addrinfo *hints,
147                        struct addrinfo **res)
148 {
149 	struct addrinfo hint;
150 	if (hints)
151 	{
152 		memcpy(&hint, hints, sizeof(hint));
153 	}
154 
155 //	// Windows 95, 98 and ME
156 //		debug(LOG_ERROR, "Name resolution isn't supported on this version of Windows");
157 //		return EAI_FAIL;
158 
159 	if (!IsWindowsVistaOrGreater())
160 	{
161 		// Windows 2000, XP and Server 2003
162 		if (hints)
163 		{
164 			// These flags are only supported from Windows Vista+
165 			hint.ai_flags &= ~(AI_V4MAPPED | AI_ADDRCONFIG);
166 		}
167 	}
168 
169 	if (!winsock2_dll)
170 	{
171 		debug(LOG_ERROR, "Failed to load winsock2 DLL. Required for name resolution.");
172 		return EAI_FAIL;
173 	}
174 
175 	if (!getaddrinfo_dll_func)
176 	{
177 		debug(LOG_ERROR, "Failed to retrieve \"getaddrinfo\" function from winsock2 DLL. Required for name resolution.");
178 		return EAI_FAIL;
179 	}
180 
181 	return getaddrinfo_dll_func(node, service, hints ? &hint : NULL, res);
182 }
183 
freeaddrinfo(struct addrinfo * res)184 static void freeaddrinfo(struct addrinfo *res)
185 {
186 
187 //	// Windows 95, 98 and ME
188 //		debug(LOG_ERROR, "Name resolution isn't supported on this version of Windows");
189 //		return;
190 
191 	if (!winsock2_dll)
192 	{
193 		debug(LOG_ERROR, "Failed to load winsock2 DLL. Required for name resolution.");
194 		return;
195 	}
196 
197 	if (!freeaddrinfo_dll_func)
198 	{
199 		debug(LOG_ERROR, "Failed to retrieve \"freeaddrinfo\" function from winsock2 DLL. Required for name resolution.");
200 		return;
201 	}
202 
203 	freeaddrinfo_dll_func(res);
204 }
205 #endif
206 
addressToText(const struct sockaddr * addr,char * buf,size_t size)207 static int addressToText(const struct sockaddr *addr, char *buf, size_t size)
208 {
209 	auto handleIpv4 = [&](uint32_t addr) {
210 		uint32_t val = ntohl(addr);
211 		return snprintf(buf, size, "%u.%u.%u.%u", (val>>24)&0xFF, (val>>16)&0xFF, (val>>8)&0xFF, val&0xFF);
212 	};
213 
214 	switch (addr->sa_family)
215 	{
216 	case AF_INET:
217 		{
218 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
219 # pragma GCC diagnostic push
220 # pragma GCC diagnostic ignored "-Wcast-align"
221 #endif
222 			return handleIpv4((reinterpret_cast<const sockaddr_in *>(addr))->sin_addr.s_addr);
223 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
224 # pragma GCC diagnostic pop
225 #endif
226 		}
227 	case AF_INET6:
228 		{
229 
230 			// Check to see if this is really a IPv6 address
231 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
232 # pragma GCC diagnostic push
233 # pragma GCC diagnostic ignored "-Wcast-align"
234 #endif
235 			const struct sockaddr_in6 *mappedIP = reinterpret_cast<const sockaddr_in6 *>(addr);
236 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
237 # pragma GCC diagnostic pop
238 #endif
239 			if (IN6_IS_ADDR_V4MAPPED(&mappedIP->sin6_addr))
240 			{
241 				// looks like it is ::ffff:(...) so lets set up a IPv4 socket address structure
242 				// slightly overkill for our needs, but it shows exactly what needs to be done.
243 				// At this time, we only care about the address, nothing else.
244 				struct sockaddr_in addr4;
245 				memcpy(&addr4.sin_addr.s_addr, mappedIP->sin6_addr.s6_addr + 12, sizeof(addr4.sin_addr.s_addr));
246 				return handleIpv4(addr4.sin_addr.s_addr);
247 			}
248 			else
249 			{
250 				static_assert(sizeof(in6_addr::s6_addr) == 16, "Standard expects in6_addr structure that contains member s6_addr[16], a 16-element array of uint8_t");
251 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
252 # pragma GCC diagnostic push
253 # pragma GCC diagnostic ignored "-Wcast-align"
254 #endif
255 				const uint8_t *address_u8 = &((reinterpret_cast<const sockaddr_in6 *>(addr))->sin6_addr.s6_addr[0]);
256 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
257 # pragma GCC diagnostic pop
258 #endif
259 				uint16_t address[8] = {0};
260 				memcpy(&address, address_u8, sizeof(uint8_t) * 16);
261 				return snprintf(buf, size,
262 								"%hx:%hx:%hx:%hx:%hx:%hx:%hx:%hx",
263 								ntohs(address[0]),
264 								ntohs(address[1]),
265 								ntohs(address[2]),
266 								ntohs(address[3]),
267 								ntohs(address[4]),
268 								ntohs(address[5]),
269 								ntohs(address[6]),
270 								ntohs(address[7]));
271 			}
272 		}
273 	default:
274 		ASSERT(!"Unknown address family", "Got non IPv4 or IPv6 address!");
275 		return -1;
276 	}
277 }
278 
strSockError(int error)279 const char *strSockError(int error)
280 {
281 #if   defined(WZ_OS_WIN)
282 	switch (error)
283 	{
284 	case 0:                     return "No error";
285 	case WSAEINTR:              return "Interrupted system call";
286 	case WSAEBADF:              return "Bad file number";
287 	case WSAEACCES:             return "Permission denied";
288 	case WSAEFAULT:             return "Bad address";
289 	case WSAEINVAL:             return "Invalid argument";
290 	case WSAEMFILE:             return "Too many open sockets";
291 	case WSAEWOULDBLOCK:        return "Operation would block";
292 	case WSAEINPROGRESS:        return "Operation now in progress";
293 	case WSAEALREADY:           return "Operation already in progress";
294 	case WSAENOTSOCK:           return "Socket operation on non-socket";
295 	case WSAEDESTADDRREQ:       return "Destination address required";
296 	case WSAEMSGSIZE:           return "Message too long";
297 	case WSAEPROTOTYPE:         return "Protocol wrong type for socket";
298 	case WSAENOPROTOOPT:        return "Bad protocol option";
299 	case WSAEPROTONOSUPPORT:    return "Protocol not supported";
300 	case WSAESOCKTNOSUPPORT:    return "Socket type not supported";
301 	case WSAEOPNOTSUPP:         return "Operation not supported on socket";
302 	case WSAEPFNOSUPPORT:       return "Protocol family not supported";
303 	case WSAEAFNOSUPPORT:       return "Address family not supported";
304 	case WSAEADDRINUSE:         return "Address already in use";
305 	case WSAEADDRNOTAVAIL:      return "Can't assign requested address";
306 	case WSAENETDOWN:           return "Network is down";
307 	case WSAENETUNREACH:        return "Network is unreachable";
308 	case WSAENETRESET:          return "Net connection reset";
309 	case WSAECONNABORTED:       return "Software caused connection abort";
310 	case WSAECONNRESET:         return "Connection reset by peer";
311 	case WSAENOBUFS:            return "No buffer space available";
312 	case WSAEISCONN:            return "Socket is already connected";
313 	case WSAENOTCONN:           return "Socket is not connected";
314 	case WSAESHUTDOWN:          return "Can't send after socket shutdown";
315 	case WSAETOOMANYREFS:       return "Too many references, can't splice";
316 	case WSAETIMEDOUT:          return "Connection timed out";
317 	case WSAECONNREFUSED:       return "Connection refused";
318 	case WSAELOOP:              return "Too many levels of symbolic links";
319 	case WSAENAMETOOLONG:       return "File name too long";
320 	case WSAEHOSTDOWN:          return "Host is down";
321 	case WSAEHOSTUNREACH:       return "No route to host";
322 	case WSAENOTEMPTY:          return "Directory not empty";
323 	case WSAEPROCLIM:           return "Too many processes";
324 	case WSAEUSERS:             return "Too many users";
325 	case WSAEDQUOT:             return "Disc quota exceeded";
326 	case WSAESTALE:             return "Stale NFS file handle";
327 	case WSAEREMOTE:            return "Too many levels of remote in path";
328 	case WSASYSNOTREADY:        return "Network system is unavailable";
329 	case WSAVERNOTSUPPORTED:    return "Winsock version out of range";
330 	case WSANOTINITIALISED:     return "WSAStartup not yet called";
331 	case WSAEDISCON:            return "Graceful shutdown in progress";
332 	case WSAHOST_NOT_FOUND:     return "Host not found";
333 	case WSANO_DATA:            return "No host data of that type was found";
334 	default:                    return "Unknown error";
335 	}
336 #elif defined(WZ_OS_UNIX)
337 	return strerror(error);
338 #endif
339 }
340 
341 /**
342  * Test whether the given socket still has an open connection.
343  *
344  * @return true when the connection is open, false when it's closed or in an
345  *         error state, check getSockErr() to find out which.
346  */
connectionIsOpen(Socket * sock)347 static bool connectionIsOpen(Socket *sock)
348 {
349 	const SocketSet set = {std::vector<Socket *>(1, sock)};
350 
351 	ASSERT_OR_RETURN((setSockErr(EBADF), false),
352 	                 sock && sock->fd[SOCK_CONNECTION] != INVALID_SOCKET, "Invalid socket");
353 
354 	// Check whether the socket is still connected
355 	int ret = checkSockets(&set, 0);
356 	if (ret == SOCKET_ERROR)
357 	{
358 		return false;
359 	}
360 	else if (ret == (int)set.fds.size() && sock->ready)
361 	{
362 		/* The next recv(2) call won't block, but we're writing. So
363 		 * check the read queue to see if the connection is closed.
364 		 * If there's no data in the queue that means the connection
365 		 * is closed.
366 		 */
367 #if defined(WZ_OS_WIN)
368 		unsigned long readQueue;
369 		ret = ioctlsocket(sock->fd[SOCK_CONNECTION], FIONREAD, &readQueue);
370 #else
371 		int readQueue;
372 		ret = ioctl(sock->fd[SOCK_CONNECTION], FIONREAD, &readQueue);
373 #endif
374 		if (ret == SOCKET_ERROR)
375 		{
376 			debug(LOG_NET, "socket error");
377 			return false;
378 		}
379 		else if (readQueue == 0)
380 		{
381 			// Disconnected
382 			setSockErr(ECONNRESET);
383 			debug(LOG_NET, "Read queue empty - failing (ECONNRESET)");
384 			return false;
385 		}
386 	}
387 
388 	return true;
389 }
390 
socketThreadFunction(void *)391 static int socketThreadFunction(void *)
392 {
393 	wzMutexLock(socketThreadMutex);
394 	while (!socketThreadQuit)
395 	{
396 #if   defined(WZ_OS_UNIX)
397 		SOCKET maxfd = INT_MIN;
398 #elif defined(WZ_OS_WIN)
399 		SOCKET maxfd = 0;
400 #endif
401 		fd_set fds;
402 		FD_ZERO(&fds);
403 		for (SocketThreadWriteMap::iterator i = socketThreadWrites.begin(); i != socketThreadWrites.end(); ++i)
404 		{
405 			if (!i->second.empty())
406 			{
407 				SOCKET fd = i->first->fd[SOCK_CONNECTION];
408 				maxfd = std::max(maxfd, fd);
409 				ASSERT(!FD_ISSET(fd, &fds), "Duplicate file descriptor!");  // Shouldn't be possible, but blocking in send, after select says it won't block, shouldn't be possible either.
410 				FD_SET(fd, &fds);
411 			}
412 		}
413 		struct timeval tv = {0, 50 * 1000};
414 
415 		// Check if we can write to any sockets.
416 		wzMutexUnlock(socketThreadMutex);
417 		int ret = select(maxfd + 1, nullptr, &fds, nullptr, &tv);
418 		wzMutexLock(socketThreadMutex);
419 
420 		// We can write to some sockets. (Ignore errors from select, we may have deleted the socket after unlocking the mutex, and before calling select.)
421 		if (ret > 0)
422 		{
423 			for (SocketThreadWriteMap::iterator i = socketThreadWrites.begin(); i != socketThreadWrites.end();)
424 			{
425 				SocketThreadWriteMap::iterator w = i;
426 				++i;
427 
428 				Socket *sock = w->first;
429 				std::vector<uint8_t> &writeQueue = w->second;
430 				ASSERT(!writeQueue.empty(), "writeQueue[sock] must not be empty.");
431 
432 				if (!FD_ISSET(sock->fd[SOCK_CONNECTION], &fds))
433 				{
434 					continue;  // This socket is not ready for writing, or we don't have anything to write.
435 				}
436 
437 				// Write data.
438 				// FIXME SOMEHOW AAARGH This send() call can't block, but unless the socket is not set to blocking (setting the socket to nonblocking had better work, or else), does anyway (at least sometimes, when someone quits). Not reproducible except in public releases.
439 				ssize_t retSent = send(sock->fd[SOCK_CONNECTION], reinterpret_cast<char *>(&writeQueue[0]), writeQueue.size(), MSG_NOSIGNAL);
440 				if (retSent != SOCKET_ERROR)
441 				{
442 					// Erase as much data as written.
443 					writeQueue.erase(writeQueue.begin(), writeQueue.begin() + retSent);
444 					if (writeQueue.empty())
445 					{
446 						socketThreadWrites.erase(w);  // Nothing left to write, delete from pending list.
447 						if (sock->deleteLater)
448 						{
449 							socketCloseNow(sock);
450 						}
451 					}
452 				}
453 				else
454 				{
455 					switch (getSockErr())
456 					{
457 					case EAGAIN:
458 #if defined(EWOULDBLOCK) && EAGAIN != EWOULDBLOCK
459 					case EWOULDBLOCK:
460 #endif
461 						if (!connectionIsOpen(sock))
462 						{
463 							debug(LOG_NET, "Socket error");
464 							sock->writeError = true;
465 							socketThreadWrites.erase(w);  // Socket broken, don't try writing to it again.
466 							if (sock->deleteLater)
467 							{
468 								socketCloseNow(sock);
469 							}
470 							break;
471 						}
472 					case EINTR:
473 						break;
474 #if defined(EPIPE)
475 					case EPIPE:
476 #endif
477 					default:
478 						sock->writeError = true;
479 						socketThreadWrites.erase(w);  // Socket broken, don't try writing to it again.
480 						if (sock->deleteLater)
481 						{
482 							socketCloseNow(sock);
483 						}
484 						break;
485 					}
486 				}
487 			}
488 		}
489 
490 		if (socketThreadWrites.empty())
491 		{
492 			// Nothing to do, expect to wait.
493 			wzMutexUnlock(socketThreadMutex);
494 			wzSemaphoreWait(socketThreadSemaphore);
495 			wzMutexLock(socketThreadMutex);
496 		}
497 	}
498 	wzMutexUnlock(socketThreadMutex);
499 
500 	return 42;  // Return value arbitrary and unused.
501 }
502 
503 /**
504  * Similar to read(2) with the exception that this function won't be
505  * interrupted by signals (EINTR).
506  */
readNoInt(Socket * sock,void * buf,size_t max_size,size_t * rawByteCount)507 ssize_t readNoInt(Socket *sock, void *buf, size_t max_size, size_t *rawByteCount)
508 {
509 	size_t ignored;
510 	size_t &rawBytes = rawByteCount != nullptr ? *rawByteCount : ignored;
511 	rawBytes = 0;
512 
513 	if (sock->fd[SOCK_CONNECTION] == INVALID_SOCKET)
514 	{
515 		debug(LOG_ERROR, "Invalid socket");
516 		setSockErr(EBADF);
517 		return SOCKET_ERROR;
518 	}
519 
520 	if (sock->isCompressed)
521 	{
522 		if (sock->zInflateNeedInput)
523 		{
524 			// No input data, read some.
525 
526 			sock->zInflateInBuf.resize(max_size + 1000);
527 
528 			ssize_t received;
529 			do
530 			{
531 				//                                                  v----- This weird cast is because recv() takes a char * on windows instead of a void *...
532 				received = recv(sock->fd[SOCK_CONNECTION], (char *)&sock->zInflateInBuf[0], sock->zInflateInBuf.size(), 0);
533 			}
534 			while (received == SOCKET_ERROR && getSockErr() == EINTR);
535 			if (received < 0)
536 			{
537 				return received;
538 			}
539 
540 			sock->zInflate.next_in = &sock->zInflateInBuf[0];
541 			sock->zInflate.avail_in = received;
542 			rawBytes = received;
543 
544 			if (received == 0)
545 			{
546 				sock->readDisconnected = true;
547 			}
548 			else
549 			{
550 				sock->zInflateNeedInput = false;
551 			}
552 		}
553 
554 		sock->zInflate.next_out = (Bytef *)buf;
555 		sock->zInflate.avail_out = max_size;
556 		int ret = inflate(&sock->zInflate, Z_NO_FLUSH);
557 		ASSERT(ret != Z_STREAM_ERROR, "zlib inflate not working!");
558 		char const *err = nullptr;
559 		switch (ret)
560 		{
561 		case Z_NEED_DICT:  err = "Z_NEED_DICT";  break;
562 		case Z_DATA_ERROR: err = "Z_DATA_ERROR"; break;
563 		case Z_MEM_ERROR:  err = "Z_MEM_ERROR";  break;
564 		}
565 		if (err != nullptr)
566 		{
567 			debug(LOG_ERROR, "Couldn't decompress data from socket. zlib error %s", err);
568 			return -1;  // Bad data!
569 		}
570 
571 		if (sock->zInflate.avail_out != 0)
572 		{
573 			sock->zInflateNeedInput = true;
574 			ASSERT(sock->zInflate.avail_in == 0, "zlib not consuming all input!");
575 		}
576 
577 		return max_size - sock->zInflate.avail_out;  // Got some data, return how much.
578 	}
579 
580 	ssize_t received;
581 	do
582 	{
583 		received = recv(sock->fd[SOCK_CONNECTION], (char *)buf, max_size, 0);
584 		if (received == 0)
585 		{
586 			sock->readDisconnected = true;
587 		}
588 	}
589 	while (received == SOCKET_ERROR && getSockErr() == EINTR);
590 
591 	sock->ready = false;
592 
593 	rawBytes = received;
594 	return received;
595 }
596 
socketReadDisconnected(Socket * sock)597 bool socketReadDisconnected(Socket *sock)
598 {
599 	return sock->readDisconnected;
600 }
601 
602 /**
603  * Similar to write(2) with the exception that this function will block until
604  * <em>all</em> data has been written or an error occurs.
605  *
606  * @return @c size when successful or @c SOCKET_ERROR if an error occurred.
607  */
writeAll(Socket * sock,const void * buf,size_t size,size_t * rawByteCount)608 ssize_t writeAll(Socket *sock, const void *buf, size_t size, size_t *rawByteCount)
609 {
610 	size_t ignored;
611 	size_t &rawBytes = rawByteCount != nullptr ? *rawByteCount : ignored;
612 	rawBytes = 0;
613 
614 	if (sock->fd[SOCK_CONNECTION] == INVALID_SOCKET)
615 	{
616 		debug(LOG_ERROR, "Invalid socket (EBADF)");
617 		setSockErr(EBADF);
618 		return SOCKET_ERROR;
619 	}
620 
621 	if (sock->writeError)
622 	{
623 		return SOCKET_ERROR;
624 	}
625 
626 	if (size > 0)
627 	{
628 		if (!sock->isCompressed)
629 		{
630 			wzMutexLock(socketThreadMutex);
631 			if (socketThreadWrites.empty())
632 			{
633 				wzSemaphorePost(socketThreadSemaphore);
634 			}
635 			std::vector<uint8_t> &writeQueue = socketThreadWrites[sock];
636 			writeQueue.insert(writeQueue.end(), static_cast<char const *>(buf), static_cast<char const *>(buf) + size);
637 			wzMutexUnlock(socketThreadMutex);
638 			rawBytes = size;
639 		}
640 		else
641 		{
642 		#if ZLIB_VERNUM < 0x1252
643 			// zlib < 1.2.5.2 does not support `#define ZLIB_CONST`
644 			// Unfortunately, some OSes (ex. OpenBSD) ship with zlib < 1.2.5.2
645 			// Workaround: cast away the const of the input, and disable the resulting -Wcast-qual warning
646 			#if defined(__clang__)
647 			#  pragma clang diagnostic push
648 			#  pragma clang diagnostic ignored "-Wcast-qual"
649 			#elif defined(__GNUC__)
650 			#  pragma GCC diagnostic push
651 			#  pragma GCC diagnostic ignored "-Wcast-qual"
652 			#endif
653 
654 			// cast away the const for earlier zlib versions
655 			sock->zDeflate.next_in = (Bytef *)buf; // -Wcast-qual
656 
657 			#if defined(__clang__)
658 			#  pragma clang diagnostic pop
659 			#elif defined(__GNUC__)
660 			#  pragma GCC diagnostic pop
661 			#endif
662 		#else
663 			// zlib >= 1.2.5.2 supports ZLIB_CONST
664 			sock->zDeflate.next_in = (const Bytef *)buf;
665 		#endif
666 
667 			sock->zDeflate.avail_in = size;
668 			sock->zDeflateInSize += sock->zDeflate.avail_in;
669 			do
670 			{
671 				size_t alreadyHave = sock->zDeflateOutBuf.size();
672 				sock->zDeflateOutBuf.resize(alreadyHave + size + 20);  // A bit more than size should be enough to always do everything in one go.
673 				sock->zDeflate.next_out = (Bytef *)&sock->zDeflateOutBuf[alreadyHave];
674 				sock->zDeflate.avail_out = sock->zDeflateOutBuf.size() - alreadyHave;
675 
676 				int ret = deflate(&sock->zDeflate, Z_NO_FLUSH);
677 				ASSERT(ret != Z_STREAM_ERROR, "zlib compression failed!");
678 
679 				// Remove unused part of buffer.
680 				sock->zDeflateOutBuf.resize(sock->zDeflateOutBuf.size() - sock->zDeflate.avail_out);
681 			}
682 			while (sock->zDeflate.avail_out == 0);
683 
684 			ASSERT(sock->zDeflate.avail_in == 0, "zlib didn't compress everything!");
685 		}
686 	}
687 
688 	return size;
689 }
690 
socketFlush(Socket * sock,size_t * rawByteCount)691 void socketFlush(Socket *sock, size_t *rawByteCount)
692 {
693 	size_t ignored;
694 	size_t &rawBytes = rawByteCount != nullptr ? *rawByteCount : ignored;
695 	rawBytes = 0;
696 
697 	if (!sock->isCompressed)
698 	{
699 		return;  // Not compressed, so don't mess with zlib.
700 	}
701 
702 	// Flush data out of zlib compression state.
703 	do
704 	{
705 		sock->zDeflate.next_in = (Bytef *)nullptr;
706 		sock->zDeflate.avail_in = 0;
707 		size_t alreadyHave = sock->zDeflateOutBuf.size();
708 		sock->zDeflateOutBuf.resize(alreadyHave + 1000);  // 100 bytes would probably be enough to flush the rest in one go.
709 		sock->zDeflate.next_out = (Bytef *)&sock->zDeflateOutBuf[alreadyHave];
710 		sock->zDeflate.avail_out = sock->zDeflateOutBuf.size() - alreadyHave;
711 
712 		int ret = deflate(&sock->zDeflate, Z_PARTIAL_FLUSH);
713 		ASSERT(ret != Z_STREAM_ERROR, "zlib compression failed!");
714 
715 		// Remove unused part of buffer.
716 		sock->zDeflateOutBuf.resize(sock->zDeflateOutBuf.size() - sock->zDeflate.avail_out);
717 	}
718 	while (sock->zDeflate.avail_out == 0);
719 
720 	if (sock->zDeflateOutBuf.empty())
721 	{
722 		return;  // No data to flush out.
723 	}
724 
725 	wzMutexLock(socketThreadMutex);
726 	if (socketThreadWrites.empty())
727 	{
728 		wzSemaphorePost(socketThreadSemaphore);
729 	}
730 	std::vector<uint8_t> &writeQueue = socketThreadWrites[sock];
731 	writeQueue.insert(writeQueue.end(), sock->zDeflateOutBuf.begin(), sock->zDeflateOutBuf.end());
732 	wzMutexUnlock(socketThreadMutex);
733 
734 	// Primitive network logging, uncomment to use.
735 	//printf("Size %3u ->%3zu, buf =", sock->zDeflateInSize, sock->zDeflateOutBuf.size());
736 	//for (unsigned n = 0; n < std::min<unsigned>(sock->zDeflateOutBuf.size(), 40); ++n) printf(" %02X", sock->zDeflateOutBuf[n]);
737 	//printf("\n");
738 
739 	// Data sent, don't send again.
740 	rawBytes = sock->zDeflateOutBuf.size();
741 	sock->zDeflateInSize = 0;
742 	sock->zDeflateOutBuf.clear();
743 }
744 
socketBeginCompression(Socket * sock)745 void socketBeginCompression(Socket *sock)
746 {
747 	if (sock->isCompressed)
748 	{
749 		return;  // Nothing to do.
750 	}
751 
752 	wzMutexLock(socketThreadMutex);
753 
754 	// Init deflate.
755 	sock->zDeflate.zalloc = Z_NULL;
756 	sock->zDeflate.zfree = Z_NULL;
757 	sock->zDeflate.opaque = Z_NULL;
758 	int ret = deflateInit(&sock->zDeflate, 6);
759 	ASSERT(ret == Z_OK, "deflateInit failed! Sockets won't work.");
760 
761 	sock->zInflate.zalloc = Z_NULL;
762 	sock->zInflate.zfree = Z_NULL;
763 	sock->zInflate.opaque = Z_NULL;
764 	sock->zInflate.avail_in = 0;
765 	sock->zInflate.next_in = Z_NULL;
766 	ret = inflateInit(&sock->zInflate);
767 	ASSERT(ret == Z_OK, "deflateInit failed! Sockets won't work.");
768 
769 	sock->zInflateNeedInput = true;
770 
771 	sock->isCompressed = true;
772 	wzMutexUnlock(socketThreadMutex);
773 }
774 
~Socket()775 Socket::~Socket()
776 {
777 	if (isCompressed)
778 	{
779 		deflateEnd(&zDeflate);
780 		deflateEnd(&zInflate);
781 	}
782 }
783 
allocSocketSet()784 SocketSet *allocSocketSet()
785 {
786 	return new SocketSet;
787 }
788 
deleteSocketSet(SocketSet * set)789 void deleteSocketSet(SocketSet *set)
790 {
791 	delete set;
792 }
793 
794 /**
795  * Add the given socket to the given socket set.
796  *
797  * @return true if @c socket is successfully added to @set.
798  */
SocketSet_AddSocket(SocketSet * set,Socket * socket)799 void SocketSet_AddSocket(SocketSet *set, Socket *socket)
800 {
801 	/* Check whether this socket is already present in this set (i.e. it
802 	 * shouldn't be added again).
803 	 */
804 	size_t i = std::find(set->fds.begin(), set->fds.end(), socket) - set->fds.begin();
805 	if (i != set->fds.size())
806 	{
807 		debug(LOG_NET, "Already found, socket: (set->fds[%lu]) %p", (unsigned long)i, static_cast<void *>(socket));
808 		return;
809 	}
810 
811 	set->fds.push_back(socket);
812 	debug(LOG_NET, "Socket added: set->fds[%lu] = %p", (unsigned long)i, static_cast<void *>(socket));
813 }
814 
815 /**
816  * Remove the given socket from the given socket set.
817  */
SocketSet_DelSocket(SocketSet * set,Socket * socket)818 void SocketSet_DelSocket(SocketSet *set, Socket *socket)
819 {
820 	size_t i = std::find(set->fds.begin(), set->fds.end(), socket) - set->fds.begin();
821 	if (i != set->fds.size())
822 	{
823 		debug(LOG_NET, "Socket %p erased (set->fds[%lu])", static_cast<void *>(socket), (unsigned long)i);
824 		set->fds.erase(set->fds.begin() + i);
825 	}
826 }
827 
828 #if !defined(SOCK_CLOEXEC)
setSocketInheritable(SOCKET fd,bool inheritable)829 static bool setSocketInheritable(SOCKET fd, bool inheritable)
830 {
831 #if   defined(WZ_OS_UNIX)
832 	int sockopts = fcntl(fd, F_SETFD);
833 	if (sockopts == SOCKET_ERROR)
834 	{
835 		debug(LOG_NET, "Failed to retrieve current socket options: %s", strSockError(getSockErr()));
836 		return false;
837 	}
838 
839 	// Set or clear FD_CLOEXEC flag
840 	if (inheritable)
841 	{
842 		sockopts &= ~FD_CLOEXEC;
843 	}
844 	else
845 	{
846 		sockopts |= FD_CLOEXEC;
847 	}
848 
849 	if (fcntl(fd, F_SETFD, sockopts) == SOCKET_ERROR)
850 	{
851 		debug(LOG_NET, "Failed to set socket %sinheritable: %s", (inheritable ? "" : "non-"), strSockError(getSockErr()));
852 		return false;
853 	}
854 #elif defined(WZ_OS_WIN)
855 	DWORD dwFlags = (inheritable) ? HANDLE_FLAG_INHERIT : 0;
856 	if (::SetHandleInformation((HANDLE)fd, HANDLE_FLAG_INHERIT, dwFlags) == 0)
857 	{
858 		DWORD dwErr = GetLastError();
859 		debug(LOG_NET, "Failed to set socket %sinheritable: %s", (inheritable ? "" : "non-"), std::to_string(dwErr).c_str());
860 		return false;
861 	}
862 #endif
863 
864 	debug(LOG_NET, "Socket is set to %sinheritable.", (inheritable ? "" : "non-"));
865 	return true;
866 }
867 #endif // !defined(SOCK_CLOEXEC)
868 
setSocketBlocking(const SOCKET fd,bool blocking)869 static bool setSocketBlocking(const SOCKET fd, bool blocking)
870 {
871 #if   defined(WZ_OS_UNIX)
872 	int sockopts = fcntl(fd, F_GETFL);
873 	if (sockopts == SOCKET_ERROR)
874 	{
875 		debug(LOG_NET, "Failed to retrieve current socket options: %s", strSockError(getSockErr()));
876 		return false;
877 	}
878 
879 	// Set or clear O_NONBLOCK flag
880 	if (blocking)
881 	{
882 		sockopts &= ~O_NONBLOCK;
883 	}
884 	else
885 	{
886 		sockopts |= O_NONBLOCK;
887 	}
888 
889 	if (fcntl(fd, F_SETFL, sockopts) == SOCKET_ERROR)
890 #elif defined(WZ_OS_WIN)
891 	unsigned long nonblocking = !blocking;
892 	if (ioctlsocket(fd, FIONBIO, &nonblocking) == SOCKET_ERROR)
893 #endif
894 	{
895 		debug(LOG_NET, "Failed to set socket %sblocking: %s", (blocking ? "" : "non-"), strSockError(getSockErr()));
896 		return false;
897 	}
898 
899 	debug(LOG_NET, "Socket is set to %sblocking.", (blocking ? "" : "non-"));
900 	return true;
901 }
902 
socketBlockSIGPIPE(const SOCKET fd,bool block_sigpipe)903 static void socketBlockSIGPIPE(const SOCKET fd, bool block_sigpipe)
904 {
905 #if defined(SO_NOSIGPIPE)
906 	const int no_sigpipe = block_sigpipe ? 1 : 0;
907 
908 	if (setsockopt(fd, SOL_SOCKET, SO_NOSIGPIPE, &no_sigpipe, sizeof(no_sigpipe)) == SOCKET_ERROR)
909 	{
910 		debug(LOG_INFO, "Failed to set SO_NOSIGPIPE on socket, SIGPIPE might be raised when connections gets broken. Error: %s", strSockError(getSockErr()));
911 	}
912 	// this is only for unix, windows don't have SIGPIPE
913 	debug(LOG_NET, "Socket fd %x sets SIGPIPE to %sblocked.", fd, (block_sigpipe ? "" : "non-"));
914 #else
915 	// Prevent warnings
916 	(void)fd;
917 	(void)block_sigpipe;
918 #endif
919 }
920 
checkSockets(const SocketSet * set,unsigned int timeout)921 int checkSockets(const SocketSet *set, unsigned int timeout)
922 {
923 	if (set->fds.empty())
924 	{
925 		return 0;
926 	}
927 
928 #if   defined(WZ_OS_UNIX)
929 	SOCKET maxfd = INT_MIN;
930 #elif defined(WZ_OS_WIN)
931 	SOCKET maxfd = 0;
932 #endif
933 
934 	bool compressedReady = false;
935 	for (size_t i = 0; i < set->fds.size(); ++i)
936 	{
937 		ASSERT(set->fds[i]->fd[SOCK_CONNECTION] != INVALID_SOCKET, "Invalid file descriptor!");
938 
939 		if (set->fds[i]->isCompressed && !set->fds[i]->zInflateNeedInput)
940 		{
941 			compressedReady = true;
942 			break;
943 		}
944 
945 		maxfd = std::max(maxfd, set->fds[i]->fd[SOCK_CONNECTION]);
946 	}
947 
948 	if (compressedReady)
949 	{
950 		// A socket already has some data ready. Don't really poll the sockets.
951 
952 		int ret = 0;
953 		for (size_t i = 0; i < set->fds.size(); ++i)
954 		{
955 			set->fds[i]->ready = set->fds[i]->isCompressed && !set->fds[i]->zInflateNeedInput;
956 			++ret;
957 		}
958 		return ret;
959 	}
960 
961 	int ret;
962 	fd_set fds;
963 	do
964 	{
965 		struct timeval tv = {(int)(timeout / 1000), (int)(timeout % 1000) * 1000};  // Cast to int to avoid narrowing needed for C++11.
966 
967 		FD_ZERO(&fds);
968 		for (size_t i = 0; i < set->fds.size(); ++i)
969 		{
970 			const SOCKET fd = set->fds[i]->fd[SOCK_CONNECTION];
971 
972 			FD_SET(fd, &fds);
973 		}
974 
975 		ret = select(maxfd + 1, &fds, nullptr, nullptr, &tv);
976 	}
977 	while (ret == SOCKET_ERROR && getSockErr() == EINTR);
978 
979 	if (ret == SOCKET_ERROR)
980 	{
981 		debug(LOG_ERROR, "select failed: %s", strSockError(getSockErr()));
982 		return SOCKET_ERROR;
983 	}
984 
985 	for (size_t i = 0; i < set->fds.size(); ++i)
986 	{
987 		set->fds[i]->ready = FD_ISSET(set->fds[i]->fd[SOCK_CONNECTION], &fds);
988 	}
989 
990 	return ret;
991 }
992 
993 /**
994  * Similar to read(2) with the exception that this function won't be
995  * interrupted by signals (EINTR) and will only return when <em>exactly</em>
996  * @c size bytes have been received. I.e. this function blocks until all data
997  * has been received or a timeout occurred.
998  *
999  * @param timeout When non-zero this function times out after @c timeout
1000  *                milliseconds. When zero this function blocks until success or
1001  *                an error occurs.
1002  *
1003  * @c return @c size when successful, less than @c size but at least zero (0)
1004  * when the other end disconnected or a timeout occurred. Or @c SOCKET_ERROR if
1005  * an error occurred.
1006  */
readAll(Socket * sock,void * buf,size_t size,unsigned int timeout)1007 ssize_t readAll(Socket *sock, void *buf, size_t size, unsigned int timeout)
1008 {
1009 	ASSERT(!sock->isCompressed, "readAll on compressed sockets not implemented.");
1010 
1011 	const SocketSet set = {std::vector<Socket *>(1, sock)};
1012 
1013 	size_t received = 0;
1014 
1015 	if (sock->fd[SOCK_CONNECTION] == INVALID_SOCKET)
1016 	{
1017 		debug(LOG_ERROR, "Invalid socket (%p), sock->fd[SOCK_CONNECTION]=%x  (error: EBADF)", static_cast<void *>(sock), sock->fd[SOCK_CONNECTION]);
1018 		setSockErr(EBADF);
1019 		return SOCKET_ERROR;
1020 	}
1021 
1022 	while (received < size)
1023 	{
1024 		ssize_t ret;
1025 
1026 		// If a timeout is set, wait for that amount of time for data to arrive (or abort)
1027 		if (timeout)
1028 		{
1029 			ret = checkSockets(&set, timeout);
1030 			if (ret < (ssize_t)set.fds.size()
1031 			    || !sock->ready)
1032 			{
1033 				if (ret == 0)
1034 				{
1035 					debug(LOG_NET, "socket (%p) has timed out.", static_cast<void *>(sock));
1036 					setSockErr(ETIMEDOUT);
1037 				}
1038 				debug(LOG_NET, "socket (%p) error.", static_cast<void *>(sock));
1039 				return SOCKET_ERROR;
1040 			}
1041 		}
1042 
1043 		ret = recv(sock->fd[SOCK_CONNECTION], &((char *)buf)[received], size - received, 0);
1044 		sock->ready = false;
1045 		if (ret == 0)
1046 		{
1047 			debug(LOG_NET, "Socket %x disconnected.", sock->fd[SOCK_CONNECTION]);
1048 			sock->readDisconnected = true;
1049 			setSockErr(ECONNRESET);
1050 			return received;
1051 		}
1052 
1053 		if (ret == SOCKET_ERROR)
1054 		{
1055 			switch (getSockErr())
1056 			{
1057 			case EAGAIN:
1058 #if defined(EWOULDBLOCK) && EAGAIN != EWOULDBLOCK
1059 			case EWOULDBLOCK:
1060 #endif
1061 			case EINTR:
1062 				continue;
1063 
1064 			default:
1065 				return SOCKET_ERROR;
1066 			}
1067 		}
1068 
1069 		received += ret;
1070 	}
1071 
1072 	return received;
1073 }
1074 
socketCloseNow(Socket * sock)1075 static void socketCloseNow(Socket *sock)
1076 {
1077 	for (unsigned i = 0; i < ARRAY_SIZE(sock->fd); ++i)
1078 	{
1079 		if (sock->fd[i] != INVALID_SOCKET)
1080 		{
1081 #if defined(WZ_OS_WIN)
1082 			int err = closesocket(sock->fd[i]);
1083 #else
1084 			int err = close(sock->fd[i]);
1085 #endif
1086 			if (err)
1087 			{
1088 				debug(LOG_ERROR, "Failed to close socket %p: %s", static_cast<void *>(sock), strSockError(getSockErr()));
1089 			}
1090 
1091 			/* Make sure that dangling pointers to this
1092 			 * structure don't think they've got their
1093 			 * hands on a valid socket.
1094 			 */
1095 			sock->fd[i] = INVALID_SOCKET;
1096 		}
1097 	}
1098 	delete sock;
1099 }
1100 
socketClose(Socket * sock)1101 void socketClose(Socket *sock)
1102 {
1103 	wzMutexLock(socketThreadMutex);
1104 	//Instead of socketThreadWrites.erase(sock);, try sending the data before actually deleting.
1105 	if (socketThreadWrites.find(sock) != socketThreadWrites.end())
1106 	{
1107 		// Wait until the data is written, then delete the socket.
1108 		sock->deleteLater = true;
1109 	}
1110 	else
1111 	{
1112 		// Delete the socket.
1113 		socketCloseNow(sock);
1114 	}
1115 	wzMutexUnlock(socketThreadMutex);
1116 }
1117 
socketAccept(Socket * sock)1118 Socket *socketAccept(Socket *sock)
1119 {
1120 	unsigned int i;
1121 
1122 	/* Search for a socket that has a pending connection on it and accept
1123 	 * the first one.
1124 	 */
1125 	for (i = 0; i < ARRAY_SIZE(sock->fd); ++i)
1126 	{
1127 		if (sock->fd[i] != INVALID_SOCKET)
1128 		{
1129 			struct sockaddr_storage addr;
1130 			socklen_t addr_len = sizeof(addr);
1131 			Socket *conn;
1132 			unsigned int j;
1133 
1134 #if defined(SOCK_CLOEXEC)
1135 			const SOCKET newConn = accept4(sock->fd[i], (struct sockaddr *)&addr, &addr_len, SOCK_CLOEXEC);
1136 #else
1137 			const SOCKET newConn = accept(sock->fd[i], (struct sockaddr *)&addr, &addr_len);
1138 #endif
1139 			if (newConn == INVALID_SOCKET)
1140 			{
1141 				// Ignore the case where no connection is pending
1142 				if (getSockErr() != EAGAIN
1143 				    && getSockErr() != EWOULDBLOCK)
1144 				{
1145 					debug(LOG_ERROR, "accept failed for socket %p: %s", static_cast<void *>(sock), strSockError(getSockErr()));
1146 				}
1147 
1148 				continue;
1149 			}
1150 
1151 			conn = new Socket;
1152 			if (conn == nullptr)
1153 			{
1154 				debug(LOG_ERROR, "Out of memory!");
1155 				abort();
1156 				return nullptr;
1157 			}
1158 
1159 #if !defined(SOCK_CLOEXEC)
1160 			if (!setSocketInheritable(newConn, false))
1161 			{
1162 				debug(LOG_NET, "Couldn't set socket (%p) inheritable status (false). Ignoring...", static_cast<void *>(conn));
1163 				// ignore and continue
1164 			}
1165 #endif
1166 
1167 			debug(LOG_NET, "setting socket (%p) blocking status (false).", static_cast<void *>(conn));
1168 			if (!setSocketBlocking(newConn, false))
1169 			{
1170 				debug(LOG_NET, "Couldn't set socket (%p) blocking status (false).  Closing.", static_cast<void *>(conn));
1171 				socketClose(conn);
1172 				return nullptr;
1173 			}
1174 
1175 			socketBlockSIGPIPE(newConn, true);
1176 
1177 			// Mark all unused socket handles as invalid
1178 			for (j = 0; j < ARRAY_SIZE(conn->fd); ++j)
1179 			{
1180 				conn->fd[j] = INVALID_SOCKET;
1181 			}
1182 
1183 			conn->fd[SOCK_CONNECTION] = newConn;
1184 
1185 			sock->ready = false;
1186 
1187 			addressToText((const struct sockaddr *)&addr, conn->textAddress, sizeof(conn->textAddress));
1188 			debug(LOG_NET, "Incoming connection from [%s]:/*%%d*/ (FIXME: gives strict-aliasing error)", conn->textAddress/*, (unsigned int)ntohs(((const struct sockaddr_in*)&addr)->sin_port)*/);
1189 			debug(LOG_NET, "Using socket %p", static_cast<void *>(conn));
1190 			return conn;
1191 		}
1192 	}
1193 
1194 	return nullptr;
1195 }
1196 
socketOpen(const SocketAddress * addr,unsigned timeout)1197 Socket *socketOpen(const SocketAddress *addr, unsigned timeout)
1198 {
1199 	unsigned int i;
1200 	int ret;
1201 
1202 	Socket *const conn = new Socket;
1203 	if (conn == nullptr)
1204 	{
1205 		debug(LOG_ERROR, "Out of memory!");
1206 		abort();
1207 		return nullptr;
1208 	}
1209 
1210 	ASSERT(addr != nullptr, "NULL Socket provided");
1211 
1212 	addressToText(addr->ai_addr, conn->textAddress, sizeof(conn->textAddress));
1213 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
1214 # pragma GCC diagnostic push
1215 # pragma GCC diagnostic ignored "-Wcast-align"
1216 #endif
1217 	debug(LOG_NET, "Connecting to [%s]:%d", conn->textAddress, (int)ntohs((reinterpret_cast<const sockaddr_in *>(addr->ai_addr))->sin_port));
1218 #if defined(__GNUC__) && !defined(__INTEL_COMPILER) && !defined(__clang__)
1219 # pragma GCC diagnostic pop
1220 #endif
1221 
1222 	// Mark all unused socket handles as invalid
1223 	for (i = 0; i < ARRAY_SIZE(conn->fd); ++i)
1224 	{
1225 		conn->fd[i] = INVALID_SOCKET;
1226 	}
1227 
1228 	int socket_type = addr->ai_socktype;
1229 #if defined(SOCK_CLOEXEC)
1230 	socket_type |= SOCK_CLOEXEC;
1231 #endif
1232 	conn->fd[SOCK_CONNECTION] = socket(addr->ai_family, socket_type, addr->ai_protocol);
1233 
1234 	if (conn->fd[SOCK_CONNECTION] == INVALID_SOCKET)
1235 	{
1236 		debug(LOG_ERROR, "Failed to create a socket (%p): %s", static_cast<void *>(conn), strSockError(getSockErr()));
1237 		socketClose(conn);
1238 		return nullptr;
1239 	}
1240 
1241 #if !defined(SOCK_CLOEXEC)
1242 	if (!setSocketInheritable(conn->fd[SOCK_CONNECTION], false))
1243 	{
1244 		debug(LOG_NET, "Couldn't set socket (%p) inheritable status (false). Ignoring...", static_cast<void *>(conn));
1245 		// ignore and continue
1246 	}
1247 #endif
1248 
1249 	debug(LOG_NET, "setting socket (%p) blocking status (false).", static_cast<void *>(conn));
1250 	if (!setSocketBlocking(conn->fd[SOCK_CONNECTION], false))
1251 	{
1252 		debug(LOG_NET, "Couldn't set socket (%p) blocking status (false).  Closing.", static_cast<void *>(conn));
1253 		socketClose(conn);
1254 		return nullptr;
1255 	}
1256 
1257 	socketBlockSIGPIPE(conn->fd[SOCK_CONNECTION], true);
1258 
1259 	ret = connect(conn->fd[SOCK_CONNECTION], addr->ai_addr, addr->ai_addrlen);
1260 	if (ret == SOCKET_ERROR)
1261 	{
1262 		fd_set conReady;
1263 #if   defined(WZ_OS_WIN)
1264 		fd_set conFailed;
1265 #endif
1266 
1267 		if ((getSockErr() != EINPROGRESS
1268 		     && getSockErr() != EAGAIN
1269 		     && getSockErr() != EWOULDBLOCK)
1270 #if   defined(WZ_OS_UNIX)
1271 		    || conn->fd[SOCK_CONNECTION] >= FD_SETSIZE
1272 #endif
1273 		    || timeout == 0)
1274 		{
1275 			debug(LOG_NET, "Failed to start connecting: %s, using socket %p", strSockError(getSockErr()), static_cast<void *>(conn));
1276 			socketClose(conn);
1277 			return nullptr;
1278 		}
1279 
1280 		do
1281 		{
1282 			struct timeval tv = {(int)(timeout / 1000), (int)(timeout % 1000) * 1000};  // Cast to int to avoid narrowing needed for C++11.
1283 
1284 			FD_ZERO(&conReady);
1285 			FD_SET(conn->fd[SOCK_CONNECTION], &conReady);
1286 #if   defined(WZ_OS_WIN)
1287 			FD_ZERO(&conFailed);
1288 			FD_SET(conn->fd[SOCK_CONNECTION], &conFailed);
1289 #endif
1290 
1291 #if   defined(WZ_OS_WIN)
1292 			ret = select(conn->fd[SOCK_CONNECTION] + 1, NULL, &conReady, &conFailed, &tv);
1293 #else
1294 			ret = select(conn->fd[SOCK_CONNECTION] + 1, nullptr, &conReady, nullptr, &tv);
1295 #endif
1296 		}
1297 		while (ret == SOCKET_ERROR && getSockErr() == EINTR);
1298 
1299 		if (ret == SOCKET_ERROR)
1300 		{
1301 			debug(LOG_NET, "Failed to wait for connection: %s, socket %p.  Closing.", strSockError(getSockErr()), static_cast<void *>(conn));
1302 			socketClose(conn);
1303 			return nullptr;
1304 		}
1305 
1306 		if (ret == 0)
1307 		{
1308 			setSockErr(ETIMEDOUT);
1309 			debug(LOG_NET, "Timed out while waiting for connection to be established: %s, using socket %p.  Closing.", strSockError(getSockErr()), static_cast<void *>(conn));
1310 			socketClose(conn);
1311 			return nullptr;
1312 		}
1313 
1314 #if   defined(WZ_OS_WIN)
1315 		ASSERT(FD_ISSET(conn->fd[SOCK_CONNECTION], &conReady) || FD_ISSET(conn->fd[SOCK_CONNECTION], &conFailed), "\"sock\" is the only file descriptor in set, it should be the one that is set.");
1316 #else
1317 		ASSERT(FD_ISSET(conn->fd[SOCK_CONNECTION], &conReady), "\"sock\" is the only file descriptor in set, it should be the one that is set.");
1318 #endif
1319 
1320 #if   defined(WZ_OS_WIN)
1321 		if (FD_ISSET(conn->fd[SOCK_CONNECTION], &conFailed))
1322 #elif defined(WZ_OS_UNIX)
1323 		if (connect(conn->fd[SOCK_CONNECTION], addr->ai_addr, addr->ai_addrlen) == SOCKET_ERROR
1324 		    && getSockErr() != EISCONN)
1325 #endif
1326 		{
1327 			debug(LOG_NET, "Failed to connect: %s, with socket %p.  Closing.", strSockError(getSockErr()), static_cast<void *>(conn));
1328 			socketClose(conn);
1329 			return nullptr;
1330 		}
1331 	}
1332 
1333 	return conn;
1334 }
1335 
socketListen(unsigned int port)1336 Socket *socketListen(unsigned int port)
1337 {
1338 	/* Enable the V4 to V6 mapping, but only when available, because it
1339 	 * isn't available on all platforms.
1340 	 */
1341 #if defined(IPV6_V6ONLY)
1342 	static const int ipv6_v6only = 0;
1343 #endif
1344 	static const int so_reuseaddr = 1;
1345 
1346 	struct sockaddr_in addr4;
1347 	struct sockaddr_in6 addr6;
1348 	unsigned int i;
1349 
1350 	Socket *const conn = new Socket;
1351 	if (conn == nullptr)
1352 	{
1353 		debug(LOG_ERROR, "Out of memory!");
1354 		abort();
1355 		return nullptr;
1356 	}
1357 
1358 	// Mark all unused socket handles as invalid
1359 	for (i = 0; i < ARRAY_SIZE(conn->fd); ++i)
1360 	{
1361 		conn->fd[i] = INVALID_SOCKET;
1362 	}
1363 
1364 	strncpy(conn->textAddress, "LISTENING SOCKET", sizeof(conn->textAddress));
1365 
1366 	// Listen on all local IPv4 and IPv6 addresses for the given port
1367 	addr4.sin_family      = AF_INET;
1368 	addr4.sin_port        = htons(port);
1369 	addr4.sin_addr.s_addr = INADDR_ANY;
1370 
1371 	addr6.sin6_family   = AF_INET6;
1372 	addr6.sin6_port     = htons(port);
1373 	addr6.sin6_addr     = in6addr_any;
1374 	addr6.sin6_flowinfo = 0;
1375 	addr6.sin6_scope_id = 0;
1376 
1377 	int socket_type = SOCK_STREAM;
1378 #if defined(SOCK_CLOEXEC)
1379 	socket_type |= SOCK_CLOEXEC;
1380 #endif
1381 	conn->fd[SOCK_IPV4_LISTEN] = socket(addr4.sin_family, socket_type, 0);
1382 	conn->fd[SOCK_IPV6_LISTEN] = socket(addr6.sin6_family, socket_type, 0);
1383 
1384 	if (conn->fd[SOCK_IPV4_LISTEN] == INVALID_SOCKET
1385 	    && conn->fd[SOCK_IPV6_LISTEN] == INVALID_SOCKET)
1386 	{
1387 		debug(LOG_ERROR, "Failed to create an IPv4 and IPv6 (only supported address families) socket (%p): %s.  Closing.", static_cast<void *>(conn), strSockError(getSockErr()));
1388 		socketClose(conn);
1389 		return nullptr;
1390 	}
1391 
1392 	if (conn->fd[SOCK_IPV4_LISTEN] != INVALID_SOCKET)
1393 	{
1394 		debug(LOG_NET, "Successfully created an IPv4 socket (%p)", static_cast<void *>(conn));
1395 	}
1396 
1397 	if (conn->fd[SOCK_IPV6_LISTEN] != INVALID_SOCKET)
1398 	{
1399 		debug(LOG_NET, "Successfully created an IPv6 socket (%p)", static_cast<void *>(conn));
1400 	}
1401 
1402 #if defined(IPV6_V6ONLY)
1403 	if (conn->fd[SOCK_IPV6_LISTEN] != INVALID_SOCKET)
1404 	{
1405 		if (setsockopt(conn->fd[SOCK_IPV6_LISTEN], IPPROTO_IPV6, IPV6_V6ONLY, (const char *)&ipv6_v6only, sizeof(ipv6_v6only)) == SOCKET_ERROR)
1406 		{
1407 			debug(LOG_INFO, "Failed to set IPv6 socket to perform IPv4 to IPv6 mapping. Falling back to using two sockets. Error: %s", strSockError(getSockErr()));
1408 		}
1409 		else
1410 		{
1411 			debug(LOG_NET, "Successfully enabled IPv4 to IPv6 mapping. Cleaning up IPv4 socket.");
1412 #if   defined(WZ_OS_WIN)
1413 			closesocket(conn->fd[SOCK_IPV4_LISTEN]);
1414 #else
1415 			close(conn->fd[SOCK_IPV4_LISTEN]);
1416 #endif
1417 			conn->fd[SOCK_IPV4_LISTEN] = INVALID_SOCKET;
1418 		}
1419 	}
1420 #endif
1421 
1422 	if (conn->fd[SOCK_IPV4_LISTEN] != INVALID_SOCKET)
1423 	{
1424 #if !defined(SOCK_CLOEXEC)
1425 		if (!setSocketInheritable(conn->fd[SOCK_IPV4_LISTEN], false))
1426 		{
1427 			debug(LOG_NET, "Couldn't set socket (%p) inheritable status (false). Ignoring...", static_cast<void *>(conn));
1428 			// ignore and continue
1429 		}
1430 #endif
1431 
1432 		if (setsockopt(conn->fd[SOCK_IPV4_LISTEN], SOL_SOCKET, SO_REUSEADDR, (const char *)&so_reuseaddr, sizeof(so_reuseaddr)) == SOCKET_ERROR)
1433 		{
1434 			debug(LOG_WARNING, "Failed to set SO_REUSEADDR on IPv4 socket. Error: %s", strSockError(getSockErr()));
1435 		}
1436 
1437 		debug(LOG_NET, "setting socket (%p) blocking status (false, IPv4).", static_cast<void *>(conn));
1438 		if (bind(conn->fd[SOCK_IPV4_LISTEN], (const struct sockaddr *)&addr4, sizeof(addr4)) == SOCKET_ERROR
1439 		    || listen(conn->fd[SOCK_IPV4_LISTEN], 5) == SOCKET_ERROR
1440 		    || !setSocketBlocking(conn->fd[SOCK_IPV4_LISTEN], false))
1441 		{
1442 			debug(LOG_ERROR, "Failed to set up IPv4 socket for listening on port %u: %s", port, strSockError(getSockErr()));
1443 #if   defined(WZ_OS_WIN)
1444 			closesocket(conn->fd[SOCK_IPV4_LISTEN]);
1445 #else
1446 			close(conn->fd[SOCK_IPV4_LISTEN]);
1447 #endif
1448 			conn->fd[SOCK_IPV4_LISTEN] = INVALID_SOCKET;
1449 		}
1450 	}
1451 
1452 	if (conn->fd[SOCK_IPV6_LISTEN] != INVALID_SOCKET)
1453 	{
1454 #if !defined(SOCK_CLOEXEC)
1455 		if (!setSocketInheritable(conn->fd[SOCK_IPV6_LISTEN], false))
1456 		{
1457 			debug(LOG_NET, "Couldn't set socket (%p) inheritable status (false). Ignoring...", static_cast<void *>(conn));
1458 			// ignore and continue
1459 		}
1460 #endif
1461 
1462 		if (setsockopt(conn->fd[SOCK_IPV6_LISTEN], SOL_SOCKET, SO_REUSEADDR, (const char *)&so_reuseaddr, sizeof(so_reuseaddr)) == SOCKET_ERROR)
1463 		{
1464 			debug(LOG_INFO, "Failed to set SO_REUSEADDR on IPv6 socket. Error: %s", strSockError(getSockErr()));
1465 		}
1466 
1467 		debug(LOG_NET, "setting socket (%p) blocking status (false, IPv6).", static_cast<void *>(conn));
1468 		if (bind(conn->fd[SOCK_IPV6_LISTEN], (const struct sockaddr *)&addr6, sizeof(addr6)) == SOCKET_ERROR
1469 		    || listen(conn->fd[SOCK_IPV6_LISTEN], 5) == SOCKET_ERROR
1470 		    || !setSocketBlocking(conn->fd[SOCK_IPV6_LISTEN], false))
1471 		{
1472 			debug(LOG_ERROR, "Failed to set up IPv6 socket for listening on port %u: %s", port, strSockError(getSockErr()));
1473 #if   defined(WZ_OS_WIN)
1474 			closesocket(conn->fd[SOCK_IPV6_LISTEN]);
1475 #else
1476 			close(conn->fd[SOCK_IPV6_LISTEN]);
1477 #endif
1478 			conn->fd[SOCK_IPV6_LISTEN] = INVALID_SOCKET;
1479 		}
1480 	}
1481 
1482 	// Check whether we still have at least a single (operating) socket.
1483 	if (conn->fd[SOCK_IPV4_LISTEN] == INVALID_SOCKET
1484 	    && conn->fd[SOCK_IPV6_LISTEN] == INVALID_SOCKET)
1485 	{
1486 		debug(LOG_NET, "No IPv4 or IPv6 sockets created.");
1487 		socketClose(conn);
1488 		return nullptr;
1489 	}
1490 
1491 	return conn;
1492 }
1493 
socketOpenAny(const SocketAddress * addr,unsigned timeout)1494 Socket *socketOpenAny(const SocketAddress *addr, unsigned timeout)
1495 {
1496 	Socket *ret = nullptr;
1497 	while (addr != nullptr && ret == nullptr)
1498 	{
1499 		ret = socketOpen(addr, timeout);
1500 
1501 		addr = addr->ai_next;
1502 	}
1503 
1504 	return ret;
1505 }
1506 
socketArrayOpen(Socket ** sockets,size_t maxSockets,const SocketAddress * addr,unsigned timeout)1507 size_t socketArrayOpen(Socket **sockets, size_t maxSockets, const SocketAddress *addr, unsigned timeout)
1508 {
1509 	size_t i = 0;
1510 	while (i < maxSockets && addr != nullptr)
1511 	{
1512 		if (addr->ai_family == AF_INET || addr->ai_family == AF_INET6)
1513 		{
1514 			sockets[i] = socketOpen(addr, timeout);
1515 			i += sockets[i] != nullptr;
1516 		}
1517 
1518 		addr = addr->ai_next;
1519 	}
1520 	std::fill(sockets + i, sockets + maxSockets, (Socket *)nullptr);
1521 	return i;
1522 }
1523 
socketArrayClose(Socket ** sockets,size_t maxSockets)1524 void socketArrayClose(Socket **sockets, size_t maxSockets)
1525 {
1526 	std::for_each(sockets, sockets + maxSockets, socketClose);     // Close any open sockets.
1527 	std::fill(sockets, sockets + maxSockets, (Socket *)nullptr);      // Set the pointers to NULL.
1528 }
1529 
socketHasIPv4(Socket * sock)1530 WZ_DECL_NONNULL(1) bool socketHasIPv4(Socket *sock)
1531 {
1532 	if (sock->fd[SOCK_IPV4_LISTEN] != INVALID_SOCKET)
1533 	{
1534 		return true;
1535 	}
1536 	else
1537 	{
1538 #if defined(IPV6_V6ONLY)
1539 		if (sock->fd[SOCK_IPV6_LISTEN] != INVALID_SOCKET)
1540 		{
1541 			int ipv6_v6only = 1;
1542 			socklen_t len = sizeof(ipv6_v6only);
1543 			if (getsockopt(sock->fd[SOCK_IPV6_LISTEN], IPPROTO_IPV6, IPV6_V6ONLY, (char *)&ipv6_v6only, &len) == 0)
1544 			{
1545 				return ipv6_v6only == 0;
1546 			}
1547 		}
1548 #endif
1549 		return false;
1550 	}
1551 }
1552 
socketHasIPv6(Socket * sock)1553 WZ_DECL_NONNULL(1) bool socketHasIPv6(Socket *sock)
1554 {
1555 	return sock->fd[SOCK_IPV6_LISTEN] != INVALID_SOCKET;
1556 }
1557 
getSocketTextAddress(Socket const * sock)1558 char const *getSocketTextAddress(Socket const *sock)
1559 {
1560 	return sock->textAddress;
1561 }
1562 
ipv4_AddressString_To_NetBinary(const std::string & ipv4Address)1563 std::vector<unsigned char> ipv4_AddressString_To_NetBinary(const std::string& ipv4Address)
1564 {
1565 	std::vector<unsigned char> binaryForm(sizeof(struct in_addr), 0);
1566 	if (inet_pton(AF_INET, ipv4Address.c_str(), binaryForm.data()) <= 0)
1567 	{
1568 		// inet_pton failed
1569 		binaryForm.clear();
1570 	}
1571 	return binaryForm;
1572 }
1573 
1574 #ifndef INET_ADDRSTRLEN
1575 # define INET_ADDRSTRLEN 16
1576 #endif
1577 
ipv4_NetBinary_To_AddressString(const std::vector<unsigned char> & ip4NetBinaryForm)1578 std::string ipv4_NetBinary_To_AddressString(const std::vector<unsigned char>& ip4NetBinaryForm)
1579 {
1580 	if (ip4NetBinaryForm.size() != sizeof(struct in_addr))
1581 	{
1582 		return "";
1583 	}
1584 
1585 	char buf[INET_ADDRSTRLEN] = {0};
1586 	if (inet_ntop(AF_INET, ip4NetBinaryForm.data(), buf, sizeof(buf)) == nullptr)
1587 	{
1588 		return "";
1589 	}
1590 	std::string ipv4Address = buf;
1591 	return ipv4Address;
1592 }
1593 
ipv6_AddressString_To_NetBinary(const std::string & ipv6Address)1594 std::vector<unsigned char> ipv6_AddressString_To_NetBinary(const std::string& ipv6Address)
1595 {
1596 	std::vector<unsigned char> binaryForm(sizeof(struct in6_addr), 0);
1597 	if (inet_pton(AF_INET6, ipv6Address.c_str(), binaryForm.data()) <= 0)
1598 	{
1599 		// inet_pton failed
1600 		binaryForm.clear();
1601 	}
1602 	return binaryForm;
1603 }
1604 
1605 #ifndef INET6_ADDRSTRLEN
1606 # define INET6_ADDRSTRLEN 46
1607 #endif
1608 
ipv6_NetBinary_To_AddressString(const std::vector<unsigned char> & ip6NetBinaryForm)1609 std::string ipv6_NetBinary_To_AddressString(const std::vector<unsigned char>& ip6NetBinaryForm)
1610 {
1611 	if (ip6NetBinaryForm.size() != sizeof(struct in6_addr))
1612 	{
1613 		return "";
1614 	}
1615 
1616 	char buf[INET6_ADDRSTRLEN] = {0};
1617 	if (inet_ntop(AF_INET6, ip6NetBinaryForm.data(), buf, sizeof(buf)) == nullptr)
1618 	{
1619 		return "";
1620 	}
1621 	std::string ipv6Address = buf;
1622 	return ipv6Address;
1623 }
1624 
resolveHost(const char * host,unsigned int port)1625 SocketAddress *resolveHost(const char *host, unsigned int port)
1626 {
1627 	struct addrinfo *results;
1628 	std::string service;
1629 	struct addrinfo hint;
1630 	int error, flags = 0;
1631 
1632 	hint.ai_family    = AF_UNSPEC;
1633 	hint.ai_socktype  = SOCK_STREAM;
1634 	hint.ai_protocol  = 0;
1635 #ifdef AI_V4MAPPED
1636 	flags             |= AI_V4MAPPED;
1637 #endif
1638 #ifdef AI_ADDRCONFIG
1639 	flags             |= AI_ADDRCONFIG;
1640 #endif
1641 	hint.ai_flags     = flags;
1642 	hint.ai_addrlen   = 0;
1643 	hint.ai_addr      = nullptr;
1644 	hint.ai_canonname = nullptr;
1645 	hint.ai_next      = nullptr;
1646 
1647 	service = astringf("%u", port);
1648 
1649 	error = getaddrinfo(host, service.c_str(), &hint, &results);
1650 	if (error != 0)
1651 	{
1652 		debug(LOG_NET, "getaddrinfo failed for %s:%s: %s", host, service.c_str(), gai_strerror(error));
1653 		return nullptr;
1654 	}
1655 
1656 	return results;
1657 }
1658 
deleteSocketAddress(SocketAddress * addr)1659 void deleteSocketAddress(SocketAddress *addr)
1660 {
1661 	freeaddrinfo(addr);
1662 }
1663 
1664 // ////////////////////////////////////////////////////////////////////////
1665 // setup stuff
SOCKETinit()1666 void SOCKETinit()
1667 {
1668 #if defined(WZ_OS_WIN)
1669 	static bool firstCall = true;
1670 	if (firstCall)
1671 	{
1672 		firstCall = false;
1673 
1674 		static WSADATA stuff;
1675 		WORD ver_required = (2 << 8) + 2;
1676 		if (WSAStartup(ver_required, &stuff) != 0)
1677 		{
1678 			debug(LOG_ERROR, "Failed to initialize Winsock: %s", strSockError(getSockErr()));
1679 			return;
1680 		}
1681 
1682 		winsock2_dll = LoadLibraryA("ws2_32.dll");
1683 		if (winsock2_dll)
1684 		{
1685 			getaddrinfo_dll_func = reinterpret_cast<GETADDRINFO_DLL_FUNC>(reinterpret_cast<void*>(GetProcAddress(winsock2_dll, "getaddrinfo")));
1686 			freeaddrinfo_dll_func = reinterpret_cast<FREEADDRINFO_DLL_FUNC>(reinterpret_cast<void*>(GetProcAddress(winsock2_dll, "freeaddrinfo")));
1687 		}
1688 	}
1689 #endif
1690 
1691 	if (socketThread == nullptr)
1692 	{
1693 		socketThreadQuit = false;
1694 		socketThreadMutex = wzMutexCreate();
1695 		socketThreadSemaphore = wzSemaphoreCreate(0);
1696 		socketThread = wzThreadCreate(socketThreadFunction, nullptr);
1697 		wzThreadStart(socketThread);
1698 	}
1699 }
1700 
SOCKETshutdown()1701 void SOCKETshutdown()
1702 {
1703 	if (socketThread != nullptr)
1704 	{
1705 		wzMutexLock(socketThreadMutex);
1706 		socketThreadQuit = true;
1707 		socketThreadWrites.clear();
1708 		wzMutexUnlock(socketThreadMutex);
1709 		wzSemaphorePost(socketThreadSemaphore);  // Wake up the thread, so it can quit.
1710 		wzThreadJoin(socketThread);
1711 		wzMutexDestroy(socketThreadMutex);
1712 		wzSemaphoreDestroy(socketThreadSemaphore);
1713 		socketThread = nullptr;
1714 	}
1715 
1716 #if defined(WZ_OS_WIN)
1717 	WSACleanup();
1718 
1719 	if (winsock2_dll)
1720 	{
1721 		FreeLibrary(winsock2_dll);
1722 		winsock2_dll = NULL;
1723 		getaddrinfo_dll_func = NULL;
1724 		freeaddrinfo_dll_func = NULL;
1725 	}
1726 #endif
1727 }
1728