1 /*
2  * Copyright (c) 2011-2014, Dustin Lundquist <dustin@null-ptr.net>
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice,
9  *    this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
18  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
24  * POSSIBILITY OF SUCH DAMAGE.
25  */
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <stdint.h>
29 #include <unistd.h>
30 #include <string.h>
31 #include <errno.h>
32 #include <sys/queue.h>
33 #include <sys/types.h>
34 #include <sys/socket.h>
35 #include <netinet/in.h>
36 #include <netdb.h> /* getaddrinfo */
37 #include <unistd.h> /* close */
38 #include <fcntl.h>
39 #include <arpa/inet.h>
40 #include <ev.h>
41 #include <assert.h>
42 #include "connection.h"
43 #include "resolv.h"
44 #include "address.h"
45 #include "protocol.h"
46 #include "logger.h"
47 
48 
49 #define IS_TEMPORARY_SOCKERR(_errno) (_errno == EAGAIN || \
50                                       _errno == EWOULDBLOCK || \
51                                       _errno == EINTR)
52 
53 
54 struct resolv_cb_data {
55     struct Connection *connection;
56     const struct Address *address;
57     struct ev_loop *loop;
58     int cb_free_addr;
59 };
60 
61 
62 static TAILQ_HEAD(ConnectionHead, Connection) connections;
63 
64 
65 static inline int client_socket_open(const struct Connection *);
66 static inline int server_socket_open(const struct Connection *);
67 
68 static void reactivate_watcher(struct ev_loop *, struct ev_io *,
69         const struct Buffer *, const struct Buffer *);
70 
71 static void connection_cb(struct ev_loop *, struct ev_io *, int);
72 static void resolv_cb(struct Address *, void *);
73 static void reactivate_watchers(struct Connection *, struct ev_loop *);
74 static void insert_proxy_v1_header(struct Connection *);
75 static void parse_client_request(struct Connection *);
76 static void resolve_server_address(struct Connection *, struct ev_loop *);
77 static void initiate_server_connect(struct Connection *, struct ev_loop *);
78 static void close_connection(struct Connection *, struct ev_loop *);
79 static void close_client_socket(struct Connection *, struct ev_loop *);
80 static void abort_connection(struct Connection *);
81 static void close_server_socket(struct Connection *, struct ev_loop *);
82 static struct Connection *new_connection(struct ev_loop *);
83 static void log_connection(struct Connection *);
84 static void log_bad_request(struct Connection *, const char *, size_t, int);
85 static void free_connection(struct Connection *);
86 static void print_connection(FILE *, const struct Connection *);
87 static void free_resolv_cb_data(struct resolv_cb_data *);
88 
89 
90 void
init_connections()91 init_connections() {
92     TAILQ_INIT(&connections);
93 }
94 
95 /**
96  * Accept a new incoming connection
97  *
98  * Returns 1 on success or 0 on error;
99  */
100 int
accept_connection(struct Listener * listener,struct ev_loop * loop)101 accept_connection(struct Listener *listener, struct ev_loop *loop) {
102     struct Connection *con = new_connection(loop);
103     if (con == NULL) {
104         err("new_connection failed");
105         return 0;
106     }
107     con->listener = listener_ref_get(listener);
108 
109 #ifdef HAVE_ACCEPT4
110     int sockfd = accept4(listener->watcher.fd,
111                     (struct sockaddr *)&con->client.addr,
112                     &con->client.addr_len,
113                     SOCK_NONBLOCK);
114 #else
115     int sockfd = accept(listener->watcher.fd,
116                     (struct sockaddr *)&con->client.addr,
117                     &con->client.addr_len);
118 #endif
119     if (sockfd < 0) {
120         int saved_errno = errno;
121 
122         warn("accept failed: %s", strerror(errno));
123         free_connection(con);
124 
125         errno = saved_errno;
126         return 0;
127     }
128 
129 #ifndef HAVE_ACCEPT4
130     int flags = fcntl(sockfd, F_GETFL, 0);
131     fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
132 #endif
133 
134     if (getsockname(sockfd, (struct sockaddr *)&con->client.local_addr,
135                 &con->client.local_addr_len) != 0) {
136         int saved_errno = errno;
137 
138         warn("getsockname failed: %s", strerror(errno));
139         free_connection(con);
140 
141         errno = saved_errno;
142         return 0;
143     }
144 
145     /* Avoiding type-punned pointer warning */
146     struct ev_io *client_watcher = &con->client.watcher;
147     ev_io_init(client_watcher, connection_cb, sockfd, EV_READ);
148     con->client.watcher.data = con;
149     con->state = ACCEPTED;
150     con->established_timestamp = ev_now(loop);
151 
152     TAILQ_INSERT_HEAD(&connections, con, entries);
153 
154     ev_io_start(loop, client_watcher);
155 
156     if (con->listener->table->use_proxy_header ||
157             con->listener->fallback_use_proxy_header)
158         insert_proxy_v1_header(con);
159 
160     return 1;
161 }
162 
163 /*
164  * Close and free all connections
165  */
166 void
free_connections(struct ev_loop * loop)167 free_connections(struct ev_loop *loop) {
168     struct Connection *iter;
169     while ((iter = TAILQ_FIRST(&connections)) != NULL) {
170         TAILQ_REMOVE(&connections, iter, entries);
171         close_connection(iter, loop);
172         free_connection(iter);
173     }
174 }
175 
176 /* dumps a list of all connections for debugging */
177 void
print_connections()178 print_connections() {
179     char filename[] = "/tmp/sniproxy-connections-XXXXXX";
180 
181     int fd = mkstemp(filename);
182     if (fd < 0) {
183         warn("mkstemp failed: %s", strerror(errno));
184         return;
185     }
186 
187     FILE *temp = fdopen(fd, "w");
188     if (temp == NULL) {
189         warn("fdopen failed: %s", strerror(errno));
190         return;
191     }
192 
193     fprintf(temp, "Running connections:\n");
194     struct Connection *iter;
195     TAILQ_FOREACH(iter, &connections, entries)
196         print_connection(temp, iter);
197 
198     if (fclose(temp) < 0)
199         warn("fclose failed: %s", strerror(errno));
200 
201     notice("Dumped connections to %s", filename);
202 }
203 
204 /*
205  * Test is client socket is open
206  *
207  * Returns true iff the client socket is opened based on connection state.
208  */
209 static inline int
client_socket_open(const struct Connection * con)210 client_socket_open(const struct Connection *con) {
211     return con->state == ACCEPTED ||
212         con->state == PARSED ||
213         con->state == RESOLVING ||
214         con->state == RESOLVED ||
215         con->state == CONNECTED ||
216         con->state == SERVER_CLOSED;
217 }
218 
219 /*
220  * Test is server socket is open
221  *
222  * Returns true iff the server socket is opened based on connection state.
223  */
224 static inline int
server_socket_open(const struct Connection * con)225 server_socket_open(const struct Connection *con) {
226     return con->state == CONNECTED ||
227         con->state == CLIENT_CLOSED;
228 }
229 
230 /*
231  * Main client callback: this is used by both the client and server watchers
232  *
233  * The logic is almost the same except for:
234  *  + input buffer
235  *  + output buffer
236  *  + how to close the socket
237  *
238  */
239 static void
connection_cb(struct ev_loop * loop,struct ev_io * w,int revents)240 connection_cb(struct ev_loop *loop, struct ev_io *w, int revents) {
241     struct Connection *con = (struct Connection *)w->data;
242     int is_client = &con->client.watcher == w;
243     const char *socket_name =
244         is_client ? "client" : "server";
245     struct Buffer *input_buffer =
246         is_client ? con->client.buffer : con->server.buffer;
247     struct Buffer *output_buffer =
248         is_client ? con->server.buffer : con->client.buffer;
249     void (*close_socket)(struct Connection *, struct ev_loop *) =
250         is_client ? close_client_socket : close_server_socket;
251 
252     /* Receive first in case the socket was closed */
253     if (revents & EV_READ && buffer_room(input_buffer)) {
254         ssize_t bytes_received = buffer_recv(input_buffer, w->fd, 0, loop);
255         if (bytes_received < 0 && !IS_TEMPORARY_SOCKERR(errno)) {
256             warn("recv(%s): %s, closing connection",
257                     socket_name,
258                     strerror(errno));
259 
260             close_socket(con, loop);
261             revents = 0; /* Clear revents so we don't try to send */
262         } else if (bytes_received == 0) { /* peer closed socket */
263             close_socket(con, loop);
264             revents = 0;
265         }
266     }
267 
268     /* Transmit */
269     if (revents & EV_WRITE && buffer_len(output_buffer)) {
270         ssize_t bytes_transmitted = buffer_send(output_buffer, w->fd, 0, loop);
271         if (bytes_transmitted < 0 && !IS_TEMPORARY_SOCKERR(errno)) {
272             warn("send(%s): %s, closing connection",
273                     socket_name,
274                     strerror(errno));
275 
276             close_socket(con, loop);
277         }
278     }
279 
280     /* Handle any state specific logic */
281     if (is_client && con->state == ACCEPTED)
282         parse_client_request(con);
283     if (is_client && con->state == PARSED)
284         resolve_server_address(con, loop);
285     if (is_client && con->state == RESOLVED)
286         initiate_server_connect(con, loop);
287 
288     /* Close other socket if we have flushed corresponding buffer */
289     if (con->state == SERVER_CLOSED && buffer_len(con->server.buffer) == 0)
290         close_client_socket(con, loop);
291     if (con->state == CLIENT_CLOSED && buffer_len(con->client.buffer) == 0)
292         close_server_socket(con, loop);
293 
294     if (con->state == CLOSED) {
295         TAILQ_REMOVE(&connections, con, entries);
296 
297         if (con->listener->access_log)
298             log_connection(con);
299 
300         free_connection(con);
301         return;
302     }
303 
304     reactivate_watchers(con, loop);
305 }
306 
307 static void
reactivate_watchers(struct Connection * con,struct ev_loop * loop)308 reactivate_watchers(struct Connection *con, struct ev_loop *loop) {
309     struct ev_io *client_watcher = &con->client.watcher;
310     struct ev_io *server_watcher = &con->server.watcher;
311 
312     /* Reactivate watchers */
313     if (client_socket_open(con))
314         reactivate_watcher(loop, client_watcher,
315                 con->client.buffer, con->server.buffer);
316 
317     if (server_socket_open(con))
318         reactivate_watcher(loop, server_watcher,
319                 con->server.buffer, con->client.buffer);
320 
321     /* Neither watcher is active when the corresponding socket is closed */
322     assert(client_socket_open(con) || !ev_is_active(client_watcher));
323     assert(server_socket_open(con) || !ev_is_active(server_watcher));
324 
325     /* At least one watcher is still active for this connection,
326      * or DNS callback active */
327     assert((ev_is_active(client_watcher) && con->client.watcher.events) ||
328            (ev_is_active(server_watcher) && con->server.watcher.events) ||
329            con->state == RESOLVING);
330 
331     /* Move to head of queue, so we can find inactive connections */
332     TAILQ_REMOVE(&connections, con, entries);
333     TAILQ_INSERT_HEAD(&connections, con, entries);
334 }
335 
336 static void
reactivate_watcher(struct ev_loop * loop,struct ev_io * w,const struct Buffer * input_buffer,const struct Buffer * output_buffer)337 reactivate_watcher(struct ev_loop *loop, struct ev_io *w,
338         const struct Buffer *input_buffer,
339         const struct Buffer *output_buffer) {
340     int events = 0;
341 
342     if (buffer_room(input_buffer))
343         events |= EV_READ;
344 
345     if (buffer_len(output_buffer))
346         events |= EV_WRITE;
347 
348     if (ev_is_active(w)) {
349         if (events == 0)
350             ev_io_stop(loop, w);
351         else if (events != w->events) {
352             ev_io_stop(loop, w);
353             ev_io_set(w, w->fd, events);
354             ev_io_start(loop, w);
355         }
356     } else if (events != 0) {
357         ev_io_set(w, w->fd, events);
358         ev_io_start(loop, w);
359     }
360 }
361 
362 static void
insert_proxy_v1_header(struct Connection * con)363 insert_proxy_v1_header(struct Connection *con) {
364     char buf[INET6_ADDRSTRLEN] = { '\0' };
365     size_t buf_len;
366 
367     con->header_len += buffer_push(con->client.buffer, "PROXY ", 6);
368 
369     switch (con->client.addr.ss_family) {
370         case AF_INET:
371             con->header_len += buffer_push(con->client.buffer, "TCP4 ", 5);
372 
373             inet_ntop(AF_INET,
374                       &((const struct sockaddr_in *)&con->client.addr)->
375                       sin_addr, buf, sizeof(buf));
376             buf_len = strlen(buf);
377             con->header_len += buffer_push(con->client.buffer, buf, buf_len);
378 
379             con->header_len += buffer_push(con->client.buffer, " ", 1);
380 
381             inet_ntop(AF_INET,
382                       &((const struct sockaddr_in *)&con->client.local_addr)->
383                       sin_addr, buf, sizeof(buf));
384             buf_len = strlen(buf);
385             con->header_len += buffer_push(con->client.buffer, buf, buf_len);
386 
387             buf_len = snprintf(buf, sizeof(buf), " %" PRIu16,
388                               ntohs(((const struct sockaddr_in *)&con->
389                               client.addr)->sin_port));
390             con->header_len += buffer_push(con->client.buffer, buf, buf_len);
391 
392             buf_len = snprintf(buf, sizeof(buf), " %" PRIu16,
393                               ntohs(((const struct sockaddr_in *)&con->
394                               client.local_addr)->sin_port));
395             con->header_len += buffer_push(con->client.buffer, buf, buf_len);
396 
397             break;
398         case AF_INET6:
399             con->header_len += buffer_push(con->client.buffer, "TCP6 ", 5);
400             inet_ntop(AF_INET6,
401                     &((const struct sockaddr_in6 *)&con->client.addr)->
402                     sin6_addr, buf, sizeof(buf));
403             buf_len = strlen(buf);
404             con->header_len += buffer_push(con->client.buffer, buf, buf_len);
405 
406             con->header_len += buffer_push(con->client.buffer, " ", 1);
407 
408             inet_ntop(AF_INET6,
409                       &((const struct sockaddr_in6 *)&con->
410                       client.local_addr)->sin6_addr, buf, sizeof(buf));
411             buf_len = strlen(buf);
412             con->header_len += buffer_push(con->client.buffer, buf, buf_len);
413 
414             buf_len = snprintf(buf, sizeof(buf), " %" PRIu16,
415                               ntohs(((const struct sockaddr_in6 *)&con->
416                               client.addr)->sin6_port));
417             con->header_len += buffer_push(con->client.buffer, buf, buf_len);
418 
419             buf_len = snprintf(buf, sizeof(buf), " %" PRIu16,
420                               ntohs(((const struct sockaddr_in6 *)&con->
421                               client.local_addr)->sin6_port));
422             con->header_len += buffer_push(con->client.buffer, buf, buf_len);
423 
424             break;
425         default:
426             con->header_len += buffer_push(con->client.buffer, "UNKNOWN", 7);
427     }
428     con->header_len += buffer_push(con->client.buffer, "\r\n", 2);
429 }
430 
431 static void
parse_client_request(struct Connection * con)432 parse_client_request(struct Connection *con) {
433     const char *payload;
434     size_t payload_len = buffer_coalesce(con->client.buffer, (const void **)&payload);
435     char *hostname = NULL;
436 
437     /* Avoid payload_len underflow and empty request */
438     if (payload_len <= con->header_len)
439         return;
440 
441     payload += con->header_len;
442     payload_len -= con->header_len;
443 
444     int result = con->listener->protocol->parse_packet(payload, payload_len, &hostname);
445     if (result < 0) {
446         char client[INET6_ADDRSTRLEN + 8];
447 
448         if (result == -1) { /* incomplete request */
449             if (buffer_room(con->client.buffer) > 0)
450                 return; /* give client a chance to send more data */
451 
452             warn("Request from %s exceeded %zu byte buffer size",
453                     display_sockaddr(&con->client.addr, client, sizeof(client)),
454                     buffer_size(con->client.buffer));
455         } else if (result == -2) {
456             warn("Request from %s did not include a hostname",
457                     display_sockaddr(&con->client.addr, client, sizeof(client)));
458         } else {
459             warn("Unable to parse request from %s: parse_packet returned %d",
460                     display_sockaddr(&con->client.addr, client, sizeof(client)),
461                     result);
462 
463             if (con->listener->log_bad_requests)
464                 log_bad_request(con, payload, payload_len, result);
465         }
466 
467         if (con->listener->fallback_address == NULL) {
468             abort_connection(con);
469             return;
470         }
471     }
472 
473     con->hostname = hostname;
474     con->hostname_len = (size_t)result;
475     con->state = PARSED;
476 }
477 
478 static void
abort_connection(struct Connection * con)479 abort_connection(struct Connection *con) {
480     assert(client_socket_open(con));
481 
482     buffer_push(con->server.buffer,
483             con->listener->protocol->abort_message,
484             con->listener->protocol->abort_message_len);
485 
486     con->state = SERVER_CLOSED;
487 }
488 
489 static void
resolve_server_address(struct Connection * con,struct ev_loop * loop)490 resolve_server_address(struct Connection *con, struct ev_loop *loop) {
491     struct LookupResult result =
492         listener_lookup_server_address(con->listener, con->hostname, con->hostname_len);
493 
494     if (result.address == NULL) {
495         abort_connection(con);
496         return;
497     } else if (address_is_hostname(result.address)) {
498 #ifndef HAVE_LIBUDNS
499         warn("DNS lookups not supported unless sniproxy compiled with libudns");
500 
501         if (result.caller_free_address)
502             free((void *)result.address);
503 
504         abort_connection(con);
505         return;
506 #else
507         struct resolv_cb_data *cb_data = malloc(sizeof(struct resolv_cb_data));
508         if (cb_data == NULL) {
509             err("%s: malloc", __func__);
510 
511             if (result.caller_free_address)
512                 free((void *)result.address);
513 
514             abort_connection(con);
515             return;
516         }
517         cb_data->connection = con;
518         cb_data->address = result.address;
519         cb_data->cb_free_addr = result.caller_free_address;
520         cb_data->loop = loop;
521         con->use_proxy_header = result.use_proxy_header;
522 
523         int resolv_mode = RESOLV_MODE_DEFAULT;
524         if (con->listener->transparent_proxy) {
525             char listener_address[ADDRESS_BUFFER_SIZE];
526             switch (con->client.addr.ss_family) {
527                 case AF_INET:
528                     resolv_mode = RESOLV_MODE_IPV4_ONLY;
529                     break;
530                 case AF_INET6:
531                     resolv_mode = RESOLV_MODE_IPV6_ONLY;
532                     break;
533                 default:
534                     warn("attempt to use transparent proxy with hostname %s "
535                             "on non-IP listener %s, falling back to "
536                             "non-transparent mode",
537                             address_hostname(result.address),
538                             display_sockaddr(con->listener->address,
539                                     listener_address, sizeof(listener_address))
540                             );
541             }
542         }
543 
544         con->query_handle = resolv_query(address_hostname(result.address),
545                 resolv_mode, resolv_cb,
546                 (void (*)(void *))free_resolv_cb_data, cb_data);
547 
548         con->state = RESOLVING;
549 #endif
550     } else if (address_is_sockaddr(result.address)) {
551         con->server.addr_len = address_sa_len(result.address);
552         assert(con->server.addr_len <= sizeof(con->server.addr));
553         memcpy(&con->server.addr, address_sa(result.address),
554             con->server.addr_len);
555         con->use_proxy_header = result.use_proxy_header;
556 
557         if (result.caller_free_address)
558             free((void *)result.address);
559 
560         con->state = RESOLVED;
561     } else {
562         /* invalid address type */
563         assert(0);
564     }
565 }
566 
567 static void
resolv_cb(struct Address * result,void * data)568 resolv_cb(struct Address *result, void *data) {
569     struct resolv_cb_data *cb_data = (struct resolv_cb_data *)data;
570     struct Connection *con = cb_data->connection;
571     struct ev_loop *loop = cb_data->loop;
572 
573     if (con->state != RESOLVING) {
574         info("resolv_cb() called for connection not in RESOLVING state");
575         return;
576     }
577 
578     if (result == NULL) {
579         notice("unable to resolve %s, closing connection",
580                 address_hostname(cb_data->address));
581         abort_connection(con);
582     } else {
583         assert(address_is_sockaddr(result));
584 
585         /* copy port from server_address */
586         address_set_port(result, address_port(cb_data->address));
587 
588         con->server.addr_len = address_sa_len(result);
589         assert(con->server.addr_len <= sizeof(con->server.addr));
590         memcpy(&con->server.addr, address_sa(result), con->server.addr_len);
591 
592         con->state = RESOLVED;
593 
594         initiate_server_connect(con, loop);
595     }
596 
597     con->query_handle = NULL;
598     reactivate_watchers(con, loop);
599 }
600 
601 static void
free_resolv_cb_data(struct resolv_cb_data * cb_data)602 free_resolv_cb_data(struct resolv_cb_data *cb_data) {
603     if (cb_data->cb_free_addr)
604         free((void *)cb_data->address);
605     free(cb_data);
606 }
607 
608 static void
initiate_server_connect(struct Connection * con,struct ev_loop * loop)609 initiate_server_connect(struct Connection *con, struct ev_loop *loop) {
610 #ifdef HAVE_ACCEPT4
611     int sockfd = socket(con->server.addr.ss_family, SOCK_STREAM | SOCK_NONBLOCK, 0);
612 #else
613     int sockfd = socket(con->server.addr.ss_family, SOCK_STREAM, 0);
614 #endif
615     if (sockfd < 0) {
616         char client[INET6_ADDRSTRLEN + 8];
617         warn("socket failed: %s, closing connection from %s",
618                 strerror(errno),
619                 display_sockaddr(&con->client.addr, client, sizeof(client)));
620         abort_connection(con);
621         return;
622     }
623 
624 #ifndef HAVE_ACCEPT4
625     int flags = fcntl(sockfd, F_GETFL, 0);
626     fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
627 #endif
628 
629     if (con->listener->transparent_proxy &&
630             con->client.addr.ss_family == con->server.addr.ss_family) {
631         int on = 1;
632 #ifdef IP_TRANSPARENT
633         int result = setsockopt(sockfd, SOL_IP, IP_TRANSPARENT, &on, sizeof(on));
634 #else
635         int result = -EPERM;
636         /* XXX error: not implemented would be better, but this shouldn't be
637          * reached since it is prohibited in the configuration parser. */
638 #endif
639         if (result < 0) {
640             err("setsockopt IP_TRANSPARENT failed: %s", strerror(errno));
641             close(sockfd);
642             abort_connection(con);
643             return;
644         }
645 
646         result = bind(sockfd, (struct sockaddr *)&con->client.addr,
647                 con->client.addr_len);
648         if (result < 0) {
649             err("bind failed: %s", strerror(errno));
650             close(sockfd);
651             abort_connection(con);
652             return;
653         }
654     } else if (con->listener->source_address) {
655         int on = 1;
656         int result = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
657         if (result < 0) {
658             err("setsockopt SO_REUSEADDR failed: %s", strerror(errno));
659             close(sockfd);
660             abort_connection(con);
661             return;
662         }
663 
664         int tries = 5;
665         do {
666             result = bind(sockfd,
667                     address_sa(con->listener->source_address),
668                     address_sa_len(con->listener->source_address));
669         } while (tries-- > 0
670                 && result < 0
671                 && errno == EADDRINUSE
672                 && address_port(con->listener->source_address) == 0);
673         if (result < 0) {
674             err("bind failed: %s", strerror(errno));
675             close(sockfd);
676             abort_connection(con);
677             return;
678         }
679     }
680 
681     int result = connect(sockfd,
682             (struct sockaddr *)&con->server.addr,
683             con->server.addr_len);
684     /* TODO retry connect in EADDRNOTAVAIL case */
685     if (result < 0 && errno != EINPROGRESS) {
686         close(sockfd);
687         char server[INET6_ADDRSTRLEN + 8];
688         warn("Failed to open connection to %s: %s",
689                 display_sockaddr(&con->server.addr, server, sizeof(server)),
690                 strerror(errno));
691         abort_connection(con);
692         return;
693     }
694 
695     if (getsockname(sockfd, (struct sockaddr *)&con->server.local_addr,
696                 &con->server.local_addr_len) != 0) {
697         close(sockfd);
698         warn("getsockname failed: %s", strerror(errno));
699 
700         abort_connection(con);
701         return;
702     }
703 
704     if (con->header_len && !con->use_proxy_header) {
705         /* If we prepended the PROXY header and this backend isn't configured
706          * to receive it, consume it now */
707         buffer_pop(con->client.buffer, NULL, con->header_len);
708     }
709 
710     struct ev_io *server_watcher = &con->server.watcher;
711     ev_io_init(server_watcher, connection_cb, sockfd, EV_WRITE);
712     con->server.watcher.data = con;
713     con->state = CONNECTED;
714 
715     ev_io_start(loop, server_watcher);
716 }
717 
718 /* Close client socket.
719  * Caller must ensure that it has not been closed before.
720  */
721 static void
close_client_socket(struct Connection * con,struct ev_loop * loop)722 close_client_socket(struct Connection *con, struct ev_loop *loop) {
723     assert(con->state != CLOSED
724             && con->state != CLIENT_CLOSED);
725 
726     ev_io_stop(loop, &con->client.watcher);
727 
728     if (close(con->client.watcher.fd) < 0)
729         warn("close failed: %s", strerror(errno));
730 
731     if (con->state == RESOLVING) {
732         resolv_cancel(con->query_handle);
733         con->state = PARSED;
734     }
735 
736     /* next state depends on previous state */
737     if (con->state == SERVER_CLOSED
738             || con->state == ACCEPTED
739             || con->state == PARSED
740             || con->state == RESOLVING
741             || con->state == RESOLVED)
742         con->state = CLOSED;
743     else
744         con->state = CLIENT_CLOSED;
745 }
746 
747 /* Close server socket.
748  * Caller must ensure that it has not been closed before.
749  */
750 static void
close_server_socket(struct Connection * con,struct ev_loop * loop)751 close_server_socket(struct Connection *con, struct ev_loop *loop) {
752     assert(con->state != CLOSED
753             && con->state != SERVER_CLOSED);
754 
755     ev_io_stop(loop, &con->server.watcher);
756 
757     if (close(con->server.watcher.fd) < 0)
758         warn("close failed: %s", strerror(errno));
759 
760     /* next state depends on previous state */
761     if (con->state == CLIENT_CLOSED)
762         con->state = CLOSED;
763     else
764         con->state = SERVER_CLOSED;
765 }
766 
767 static void
close_connection(struct Connection * con,struct ev_loop * loop)768 close_connection(struct Connection *con, struct ev_loop *loop) {
769     assert(con->state != NEW); /* only used during initialization */
770 
771     if (con->state == CONNECTED
772             || con->state == CLIENT_CLOSED)
773         close_server_socket(con, loop);
774 
775     assert(con->state == ACCEPTED
776             || con->state == PARSED
777             || con->state == RESOLVING
778             || con->state == RESOLVED
779             || con->state == SERVER_CLOSED
780             || con->state == CLOSED);
781 
782     if (con->state == ACCEPTED
783             || con->state == PARSED
784             || con->state == RESOLVING
785             || con->state == RESOLVED
786             || con->state == SERVER_CLOSED)
787         close_client_socket(con, loop);
788 
789     assert(con->state == CLOSED);
790 }
791 
792 /*
793  * Allocate and initialize a new connection
794  */
795 static struct Connection *
new_connection(struct ev_loop * loop)796 new_connection(struct ev_loop *loop) {
797     struct Connection *con = calloc(1, sizeof(struct Connection));
798     if (con == NULL)
799         return NULL;
800 
801     con->state = NEW;
802     con->client.addr_len = sizeof(con->client.addr);
803     con->client.local_addr = (struct sockaddr_storage){.ss_family = AF_UNSPEC};
804     con->client.local_addr_len = sizeof(con->client.local_addr);
805     con->server.addr_len = sizeof(con->server.addr);
806     con->server.local_addr = (struct sockaddr_storage){.ss_family = AF_UNSPEC};
807     con->server.local_addr_len = sizeof(con->server.local_addr);
808     con->hostname = NULL;
809     con->hostname_len = 0;
810     con->header_len = 0;
811     con->query_handle = NULL;
812     con->use_proxy_header = 0;
813 
814     con->client.buffer = new_buffer(4096, loop);
815     if (con->client.buffer == NULL) {
816         free_connection(con);
817         return NULL;
818     }
819 
820     con->server.buffer = new_buffer(4096, loop);
821     if (con->server.buffer == NULL) {
822         free_connection(con);
823         return NULL;
824     }
825 
826     return con;
827 }
828 
829 static void
log_connection(struct Connection * con)830 log_connection(struct Connection *con) {
831     ev_tstamp duration;
832     char client_address[ADDRESS_BUFFER_SIZE];
833     char listener_address[ADDRESS_BUFFER_SIZE];
834     char server_address[ADDRESS_BUFFER_SIZE];
835 
836     if (con->client.buffer->last_recv > con->server.buffer->last_recv)
837         duration = con->client.buffer->last_recv - con->established_timestamp;
838     else
839         duration = con->server.buffer->last_recv - con->established_timestamp;
840 
841     display_sockaddr(&con->client.addr, client_address, sizeof(client_address));
842     display_sockaddr(&con->client.local_addr, listener_address, sizeof(listener_address));
843     display_sockaddr(&con->server.addr, server_address, sizeof(server_address));
844 
845     log_msg(con->listener->access_log,
846            LOG_NOTICE,
847            "%s -> %s -> %s [%.*s] %zu/%zu bytes tx %zu/%zu bytes rx %1.3f seconds",
848            client_address,
849            listener_address,
850            server_address,
851            (int)con->hostname_len,
852            con->hostname,
853            con->server.buffer->tx_bytes,
854            con->server.buffer->rx_bytes,
855            con->client.buffer->tx_bytes,
856            con->client.buffer->rx_bytes,
857            duration);
858 }
859 
860 static void
log_bad_request(struct Connection * con,const char * req,size_t req_len,int parse_result)861 log_bad_request(struct Connection *con __attribute__((unused)), const char *req, size_t req_len, int parse_result) {
862     size_t message_len = 64 + 6 * req_len;
863     char *message = malloc(message_len);
864     if (message == NULL) {
865         err("log_bad_request: unable to allocate message buffer");
866         return;
867     }
868     char *message_pos = message;
869     char *message_end = message + message_len;
870 
871     message_pos += snprintf(message_pos, (size_t)(message_end - message_pos),
872                             "parse_packet({");
873 
874     for (size_t i = 0; i < req_len; i++)
875         message_pos += snprintf(message_pos, (size_t)(message_end - message_pos),
876                                 "0x%02hhx, ", (unsigned char)req[i]);
877 
878     message_pos -= 2;/* Delete the trailing ', ' */
879     snprintf(message_pos, (size_t)(message_end - message_pos), "}, %zu, ...) = %d",
880              req_len, parse_result);
881     debug("%s", message);
882 
883     free(message);
884 }
885 
886 /*
887  * Free a connection and associated data
888  *
889  * Requires that no watchers remain active
890  */
891 static void
free_connection(struct Connection * con)892 free_connection(struct Connection *con) {
893     if (con == NULL)
894         return;
895 
896     listener_ref_put(con->listener);
897     free_buffer(con->client.buffer);
898     free_buffer(con->server.buffer);
899     free((void *)con->hostname); /* cast away const'ness */
900     free(con);
901 }
902 
903 static void
print_connection(FILE * file,const struct Connection * con)904 print_connection(FILE *file, const struct Connection *con) {
905     char client[INET6_ADDRSTRLEN + 8];
906     char server[INET6_ADDRSTRLEN + 8];
907 
908     switch (con->state) {
909         case NEW:
910             fprintf(file, "NEW           -\t-\n");
911             break;
912         case ACCEPTED:
913             fprintf(file, "ACCEPTED      %s %zu/%zu\t-\n",
914                     display_sockaddr(&con->client.addr, client, sizeof(client)),
915                     buffer_len(con->client.buffer), buffer_size(con->client.buffer));
916             break;
917         case PARSED:
918             fprintf(file, "PARSED        %s %zu/%zu\t-\n",
919                     display_sockaddr(&con->client.addr, client, sizeof(client)),
920                     buffer_len(con->client.buffer), buffer_size(con->client.buffer));
921             break;
922         case RESOLVING:
923             fprintf(file, "RESOLVING      %s %zu/%zu\t-\n",
924                     display_sockaddr(&con->client.addr, client, sizeof(client)),
925                     buffer_len(con->client.buffer), buffer_size(con->client.buffer));
926             break;
927         case RESOLVED:
928             fprintf(file, "RESOLVED      %s %zu/%zu\t-\n",
929                     display_sockaddr(&con->client.addr, client, sizeof(client)),
930                     buffer_len(con->client.buffer), buffer_size(con->client.buffer));
931             break;
932         case CONNECTED:
933             fprintf(file, "CONNECTED     %s %zu/%zu\t%s %zu/%zu\n",
934                     display_sockaddr(&con->client.addr, client, sizeof(client)),
935                     buffer_len(con->client.buffer), buffer_size(con->client.buffer),
936                     display_sockaddr(&con->server.addr, server, sizeof(server)),
937                     buffer_len(con->server.buffer), buffer_size(con->server.buffer));
938             break;
939         case SERVER_CLOSED:
940             fprintf(file, "SERVER_CLOSED %s %zu/%zu\t-\n",
941                     display_sockaddr(&con->client.addr, client, sizeof(client)),
942                     buffer_len(con->client.buffer), buffer_size(con->client.buffer));
943             break;
944         case CLIENT_CLOSED:
945             fprintf(file, "CLIENT_CLOSED -\t%s %zu/%zu\n",
946                     display_sockaddr(&con->server.addr, server, sizeof(server)),
947                     buffer_len(con->server.buffer), buffer_size(con->server.buffer));
948             break;
949         case CLOSED:
950             fprintf(file, "CLOSED        -\t-\n");
951             break;
952     }
953 }
954