1 /*!
2  *  Copyright (c) 2014-2019 by Contributors
3  * \file socket.h
4  * \brief this file aims to provide a wrapper of sockets
5  * \author Tianqi Chen
6  */
7 #ifndef RABIT_INTERNAL_SOCKET_H_
8 #define RABIT_INTERNAL_SOCKET_H_
9 #if defined(_WIN32)
10 #include <winsock2.h>
11 #include <ws2tcpip.h>
12 
13 #ifdef _MSC_VER
14 #pragma comment(lib, "Ws2_32.lib")
15 #endif  // _MSC_VER
16 
17 #else
18 
19 #include <fcntl.h>
20 #include <netdb.h>
21 #include <cerrno>
22 #include <unistd.h>
23 #include <arpa/inet.h>
24 #include <netinet/in.h>
25 #include <sys/socket.h>
26 #include <sys/ioctl.h>
27 
28 #if defined(__sun) || defined(sun)
29 #include <sys/sockio.h>
30 #endif  // defined(__sun) || defined(sun)
31 
32 #endif  // defined(_WIN32)
33 
34 #include <string>
35 #include <cstring>
36 #include <vector>
37 #include <chrono>
38 #include <unordered_map>
39 #include "utils.h"
40 
41 #if defined(_WIN32) && !defined(__MINGW32__)
42 typedef int ssize_t;
43 #endif  // defined(_WIN32) || defined(__MINGW32__)
44 
45 #if defined(_WIN32)
46 using sock_size_t = int;
47 
48 #else
49 
50 #include <sys/poll.h>
51 using SOCKET = int;
52 using sock_size_t = size_t;  // NOLINT
53 #endif  // defined(_WIN32)
54 
55 #define IS_MINGW() defined(__MINGW32__)
56 
57 #if IS_MINGW()
MingWError()58 inline void MingWError() {
59   throw dmlc::Error("Distributed training on mingw is not supported.");
60 }
61 #endif  // IS_MINGW()
62 
63 #if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
64 /*
65  * On later mingw versions poll should be supported (with bugs).  See:
66  * https://stackoverflow.com/a/60623080
67  *
68  * But right now the mingw distributed with R 3.6 doesn't support it.
69  * So we just give a warning and provide dummy implementation to get
70  * compilation passed.  Otherwise we will have to provide a stub for
71  * RABIT.
72  *
73  * Even on mingw version that has these structures and flags defined,
74  * functions like `send` and `listen` might have unresolved linkage to
75  * their implementation.  So supporting mingw is quite difficult at
76  * the time of writing.
77  */
78 #pragma message("Distributed training on mingw is not supported.")
79 typedef struct pollfd {
80   SOCKET fd;
81   short  events;
82   short  revents;
83 } WSAPOLLFD, *PWSAPOLLFD, *LPWSAPOLLFD;
84 
85 // POLLRDNORM | POLLRDBAND
86 #define POLLIN    (0x0100 | 0x0200)
87 #define POLLPRI    0x0400
88 // POLLWRNORM
89 #define POLLOUT    0x0010
90 
inet_ntop(int,const void *,char *,size_t)91 inline const char *inet_ntop(int, const void *, char *, size_t) {
92   MingWError();
93   return nullptr;
94 }
95 #endif  // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
96 
97 namespace rabit {
98 namespace utils {
99 
100 static constexpr int kInvalidSocket = -1;
101 
102 template <typename PollFD>
PollImpl(PollFD * pfd,int nfds,std::chrono::seconds timeout)103 int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
104 #if defined(_WIN32)
105 
106 #if IS_MINGW()
107   MingWError();
108   return -1;
109 #else
110   return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
111 #endif  // IS_MINGW()
112 
113 #else
114   return poll(pfd, nfds, std::chrono::milliseconds(timeout).count());
115 #endif  // IS_MINGW()
116 }
117 
118 /*! \brief data structure for network address */
119 struct SockAddr {
120   sockaddr_in addr;
121   // constructor
122   SockAddr() = default;
SockAddrSockAddr123   SockAddr(const char *url, int port) {
124     this->Set(url, port);
125   }
GetHostNameSockAddr126   inline static std::string GetHostName() {
127     std::string buf; buf.resize(256);
128 #if !IS_MINGW()
129     utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
130 #endif  // IS_MINGW()
131     return std::string(buf.c_str());
132   }
133   /*!
134    * \brief set the address
135    * \param url the url of the address
136    * \param port the port of address
137    */
SetSockAddr138   inline void Set(const char *host, int port) {
139 #if !IS_MINGW()
140     addrinfo hints;
141     memset(&hints, 0, sizeof(hints));
142     hints.ai_family = AF_INET;
143     hints.ai_protocol = SOCK_STREAM;
144     addrinfo *res = nullptr;
145     int sig = getaddrinfo(host, nullptr, &hints, &res);
146     Check(sig == 0 && res != nullptr, "cannot obtain address of %s", host);
147     Check(res->ai_family == AF_INET, "Does not support IPv6");
148     memcpy(&addr, res->ai_addr, res->ai_addrlen);
149     addr.sin_port = htons(port);
150     freeaddrinfo(res);
151 #endif  // !IS_MINGW()
152   }
153   /*! \brief return port of the address*/
PortSockAddr154   inline int Port() const {
155     return ntohs(addr.sin_port);
156   }
157   /*! \return a string representation of the address */
AddrStrSockAddr158   inline std::string AddrStr() const {
159     std::string buf; buf.resize(256);
160 #ifdef _WIN32
161     const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
162                     &buf[0], buf.length());
163 #else
164     const char *s = inet_ntop(AF_INET, &addr.sin_addr,
165                               &buf[0], buf.length());
166 #endif  // _WIN32
167     Assert(s != nullptr, "cannot decode address");
168     return std::string(s);
169   }
170 };
171 
172 /*!
173  * \brief base class containing common operations of TCP and UDP sockets
174  */
175 class Socket {
176  public:
177   /*! \brief the file descriptor of socket */
178   SOCKET sockfd;
179   // default conversion to int
SOCKET()180   operator SOCKET() const {  // NOLINT
181     return sockfd;
182   }
183   /*!
184    * \return last error of socket operation
185    */
GetLastError()186   inline static int GetLastError() {
187 #ifdef _WIN32
188 
189 #if IS_MINGW()
190     MingWError();
191     return -1;
192 #else
193     return WSAGetLastError();
194 #endif  // IS_MINGW()
195 
196 #else
197     return errno;
198 #endif  // _WIN32
199   }
200   /*! \return whether last error was would block */
LastErrorWouldBlock()201   inline static bool LastErrorWouldBlock() {
202     int errsv = GetLastError();
203 #ifdef _WIN32
204     return errsv == WSAEWOULDBLOCK;
205 #else
206     return errsv == EAGAIN || errsv == EWOULDBLOCK;
207 #endif  // _WIN32
208   }
209   /*!
210    * \brief start up the socket module
211    *   call this before using the sockets
212    */
Startup()213   inline static void Startup() {
214 #ifdef _WIN32
215 #if !IS_MINGW()
216     WSADATA wsa_data;
217     if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
218       Socket::Error("Startup");
219     }
220     if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
221     WSACleanup();
222     utils::Error("Could not find a usable version of Winsock.dll\n");
223     }
224 #endif  // !IS_MINGW()
225 #endif  // _WIN32
226   }
227   /*!
228    * \brief shutdown the socket module after use, all sockets need to be closed
229    */
Finalize()230   inline static void Finalize() {
231 #ifdef _WIN32
232 #if !IS_MINGW()
233     WSACleanup();
234 #endif  // !IS_MINGW()
235 #endif  // _WIN32
236   }
237   /*!
238    * \brief set this socket to use non-blocking mode
239    * \param non_block whether set it to be non-block, if it is false
240    *        it will set it back to block mode
241    */
SetNonBlock(bool non_block)242   inline void SetNonBlock(bool non_block) {
243 #ifdef _WIN32
244 #if !IS_MINGW()
245     u_long mode = non_block ? 1 : 0;
246     if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
247       Socket::Error("SetNonBlock");
248     }
249 #endif  // !IS_MINGW()
250 #else
251     int flag = fcntl(sockfd, F_GETFL, 0);
252     if (flag == -1) {
253       Socket::Error("SetNonBlock-1");
254     }
255     if (non_block) {
256       flag |= O_NONBLOCK;
257     } else {
258       flag &= ~O_NONBLOCK;
259     }
260     if (fcntl(sockfd, F_SETFL, flag) == -1) {
261       Socket::Error("SetNonBlock-2");
262     }
263 #endif  // _WIN32
264   }
265   /*!
266    * \brief bind the socket to an address
267    * \param addr
268    */
Bind(const SockAddr & addr)269   inline void Bind(const SockAddr &addr) {
270 #if !IS_MINGW()
271     if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
272              sizeof(addr.addr)) == -1) {
273       Socket::Error("Bind");
274     }
275 #endif  // !IS_MINGW()
276   }
277   /*!
278    * \brief try bind the socket to host, from start_port to end_port
279    * \param start_port starting port number to try
280    * \param end_port ending port number to try
281    * \return the port successfully bind to, return -1 if failed to bind any port
282    */
TryBindHost(int start_port,int end_port)283   inline int TryBindHost(int start_port, int end_port) {
284     // TODO(tqchen) add prefix check
285 #if !IS_MINGW()
286     for (int port = start_port; port < end_port; ++port) {
287       SockAddr addr("0.0.0.0", port);
288       if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
289                sizeof(addr.addr)) == 0) {
290         return port;
291       }
292 #if defined(_WIN32)
293       if (WSAGetLastError() != WSAEADDRINUSE) {
294         Socket::Error("TryBindHost");
295       }
296 #else
297       if (errno != EADDRINUSE) {
298         Socket::Error("TryBindHost");
299       }
300 #endif  // defined(_WIN32)
301     }
302 #endif  // !IS_MINGW()
303     return -1;
304   }
305   /*! \brief get last error code if any */
GetSockError()306   inline int GetSockError() const {
307     int error = 0;
308     socklen_t len = sizeof(error);
309 #if !IS_MINGW()
310     if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
311                    reinterpret_cast<char *>(&error), &len) != 0) {
312       Error("GetSockError");
313     }
314 #else
315     // undefined reference to `_imp__getsockopt@20'
316     MingWError();
317 #endif  // !IS_MINGW()
318     return error;
319   }
320   /*! \brief check if anything bad happens */
BadSocket()321   inline bool BadSocket() const {
322     if (IsClosed()) return true;
323     int err = GetSockError();
324     if (err == EBADF || err == EINTR) return true;
325     return false;
326   }
327   /*! \brief check if socket is already closed */
IsClosed()328   inline bool IsClosed() const {
329     return sockfd == kInvalidSocket;
330   }
331   /*! \brief close the socket */
Close()332   inline void Close() {
333     if (sockfd != kInvalidSocket) {
334 #ifdef _WIN32
335 #if !IS_MINGW()
336       closesocket(sockfd);
337 #endif  // !IS_MINGW()
338 #else
339       close(sockfd);
340 #endif
341       sockfd = kInvalidSocket;
342     } else {
343       Error("Socket::Close double close the socket or close without create");
344     }
345   }
346   // report an socket error
Error(const char * msg)347   inline static void Error(const char *msg) {
348     int errsv = GetLastError();
349 #ifdef _WIN32
350     utils::Error("Socket %s Error:WSAError-code=%d", msg, errsv);
351 #else
352     utils::Error("Socket %s Error:%s", msg, strerror(errsv));
353 #endif
354   }
355 
356  protected:
Socket(SOCKET sockfd)357   explicit Socket(SOCKET sockfd) : sockfd(sockfd) {
358   }
359 };
360 
361 /*!
362  * \brief a wrapper of TCP socket that hopefully be cross platform
363  */
364 class TCPSocket : public Socket{
365  public:
366   // constructor
TCPSocket()367   TCPSocket() : Socket(kInvalidSocket) {
368   }
TCPSocket(SOCKET sockfd)369   explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
370   }
371   /*!
372    * \brief enable/disable TCP keepalive
373    * \param keepalive whether to set the keep alive option on
374    */
SetKeepAlive(bool keepalive)375   void SetKeepAlive(bool keepalive) {
376 #if !IS_MINGW()
377     int opt = static_cast<int>(keepalive);
378     if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
379                    reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
380       Socket::Error("SetKeepAlive");
381     }
382 #endif  // !IS_MINGW()
383   }
384   inline void SetLinger(int timeout = 0) {
385 #if !IS_MINGW()
386     struct linger sl;
387     sl.l_onoff = 1;    /* non-zero value enables linger option in kernel */
388     sl.l_linger = timeout;    /* timeout interval in seconds */
389     if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&sl), sizeof(sl)) == -1) {
390       Socket::Error("SO_LINGER");
391     }
392 #endif  // !IS_MINGW()
393   }
394   /*!
395    * \brief create the socket, call this before using socket
396    * \param af domain
397    */
398   inline void Create(int af = PF_INET) {
399 #if !IS_MINGW()
400     sockfd = socket(PF_INET, SOCK_STREAM, 0);
401     if (sockfd == kInvalidSocket) {
402       Socket::Error("Create");
403     }
404 #endif  // !IS_MINGW()
405   }
406   /*!
407    * \brief perform listen of the socket
408    * \param backlog backlog parameter
409    */
410   inline void Listen(int backlog = 16) {
411 #if !IS_MINGW()
412     listen(sockfd, backlog);
413 #endif  // !IS_MINGW()
414   }
415   /*! \brief get a new connection */
Accept()416   TCPSocket Accept() {
417 #if !IS_MINGW()
418     SOCKET newfd = accept(sockfd, nullptr, nullptr);
419     if (newfd == kInvalidSocket) {
420       Socket::Error("Accept");
421     }
422     return TCPSocket(newfd);
423 #else
424     return TCPSocket();
425 #endif // !IS_MINGW()
426   }
427   /*!
428    * \brief decide whether the socket is at OOB mark
429    * \return 1 if at mark, 0 if not, -1 if an error occured
430    */
AtMark()431   inline int AtMark() const {
432 #if !IS_MINGW()
433 
434 #ifdef _WIN32
435     unsigned long atmark;  // NOLINT(*)
436     if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
437 #else
438     int atmark;
439     if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
440 #endif  // _WIN32
441 
442     return static_cast<int>(atmark);
443 
444 #else
445     return -1;
446 #endif  // !IS_MINGW()
447   }
448   /*!
449    * \brief connect to an address
450    * \param addr the address to connect to
451    * \return whether connect is successful
452    */
Connect(const SockAddr & addr)453   inline bool Connect(const SockAddr &addr) {
454 #if !IS_MINGW()
455     return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
456                    sizeof(addr.addr)) == 0;
457 #else
458     return false;
459 #endif  // !IS_MINGW()
460   }
461   /*!
462    * \brief send data using the socket
463    * \param buf the pointer to the buffer
464    * \param len the size of the buffer
465    * \param flags extra flags
466    * \return size of data actually sent
467    *         return -1 if error occurs
468    */
469   inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
470     const char *buf = reinterpret_cast<const char*>(buf_);
471 #if !IS_MINGW()
472     return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
473 #else
474     return 0;
475 #endif  // !IS_MINGW()
476   }
477   /*!
478    * \brief receive data using the socket
479    * \param buf_ the pointer to the buffer
480    * \param len the size of the buffer
481    * \param flags extra flags
482    * \return size of data actually received
483    *         return -1 if error occurs
484    */
485   inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
486     char *buf = reinterpret_cast<char*>(buf_);
487 #if !IS_MINGW()
488     return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
489 #else
490     return 0;
491 #endif  // !IS_MINGW()
492   }
493   /*!
494    * \brief peform block write that will attempt to send all data out
495    *    can still return smaller than request when error occurs
496    * \param buf the pointer to the buffer
497    * \param len the size of the buffer
498    * \return size of data actually sent
499    */
SendAll(const void * buf_,size_t len)500   inline size_t SendAll(const void *buf_, size_t len) {
501     const char *buf = reinterpret_cast<const char*>(buf_);
502     size_t ndone = 0;
503 #if !IS_MINGW()
504     while (ndone <  len) {
505       ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
506       if (ret == -1) {
507         if (LastErrorWouldBlock()) return ndone;
508         Socket::Error("SendAll");
509       }
510       buf += ret;
511       ndone += ret;
512     }
513 #endif  // !IS_MINGW()
514     return ndone;
515   }
516   /*!
517    * \brief peforma block read that will attempt to read all data
518    *    can still return smaller than request when error occurs
519    * \param buf_ the buffer pointer
520    * \param len length of data to recv
521    * \return size of data actually sent
522    */
RecvAll(void * buf_,size_t len)523   inline size_t RecvAll(void *buf_, size_t len) {
524     char *buf = reinterpret_cast<char*>(buf_);
525     size_t ndone = 0;
526 #if !IS_MINGW()
527     while (ndone <  len) {
528       ssize_t ret = recv(sockfd, buf,
529                          static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
530       if (ret == -1) {
531         if (LastErrorWouldBlock()) return ndone;
532         Socket::Error("RecvAll");
533       }
534       if (ret == 0) return ndone;
535       buf += ret;
536       ndone += ret;
537     }
538 #endif  // !IS_MINGW()
539     return ndone;
540   }
541   /*!
542    * \brief send a string over network
543    * \param str the string to be sent
544    */
SendStr(const std::string & str)545   inline void SendStr(const std::string &str) {
546     int len = static_cast<int>(str.length());
547     utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len),
548                   "error during send SendStr");
549     if (len != 0) {
550       utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(),
551                     "error during send SendStr");
552     }
553   }
554   /*!
555    * \brief recv a string from network
556    * \param out_str the string to receive
557    */
RecvStr(std::string * out_str)558   inline void RecvStr(std::string *out_str) {
559     int len;
560     utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len),
561                   "error during send RecvStr");
562     out_str->resize(len);
563     if (len != 0) {
564       utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(),
565                     "error during send SendStr");
566     }
567   }
568 };
569 
570 /*! \brief helper data structure to perform poll */
571 struct PollHelper {
572  public:
573   /*!
574    * \brief add file descriptor to watch for read
575    * \param fd file descriptor to be watched
576    */
WatchReadPollHelper577   inline void WatchRead(SOCKET fd) {
578     auto& pfd = fds[fd];
579     pfd.fd = fd;
580     pfd.events |= POLLIN;
581   }
582   /*!
583    * \brief add file descriptor to watch for write
584    * \param fd file descriptor to be watched
585    */
WatchWritePollHelper586   inline void WatchWrite(SOCKET fd) {
587     auto& pfd = fds[fd];
588     pfd.fd = fd;
589     pfd.events |= POLLOUT;
590   }
591   /*!
592    * \brief add file descriptor to watch for exception
593    * \param fd file descriptor to be watched
594    */
WatchExceptionPollHelper595   inline void WatchException(SOCKET fd) {
596     auto& pfd = fds[fd];
597     pfd.fd = fd;
598     pfd.events |= POLLPRI;
599   }
600   /*!
601    * \brief Check if the descriptor is ready for read
602    * \param fd file descriptor to check status
603    */
CheckReadPollHelper604   inline bool CheckRead(SOCKET fd) const {
605     const auto& pfd = fds.find(fd);
606     return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
607   }
608   /*!
609    * \brief Check if the descriptor is ready for write
610    * \param fd file descriptor to check status
611    */
CheckWritePollHelper612   inline bool CheckWrite(SOCKET fd) const {
613     const auto& pfd = fds.find(fd);
614     return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
615   }
616 
617   /*!
618    * \brief perform poll on the set defined, read, write, exception
619    * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
620    * \return
621    */
PollPollHelper622   inline void Poll(std::chrono::seconds timeout) {  // NOLINT(*)
623     std::vector<pollfd> fdset;
624     fdset.reserve(fds.size());
625     for (auto kv : fds) {
626       fdset.push_back(kv.second);
627     }
628     int ret = PollImpl(fdset.data(), fdset.size(), timeout);
629     if (ret == 0) {
630       LOG(FATAL) << "Poll timeout";
631     } else if (ret < 0) {
632       Socket::Error("Poll");
633     } else {
634       for (auto& pfd : fdset) {
635         auto revents = pfd.revents & pfd.events;
636         if (!revents) {
637           fds.erase(pfd.fd);
638         } else {
639           fds[pfd.fd].events = revents;
640         }
641       }
642     }
643   }
644 
645   std::unordered_map<SOCKET, pollfd> fds;
646 };
647 }  // namespace utils
648 }  // namespace rabit
649 
650 #if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
651 #undef POLLIN
652 #undef POLLPRI
653 #undef POLLOUT
654 #endif  // IS_MINGW()
655 
656 #endif  // RABIT_INTERNAL_SOCKET_H_
657