1 /*!
2 * Copyright (c) 2014-2019 by Contributors
3 * \file socket.h
4 * \brief this file aims to provide a wrapper of sockets
5 * \author Tianqi Chen
6 */
7 #ifndef RABIT_INTERNAL_SOCKET_H_
8 #define RABIT_INTERNAL_SOCKET_H_
9 #if defined(_WIN32)
10 #include <winsock2.h>
11 #include <ws2tcpip.h>
12
13 #ifdef _MSC_VER
14 #pragma comment(lib, "Ws2_32.lib")
15 #endif // _MSC_VER
16
17 #else
18
19 #include <fcntl.h>
20 #include <netdb.h>
21 #include <cerrno>
22 #include <unistd.h>
23 #include <arpa/inet.h>
24 #include <netinet/in.h>
25 #include <sys/socket.h>
26 #include <sys/ioctl.h>
27
28 #if defined(__sun) || defined(sun)
29 #include <sys/sockio.h>
30 #endif // defined(__sun) || defined(sun)
31
32 #endif // defined(_WIN32)
33
34 #include <string>
35 #include <cstring>
36 #include <vector>
37 #include <chrono>
38 #include <unordered_map>
39 #include "utils.h"
40
41 #if defined(_WIN32) && !defined(__MINGW32__)
42 typedef int ssize_t;
43 #endif // defined(_WIN32) || defined(__MINGW32__)
44
45 #if defined(_WIN32)
46 using sock_size_t = int;
47
48 #else
49
50 #include <sys/poll.h>
51 using SOCKET = int;
52 using sock_size_t = size_t; // NOLINT
53 #endif // defined(_WIN32)
54
55 #define IS_MINGW() defined(__MINGW32__)
56
57 #if IS_MINGW()
MingWError()58 inline void MingWError() {
59 throw dmlc::Error("Distributed training on mingw is not supported.");
60 }
61 #endif // IS_MINGW()
62
63 #if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
64 /*
65 * On later mingw versions poll should be supported (with bugs). See:
66 * https://stackoverflow.com/a/60623080
67 *
68 * But right now the mingw distributed with R 3.6 doesn't support it.
69 * So we just give a warning and provide dummy implementation to get
70 * compilation passed. Otherwise we will have to provide a stub for
71 * RABIT.
72 *
73 * Even on mingw version that has these structures and flags defined,
74 * functions like `send` and `listen` might have unresolved linkage to
75 * their implementation. So supporting mingw is quite difficult at
76 * the time of writing.
77 */
78 #pragma message("Distributed training on mingw is not supported.")
79 typedef struct pollfd {
80 SOCKET fd;
81 short events;
82 short revents;
83 } WSAPOLLFD, *PWSAPOLLFD, *LPWSAPOLLFD;
84
85 // POLLRDNORM | POLLRDBAND
86 #define POLLIN (0x0100 | 0x0200)
87 #define POLLPRI 0x0400
88 // POLLWRNORM
89 #define POLLOUT 0x0010
90
inet_ntop(int,const void *,char *,size_t)91 inline const char *inet_ntop(int, const void *, char *, size_t) {
92 MingWError();
93 return nullptr;
94 }
95 #endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
96
97 namespace rabit {
98 namespace utils {
99
100 static constexpr int kInvalidSocket = -1;
101
102 template <typename PollFD>
PollImpl(PollFD * pfd,int nfds,std::chrono::seconds timeout)103 int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
104 #if defined(_WIN32)
105
106 #if IS_MINGW()
107 MingWError();
108 return -1;
109 #else
110 return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
111 #endif // IS_MINGW()
112
113 #else
114 return poll(pfd, nfds, std::chrono::milliseconds(timeout).count());
115 #endif // IS_MINGW()
116 }
117
118 /*! \brief data structure for network address */
119 struct SockAddr {
120 sockaddr_in addr;
121 // constructor
122 SockAddr() = default;
SockAddrSockAddr123 SockAddr(const char *url, int port) {
124 this->Set(url, port);
125 }
GetHostNameSockAddr126 inline static std::string GetHostName() {
127 std::string buf; buf.resize(256);
128 #if !IS_MINGW()
129 utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name");
130 #endif // IS_MINGW()
131 return std::string(buf.c_str());
132 }
133 /*!
134 * \brief set the address
135 * \param url the url of the address
136 * \param port the port of address
137 */
SetSockAddr138 inline void Set(const char *host, int port) {
139 #if !IS_MINGW()
140 addrinfo hints;
141 memset(&hints, 0, sizeof(hints));
142 hints.ai_family = AF_INET;
143 hints.ai_protocol = SOCK_STREAM;
144 addrinfo *res = nullptr;
145 int sig = getaddrinfo(host, nullptr, &hints, &res);
146 Check(sig == 0 && res != nullptr, "cannot obtain address of %s", host);
147 Check(res->ai_family == AF_INET, "Does not support IPv6");
148 memcpy(&addr, res->ai_addr, res->ai_addrlen);
149 addr.sin_port = htons(port);
150 freeaddrinfo(res);
151 #endif // !IS_MINGW()
152 }
153 /*! \brief return port of the address*/
PortSockAddr154 inline int Port() const {
155 return ntohs(addr.sin_port);
156 }
157 /*! \return a string representation of the address */
AddrStrSockAddr158 inline std::string AddrStr() const {
159 std::string buf; buf.resize(256);
160 #ifdef _WIN32
161 const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr,
162 &buf[0], buf.length());
163 #else
164 const char *s = inet_ntop(AF_INET, &addr.sin_addr,
165 &buf[0], buf.length());
166 #endif // _WIN32
167 Assert(s != nullptr, "cannot decode address");
168 return std::string(s);
169 }
170 };
171
172 /*!
173 * \brief base class containing common operations of TCP and UDP sockets
174 */
175 class Socket {
176 public:
177 /*! \brief the file descriptor of socket */
178 SOCKET sockfd;
179 // default conversion to int
SOCKET()180 operator SOCKET() const { // NOLINT
181 return sockfd;
182 }
183 /*!
184 * \return last error of socket operation
185 */
GetLastError()186 inline static int GetLastError() {
187 #ifdef _WIN32
188
189 #if IS_MINGW()
190 MingWError();
191 return -1;
192 #else
193 return WSAGetLastError();
194 #endif // IS_MINGW()
195
196 #else
197 return errno;
198 #endif // _WIN32
199 }
200 /*! \return whether last error was would block */
LastErrorWouldBlock()201 inline static bool LastErrorWouldBlock() {
202 int errsv = GetLastError();
203 #ifdef _WIN32
204 return errsv == WSAEWOULDBLOCK;
205 #else
206 return errsv == EAGAIN || errsv == EWOULDBLOCK;
207 #endif // _WIN32
208 }
209 /*!
210 * \brief start up the socket module
211 * call this before using the sockets
212 */
Startup()213 inline static void Startup() {
214 #ifdef _WIN32
215 #if !IS_MINGW()
216 WSADATA wsa_data;
217 if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
218 Socket::Error("Startup");
219 }
220 if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
221 WSACleanup();
222 utils::Error("Could not find a usable version of Winsock.dll\n");
223 }
224 #endif // !IS_MINGW()
225 #endif // _WIN32
226 }
227 /*!
228 * \brief shutdown the socket module after use, all sockets need to be closed
229 */
Finalize()230 inline static void Finalize() {
231 #ifdef _WIN32
232 #if !IS_MINGW()
233 WSACleanup();
234 #endif // !IS_MINGW()
235 #endif // _WIN32
236 }
237 /*!
238 * \brief set this socket to use non-blocking mode
239 * \param non_block whether set it to be non-block, if it is false
240 * it will set it back to block mode
241 */
SetNonBlock(bool non_block)242 inline void SetNonBlock(bool non_block) {
243 #ifdef _WIN32
244 #if !IS_MINGW()
245 u_long mode = non_block ? 1 : 0;
246 if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
247 Socket::Error("SetNonBlock");
248 }
249 #endif // !IS_MINGW()
250 #else
251 int flag = fcntl(sockfd, F_GETFL, 0);
252 if (flag == -1) {
253 Socket::Error("SetNonBlock-1");
254 }
255 if (non_block) {
256 flag |= O_NONBLOCK;
257 } else {
258 flag &= ~O_NONBLOCK;
259 }
260 if (fcntl(sockfd, F_SETFL, flag) == -1) {
261 Socket::Error("SetNonBlock-2");
262 }
263 #endif // _WIN32
264 }
265 /*!
266 * \brief bind the socket to an address
267 * \param addr
268 */
Bind(const SockAddr & addr)269 inline void Bind(const SockAddr &addr) {
270 #if !IS_MINGW()
271 if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
272 sizeof(addr.addr)) == -1) {
273 Socket::Error("Bind");
274 }
275 #endif // !IS_MINGW()
276 }
277 /*!
278 * \brief try bind the socket to host, from start_port to end_port
279 * \param start_port starting port number to try
280 * \param end_port ending port number to try
281 * \return the port successfully bind to, return -1 if failed to bind any port
282 */
TryBindHost(int start_port,int end_port)283 inline int TryBindHost(int start_port, int end_port) {
284 // TODO(tqchen) add prefix check
285 #if !IS_MINGW()
286 for (int port = start_port; port < end_port; ++port) {
287 SockAddr addr("0.0.0.0", port);
288 if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
289 sizeof(addr.addr)) == 0) {
290 return port;
291 }
292 #if defined(_WIN32)
293 if (WSAGetLastError() != WSAEADDRINUSE) {
294 Socket::Error("TryBindHost");
295 }
296 #else
297 if (errno != EADDRINUSE) {
298 Socket::Error("TryBindHost");
299 }
300 #endif // defined(_WIN32)
301 }
302 #endif // !IS_MINGW()
303 return -1;
304 }
305 /*! \brief get last error code if any */
GetSockError()306 inline int GetSockError() const {
307 int error = 0;
308 socklen_t len = sizeof(error);
309 #if !IS_MINGW()
310 if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR,
311 reinterpret_cast<char *>(&error), &len) != 0) {
312 Error("GetSockError");
313 }
314 #else
315 // undefined reference to `_imp__getsockopt@20'
316 MingWError();
317 #endif // !IS_MINGW()
318 return error;
319 }
320 /*! \brief check if anything bad happens */
BadSocket()321 inline bool BadSocket() const {
322 if (IsClosed()) return true;
323 int err = GetSockError();
324 if (err == EBADF || err == EINTR) return true;
325 return false;
326 }
327 /*! \brief check if socket is already closed */
IsClosed()328 inline bool IsClosed() const {
329 return sockfd == kInvalidSocket;
330 }
331 /*! \brief close the socket */
Close()332 inline void Close() {
333 if (sockfd != kInvalidSocket) {
334 #ifdef _WIN32
335 #if !IS_MINGW()
336 closesocket(sockfd);
337 #endif // !IS_MINGW()
338 #else
339 close(sockfd);
340 #endif
341 sockfd = kInvalidSocket;
342 } else {
343 Error("Socket::Close double close the socket or close without create");
344 }
345 }
346 // report an socket error
Error(const char * msg)347 inline static void Error(const char *msg) {
348 int errsv = GetLastError();
349 #ifdef _WIN32
350 utils::Error("Socket %s Error:WSAError-code=%d", msg, errsv);
351 #else
352 utils::Error("Socket %s Error:%s", msg, strerror(errsv));
353 #endif
354 }
355
356 protected:
Socket(SOCKET sockfd)357 explicit Socket(SOCKET sockfd) : sockfd(sockfd) {
358 }
359 };
360
361 /*!
362 * \brief a wrapper of TCP socket that hopefully be cross platform
363 */
364 class TCPSocket : public Socket{
365 public:
366 // constructor
TCPSocket()367 TCPSocket() : Socket(kInvalidSocket) {
368 }
TCPSocket(SOCKET sockfd)369 explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
370 }
371 /*!
372 * \brief enable/disable TCP keepalive
373 * \param keepalive whether to set the keep alive option on
374 */
SetKeepAlive(bool keepalive)375 void SetKeepAlive(bool keepalive) {
376 #if !IS_MINGW()
377 int opt = static_cast<int>(keepalive);
378 if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
379 reinterpret_cast<char*>(&opt), sizeof(opt)) < 0) {
380 Socket::Error("SetKeepAlive");
381 }
382 #endif // !IS_MINGW()
383 }
384 inline void SetLinger(int timeout = 0) {
385 #if !IS_MINGW()
386 struct linger sl;
387 sl.l_onoff = 1; /* non-zero value enables linger option in kernel */
388 sl.l_linger = timeout; /* timeout interval in seconds */
389 if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&sl), sizeof(sl)) == -1) {
390 Socket::Error("SO_LINGER");
391 }
392 #endif // !IS_MINGW()
393 }
394 /*!
395 * \brief create the socket, call this before using socket
396 * \param af domain
397 */
398 inline void Create(int af = PF_INET) {
399 #if !IS_MINGW()
400 sockfd = socket(PF_INET, SOCK_STREAM, 0);
401 if (sockfd == kInvalidSocket) {
402 Socket::Error("Create");
403 }
404 #endif // !IS_MINGW()
405 }
406 /*!
407 * \brief perform listen of the socket
408 * \param backlog backlog parameter
409 */
410 inline void Listen(int backlog = 16) {
411 #if !IS_MINGW()
412 listen(sockfd, backlog);
413 #endif // !IS_MINGW()
414 }
415 /*! \brief get a new connection */
Accept()416 TCPSocket Accept() {
417 #if !IS_MINGW()
418 SOCKET newfd = accept(sockfd, nullptr, nullptr);
419 if (newfd == kInvalidSocket) {
420 Socket::Error("Accept");
421 }
422 return TCPSocket(newfd);
423 #else
424 return TCPSocket();
425 #endif // !IS_MINGW()
426 }
427 /*!
428 * \brief decide whether the socket is at OOB mark
429 * \return 1 if at mark, 0 if not, -1 if an error occured
430 */
AtMark()431 inline int AtMark() const {
432 #if !IS_MINGW()
433
434 #ifdef _WIN32
435 unsigned long atmark; // NOLINT(*)
436 if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
437 #else
438 int atmark;
439 if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
440 #endif // _WIN32
441
442 return static_cast<int>(atmark);
443
444 #else
445 return -1;
446 #endif // !IS_MINGW()
447 }
448 /*!
449 * \brief connect to an address
450 * \param addr the address to connect to
451 * \return whether connect is successful
452 */
Connect(const SockAddr & addr)453 inline bool Connect(const SockAddr &addr) {
454 #if !IS_MINGW()
455 return connect(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
456 sizeof(addr.addr)) == 0;
457 #else
458 return false;
459 #endif // !IS_MINGW()
460 }
461 /*!
462 * \brief send data using the socket
463 * \param buf the pointer to the buffer
464 * \param len the size of the buffer
465 * \param flags extra flags
466 * \return size of data actually sent
467 * return -1 if error occurs
468 */
469 inline ssize_t Send(const void *buf_, size_t len, int flag = 0) {
470 const char *buf = reinterpret_cast<const char*>(buf_);
471 #if !IS_MINGW()
472 return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
473 #else
474 return 0;
475 #endif // !IS_MINGW()
476 }
477 /*!
478 * \brief receive data using the socket
479 * \param buf_ the pointer to the buffer
480 * \param len the size of the buffer
481 * \param flags extra flags
482 * \return size of data actually received
483 * return -1 if error occurs
484 */
485 inline ssize_t Recv(void *buf_, size_t len, int flags = 0) {
486 char *buf = reinterpret_cast<char*>(buf_);
487 #if !IS_MINGW()
488 return recv(sockfd, buf, static_cast<sock_size_t>(len), flags);
489 #else
490 return 0;
491 #endif // !IS_MINGW()
492 }
493 /*!
494 * \brief peform block write that will attempt to send all data out
495 * can still return smaller than request when error occurs
496 * \param buf the pointer to the buffer
497 * \param len the size of the buffer
498 * \return size of data actually sent
499 */
SendAll(const void * buf_,size_t len)500 inline size_t SendAll(const void *buf_, size_t len) {
501 const char *buf = reinterpret_cast<const char*>(buf_);
502 size_t ndone = 0;
503 #if !IS_MINGW()
504 while (ndone < len) {
505 ssize_t ret = send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0);
506 if (ret == -1) {
507 if (LastErrorWouldBlock()) return ndone;
508 Socket::Error("SendAll");
509 }
510 buf += ret;
511 ndone += ret;
512 }
513 #endif // !IS_MINGW()
514 return ndone;
515 }
516 /*!
517 * \brief peforma block read that will attempt to read all data
518 * can still return smaller than request when error occurs
519 * \param buf_ the buffer pointer
520 * \param len length of data to recv
521 * \return size of data actually sent
522 */
RecvAll(void * buf_,size_t len)523 inline size_t RecvAll(void *buf_, size_t len) {
524 char *buf = reinterpret_cast<char*>(buf_);
525 size_t ndone = 0;
526 #if !IS_MINGW()
527 while (ndone < len) {
528 ssize_t ret = recv(sockfd, buf,
529 static_cast<sock_size_t>(len - ndone), MSG_WAITALL);
530 if (ret == -1) {
531 if (LastErrorWouldBlock()) return ndone;
532 Socket::Error("RecvAll");
533 }
534 if (ret == 0) return ndone;
535 buf += ret;
536 ndone += ret;
537 }
538 #endif // !IS_MINGW()
539 return ndone;
540 }
541 /*!
542 * \brief send a string over network
543 * \param str the string to be sent
544 */
SendStr(const std::string & str)545 inline void SendStr(const std::string &str) {
546 int len = static_cast<int>(str.length());
547 utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len),
548 "error during send SendStr");
549 if (len != 0) {
550 utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(),
551 "error during send SendStr");
552 }
553 }
554 /*!
555 * \brief recv a string from network
556 * \param out_str the string to receive
557 */
RecvStr(std::string * out_str)558 inline void RecvStr(std::string *out_str) {
559 int len;
560 utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len),
561 "error during send RecvStr");
562 out_str->resize(len);
563 if (len != 0) {
564 utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(),
565 "error during send SendStr");
566 }
567 }
568 };
569
570 /*! \brief helper data structure to perform poll */
571 struct PollHelper {
572 public:
573 /*!
574 * \brief add file descriptor to watch for read
575 * \param fd file descriptor to be watched
576 */
WatchReadPollHelper577 inline void WatchRead(SOCKET fd) {
578 auto& pfd = fds[fd];
579 pfd.fd = fd;
580 pfd.events |= POLLIN;
581 }
582 /*!
583 * \brief add file descriptor to watch for write
584 * \param fd file descriptor to be watched
585 */
WatchWritePollHelper586 inline void WatchWrite(SOCKET fd) {
587 auto& pfd = fds[fd];
588 pfd.fd = fd;
589 pfd.events |= POLLOUT;
590 }
591 /*!
592 * \brief add file descriptor to watch for exception
593 * \param fd file descriptor to be watched
594 */
WatchExceptionPollHelper595 inline void WatchException(SOCKET fd) {
596 auto& pfd = fds[fd];
597 pfd.fd = fd;
598 pfd.events |= POLLPRI;
599 }
600 /*!
601 * \brief Check if the descriptor is ready for read
602 * \param fd file descriptor to check status
603 */
CheckReadPollHelper604 inline bool CheckRead(SOCKET fd) const {
605 const auto& pfd = fds.find(fd);
606 return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
607 }
608 /*!
609 * \brief Check if the descriptor is ready for write
610 * \param fd file descriptor to check status
611 */
CheckWritePollHelper612 inline bool CheckWrite(SOCKET fd) const {
613 const auto& pfd = fds.find(fd);
614 return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
615 }
616
617 /*!
618 * \brief perform poll on the set defined, read, write, exception
619 * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
620 * \return
621 */
PollPollHelper622 inline void Poll(std::chrono::seconds timeout) { // NOLINT(*)
623 std::vector<pollfd> fdset;
624 fdset.reserve(fds.size());
625 for (auto kv : fds) {
626 fdset.push_back(kv.second);
627 }
628 int ret = PollImpl(fdset.data(), fdset.size(), timeout);
629 if (ret == 0) {
630 LOG(FATAL) << "Poll timeout";
631 } else if (ret < 0) {
632 Socket::Error("Poll");
633 } else {
634 for (auto& pfd : fdset) {
635 auto revents = pfd.revents & pfd.events;
636 if (!revents) {
637 fds.erase(pfd.fd);
638 } else {
639 fds[pfd.fd].events = revents;
640 }
641 }
642 }
643 }
644
645 std::unordered_map<SOCKET, pollfd> fds;
646 };
647 } // namespace utils
648 } // namespace rabit
649
650 #if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
651 #undef POLLIN
652 #undef POLLPRI
653 #undef POLLOUT
654 #endif // IS_MINGW()
655
656 #endif // RABIT_INTERNAL_SOCKET_H_
657