1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* This Source Code Form is subject to the terms of the Mozilla Public
3 * License, v. 2.0. If a copy of the MPL was not distributed with this
4 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
5
6 /*
7 * WARNING: DO NOT USE THIS CODE IN PRODUCTION SYSTEMS. It is highly likely to
8 * be plagued with the usual problems endemic to C (buffer overflows
9 * and the like). We don't especially care here (but would accept
10 * patches!) because this is only intended for use in our test
11 * harnesses in controlled situations where input is guaranteed not to
12 * be malicious.
13 */
14
15 #include "ScopedNSSTypes.h"
16 #include <assert.h>
17 #include <stdio.h>
18 #include <string>
19 #include <vector>
20 #include <algorithm>
21 #include <stdarg.h>
22 #include "prinit.h"
23 #include "prerror.h"
24 #include "prenv.h"
25 #include "prnetdb.h"
26 #include "prtpool.h"
27 #include "nsAlgorithm.h"
28 #include "nss.h"
29 #include "keyhi.h"
30 #include "ssl.h"
31 #include "sslproto.h"
32 #include "plhash.h"
33 #include "mozilla/Sprintf.h"
34 #include "mozilla/Unused.h"
35
36 using namespace mozilla;
37 using namespace mozilla::psm;
38 using std::string;
39 using std::vector;
40
41 #define IS_DELIM(m, c) ((m)[(c) >> 3] & (1 << ((c)&7)))
42 #define SET_DELIM(m, c) ((m)[(c) >> 3] |= (1 << ((c)&7)))
43 #define DELIM_TABLE_SIZE 32
44
45 // You can set the level of logging by env var SSLTUNNEL_LOG_LEVEL=n, where n
46 // is 0 through 3. The default is 1, INFO level logging.
47 enum LogLevel {
48 LEVEL_DEBUG = 0,
49 LEVEL_INFO = 1,
50 LEVEL_ERROR = 2,
51 LEVEL_SILENT = 3
52 } gLogLevel,
53 gLastLogLevel;
54
55 #define _LOG_OUTPUT(level, func, params) \
56 PR_BEGIN_MACRO \
57 if (level >= gLogLevel) { \
58 gLastLogLevel = level; \
59 func params; \
60 } \
61 PR_END_MACRO
62
63 // The most verbose output
64 #define LOG_DEBUG(params) _LOG_OUTPUT(LEVEL_DEBUG, printf, params)
65
66 // Top level informative messages
67 #define LOG_INFO(params) _LOG_OUTPUT(LEVEL_INFO, printf, params)
68
69 // Serious errors that must be logged always until completely gag
70 #define LOG_ERROR(params) _LOG_OUTPUT(LEVEL_ERROR, eprintf, params)
71
72 // Same as LOG_ERROR, but when logging is set to LEVEL_DEBUG, the message
73 // will be put to the stdout instead of stderr to keep continuity with other
74 // LOG_DEBUG message output
75 #define LOG_ERRORD(params) \
76 PR_BEGIN_MACRO \
77 if (gLogLevel == LEVEL_DEBUG) \
78 _LOG_OUTPUT(LEVEL_ERROR, printf, params); \
79 else \
80 _LOG_OUTPUT(LEVEL_ERROR, eprintf, params); \
81 PR_END_MACRO
82
83 // If there is any output written between LOG_BEGIN_BLOCK() and
84 // LOG_END_BLOCK() then a new line will be put to the proper output (out/err)
85 #define LOG_BEGIN_BLOCK() gLastLogLevel = LEVEL_SILENT;
86
87 #define LOG_END_BLOCK() \
88 PR_BEGIN_MACRO \
89 if (gLastLogLevel == LEVEL_ERROR) LOG_ERROR(("\n")); \
90 if (gLastLogLevel < LEVEL_ERROR) _LOG_OUTPUT(gLastLogLevel, printf, ("\n")); \
91 PR_END_MACRO
92
eprintf(const char * str,...)93 int eprintf(const char* str, ...) {
94 va_list ap;
95 va_start(ap, str);
96 int result = vfprintf(stderr, str, ap);
97 va_end(ap);
98 return result;
99 }
100
101 // Copied from nsCRT
strtok2(char * string,const char * delims,char ** newStr)102 char* strtok2(char* string, const char* delims, char** newStr) {
103 PR_ASSERT(string);
104
105 char delimTable[DELIM_TABLE_SIZE];
106 uint32_t i;
107 char* result;
108 char* str = string;
109
110 for (i = 0; i < DELIM_TABLE_SIZE; i++) delimTable[i] = '\0';
111
112 for (i = 0; delims[i]; i++) {
113 SET_DELIM(delimTable, static_cast<uint8_t>(delims[i]));
114 }
115
116 // skip to beginning
117 while (*str && IS_DELIM(delimTable, static_cast<uint8_t>(*str))) {
118 str++;
119 }
120 result = str;
121
122 // fix up the end of the token
123 while (*str) {
124 if (IS_DELIM(delimTable, static_cast<uint8_t>(*str))) {
125 *str++ = '\0';
126 break;
127 }
128 str++;
129 }
130 *newStr = str;
131
132 return str == result ? nullptr : result;
133 }
134
135 enum client_auth_option { caNone = 0, caRequire = 1, caRequest = 2 };
136
137 // Structs for passing data into jobs on the thread pool
138 struct server_info_t {
139 int32_t listen_port;
140 string cert_nickname;
141 PLHashTable* host_cert_table;
142 PLHashTable* host_clientauth_table;
143 PLHashTable* host_redir_table;
144 PLHashTable* host_ssl3_table;
145 PLHashTable* host_tls1_table;
146 PLHashTable* host_tls11_table;
147 PLHashTable* host_tls12_table;
148 PLHashTable* host_tls13_table;
149 PLHashTable* host_rc4_table;
150 PLHashTable* host_failhandshake_table;
151 };
152
153 struct connection_info_t {
154 PRFileDesc* client_sock;
155 PRNetAddr client_addr;
156 server_info_t* server_info;
157 // the original host in the Host: header for this connection is
158 // stored here, for proxied connections
159 string original_host;
160 // true if no SSL should be used for this connection
161 bool http_proxy_only;
162 // true if this connection is for a WebSocket
163 bool iswebsocket;
164 };
165
166 struct server_match_t {
167 string fullHost;
168 bool matched;
169 };
170
171 const int32_t BUF_SIZE = 16384;
172 const int32_t BUF_MARGIN = 1024;
173 const int32_t BUF_TOTAL = BUF_SIZE + BUF_MARGIN;
174
175 struct relayBuffer {
176 char *buffer, *bufferhead, *buffertail, *bufferend;
177
relayBufferrelayBuffer178 relayBuffer() {
179 // Leave 1024 bytes more for request line manipulations
180 bufferhead = buffertail = buffer = new char[BUF_TOTAL];
181 bufferend = buffer + BUF_SIZE;
182 }
183
~relayBufferrelayBuffer184 ~relayBuffer() { delete[] buffer; }
185
compactrelayBuffer186 void compact() {
187 if (buffertail == bufferhead) buffertail = bufferhead = buffer;
188 }
189
emptyrelayBuffer190 bool empty() { return bufferhead == buffertail; }
areafreerelayBuffer191 size_t areafree() { return bufferend - buffertail; }
marginrelayBuffer192 size_t margin() { return areafree() + BUF_MARGIN; }
presentrelayBuffer193 size_t present() { return buffertail - bufferhead; }
194 };
195
196 // These numbers are multiplied by the number of listening ports (actual
197 // servers running). According the thread pool implementation there is no
198 // need to limit the number of threads initially, threads are allocated
199 // dynamically and stored in a linked list. Initial number of 2 is chosen
200 // to allocate a thread for socket accept and preallocate one for the first
201 // connection that is with high probability expected to come.
202 const uint32_t INITIAL_THREADS = 2;
203 const uint32_t MAX_THREADS = 100;
204 const uint32_t DEFAULT_STACKSIZE = (512 * 1024);
205
206 // global data
207 string nssconfigdir;
208 vector<server_info_t> servers;
209 PRNetAddr remote_addr;
210 PRNetAddr websocket_server;
211 PRThreadPool* threads = nullptr;
212 PRLock* shutdown_lock = nullptr;
213 PRCondVar* shutdown_condvar = nullptr;
214 // Not really used, unless something fails to start
215 bool shutdown_server = false;
216 bool do_http_proxy = false;
217 bool any_host_spec_config = false;
218 bool listen_public = false;
219
ClientAuthValueComparator(const void * v1,const void * v2)220 int ClientAuthValueComparator(const void* v1, const void* v2) {
221 int a = *static_cast<const client_auth_option*>(v1) -
222 *static_cast<const client_auth_option*>(v2);
223 if (a == 0) return 0;
224 if (a > 0) return 1;
225 // (a < 0)
226 return -1;
227 }
228
match_hostname(PLHashEntry * he,int index,void * arg)229 static int match_hostname(PLHashEntry* he, int index, void* arg) {
230 server_match_t* match = (server_match_t*)arg;
231 if (match->fullHost.find((char*)he->key) != string::npos)
232 match->matched = true;
233 return HT_ENUMERATE_NEXT;
234 }
235
236 /*
237 * Signal the main thread that the application should shut down.
238 */
SignalShutdown()239 void SignalShutdown() {
240 PR_Lock(shutdown_lock);
241 PR_NotifyCondVar(shutdown_condvar);
242 PR_Unlock(shutdown_lock);
243 }
244
245 // available flags
246 enum {
247 USE_SSL3 = 1 << 0,
248 USE_RC4 = 1 << 1,
249 FAIL_HANDSHAKE = 1 << 2,
250 USE_TLS1 = 1 << 3,
251 USE_TLS1_1 = 1 << 4,
252 USE_TLS1_2 = 1 << 5,
253 USE_TLS1_3 = 1 << 6
254 };
255
ReadConnectRequest(server_info_t * server_info,relayBuffer & buffer,int32_t * result,string & certificate,client_auth_option * clientauth,string & host,string & location,int32_t * flags)256 bool ReadConnectRequest(server_info_t* server_info, relayBuffer& buffer,
257 int32_t* result, string& certificate,
258 client_auth_option* clientauth, string& host,
259 string& location, int32_t* flags) {
260 if (buffer.present() < 4) {
261 LOG_DEBUG(
262 (" !! only %d bytes present in the buffer", (int)buffer.present()));
263 return false;
264 }
265 if (strncmp(buffer.buffertail - 4, "\r\n\r\n", 4)) {
266 LOG_ERRORD((" !! request is not tailed with CRLFCRLF but with %x %x %x %x",
267 *(buffer.buffertail - 4), *(buffer.buffertail - 3),
268 *(buffer.buffertail - 2), *(buffer.buffertail - 1)));
269 return false;
270 }
271
272 LOG_DEBUG((" parsing initial connect request, dump:\n%.*s\n",
273 (int)buffer.present(), buffer.bufferhead));
274
275 *result = 400;
276
277 char* token;
278 char* _caret;
279 token = strtok2(buffer.bufferhead, " ", &_caret);
280 if (!token) {
281 LOG_ERRORD((" no space found"));
282 return true;
283 }
284 if (strcmp(token, "CONNECT")) {
285 LOG_ERRORD((" not CONNECT request but %s", token));
286 return true;
287 }
288
289 token = strtok2(_caret, " ", &_caret);
290 void* c = PL_HashTableLookup(server_info->host_cert_table, token);
291 if (c) certificate = static_cast<char*>(c);
292
293 host = "https://";
294 host += token;
295
296 c = PL_HashTableLookup(server_info->host_clientauth_table, token);
297 if (c)
298 *clientauth = *static_cast<client_auth_option*>(c);
299 else
300 *clientauth = caNone;
301
302 void* redir = PL_HashTableLookup(server_info->host_redir_table, token);
303 if (redir) location = static_cast<char*>(redir);
304
305 if (PL_HashTableLookup(server_info->host_ssl3_table, token)) {
306 *flags |= USE_SSL3;
307 }
308
309 if (PL_HashTableLookup(server_info->host_rc4_table, token)) {
310 *flags |= USE_RC4;
311 }
312
313 if (PL_HashTableLookup(server_info->host_tls1_table, token)) {
314 *flags |= USE_TLS1;
315 }
316
317 if (PL_HashTableLookup(server_info->host_tls11_table, token)) {
318 *flags |= USE_TLS1_1;
319 }
320
321 if (PL_HashTableLookup(server_info->host_tls12_table, token)) {
322 *flags |= USE_TLS1_2;
323 }
324
325 if (PL_HashTableLookup(server_info->host_tls13_table, token)) {
326 *flags |= USE_TLS1_3;
327 }
328
329 if (PL_HashTableLookup(server_info->host_failhandshake_table, token)) {
330 *flags |= FAIL_HANDSHAKE;
331 }
332
333 token = strtok2(_caret, "/", &_caret);
334 if (strcmp(token, "HTTP")) {
335 LOG_ERRORD((" not tailed with HTTP but with %s", token));
336 return true;
337 }
338
339 *result = (redir) ? 302 : 200;
340 return true;
341 }
342
ConfigureSSLServerSocket(PRFileDesc * socket,server_info_t * si,const string & certificate,const client_auth_option clientAuth,int32_t flags)343 bool ConfigureSSLServerSocket(PRFileDesc* socket, server_info_t* si,
344 const string& certificate,
345 const client_auth_option clientAuth,
346 int32_t flags) {
347 const char* certnick =
348 certificate.empty() ? si->cert_nickname.c_str() : certificate.c_str();
349
350 UniqueCERTCertificate cert(PK11_FindCertFromNickname(certnick, nullptr));
351 if (!cert) {
352 LOG_ERROR(("Failed to find cert %s\n", certnick));
353 return false;
354 }
355
356 UniqueSECKEYPrivateKey privKey(PK11_FindKeyByAnyCert(cert.get(), nullptr));
357 if (!privKey) {
358 LOG_ERROR(("Failed to find private key\n"));
359 return false;
360 }
361
362 PRFileDesc* ssl_socket = SSL_ImportFD(nullptr, socket);
363 if (!ssl_socket) {
364 LOG_ERROR(("Error importing SSL socket\n"));
365 return false;
366 }
367
368 if (flags & FAIL_HANDSHAKE) {
369 // deliberately cause handshake to fail by sending the client a client hello
370 SSL_ResetHandshake(ssl_socket, false);
371 return true;
372 }
373
374 SSLKEAType certKEA = NSS_FindCertKEAType(cert.get());
375 if (SSL_ConfigSecureServer(ssl_socket, cert.get(), privKey.get(), certKEA) !=
376 SECSuccess) {
377 LOG_ERROR(("Error configuring SSL server socket\n"));
378 return false;
379 }
380
381 SSL_OptionSet(ssl_socket, SSL_SECURITY, true);
382 SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_CLIENT, false);
383 SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_SERVER, true);
384 SSL_OptionSet(ssl_socket, SSL_ENABLE_SESSION_TICKETS, true);
385
386 if (clientAuth != caNone) {
387 // If we're requesting or requiring a client certificate, we should
388 // configure NSS to include the "certificate_authorities" field in the
389 // certificate request message. That way we can test that gecko properly
390 // takes note of it.
391 UniqueCERTCertificate issuer(
392 CERT_FindCertIssuer(cert.get(), PR_Now(), certUsageAnyCA));
393 if (!issuer) {
394 LOG_DEBUG(("Failed to find issuer for %s\n", certnick));
395 return false;
396 }
397 UniqueCERTCertList issuerList(CERT_NewCertList());
398 if (!issuerList) {
399 LOG_ERROR(("Failed to allocate new CERTCertList\n"));
400 return false;
401 }
402 if (CERT_AddCertToListTail(issuerList.get(), issuer.get()) != SECSuccess) {
403 LOG_ERROR(("Failed to add issuer to issuerList\n"));
404 return false;
405 }
406 Unused << issuer.release(); // Ownership transferred to issuerList.
407 if (SSL_SetTrustAnchors(ssl_socket, issuerList.get()) != SECSuccess) {
408 LOG_ERROR(
409 ("Failed to set certificate_authorities list for client "
410 "authentication\n"));
411 return false;
412 }
413 SSL_OptionSet(ssl_socket, SSL_REQUEST_CERTIFICATE, true);
414 SSL_OptionSet(ssl_socket, SSL_REQUIRE_CERTIFICATE, clientAuth == caRequire);
415 }
416
417 SSLVersionRange range = {SSL_LIBRARY_VERSION_TLS_1_3,
418 SSL_LIBRARY_VERSION_3_0};
419 if (flags & USE_SSL3) {
420 range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_3_0);
421 range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_3_0);
422 }
423 if (flags & USE_TLS1) {
424 range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_TLS_1_0);
425 range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_TLS_1_0);
426 }
427 if (flags & USE_TLS1_1) {
428 range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_TLS_1_1);
429 range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_TLS_1_1);
430 }
431 if (flags & USE_TLS1_2) {
432 range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_TLS_1_2);
433 range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_TLS_1_2);
434 }
435 if (flags & USE_TLS1_3) {
436 range.min = PR_MIN(range.min, SSL_LIBRARY_VERSION_TLS_1_3);
437 range.max = PR_MAX(range.max, SSL_LIBRARY_VERSION_TLS_1_3);
438 }
439 // Set the valid range, if any were specified (if not, skip
440 // when the default range is invalid, i.e. max > min)
441 if (range.min <= range.max &&
442 SSL_VersionRangeSet(ssl_socket, &range) != SECSuccess) {
443 LOG_ERROR(("Error configuring SSL socket version range\n"));
444 return false;
445 }
446
447 if (flags & USE_RC4) {
448 for (uint16_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
449 uint16_t cipher_id = SSL_ImplementedCiphers[i];
450 switch (cipher_id) {
451 case TLS_ECDHE_ECDSA_WITH_RC4_128_SHA:
452 case TLS_ECDHE_RSA_WITH_RC4_128_SHA:
453 case TLS_RSA_WITH_RC4_128_SHA:
454 case TLS_RSA_WITH_RC4_128_MD5:
455 SSL_CipherPrefSet(ssl_socket, cipher_id, true);
456 break;
457
458 default:
459 SSL_CipherPrefSet(ssl_socket, cipher_id, false);
460 break;
461 }
462 }
463 }
464
465 SSL_ResetHandshake(ssl_socket, true);
466
467 return true;
468 }
469
470 /**
471 * This function examines the buffer for a Sec-WebSocket-Location: field,
472 * and if it's present, it replaces the hostname in that field with the
473 * value in the server's original_host field. This function works
474 * in the reverse direction as AdjustWebSocketHost(), replacing the real
475 * hostname of a response with the potentially fake hostname that is expected
476 * by the browser (e.g., mochi.test).
477 *
478 * @return true if the header was adjusted successfully, or not found, false
479 * if the header is present but the url is not, which should indicate
480 * that more data needs to be read from the socket
481 */
AdjustWebSocketLocation(relayBuffer & buffer,connection_info_t * ci)482 bool AdjustWebSocketLocation(relayBuffer& buffer, connection_info_t* ci) {
483 assert(buffer.margin());
484 buffer.buffertail[1] = '\0';
485
486 char* wsloc = strstr(buffer.bufferhead, "Sec-WebSocket-Location:");
487 if (!wsloc) return true;
488 // advance pointer to the start of the hostname
489 wsloc = strstr(wsloc, "ws://");
490 if (!wsloc) return false;
491 wsloc += 5;
492 // find the end of the hostname
493 char* wslocend = strchr(wsloc + 1, '/');
494 if (!wslocend) return false;
495 char* crlf = strstr(wsloc, "\r\n");
496 if (!crlf) return false;
497 if (ci->original_host.empty()) return true;
498
499 int diff = ci->original_host.length() - (wslocend - wsloc);
500 if (diff > 0) assert(size_t(diff) <= buffer.margin());
501 memmove(wslocend + diff, wslocend, buffer.buffertail - wsloc - diff);
502 buffer.buffertail += diff;
503
504 memcpy(wsloc, ci->original_host.c_str(), ci->original_host.length());
505 return true;
506 }
507
508 /**
509 * This function examines the buffer for a Host: field, and if it's present,
510 * it replaces the hostname in that field with the hostname in the server's
511 * remote_addr field. This is needed because proxy requests may be coming
512 * from mochitest with fake hosts, like mochi.test, and these need to be
513 * replaced with the host that the destination server is actually running
514 * on.
515 */
AdjustWebSocketHost(relayBuffer & buffer,connection_info_t * ci)516 bool AdjustWebSocketHost(relayBuffer& buffer, connection_info_t* ci) {
517 const char HEADER_UPGRADE[] = "Upgrade:";
518 const char HEADER_HOST[] = "Host:";
519
520 PRNetAddr inet_addr =
521 (websocket_server.inet.port ? websocket_server : remote_addr);
522
523 assert(buffer.margin());
524
525 // Cannot use strnchr so add a null char at the end. There is always some
526 // space left because we preserve a margin.
527 buffer.buffertail[1] = '\0';
528
529 // Verify this is a WebSocket header.
530 char* h1 = strstr(buffer.bufferhead, HEADER_UPGRADE);
531 if (!h1) return false;
532 h1 += strlen(HEADER_UPGRADE);
533 h1 += strspn(h1, " \t");
534 char* h2 = strstr(h1, "WebSocket\r\n");
535 if (!h2) h2 = strstr(h1, "websocket\r\n");
536 if (!h2) h2 = strstr(h1, "Websocket\r\n");
537 if (!h2) return false;
538
539 char* host = strstr(buffer.bufferhead, HEADER_HOST);
540 if (!host) return false;
541 // advance pointer to beginning of hostname
542 host += strlen(HEADER_HOST);
543 host += strspn(host, " \t");
544
545 char* endhost = strstr(host, "\r\n");
546 if (!endhost) return false;
547
548 // Save the original host, so we can use it later on responses from the
549 // server.
550 ci->original_host.assign(host, endhost - host);
551
552 char newhost[40];
553 PR_NetAddrToString(&inet_addr, newhost, sizeof(newhost));
554 assert(strlen(newhost) < sizeof(newhost) - 7);
555 SprintfLiteral(newhost, "%s:%d", newhost, PR_ntohs(inet_addr.inet.port));
556
557 int diff = strlen(newhost) - (endhost - host);
558 if (diff > 0) assert(size_t(diff) <= buffer.margin());
559 memmove(endhost + diff, endhost, buffer.buffertail - host - diff);
560 buffer.buffertail += diff;
561
562 memcpy(host, newhost, strlen(newhost));
563 return true;
564 }
565
566 /**
567 * This function prefixes Request-URI path with a full scheme-host-port
568 * string.
569 */
AdjustRequestURI(relayBuffer & buffer,string * host)570 bool AdjustRequestURI(relayBuffer& buffer, string* host) {
571 assert(buffer.margin());
572
573 // Cannot use strnchr so add a null char at the end. There is always some
574 // space left because we preserve a margin.
575 buffer.buffertail[1] = '\0';
576 LOG_DEBUG((" incoming request to adjust:\n%s\n", buffer.bufferhead));
577
578 char *token, *path;
579 path = strchr(buffer.bufferhead, ' ') + 1;
580 if (!path) return false;
581
582 // If the path doesn't start with a slash don't change it, it is probably '*'
583 // or a full path already. Return true, we are done with this request
584 // adjustment.
585 if (*path != '/') return true;
586
587 token = strchr(path, ' ') + 1;
588 if (!token) return false;
589
590 if (strncmp(token, "HTTP/", 5)) return false;
591
592 size_t hostlength = host->length();
593 assert(hostlength <= buffer.margin());
594
595 memmove(path + hostlength, path, buffer.buffertail - path);
596 memcpy(path, host->c_str(), hostlength);
597 buffer.buffertail += hostlength;
598
599 return true;
600 }
601
ConnectSocket(UniquePRFileDesc & fd,const PRNetAddr * addr,PRIntervalTime timeout)602 bool ConnectSocket(UniquePRFileDesc& fd, const PRNetAddr* addr,
603 PRIntervalTime timeout) {
604 PRStatus stat = PR_Connect(fd.get(), addr, timeout);
605 if (stat != PR_SUCCESS) return false;
606
607 PRSocketOptionData option;
608 option.option = PR_SockOpt_Nonblocking;
609 option.value.non_blocking = true;
610 PR_SetSocketOption(fd.get(), &option);
611
612 return true;
613 }
614
615 /*
616 * Handle an incoming client connection. The server thread has already
617 * accepted the connection, so we just need to connect to the remote
618 * port and then proxy data back and forth.
619 * The data parameter is a connection_info_t*, and must be deleted
620 * by this function.
621 */
HandleConnection(void * data)622 void HandleConnection(void* data) {
623 connection_info_t* ci = static_cast<connection_info_t*>(data);
624 PRIntervalTime connect_timeout = PR_SecondsToInterval(30);
625
626 UniquePRFileDesc other_sock(PR_NewTCPSocket());
627 bool client_done = false;
628 bool client_error = false;
629 bool connect_accepted = !do_http_proxy;
630 bool ssl_updated = !do_http_proxy;
631 bool expect_request_start = do_http_proxy;
632 string certificateToUse;
633 string locationHeader;
634 client_auth_option clientAuth;
635 string fullHost;
636 int32_t flags = 0;
637
638 LOG_DEBUG(("SSLTUNNEL(%p)): incoming connection csock(0)=%p, ssock(1)=%p\n",
639 static_cast<void*>(data), static_cast<void*>(ci->client_sock),
640 static_cast<void*>(other_sock.get())));
641 if (other_sock) {
642 int32_t numberOfSockets = 1;
643
644 relayBuffer buffers[2];
645
646 if (!do_http_proxy) {
647 if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info,
648 certificateToUse, caNone, flags))
649 client_error = true;
650 else if (!ConnectSocket(other_sock, &remote_addr, connect_timeout))
651 client_error = true;
652 else
653 numberOfSockets = 2;
654 }
655
656 PRPollDesc sockets[2] = {{ci->client_sock, PR_POLL_READ, 0},
657 {other_sock.get(), PR_POLL_READ, 0}};
658 bool socketErrorState[2] = {false, false};
659
660 while (!((client_error || client_done) && buffers[0].empty() &&
661 buffers[1].empty())) {
662 sockets[0].in_flags |= PR_POLL_EXCEPT;
663 sockets[1].in_flags |= PR_POLL_EXCEPT;
664 LOG_DEBUG(("SSLTUNNEL(%p)): polling flags csock(0)=%c%c, ssock(1)=%c%c\n",
665 static_cast<void*>(data),
666 sockets[0].in_flags & PR_POLL_READ ? 'R' : '-',
667 sockets[0].in_flags & PR_POLL_WRITE ? 'W' : '-',
668 sockets[1].in_flags & PR_POLL_READ ? 'R' : '-',
669 sockets[1].in_flags & PR_POLL_WRITE ? 'W' : '-'));
670 int32_t pollStatus =
671 PR_Poll(sockets, numberOfSockets, PR_MillisecondsToInterval(1000));
672 if (pollStatus < 0) {
673 LOG_DEBUG(("SSLTUNNEL(%p)): pollStatus=%d, exiting\n",
674 static_cast<void*>(data), pollStatus));
675 client_error = true;
676 break;
677 }
678
679 if (pollStatus == 0) {
680 // timeout
681 LOG_DEBUG(("SSLTUNNEL(%p)): poll timeout, looping\n",
682 static_cast<void*>(data)));
683 continue;
684 }
685
686 for (int32_t s = 0; s < numberOfSockets; ++s) {
687 int32_t s2 = s == 1 ? 0 : 1;
688 int16_t out_flags = sockets[s].out_flags;
689 int16_t& in_flags = sockets[s].in_flags;
690 int16_t& in_flags2 = sockets[s2].in_flags;
691 sockets[s].out_flags = 0;
692
693 LOG_BEGIN_BLOCK();
694 LOG_DEBUG(("SSLTUNNEL(%p)): %csock(%d)=%p out_flags=%d",
695 static_cast<void*>(data), s == 0 ? 'c' : 's', s,
696 static_cast<void*>(sockets[s].fd), out_flags));
697 if (out_flags & (PR_POLL_EXCEPT | PR_POLL_ERR | PR_POLL_HUP)) {
698 LOG_DEBUG((" :exception\n"));
699 client_error = true;
700 socketErrorState[s] = true;
701 // We got a fatal error state on the socket. Clear the output buffer
702 // for this socket to break the main loop, we will never more be able
703 // to send those data anyway.
704 buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
705 continue;
706 } // PR_POLL_EXCEPT, PR_POLL_ERR, PR_POLL_HUP handling
707
708 if (out_flags & PR_POLL_READ && !buffers[s].areafree()) {
709 LOG_DEBUG(
710 (" no place in read buffer but got read flag, dropping it now!"));
711 in_flags &= ~PR_POLL_READ;
712 }
713
714 if (out_flags & PR_POLL_READ && buffers[s].areafree()) {
715 LOG_DEBUG((" :reading"));
716 int32_t bytesRead =
717 PR_Recv(sockets[s].fd, buffers[s].buffertail,
718 buffers[s].areafree(), 0, PR_INTERVAL_NO_TIMEOUT);
719
720 if (bytesRead == 0) {
721 LOG_DEBUG((" socket gracefully closed"));
722 client_done = true;
723 in_flags &= ~PR_POLL_READ;
724 } else if (bytesRead < 0) {
725 if (PR_GetError() != PR_WOULD_BLOCK_ERROR) {
726 LOG_DEBUG((" error=%d", PR_GetError()));
727 // We are in error state, indicate that the connection was
728 // not closed gracefully
729 client_error = true;
730 socketErrorState[s] = true;
731 // Wipe out our send buffer, we cannot send it anyway.
732 buffers[s2].bufferhead = buffers[s2].buffertail =
733 buffers[s2].buffer;
734 } else
735 LOG_DEBUG((" would block"));
736 } else {
737 // If the other socket is in error state (unable to send/receive)
738 // throw this data away and continue loop
739 if (socketErrorState[s2]) {
740 LOG_DEBUG((" have read but other socket is in error state\n"));
741 continue;
742 }
743
744 buffers[s].buffertail += bytesRead;
745 LOG_DEBUG((", read %d bytes", bytesRead));
746
747 // We have to accept and handle the initial CONNECT request here
748 int32_t response;
749 if (!connect_accepted &&
750 ReadConnectRequest(ci->server_info, buffers[s], &response,
751 certificateToUse, &clientAuth, fullHost,
752 locationHeader, &flags)) {
753 // Mark this as a proxy-only connection (no SSL) if the CONNECT
754 // request didn't come for port 443 or from any of the server's
755 // cert or clientauth hostnames.
756 if (fullHost.find(":443") == string::npos) {
757 server_match_t match;
758 match.fullHost = fullHost;
759 match.matched = false;
760 PL_HashTableEnumerateEntries(ci->server_info->host_cert_table,
761 match_hostname, &match);
762 PL_HashTableEnumerateEntries(
763 ci->server_info->host_clientauth_table, match_hostname,
764 &match);
765 PL_HashTableEnumerateEntries(ci->server_info->host_ssl3_table,
766 match_hostname, &match);
767 PL_HashTableEnumerateEntries(ci->server_info->host_tls1_table,
768 match_hostname, &match);
769 PL_HashTableEnumerateEntries(ci->server_info->host_tls11_table,
770 match_hostname, &match);
771 PL_HashTableEnumerateEntries(ci->server_info->host_tls12_table,
772 match_hostname, &match);
773 PL_HashTableEnumerateEntries(ci->server_info->host_tls13_table,
774 match_hostname, &match);
775 PL_HashTableEnumerateEntries(ci->server_info->host_rc4_table,
776 match_hostname, &match);
777 PL_HashTableEnumerateEntries(
778 ci->server_info->host_failhandshake_table, match_hostname,
779 &match);
780 ci->http_proxy_only = !match.matched;
781 } else {
782 ci->http_proxy_only = false;
783 }
784
785 // Clean the request as it would be read
786 buffers[s].bufferhead = buffers[s].buffertail = buffers[s].buffer;
787 in_flags |= PR_POLL_WRITE;
788 connect_accepted = true;
789
790 // Store response to the oposite buffer
791 if (response == 200) {
792 LOG_DEBUG(
793 (" accepted CONNECT request, connected to the server, "
794 "sending OK to the client\n"));
795 strcpy(
796 buffers[s2].buffer,
797 "HTTP/1.1 200 Connected\r\nConnection: keep-alive\r\n\r\n");
798 } else if (response == 302) {
799 LOG_DEBUG(
800 (" accepted CONNECT request with redirection, "
801 "sending location and 302 to the client\n"));
802 client_done = true;
803 snprintf(buffers[s2].buffer,
804 buffers[s2].bufferend - buffers[s2].buffer,
805 "HTTP/1.1 302 Moved\r\n"
806 "Location: https://%s/\r\n"
807 "Connection: close\r\n\r\n",
808 locationHeader.c_str());
809 } else {
810 LOG_ERRORD(
811 (" could not read the connect request, closing connection "
812 "with %d",
813 response));
814 client_done = true;
815 snprintf(buffers[s2].buffer,
816 buffers[s2].bufferend - buffers[s2].buffer,
817 "HTTP/1.1 %d ERROR\r\nConnection: close\r\n\r\n",
818 response);
819
820 break;
821 }
822
823 buffers[s2].buffertail =
824 buffers[s2].buffer + strlen(buffers[s2].buffer);
825
826 // Send the response to the client socket
827 break;
828 } // end of CONNECT handling
829
830 if (!buffers[s].areafree()) {
831 // Do not poll for read when the buffer is full
832 LOG_DEBUG((" no place in our read buffer, stop reading"));
833 in_flags &= ~PR_POLL_READ;
834 }
835
836 if (ssl_updated) {
837 if (s == 0 && expect_request_start) {
838 if (!strstr(buffers[s].bufferhead, "\r\n\r\n")) {
839 // We haven't received the complete header yet, so wait.
840 continue;
841 }
842 ci->iswebsocket = AdjustWebSocketHost(buffers[s], ci);
843 expect_request_start = !(
844 ci->iswebsocket || AdjustRequestURI(buffers[s], &fullHost));
845 PRNetAddr* addr = &remote_addr;
846 if (ci->iswebsocket && websocket_server.inet.port)
847 addr = &websocket_server;
848 if (!ConnectSocket(other_sock, addr, connect_timeout)) {
849 LOG_ERRORD(
850 (" could not open connection to the real server\n"));
851 client_error = true;
852 break;
853 }
854 LOG_DEBUG(("\n connected to remote server\n"));
855 numberOfSockets = 2;
856 } else if (s == 1 && ci->iswebsocket) {
857 if (!AdjustWebSocketLocation(buffers[s], ci)) continue;
858 }
859
860 in_flags2 |= PR_POLL_WRITE;
861 LOG_DEBUG((" telling the other socket to write"));
862 } else
863 LOG_DEBUG(
864 (" we have something for the other socket to write, but ssl "
865 "has not been administered on it"));
866 }
867 } // PR_POLL_READ handling
868
869 if (out_flags & PR_POLL_WRITE) {
870 LOG_DEBUG((" :writing"));
871 int32_t bytesWrite =
872 PR_Send(sockets[s].fd, buffers[s2].bufferhead,
873 buffers[s2].present(), 0, PR_INTERVAL_NO_TIMEOUT);
874
875 if (bytesWrite < 0) {
876 if (PR_GetError() != PR_WOULD_BLOCK_ERROR) {
877 LOG_DEBUG((" error=%d", PR_GetError()));
878 client_error = true;
879 socketErrorState[s] = true;
880 // We got a fatal error while writting the buffer. Clear it to
881 // break the main loop, we will never more be able to send it.
882 buffers[s2].bufferhead = buffers[s2].buffertail =
883 buffers[s2].buffer;
884 } else
885 LOG_DEBUG((" would block"));
886 } else {
887 LOG_DEBUG((", written %d bytes", bytesWrite));
888 buffers[s2].buffertail[1] = '\0';
889 LOG_DEBUG((" dump:\n%.*s\n", bytesWrite, buffers[s2].bufferhead));
890
891 buffers[s2].bufferhead += bytesWrite;
892 if (buffers[s2].present()) {
893 LOG_DEBUG((" still have to write %d bytes",
894 (int)buffers[s2].present()));
895 in_flags |= PR_POLL_WRITE;
896 } else {
897 if (!ssl_updated) {
898 LOG_DEBUG((" proxy response sent to the client"));
899 // Proxy response has just been writen, update to ssl
900 ssl_updated = true;
901 if (ci->http_proxy_only) {
902 LOG_DEBUG(
903 (" not updating to SSL based on http_proxy_only for this "
904 "socket"));
905 } else if (!ConfigureSSLServerSocket(
906 ci->client_sock, ci->server_info,
907 certificateToUse, clientAuth, flags)) {
908 LOG_ERRORD((" failed to config server socket\n"));
909 client_error = true;
910 break;
911 } else {
912 LOG_DEBUG((" client socket updated to SSL"));
913 }
914 } // sslUpdate
915
916 LOG_DEBUG(
917 (" dropping our write flag and setting other socket read "
918 "flag"));
919 in_flags &= ~PR_POLL_WRITE;
920 in_flags2 |= PR_POLL_READ;
921 buffers[s2].compact();
922 }
923 }
924 } // PR_POLL_WRITE handling
925 LOG_END_BLOCK(); // end the log
926 } // for...
927 } // while, poll
928 } else
929 client_error = true;
930
931 LOG_DEBUG(("SSLTUNNEL(%p)): exiting root function for csock=%p, ssock=%p\n",
932 static_cast<void*>(data), static_cast<void*>(ci->client_sock),
933 static_cast<void*>(other_sock.get())));
934 if (!client_error) PR_Shutdown(ci->client_sock, PR_SHUTDOWN_SEND);
935 PR_Close(ci->client_sock);
936
937 delete ci;
938 }
939
940 /*
941 * Start listening for SSL connections on a specified port, handing
942 * them off to client threads after accepting the connection.
943 * The data parameter is a server_info_t*, owned by the calling
944 * function.
945 */
StartServer(void * data)946 void StartServer(void* data) {
947 server_info_t* si = static_cast<server_info_t*>(data);
948
949 // TODO: select ciphers?
950 UniquePRFileDesc listen_socket(PR_NewTCPSocket());
951 if (!listen_socket) {
952 LOG_ERROR(("failed to create socket\n"));
953 SignalShutdown();
954 return;
955 }
956
957 // In case the socket is still open in the TIME_WAIT state from a previous
958 // instance of ssltunnel we ask to reuse the port.
959 PRSocketOptionData socket_option;
960 socket_option.option = PR_SockOpt_Reuseaddr;
961 socket_option.value.reuse_addr = true;
962 PR_SetSocketOption(listen_socket.get(), &socket_option);
963
964 PRNetAddr server_addr;
965 PRNetAddrValue listen_addr;
966 if (listen_public) {
967 listen_addr = PR_IpAddrAny;
968 } else {
969 listen_addr = PR_IpAddrLoopback;
970 }
971 PR_InitializeNetAddr(listen_addr, si->listen_port, &server_addr);
972
973 if (PR_Bind(listen_socket.get(), &server_addr) != PR_SUCCESS) {
974 LOG_ERROR(("failed to bind socket on port %d: error %d\n", si->listen_port,
975 PR_GetError()));
976 SignalShutdown();
977 return;
978 }
979
980 if (PR_Listen(listen_socket.get(), 1) != PR_SUCCESS) {
981 LOG_ERROR(("failed to listen on socket\n"));
982 SignalShutdown();
983 return;
984 }
985
986 LOG_INFO(("Server listening on port %d with cert %s\n", si->listen_port,
987 si->cert_nickname.c_str()));
988
989 while (!shutdown_server) {
990 connection_info_t* ci = new connection_info_t();
991 ci->server_info = si;
992 ci->http_proxy_only = do_http_proxy;
993 // block waiting for connections
994 ci->client_sock = PR_Accept(listen_socket.get(), &ci->client_addr,
995 PR_INTERVAL_NO_TIMEOUT);
996
997 PRSocketOptionData option;
998 option.option = PR_SockOpt_Nonblocking;
999 option.value.non_blocking = true;
1000 PR_SetSocketOption(ci->client_sock, &option);
1001
1002 if (ci->client_sock)
1003 // Not actually using this PRJob*...
1004 // PRJob* job =
1005 PR_QueueJob(threads, HandleConnection, ci, true);
1006 else
1007 delete ci;
1008 }
1009 }
1010
1011 // bogus password func, just don't use passwords. :-P
password_func(PK11SlotInfo * slot,PRBool retry,void * arg)1012 char* password_func(PK11SlotInfo* slot, PRBool retry, void* arg) {
1013 if (retry) return nullptr;
1014
1015 return PL_strdup("");
1016 }
1017
findServerInfo(int portnumber)1018 server_info_t* findServerInfo(int portnumber) {
1019 for (auto& server : servers) {
1020 if (server.listen_port == portnumber) return &server;
1021 }
1022
1023 return nullptr;
1024 }
1025
get_ssl3_table(server_info_t * server)1026 PLHashTable* get_ssl3_table(server_info_t* server) {
1027 return server->host_ssl3_table;
1028 }
1029
get_tls1_table(server_info_t * server)1030 PLHashTable* get_tls1_table(server_info_t* server) {
1031 return server->host_tls1_table;
1032 }
1033
get_tls11_table(server_info_t * server)1034 PLHashTable* get_tls11_table(server_info_t* server) {
1035 return server->host_tls11_table;
1036 }
1037
get_tls12_table(server_info_t * server)1038 PLHashTable* get_tls12_table(server_info_t* server) {
1039 return server->host_tls12_table;
1040 }
1041
get_tls13_table(server_info_t * server)1042 PLHashTable* get_tls13_table(server_info_t* server) {
1043 return server->host_tls13_table;
1044 }
1045
get_rc4_table(server_info_t * server)1046 PLHashTable* get_rc4_table(server_info_t* server) {
1047 return server->host_rc4_table;
1048 }
1049
get_failhandshake_table(server_info_t * server)1050 PLHashTable* get_failhandshake_table(server_info_t* server) {
1051 return server->host_failhandshake_table;
1052 }
1053
parseWeakCryptoConfig(char * const & keyword,char * & _caret,PLHashTable * (* get_table)(server_info_t *))1054 int parseWeakCryptoConfig(char* const& keyword, char*& _caret,
1055 PLHashTable* (*get_table)(server_info_t*)) {
1056 char* hostname = strtok2(_caret, ":", &_caret);
1057 char* hostportstring = strtok2(_caret, ":", &_caret);
1058 char* serverportstring = strtok2(_caret, "\n", &_caret);
1059
1060 int port = atoi(serverportstring);
1061 if (port <= 0) {
1062 LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
1063 return 1;
1064 }
1065
1066 if (server_info_t* existingServer = findServerInfo(port)) {
1067 any_host_spec_config = true;
1068
1069 char* hostname_copy =
1070 new char[strlen(hostname) + strlen(hostportstring) + 2];
1071 if (!hostname_copy) {
1072 LOG_ERROR(("Out of memory"));
1073 return 1;
1074 }
1075
1076 strcpy(hostname_copy, hostname);
1077 strcat(hostname_copy, ":");
1078 strcat(hostname_copy, hostportstring);
1079
1080 PLHashEntry* entry =
1081 PL_HashTableAdd(get_table(existingServer), hostname_copy, keyword);
1082 if (!entry) {
1083 LOG_ERROR(("Out of memory"));
1084 return 1;
1085 }
1086 } else {
1087 LOG_ERROR(
1088 ("Server on port %d for redirhost option is not defined, use 'listen' "
1089 "option first",
1090 port));
1091 return 1;
1092 }
1093
1094 return 0;
1095 }
1096
processConfigLine(char * configLine)1097 int processConfigLine(char* configLine) {
1098 if (*configLine == 0 || *configLine == '#') return 0;
1099
1100 char* _caret;
1101 char* keyword = strtok2(configLine, ":", &_caret);
1102 // Configure usage of http/ssl tunneling proxy behavior
1103 if (!strcmp(keyword, "httpproxy")) {
1104 char* value = strtok2(_caret, ":", &_caret);
1105 if (!strcmp(value, "1")) do_http_proxy = true;
1106
1107 return 0;
1108 }
1109
1110 if (!strcmp(keyword, "websocketserver")) {
1111 char* ipstring = strtok2(_caret, ":", &_caret);
1112 if (PR_StringToNetAddr(ipstring, &websocket_server) != PR_SUCCESS) {
1113 LOG_ERROR(("Invalid IP address in proxy config: %s\n", ipstring));
1114 return 1;
1115 }
1116 char* remoteport = strtok2(_caret, ":", &_caret);
1117 int port = atoi(remoteport);
1118 if (port <= 0) {
1119 LOG_ERROR(("Invalid remote port in proxy config: %s\n", remoteport));
1120 return 1;
1121 }
1122 websocket_server.inet.port = PR_htons(port);
1123 return 0;
1124 }
1125
1126 // Configure the forward address of the target server
1127 if (!strcmp(keyword, "forward")) {
1128 char* ipstring = strtok2(_caret, ":", &_caret);
1129 if (PR_StringToNetAddr(ipstring, &remote_addr) != PR_SUCCESS) {
1130 LOG_ERROR(("Invalid remote IP address: %s\n", ipstring));
1131 return 1;
1132 }
1133 char* serverportstring = strtok2(_caret, ":", &_caret);
1134 int port = atoi(serverportstring);
1135 if (port <= 0) {
1136 LOG_ERROR(("Invalid remote port: %s\n", serverportstring));
1137 return 1;
1138 }
1139 remote_addr.inet.port = PR_htons(port);
1140
1141 return 0;
1142 }
1143
1144 // Configure all listen sockets and port+certificate bindings.
1145 // Listen on the public address if "*" was specified as the listen
1146 // address or listen on the loopback address if "127.0.0.1" was
1147 // specified. Using loopback will prevent users getting errors from
1148 // their firewalls about ssltunnel needing permission. A public
1149 // address is required when proxying ssl traffic from a physical or
1150 // emulated Android device since it has a different ip address from
1151 // the host.
1152 if (!strcmp(keyword, "listen")) {
1153 char* hostname = strtok2(_caret, ":", &_caret);
1154 char* hostportstring = nullptr;
1155 if (!strcmp(hostname, "*")) {
1156 listen_public = true;
1157 } else if (strcmp(hostname, "127.0.0.1")) {
1158 any_host_spec_config = true;
1159 hostportstring = strtok2(_caret, ":", &_caret);
1160 }
1161
1162 char* serverportstring = strtok2(_caret, ":", &_caret);
1163 char* certnick = strtok2(_caret, ":", &_caret);
1164
1165 int port = atoi(serverportstring);
1166 if (port <= 0) {
1167 LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
1168 return 1;
1169 }
1170
1171 if (server_info_t* existingServer = findServerInfo(port)) {
1172 if (!hostportstring) {
1173 LOG_ERROR(
1174 ("Null hostportstring specified for hostname %s\n", hostname));
1175 return 1;
1176 }
1177 char* certnick_copy = new char[strlen(certnick) + 1];
1178 char* hostname_copy =
1179 new char[strlen(hostname) + strlen(hostportstring) + 2];
1180
1181 strcpy(hostname_copy, hostname);
1182 strcat(hostname_copy, ":");
1183 strcat(hostname_copy, hostportstring);
1184 strcpy(certnick_copy, certnick);
1185
1186 PLHashEntry* entry = PL_HashTableAdd(existingServer->host_cert_table,
1187 hostname_copy, certnick_copy);
1188 if (!entry) {
1189 LOG_ERROR(("Out of memory"));
1190 return 1;
1191 }
1192 } else {
1193 server_info_t server;
1194 server.cert_nickname = certnick;
1195 server.listen_port = port;
1196 server.host_cert_table =
1197 PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1198 PL_CompareStrings, nullptr, nullptr);
1199 if (!server.host_cert_table) {
1200 LOG_ERROR(("Internal, could not create hash table\n"));
1201 return 1;
1202 }
1203 server.host_clientauth_table =
1204 PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1205 ClientAuthValueComparator, nullptr, nullptr);
1206 if (!server.host_clientauth_table) {
1207 LOG_ERROR(("Internal, could not create hash table\n"));
1208 return 1;
1209 }
1210 server.host_redir_table =
1211 PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1212 PL_CompareStrings, nullptr, nullptr);
1213 if (!server.host_redir_table) {
1214 LOG_ERROR(("Internal, could not create hash table\n"));
1215 return 1;
1216 }
1217
1218 server.host_ssl3_table =
1219 PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1220 PL_CompareStrings, nullptr, nullptr);
1221
1222 if (!server.host_ssl3_table) {
1223 LOG_ERROR(("Internal, could not create hash table\n"));
1224 return 1;
1225 }
1226
1227 server.host_tls1_table =
1228 PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1229 PL_CompareStrings, nullptr, nullptr);
1230
1231 if (!server.host_tls1_table) {
1232 LOG_ERROR(("Internal, could not create hash table\n"));
1233 return 1;
1234 }
1235
1236 server.host_tls11_table =
1237 PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1238 PL_CompareStrings, nullptr, nullptr);
1239
1240 if (!server.host_tls11_table) {
1241 LOG_ERROR(("Internal, could not create hash table\n"));
1242 return 1;
1243 }
1244
1245 server.host_tls12_table =
1246 PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1247 PL_CompareStrings, nullptr, nullptr);
1248
1249 if (!server.host_tls12_table) {
1250 LOG_ERROR(("Internal, could not create hash table\n"));
1251 return 1;
1252 }
1253
1254 server.host_tls13_table =
1255 PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1256 PL_CompareStrings, nullptr, nullptr);
1257
1258 if (!server.host_tls13_table) {
1259 LOG_ERROR(("Internal, could not create hash table\n"));
1260 return 1;
1261 }
1262
1263 server.host_rc4_table =
1264 PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1265 PL_CompareStrings, nullptr, nullptr);
1266 ;
1267 if (!server.host_rc4_table) {
1268 LOG_ERROR(("Internal, could not create hash table\n"));
1269 return 1;
1270 }
1271
1272 server.host_failhandshake_table =
1273 PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
1274 PL_CompareStrings, nullptr, nullptr);
1275 ;
1276 if (!server.host_failhandshake_table) {
1277 LOG_ERROR(("Internal, could not create hash table\n"));
1278 return 1;
1279 }
1280
1281 servers.push_back(server);
1282 }
1283
1284 return 0;
1285 }
1286
1287 if (!strcmp(keyword, "clientauth")) {
1288 char* hostname = strtok2(_caret, ":", &_caret);
1289 char* hostportstring = strtok2(_caret, ":", &_caret);
1290 char* serverportstring = strtok2(_caret, ":", &_caret);
1291
1292 int port = atoi(serverportstring);
1293 if (port <= 0) {
1294 LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
1295 return 1;
1296 }
1297
1298 if (server_info_t* existingServer = findServerInfo(port)) {
1299 char* authoptionstring = strtok2(_caret, ":", &_caret);
1300 client_auth_option* authoption = new client_auth_option;
1301 if (!authoption) {
1302 LOG_ERROR(("Out of memory"));
1303 return 1;
1304 }
1305
1306 if (!strcmp(authoptionstring, "require"))
1307 *authoption = caRequire;
1308 else if (!strcmp(authoptionstring, "request"))
1309 *authoption = caRequest;
1310 else if (!strcmp(authoptionstring, "none"))
1311 *authoption = caNone;
1312 else {
1313 LOG_ERROR(
1314 ("Incorrect client auth option modifier for host '%s'", hostname));
1315 delete authoption;
1316 return 1;
1317 }
1318
1319 any_host_spec_config = true;
1320
1321 char* hostname_copy =
1322 new char[strlen(hostname) + strlen(hostportstring) + 2];
1323 if (!hostname_copy) {
1324 LOG_ERROR(("Out of memory"));
1325 delete authoption;
1326 return 1;
1327 }
1328
1329 strcpy(hostname_copy, hostname);
1330 strcat(hostname_copy, ":");
1331 strcat(hostname_copy, hostportstring);
1332
1333 PLHashEntry* entry = PL_HashTableAdd(
1334 existingServer->host_clientauth_table, hostname_copy, authoption);
1335 if (!entry) {
1336 LOG_ERROR(("Out of memory"));
1337 delete authoption;
1338 return 1;
1339 }
1340 } else {
1341 LOG_ERROR(
1342 ("Server on port %d for client authentication option is not defined, "
1343 "use 'listen' option first",
1344 port));
1345 return 1;
1346 }
1347
1348 return 0;
1349 }
1350
1351 if (!strcmp(keyword, "redirhost")) {
1352 char* hostname = strtok2(_caret, ":", &_caret);
1353 char* hostportstring = strtok2(_caret, ":", &_caret);
1354 char* serverportstring = strtok2(_caret, ":", &_caret);
1355
1356 int port = atoi(serverportstring);
1357 if (port <= 0) {
1358 LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
1359 return 1;
1360 }
1361
1362 if (server_info_t* existingServer = findServerInfo(port)) {
1363 char* redirhoststring = strtok2(_caret, ":", &_caret);
1364
1365 any_host_spec_config = true;
1366
1367 char* hostname_copy =
1368 new char[strlen(hostname) + strlen(hostportstring) + 2];
1369 if (!hostname_copy) {
1370 LOG_ERROR(("Out of memory"));
1371 return 1;
1372 }
1373
1374 strcpy(hostname_copy, hostname);
1375 strcat(hostname_copy, ":");
1376 strcat(hostname_copy, hostportstring);
1377
1378 char* redir_copy = new char[strlen(redirhoststring) + 1];
1379 strcpy(redir_copy, redirhoststring);
1380 PLHashEntry* entry = PL_HashTableAdd(existingServer->host_redir_table,
1381 hostname_copy, redir_copy);
1382 if (!entry) {
1383 LOG_ERROR(("Out of memory"));
1384 delete[] hostname_copy;
1385 delete[] redir_copy;
1386 return 1;
1387 }
1388 } else {
1389 LOG_ERROR(
1390 ("Server on port %d for redirhost option is not defined, use "
1391 "'listen' option first",
1392 port));
1393 return 1;
1394 }
1395
1396 return 0;
1397 }
1398
1399 if (!strcmp(keyword, "ssl3")) {
1400 return parseWeakCryptoConfig(keyword, _caret, get_ssl3_table);
1401 }
1402 if (!strcmp(keyword, "tls1")) {
1403 return parseWeakCryptoConfig(keyword, _caret, get_tls1_table);
1404 }
1405 if (!strcmp(keyword, "tls1_1")) {
1406 return parseWeakCryptoConfig(keyword, _caret, get_tls11_table);
1407 }
1408 if (!strcmp(keyword, "tls1_2")) {
1409 return parseWeakCryptoConfig(keyword, _caret, get_tls12_table);
1410 }
1411 if (!strcmp(keyword, "tls1_3")) {
1412 return parseWeakCryptoConfig(keyword, _caret, get_tls13_table);
1413 }
1414
1415 if (!strcmp(keyword, "rc4")) {
1416 return parseWeakCryptoConfig(keyword, _caret, get_rc4_table);
1417 }
1418
1419 if (!strcmp(keyword, "failHandshake")) {
1420 return parseWeakCryptoConfig(keyword, _caret, get_failhandshake_table);
1421 }
1422
1423 // Configure the NSS certificate database directory
1424 if (!strcmp(keyword, "certdbdir")) {
1425 nssconfigdir = strtok2(_caret, "\n", &_caret);
1426 return 0;
1427 }
1428
1429 LOG_ERROR(("Error: keyword \"%s\" unexpected\n", keyword));
1430 return 1;
1431 }
1432
parseConfigFile(const char * filePath)1433 int parseConfigFile(const char* filePath) {
1434 FILE* f = fopen(filePath, "r");
1435 if (!f) return 1;
1436
1437 char buffer[1024], *b = buffer;
1438 while (!feof(f)) {
1439 char c;
1440
1441 if (fscanf(f, "%c", &c) != 1) {
1442 break;
1443 }
1444
1445 switch (c) {
1446 case '\n':
1447 *b++ = 0;
1448 if (processConfigLine(buffer)) {
1449 fclose(f);
1450 return 1;
1451 }
1452 b = buffer;
1453 continue;
1454
1455 case '\r':
1456 continue;
1457
1458 default:
1459 *b++ = c;
1460 }
1461 }
1462
1463 fclose(f);
1464
1465 // Check mandatory items
1466 if (nssconfigdir.empty()) {
1467 LOG_ERROR(
1468 ("Error: missing path to NSS certification database\n,use "
1469 "certdbdir:<path> in the config file\n"));
1470 return 1;
1471 }
1472
1473 if (any_host_spec_config && !do_http_proxy) {
1474 LOG_ERROR(
1475 ("Warning: any host-specific configurations are ignored, add "
1476 "httpproxy:1 to allow them\n"));
1477 }
1478
1479 return 0;
1480 }
1481
freeHostCertHashItems(PLHashEntry * he,int i,void * arg)1482 int freeHostCertHashItems(PLHashEntry* he, int i, void* arg) {
1483 delete[](char*) he->key;
1484 delete[](char*) he->value;
1485 return HT_ENUMERATE_REMOVE;
1486 }
1487
freeHostRedirHashItems(PLHashEntry * he,int i,void * arg)1488 int freeHostRedirHashItems(PLHashEntry* he, int i, void* arg) {
1489 delete[](char*) he->key;
1490 delete[](char*) he->value;
1491 return HT_ENUMERATE_REMOVE;
1492 }
1493
freeClientAuthHashItems(PLHashEntry * he,int i,void * arg)1494 int freeClientAuthHashItems(PLHashEntry* he, int i, void* arg) {
1495 delete[](char*) he->key;
1496 delete (client_auth_option*)he->value;
1497 return HT_ENUMERATE_REMOVE;
1498 }
1499
freeSSL3HashItems(PLHashEntry * he,int i,void * arg)1500 int freeSSL3HashItems(PLHashEntry* he, int i, void* arg) {
1501 delete[](char*) he->key;
1502 return HT_ENUMERATE_REMOVE;
1503 }
1504
freeTLSHashItems(PLHashEntry * he,int i,void * arg)1505 int freeTLSHashItems(PLHashEntry* he, int i, void* arg) {
1506 delete[](char*) he->key;
1507 return HT_ENUMERATE_REMOVE;
1508 }
1509
freeRC4HashItems(PLHashEntry * he,int i,void * arg)1510 int freeRC4HashItems(PLHashEntry* he, int i, void* arg) {
1511 delete[](char*) he->key;
1512 return HT_ENUMERATE_REMOVE;
1513 }
1514
main(int argc,char ** argv)1515 int main(int argc, char** argv) {
1516 const char* configFilePath;
1517
1518 const char* logLevelEnv = PR_GetEnv("SSLTUNNEL_LOG_LEVEL");
1519 gLogLevel = logLevelEnv ? (LogLevel)atoi(logLevelEnv) : LEVEL_INFO;
1520
1521 if (argc == 1)
1522 configFilePath = "ssltunnel.cfg";
1523 else
1524 configFilePath = argv[1];
1525
1526 memset(&websocket_server, 0, sizeof(PRNetAddr));
1527
1528 if (parseConfigFile(configFilePath)) {
1529 LOG_ERROR((
1530 "Error: config file \"%s\" missing or formating incorrect\n"
1531 "Specify path to the config file as parameter to ssltunnel or \n"
1532 "create ssltunnel.cfg in the working directory.\n\n"
1533 "Example format of the config file:\n\n"
1534 " # Enable http/ssl tunneling proxy-like behavior.\n"
1535 " # If not specified ssltunnel simply does direct forward.\n"
1536 " httpproxy:1\n\n"
1537 " # Specify path to the certification database used.\n"
1538 " certdbdir:/path/to/certdb\n\n"
1539 " # Forward/proxy all requests in raw to 127.0.0.1:8888.\n"
1540 " forward:127.0.0.1:8888\n\n"
1541 " # Accept connections on port 4443 or 5678 resp. and "
1542 "authenticate\n"
1543 " # to any host ('*') using the 'server cert' or 'server cert 2' "
1544 "resp.\n"
1545 " listen:*:4443:server cert\n"
1546 " listen:*:5678:server cert 2\n\n"
1547 " # Accept connections on port 4443 and authenticate using\n"
1548 " # 'a different cert' when target host is 'my.host.name:443'.\n"
1549 " # This only works in httpproxy mode and has higher priority\n"
1550 " # than the previous option.\n"
1551 " listen:my.host.name:443:4443:a different cert\n\n"
1552 " # To make a specific host require or just request a client "
1553 "certificate\n"
1554 " # to authenticate use the following options. This can only be "
1555 "used\n"
1556 " # in httpproxy mode and only after the 'listen' option has "
1557 "been\n"
1558 " # specified. You also have to specify the tunnel listen port.\n"
1559 " clientauth:requesting-client-cert.host.com:443:4443:request\n"
1560 " clientauth:requiring-client-cert.host.com:443:4443:require\n"
1561 " # Proxy WebSocket traffic to the server at 127.0.0.1:9999,\n"
1562 " # instead of the server specified in the 'forward' option.\n"
1563 " websocketserver:127.0.0.1:9999\n",
1564 configFilePath));
1565 return 1;
1566 }
1567
1568 // create a thread pool to handle connections
1569 threads =
1570 PR_CreateThreadPool(INITIAL_THREADS * servers.size(),
1571 MAX_THREADS * servers.size(), DEFAULT_STACKSIZE);
1572 if (!threads) {
1573 LOG_ERROR(("Failed to create thread pool\n"));
1574 return 1;
1575 }
1576
1577 shutdown_lock = PR_NewLock();
1578 if (!shutdown_lock) {
1579 LOG_ERROR(("Failed to create lock\n"));
1580 PR_ShutdownThreadPool(threads);
1581 return 1;
1582 }
1583 shutdown_condvar = PR_NewCondVar(shutdown_lock);
1584 if (!shutdown_condvar) {
1585 LOG_ERROR(("Failed to create condvar\n"));
1586 PR_ShutdownThreadPool(threads);
1587 PR_DestroyLock(shutdown_lock);
1588 return 1;
1589 }
1590
1591 PK11_SetPasswordFunc(password_func);
1592
1593 // Initialize NSS
1594 if (NSS_Init(nssconfigdir.c_str()) != SECSuccess) {
1595 int32_t errorlen = PR_GetErrorTextLength();
1596 if (errorlen) {
1597 auto err = mozilla::MakeUnique<char[]>(errorlen + 1);
1598 PR_GetErrorText(err.get());
1599 LOG_ERROR(("Failed to init NSS: %s", err.get()));
1600 } else {
1601 LOG_ERROR(("Failed to init NSS: Cannot get error from NSPR."));
1602 }
1603 PR_ShutdownThreadPool(threads);
1604 PR_DestroyCondVar(shutdown_condvar);
1605 PR_DestroyLock(shutdown_lock);
1606 return 1;
1607 }
1608
1609 if (NSS_SetDomesticPolicy() != SECSuccess) {
1610 LOG_ERROR(("NSS_SetDomesticPolicy failed\n"));
1611 PR_ShutdownThreadPool(threads);
1612 PR_DestroyCondVar(shutdown_condvar);
1613 PR_DestroyLock(shutdown_lock);
1614 NSS_Shutdown();
1615 return 1;
1616 }
1617
1618 // these values should make NSS use the defaults
1619 if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) {
1620 LOG_ERROR(("SSL_ConfigServerSessionIDCache failed\n"));
1621 PR_ShutdownThreadPool(threads);
1622 PR_DestroyCondVar(shutdown_condvar);
1623 PR_DestroyLock(shutdown_lock);
1624 NSS_Shutdown();
1625 return 1;
1626 }
1627
1628 for (auto& server : servers) {
1629 // Not actually using this PRJob*...
1630 // PRJob* server_job =
1631 PR_QueueJob(threads, StartServer, &server, true);
1632 }
1633 // now wait for someone to tell us to quit
1634 PR_Lock(shutdown_lock);
1635 PR_WaitCondVar(shutdown_condvar, PR_INTERVAL_NO_TIMEOUT);
1636 PR_Unlock(shutdown_lock);
1637 shutdown_server = true;
1638 LOG_INFO(("Shutting down...\n"));
1639 // cleanup
1640 PR_ShutdownThreadPool(threads);
1641 PR_JoinThreadPool(threads);
1642 PR_DestroyCondVar(shutdown_condvar);
1643 PR_DestroyLock(shutdown_lock);
1644 if (NSS_Shutdown() == SECFailure) {
1645 LOG_DEBUG(("Leaked NSS objects!\n"));
1646 }
1647
1648 for (auto& server : servers) {
1649 PL_HashTableEnumerateEntries(server.host_cert_table, freeHostCertHashItems,
1650 nullptr);
1651 PL_HashTableEnumerateEntries(server.host_clientauth_table,
1652 freeClientAuthHashItems, nullptr);
1653 PL_HashTableEnumerateEntries(server.host_redir_table,
1654 freeHostRedirHashItems, nullptr);
1655 PL_HashTableEnumerateEntries(server.host_ssl3_table, freeSSL3HashItems,
1656 nullptr);
1657 PL_HashTableEnumerateEntries(server.host_tls1_table, freeTLSHashItems,
1658 nullptr);
1659 PL_HashTableEnumerateEntries(server.host_tls11_table, freeTLSHashItems,
1660 nullptr);
1661 PL_HashTableEnumerateEntries(server.host_tls12_table, freeTLSHashItems,
1662 nullptr);
1663 PL_HashTableEnumerateEntries(server.host_tls13_table, freeTLSHashItems,
1664 nullptr);
1665 PL_HashTableEnumerateEntries(server.host_rc4_table, freeRC4HashItems,
1666 nullptr);
1667 PL_HashTableEnumerateEntries(server.host_failhandshake_table,
1668 freeRC4HashItems, nullptr);
1669 PL_HashTableDestroy(server.host_cert_table);
1670 PL_HashTableDestroy(server.host_clientauth_table);
1671 PL_HashTableDestroy(server.host_redir_table);
1672 PL_HashTableDestroy(server.host_ssl3_table);
1673 PL_HashTableDestroy(server.host_tls1_table);
1674 PL_HashTableDestroy(server.host_tls11_table);
1675 PL_HashTableDestroy(server.host_tls12_table);
1676 PL_HashTableDestroy(server.host_tls13_table);
1677 PL_HashTableDestroy(server.host_rc4_table);
1678 PL_HashTableDestroy(server.host_failhandshake_table);
1679 }
1680
1681 PR_Cleanup();
1682 return 0;
1683 }
1684