1 #include "Core.h"
2 
3 #ifdef PLATFORM_WIN32
4 #include <winsock2.h>
5 #include <ws2tcpip.h>
6 #endif
7 
8 #ifdef PLATFORM_POSIX
9 #include <arpa/inet.h>
10 #endif
11 
12 namespace Upp {
13 
14 #ifdef PLATFORM_WIN32
15 #pragma comment(lib, "ws2_32.lib")
16 #endif
17 
18 #define LLOG(x)  // LOG("TCP " << x)
19 
20 IpAddrInfo::Entry IpAddrInfo::pool[IpAddrInfo::COUNT];
21 
22 Mutex IpAddrInfoPoolMutex;
23 
EnterPool()24 void IpAddrInfo::EnterPool()
25 {
26 	IpAddrInfoPoolMutex.Enter();
27 }
28 
LeavePool()29 void IpAddrInfo::LeavePool()
30 {
31 	IpAddrInfoPoolMutex.Leave();
32 }
33 
sGetAddrInfo(const char * host,const char * port,int family,addrinfo ** result)34 int sGetAddrInfo(const char *host, const char *port, int family, addrinfo **result)
35 {
36 	if(!host || !*host)
37 		return EAI_NONAME;
38 	addrinfo hints;
39 	memset(&hints, 0, sizeof(addrinfo));
40 	const static int FamilyToAF[] = { AF_UNSPEC, AF_INET, AF_INET6 };
41 	hints.ai_family = FamilyToAF[(family > 0 && family < __countof(FamilyToAF)) ? family : 0];
42 	hints.ai_socktype = SOCK_STREAM;
43 	hints.ai_protocol = IPPROTO_TCP;
44 
45 	return getaddrinfo(host, port, &hints, result);
46 }
47 
Thread(void * ptr)48 auxthread_t auxthread__ IpAddrInfo::Thread(void *ptr)
49 {
50 	Entry *entry = (Entry *)ptr;
51 	EnterPool();
52 	if(entry->status == WORKING) {
53 		char host[1025];
54 		char port[257];
55 		int family = entry->family;
56 		strcpy(host, entry->host);
57 		strcpy(port, entry->port);
58 		LeavePool();
59 		addrinfo *result;
60 		if(sGetAddrInfo(host, port, family, &result) == 0 && result) {
61 			EnterPool();
62 			if(entry->status == WORKING) {
63 				entry->addr = result;
64 				entry->status = RESOLVED;
65 			}
66 			else {
67 				freeaddrinfo(result);
68 				entry->status = EMPTY;
69 			}
70 		}
71 		else {
72 			EnterPool();
73 			if(entry->status == CANCELED)
74 				entry->status = EMPTY;
75 			else
76 				entry->status = FAILED;
77 		}
78 	}
79 	LeavePool();
80 	return 0;
81 }
82 
Execute(const String & host,int port,int family)83 bool IpAddrInfo::Execute(const String& host, int port, int family)
84 {
85 	Clear();
86 	entry = exe;
87 	addrinfo *result;
88 	entry->addr = sGetAddrInfo(~host, ~AsString(port), family, &result) == 0 ? result : NULL;
89 	entry->status = entry->addr ? RESOLVED : FAILED;
90 	return entry->addr;
91 }
92 
Start()93 void IpAddrInfo::Start()
94 {
95 	if(entry)
96 		return;
97 	EnterPool();
98 	for(int i = 0; i < COUNT; i++) {
99 		Entry *e = pool + i;
100 		if(e->status == EMPTY) {
101 			entry = e;
102 			e->addr = NULL;
103 			if(host.GetCount() > 1024 || port.GetCount() > 256)
104 				e->status = FAILED;
105 			else {
106 				e->status = WORKING;
107 				e->host = host;
108 				e->port = port;
109 				e->family = family;
110 				StartAuxThread(&IpAddrInfo::Thread, e);
111 			}
112 			break;
113 		}
114 	}
115 	LeavePool();
116 }
117 
Start(const String & host_,int port_,int family_)118 void IpAddrInfo::Start(const String& host_, int port_, int family_)
119 {
120 	Clear();
121 	port = AsString(port_);
122 	host = host_;
123 	family = family_;
124 	Start();
125 }
126 
InProgress()127 bool IpAddrInfo::InProgress()
128 {
129 	if(!entry) {
130 		Start();
131 		return true;
132 	}
133 	EnterPool();
134 	int s = entry->status;
135 	LeavePool();
136 	return s == WORKING;
137 }
138 
GetResult() const139 addrinfo *IpAddrInfo::GetResult() const
140 {
141 	EnterPool();
142 	addrinfo *ai = entry ? entry->addr : NULL;
143 	LeavePool();
144 	return ai;
145 }
146 
Clear()147 void IpAddrInfo::Clear()
148 {
149 	EnterPool();
150 	if(entry) {
151 		if(entry->status == RESOLVED && entry->addr)
152 			freeaddrinfo(entry->addr);
153 		if(entry->status == WORKING)
154 			entry->status = CANCELED;
155 		else
156 			entry->status = EMPTY;
157 		entry = NULL;
158 	}
159 	LeavePool();
160 }
161 
IpAddrInfo()162 IpAddrInfo::IpAddrInfo()
163 {
164 	TcpSocket::Init();
165 	entry = NULL;
166 }
167 
168 #ifdef PLATFORM_POSIX
169 
170 #define SOCKERR(x) x
171 
TcpSocketErrorDesc(int code)172 const char *TcpSocketErrorDesc(int code)
173 {
174 	return strerror(code);
175 }
176 
GetErrorCode()177 int TcpSocket::GetErrorCode()
178 {
179 	return errno;
180 }
181 
182 #else
183 
184 #define SOCKERR(x) WSA##x
185 
TcpSocketErrorDesc(int code)186 const char *TcpSocketErrorDesc(int code)
187 {
188 	static Tuple<int, const char *> err[] = {
189 		{ WSAEINTR,                 "Interrupted function call." },
190 		{ WSAEACCES,                "Permission denied." },
191 		{ WSAEFAULT,                "Bad address." },
192 		{ WSAEINVAL,                "Invalid argument." },
193 		{ WSAEMFILE,                "Too many open files." },
194 		{ WSAEWOULDBLOCK,           "Resource temporarily unavailable." },
195 		{ WSAEINPROGRESS,           "Operation now in progress." },
196 		{ WSAEALREADY,              "Operation already in progress." },
197 		{ WSAENOTSOCK,              "TcpSocket operation on nonsocket." },
198 		{ WSAEDESTADDRREQ,          "Destination address required." },
199 		{ WSAEMSGSIZE,              "Message too long." },
200 		{ WSAEPROTOTYPE,            "Protocol wrong type for socket." },
201 		{ WSAENOPROTOOPT,           "Bad protocol option." },
202 		{ WSAEPROTONOSUPPORT,       "Protocol not supported." },
203 		{ WSAESOCKTNOSUPPORT,       "TcpSocket type not supported." },
204 		{ WSAEOPNOTSUPP,            "Operation not supported." },
205 		{ WSAEPFNOSUPPORT,          "Protocol family not supported." },
206 		{ WSAEAFNOSUPPORT,          "Address family not supported by protocol family." },
207 		{ WSAEADDRINUSE,            "Address already in use." },
208 		{ WSAEADDRNOTAVAIL,         "Cannot assign requested address." },
209 		{ WSAENETDOWN,              "Network is down." },
210 		{ WSAENETUNREACH,           "Network is unreachable." },
211 		{ WSAENETRESET,             "Network dropped connection on reset." },
212 		{ WSAECONNABORTED,          "Software caused connection abort." },
213 		{ WSAECONNRESET,            "Connection reset by peer." },
214 		{ WSAENOBUFS,               "No buffer space available." },
215 		{ WSAEISCONN,               "TcpSocket is already connected." },
216 		{ WSAENOTCONN,              "TcpSocket is not connected." },
217 		{ WSAESHUTDOWN,             "Cannot send after socket shutdown." },
218 		{ WSAETIMEDOUT,             "Connection timed out." },
219 		{ WSAECONNREFUSED,          "Connection refused." },
220 		{ WSAEHOSTDOWN,             "Host is down." },
221 		{ WSAEHOSTUNREACH,          "No route to host." },
222 		{ WSAEPROCLIM,              "Too many processes." },
223 		{ WSASYSNOTREADY,           "Network subsystem is unavailable." },
224 		{ WSAVERNOTSUPPORTED,       "Winsock.dll version out of range." },
225 		{ WSANOTINITIALISED,        "Successful WSAStartup not yet performed." },
226 		{ WSAEDISCON,               "Graceful shutdown in progress." },
227 		{ WSATYPE_NOT_FOUND,        "Class type not found." },
228 		{ WSAHOST_NOT_FOUND,        "Host not found." },
229 		{ WSATRY_AGAIN,             "Nonauthoritative host not found." },
230 		{ WSANO_RECOVERY,           "This is a nonrecoverable error." },
231 		{ WSANO_DATA,               "Valid name, no data record of requested type." },
232 		{ WSASYSCALLFAILURE,        "System call failure." },
233 	};
234 	const Tuple<int, const char *> *x = FindTuple(err, __countof(err), code);
235 	return x ? x->b : "Unknown error code.";
236 }
237 
GetErrorCode()238 int TcpSocket::GetErrorCode()
239 {
240 	return WSAGetLastError();
241 }
242 
243 #endif
244 
TcpSocketInit()245 void TcpSocketInit()
246 {
247 #ifdef PLATFORM_WIN32
248 	ONCELOCK {
249 		WSADATA wsadata;
250 		WSAStartup(MAKEWORD(2, 2), &wsadata);
251 	}
252 #endif
253 #ifdef PLATFORM_POSIX
254 	signal(SIGPIPE, SIG_IGN);
255 #endif
256 }
257 
Init()258 void TcpSocket::Init()
259 {
260 	TcpSocketInit();
261 }
262 
Reset()263 void TcpSocket::Reset()
264 {
265 	LLOG("Reset");
266 	is_eof = false;
267 	socket = INVALID_SOCKET;
268 	ipv6 = false;
269 	ptr = end = buffer;
270 	is_error = false;
271 	is_abort = false;
272 	is_timeout = false;
273 	mode = NONE;
274 	ssl.Clear();
275 	sslinfo.Clear();
276 	cert = pkey = Null;
277 #if defined(PLATFORM_WIN32) || defined(PLATFORM_BSD)
278 	connection_start = Null;
279 #endif
280 	ssl_start = Null;
281 }
282 
TcpSocket()283 TcpSocket::TcpSocket()
284 {
285 	ClearError();
286 	Reset();
287 	timeout = global_timeout = start_time = Null;
288 	waitstep = 10;
289 	asn1 = false;
290 }
291 
SetupSocket()292 bool TcpSocket::SetupSocket()
293 {
294 #ifdef PLATFORM_WIN32
295 	connection_start = msecs();
296 	u_long arg = 1;
297 	if(ioctlsocket(socket, FIONBIO, &arg)) {
298 		SetSockError("ioctlsocket(FIO[N]BIO)");
299 		return false;
300 	}
301 #else
302 	#ifdef PLATFORM_BSD
303 		connection_start = msecs();
304 	#endif
305 	if(fcntl(socket, F_SETFL, (fcntl(socket, F_GETFL, 0) | O_NONBLOCK))) {
306 		SetSockError("fcntl(O_[NON]BLOCK)");
307 		return false;
308 	}
309 #endif
310 	return true;
311 }
312 
Open(int family,int type,int protocol)313 bool TcpSocket::Open(int family, int type, int protocol)
314 {
315 	Init();
316 	Close();
317 	ClearError();
318 	if((socket = ::socket(family, type, protocol)) == INVALID_SOCKET) {
319 		SetSockError("open");
320 		return false;
321 	}
322 	LLOG("TcpSocket::Data::Open() -> " << (int)socket);
323 	return SetupSocket();
324 }
325 
Listen(int port,int listen_count,bool ipv6_,bool reuse,void * addr)326 bool TcpSocket::Listen(int port, int listen_count, bool ipv6_, bool reuse, void *addr)
327 {
328 	Close();
329 	Init();
330 	Reset();
331 
332 	ipv6 = ipv6_;
333 	if(!Open(ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0))
334 		return false;
335 	sockaddr_in sin;
336 #ifdef PLATFORM_WIN32
337 	SOCKADDR_IN6 sin6;
338 	if(ipv6 && IsWinVista())
339 #else
340 	sockaddr_in6 sin6;
341 	if(ipv6)
342 #endif
343 	{
344 		Zero(sin6);
345 		sin6.sin6_family = AF_INET6;
346 		sin6.sin6_port = htons(port);
347 		sin6.sin6_addr = addr?(*(in6_addr*)addr):in6addr_any;
348 	}
349 	else {
350 		Zero(sin);
351 		sin.sin_family = AF_INET;
352 		sin.sin_port = htons(port);
353 		sin.sin_addr.s_addr = addr?(*(uint32*)addr):htonl(INADDR_ANY);
354 	}
355 	if(reuse) {
356 		int optval = 1;
357 		setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (const char *)&optval, sizeof(optval));
358 	}
359 	if(bind(socket, ipv6 ? (const sockaddr *)&sin6 : (const sockaddr *)&sin,
360 	        ipv6 ? sizeof(sin6) : sizeof(sin))) {
361 		SetSockError(Format("bind(port=%d)", port));
362 		return false;
363 	}
364 	if(listen(socket, listen_count)) {
365 		SetSockError(Format("listen(port=%d, count=%d)", port, listen_count));
366 		return false;
367 	}
368 	return true;
369 }
370 
Listen(const IpAddrInfo & addr,int port,int listen_count,bool ipv6,bool reuse)371 bool TcpSocket::Listen(const IpAddrInfo& addr, int port, int listen_count, bool ipv6, bool reuse)
372 {
373 	addrinfo *a = addr.GetResult();
374 	return Listen(port, listen_count, ipv6, reuse, &(((sockaddr_in*)a->ai_addr)->sin_addr.s_addr));
375 }
376 
Accept(TcpSocket & ls)377 bool TcpSocket::Accept(TcpSocket& ls)
378 {
379 	Close();
380 	Init();
381 	Reset();
382 	ASSERT(ls.IsOpen());
383 	int et = GetEndTime();
384 	for(;;) {
385 		int h = ls.GetTimeout();
386 		bool b = ls.Timeout(timeout).Wait(WAIT_READ, et);
387 		ls.Timeout(h);
388 		if(!b) // timeout
389 			return false;
390 		socket = accept(ls.GetSOCKET(), NULL, NULL);
391 		if(socket != INVALID_SOCKET)
392 			break;
393 		if(!WouldBlock() && GetErrorCode() != SOCKERR(EINTR)) { // In prefork condition, Wait is not enough, as other process can accept
394 			SetSockError("accept");
395 			return false;
396 		}
397 	}
398 	mode = ACCEPT;
399 	return SetupSocket();
400 }
401 
GetPeerAddr() const402 String TcpSocket::GetPeerAddr() const
403 {
404 	if(!IsOpen())
405 		return Null;
406 	sockaddr_in addr;
407 	socklen_t l = sizeof(addr);
408 	if(getpeername(socket, (sockaddr *)&addr, &l) != 0)
409 		return Null;
410 	if(l > sizeof(addr))
411 		return Null;
412 #ifdef PLATFORM_WIN32
413 	return inet_ntoa(addr.sin_addr);
414 #else
415 	char h[200];
416 	return inet_ntop(AF_INET, &addr.sin_addr, h, 200);
417 #endif
418 }
419 
NoDelay()420 void TcpSocket::NoDelay()
421 {
422 	ASSERT(IsOpen());
423 	int __true = 1;
424 	LLOG("NoDelay(" << (int)socket << ")");
425 	if(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (const char *)&__true, sizeof(__true)))
426 		SetSockError("setsockopt(TCP_NODELAY)");
427 }
428 
Linger(int msecs)429 void TcpSocket::Linger(int msecs)
430 {
431 	ASSERT(IsOpen());
432 	linger ls;
433 	ls.l_onoff = !IsNull(msecs) ? 1 : 0;
434 	ls.l_linger = !IsNull(msecs) ? (msecs + 999) / 1000 : 0;
435 	if(setsockopt(socket, SOL_SOCKET, SO_LINGER, reinterpret_cast<const char *>(&ls), sizeof(ls)))
436 		SetSockError("setsockopt(SO_LINGER)");
437 }
438 
Attach(SOCKET s)439 void TcpSocket::Attach(SOCKET s)
440 {
441 	Close();
442 	socket = s;
443 }
444 
RawConnect(addrinfo * arp)445 bool TcpSocket::RawConnect(addrinfo *arp)
446 {
447 	if(!arp) {
448 		SetSockError("connect", -1, "not found");
449 		return false;
450 	}
451 	String err;
452 	for(int pass = 0; pass < 2; pass++) {
453 		addrinfo *rp = arp;
454 		while(rp) {
455 			if(rp->ai_family == AF_INET == !pass && // Try to connect IPv4 in the first pass
456 			   Open(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) {
457 				if(connect(socket, rp->ai_addr, (int)rp->ai_addrlen) == 0 ||
458 				   GetErrorCode() == SOCKERR(EINPROGRESS) || GetErrorCode() == SOCKERR(EWOULDBLOCK)
459 				) {
460 					mode = CONNECT;
461 					return true;
462 				}
463 				if(err.GetCount())
464 					err << '\n';
465 				err << TcpSocketErrorDesc(GetErrorCode());
466 				Close();
467 			}
468 			rp = rp->ai_next;
469 		}
470     }
471 	SetSockError("connect", -1, Nvl(err, "failed"));
472 	return false;
473 }
474 
475 
Connect(IpAddrInfo & info)476 bool TcpSocket::Connect(IpAddrInfo& info)
477 {
478 	LLOG("Connect addrinfo");
479 	Init();
480 	Reset();
481 	addrinfo *result = info.GetResult();
482 	return RawConnect(result);
483 }
484 
Connect(const char * host,int port)485 bool TcpSocket::Connect(const char *host, int port)
486 {
487 	LLOG("Connect(" << host << ':' << port << ')');
488 	Close();
489 	Init();
490 	Reset();
491 	IpAddrInfo info;
492 	if(!info.Execute(host, port)) {
493 		SetSockError(Format("getaddrinfo(%s) failed", host));
494 		return false;
495 	}
496 	return Connect(info);
497 }
498 
WaitConnect()499 bool TcpSocket::WaitConnect()
500 {
501 	if(WaitWrite()) {
502 		int optval = 0;
503 		socklen_t optlen = sizeof(optval);
504 		if (getsockopt(GetSOCKET(), SOL_SOCKET, SO_ERROR, (char*)&optval, &optlen) == 0) {
505 			if (optval == 0)
506 				return true;
507 			else {
508 				SetSockError("wait connect", -1, Nvl(String(TcpSocketErrorDesc(optval)), "failed"));
509 				return false;
510 			}
511 		}
512 	}
513 	return false;
514 }
515 
RawClose()516 void TcpSocket::RawClose()
517 {
518 	LLOG("close " << (int)socket);
519 	if(socket != INVALID_SOCKET) {
520 		int res;
521 #if defined(PLATFORM_WIN32)
522 		res = closesocket(socket);
523 #elif defined(PLATFORM_POSIX)
524 		res = close(socket);
525 #else
526 	#error Unsupported platform
527 #endif
528 		if(res && !IsError())
529 			SetSockError("close");
530 		socket = INVALID_SOCKET;
531 	}
532 }
533 
Close()534 void TcpSocket::Close()
535 {
536 	if(ssl)
537 		ssl->Close();
538 	else
539 		RawClose();
540 	ssl.Clear();
541 }
542 
WouldBlock()543 bool TcpSocket::WouldBlock()
544 {
545 	int c = GetErrorCode();
546 #ifdef PLATFORM_POSIX
547 #ifdef PLATFORM_BSD
548 		if(c == SOCKERR(ENOTCONN) && !IsNull(connection_start) && msecs(connection_start) < 20000)
549 			return true;
550 #endif
551 	return c == SOCKERR(EWOULDBLOCK) || c == SOCKERR(EAGAIN);
552 #endif
553 #ifdef PLATFORM_WIN32
554 	if(c == SOCKERR(ENOTCONN) && !IsNull(connection_start) && msecs(connection_start) < 20000) {
555 		LLOG("ENOTCONN issue");
556 		return true;
557 	}
558 	return c == SOCKERR(EWOULDBLOCK);
559 #endif
560 }
561 
RawSend(const void * buf,int amount)562 int TcpSocket::RawSend(const void *buf, int amount)
563 {
564 	int res = send(socket, (const char *)buf, amount, 0);
565 	if(res < 0 && WouldBlock())
566 		res = 0;
567 	else
568 	if(res == 0 || res < 0)
569 		SetSockError("send");
570 	return res;
571 }
572 
Send(const void * buf,int amount)573 int TcpSocket::Send(const void *buf, int amount)
574 {
575 	if(SSLHandshake())
576 		return 0;
577 	return ssl ? ssl->Send(buf, amount) : RawSend(buf, amount);
578 }
579 
Shutdown()580 void TcpSocket::Shutdown()
581 {
582 	ASSERT(IsOpen());
583 	if(shutdown(socket, SD_SEND))
584 		SetSockError("shutdown(SD_SEND)");
585 }
586 
GetHostName()587 String TcpSocket::GetHostName()
588 {
589 	Init();
590 	char buffer[256];
591 	gethostname(buffer, __countof(buffer));
592 	return buffer;
593 }
594 
IsGlobalTimeout()595 bool TcpSocket::IsGlobalTimeout()
596 {
597 	if(!IsNull(global_timeout) && msecs() - start_time > global_timeout) {
598 		SetSockError("wait", ERROR_GLOBAL_TIMEOUT, "Timeout");
599 		return true;
600 	}
601 	return false;
602 }
603 
RawWait(dword flags,int end_time)604 bool TcpSocket::RawWait(dword flags, int end_time)
605 { // wait till end_time
606 	LLOG("RawWait end_time: " << end_time << ", current time " << msecs() << ", to wait: " << end_time - msecs());
607 	is_timeout = false;
608 	if((flags & WAIT_READ) && ptr != end)
609 		return true;
610 	if(socket == INVALID_SOCKET)
611 		return false;
612 	for(;;) {
613 		if(IsError() || IsAbort())
614 			return false;
615 		int to = end_time - msecs();
616 		if(WhenWait)
617 			to = waitstep;
618 		timeval *tvalp = NULL;
619 		timeval tval;
620 		if(end_time != INT_MAX || WhenWait) {
621 			to = max(to, 0);
622 			tval.tv_sec = to / 1000;
623 			tval.tv_usec = 1000 * (to % 1000);
624 			tvalp = &tval;
625 			LLOG("RawWait timeout: " << to);
626 		}
627 		fd_set fdsetr[1], fdsetw[1], fdsetx[1];;
628 		FD_ZERO(fdsetr);
629 		if(flags & WAIT_READ)
630 			FD_SET(socket, fdsetr);
631 		FD_ZERO(fdsetw);
632 		if(flags & WAIT_WRITE)
633 			FD_SET(socket, fdsetw);
634 		FD_ZERO(fdsetx);
635 		FD_SET(socket, fdsetx);
636 		int avail = select((int)socket + 1, fdsetr, fdsetw, fdsetx, tvalp);
637 		LLOG("Wait select avail: " << avail);
638 		if(avail < 0 && GetErrorCode() != SOCKERR(EINTR)) {
639 			SetSockError("wait");
640 			return false;
641 		}
642 		if(avail > 0) {
643 		#if defined(PLATFORM_WIN32) || defined(PLATFORM_BSD)
644 			connection_start = Null;
645 		#endif
646 			return true;
647 		}
648 		if(IsGlobalTimeout() || to <= 0 && timeout) {
649 			is_timeout = true;
650 			return false;
651 		}
652 		WhenWait();
653 		if(timeout == 0) {
654 			is_timeout = true;
655 			return false;
656 		}
657 	}
658 }
659 
GlobalTimeout(int ms)660 TcpSocket& TcpSocket::GlobalTimeout(int ms)
661 {
662 	start_time = msecs();
663 	global_timeout = ms;
664 	return *this;
665 }
666 
Wait(dword flags,int end_time)667 bool TcpSocket::Wait(dword flags, int end_time)
668 {
669 	return ssl ? ssl->Wait(flags, end_time) : RawWait(flags, end_time);
670 }
671 
GetEndTime() const672 int  TcpSocket::GetEndTime() const
673 { // Compute time limit for operation, based on global timeout and per-operation timeout settings
674 	int o = min(IsNull(global_timeout) ? INT_MAX : start_time + global_timeout,
675 	            IsNull(timeout) ? INT_MAX : msecs() + timeout);
676 #if defined(PLATFORM_WIN32) || defined(PLATFORM_BSD)
677 	if(GetErrorCode() == SOCKERR(ENOTCONN) && !IsNull(connection_start))
678 		if(msecs(connection_start) < 20000)
679 			o = connection_start + 20000;
680 #endif
681 	return o;
682 }
683 
Wait(dword flags)684 bool TcpSocket::Wait(dword flags)
685 {
686 	return Wait(flags, GetEndTime());
687 }
688 
Put(const void * s_,int length)689 int TcpSocket::Put(const void *s_, int length)
690 {
691 	LLOG("Put " << socket << ": " << length);
692 	ASSERT(IsOpen());
693 	const char *s = (const char *)s_;
694 	if(length < 0 && s)
695 		length = (int)strlen(s);
696 	if(!s || length <= 0 || IsError() || IsAbort())
697 		return 0;
698 	done = 0;
699 	bool peek = false;
700 	int end_time = GetEndTime();
701 	while(done < length) {
702 		if(peek && !Wait(WAIT_WRITE, end_time))
703 			return done;
704 		peek = false;
705 		int count = Send(s + done, length - done);
706 		if(IsError() || timeout == 0 && count == 0 && peek)
707 			return done;
708 		if(count > 0)
709 			done += count;
710 		else
711 			peek = true;
712 	}
713 	LLOG("//Put() -> " << done);
714 	return done;
715 }
716 
PutAll(const void * s,int len)717 bool TcpSocket::PutAll(const void *s, int len)
718 {
719 	if(Put(s, len) != len) {
720 		if(!IsError())
721 			SetSockError("GePutAll", -1, "timeout");
722 		return false;
723 	}
724 	return true;
725 }
726 
PutAll(const String & s)727 bool TcpSocket::PutAll(const String& s)
728 {
729 	if(Put(s) != s.GetCount()) {
730 		if(!IsError())
731 			SetSockError("GePutAll", -1, "timeout");
732 		return false;
733 	}
734 	return true;
735 }
736 
RawRecv(void * buf,int amount)737 int TcpSocket::RawRecv(void *buf, int amount)
738 {
739 	int res = recv(socket, (char *)buf, amount, 0);
740 	if(res == 0)
741 		is_eof = true;
742 	else
743 	if(res < 0 && WouldBlock())
744 		res = 0;
745 	else
746 	if(res < 0)
747 		SetSockError("recv");
748 	LLOG("recv(" << socket << "): " << res << " bytes: "
749 	     << AsCString((char *)buf, (char *)buf + min(res, 16))
750 	     << (res ? "" : IsEof() ? ", EOF" : ", WOULDBLOCK"));
751 	return res;
752 }
753 
Recv(void * buffer,int maxlen)754 int TcpSocket::Recv(void *buffer, int maxlen)
755 {
756 	if(SSLHandshake())
757 		return 0;
758 	return ssl ? ssl->Recv(buffer, maxlen) : RawRecv(buffer, maxlen);
759 }
760 
ReadBuffer(int end_time)761 void TcpSocket::ReadBuffer(int end_time)
762 {
763 	ptr = end = buffer;
764 	if(Wait(WAIT_READ, end_time))
765 		end = buffer + Recv(buffer, BUFFERSIZE);
766 }
767 
IsEof() const768 bool TcpSocket::IsEof() const
769 {
770 	return is_eof && ptr == end || IsAbort() || !IsOpen() || IsError();
771 }
772 
Get_()773 int TcpSocket::Get_()
774 {
775 	if(!IsOpen() || IsError() || IsEof() || IsAbort())
776 		return -1;
777 	ReadBuffer(GetEndTime());
778 	return ptr < end ? (byte)*ptr++ : -1;
779 }
780 
Peek_(int end_time)781 int TcpSocket::Peek_(int end_time)
782 {
783 	if(!IsOpen() || IsError() || IsEof() || IsAbort())
784 		return -1;
785 	ReadBuffer(end_time);
786 	return ptr < end ? (byte)*ptr : -1;
787 }
788 
Peek_()789 int TcpSocket::Peek_()
790 {
791 	return Peek_(GetEndTime());
792 }
793 
Get(void * buffer,int count)794 int TcpSocket::Get(void *buffer, int count)
795 {
796 	LLOG("Get " << count);
797 
798 	if(!IsOpen() || IsError() || IsEof() || IsAbort())
799 		return 0;
800 
801 	int l = (int)(end - ptr);
802 	done = 0;
803 	if(l > 0) {
804 		if(l < count) {
805 			memcpy(buffer, ptr, l);
806 			done += l;
807 			ptr = end;
808 		}
809 		else {
810 			memcpy(buffer, ptr, count);
811 			ptr += count;
812 			return count;
813 		}
814 	}
815 	int end_time = GetEndTime();
816 	while(done < count && !IsError() && !IsEof()) {
817 		if(!Wait(WAIT_READ, end_time))
818 			break;
819 		int part = Recv((char *)buffer + done, count - done);
820 		if(part > 0)
821 			done += part;
822 		if(timeout == 0)
823 			break;
824 	}
825 	return done;
826 }
827 
Get(int count)828 String TcpSocket::Get(int count)
829 {
830 	if(count == 0)
831 		return Null;
832 	StringBuffer out(count);
833 	int done = Get(out, count);
834 	if(!done && IsEof())
835 		return String::GetVoid();
836 	out.SetLength(done);
837 	return String(out);
838 }
839 
GetAll(void * buffer,int len)840 bool  TcpSocket::GetAll(void *buffer, int len)
841 {
842 	if(Get(buffer, len) == len)
843 		return true;
844 	if(!IsError())
845 		SetSockError("GetAll", -1, "timeout");
846 	return false;
847 }
848 
GetAll(int len)849 String TcpSocket::GetAll(int len)
850 {
851 	String s = Get(len);
852 	if(s.GetCount() != len) {
853 		if(!IsError())
854 			SetSockError("GetAll", -1, "timeout");
855 		return String::GetVoid();
856 	}
857 	return s;
858 }
859 
GetLine(int maxlen)860 String TcpSocket::GetLine(int maxlen)
861 {
862 	LLOG("GetLine " << maxlen << ", iseof " << IsEof());
863 	String ln;
864 	int end_time = GetEndTime();
865 	for(;;) {
866 		if(IsEof())
867 			return String::GetVoid();
868 		int c = Peek(end_time);
869 		if(c < 0) {
870 			if(!IsError()) {
871 				if(msecs() > end_time)
872 					SetSockError("GetLine", -1, "timeout");
873 				else
874 					continue;
875 			}
876 			return String::GetVoid();
877 		}
878 		Get();
879 		if(c == '\n')
880 			return ln;
881 		if(ln.GetCount() >= maxlen) {
882 			if(!IsError())
883 				SetSockError("GetLine", -1, "maximal length exceeded");
884 			return String::GetVoid();
885 		}
886 		if(c != '\r')
887 			ln.Cat(c);
888 	}
889 }
890 
SetSockError(const char * context,int code,const char * errdesc)891 void TcpSocket::SetSockError(const char *context, int code, const char *errdesc)
892 {
893 	errorcode = code;
894 	errordesc.Clear();
895 	if(socket != INVALID_SOCKET)
896 		errordesc << "socket(" << (int)socket << ") / ";
897 	errordesc << context << ": " << errdesc;
898 	is_error = true;
899 	LLOG("ERROR " << errordesc);
900 }
901 
SetSockError(const char * context,const char * errdesc)902 void TcpSocket::SetSockError(const char *context, const char *errdesc)
903 {
904 	SetSockError(context, GetErrorCode(), errdesc);
905 }
906 
SetSockError(const char * context)907 void TcpSocket::SetSockError(const char *context)
908 {
909 	SetSockError(context, TcpSocketErrorDesc(GetErrorCode()));
910 }
911 
912 TcpSocket::SSL *(*TcpSocket::CreateSSL)(TcpSocket& socket);
913 
StartSSL()914 bool TcpSocket::StartSSL()
915 {
916 	ASSERT(IsOpen());
917 	if(!CreateSSL) {
918 		SetSockError("StartSSL", -1, "Missing SSL support (Core/SSL)");
919 		return false;
920 	}
921 	if(!IsOpen()) {
922 		SetSockError("StartSSL", -1, "Socket is not open");
923 		return false;
924 	}
925 	if(mode != CONNECT && mode != ACCEPT) {
926 		SetSockError("StartSSL", -1, "Socket is not connected");
927 		return false;
928 	}
929 	ssl = (*CreateSSL)(*this);
930 	if(!ssl->Start()) {
931 		ssl.Clear();
932 		return false;
933 	}
934 	ssl_start = msecs();
935 	SSLHandshake();
936 	return true;
937 }
938 
SSLHandshake()939 dword TcpSocket::SSLHandshake()
940 {
941 	if(ssl && (mode == CONNECT || mode == ACCEPT)) {
942 		dword w = ssl->Handshake();
943 		if(w) {
944 			if(msecs(ssl_start) > 20000) {
945 				SetSockError("ssl handshake", ERROR_SSLHANDSHAKE_TIMEOUT, "Timeout");
946 				return false;
947 			}
948 			if(IsGlobalTimeout())
949 				return false;
950 			Wait(w);
951 			return ssl->Handshake();
952 		}
953 	}
954 	return 0;
955 }
956 
SSLCertificate(const String & cert_,const String & pkey_,bool asn1_)957 void TcpSocket::SSLCertificate(const String& cert_, const String& pkey_, bool asn1_)
958 {
959 	cert = cert_;
960 	pkey = pkey_;
961 	asn1 = asn1_;
962 }
963 
SSLServerNameIndication(const String & name)964 void TcpSocket::SSLServerNameIndication(const String& name)
965 {
966 	sni = name;
967 }
968 
Clear()969 void TcpSocket::Clear()
970 {
971 	ClearError();
972 	if(IsOpen())
973 		Close();
974 	Reset();
975 }
976 
Wait(int timeout)977 int SocketWaitEvent::Wait(int timeout)
978 {
979 	FD_ZERO(read);
980 	FD_ZERO(write);
981 	FD_ZERO(exception);
982 	int maxindex = -1;
983 	for(int i = 0; i < socket.GetCount(); i++) {
984 		const Tuple<int, dword>& s = socket[i];
985 		if(s.a >= 0) {
986 			const Tuple<int, dword>& s = socket[i];
987 			if(s.b & WAIT_READ)
988 				FD_SET(s.a, read);
989 			if(s.b & WAIT_WRITE)
990 				FD_SET(s.a, write);
991 			FD_SET(s.a, exception);
992 			maxindex = max(s.a, maxindex);
993 		}
994 	}
995 	timeval *tvalp = NULL;
996 	timeval tval;
997 	if(!IsNull(timeout)) {
998 		tval.tv_sec = timeout / 1000;
999 		tval.tv_usec = 1000 * (timeout % 1000);
1000 		tvalp = &tval;
1001 	}
1002 	return select(maxindex + 1, read, write, exception, tvalp);
1003 }
1004 
Get(int i) const1005 dword SocketWaitEvent::Get(int i) const
1006 {
1007 	int s = socket[i].a;
1008 	if(s < 0)
1009 		return 0;
1010 	dword events = 0;
1011 	if(FD_ISSET(s, read))
1012 		events |= WAIT_READ;
1013 	if(FD_ISSET(s, write))
1014 		events |= WAIT_WRITE;
1015 	if(FD_ISSET(s, exception))
1016 		events |= WAIT_IS_EXCEPTION;
1017 	return events;
1018 }
1019 
SocketWaitEvent()1020 SocketWaitEvent::SocketWaitEvent()
1021 {
1022 	FD_ZERO(read);
1023 	FD_ZERO(write);
1024 	FD_ZERO(exception);
1025 }
1026 
1027 }
1028