1 /*=========================================================================*\
2 * Socket compatibilization module for Unix
3 * LuaSocket toolkit
4 *
5 * The code is now interrupt-safe.
6 * The penalty of calling select to avoid busy-wait is only paid when
7 * the I/O call fail in the first place.
8 \*=========================================================================*/
9 #include <string.h>
10 #include <signal.h>
11 
12 #include "socket.h"
13 #include "pierror.h"
14 
15 /*-------------------------------------------------------------------------*\
16 * Wait for readable/writable/connected socket with timeout
17 \*-------------------------------------------------------------------------*/
18 #ifndef SOCKET_SELECT
19 #include <sys/poll.h>
20 
21 #define WAITFD_R        POLLIN
22 #define WAITFD_W        POLLOUT
23 #define WAITFD_C        (POLLIN|POLLOUT)
socket_waitfd(p_socket ps,int sw,p_timeout tm)24 int socket_waitfd(p_socket ps, int sw, p_timeout tm) {
25     int ret;
26     struct pollfd pfd;
27     pfd.fd = *ps;
28     pfd.events = sw;
29     pfd.revents = 0;
30     if (timeout_iszero(tm)) return IO_TIMEOUT;  /* optimize timeout == 0 case */
31     do {
32         int t = (int)(timeout_getretry(tm)*1e3);
33         ret = poll(&pfd, 1, t >= 0? t: -1);
34     } while (ret == -1 && errno == EINTR);
35     if (ret == -1) return errno;
36     if (ret == 0) return IO_TIMEOUT;
37     if (sw == WAITFD_C && (pfd.revents & (POLLIN|POLLERR))) return IO_CLOSED;
38     return IO_DONE;
39 }
40 #else
41 
42 #define WAITFD_R        1
43 #define WAITFD_W        2
44 #define WAITFD_C        (WAITFD_R|WAITFD_W)
45 
socket_waitfd(p_socket ps,int sw,p_timeout tm)46 int socket_waitfd(p_socket ps, int sw, p_timeout tm) {
47     int ret;
48     fd_set rfds, wfds, *rp, *wp;
49     struct timeval tv, *tp;
50     double t;
51     if (*ps >= FD_SETSIZE) return EINVAL;
52     if (timeout_iszero(tm)) return IO_TIMEOUT;  /* optimize timeout == 0 case */
53     do {
54         /* must set bits within loop, because select may have modifed them */
55         rp = wp = NULL;
56         if (sw & WAITFD_R) { FD_ZERO(&rfds); FD_SET(*ps, &rfds); rp = &rfds; }
57         if (sw & WAITFD_W) { FD_ZERO(&wfds); FD_SET(*ps, &wfds); wp = &wfds; }
58         t = timeout_getretry(tm);
59         tp = NULL;
60         if (t >= 0.0) {
61             tv.tv_sec = (int)t;
62             tv.tv_usec = (int)((t-tv.tv_sec)*1.0e6);
63             tp = &tv;
64         }
65         ret = select(*ps+1, rp, wp, NULL, tp);
66     } while (ret == -1 && errno == EINTR);
67     if (ret == -1) return errno;
68     if (ret == 0) return IO_TIMEOUT;
69     if (sw == WAITFD_C && FD_ISSET(*ps, &rfds)) return IO_CLOSED;
70     return IO_DONE;
71 }
72 #endif
73 
74 
75 /*-------------------------------------------------------------------------*\
76 * Initializes module
77 \*-------------------------------------------------------------------------*/
socket_open(void)78 int socket_open(void) {
79     /* instals a handler to ignore sigpipe or it will crash us */
80     signal(SIGPIPE, SIG_IGN);
81     return 1;
82 }
83 
84 /*-------------------------------------------------------------------------*\
85 * Close module
86 \*-------------------------------------------------------------------------*/
socket_close(void)87 int socket_close(void) {
88     return 1;
89 }
90 
91 /*-------------------------------------------------------------------------*\
92 * Close and inutilize socket
93 \*-------------------------------------------------------------------------*/
socket_destroy(p_socket ps)94 void socket_destroy(p_socket ps) {
95     if (*ps != SOCKET_INVALID) {
96         close(*ps);
97         *ps = SOCKET_INVALID;
98     }
99 }
100 
101 /*-------------------------------------------------------------------------*\
102 * Select with timeout control
103 \*-------------------------------------------------------------------------*/
socket_select(t_socket n,fd_set * rfds,fd_set * wfds,fd_set * efds,p_timeout tm)104 int socket_select(t_socket n, fd_set *rfds, fd_set *wfds, fd_set *efds,
105         p_timeout tm) {
106     int ret;
107     do {
108         struct timeval tv;
109         double t = timeout_getretry(tm);
110         tv.tv_sec = (int) t;
111         tv.tv_usec = (int) ((t - tv.tv_sec) * 1.0e6);
112         /* timeout = 0 means no wait */
113         ret = select(n, rfds, wfds, efds, t >= 0.0 ? &tv: NULL);
114     } while (ret < 0 && errno == EINTR);
115     return ret;
116 }
117 
118 /*-------------------------------------------------------------------------*\
119 * Creates and sets up a socket
120 \*-------------------------------------------------------------------------*/
socket_create(p_socket ps,int domain,int type,int protocol)121 int socket_create(p_socket ps, int domain, int type, int protocol) {
122     *ps = socket(domain, type, protocol);
123     if (*ps != SOCKET_INVALID) return IO_DONE;
124     else return errno;
125 }
126 
127 /*-------------------------------------------------------------------------*\
128 * Binds or returns error message
129 \*-------------------------------------------------------------------------*/
socket_bind(p_socket ps,SA * addr,socklen_t len)130 int socket_bind(p_socket ps, SA *addr, socklen_t len) {
131     int err = IO_DONE;
132     socket_setblocking(ps);
133     if (bind(*ps, addr, len) < 0) err = errno;
134     socket_setnonblocking(ps);
135     return err;
136 }
137 
138 /*-------------------------------------------------------------------------*\
139 *
140 \*-------------------------------------------------------------------------*/
socket_listen(p_socket ps,int backlog)141 int socket_listen(p_socket ps, int backlog) {
142     int err = IO_DONE;
143     if (listen(*ps, backlog)) err = errno;
144     return err;
145 }
146 
147 /*-------------------------------------------------------------------------*\
148 *
149 \*-------------------------------------------------------------------------*/
socket_shutdown(p_socket ps,int how)150 void socket_shutdown(p_socket ps, int how) {
151     shutdown(*ps, how);
152 }
153 
154 /*-------------------------------------------------------------------------*\
155 * Connects or returns error message
156 \*-------------------------------------------------------------------------*/
socket_connect(p_socket ps,SA * addr,socklen_t len,p_timeout tm)157 int socket_connect(p_socket ps, SA *addr, socklen_t len, p_timeout tm) {
158     int err;
159     /* avoid calling on closed sockets */
160     if (*ps == SOCKET_INVALID) return IO_CLOSED;
161     /* call connect until done or failed without being interrupted */
162     do if (connect(*ps, addr, len) == 0) return IO_DONE;
163     while ((err = errno) == EINTR);
164     /* if connection failed immediately, return error code */
165     if (err != EINPROGRESS && err != EAGAIN) return err;
166     /* zero timeout case optimization */
167     if (timeout_iszero(tm)) return IO_TIMEOUT;
168     /* wait until we have the result of the connection attempt or timeout */
169     err = socket_waitfd(ps, WAITFD_C, tm);
170     if (err == IO_CLOSED) {
171         if (recv(*ps, (char *) &err, 0, 0) == 0) return IO_DONE;
172         else return errno;
173     } else return err;
174 }
175 
176 /*-------------------------------------------------------------------------*\
177 * Accept with timeout
178 \*-------------------------------------------------------------------------*/
socket_accept(p_socket ps,p_socket pa,SA * addr,socklen_t * len,p_timeout tm)179 int socket_accept(p_socket ps, p_socket pa, SA *addr, socklen_t *len, p_timeout tm) {
180     if (*ps == SOCKET_INVALID) return IO_CLOSED;
181     for ( ;; ) {
182         int err;
183         if ((*pa = accept(*ps, addr, len)) != SOCKET_INVALID) return IO_DONE;
184         err = errno;
185         if (err == EINTR) continue;
186         if (err != EAGAIN && err != ECONNABORTED) return err;
187         if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err;
188     }
189     /* can't reach here */
190     return IO_UNKNOWN;
191 }
192 
193 /*-------------------------------------------------------------------------*\
194 * Send with timeout
195 \*-------------------------------------------------------------------------*/
socket_send(p_socket ps,const char * data,size_t count,size_t * sent,p_timeout tm)196 int socket_send(p_socket ps, const char *data, size_t count,
197         size_t *sent, p_timeout tm)
198 {
199     int err;
200     *sent = 0;
201     /* avoid making system calls on closed sockets */
202     if (*ps == SOCKET_INVALID) return IO_CLOSED;
203     /* loop until we send something or we give up on error */
204     for ( ;; ) {
205         long put = (long) send(*ps, data, count, 0);
206         /* if we sent anything, we are done */
207         if (put >= 0) {
208             *sent = put;
209             return IO_DONE;
210         }
211         err = errno;
212         /* EPIPE means the connection was closed */
213         if (err == EPIPE) return IO_CLOSED;
214         /* EPROTOTYPE means the connection is being closed (on Yosemite!)*/
215         if (err == EPROTOTYPE) continue;
216         /* we call was interrupted, just try again */
217         if (err == EINTR) continue;
218         /* if failed fatal reason, report error */
219         if (err != EAGAIN) return err;
220         /* wait until we can send something or we timeout */
221         if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err;
222     }
223     /* can't reach here */
224     return IO_UNKNOWN;
225 }
226 
227 /*-------------------------------------------------------------------------*\
228 * Sendto with timeout
229 \*-------------------------------------------------------------------------*/
socket_sendto(p_socket ps,const char * data,size_t count,size_t * sent,SA * addr,socklen_t len,p_timeout tm)230 int socket_sendto(p_socket ps, const char *data, size_t count, size_t *sent,
231         SA *addr, socklen_t len, p_timeout tm)
232 {
233     int err;
234     *sent = 0;
235     if (*ps == SOCKET_INVALID) return IO_CLOSED;
236     for ( ;; ) {
237         long put = (long) sendto(*ps, data, count, 0, addr, len);
238         if (put >= 0) {
239             *sent = put;
240             return IO_DONE;
241         }
242         err = errno;
243         if (err == EPIPE) return IO_CLOSED;
244         if (err == EPROTOTYPE) continue;
245         if (err == EINTR) continue;
246         if (err != EAGAIN) return err;
247         if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err;
248     }
249     return IO_UNKNOWN;
250 }
251 
252 /*-------------------------------------------------------------------------*\
253 * Receive with timeout
254 \*-------------------------------------------------------------------------*/
socket_recv(p_socket ps,char * data,size_t count,size_t * got,p_timeout tm)255 int socket_recv(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm) {
256     int err;
257     *got = 0;
258     if (*ps == SOCKET_INVALID) return IO_CLOSED;
259     for ( ;; ) {
260         long taken = (long) recv(*ps, data, count, 0);
261         if (taken > 0) {
262             *got = taken;
263             return IO_DONE;
264         }
265         err = errno;
266         if (taken == 0) return IO_CLOSED;
267         if (err == EINTR) continue;
268         if (err != EAGAIN) return err;
269         if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err;
270     }
271     return IO_UNKNOWN;
272 }
273 
274 /*-------------------------------------------------------------------------*\
275 * Recvfrom with timeout
276 \*-------------------------------------------------------------------------*/
socket_recvfrom(p_socket ps,char * data,size_t count,size_t * got,SA * addr,socklen_t * len,p_timeout tm)277 int socket_recvfrom(p_socket ps, char *data, size_t count, size_t *got,
278         SA *addr, socklen_t *len, p_timeout tm) {
279     int err;
280     *got = 0;
281     if (*ps == SOCKET_INVALID) return IO_CLOSED;
282     for ( ;; ) {
283         long taken = (long) recvfrom(*ps, data, count, 0, addr, len);
284         if (taken > 0) {
285             *got = taken;
286             return IO_DONE;
287         }
288         err = errno;
289         if (taken == 0) return IO_CLOSED;
290         if (err == EINTR) continue;
291         if (err != EAGAIN) return err;
292         if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err;
293     }
294     return IO_UNKNOWN;
295 }
296 
297 
298 /*-------------------------------------------------------------------------*\
299 * Write with timeout
300 *
301 * socket_read and socket_write are cut-n-paste of socket_send and socket_recv,
302 * with send/recv replaced with write/read. We can't just use write/read
303 * in the socket version, because behaviour when size is zero is different.
304 \*-------------------------------------------------------------------------*/
socket_write(p_socket ps,const char * data,size_t count,size_t * sent,p_timeout tm)305 int socket_write(p_socket ps, const char *data, size_t count,
306         size_t *sent, p_timeout tm)
307 {
308     int err;
309     *sent = 0;
310     /* avoid making system calls on closed sockets */
311     if (*ps == SOCKET_INVALID) return IO_CLOSED;
312     /* loop until we send something or we give up on error */
313     for ( ;; ) {
314         long put = (long) write(*ps, data, count);
315         /* if we sent anything, we are done */
316         if (put >= 0) {
317             *sent = put;
318             return IO_DONE;
319         }
320         err = errno;
321         /* EPIPE means the connection was closed */
322         if (err == EPIPE) return IO_CLOSED;
323         /* EPROTOTYPE means the connection is being closed (on Yosemite!)*/
324         if (err == EPROTOTYPE) continue;
325         /* we call was interrupted, just try again */
326         if (err == EINTR) continue;
327         /* if failed fatal reason, report error */
328         if (err != EAGAIN) return err;
329         /* wait until we can send something or we timeout */
330         if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err;
331     }
332     /* can't reach here */
333     return IO_UNKNOWN;
334 }
335 
336 /*-------------------------------------------------------------------------*\
337 * Read with timeout
338 * See note for socket_write
339 \*-------------------------------------------------------------------------*/
socket_read(p_socket ps,char * data,size_t count,size_t * got,p_timeout tm)340 int socket_read(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm) {
341     int err;
342     *got = 0;
343     if (*ps == SOCKET_INVALID) return IO_CLOSED;
344     for ( ;; ) {
345         long taken = (long) read(*ps, data, count);
346         if (taken > 0) {
347             *got = taken;
348             return IO_DONE;
349         }
350         err = errno;
351         if (taken == 0) return IO_CLOSED;
352         if (err == EINTR) continue;
353         if (err != EAGAIN) return err;
354         if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err;
355     }
356     return IO_UNKNOWN;
357 }
358 
359 /*-------------------------------------------------------------------------*\
360 * Put socket into blocking mode
361 \*-------------------------------------------------------------------------*/
socket_setblocking(p_socket ps)362 void socket_setblocking(p_socket ps) {
363     int flags = fcntl(*ps, F_GETFL, 0);
364     flags &= (~(O_NONBLOCK));
365     fcntl(*ps, F_SETFL, flags);
366 }
367 
368 /*-------------------------------------------------------------------------*\
369 * Put socket into non-blocking mode
370 \*-------------------------------------------------------------------------*/
socket_setnonblocking(p_socket ps)371 void socket_setnonblocking(p_socket ps) {
372     int flags = fcntl(*ps, F_GETFL, 0);
373     flags |= O_NONBLOCK;
374     fcntl(*ps, F_SETFL, flags);
375 }
376 
377 /*-------------------------------------------------------------------------*\
378 * DNS helpers
379 \*-------------------------------------------------------------------------*/
socket_gethostbyaddr(const char * addr,socklen_t len,struct hostent ** hp)380 int socket_gethostbyaddr(const char *addr, socklen_t len, struct hostent **hp) {
381     *hp = gethostbyaddr(addr, len, AF_INET);
382     if (*hp) return IO_DONE;
383     else if (h_errno) return h_errno;
384     else if (errno) return errno;
385     else return IO_UNKNOWN;
386 }
387 
socket_gethostbyname(const char * addr,struct hostent ** hp)388 int socket_gethostbyname(const char *addr, struct hostent **hp) {
389     *hp = gethostbyname(addr);
390     if (*hp) return IO_DONE;
391     else if (h_errno) return h_errno;
392     else if (errno) return errno;
393     else return IO_UNKNOWN;
394 }
395 
396 /*-------------------------------------------------------------------------*\
397 * Error translation functions
398 * Make sure important error messages are standard
399 \*-------------------------------------------------------------------------*/
socket_hoststrerror(int err)400 const char *socket_hoststrerror(int err) {
401     if (err <= 0) return io_strerror(err);
402     switch (err) {
403         case HOST_NOT_FOUND: return PIE_HOST_NOT_FOUND;
404         default: return hstrerror(err);
405     }
406 }
407 
socket_strerror(int err)408 const char *socket_strerror(int err) {
409     if (err <= 0) return io_strerror(err);
410     switch (err) {
411         case EADDRINUSE: return PIE_ADDRINUSE;
412         case EISCONN: return PIE_ISCONN;
413         case EACCES: return PIE_ACCESS;
414         case ECONNREFUSED: return PIE_CONNREFUSED;
415         case ECONNABORTED: return PIE_CONNABORTED;
416         case ECONNRESET: return PIE_CONNRESET;
417         case ETIMEDOUT: return PIE_TIMEDOUT;
418         default: {
419             return strerror(err);
420         }
421     }
422 }
423 
socket_ioerror(p_socket ps,int err)424 const char *socket_ioerror(p_socket ps, int err) {
425     (void) ps;
426     return socket_strerror(err);
427 }
428 
socket_gaistrerror(int err)429 const char *socket_gaistrerror(int err) {
430     if (err == 0) return NULL;
431     switch (err) {
432         case EAI_AGAIN: return PIE_AGAIN;
433         case EAI_BADFLAGS: return PIE_BADFLAGS;
434 #ifdef EAI_BADHINTS
435         case EAI_BADHINTS: return PIE_BADHINTS;
436 #endif
437         case EAI_FAIL: return PIE_FAIL;
438         case EAI_FAMILY: return PIE_FAMILY;
439         case EAI_MEMORY: return PIE_MEMORY;
440         case EAI_NONAME: return PIE_NONAME;
441         case EAI_OVERFLOW: return PIE_OVERFLOW;
442 #ifdef EAI_PROTOCOL
443         case EAI_PROTOCOL: return PIE_PROTOCOL;
444 #endif
445         case EAI_SERVICE: return PIE_SERVICE;
446         case EAI_SOCKTYPE: return PIE_SOCKTYPE;
447         case EAI_SYSTEM: return strerror(errno);
448         default: return gai_strerror(err);
449     }
450 }
451 
452