1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* This Source Code Form is subject to the terms of the Mozilla Public
3  * License, v. 2.0. If a copy of the MPL was not distributed with this
4  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
5 
6 /*
7  * WARNING: DO NOT USE THIS CODE IN PRODUCTION SYSTEMS.  It is highly likely to
8  *          be plagued with the usual problems endemic to C (buffer overflows
9  *          and the like).  We don't especially care here (but would accept
10  *          patches!) because this is only intended for use in our test
11  *          harnesses in controlled situations where input is guaranteed not to
12  *          be malicious.
13  */
14 
15 #include "ScopedNSSTypes.h"
16 #include <assert.h>
17 #include <stdio.h>
18 #include <string>
19 #include <vector>
20 #include <algorithm>
21 #include <stdarg.h>
22 #include "prinit.h"
23 #include "prerror.h"
24 #include "prenv.h"
25 #include "prnetdb.h"
26 #include "prtpool.h"
27 #include "nsAlgorithm.h"
28 #include "nss.h"
29 #include "key.h"
30 #include "ssl.h"
31 #include "sslproto.h"
32 #include "plhash.h"
33 #include "mozilla/Sprintf.h"
34 
35 using namespace mozilla;
36 using namespace mozilla::psm;
37 using std::string;
38 using std::vector;
39 
40 #define IS_DELIM(m, c) ((m)[(c) >> 3] & (1 << ((c)&7)))
41 #define SET_DELIM(m, c) ((m)[(c) >> 3] |= (1 << ((c)&7)))
42 #define DELIM_TABLE_SIZE 32
43 
44 // You can set the level of logging by env var SSLTUNNEL_LOG_LEVEL=n, where n
45 // is 0 through 3.  The default is 1, INFO level logging.
46 enum LogLevel {
47   LEVEL_DEBUG = 0,
48   LEVEL_INFO = 1,
49   LEVEL_ERROR = 2,
50   LEVEL_SILENT = 3
51 } gLogLevel,
52     gLastLogLevel;
53 
54 #define _LOG_OUTPUT(level, func, params) \
55   PR_BEGIN_MACRO                         \
56   if (level >= gLogLevel) {              \
57     gLastLogLevel = level;               \
58     func params;                         \
59   }                                      \
60   PR_END_MACRO
61 
62 // The most verbose output
63 #define LOG_DEBUG(params) _LOG_OUTPUT(LEVEL_DEBUG, printf, params)
64 
65 // Top level informative messages
66 #define LOG_INFO(params) _LOG_OUTPUT(LEVEL_INFO, printf, params)
67 
68 // Serious errors that must be logged always until completely gag
69 #define LOG_ERROR(params) _LOG_OUTPUT(LEVEL_ERROR, eprintf, params)
70 
71 // Same as LOG_ERROR, but when logging is set to LEVEL_DEBUG, the message
72 // will be put to the stdout instead of stderr to keep continuity with other
73 // LOG_DEBUG message output
74 #define LOG_ERRORD(params)                     \
75   PR_BEGIN_MACRO                               \
76   if (gLogLevel == LEVEL_DEBUG)                \
77     _LOG_OUTPUT(LEVEL_ERROR, printf, params);  \
78   else                                         \
79     _LOG_OUTPUT(LEVEL_ERROR, eprintf, params); \
80   PR_END_MACRO
81 
82 // If there is any output written between LOG_BEGIN_BLOCK() and
83 // LOG_END_BLOCK() then a new line will be put to the proper output (out/err)
84 #define LOG_BEGIN_BLOCK() gLastLogLevel = LEVEL_SILENT;
85 
86 #define LOG_END_BLOCK()                                                        \
87   PR_BEGIN_MACRO                                                               \
88   if (gLastLogLevel == LEVEL_ERROR) LOG_ERROR(("\n"));                         \
89   if (gLastLogLevel < LEVEL_ERROR) _LOG_OUTPUT(gLastLogLevel, printf, ("\n")); \
90   PR_END_MACRO
91 
eprintf(const char * str,...)92 int eprintf(const char* str, ...) {
93   va_list ap;
94   va_start(ap, str);
95   int result = vfprintf(stderr, str, ap);
96   va_end(ap);
97   return result;
98 }
99 
100 // Copied from nsCRT
strtok2(char * string,const char * delims,char ** newStr)101 char* strtok2(char* string, const char* delims, char** newStr) {
102   PR_ASSERT(string);
103 
104   char delimTable[DELIM_TABLE_SIZE];
105   uint32_t i;
106   char* result;
107   char* str = string;
108 
109   for (i = 0; i < DELIM_TABLE_SIZE; i++) delimTable[i] = '\0';
110 
111   for (i = 0; delims[i]; i++) {
112     SET_DELIM(delimTable, static_cast<uint8_t>(delims[i]));
113   }
114 
115   // skip to beginning
116   while (*str && IS_DELIM(delimTable, static_cast<uint8_t>(*str))) {
117     str++;
118   }
119   result = str;
120 
121   // fix up the end of the token
122   while (*str) {
123     if (IS_DELIM(delimTable, static_cast<uint8_t>(*str))) {
124       *str++ = '\0';
125       break;
126     }
127     str++;
128   }
129   *newStr = str;
130 
131   return str == result ? nullptr : result;
132 }
133 
134 enum client_auth_option { caNone = 0, caRequire = 1, caRequest = 2 };
135 
136 // Structs for passing data into jobs on the thread pool
137 typedef struct {
138   int32_t listen_port;
139   string cert_nickname;
140   PLHashTable* host_cert_table;
141   PLHashTable* host_clientauth_table;
142   PLHashTable* host_redir_table;
143   PLHashTable* host_ssl3_table;
144   PLHashTable* host_tls1_table;
145   PLHashTable* host_rc4_table;
146   PLHashTable* host_failhandshake_table;
147 } server_info_t;
148 
149 typedef struct {
150   PRFileDesc* client_sock;
151   PRNetAddr client_addr;
152   server_info_t* server_info;
153   // the original host in the Host: header for this connection is
154   // stored here, for proxied connections
155   string original_host;
156   // true if no SSL should be used for this connection
157   bool http_proxy_only;
158   // true if this connection is for a WebSocket
159   bool iswebsocket;
160 } connection_info_t;
161 
162 typedef struct {
163   string fullHost;
164   bool matched;
165 } server_match_t;
166 
167 const int32_t BUF_SIZE = 16384;
168 const int32_t BUF_MARGIN = 1024;
169 const int32_t BUF_TOTAL = BUF_SIZE + BUF_MARGIN;
170 
171 struct relayBuffer {
172   char *buffer, *bufferhead, *buffertail, *bufferend;
173 
relayBufferrelayBuffer174   relayBuffer() {
175     // Leave 1024 bytes more for request line manipulations
176     bufferhead = buffertail = buffer = new char[BUF_TOTAL];
177     bufferend = buffer + BUF_SIZE;
178   }
179 
~relayBufferrelayBuffer180   ~relayBuffer() { delete[] buffer; }
181 
compactrelayBuffer182   void compact() {
183     if (buffertail == bufferhead) buffertail = bufferhead = buffer;
184   }
185 
emptyrelayBuffer186   bool empty() { return bufferhead == buffertail; }
areafreerelayBuffer187   size_t areafree() { return bufferend - buffertail; }
marginrelayBuffer188   size_t margin() { return areafree() + BUF_MARGIN; }
presentrelayBuffer189   size_t present() { return buffertail - bufferhead; }
190 };
191 
192 // These numbers are multiplied by the number of listening ports (actual
193 // servers running).  According the thread pool implementation there is no
194 // need to limit the number of threads initially, threads are allocated
195 // dynamically and stored in a linked list.  Initial number of 2 is chosen
196 // to allocate a thread for socket accept and preallocate one for the first
197 // connection that is with high probability expected to come.
198 const uint32_t INITIAL_THREADS = 2;
199 const uint32_t MAX_THREADS = 100;
200 const uint32_t DEFAULT_STACKSIZE = (512 * 1024);
201 
202 // global data
203 string nssconfigdir;
204 vector<server_info_t> servers;
205 PRNetAddr remote_addr;
206 PRNetAddr websocket_server;
207 PRThreadPool* threads = nullptr;
208 PRLock* shutdown_lock = nullptr;
209 PRCondVar* shutdown_condvar = nullptr;
210 // Not really used, unless something fails to start
211 bool shutdown_server = false;
212 bool do_http_proxy = false;
213 bool any_host_spec_config = false;
214 
ClientAuthValueComparator(const void * v1,const void * v2)215 int ClientAuthValueComparator(const void* v1, const void* v2) {
216   int a = *static_cast<const client_auth_option*>(v1) -
217           *static_cast<const client_auth_option*>(v2);
218   if (a == 0) return 0;
219   if (a > 0) return 1;
220   // (a < 0)
221   return -1;
222 }
223 
match_hostname(PLHashEntry * he,int index,void * arg)224 static int match_hostname(PLHashEntry* he, int index, void* arg) {
225   server_match_t* match = (server_match_t*)arg;
226   if (match->fullHost.find((char*)he->key) != string::npos)
227     match->matched = true;
228   return HT_ENUMERATE_NEXT;
229 }
230 
231 /*
232  * Signal the main thread that the application should shut down.
233  */
SignalShutdown()234 void SignalShutdown() {
235   PR_Lock(shutdown_lock);
236   PR_NotifyCondVar(shutdown_condvar);
237   PR_Unlock(shutdown_lock);
238 }
239 
240 // available flags
241 enum {
242   USE_SSL3 = 1 << 0,
243   USE_RC4 = 1 << 1,
244   FAIL_HANDSHAKE = 1 << 2,
245   USE_TLS1 = 1 << 4
246 };
247 
ReadConnectRequest(server_info_t * server_info,relayBuffer & buffer,int32_t * result,string & certificate,client_auth_option * clientauth,string & host,string & location,int32_t * flags)248 bool ReadConnectRequest(server_info_t* server_info, relayBuffer& buffer,
249                         int32_t* result, string& certificate,
250                         client_auth_option* clientauth, string& host,
251                         string& location, int32_t* flags) {
252   if (buffer.present() < 4) {
253     LOG_DEBUG(
254         (" !! only %d bytes present in the buffer", (int)buffer.present()));
255     return false;
256   }
257   if (strncmp(buffer.buffertail - 4, "\r\n\r\n", 4)) {
258     LOG_ERRORD((" !! request is not tailed with CRLFCRLF but with %x %x %x %x",
259                 *(buffer.buffertail - 4), *(buffer.buffertail - 3),
260                 *(buffer.buffertail - 2), *(buffer.buffertail - 1)));
261     return false;
262   }
263 
264   LOG_DEBUG((" parsing initial connect request, dump:\n%.*s\n",
265              (int)buffer.present(), buffer.bufferhead));
266 
267   *result = 400;
268 
269   char* token;
270   char* _caret;
271   token = strtok2(buffer.bufferhead, " ", &_caret);
272   if (!token) {
273     LOG_ERRORD((" no space found"));
274     return true;
275   }
276   if (strcmp(token, "CONNECT")) {
277     LOG_ERRORD((" not CONNECT request but %s", token));
278     return true;
279   }
280 
281   token = strtok2(_caret, " ", &_caret);
282   void* c = PL_HashTableLookup(server_info->host_cert_table, token);
283   if (c) certificate = static_cast<char*>(c);
284 
285   host = "https://";
286   host += token;
287 
288   c = PL_HashTableLookup(server_info->host_clientauth_table, token);
289   if (c)
290     *clientauth = *static_cast<client_auth_option*>(c);
291   else
292     *clientauth = caNone;
293 
294   void* redir = PL_HashTableLookup(server_info->host_redir_table, token);
295   if (redir) location = static_cast<char*>(redir);
296 
297   if (PL_HashTableLookup(server_info->host_ssl3_table, token)) {
298     *flags |= USE_SSL3;
299   }
300 
301   if (PL_HashTableLookup(server_info->host_rc4_table, token)) {
302     *flags |= USE_RC4;
303   }
304 
305   if (PL_HashTableLookup(server_info->host_tls1_table, token)) {
306     *flags |= USE_TLS1;
307   }
308 
309   if (PL_HashTableLookup(server_info->host_failhandshake_table, token)) {
310     *flags |= FAIL_HANDSHAKE;
311   }
312 
313   token = strtok2(_caret, "/", &_caret);
314   if (strcmp(token, "HTTP")) {
315     LOG_ERRORD((" not tailed with HTTP but with %s", token));
316     return true;
317   }
318 
319   *result = (redir) ? 302 : 200;
320   return true;
321 }
322 
ConfigureSSLServerSocket(PRFileDesc * socket,server_info_t * si,const string & certificate,const client_auth_option clientAuth,int32_t flags)323 bool ConfigureSSLServerSocket(PRFileDesc* socket, server_info_t* si,
324                               const string& certificate,
325                               const client_auth_option clientAuth,
326                               int32_t flags) {
327   const char* certnick =
328       certificate.empty() ? si->cert_nickname.c_str() : certificate.c_str();
329 
330   UniqueCERTCertificate cert(PK11_FindCertFromNickname(certnick, nullptr));
331   if (!cert) {
332     LOG_ERROR(("Failed to find cert %s\n", certnick));
333     return false;
334   }
335 
336   UniqueSECKEYPrivateKey privKey(PK11_FindKeyByAnyCert(cert.get(), nullptr));
337   if (!privKey) {
338     LOG_ERROR(("Failed to find private key\n"));
339     return false;
340   }
341 
342   PRFileDesc* ssl_socket = SSL_ImportFD(nullptr, socket);
343   if (!ssl_socket) {
344     LOG_ERROR(("Error importing SSL socket\n"));
345     return false;
346   }
347 
348   if (flags & FAIL_HANDSHAKE) {
349     // deliberately cause handshake to fail by sending the client a client hello
350     SSL_ResetHandshake(ssl_socket, false);
351     return true;
352   }
353 
354   SSLKEAType certKEA = NSS_FindCertKEAType(cert.get());
355   if (SSL_ConfigSecureServer(ssl_socket, cert.get(), privKey.get(), certKEA) !=
356       SECSuccess) {
357     LOG_ERROR(("Error configuring SSL server socket\n"));
358     return false;
359   }
360 
361   SSL_OptionSet(ssl_socket, SSL_SECURITY, true);
362   SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_CLIENT, false);
363   SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_SERVER, true);
364 
365   if (clientAuth != caNone) {
366     SSL_OptionSet(ssl_socket, SSL_REQUEST_CERTIFICATE, true);
367     SSL_OptionSet(ssl_socket, SSL_REQUIRE_CERTIFICATE, clientAuth == caRequire);
368   }
369 
370   if (flags & USE_SSL3) {
371     SSLVersionRange range = {SSL_LIBRARY_VERSION_3_0, SSL_LIBRARY_VERSION_3_0};
372     SSL_VersionRangeSet(ssl_socket, &range);
373   }
374 
375   if (flags & USE_TLS1) {
376     SSLVersionRange range = {SSL_LIBRARY_VERSION_TLS_1_0,
377                              SSL_LIBRARY_VERSION_TLS_1_0};
378     SSL_VersionRangeSet(ssl_socket, &range);
379   }
380 
381   if (flags & USE_RC4) {
382     for (uint16_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
383       uint16_t cipher_id = SSL_ImplementedCiphers[i];
384       switch (cipher_id) {
385         case TLS_ECDHE_ECDSA_WITH_RC4_128_SHA:
386         case TLS_ECDHE_RSA_WITH_RC4_128_SHA:
387         case TLS_RSA_WITH_RC4_128_SHA:
388         case TLS_RSA_WITH_RC4_128_MD5:
389           SSL_CipherPrefSet(ssl_socket, cipher_id, true);
390           break;
391 
392         default:
393           SSL_CipherPrefSet(ssl_socket, cipher_id, false);
394           break;
395       }
396     }
397   }
398 
399   SSL_ResetHandshake(ssl_socket, true);
400 
401   return true;
402 }
403 
404 /**
405  * This function examines the buffer for a Sec-WebSocket-Location: field,
406  * and if it's present, it replaces the hostname in that field with the
407  * value in the server's original_host field.  This function works
408  * in the reverse direction as AdjustWebSocketHost(), replacing the real
409  * hostname of a response with the potentially fake hostname that is expected
410  * by the browser (e.g., mochi.test).
411  *
412  * @return true if the header was adjusted successfully, or not found, false
413  * if the header is present but the url is not, which should indicate
414  * that more data needs to be read from the socket
415  */
AdjustWebSocketLocation(relayBuffer & buffer,connection_info_t * ci)416 bool AdjustWebSocketLocation(relayBuffer& buffer, connection_info_t* ci) {
417   assert(buffer.margin());
418   buffer.buffertail[1] = '\0';
419 
420   char* wsloc = strstr(buffer.bufferhead, "Sec-WebSocket-Location:");
421   if (!wsloc) return true;
422   // advance pointer to the start of the hostname
423   wsloc = strstr(wsloc, "ws://");
424   if (!wsloc) return false;
425   wsloc += 5;
426   // find the end of the hostname
427   char* wslocend = strchr(wsloc + 1, '/');
428   if (!wslocend) return false;
429   char* crlf = strstr(wsloc, "\r\n");
430   if (!crlf) return false;
431   if (ci->original_host.empty()) return true;
432 
433   int diff = ci->original_host.length() - (wslocend - wsloc);
434   if (diff > 0) assert(size_t(diff) <= buffer.margin());
435   memmove(wslocend + diff, wslocend, buffer.buffertail - wsloc - diff);
436   buffer.buffertail += diff;
437 
438   memcpy(wsloc, ci->original_host.c_str(), ci->original_host.length());
439   return true;
440 }
441 
442 /**
443  * This function examines the buffer for a Host: field, and if it's present,
444  * it replaces the hostname in that field with the hostname in the server's
445  * remote_addr field.  This is needed because proxy requests may be coming
446  * from mochitest with fake hosts, like mochi.test, and these need to be
447  * replaced with the host that the destination server is actually running
448  * on.
449  */
AdjustWebSocketHost(relayBuffer & buffer,connection_info_t * ci)450 bool AdjustWebSocketHost(relayBuffer& buffer, connection_info_t* ci) {
451   const char HEADER_UPGRADE[] = "Upgrade:";
452   const char HEADER_HOST[] = "Host:";
453 
454   PRNetAddr inet_addr =
455       (websocket_server.inet.port ? websocket_server : remote_addr);
456 
457   assert(buffer.margin());
458 
459   // Cannot use strnchr so add a null char at the end. There is always some
460   // space left because we preserve a margin.
461   buffer.buffertail[1] = '\0';
462 
463   // Verify this is a WebSocket header.
464   char* h1 = strstr(buffer.bufferhead, HEADER_UPGRADE);
465   if (!h1) return false;
466   h1 += strlen(HEADER_UPGRADE);
467   h1 += strspn(h1, " \t");
468   char* h2 = strstr(h1, "WebSocket\r\n");
469   if (!h2) h2 = strstr(h1, "websocket\r\n");
470   if (!h2) h2 = strstr(h1, "Websocket\r\n");
471   if (!h2) return false;
472 
473   char* host = strstr(buffer.bufferhead, HEADER_HOST);
474   if (!host) return false;
475   // advance pointer to beginning of hostname
476   host += strlen(HEADER_HOST);
477   host += strspn(host, " \t");
478 
479   char* endhost = strstr(host, "\r\n");
480   if (!endhost) return false;
481 
482   // Save the original host, so we can use it later on responses from the
483   // server.
484   ci->original_host.assign(host, endhost - host);
485 
486   char newhost[40];
487   PR_NetAddrToString(&inet_addr, newhost, sizeof(newhost));
488   assert(strlen(newhost) < sizeof(newhost) - 7);
489   SprintfLiteral(newhost, "%s:%d", newhost, PR_ntohs(inet_addr.inet.port));
490 
491   int diff = strlen(newhost) - (endhost - host);
492   if (diff > 0) assert(size_t(diff) <= buffer.margin());
493   memmove(endhost + diff, endhost, buffer.buffertail - host - diff);
494   buffer.buffertail += diff;
495 
496   memcpy(host, newhost, strlen(newhost));
497   return true;
498 }
499 
500 /**
501  * This function prefixes Request-URI path with a full scheme-host-port
502  * string.
503  */
AdjustRequestURI(relayBuffer & buffer,string * host)504 bool AdjustRequestURI(relayBuffer& buffer, string* host) {
505   assert(buffer.margin());
506 
507   // Cannot use strnchr so add a null char at the end. There is always some
508   // space left because we preserve a margin.
509   buffer.buffertail[1] = '\0';
510   LOG_DEBUG((" incoming request to adjust:\n%s\n", buffer.bufferhead));
511 
512   char *token, *path;
513   path = strchr(buffer.bufferhead, ' ') + 1;
514   if (!path) return false;
515 
516   // If the path doesn't start with a slash don't change it, it is probably '*'
517   // or a full path already. Return true, we are done with this request
518   // adjustment.
519   if (*path != '/') return true;
520 
521   token = strchr(path, ' ') + 1;
522   if (!token) return false;
523 
524   if (strncmp(token, "HTTP/", 5)) return false;
525 
526   size_t hostlength = host->length();
527   assert(hostlength <= buffer.margin());
528 
529   memmove(path + hostlength, path, buffer.buffertail - path);
530   memcpy(path, host->c_str(), hostlength);
531   buffer.buffertail += hostlength;
532 
533   return true;
534 }
535 
ConnectSocket(UniquePRFileDesc & fd,const PRNetAddr * addr,PRIntervalTime timeout)536 bool ConnectSocket(UniquePRFileDesc& fd, const PRNetAddr* addr,
537                    PRIntervalTime timeout) {
538   PRStatus stat = PR_Connect(fd.get(), addr, timeout);
539   if (stat != PR_SUCCESS) return false;
540 
541   PRSocketOptionData option;
542   option.option = PR_SockOpt_Nonblocking;
543   option.value.non_blocking = true;
544   PR_SetSocketOption(fd.get(), &option);
545 
546   return true;
547 }
548 
549 /*
550  * Handle an incoming client connection. The server thread has already
551  * accepted the connection, so we just need to connect to the remote
552  * port and then proxy data back and forth.
553  * The data parameter is a connection_info_t*, and must be deleted
554  * by this function.
555  */
HandleConnection(void * data)556 void HandleConnection(void* data) {
557   connection_info_t* ci = static_cast<connection_info_t*>(data);
558   PRIntervalTime connect_timeout = PR_SecondsToInterval(30);
559 
560   UniquePRFileDesc other_sock(PR_NewTCPSocket());
561   bool client_done = false;
562   bool client_error = false;
563   bool connect_accepted = !do_http_proxy;
564   bool ssl_updated = !do_http_proxy;
565   bool expect_request_start = do_http_proxy;
566   string certificateToUse;
567   string locationHeader;
568   client_auth_option clientAuth;
569   string fullHost;
570   int32_t flags = 0;
571 
572   LOG_DEBUG(("SSLTUNNEL(%p)): incoming connection csock(0)=%p, ssock(1)=%p\n",
573              static_cast<void*>(data), static_cast<void*>(ci->client_sock),
574              static_cast<void*>(other_sock.get())));
575   if (other_sock) {
576     int32_t numberOfSockets = 1;
577 
578     relayBuffer buffers[2];
579 
580     if (!do_http_proxy) {
581       if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info,
582                                     certificateToUse, caNone, flags))
583         client_error = true;
584       else if (!ConnectSocket(other_sock, &remote_addr, connect_timeout))
585         client_error = true;
586       else
587         numberOfSockets = 2;
588     }
589 
590     PRPollDesc sockets[2] = {{ci->client_sock, PR_POLL_READ, 0},
591                              {other_sock.get(), PR_POLL_READ, 0}};
592     bool socketErrorState[2] = {false, false};
593 
594     while (!((client_error || client_done) && buffers[0].empty() &&
595              buffers[1].empty())) {
596       sockets[0].in_flags |= PR_POLL_EXCEPT;
597       sockets[1].in_flags |= PR_POLL_EXCEPT;
598       LOG_DEBUG(("SSLTUNNEL(%p)): polling flags csock(0)=%c%c, ssock(1)=%c%c\n",
599                  static_cast<void*>(data),
600                  sockets[0].in_flags & PR_POLL_READ ? 'R' : '-',
601                  sockets[0].in_flags & PR_POLL_WRITE ? 'W' : '-',
602                  sockets[1].in_flags & PR_POLL_READ ? 'R' : '-',
603                  sockets[1].in_flags & PR_POLL_WRITE ? 'W' : '-'));
604       int32_t pollStatus =
605           PR_Poll(sockets, numberOfSockets, PR_MillisecondsToInterval(1000));
606       if (pollStatus < 0) {
607         LOG_DEBUG(("SSLTUNNEL(%p)): pollStatus=%d, exiting\n",
608                    static_cast<void*>(data), pollStatus));
609         client_error = true;
610         break;
611       }
612 
613       if (pollStatus == 0) {
614         // timeout
615         LOG_DEBUG(("SSLTUNNEL(%p)): poll timeout, looping\n",
616                    static_cast<void*>(data)));
617         continue;
618       }
619 
620       for (int32_t s = 0; s < numberOfSockets; ++s) {
621         int32_t s2 = s == 1 ? 0 : 1;
622         int16_t out_flags = sockets[s].out_flags;
623         int16_t& in_flags = sockets[s].in_flags;
624         int16_t& in_flags2 = sockets[s2].in_flags;
625         sockets[s].out_flags = 0;
626 
627         LOG_BEGIN_BLOCK();
628         LOG_DEBUG(("SSLTUNNEL(%p)): %csock(%d)=%p out_flags=%d",
629                    static_cast<void*>(data), s == 0 ? 'c' : 's', s,
630                    static_cast<void*>(sockets[s].fd), out_flags));
631         if (out_flags & (PR_POLL_EXCEPT | PR_POLL_ERR | PR_POLL_HUP)) {
632           LOG_DEBUG((" :exception\n"));
633           client_error = true;
634           socketErrorState[s] = true;
635           // We got a fatal error state on the socket. Clear the output buffer
636           // for this socket to break the main loop, we will never more be able
637           // to send those data anyway.
638           buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
639           continue;
640         }  // PR_POLL_EXCEPT, PR_POLL_ERR, PR_POLL_HUP handling
641 
642         if (out_flags & PR_POLL_READ && !buffers[s].areafree()) {
643           LOG_DEBUG(
644               (" no place in read buffer but got read flag, dropping it now!"));
645           in_flags &= ~PR_POLL_READ;
646         }
647 
648         if (out_flags & PR_POLL_READ && buffers[s].areafree()) {
649           LOG_DEBUG((" :reading"));
650           int32_t bytesRead =
651               PR_Recv(sockets[s].fd, buffers[s].buffertail,
652                       buffers[s].areafree(), 0, PR_INTERVAL_NO_TIMEOUT);
653 
654           if (bytesRead == 0) {
655             LOG_DEBUG((" socket gracefully closed"));
656             client_done = true;
657             in_flags &= ~PR_POLL_READ;
658           } else if (bytesRead < 0) {
659             if (PR_GetError() != PR_WOULD_BLOCK_ERROR) {
660               LOG_DEBUG((" error=%d", PR_GetError()));
661               // We are in error state, indicate that the connection was
662               // not closed gracefully
663               client_error = true;
664               socketErrorState[s] = true;
665               // Wipe out our send buffer, we cannot send it anyway.
666               buffers[s2].bufferhead = buffers[s2].buffertail =
667                   buffers[s2].buffer;
668             } else
669               LOG_DEBUG((" would block"));
670           } else {
671             // If the other socket is in error state (unable to send/receive)
672             // throw this data away and continue loop
673             if (socketErrorState[s2]) {
674               LOG_DEBUG((" have read but other socket is in error state\n"));
675               continue;
676             }
677 
678             buffers[s].buffertail += bytesRead;
679             LOG_DEBUG((", read %d bytes", bytesRead));
680 
681             // We have to accept and handle the initial CONNECT request here
682             int32_t response;
683             if (!connect_accepted &&
684                 ReadConnectRequest(ci->server_info, buffers[s], &response,
685                                    certificateToUse, &clientAuth, fullHost,
686                                    locationHeader, &flags)) {
687               // Mark this as a proxy-only connection (no SSL) if the CONNECT
688               // request didn't come for port 443 or from any of the server's
689               // cert or clientauth hostnames.
690               if (fullHost.find(":443") == string::npos) {
691                 server_match_t match;
692                 match.fullHost = fullHost;
693                 match.matched = false;
694                 PL_HashTableEnumerateEntries(ci->server_info->host_cert_table,
695                                              match_hostname, &match);
696                 PL_HashTableEnumerateEntries(
697                     ci->server_info->host_clientauth_table, match_hostname,
698                     &match);
699                 PL_HashTableEnumerateEntries(ci->server_info->host_ssl3_table,
700                                              match_hostname, &match);
701                 PL_HashTableEnumerateEntries(ci->server_info->host_tls1_table,
702                                              match_hostname, &match);
703                 PL_HashTableEnumerateEntries(ci->server_info->host_rc4_table,
704                                              match_hostname, &match);
705                 PL_HashTableEnumerateEntries(
706                     ci->server_info->host_failhandshake_table, match_hostname,
707                     &match);
708                 ci->http_proxy_only = !match.matched;
709               } else {
710                 ci->http_proxy_only = false;
711               }
712 
713               // Clean the request as it would be read
714               buffers[s].bufferhead = buffers[s].buffertail = buffers[s].buffer;
715               in_flags |= PR_POLL_WRITE;
716               connect_accepted = true;
717 
718               // Store response to the oposite buffer
719               if (response == 200) {
720                 LOG_DEBUG(
721                     (" accepted CONNECT request, connected to the server, "
722                      "sending OK to the client\n"));
723                 strcpy(
724                     buffers[s2].buffer,
725                     "HTTP/1.1 200 Connected\r\nConnection: keep-alive\r\n\r\n");
726               } else if (response == 302) {
727                 LOG_DEBUG(
728                     (" accepted CONNECT request with redirection, "
729                      "sending location and 302 to the client\n"));
730                 client_done = true;
731                 snprintf(buffers[s2].buffer,
732                          buffers[s2].bufferend - buffers[s2].buffer,
733                          "HTTP/1.1 302 Moved\r\n"
734                          "Location: https://%s/\r\n"
735                          "Connection: close\r\n\r\n",
736                          locationHeader.c_str());
737               } else {
738                 LOG_ERRORD(
739                     (" could not read the connect request, closing connection "
740                      "with %d",
741                      response));
742                 client_done = true;
743                 snprintf(buffers[s2].buffer,
744                          buffers[s2].bufferend - buffers[s2].buffer,
745                          "HTTP/1.1 %d ERROR\r\nConnection: close\r\n\r\n",
746                          response);
747 
748                 break;
749               }
750 
751               buffers[s2].buffertail =
752                   buffers[s2].buffer + strlen(buffers[s2].buffer);
753 
754               // Send the response to the client socket
755               break;
756             }  // end of CONNECT handling
757 
758             if (!buffers[s].areafree()) {
759               // Do not poll for read when the buffer is full
760               LOG_DEBUG((" no place in our read buffer, stop reading"));
761               in_flags &= ~PR_POLL_READ;
762             }
763 
764             if (ssl_updated) {
765               if (s == 0 && expect_request_start) {
766                 if (!strstr(buffers[s].bufferhead, "\r\n\r\n")) {
767                   // We haven't received the complete header yet, so wait.
768                   continue;
769                 }
770                 ci->iswebsocket = AdjustWebSocketHost(buffers[s], ci);
771                 expect_request_start = !(
772                     ci->iswebsocket || AdjustRequestURI(buffers[s], &fullHost));
773                 PRNetAddr* addr = &remote_addr;
774                 if (ci->iswebsocket && websocket_server.inet.port)
775                   addr = &websocket_server;
776                 if (!ConnectSocket(other_sock, addr, connect_timeout)) {
777                   LOG_ERRORD(
778                       (" could not open connection to the real server\n"));
779                   client_error = true;
780                   break;
781                 }
782                 LOG_DEBUG(("\n connected to remote server\n"));
783                 numberOfSockets = 2;
784               } else if (s == 1 && ci->iswebsocket) {
785                 if (!AdjustWebSocketLocation(buffers[s], ci)) continue;
786               }
787 
788               in_flags2 |= PR_POLL_WRITE;
789               LOG_DEBUG((" telling the other socket to write"));
790             } else
791               LOG_DEBUG(
792                   (" we have something for the other socket to write, but ssl "
793                    "has not been administered on it"));
794           }
795         }  // PR_POLL_READ handling
796 
797         if (out_flags & PR_POLL_WRITE) {
798           LOG_DEBUG((" :writing"));
799           int32_t bytesWrite =
800               PR_Send(sockets[s].fd, buffers[s2].bufferhead,
801                       buffers[s2].present(), 0, PR_INTERVAL_NO_TIMEOUT);
802 
803           if (bytesWrite < 0) {
804             if (PR_GetError() != PR_WOULD_BLOCK_ERROR) {
805               LOG_DEBUG((" error=%d", PR_GetError()));
806               client_error = true;
807               socketErrorState[s] = true;
808               // We got a fatal error while writting the buffer. Clear it to
809               // break the main loop, we will never more be able to send it.
810               buffers[s2].bufferhead = buffers[s2].buffertail =
811                   buffers[s2].buffer;
812             } else
813               LOG_DEBUG((" would block"));
814           } else {
815             LOG_DEBUG((", written %d bytes", bytesWrite));
816             buffers[s2].buffertail[1] = '\0';
817             LOG_DEBUG((" dump:\n%.*s\n", bytesWrite, buffers[s2].bufferhead));
818 
819             buffers[s2].bufferhead += bytesWrite;
820             if (buffers[s2].present()) {
821               LOG_DEBUG((" still have to write %d bytes",
822                          (int)buffers[s2].present()));
823               in_flags |= PR_POLL_WRITE;
824             } else {
825               if (!ssl_updated) {
826                 LOG_DEBUG((" proxy response sent to the client"));
827                 // Proxy response has just been writen, update to ssl
828                 ssl_updated = true;
829                 if (ci->http_proxy_only) {
830                   LOG_DEBUG(
831                       (" not updating to SSL based on http_proxy_only for this "
832                        "socket"));
833                 } else if (!ConfigureSSLServerSocket(
834                                ci->client_sock, ci->server_info,
835                                certificateToUse, clientAuth, flags)) {
836                   LOG_ERRORD((" failed to config server socket\n"));
837                   client_error = true;
838                   break;
839                 } else {
840                   LOG_DEBUG((" client socket updated to SSL"));
841                 }
842               }  // sslUpdate
843 
844               LOG_DEBUG(
845                   (" dropping our write flag and setting other socket read "
846                    "flag"));
847               in_flags &= ~PR_POLL_WRITE;
848               in_flags2 |= PR_POLL_READ;
849               buffers[s2].compact();
850             }
851           }
852         }                 // PR_POLL_WRITE handling
853         LOG_END_BLOCK();  // end the log
854       }                   // for...
855     }                     // while, poll
856   } else
857     client_error = true;
858 
859   LOG_DEBUG(("SSLTUNNEL(%p)): exiting root function for csock=%p, ssock=%p\n",
860              static_cast<void*>(data), static_cast<void*>(ci->client_sock),
861              static_cast<void*>(other_sock.get())));
862   if (!client_error) PR_Shutdown(ci->client_sock, PR_SHUTDOWN_SEND);
863   PR_Close(ci->client_sock);
864 
865   delete ci;
866 }
867 
868 /*
869  * Start listening for SSL connections on a specified port, handing
870  * them off to client threads after accepting the connection.
871  * The data parameter is a server_info_t*, owned by the calling
872  * function.
873  */
StartServer(void * data)874 void StartServer(void* data) {
875   server_info_t* si = static_cast<server_info_t*>(data);
876 
877   // TODO: select ciphers?
878   UniquePRFileDesc listen_socket(PR_NewTCPSocket());
879   if (!listen_socket) {
880     LOG_ERROR(("failed to create socket\n"));
881     SignalShutdown();
882     return;
883   }
884 
885   // In case the socket is still open in the TIME_WAIT state from a previous
886   // instance of ssltunnel we ask to reuse the port.
887   PRSocketOptionData socket_option;
888   socket_option.option = PR_SockOpt_Reuseaddr;
889   socket_option.value.reuse_addr = true;
890   PR_SetSocketOption(listen_socket.get(), &socket_option);
891 
892   PRNetAddr server_addr;
893   PR_InitializeNetAddr(PR_IpAddrAny, si->listen_port, &server_addr);
894   if (PR_Bind(listen_socket.get(), &server_addr) != PR_SUCCESS) {
895     LOG_ERROR(("failed to bind socket on port %d: error %d\n", si->listen_port,
896                PR_GetError()));
897     SignalShutdown();
898     return;
899   }
900 
901   if (PR_Listen(listen_socket.get(), 1) != PR_SUCCESS) {
902     LOG_ERROR(("failed to listen on socket\n"));
903     SignalShutdown();
904     return;
905   }
906 
907   LOG_INFO(("Server listening on port %d with cert %s\n", si->listen_port,
908             si->cert_nickname.c_str()));
909 
910   while (!shutdown_server) {
911     connection_info_t* ci = new connection_info_t();
912     ci->server_info = si;
913     ci->http_proxy_only = do_http_proxy;
914     // block waiting for connections
915     ci->client_sock = PR_Accept(listen_socket.get(), &ci->client_addr,
916                                 PR_INTERVAL_NO_TIMEOUT);
917 
918     PRSocketOptionData option;
919     option.option = PR_SockOpt_Nonblocking;
920     option.value.non_blocking = true;
921     PR_SetSocketOption(ci->client_sock, &option);
922 
923     if (ci->client_sock)
924       // Not actually using this PRJob*...
925       // PRJob* job =
926       PR_QueueJob(threads, HandleConnection, ci, true);
927     else
928       delete ci;
929   }
930 }
931 
932 // bogus password func, just don't use passwords. :-P
password_func(PK11SlotInfo * slot,PRBool retry,void * arg)933 char* password_func(PK11SlotInfo* slot, PRBool retry, void* arg) {
934   if (retry) return nullptr;
935 
936   return PL_strdup("");
937 }
938 
findServerInfo(int portnumber)939 server_info_t* findServerInfo(int portnumber) {
940   for (auto& server : servers) {
941     if (server.listen_port == portnumber) return &server;
942   }
943 
944   return nullptr;
945 }
946 
get_ssl3_table(server_info_t * server)947 PLHashTable* get_ssl3_table(server_info_t* server) {
948   return server->host_ssl3_table;
949 }
950 
get_tls1_table(server_info_t * server)951 PLHashTable* get_tls1_table(server_info_t* server) {
952   return server->host_tls1_table;
953 }
954 
get_rc4_table(server_info_t * server)955 PLHashTable* get_rc4_table(server_info_t* server) {
956   return server->host_rc4_table;
957 }
958 
get_failhandshake_table(server_info_t * server)959 PLHashTable* get_failhandshake_table(server_info_t* server) {
960   return server->host_failhandshake_table;
961 }
962 
parseWeakCryptoConfig(char * const & keyword,char * & _caret,PLHashTable * (* get_table)(server_info_t *))963 int parseWeakCryptoConfig(char* const& keyword, char*& _caret,
964                           PLHashTable* (*get_table)(server_info_t*)) {
965   char* hostname = strtok2(_caret, ":", &_caret);
966   char* hostportstring = strtok2(_caret, ":", &_caret);
967   char* serverportstring = strtok2(_caret, "\n", &_caret);
968 
969   int port = atoi(serverportstring);
970   if (port <= 0) {
971     LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
972     return 1;
973   }
974 
975   if (server_info_t* existingServer = findServerInfo(port)) {
976     any_host_spec_config = true;
977 
978     char* hostname_copy =
979         new char[strlen(hostname) + strlen(hostportstring) + 2];
980     if (!hostname_copy) {
981       LOG_ERROR(("Out of memory"));
982       return 1;
983     }
984 
985     strcpy(hostname_copy, hostname);
986     strcat(hostname_copy, ":");
987     strcat(hostname_copy, hostportstring);
988 
989     PLHashEntry* entry =
990         PL_HashTableAdd(get_table(existingServer), hostname_copy, keyword);
991     if (!entry) {
992       LOG_ERROR(("Out of memory"));
993       return 1;
994     }
995   } else {
996     LOG_ERROR(
997         ("Server on port %d for redirhost option is not defined, use 'listen' "
998          "option first",
999          port));
1000     return 1;
1001   }
1002 
1003   return 0;
1004 }
1005 
processConfigLine(char * configLine)1006 int processConfigLine(char* configLine) {
1007   if (*configLine == 0 || *configLine == '#') return 0;
1008 
1009   char* _caret;
1010   char* keyword = strtok2(configLine, ":", &_caret);
1011 
1012   // Configure usage of http/ssl tunneling proxy behavior
1013   if (!strcmp(keyword, "httpproxy")) {
1014     char* value = strtok2(_caret, ":", &_caret);
1015     if (!strcmp(value, "1")) do_http_proxy = true;
1016 
1017     return 0;
1018   }
1019 
1020   if (!strcmp(keyword, "websocketserver")) {
1021     char* ipstring = strtok2(_caret, ":", &_caret);
1022     if (PR_StringToNetAddr(ipstring, &websocket_server) != PR_SUCCESS) {
1023       LOG_ERROR(("Invalid IP address in proxy config: %s\n", ipstring));
1024       return 1;
1025     }
1026     char* remoteport = strtok2(_caret, ":", &_caret);
1027     int port = atoi(remoteport);
1028     if (port <= 0) {
1029       LOG_ERROR(("Invalid remote port in proxy config: %s\n", remoteport));
1030       return 1;
1031     }
1032     websocket_server.inet.port = PR_htons(port);
1033     return 0;
1034   }
1035 
1036   // Configure the forward address of the target server
1037   if (!strcmp(keyword, "forward")) {
1038     char* ipstring = strtok2(_caret, ":", &_caret);
1039     if (PR_StringToNetAddr(ipstring, &remote_addr) != PR_SUCCESS) {
1040       LOG_ERROR(("Invalid remote IP address: %s\n", ipstring));
1041       return 1;
1042     }
1043     char* serverportstring = strtok2(_caret, ":", &_caret);
1044     int port = atoi(serverportstring);
1045     if (port <= 0) {
1046       LOG_ERROR(("Invalid remote port: %s\n", serverportstring));
1047       return 1;
1048     }
1049     remote_addr.inet.port = PR_htons(port);
1050 
1051     return 0;
1052   }
1053 
1054   // Configure all listen sockets and port+certificate bindings
1055   if (!strcmp(keyword, "listen")) {
1056     char* hostname = strtok2(_caret, ":", &_caret);
1057     char* hostportstring = nullptr;
1058     if (strcmp(hostname, "*")) {
1059       any_host_spec_config = true;
1060       hostportstring = strtok2(_caret, ":", &_caret);
1061     }
1062 
1063     char* serverportstring = strtok2(_caret, ":", &_caret);
1064     char* certnick = strtok2(_caret, ":", &_caret);
1065 
1066     int port = atoi(serverportstring);
1067     if (port <= 0) {
1068       LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
1069       return 1;
1070     }
1071 
1072     if (server_info_t* existingServer = findServerInfo(port)) {
1073       char* certnick_copy = new char[strlen(certnick) + 1];
1074       char* hostname_copy =
1075           new char[strlen(hostname) + strlen(hostportstring) + 2];
1076 
1077       strcpy(hostname_copy, hostname);
1078       strcat(hostname_copy, ":");
1079       strcat(hostname_copy, hostportstring);
1080       strcpy(certnick_copy, certnick);
1081 
1082       PLHashEntry* entry = PL_HashTableAdd(existingServer->host_cert_table,
1083                                            hostname_copy, certnick_copy);
1084       if (!entry) {
1085         LOG_ERROR(("Out of memory"));
1086         return 1;
1087       }
1088     } else {
1089       server_info_t server;
1090       server.cert_nickname = certnick;
1091       server.listen_port = port;
1092       server.host_cert_table =
1093           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1094                           PL_CompareStrings, nullptr, nullptr);
1095       if (!server.host_cert_table) {
1096         LOG_ERROR(("Internal, could not create hash table\n"));
1097         return 1;
1098       }
1099       server.host_clientauth_table =
1100           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1101                           ClientAuthValueComparator, nullptr, nullptr);
1102       if (!server.host_clientauth_table) {
1103         LOG_ERROR(("Internal, could not create hash table\n"));
1104         return 1;
1105       }
1106       server.host_redir_table =
1107           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1108                           PL_CompareStrings, nullptr, nullptr);
1109       if (!server.host_redir_table) {
1110         LOG_ERROR(("Internal, could not create hash table\n"));
1111         return 1;
1112       }
1113 
1114       server.host_ssl3_table =
1115           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1116                           PL_CompareStrings, nullptr, nullptr);
1117       ;
1118       if (!server.host_ssl3_table) {
1119         LOG_ERROR(("Internal, could not create hash table\n"));
1120         return 1;
1121       }
1122 
1123       server.host_tls1_table =
1124           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1125                           PL_CompareStrings, nullptr, nullptr);
1126       ;
1127       if (!server.host_tls1_table) {
1128         LOG_ERROR(("Internal, could not create hash table\n"));
1129         return 1;
1130       }
1131 
1132       server.host_rc4_table =
1133           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1134                           PL_CompareStrings, nullptr, nullptr);
1135       ;
1136       if (!server.host_rc4_table) {
1137         LOG_ERROR(("Internal, could not create hash table\n"));
1138         return 1;
1139       }
1140 
1141       server.host_failhandshake_table =
1142           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1143                           PL_CompareStrings, nullptr, nullptr);
1144       ;
1145       if (!server.host_failhandshake_table) {
1146         LOG_ERROR(("Internal, could not create hash table\n"));
1147         return 1;
1148       }
1149 
1150       servers.push_back(server);
1151     }
1152 
1153     return 0;
1154   }
1155 
1156   if (!strcmp(keyword, "clientauth")) {
1157     char* hostname = strtok2(_caret, ":", &_caret);
1158     char* hostportstring = strtok2(_caret, ":", &_caret);
1159     char* serverportstring = strtok2(_caret, ":", &_caret);
1160 
1161     int port = atoi(serverportstring);
1162     if (port <= 0) {
1163       LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
1164       return 1;
1165     }
1166 
1167     if (server_info_t* existingServer = findServerInfo(port)) {
1168       char* authoptionstring = strtok2(_caret, ":", &_caret);
1169       client_auth_option* authoption = new client_auth_option;
1170       if (!authoption) {
1171         LOG_ERROR(("Out of memory"));
1172         return 1;
1173       }
1174 
1175       if (!strcmp(authoptionstring, "require"))
1176         *authoption = caRequire;
1177       else if (!strcmp(authoptionstring, "request"))
1178         *authoption = caRequest;
1179       else if (!strcmp(authoptionstring, "none"))
1180         *authoption = caNone;
1181       else {
1182         LOG_ERROR(
1183             ("Incorrect client auth option modifier for host '%s'", hostname));
1184         delete authoption;
1185         return 1;
1186       }
1187 
1188       any_host_spec_config = true;
1189 
1190       char* hostname_copy =
1191           new char[strlen(hostname) + strlen(hostportstring) + 2];
1192       if (!hostname_copy) {
1193         LOG_ERROR(("Out of memory"));
1194         delete authoption;
1195         return 1;
1196       }
1197 
1198       strcpy(hostname_copy, hostname);
1199       strcat(hostname_copy, ":");
1200       strcat(hostname_copy, hostportstring);
1201 
1202       PLHashEntry* entry = PL_HashTableAdd(
1203           existingServer->host_clientauth_table, hostname_copy, authoption);
1204       if (!entry) {
1205         LOG_ERROR(("Out of memory"));
1206         delete authoption;
1207         return 1;
1208       }
1209     } else {
1210       LOG_ERROR(
1211           ("Server on port %d for client authentication option is not defined, "
1212            "use 'listen' option first",
1213            port));
1214       return 1;
1215     }
1216 
1217     return 0;
1218   }
1219 
1220   if (!strcmp(keyword, "redirhost")) {
1221     char* hostname = strtok2(_caret, ":", &_caret);
1222     char* hostportstring = strtok2(_caret, ":", &_caret);
1223     char* serverportstring = strtok2(_caret, ":", &_caret);
1224 
1225     int port = atoi(serverportstring);
1226     if (port <= 0) {
1227       LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
1228       return 1;
1229     }
1230 
1231     if (server_info_t* existingServer = findServerInfo(port)) {
1232       char* redirhoststring = strtok2(_caret, ":", &_caret);
1233 
1234       any_host_spec_config = true;
1235 
1236       char* hostname_copy =
1237           new char[strlen(hostname) + strlen(hostportstring) + 2];
1238       if (!hostname_copy) {
1239         LOG_ERROR(("Out of memory"));
1240         return 1;
1241       }
1242 
1243       strcpy(hostname_copy, hostname);
1244       strcat(hostname_copy, ":");
1245       strcat(hostname_copy, hostportstring);
1246 
1247       char* redir_copy = new char[strlen(redirhoststring) + 1];
1248       strcpy(redir_copy, redirhoststring);
1249       PLHashEntry* entry = PL_HashTableAdd(existingServer->host_redir_table,
1250                                            hostname_copy, redir_copy);
1251       if (!entry) {
1252         LOG_ERROR(("Out of memory"));
1253         delete[] hostname_copy;
1254         delete[] redir_copy;
1255         return 1;
1256       }
1257     } else {
1258       LOG_ERROR(
1259           ("Server on port %d for redirhost option is not defined, use "
1260            "'listen' option first",
1261            port));
1262       return 1;
1263     }
1264 
1265     return 0;
1266   }
1267 
1268   if (!strcmp(keyword, "ssl3")) {
1269     return parseWeakCryptoConfig(keyword, _caret, get_ssl3_table);
1270   }
1271   if (!strcmp(keyword, "tls1")) {
1272     return parseWeakCryptoConfig(keyword, _caret, get_tls1_table);
1273   }
1274 
1275   if (!strcmp(keyword, "rc4")) {
1276     return parseWeakCryptoConfig(keyword, _caret, get_rc4_table);
1277   }
1278 
1279   if (!strcmp(keyword, "failHandshake")) {
1280     return parseWeakCryptoConfig(keyword, _caret, get_failhandshake_table);
1281   }
1282 
1283   // Configure the NSS certificate database directory
1284   if (!strcmp(keyword, "certdbdir")) {
1285     nssconfigdir = strtok2(_caret, "\n", &_caret);
1286     return 0;
1287   }
1288 
1289   LOG_ERROR(("Error: keyword \"%s\" unexpected\n", keyword));
1290   return 1;
1291 }
1292 
parseConfigFile(const char * filePath)1293 int parseConfigFile(const char* filePath) {
1294   FILE* f = fopen(filePath, "r");
1295   if (!f) return 1;
1296 
1297   char buffer[1024], *b = buffer;
1298   while (!feof(f)) {
1299     char c;
1300 
1301     if (fscanf(f, "%c", &c) != 1) {
1302       break;
1303     }
1304 
1305     switch (c) {
1306       case '\n':
1307         *b++ = 0;
1308         if (processConfigLine(buffer)) {
1309           fclose(f);
1310           return 1;
1311         }
1312         b = buffer;
1313         continue;
1314 
1315       case '\r':
1316         continue;
1317 
1318       default:
1319         *b++ = c;
1320     }
1321   }
1322 
1323   fclose(f);
1324 
1325   // Check mandatory items
1326   if (nssconfigdir.empty()) {
1327     LOG_ERROR(
1328         ("Error: missing path to NSS certification database\n,use "
1329          "certdbdir:<path> in the config file\n"));
1330     return 1;
1331   }
1332 
1333   if (any_host_spec_config && !do_http_proxy) {
1334     LOG_ERROR(
1335         ("Warning: any host-specific configurations are ignored, add "
1336          "httpproxy:1 to allow them\n"));
1337   }
1338 
1339   return 0;
1340 }
1341 
freeHostCertHashItems(PLHashEntry * he,int i,void * arg)1342 int freeHostCertHashItems(PLHashEntry* he, int i, void* arg) {
1343   delete[](char*) he->key;
1344   delete[](char*) he->value;
1345   return HT_ENUMERATE_REMOVE;
1346 }
1347 
freeHostRedirHashItems(PLHashEntry * he,int i,void * arg)1348 int freeHostRedirHashItems(PLHashEntry* he, int i, void* arg) {
1349   delete[](char*) he->key;
1350   delete[](char*) he->value;
1351   return HT_ENUMERATE_REMOVE;
1352 }
1353 
freeClientAuthHashItems(PLHashEntry * he,int i,void * arg)1354 int freeClientAuthHashItems(PLHashEntry* he, int i, void* arg) {
1355   delete[](char*) he->key;
1356   delete (client_auth_option*)he->value;
1357   return HT_ENUMERATE_REMOVE;
1358 }
1359 
freeSSL3HashItems(PLHashEntry * he,int i,void * arg)1360 int freeSSL3HashItems(PLHashEntry* he, int i, void* arg) {
1361   delete[](char*) he->key;
1362   return HT_ENUMERATE_REMOVE;
1363 }
1364 
freeTLS1HashItems(PLHashEntry * he,int i,void * arg)1365 int freeTLS1HashItems(PLHashEntry* he, int i, void* arg) {
1366   delete[](char*) he->key;
1367   return HT_ENUMERATE_REMOVE;
1368 }
1369 
freeRC4HashItems(PLHashEntry * he,int i,void * arg)1370 int freeRC4HashItems(PLHashEntry* he, int i, void* arg) {
1371   delete[](char*) he->key;
1372   return HT_ENUMERATE_REMOVE;
1373 }
1374 
main(int argc,char ** argv)1375 int main(int argc, char** argv) {
1376   const char* configFilePath;
1377 
1378   const char* logLevelEnv = PR_GetEnv("SSLTUNNEL_LOG_LEVEL");
1379   gLogLevel = logLevelEnv ? (LogLevel)atoi(logLevelEnv) : LEVEL_INFO;
1380 
1381   if (argc == 1)
1382     configFilePath = "ssltunnel.cfg";
1383   else
1384     configFilePath = argv[1];
1385 
1386   memset(&websocket_server, 0, sizeof(PRNetAddr));
1387 
1388   if (parseConfigFile(configFilePath)) {
1389     LOG_ERROR((
1390         "Error: config file \"%s\" missing or formating incorrect\n"
1391         "Specify path to the config file as parameter to ssltunnel or \n"
1392         "create ssltunnel.cfg in the working directory.\n\n"
1393         "Example format of the config file:\n\n"
1394         "       # Enable http/ssl tunneling proxy-like behavior.\n"
1395         "       # If not specified ssltunnel simply does direct forward.\n"
1396         "       httpproxy:1\n\n"
1397         "       # Specify path to the certification database used.\n"
1398         "       certdbdir:/path/to/certdb\n\n"
1399         "       # Forward/proxy all requests in raw to 127.0.0.1:8888.\n"
1400         "       forward:127.0.0.1:8888\n\n"
1401         "       # Accept connections on port 4443 or 5678 resp. and "
1402         "authenticate\n"
1403         "       # to any host ('*') using the 'server cert' or 'server cert 2' "
1404         "resp.\n"
1405         "       listen:*:4443:server cert\n"
1406         "       listen:*:5678:server cert 2\n\n"
1407         "       # Accept connections on port 4443 and authenticate using\n"
1408         "       # 'a different cert' when target host is 'my.host.name:443'.\n"
1409         "       # This only works in httpproxy mode and has higher priority\n"
1410         "       # than the previous option.\n"
1411         "       listen:my.host.name:443:4443:a different cert\n\n"
1412         "       # To make a specific host require or just request a client "
1413         "certificate\n"
1414         "       # to authenticate use the following options. This can only be "
1415         "used\n"
1416         "       # in httpproxy mode and only after the 'listen' option has "
1417         "been\n"
1418         "       # specified. You also have to specify the tunnel listen port.\n"
1419         "       clientauth:requesting-client-cert.host.com:443:4443:request\n"
1420         "       clientauth:requiring-client-cert.host.com:443:4443:require\n"
1421         "       # Proxy WebSocket traffic to the server at 127.0.0.1:9999,\n"
1422         "       # instead of the server specified in the 'forward' option.\n"
1423         "       websocketserver:127.0.0.1:9999\n",
1424         configFilePath));
1425     return 1;
1426   }
1427 
1428   // create a thread pool to handle connections
1429   threads =
1430       PR_CreateThreadPool(INITIAL_THREADS * servers.size(),
1431                           MAX_THREADS * servers.size(), DEFAULT_STACKSIZE);
1432   if (!threads) {
1433     LOG_ERROR(("Failed to create thread pool\n"));
1434     return 1;
1435   }
1436 
1437   shutdown_lock = PR_NewLock();
1438   if (!shutdown_lock) {
1439     LOG_ERROR(("Failed to create lock\n"));
1440     PR_ShutdownThreadPool(threads);
1441     return 1;
1442   }
1443   shutdown_condvar = PR_NewCondVar(shutdown_lock);
1444   if (!shutdown_condvar) {
1445     LOG_ERROR(("Failed to create condvar\n"));
1446     PR_ShutdownThreadPool(threads);
1447     PR_DestroyLock(shutdown_lock);
1448     return 1;
1449   }
1450 
1451   PK11_SetPasswordFunc(password_func);
1452 
1453   // Initialize NSS
1454   if (NSS_Init(nssconfigdir.c_str()) != SECSuccess) {
1455     int32_t errorlen = PR_GetErrorTextLength();
1456     if (errorlen) {
1457       auto err = mozilla::MakeUnique<char[]>(errorlen + 1);
1458       PR_GetErrorText(err.get());
1459       LOG_ERROR(("Failed to init NSS: %s", err.get()));
1460     } else {
1461       LOG_ERROR(("Failed to init NSS: Cannot get error from NSPR."));
1462     }
1463     PR_ShutdownThreadPool(threads);
1464     PR_DestroyCondVar(shutdown_condvar);
1465     PR_DestroyLock(shutdown_lock);
1466     return 1;
1467   }
1468 
1469   if (NSS_SetDomesticPolicy() != SECSuccess) {
1470     LOG_ERROR(("NSS_SetDomesticPolicy failed\n"));
1471     PR_ShutdownThreadPool(threads);
1472     PR_DestroyCondVar(shutdown_condvar);
1473     PR_DestroyLock(shutdown_lock);
1474     NSS_Shutdown();
1475     return 1;
1476   }
1477 
1478   // these values should make NSS use the defaults
1479   if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) {
1480     LOG_ERROR(("SSL_ConfigServerSessionIDCache failed\n"));
1481     PR_ShutdownThreadPool(threads);
1482     PR_DestroyCondVar(shutdown_condvar);
1483     PR_DestroyLock(shutdown_lock);
1484     NSS_Shutdown();
1485     return 1;
1486   }
1487 
1488   for (auto& server : servers) {
1489     // Not actually using this PRJob*...
1490     // PRJob* server_job =
1491     PR_QueueJob(threads, StartServer, &server, true);
1492   }
1493   // now wait for someone to tell us to quit
1494   PR_Lock(shutdown_lock);
1495   PR_WaitCondVar(shutdown_condvar, PR_INTERVAL_NO_TIMEOUT);
1496   PR_Unlock(shutdown_lock);
1497   shutdown_server = true;
1498   LOG_INFO(("Shutting down...\n"));
1499   // cleanup
1500   PR_ShutdownThreadPool(threads);
1501   PR_JoinThreadPool(threads);
1502   PR_DestroyCondVar(shutdown_condvar);
1503   PR_DestroyLock(shutdown_lock);
1504   if (NSS_Shutdown() == SECFailure) {
1505     LOG_DEBUG(("Leaked NSS objects!\n"));
1506   }
1507 
1508   for (auto& server : servers) {
1509     PL_HashTableEnumerateEntries(server.host_cert_table, freeHostCertHashItems,
1510                                  nullptr);
1511     PL_HashTableEnumerateEntries(server.host_clientauth_table,
1512                                  freeClientAuthHashItems, nullptr);
1513     PL_HashTableEnumerateEntries(server.host_redir_table,
1514                                  freeHostRedirHashItems, nullptr);
1515     PL_HashTableEnumerateEntries(server.host_ssl3_table, freeSSL3HashItems,
1516                                  nullptr);
1517     PL_HashTableEnumerateEntries(server.host_tls1_table, freeTLS1HashItems,
1518                                  nullptr);
1519     PL_HashTableEnumerateEntries(server.host_rc4_table, freeRC4HashItems,
1520                                  nullptr);
1521     PL_HashTableEnumerateEntries(server.host_failhandshake_table,
1522                                  freeRC4HashItems, nullptr);
1523     PL_HashTableDestroy(server.host_cert_table);
1524     PL_HashTableDestroy(server.host_clientauth_table);
1525     PL_HashTableDestroy(server.host_redir_table);
1526     PL_HashTableDestroy(server.host_ssl3_table);
1527     PL_HashTableDestroy(server.host_tls1_table);
1528     PL_HashTableDestroy(server.host_rc4_table);
1529     PL_HashTableDestroy(server.host_failhandshake_table);
1530   }
1531 
1532   PR_Cleanup();
1533   return 0;
1534 }
1535