1 /* <!-- copyright */
2 /*
3 * aria2 - The high speed download utility
4 *
5 * Copyright (C) 2006 Tatsuhiro Tsujikawa
6 *
7 * This program is free software; you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation; either version 2 of the License, or
10 * (at your option) any later version.
11 *
12 * This program is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with this program; if not, write to the Free Software
19 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20 *
21 * In addition, as a special exception, the copyright holders give
22 * permission to link the code of portions of this program with the
23 * OpenSSL library under certain conditions as described in each
24 * individual source file, and distribute linked combinations
25 * including the two.
26 * You must obey the GNU General Public License in all respects
27 * for all of the code used other than OpenSSL. If you modify
28 * file(s) with this exception, you may extend this exception to your
29 * version of the file(s), but you are not obligated to do so. If you
30 * do not wish to do so, delete this exception statement from your
31 * version. If you delete this exception statement from all source
32 * files in the program, then also delete it here.
33 */
34 /* copyright --> */
35 #include "SocketCore.h"
36
37 #ifdef HAVE_IPHLPAPI_H
38 # include <iphlpapi.h>
39 #endif // HAVE_IPHLPAPI_H
40
41 #include <unistd.h>
42 #ifdef HAVE_IFADDRS_H
43 # include <ifaddrs.h>
44 #endif // HAVE_IFADDRS_H
45
46 #include <cerrno>
47 #include <cstring>
48 #include <cassert>
49 #include <sstream>
50 #include <array>
51
52 #include "message.h"
53 #include "DlRetryEx.h"
54 #include "DlAbortEx.h"
55 #include "fmt.h"
56 #include "util.h"
57 #include "TimeA2.h"
58 #include "a2functional.h"
59 #include "LogFactory.h"
60 #include "A2STR.h"
61 #ifdef ENABLE_SSL
62 # include "TLSContext.h"
63 # include "TLSSession.h"
64 #endif // ENABLE_SSL
65 #ifdef HAVE_LIBSSH2
66 # include "SSHSession.h"
67 #endif // HAVE_LIBSSH2
68
69 namespace aria2 {
70
71 #ifndef __MINGW32__
72 # define SOCKET_ERRNO (errno)
73 #else
74 # define SOCKET_ERRNO (WSAGetLastError())
75 #endif // __MINGW32__
76
77 #ifdef __MINGW32__
78 # define A2_EINPROGRESS WSAEWOULDBLOCK
79 # define A2_EWOULDBLOCK WSAEWOULDBLOCK
80 # define A2_EINTR WSAEINTR
81 # define A2_WOULDBLOCK(e) (e == WSAEWOULDBLOCK)
82 #else // !__MINGW32__
83 # define A2_EINPROGRESS EINPROGRESS
84 # ifndef EWOULDBLOCK
85 # define EWOULDBLOCK EAGAIN
86 # endif // EWOULDBLOCK
87 # define A2_EWOULDBLOCK EWOULDBLOCK
88 # define A2_EINTR EINTR
89 # if EWOULDBLOCK == EAGAIN
90 # define A2_WOULDBLOCK(e) (e == EWOULDBLOCK)
91 # else // EWOULDBLOCK != EAGAIN
92 # define A2_WOULDBLOCK(e) (e == EWOULDBLOCK || e == EAGAIN)
93 # endif // EWOULDBLOCK != EAGAIN
94 #endif // !__MINGW32__
95
96 #ifdef __MINGW32__
97 # define CLOSE(X) ::closesocket(X)
98 #else
99 # define CLOSE(X) close(X)
100 #endif // __MINGW32__
101
102 namespace {
errorMsg(int errNum)103 std::string errorMsg(int errNum)
104 {
105 #ifndef __MINGW32__
106 return util::safeStrerror(errNum);
107 #else
108 auto msg = util::formatLastError(errNum);
109 if (msg.empty()) {
110 char buf[256];
111 snprintf(buf, sizeof(buf), EX_SOCKET_UNKNOWN_ERROR, errNum, errNum);
112 return buf;
113 }
114 return msg;
115 #endif // __MINGW32__
116 }
117 } // namespace
118
119 namespace {
120 enum TlsState {
121 // TLS object is not initialized.
122 A2_TLS_NONE = 0,
123 // TLS object is now handshaking.
124 A2_TLS_HANDSHAKING = 2,
125 // TLS object is now connected.
126 A2_TLS_CONNECTED = 3
127 };
128 } // namespace
129
130 int SocketCore::protocolFamily_ = AF_UNSPEC;
131 int SocketCore::ipDscp_ = 0;
132
133 std::vector<SockAddr> SocketCore::bindAddrs_;
134 std::vector<std::vector<SockAddr>> SocketCore::bindAddrsList_;
135 std::vector<std::vector<SockAddr>>::iterator SocketCore::bindAddrsListIt_;
136
137 int SocketCore::socketRecvBufferSize_ = 0;
138
139 #ifdef ENABLE_SSL
140 std::shared_ptr<TLSContext> SocketCore::clTlsContext_;
141 std::shared_ptr<TLSContext> SocketCore::svTlsContext_;
142
setClientTLSContext(const std::shared_ptr<TLSContext> & tlsContext)143 void SocketCore::setClientTLSContext(
144 const std::shared_ptr<TLSContext>& tlsContext)
145 {
146 clTlsContext_ = tlsContext;
147 }
148
setServerTLSContext(const std::shared_ptr<TLSContext> & tlsContext)149 void SocketCore::setServerTLSContext(
150 const std::shared_ptr<TLSContext>& tlsContext)
151 {
152 svTlsContext_ = tlsContext;
153 }
154 #endif // ENABLE_SSL
155
SocketCore(int sockType)156 SocketCore::SocketCore(int sockType) : sockType_(sockType), sockfd_(-1)
157 {
158 init();
159 }
160
SocketCore(sock_t sockfd,int sockType)161 SocketCore::SocketCore(sock_t sockfd, int sockType)
162 : sockType_(sockType), sockfd_(sockfd)
163 {
164 init();
165 }
166
init()167 void SocketCore::init()
168 {
169 blocking_ = true;
170 secure_ = A2_TLS_NONE;
171
172 wantRead_ = false;
173 wantWrite_ = false;
174 }
175
~SocketCore()176 SocketCore::~SocketCore() { closeConnection(); }
177
178 namespace {
applySocketBufferSize(sock_t fd)179 void applySocketBufferSize(sock_t fd)
180 {
181 auto recvBufSize = SocketCore::getSocketRecvBufferSize();
182 if (recvBufSize == 0) {
183 return;
184 }
185
186 if (setsockopt(fd, SOL_SOCKET, SO_RCVBUF, (a2_sockopt_t)&recvBufSize,
187 sizeof(recvBufSize)) < 0) {
188 auto errNum = SOCKET_ERRNO;
189 A2_LOG_WARN(fmt("Failed to set socket buffer size. Cause: %s",
190 errorMsg(errNum).c_str()));
191 }
192 }
193 } // namespace
194
create(int family,int protocol)195 void SocketCore::create(int family, int protocol)
196 {
197 int errNum;
198 closeConnection();
199 sock_t fd = socket(family, sockType_, protocol);
200 errNum = SOCKET_ERRNO;
201 if (fd == (sock_t)-1) {
202 throw DL_ABORT_EX(
203 fmt("Failed to create socket. Cause:%s", errorMsg(errNum).c_str()));
204 }
205 util::make_fd_cloexec(fd);
206 int sockopt = 1;
207 if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (a2_sockopt_t)&sockopt,
208 sizeof(sockopt)) < 0) {
209 errNum = SOCKET_ERRNO;
210 CLOSE(fd);
211 throw DL_ABORT_EX(
212 fmt("Failed to create socket. Cause:%s", errorMsg(errNum).c_str()));
213 }
214
215 applySocketBufferSize(fd);
216
217 sockfd_ = fd;
218 }
219
bindInternal(int family,int socktype,int protocol,const struct sockaddr * addr,socklen_t addrlen,std::string & error)220 static sock_t bindInternal(int family, int socktype, int protocol,
221 const struct sockaddr* addr, socklen_t addrlen,
222 std::string& error)
223 {
224 int errNum;
225 sock_t fd = socket(family, socktype, protocol);
226 errNum = SOCKET_ERRNO;
227 if (fd == (sock_t)-1) {
228 error = errorMsg(errNum);
229 return -1;
230 }
231 util::make_fd_cloexec(fd);
232 int sockopt = 1;
233 if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (a2_sockopt_t)&sockopt,
234 sizeof(sockopt)) < 0) {
235 errNum = SOCKET_ERRNO;
236 error = errorMsg(errNum);
237 CLOSE(fd);
238 return -1;
239 }
240 #ifdef IPV6_V6ONLY
241 if (family == AF_INET6) {
242 int sockopt = 1;
243 if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, (a2_sockopt_t)&sockopt,
244 sizeof(sockopt)) < 0) {
245 errNum = SOCKET_ERRNO;
246 error = errorMsg(errNum);
247 CLOSE(fd);
248 return -1;
249 }
250 }
251 #endif // IPV6_V6ONLY
252
253 applySocketBufferSize(fd);
254
255 if (::bind(fd, addr, addrlen) == -1) {
256 errNum = SOCKET_ERRNO;
257 error = errorMsg(errNum);
258 CLOSE(fd);
259 return -1;
260 }
261 return fd;
262 }
263
bindTo(const char * host,uint16_t port,int family,int sockType,int getaddrinfoFlags,std::string & error)264 static sock_t bindTo(const char* host, uint16_t port, int family, int sockType,
265 int getaddrinfoFlags, std::string& error)
266 {
267 struct addrinfo* res;
268 int s = callGetaddrinfo(&res, host, util::uitos(port).c_str(), family,
269 sockType, getaddrinfoFlags, 0);
270 if (s) {
271 error = gai_strerror(s);
272 return -1;
273 }
274 std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> resDeleter(res,
275 freeaddrinfo);
276 struct addrinfo* rp;
277 for (rp = res; rp; rp = rp->ai_next) {
278 sock_t fd = bindInternal(rp->ai_family, rp->ai_socktype, rp->ai_protocol,
279 rp->ai_addr, rp->ai_addrlen, error);
280 if (fd != (sock_t)-1) {
281 return fd;
282 }
283 }
284 return -1;
285 }
286
bindWithFamily(uint16_t port,int family,int flags)287 void SocketCore::bindWithFamily(uint16_t port, int family, int flags)
288 {
289 closeConnection();
290 std::string error;
291 sock_t fd = bindTo(nullptr, port, family, sockType_, flags, error);
292 if (fd == (sock_t)-1) {
293 throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str()));
294 }
295 sockfd_ = fd;
296 }
297
bind(const char * addr,uint16_t port,int family,int flags)298 void SocketCore::bind(const char* addr, uint16_t port, int family, int flags)
299 {
300 closeConnection();
301 std::string error;
302 const char* addrp;
303 if (addr && addr[0]) {
304 addrp = addr;
305 }
306 else {
307 addrp = nullptr;
308 }
309 if (addrp || !(flags & AI_PASSIVE) || bindAddrsList_.empty()) {
310 sock_t fd = bindTo(addrp, port, family, sockType_, flags, error);
311 if (fd == (sock_t)-1) {
312 throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str()));
313 }
314 sockfd_ = fd;
315 return;
316 }
317
318 std::array<char, NI_MAXHOST> host;
319 for (const auto& bindAddrs : bindAddrsList_) {
320 for (const auto& a : bindAddrs) {
321 if (family != AF_UNSPEC && family != a.su.storage.ss_family) {
322 continue;
323 }
324 auto s = getnameinfo(&a.su.sa, a.suLength, host.data(), NI_MAXHOST,
325 nullptr, 0, NI_NUMERICHOST);
326 if (s) {
327 error = gai_strerror(s);
328 continue;
329 }
330 if (addrp && strcmp(host.data(), addrp) != 0) {
331 error = "Given address and resolved address do not match.";
332 continue;
333 }
334 auto fd = bindTo(host.data(), port, family, sockType_, flags, error);
335 if (fd != (sock_t)-1) {
336 sockfd_ = fd;
337 return;
338 }
339 }
340 }
341
342 if (sockfd_ == (sock_t)-1) {
343 throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str()));
344 }
345 }
346
bind(uint16_t port,int flags)347 void SocketCore::bind(uint16_t port, int flags)
348 {
349 bind(nullptr, port, protocolFamily_, flags);
350 }
351
bind(const struct sockaddr * addr,socklen_t addrlen)352 void SocketCore::bind(const struct sockaddr* addr, socklen_t addrlen)
353 {
354 closeConnection();
355 std::string error;
356 sock_t fd = bindInternal(addr->sa_family, sockType_, 0, addr, addrlen, error);
357 if (fd == (sock_t)-1) {
358 throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str()));
359 }
360 sockfd_ = fd;
361 }
362
beginListen()363 void SocketCore::beginListen()
364 {
365 if (listen(sockfd_, 1024) == -1) {
366 int errNum = SOCKET_ERRNO;
367 throw DL_ABORT_EX(fmt(EX_SOCKET_LISTEN, errorMsg(errNum).c_str()));
368 }
369 setNonBlockingMode();
370 }
371
acceptConnection() const372 std::shared_ptr<SocketCore> SocketCore::acceptConnection() const
373 {
374 sockaddr_union sockaddr;
375 socklen_t len = sizeof(sockaddr);
376 sock_t fd;
377 while ((fd = accept(sockfd_, &sockaddr.sa, &len)) == (sock_t)-1 &&
378 SOCKET_ERRNO == A2_EINTR)
379 ;
380 int errNum = SOCKET_ERRNO;
381 if (fd == (sock_t)-1) {
382 throw DL_ABORT_EX(fmt(EX_SOCKET_ACCEPT, errorMsg(errNum).c_str()));
383 }
384
385 applySocketBufferSize(fd);
386
387 auto sock = std::make_shared<SocketCore>(fd, sockType_);
388 sock->setNonBlockingMode();
389 return sock;
390 }
391
getAddrInfo() const392 Endpoint SocketCore::getAddrInfo() const
393 {
394 sockaddr_union sockaddr;
395 socklen_t len = sizeof(sockaddr);
396 getAddrInfo(sockaddr, len);
397 return util::getNumericNameInfo(&sockaddr.sa, len);
398 }
399
getAddrInfo(sockaddr_union & sockaddr,socklen_t & len) const400 void SocketCore::getAddrInfo(sockaddr_union& sockaddr, socklen_t& len) const
401 {
402 if (getsockname(sockfd_, &sockaddr.sa, &len) == -1) {
403 int errNum = SOCKET_ERRNO;
404 throw DL_ABORT_EX(fmt(EX_SOCKET_GET_NAME, errorMsg(errNum).c_str()));
405 }
406 }
407
getAddressFamily() const408 int SocketCore::getAddressFamily() const
409 {
410 sockaddr_union sockaddr;
411 socklen_t len = sizeof(sockaddr);
412 getAddrInfo(sockaddr, len);
413 return sockaddr.storage.ss_family;
414 }
415
getPeerInfo() const416 Endpoint SocketCore::getPeerInfo() const
417 {
418 sockaddr_union sockaddr;
419 socklen_t len = sizeof(sockaddr);
420 if (getpeername(sockfd_, &sockaddr.sa, &len) == -1) {
421 int errNum = SOCKET_ERRNO;
422 throw DL_ABORT_EX(fmt(EX_SOCKET_GET_NAME, errorMsg(errNum).c_str()));
423 }
424 return util::getNumericNameInfo(&sockaddr.sa, len);
425 }
426
establishConnection(const std::string & host,uint16_t port,bool tcpNodelay)427 void SocketCore::establishConnection(const std::string& host, uint16_t port,
428 bool tcpNodelay)
429 {
430 closeConnection();
431 std::string error;
432 struct addrinfo* res;
433 int s;
434 s = callGetaddrinfo(&res, host.c_str(), util::uitos(port).c_str(),
435 protocolFamily_, sockType_, 0, 0);
436 if (s) {
437 throw DL_ABORT_EX(fmt(EX_RESOLVE_HOSTNAME, host.c_str(), gai_strerror(s)));
438 }
439 std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> resDeleter(res,
440 freeaddrinfo);
441 struct addrinfo* rp;
442 int errNum;
443 for (rp = res; rp; rp = rp->ai_next) {
444 sock_t fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
445 errNum = SOCKET_ERRNO;
446 if (fd == (sock_t)-1) {
447 error = errorMsg(errNum);
448 continue;
449 }
450 util::make_fd_cloexec(fd);
451 int sockopt = 1;
452 if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (a2_sockopt_t)&sockopt,
453 sizeof(sockopt)) < 0) {
454 errNum = SOCKET_ERRNO;
455 error = errorMsg(errNum);
456 CLOSE(fd);
457 continue;
458 }
459
460 applySocketBufferSize(fd);
461
462 if (!bindAddrs_.empty()) {
463 bool bindSuccess = false;
464 for (const auto& soaddr : bindAddrs_) {
465 if (::bind(fd, &soaddr.su.sa, soaddr.suLength) == -1) {
466 errNum = SOCKET_ERRNO;
467 error = errorMsg(errNum);
468 A2_LOG_DEBUG(fmt(EX_SOCKET_BIND, error.c_str()));
469 }
470 else {
471 bindSuccess = true;
472 break;
473 }
474 }
475 if (!bindSuccess) {
476 CLOSE(fd);
477 continue;
478 }
479 }
480 if (!bindAddrsList_.empty()) {
481 ++bindAddrsListIt_;
482 if (bindAddrsListIt_ == bindAddrsList_.end()) {
483 bindAddrsListIt_ = bindAddrsList_.begin();
484 }
485 bindAddrs_ = *bindAddrsListIt_;
486 }
487
488 sockfd_ = fd;
489 // make socket non-blocking mode
490 setNonBlockingMode();
491 if (tcpNodelay) {
492 setTcpNodelay(true);
493 }
494 if (connect(fd, rp->ai_addr, rp->ai_addrlen) == -1 &&
495 SOCKET_ERRNO != A2_EINPROGRESS) {
496 errNum = SOCKET_ERRNO;
497 error = errorMsg(errNum);
498 CLOSE(sockfd_);
499 sockfd_ = (sock_t)-1;
500 continue;
501 }
502 // TODO at this point, connection may not be established and it may fail
503 // later. In such case, next ai_addr should be tried.
504 break;
505 }
506 if (sockfd_ == (sock_t)-1) {
507 throw DL_ABORT_EX(fmt(EX_SOCKET_CONNECT, host.c_str(), error.c_str()));
508 }
509 }
510
setSockOpt(int level,int optname,void * optval,socklen_t optlen)511 void SocketCore::setSockOpt(int level, int optname, void* optval,
512 socklen_t optlen)
513 {
514 if (setsockopt(sockfd_, level, optname, (a2_sockopt_t)optval, optlen) < 0) {
515 int errNum = SOCKET_ERRNO;
516 throw DL_ABORT_EX(fmt(EX_SOCKET_SET_OPT, errorMsg(errNum).c_str()));
517 }
518 }
519
setMulticastInterface(const std::string & localAddr)520 void SocketCore::setMulticastInterface(const std::string& localAddr)
521 {
522 in_addr addr;
523 if (localAddr.empty()) {
524 addr.s_addr = htonl(INADDR_ANY);
525 }
526 else if (inetPton(AF_INET, localAddr.c_str(), &addr) != 0) {
527 throw DL_ABORT_EX(
528 fmt("%s is not valid IPv4 numeric address", localAddr.c_str()));
529 }
530 setSockOpt(IPPROTO_IP, IP_MULTICAST_IF, &addr, sizeof(addr));
531 }
532
setMulticastTtl(unsigned char ttl)533 void SocketCore::setMulticastTtl(unsigned char ttl)
534 {
535 setSockOpt(IPPROTO_IP, IP_MULTICAST_TTL, &ttl, sizeof(ttl));
536 }
537
setMulticastLoop(unsigned char loop)538 void SocketCore::setMulticastLoop(unsigned char loop)
539 {
540 setSockOpt(IPPROTO_IP, IP_MULTICAST_LOOP, &loop, sizeof(loop));
541 }
542
joinMulticastGroup(const std::string & multicastAddr,uint16_t multicastPort,const std::string & localAddr)543 void SocketCore::joinMulticastGroup(const std::string& multicastAddr,
544 uint16_t multicastPort,
545 const std::string& localAddr)
546 {
547 in_addr multiAddr;
548 if (inetPton(AF_INET, multicastAddr.c_str(), &multiAddr) != 0) {
549 throw DL_ABORT_EX(
550 fmt("%s is not valid IPv4 numeric address", multicastAddr.c_str()));
551 }
552 in_addr ifAddr;
553 if (localAddr.empty()) {
554 ifAddr.s_addr = htonl(INADDR_ANY);
555 }
556 else if (inetPton(AF_INET, localAddr.c_str(), &ifAddr) != 0) {
557 throw DL_ABORT_EX(
558 fmt("%s is not valid IPv4 numeric address", localAddr.c_str()));
559 }
560 struct ip_mreq mreq;
561 memset(&mreq, 0, sizeof(mreq));
562 mreq.imr_multiaddr = multiAddr;
563 mreq.imr_interface = ifAddr;
564 setSockOpt(IPPROTO_IP, IP_ADD_MEMBERSHIP, &mreq, sizeof(mreq));
565 }
566
setTcpNodelay(bool f)567 void SocketCore::setTcpNodelay(bool f)
568 {
569 int val = f;
570 setSockOpt(IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val));
571 }
572
applyIpDscp()573 void SocketCore::applyIpDscp()
574 {
575 if (ipDscp_ == 0) {
576 return;
577 }
578
579 try {
580 int family = getAddressFamily();
581 if (family == AF_INET) {
582 setSockOpt(IPPROTO_IP, IP_TOS, &ipDscp_, sizeof(ipDscp_));
583 }
584 #if defined(IPV6_TCLASS) || defined(__linux__) || defined(__FreeBSD__) || \
585 defined(__NetBSD__) || defined(__OpenBSD__) || defined(__DragonFly__)
586 else if (family == AF_INET6) {
587 setSockOpt(IPPROTO_IPV6, IPV6_TCLASS, &ipDscp_, sizeof(ipDscp_));
588 }
589 #endif
590 }
591 catch (RecoverableException& e) {
592 A2_LOG_INFO_EX("Applying DSCP value failed", e);
593 }
594 }
595
setNonBlockingMode()596 void SocketCore::setNonBlockingMode()
597 {
598 #ifdef __MINGW32__
599 static u_long flag = 1;
600 if (::ioctlsocket(sockfd_, FIONBIO, &flag) == -1) {
601 int errNum = SOCKET_ERRNO;
602 throw DL_ABORT_EX(fmt(EX_SOCKET_NONBLOCKING, errorMsg(errNum).c_str()));
603 }
604 #else
605 int flags;
606 while ((flags = fcntl(sockfd_, F_GETFL, 0)) == -1 && errno == EINTR)
607 ;
608 // TODO add error handling
609 while (fcntl(sockfd_, F_SETFL, flags | O_NONBLOCK) == -1 && errno == EINTR)
610 ;
611 #endif // __MINGW32__
612 blocking_ = false;
613 }
614
setBlockingMode()615 void SocketCore::setBlockingMode()
616 {
617 #ifdef __MINGW32__
618 static u_long flag = 0;
619 if (::ioctlsocket(sockfd_, FIONBIO, &flag) == -1) {
620 int errNum = SOCKET_ERRNO;
621 throw DL_ABORT_EX(fmt(EX_SOCKET_BLOCKING, errorMsg(errNum).c_str()));
622 }
623 #else
624 int flags;
625 while ((flags = fcntl(sockfd_, F_GETFL, 0)) == -1 && errno == EINTR)
626 ;
627 // TODO add error handling
628 while (fcntl(sockfd_, F_SETFL, flags & (~O_NONBLOCK)) == -1 && errno == EINTR)
629 ;
630 #endif // __MINGW32__
631 blocking_ = true;
632 }
633
closeConnection()634 void SocketCore::closeConnection()
635 {
636 #ifdef ENABLE_SSL
637 if (tlsSession_) {
638 tlsSession_->closeConnection();
639 tlsSession_.reset();
640 }
641 #endif // ENABLE_SSL
642
643 #ifdef HAVE_LIBSSH2
644 if (sshSession_) {
645 sshSession_->closeConnection();
646 sshSession_.reset();
647 }
648 #endif // HAVE_LIBSSH2
649
650 if (sockfd_ != (sock_t)-1) {
651 shutdown(sockfd_, SHUT_WR);
652 CLOSE(sockfd_);
653 sockfd_ = -1;
654 }
655 }
656
657 #ifndef __MINGW32__
658 # define CHECK_FD(fd) \
659 if (fd < 0 || FD_SETSIZE <= fd) { \
660 logger_->warn("Detected file descriptor >= FD_SETSIZE or < 0. " \
661 "Download may slow down or fail."); \
662 return false; \
663 }
664 #endif // !__MINGW32__
665
isWritable(time_t timeout)666 bool SocketCore::isWritable(time_t timeout)
667 {
668 #ifdef HAVE_POLL
669 struct pollfd p;
670 p.fd = sockfd_;
671 p.events = POLLOUT;
672 int r;
673 while ((r = poll(&p, 1, timeout * 1000)) == -1 && errno == EINTR)
674 ;
675 int errNum = SOCKET_ERRNO;
676 if (r > 0) {
677 return p.revents & (POLLOUT | POLLHUP | POLLERR);
678 }
679 if (r == 0) {
680 return false;
681 }
682 throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_WRITABLE, errorMsg(errNum).c_str()));
683 #else // !HAVE_POLL
684 # ifndef __MINGW32__
685 CHECK_FD(sockfd_);
686 # endif // !__MINGW32__
687 fd_set fds;
688 FD_ZERO(&fds);
689 FD_SET(sockfd_, &fds);
690
691 struct timeval tv;
692 tv.tv_sec = timeout;
693 tv.tv_usec = 0;
694
695 int r = select(sockfd_ + 1, nullptr, &fds, nullptr, &tv);
696 int errNum = SOCKET_ERRNO;
697 if (r == 1) {
698 return true;
699 }
700 if (r == 0) {
701 // time out
702 return false;
703 }
704 if (errNum == A2_EINPROGRESS || errNum == A2_EINTR) {
705 return false;
706 }
707 throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_WRITABLE, errorMsg(errNum).c_str()));
708 #endif // !HAVE_POLL
709 }
710
isReadable(time_t timeout)711 bool SocketCore::isReadable(time_t timeout)
712 {
713 #ifdef HAVE_POLL
714 struct pollfd p;
715 p.fd = sockfd_;
716 p.events = POLLIN;
717 int r;
718 while ((r = poll(&p, 1, timeout * 1000)) == -1 && errno == EINTR)
719 ;
720 int errNum = SOCKET_ERRNO;
721 if (r > 0) {
722 return p.revents & (POLLIN | POLLHUP | POLLERR);
723 }
724 if (r == 0) {
725 return false;
726 }
727 throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_READABLE, errorMsg(errNum).c_str()));
728 #else // !HAVE_POLL
729 # ifndef __MINGW32__
730 CHECK_FD(sockfd_);
731 # endif // !__MINGW32__
732 fd_set fds;
733 FD_ZERO(&fds);
734 FD_SET(sockfd_, &fds);
735
736 struct timeval tv;
737 tv.tv_sec = timeout;
738 tv.tv_usec = 0;
739
740 int r = select(sockfd_ + 1, &fds, nullptr, nullptr, &tv);
741 int errNum = SOCKET_ERRNO;
742 if (r == 1) {
743 return true;
744 }
745 if (r == 0) {
746 // time out
747 return false;
748 }
749 if (errNum == A2_EINPROGRESS || errNum == A2_EINTR) {
750 return false;
751 }
752 throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_READABLE, errorMsg(errNum).c_str()));
753 #endif // !HAVE_POLL
754 }
755
writeVector(a2iovec * iov,size_t iovcnt)756 ssize_t SocketCore::writeVector(a2iovec* iov, size_t iovcnt)
757 {
758 ssize_t ret = 0;
759 wantRead_ = false;
760 wantWrite_ = false;
761 if (!secure_) {
762 #ifdef __MINGW32__
763 DWORD nsent;
764 int rv = WSASend(sockfd_, iov, iovcnt, &nsent, 0, 0, 0);
765 if (rv == 0) {
766 ret = nsent;
767 }
768 else {
769 ret = -1;
770 }
771 #else // !__MINGW32__
772 while ((ret = writev(sockfd_, iov, iovcnt)) == -1 &&
773 SOCKET_ERRNO == A2_EINTR)
774 ;
775 #endif // !__MINGW32__
776 int errNum = SOCKET_ERRNO;
777 if (ret == -1) {
778 if (!A2_WOULDBLOCK(errNum)) {
779 throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, errorMsg(errNum).c_str()));
780 }
781 wantWrite_ = true;
782 ret = 0;
783 }
784 }
785 else {
786 // For SSL/TLS, we could not use writev, so just iterate vector
787 // and write the data in normal way.
788 for (size_t i = 0; i < iovcnt; ++i) {
789 ssize_t rv = writeData(iov[i].A2IOVEC_BASE, iov[i].A2IOVEC_LEN);
790 if (rv == 0) {
791 break;
792 }
793 ret += rv;
794 }
795 }
796 return ret;
797 }
798
writeData(const void * data,size_t len)799 ssize_t SocketCore::writeData(const void* data, size_t len)
800 {
801 ssize_t ret = 0;
802 wantRead_ = false;
803 wantWrite_ = false;
804
805 if (!secure_) {
806 // Cast for Windows send()
807 while ((ret = send(sockfd_, reinterpret_cast<const char*>(data), len, 0)) ==
808 -1 &&
809 SOCKET_ERRNO == A2_EINTR)
810 ;
811 int errNum = SOCKET_ERRNO;
812 if (ret == -1) {
813 if (!A2_WOULDBLOCK(errNum)) {
814 throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, errorMsg(errNum).c_str()));
815 }
816 wantWrite_ = true;
817 ret = 0;
818 }
819 }
820 else {
821 #ifdef ENABLE_SSL
822 ret = tlsSession_->writeData(data, len);
823 if (ret < 0) {
824 if (ret != TLS_ERR_WOULDBLOCK) {
825 throw DL_RETRY_EX(
826 fmt(EX_SOCKET_SEND, tlsSession_->getLastErrorString().c_str()));
827 }
828 if (tlsSession_->checkDirection() == TLS_WANT_READ) {
829 wantRead_ = true;
830 }
831 else {
832 wantWrite_ = true;
833 }
834 ret = 0;
835 }
836 #endif // ENABLE_SSL
837 }
838 return ret;
839 }
840
readData(void * data,size_t & len)841 void SocketCore::readData(void* data, size_t& len)
842 {
843 ssize_t ret = 0;
844 wantRead_ = false;
845 wantWrite_ = false;
846
847 #ifdef HAVE_LIBSSH2
848 if (sshSession_) {
849 ret = sshSession_->readData(data, len);
850 if (ret < 0) {
851 if (ret != SSH_ERR_WOULDBLOCK) {
852 throw DL_RETRY_EX(
853 fmt(EX_SOCKET_RECV, sshSession_->getLastErrorString().c_str()));
854 }
855 if (sshSession_->checkDirection() == SSH_WANT_READ) {
856 wantRead_ = true;
857 }
858 else {
859 wantWrite_ = true;
860 }
861 ret = 0;
862 }
863 }
864 else
865 #endif // HAVE_LIBSSH2
866 if (!secure_) {
867 // Cast for Windows recv()
868 while ((ret = recv(sockfd_, reinterpret_cast<char*>(data), len, 0)) == -1 &&
869 SOCKET_ERRNO == A2_EINTR)
870 ;
871 int errNum = SOCKET_ERRNO;
872 if (ret == -1) {
873 if (!A2_WOULDBLOCK(errNum)) {
874 throw DL_RETRY_EX(fmt(EX_SOCKET_RECV, errorMsg(errNum).c_str()));
875 }
876 wantRead_ = true;
877 ret = 0;
878 }
879 }
880 else {
881 #ifdef ENABLE_SSL
882 ret = tlsSession_->readData(data, len);
883 if (ret < 0) {
884 if (ret != TLS_ERR_WOULDBLOCK) {
885 throw DL_RETRY_EX(
886 fmt(EX_SOCKET_RECV, tlsSession_->getLastErrorString().c_str()));
887 }
888 if (tlsSession_->checkDirection() == TLS_WANT_READ) {
889 wantRead_ = true;
890 }
891 else {
892 wantWrite_ = true;
893 }
894 ret = 0;
895 }
896 #endif // ENABLE_SSL
897 }
898
899 len = ret;
900 }
901
902 #ifdef ENABLE_SSL
903
tlsAccept()904 bool SocketCore::tlsAccept()
905 {
906 return tlsHandshake(svTlsContext_.get(), A2STR::NIL);
907 }
908
tlsConnect(const std::string & hostname)909 bool SocketCore::tlsConnect(const std::string& hostname)
910 {
911 return tlsHandshake(clTlsContext_.get(), hostname);
912 }
913
tlsHandshake(TLSContext * tlsctx,const std::string & hostname)914 bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
915 {
916 wantRead_ = false;
917 wantWrite_ = false;
918
919 if (secure_ == A2_TLS_CONNECTED) {
920 // Already connected!
921 return true;
922 }
923
924 if (secure_ == A2_TLS_NONE) {
925 // Do some initial setup
926 A2_LOG_DEBUG("Creating TLS session");
927 tlsSession_.reset(TLSSession::make(tlsctx));
928 auto rv = tlsSession_->init(sockfd_);
929 if (rv != TLS_ERR_OK) {
930 std::string error = tlsSession_->getLastErrorString();
931 tlsSession_.reset();
932 throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, error.c_str()));
933 }
934 // Check hostname is not numeric and it includes ".". Setting
935 // "localhost" will produce TLS alert with GNUTLS.
936 if (tlsctx->getSide() == TLS_CLIENT && !util::isNumericHost(hostname) &&
937 hostname.find(".") != std::string::npos) {
938 rv = tlsSession_->setSNIHostname(hostname);
939 if (rv != TLS_ERR_OK) {
940 throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE,
941 tlsSession_->getLastErrorString().c_str()));
942 }
943 }
944 // Done with the setup, now let handshaking begin immediately.
945 secure_ = A2_TLS_HANDSHAKING;
946 A2_LOG_DEBUG("TLS Handshaking");
947 }
948
949 if (secure_ == A2_TLS_HANDSHAKING) {
950 // Starting handshake after initial setup or still handshaking.
951 TLSVersion ver = TLS_PROTO_NONE;
952 int rv = 0;
953 std::string handshakeError;
954
955 if (tlsctx->getSide() == TLS_CLIENT) {
956 rv = tlsSession_->tlsConnect(hostname, ver, handshakeError);
957 }
958 else {
959 rv = tlsSession_->tlsAccept(ver);
960 }
961
962 if (rv == TLS_ERR_OK) {
963 // We're good, more or less.
964 // 1. Construct peerinfo
965 std::stringstream ss;
966 if (!hostname.empty()) {
967 ss << hostname << " (";
968 }
969 auto peerEndpoint = getPeerInfo();
970 ss << peerEndpoint.addr << ":" << peerEndpoint.port;
971 if (!hostname.empty()) {
972 ss << ")";
973 }
974
975 std::string tlsVersion;
976 switch (ver) {
977 case TLS_PROTO_TLS11:
978 tlsVersion = A2_V_TLS11;
979 break;
980 case TLS_PROTO_TLS12:
981 tlsVersion = A2_V_TLS12;
982 break;
983 case TLS_PROTO_TLS13:
984 tlsVersion = A2_V_TLS13;
985 break;
986 default:
987 assert(0);
988 abort();
989 }
990
991 auto peerInfo = ss.str();
992
993 A2_LOG_DEBUG(fmt("Securely connected to %s with %s", peerInfo.c_str(),
994 tlsVersion.c_str()));
995
996 // 2. We're connected now!
997 secure_ = A2_TLS_CONNECTED;
998 return true;
999 }
1000
1001 if (rv == TLS_ERR_WOULDBLOCK) {
1002 // We're not done yet...
1003 if (tlsSession_->checkDirection() == TLS_WANT_READ) {
1004 // ... but read buffers are empty.
1005 wantRead_ = true;
1006 }
1007 else {
1008 // ... but write buffers are full.
1009 wantWrite_ = true;
1010 }
1011 // Returning false (instead of true==success or throwing) will cause this
1012 // function to be called again once buffering is dealt with
1013 return false;
1014 }
1015
1016 if (rv == TLS_ERR_ERROR) {
1017 // Damn those error.
1018 throw DL_ABORT_EX(fmt("SSL/TLS handshake failure: %s",
1019 handshakeError.empty()
1020 ? tlsSession_->getLastErrorString().c_str()
1021 : handshakeError.c_str()));
1022 }
1023
1024 // Some implementation passed back an invalid result.
1025 throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE,
1026 "Invalid connect state (this is a bug in the TLS "
1027 "backend!)"));
1028 }
1029
1030 // We should never get here, i.e. all possible states should have been handled
1031 // and returned from a branch before! Getting here is a bug, of course!
1032 throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, "Invalid state (this is a bug!)"));
1033 }
1034
1035 #endif // ENABLE_SSL
1036
1037 #ifdef HAVE_LIBSSH2
1038
sshHandshake(const std::string & hashType,const std::string & digest)1039 bool SocketCore::sshHandshake(const std::string& hashType,
1040 const std::string& digest)
1041 {
1042 wantRead_ = false;
1043 wantWrite_ = false;
1044
1045 if (!sshSession_) {
1046 sshSession_ = make_unique<SSHSession>();
1047 if (sshSession_->init(sockfd_) == SSH_ERR_ERROR) {
1048 throw DL_ABORT_EX("Could not create SSH session");
1049 }
1050 }
1051 auto rv = sshSession_->handshake();
1052 if (rv == SSH_ERR_WOULDBLOCK) {
1053 sshCheckDirection();
1054 return false;
1055 }
1056 if (rv == SSH_ERR_ERROR) {
1057 throw DL_ABORT_EX(fmt("SSH handshake failure: %s",
1058 sshSession_->getLastErrorString().c_str()));
1059 }
1060 if (!hashType.empty()) {
1061 auto actualDigest = sshSession_->hostkeyMessageDigest(hashType);
1062 if (actualDigest.empty()) {
1063 throw DL_ABORT_EX(fmt("Empty host key fingerprint from SSH layer: "
1064 "perhaps hash type %s is not supported?",
1065 hashType.c_str()));
1066 }
1067 if (digest != actualDigest) {
1068 throw DL_ABORT_EX(fmt("Unexpected SSH host key: expected %s, actual %s",
1069 util::toHex(digest).c_str(),
1070 util::toHex(actualDigest).c_str()));
1071 }
1072 }
1073 return true;
1074 }
1075
sshAuthPassword(const std::string & user,const std::string & password)1076 bool SocketCore::sshAuthPassword(const std::string& user,
1077 const std::string& password)
1078 {
1079 assert(sshSession_);
1080
1081 wantRead_ = false;
1082 wantWrite_ = false;
1083
1084 auto rv = sshSession_->authPassword(user, password);
1085 if (rv == SSH_ERR_WOULDBLOCK) {
1086 sshCheckDirection();
1087 return false;
1088 }
1089 if (rv == SSH_ERR_ERROR) {
1090 throw DL_ABORT_EX(fmt("SSH authentication failure: %s",
1091 sshSession_->getLastErrorString().c_str()));
1092 }
1093 return true;
1094 }
1095
sshSFTPOpen(const std::string & path)1096 bool SocketCore::sshSFTPOpen(const std::string& path)
1097 {
1098 assert(sshSession_);
1099
1100 wantRead_ = false;
1101 wantWrite_ = false;
1102
1103 auto rv = sshSession_->sftpOpen(path);
1104 if (rv == SSH_ERR_WOULDBLOCK) {
1105 sshCheckDirection();
1106 return false;
1107 }
1108 if (rv == SSH_ERR_ERROR) {
1109 throw DL_ABORT_EX(fmt("SSH opening SFTP path %s failed: %s", path.c_str(),
1110 sshSession_->getLastErrorString().c_str()));
1111 }
1112 return true;
1113 }
1114
sshSFTPClose()1115 bool SocketCore::sshSFTPClose()
1116 {
1117 assert(sshSession_);
1118
1119 wantRead_ = false;
1120 wantWrite_ = false;
1121
1122 auto rv = sshSession_->sftpClose();
1123 if (rv == SSH_ERR_WOULDBLOCK) {
1124 sshCheckDirection();
1125 return false;
1126 }
1127 if (rv == SSH_ERR_ERROR) {
1128 throw DL_ABORT_EX(fmt("SSH closing SFTP failed: %s",
1129 sshSession_->getLastErrorString().c_str()));
1130 }
1131 return true;
1132 }
1133
sshSFTPStat(int64_t & totalLength,time_t & mtime,const std::string & path)1134 bool SocketCore::sshSFTPStat(int64_t& totalLength, time_t& mtime,
1135 const std::string& path)
1136 {
1137 assert(sshSession_);
1138
1139 wantRead_ = false;
1140 wantWrite_ = false;
1141
1142 auto rv = sshSession_->sftpStat(totalLength, mtime);
1143 if (rv == SSH_ERR_WOULDBLOCK) {
1144 sshCheckDirection();
1145 return false;
1146 }
1147 if (rv == SSH_ERR_ERROR) {
1148 throw DL_ABORT_EX(fmt("SSH stat SFTP path %s filed: %s", path.c_str(),
1149 sshSession_->getLastErrorString().c_str()));
1150 }
1151 return true;
1152 }
1153
sshSFTPSeek(int64_t pos)1154 void SocketCore::sshSFTPSeek(int64_t pos)
1155 {
1156 assert(sshSession_);
1157
1158 sshSession_->sftpSeek(pos);
1159 }
1160
sshGracefulShutdown()1161 bool SocketCore::sshGracefulShutdown()
1162 {
1163 assert(sshSession_);
1164 auto rv = sshSession_->gracefulShutdown();
1165 if (rv == SSH_ERR_WOULDBLOCK) {
1166 sshCheckDirection();
1167 return false;
1168 }
1169 if (rv == SSH_ERR_ERROR) {
1170 throw DL_ABORT_EX(fmt("SSH graceful shutdown failed: %s",
1171 sshSession_->getLastErrorString().c_str()));
1172 }
1173 return true;
1174 }
1175
sshCheckDirection()1176 void SocketCore::sshCheckDirection()
1177 {
1178 if (sshSession_->checkDirection() == SSH_WANT_READ) {
1179 wantRead_ = true;
1180 }
1181 else {
1182 wantWrite_ = true;
1183 }
1184 }
1185
1186 #endif // HAVE_LIBSSH2
1187
writeData(const void * data,size_t len,const std::string & host,uint16_t port)1188 ssize_t SocketCore::writeData(const void* data, size_t len,
1189 const std::string& host, uint16_t port)
1190 {
1191 wantRead_ = false;
1192 wantWrite_ = false;
1193
1194 struct addrinfo* res;
1195 int s;
1196 s = callGetaddrinfo(&res, host.c_str(), util::uitos(port).c_str(),
1197 protocolFamily_, sockType_, 0, 0);
1198 if (s) {
1199 throw DL_ABORT_EX(fmt(EX_SOCKET_SEND, gai_strerror(s)));
1200 }
1201 std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> resDeleter(res,
1202 freeaddrinfo);
1203 struct addrinfo* rp;
1204 ssize_t r = -1;
1205 int errNum = 0;
1206 for (rp = res; rp; rp = rp->ai_next) {
1207 // Cast for Windows sendto()
1208 while ((r = sendto(sockfd_, reinterpret_cast<const char*>(data), len, 0,
1209 rp->ai_addr, rp->ai_addrlen)) == -1 &&
1210 A2_EINTR == SOCKET_ERRNO)
1211 ;
1212 errNum = SOCKET_ERRNO;
1213 if (r == static_cast<ssize_t>(len)) {
1214 break;
1215 }
1216 if (r == -1 && A2_WOULDBLOCK(errNum)) {
1217 wantWrite_ = true;
1218 r = 0;
1219 break;
1220 }
1221 }
1222 if (r == -1) {
1223 throw DL_ABORT_EX(fmt(EX_SOCKET_SEND, errorMsg(errNum).c_str()));
1224 }
1225 return r;
1226 }
1227
readDataFrom(void * data,size_t len,Endpoint & sender)1228 ssize_t SocketCore::readDataFrom(void* data, size_t len, Endpoint& sender)
1229 {
1230 wantRead_ = false;
1231 wantWrite_ = false;
1232 sockaddr_union sockaddr;
1233 socklen_t sockaddrlen = sizeof(sockaddr);
1234 ssize_t r;
1235 // Cast for Windows recvfrom()
1236 while ((r = recvfrom(sockfd_, reinterpret_cast<char*>(data), len, 0,
1237 &sockaddr.sa, &sockaddrlen)) == -1 &&
1238 A2_EINTR == SOCKET_ERRNO)
1239 ;
1240 int errNum = SOCKET_ERRNO;
1241 if (r == -1) {
1242 if (!A2_WOULDBLOCK(errNum)) {
1243 throw DL_RETRY_EX(fmt(EX_SOCKET_RECV, errorMsg(errNum).c_str()));
1244 }
1245 wantRead_ = true;
1246 r = 0;
1247 }
1248 else {
1249 sender = util::getNumericNameInfo(&sockaddr.sa, sockaddrlen);
1250 }
1251
1252 return r;
1253 }
1254
getSocketError() const1255 std::string SocketCore::getSocketError() const
1256 {
1257 int error;
1258 socklen_t optlen = sizeof(error);
1259
1260 if (getsockopt(sockfd_, SOL_SOCKET, SO_ERROR, (a2_sockopt_t)&error,
1261 &optlen) == -1) {
1262 int errNum = SOCKET_ERRNO;
1263 throw DL_ABORT_EX(
1264 fmt("Failed to get socket error: %s", errorMsg(errNum).c_str()));
1265 }
1266 if (error != 0) {
1267 return errorMsg(error);
1268 }
1269 return "";
1270 }
1271
wantRead() const1272 bool SocketCore::wantRead() const { return wantRead_; }
1273
wantWrite() const1274 bool SocketCore::wantWrite() const { return wantWrite_; }
1275
bindAddress(const std::string & iface)1276 void SocketCore::bindAddress(const std::string& iface)
1277 {
1278 auto bindAddrs = getInterfaceAddress(iface, protocolFamily_);
1279 if (bindAddrs.empty()) {
1280 throw DL_ABORT_EX(
1281 fmt(MSG_INTERFACE_NOT_FOUND, iface.c_str(), "not available"));
1282 }
1283 bindAddrs_.swap(bindAddrs);
1284 for (const auto& a : bindAddrs_) {
1285 char host[NI_MAXHOST];
1286 int s;
1287 s = getnameinfo(&a.su.sa, a.suLength, host, NI_MAXHOST, nullptr, 0,
1288 NI_NUMERICHOST);
1289 if (s == 0) {
1290 A2_LOG_DEBUG(fmt("Sockets will bind to %s", host));
1291 }
1292 }
1293 bindAddrsList_.push_back(bindAddrs_);
1294 bindAddrsListIt_ = std::begin(bindAddrsList_);
1295 }
1296
bindAllAddress(const std::string & ifaces)1297 void SocketCore::bindAllAddress(const std::string& ifaces)
1298 {
1299 std::vector<std::vector<SockAddr>> bindAddrsList;
1300 std::vector<std::string> ifaceList;
1301 util::split(ifaces.begin(), ifaces.end(), std::back_inserter(ifaceList), ',',
1302 true);
1303 if (ifaceList.empty()) {
1304 throw DL_ABORT_EX(
1305 "List of interfaces is empty, one or more interfaces is required");
1306 }
1307 for (auto& iface : ifaceList) {
1308 auto bindAddrs = getInterfaceAddress(iface, protocolFamily_);
1309 if (bindAddrs.empty()) {
1310 throw DL_ABORT_EX(
1311 fmt(MSG_INTERFACE_NOT_FOUND, iface.c_str(), "not available"));
1312 }
1313 bindAddrsList.push_back(bindAddrs);
1314 for (const auto& a : bindAddrs) {
1315 char host[NI_MAXHOST];
1316 int s;
1317 s = getnameinfo(&a.su.sa, a.suLength, host, NI_MAXHOST, nullptr, 0,
1318 NI_NUMERICHOST);
1319 if (s == 0) {
1320 A2_LOG_DEBUG(fmt("Sockets will bind to %s", host));
1321 }
1322 }
1323 }
1324 bindAddrsList_.swap(bindAddrsList);
1325 bindAddrsListIt_ = bindAddrsList_.begin();
1326 bindAddrs_ = *bindAddrsListIt_;
1327 }
1328
setSocketRecvBufferSize(int size)1329 void SocketCore::setSocketRecvBufferSize(int size)
1330 {
1331 socketRecvBufferSize_ = size;
1332 }
1333
getSocketRecvBufferSize()1334 int SocketCore::getSocketRecvBufferSize() { return socketRecvBufferSize_; }
1335
getRecvBufferedLength() const1336 size_t SocketCore::getRecvBufferedLength() const
1337 {
1338 #ifdef ENABLE_SSL
1339 if (!tlsSession_) {
1340 return 0;
1341 }
1342
1343 return tlsSession_->getRecvBufferedLength();
1344 #else // !ENABLE_SSL
1345 return 0;
1346 #endif // !ENABLE_SSL
1347 }
1348
getInterfaceAddress(const std::string & iface,int family,int aiFlags)1349 std::vector<SockAddr> SocketCore::getInterfaceAddress(const std::string& iface,
1350 int family, int aiFlags)
1351 {
1352 A2_LOG_DEBUG(fmt("Finding interface %s", iface.c_str()));
1353 std::vector<SockAddr> ifAddrs;
1354 #ifdef HAVE_GETIFADDRS
1355 // First find interface in interface addresses
1356 struct ifaddrs* ifaddr = nullptr;
1357 if (getifaddrs(&ifaddr) == -1) {
1358 int errNum = SOCKET_ERRNO;
1359 A2_LOG_INFO(
1360 fmt(MSG_INTERFACE_NOT_FOUND, iface.c_str(), errorMsg(errNum).c_str()));
1361 }
1362 else {
1363 std::unique_ptr<ifaddrs, decltype(&freeifaddrs)> ifaddrDeleter(ifaddr,
1364 freeifaddrs);
1365 for (ifaddrs* ifa = ifaddr; ifa; ifa = ifa->ifa_next) {
1366 if (!ifa->ifa_addr) {
1367 continue;
1368 }
1369 int iffamily = ifa->ifa_addr->sa_family;
1370 if (family == AF_UNSPEC) {
1371 if (iffamily != AF_INET && iffamily != AF_INET6) {
1372 continue;
1373 }
1374 }
1375 else if (family == AF_INET) {
1376 if (iffamily != AF_INET) {
1377 continue;
1378 }
1379 }
1380 else if (family == AF_INET6) {
1381 if (iffamily != AF_INET6) {
1382 continue;
1383 }
1384 }
1385 else {
1386 continue;
1387 }
1388 if (strcmp(iface.c_str(), ifa->ifa_name) == 0) {
1389 SockAddr soaddr;
1390 soaddr.suLength =
1391 iffamily == AF_INET ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
1392 memcpy(&soaddr.su, ifa->ifa_addr, soaddr.suLength);
1393 ifAddrs.push_back(soaddr);
1394 }
1395 }
1396 }
1397 #endif // HAVE_GETIFADDRS
1398 if (ifAddrs.empty()) {
1399 addrinfo* res;
1400 int s;
1401 s = callGetaddrinfo(&res, iface.c_str(), nullptr, family, SOCK_STREAM,
1402 aiFlags, 0);
1403 if (s) {
1404 A2_LOG_INFO(fmt(MSG_INTERFACE_NOT_FOUND, iface.c_str(), gai_strerror(s)));
1405 }
1406 else {
1407 std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> resDeleter(
1408 res, freeaddrinfo);
1409 addrinfo* rp;
1410 for (rp = res; rp; rp = rp->ai_next) {
1411 // Try to bind socket with this address. If it fails, the
1412 // address is not for this machine.
1413 try {
1414 SocketCore socket;
1415 socket.bind(rp->ai_addr, rp->ai_addrlen);
1416 SockAddr soaddr;
1417 memcpy(&soaddr.su, rp->ai_addr, rp->ai_addrlen);
1418 soaddr.suLength = rp->ai_addrlen;
1419 ifAddrs.push_back(soaddr);
1420 }
1421 catch (RecoverableException& e) {
1422 continue;
1423 }
1424 }
1425 }
1426 }
1427
1428 return ifAddrs;
1429 }
1430
1431 namespace {
1432
1433 int defaultAIFlags = DEFAULT_AI_FLAGS;
1434
getDefaultAIFlags()1435 int getDefaultAIFlags() { return defaultAIFlags; }
1436
1437 } // namespace
1438
setDefaultAIFlags(int flags)1439 void setDefaultAIFlags(int flags) { defaultAIFlags = flags; }
1440
callGetaddrinfo(struct addrinfo ** resPtr,const char * host,const char * service,int family,int sockType,int flags,int protocol)1441 int callGetaddrinfo(struct addrinfo** resPtr, const char* host,
1442 const char* service, int family, int sockType, int flags,
1443 int protocol)
1444 {
1445 struct addrinfo hints;
1446 memset(&hints, 0, sizeof(hints));
1447 hints.ai_family = family;
1448 hints.ai_socktype = sockType;
1449 hints.ai_flags = getDefaultAIFlags();
1450 hints.ai_flags |= flags;
1451 hints.ai_protocol = protocol;
1452 return getaddrinfo(host, service, &hints, resPtr);
1453 }
1454
inetNtop(int af,const void * src,char * dst,socklen_t size)1455 int inetNtop(int af, const void* src, char* dst, socklen_t size)
1456 {
1457 sockaddr_union su;
1458 memset(&su, 0, sizeof(su));
1459 if (af == AF_INET) {
1460 su.in.sin_family = AF_INET;
1461 #ifdef HAVE_SOCKADDR_IN_SIN_LEN
1462 su.in.sin_len = sizeof(su.in);
1463 #endif // HAVE_SOCKADDR_IN_SIN_LEN
1464 memcpy(&su.in.sin_addr, src, sizeof(su.in.sin_addr));
1465 return getnameinfo(&su.sa, sizeof(su.in), dst, size, nullptr, 0,
1466 NI_NUMERICHOST);
1467 }
1468 if (af == AF_INET6) {
1469 su.in6.sin6_family = AF_INET6;
1470 #ifdef HAVE_SOCKADDR_IN6_SIN6_LEN
1471 su.in6.sin6_len = sizeof(su.in6);
1472 #endif // HAVE_SOCKADDR_IN6_SIN6_LEN
1473 memcpy(&su.in6.sin6_addr, src, sizeof(su.in6.sin6_addr));
1474 return getnameinfo(&su.sa, sizeof(su.in6), dst, size, nullptr, 0,
1475 NI_NUMERICHOST);
1476 }
1477 return EAI_FAMILY;
1478 }
1479
inetPton(int af,const char * src,void * dst)1480 int inetPton(int af, const char* src, void* dst)
1481 {
1482 union {
1483 uint32_t ipv4_addr;
1484 unsigned char ipv6_addr[16];
1485 } binaddr;
1486 size_t len = net::getBinAddr(binaddr.ipv6_addr, src);
1487 if (af == AF_INET) {
1488 if (len != 4) {
1489 return -1;
1490 }
1491 in_addr* addr = reinterpret_cast<in_addr*>(dst);
1492 addr->s_addr = binaddr.ipv4_addr;
1493 return 0;
1494 }
1495 if (af == AF_INET6) {
1496 if (len != 16) {
1497 return -1;
1498 }
1499 in6_addr* addr = reinterpret_cast<in6_addr*>(dst);
1500 memcpy(addr->s6_addr, binaddr.ipv6_addr, sizeof(addr->s6_addr));
1501 return 0;
1502 }
1503 return -1;
1504 }
1505
1506 namespace net {
1507
getBinAddr(void * dest,const std::string & ip)1508 size_t getBinAddr(void* dest, const std::string& ip)
1509 {
1510 size_t len = 0;
1511 addrinfo* res;
1512 if (callGetaddrinfo(&res, ip.c_str(), nullptr, AF_UNSPEC, 0, AI_NUMERICHOST,
1513 0) != 0) {
1514 return len;
1515 }
1516 std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> resDeleter(res,
1517 freeaddrinfo);
1518 for (addrinfo* rp = res; rp; rp = rp->ai_next) {
1519 sockaddr_union su;
1520 memcpy(&su, rp->ai_addr, rp->ai_addrlen);
1521 if (rp->ai_family == AF_INET) {
1522 len = sizeof(in_addr);
1523 memcpy(dest, &(su.in.sin_addr), len);
1524 break;
1525 }
1526 else if (rp->ai_family == AF_INET6) {
1527 len = sizeof(in6_addr);
1528 memcpy(dest, &(su.in6.sin6_addr), len);
1529 break;
1530 }
1531 }
1532 return len;
1533 }
1534
verifyHostname(const std::string & hostname,const std::vector<std::string> & dnsNames,const std::vector<std::string> & ipAddrs,const std::string & commonName)1535 bool verifyHostname(const std::string& hostname,
1536 const std::vector<std::string>& dnsNames,
1537 const std::vector<std::string>& ipAddrs,
1538 const std::string& commonName)
1539 {
1540 if (util::isNumericHost(hostname)) {
1541 if (ipAddrs.empty()) {
1542 return commonName == hostname;
1543 }
1544 // We need max 16 bytes to store IPv6 address.
1545 unsigned char binAddr[16];
1546 size_t addrLen = getBinAddr(binAddr, hostname);
1547 if (addrLen == 0) {
1548 return false;
1549 }
1550 for (auto& ipAddr : ipAddrs) {
1551 if (addrLen == ipAddr.size() &&
1552 memcmp(binAddr, ipAddr.c_str(), addrLen) == 0) {
1553 return true;
1554 }
1555 }
1556 return false;
1557 }
1558
1559 if (dnsNames.empty()) {
1560 return util::tlsHostnameMatch(commonName, hostname);
1561 }
1562 for (auto& dnsName : dnsNames) {
1563 if (util::tlsHostnameMatch(dnsName, hostname)) {
1564 return true;
1565 }
1566 }
1567 return false;
1568 }
1569
1570 namespace {
1571 bool ipv4AddrConfigured = true;
1572 bool ipv6AddrConfigured = true;
1573 } // namespace
1574
1575 #ifdef __MINGW32__
1576 namespace {
1577 const uint32_t APIPA_IPV4_BEGIN = 2851995649u; // 169.254.0.1
1578 const uint32_t APIPA_IPV4_END = 2852061183u; // 169.254.255.255
1579 } // namespace
1580 #endif // __MINGW32__
1581
checkAddrconfig()1582 void checkAddrconfig()
1583 {
1584 #ifdef HAVE_IPHLPAPI_H
1585 A2_LOG_INFO("Checking configured addresses");
1586 ULONG bufsize = 15_k;
1587 ULONG retval = 0;
1588 IP_ADAPTER_ADDRESSES* buf = 0;
1589 int numTry = 0;
1590 const int MAX_TRY = 3;
1591 do {
1592 buf = reinterpret_cast<IP_ADAPTER_ADDRESSES*>(malloc(bufsize));
1593 retval = GetAdaptersAddresses(AF_UNSPEC, 0, 0, buf, &bufsize);
1594 if (retval != ERROR_BUFFER_OVERFLOW) {
1595 break;
1596 }
1597 free(buf);
1598 buf = 0;
1599 } while (retval == ERROR_BUFFER_OVERFLOW && numTry < MAX_TRY);
1600 if (retval != NO_ERROR) {
1601 A2_LOG_INFO("GetAdaptersAddresses failed. Assume both IPv4 and IPv6 "
1602 " addresses are configured.");
1603 return;
1604 }
1605 ipv4AddrConfigured = false;
1606 ipv6AddrConfigured = false;
1607 char host[NI_MAXHOST];
1608 sockaddr_union ad;
1609 int rv;
1610 for (IP_ADAPTER_ADDRESSES* p = buf; p; p = p->Next) {
1611 if (p->IfType == IF_TYPE_TUNNEL) {
1612 // Skip tunnel interface because Windows7 automatically setup
1613 // this for IPv6.
1614 continue;
1615 }
1616 PIP_ADAPTER_UNICAST_ADDRESS ucaddr = p->FirstUnicastAddress;
1617 if (!ucaddr) {
1618 continue;
1619 }
1620 for (PIP_ADAPTER_UNICAST_ADDRESS i = ucaddr; i; i = i->Next) {
1621 bool found = false;
1622 switch (i->Address.iSockaddrLength) {
1623 case sizeof(sockaddr_in): {
1624 memcpy(&ad.storage, i->Address.lpSockaddr, i->Address.iSockaddrLength);
1625 uint32_t haddr = ntohl(ad.in.sin_addr.s_addr);
1626 if (haddr != INADDR_LOOPBACK &&
1627 (haddr < APIPA_IPV4_BEGIN || APIPA_IPV4_END <= haddr)) {
1628 ipv4AddrConfigured = true;
1629 found = true;
1630 }
1631 break;
1632 }
1633 case sizeof(sockaddr_in6):
1634 memcpy(&ad.storage, i->Address.lpSockaddr, i->Address.iSockaddrLength);
1635 if (!IN6_IS_ADDR_LOOPBACK(&ad.in6.sin6_addr) &&
1636 !IN6_IS_ADDR_LINKLOCAL(&ad.in6.sin6_addr)) {
1637 ipv6AddrConfigured = true;
1638 found = true;
1639 }
1640 break;
1641 }
1642 rv = getnameinfo(i->Address.lpSockaddr, i->Address.iSockaddrLength, host,
1643 NI_MAXHOST, 0, 0, NI_NUMERICHOST);
1644 if (rv == 0) {
1645 if (found) {
1646 A2_LOG_INFO(fmt("Found configured address: %s", host));
1647 }
1648 else {
1649 A2_LOG_INFO(fmt("Not considered: %s", host));
1650 }
1651 }
1652 }
1653 }
1654 free(buf);
1655
1656 A2_LOG_INFO(fmt("IPv4 configured=%d, IPv6 configured=%d", ipv4AddrConfigured,
1657 ipv6AddrConfigured));
1658 #elif defined(HAVE_GETIFADDRS)
1659 A2_LOG_INFO("Checking configured addresses");
1660 ipv4AddrConfigured = false;
1661 ipv6AddrConfigured = false;
1662 ifaddrs* ifaddr = nullptr;
1663 int rv;
1664 rv = getifaddrs(&ifaddr);
1665 if (rv == -1) {
1666 int errNum = SOCKET_ERRNO;
1667 A2_LOG_INFO(fmt("getifaddrs failed. Cause: %s", errorMsg(errNum).c_str()));
1668 return;
1669 }
1670 std::unique_ptr<ifaddrs, decltype(&freeifaddrs)> ifaddrDeleter(ifaddr,
1671 freeifaddrs);
1672 char host[NI_MAXHOST];
1673 sockaddr_union ad;
1674 for (ifaddrs* ifa = ifaddr; ifa; ifa = ifa->ifa_next) {
1675 if (!ifa->ifa_addr) {
1676 continue;
1677 }
1678 bool found = false;
1679 size_t addrlen = 0;
1680 switch (ifa->ifa_addr->sa_family) {
1681 case AF_INET: {
1682 addrlen = sizeof(sockaddr_in);
1683 memcpy(&ad.storage, ifa->ifa_addr, addrlen);
1684 if (ad.in.sin_addr.s_addr != htonl(INADDR_LOOPBACK)) {
1685 ipv4AddrConfigured = true;
1686 found = true;
1687 }
1688 break;
1689 }
1690 case AF_INET6: {
1691 addrlen = sizeof(sockaddr_in6);
1692 memcpy(&ad.storage, ifa->ifa_addr, addrlen);
1693 if (!IN6_IS_ADDR_LOOPBACK(&ad.in6.sin6_addr) &&
1694 !IN6_IS_ADDR_LINKLOCAL(&ad.in6.sin6_addr)) {
1695 ipv6AddrConfigured = true;
1696 found = true;
1697 }
1698 break;
1699 }
1700 default:
1701 continue;
1702 }
1703 rv = getnameinfo(ifa->ifa_addr, addrlen, host, NI_MAXHOST, nullptr, 0,
1704 NI_NUMERICHOST);
1705 if (rv == 0) {
1706 if (found) {
1707 A2_LOG_INFO(fmt("Found configured address: %s", host));
1708 }
1709 else {
1710 A2_LOG_INFO(fmt("Not considered: %s", host));
1711 }
1712 }
1713 }
1714 A2_LOG_INFO(fmt("IPv4 configured=%d, IPv6 configured=%d", ipv4AddrConfigured,
1715 ipv6AddrConfigured));
1716 #else // !HAVE_GETIFADDRS
1717 A2_LOG_INFO("getifaddrs is not available. Assume IPv4 and IPv6 addresses"
1718 " are configured.");
1719 #endif // !HAVE_GETIFADDRS
1720 }
1721
getIPv4AddrConfigured()1722 bool getIPv4AddrConfigured() { return ipv4AddrConfigured; }
1723
getIPv6AddrConfigured()1724 bool getIPv6AddrConfigured() { return ipv6AddrConfigured; }
1725
1726 } // namespace net
1727
1728 } // namespace aria2
1729