1 /*
2  * redir.c - Provide a transparent TCP proxy through remote shadowsocks
3  *           server
4  *
5  * Copyright (C) 2013 - 2019, Max Lv <max.c.lv@gmail.com>
6  *
7  * This file is part of the shadowsocks-libev.
8  *
9  * shadowsocks-libev is free software; you can redistribute it and/or modify
10  * it under the terms of the GNU General Public License as published by
11  * the Free Software Foundation; either version 3 of the License, or
12  * (at your option) any later version.
13  *
14  * shadowsocks-libev is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17  * GNU General Public License for more details.
18  *
19  * You should have received a copy of the GNU General Public License
20  * along with shadowsocks-libev; see the file COPYING. If not, see
21  * <http://www.gnu.org/licenses/>.
22  */
23 
24 #include <sys/stat.h>
25 #include <sys/types.h>
26 #include <arpa/inet.h>
27 #include <errno.h>
28 #include <fcntl.h>
29 #include <locale.h>
30 #include <netdb.h>
31 #include <netinet/in.h>
32 #include <pthread.h>
33 #include <signal.h>
34 #include <string.h>
35 #include <strings.h>
36 #include <time.h>
37 #include <unistd.h>
38 #include <getopt.h>
39 #include <limits.h>
40 #include <linux/if.h>
41 #include <linux/netfilter_ipv4.h>
42 #include <linux/netfilter_ipv6/ip6_tables.h>
43 
44 #include <libcork/core.h>
45 
46 #ifdef HAVE_CONFIG_H
47 #include "config.h"
48 #endif
49 
50 #include "plugin.h"
51 #include "netutils.h"
52 #include "utils.h"
53 #include "common.h"
54 #include "redir.h"
55 
56 #ifndef EAGAIN
57 #define EAGAIN EWOULDBLOCK
58 #endif
59 
60 #ifndef EWOULDBLOCK
61 #define EWOULDBLOCK EAGAIN
62 #endif
63 
64 #ifndef IP6T_SO_ORIGINAL_DST
65 #define IP6T_SO_ORIGINAL_DST 80
66 #endif
67 
68 #ifndef IP_TRANSPARENT
69 #define IP_TRANSPARENT       19
70 #endif
71 
72 #ifndef IPV6_TRANSPARENT
73 #define IPV6_TRANSPARENT     75
74 #endif
75 
76 static void accept_cb(EV_P_ ev_io *w, int revents);
77 static void server_recv_cb(EV_P_ ev_io *w, int revents);
78 static void server_send_cb(EV_P_ ev_io *w, int revents);
79 static void remote_recv_cb(EV_P_ ev_io *w, int revents);
80 static void remote_send_cb(EV_P_ ev_io *w, int revents);
81 
82 static remote_t *new_remote(int fd, int timeout);
83 static server_t *new_server(int fd);
84 
85 static void free_remote(remote_t *remote);
86 static void close_and_free_remote(EV_P_ remote_t *remote);
87 static void free_server(server_t *server);
88 static void close_and_free_server(EV_P_ server_t *server);
89 
90 int verbose    = 0;
91 int reuse_port = 0;
92 
93 static crypto_t *crypto;
94 
95 static int ipv6first = 0;
96 static int mode      = TCP_ONLY;
97 #ifdef HAVE_SETRLIMIT
98 static int nofile = 0;
99 #endif
100 int fast_open       = 0;
101 static int no_delay = 0;
102 static int ret_val  = 0;
103 
104 static struct ev_signal sigint_watcher;
105 static struct ev_signal sigterm_watcher;
106 static struct ev_signal sigchld_watcher;
107 
108 static int tcp_tproxy = 0; /* use tproxy instead of redirect (for tcp) */
109 
110 static int
getdestaddr(int fd,struct sockaddr_storage * destaddr)111 getdestaddr(int fd, struct sockaddr_storage *destaddr)
112 {
113     socklen_t socklen = sizeof(*destaddr);
114     int error         = 0;
115 
116     if (tcp_tproxy) {
117         error = getsockname(fd, (void *)destaddr, &socklen);
118     } else {
119         error = getsockopt(fd, SOL_IPV6, IP6T_SO_ORIGINAL_DST, destaddr, &socklen);
120         if (error) { // Didn't find a proper way to detect IP version.
121             error = getsockopt(fd, SOL_IP, SO_ORIGINAL_DST, destaddr, &socklen);
122         }
123     }
124 
125     if (error) {
126         return -1;
127     }
128     return 0;
129 }
130 
131 int
setnonblocking(int fd)132 setnonblocking(int fd)
133 {
134     int flags;
135     if (-1 == (flags = fcntl(fd, F_GETFL, 0))) {
136         flags = 0;
137     }
138     return fcntl(fd, F_SETFL, flags | O_NONBLOCK);
139 }
140 
141 int
create_and_bind(const char * addr,const char * port)142 create_and_bind(const char *addr, const char *port)
143 {
144     struct addrinfo hints;
145     struct addrinfo *result, *rp;
146     int s, listen_sock;
147 
148     memset(&hints, 0, sizeof(struct addrinfo));
149     hints.ai_family   = AF_UNSPEC;   /* Return IPv4 and IPv6 choices */
150     hints.ai_socktype = SOCK_STREAM; /* We want a TCP socket */
151 
152     result = NULL;
153 
154     s = getaddrinfo(addr, port, &hints, &result);
155     if (s != 0) {
156         LOGI("getaddrinfo: %s", gai_strerror(s));
157         return -1;
158     }
159 
160     if (result == NULL) {
161         LOGE("Could not bind");
162         return -1;
163     }
164 
165     for (rp = result; rp != NULL; rp = rp->ai_next) {
166         listen_sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
167         if (listen_sock == -1) {
168             continue;
169         }
170 
171         int opt = 1;
172         setsockopt(listen_sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
173 #ifdef SO_NOSIGPIPE
174         setsockopt(listen_sock, SOL_SOCKET, SO_NOSIGPIPE, &opt, sizeof(opt));
175 #endif
176         if (reuse_port) {
177             int err = set_reuseport(listen_sock);
178             if (err == 0) {
179                 LOGI("tcp port reuse enabled");
180             }
181         }
182 
183         if (tcp_tproxy) {
184             int level = 0, optname = 0;
185             if (rp->ai_family == AF_INET) {
186                 level = IPPROTO_IP;
187                 optname = IP_TRANSPARENT;
188             } else {
189                 level = IPPROTO_IPV6;
190                 optname = IPV6_TRANSPARENT;
191             }
192 
193             if (setsockopt(listen_sock, level, optname, &opt, sizeof(opt)) != 0) {
194                 ERROR("setsockopt IP_TRANSPARENT");
195                 exit(EXIT_FAILURE);
196             }
197             LOGI("tcp tproxy mode enabled");
198         }
199 
200         s = bind(listen_sock, rp->ai_addr, rp->ai_addrlen);
201         if (s == 0) {
202             /* We managed to bind successfully! */
203             break;
204         } else {
205             ERROR("bind");
206         }
207 
208         close(listen_sock);
209         listen_sock = -1;
210     }
211 
212     freeaddrinfo(result);
213 
214     return listen_sock;
215 }
216 
217 static void
server_recv_cb(EV_P_ ev_io * w,int revents)218 server_recv_cb(EV_P_ ev_io *w, int revents)
219 {
220     server_ctx_t *server_recv_ctx = (server_ctx_t *)w;
221     server_t *server              = server_recv_ctx->server;
222     remote_t *remote              = server->remote;
223 
224     ev_timer_stop(EV_A_ & server->delayed_connect_watcher);
225 
226     ssize_t r = recv(server->fd, remote->buf->data + remote->buf->len,
227                      SOCKET_BUF_SIZE - remote->buf->len, 0);
228 
229     if (r == 0) {
230         // connection closed
231         close_and_free_remote(EV_A_ remote);
232         close_and_free_server(EV_A_ server);
233         return;
234     } else if (r == -1) {
235         if (errno == EAGAIN || errno == EWOULDBLOCK) {
236             // no data
237             // continue to wait for recv
238             return;
239         } else {
240             ERROR("server recv");
241             close_and_free_remote(EV_A_ remote);
242             close_and_free_server(EV_A_ server);
243             return;
244         }
245     }
246 
247     remote->buf->len += r;
248 
249     if (verbose) {
250         uint16_t port = 0;
251         char ipstr[INET6_ADDRSTRLEN];
252         memset(&ipstr, 0, INET6_ADDRSTRLEN);
253 
254         if (AF_INET == server->destaddr.ss_family) {
255             struct sockaddr_in *sa = (struct sockaddr_in *)&(server->destaddr);
256             inet_ntop(AF_INET, &(sa->sin_addr), ipstr, INET_ADDRSTRLEN);
257             port = ntohs(sa->sin_port);
258         } else {
259             struct sockaddr_in6 *sa = (struct sockaddr_in6 *)&(server->destaddr);
260             inet_ntop(AF_INET6, &(sa->sin6_addr), ipstr, INET6_ADDRSTRLEN);
261             port = ntohs(sa->sin6_port);
262         }
263 
264         LOGI("redir to %s:%d, len=%zu, recv=%zd", ipstr, port, remote->buf->len, r);
265     }
266 
267     if (!remote->send_ctx->connected) {
268         ev_io_stop(EV_A_ & server_recv_ctx->io);
269         ev_io_start(EV_A_ & remote->send_ctx->io);
270         return;
271     }
272 
273     int err = crypto->encrypt(remote->buf, server->e_ctx, SOCKET_BUF_SIZE);
274 
275     if (err) {
276         LOGE("invalid password or cipher");
277         close_and_free_remote(EV_A_ remote);
278         close_and_free_server(EV_A_ server);
279         return;
280     }
281 
282     int s = send(remote->fd, remote->buf->data, remote->buf->len, 0);
283 
284     if (s == -1) {
285         if (errno == EAGAIN || errno == EWOULDBLOCK) {
286             // no data, wait for send
287             remote->buf->idx = 0;
288             ev_io_stop(EV_A_ & server_recv_ctx->io);
289             ev_io_start(EV_A_ & remote->send_ctx->io);
290             return;
291         } else {
292             ERROR("send");
293             close_and_free_remote(EV_A_ remote);
294             close_and_free_server(EV_A_ server);
295             return;
296         }
297     } else if (s < remote->buf->len) {
298         remote->buf->len -= s;
299         remote->buf->idx  = s;
300         ev_io_stop(EV_A_ & server_recv_ctx->io);
301         ev_io_start(EV_A_ & remote->send_ctx->io);
302         return;
303     } else {
304         remote->buf->idx = 0;
305         remote->buf->len = 0;
306     }
307 }
308 
309 static void
server_send_cb(EV_P_ ev_io * w,int revents)310 server_send_cb(EV_P_ ev_io *w, int revents)
311 {
312     server_ctx_t *server_send_ctx = (server_ctx_t *)w;
313     server_t *server              = server_send_ctx->server;
314     remote_t *remote              = server->remote;
315     if (server->buf->len == 0) {
316         // close and free
317         close_and_free_remote(EV_A_ remote);
318         close_and_free_server(EV_A_ server);
319         return;
320     } else {
321         // has data to send
322         ssize_t s = send(server->fd, server->buf->data + server->buf->idx,
323                          server->buf->len, 0);
324         if (s == -1) {
325             if (errno != EAGAIN && errno != EWOULDBLOCK) {
326                 ERROR("send");
327                 close_and_free_remote(EV_A_ remote);
328                 close_and_free_server(EV_A_ server);
329             }
330             return;
331         } else if (s < server->buf->len) {
332             // partly sent, move memory, wait for the next time to send
333             server->buf->len -= s;
334             server->buf->idx += s;
335             return;
336         } else {
337             // all sent out, wait for reading
338             server->buf->len = 0;
339             server->buf->idx = 0;
340             ev_io_stop(EV_A_ & server_send_ctx->io);
341             ev_io_start(EV_A_ & remote->recv_ctx->io);
342         }
343     }
344 }
345 
346 static void
delayed_connect_cb(EV_P_ ev_timer * watcher,int revents)347 delayed_connect_cb(EV_P_ ev_timer *watcher, int revents)
348 {
349     server_t *server = cork_container_of(watcher, server_t,
350                                          delayed_connect_watcher);
351     remote_t *remote = server->remote;
352 
353     int r = connect(remote->fd, remote->addr,
354                     get_sockaddr_len(remote->addr));
355 
356     remote->addr = NULL;
357 
358     if (r == -1 && errno != CONNECT_IN_PROGRESS) {
359         ERROR("connect");
360         close_and_free_remote(EV_A_ remote);
361         close_and_free_server(EV_A_ server);
362         return;
363     } else {
364         // listen to remote connected event
365         ev_io_start(EV_A_ & remote->send_ctx->io);
366         ev_timer_start(EV_A_ & remote->send_ctx->watcher);
367     }
368 }
369 
370 static void
remote_timeout_cb(EV_P_ ev_timer * watcher,int revents)371 remote_timeout_cb(EV_P_ ev_timer *watcher, int revents)
372 {
373     remote_ctx_t *remote_ctx
374         = cork_container_of(watcher, remote_ctx_t, watcher);
375 
376     remote_t *remote = remote_ctx->remote;
377     server_t *server = remote->server;
378 
379     ev_timer_stop(EV_A_ watcher);
380 
381     close_and_free_remote(EV_A_ remote);
382     close_and_free_server(EV_A_ server);
383 }
384 
385 static void
remote_recv_cb(EV_P_ ev_io * w,int revents)386 remote_recv_cb(EV_P_ ev_io *w, int revents)
387 {
388     remote_ctx_t *remote_recv_ctx = (remote_ctx_t *)w;
389     remote_t *remote              = remote_recv_ctx->remote;
390     server_t *server              = remote->server;
391 
392     ssize_t r = recv(remote->fd, server->buf->data, SOCKET_BUF_SIZE, 0);
393 
394     if (r == 0) {
395         // connection closed
396         close_and_free_remote(EV_A_ remote);
397         close_and_free_server(EV_A_ server);
398         return;
399     } else if (r == -1) {
400         if (errno == EAGAIN || errno == EWOULDBLOCK) {
401             // no data
402             // continue to wait for recv
403             return;
404         } else {
405             ERROR("remote recv");
406             close_and_free_remote(EV_A_ remote);
407             close_and_free_server(EV_A_ server);
408             return;
409         }
410     }
411 
412     server->buf->len = r;
413 
414     int err = crypto->decrypt(server->buf, server->d_ctx, SOCKET_BUF_SIZE);
415     if (err == CRYPTO_ERROR) {
416         LOGE("invalid password or cipher");
417         close_and_free_remote(EV_A_ remote);
418         close_and_free_server(EV_A_ server);
419         return;
420     } else if (err == CRYPTO_NEED_MORE) {
421         return; // Wait for more
422     }
423 
424     int s = send(server->fd, server->buf->data, server->buf->len, 0);
425 
426     if (s == -1) {
427         if (errno == EAGAIN || errno == EWOULDBLOCK) {
428             // no data, wait for send
429             server->buf->idx = 0;
430             ev_io_stop(EV_A_ & remote_recv_ctx->io);
431             ev_io_start(EV_A_ & server->send_ctx->io);
432         } else {
433             ERROR("send");
434             close_and_free_remote(EV_A_ remote);
435             close_and_free_server(EV_A_ server);
436             return;
437         }
438     } else if (s < server->buf->len) {
439         server->buf->len -= s;
440         server->buf->idx  = s;
441         ev_io_stop(EV_A_ & remote_recv_ctx->io);
442         ev_io_start(EV_A_ & server->send_ctx->io);
443     }
444 
445     // Disable TCP_NODELAY after the first response are sent
446     if (!remote->recv_ctx->connected && !no_delay) {
447         int opt = 0;
448         setsockopt(server->fd, SOL_TCP, TCP_NODELAY, &opt, sizeof(opt));
449         setsockopt(remote->fd, SOL_TCP, TCP_NODELAY, &opt, sizeof(opt));
450     }
451     remote->recv_ctx->connected = 1;
452 }
453 
454 static void
remote_send_cb(EV_P_ ev_io * w,int revents)455 remote_send_cb(EV_P_ ev_io *w, int revents)
456 {
457     remote_ctx_t *remote_send_ctx = (remote_ctx_t *)w;
458     remote_t *remote              = remote_send_ctx->remote;
459     server_t *server              = remote->server;
460 
461     ev_timer_stop(EV_A_ & remote_send_ctx->watcher);
462 
463     if (!remote_send_ctx->connected) {
464         int r = 0;
465         if (remote->addr == NULL) {
466             struct sockaddr_storage addr;
467             memset(&addr, 0, sizeof(struct sockaddr_storage));
468             socklen_t len = sizeof addr;
469             r = getpeername(remote->fd, (struct sockaddr *)&addr, &len);
470         }
471         if (r == 0) {
472             remote_send_ctx->connected = 1;
473 
474             ev_io_stop(EV_A_ & remote_send_ctx->io);
475             ev_io_stop(EV_A_ & server->recv_ctx->io);
476             ev_io_start(EV_A_ & remote->recv_ctx->io);
477 
478             // send destaddr
479             buffer_t ss_addr_to_send;
480             buffer_t *abuf = &ss_addr_to_send;
481             balloc(abuf, SOCKET_BUF_SIZE);
482 
483             if (AF_INET6 == server->destaddr.ss_family) { // IPv6
484                 abuf->data[abuf->len++] = 4;          // Type 4 is IPv6 address
485 
486                 size_t in6_addr_len = sizeof(struct in6_addr);
487                 memcpy(abuf->data + abuf->len,
488                        &(((struct sockaddr_in6 *)&(server->destaddr))->sin6_addr),
489                        in6_addr_len);
490                 abuf->len += in6_addr_len;
491                 memcpy(abuf->data + abuf->len,
492                        &(((struct sockaddr_in6 *)&(server->destaddr))->sin6_port),
493                        2);
494             } else {                             // IPv4
495                 abuf->data[abuf->len++] = 1; // Type 1 is IPv4 address
496 
497                 size_t in_addr_len = sizeof(struct in_addr);
498                 memcpy(abuf->data + abuf->len,
499                        &((struct sockaddr_in *)&(server->destaddr))->sin_addr, in_addr_len);
500                 abuf->len += in_addr_len;
501                 memcpy(abuf->data + abuf->len,
502                        &((struct sockaddr_in *)&(server->destaddr))->sin_port, 2);
503             }
504 
505             abuf->len += 2;
506 
507             int err = crypto->encrypt(abuf, server->e_ctx, SOCKET_BUF_SIZE);
508             if (err) {
509                 LOGE("invalid password or cipher");
510                 bfree(abuf);
511                 close_and_free_remote(EV_A_ remote);
512                 close_and_free_server(EV_A_ server);
513                 return;
514             }
515 
516             err = crypto->encrypt(remote->buf, server->e_ctx, SOCKET_BUF_SIZE);
517             if (err) {
518                 LOGE("invalid password or cipher");
519                 bfree(abuf);
520                 close_and_free_remote(EV_A_ remote);
521                 close_and_free_server(EV_A_ server);
522                 return;
523             }
524 
525             bprepend(remote->buf, abuf, SOCKET_BUF_SIZE);
526             bfree(abuf);
527         } else {
528             ERROR("getpeername");
529             // not connected
530             close_and_free_remote(EV_A_ remote);
531             close_and_free_server(EV_A_ server);
532             return;
533         }
534     }
535 
536     if (remote->buf->len == 0) {
537         // close and free
538         close_and_free_remote(EV_A_ remote);
539         close_and_free_server(EV_A_ server);
540         return;
541     } else {
542         // has data to send
543         int s = -1;
544 
545         if (remote->addr != NULL) {
546 #if defined(TCP_FASTOPEN_CONNECT)
547             int optval = 1;
548             if (setsockopt(remote->fd, IPPROTO_TCP, TCP_FASTOPEN_CONNECT,
549                            (void *)&optval, sizeof(optval)) < 0)
550                 FATAL("failed to set TCP_FASTOPEN_CONNECT");
551             s = connect(remote->fd, remote->addr, get_sockaddr_len(remote->addr));
552             if (s == 0)
553                 s = send(remote->fd, remote->buf->data, remote->buf->len, 0);
554 #elif defined(MSG_FASTOPEN)
555             s = sendto(remote->fd, remote->buf->data + remote->buf->idx,
556                        remote->buf->len, MSG_FASTOPEN, remote->addr,
557                        get_sockaddr_len(remote->addr));
558 #else
559             FATAL("tcp fast open is not supported on this platform");
560 #endif
561 
562             remote->addr = NULL;
563 
564             if (s == -1) {
565                 if (errno == CONNECT_IN_PROGRESS) {
566                     ev_io_start(EV_A_ & remote_send_ctx->io);
567                     ev_timer_start(EV_A_ & remote_send_ctx->watcher);
568                 } else {
569                     if (errno == EOPNOTSUPP || errno == EPROTONOSUPPORT ||
570                         errno == ENOPROTOOPT) {
571                         fast_open = 0;
572                         LOGE("fast open is not supported on this platform");
573                     } else {
574                         ERROR("fast_open_connect");
575                     }
576                     close_and_free_remote(EV_A_ remote);
577                     close_and_free_server(EV_A_ server);
578                 }
579                 return;
580             }
581         } else {
582             s = send(remote->fd, remote->buf->data + remote->buf->idx,
583                      remote->buf->len, 0);
584         }
585 
586         if (s == -1) {
587             if (errno != EAGAIN && errno != EWOULDBLOCK) {
588                 ERROR("send");
589                 // close and free
590                 close_and_free_remote(EV_A_ remote);
591                 close_and_free_server(EV_A_ server);
592             }
593             return;
594         } else if (s < remote->buf->len) {
595             // partly sent, move memory, wait for the next time to send
596             remote->buf->len -= s;
597             remote->buf->idx += s;
598             ev_io_start(EV_A_ & remote_send_ctx->io);
599             return;
600         } else {
601             // all sent out, wait for reading
602             remote->buf->len = 0;
603             remote->buf->idx = 0;
604             ev_io_stop(EV_A_ & remote_send_ctx->io);
605             ev_io_start(EV_A_ & server->recv_ctx->io);
606         }
607     }
608 }
609 
610 static remote_t *
new_remote(int fd,int timeout)611 new_remote(int fd, int timeout)
612 {
613     remote_t *remote = ss_malloc(sizeof(remote_t));
614     memset(remote, 0, sizeof(remote_t));
615 
616     remote->recv_ctx = ss_malloc(sizeof(remote_ctx_t));
617     remote->send_ctx = ss_malloc(sizeof(remote_ctx_t));
618     remote->buf      = ss_malloc(sizeof(buffer_t));
619     balloc(remote->buf, SOCKET_BUF_SIZE);
620     memset(remote->recv_ctx, 0, sizeof(remote_ctx_t));
621     memset(remote->send_ctx, 0, sizeof(remote_ctx_t));
622     remote->fd                  = fd;
623     remote->recv_ctx->remote    = remote;
624     remote->recv_ctx->connected = 0;
625     remote->send_ctx->remote    = remote;
626     remote->send_ctx->connected = 0;
627 
628     ev_io_init(&remote->recv_ctx->io, remote_recv_cb, fd, EV_READ);
629     ev_io_init(&remote->send_ctx->io, remote_send_cb, fd, EV_WRITE);
630     ev_timer_init(&remote->send_ctx->watcher, remote_timeout_cb,
631                   min(MAX_CONNECT_TIMEOUT, timeout), 0);
632 
633     return remote;
634 }
635 
636 static void
free_remote(remote_t * remote)637 free_remote(remote_t *remote)
638 {
639     if (remote->server != NULL) {
640         remote->server->remote = NULL;
641     }
642     if (remote->buf != NULL) {
643         bfree(remote->buf);
644         ss_free(remote->buf);
645     }
646     ss_free(remote->recv_ctx);
647     ss_free(remote->send_ctx);
648     ss_free(remote);
649 }
650 
651 static void
close_and_free_remote(EV_P_ remote_t * remote)652 close_and_free_remote(EV_P_ remote_t *remote)
653 {
654     if (remote != NULL) {
655         ev_timer_stop(EV_A_ & remote->send_ctx->watcher);
656         ev_io_stop(EV_A_ & remote->send_ctx->io);
657         ev_io_stop(EV_A_ & remote->recv_ctx->io);
658         close(remote->fd);
659         free_remote(remote);
660     }
661 }
662 
663 static server_t *
new_server(int fd)664 new_server(int fd)
665 {
666     server_t *server = ss_malloc(sizeof(server_t));
667     memset(server, 0, sizeof(server_t));
668 
669     server->recv_ctx = ss_malloc(sizeof(server_ctx_t));
670     server->send_ctx = ss_malloc(sizeof(server_ctx_t));
671     server->buf      = ss_malloc(sizeof(buffer_t));
672     balloc(server->buf, SOCKET_BUF_SIZE);
673     memset(server->recv_ctx, 0, sizeof(server_ctx_t));
674     memset(server->send_ctx, 0, sizeof(server_ctx_t));
675     server->fd                  = fd;
676     server->recv_ctx->server    = server;
677     server->recv_ctx->connected = 0;
678     server->send_ctx->server    = server;
679     server->send_ctx->connected = 0;
680 
681     server->e_ctx = ss_malloc(sizeof(cipher_ctx_t));
682     server->d_ctx = ss_malloc(sizeof(cipher_ctx_t));
683     crypto->ctx_init(crypto->cipher, server->e_ctx, 1);
684     crypto->ctx_init(crypto->cipher, server->d_ctx, 0);
685 
686     ev_io_init(&server->recv_ctx->io, server_recv_cb, fd, EV_READ);
687     ev_io_init(&server->send_ctx->io, server_send_cb, fd, EV_WRITE);
688 
689     ev_timer_init(&server->delayed_connect_watcher, delayed_connect_cb, 0.05,
690                   0);
691 
692     return server;
693 }
694 
695 static void
free_server(server_t * server)696 free_server(server_t *server)
697 {
698     if (server->remote != NULL) {
699         server->remote->server = NULL;
700     }
701     if (server->e_ctx != NULL) {
702         crypto->ctx_release(server->e_ctx);
703         ss_free(server->e_ctx);
704     }
705     if (server->d_ctx != NULL) {
706         crypto->ctx_release(server->d_ctx);
707         ss_free(server->d_ctx);
708     }
709     if (server->buf != NULL) {
710         bfree(server->buf);
711         ss_free(server->buf);
712     }
713     ss_free(server->recv_ctx);
714     ss_free(server->send_ctx);
715     ss_free(server);
716 }
717 
718 static void
close_and_free_server(EV_P_ server_t * server)719 close_and_free_server(EV_P_ server_t *server)
720 {
721     if (server != NULL) {
722         ev_io_stop(EV_A_ & server->send_ctx->io);
723         ev_io_stop(EV_A_ & server->recv_ctx->io);
724         ev_timer_stop(EV_A_ & server->delayed_connect_watcher);
725         close(server->fd);
726         free_server(server);
727     }
728 }
729 
730 static void
accept_cb(EV_P_ ev_io * w,int revents)731 accept_cb(EV_P_ ev_io *w, int revents)
732 {
733     listen_ctx_t *listener = (listen_ctx_t *)w;
734     struct sockaddr_storage destaddr;
735     memset(&destaddr, 0, sizeof(struct sockaddr_storage));
736 
737     int err;
738 
739     int serverfd = accept(listener->fd, NULL, NULL);
740     if (serverfd == -1) {
741         ERROR("accept");
742         return;
743     }
744 
745     err = getdestaddr(serverfd, &destaddr);
746     if (err) {
747         ERROR("getdestaddr");
748         return;
749     }
750 
751     setnonblocking(serverfd);
752     int opt = 1;
753     setsockopt(serverfd, SOL_TCP, TCP_NODELAY, &opt, sizeof(opt));
754 #ifdef SO_NOSIGPIPE
755     setsockopt(serverfd, SOL_SOCKET, SO_NOSIGPIPE, &opt, sizeof(opt));
756 #endif
757 
758     int index                    = rand() % listener->remote_num;
759     struct sockaddr *remote_addr = listener->remote_addr[index];
760 
761     int remotefd = socket(remote_addr->sa_family, SOCK_STREAM, IPPROTO_TCP);
762     if (remotefd == -1) {
763         ERROR("socket");
764         return;
765     }
766 
767     // Set flags
768     setsockopt(remotefd, SOL_TCP, TCP_NODELAY, &opt, sizeof(opt));
769 #ifdef SO_NOSIGPIPE
770     setsockopt(remotefd, SOL_SOCKET, SO_NOSIGPIPE, &opt, sizeof(opt));
771 #endif
772 
773     // Enable TCP keepalive feature
774     int keepAlive    = 1;
775     int keepIdle     = 40;
776     int keepInterval = 20;
777     int keepCount    = 5;
778     setsockopt(remotefd, SOL_SOCKET, SO_KEEPALIVE, (void *)&keepAlive, sizeof(keepAlive));
779     setsockopt(remotefd, SOL_TCP, TCP_KEEPIDLE, (void *)&keepIdle, sizeof(keepIdle));
780     setsockopt(remotefd, SOL_TCP, TCP_KEEPINTVL, (void *)&keepInterval, sizeof(keepInterval));
781     setsockopt(remotefd, SOL_TCP, TCP_KEEPCNT, (void *)&keepCount, sizeof(keepCount));
782 
783     // Set non blocking
784     setnonblocking(remotefd);
785 
786     if (listener->tos >= 0) {
787         int rc = setsockopt(remotefd, IPPROTO_IP, IP_TOS, &listener->tos, sizeof(listener->tos));
788         if (rc < 0 && errno != ENOPROTOOPT) {
789             LOGE("setting ipv4 dscp failed: %d", errno);
790         }
791         rc = setsockopt(remotefd, IPPROTO_IPV6, IPV6_TCLASS, &listener->tos, sizeof(listener->tos));
792         if (rc < 0 && errno != ENOPROTOOPT) {
793             LOGE("setting ipv6 dscp failed: %d", errno);
794         }
795     }
796 
797     // Enable MPTCP
798     if (listener->mptcp > 1) {
799         int err = setsockopt(remotefd, SOL_TCP, listener->mptcp, &opt, sizeof(opt));
800         if (err == -1) {
801             ERROR("failed to enable multipath TCP");
802         }
803     } else if (listener->mptcp == 1) {
804         int i = 0;
805         while ((listener->mptcp = mptcp_enabled_values[i]) > 0) {
806             int err = setsockopt(remotefd, SOL_TCP, listener->mptcp, &opt, sizeof(opt));
807             if (err != -1) {
808                 break;
809             }
810             i++;
811         }
812         if (listener->mptcp == 0) {
813             ERROR("failed to enable multipath TCP");
814         }
815     }
816 
817     server_t *server = new_server(serverfd);
818     remote_t *remote = new_remote(remotefd, listener->timeout);
819     server->remote   = remote;
820     remote->server   = server;
821     server->destaddr = destaddr;
822 
823     if (fast_open) {
824         // save remote addr for fast open
825         remote->addr = remote_addr;
826         ev_timer_start(EV_A_ & server->delayed_connect_watcher);
827     } else {
828         int r = connect(remotefd, remote_addr, get_sockaddr_len(remote_addr));
829 
830         if (r == -1 && errno != CONNECT_IN_PROGRESS) {
831             ERROR("connect");
832             close_and_free_remote(EV_A_ remote);
833             close_and_free_server(EV_A_ server);
834             return;
835         }
836         // listen to remote connected event
837         ev_io_start(EV_A_ & remote->send_ctx->io);
838         ev_timer_start(EV_A_ & remote->send_ctx->watcher);
839     }
840     ev_io_start(EV_A_ & server->recv_ctx->io);
841 }
842 
843 static void
signal_cb(EV_P_ ev_signal * w,int revents)844 signal_cb(EV_P_ ev_signal *w, int revents)
845 {
846     if (revents & EV_SIGNAL) {
847         switch (w->signum) {
848         case SIGCHLD:
849             if (!is_plugin_running()) {
850                 LOGE("plugin service exit unexpectedly");
851                 ret_val = -1;
852             } else
853                 return;
854         case SIGINT:
855         case SIGTERM:
856             ev_signal_stop(EV_DEFAULT, &sigint_watcher);
857             ev_signal_stop(EV_DEFAULT, &sigterm_watcher);
858             ev_signal_stop(EV_DEFAULT, &sigchld_watcher);
859 
860             ev_unloop(EV_A_ EVUNLOOP_ALL);
861         }
862     }
863 }
864 
865 int
main(int argc,char ** argv)866 main(int argc, char **argv)
867 {
868     srand(time(NULL));
869 
870     int i, c;
871     int pid_flags    = 0;
872     int mptcp        = 0;
873     int mtu          = 0;
874     char *user       = NULL;
875     char *local_port = NULL;
876     char *local_addr = NULL;
877     char *password   = NULL;
878     char *key        = NULL;
879     char *timeout    = NULL;
880     char *method     = NULL;
881     char *pid_path   = NULL;
882     char *conf_path  = NULL;
883 
884     char *plugin      = NULL;
885     char *plugin_opts = NULL;
886     char *plugin_host = NULL;
887     char *plugin_port = NULL;
888     char tmp_port[8];
889 
890     int dscp_num    = 0;
891     ss_dscp_t *dscp = NULL;
892 
893     int remote_num    = 0;
894     char *remote_port = NULL;
895     ss_addr_t remote_addr[MAX_REMOTE_NUM];
896 
897     memset(remote_addr, 0, sizeof(ss_addr_t) * MAX_REMOTE_NUM);
898 
899     static struct option long_options[] = {
900         { "fast-open",   no_argument,       NULL, GETOPT_VAL_FAST_OPEN   },
901         { "mtu",         required_argument, NULL, GETOPT_VAL_MTU         },
902         { "mptcp",       no_argument,       NULL, GETOPT_VAL_MPTCP       },
903         { "plugin",      required_argument, NULL, GETOPT_VAL_PLUGIN      },
904         { "plugin-opts", required_argument, NULL, GETOPT_VAL_PLUGIN_OPTS },
905         { "reuse-port",  no_argument,       NULL, GETOPT_VAL_REUSE_PORT  },
906         { "no-delay",    no_argument,       NULL, GETOPT_VAL_NODELAY     },
907         { "password",    required_argument, NULL, GETOPT_VAL_PASSWORD    },
908         { "key",         required_argument, NULL, GETOPT_VAL_KEY         },
909         { "help",        no_argument,       NULL, GETOPT_VAL_HELP        },
910         { NULL,          0,                 NULL, 0                      }
911     };
912 
913     opterr = 0;
914 
915     USE_TTY();
916 
917     while ((c = getopt_long(argc, argv, "f:s:p:l:k:t:m:c:b:a:n:huUTv6A",
918                             long_options, NULL)) != -1) {
919         switch (c) {
920         case GETOPT_VAL_FAST_OPEN:
921             fast_open = 1;
922             break;
923         case GETOPT_VAL_MTU:
924             mtu = atoi(optarg);
925             LOGI("set MTU to %d", mtu);
926             break;
927         case GETOPT_VAL_MPTCP:
928             mptcp = 1;
929             LOGI("enable multipath TCP");
930             break;
931         case GETOPT_VAL_NODELAY:
932             no_delay = 1;
933             LOGI("enable TCP no-delay");
934             break;
935         case GETOPT_VAL_PLUGIN:
936             plugin = optarg;
937             break;
938         case GETOPT_VAL_PLUGIN_OPTS:
939             plugin_opts = optarg;
940             break;
941         case GETOPT_VAL_KEY:
942             key = optarg;
943             break;
944         case GETOPT_VAL_REUSE_PORT:
945             reuse_port = 1;
946             break;
947         case 's':
948             if (remote_num < MAX_REMOTE_NUM) {
949                 parse_addr(optarg, &remote_addr[remote_num++]);
950             }
951             break;
952         case 'p':
953             remote_port = optarg;
954             break;
955         case 'l':
956             local_port = optarg;
957             break;
958         case GETOPT_VAL_PASSWORD:
959         case 'k':
960             password = optarg;
961             break;
962         case 'f':
963             pid_flags = 1;
964             pid_path  = optarg;
965             break;
966         case 't':
967             timeout = optarg;
968             break;
969         case 'm':
970             method = optarg;
971             break;
972         case 'c':
973             conf_path = optarg;
974             break;
975         case 'b':
976             local_addr = optarg;
977             break;
978         case 'a':
979             user = optarg;
980             break;
981 #ifdef HAVE_SETRLIMIT
982         case 'n':
983             nofile = atoi(optarg);
984             break;
985 #endif
986         case 'u':
987             mode = TCP_AND_UDP;
988             break;
989         case 'U':
990             mode = UDP_ONLY;
991             break;
992         case 'T':
993             tcp_tproxy = 1;
994             break;
995         case 'v':
996             verbose = 1;
997             break;
998         case GETOPT_VAL_HELP:
999         case 'h':
1000             usage();
1001             exit(EXIT_SUCCESS);
1002         case '6':
1003             ipv6first = 1;
1004             break;
1005         case 'A':
1006             FATAL("One time auth has been deprecated. Try AEAD ciphers instead.");
1007             break;
1008         case '?':
1009             // The option character is not recognized.
1010             LOGE("Unrecognized option: %s", optarg);
1011             opterr = 1;
1012             break;
1013         }
1014     }
1015 
1016     if (opterr) {
1017         usage();
1018         exit(EXIT_FAILURE);
1019     }
1020 
1021     if (argc == 1) {
1022         if (conf_path == NULL) {
1023             conf_path = get_default_conf();
1024         }
1025     }
1026 
1027     if (conf_path != NULL) {
1028         jconf_t *conf = read_jconf(conf_path);
1029         if (remote_num == 0) {
1030             remote_num = conf->remote_num;
1031             for (i = 0; i < remote_num; i++)
1032                 remote_addr[i] = conf->remote_addr[i];
1033         }
1034         if (remote_port == NULL) {
1035             remote_port = conf->remote_port;
1036         }
1037         if (local_addr == NULL) {
1038             local_addr = conf->local_addr;
1039         }
1040         if (local_port == NULL) {
1041             local_port = conf->local_port;
1042         }
1043         if (password == NULL) {
1044             password = conf->password;
1045         }
1046         if (key == NULL) {
1047             key = conf->key;
1048         }
1049         if (method == NULL) {
1050             method = conf->method;
1051         }
1052         if (timeout == NULL) {
1053             timeout = conf->timeout;
1054         }
1055         if (user == NULL) {
1056             user = conf->user;
1057         }
1058         if (plugin == NULL) {
1059             plugin = conf->plugin;
1060         }
1061         if (plugin_opts == NULL) {
1062             plugin_opts = conf->plugin_opts;
1063         }
1064         if (mode == TCP_ONLY) {
1065             mode = conf->mode;
1066         }
1067         if (tcp_tproxy == 0) {
1068             tcp_tproxy = conf->tcp_tproxy;
1069         }
1070         if (mtu == 0) {
1071             mtu = conf->mtu;
1072         }
1073         if (mptcp == 0) {
1074             mptcp = conf->mptcp;
1075         }
1076         if (no_delay == 0) {
1077             no_delay = conf->no_delay;
1078         }
1079         if (reuse_port == 0) {
1080             reuse_port = conf->reuse_port;
1081         }
1082         if (fast_open == 0) {
1083             fast_open = conf->fast_open;
1084         }
1085 #ifdef HAVE_SETRLIMIT
1086         if (nofile == 0) {
1087             nofile = conf->nofile;
1088         }
1089 #endif
1090         if (ipv6first == 0) {
1091             ipv6first = conf->ipv6_first;
1092         }
1093         dscp_num = conf->dscp_num;
1094         dscp     = conf->dscp;
1095     }
1096 
1097     if (remote_num == 0 || remote_port == NULL || local_port == NULL
1098         || (password == NULL && key == NULL)) {
1099         usage();
1100         exit(EXIT_FAILURE);
1101     }
1102 
1103     if (plugin != NULL) {
1104         uint16_t port = get_local_port();
1105         if (port == 0) {
1106             FATAL("failed to find a free port");
1107         }
1108         snprintf(tmp_port, 8, "%d", port);
1109         if (is_ipv6only(remote_addr, remote_num, ipv6first)) {
1110             plugin_host = "::1";
1111         } else {
1112             plugin_host = "127.0.0.1";
1113         }
1114         plugin_port = tmp_port;
1115 
1116         LOGI("plugin \"%s\" enabled", plugin);
1117     }
1118 
1119     if (method == NULL) {
1120         method = "chacha20-ietf-poly1305";
1121     }
1122 
1123     if (timeout == NULL) {
1124         timeout = "600";
1125     }
1126 
1127 #ifdef HAVE_SETRLIMIT
1128     /*
1129      * no need to check the return value here since we will show
1130      * the user an error message if setrlimit(2) fails
1131      */
1132     if (nofile > 1024) {
1133         if (verbose) {
1134             LOGI("setting NOFILE to %d", nofile);
1135         }
1136         set_nofile(nofile);
1137     }
1138 #endif
1139 
1140     if (local_addr == NULL) {
1141         if (is_ipv6only(remote_addr, remote_num, ipv6first)) {
1142             local_addr = "::1";
1143         } else {
1144             local_addr = "127.0.0.1";
1145         }
1146     }
1147 
1148     if (fast_open == 1) {
1149 #ifdef TCP_FASTOPEN
1150         LOGI("using tcp fast open");
1151 #else
1152         LOGE("tcp fast open is not supported by this environment");
1153         fast_open = 0;
1154 #endif
1155     }
1156 
1157     USE_SYSLOG(argv[0], pid_flags);
1158     if (pid_flags) {
1159         daemonize(pid_path);
1160     }
1161 
1162     if (no_delay) {
1163         LOGI("enable TCP no-delay");
1164     }
1165 
1166     if (ipv6first) {
1167         LOGI("resolving hostname to IPv6 address first");
1168     }
1169 
1170     if (plugin != NULL) {
1171         int len          = 0;
1172         size_t buf_size  = 256 * remote_num;
1173         char *remote_str = ss_malloc(buf_size);
1174 
1175         snprintf(remote_str, buf_size, "%s", remote_addr[0].host);
1176         for (int i = 1; i < remote_num; i++) {
1177             snprintf(remote_str + len, buf_size - len, "|%s", remote_addr[i].host);
1178             len = strlen(remote_str);
1179         }
1180         int err = start_plugin(plugin, plugin_opts, remote_str,
1181                                remote_port, plugin_host, plugin_port, MODE_CLIENT);
1182         if (err) {
1183             FATAL("failed to start the plugin");
1184         }
1185     }
1186 
1187     // ignore SIGPIPE
1188     signal(SIGPIPE, SIG_IGN);
1189     signal(SIGABRT, SIG_IGN);
1190 
1191     ev_signal_init(&sigint_watcher, signal_cb, SIGINT);
1192     ev_signal_init(&sigterm_watcher, signal_cb, SIGTERM);
1193     ev_signal_init(&sigchld_watcher, signal_cb, SIGCHLD);
1194     ev_signal_start(EV_DEFAULT, &sigint_watcher);
1195     ev_signal_start(EV_DEFAULT, &sigterm_watcher);
1196     ev_signal_start(EV_DEFAULT, &sigchld_watcher);
1197 
1198     // Setup keys
1199     LOGI("initializing ciphers... %s", method);
1200     crypto = crypto_init(password, key, method);
1201     if (crypto == NULL)
1202         FATAL("failed to initialize ciphers");
1203 
1204     // Setup proxy context
1205     struct listen_ctx listen_ctx;
1206     memset(&listen_ctx, 0, sizeof(struct listen_ctx));
1207     listen_ctx.remote_num  = remote_num;
1208     listen_ctx.remote_addr = ss_malloc(sizeof(struct sockaddr *) * remote_num);
1209     memset(listen_ctx.remote_addr, 0, sizeof(struct sockaddr *) * remote_num);
1210     for (i = 0; i < remote_num; i++) {
1211         char *host = remote_addr[i].host;
1212         char *port = remote_addr[i].port == NULL ? remote_port :
1213                      remote_addr[i].port;
1214         if (plugin != NULL) {
1215             host = plugin_host;
1216             port = plugin_port;
1217         }
1218         struct sockaddr_storage *storage = ss_malloc(sizeof(struct sockaddr_storage));
1219         memset(storage, 0, sizeof(struct sockaddr_storage));
1220         if (get_sockaddr(host, port, storage, 1, ipv6first) == -1) {
1221             FATAL("failed to resolve the provided hostname");
1222         }
1223         listen_ctx.remote_addr[i] = (struct sockaddr *)storage;
1224 
1225         if (plugin != NULL)
1226             break;
1227     }
1228     listen_ctx.timeout = atoi(timeout);
1229     listen_ctx.mptcp   = mptcp;
1230 
1231     struct ev_loop *loop = EV_DEFAULT;
1232 
1233     listen_ctx_t *listen_ctx_current = &listen_ctx;
1234     do {
1235         if (listen_ctx_current->tos) {
1236             LOGI("listening at %s:%s (TOS 0x%x)", local_addr, local_port, listen_ctx_current->tos);
1237         } else {
1238             LOGI("listening at %s:%s", local_addr, local_port);
1239         }
1240 
1241         if (mode != UDP_ONLY) {
1242             // Setup socket
1243             int listenfd;
1244             listenfd = create_and_bind(local_addr, local_port);
1245             if (listenfd == -1) {
1246                 FATAL("bind() error");
1247             }
1248             if (listen(listenfd, SOMAXCONN) == -1) {
1249                 FATAL("listen() error");
1250             }
1251             setnonblocking(listenfd);
1252 
1253             listen_ctx_current->fd = listenfd;
1254 
1255             ev_io_init(&listen_ctx_current->io, accept_cb, listenfd, EV_READ);
1256             ev_io_start(loop, &listen_ctx_current->io);
1257         }
1258 
1259         // Setup UDP
1260         if (mode != TCP_ONLY) {
1261             LOGI("UDP relay enabled");
1262             char *host                       = remote_addr[0].host;
1263             char *port                       = remote_addr[0].port == NULL ? remote_port : remote_addr[0].port;
1264             struct sockaddr_storage *storage = ss_malloc(sizeof(struct sockaddr_storage));
1265             memset(storage, 0, sizeof(struct sockaddr_storage));
1266             if (get_sockaddr(host, port, storage, 1, ipv6first) == -1) {
1267                 FATAL("failed to resolve the provided hostname");
1268             }
1269             struct sockaddr *addr = (struct sockaddr *)storage;
1270             init_udprelay(local_addr, local_port, addr,
1271                           get_sockaddr_len(addr), mtu, crypto, listen_ctx_current->timeout, NULL);
1272         }
1273 
1274         if (mode == UDP_ONLY) {
1275             LOGI("TCP relay disabled");
1276         }
1277 
1278         // Handle additionals TOS/DSCP listening ports
1279         if (dscp_num > 0) {
1280             listen_ctx_current      = (listen_ctx_t *)ss_malloc(sizeof(listen_ctx_t));
1281             listen_ctx_current      = memcpy(listen_ctx_current, &listen_ctx, sizeof(listen_ctx_t));
1282             local_port              = dscp[dscp_num - 1].port;
1283             listen_ctx_current->tos = dscp[dscp_num - 1].dscp << 2;
1284         }
1285     } while (dscp_num-- > 0);
1286 
1287     // setuid
1288     if (user != NULL && !run_as(user)) {
1289         FATAL("failed to switch user");
1290     }
1291 
1292     if (geteuid() == 0) {
1293         LOGI("running from root user");
1294     }
1295 
1296     ev_run(loop, 0);
1297 
1298     if (plugin != NULL) {
1299         stop_plugin();
1300     }
1301 
1302     return ret_val;
1303 }
1304