1 // httpconnection.c -- Manage state machine for HTTP connections
2 // Copyright (C) 2008-2010 Markus Gutschke <markus@shellinabox.com>
3 //
4 // This program is free software; you can redistribute it and/or modify
5 // it under the terms of the GNU General Public License version 2 as
6 // published by the Free Software Foundation.
7 //
8 // This program is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License along
14 // with this program; if not, write to the Free Software Foundation, Inc.,
15 // 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16 //
17 // In addition to these license terms, the author grants the following
18 // additional rights:
19 //
20 // If you modify this program, or any covered work, by linking or
21 // combining it with the OpenSSL project's OpenSSL library (or a
22 // modified version of that library), containing parts covered by the
23 // terms of the OpenSSL or SSLeay licenses, the author
24 // grants you additional permission to convey the resulting work.
25 // Corresponding Source for a non-source form of such a combination
26 // shall include the source code for the parts of OpenSSL used as well
27 // as that of the covered work.
28 //
29 // You may at your option choose to remove this additional permission from
30 // the work, or from any part of it.
31 //
32 // It is possible to build this program in a way that it loads OpenSSL
33 // libraries at run-time. If doing so, the following notices are required
34 // by the OpenSSL and SSLeay licenses:
35 //
36 // This product includes software developed by the OpenSSL Project
37 // for use in the OpenSSL Toolkit. (http://www.openssl.org/)
38 //
39 // This product includes cryptographic software written by Eric Young
40 // (eay@cryptsoft.com)
41 //
42 //
43 // The most up-to-date version of this program is always available from
44 // http://shellinabox.com
45 
46 #define _GNU_SOURCE
47 #include "config.h"
48 
49 #include <errno.h>
50 #include <arpa/inet.h>
51 #include <math.h>
52 #include <netdb.h>
53 #include <netinet/in.h>
54 #include <stdio.h>
55 #include <stdlib.h>
56 #include <string.h>
57 #include <sys/poll.h>
58 #include <sys/socket.h>
59 #include <sys/types.h>
60 #include <unistd.h>
61 
62 #ifdef HAVE_ZLIB
63 #include <zlib.h>
64 #endif
65 
66 #ifdef HAVE_STRLCAT
67 #define strncat(a,b,c) ({ char *_a = (a); strlcat(_a, (b), (c)+1); _a; })
68 #endif
69 #ifndef HAVE_ISNAN
70 #define isnan(x) ({ typeof(x) _x = (x); _x != _x; })
71 #endif
72 #define max(a, b) ({ typeof(a) _a = (a); typeof(b) _b = (b);                  \
73                      _a > _b ? _a : _b; })
74 #ifdef HAVE_UNUSED
75 #defined ATTR_UNUSED __attribute__((unused))
76 #defined UNUSED(x)   do { } while (0)
77 #else
78 #define ATTR_UNUSED
79 #define UNUSED(x)    do { (void)(x); } while (0)
80 #endif
81 
82 #include "libhttp/httpconnection.h"
83 #include "logging/logging.h"
84 
85 #define MAX_HEADER_LENGTH   (64<<10)
86 #define CONNECTION_TIMEOUT  (10*60)
87 
httpPromoteToSSL(struct HttpConnection * http,const char * buf,int len)88 static int httpPromoteToSSL(struct HttpConnection *http, const char *buf,
89                             int len) {
90   if (http->ssl->enabled && !http->sslHndl) {
91     debug("[ssl] Switching to SSL (replaying %d+%d bytes)...",
92           http->partialLength, len);
93     if (http->partial && len > 0) {
94       check(http->partial  = realloc(http->partial,
95                                      http->partialLength + len));
96       memcpy(http->partial + http->partialLength, buf, len);
97       http->partialLength += len;
98     }
99     int rc                 = sslPromoteToSSL(
100                                     http->ssl, &http->sslHndl, http->fd,
101                                     http->partial ? http->partial : buf,
102                                     http->partial ? http->partialLength : len);
103     if (http->sslHndl) {
104       check(!rc);
105       // Reset renegotiations count for connections promoted to SSL.
106       http->ssl->renegotiationCount = 0;
107       SSL_set_app_data(http->sslHndl, http);
108     }
109     free(http->partial);
110     http->partialLength    = 0;
111     return rc;
112   } else {
113     errno                  = EINVAL;
114     return -1;
115   }
116 }
117 
httpRead(struct HttpConnection * http,char * buf,ssize_t len)118 static ssize_t httpRead(struct HttpConnection *http, char *buf, ssize_t len) {
119   sslBlockSigPipe();
120   int rc;
121   if (http->sslHndl) {
122     dcheck(!ERR_peek_error());
123     rc                        = SSL_read(http->sslHndl, buf, len);
124     switch (rc) {
125     case 0:
126     case -1:
127       switch (http->lastError = SSL_get_error(http->sslHndl, rc)) {
128       case SSL_ERROR_WANT_READ:
129       case SSL_ERROR_WANT_WRITE:
130         errno                 = EAGAIN;
131         rc                    = -1;
132         break;
133       default:
134         errno                 = EINVAL;
135         break;
136       }
137       ERR_clear_error();
138       break;
139     default:
140       break;
141     }
142     dcheck(!ERR_peek_error());
143 
144     // Shutdown SSL connection, if client initiated renegotiation.
145     if (http->ssl->renegotiationCount > 1) {
146       debug("[ssl] Connection shutdown due to client initiated renegotiation!");
147       rc                     = 0;
148       errno                  = EINVAL;
149     }
150   } else {
151     rc = NOINTR(read(http->fd, buf, len));
152   }
153   sslUnblockSigPipe();
154   if (rc > 0) {
155     serverSetTimeout(httpGetServerConnection(http), CONNECTION_TIMEOUT);
156   }
157   return rc;
158 }
159 
httpWrite(struct HttpConnection * http,const char * buf,ssize_t len)160 static ssize_t httpWrite(struct HttpConnection *http, const char *buf,
161                          ssize_t len) {
162   sslBlockSigPipe();
163   int rc;
164   if (http->sslHndl) {
165     dcheck(!ERR_peek_error());
166     rc                        = SSL_write(http->sslHndl, buf, len);
167     switch (rc) {
168     case 0:
169     case -1:
170       switch (http->lastError = SSL_get_error(http->sslHndl, rc)) {
171       case SSL_ERROR_WANT_READ:
172       case SSL_ERROR_WANT_WRITE:
173         errno                 = EAGAIN;
174         rc                    = -1;
175         break;
176       default:
177         errno                 = EINVAL;
178         break;
179       }
180       ERR_clear_error();
181       break;
182     default:
183       break;
184     }
185     dcheck(!ERR_peek_error());
186   } else {
187     rc = NOINTR(write(http->fd, buf, len));
188   }
189   sslUnblockSigPipe();
190   return rc;
191 }
192 
httpShutdown(struct HttpConnection * http,int how)193 static int httpShutdown(struct HttpConnection *http, int how) {
194   if (http->sslHndl) {
195     if (how != SHUT_RD) {
196       dcheck(!ERR_peek_error());
197       for (int i = 0; i < 10; i++) {
198         sslBlockSigPipe();
199         int rc    = SSL_shutdown(http->sslHndl);
200         int sPipe = sslUnblockSigPipe();
201         if (rc < 1) {
202           // Retry a few times in order to prefer a clean bidirectional
203           // shutdown. But don't bother if the other side already closed
204           // the connection.
205           if (sPipe) {
206             break;
207           }
208         }
209       }
210       sslFreeHndl(&http->sslHndl);
211     }
212   }
213   return shutdown(http->fd, how);
214 }
215 
httpCloseRead(struct HttpConnection * http)216 static void httpCloseRead(struct HttpConnection *http) {
217   if (!http->closed) {
218     httpShutdown(http, SHUT_RD);
219     http->closed = 1;
220   }
221 }
222 
223 #ifndef HAVE_STRCASESTR
strcasestr(const char * haystack,const char * needle)224 static char *strcasestr(const char *haystack, const char *needle) {
225   // This algorithm is O(len(haystack)*len(needle)). Much better algorithms
226   // are available, but this code is much simpler and performance is not
227   // critical for our workloads.
228   int len = strlen(needle);
229   do {
230     if (!strncasecmp(haystack, needle, len)) {
231       return haystack;
232     }
233   } while (*haystack++);
234   return NULL;
235 }
236 #endif
237 
httpFinishCommand(struct HttpConnection * http)238 static int httpFinishCommand(struct HttpConnection *http) {
239   int rc            = HTTP_DONE;
240   if ((http->callback || http->websocketHandler) && !http->done) {
241     rc              = http->callback ? http->callback(http, http->arg, NULL, 0)
242        : http->websocketHandler(http, http->arg, WS_CONNECTION_CLOSED, NULL,0);
243     check(rc != HTTP_SUSPEND);
244     check(rc != HTTP_PARTIAL_REPLY);
245     http->callback  = NULL;
246     http->arg       = NULL;
247     if (rc == HTTP_ERROR) {
248       httpCloseRead(http);
249     }
250   }
251   if (!http->closed) {
252     const char *con = getFromHashMap(&http->header, "connection");
253     if ((con && strcasestr(con, "close")) ||
254         !http->version || strcmp(http->version, "HTTP/1.1") < 0) {
255       httpCloseRead(http);
256     }
257   }
258   if (logIsInfo()) {
259     check(http->method);
260     check(http->path);
261     check(http->version);
262     if (http->peerName) {
263       time_t t      = currentTime;
264       struct tm *ltime;
265       check (ltime  = localtime(&t));
266       char timeBuf[80];
267       char lengthBuf[40];
268       check(strftime(timeBuf, sizeof(timeBuf),
269                      "[%d/%b/%Y:%H:%M:%S %z]", ltime));
270       if (http->totalWritten > 0) {
271         snprintf(lengthBuf, sizeof(lengthBuf), "%d", http->totalWritten);
272       } else {
273         *lengthBuf  = '\000';
274         strncat(lengthBuf, "-", sizeof(lengthBuf)-1);
275       }
276       info("[http] %s - - %s \"%s %s %s\" %d %s",
277            http->peerName, timeBuf, http->method, http->path, http->version,
278            http->code, lengthBuf);
279     }
280   }
281   return rc;
282 }
283 
httpDestroyHeaders(void * arg ATTR_UNUSED,char * key,char * value)284 static void httpDestroyHeaders(void *arg ATTR_UNUSED, char *key, char *value) {
285   UNUSED(arg);
286   free(key);
287   free(value);
288 }
289 
getPeerName(int fd,int * port,int numericHosts)290 static char *getPeerName(int fd, int *port, int numericHosts) {
291   struct sockaddr peerAddr;
292   socklen_t sockLen = sizeof(peerAddr);
293   if (getpeername(fd, &peerAddr, &sockLen)) {
294     if (port) {
295       *port         = -1;
296     }
297     return NULL;
298   }
299   char *ret;
300   if (peerAddr.sa_family == AF_UNIX) {
301     if (port) {
302       *port         = 0;
303     }
304     check(ret       = strdup("localhost"));
305     return ret;
306   }
307   char host[256];
308   if (numericHosts ||
309       getnameinfo(&peerAddr, sockLen, host, sizeof(host), NULL, 0, NI_NOFQDN)){
310     check(inet_ntop(peerAddr.sa_family,
311                     &((struct sockaddr_in *)&peerAddr)->sin_addr,
312                     host, sizeof(host)));
313   }
314   if (port) {
315     *port           = ntohs(((struct sockaddr_in *)&peerAddr)->sin_port);
316   }
317   check(ret         = strdup(host));
318   return ret;
319 }
320 
httpSetState(struct HttpConnection * http,int state)321 static void httpSetState(struct HttpConnection *http, int state) {
322   if (state == (int)http->state) {
323     return;
324   }
325 
326   if (state == COMMAND) {
327     if (http->state != SNIFFING_SSL) {
328       int rc                 = httpFinishCommand(http);
329       check(rc != HTTP_SUSPEND);
330       check(rc != HTTP_PARTIAL_REPLY);
331     }
332     check(!http->private);
333     free(http->url);
334     free(http->method);
335     free(http->path);
336     free(http->matchedPath);
337     free(http->pathInfo);
338     free(http->query);
339     free(http->version);
340     http->done               = 0;
341     http->url                = NULL;
342     http->method             = NULL;
343     http->path               = NULL;
344     http->matchedPath        = NULL;
345     http->pathInfo           = NULL;
346     http->query              = NULL;
347     http->version            = NULL;
348     destroyHashMap(&http->header);
349     initHashMap(&http->header, httpDestroyHeaders, NULL);
350     http->headerLength       = 0;
351     http->callback           = NULL;
352     http->arg                = NULL;
353     http->totalWritten       = 0;
354     http->code               = 200;
355   }
356   http->state                = state;
357 }
358 
newHttpConnection(struct Server * server,int fd,int port,struct SSLSupport * ssl,int numericHosts)359 struct HttpConnection *newHttpConnection(struct Server *server, int fd,
360                                          int port, struct SSLSupport *ssl,
361                                          int numericHosts) {
362   struct HttpConnection *http;
363   check(http = malloc(sizeof(struct HttpConnection)));
364   initHttpConnection(http, server, fd, port, ssl, numericHosts);
365   return http;
366 }
367 
initHttpConnection(struct HttpConnection * http,struct Server * server,int fd,int port,struct SSLSupport * ssl,int numericHosts)368 void initHttpConnection(struct HttpConnection *http, struct Server *server,
369                         int fd, int port, struct SSLSupport *ssl,
370                         int numericHosts) {
371   http->server             = server;
372   http->connection         = NULL;
373   http->fd                 = fd;
374   http->port               = port;
375   http->closed             = 0;
376   http->isSuspended        = 0;
377   http->isPartialReply     = 0;
378   http->done               = 0;
379   http->state              = ssl ? SNIFFING_SSL : COMMAND;
380   http->peerName           = getPeerName(fd, &http->peerPort, numericHosts);
381   http->url                = NULL;
382   http->method             = NULL;
383   http->path               = NULL;
384   http->matchedPath        = NULL;
385   http->pathInfo           = NULL;
386   http->query              = NULL;
387   http->version            = NULL;
388   initHashMap(&http->header, httpDestroyHeaders, NULL);
389   http->headerLength       = 0;
390   http->key                = NULL;
391   http->partial            = NULL;
392   http->partialLength      = 0;
393   http->msg                = NULL;
394   http->msgLength          = 0;
395   http->msgOffset          = 0;
396   http->totalWritten       = 0;
397   http->expecting          = 0;
398   http->websocketType      = WS_UNDEFINED;
399   http->callback           = NULL;
400   http->websocketHandler   = NULL;
401   http->arg                = NULL;
402   http->private            = NULL;
403   http->code               = 200;
404   http->ssl                = ssl;
405   http->sslHndl            = NULL;
406   http->lastError          = 0;
407   if (logIsInfo()) {
408     debug("[http] Accepted connection from %s:%d",
409           http->peerName ? http->peerName : "???", http->peerPort);
410   }
411 }
412 
destroyHttpConnection(struct HttpConnection * http)413 void destroyHttpConnection(struct HttpConnection *http) {
414   if (http) {
415     if (http->isSuspended || http->isPartialReply) {
416       if (!http->done) {
417         if (http->callback) {
418           http->callback(http, http->arg, NULL, 0);
419         } else if (http->websocketHandler) {
420           http->websocketHandler(http, http->arg, WS_CONNECTION_CLOSED,NULL,0);
421         }
422       }
423       http->callback       = NULL;
424       http->isSuspended    = 0;
425       http->isPartialReply = 0;
426     }
427     httpSetState(http, COMMAND);
428     if (logIsInfo()) {
429       debug("[http] Closing connection to %s:%d",
430             http->peerName ? http->peerName : "???", http->peerPort);
431     }
432     httpShutdown(http, http->closed ? SHUT_WR : SHUT_RDWR);
433     dcheck(!close(http->fd) || errno != EBADF);
434     free(http->peerName);
435     free(http->url);
436     free(http->method);
437     free(http->path);
438     free(http->matchedPath);
439     free(http->pathInfo);
440     free(http->query);
441     free(http->version);
442     destroyHashMap(&http->header);
443     free(http->partial);
444     free(http->msg);
445   }
446 }
447 
deleteHttpConnection(struct HttpConnection * http)448 void deleteHttpConnection(struct HttpConnection *http) {
449   destroyHttpConnection(http);
450   free(http);
451 }
452 
453 #ifdef HAVE_ZLIB
httpAcceptsEncoding(struct HttpConnection * http,const char * encoding)454 static int httpAcceptsEncoding(struct HttpConnection *http,
455                                const char *encoding) {
456   int encodingLength  = strlen(encoding);
457   const char *accepts = getFromHashMap(&http->header, "accept-encoding");
458   if (!accepts) {
459     return 0;
460   }
461   double all          = -1.0;
462   double match        = -1.0;
463   while (*accepts) {
464     while (*accepts == ' ' || *accepts == '\t' ||
465            *accepts == '\r' || *accepts == '\n') {
466       accepts++;
467     }
468     const char *ptr   = accepts;
469     while (*ptr && *ptr != ',' && *ptr != ';' &&
470            *ptr != ' ' && *ptr != '\t' &&
471            *ptr != '\r' && *ptr != '\n') {
472       ptr++;
473     }
474     int isAll         = ptr - accepts == 1 && *accepts == '*';
475     int isMatch       = ptr - accepts == encodingLength &&
476                         !strncasecmp(accepts, encoding, encodingLength);
477     while (*ptr && *ptr != ';' && *ptr != ',') {
478       ptr++;
479     }
480     double val        = 1.0;
481     if (*ptr == ';') {
482       ptr++;
483       while (*ptr == ' ' || *ptr == '\t' || *ptr == '\r' || *ptr == '\n') {
484         ptr++;
485       }
486       if ((*ptr | 0x20) == 'q') {
487         ptr++;
488         while (*ptr == ' ' || *ptr == '\t' || *ptr == '\r' || *ptr == '\n') {
489           ptr++;
490         }
491         if (*ptr == '=') {
492           val         = strtod(ptr + 1, (char **)&ptr);
493         }
494       }
495     }
496     if (isnan(val) || val == -HUGE_VAL || val < 0) {
497       val             = 0;
498     } else if (val == HUGE_VAL || val > 1.0) {
499       val             = 1.0;
500     }
501     if (isAll) {
502       all             = val;
503     } else if (isMatch) {
504       match           = val;
505     }
506     while (*ptr && *ptr != ',') {
507       ptr++;
508     }
509     while (*ptr == ',') {
510       ptr++;
511     }
512     accepts           = ptr;
513   }
514   if (match >= 0.0) {
515     return match > 0.0;
516   } else {
517     return all > 0.0;
518   }
519 }
520 #endif
521 
removeHeader(char * header,int * headerLength,const char * id)522 static void removeHeader(char *header, int *headerLength, const char *id) {
523   check(header);
524   check(headerLength);
525   check(*headerLength >= 0);
526   check(id);
527   check(strchr(id, ':'));
528   int idLength       = strlen(id);
529   if (idLength <= 0) {
530     return;
531   }
532   for (char *ptr = header; header + *headerLength - ptr >= idLength; ) {
533     char *end        = ptr;
534     do {
535       end            = memchr(end, '\n', header + *headerLength - end);
536       if (end == NULL) {
537         end          = header + *headerLength;
538       } else {
539         ++end;
540       }
541     } while (end < header + *headerLength && *end == ' ');
542     if (!strncasecmp(ptr, id, idLength)) {
543       memmove(ptr, end, header + *headerLength - end);
544       *headerLength -= end - ptr;
545     } else {
546       ptr            = end;
547     }
548   }
549 }
550 
addHeader(char ** header,int * headerLength,const char * fmt,...)551 static void addHeader(char **header, int *headerLength, const char *fmt, ...) {
552   check(header);
553   check(headerLength);
554   check(*headerLength >= 0);
555   check(strstr(fmt, "\r\n"));
556 
557   va_list ap;
558   va_start(ap, fmt);
559   char *tmp        = vStringPrintf(NULL, fmt, ap);
560   va_end(ap);
561   int tmpLength    = strlen(tmp);
562 
563   if (*headerLength >= 2 && !memcmp(*header + *headerLength - 2, "\r\n", 2)) {
564     *headerLength -= 2;
565   }
566   check(*header    = realloc(*header, *headerLength + tmpLength + 2));
567 
568   memcpy(*header + *headerLength, tmp, tmpLength);
569   memcpy(*header + *headerLength + tmpLength, "\r\n", 2);
570   *headerLength   += tmpLength + 2;
571   free(tmp);
572 }
573 
httpTransfer(struct HttpConnection * http,char * msg,int len)574 void httpTransfer(struct HttpConnection *http, char *msg, int len) {
575   check(msg);
576   check(len >= 0);
577 
578   char *header              = NULL;
579   int headerLength          = 0;
580   int bodyOffset            = 0;
581 
582   int compress              = 0;
583   if (!http->totalWritten) {
584     // Perform some basic sanity checks. This does not necessarily catch all
585     // possible problems, though.
586     int l                   = len;
587     char *line              = msg;
588     for (char *eol, *lastLine = NULL;
589          l > 0 && (eol = memchr(line, '\n', l)) != NULL; ) {
590       // All lines end in CR LF
591       check(eol[-1] == '\r');
592       if (!lastLine) {
593         // The first line looks like "HTTP/1.x STATUS\r\n"
594         check(eol - line > 11);
595         check(!memcmp(line, "HTTP/1.", 7));
596         check(line[7] >= '0' && line[7] <= '9' &&
597               (line[8] == ' ' || line[8] == '\t'));
598         int i               = eol - line - 9;
599         for (char *ptr = line + 9; i-- > 0; ) {
600           char ch           = *ptr++;
601           if (ch < '0' || ch > '9') {
602             check(ptr > line + 10);
603             check(ch == ' ' || ch == '\t');
604             break;
605           }
606         }
607         check(i > 1);
608       } else if (line + 1 == eol) {
609         // Found the end of the headers.
610 
611         // Check that we don't send any data with HEAD requests
612         int isHead          = http->method && !strcmp(http->method, "HEAD");
613         check(l == 2 || !isHead);
614 
615         #ifdef HAVE_ZLIB
616         // Compress replies that might exceed the size of a single IP packet
617         compress            = !isHead &&
618                               !http->isPartialReply &&
619                               len > 1400 &&
620                               httpAcceptsEncoding(http, "gzip");
621         #endif
622         break;
623       } else {
624         // Header lines either contain a colon, or they are continuation
625         // lines
626         if (*line != ' ' && *line != '\t') {
627           check(memchr(line, ':', eol - line));
628         }
629       }
630       lastLine              = line;
631       l                    -= eol - line + 1;
632       line                  = eol + 1;
633     }
634 
635     if (compress) {
636       if (l >= 2 && !memcmp(line, "\r\n", 2)) {
637         line               += 2;
638         l                  -= 2;
639       }
640       headerLength          = line - msg;
641       bodyOffset            = headerLength;
642       check(header          = malloc(headerLength));
643       memcpy(header, msg, headerLength);
644     }
645 
646     if (compress) {
647       #ifdef HAVE_ZLIB
648       // Compress the message
649       char *compressed;
650       check(compressed      = malloc(len));
651       check(len >= bodyOffset + 2);
652       z_stream strm         = { .zalloc    = Z_NULL,
653                                 .zfree     = Z_NULL,
654                                 .opaque    = Z_NULL,
655                                 .avail_in  = l,
656                                 .next_in   = (unsigned char *)line,
657                                 .avail_out = len,
658                                 .next_out  = (unsigned char *)compressed
659                               };
660       if (deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED,
661                        31, 8, Z_DEFAULT_STRATEGY) == Z_OK) {
662         if (deflate(&strm, Z_FINISH) == Z_STREAM_END) {
663           // Compression was successful and resulted in reduction in size
664           debug("[http] Compressed response from %d to %d", len, len-strm.avail_out);
665           free(msg);
666           msg               = compressed;
667           len              -= strm.avail_out;
668           bodyOffset        = 0;
669           removeHeader(header, &headerLength, "content-length:");
670           removeHeader(header, &headerLength, "content-encoding:");
671           addHeader(&header, &headerLength, "Content-Length: %d\r\n", len);
672           addHeader(&header, &headerLength, "Content-Encoding: gzip\r\n");
673         } else {
674           free(compressed);
675         }
676         deflateEnd(&strm);
677       } else {
678         free(compressed);
679       }
680       #endif
681     }
682   }
683 
684   http->totalWritten       += headerLength + (len - bodyOffset);
685   if (!headerLength) {
686     free(header);
687   } else if (http->msg) {
688     check(http->msg         = realloc(http->msg,
689                                       http->msgLength - http->msgOffset +
690                                       max(http->msgOffset, headerLength)));
691     if (http->msgOffset) {
692       memmove(http->msg, http->msg + http->msgOffset,
693               http->msgLength - http->msgOffset);
694       http->msgLength      -= http->msgOffset;
695       http->msgOffset       = 0;
696     }
697     memcpy(http->msg + http->msgLength, header, headerLength);
698     http->msgLength        += headerLength;
699     free(header);
700   } else {
701     check(!http->msgOffset);
702     http->msg               = header;
703     http->msgLength         = headerLength;
704   }
705 
706   if (len <= bodyOffset) {
707     free(msg);
708   } else if (http->msg) {
709     check(http->msg         = realloc(http->msg,
710                                       http->msgLength - http->msgOffset +
711                                       max(http->msgOffset, len - bodyOffset)));
712     if (http->msgOffset) {
713       memmove(http->msg, http->msg + http->msgOffset,
714               http->msgLength - http->msgOffset);
715       http->msgLength      -= http->msgOffset;
716       http->msgOffset       = 0;
717     }
718     memcpy(http->msg + http->msgLength, msg + bodyOffset, len - bodyOffset);
719     http->msgLength        += len - bodyOffset;
720     free(msg);
721   } else {
722     check(!http->msgOffset);
723     if (bodyOffset) {
724       memmove(msg, msg + bodyOffset, len - bodyOffset);
725     }
726     http->msg               = msg;
727     http->msgLength         = len - bodyOffset;
728   }
729 
730   // The caller can suspend the connection, so that it can send an
731   // asynchronous reply. Once the reply has been sent, the connection
732   // gets reactivated. Normally, this means it would go back to listening
733   // for commands.
734   // Similarly, the caller can indicate that this is a partial message and
735   // return additional data in subsequent calls to the callback handler.
736   if (http->isSuspended || http->isPartialReply) {
737     if (http->msg && http->msgLength > 0) {
738       int wrote             = httpWrite(http, http->msg, http->msgLength);
739       if (wrote < 0 && errno != EAGAIN) {
740         httpCloseRead(http);
741         free(http->msg);
742         http->msgLength     = 0;
743         http->msg           = NULL;
744       } else if (wrote > 0) {
745         if (wrote == http->msgLength) {
746           free(http->msg);
747           http->msgLength   = 0;
748           http->msg         = NULL;
749         } else {
750           memmove(http->msg, http->msg + wrote, http->msgLength - wrote);
751           http->msgLength  -= wrote;
752         }
753       }
754     }
755 
756     check(http->state == PAYLOAD || http->state == DISCARD_PAYLOAD);
757     if (!http->isPartialReply) {
758       if (http->expecting < 0) {
759         // If we do not know the length of the content, close the connection.
760         debug("[http] Closing previously suspended connection!");
761         httpCloseRead(http);
762         httpSetState(http, DISCARD_PAYLOAD);
763       } else if (http->expecting == 0) {
764         httpSetState(http, COMMAND);
765         http->isSuspended  = 0;
766         struct ServerConnection *connection = httpGetServerConnection(http);
767         if (!serverGetTimeout(connection)) {
768           serverSetTimeout(connection, CONNECTION_TIMEOUT);
769         }
770         serverConnectionSetEvents(http->server, connection, http->fd,
771                                   http->msgLength ? POLLIN|POLLOUT : POLLIN);
772       }
773     }
774   }
775 }
776 
httpTransferPartialReply(struct HttpConnection * http,char * msg,int len)777 void httpTransferPartialReply(struct HttpConnection *http, char *msg, int len){
778   check(!http->isSuspended);
779   http->isPartialReply = 1;
780   if (http->state != PAYLOAD && http->state != DISCARD_PAYLOAD) {
781     check(http->state == HEADERS);
782     httpSetState(http, PAYLOAD);
783   }
784   httpTransfer(http, msg, len);
785 }
786 
httpHandleCommand(struct HttpConnection * http,const struct Trie * handlers)787 static int httpHandleCommand(struct HttpConnection *http,
788                              const struct Trie *handlers) {
789   debug("[http] Handling \"%s\" \"%s\"", http->method, http->path);
790   const char *contentLength                  = getFromHashMap(&http->header,
791                                                              "content-length");
792   if (contentLength != NULL && *contentLength) {
793     char *endptr;
794     http->expecting                          = strtol(contentLength,
795                                                       &endptr, 10);
796     if (*endptr) {
797       // Invalid length. Read until end of stream and then close
798       // connection.
799       http->expecting                        = -1;
800     }
801   } else {
802       // Unknown length. Read until end of stream and then close
803       // connection.
804     http->expecting                          = -1;
805   }
806   if (!strcmp(http->method, "OPTIONS")) {
807     char *response                           = stringPrintf(NULL,
808                                                 "HTTP/1.1 200 OK\r\n"
809                                                 "Content-Length: 0\r\n"
810                                                 "Allow: GET, POST, OPTIONS\r\n"
811                                                 "\r\n");
812     httpTransfer(http, response, strlen(response));
813     if (http->expecting < 0) {
814       http->expecting                        = 0;
815     }
816     return HTTP_READ_MORE;
817   } else if (!strcmp(http->method, "GET")) {
818     if (http->expecting < 0) {
819       http->expecting                        = 0;
820     }
821   } else if (!strcmp(http->method, "POST")) {
822   } else if (!strcmp(http->method, "HEAD")) {
823     if (http->expecting < 0) {
824       http->expecting                        = 0;
825     }
826   } else if (!strcmp(http->method, "PUT")    ||
827              !strcmp(http->method, "DELETE") ||
828              !strcmp(http->method, "TRACE")  ||
829              !strcmp(http->method, "CONNECT")) {
830     httpSendReply(http, 405, "Method Not Allowed", NO_MSG);
831     return HTTP_DONE;
832   } else {
833     httpSendReply(http, 501, "Method Not Implemented", NO_MSG);
834     return HTTP_DONE;
835   }
836   const char *host                           = getFromHashMap(&http->header,
837                                                               "host");
838   if (host) {
839     for (char ch, *ptr = (char *)host; (ch = *ptr) != '\000'; ptr++) {
840       if (ch == ':') {
841         *ptr                                 = '\000';
842         break;
843       }
844       if (ch != '-' && ch != '.' &&
845           (ch < '0' ||(ch > '9' && ch < 'A') ||
846           (ch > 'Z' && ch < 'a')||(ch > 'z' && ch <= 0x7E))) {
847         httpSendReply(http, 400, "Bad Request", NO_MSG);
848         return HTTP_DONE;
849       }
850     }
851   }
852 
853   char *diff;
854   struct HttpHandler *h = (struct HttpHandler *)getFromTrie(handlers,
855                                                             http->path, &diff);
856 
857   if (h) {
858     if (h->websocketHandler) {
859       // Check for WebSocket handshake
860       const char *upgrade                    = getFromHashMap(&http->header,
861                                                               "upgrade");
862       if (upgrade && !strcmp(upgrade, "WebSocket")) {
863         const char *connection               = getFromHashMap(&http->header,
864                                                               "connection");
865         if (connection && !strcmp(connection, "Upgrade")) {
866           const char *origin                 = getFromHashMap(&http->header,
867                                                               "origin");
868           if (origin) {
869             for (const char *ptr = origin; *ptr; ptr++) {
870               if ((unsigned char)*ptr < ' ') {
871                 goto bad_ws_upgrade;
872               }
873             }
874 
875             const char *protocol             = getFromHashMap(&http->header,
876                                                          "websocket-protocol");
877             if (protocol) {
878               for (const char *ptr = protocol; *ptr; ptr++) {
879                 if ((unsigned char)*ptr < ' ') {
880                   goto bad_ws_upgrade;
881                 }
882               }
883             }
884             char *port                       = NULL;
885             if (http->port != (http->sslHndl ? 443 : 80)) {
886               port                           = stringPrintf(NULL,
887                                                             ":%d", http->port);
888             }
889             char *response                   = stringPrintf(NULL,
890               "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
891               "Upgrade: WebSocket\r\n"
892               "Connection: Upgrade\r\n"
893               "WebSocket-Origin: %s\r\n"
894               "WebSocket-Location: %s://%s%s%s\r\n"
895               "%s%s%s"
896               "\r\n",
897               origin,
898               http->sslHndl ? "wss" : "ws", host && *host ? host : "localhost",
899               port ? port : "", http->path,
900               protocol ? "WebSocket-Protocol: " : "",
901               protocol ? protocol : "",
902               protocol ? "\r\n" : "");
903             free(port);
904             debug("[http] Switching to WebSockets");
905             httpTransfer(http, response, strlen(response));
906             if (http->expecting < 0) {
907               http->expecting                = 0;
908             }
909             http->websocketHandler           = h->websocketHandler;
910             httpSetState(http, WEBSOCKET);
911             return HTTP_READ_MORE;
912           }
913         }
914       }
915     }
916   bad_ws_upgrade:;
917 
918     if (h->handler) {
919       check(diff);
920       while (diff > http->path && diff[-1] == '/') {
921         diff--;
922       }
923       if (!*diff || *diff == '/' || *diff == '?' || *diff == '#') {
924         check(!http->matchedPath);
925         check(!http->pathInfo);
926         check(!http->query);
927 
928         check(http->matchedPath              = malloc(diff - http->path + 1));
929         memcpy(http->matchedPath, http->path, diff - http->path);
930         http->matchedPath[diff - http->path] = '\000';
931 
932         const char *query = strchr(diff, '?');
933         if (*diff && *diff != '?') {
934           const char *endOfInfo              = query
935                                                ? query : strrchr(diff, '\000');
936           check(http->pathInfo               = malloc(endOfInfo - diff + 1));
937           memcpy(http->pathInfo, diff, endOfInfo - diff);
938           http->pathInfo[endOfInfo - diff]   = '\000';
939         }
940 
941         if (query) {
942           check(http->query                  = strdup(query + 1));
943         }
944         return h->handler(http, h->arg);
945       }
946     }
947   }
948   httpSendReply(http, 404, "File Not Found", NO_MSG);
949   return HTTP_DONE;
950 }
951 
httpGetChar(struct HttpConnection * http,const char * buf,int size,int * offset)952 static int httpGetChar(struct HttpConnection *http, const char *buf,
953                        int size, int *offset) {
954   if (*offset < 0) {
955     return (unsigned char)http->partial[http->partialLength + (*offset)++];
956   } else if (*offset < size) {
957     return (unsigned char)buf[(*offset)++];
958   } else {
959     return -1;
960   }
961 }
962 
httpParseCommand(struct HttpConnection * http,int offset,const char * buf,int bytes,int firstSpace,int lastSpace,int lineLength)963 static int httpParseCommand(struct HttpConnection *http, int offset,
964                             const char *buf, int bytes, int firstSpace,
965                             int lastSpace, int lineLength) {
966   if (firstSpace < 1 || lastSpace < 0) {
967   bad_request:
968     if (!http->method) {
969       check(http->method  = strdup(""));
970     }
971     if (!http->path) {
972       check(http->path    = strdup(""));
973     }
974     if (!http->version) {
975       check(http->version = strdup(""));
976     }
977     httpSendReply(http, 400, "Bad Request", NO_MSG);
978     httpSetState(http, COMMAND);
979     return 0;
980   }
981   check(!http->method);
982   check(http->method      = malloc(firstSpace + 1));
983   int i                   = offset;
984   int j                   = 0;
985   for (; j < firstSpace; j++) {
986     int ch                = httpGetChar(http, buf, bytes, &i);
987     if (ch >= 'a' && ch <= 'z') {
988       ch                 &= ~0x20;
989     }
990     http->method[j]       = ch;
991   }
992   http->method[j]         = '\000';
993   check(!http->path);
994   check(http->path        = malloc(lastSpace - firstSpace));
995   j                       = 0;
996   while (i < offset + lastSpace) {
997     int ch                = httpGetChar(http, buf, bytes, &i);
998     if ((ch != ' ' && ch != '\t') || j) {
999       http->path[j++]     = ch;
1000     }
1001   }
1002   http->path[j]           = '\000';
1003   if (*http->path != '/' &&
1004       (strcmp(http->method, "OPTIONS") || strcmp(http->path, "*"))) {
1005     goto bad_request;
1006   }
1007   check(!http->version);
1008   check(http->version     = malloc(lineLength - lastSpace + 1));
1009   j                       = 0;
1010   while (i < offset + lineLength) {
1011     int ch                = httpGetChar(http, buf, bytes, &i);
1012     if (ch == '\r') {
1013       break;
1014     }
1015     if (ch >= 'a' && ch <= 'z') {
1016       ch                 &= ~0x20;
1017     }
1018     if ((ch != ' ' && ch != '\t') || j) {
1019       http->version[j]    = ch;
1020       j++;
1021     }
1022   }
1023   http->version[j]        = '\000';
1024   if (memcmp(http->version, "HTTP/", 5) ||
1025       (http->version[5] < '1' || http->version[5] > '9')) {
1026     goto bad_request;
1027   }
1028   httpSetState(http, HEADERS);
1029   return 1;
1030 }
1031 
httpParseHeaders(struct HttpConnection * http,const struct Trie * handlers,int offset,const char * buf,int bytes,int colon,int lineLength)1032 static int httpParseHeaders(struct HttpConnection *http,
1033                             const struct Trie *handlers, int offset,
1034                             const char *buf, int bytes, int colon,
1035                             int lineLength) {
1036   int i                    = offset;
1037   int ch                   = httpGetChar(http, buf, bytes, &i);
1038   if (ch == ' ' || ch == '\t') {
1039     if (http->key) {
1040       char **oldValue      = getRefFromHashMap(&http->header, http->key);
1041       check(oldValue);
1042       int oldLength        = strlen(*oldValue);
1043       check(*oldValue      = realloc(*oldValue,
1044                                     oldLength + lineLength + 1));
1045       int j                = oldLength;
1046       int end              = oldLength + lineLength;
1047       (*oldValue)[j++]     = ' ';
1048       for (; j < end; j++) {
1049         ch                 = httpGetChar(http, buf, bytes, &i);
1050         if (ch == ' ' || ch == '\t') {
1051           end--;
1052           j--;
1053           continue;
1054         } else if (ch == '\r' && j == end - 1) {
1055           break;
1056         }
1057         (*oldValue)[j]     = ch;
1058       }
1059       (*oldValue)[j]       = '\000';
1060     }
1061   } else if ((ch == '\r' &&
1062               httpGetChar(http, buf, bytes, &i) == '\n') ||
1063              ch == '\n' || ch == -1) {
1064     check(!http->expecting);
1065     http->callback         = NULL;
1066     http->arg              = NULL;
1067     int rc                 = httpHandleCommand(http, handlers);
1068   retry:;
1069     struct ServerConnection *connection = httpGetServerConnection(http);
1070     switch (rc) {
1071     case HTTP_DONE:
1072     case HTTP_ERROR: {
1073       if (http->expecting < 0 || rc == HTTP_ERROR) {
1074         httpCloseRead(http);
1075       }
1076       http->done           = 1;
1077       http->isSuspended    = 0;
1078       http->isPartialReply = 0;
1079       if (!serverGetTimeout(connection)) {
1080         serverSetTimeout(connection, CONNECTION_TIMEOUT);
1081       }
1082       httpSetState(http, http->expecting ? DISCARD_PAYLOAD : COMMAND);
1083       break; }
1084     case HTTP_READ_MORE:
1085       http->isSuspended    = 0;
1086       http->isPartialReply = 0;
1087       if (!serverGetTimeout(connection)) {
1088         serverSetTimeout(connection, CONNECTION_TIMEOUT);
1089       }
1090       check(!http->done);
1091       if (!http->expecting) {
1092         if (http->callback) {
1093           rc                 = http->callback(http, http->arg, "", 0);
1094           if (rc != HTTP_READ_MORE) {
1095             goto retry;
1096           }
1097         } else if (http->websocketHandler) {
1098           http->websocketHandler(http, http->arg, WS_CONNECTION_OPENED,
1099                                  NULL, 0);
1100         }
1101       }
1102       if (http->state != WEBSOCKET) {
1103         httpSetState(http, http->expecting ? PAYLOAD : COMMAND);
1104       }
1105       break;
1106     case HTTP_SUSPEND:
1107       http->isSuspended    = 1;
1108       http->isPartialReply = 0;
1109       serverSetTimeout(connection, 0);
1110       if (http->state != PAYLOAD && http->state != DISCARD_PAYLOAD) {
1111         check(http->state == HEADERS);
1112         httpSetState(http, PAYLOAD);
1113       }
1114       break;
1115     case HTTP_PARTIAL_REPLY:
1116       http->isSuspended    = 0;
1117       http->isPartialReply = 1;
1118       if (http->state != PAYLOAD && http->state != DISCARD_PAYLOAD) {
1119         check(http->state == HEADERS);
1120         httpSetState(http, PAYLOAD);
1121       }
1122       break;
1123     default:
1124       check(0);
1125     }
1126     if (ch == -1) {
1127       httpCloseRead(http);
1128     }
1129   } else {
1130     if (colon <= 0) {
1131       httpSendReply(http, 400, "Bad Request", NO_MSG);
1132       return 0;
1133     }
1134     check(colon < lineLength);
1135     check(http->key        = malloc(colon + 1));
1136     int i                  = offset;
1137     for (int j = 0; j < colon; j++) {
1138       ch                   = httpGetChar(http, buf, bytes, &i);
1139       if (ch >= 'A' && ch <= 'Z') {
1140         ch                |= 0x20;
1141       }
1142       http->key[j]         = ch;
1143     }
1144     http->key[colon]       = '\000';
1145     char *value;
1146     check(value            = malloc(lineLength - colon));
1147     i++;
1148     int j                  = 0;
1149     for (int k = 0; k < lineLength - colon - 1; j++, k++) {
1150       int ch           = httpGetChar(http, buf, bytes, &i);
1151       if ((ch == ' ' || ch == '\t') && j == 0) {
1152         j--;
1153       } else if (ch == '\r' && k == lineLength - colon - 2) {
1154         break;
1155       } else {
1156         value[j]           = ch;
1157       }
1158     }
1159     value[j]               = '\000';
1160     if (getRefFromHashMap(&http->header, http->key)) {
1161       debug("[http] Dropping duplicate header \"%s\"", http->key);
1162       free(http->key);
1163       free(value);
1164       http->key            = NULL;
1165     } else {
1166       addToHashMap(&http->header, http->key, value);
1167     }
1168   }
1169   return 1;
1170 }
1171 
httpConsumePayload(struct HttpConnection * http,const char * buf,int len)1172 static int httpConsumePayload(struct HttpConnection *http, const char *buf,
1173                               int len) {
1174   if (http->expecting >= 0) {
1175     // If positive, we know the expected length of payload and
1176     // can keep the connection open.
1177     // If negative, allow unlimited payload, but close connection
1178     // when done.
1179     if (len > http->expecting) {
1180       len                  = http->expecting;
1181     }
1182     http->expecting       -= len;
1183   }
1184   if (http->callback) {
1185     check(!http->done);
1186     int rc                 = http->callback(http, http->arg, buf, len);
1187     struct ServerConnection *connection = httpGetServerConnection(http);
1188     switch (rc) {
1189     case HTTP_DONE:
1190     case HTTP_ERROR:
1191       if (http->expecting < 0 || rc == HTTP_ERROR) {
1192         httpCloseRead(http);
1193       }
1194       http->done           = 1;
1195       http->isSuspended    = 0;
1196       http->isPartialReply = 0;
1197       if (!serverGetTimeout(connection)) {
1198         serverSetTimeout(connection, CONNECTION_TIMEOUT);
1199       }
1200       httpSetState(http, http->expecting ? DISCARD_PAYLOAD : COMMAND);
1201       break;
1202     case HTTP_READ_MORE:
1203       http->isSuspended    = 0;
1204       http->isPartialReply = 0;
1205       if (!serverGetTimeout(connection)) {
1206         serverSetTimeout(connection, CONNECTION_TIMEOUT);
1207       }
1208       if (!http->expecting) {
1209         httpSetState(http, COMMAND);
1210       }
1211       break;
1212     case HTTP_SUSPEND:
1213       http->isSuspended    = 1;
1214       http->isPartialReply = 0;
1215       serverSetTimeout(connection, 0);
1216       if (http->state != PAYLOAD && http->state != DISCARD_PAYLOAD) {
1217         check(http->state == HEADERS);
1218         httpSetState(http, PAYLOAD);
1219       }
1220       break;
1221     case HTTP_PARTIAL_REPLY:
1222       http->isSuspended    = 0;
1223       http->isPartialReply = 1;
1224       if (http->state != PAYLOAD && http->state != DISCARD_PAYLOAD) {
1225         check(http->state == HEADERS);
1226         httpSetState(http, PAYLOAD);
1227       }
1228       break;
1229     default:
1230       check(0);
1231     }
1232   } else {
1233     // If we do not have a callback for handling the payload, and we also do
1234     // not know how long the payload is (because there was not Content-Length),
1235     // we now close the connection.
1236     if (http->expecting < 0) {
1237       http->expecting      = 0;
1238       httpCloseRead(http);
1239       httpSetState(http, COMMAND);
1240     }
1241   }
1242   return len;
1243 }
1244 
httpParsePayload(struct HttpConnection * http,int offset,const char * buf,int bytes)1245 static int httpParsePayload(struct HttpConnection *http, int offset,
1246                             const char *buf, int bytes) {
1247   int consumed               = 0;
1248   if (offset < 0) {
1249     check(-offset <= http->partialLength);
1250     if (http->expecting) {
1251       consumed               = httpConsumePayload(http,
1252                                   http->partial + http->partialLength + offset,
1253                                   -offset);
1254       if (consumed == http->partialLength) {
1255         free(http->partial);
1256         http->partial        = NULL;
1257         http->partialLength  = 0;
1258       } else {
1259         memmove(http->partial, http->partial + consumed,
1260                 http->partialLength - consumed);
1261         http->partialLength -= consumed;
1262       }
1263       offset                += consumed;
1264     }
1265   }
1266   if (http->expecting && bytes - offset > 0) {
1267     check(offset >= 0);
1268     consumed                += httpConsumePayload(http, buf + offset,
1269                                                   bytes - offset);
1270   }
1271   return consumed;
1272 }
1273 
httpHandleWebSocket(struct HttpConnection * http,int offset,const char * buf,int bytes)1274 static int httpHandleWebSocket(struct HttpConnection *http, int offset,
1275                                const char *buf, int bytes) {
1276   check(http->websocketHandler);
1277   int ch                          = 0x00;
1278   while (bytes > offset) {
1279     if (http->websocketType & WS_UNDEFINED) {
1280       ch                          = httpGetChar(http, buf, bytes, &offset);
1281       check(ch >= 0);
1282       if (http->websocketType & 0xFF) {
1283         // Reading another byte of length information.
1284         if (http->expecting > 0xFFFFFF) {
1285           return 0;
1286         }
1287         http->expecting           = (128 * http->expecting) + (ch & 0x7F);
1288         if ((ch & 0x80) == 0) {
1289           // Done reading length information.
1290           http->websocketType    &= ~WS_UNDEFINED;
1291 
1292           // ch is used to detect when we read the terminating byte in text
1293           // mode. In binary mode, it must be set to something other than 0xFF.
1294           ch                      = 0x00;
1295         }
1296       } else {
1297         // Reading first byte of frame.
1298         http->websocketType       = (ch & 0xFF) | WS_START_OF_FRAME;
1299         if (ch & 0x80) {
1300           // For binary data, we have to read the length before we can start
1301           // processing payload.
1302           http->websocketType    |= WS_UNDEFINED;
1303           http->expecting         = 0;
1304         }
1305       }
1306     } else if (http->websocketType & 0x80) {
1307       // Binary data
1308       if (http->expecting) {
1309         if (offset < 0) {
1310         handle_partial:
1311           check(-offset <= http->partialLength);
1312           int len                 = -offset;
1313           if (len >= http->expecting) {
1314             len                   = http->expecting;
1315             http->websocketType  |= WS_END_OF_FRAME;
1316           }
1317           if (len &&
1318               http->websocketHandler(http, http->arg, http->websocketType,
1319                                   http->partial + http->partialLength + offset,
1320                                   len) != HTTP_DONE) {
1321             return 0;
1322           }
1323 
1324           if (ch == 0xFF) {
1325             // In text mode, we jump to handle_partial, when we find the
1326             // terminating 0xFF byte. If so, we should try to consume it now.
1327             if (len < http->partialLength) {
1328               len++;
1329               http->websocketType = WS_UNDEFINED;
1330             }
1331           }
1332 
1333           if (len == http->partialLength) {
1334             free(http->partial);
1335             http->partial         = NULL;
1336             http->partialLength   = 0;
1337           } else {
1338             memmove(http->partial, http->partial + len,
1339                     http->partialLength - len);
1340             http->partialLength  -= len;
1341           }
1342           offset                 += len;
1343           http->expecting        -= len;
1344         } else {
1345         handle_buffered:;
1346           int len                 = bytes - offset;
1347           if (len >= http->expecting) {
1348             len                   = http->expecting;
1349             http->websocketType  |= WS_END_OF_FRAME;
1350           }
1351           if (len &&
1352               http->websocketHandler(http, http->arg, http->websocketType,
1353                                      buf + offset, len) != HTTP_DONE) {
1354             return 0;
1355           }
1356 
1357           if (ch == 0xFF) {
1358             // In text mode, we jump to handle_buffered, when we find the
1359             // terminating 0xFF byte. If so, we should consume it now.
1360             check(offset + len < bytes);
1361             len++;
1362             http->websocketType   = WS_UNDEFINED;
1363           }
1364           offset                 += len;
1365           http->expecting        -= len;
1366         }
1367         http->websocketType      &= ~(WS_START_OF_FRAME | WS_END_OF_FRAME);
1368       } else {
1369         // Read all data. Go back to looking for a new frame header.
1370         http->websocketType       = WS_UNDEFINED;
1371       }
1372     } else {
1373       // Process text data until we find a 0xFF bytes.
1374       int i                       = offset;
1375 
1376       // If we have partial data, process that first.
1377       while (i < 0) {
1378         ch                        = httpGetChar(http, buf, bytes, &i);
1379         check(ch != -1);
1380 
1381         // Terminate when we either find the 0xFF, or we have reached the end
1382         // of partial data.
1383         if (ch == 0xFF || !i) {
1384           // Set WS_END_OF_FRAME, iff we have found the 0xFF marker.
1385           http->expecting         = i - offset - (ch == 0xFF);
1386           goto handle_partial;
1387         }
1388       }
1389 
1390       // Read all remaining buffered bytes (i.e. positive offset).
1391       while (bytes > i) {
1392         ch                        = httpGetChar(http, buf, bytes, &i);
1393         check(ch != -1);
1394 
1395         // Terminate when we either find the 0xFF, or we have reached the end
1396         // of buffered data.
1397         if (ch == 0xFF || bytes == i) {
1398           // Set WS_END_OF_FRAME, iff we have found the 0xFF marker.
1399           http->expecting         = i - offset - (ch == 0xFF);
1400           goto handle_buffered;
1401         }
1402       }
1403     }
1404   }
1405   return 1;
1406 }
1407 
httpHandleConnection(struct ServerConnection * connection,void * http_,short * events,short revents)1408 int httpHandleConnection(struct ServerConnection *connection, void *http_,
1409                          short *events, short revents) {
1410   struct HttpConnection *http        = (struct HttpConnection *)http_;
1411   struct Trie *handlers              = serverGetHttpHandlers(http->server);
1412   http->connection                   = connection;
1413   int  bytes;
1414   do {
1415     bytes                            = 0;
1416     *events                          = 0;
1417     char buf[4096];
1418     int  eof                         = http->closed;
1419     if ((revents & POLLIN) && !http->closed) {
1420       bytes                          = httpRead(http, buf, sizeof(buf));
1421       if (bytes > 0) {
1422         http->headerLength          += bytes;
1423         if (http->headerLength > MAX_HEADER_LENGTH) {
1424           debug("[http] Connection closed due to exceeded header size!");
1425           httpSendReply(http, 413, "Header too big", NO_MSG);
1426           bytes                      = 0;
1427           eof                        = 1;
1428         }
1429       } else {
1430         if (bytes == 0 || errno != EAGAIN) {
1431           httpCloseRead(http);
1432           eof                        = 1;
1433         } else {
1434           if (http->sslHndl && http->lastError == SSL_ERROR_WANT_WRITE) {
1435             *events                 |= POLLOUT;
1436           }
1437         }
1438         bytes                        = 0;
1439       }
1440     }
1441 
1442     if (bytes > 0 && http->state == SNIFFING_SSL) {
1443       // Assume that all legitimate HTTP commands start with a sequence of
1444       // letters followed by a space character. If we don't see this pattern,
1445       // or if the method does not match one of the known methods, we try
1446       // switching to SSL, instead.
1447       int isSSL                      = 0;
1448       char method[12]                = { 0 };
1449       for (int i = -http->partialLength, j = 0, ch;
1450            (ch = httpGetChar(http, buf, bytes, &i)) != -1;
1451            j++) {
1452         if ((j > 0 && (ch == ' ' || ch == '\t')) ||
1453             ch == '\r' || ch == '\n') {
1454           isSSL                      = strcmp(method, "OPTIONS") &&
1455                                        strcmp(method, "GET") &&
1456                                        strcmp(method, "HEAD") &&
1457                                        strcmp(method, "POST") &&
1458                                        strcmp(method, "PUT") &&
1459                                        strcmp(method, "DELETE") &&
1460                                        strcmp(method, "TRACE") &&
1461                                        strcmp(method, "CONNECT");
1462           http->state                = COMMAND;
1463           break;
1464         } else if (j >= (int)sizeof(method)-1 ||
1465                    ch < 'A' || (ch > 'Z' && ch < 'a') || ch > 'z') {
1466           isSSL                      = 1;
1467           http->state                = COMMAND;
1468           break;
1469         } else {
1470           method[j]                  = ch & ~0x20;
1471         }
1472       }
1473       if (isSSL) {
1474         if (httpPromoteToSSL(http, buf, bytes) < 0) {
1475           httpCloseRead(http);
1476           bytes                      = 0;
1477           eof                        = 1;
1478         } else {
1479           http->headerLength         = 0;
1480           *events                   |= POLLIN;
1481           continue;
1482         }
1483       } else {
1484         if (http->ssl && http->ssl->enabled && http->ssl->force) {
1485           debug("[http] Non-SSL connections not allowed!");
1486           httpCloseRead(http);
1487           bytes                      = 0;
1488           eof                        = 1;
1489         }
1490       }
1491     }
1492 
1493     if (bytes > 0 || (eof && http->partial)) {
1494       check(!!http->partial == !!http->partialLength);
1495       int  offset                    = -http->partialLength;
1496       int  eob                       = 0;
1497       do {
1498         int pushBack                 = 0;
1499         int consumed                 = 0;
1500         if (http->state == SNIFFING_SSL || http->state == COMMAND ||
1501             http->state == HEADERS) {
1502           check(!http->expecting);
1503           int  lineLength            = 0;
1504           int  colon                 = -1;
1505           int  firstSpace            = -1;
1506           int  lastSpace             = -1;
1507           int  fullLine              = 1;
1508           for (int i = offset; ; lineLength++) {
1509             int ch                   = httpGetChar(http, buf, bytes, &i);
1510             if (ch == ':') {
1511               if (colon < 0) {
1512                 colon                = lineLength;
1513               }
1514             } else if (ch == ' ' || ch == '\t') {
1515               if (firstSpace < 0) {
1516                 firstSpace           = lineLength;
1517               } else {
1518                 lastSpace            = lineLength;
1519               }
1520             } else if (ch == '\n') {
1521               break;
1522             } else if (ch == -1) {
1523               fullLine               = 0;
1524               eob                    = 1;
1525               break;
1526             }
1527           }
1528           if (fullLine || eof) {
1529             consumed                 = lineLength + 1;
1530             if (lineLength) {
1531               if (http->state == SNIFFING_SSL || http->state == COMMAND) {
1532                 if (!httpParseCommand(http, offset, buf, bytes, firstSpace,
1533                                       lastSpace, lineLength)) {
1534                   break;
1535                 }
1536               } else {
1537                 check(http->state == HEADERS);
1538                 if (!httpParseHeaders(http, handlers, offset, buf, bytes,
1539                                       colon, lineLength)) {
1540                   break;
1541                 }
1542               }
1543             }
1544           } else {
1545             pushBack                 = lineLength;
1546           }
1547         } else if (http->state == PAYLOAD ||
1548                    http->state == DISCARD_PAYLOAD) {
1549           if (http->expecting) {
1550             int len                  = bytes - offset;
1551             if (http->expecting > 0 &&
1552                 len > http->expecting) {
1553               len                    = http->expecting;
1554             }
1555             if (http->state == PAYLOAD) {
1556               len                    = httpParsePayload(http, offset, buf,
1557                                                         len + offset);
1558             }
1559             consumed                 = len;
1560             pushBack                 = bytes - offset - len;
1561           }
1562         } else if (http->state == WEBSOCKET) {
1563           if (!httpHandleWebSocket(http, offset, buf, bytes)) {
1564             httpCloseRead(http);
1565             break;
1566           }
1567           consumed                  += bytes - offset;
1568         } else {
1569           check(0);
1570         }
1571 
1572         offset                      += consumed;
1573         if (pushBack) {
1574           check(offset + pushBack == bytes);
1575           if (offset >= 0) {
1576             check(http->partial      = realloc(http->partial, pushBack));
1577             memcpy(http->partial, buf + offset, pushBack);
1578           } else if (pushBack != http->partialLength) {
1579             char *partial;
1580             check(partial            = malloc(pushBack));
1581             for (int i = offset, j = 0; j < pushBack; j++) {
1582               partial[j]             = httpGetChar(http, buf, bytes, &i);
1583             }
1584             free(http->partial);
1585             http->partial            = partial;
1586           }
1587           http->partialLength        = pushBack;
1588           offset                     = -pushBack;
1589           break;
1590         } else {
1591           eob                       |= offset >= bytes;
1592         }
1593       } while (!eob && !http->closed);
1594       if (http->closed || offset >= 0) {
1595         free(http->partial);
1596         http->partial                = NULL;
1597         http->partialLength          = 0;
1598       } else if (-offset != http->partialLength) {
1599         check(-offset < http->partialLength);
1600         memmove(http->partial, http->partial + http->partialLength + offset,
1601                 -offset);
1602         http->partialLength          = -offset;
1603       }
1604     }
1605 
1606     // If the peer closed the connection, clean up now.
1607     if (eof) {
1608       check(!http->partial);
1609       switch (http->state) {
1610       case SNIFFING_SSL:
1611       case COMMAND:
1612         break;
1613       case HEADERS:
1614         check(!http->expecting);
1615         http->callback               = NULL;
1616         http->arg                    = NULL;
1617         httpHandleCommand(http, handlers);
1618         httpCloseRead(http);
1619         httpSetState(http, COMMAND);
1620         break;
1621       case PAYLOAD:
1622       case DISCARD_PAYLOAD:
1623       case WEBSOCKET:
1624         http->expecting              = 0;
1625         httpCloseRead(http);
1626         httpSetState(http, COMMAND);
1627         break;
1628       }
1629     }
1630 
1631     for (;;) {
1632       // Try to write any pending outgoing data
1633       if (http->msg && http->msgLength > 0) {
1634         int wrote                    = httpWrite(http, http->msg,
1635                                                  http->msgLength);
1636         if (wrote < 0 && errno != EAGAIN) {
1637           httpCloseRead(http);
1638           free(http->msg);
1639           http->msgLength            = 0;
1640           http->msg                  = NULL;
1641           break;
1642         } else if (wrote > 0) {
1643           if (wrote == http->msgLength) {
1644             free(http->msg);
1645             http->msgLength          = 0;
1646             http->msg                = NULL;
1647           } else {
1648             memmove(http->msg, http->msg + wrote, http->msgLength - wrote);
1649             http->msgLength         -= wrote;
1650           }
1651         }
1652         // SSL might require reading in order to write
1653         else if (wrote < 0 && errno == EAGAIN && http->sslHndl) {
1654           if (http->lastError == SSL_ERROR_WANT_READ && !http->closed) {
1655             *events                 |= POLLIN;
1656           }
1657         }
1658       }
1659 
1660       // If the callback only provided partial data, refill the outgoing
1661       // buffer whenever it runs low.
1662       if (http->isPartialReply && (!http->msg || http->msgLength <= 0)) {
1663         httpConsumePayload(http, "", 0);
1664       } else {
1665         break;
1666       }
1667     }
1668 
1669     *events                         |=
1670       (*events & ~(POLLIN|POLLOUT)) |
1671       (!http->closed && ((http->state != PAYLOAD &&
1672                           http->state != DISCARD_PAYLOAD) ||
1673                          http->expecting) ? POLLIN : 0) |
1674       (http->msg || http->isPartialReply ? POLLOUT : 0);
1675 
1676     connection                       = httpGetServerConnection(http);
1677     int timedOut                     = serverGetTimeout(connection) < 0;
1678     if (timedOut) {
1679       free(http->partial);
1680       http->partial                  = NULL;
1681       http->partialLength            = 0;
1682       free(http->msg);
1683       http->msg                      = NULL;
1684       http->msgLength                = 0;
1685     }
1686 
1687     if ((!(*events || http->isSuspended) || timedOut) && http->sslHndl) {
1688       *events                        = 0;
1689       serverSetTimeout(connection, 1);
1690       int wasAlreadyClosed           = http->closed;
1691       httpCloseRead(http);
1692       dcheck(!ERR_peek_error());
1693       sslBlockSigPipe();
1694       int rc                         = SSL_shutdown(http->sslHndl);
1695       switch (rc) {
1696       case 1:
1697         sslFreeHndl(&http->sslHndl);
1698         break;
1699       case 0:
1700         if (!wasAlreadyClosed) {
1701           *events                   |= POLLIN;
1702         }
1703         break;
1704       case -1:
1705         switch (SSL_get_error(http->sslHndl, rc)) {
1706         case SSL_ERROR_WANT_READ:
1707           if (!wasAlreadyClosed) {
1708             *events                 |= POLLIN;
1709           }
1710           break;
1711         case SSL_ERROR_WANT_WRITE:
1712           *events                   |= POLLOUT;
1713           break;
1714         }
1715         break;
1716       }
1717       ERR_clear_error();
1718       dcheck(!ERR_peek_error());
1719       if (sslUnblockSigPipe()) {
1720         *events                      = 0;
1721         sslFreeHndl(&http->sslHndl);
1722       }
1723     } else if (!http->sslHndl && timedOut) {
1724       *events                        = 0;
1725       serverSetTimeout(connection, 0);
1726       httpCloseRead(http);
1727     }
1728     revents                          = POLLIN | POLLOUT;
1729   } while (bytes > 0 && *events & POLLIN && !http->closed);
1730   return (*events & (POLLIN|POLLOUT)) ||
1731          (!http->closed && http->isSuspended);
1732 }
1733 
httpSetCallback(struct HttpConnection * http,int (* callback)(struct HttpConnection *,void *,const char *,int),void * arg)1734 void httpSetCallback(struct HttpConnection *http,
1735                      int (*callback)(struct HttpConnection *, void *,
1736                                      const char *, int), void *arg) {
1737   http->callback = callback;
1738   http->arg      = arg;
1739 }
1740 
httpGetPrivate(struct HttpConnection * http)1741 void *httpGetPrivate(struct HttpConnection *http) {
1742   return http->private;
1743 }
1744 
httpSetPrivate(struct HttpConnection * http,void * private)1745 void *httpSetPrivate(struct HttpConnection *http, void *private) {
1746   void *old     = http->private;
1747   http->private = private;
1748   return old;
1749 }
1750 
httpSendReply(struct HttpConnection * http,int code,const char * msg,const char * fmt,...)1751 void httpSendReply(struct HttpConnection *http, int code,
1752                    const char *msg, const char *fmt, ...) {
1753   http->code     = code;
1754   char *body;
1755   char *title    = code != 200 ? stringPrintf(NULL, "%d %s", code, msg) : NULL;
1756   char *details  = NULL;
1757   if (fmt != NULL && strcmp(fmt, NO_MSG)) {
1758     va_list ap;
1759     va_start(ap, fmt);
1760     details      = vStringPrintf(NULL, fmt, ap);
1761     va_end(ap);
1762   }
1763   body           = stringPrintf(NULL,
1764      "<?xml version=\"1.0\" encoding=\"utf-8\"?>\n"
1765      "<!DOCTYPE html PUBLIC "
1766                "\"-//W3C//DTD XHTML 1.0 Transitional//EN\" "
1767                "\"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd\">\n"
1768      "<html xmlns=\"http://www.w3.org/1999/xhtml\" "
1769      "xmlns:v=\"urn:schemas-microsoft-com:vml\" "
1770      "xml:lang=\"en\" lang=\"en\">\n"
1771      "<head>\n"
1772      "<title>%s</title>\n"
1773      "</head>\n"
1774      "<body>\n"
1775      "%s\n"
1776      "</body>\n"
1777      "</html>\n",
1778      title ? title : msg, fmt && strcmp(fmt, NO_MSG) ? details : msg);
1779   free(details);
1780   free(title);
1781   char *response = NULL;
1782   if (code) {
1783     response     = stringPrintf(NULL,
1784                                 "HTTP/1.1 %d %s\r\n"
1785                                 "%s"
1786                                 "Content-Type: text/html; charset=utf-8\r\n"
1787                                 "Content-Length: %ld\r\n"
1788                                 "\r\n",
1789                                 code, msg,
1790                                 code != 200 ? "Connection: close\r\n" : "",
1791                                 (long)strlen(body));
1792   }
1793   int isHead     = http->method && !strcmp(http->method, "HEAD");
1794   if (!isHead) {
1795     response     = stringPrintf(response, "%s", body);
1796   }
1797   free(body);
1798   check(response);
1799   httpTransfer(http, response, strlen(response));
1800   if (code != 200 || isHead) {
1801     httpCloseRead(http);
1802   }
1803 }
1804 
httpSendWebSocketTextMsg(struct HttpConnection * http,int type,const char * fmt,...)1805 void httpSendWebSocketTextMsg(struct HttpConnection *http, int type,
1806                               const char *fmt, ...) {
1807   check(type >= 0 && type <= 0x7F);
1808   va_list ap;
1809   va_start(ap, fmt);
1810   char *buf;
1811   int len;
1812   if (strcmp(fmt, BINARY_MSG)) {
1813     // Send a printf() style text message
1814     buf              = vStringPrintf(NULL, fmt, ap);
1815     len              = strlen(buf);
1816   } else {
1817     // Send a binary message
1818     len              = va_arg(ap, int);
1819     buf              = va_arg(ap, char *);
1820   }
1821   va_end(ap);
1822   check(len >= 0 && len < 0x60000000);
1823 
1824   // We assume that all input data is directly mapped in the range 0..255
1825   // (e.g. ISO-8859-1). In order to transparently send it over a web socket,
1826   // we have to encode it in UTF-8.
1827   int utf8Len        = len + 2;
1828   for (int i = 0; i < len; ++i) {
1829     if (buf[i] & 0x80) {
1830       ++utf8Len;
1831     }
1832   }
1833   char *utf8;
1834   check(utf8         = malloc(utf8Len));
1835   utf8[0]            = type;
1836   for (int i = 0, j = 1; i < len; ++i) {
1837     unsigned char ch = buf[i];
1838     if (ch & 0x80) {
1839       utf8[j++]      = 0xC0 + (ch >> 6);
1840       utf8[j++]      = 0x80 + (ch & 0x3F);
1841     } else {
1842       utf8[j++]      = ch;
1843     }
1844     check(j < utf8Len);
1845   }
1846   utf8[utf8Len-1]    = '\xFF';
1847 
1848   // Free our temporary buffer, if we actually did allocate one.
1849   if (strcmp(fmt, BINARY_MSG)) {
1850     free(buf);
1851   }
1852 
1853   // Send to browser.
1854   httpTransfer(http, utf8, utf8Len);
1855 }
1856 
httpSendWebSocketBinaryMsg(struct HttpConnection * http,int type,const void * buf,int len)1857 void httpSendWebSocketBinaryMsg(struct HttpConnection *http, int type,
1858                                 const void *buf, int len) {
1859   check(type >= 0x80 && type <= 0xFF);
1860   check(len > 0 && len < 0x7FFFFFF0);
1861 
1862   // Allocate buffer for header and payload.
1863   char *data;
1864   check(data  = malloc(len + 6));
1865   data[0]     = type;
1866 
1867   // Convert length to base-128.
1868   int i       = 0;
1869   int l       = len;
1870   do {
1871     data[++i] = 0x80 + (l & 0x7F);
1872     l        /= 128;
1873   } while (l);
1874   data[i]    &= 0x7F;
1875 
1876   // Reverse digits, so that they are big-endian.
1877   for (int j = 0; j < i/2; ++j) {
1878     char ch   = data[1+j];
1879     data[1+j] = data[i-j];
1880     data[i-j] = ch;
1881   }
1882 
1883   // Transmit header and payload.
1884   memmove(data + i + 1, buf, len);
1885   httpTransfer(http, data, len + i + 1);
1886 }
1887 
httpExitLoop(struct HttpConnection * http,int exitAll)1888 void httpExitLoop(struct HttpConnection *http, int exitAll) {
1889   serverExitLoop(http->server, exitAll);
1890 }
1891 
httpGetServer(const struct HttpConnection * http)1892 struct Server *httpGetServer(const struct HttpConnection *http) {
1893   return http->server;
1894 }
1895 
httpGetServerConnection(const struct HttpConnection * http)1896 struct ServerConnection *httpGetServerConnection(const struct HttpConnection *
1897                                                  http) {
1898   struct HttpConnection *httpW = (struct HttpConnection *)http;
1899   httpW->connection = serverGetConnection(http->server, http->connection,
1900                                           http->fd);
1901   return http->connection;
1902 }
1903 
httpGetFd(const HttpConnection * http)1904 int httpGetFd(const HttpConnection *http) {
1905   return http->fd;
1906 }
1907 
httpGetPeerName(const struct HttpConnection * http)1908 const char *httpGetPeerName(const struct HttpConnection *http) {
1909   return http->peerName;
1910 }
1911 
httpGetRealIP(const struct HttpConnection * http)1912 const char *httpGetRealIP(const struct HttpConnection *http) {
1913   return getFromHashMap(&http->header, "x-real-ip");
1914 }
1915 
httpGetMethod(const struct HttpConnection * http)1916 const char *httpGetMethod(const struct HttpConnection *http) {
1917   return http->method;
1918 }
1919 
httpGetProtocol(const struct HttpConnection * http)1920 const char *httpGetProtocol(const struct HttpConnection *http) {
1921   return http->sslHndl ? "https" : "http";
1922 }
1923 
httpGetHost(const struct HttpConnection * http)1924 const char *httpGetHost(const struct HttpConnection *http) {
1925   const char *host = getFromHashMap(&http->header, "host");
1926   if (!host || !*host) {
1927     host           = "localhost";
1928   }
1929   return host;
1930 }
1931 
httpGetPort(const struct HttpConnection * http)1932 int httpGetPort(const struct HttpConnection *http) {
1933   return http->port;
1934 }
1935 
httpGetPath(const struct HttpConnection * http)1936 const char *httpGetPath(const struct HttpConnection *http) {
1937   return http->matchedPath;
1938 }
1939 
httpGetPathInfo(const struct HttpConnection * http)1940 const char *httpGetPathInfo(const struct HttpConnection *http) {
1941   return http->pathInfo ? http->pathInfo : "";
1942 }
1943 
httpGetQuery(const struct HttpConnection * http)1944 const char *httpGetQuery(const struct HttpConnection *http) {
1945   return http->query ? http->query : "";
1946 }
1947 
httpGetURL(const struct HttpConnection * http)1948 const char *httpGetURL(const struct HttpConnection *http) {
1949   if (!http->url) {
1950     const char *host           = httpGetHost(http);
1951     int s_size                 = 8 + strlen(host) + 25 + strlen(http->path);
1952     check(*(char **)&http->url = malloc(s_size + 1));
1953     *http->url                 = '\000';
1954     strncat(http->url, http->sslHndl ? "https://" : "http://", s_size);
1955     strncat(http->url, host, s_size);
1956     if (http->port != (http->sslHndl ? 443 : 80)) {
1957       snprintf(strrchr(http->url, '\000'), 25, ":%d", http->port);
1958     }
1959     strncat(http->url, http->path, s_size);
1960   }
1961   return http->url;
1962 }
1963 
httpGetVersion(const struct HttpConnection * http)1964 const char *httpGetVersion(const struct HttpConnection *http) {
1965   return http->version;
1966 }
1967 
httpGetHeaders(const struct HttpConnection * http)1968 const struct HashMap *httpGetHeaders(const struct HttpConnection *http) {
1969   return &http->header;
1970 }
1971