1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* This Source Code Form is subject to the terms of the Mozilla Public
3  * License, v. 2.0. If a copy of the MPL was not distributed with this
4  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
5 
6 /*
7  * WARNING: DO NOT USE THIS CODE IN PRODUCTION SYSTEMS.  It is highly likely to
8  *          be plagued with the usual problems endemic to C (buffer overflows
9  *          and the like).  We don't especially care here (but would accept
10  *          patches!) because this is only intended for use in our test
11  *          harnesses in controlled situations where input is guaranteed not to
12  *          be malicious.
13  */
14 
15 #include "ScopedNSSTypes.h"
16 #include <assert.h>
17 #include <stdio.h>
18 #include <string>
19 #include <vector>
20 #include <algorithm>
21 #include <stdarg.h>
22 #include "prinit.h"
23 #include "prerror.h"
24 #include "prenv.h"
25 #include "prnetdb.h"
26 #include "prtpool.h"
27 #include "nsAlgorithm.h"
28 #include "nss.h"
29 #include "keyhi.h"
30 #include "ssl.h"
31 #include "sslproto.h"
32 #include "plhash.h"
33 #include "mozilla/Sprintf.h"
34 #include "mozilla/Unused.h"
35 
36 using namespace mozilla;
37 using namespace mozilla::psm;
38 using std::string;
39 using std::vector;
40 
41 #define IS_DELIM(m, c) ((m)[(c) >> 3] & (1 << ((c)&7)))
42 #define SET_DELIM(m, c) ((m)[(c) >> 3] |= (1 << ((c)&7)))
43 #define DELIM_TABLE_SIZE 32
44 
45 // You can set the level of logging by env var SSLTUNNEL_LOG_LEVEL=n, where n
46 // is 0 through 3.  The default is 1, INFO level logging.
47 enum LogLevel {
48   LEVEL_DEBUG = 0,
49   LEVEL_INFO = 1,
50   LEVEL_ERROR = 2,
51   LEVEL_SILENT = 3
52 } gLogLevel,
53     gLastLogLevel;
54 
55 #define _LOG_OUTPUT(level, func, params) \
56   PR_BEGIN_MACRO                         \
57   if (level >= gLogLevel) {              \
58     gLastLogLevel = level;               \
59     func params;                         \
60   }                                      \
61   PR_END_MACRO
62 
63 // The most verbose output
64 #define LOG_DEBUG(params) _LOG_OUTPUT(LEVEL_DEBUG, printf, params)
65 
66 // Top level informative messages
67 #define LOG_INFO(params) _LOG_OUTPUT(LEVEL_INFO, printf, params)
68 
69 // Serious errors that must be logged always until completely gag
70 #define LOG_ERROR(params) _LOG_OUTPUT(LEVEL_ERROR, eprintf, params)
71 
72 // Same as LOG_ERROR, but when logging is set to LEVEL_DEBUG, the message
73 // will be put to the stdout instead of stderr to keep continuity with other
74 // LOG_DEBUG message output
75 #define LOG_ERRORD(params)                     \
76   PR_BEGIN_MACRO                               \
77   if (gLogLevel == LEVEL_DEBUG)                \
78     _LOG_OUTPUT(LEVEL_ERROR, printf, params);  \
79   else                                         \
80     _LOG_OUTPUT(LEVEL_ERROR, eprintf, params); \
81   PR_END_MACRO
82 
83 // If there is any output written between LOG_BEGIN_BLOCK() and
84 // LOG_END_BLOCK() then a new line will be put to the proper output (out/err)
85 #define LOG_BEGIN_BLOCK() gLastLogLevel = LEVEL_SILENT;
86 
87 #define LOG_END_BLOCK()                                                        \
88   PR_BEGIN_MACRO                                                               \
89   if (gLastLogLevel == LEVEL_ERROR) LOG_ERROR(("\n"));                         \
90   if (gLastLogLevel < LEVEL_ERROR) _LOG_OUTPUT(gLastLogLevel, printf, ("\n")); \
91   PR_END_MACRO
92 
eprintf(const char * str,...)93 int eprintf(const char* str, ...) {
94   va_list ap;
95   va_start(ap, str);
96   int result = vfprintf(stderr, str, ap);
97   va_end(ap);
98   return result;
99 }
100 
101 // Copied from nsCRT
strtok2(char * string,const char * delims,char ** newStr)102 char* strtok2(char* string, const char* delims, char** newStr) {
103   PR_ASSERT(string);
104 
105   char delimTable[DELIM_TABLE_SIZE];
106   uint32_t i;
107   char* result;
108   char* str = string;
109 
110   for (i = 0; i < DELIM_TABLE_SIZE; i++) delimTable[i] = '\0';
111 
112   for (i = 0; delims[i]; i++) {
113     SET_DELIM(delimTable, static_cast<uint8_t>(delims[i]));
114   }
115 
116   // skip to beginning
117   while (*str && IS_DELIM(delimTable, static_cast<uint8_t>(*str))) {
118     str++;
119   }
120   result = str;
121 
122   // fix up the end of the token
123   while (*str) {
124     if (IS_DELIM(delimTable, static_cast<uint8_t>(*str))) {
125       *str++ = '\0';
126       break;
127     }
128     str++;
129   }
130   *newStr = str;
131 
132   return str == result ? nullptr : result;
133 }
134 
135 enum client_auth_option { caNone = 0, caRequire = 1, caRequest = 2 };
136 
137 // Structs for passing data into jobs on the thread pool
138 struct server_info_t {
139   int32_t listen_port;
140   string cert_nickname;
141   PLHashTable* host_cert_table;
142   PLHashTable* host_clientauth_table;
143   PLHashTable* host_redir_table;
144   PLHashTable* host_ssl3_table;
145   PLHashTable* host_tls1_table;
146   PLHashTable* host_tls11_table;
147   PLHashTable* host_tls12_table;
148   PLHashTable* host_tls13_table;
149   PLHashTable* host_rc4_table;
150   PLHashTable* host_failhandshake_table;
151 };
152 
153 struct connection_info_t {
154   PRFileDesc* client_sock;
155   PRNetAddr client_addr;
156   server_info_t* server_info;
157   // the original host in the Host: header for this connection is
158   // stored here, for proxied connections
159   string original_host;
160   // true if no SSL should be used for this connection
161   bool http_proxy_only;
162   // true if this connection is for a WebSocket
163   bool iswebsocket;
164 };
165 
166 struct server_match_t {
167   string fullHost;
168   bool matched;
169 };
170 
171 const int32_t BUF_SIZE = 16384;
172 const int32_t BUF_MARGIN = 1024;
173 const int32_t BUF_TOTAL = BUF_SIZE + BUF_MARGIN;
174 
175 struct relayBuffer {
176   char *buffer, *bufferhead, *buffertail, *bufferend;
177 
relayBufferrelayBuffer178   relayBuffer() {
179     // Leave 1024 bytes more for request line manipulations
180     bufferhead = buffertail = buffer = new char[BUF_TOTAL];
181     bufferend = buffer + BUF_SIZE;
182   }
183 
~relayBufferrelayBuffer184   ~relayBuffer() { delete[] buffer; }
185 
compactrelayBuffer186   void compact() {
187     if (buffertail == bufferhead) buffertail = bufferhead = buffer;
188   }
189 
emptyrelayBuffer190   bool empty() { return bufferhead == buffertail; }
areafreerelayBuffer191   size_t areafree() { return bufferend - buffertail; }
marginrelayBuffer192   size_t margin() { return areafree() + BUF_MARGIN; }
presentrelayBuffer193   size_t present() { return buffertail - bufferhead; }
194 };
195 
196 // These numbers are multiplied by the number of listening ports (actual
197 // servers running).  According the thread pool implementation there is no
198 // need to limit the number of threads initially, threads are allocated
199 // dynamically and stored in a linked list.  Initial number of 2 is chosen
200 // to allocate a thread for socket accept and preallocate one for the first
201 // connection that is with high probability expected to come.
202 const uint32_t INITIAL_THREADS = 2;
203 const uint32_t MAX_THREADS = 100;
204 const uint32_t DEFAULT_STACKSIZE = (512 * 1024);
205 
206 // global data
207 string nssconfigdir;
208 vector<server_info_t> servers;
209 PRNetAddr remote_addr;
210 PRNetAddr websocket_server;
211 PRThreadPool* threads = nullptr;
212 PRLock* shutdown_lock = nullptr;
213 PRCondVar* shutdown_condvar = nullptr;
214 // Not really used, unless something fails to start
215 bool shutdown_server = false;
216 bool do_http_proxy = false;
217 bool any_host_spec_config = false;
218 bool listen_public = false;
219 
ClientAuthValueComparator(const void * v1,const void * v2)220 int ClientAuthValueComparator(const void* v1, const void* v2) {
221   int a = *static_cast<const client_auth_option*>(v1) -
222           *static_cast<const client_auth_option*>(v2);
223   if (a == 0) return 0;
224   if (a > 0) return 1;
225   // (a < 0)
226   return -1;
227 }
228 
match_hostname(PLHashEntry * he,int index,void * arg)229 static int match_hostname(PLHashEntry* he, int index, void* arg) {
230   server_match_t* match = (server_match_t*)arg;
231   if (match->fullHost.find((char*)he->key) != string::npos)
232     match->matched = true;
233   return HT_ENUMERATE_NEXT;
234 }
235 
236 /*
237  * Signal the main thread that the application should shut down.
238  */
SignalShutdown()239 void SignalShutdown() {
240   PR_Lock(shutdown_lock);
241   PR_NotifyCondVar(shutdown_condvar);
242   PR_Unlock(shutdown_lock);
243 }
244 
245 // available flags
246 enum {
247   USE_SSL3 = 1 << 0,
248   USE_RC4 = 1 << 1,
249   FAIL_HANDSHAKE = 1 << 2,
250   USE_TLS1 = 1 << 3,
251   USE_TLS1_1 = 1 << 4,
252   USE_TLS1_2 = 1 << 5,
253   USE_TLS1_3 = 1 << 6
254 };
255 
ReadConnectRequest(server_info_t * server_info,relayBuffer & buffer,int32_t * result,string & certificate,client_auth_option * clientauth,string & host,string & location,int32_t * flags)256 bool ReadConnectRequest(server_info_t* server_info, relayBuffer& buffer,
257                         int32_t* result, string& certificate,
258                         client_auth_option* clientauth, string& host,
259                         string& location, int32_t* flags) {
260   if (buffer.present() < 4) {
261     LOG_DEBUG(
262         (" !! only %d bytes present in the buffer", (int)buffer.present()));
263     return false;
264   }
265   if (strncmp(buffer.buffertail - 4, "\r\n\r\n", 4)) {
266     LOG_ERRORD((" !! request is not tailed with CRLFCRLF but with %x %x %x %x",
267                 *(buffer.buffertail - 4), *(buffer.buffertail - 3),
268                 *(buffer.buffertail - 2), *(buffer.buffertail - 1)));
269     return false;
270   }
271 
272   LOG_DEBUG((" parsing initial connect request, dump:\n%.*s\n",
273              (int)buffer.present(), buffer.bufferhead));
274 
275   *result = 400;
276 
277   char* token;
278   char* _caret;
279   token = strtok2(buffer.bufferhead, " ", &_caret);
280   if (!token) {
281     LOG_ERRORD((" no space found"));
282     return true;
283   }
284   if (strcmp(token, "CONNECT")) {
285     LOG_ERRORD((" not CONNECT request but %s", token));
286     return true;
287   }
288 
289   token = strtok2(_caret, " ", &_caret);
290   void* c = PL_HashTableLookup(server_info->host_cert_table, token);
291   if (c) certificate = static_cast<char*>(c);
292 
293   host = "https://";
294   host += token;
295 
296   c = PL_HashTableLookup(server_info->host_clientauth_table, token);
297   if (c)
298     *clientauth = *static_cast<client_auth_option*>(c);
299   else
300     *clientauth = caNone;
301 
302   void* redir = PL_HashTableLookup(server_info->host_redir_table, token);
303   if (redir) location = static_cast<char*>(redir);
304 
305   if (PL_HashTableLookup(server_info->host_ssl3_table, token)) {
306     *flags |= USE_SSL3;
307   }
308 
309   if (PL_HashTableLookup(server_info->host_rc4_table, token)) {
310     *flags |= USE_RC4;
311   }
312 
313   if (PL_HashTableLookup(server_info->host_tls1_table, token)) {
314     *flags |= USE_TLS1;
315   }
316 
317   if (PL_HashTableLookup(server_info->host_tls11_table, token)) {
318     *flags |= USE_TLS1_1;
319   }
320 
321   if (PL_HashTableLookup(server_info->host_tls12_table, token)) {
322     *flags |= USE_TLS1_2;
323   }
324 
325   if (PL_HashTableLookup(server_info->host_tls13_table, token)) {
326     *flags |= USE_TLS1_3;
327   }
328 
329   if (PL_HashTableLookup(server_info->host_failhandshake_table, token)) {
330     *flags |= FAIL_HANDSHAKE;
331   }
332 
333   token = strtok2(_caret, "/", &_caret);
334   if (strcmp(token, "HTTP")) {
335     LOG_ERRORD((" not tailed with HTTP but with %s", token));
336     return true;
337   }
338 
339   *result = (redir) ? 302 : 200;
340   return true;
341 }
342 
ConfigureSSLServerSocket(PRFileDesc * socket,server_info_t * si,const string & certificate,const client_auth_option clientAuth,int32_t flags)343 bool ConfigureSSLServerSocket(PRFileDesc* socket, server_info_t* si,
344                               const string& certificate,
345                               const client_auth_option clientAuth,
346                               int32_t flags) {
347   const char* certnick =
348       certificate.empty() ? si->cert_nickname.c_str() : certificate.c_str();
349 
350   UniqueCERTCertificate cert(PK11_FindCertFromNickname(certnick, nullptr));
351   if (!cert) {
352     LOG_ERROR(("Failed to find cert %s\n", certnick));
353     return false;
354   }
355 
356   UniqueSECKEYPrivateKey privKey(PK11_FindKeyByAnyCert(cert.get(), nullptr));
357   if (!privKey) {
358     LOG_ERROR(("Failed to find private key\n"));
359     return false;
360   }
361 
362   PRFileDesc* ssl_socket = SSL_ImportFD(nullptr, socket);
363   if (!ssl_socket) {
364     LOG_ERROR(("Error importing SSL socket\n"));
365     return false;
366   }
367 
368   if (flags & FAIL_HANDSHAKE) {
369     // deliberately cause handshake to fail by sending the client a client hello
370     SSL_ResetHandshake(ssl_socket, false);
371     return true;
372   }
373 
374   SSLKEAType certKEA = NSS_FindCertKEAType(cert.get());
375   if (SSL_ConfigSecureServer(ssl_socket, cert.get(), privKey.get(), certKEA) !=
376       SECSuccess) {
377     LOG_ERROR(("Error configuring SSL server socket\n"));
378     return false;
379   }
380 
381   SSL_OptionSet(ssl_socket, SSL_SECURITY, true);
382   SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_CLIENT, false);
383   SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_SERVER, true);
384   SSL_OptionSet(ssl_socket, SSL_ENABLE_SESSION_TICKETS, true);
385 
386   if (clientAuth != caNone) {
387     // If we're requesting or requiring a client certificate, we should
388     // configure NSS to include the "certificate_authorities" field in the
389     // certificate request message. That way we can test that gecko properly
390     // takes note of it.
391     UniqueCERTCertificate issuer(
392         CERT_FindCertIssuer(cert.get(), PR_Now(), certUsageAnyCA));
393     if (!issuer) {
394       LOG_DEBUG(("Failed to find issuer for %s\n", certnick));
395       return false;
396     }
397     UniqueCERTCertList issuerList(CERT_NewCertList());
398     if (!issuerList) {
399       LOG_ERROR(("Failed to allocate new CERTCertList\n"));
400       return false;
401     }
402     if (CERT_AddCertToListTail(issuerList.get(), issuer.get()) != SECSuccess) {
403       LOG_ERROR(("Failed to add issuer to issuerList\n"));
404       return false;
405     }
406     Unused << issuer.release();  // Ownership transferred to issuerList.
407     if (SSL_SetTrustAnchors(ssl_socket, issuerList.get()) != SECSuccess) {
408       LOG_ERROR(
409           ("Failed to set certificate_authorities list for client "
410            "authentication\n"));
411       return false;
412     }
413     SSL_OptionSet(ssl_socket, SSL_REQUEST_CERTIFICATE, true);
414     SSL_OptionSet(ssl_socket, SSL_REQUIRE_CERTIFICATE, clientAuth == caRequire);
415   }
416 
417   SSLVersionRange range = {SSL_LIBRARY_VERSION_TLS_1_3,
418                            SSL_LIBRARY_VERSION_3_0};
419   if (flags & USE_SSL3) {
420     range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_3_0);
421     range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_3_0);
422   }
423   if (flags & USE_TLS1) {
424     range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_TLS_1_0);
425     range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_TLS_1_0);
426   }
427   if (flags & USE_TLS1_1) {
428     range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_TLS_1_1);
429     range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_TLS_1_1);
430   }
431   if (flags & USE_TLS1_2) {
432     range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_TLS_1_2);
433     range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_TLS_1_2);
434   }
435   if (flags & USE_TLS1_3) {
436     range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_TLS_1_3);
437     range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_TLS_1_3);
438   }
439   // Set the valid range, if any were specified (if not, skip
440   // when the default range is invalid, i.e. max > min)
441   if (range.min <= range.max &&
442       SSL_VersionRangeSet(ssl_socket, &range) != SECSuccess) {
443     LOG_ERROR(("Error configuring SSL socket version range\n"));
444     return false;
445   }
446 
447   if (flags & USE_RC4) {
448     for (uint16_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
449       uint16_t cipher_id = SSL_ImplementedCiphers[i];
450       switch (cipher_id) {
451         case TLS_ECDHE_ECDSA_WITH_RC4_128_SHA:
452         case TLS_ECDHE_RSA_WITH_RC4_128_SHA:
453         case TLS_RSA_WITH_RC4_128_SHA:
454         case TLS_RSA_WITH_RC4_128_MD5:
455           SSL_CipherPrefSet(ssl_socket, cipher_id, true);
456           break;
457 
458         default:
459           SSL_CipherPrefSet(ssl_socket, cipher_id, false);
460           break;
461       }
462     }
463   }
464 
465   SSL_ResetHandshake(ssl_socket, true);
466 
467   return true;
468 }
469 
470 /**
471  * This function examines the buffer for a Sec-WebSocket-Location: field,
472  * and if it's present, it replaces the hostname in that field with the
473  * value in the server's original_host field.  This function works
474  * in the reverse direction as AdjustWebSocketHost(), replacing the real
475  * hostname of a response with the potentially fake hostname that is expected
476  * by the browser (e.g., mochi.test).
477  *
478  * @return true if the header was adjusted successfully, or not found, false
479  * if the header is present but the url is not, which should indicate
480  * that more data needs to be read from the socket
481  */
AdjustWebSocketLocation(relayBuffer & buffer,connection_info_t * ci)482 bool AdjustWebSocketLocation(relayBuffer& buffer, connection_info_t* ci) {
483   assert(buffer.margin());
484   buffer.buffertail[1] = '\0';
485 
486   char* wsloc = strstr(buffer.bufferhead, "Sec-WebSocket-Location:");
487   if (!wsloc) return true;
488   // advance pointer to the start of the hostname
489   wsloc = strstr(wsloc, "ws://");
490   if (!wsloc) return false;
491   wsloc += 5;
492   // find the end of the hostname
493   char* wslocend = strchr(wsloc + 1, '/');
494   if (!wslocend) return false;
495   char* crlf = strstr(wsloc, "\r\n");
496   if (!crlf) return false;
497   if (ci->original_host.empty()) return true;
498 
499   int diff = ci->original_host.length() - (wslocend - wsloc);
500   if (diff > 0) assert(size_t(diff) <= buffer.margin());
501   memmove(wslocend + diff, wslocend, buffer.buffertail - wsloc - diff);
502   buffer.buffertail += diff;
503 
504   memcpy(wsloc, ci->original_host.c_str(), ci->original_host.length());
505   return true;
506 }
507 
508 /**
509  * This function examines the buffer for a Host: field, and if it's present,
510  * it replaces the hostname in that field with the hostname in the server's
511  * remote_addr field.  This is needed because proxy requests may be coming
512  * from mochitest with fake hosts, like mochi.test, and these need to be
513  * replaced with the host that the destination server is actually running
514  * on.
515  */
AdjustWebSocketHost(relayBuffer & buffer,connection_info_t * ci)516 bool AdjustWebSocketHost(relayBuffer& buffer, connection_info_t* ci) {
517   const char HEADER_UPGRADE[] = "Upgrade:";
518   const char HEADER_HOST[] = "Host:";
519 
520   PRNetAddr inet_addr =
521       (websocket_server.inet.port ? websocket_server : remote_addr);
522 
523   assert(buffer.margin());
524 
525   // Cannot use strnchr so add a null char at the end. There is always some
526   // space left because we preserve a margin.
527   buffer.buffertail[1] = '\0';
528 
529   // Verify this is a WebSocket header.
530   char* h1 = strstr(buffer.bufferhead, HEADER_UPGRADE);
531   if (!h1) return false;
532   h1 += strlen(HEADER_UPGRADE);
533   h1 += strspn(h1, " \t");
534   char* h2 = strstr(h1, "WebSocket\r\n");
535   if (!h2) h2 = strstr(h1, "websocket\r\n");
536   if (!h2) h2 = strstr(h1, "Websocket\r\n");
537   if (!h2) return false;
538 
539   char* host = strstr(buffer.bufferhead, HEADER_HOST);
540   if (!host) return false;
541   // advance pointer to beginning of hostname
542   host += strlen(HEADER_HOST);
543   host += strspn(host, " \t");
544 
545   char* endhost = strstr(host, "\r\n");
546   if (!endhost) return false;
547 
548   // Save the original host, so we can use it later on responses from the
549   // server.
550   ci->original_host.assign(host, endhost - host);
551 
552   char newhost[40];
553   PR_NetAddrToString(&inet_addr, newhost, sizeof(newhost));
554   assert(strlen(newhost) < sizeof(newhost) - 7);
555   SprintfLiteral(newhost, "%s:%d", newhost, PR_ntohs(inet_addr.inet.port));
556 
557   int diff = strlen(newhost) - (endhost - host);
558   if (diff > 0) assert(size_t(diff) <= buffer.margin());
559   memmove(endhost + diff, endhost, buffer.buffertail - host - diff);
560   buffer.buffertail += diff;
561 
562   memcpy(host, newhost, strlen(newhost));
563   return true;
564 }
565 
566 /**
567  * This function prefixes Request-URI path with a full scheme-host-port
568  * string.
569  */
AdjustRequestURI(relayBuffer & buffer,string * host)570 bool AdjustRequestURI(relayBuffer& buffer, string* host) {
571   assert(buffer.margin());
572 
573   // Cannot use strnchr so add a null char at the end. There is always some
574   // space left because we preserve a margin.
575   buffer.buffertail[1] = '\0';
576   LOG_DEBUG((" incoming request to adjust:\n%s\n", buffer.bufferhead));
577 
578   char *token, *path;
579   path = strchr(buffer.bufferhead, ' ') + 1;
580   if (!path) return false;
581 
582   // If the path doesn't start with a slash don't change it, it is probably '*'
583   // or a full path already. Return true, we are done with this request
584   // adjustment.
585   if (*path != '/') return true;
586 
587   token = strchr(path, ' ') + 1;
588   if (!token) return false;
589 
590   if (strncmp(token, "HTTP/", 5)) return false;
591 
592   size_t hostlength = host->length();
593   assert(hostlength <= buffer.margin());
594 
595   memmove(path + hostlength, path, buffer.buffertail - path);
596   memcpy(path, host->c_str(), hostlength);
597   buffer.buffertail += hostlength;
598 
599   return true;
600 }
601 
ConnectSocket(UniquePRFileDesc & fd,const PRNetAddr * addr,PRIntervalTime timeout)602 bool ConnectSocket(UniquePRFileDesc& fd, const PRNetAddr* addr,
603                    PRIntervalTime timeout) {
604   PRStatus stat = PR_Connect(fd.get(), addr, timeout);
605   if (stat != PR_SUCCESS) return false;
606 
607   PRSocketOptionData option;
608   option.option = PR_SockOpt_Nonblocking;
609   option.value.non_blocking = true;
610   PR_SetSocketOption(fd.get(), &option);
611 
612   return true;
613 }
614 
615 /*
616  * Handle an incoming client connection. The server thread has already
617  * accepted the connection, so we just need to connect to the remote
618  * port and then proxy data back and forth.
619  * The data parameter is a connection_info_t*, and must be deleted
620  * by this function.
621  */
HandleConnection(void * data)622 void HandleConnection(void* data) {
623   connection_info_t* ci = static_cast<connection_info_t*>(data);
624   PRIntervalTime connect_timeout = PR_SecondsToInterval(30);
625 
626   UniquePRFileDesc other_sock(PR_NewTCPSocket());
627   bool client_done = false;
628   bool client_error = false;
629   bool connect_accepted = !do_http_proxy;
630   bool ssl_updated = !do_http_proxy;
631   bool expect_request_start = do_http_proxy;
632   string certificateToUse;
633   string locationHeader;
634   client_auth_option clientAuth;
635   string fullHost;
636   int32_t flags = 0;
637 
638   LOG_DEBUG(("SSLTUNNEL(%p)): incoming connection csock(0)=%p, ssock(1)=%p\n",
639              static_cast<void*>(data), static_cast<void*>(ci->client_sock),
640              static_cast<void*>(other_sock.get())));
641   if (other_sock) {
642     int32_t numberOfSockets = 1;
643 
644     relayBuffer buffers[2];
645 
646     if (!do_http_proxy) {
647       if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info,
648                                     certificateToUse, caNone, flags))
649         client_error = true;
650       else if (!ConnectSocket(other_sock, &remote_addr, connect_timeout))
651         client_error = true;
652       else
653         numberOfSockets = 2;
654     }
655 
656     PRPollDesc sockets[2] = {{ci->client_sock, PR_POLL_READ, 0},
657                              {other_sock.get(), PR_POLL_READ, 0}};
658     bool socketErrorState[2] = {false, false};
659 
660     while (!((client_error || client_done) && buffers[0].empty() &&
661              buffers[1].empty())) {
662       sockets[0].in_flags |= PR_POLL_EXCEPT;
663       sockets[1].in_flags |= PR_POLL_EXCEPT;
664       LOG_DEBUG(("SSLTUNNEL(%p)): polling flags csock(0)=%c%c, ssock(1)=%c%c\n",
665                  static_cast<void*>(data),
666                  sockets[0].in_flags & PR_POLL_READ ? 'R' : '-',
667                  sockets[0].in_flags & PR_POLL_WRITE ? 'W' : '-',
668                  sockets[1].in_flags & PR_POLL_READ ? 'R' : '-',
669                  sockets[1].in_flags & PR_POLL_WRITE ? 'W' : '-'));
670       int32_t pollStatus =
671           PR_Poll(sockets, numberOfSockets, PR_MillisecondsToInterval(1000));
672       if (pollStatus < 0) {
673         LOG_DEBUG(("SSLTUNNEL(%p)): pollStatus=%d, exiting\n",
674                    static_cast<void*>(data), pollStatus));
675         client_error = true;
676         break;
677       }
678 
679       if (pollStatus == 0) {
680         // timeout
681         LOG_DEBUG(("SSLTUNNEL(%p)): poll timeout, looping\n",
682                    static_cast<void*>(data)));
683         continue;
684       }
685 
686       for (int32_t s = 0; s < numberOfSockets; ++s) {
687         int32_t s2 = s == 1 ? 0 : 1;
688         int16_t out_flags = sockets[s].out_flags;
689         int16_t& in_flags = sockets[s].in_flags;
690         int16_t& in_flags2 = sockets[s2].in_flags;
691         sockets[s].out_flags = 0;
692 
693         LOG_BEGIN_BLOCK();
694         LOG_DEBUG(("SSLTUNNEL(%p)): %csock(%d)=%p out_flags=%d",
695                    static_cast<void*>(data), s == 0 ? 'c' : 's', s,
696                    static_cast<void*>(sockets[s].fd), out_flags));
697         if (out_flags & (PR_POLL_EXCEPT | PR_POLL_ERR | PR_POLL_HUP)) {
698           LOG_DEBUG((" :exception\n"));
699           client_error = true;
700           socketErrorState[s] = true;
701           // We got a fatal error state on the socket. Clear the output buffer
702           // for this socket to break the main loop, we will never more be able
703           // to send those data anyway.
704           buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
705           continue;
706         }  // PR_POLL_EXCEPT, PR_POLL_ERR, PR_POLL_HUP handling
707 
708         if (out_flags & PR_POLL_READ && !buffers[s].areafree()) {
709           LOG_DEBUG(
710               (" no place in read buffer but got read flag, dropping it now!"));
711           in_flags &= ~PR_POLL_READ;
712         }
713 
714         if (out_flags & PR_POLL_READ && buffers[s].areafree()) {
715           LOG_DEBUG((" :reading"));
716           int32_t bytesRead =
717               PR_Recv(sockets[s].fd, buffers[s].buffertail,
718                       buffers[s].areafree(), 0, PR_INTERVAL_NO_TIMEOUT);
719 
720           if (bytesRead == 0) {
721             LOG_DEBUG((" socket gracefully closed"));
722             client_done = true;
723             in_flags &= ~PR_POLL_READ;
724           } else if (bytesRead < 0) {
725             if (PR_GetError() != PR_WOULD_BLOCK_ERROR) {
726               LOG_DEBUG((" error=%d", PR_GetError()));
727               // We are in error state, indicate that the connection was
728               // not closed gracefully
729               client_error = true;
730               socketErrorState[s] = true;
731               // Wipe out our send buffer, we cannot send it anyway.
732               buffers[s2].bufferhead = buffers[s2].buffertail =
733                   buffers[s2].buffer;
734             } else
735               LOG_DEBUG((" would block"));
736           } else {
737             // If the other socket is in error state (unable to send/receive)
738             // throw this data away and continue loop
739             if (socketErrorState[s2]) {
740               LOG_DEBUG((" have read but other socket is in error state\n"));
741               continue;
742             }
743 
744             buffers[s].buffertail += bytesRead;
745             LOG_DEBUG((", read %d bytes", bytesRead));
746 
747             // We have to accept and handle the initial CONNECT request here
748             int32_t response;
749             if (!connect_accepted &&
750                 ReadConnectRequest(ci->server_info, buffers[s], &response,
751                                    certificateToUse, &clientAuth, fullHost,
752                                    locationHeader, &flags)) {
753               // Mark this as a proxy-only connection (no SSL) if the CONNECT
754               // request didn't come for port 443 or from any of the server's
755               // cert or clientauth hostnames.
756               if (fullHost.find(":443") == string::npos) {
757                 server_match_t match;
758                 match.fullHost = fullHost;
759                 match.matched = false;
760                 PL_HashTableEnumerateEntries(ci->server_info->host_cert_table,
761                                              match_hostname, &match);
762                 PL_HashTableEnumerateEntries(
763                     ci->server_info->host_clientauth_table, match_hostname,
764                     &match);
765                 PL_HashTableEnumerateEntries(ci->server_info->host_ssl3_table,
766                                              match_hostname, &match);
767                 PL_HashTableEnumerateEntries(ci->server_info->host_tls1_table,
768                                              match_hostname, &match);
769                 PL_HashTableEnumerateEntries(ci->server_info->host_tls11_table,
770                                              match_hostname, &match);
771                 PL_HashTableEnumerateEntries(ci->server_info->host_tls12_table,
772                                              match_hostname, &match);
773                 PL_HashTableEnumerateEntries(ci->server_info->host_tls13_table,
774                                              match_hostname, &match);
775                 PL_HashTableEnumerateEntries(ci->server_info->host_rc4_table,
776                                              match_hostname, &match);
777                 PL_HashTableEnumerateEntries(
778                     ci->server_info->host_failhandshake_table, match_hostname,
779                     &match);
780                 ci->http_proxy_only = !match.matched;
781               } else {
782                 ci->http_proxy_only = false;
783               }
784 
785               // Clean the request as it would be read
786               buffers[s].bufferhead = buffers[s].buffertail = buffers[s].buffer;
787               in_flags |= PR_POLL_WRITE;
788               connect_accepted = true;
789 
790               // Store response to the oposite buffer
791               if (response == 200) {
792                 LOG_DEBUG(
793                     (" accepted CONNECT request, connected to the server, "
794                      "sending OK to the client\n"));
795                 strcpy(
796                     buffers[s2].buffer,
797                     "HTTP/1.1 200 Connected\r\nConnection: keep-alive\r\n\r\n");
798               } else if (response == 302) {
799                 LOG_DEBUG(
800                     (" accepted CONNECT request with redirection, "
801                      "sending location and 302 to the client\n"));
802                 client_done = true;
803                 snprintf(buffers[s2].buffer,
804                          buffers[s2].bufferend - buffers[s2].buffer,
805                          "HTTP/1.1 302 Moved\r\n"
806                          "Location: https://%s/\r\n"
807                          "Connection: close\r\n\r\n",
808                          locationHeader.c_str());
809               } else {
810                 LOG_ERRORD(
811                     (" could not read the connect request, closing connection "
812                      "with %d",
813                      response));
814                 client_done = true;
815                 snprintf(buffers[s2].buffer,
816                          buffers[s2].bufferend - buffers[s2].buffer,
817                          "HTTP/1.1 %d ERROR\r\nConnection: close\r\n\r\n",
818                          response);
819 
820                 break;
821               }
822 
823               buffers[s2].buffertail =
824                   buffers[s2].buffer + strlen(buffers[s2].buffer);
825 
826               // Send the response to the client socket
827               break;
828             }  // end of CONNECT handling
829 
830             if (!buffers[s].areafree()) {
831               // Do not poll for read when the buffer is full
832               LOG_DEBUG((" no place in our read buffer, stop reading"));
833               in_flags &= ~PR_POLL_READ;
834             }
835 
836             if (ssl_updated) {
837               if (s == 0 && expect_request_start) {
838                 if (!strstr(buffers[s].bufferhead, "\r\n\r\n")) {
839                   // We haven't received the complete header yet, so wait.
840                   continue;
841                 }
842                 ci->iswebsocket = AdjustWebSocketHost(buffers[s], ci);
843                 expect_request_start = !(
844                     ci->iswebsocket || AdjustRequestURI(buffers[s], &fullHost));
845                 PRNetAddr* addr = &remote_addr;
846                 if (ci->iswebsocket && websocket_server.inet.port)
847                   addr = &websocket_server;
848                 if (!ConnectSocket(other_sock, addr, connect_timeout)) {
849                   LOG_ERRORD(
850                       (" could not open connection to the real server\n"));
851                   client_error = true;
852                   break;
853                 }
854                 LOG_DEBUG(("\n connected to remote server\n"));
855                 numberOfSockets = 2;
856               } else if (s == 1 && ci->iswebsocket) {
857                 if (!AdjustWebSocketLocation(buffers[s], ci)) continue;
858               }
859 
860               in_flags2 |= PR_POLL_WRITE;
861               LOG_DEBUG((" telling the other socket to write"));
862             } else
863               LOG_DEBUG(
864                   (" we have something for the other socket to write, but ssl "
865                    "has not been administered on it"));
866           }
867         }  // PR_POLL_READ handling
868 
869         if (out_flags & PR_POLL_WRITE) {
870           LOG_DEBUG((" :writing"));
871           int32_t bytesWrite =
872               PR_Send(sockets[s].fd, buffers[s2].bufferhead,
873                       buffers[s2].present(), 0, PR_INTERVAL_NO_TIMEOUT);
874 
875           if (bytesWrite < 0) {
876             if (PR_GetError() != PR_WOULD_BLOCK_ERROR) {
877               LOG_DEBUG((" error=%d", PR_GetError()));
878               client_error = true;
879               socketErrorState[s] = true;
880               // We got a fatal error while writting the buffer. Clear it to
881               // break the main loop, we will never more be able to send it.
882               buffers[s2].bufferhead = buffers[s2].buffertail =
883                   buffers[s2].buffer;
884             } else
885               LOG_DEBUG((" would block"));
886           } else {
887             LOG_DEBUG((", written %d bytes", bytesWrite));
888             buffers[s2].buffertail[1] = '\0';
889             LOG_DEBUG((" dump:\n%.*s\n", bytesWrite, buffers[s2].bufferhead));
890 
891             buffers[s2].bufferhead += bytesWrite;
892             if (buffers[s2].present()) {
893               LOG_DEBUG((" still have to write %d bytes",
894                          (int)buffers[s2].present()));
895               in_flags |= PR_POLL_WRITE;
896             } else {
897               if (!ssl_updated) {
898                 LOG_DEBUG((" proxy response sent to the client"));
899                 // Proxy response has just been writen, update to ssl
900                 ssl_updated = true;
901                 if (ci->http_proxy_only) {
902                   LOG_DEBUG(
903                       (" not updating to SSL based on http_proxy_only for this "
904                        "socket"));
905                 } else if (!ConfigureSSLServerSocket(
906                                ci->client_sock, ci->server_info,
907                                certificateToUse, clientAuth, flags)) {
908                   LOG_ERRORD((" failed to config server socket\n"));
909                   client_error = true;
910                   break;
911                 } else {
912                   LOG_DEBUG((" client socket updated to SSL"));
913                 }
914               }  // sslUpdate
915 
916               LOG_DEBUG(
917                   (" dropping our write flag and setting other socket read "
918                    "flag"));
919               in_flags &= ~PR_POLL_WRITE;
920               in_flags2 |= PR_POLL_READ;
921               buffers[s2].compact();
922             }
923           }
924         }                 // PR_POLL_WRITE handling
925         LOG_END_BLOCK();  // end the log
926       }                   // for...
927     }                     // while, poll
928   } else
929     client_error = true;
930 
931   LOG_DEBUG(("SSLTUNNEL(%p)): exiting root function for csock=%p, ssock=%p\n",
932              static_cast<void*>(data), static_cast<void*>(ci->client_sock),
933              static_cast<void*>(other_sock.get())));
934   if (!client_error) PR_Shutdown(ci->client_sock, PR_SHUTDOWN_SEND);
935   PR_Close(ci->client_sock);
936 
937   delete ci;
938 }
939 
940 /*
941  * Start listening for SSL connections on a specified port, handing
942  * them off to client threads after accepting the connection.
943  * The data parameter is a server_info_t*, owned by the calling
944  * function.
945  */
StartServer(void * data)946 void StartServer(void* data) {
947   server_info_t* si = static_cast<server_info_t*>(data);
948 
949   // TODO: select ciphers?
950   UniquePRFileDesc listen_socket(PR_NewTCPSocket());
951   if (!listen_socket) {
952     LOG_ERROR(("failed to create socket\n"));
953     SignalShutdown();
954     return;
955   }
956 
957   // In case the socket is still open in the TIME_WAIT state from a previous
958   // instance of ssltunnel we ask to reuse the port.
959   PRSocketOptionData socket_option;
960   socket_option.option = PR_SockOpt_Reuseaddr;
961   socket_option.value.reuse_addr = true;
962   PR_SetSocketOption(listen_socket.get(), &socket_option);
963 
964   PRNetAddr server_addr;
965   PRNetAddrValue listen_addr;
966   if (listen_public) {
967     listen_addr = PR_IpAddrAny;
968   } else {
969     listen_addr = PR_IpAddrLoopback;
970   }
971   PR_InitializeNetAddr(listen_addr, si->listen_port, &server_addr);
972 
973   if (PR_Bind(listen_socket.get(), &server_addr) != PR_SUCCESS) {
974     LOG_ERROR(("failed to bind socket on port %d: error %d\n", si->listen_port,
975                PR_GetError()));
976     SignalShutdown();
977     return;
978   }
979 
980   if (PR_Listen(listen_socket.get(), 1) != PR_SUCCESS) {
981     LOG_ERROR(("failed to listen on socket\n"));
982     SignalShutdown();
983     return;
984   }
985 
986   LOG_INFO(("Server listening on port %d with cert %s\n", si->listen_port,
987             si->cert_nickname.c_str()));
988 
989   while (!shutdown_server) {
990     connection_info_t* ci = new connection_info_t();
991     ci->server_info = si;
992     ci->http_proxy_only = do_http_proxy;
993     // block waiting for connections
994     ci->client_sock = PR_Accept(listen_socket.get(), &ci->client_addr,
995                                 PR_INTERVAL_NO_TIMEOUT);
996 
997     PRSocketOptionData option;
998     option.option = PR_SockOpt_Nonblocking;
999     option.value.non_blocking = true;
1000     PR_SetSocketOption(ci->client_sock, &option);
1001 
1002     if (ci->client_sock)
1003       // Not actually using this PRJob*...
1004       // PRJob* job =
1005       PR_QueueJob(threads, HandleConnection, ci, true);
1006     else
1007       delete ci;
1008   }
1009 }
1010 
1011 // bogus password func, just don't use passwords. :-P
password_func(PK11SlotInfo * slot,PRBool retry,void * arg)1012 char* password_func(PK11SlotInfo* slot, PRBool retry, void* arg) {
1013   if (retry) return nullptr;
1014 
1015   return PL_strdup("");
1016 }
1017 
findServerInfo(int portnumber)1018 server_info_t* findServerInfo(int portnumber) {
1019   for (auto& server : servers) {
1020     if (server.listen_port == portnumber) return &server;
1021   }
1022 
1023   return nullptr;
1024 }
1025 
get_ssl3_table(server_info_t * server)1026 PLHashTable* get_ssl3_table(server_info_t* server) {
1027   return server->host_ssl3_table;
1028 }
1029 
get_tls1_table(server_info_t * server)1030 PLHashTable* get_tls1_table(server_info_t* server) {
1031   return server->host_tls1_table;
1032 }
1033 
get_tls11_table(server_info_t * server)1034 PLHashTable* get_tls11_table(server_info_t* server) {
1035   return server->host_tls11_table;
1036 }
1037 
get_tls12_table(server_info_t * server)1038 PLHashTable* get_tls12_table(server_info_t* server) {
1039   return server->host_tls12_table;
1040 }
1041 
get_tls13_table(server_info_t * server)1042 PLHashTable* get_tls13_table(server_info_t* server) {
1043   return server->host_tls13_table;
1044 }
1045 
get_rc4_table(server_info_t * server)1046 PLHashTable* get_rc4_table(server_info_t* server) {
1047   return server->host_rc4_table;
1048 }
1049 
get_failhandshake_table(server_info_t * server)1050 PLHashTable* get_failhandshake_table(server_info_t* server) {
1051   return server->host_failhandshake_table;
1052 }
1053 
parseWeakCryptoConfig(char * const & keyword,char * & _caret,PLHashTable * (* get_table)(server_info_t *))1054 int parseWeakCryptoConfig(char* const& keyword, char*& _caret,
1055                           PLHashTable* (*get_table)(server_info_t*)) {
1056   char* hostname = strtok2(_caret, ":", &_caret);
1057   char* hostportstring = strtok2(_caret, ":", &_caret);
1058   char* serverportstring = strtok2(_caret, "\n", &_caret);
1059 
1060   int port = atoi(serverportstring);
1061   if (port <= 0) {
1062     LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
1063     return 1;
1064   }
1065 
1066   if (server_info_t* existingServer = findServerInfo(port)) {
1067     any_host_spec_config = true;
1068 
1069     char* hostname_copy =
1070         new char[strlen(hostname) + strlen(hostportstring) + 2];
1071     if (!hostname_copy) {
1072       LOG_ERROR(("Out of memory"));
1073       return 1;
1074     }
1075 
1076     strcpy(hostname_copy, hostname);
1077     strcat(hostname_copy, ":");
1078     strcat(hostname_copy, hostportstring);
1079 
1080     PLHashEntry* entry =
1081         PL_HashTableAdd(get_table(existingServer), hostname_copy, keyword);
1082     if (!entry) {
1083       LOG_ERROR(("Out of memory"));
1084       return 1;
1085     }
1086   } else {
1087     LOG_ERROR(
1088         ("Server on port %d for redirhost option is not defined, use 'listen' "
1089          "option first",
1090          port));
1091     return 1;
1092   }
1093 
1094   return 0;
1095 }
1096 
processConfigLine(char * configLine)1097 int processConfigLine(char* configLine) {
1098   if (*configLine == 0 || *configLine == '#') return 0;
1099 
1100   char* _caret;
1101   char* keyword = strtok2(configLine, ":", &_caret);
1102   // Configure usage of http/ssl tunneling proxy behavior
1103   if (!strcmp(keyword, "httpproxy")) {
1104     char* value = strtok2(_caret, ":", &_caret);
1105     if (!strcmp(value, "1")) do_http_proxy = true;
1106 
1107     return 0;
1108   }
1109 
1110   if (!strcmp(keyword, "websocketserver")) {
1111     char* ipstring = strtok2(_caret, ":", &_caret);
1112     if (PR_StringToNetAddr(ipstring, &websocket_server) != PR_SUCCESS) {
1113       LOG_ERROR(("Invalid IP address in proxy config: %s\n", ipstring));
1114       return 1;
1115     }
1116     char* remoteport = strtok2(_caret, ":", &_caret);
1117     int port = atoi(remoteport);
1118     if (port <= 0) {
1119       LOG_ERROR(("Invalid remote port in proxy config: %s\n", remoteport));
1120       return 1;
1121     }
1122     websocket_server.inet.port = PR_htons(port);
1123     return 0;
1124   }
1125 
1126   // Configure the forward address of the target server
1127   if (!strcmp(keyword, "forward")) {
1128     char* ipstring = strtok2(_caret, ":", &_caret);
1129     if (PR_StringToNetAddr(ipstring, &remote_addr) != PR_SUCCESS) {
1130       LOG_ERROR(("Invalid remote IP address: %s\n", ipstring));
1131       return 1;
1132     }
1133     char* serverportstring = strtok2(_caret, ":", &_caret);
1134     int port = atoi(serverportstring);
1135     if (port <= 0) {
1136       LOG_ERROR(("Invalid remote port: %s\n", serverportstring));
1137       return 1;
1138     }
1139     remote_addr.inet.port = PR_htons(port);
1140 
1141     return 0;
1142   }
1143 
1144   // Configure all listen sockets and port+certificate bindings.
1145   // Listen on the public address if "*" was specified as the listen
1146   // address or listen on the loopback address if "127.0.0.1" was
1147   // specified. Using loopback will prevent users getting errors from
1148   // their firewalls about ssltunnel needing permission. A public
1149   // address is required when proxying ssl traffic from a physical or
1150   // emulated Android device since it has a different ip address from
1151   // the host.
1152   if (!strcmp(keyword, "listen")) {
1153     char* hostname = strtok2(_caret, ":", &_caret);
1154     char* hostportstring = nullptr;
1155     if (!strcmp(hostname, "*")) {
1156       listen_public = true;
1157     } else if (strcmp(hostname, "127.0.0.1")) {
1158       any_host_spec_config = true;
1159       hostportstring = strtok2(_caret, ":", &_caret);
1160     }
1161 
1162     char* serverportstring = strtok2(_caret, ":", &_caret);
1163     char* certnick = strtok2(_caret, ":", &_caret);
1164 
1165     int port = atoi(serverportstring);
1166     if (port <= 0) {
1167       LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
1168       return 1;
1169     }
1170 
1171     if (server_info_t* existingServer = findServerInfo(port)) {
1172       if (!hostportstring) {
1173         LOG_ERROR(
1174             ("Null hostportstring specified for hostname %s\n", hostname));
1175         return 1;
1176       }
1177       char* certnick_copy = new char[strlen(certnick) + 1];
1178       char* hostname_copy =
1179           new char[strlen(hostname) + strlen(hostportstring) + 2];
1180 
1181       strcpy(hostname_copy, hostname);
1182       strcat(hostname_copy, ":");
1183       strcat(hostname_copy, hostportstring);
1184       strcpy(certnick_copy, certnick);
1185 
1186       PLHashEntry* entry = PL_HashTableAdd(existingServer->host_cert_table,
1187                                            hostname_copy, certnick_copy);
1188       if (!entry) {
1189         LOG_ERROR(("Out of memory"));
1190         return 1;
1191       }
1192     } else {
1193       server_info_t server;
1194       server.cert_nickname = certnick;
1195       server.listen_port = port;
1196       server.host_cert_table =
1197           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1198                           PL_CompareStrings, nullptr, nullptr);
1199       if (!server.host_cert_table) {
1200         LOG_ERROR(("Internal, could not create hash table\n"));
1201         return 1;
1202       }
1203       server.host_clientauth_table =
1204           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1205                           ClientAuthValueComparator, nullptr, nullptr);
1206       if (!server.host_clientauth_table) {
1207         LOG_ERROR(("Internal, could not create hash table\n"));
1208         return 1;
1209       }
1210       server.host_redir_table =
1211           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1212                           PL_CompareStrings, nullptr, nullptr);
1213       if (!server.host_redir_table) {
1214         LOG_ERROR(("Internal, could not create hash table\n"));
1215         return 1;
1216       }
1217 
1218       server.host_ssl3_table =
1219           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1220                           PL_CompareStrings, nullptr, nullptr);
1221 
1222       if (!server.host_ssl3_table) {
1223         LOG_ERROR(("Internal, could not create hash table\n"));
1224         return 1;
1225       }
1226 
1227       server.host_tls1_table =
1228           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1229                           PL_CompareStrings, nullptr, nullptr);
1230 
1231       if (!server.host_tls1_table) {
1232         LOG_ERROR(("Internal, could not create hash table\n"));
1233         return 1;
1234       }
1235 
1236       server.host_tls11_table =
1237           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1238                           PL_CompareStrings, nullptr, nullptr);
1239 
1240       if (!server.host_tls11_table) {
1241         LOG_ERROR(("Internal, could not create hash table\n"));
1242         return 1;
1243       }
1244 
1245       server.host_tls12_table =
1246           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1247                           PL_CompareStrings, nullptr, nullptr);
1248 
1249       if (!server.host_tls12_table) {
1250         LOG_ERROR(("Internal, could not create hash table\n"));
1251         return 1;
1252       }
1253 
1254       server.host_tls13_table =
1255           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1256                           PL_CompareStrings, nullptr, nullptr);
1257 
1258       if (!server.host_tls13_table) {
1259         LOG_ERROR(("Internal, could not create hash table\n"));
1260         return 1;
1261       }
1262 
1263       server.host_rc4_table =
1264           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1265                           PL_CompareStrings, nullptr, nullptr);
1266       ;
1267       if (!server.host_rc4_table) {
1268         LOG_ERROR(("Internal, could not create hash table\n"));
1269         return 1;
1270       }
1271 
1272       server.host_failhandshake_table =
1273           PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1274                           PL_CompareStrings, nullptr, nullptr);
1275       ;
1276       if (!server.host_failhandshake_table) {
1277         LOG_ERROR(("Internal, could not create hash table\n"));
1278         return 1;
1279       }
1280 
1281       servers.push_back(server);
1282     }
1283 
1284     return 0;
1285   }
1286 
1287   if (!strcmp(keyword, "clientauth")) {
1288     char* hostname = strtok2(_caret, ":", &_caret);
1289     char* hostportstring = strtok2(_caret, ":", &_caret);
1290     char* serverportstring = strtok2(_caret, ":", &_caret);
1291 
1292     int port = atoi(serverportstring);
1293     if (port <= 0) {
1294       LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
1295       return 1;
1296     }
1297 
1298     if (server_info_t* existingServer = findServerInfo(port)) {
1299       char* authoptionstring = strtok2(_caret, ":", &_caret);
1300       client_auth_option* authoption = new client_auth_option;
1301       if (!authoption) {
1302         LOG_ERROR(("Out of memory"));
1303         return 1;
1304       }
1305 
1306       if (!strcmp(authoptionstring, "require"))
1307         *authoption = caRequire;
1308       else if (!strcmp(authoptionstring, "request"))
1309         *authoption = caRequest;
1310       else if (!strcmp(authoptionstring, "none"))
1311         *authoption = caNone;
1312       else {
1313         LOG_ERROR(
1314             ("Incorrect client auth option modifier for host '%s'", hostname));
1315         delete authoption;
1316         return 1;
1317       }
1318 
1319       any_host_spec_config = true;
1320 
1321       char* hostname_copy =
1322           new char[strlen(hostname) + strlen(hostportstring) + 2];
1323       if (!hostname_copy) {
1324         LOG_ERROR(("Out of memory"));
1325         delete authoption;
1326         return 1;
1327       }
1328 
1329       strcpy(hostname_copy, hostname);
1330       strcat(hostname_copy, ":");
1331       strcat(hostname_copy, hostportstring);
1332 
1333       PLHashEntry* entry = PL_HashTableAdd(
1334           existingServer->host_clientauth_table, hostname_copy, authoption);
1335       if (!entry) {
1336         LOG_ERROR(("Out of memory"));
1337         delete authoption;
1338         return 1;
1339       }
1340     } else {
1341       LOG_ERROR(
1342           ("Server on port %d for client authentication option is not defined, "
1343            "use 'listen' option first",
1344            port));
1345       return 1;
1346     }
1347 
1348     return 0;
1349   }
1350 
1351   if (!strcmp(keyword, "redirhost")) {
1352     char* hostname = strtok2(_caret, ":", &_caret);
1353     char* hostportstring = strtok2(_caret, ":", &_caret);
1354     char* serverportstring = strtok2(_caret, ":", &_caret);
1355 
1356     int port = atoi(serverportstring);
1357     if (port <= 0) {
1358       LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
1359       return 1;
1360     }
1361 
1362     if (server_info_t* existingServer = findServerInfo(port)) {
1363       char* redirhoststring = strtok2(_caret, ":", &_caret);
1364 
1365       any_host_spec_config = true;
1366 
1367       char* hostname_copy =
1368           new char[strlen(hostname) + strlen(hostportstring) + 2];
1369       if (!hostname_copy) {
1370         LOG_ERROR(("Out of memory"));
1371         return 1;
1372       }
1373 
1374       strcpy(hostname_copy, hostname);
1375       strcat(hostname_copy, ":");
1376       strcat(hostname_copy, hostportstring);
1377 
1378       char* redir_copy = new char[strlen(redirhoststring) + 1];
1379       strcpy(redir_copy, redirhoststring);
1380       PLHashEntry* entry = PL_HashTableAdd(existingServer->host_redir_table,
1381                                            hostname_copy, redir_copy);
1382       if (!entry) {
1383         LOG_ERROR(("Out of memory"));
1384         delete[] hostname_copy;
1385         delete[] redir_copy;
1386         return 1;
1387       }
1388     } else {
1389       LOG_ERROR(
1390           ("Server on port %d for redirhost option is not defined, use "
1391            "'listen' option first",
1392            port));
1393       return 1;
1394     }
1395 
1396     return 0;
1397   }
1398 
1399   if (!strcmp(keyword, "ssl3")) {
1400     return parseWeakCryptoConfig(keyword, _caret, get_ssl3_table);
1401   }
1402   if (!strcmp(keyword, "tls1")) {
1403     return parseWeakCryptoConfig(keyword, _caret, get_tls1_table);
1404   }
1405   if (!strcmp(keyword, "tls1_1")) {
1406     return parseWeakCryptoConfig(keyword, _caret, get_tls11_table);
1407   }
1408   if (!strcmp(keyword, "tls1_2")) {
1409     return parseWeakCryptoConfig(keyword, _caret, get_tls12_table);
1410   }
1411   if (!strcmp(keyword, "tls1_3")) {
1412     return parseWeakCryptoConfig(keyword, _caret, get_tls13_table);
1413   }
1414 
1415   if (!strcmp(keyword, "rc4")) {
1416     return parseWeakCryptoConfig(keyword, _caret, get_rc4_table);
1417   }
1418 
1419   if (!strcmp(keyword, "failHandshake")) {
1420     return parseWeakCryptoConfig(keyword, _caret, get_failhandshake_table);
1421   }
1422 
1423   // Configure the NSS certificate database directory
1424   if (!strcmp(keyword, "certdbdir")) {
1425     nssconfigdir = strtok2(_caret, "\n", &_caret);
1426     return 0;
1427   }
1428 
1429   LOG_ERROR(("Error: keyword \"%s\" unexpected\n", keyword));
1430   return 1;
1431 }
1432 
parseConfigFile(const char * filePath)1433 int parseConfigFile(const char* filePath) {
1434   FILE* f = fopen(filePath, "r");
1435   if (!f) return 1;
1436 
1437   char buffer[1024], *b = buffer;
1438   while (!feof(f)) {
1439     char c;
1440 
1441     if (fscanf(f, "%c", &c) != 1) {
1442       break;
1443     }
1444 
1445     switch (c) {
1446       case '\n':
1447         *b++ = 0;
1448         if (processConfigLine(buffer)) {
1449           fclose(f);
1450           return 1;
1451         }
1452         b = buffer;
1453         continue;
1454 
1455       case '\r':
1456         continue;
1457 
1458       default:
1459         *b++ = c;
1460     }
1461   }
1462 
1463   fclose(f);
1464 
1465   // Check mandatory items
1466   if (nssconfigdir.empty()) {
1467     LOG_ERROR(
1468         ("Error: missing path to NSS certification database\n,use "
1469          "certdbdir:<path> in the config file\n"));
1470     return 1;
1471   }
1472 
1473   if (any_host_spec_config && !do_http_proxy) {
1474     LOG_ERROR(
1475         ("Warning: any host-specific configurations are ignored, add "
1476          "httpproxy:1 to allow them\n"));
1477   }
1478 
1479   return 0;
1480 }
1481 
freeHostCertHashItems(PLHashEntry * he,int i,void * arg)1482 int freeHostCertHashItems(PLHashEntry* he, int i, void* arg) {
1483   delete[](char*) he->key;
1484   delete[](char*) he->value;
1485   return HT_ENUMERATE_REMOVE;
1486 }
1487 
freeHostRedirHashItems(PLHashEntry * he,int i,void * arg)1488 int freeHostRedirHashItems(PLHashEntry* he, int i, void* arg) {
1489   delete[](char*) he->key;
1490   delete[](char*) he->value;
1491   return HT_ENUMERATE_REMOVE;
1492 }
1493 
freeClientAuthHashItems(PLHashEntry * he,int i,void * arg)1494 int freeClientAuthHashItems(PLHashEntry* he, int i, void* arg) {
1495   delete[](char*) he->key;
1496   delete (client_auth_option*)he->value;
1497   return HT_ENUMERATE_REMOVE;
1498 }
1499 
freeSSL3HashItems(PLHashEntry * he,int i,void * arg)1500 int freeSSL3HashItems(PLHashEntry* he, int i, void* arg) {
1501   delete[](char*) he->key;
1502   return HT_ENUMERATE_REMOVE;
1503 }
1504 
freeTLSHashItems(PLHashEntry * he,int i,void * arg)1505 int freeTLSHashItems(PLHashEntry* he, int i, void* arg) {
1506   delete[](char*) he->key;
1507   return HT_ENUMERATE_REMOVE;
1508 }
1509 
freeRC4HashItems(PLHashEntry * he,int i,void * arg)1510 int freeRC4HashItems(PLHashEntry* he, int i, void* arg) {
1511   delete[](char*) he->key;
1512   return HT_ENUMERATE_REMOVE;
1513 }
1514 
main(int argc,char ** argv)1515 int main(int argc, char** argv) {
1516   const char* configFilePath;
1517 
1518   const char* logLevelEnv = PR_GetEnv("SSLTUNNEL_LOG_LEVEL");
1519   gLogLevel = logLevelEnv ? (LogLevel)atoi(logLevelEnv) : LEVEL_INFO;
1520 
1521   if (argc == 1)
1522     configFilePath = "ssltunnel.cfg";
1523   else
1524     configFilePath = argv[1];
1525 
1526   memset(&websocket_server, 0, sizeof(PRNetAddr));
1527 
1528   if (parseConfigFile(configFilePath)) {
1529     LOG_ERROR((
1530         "Error: config file \"%s\" missing or formating incorrect\n"
1531         "Specify path to the config file as parameter to ssltunnel or \n"
1532         "create ssltunnel.cfg in the working directory.\n\n"
1533         "Example format of the config file:\n\n"
1534         "       # Enable http/ssl tunneling proxy-like behavior.\n"
1535         "       # If not specified ssltunnel simply does direct forward.\n"
1536         "       httpproxy:1\n\n"
1537         "       # Specify path to the certification database used.\n"
1538         "       certdbdir:/path/to/certdb\n\n"
1539         "       # Forward/proxy all requests in raw to 127.0.0.1:8888.\n"
1540         "       forward:127.0.0.1:8888\n\n"
1541         "       # Accept connections on port 4443 or 5678 resp. and "
1542         "authenticate\n"
1543         "       # to any host ('*') using the 'server cert' or 'server cert 2' "
1544         "resp.\n"
1545         "       listen:*:4443:server cert\n"
1546         "       listen:*:5678:server cert 2\n\n"
1547         "       # Accept connections on port 4443 and authenticate using\n"
1548         "       # 'a different cert' when target host is 'my.host.name:443'.\n"
1549         "       # This only works in httpproxy mode and has higher priority\n"
1550         "       # than the previous option.\n"
1551         "       listen:my.host.name:443:4443:a different cert\n\n"
1552         "       # To make a specific host require or just request a client "
1553         "certificate\n"
1554         "       # to authenticate use the following options. This can only be "
1555         "used\n"
1556         "       # in httpproxy mode and only after the 'listen' option has "
1557         "been\n"
1558         "       # specified. You also have to specify the tunnel listen port.\n"
1559         "       clientauth:requesting-client-cert.host.com:443:4443:request\n"
1560         "       clientauth:requiring-client-cert.host.com:443:4443:require\n"
1561         "       # Proxy WebSocket traffic to the server at 127.0.0.1:9999,\n"
1562         "       # instead of the server specified in the 'forward' option.\n"
1563         "       websocketserver:127.0.0.1:9999\n",
1564         configFilePath));
1565     return 1;
1566   }
1567 
1568   // create a thread pool to handle connections
1569   threads =
1570       PR_CreateThreadPool(INITIAL_THREADS * servers.size(),
1571                           MAX_THREADS * servers.size(), DEFAULT_STACKSIZE);
1572   if (!threads) {
1573     LOG_ERROR(("Failed to create thread pool\n"));
1574     return 1;
1575   }
1576 
1577   shutdown_lock = PR_NewLock();
1578   if (!shutdown_lock) {
1579     LOG_ERROR(("Failed to create lock\n"));
1580     PR_ShutdownThreadPool(threads);
1581     return 1;
1582   }
1583   shutdown_condvar = PR_NewCondVar(shutdown_lock);
1584   if (!shutdown_condvar) {
1585     LOG_ERROR(("Failed to create condvar\n"));
1586     PR_ShutdownThreadPool(threads);
1587     PR_DestroyLock(shutdown_lock);
1588     return 1;
1589   }
1590 
1591   PK11_SetPasswordFunc(password_func);
1592 
1593   // Initialize NSS
1594   if (NSS_Init(nssconfigdir.c_str()) != SECSuccess) {
1595     int32_t errorlen = PR_GetErrorTextLength();
1596     if (errorlen) {
1597       auto err = mozilla::MakeUnique<char[]>(errorlen + 1);
1598       PR_GetErrorText(err.get());
1599       LOG_ERROR(("Failed to init NSS: %s", err.get()));
1600     } else {
1601       LOG_ERROR(("Failed to init NSS: Cannot get error from NSPR."));
1602     }
1603     PR_ShutdownThreadPool(threads);
1604     PR_DestroyCondVar(shutdown_condvar);
1605     PR_DestroyLock(shutdown_lock);
1606     return 1;
1607   }
1608 
1609   if (NSS_SetDomesticPolicy() != SECSuccess) {
1610     LOG_ERROR(("NSS_SetDomesticPolicy failed\n"));
1611     PR_ShutdownThreadPool(threads);
1612     PR_DestroyCondVar(shutdown_condvar);
1613     PR_DestroyLock(shutdown_lock);
1614     NSS_Shutdown();
1615     return 1;
1616   }
1617 
1618   // these values should make NSS use the defaults
1619   if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) {
1620     LOG_ERROR(("SSL_ConfigServerSessionIDCache failed\n"));
1621     PR_ShutdownThreadPool(threads);
1622     PR_DestroyCondVar(shutdown_condvar);
1623     PR_DestroyLock(shutdown_lock);
1624     NSS_Shutdown();
1625     return 1;
1626   }
1627 
1628   for (auto& server : servers) {
1629     // Not actually using this PRJob*...
1630     // PRJob* server_job =
1631     PR_QueueJob(threads, StartServer, &server, true);
1632   }
1633   // now wait for someone to tell us to quit
1634   PR_Lock(shutdown_lock);
1635   PR_WaitCondVar(shutdown_condvar, PR_INTERVAL_NO_TIMEOUT);
1636   PR_Unlock(shutdown_lock);
1637   shutdown_server = true;
1638   LOG_INFO(("Shutting down...\n"));
1639   // cleanup
1640   PR_ShutdownThreadPool(threads);
1641   PR_JoinThreadPool(threads);
1642   PR_DestroyCondVar(shutdown_condvar);
1643   PR_DestroyLock(shutdown_lock);
1644   if (NSS_Shutdown() == SECFailure) {
1645     LOG_DEBUG(("Leaked NSS objects!\n"));
1646   }
1647 
1648   for (auto& server : servers) {
1649     PL_HashTableEnumerateEntries(server.host_cert_table, freeHostCertHashItems,
1650                                  nullptr);
1651     PL_HashTableEnumerateEntries(server.host_clientauth_table,
1652                                  freeClientAuthHashItems, nullptr);
1653     PL_HashTableEnumerateEntries(server.host_redir_table,
1654                                  freeHostRedirHashItems, nullptr);
1655     PL_HashTableEnumerateEntries(server.host_ssl3_table, freeSSL3HashItems,
1656                                  nullptr);
1657     PL_HashTableEnumerateEntries(server.host_tls1_table, freeTLSHashItems,
1658                                  nullptr);
1659     PL_HashTableEnumerateEntries(server.host_tls11_table, freeTLSHashItems,
1660                                  nullptr);
1661     PL_HashTableEnumerateEntries(server.host_tls12_table, freeTLSHashItems,
1662                                  nullptr);
1663     PL_HashTableEnumerateEntries(server.host_tls13_table, freeTLSHashItems,
1664                                  nullptr);
1665     PL_HashTableEnumerateEntries(server.host_rc4_table, freeRC4HashItems,
1666                                  nullptr);
1667     PL_HashTableEnumerateEntries(server.host_failhandshake_table,
1668                                  freeRC4HashItems, nullptr);
1669     PL_HashTableDestroy(server.host_cert_table);
1670     PL_HashTableDestroy(server.host_clientauth_table);
1671     PL_HashTableDestroy(server.host_redir_table);
1672     PL_HashTableDestroy(server.host_ssl3_table);
1673     PL_HashTableDestroy(server.host_tls1_table);
1674     PL_HashTableDestroy(server.host_tls11_table);
1675     PL_HashTableDestroy(server.host_tls12_table);
1676     PL_HashTableDestroy(server.host_tls13_table);
1677     PL_HashTableDestroy(server.host_rc4_table);
1678     PL_HashTableDestroy(server.host_failhandshake_table);
1679   }
1680 
1681   PR_Cleanup();
1682   return 0;
1683 }
1684