1 // Copyright (C)2004 Landmark Graphics Corporation
2 // Copyright (C)2005 Sun Microsystems, Inc.
3 // Copyright (C)2014, 2016, 2018-2019 D. R. Commander
4 //
5 // This library is free software and may be redistributed and/or modified under
6 // the terms of the wxWindows Library License, Version 3.1 or (at your option)
7 // any later version.  The full license is in the LICENSE.txt file included
8 // with this distribution.
9 //
10 // This library is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // wxWindows Library License for more details.
14 
15 #ifndef __SOCKET_H__
16 #define __SOCKET_H__
17 
18 #ifdef _WIN32
19 	#include <winsock2.h>
20 	#include <ws2ipdef.h>
21 #else
22 	#include <netinet/in.h>
23 #endif
24 #ifdef USESSL
25 	#define OPENSSL_NO_KRB5
26 	#include <openssl/ssl.h>
27 	#include <openssl/err.h>
28 	#if !defined(HAVE_DEVURANDOM) && !defined(_WIN32)
29 		#include <openssl/rand.h>
30 	#endif
31 #endif
32 
33 #include "Error.h"
34 #include "Mutex.h"
35 
36 
37 namespace vglutil
38 {
39 	class SockError : public Error
40 	{
41 		public:
42 
43 			#ifdef _WIN32
44 
SockError(const char * method_,int line)45 			SockError(const char *method_, int line) :
46 				Error(method_, (char *)NULL, line)
47 			{
48 				if(!FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM, NULL, WSAGetLastError(),
49 					MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), message, MLEN, NULL))
50 					strncpy(message, "Error in FormatMessage()", MLEN);
51 			}
52 
53 			#else
54 
55 			SockError(const char *method_, int line) :
56 				Error(method_, strerror(errno), line) {}
57 
58 			#endif
59 	};
60 }
61 
62 #define THROW_SOCK()  throw(SockError(__FUNCTION__, __LINE__))
63 
64 
65 #ifdef USESSL
66 
67 namespace vglutil
68 {
69 	class SSLError : public Error
70 	{
71 		public:
72 
SSLError(const char * method_,int line)73 			SSLError(const char *method_, int line) :
74 				Error(method_, (char *)NULL, line)
75 			{
76 				ERR_error_string_n(ERR_get_error(), &message[strlen(message)],
77 					MLEN - strlen(message));
78 			}
79 
SSLError(const char * method_,SSL * ssl,int ret)80 			SSLError(const char *method_, SSL *ssl, int ret) :
81 				Error(method_, (char *)NULL)
82 			{
83 				const char *errorString = NULL;
84 
85 				switch(SSL_get_error(ssl, ret))
86 				{
87 					case SSL_ERROR_NONE:
88 						errorString = "SSL_ERROR_NONE";  break;
89 					case SSL_ERROR_ZERO_RETURN:
90 						errorString = "SSL_ERROR_ZERO_RETURN";  break;
91 					case SSL_ERROR_WANT_READ:
92 						errorString = "SSL_ERROR_WANT_READ";  break;
93 					case SSL_ERROR_WANT_WRITE:
94 						errorString = "SSL_ERROR_WANT_WRITE";  break;
95 					case SSL_ERROR_WANT_CONNECT:
96 						errorString = "SSL_ERROR_WANT_CONNECT";  break;
97 					#ifdef SSL_ERROR_WANT_ACCEPT
98 					case SSL_ERROR_WANT_ACCEPT:
99 						errorString = "SSL_ERROR_WANT_ACCEPT";  break;
100 					#endif
101 					case SSL_ERROR_WANT_X509_LOOKUP:
102 						errorString = "SSL_ERROR_WANT_X509_LOOKUP";  break;
103 					case SSL_ERROR_SYSCALL:
104 						#ifdef _WIN32
105 						if(ret == -1)
106 						{
107 							if(!FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM, NULL,
108 								WSAGetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
109 								message, MLEN, NULL))
110 								strncpy(message, "Error in FormatMessage()", MLEN);
111 							return;
112 						}
113 						#else
114 						if(ret == -1) errorString = strerror(errno);
115 						#endif
116 						else if(ret == 0)
117 							errorString = "SSL_ERROR_SYSCALL (abnormal termination)";
118 						else errorString = "SSL_ERROR_SYSCALL";
119 						break;
120 					case SSL_ERROR_SSL:
121 						ERR_error_string_n(ERR_get_error(), message, MLEN);  return;
122 				}
123 				strncpy(message, errorString, MLEN);
124 			}
125 	};
126 }
127 
128 #define THROW_SSL()  throw(SSLError(__FUNCTION__, __LINE__))
129 
130 #endif  // USESSL
131 
132 
133 #ifndef _WIN32
134 typedef int SOCKET;
135 #endif
136 
137 namespace vglutil
138 {
139 	class Socket
140 	{
141 		public:
142 
143 			Socket(bool doSSL, bool ipv6);
144 			#ifdef USESSL
145 			Socket(SOCKET sd, SSL *ssl);
146 			#else
147 			Socket(SOCKET sd);
148 			#endif
149 			~Socket(void);
150 			void close(void);
151 			void connect(char *serverName, unsigned short port);
152 			unsigned short findPort(void);
153 			unsigned short listen(unsigned short port, bool reuseAddr = false);
154 			Socket *accept(void);
155 			void send(char *buf, int len);
156 			void recv(char *buf, int len);
157 			const char *remoteName(void);
158 
159 		private:
160 
161 			unsigned short setupListener(unsigned short port, bool reuseAddr);
162 
163 			#ifdef USESSL
164 
165 			#if OPENSSL_VERSION_NUMBER < 0x10100000L
lockingCallback(int mode,int type,const char * file,int line)166 			static void lockingCallback(int mode, int type, const char *file,
167 				int line)
168 			{
169 				if(mode & CRYPTO_LOCK) cryptoLock[type].lock();
170 				else cryptoLock[type].unlock();
171 			}
172 			#endif
173 
174 			static bool sslInit;
175 			#if OPENSSL_VERSION_NUMBER < 0x10100000L
176 			static CriticalSection cryptoLock[CRYPTO_NUM_LOCKS];
177 			#endif
178 			bool doSSL;  SSL_CTX *sslctx;  SSL *ssl;
179 
180 			#endif
181 
182 			static const int MAXCONN = 1024;
183 			static int instanceCount;
184 			static CriticalSection mutex;
185 			SOCKET sd;
186 			char remoteNameBuf[INET6_ADDRSTRLEN];
187 			bool ipv6;
188 	};
189 }
190 
191 #endif  // __SOCKET_H__
192