1 /*
2 
3 The MIT License (MIT)
4 
5 Copyright (c) 2016 Wrymouth Innovation Ltd
6 
7 Permission is hereby granted, free of charge, to any person obtaining a
8 copy of this software and associated documentation files (the "Software"),
9 to deal in the Software without restriction, including without limitation
10 the rights to use, copy, modify, merge, publish, distribute, sublicense,
11 and/or sell copies of the Software, and to permit persons to whom the
12 Software is furnished to do so, subject to the following conditions:
13 
14 The above copyright notice and this permission notice shall be included
15 in all copies or substantial portions of the Software.
16 
17 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
20 THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
21 OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
22 ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
23 OTHER DEALINGS IN THE SOFTWARE.
24 
25 */
26 
27 #include "config.h"
28 
29 #include <errno.h>
30 #include <getopt.h>
31 #include <netdb.h>
32 #include <signal.h>
33 #include <stdio.h>
34 #include <stdlib.h>
35 #include <string.h>
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 #include <unistd.h>
39 
40 #include "crypto-gnutls.h"
41 
42 static char *connectaddr = NULL;
43 static char *listenaddr = NULL;
44 static char *keyfile = NULL;
45 static char *certfile = NULL;
46 static char *cacertfile = NULL;
47 static char *hostname = NULL;
48 static int debug = 0;
49 static int insecure = 0;
50 static int nofork = 0;
51 static int server = 0;
52 
53 static const char *defaultport = "12345";
54 
55 static volatile sig_atomic_t rxsigquit = 0;
56 
57 static int
bindtoaddress(char * addrport)58 bindtoaddress (char *addrport)
59 {
60   struct addrinfo hints;
61   struct addrinfo *result, *rp;
62   int fd, s;
63   char addr[128];
64 
65   snprintf(addr, sizeof(addr), "%s", addrport);
66 
67   memset (&hints, 0, sizeof (struct addrinfo));
68   hints.ai_flags = AI_PASSIVE;	/* For wildcard IP address */
69   hints.ai_family = AF_UNSPEC;	/* Allow IPv4 or IPv6 */
70   hints.ai_socktype = SOCK_STREAM;	/* Stream socket */
71   hints.ai_protocol = 0;	/* any protocol */
72 
73   char *colon = strrchr (addr, ':');
74   const char *port = defaultport;
75   if (colon)
76     {
77       *colon = 0;
78       port = colon + 1;
79     }
80 
81   s = getaddrinfo (addr, port, &hints, &result);
82   if (s != 0)
83     {
84       fprintf (stderr, "Error in address %s: %s\n", addr, gai_strerror (s));
85       return -1;
86     }
87 
88   /* attempt to bind to each address */
89 
90   for (rp = result; rp != NULL; rp = rp->ai_next)
91     {
92       fd = socket (rp->ai_family, rp->ai_socktype, rp->ai_protocol);
93 
94       if (fd >= 0)
95 	{
96 	  int one = 1;
97 	  if (setsockopt (fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof (one)) <
98 	      0)
99 	    {
100 	      close (fd);
101 	      continue;
102 	    }
103 	  if (bind (fd, rp->ai_addr, rp->ai_addrlen) == 0)
104 	    break;
105 	  close (fd);
106 	}
107     }
108 
109   if (!rp)
110     {
111       fprintf (stderr, "Error binding to %s:%s: %m\n", addr, port);
112       return -1;
113     }
114 
115   freeaddrinfo (result);	/* No longer needed */
116 
117   if (listen (fd, 5) < 0)
118     {
119       close (fd);
120       return -1;
121     }
122 
123   return fd;
124 }
125 
126 static int
connecttoaddress(char * addrport)127 connecttoaddress (char *addrport)
128 {
129   struct addrinfo hints;
130   struct addrinfo *result, *rp;
131   int fd, s;
132   char addr[128];
133 
134   snprintf(addr, sizeof(addr), "%s", addrport);
135 
136   memset (&hints, 0, sizeof (struct addrinfo));
137   hints.ai_flags = AI_PASSIVE;	/* For wildcard IP address */
138   hints.ai_family = AF_UNSPEC;	/* Allow IPv4 or IPv6 */
139   hints.ai_socktype = SOCK_STREAM;	/* Stream socket */
140   hints.ai_protocol = 0;	/* any protocol */
141 
142   char *colon = strrchr (addr, ':');
143   const char *port = defaultport;
144   if (colon)
145     {
146       *colon = 0;
147       port = colon + 1;
148     }
149 
150   if (!hostname && !server)
151     hostname = strdup (addr);
152 
153   s = getaddrinfo (addr, port, &hints, &result);
154   if (s != 0)
155     {
156       fprintf (stderr, "Error in address %s: %s\n", addr, gai_strerror (s));
157       return -1;
158     }
159 
160   /* attempt to connect to each address */
161   for (rp = result; rp != NULL; rp = rp->ai_next)
162     {
163       fd = socket (rp->ai_family, rp->ai_socktype, rp->ai_protocol);
164       if (fd >= 0)
165 	{
166 	  if (connect (fd, rp->ai_addr, rp->ai_addrlen) == 0)
167 	    break;
168 	  close (fd);
169 	}
170     }
171 
172   if (!rp)
173     {
174       fprintf (stderr, "Error connecting to %s:%s: %m\n", addr, port);
175       return -1;
176     }
177 
178   freeaddrinfo (result);	/* No longer needed */
179 
180   return fd;
181 }
182 
183 static int
quitfn(void * opaque)184 quitfn (void *opaque)
185 {
186   return rxsigquit;
187 }
188 
189 static int
runproxy(int acceptfd)190 runproxy (int acceptfd)
191 {
192   int connectfd;
193   if ((connectfd = connecttoaddress (connectaddr)) < 0)
194     {
195       fprintf (stderr, "Could not connect\n");
196       close (acceptfd);
197       return -1;
198     }
199 
200   tlssession_t *session =
201     tlssession_new (server, keyfile, certfile, cacertfile, hostname, insecure,
202 		    debug, quitfn, NULL, NULL);
203   if (!session)
204     {
205       fprintf (stderr, "Could create TLS session\n");
206       close (connectfd);
207       close (acceptfd);
208       return -1;
209     }
210 
211   int ret;
212   if (server)
213     ret = tlssession_mainloop (acceptfd, connectfd, session);
214   else
215     ret = tlssession_mainloop (connectfd, acceptfd, session);
216 
217   tlssession_close (session);
218   close (connectfd);
219   close (acceptfd);
220 
221   if (ret < 0)
222     {
223       fprintf (stderr, "TLS proxy exited with an error\n");
224       return -1;
225     }
226   return 0;
227 }
228 
229 static int
runlistener(void)230 runlistener (void)
231 {
232   int listenfd;
233   if ((listenfd = bindtoaddress (listenaddr)) < 0)
234     {
235       fprintf (stderr, "Could not bind listener\n");
236       return -1;
237     }
238 
239   /*
240      if (!nofork)
241      daemon (FALSE, FALSE);
242    */
243 
244   int fd;
245   while (!rxsigquit)
246     {
247       do
248 	{
249 	  if ((fd = accept (listenfd, NULL, NULL)) < 0)
250 	    {
251 	      if (errno != EINTR)
252 		{
253 		  fprintf (stderr, "Accept failed\n");
254 		  return -1;
255 		}
256 	    }
257 	}
258       while (fd < 0 && !rxsigquit);
259       if (rxsigquit)
260 	break;
261       if (nofork < 2)
262 	{
263 	  int ret = runproxy (fd);
264 	  if (ret < 0)
265 	    return -1;
266 	}
267       else
268 	{
269 	  int cpid = fork ();
270 	  if (cpid == 0)
271 	    {
272 	      /* we're the child */
273 	      runproxy (fd);
274 	      exit (0);
275 	    }
276 	  else
277 	    close (fd);
278 	}
279     }
280   return 0;
281 }
282 
283 static void
usage(void)284 usage (void)
285 {
286   fprintf (stderr, "tlsproxy\n\n\
287 Usage:\n\
288      tlsproxy [OPTIONS]\n\
289 \n\
290 A TLS client or server proxy\n\
291 \n\
292 Options:\n\
293      -c, --connect ADDRRESS    Connect to ADDRESS\n\
294      -l, --listen ADDRESS      Listen on ADDRESS\n\
295      -K, --key FILE            Use FILE as private key\n\
296      -C, --cert FILE           Use FILE as public key\n\
297      -A, --cacert FILE         Use FILE as public CA cert file\n\
298      -H, --hostname HOSTNAME   Use HOSTNAME to validate the CN of the peer\n\
299                                rather than hostname extracted from -C option\n\
300      -s, --server              Run the listen port encrypted rather than the\n\
301                                connect port\n\
302      -i, --insecure            Do not validate certificates\n\
303      -n, --nofork              Do not fork off (aids debugging); specify twice\n\
304                                to stop forking on accept as well\n\
305      -d, --debug               Turn on debugging\n\
306      -h, --help                Show this usage message\n\
307 \n\
308 \n");
309 }
310 
311 static void
processoptions(int argc,char ** argv)312 processoptions (int argc, char **argv)
313 {
314   while (1)
315     {
316       static const struct option longopts[] = {
317 	{"connect", required_argument, 0, 'c'},
318 	{"listen", required_argument, 0, 'l'},
319 	{"key", required_argument, 0, 'K'},
320 	{"cert", required_argument, 0, 'C'},
321 	{"cacert", required_argument, 0, 'A'},
322 	{"hostname", required_argument, 0, 'H'},
323 	{"server", no_argument, 0, 's'},
324 	{"insecure", no_argument, 0, 'i'},
325 	{"nofork", no_argument, 0, 'n'},
326 	{"debug", no_argument, 0, 'd'},
327 	{"help", no_argument, 0, 'h'},
328 	{0, 0, 0, 0}
329       };
330 
331       int optidx = 0;
332 
333       int c =
334 	getopt_long (argc, argv, "c:l:K:C:A:H:sindh", longopts, &optidx);
335       if (c == -1)
336 	break;
337 
338       switch (c)
339 	{
340 	case 0:		/* set a flag, nothing else to do */
341 	  break;
342 
343 	case 'c':
344 	  connectaddr = strdup (optarg);
345 	  break;
346 
347 	case 'l':
348 	  listenaddr = strdup (optarg);
349 	  break;
350 
351 	case 'K':
352 	  keyfile = strdup (optarg);
353 	  break;
354 
355 	case 'C':
356 	  certfile = strdup (optarg);
357 	  break;
358 
359 	case 'A':
360 	  cacertfile = strdup (optarg);
361 	  break;
362 
363 	case 'H':
364 	  hostname = strdup (optarg);
365 	  break;
366 
367 	case 's':
368 	  server = 1;
369 	  break;
370 
371 	case 'i':
372 	  insecure = 1;
373 	  break;
374 
375 	case 'n':
376 	  nofork++;
377 	  break;
378 
379 	case 'd':
380 	  debug++;
381 	  break;
382 
383 	case 'h':
384 	  usage ();
385 	  exit (0);
386 	  break;
387 
388 	default:
389 	  usage ();
390 	  exit (1);
391 	}
392     }
393 
394   if (optind != argc || !connectaddr || !listenaddr)
395     {
396       usage ();
397       exit (1);
398     }
399 
400   if (!certfile && keyfile)
401     certfile = strdup (keyfile);
402 }
403 
404 static void
handlesignal(int sig)405 handlesignal (int sig)
406 {
407   switch (sig)
408     {
409     case SIGINT:
410     case SIGTERM:
411       rxsigquit++;
412       break;
413     default:
414       break;
415     }
416 }
417 
418 static void
setsignalmasks(void)419 setsignalmasks (void)
420 {
421   struct sigaction sa;
422   /* Set up the structure to specify the new action. */
423   memset (&sa, 0, sizeof (struct sigaction));
424   sa.sa_handler = handlesignal;
425   sigemptyset (&sa.sa_mask);
426   sa.sa_flags = 0;
427   sigaction (SIGINT, &sa, NULL);
428   sigaction (SIGTERM, &sa, NULL);
429 
430   memset (&sa, 0, sizeof (struct sigaction));
431   sa.sa_handler = SIG_IGN;
432   sa.sa_flags = SA_RESTART;
433   sigaction (SIGPIPE, &sa, NULL);
434 }
435 
436 int
main(int argc,char ** argv)437 main (int argc, char **argv)
438 {
439   processoptions (argc, argv);
440 
441   setsignalmasks ();
442 
443   if (tlssession_init ())
444     exit (1);
445 
446   runlistener ();
447 
448   free (connectaddr);
449   free (listenaddr);
450   free (keyfile);
451   free (certfile);
452   free (cacertfile);
453   free (hostname);
454 
455   exit (0);
456 }
457