1 /*
2  * Copyright (c) 2002, Stefan Farfeleder <e0026813@stud3.tuwien.ac.at>
3  * $Id: network.cc,v 1.5 2002/09/10 22:29:38 stefan Exp $
4  */
5 #include <algorithm>
6 #include <cassert>
7 #include <cerrno>
8 #include <csignal>
9 #include <cstdlib>
10 #include <cstddef>
11 #include <cstring>
12 #include <new>
13 #include <string>
14 
15 #include "config.h"
16 
17 #define _POSIX_C_SOURCE 199506L
18 #define _XOPEN_SOURCE_EXTENDED
19 #if defined(HAVE_SYS_TYPES_H)
20 #include <sys/types.h>
21 #endif
22 #if defined(HAVE_SYS_SOCKET_H)
23 #include <sys/socket.h>
24 #endif
25 #if defined(HAVE_SYS_SELECT_H)
26 #include <sys/select.h>
27 #endif
28 #if defined(HAVE_SYS_TIME_H)
29 #include <sys/time.h>
30 #endif
31 #if defined(HAVE_ARPA_INET_H)
32 #include <arpa/inet.h>
33 #endif
34 #if defined(HAVE_FCNTL_H)
35 #include <fcntl.h>
36 #endif
37 #if defined(HAVE_NETDB_H)
38 #include <netdb.h>
39 #endif
40 #if defined(HAVE_NETINET_IN_H)
41 #include <netinet/in.h>
42 #endif
43 #if defined(HAVE_NETINET_TCP_H)
44 #include <netinet/tcp.h>
45 #endif
46 #if defined(HAVE_POLL_H)
47 #include <poll.h>
48 #endif
49 #if defined(HAVE_UNISTD_H)
50 #include <unistd.h>
51 #endif
52 
53 #if defined(_WIN32)
54 #include <winsock2.h>
55 #define     close           closesocket
56 #define     EWOULDBLOCK     WSAEWOULDBLOCK
57 #define     ECONNABORTED    WSAECONNABORTED
58 #define     SET_ERRNO       do { errno = WSAGetLastError(); } while (0)
59 #else
60 #define     SET_ERRNO
61 #endif
62 
63 #include "addrinfo.h"
64 #include "exception.h"
65 #include "missing.h"
66 #include "network.h"
67 
68 using std::string;
69 using std::vector;
70 using namespace JFK;
71 
72 const int JFK::TO_ALL = -1;
73 
74 #if defined(INET6)
75 const int MAXSOCKADDR = sizeof(sockaddr_in6);
76 #else
77 const int MAXSOCKADDR = sizeof(sockaddr_in);
78 #endif
79 
80 /*
81  * We really need non-blocking sockets, so bomb out if we can't have them.
82  */
83 static void
set_socket_nonblocking(int sock)84 set_socket_nonblocking(int sock)
85 {
86     int             val;
87 
88 #if defined(HAVE_FCNTL)
89     val = fcntl(sock, F_GETFL);
90     if (val == -1) throw JFK::exception_e("fcntl");
91 
92     if (fcntl(sock, F_SETFL, val | O_NONBLOCK) == -1)
93         throw exception_e("fcntl");
94 #elif defined(_WIN32)
95     unsigned long   ul_on = 1;
96 
97     val = ioctlsocket(sock, FIONBIO, &ul_on);
98     if (val == -1) throw JFK::exception("ioctlsocket");
99 #else
100     #error "can't set the socket to non-blocking"
101 #endif
102 }
103 
104 static int
setsockopt_wrapper(int sock,int level,int optname,int value)105 setsockopt_wrapper(int sock, int level, int optname, int value)
106 {
107 #if defined(HAVE_SETSOCKOPT)
108     const int       v = value;
109     return setsockopt(sock, level, optname, &v, sizeof v);
110 #elif defined(_WIN32)
111     const char      v = value;
112     return setsockopt(sock, level, optname, &v, sizeof v);
113 #else
114     return -1;
115 #endif
116 }
117 
118 /*
119  * Disabling Nagle's algorithm should reduce network lag, but we don't require
120  * it.
121  */
122 static void
disable_nagle_algorithm(int sock)123 disable_nagle_algorithm(int sock)
124 {
125 #if defined(IPPROTO_TCP) && defined(TCP_NODELAY)
126     (void)setsockopt_wrapper(sock, IPPROTO_TCP, TCP_NODELAY, 1);
127 #endif
128 }
129 
130 static void
network_init()131 network_init()
132 {
133 #if defined(_WIN32)
134     unsigned long   version = MAKEWORD(2, 2);
135     WSADATA         wsadata;
136 
137     if (WSAStartup(version, &wsadata) != 0)
138         throw JFK::exception("WSAStartup");
139 #endif
140 #ifdef SIGPIPE
141     /* ignore SIGPIPE */
142     if (std::signal(SIGPIPE, SIG_IGN) == SIG_ERR)
143         throw JFK::exception_e("signal");
144 #endif
145 }
146 
147 /*
148  * Connect to the server with the specified service, then set the socket to
149  * non-blocking mode.
150  */
network_client(const string & servername,const string & service)151 network_client::network_client(const string& servername,
152                                const string& service)
153 {
154     addrinfo        hints;
155     addrinfo*       res;
156     addrinfo*       lres;
157     int             val;
158 
159     network_init();
160 
161     std::memset(&hints, 0, sizeof hints);
162     hints.ai_family = AF_UNSPEC;
163     hints.ai_socktype = SOCK_STREAM;
164 
165     if ((val = getaddrinfo(servername.c_str(), service.c_str(),
166                            &hints, &res)) != 0)
167         throw exception("getaddrinfo: " + string(gai_strerror(val)));
168 
169     /* loop through each entry and try to connect */
170     for (lres = res; lres != NULL; lres = lres->ai_next)
171     {
172         server.fd = socket(lres->ai_family, lres->ai_socktype,
173                            lres->ai_protocol);
174         if (server.fd < 0)
175             continue;
176 
177         if (connect(server.fd, lres->ai_addr, lres->ai_addrlen) == 0)
178             break;  /* found */
179 
180         if (close(server.fd) != 0)
181             throw exception_e("close");
182     }
183 
184     freeaddrinfo(res);
185 
186     if (lres == NULL)
187         /* errno could be wrong, so probably this should only be
188          * exception(...) */
189         throw exception_e("cannot connect to " + servername + '.' + service);
190 
191     set_socket_nonblocking(server.fd);
192     disable_nagle_algorithm(server.fd);
193 }
194 
~network_client()195 network_client::~network_client()
196 {
197 #if defined(_WIN32)
198     WSACleanup();
199 #endif
200 }
201 
202 /*
203  * Send 'msg' to the server.
204  */
205 void
send(const string & msg)206 network_client::send(const string& msg)
207 {
208     server.writeq.push(msg + '\n');
209 }
210 
211 /*
212  * Get a line from the server and return true; if no line is ready false is
213  * returned.
214  */
215 bool
receive(string * msg)216 network_client::receive(string* msg)
217 {
218     if (server.readq.empty())
219         return false;
220 
221     *msg = server.readq.front();
222     server.readq.pop();
223     return true;
224 }
225 
226 /*
227  * Try to write lines stored in writeq to the server and to read lines from
228  * the server into readq.
229  */
230 void
dispatch()231 network_client::dispatch()
232 {
233     server.flush_writeq();
234 
235     server.fill_readq();
236 
237     if (server.lost)
238         throw exception("lost server");
239 }
240 
241 
242 /*
243  * Create a server listening on service 'service' and accepting connections.
244  */
network_server(const string & service)245 network_server::network_server(const string& service)
246 #if defined(HAVE_POLL)
247     : fds(NULL), fd_alloc(0)
248 #endif
249 {
250     addrinfo        hints;
251     addrinfo*       res;
252     addrinfo*       lres;
253     int             val;
254 
255     network_init();
256 
257     std::memset(&hints, 0, sizeof hints);
258     hints.ai_family = AF_UNSPEC;
259     hints.ai_flags = AI_PASSIVE;
260     hints.ai_socktype = SOCK_STREAM;
261 
262     if ((val = getaddrinfo(NULL, service.c_str(), &hints, &res)) != 0)
263         throw exception("getaddrinfo: " + string(gai_strerror(val)));
264 
265     /* loop through each entry and try to bind */
266     for (lres = res; lres != NULL; lres = lres->ai_next)
267     {
268         listenfd = socket(lres->ai_family, lres->ai_socktype,
269                           lres->ai_protocol);
270         if (listenfd < 0)
271             continue;
272 
273 #if defined(SOL_SOCKET) && defined(SO_REUSEADDR)
274         (void)setsockopt_wrapper(listenfd, SOL_SOCKET, SO_REUSEADDR, 1);
275 #endif
276 
277         if (bind(listenfd, lres->ai_addr, lres->ai_addrlen) == 0)
278             break;  /* found */
279 
280         if (close(listenfd) != 0)
281             throw exception_e("close");
282     }
283 
284     freeaddrinfo(res);
285 
286     if (lres == NULL)
287         throw exception_e("cannot listen on " + service);
288 
289     if (listen(listenfd, SOMAXCONN) != 0)
290         throw exception_e("listen");
291 
292     set_socket_nonblocking(listenfd);
293 
294     sa = (sockaddr*)malloc(MAXSOCKADDR);
295     if (sa == NULL)
296         throw std::bad_alloc();
297 }
298 
299 /*
300  * Free allocated memory.
301  */
~network_server()302 network_server::~network_server()
303 {
304 #if defined(HAVE_POLL)
305     std::free(fds);
306 #endif
307     std::free(sa);
308 #if defined(_WIN32)
309     WSACleanup();
310 #endif
311 }
312 
313 /*
314  * Return true if a new client has connected to the server. The ip number is
315  * stored into 'host'. An id number for the client is assigned with which it
316  * can be referred to in further requests.
317  */
318 bool
new_client(string * hostname,int * id)319 network_server::new_client(string* hostname, int* id)
320 {
321     char            buf[NI_MAXHOST];
322     socklen_t       len = MAXSOCKADDR;
323     int             clientfd;
324     int             offs;
325 
326     errno = 0;
327 
328 #if !defined(HAVE_POLL)
329     /* only FD_SETSIZE fds in select(), can't accept() more */
330     if (client.size() == FD_SETSIZE) return false;
331 #endif
332 
333     do
334     {
335         clientfd = accept(listenfd, sa, &len);
336         SET_ERRNO;
337     }
338     while (clientfd == -1 && errno == EINTR);
339 
340     if (clientfd == -1)
341     {
342         if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ECONNABORTED)
343         {
344             return false; /* no new client yet */
345         }
346         throw exception_e("accept");
347     }
348 
349     if (sa->sa_family == AF_INET)
350         offs = offsetof(sockaddr_in, sin_addr);
351 #if defined(INET6)
352     else if (sa->sa_family == AF_INET6)
353         offs = offsetof(sockaddr_in6, sin6_addr);
354 #endif
355     else
356         throw exception("unknown sa_family");
357 
358     if (inet_ntop(sa->sa_family, (char*)sa + offs, buf, sizeof buf) == NULL)
359         throw exception_e("inet_ntop");
360 
361     set_socket_nonblocking(clientfd);
362     disable_nagle_algorithm(clientfd);
363 
364     *hostname = string(buf);
365     *id = clientfd;
366 
367     client.push_back(clientfd);
368 
369 #if defined(HAVE_POLL)
370     if (client.size() > fd_alloc)
371     {
372         const size_t    MIN_POLLFD = 10;
373         /* allocate more memory for the pollfd array */
374         fd_alloc = std::max(MIN_POLLFD, 3 * fd_alloc / 2);
375         pollfd* tmp = (pollfd*)realloc(fds, fd_alloc * sizeof *fds);
376         if (tmp == NULL)
377             throw std::bad_alloc();
378 
379         fds = tmp;
380     }
381 #endif
382 
383     return true;
384 }
385 
386 /*
387  * Return true if a client has disconnected and store its id into 'id'.
388  */
389 bool
lost_client(int * id)390 network_server::lost_client(int* id)
391 {
392     for (size_t i = 0; i < client.size(); i++)
393     {
394         if (client[i].lost)
395         {
396             *id = client[i].fd;
397             remove_client(*id);
398             return true;
399         }
400     }
401     return false;
402 }
403 
404 /*
405  * Cut the connection to the client.
406  */
407 void
remove_client(int id)408 network_server::remove_client(int id)
409 {
410     client_iter     i = find_id(client, id);
411 
412     if (i == client.end())
413         throw exception("id nonexistent");
414 
415     if (close(i->fd) != 0)
416         throw exception_e("close");
417 
418     *i = client.back();
419     client.resize(client.size() - 1);
420 }
421 
422 void
enable_client(int id)423 network_server::enable_client(int id)
424 {
425     client_iter     i = find_id(client, id);
426 
427     if (i == client.end())
428         throw exception("id nonexistent");
429 
430     i->enabled = true;
431 }
432 
433 void
disable_client(int id)434 network_server::disable_client(int id)
435 {
436     client_iter     i = find_id(client, id);
437 
438     if (i == client.end())
439         throw exception("id nonexistent");
440 
441     i->enabled = false;
442 }
443 
444 /*
445  * Push 'msg' onto the stack of the client with the id 'to' (or to all's if
446  * to == TO_ALL).
447  */
448 void
send(const string & msg,int to)449 network_server::send(const string& msg, int to)
450 {
451     for (size_t i = 0; i < client.size(); i++)
452     {
453         if (client[i].fd == to || (to == TO_ALL && client[i].enabled))
454         {
455             client[i].writeq.push(msg + '\n');
456 
457             if (to != TO_ALL) return;
458         }
459     }
460 
461     if (to != TO_ALL)
462         throw exception("id nonexistent");
463 }
464 
465 /*
466  * If a message arrived it is stored to 'msg' and true is returned. The
467  * sender-id is stored into 'from'.
468  */
469 bool
receive(string * msg,int from)470 network_server::receive(string* msg, int from)
471 {
472     client_iter     i = find_id(client, from);
473 
474     if (i == client.end())
475         throw exception("id nonexistent");
476 
477     if (i->readq.empty())
478         return false;
479     else
480     {
481         *msg = i->readq.front();
482         i->readq.pop();
483         return true;
484     }
485 }
486 
487 /*
488  * Try to send each writeq to its client and fill the readq with incoming
489  * data.
490  */
491 void
dispatch()492 network_server::dispatch()
493 {
494     int     nfds = 0;       /* number of elements in fds */
495     int     nused;          /* number of fds read-/writable */
496 
497 #if !defined(HAVE_POLL) && !defined(HAVE_SELECT) && !defined(_WIN32)
498     #error "neither poll() nor select() available!"
499 #endif
500 
501 #if !defined(HAVE_POLL)
502     fd_set  readset;
503     fd_set  writeset;
504 
505     FD_ZERO(&readset);
506     FD_ZERO(&writeset);
507 #endif
508 
509     /* fill the fds array/fd_set */
510     for (size_t i = 0; i < client.size(); i++)
511     {
512         host&   c = client[i];
513 
514         if (c.lost) continue;
515 
516 #if defined(HAVE_POLL)
517         fds[nfds].fd = c.fd;
518         fds[nfds].events = POLLIN;
519 
520         if (!c.writeq.empty())
521             fds[nfds].events |= POLLOUT;
522 
523         nfds++;
524 #else
525         FD_SET(c.fd, &readset);
526 
527         if (!c.writeq.empty())
528             FD_SET(c.fd, &writeset);
529 
530         if (c.fd > nfds)
531             nfds = c.fd;
532 #endif
533     }
534 
535 #if defined(HAVE_POLL)
536     nused = poll(fds, nfds, 0);
537     if (nused == -1) throw exception_e("poll");
538 
539     /* i is the index in the fds array, j in client */
540     for (size_t i = 0, j = 0; nused > 0; i++)
541     {
542         /* skip over clients not in the fds array */
543         while (j < client.size() && client[j].fd != fds[i].fd)
544             j++;
545 
546         assert(j < client.size());
547 
548         if (fds[i].revents == 0) continue;
549 
550         nused--;
551 
552         host&   c = client[j];
553 
554         if (fds[i].revents & POLLERR || fds[i].revents & POLLHUP)
555         {
556             c.lost = true;
557             continue;
558         }
559 
560         if (fds[i].revents & POLLOUT)
561         {
562             c.flush_writeq();
563         }
564 
565         if (fds[i].revents & POLLIN)
566         {
567             c.fill_readq();
568         }
569     }
570 #else
571     timeval     tv;
572     tv.tv_sec = tv.tv_usec = 0;
573 
574     nused = 0;
575 
576     if (nfds > 0)
577     {
578         nused = select(nfds + 1, &readset, &writeset, NULL, &tv);
579         SET_ERRNO;
580     }
581 
582     if (nused == -1) throw exception_e("select");
583 
584     for (int i = 0; i <= nfds && nused > 0; i++)
585     {
586         client_iter     it;
587 
588         if (FD_ISSET(i, &readset))
589         {
590             it = find_id(client, i);
591             assert(it != client.end());
592             it->fill_readq();
593             nused--;
594         }
595         if (FD_ISSET(i, &writeset))
596         {
597             it = find_id(client, i);
598             assert(it != client.end());
599             it->flush_writeq();
600             nused--;
601         }
602     }
603 
604     assert(nused == 0);
605 #endif
606 }
607 
608 /*
609  * Search for host with id 'id'. Returns vh.end() if not found.
610  */
611 network_server::client_iter
find_id(vector<host> & vh,int id)612 network_server::find_id(vector<host>& vh, int id)
613 {
614     client_iter     ret;
615 
616     for (ret = vh.begin(); ret != vh.end(); ret++)
617         if (ret->fd == id)
618             break;
619 
620     return ret;
621 }
622 
623 /*
624  * Send as much data as possible from writeq.
625  */
626 void
flush_writeq()627 host::flush_writeq()
628 {
629     errno = 0;
630 
631     while (!writeq.empty())
632     {
633         ssize_t n = ::send(fd, writeq.front().data() + written_chars,
634                            writeq.front().length() - written_chars, 0);
635         SET_ERRNO;
636 
637         if (n <= 0)
638         {
639             if (errno == EINTR) continue; /* try again */
640 
641             if (errno != EWOULDBLOCK && errno != EAGAIN)
642             {
643                 lost = true;
644             }
645             break;
646         }
647 
648         written_chars += n;
649         if (written_chars == writeq.front().length())
650         {
651             /* All of the line has been written, remove it from writeq. */
652             writeq.pop();
653             written_chars = 0;
654         }
655     }
656 }
657 
658 /*
659  * Read as much data as possible into readq.
660  */
661 void
fill_readq()662 host::fill_readq()
663 {
664     char    buf[2048];
665 
666     errno = 0;
667 
668     for (;;)
669     {
670         ssize_t n = recv(fd, buf, sizeof buf, 0);
671         SET_ERRNO;
672 
673         if (n <= 0)
674         {
675             if (errno == EINTR) continue;
676 
677             if (errno != EWOULDBLOCK && errno != EAGAIN)
678             {
679                 lost = true;
680             }
681             break;
682         }
683 
684         char*   p; /* Points to the next \n. */
685         char*   r; /* Points to the beginning of a line. */
686 
687         for (r = buf;
688              r < buf + n &&
689                 (p = (char*)std::memchr(r, '\n', buf + n - r)) != NULL;
690              r = p + 1)
691         {
692             if (incomplete.length() > 0)
693             {
694                 /* incomplete + everything to the next \n go into readq. */
695                 readq.push(incomplete);
696                 /* Dont store the '\n'. */
697                 readq.back().append(r, p - r);
698                 incomplete = "";
699             }
700             else
701             {
702                 /* Don't store the '\n'. */
703                 readq.push(string(r, p - r));
704             }
705         }
706 
707         /* Everything after the last \n goes into incomplete. */
708         incomplete.append(r, buf + n - r);
709     }
710 }
711