1 /*
2  * Portions created by SGI are Copyright (C) 2000 Silicon Graphics, Inc.
3  * All Rights Reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  * 3. Neither the name of Silicon Graphics, Inc. nor the names of its
15  *    contributors may be used to endorse or promote products derived from
16  *    this software without specific prior written permission.
17  *
18  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22  * HOLDERS AND CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
24  * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
25  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
26  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
27  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  */
30 
31 #include <stdio.h>
32 #include <stdlib.h>
33 #include <string.h>
34 #include <signal.h>
35 #include <unistd.h>
36 #include <fcntl.h>
37 #include <sys/types.h>
38 #include <sys/stat.h>
39 #include <sys/socket.h>
40 #include <netinet/in.h>
41 #include <arpa/inet.h>
42 #include <netdb.h>
43 #include "st.h"
44 
45 #define IOBUFSIZE (16*1024)
46 
47 #define IOV_LEN   256
48 #define IOV_COUNT (IOBUFSIZE / IOV_LEN)
49 
50 #ifndef INADDR_NONE
51 #define INADDR_NONE 0xffffffff
52 #endif
53 
54 static char *prog;                     /* Program name   */
55 static struct sockaddr_in rmt_addr;    /* Remote address */
56 
57 static unsigned long testing;
58 #define TESTING_VERBOSE		0x1
59 #define TESTING_READV		0x2
60 #define	TESTING_READ_RESID	0x4
61 #define TESTING_WRITEV		0x8
62 #define TESTING_WRITE_RESID	0x10
63 
64 static void read_address(const char *str, struct sockaddr_in *sin);
65 static void start_daemon(void);
66 static int  cpu_count(void);
67 static void set_concurrency(int nproc);
68 static void *handle_request(void *arg);
69 static void print_sys_error(const char *msg);
70 
71 
72 /*
73  * This program acts as a generic gateway. It listens for connections
74  * to a local address ('-l' option). Upon accepting a client connection,
75  * it connects to the specified remote address ('-r' option) and then
76  * just pumps the data through without any modification.
77  */
main(int argc,char * argv[])78 int main(int argc, char *argv[])
79 {
80   extern char *optarg;
81   int opt, sock, n;
82   int laddr, raddr, num_procs, alt_ev, one_process;
83   int serialize_accept = 0;
84   struct sockaddr_in lcl_addr, cli_addr;
85   st_netfd_t cli_nfd, srv_nfd;
86 
87   prog = argv[0];
88   num_procs = laddr = raddr = alt_ev = one_process = 0;
89 
90   /* Parse arguments */
91   while((opt = getopt(argc, argv, "l:r:p:Saht:X")) != EOF) {
92     switch (opt) {
93     case 'a':
94       alt_ev = 1;
95       break;
96     case 'l':
97       read_address(optarg, &lcl_addr);
98       laddr = 1;
99       break;
100     case 'r':
101       read_address(optarg, &rmt_addr);
102       if (rmt_addr.sin_addr.s_addr == INADDR_ANY) {
103 	fprintf(stderr, "%s: invalid remote address: %s\n", prog, optarg);
104 	exit(1);
105       }
106       raddr = 1;
107       break;
108     case 'p':
109       num_procs = atoi(optarg);
110       if (num_procs < 1) {
111 	fprintf(stderr, "%s: invalid number of processes: %s\n", prog, optarg);
112 	exit(1);
113       }
114       break;
115     case 'S':
116       /*
117        * Serialization decision is tricky on some platforms. For example,
118        * Solaris 2.6 and above has kernel sockets implementation, so supposedly
119        * there is no need for serialization. The ST library may be compiled
120        * on one OS version, but used on another, so the need for serialization
121        * should be determined at run time by the application. Since it's just
122        * an example, the serialization decision is left up to user.
123        * Only on platforms where the serialization is never needed on any OS
124        * version st_netfd_serialize_accept() is a no-op.
125        */
126       serialize_accept = 1;
127       break;
128     case 't':
129       testing = strtoul(optarg, NULL, 0);
130       break;
131     case 'X':
132       one_process = 1;
133       break;
134     case 'h':
135     case '?':
136       fprintf(stderr, "Usage: %s [options] -l <[host]:port> -r <host:port>\n",
137        prog);
138       fprintf(stderr, "options are:\n");
139       fprintf(stderr, "  -p <num_processes>	number of parallel processes\n");
140       fprintf(stderr, "  -S			serialize accepts\n");
141       fprintf(stderr, "  -a			use alternate event system\n");
142 #ifdef DEBUG
143       fprintf(stderr, "  -t mask		testing/debugging mode\n");
144       fprintf(stderr, "  -X			one process, don't daemonize\n");
145 #endif
146       exit(1);
147     }
148   }
149   if (!laddr) {
150     fprintf(stderr, "%s: local address required\n", prog);
151     exit(1);
152   }
153   if (!raddr) {
154     fprintf(stderr, "%s: remote address required\n", prog);
155     exit(1);
156   }
157   if (num_procs == 0)
158     num_procs = cpu_count();
159 
160   fprintf(stderr, "%s: starting proxy daemon on %s:%d\n", prog,
161 	  inet_ntoa(lcl_addr.sin_addr), ntohs(lcl_addr.sin_port));
162 
163   /* Start the daemon */
164   if (one_process)
165     num_procs = 1;
166   else
167     start_daemon();
168 
169   if (alt_ev)
170     st_set_eventsys(ST_EVENTSYS_ALT);
171 
172   /* Initialize the ST library */
173   if (st_init() < 0) {
174     print_sys_error("st_init");
175     exit(1);
176   }
177 
178   /* Create and bind listening socket */
179   if ((sock = socket(PF_INET, SOCK_STREAM, 0)) < 0) {
180     print_sys_error("socket");
181     exit(1);
182   }
183   n = 1;
184   if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char *)&n, sizeof(n)) < 0) {
185     print_sys_error("setsockopt");
186     exit(1);
187   }
188   if (bind(sock, (struct sockaddr *)&lcl_addr, sizeof(lcl_addr)) < 0) {
189     print_sys_error("bind");
190     exit(1);
191   }
192   listen(sock, 128);
193   if ((srv_nfd = st_netfd_open_socket(sock)) == NULL) {
194     print_sys_error("st_netfd_open");
195     exit(1);
196   }
197   /* See the comment regarding serialization decision above */
198   if (num_procs > 1 && serialize_accept && st_netfd_serialize_accept(srv_nfd)
199       < 0) {
200     print_sys_error("st_netfd_serialize_accept");
201     exit(1);
202   }
203 
204   /* Start server processes */
205   if (!one_process)
206     set_concurrency(num_procs);
207 
208   for ( ; ; ) {
209     n = sizeof(cli_addr);
210     cli_nfd = st_accept(srv_nfd, (struct sockaddr *)&cli_addr, &n,
211      ST_UTIME_NO_TIMEOUT);
212     if (cli_nfd == NULL) {
213       print_sys_error("st_accept");
214       exit(1);
215     }
216     if (st_thread_create(handle_request, cli_nfd, 0, 0) == NULL) {
217       print_sys_error("st_thread_create");
218       exit(1);
219     }
220   }
221 
222   /* NOTREACHED */
223   return 1;
224 }
225 
226 
read_address(const char * str,struct sockaddr_in * sin)227 static void read_address(const char *str, struct sockaddr_in *sin)
228 {
229   char host[128], *p;
230   struct hostent *hp;
231   unsigned short port;
232 
233   strcpy(host, str);
234   if ((p = strchr(host, ':')) == NULL) {
235     fprintf(stderr, "%s: invalid address: %s\n", prog, host);
236     exit(1);
237   }
238   *p++ = '\0';
239   port = (unsigned short) atoi(p);
240   if (port < 1) {
241     fprintf(stderr, "%s: invalid port: %s\n", prog, p);
242     exit(1);
243   }
244 
245   memset(sin, 0, sizeof(struct sockaddr_in));
246   sin->sin_family = AF_INET;
247   sin->sin_port = htons(port);
248   if (host[0] == '\0') {
249     sin->sin_addr.s_addr = INADDR_ANY;
250     return;
251   }
252   sin->sin_addr.s_addr = inet_addr(host);
253   if (sin->sin_addr.s_addr == INADDR_NONE) {
254     /* not dotted-decimal */
255     if ((hp = gethostbyname(host)) == NULL) {
256       fprintf(stderr, "%s: can't resolve address: %s\n", prog, host);
257       exit(1);
258     }
259     memcpy(&sin->sin_addr, hp->h_addr, hp->h_length);
260   }
261 }
262 
263 #ifdef DEBUG
show_iov(const struct iovec * iov,int niov)264 static void show_iov(const struct iovec *iov, int niov)
265 {
266   int i;
267   size_t total;
268 
269   printf("iov %p has %d entries:\n", iov, niov);
270   total = 0;
271   for (i = 0; i < niov; i++) {
272     printf("iov[%3d] iov_base=%p iov_len=0x%lx(%lu)\n",
273      i, iov[i].iov_base, (unsigned long) iov[i].iov_len,
274      (unsigned long) iov[i].iov_len);
275     total += iov[i].iov_len;
276   }
277   printf("total 0x%lx(%ld)\n", (unsigned long) total, (unsigned long) total);
278 }
279 
280 /*
281  * This version is tricked out to test all the
282  * st_(read|write)v?(_resid)? variants.  Use the non-DEBUG version for
283  * anything serious.  st_(read|write) are all this function really
284  * needs.
285  */
pass(st_netfd_t in,st_netfd_t out)286 static int pass(st_netfd_t in, st_netfd_t out)
287 {
288   char buf[IOBUFSIZE];
289   struct iovec iov[IOV_COUNT];
290   int ioviter, nw, nr;
291 
292   if (testing & TESTING_READV) {
293     for (ioviter = 0; ioviter < IOV_COUNT; ioviter++) {
294       iov[ioviter].iov_base = &buf[ioviter * IOV_LEN];
295       iov[ioviter].iov_len = IOV_LEN;
296     }
297     if (testing & TESTING_VERBOSE) {
298       printf("readv(%p)...\n", in);
299       show_iov(iov, IOV_COUNT);
300     }
301     if (testing & TESTING_READ_RESID) {
302       struct iovec *riov = iov;
303       int riov_cnt = IOV_COUNT;
304       if (st_readv_resid(in, &riov, &riov_cnt, ST_UTIME_NO_TIMEOUT) == 0) {
305 	if (testing & TESTING_VERBOSE) {
306 	  printf("resid\n");
307 	  show_iov(riov, riov_cnt);
308 	  printf("full\n");
309 	  show_iov(iov, IOV_COUNT);
310 	}
311 	nr = 0;
312 	for (ioviter = 0; ioviter < IOV_COUNT; ioviter++)
313 	  nr += iov[ioviter].iov_len;
314 	nr = IOBUFSIZE - nr;
315       } else
316 	nr = -1;
317     } else
318       nr = (int) st_readv(in, iov, IOV_COUNT, ST_UTIME_NO_TIMEOUT);
319   } else {
320     if (testing & TESTING_READ_RESID) {
321       size_t resid = IOBUFSIZE;
322       if (st_read_resid(in, buf, &resid, ST_UTIME_NO_TIMEOUT) == 0)
323 	nr = IOBUFSIZE - resid;
324       else
325 	nr = -1;
326     } else
327       nr = (int) st_read(in, buf, IOBUFSIZE, ST_UTIME_NO_TIMEOUT);
328   }
329   if (testing & TESTING_VERBOSE)
330     printf("got 0x%x(%d) E=%d\n", nr, nr, errno);
331 
332   if (nr <= 0)
333     return 0;
334 
335   if (testing & TESTING_WRITEV) {
336     for (nw = 0, ioviter = 0; nw < nr;
337      nw += iov[ioviter].iov_len, ioviter++) {
338       iov[ioviter].iov_base = &buf[nw];
339       iov[ioviter].iov_len = nr - nw;
340       if (iov[ioviter].iov_len > IOV_LEN)
341 	iov[ioviter].iov_len = IOV_LEN;
342     }
343     if (testing & TESTING_VERBOSE) {
344       printf("writev(%p)...\n", out);
345       show_iov(iov, ioviter);
346     }
347     if (testing & TESTING_WRITE_RESID) {
348       struct iovec *riov = iov;
349       int riov_cnt = ioviter;
350       if (st_writev_resid(out, &riov, &riov_cnt, ST_UTIME_NO_TIMEOUT) == 0) {
351 	if (testing & TESTING_VERBOSE) {
352 	  printf("resid\n");
353 	  show_iov(riov, riov_cnt);
354 	  printf("full\n");
355 	  show_iov(iov, ioviter);
356 	}
357 	nw = 0;
358 	while (--ioviter >= 0)
359 	  nw += iov[ioviter].iov_len;
360 	nw = nr - nw;
361       } else
362 	nw = -1;
363     } else
364       nw = st_writev(out, iov, ioviter, ST_UTIME_NO_TIMEOUT);
365   } else {
366     if (testing & TESTING_WRITE_RESID) {
367       size_t resid = nr;
368       if (st_write_resid(out, buf, &resid, ST_UTIME_NO_TIMEOUT) == 0)
369 	nw = nr - resid;
370       else
371 	nw = -1;
372     } else
373       nw = st_write(out, buf, nr, ST_UTIME_NO_TIMEOUT);
374   }
375   if (testing & TESTING_VERBOSE)
376     printf("put 0x%x(%d) E=%d\n", nw, nw, errno);
377 
378   if (nw != nr)
379     return 0;
380 
381   return 1;
382 }
383 #else /* DEBUG */
384 /*
385  * This version is the simple one suitable for serious use.
386  */
pass(st_netfd_t in,st_netfd_t out)387 static int pass(st_netfd_t in, st_netfd_t out)
388 {
389   char buf[IOBUFSIZE];
390   int nw, nr;
391 
392   nr = (int) st_read(in, buf, IOBUFSIZE, ST_UTIME_NO_TIMEOUT);
393   if (nr <= 0)
394     return 0;
395 
396   nw = st_write(out, buf, nr, ST_UTIME_NO_TIMEOUT);
397   if (nw != nr)
398     return 0;
399 
400   return 1;
401 }
402 #endif
403 
handle_request(void * arg)404 static void *handle_request(void *arg)
405 {
406   struct pollfd pds[2];
407   st_netfd_t cli_nfd, rmt_nfd;
408   int sock;
409 
410   cli_nfd = (st_netfd_t) arg;
411   pds[0].fd = st_netfd_fileno(cli_nfd);
412   pds[0].events = POLLIN;
413 
414   /* Connect to remote host */
415   if ((sock = socket(PF_INET, SOCK_STREAM, 0)) < 0) {
416     print_sys_error("socket");
417     goto done;
418   }
419   if ((rmt_nfd = st_netfd_open_socket(sock)) == NULL) {
420     print_sys_error("st_netfd_open_socket");
421     close(sock);
422     goto done;
423   }
424   if (st_connect(rmt_nfd, (struct sockaddr *)&rmt_addr,
425 		 sizeof(rmt_addr), ST_UTIME_NO_TIMEOUT) < 0) {
426     print_sys_error("st_connect");
427     st_netfd_close(rmt_nfd);
428     goto done;
429   }
430   pds[1].fd = sock;
431   pds[1].events = POLLIN;
432 
433   /*
434    * Now just pump the data through.
435    * XXX This should use one thread for each direction for true full-duplex.
436    */
437   for ( ; ; ) {
438     pds[0].revents = 0;
439     pds[1].revents = 0;
440 
441     if (st_poll(pds, 2, ST_UTIME_NO_TIMEOUT) <= 0) {
442       print_sys_error("st_poll");
443       break;
444     }
445 
446     if (pds[0].revents & POLLIN) {
447       if (!pass(cli_nfd, rmt_nfd))
448 	break;
449     }
450 
451     if (pds[1].revents & POLLIN) {
452       if (!pass(rmt_nfd, cli_nfd))
453 	break;
454     }
455   }
456   st_netfd_close(rmt_nfd);
457 
458 done:
459 
460   st_netfd_close(cli_nfd);
461 
462   return NULL;
463 }
464 
start_daemon(void)465 static void start_daemon(void)
466 {
467   pid_t pid;
468 
469   /* Start forking */
470   if ((pid = fork()) < 0) {
471     print_sys_error("fork");
472     exit(1);
473   }
474   if (pid > 0)
475     exit(0);                        /* parent */
476 
477   /* First child process */
478   setsid();                         /* become session leader */
479 
480   if ((pid = fork()) < 0) {
481     print_sys_error("fork");
482     exit(1);
483   }
484   if (pid > 0)                      /* first child */
485     exit(0);
486 
487   chdir("/");
488   umask(022);
489 }
490 
491 /*
492  * Create separate processes ("virtual processors"). Since it's just an
493  * example, there is no watchdog - the parent just exits leaving children
494  * on their own.
495  */
set_concurrency(int nproc)496 static void set_concurrency(int nproc)
497 {
498   pid_t pid;
499   int i;
500 
501   if (nproc < 1)
502     nproc = 1;
503 
504   for (i = 0; i < nproc; i++) {
505     if ((pid = fork()) < 0) {
506       print_sys_error("fork");
507       exit(1);
508     }
509     /* Child returns */
510     if (pid == 0)
511       return;
512   }
513 
514   /* Parent just exits */
515   exit(0);
516 }
517 
cpu_count(void)518 static int cpu_count(void)
519 {
520   int n;
521 
522 #if defined (_SC_NPROCESSORS_ONLN)
523   n = (int) sysconf(_SC_NPROCESSORS_ONLN);
524 #elif defined (_SC_NPROC_ONLN)
525   n = (int) sysconf(_SC_NPROC_ONLN);
526 #elif defined (HPUX)
527 #include <sys/mpctl.h>
528   n = mpctl(MPC_GETNUMSPUS, 0, 0);
529 #else
530   n = -1;
531   errno = ENOSYS;
532 #endif
533 
534   return n;
535 }
536 
print_sys_error(const char * msg)537 static void print_sys_error(const char *msg)
538 {
539   fprintf(stderr, "%s: %s: %s\n", prog, msg, strerror(errno));
540 }
541 
542