1 /* <!-- copyright */
2 /*
3  * aria2 - The high speed download utility
4  *
5  * Copyright (C) 2006 Tatsuhiro Tsujikawa
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation; either version 2 of the License, or
10  * (at your option) any later version.
11  *
12  * This program is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  * GNU General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20  *
21  * In addition, as a special exception, the copyright holders give
22  * permission to link the code of portions of this program with the
23  * OpenSSL library under certain conditions as described in each
24  * individual source file, and distribute linked combinations
25  * including the two.
26  * You must obey the GNU General Public License in all respects
27  * for all of the code used other than OpenSSL.  If you modify
28  * file(s) with this exception, you may extend this exception to your
29  * version of the file(s), but you are not obligated to do so.  If you
30  * do not wish to do so, delete this exception statement from your
31  * version.  If you delete this exception statement from all source
32  * files in the program, then also delete it here.
33  */
34 /* copyright --> */
35 #include "SocketCore.h"
36 
37 #ifdef HAVE_IPHLPAPI_H
38 #  include <iphlpapi.h>
39 #endif // HAVE_IPHLPAPI_H
40 
41 #include <unistd.h>
42 #ifdef HAVE_IFADDRS_H
43 #  include <ifaddrs.h>
44 #endif // HAVE_IFADDRS_H
45 
46 #include <cerrno>
47 #include <cstring>
48 #include <cassert>
49 #include <sstream>
50 #include <array>
51 
52 #include "message.h"
53 #include "DlRetryEx.h"
54 #include "DlAbortEx.h"
55 #include "fmt.h"
56 #include "util.h"
57 #include "TimeA2.h"
58 #include "a2functional.h"
59 #include "LogFactory.h"
60 #include "A2STR.h"
61 #ifdef ENABLE_SSL
62 #  include "TLSContext.h"
63 #  include "TLSSession.h"
64 #endif // ENABLE_SSL
65 #ifdef HAVE_LIBSSH2
66 #  include "SSHSession.h"
67 #endif // HAVE_LIBSSH2
68 
69 namespace aria2 {
70 
71 #ifndef __MINGW32__
72 #  define SOCKET_ERRNO (errno)
73 #else
74 #  define SOCKET_ERRNO (WSAGetLastError())
75 #endif // __MINGW32__
76 
77 #ifdef __MINGW32__
78 #  define A2_EINPROGRESS WSAEWOULDBLOCK
79 #  define A2_EWOULDBLOCK WSAEWOULDBLOCK
80 #  define A2_EINTR WSAEINTR
81 #  define A2_WOULDBLOCK(e) (e == WSAEWOULDBLOCK)
82 #else // !__MINGW32__
83 #  define A2_EINPROGRESS EINPROGRESS
84 #  ifndef EWOULDBLOCK
85 #    define EWOULDBLOCK EAGAIN
86 #  endif // EWOULDBLOCK
87 #  define A2_EWOULDBLOCK EWOULDBLOCK
88 #  define A2_EINTR EINTR
89 #  if EWOULDBLOCK == EAGAIN
90 #    define A2_WOULDBLOCK(e) (e == EWOULDBLOCK)
91 #  else // EWOULDBLOCK != EAGAIN
92 #    define A2_WOULDBLOCK(e) (e == EWOULDBLOCK || e == EAGAIN)
93 #  endif // EWOULDBLOCK != EAGAIN
94 #endif   // !__MINGW32__
95 
96 #ifdef __MINGW32__
97 #  define CLOSE(X) ::closesocket(X)
98 #else
99 #  define CLOSE(X) close(X)
100 #endif // __MINGW32__
101 
102 namespace {
errorMsg(int errNum)103 std::string errorMsg(int errNum)
104 {
105 #ifndef __MINGW32__
106   return util::safeStrerror(errNum);
107 #else
108   auto msg = util::formatLastError(errNum);
109   if (msg.empty()) {
110     char buf[256];
111     snprintf(buf, sizeof(buf), EX_SOCKET_UNKNOWN_ERROR, errNum, errNum);
112     return buf;
113   }
114   return msg;
115 #endif // __MINGW32__
116 }
117 } // namespace
118 
119 namespace {
120 enum TlsState {
121   // TLS object is not initialized.
122   A2_TLS_NONE = 0,
123   // TLS object is now handshaking.
124   A2_TLS_HANDSHAKING = 2,
125   // TLS object is now connected.
126   A2_TLS_CONNECTED = 3
127 };
128 } // namespace
129 
130 int SocketCore::protocolFamily_ = AF_UNSPEC;
131 int SocketCore::ipDscp_ = 0;
132 
133 std::vector<SockAddr> SocketCore::bindAddrs_;
134 std::vector<std::vector<SockAddr>> SocketCore::bindAddrsList_;
135 std::vector<std::vector<SockAddr>>::iterator SocketCore::bindAddrsListIt_;
136 
137 int SocketCore::socketRecvBufferSize_ = 0;
138 
139 #ifdef ENABLE_SSL
140 std::shared_ptr<TLSContext> SocketCore::clTlsContext_;
141 std::shared_ptr<TLSContext> SocketCore::svTlsContext_;
142 
setClientTLSContext(const std::shared_ptr<TLSContext> & tlsContext)143 void SocketCore::setClientTLSContext(
144     const std::shared_ptr<TLSContext>& tlsContext)
145 {
146   clTlsContext_ = tlsContext;
147 }
148 
setServerTLSContext(const std::shared_ptr<TLSContext> & tlsContext)149 void SocketCore::setServerTLSContext(
150     const std::shared_ptr<TLSContext>& tlsContext)
151 {
152   svTlsContext_ = tlsContext;
153 }
154 #endif // ENABLE_SSL
155 
SocketCore(int sockType)156 SocketCore::SocketCore(int sockType) : sockType_(sockType), sockfd_(-1)
157 {
158   init();
159 }
160 
SocketCore(sock_t sockfd,int sockType)161 SocketCore::SocketCore(sock_t sockfd, int sockType)
162     : sockType_(sockType), sockfd_(sockfd)
163 {
164   init();
165 }
166 
init()167 void SocketCore::init()
168 {
169   blocking_ = true;
170   secure_ = A2_TLS_NONE;
171 
172   wantRead_ = false;
173   wantWrite_ = false;
174 }
175 
~SocketCore()176 SocketCore::~SocketCore() { closeConnection(); }
177 
178 namespace {
applySocketBufferSize(sock_t fd)179 void applySocketBufferSize(sock_t fd)
180 {
181   auto recvBufSize = SocketCore::getSocketRecvBufferSize();
182   if (recvBufSize == 0) {
183     return;
184   }
185 
186   if (setsockopt(fd, SOL_SOCKET, SO_RCVBUF, (a2_sockopt_t)&recvBufSize,
187                  sizeof(recvBufSize)) < 0) {
188     auto errNum = SOCKET_ERRNO;
189     A2_LOG_WARN(fmt("Failed to set socket buffer size. Cause: %s",
190                     errorMsg(errNum).c_str()));
191   }
192 }
193 } // namespace
194 
create(int family,int protocol)195 void SocketCore::create(int family, int protocol)
196 {
197   int errNum;
198   closeConnection();
199   sock_t fd = socket(family, sockType_, protocol);
200   errNum = SOCKET_ERRNO;
201   if (fd == (sock_t)-1) {
202     throw DL_ABORT_EX(
203         fmt("Failed to create socket. Cause:%s", errorMsg(errNum).c_str()));
204   }
205   util::make_fd_cloexec(fd);
206   int sockopt = 1;
207   if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (a2_sockopt_t)&sockopt,
208                  sizeof(sockopt)) < 0) {
209     errNum = SOCKET_ERRNO;
210     CLOSE(fd);
211     throw DL_ABORT_EX(
212         fmt("Failed to create socket. Cause:%s", errorMsg(errNum).c_str()));
213   }
214 
215   applySocketBufferSize(fd);
216 
217   sockfd_ = fd;
218 }
219 
bindInternal(int family,int socktype,int protocol,const struct sockaddr * addr,socklen_t addrlen,std::string & error)220 static sock_t bindInternal(int family, int socktype, int protocol,
221                            const struct sockaddr* addr, socklen_t addrlen,
222                            std::string& error)
223 {
224   int errNum;
225   sock_t fd = socket(family, socktype, protocol);
226   errNum = SOCKET_ERRNO;
227   if (fd == (sock_t)-1) {
228     error = errorMsg(errNum);
229     return -1;
230   }
231   util::make_fd_cloexec(fd);
232   int sockopt = 1;
233   if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (a2_sockopt_t)&sockopt,
234                  sizeof(sockopt)) < 0) {
235     errNum = SOCKET_ERRNO;
236     error = errorMsg(errNum);
237     CLOSE(fd);
238     return -1;
239   }
240 #ifdef IPV6_V6ONLY
241   if (family == AF_INET6) {
242     int sockopt = 1;
243     if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, (a2_sockopt_t)&sockopt,
244                    sizeof(sockopt)) < 0) {
245       errNum = SOCKET_ERRNO;
246       error = errorMsg(errNum);
247       CLOSE(fd);
248       return -1;
249     }
250   }
251 #endif // IPV6_V6ONLY
252 
253   applySocketBufferSize(fd);
254 
255   if (::bind(fd, addr, addrlen) == -1) {
256     errNum = SOCKET_ERRNO;
257     error = errorMsg(errNum);
258     CLOSE(fd);
259     return -1;
260   }
261   return fd;
262 }
263 
bindTo(const char * host,uint16_t port,int family,int sockType,int getaddrinfoFlags,std::string & error)264 static sock_t bindTo(const char* host, uint16_t port, int family, int sockType,
265                      int getaddrinfoFlags, std::string& error)
266 {
267   struct addrinfo* res;
268   int s = callGetaddrinfo(&res, host, util::uitos(port).c_str(), family,
269                           sockType, getaddrinfoFlags, 0);
270   if (s) {
271     error = gai_strerror(s);
272     return -1;
273   }
274   std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> resDeleter(res,
275                                                                 freeaddrinfo);
276   struct addrinfo* rp;
277   for (rp = res; rp; rp = rp->ai_next) {
278     sock_t fd = bindInternal(rp->ai_family, rp->ai_socktype, rp->ai_protocol,
279                              rp->ai_addr, rp->ai_addrlen, error);
280     if (fd != (sock_t)-1) {
281       return fd;
282     }
283   }
284   return -1;
285 }
286 
bindWithFamily(uint16_t port,int family,int flags)287 void SocketCore::bindWithFamily(uint16_t port, int family, int flags)
288 {
289   closeConnection();
290   std::string error;
291   sock_t fd = bindTo(nullptr, port, family, sockType_, flags, error);
292   if (fd == (sock_t)-1) {
293     throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str()));
294   }
295   sockfd_ = fd;
296 }
297 
bind(const char * addr,uint16_t port,int family,int flags)298 void SocketCore::bind(const char* addr, uint16_t port, int family, int flags)
299 {
300   closeConnection();
301   std::string error;
302   const char* addrp;
303   if (addr && addr[0]) {
304     addrp = addr;
305   }
306   else {
307     addrp = nullptr;
308   }
309   if (addrp || !(flags & AI_PASSIVE) || bindAddrsList_.empty()) {
310     sock_t fd = bindTo(addrp, port, family, sockType_, flags, error);
311     if (fd == (sock_t)-1) {
312       throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str()));
313     }
314     sockfd_ = fd;
315     return;
316   }
317 
318   std::array<char, NI_MAXHOST> host;
319   for (const auto& bindAddrs : bindAddrsList_) {
320     for (const auto& a : bindAddrs) {
321       if (family != AF_UNSPEC && family != a.su.storage.ss_family) {
322         continue;
323       }
324       auto s = getnameinfo(&a.su.sa, a.suLength, host.data(), NI_MAXHOST,
325                            nullptr, 0, NI_NUMERICHOST);
326       if (s) {
327         error = gai_strerror(s);
328         continue;
329       }
330       if (addrp && strcmp(host.data(), addrp) != 0) {
331         error = "Given address and resolved address do not match.";
332         continue;
333       }
334       auto fd = bindTo(host.data(), port, family, sockType_, flags, error);
335       if (fd != (sock_t)-1) {
336         sockfd_ = fd;
337         return;
338       }
339     }
340   }
341 
342   if (sockfd_ == (sock_t)-1) {
343     throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str()));
344   }
345 }
346 
bind(uint16_t port,int flags)347 void SocketCore::bind(uint16_t port, int flags)
348 {
349   bind(nullptr, port, protocolFamily_, flags);
350 }
351 
bind(const struct sockaddr * addr,socklen_t addrlen)352 void SocketCore::bind(const struct sockaddr* addr, socklen_t addrlen)
353 {
354   closeConnection();
355   std::string error;
356   sock_t fd = bindInternal(addr->sa_family, sockType_, 0, addr, addrlen, error);
357   if (fd == (sock_t)-1) {
358     throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str()));
359   }
360   sockfd_ = fd;
361 }
362 
beginListen()363 void SocketCore::beginListen()
364 {
365   if (listen(sockfd_, 1024) == -1) {
366     int errNum = SOCKET_ERRNO;
367     throw DL_ABORT_EX(fmt(EX_SOCKET_LISTEN, errorMsg(errNum).c_str()));
368   }
369   setNonBlockingMode();
370 }
371 
acceptConnection() const372 std::shared_ptr<SocketCore> SocketCore::acceptConnection() const
373 {
374   sockaddr_union sockaddr;
375   socklen_t len = sizeof(sockaddr);
376   sock_t fd;
377   while ((fd = accept(sockfd_, &sockaddr.sa, &len)) == (sock_t)-1 &&
378          SOCKET_ERRNO == A2_EINTR)
379     ;
380   int errNum = SOCKET_ERRNO;
381   if (fd == (sock_t)-1) {
382     throw DL_ABORT_EX(fmt(EX_SOCKET_ACCEPT, errorMsg(errNum).c_str()));
383   }
384 
385   applySocketBufferSize(fd);
386 
387   auto sock = std::make_shared<SocketCore>(fd, sockType_);
388   sock->setNonBlockingMode();
389   return sock;
390 }
391 
getAddrInfo() const392 Endpoint SocketCore::getAddrInfo() const
393 {
394   sockaddr_union sockaddr;
395   socklen_t len = sizeof(sockaddr);
396   getAddrInfo(sockaddr, len);
397   return util::getNumericNameInfo(&sockaddr.sa, len);
398 }
399 
getAddrInfo(sockaddr_union & sockaddr,socklen_t & len) const400 void SocketCore::getAddrInfo(sockaddr_union& sockaddr, socklen_t& len) const
401 {
402   if (getsockname(sockfd_, &sockaddr.sa, &len) == -1) {
403     int errNum = SOCKET_ERRNO;
404     throw DL_ABORT_EX(fmt(EX_SOCKET_GET_NAME, errorMsg(errNum).c_str()));
405   }
406 }
407 
getAddressFamily() const408 int SocketCore::getAddressFamily() const
409 {
410   sockaddr_union sockaddr;
411   socklen_t len = sizeof(sockaddr);
412   getAddrInfo(sockaddr, len);
413   return sockaddr.storage.ss_family;
414 }
415 
getPeerInfo() const416 Endpoint SocketCore::getPeerInfo() const
417 {
418   sockaddr_union sockaddr;
419   socklen_t len = sizeof(sockaddr);
420   if (getpeername(sockfd_, &sockaddr.sa, &len) == -1) {
421     int errNum = SOCKET_ERRNO;
422     throw DL_ABORT_EX(fmt(EX_SOCKET_GET_NAME, errorMsg(errNum).c_str()));
423   }
424   return util::getNumericNameInfo(&sockaddr.sa, len);
425 }
426 
establishConnection(const std::string & host,uint16_t port,bool tcpNodelay)427 void SocketCore::establishConnection(const std::string& host, uint16_t port,
428                                      bool tcpNodelay)
429 {
430   closeConnection();
431   std::string error;
432   struct addrinfo* res;
433   int s;
434   s = callGetaddrinfo(&res, host.c_str(), util::uitos(port).c_str(),
435                       protocolFamily_, sockType_, 0, 0);
436   if (s) {
437     throw DL_ABORT_EX(fmt(EX_RESOLVE_HOSTNAME, host.c_str(), gai_strerror(s)));
438   }
439   std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> resDeleter(res,
440                                                                 freeaddrinfo);
441   struct addrinfo* rp;
442   int errNum;
443   for (rp = res; rp; rp = rp->ai_next) {
444     sock_t fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
445     errNum = SOCKET_ERRNO;
446     if (fd == (sock_t)-1) {
447       error = errorMsg(errNum);
448       continue;
449     }
450     util::make_fd_cloexec(fd);
451     int sockopt = 1;
452     if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (a2_sockopt_t)&sockopt,
453                    sizeof(sockopt)) < 0) {
454       errNum = SOCKET_ERRNO;
455       error = errorMsg(errNum);
456       CLOSE(fd);
457       continue;
458     }
459 
460     applySocketBufferSize(fd);
461 
462     if (!bindAddrs_.empty()) {
463       bool bindSuccess = false;
464       for (const auto& soaddr : bindAddrs_) {
465         if (::bind(fd, &soaddr.su.sa, soaddr.suLength) == -1) {
466           errNum = SOCKET_ERRNO;
467           error = errorMsg(errNum);
468           A2_LOG_DEBUG(fmt(EX_SOCKET_BIND, error.c_str()));
469         }
470         else {
471           bindSuccess = true;
472           break;
473         }
474       }
475       if (!bindSuccess) {
476         CLOSE(fd);
477         continue;
478       }
479     }
480     if (!bindAddrsList_.empty()) {
481       ++bindAddrsListIt_;
482       if (bindAddrsListIt_ == bindAddrsList_.end()) {
483         bindAddrsListIt_ = bindAddrsList_.begin();
484       }
485       bindAddrs_ = *bindAddrsListIt_;
486     }
487 
488     sockfd_ = fd;
489     // make socket non-blocking mode
490     setNonBlockingMode();
491     if (tcpNodelay) {
492       setTcpNodelay(true);
493     }
494     if (connect(fd, rp->ai_addr, rp->ai_addrlen) == -1 &&
495         SOCKET_ERRNO != A2_EINPROGRESS) {
496       errNum = SOCKET_ERRNO;
497       error = errorMsg(errNum);
498       CLOSE(sockfd_);
499       sockfd_ = (sock_t)-1;
500       continue;
501     }
502     // TODO at this point, connection may not be established and it may fail
503     // later. In such case, next ai_addr should be tried.
504     break;
505   }
506   if (sockfd_ == (sock_t)-1) {
507     throw DL_ABORT_EX(fmt(EX_SOCKET_CONNECT, host.c_str(), error.c_str()));
508   }
509 }
510 
setSockOpt(int level,int optname,void * optval,socklen_t optlen)511 void SocketCore::setSockOpt(int level, int optname, void* optval,
512                             socklen_t optlen)
513 {
514   if (setsockopt(sockfd_, level, optname, (a2_sockopt_t)optval, optlen) < 0) {
515     int errNum = SOCKET_ERRNO;
516     throw DL_ABORT_EX(fmt(EX_SOCKET_SET_OPT, errorMsg(errNum).c_str()));
517   }
518 }
519 
setMulticastInterface(const std::string & localAddr)520 void SocketCore::setMulticastInterface(const std::string& localAddr)
521 {
522   in_addr addr;
523   if (localAddr.empty()) {
524     addr.s_addr = htonl(INADDR_ANY);
525   }
526   else if (inetPton(AF_INET, localAddr.c_str(), &addr) != 0) {
527     throw DL_ABORT_EX(
528         fmt("%s is not valid IPv4 numeric address", localAddr.c_str()));
529   }
530   setSockOpt(IPPROTO_IP, IP_MULTICAST_IF, &addr, sizeof(addr));
531 }
532 
setMulticastTtl(unsigned char ttl)533 void SocketCore::setMulticastTtl(unsigned char ttl)
534 {
535   setSockOpt(IPPROTO_IP, IP_MULTICAST_TTL, &ttl, sizeof(ttl));
536 }
537 
setMulticastLoop(unsigned char loop)538 void SocketCore::setMulticastLoop(unsigned char loop)
539 {
540   setSockOpt(IPPROTO_IP, IP_MULTICAST_LOOP, &loop, sizeof(loop));
541 }
542 
joinMulticastGroup(const std::string & multicastAddr,uint16_t multicastPort,const std::string & localAddr)543 void SocketCore::joinMulticastGroup(const std::string& multicastAddr,
544                                     uint16_t multicastPort,
545                                     const std::string& localAddr)
546 {
547   in_addr multiAddr;
548   if (inetPton(AF_INET, multicastAddr.c_str(), &multiAddr) != 0) {
549     throw DL_ABORT_EX(
550         fmt("%s is not valid IPv4 numeric address", multicastAddr.c_str()));
551   }
552   in_addr ifAddr;
553   if (localAddr.empty()) {
554     ifAddr.s_addr = htonl(INADDR_ANY);
555   }
556   else if (inetPton(AF_INET, localAddr.c_str(), &ifAddr) != 0) {
557     throw DL_ABORT_EX(
558         fmt("%s is not valid IPv4 numeric address", localAddr.c_str()));
559   }
560   struct ip_mreq mreq;
561   memset(&mreq, 0, sizeof(mreq));
562   mreq.imr_multiaddr = multiAddr;
563   mreq.imr_interface = ifAddr;
564   setSockOpt(IPPROTO_IP, IP_ADD_MEMBERSHIP, &mreq, sizeof(mreq));
565 }
566 
setTcpNodelay(bool f)567 void SocketCore::setTcpNodelay(bool f)
568 {
569   int val = f;
570   setSockOpt(IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val));
571 }
572 
applyIpDscp()573 void SocketCore::applyIpDscp()
574 {
575   if (ipDscp_ == 0) {
576     return;
577   }
578 
579   try {
580     int family = getAddressFamily();
581     if (family == AF_INET) {
582       setSockOpt(IPPROTO_IP, IP_TOS, &ipDscp_, sizeof(ipDscp_));
583     }
584 #if defined(IPV6_TCLASS) || defined(__linux__) || defined(__FreeBSD__) ||      \
585     defined(__NetBSD__) || defined(__OpenBSD__) || defined(__DragonFly__)
586     else if (family == AF_INET6) {
587       setSockOpt(IPPROTO_IPV6, IPV6_TCLASS, &ipDscp_, sizeof(ipDscp_));
588     }
589 #endif
590   }
591   catch (RecoverableException& e) {
592     A2_LOG_INFO_EX("Applying DSCP value failed", e);
593   }
594 }
595 
setNonBlockingMode()596 void SocketCore::setNonBlockingMode()
597 {
598 #ifdef __MINGW32__
599   static u_long flag = 1;
600   if (::ioctlsocket(sockfd_, FIONBIO, &flag) == -1) {
601     int errNum = SOCKET_ERRNO;
602     throw DL_ABORT_EX(fmt(EX_SOCKET_NONBLOCKING, errorMsg(errNum).c_str()));
603   }
604 #else
605   int flags;
606   while ((flags = fcntl(sockfd_, F_GETFL, 0)) == -1 && errno == EINTR)
607     ;
608   // TODO add error handling
609   while (fcntl(sockfd_, F_SETFL, flags | O_NONBLOCK) == -1 && errno == EINTR)
610     ;
611 #endif // __MINGW32__
612   blocking_ = false;
613 }
614 
setBlockingMode()615 void SocketCore::setBlockingMode()
616 {
617 #ifdef __MINGW32__
618   static u_long flag = 0;
619   if (::ioctlsocket(sockfd_, FIONBIO, &flag) == -1) {
620     int errNum = SOCKET_ERRNO;
621     throw DL_ABORT_EX(fmt(EX_SOCKET_BLOCKING, errorMsg(errNum).c_str()));
622   }
623 #else
624   int flags;
625   while ((flags = fcntl(sockfd_, F_GETFL, 0)) == -1 && errno == EINTR)
626     ;
627   // TODO add error handling
628   while (fcntl(sockfd_, F_SETFL, flags & (~O_NONBLOCK)) == -1 && errno == EINTR)
629     ;
630 #endif // __MINGW32__
631   blocking_ = true;
632 }
633 
closeConnection()634 void SocketCore::closeConnection()
635 {
636 #ifdef ENABLE_SSL
637   if (tlsSession_) {
638     tlsSession_->closeConnection();
639     tlsSession_.reset();
640   }
641 #endif // ENABLE_SSL
642 
643 #ifdef HAVE_LIBSSH2
644   if (sshSession_) {
645     sshSession_->closeConnection();
646     sshSession_.reset();
647   }
648 #endif // HAVE_LIBSSH2
649 
650   if (sockfd_ != (sock_t)-1) {
651     shutdown(sockfd_, SHUT_WR);
652     CLOSE(sockfd_);
653     sockfd_ = -1;
654   }
655 }
656 
657 #ifndef __MINGW32__
658 #  define CHECK_FD(fd)                                                         \
659     if (fd < 0 || FD_SETSIZE <= fd) {                                          \
660       logger_->warn("Detected file descriptor >= FD_SETSIZE or < 0. "          \
661                     "Download may slow down or fail.");                        \
662       return false;                                                            \
663     }
664 #endif // !__MINGW32__
665 
isWritable(time_t timeout)666 bool SocketCore::isWritable(time_t timeout)
667 {
668 #ifdef HAVE_POLL
669   struct pollfd p;
670   p.fd = sockfd_;
671   p.events = POLLOUT;
672   int r;
673   while ((r = poll(&p, 1, timeout * 1000)) == -1 && errno == EINTR)
674     ;
675   int errNum = SOCKET_ERRNO;
676   if (r > 0) {
677     return p.revents & (POLLOUT | POLLHUP | POLLERR);
678   }
679   if (r == 0) {
680     return false;
681   }
682   throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_WRITABLE, errorMsg(errNum).c_str()));
683 #else // !HAVE_POLL
684 #  ifndef __MINGW32__
685   CHECK_FD(sockfd_);
686 #  endif // !__MINGW32__
687   fd_set fds;
688   FD_ZERO(&fds);
689   FD_SET(sockfd_, &fds);
690 
691   struct timeval tv;
692   tv.tv_sec = timeout;
693   tv.tv_usec = 0;
694 
695   int r = select(sockfd_ + 1, nullptr, &fds, nullptr, &tv);
696   int errNum = SOCKET_ERRNO;
697   if (r == 1) {
698     return true;
699   }
700   if (r == 0) {
701     // time out
702     return false;
703   }
704   if (errNum == A2_EINPROGRESS || errNum == A2_EINTR) {
705     return false;
706   }
707   throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_WRITABLE, errorMsg(errNum).c_str()));
708 #endif   // !HAVE_POLL
709 }
710 
isReadable(time_t timeout)711 bool SocketCore::isReadable(time_t timeout)
712 {
713 #ifdef HAVE_POLL
714   struct pollfd p;
715   p.fd = sockfd_;
716   p.events = POLLIN;
717   int r;
718   while ((r = poll(&p, 1, timeout * 1000)) == -1 && errno == EINTR)
719     ;
720   int errNum = SOCKET_ERRNO;
721   if (r > 0) {
722     return p.revents & (POLLIN | POLLHUP | POLLERR);
723   }
724   if (r == 0) {
725     return false;
726   }
727   throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_READABLE, errorMsg(errNum).c_str()));
728 #else // !HAVE_POLL
729 #  ifndef __MINGW32__
730   CHECK_FD(sockfd_);
731 #  endif // !__MINGW32__
732   fd_set fds;
733   FD_ZERO(&fds);
734   FD_SET(sockfd_, &fds);
735 
736   struct timeval tv;
737   tv.tv_sec = timeout;
738   tv.tv_usec = 0;
739 
740   int r = select(sockfd_ + 1, &fds, nullptr, nullptr, &tv);
741   int errNum = SOCKET_ERRNO;
742   if (r == 1) {
743     return true;
744   }
745   if (r == 0) {
746     // time out
747     return false;
748   }
749   if (errNum == A2_EINPROGRESS || errNum == A2_EINTR) {
750     return false;
751   }
752   throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_READABLE, errorMsg(errNum).c_str()));
753 #endif   // !HAVE_POLL
754 }
755 
writeVector(a2iovec * iov,size_t iovcnt)756 ssize_t SocketCore::writeVector(a2iovec* iov, size_t iovcnt)
757 {
758   ssize_t ret = 0;
759   wantRead_ = false;
760   wantWrite_ = false;
761   if (!secure_) {
762 #ifdef __MINGW32__
763     DWORD nsent;
764     int rv = WSASend(sockfd_, iov, iovcnt, &nsent, 0, 0, 0);
765     if (rv == 0) {
766       ret = nsent;
767     }
768     else {
769       ret = -1;
770     }
771 #else  // !__MINGW32__
772     while ((ret = writev(sockfd_, iov, iovcnt)) == -1 &&
773            SOCKET_ERRNO == A2_EINTR)
774       ;
775 #endif // !__MINGW32__
776     int errNum = SOCKET_ERRNO;
777     if (ret == -1) {
778       if (!A2_WOULDBLOCK(errNum)) {
779         throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, errorMsg(errNum).c_str()));
780       }
781       wantWrite_ = true;
782       ret = 0;
783     }
784   }
785   else {
786     // For SSL/TLS, we could not use writev, so just iterate vector
787     // and write the data in normal way.
788     for (size_t i = 0; i < iovcnt; ++i) {
789       ssize_t rv = writeData(iov[i].A2IOVEC_BASE, iov[i].A2IOVEC_LEN);
790       if (rv == 0) {
791         break;
792       }
793       ret += rv;
794     }
795   }
796   return ret;
797 }
798 
writeData(const void * data,size_t len)799 ssize_t SocketCore::writeData(const void* data, size_t len)
800 {
801   ssize_t ret = 0;
802   wantRead_ = false;
803   wantWrite_ = false;
804 
805   if (!secure_) {
806     // Cast for Windows send()
807     while ((ret = send(sockfd_, reinterpret_cast<const char*>(data), len, 0)) ==
808                -1 &&
809            SOCKET_ERRNO == A2_EINTR)
810       ;
811     int errNum = SOCKET_ERRNO;
812     if (ret == -1) {
813       if (!A2_WOULDBLOCK(errNum)) {
814         throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, errorMsg(errNum).c_str()));
815       }
816       wantWrite_ = true;
817       ret = 0;
818     }
819   }
820   else {
821 #ifdef ENABLE_SSL
822     ret = tlsSession_->writeData(data, len);
823     if (ret < 0) {
824       if (ret != TLS_ERR_WOULDBLOCK) {
825         throw DL_RETRY_EX(
826             fmt(EX_SOCKET_SEND, tlsSession_->getLastErrorString().c_str()));
827       }
828       if (tlsSession_->checkDirection() == TLS_WANT_READ) {
829         wantRead_ = true;
830       }
831       else {
832         wantWrite_ = true;
833       }
834       ret = 0;
835     }
836 #endif // ENABLE_SSL
837   }
838   return ret;
839 }
840 
readData(void * data,size_t & len)841 void SocketCore::readData(void* data, size_t& len)
842 {
843   ssize_t ret = 0;
844   wantRead_ = false;
845   wantWrite_ = false;
846 
847 #ifdef HAVE_LIBSSH2
848   if (sshSession_) {
849     ret = sshSession_->readData(data, len);
850     if (ret < 0) {
851       if (ret != SSH_ERR_WOULDBLOCK) {
852         throw DL_RETRY_EX(
853             fmt(EX_SOCKET_RECV, sshSession_->getLastErrorString().c_str()));
854       }
855       if (sshSession_->checkDirection() == SSH_WANT_READ) {
856         wantRead_ = true;
857       }
858       else {
859         wantWrite_ = true;
860       }
861       ret = 0;
862     }
863   }
864   else
865 #endif // HAVE_LIBSSH2
866       if (!secure_) {
867     // Cast for Windows recv()
868     while ((ret = recv(sockfd_, reinterpret_cast<char*>(data), len, 0)) == -1 &&
869            SOCKET_ERRNO == A2_EINTR)
870       ;
871     int errNum = SOCKET_ERRNO;
872     if (ret == -1) {
873       if (!A2_WOULDBLOCK(errNum)) {
874         throw DL_RETRY_EX(fmt(EX_SOCKET_RECV, errorMsg(errNum).c_str()));
875       }
876       wantRead_ = true;
877       ret = 0;
878     }
879   }
880   else {
881 #ifdef ENABLE_SSL
882     ret = tlsSession_->readData(data, len);
883     if (ret < 0) {
884       if (ret != TLS_ERR_WOULDBLOCK) {
885         throw DL_RETRY_EX(
886             fmt(EX_SOCKET_RECV, tlsSession_->getLastErrorString().c_str()));
887       }
888       if (tlsSession_->checkDirection() == TLS_WANT_READ) {
889         wantRead_ = true;
890       }
891       else {
892         wantWrite_ = true;
893       }
894       ret = 0;
895     }
896 #endif // ENABLE_SSL
897   }
898 
899   len = ret;
900 }
901 
902 #ifdef ENABLE_SSL
903 
tlsAccept()904 bool SocketCore::tlsAccept()
905 {
906   return tlsHandshake(svTlsContext_.get(), A2STR::NIL);
907 }
908 
tlsConnect(const std::string & hostname)909 bool SocketCore::tlsConnect(const std::string& hostname)
910 {
911   return tlsHandshake(clTlsContext_.get(), hostname);
912 }
913 
tlsHandshake(TLSContext * tlsctx,const std::string & hostname)914 bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
915 {
916   wantRead_ = false;
917   wantWrite_ = false;
918 
919   if (secure_ == A2_TLS_CONNECTED) {
920     // Already connected!
921     return true;
922   }
923 
924   if (secure_ == A2_TLS_NONE) {
925     // Do some initial setup
926     A2_LOG_DEBUG("Creating TLS session");
927     tlsSession_.reset(TLSSession::make(tlsctx));
928     auto rv = tlsSession_->init(sockfd_);
929     if (rv != TLS_ERR_OK) {
930       std::string error = tlsSession_->getLastErrorString();
931       tlsSession_.reset();
932       throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, error.c_str()));
933     }
934     // Check hostname is not numeric and it includes ".". Setting
935     // "localhost" will produce TLS alert with GNUTLS.
936     if (tlsctx->getSide() == TLS_CLIENT && !util::isNumericHost(hostname) &&
937         hostname.find(".") != std::string::npos) {
938       rv = tlsSession_->setSNIHostname(hostname);
939       if (rv != TLS_ERR_OK) {
940         throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE,
941                               tlsSession_->getLastErrorString().c_str()));
942       }
943     }
944     // Done with the setup, now let handshaking begin immediately.
945     secure_ = A2_TLS_HANDSHAKING;
946     A2_LOG_DEBUG("TLS Handshaking");
947   }
948 
949   if (secure_ == A2_TLS_HANDSHAKING) {
950     // Starting handshake after initial setup or still handshaking.
951     TLSVersion ver = TLS_PROTO_NONE;
952     int rv = 0;
953     std::string handshakeError;
954 
955     if (tlsctx->getSide() == TLS_CLIENT) {
956       rv = tlsSession_->tlsConnect(hostname, ver, handshakeError);
957     }
958     else {
959       rv = tlsSession_->tlsAccept(ver);
960     }
961 
962     if (rv == TLS_ERR_OK) {
963       // We're good, more or less.
964       // 1. Construct peerinfo
965       std::stringstream ss;
966       if (!hostname.empty()) {
967         ss << hostname << " (";
968       }
969       auto peerEndpoint = getPeerInfo();
970       ss << peerEndpoint.addr << ":" << peerEndpoint.port;
971       if (!hostname.empty()) {
972         ss << ")";
973       }
974 
975       std::string tlsVersion;
976       switch (ver) {
977       case TLS_PROTO_TLS11:
978         tlsVersion = A2_V_TLS11;
979         break;
980       case TLS_PROTO_TLS12:
981         tlsVersion = A2_V_TLS12;
982         break;
983       case TLS_PROTO_TLS13:
984         tlsVersion = A2_V_TLS13;
985         break;
986       default:
987         assert(0);
988         abort();
989       }
990 
991       auto peerInfo = ss.str();
992 
993       A2_LOG_DEBUG(fmt("Securely connected to %s with %s", peerInfo.c_str(),
994                        tlsVersion.c_str()));
995 
996       // 2. We're connected now!
997       secure_ = A2_TLS_CONNECTED;
998       return true;
999     }
1000 
1001     if (rv == TLS_ERR_WOULDBLOCK) {
1002       // We're not done yet...
1003       if (tlsSession_->checkDirection() == TLS_WANT_READ) {
1004         // ... but read buffers are empty.
1005         wantRead_ = true;
1006       }
1007       else {
1008         // ... but write buffers are full.
1009         wantWrite_ = true;
1010       }
1011       // Returning false (instead of true==success or throwing) will cause this
1012       // function to be called again once buffering is dealt with
1013       return false;
1014     }
1015 
1016     if (rv == TLS_ERR_ERROR) {
1017       // Damn those error.
1018       throw DL_ABORT_EX(fmt("SSL/TLS handshake failure: %s",
1019                             handshakeError.empty()
1020                                 ? tlsSession_->getLastErrorString().c_str()
1021                                 : handshakeError.c_str()));
1022     }
1023 
1024     // Some implementation passed back an invalid result.
1025     throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE,
1026                           "Invalid connect state (this is a bug in the TLS "
1027                           "backend!)"));
1028   }
1029 
1030   // We should never get here, i.e. all possible states should have been handled
1031   // and returned from a branch before! Getting here is a bug, of course!
1032   throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, "Invalid state (this is a bug!)"));
1033 }
1034 
1035 #endif // ENABLE_SSL
1036 
1037 #ifdef HAVE_LIBSSH2
1038 
sshHandshake(const std::string & hashType,const std::string & digest)1039 bool SocketCore::sshHandshake(const std::string& hashType,
1040                               const std::string& digest)
1041 {
1042   wantRead_ = false;
1043   wantWrite_ = false;
1044 
1045   if (!sshSession_) {
1046     sshSession_ = make_unique<SSHSession>();
1047     if (sshSession_->init(sockfd_) == SSH_ERR_ERROR) {
1048       throw DL_ABORT_EX("Could not create SSH session");
1049     }
1050   }
1051   auto rv = sshSession_->handshake();
1052   if (rv == SSH_ERR_WOULDBLOCK) {
1053     sshCheckDirection();
1054     return false;
1055   }
1056   if (rv == SSH_ERR_ERROR) {
1057     throw DL_ABORT_EX(fmt("SSH handshake failure: %s",
1058                           sshSession_->getLastErrorString().c_str()));
1059   }
1060   if (!hashType.empty()) {
1061     auto actualDigest = sshSession_->hostkeyMessageDigest(hashType);
1062     if (actualDigest.empty()) {
1063       throw DL_ABORT_EX(fmt("Empty host key fingerprint from SSH layer: "
1064                             "perhaps hash type %s is not supported?",
1065                             hashType.c_str()));
1066     }
1067     if (digest != actualDigest) {
1068       throw DL_ABORT_EX(fmt("Unexpected SSH host key: expected %s, actual %s",
1069                             util::toHex(digest).c_str(),
1070                             util::toHex(actualDigest).c_str()));
1071     }
1072   }
1073   return true;
1074 }
1075 
sshAuthPassword(const std::string & user,const std::string & password)1076 bool SocketCore::sshAuthPassword(const std::string& user,
1077                                  const std::string& password)
1078 {
1079   assert(sshSession_);
1080 
1081   wantRead_ = false;
1082   wantWrite_ = false;
1083 
1084   auto rv = sshSession_->authPassword(user, password);
1085   if (rv == SSH_ERR_WOULDBLOCK) {
1086     sshCheckDirection();
1087     return false;
1088   }
1089   if (rv == SSH_ERR_ERROR) {
1090     throw DL_ABORT_EX(fmt("SSH authentication failure: %s",
1091                           sshSession_->getLastErrorString().c_str()));
1092   }
1093   return true;
1094 }
1095 
sshSFTPOpen(const std::string & path)1096 bool SocketCore::sshSFTPOpen(const std::string& path)
1097 {
1098   assert(sshSession_);
1099 
1100   wantRead_ = false;
1101   wantWrite_ = false;
1102 
1103   auto rv = sshSession_->sftpOpen(path);
1104   if (rv == SSH_ERR_WOULDBLOCK) {
1105     sshCheckDirection();
1106     return false;
1107   }
1108   if (rv == SSH_ERR_ERROR) {
1109     throw DL_ABORT_EX(fmt("SSH opening SFTP path %s failed: %s", path.c_str(),
1110                           sshSession_->getLastErrorString().c_str()));
1111   }
1112   return true;
1113 }
1114 
sshSFTPClose()1115 bool SocketCore::sshSFTPClose()
1116 {
1117   assert(sshSession_);
1118 
1119   wantRead_ = false;
1120   wantWrite_ = false;
1121 
1122   auto rv = sshSession_->sftpClose();
1123   if (rv == SSH_ERR_WOULDBLOCK) {
1124     sshCheckDirection();
1125     return false;
1126   }
1127   if (rv == SSH_ERR_ERROR) {
1128     throw DL_ABORT_EX(fmt("SSH closing SFTP failed: %s",
1129                           sshSession_->getLastErrorString().c_str()));
1130   }
1131   return true;
1132 }
1133 
sshSFTPStat(int64_t & totalLength,time_t & mtime,const std::string & path)1134 bool SocketCore::sshSFTPStat(int64_t& totalLength, time_t& mtime,
1135                              const std::string& path)
1136 {
1137   assert(sshSession_);
1138 
1139   wantRead_ = false;
1140   wantWrite_ = false;
1141 
1142   auto rv = sshSession_->sftpStat(totalLength, mtime);
1143   if (rv == SSH_ERR_WOULDBLOCK) {
1144     sshCheckDirection();
1145     return false;
1146   }
1147   if (rv == SSH_ERR_ERROR) {
1148     throw DL_ABORT_EX(fmt("SSH stat SFTP path %s filed: %s", path.c_str(),
1149                           sshSession_->getLastErrorString().c_str()));
1150   }
1151   return true;
1152 }
1153 
sshSFTPSeek(int64_t pos)1154 void SocketCore::sshSFTPSeek(int64_t pos)
1155 {
1156   assert(sshSession_);
1157 
1158   sshSession_->sftpSeek(pos);
1159 }
1160 
sshGracefulShutdown()1161 bool SocketCore::sshGracefulShutdown()
1162 {
1163   assert(sshSession_);
1164   auto rv = sshSession_->gracefulShutdown();
1165   if (rv == SSH_ERR_WOULDBLOCK) {
1166     sshCheckDirection();
1167     return false;
1168   }
1169   if (rv == SSH_ERR_ERROR) {
1170     throw DL_ABORT_EX(fmt("SSH graceful shutdown failed: %s",
1171                           sshSession_->getLastErrorString().c_str()));
1172   }
1173   return true;
1174 }
1175 
sshCheckDirection()1176 void SocketCore::sshCheckDirection()
1177 {
1178   if (sshSession_->checkDirection() == SSH_WANT_READ) {
1179     wantRead_ = true;
1180   }
1181   else {
1182     wantWrite_ = true;
1183   }
1184 }
1185 
1186 #endif // HAVE_LIBSSH2
1187 
writeData(const void * data,size_t len,const std::string & host,uint16_t port)1188 ssize_t SocketCore::writeData(const void* data, size_t len,
1189                               const std::string& host, uint16_t port)
1190 {
1191   wantRead_ = false;
1192   wantWrite_ = false;
1193 
1194   struct addrinfo* res;
1195   int s;
1196   s = callGetaddrinfo(&res, host.c_str(), util::uitos(port).c_str(),
1197                       protocolFamily_, sockType_, 0, 0);
1198   if (s) {
1199     throw DL_ABORT_EX(fmt(EX_SOCKET_SEND, gai_strerror(s)));
1200   }
1201   std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> resDeleter(res,
1202                                                                 freeaddrinfo);
1203   struct addrinfo* rp;
1204   ssize_t r = -1;
1205   int errNum = 0;
1206   for (rp = res; rp; rp = rp->ai_next) {
1207     // Cast for Windows sendto()
1208     while ((r = sendto(sockfd_, reinterpret_cast<const char*>(data), len, 0,
1209                        rp->ai_addr, rp->ai_addrlen)) == -1 &&
1210            A2_EINTR == SOCKET_ERRNO)
1211       ;
1212     errNum = SOCKET_ERRNO;
1213     if (r == static_cast<ssize_t>(len)) {
1214       break;
1215     }
1216     if (r == -1 && A2_WOULDBLOCK(errNum)) {
1217       wantWrite_ = true;
1218       r = 0;
1219       break;
1220     }
1221   }
1222   if (r == -1) {
1223     throw DL_ABORT_EX(fmt(EX_SOCKET_SEND, errorMsg(errNum).c_str()));
1224   }
1225   return r;
1226 }
1227 
readDataFrom(void * data,size_t len,Endpoint & sender)1228 ssize_t SocketCore::readDataFrom(void* data, size_t len, Endpoint& sender)
1229 {
1230   wantRead_ = false;
1231   wantWrite_ = false;
1232   sockaddr_union sockaddr;
1233   socklen_t sockaddrlen = sizeof(sockaddr);
1234   ssize_t r;
1235   // Cast for Windows recvfrom()
1236   while ((r = recvfrom(sockfd_, reinterpret_cast<char*>(data), len, 0,
1237                        &sockaddr.sa, &sockaddrlen)) == -1 &&
1238          A2_EINTR == SOCKET_ERRNO)
1239     ;
1240   int errNum = SOCKET_ERRNO;
1241   if (r == -1) {
1242     if (!A2_WOULDBLOCK(errNum)) {
1243       throw DL_RETRY_EX(fmt(EX_SOCKET_RECV, errorMsg(errNum).c_str()));
1244     }
1245     wantRead_ = true;
1246     r = 0;
1247   }
1248   else {
1249     sender = util::getNumericNameInfo(&sockaddr.sa, sockaddrlen);
1250   }
1251 
1252   return r;
1253 }
1254 
getSocketError() const1255 std::string SocketCore::getSocketError() const
1256 {
1257   int error;
1258   socklen_t optlen = sizeof(error);
1259 
1260   if (getsockopt(sockfd_, SOL_SOCKET, SO_ERROR, (a2_sockopt_t)&error,
1261                  &optlen) == -1) {
1262     int errNum = SOCKET_ERRNO;
1263     throw DL_ABORT_EX(
1264         fmt("Failed to get socket error: %s", errorMsg(errNum).c_str()));
1265   }
1266   if (error != 0) {
1267     return errorMsg(error);
1268   }
1269   return "";
1270 }
1271 
wantRead() const1272 bool SocketCore::wantRead() const { return wantRead_; }
1273 
wantWrite() const1274 bool SocketCore::wantWrite() const { return wantWrite_; }
1275 
bindAddress(const std::string & iface)1276 void SocketCore::bindAddress(const std::string& iface)
1277 {
1278   auto bindAddrs = getInterfaceAddress(iface, protocolFamily_);
1279   if (bindAddrs.empty()) {
1280     throw DL_ABORT_EX(
1281         fmt(MSG_INTERFACE_NOT_FOUND, iface.c_str(), "not available"));
1282   }
1283   bindAddrs_.swap(bindAddrs);
1284   for (const auto& a : bindAddrs_) {
1285     char host[NI_MAXHOST];
1286     int s;
1287     s = getnameinfo(&a.su.sa, a.suLength, host, NI_MAXHOST, nullptr, 0,
1288                     NI_NUMERICHOST);
1289     if (s == 0) {
1290       A2_LOG_DEBUG(fmt("Sockets will bind to %s", host));
1291     }
1292   }
1293   bindAddrsList_.push_back(bindAddrs_);
1294   bindAddrsListIt_ = std::begin(bindAddrsList_);
1295 }
1296 
bindAllAddress(const std::string & ifaces)1297 void SocketCore::bindAllAddress(const std::string& ifaces)
1298 {
1299   std::vector<std::vector<SockAddr>> bindAddrsList;
1300   std::vector<std::string> ifaceList;
1301   util::split(ifaces.begin(), ifaces.end(), std::back_inserter(ifaceList), ',',
1302               true);
1303   if (ifaceList.empty()) {
1304     throw DL_ABORT_EX(
1305         "List of interfaces is empty, one or more interfaces is required");
1306   }
1307   for (auto& iface : ifaceList) {
1308     auto bindAddrs = getInterfaceAddress(iface, protocolFamily_);
1309     if (bindAddrs.empty()) {
1310       throw DL_ABORT_EX(
1311           fmt(MSG_INTERFACE_NOT_FOUND, iface.c_str(), "not available"));
1312     }
1313     bindAddrsList.push_back(bindAddrs);
1314     for (const auto& a : bindAddrs) {
1315       char host[NI_MAXHOST];
1316       int s;
1317       s = getnameinfo(&a.su.sa, a.suLength, host, NI_MAXHOST, nullptr, 0,
1318                       NI_NUMERICHOST);
1319       if (s == 0) {
1320         A2_LOG_DEBUG(fmt("Sockets will bind to %s", host));
1321       }
1322     }
1323   }
1324   bindAddrsList_.swap(bindAddrsList);
1325   bindAddrsListIt_ = bindAddrsList_.begin();
1326   bindAddrs_ = *bindAddrsListIt_;
1327 }
1328 
setSocketRecvBufferSize(int size)1329 void SocketCore::setSocketRecvBufferSize(int size)
1330 {
1331   socketRecvBufferSize_ = size;
1332 }
1333 
getSocketRecvBufferSize()1334 int SocketCore::getSocketRecvBufferSize() { return socketRecvBufferSize_; }
1335 
getRecvBufferedLength() const1336 size_t SocketCore::getRecvBufferedLength() const
1337 {
1338 #ifdef ENABLE_SSL
1339   if (!tlsSession_) {
1340     return 0;
1341   }
1342 
1343   return tlsSession_->getRecvBufferedLength();
1344 #else  // !ENABLE_SSL
1345   return 0;
1346 #endif // !ENABLE_SSL
1347 }
1348 
getInterfaceAddress(const std::string & iface,int family,int aiFlags)1349 std::vector<SockAddr> SocketCore::getInterfaceAddress(const std::string& iface,
1350                                                       int family, int aiFlags)
1351 {
1352   A2_LOG_DEBUG(fmt("Finding interface %s", iface.c_str()));
1353   std::vector<SockAddr> ifAddrs;
1354 #ifdef HAVE_GETIFADDRS
1355   // First find interface in interface addresses
1356   struct ifaddrs* ifaddr = nullptr;
1357   if (getifaddrs(&ifaddr) == -1) {
1358     int errNum = SOCKET_ERRNO;
1359     A2_LOG_INFO(
1360         fmt(MSG_INTERFACE_NOT_FOUND, iface.c_str(), errorMsg(errNum).c_str()));
1361   }
1362   else {
1363     std::unique_ptr<ifaddrs, decltype(&freeifaddrs)> ifaddrDeleter(ifaddr,
1364                                                                    freeifaddrs);
1365     for (ifaddrs* ifa = ifaddr; ifa; ifa = ifa->ifa_next) {
1366       if (!ifa->ifa_addr) {
1367         continue;
1368       }
1369       int iffamily = ifa->ifa_addr->sa_family;
1370       if (family == AF_UNSPEC) {
1371         if (iffamily != AF_INET && iffamily != AF_INET6) {
1372           continue;
1373         }
1374       }
1375       else if (family == AF_INET) {
1376         if (iffamily != AF_INET) {
1377           continue;
1378         }
1379       }
1380       else if (family == AF_INET6) {
1381         if (iffamily != AF_INET6) {
1382           continue;
1383         }
1384       }
1385       else {
1386         continue;
1387       }
1388       if (strcmp(iface.c_str(), ifa->ifa_name) == 0) {
1389         SockAddr soaddr;
1390         soaddr.suLength =
1391             iffamily == AF_INET ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
1392         memcpy(&soaddr.su, ifa->ifa_addr, soaddr.suLength);
1393         ifAddrs.push_back(soaddr);
1394       }
1395     }
1396   }
1397 #endif // HAVE_GETIFADDRS
1398   if (ifAddrs.empty()) {
1399     addrinfo* res;
1400     int s;
1401     s = callGetaddrinfo(&res, iface.c_str(), nullptr, family, SOCK_STREAM,
1402                         aiFlags, 0);
1403     if (s) {
1404       A2_LOG_INFO(fmt(MSG_INTERFACE_NOT_FOUND, iface.c_str(), gai_strerror(s)));
1405     }
1406     else {
1407       std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> resDeleter(
1408           res, freeaddrinfo);
1409       addrinfo* rp;
1410       for (rp = res; rp; rp = rp->ai_next) {
1411         // Try to bind socket with this address. If it fails, the
1412         // address is not for this machine.
1413         try {
1414           SocketCore socket;
1415           socket.bind(rp->ai_addr, rp->ai_addrlen);
1416           SockAddr soaddr;
1417           memcpy(&soaddr.su, rp->ai_addr, rp->ai_addrlen);
1418           soaddr.suLength = rp->ai_addrlen;
1419           ifAddrs.push_back(soaddr);
1420         }
1421         catch (RecoverableException& e) {
1422           continue;
1423         }
1424       }
1425     }
1426   }
1427 
1428   return ifAddrs;
1429 }
1430 
1431 namespace {
1432 
1433 int defaultAIFlags = DEFAULT_AI_FLAGS;
1434 
getDefaultAIFlags()1435 int getDefaultAIFlags() { return defaultAIFlags; }
1436 
1437 } // namespace
1438 
setDefaultAIFlags(int flags)1439 void setDefaultAIFlags(int flags) { defaultAIFlags = flags; }
1440 
callGetaddrinfo(struct addrinfo ** resPtr,const char * host,const char * service,int family,int sockType,int flags,int protocol)1441 int callGetaddrinfo(struct addrinfo** resPtr, const char* host,
1442                     const char* service, int family, int sockType, int flags,
1443                     int protocol)
1444 {
1445   struct addrinfo hints;
1446   memset(&hints, 0, sizeof(hints));
1447   hints.ai_family = family;
1448   hints.ai_socktype = sockType;
1449   hints.ai_flags = getDefaultAIFlags();
1450   hints.ai_flags |= flags;
1451   hints.ai_protocol = protocol;
1452   return getaddrinfo(host, service, &hints, resPtr);
1453 }
1454 
inetNtop(int af,const void * src,char * dst,socklen_t size)1455 int inetNtop(int af, const void* src, char* dst, socklen_t size)
1456 {
1457   sockaddr_union su;
1458   memset(&su, 0, sizeof(su));
1459   if (af == AF_INET) {
1460     su.in.sin_family = AF_INET;
1461 #ifdef HAVE_SOCKADDR_IN_SIN_LEN
1462     su.in.sin_len = sizeof(su.in);
1463 #endif // HAVE_SOCKADDR_IN_SIN_LEN
1464     memcpy(&su.in.sin_addr, src, sizeof(su.in.sin_addr));
1465     return getnameinfo(&su.sa, sizeof(su.in), dst, size, nullptr, 0,
1466                        NI_NUMERICHOST);
1467   }
1468   if (af == AF_INET6) {
1469     su.in6.sin6_family = AF_INET6;
1470 #ifdef HAVE_SOCKADDR_IN6_SIN6_LEN
1471     su.in6.sin6_len = sizeof(su.in6);
1472 #endif // HAVE_SOCKADDR_IN6_SIN6_LEN
1473     memcpy(&su.in6.sin6_addr, src, sizeof(su.in6.sin6_addr));
1474     return getnameinfo(&su.sa, sizeof(su.in6), dst, size, nullptr, 0,
1475                        NI_NUMERICHOST);
1476   }
1477   return EAI_FAMILY;
1478 }
1479 
inetPton(int af,const char * src,void * dst)1480 int inetPton(int af, const char* src, void* dst)
1481 {
1482   union {
1483     uint32_t ipv4_addr;
1484     unsigned char ipv6_addr[16];
1485   } binaddr;
1486   size_t len = net::getBinAddr(binaddr.ipv6_addr, src);
1487   if (af == AF_INET) {
1488     if (len != 4) {
1489       return -1;
1490     }
1491     in_addr* addr = reinterpret_cast<in_addr*>(dst);
1492     addr->s_addr = binaddr.ipv4_addr;
1493     return 0;
1494   }
1495   if (af == AF_INET6) {
1496     if (len != 16) {
1497       return -1;
1498     }
1499     in6_addr* addr = reinterpret_cast<in6_addr*>(dst);
1500     memcpy(addr->s6_addr, binaddr.ipv6_addr, sizeof(addr->s6_addr));
1501     return 0;
1502   }
1503   return -1;
1504 }
1505 
1506 namespace net {
1507 
getBinAddr(void * dest,const std::string & ip)1508 size_t getBinAddr(void* dest, const std::string& ip)
1509 {
1510   size_t len = 0;
1511   addrinfo* res;
1512   if (callGetaddrinfo(&res, ip.c_str(), nullptr, AF_UNSPEC, 0, AI_NUMERICHOST,
1513                       0) != 0) {
1514     return len;
1515   }
1516   std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> resDeleter(res,
1517                                                                 freeaddrinfo);
1518   for (addrinfo* rp = res; rp; rp = rp->ai_next) {
1519     sockaddr_union su;
1520     memcpy(&su, rp->ai_addr, rp->ai_addrlen);
1521     if (rp->ai_family == AF_INET) {
1522       len = sizeof(in_addr);
1523       memcpy(dest, &(su.in.sin_addr), len);
1524       break;
1525     }
1526     else if (rp->ai_family == AF_INET6) {
1527       len = sizeof(in6_addr);
1528       memcpy(dest, &(su.in6.sin6_addr), len);
1529       break;
1530     }
1531   }
1532   return len;
1533 }
1534 
verifyHostname(const std::string & hostname,const std::vector<std::string> & dnsNames,const std::vector<std::string> & ipAddrs,const std::string & commonName)1535 bool verifyHostname(const std::string& hostname,
1536                     const std::vector<std::string>& dnsNames,
1537                     const std::vector<std::string>& ipAddrs,
1538                     const std::string& commonName)
1539 {
1540   if (util::isNumericHost(hostname)) {
1541     if (ipAddrs.empty()) {
1542       return commonName == hostname;
1543     }
1544     // We need max 16 bytes to store IPv6 address.
1545     unsigned char binAddr[16];
1546     size_t addrLen = getBinAddr(binAddr, hostname);
1547     if (addrLen == 0) {
1548       return false;
1549     }
1550     for (auto& ipAddr : ipAddrs) {
1551       if (addrLen == ipAddr.size() &&
1552           memcmp(binAddr, ipAddr.c_str(), addrLen) == 0) {
1553         return true;
1554       }
1555     }
1556     return false;
1557   }
1558 
1559   if (dnsNames.empty()) {
1560     return util::tlsHostnameMatch(commonName, hostname);
1561   }
1562   for (auto& dnsName : dnsNames) {
1563     if (util::tlsHostnameMatch(dnsName, hostname)) {
1564       return true;
1565     }
1566   }
1567   return false;
1568 }
1569 
1570 namespace {
1571 bool ipv4AddrConfigured = true;
1572 bool ipv6AddrConfigured = true;
1573 } // namespace
1574 
1575 #ifdef __MINGW32__
1576 namespace {
1577 const uint32_t APIPA_IPV4_BEGIN = 2851995649u; // 169.254.0.1
1578 const uint32_t APIPA_IPV4_END = 2852061183u;   // 169.254.255.255
1579 } // namespace
1580 #endif // __MINGW32__
1581 
checkAddrconfig()1582 void checkAddrconfig()
1583 {
1584 #ifdef HAVE_IPHLPAPI_H
1585   A2_LOG_INFO("Checking configured addresses");
1586   ULONG bufsize = 15_k;
1587   ULONG retval = 0;
1588   IP_ADAPTER_ADDRESSES* buf = 0;
1589   int numTry = 0;
1590   const int MAX_TRY = 3;
1591   do {
1592     buf = reinterpret_cast<IP_ADAPTER_ADDRESSES*>(malloc(bufsize));
1593     retval = GetAdaptersAddresses(AF_UNSPEC, 0, 0, buf, &bufsize);
1594     if (retval != ERROR_BUFFER_OVERFLOW) {
1595       break;
1596     }
1597     free(buf);
1598     buf = 0;
1599   } while (retval == ERROR_BUFFER_OVERFLOW && numTry < MAX_TRY);
1600   if (retval != NO_ERROR) {
1601     A2_LOG_INFO("GetAdaptersAddresses failed. Assume both IPv4 and IPv6 "
1602                 " addresses are configured.");
1603     return;
1604   }
1605   ipv4AddrConfigured = false;
1606   ipv6AddrConfigured = false;
1607   char host[NI_MAXHOST];
1608   sockaddr_union ad;
1609   int rv;
1610   for (IP_ADAPTER_ADDRESSES* p = buf; p; p = p->Next) {
1611     if (p->IfType == IF_TYPE_TUNNEL) {
1612       // Skip tunnel interface because Windows7 automatically setup
1613       // this for IPv6.
1614       continue;
1615     }
1616     PIP_ADAPTER_UNICAST_ADDRESS ucaddr = p->FirstUnicastAddress;
1617     if (!ucaddr) {
1618       continue;
1619     }
1620     for (PIP_ADAPTER_UNICAST_ADDRESS i = ucaddr; i; i = i->Next) {
1621       bool found = false;
1622       switch (i->Address.iSockaddrLength) {
1623       case sizeof(sockaddr_in): {
1624         memcpy(&ad.storage, i->Address.lpSockaddr, i->Address.iSockaddrLength);
1625         uint32_t haddr = ntohl(ad.in.sin_addr.s_addr);
1626         if (haddr != INADDR_LOOPBACK &&
1627             (haddr < APIPA_IPV4_BEGIN || APIPA_IPV4_END <= haddr)) {
1628           ipv4AddrConfigured = true;
1629           found = true;
1630         }
1631         break;
1632       }
1633       case sizeof(sockaddr_in6):
1634         memcpy(&ad.storage, i->Address.lpSockaddr, i->Address.iSockaddrLength);
1635         if (!IN6_IS_ADDR_LOOPBACK(&ad.in6.sin6_addr) &&
1636             !IN6_IS_ADDR_LINKLOCAL(&ad.in6.sin6_addr)) {
1637           ipv6AddrConfigured = true;
1638           found = true;
1639         }
1640         break;
1641       }
1642       rv = getnameinfo(i->Address.lpSockaddr, i->Address.iSockaddrLength, host,
1643                        NI_MAXHOST, 0, 0, NI_NUMERICHOST);
1644       if (rv == 0) {
1645         if (found) {
1646           A2_LOG_INFO(fmt("Found configured address: %s", host));
1647         }
1648         else {
1649           A2_LOG_INFO(fmt("Not considered: %s", host));
1650         }
1651       }
1652     }
1653   }
1654   free(buf);
1655 
1656   A2_LOG_INFO(fmt("IPv4 configured=%d, IPv6 configured=%d", ipv4AddrConfigured,
1657                   ipv6AddrConfigured));
1658 #elif defined(HAVE_GETIFADDRS)
1659   A2_LOG_INFO("Checking configured addresses");
1660   ipv4AddrConfigured = false;
1661   ipv6AddrConfigured = false;
1662   ifaddrs* ifaddr = nullptr;
1663   int rv;
1664   rv = getifaddrs(&ifaddr);
1665   if (rv == -1) {
1666     int errNum = SOCKET_ERRNO;
1667     A2_LOG_INFO(fmt("getifaddrs failed. Cause: %s", errorMsg(errNum).c_str()));
1668     return;
1669   }
1670   std::unique_ptr<ifaddrs, decltype(&freeifaddrs)> ifaddrDeleter(ifaddr,
1671                                                                  freeifaddrs);
1672   char host[NI_MAXHOST];
1673   sockaddr_union ad;
1674   for (ifaddrs* ifa = ifaddr; ifa; ifa = ifa->ifa_next) {
1675     if (!ifa->ifa_addr) {
1676       continue;
1677     }
1678     bool found = false;
1679     size_t addrlen = 0;
1680     switch (ifa->ifa_addr->sa_family) {
1681     case AF_INET: {
1682       addrlen = sizeof(sockaddr_in);
1683       memcpy(&ad.storage, ifa->ifa_addr, addrlen);
1684       if (ad.in.sin_addr.s_addr != htonl(INADDR_LOOPBACK)) {
1685         ipv4AddrConfigured = true;
1686         found = true;
1687       }
1688       break;
1689     }
1690     case AF_INET6: {
1691       addrlen = sizeof(sockaddr_in6);
1692       memcpy(&ad.storage, ifa->ifa_addr, addrlen);
1693       if (!IN6_IS_ADDR_LOOPBACK(&ad.in6.sin6_addr) &&
1694           !IN6_IS_ADDR_LINKLOCAL(&ad.in6.sin6_addr)) {
1695         ipv6AddrConfigured = true;
1696         found = true;
1697       }
1698       break;
1699     }
1700     default:
1701       continue;
1702     }
1703     rv = getnameinfo(ifa->ifa_addr, addrlen, host, NI_MAXHOST, nullptr, 0,
1704                      NI_NUMERICHOST);
1705     if (rv == 0) {
1706       if (found) {
1707         A2_LOG_INFO(fmt("Found configured address: %s", host));
1708       }
1709       else {
1710         A2_LOG_INFO(fmt("Not considered: %s", host));
1711       }
1712     }
1713   }
1714   A2_LOG_INFO(fmt("IPv4 configured=%d, IPv6 configured=%d", ipv4AddrConfigured,
1715                   ipv6AddrConfigured));
1716 #else  // !HAVE_GETIFADDRS
1717   A2_LOG_INFO("getifaddrs is not available. Assume IPv4 and IPv6 addresses"
1718               " are configured.");
1719 #endif // !HAVE_GETIFADDRS
1720 }
1721 
getIPv4AddrConfigured()1722 bool getIPv4AddrConfigured() { return ipv4AddrConfigured; }
1723 
getIPv6AddrConfigured()1724 bool getIPv6AddrConfigured() { return ipv6AddrConfigured; }
1725 
1726 } // namespace net
1727 
1728 } // namespace aria2
1729