1 /* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5 #include "TLSServer.h"
6
7 #include <stdio.h>
8 #include <string>
9 #include <vector>
10
11 #include "base64.h"
12 #include "mozilla/Move.h"
13 #include "mozilla/Sprintf.h"
14 #include "nspr.h"
15 #include "nss.h"
16 #include "plarenas.h"
17 #include "prenv.h"
18 #include "prerror.h"
19 #include "prnetdb.h"
20 #include "prtime.h"
21 #include "ssl.h"
22
23 namespace mozilla {
24 namespace test {
25
26 static const uint16_t LISTEN_PORT = 8443;
27
28 DebugLevel gDebugLevel = DEBUG_ERRORS;
29 uint16_t gCallbackPort = 0;
30
31 const char DEFAULT_CERT_NICKNAME[] = "default-ee";
32
33 struct Connection {
34 PRFileDesc* mSocket;
35 char mByte;
36
37 explicit Connection(PRFileDesc* aSocket);
38 ~Connection();
39 };
40
Connection(PRFileDesc * aSocket)41 Connection::Connection(PRFileDesc* aSocket) : mSocket(aSocket), mByte(0) {}
42
~Connection()43 Connection::~Connection() {
44 if (mSocket) {
45 PR_Close(mSocket);
46 }
47 }
48
PrintPRError(const char * aPrefix)49 void PrintPRError(const char* aPrefix) {
50 const char* err = PR_ErrorToName(PR_GetError());
51 if (err) {
52 if (gDebugLevel >= DEBUG_ERRORS) {
53 fprintf(stderr, "%s: %s\n", aPrefix, err);
54 }
55 } else {
56 if (gDebugLevel >= DEBUG_ERRORS) {
57 fprintf(stderr, "%s\n", aPrefix);
58 }
59 }
60 }
61
62 template <size_t N>
ReadFileToBuffer(const char * basePath,const char * filename,char (& buf)[N])63 SECStatus ReadFileToBuffer(const char* basePath, const char* filename,
64 char (&buf)[N]) {
65 static_assert(N > 0, "input buffer too small for ReadFileToBuffer");
66 if (snprintf(buf, N - 1, "%s/%s", basePath, filename) == 0) {
67 PrintPRError("snprintf failed");
68 return SECFailure;
69 }
70 UniquePRFileDesc fd(PR_OpenFile(buf, PR_RDONLY, 0));
71 if (!fd) {
72 PrintPRError("PR_Open failed");
73 return SECFailure;
74 }
75 int32_t fileSize = PR_Available(fd.get());
76 if (fileSize < 0) {
77 PrintPRError("PR_Available failed");
78 return SECFailure;
79 }
80 if (static_cast<size_t>(fileSize) > N - 1) {
81 PR_fprintf(PR_STDERR, "file too large - not reading\n");
82 return SECFailure;
83 }
84 int32_t bytesRead = PR_Read(fd.get(), buf, fileSize);
85 if (bytesRead != fileSize) {
86 PrintPRError("PR_Read failed");
87 return SECFailure;
88 }
89 buf[bytesRead] = 0;
90 return SECSuccess;
91 }
92
AddKeyFromFile(const char * basePath,const char * filename)93 SECStatus AddKeyFromFile(const char* basePath, const char* filename) {
94 const char* PRIVATE_KEY_HEADER = "-----BEGIN PRIVATE KEY-----";
95 const char* PRIVATE_KEY_FOOTER = "-----END PRIVATE KEY-----";
96
97 char buf[16384] = {0};
98 SECStatus rv = ReadFileToBuffer(basePath, filename, buf);
99 if (rv != SECSuccess) {
100 return rv;
101 }
102 if (strncmp(buf, PRIVATE_KEY_HEADER, strlen(PRIVATE_KEY_HEADER)) != 0) {
103 PR_fprintf(PR_STDERR, "invalid key - not importing\n");
104 return SECFailure;
105 }
106 const char* bufPtr = buf + strlen(PRIVATE_KEY_HEADER);
107 size_t bufLen = strlen(buf);
108 char base64[16384] = {0};
109 char* base64Ptr = base64;
110 while (bufPtr < buf + bufLen) {
111 if (strncmp(bufPtr, PRIVATE_KEY_FOOTER, strlen(PRIVATE_KEY_FOOTER)) == 0) {
112 break;
113 }
114 if (*bufPtr != '\r' && *bufPtr != '\n') {
115 *base64Ptr = *bufPtr;
116 base64Ptr++;
117 }
118 bufPtr++;
119 }
120
121 unsigned int binLength;
122 UniquePORTString bin(
123 BitwiseCast<char*, unsigned char*>(ATOB_AsciiToData(base64, &binLength)));
124 if (!bin || binLength == 0) {
125 PrintPRError("ATOB_AsciiToData failed");
126 return SECFailure;
127 }
128 UniqueSECItem secitem(::SECITEM_AllocItem(nullptr, nullptr, binLength));
129 if (!secitem) {
130 PrintPRError("SECITEM_AllocItem failed");
131 return SECFailure;
132 }
133 PORT_Memcpy(secitem->data, bin.get(), binLength);
134 UniquePK11SlotInfo slot(PK11_GetInternalKeySlot());
135 if (!slot) {
136 PrintPRError("PK11_GetInternalKeySlot failed");
137 return SECFailure;
138 }
139 if (PK11_NeedUserInit(slot.get())) {
140 if (PK11_InitPin(slot.get(), nullptr, nullptr) != SECSuccess) {
141 PrintPRError("PK11_InitPin failed");
142 return SECFailure;
143 }
144 }
145 SECKEYPrivateKey* privateKey;
146 if (PK11_ImportDERPrivateKeyInfoAndReturnKey(
147 slot.get(), secitem.get(), nullptr, nullptr, true, false, KU_ALL,
148 &privateKey, nullptr) != SECSuccess) {
149 PrintPRError("PK11_ImportDERPrivateKeyInfoAndReturnKey failed");
150 return SECFailure;
151 }
152 SECKEY_DestroyPrivateKey(privateKey);
153 return SECSuccess;
154 }
155
DecodeCertCallback(void * arg,SECItem ** certs,int numcerts)156 SECStatus DecodeCertCallback(void* arg, SECItem** certs, int numcerts) {
157 if (numcerts != 1) {
158 PR_SetError(SEC_ERROR_LIBRARY_FAILURE, 0);
159 return SECFailure;
160 }
161
162 SECItem* certDEROut = static_cast<SECItem*>(arg);
163 return SECITEM_CopyItem(nullptr, certDEROut, *certs);
164 }
165
AddCertificateFromFile(const char * basePath,const char * filename)166 SECStatus AddCertificateFromFile(const char* basePath, const char* filename) {
167 char buf[16384] = {0};
168 SECStatus rv = ReadFileToBuffer(basePath, filename, buf);
169 if (rv != SECSuccess) {
170 return rv;
171 }
172 ScopedAutoSECItem certDER;
173 rv = CERT_DecodeCertPackage(buf, strlen(buf), DecodeCertCallback, &certDER);
174 if (rv != SECSuccess) {
175 PrintPRError("CERT_DecodeCertPackage failed");
176 return rv;
177 }
178 UniqueCERTCertificate cert(CERT_NewTempCertificate(
179 CERT_GetDefaultCertDB(), &certDER, nullptr, false, true));
180 if (!cert) {
181 PrintPRError("CERT_NewTempCertificate failed");
182 return SECFailure;
183 }
184 UniquePK11SlotInfo slot(PK11_GetInternalKeySlot());
185 if (!slot) {
186 PrintPRError("PK11_GetInternalKeySlot failed");
187 return SECFailure;
188 }
189 // The nickname is the filename without '.pem'.
190 std::string nickname(filename, strlen(filename) - 4);
191 rv = PK11_ImportCert(slot.get(), cert.get(), CK_INVALID_HANDLE,
192 nickname.c_str(), false);
193 if (rv != SECSuccess) {
194 PrintPRError("PK11_ImportCert failed");
195 return rv;
196 }
197 return SECSuccess;
198 }
199
LoadCertificatesAndKeys(const char * basePath)200 SECStatus LoadCertificatesAndKeys(const char* basePath) {
201 // The NSS cert DB path could have been specified as "sql:path". Trim off
202 // the leading "sql:" if so.
203 if (strncmp(basePath, "sql:", 4) == 0) {
204 basePath = basePath + 4;
205 }
206
207 UniquePRDir fdDir(PR_OpenDir(basePath));
208 if (!fdDir) {
209 PrintPRError("PR_OpenDir failed");
210 return SECFailure;
211 }
212 // On the B2G ICS emulator, operations taken in AddCertificateFromFile
213 // appear to interact poorly with readdir (more specifically, something is
214 // causing readdir to never return null - it indefinitely loops through every
215 // file in the directory, which causes timeouts). Rather than waste more time
216 // chasing this down, loading certificates and keys happens in two phases:
217 // filename collection and then loading. (This is probably a good
218 // idea anyway because readdir isn't reentrant. Something could change later
219 // such that it gets called as a result of calling AddCertificateFromFile or
220 // AddKeyFromFile.)
221 std::vector<std::string> certificates;
222 std::vector<std::string> keys;
223 for (PRDirEntry* dirEntry = PR_ReadDir(fdDir.get(), PR_SKIP_BOTH); dirEntry;
224 dirEntry = PR_ReadDir(fdDir.get(), PR_SKIP_BOTH)) {
225 size_t nameLength = strlen(dirEntry->name);
226 if (nameLength > 4) {
227 if (strncmp(dirEntry->name + nameLength - 4, ".pem", 4) == 0) {
228 certificates.push_back(dirEntry->name);
229 } else if (strncmp(dirEntry->name + nameLength - 4, ".key", 4) == 0) {
230 keys.push_back(dirEntry->name);
231 }
232 }
233 }
234 SECStatus rv;
235 for (std::string& certificate : certificates) {
236 rv = AddCertificateFromFile(basePath, certificate.c_str());
237 if (rv != SECSuccess) {
238 return rv;
239 }
240 }
241 for (std::string& key : keys) {
242 rv = AddKeyFromFile(basePath, key.c_str());
243 if (rv != SECSuccess) {
244 return rv;
245 }
246 }
247 return SECSuccess;
248 }
249
InitializeNSS(const char * nssCertDBDir)250 SECStatus InitializeNSS(const char* nssCertDBDir) {
251 // Try initializing an existing DB.
252 if (NSS_Init(nssCertDBDir) == SECSuccess) {
253 return SECSuccess;
254 }
255
256 // Create a new DB if there is none...
257 SECStatus rv = NSS_Initialize(nssCertDBDir, nullptr, nullptr, nullptr, 0);
258 if (rv != SECSuccess) {
259 return rv;
260 }
261
262 // ...and load all certificates into it.
263 return LoadCertificatesAndKeys(nssCertDBDir);
264 }
265
SendAll(PRFileDesc * aSocket,const char * aData,size_t aDataLen)266 nsresult SendAll(PRFileDesc* aSocket, const char* aData, size_t aDataLen) {
267 if (gDebugLevel >= DEBUG_VERBOSE) {
268 fprintf(stderr, "sending '%s'\n", aData);
269 }
270
271 while (aDataLen > 0) {
272 int32_t bytesSent =
273 PR_Send(aSocket, aData, aDataLen, 0, PR_INTERVAL_NO_TIMEOUT);
274 if (bytesSent == -1) {
275 PrintPRError("PR_Send failed");
276 return NS_ERROR_FAILURE;
277 }
278
279 aDataLen -= bytesSent;
280 aData += bytesSent;
281 }
282
283 return NS_OK;
284 }
285
ReplyToRequest(Connection * aConn)286 nsresult ReplyToRequest(Connection* aConn) {
287 // For debugging purposes, SendAll can print out what it's sending.
288 // So, any strings we give to it to send need to be null-terminated.
289 char buf[2] = {aConn->mByte, 0};
290 return SendAll(aConn->mSocket, buf, 1);
291 }
292
SetupTLS(Connection * aConn,PRFileDesc * aModelSocket)293 nsresult SetupTLS(Connection* aConn, PRFileDesc* aModelSocket) {
294 PRFileDesc* sslSocket = SSL_ImportFD(aModelSocket, aConn->mSocket);
295 if (!sslSocket) {
296 PrintPRError("SSL_ImportFD failed");
297 return NS_ERROR_FAILURE;
298 }
299 aConn->mSocket = sslSocket;
300
301 SSL_OptionSet(sslSocket, SSL_SECURITY, true);
302 SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, false);
303 SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_SERVER, true);
304
305 SSL_ResetHandshake(sslSocket, /* asServer */ 1);
306
307 return NS_OK;
308 }
309
ReadRequest(Connection * aConn)310 nsresult ReadRequest(Connection* aConn) {
311 int32_t bytesRead =
312 PR_Recv(aConn->mSocket, &aConn->mByte, 1, 0, PR_INTERVAL_NO_TIMEOUT);
313 if (bytesRead < 0) {
314 PrintPRError("PR_Recv failed");
315 return NS_ERROR_FAILURE;
316 } else if (bytesRead == 0) {
317 PR_SetError(PR_IO_ERROR, 0);
318 PrintPRError("PR_Recv EOF in ReadRequest");
319 return NS_ERROR_FAILURE;
320 } else {
321 if (gDebugLevel >= DEBUG_VERBOSE) {
322 fprintf(stderr, "read '0x%hhx'\n", aConn->mByte);
323 }
324 }
325 return NS_OK;
326 }
327
HandleConnection(PRFileDesc * aSocket,const UniquePRFileDesc & aModelSocket)328 void HandleConnection(PRFileDesc* aSocket,
329 const UniquePRFileDesc& aModelSocket) {
330 Connection conn(aSocket);
331 nsresult rv = SetupTLS(&conn, aModelSocket.get());
332 if (NS_FAILED(rv)) {
333 PR_SetError(PR_INVALID_STATE_ERROR, 0);
334 PrintPRError("PR_Recv failed");
335 exit(1);
336 }
337
338 // TODO: On tests that are expected to fail (e.g. due to a revoked
339 // certificate), the client will close the connection wtihout sending us the
340 // request byte. In those cases, we should keep going. But, in the cases
341 // where the connection is supposed to suceed, we should verify that we
342 // successfully receive the request and send the response.
343 rv = ReadRequest(&conn);
344 if (NS_SUCCEEDED(rv)) {
345 rv = ReplyToRequest(&conn);
346 }
347 }
348
349 // returns 0 on success, non-zero on error
DoCallback()350 int DoCallback() {
351 UniquePRFileDesc socket(PR_NewTCPSocket());
352 if (!socket) {
353 PrintPRError("PR_NewTCPSocket failed");
354 return 1;
355 }
356
357 PRNetAddr addr;
358 PR_InitializeNetAddr(PR_IpAddrLoopback, gCallbackPort, &addr);
359 if (PR_Connect(socket.get(), &addr, PR_INTERVAL_NO_TIMEOUT) != PR_SUCCESS) {
360 PrintPRError("PR_Connect failed");
361 return 1;
362 }
363
364 const char* request = "GET / HTTP/1.0\r\n\r\n";
365 SendAll(socket.get(), request, strlen(request));
366 char buf[4096];
367 memset(buf, 0, sizeof(buf));
368 int32_t bytesRead =
369 PR_Recv(socket.get(), buf, sizeof(buf) - 1, 0, PR_INTERVAL_NO_TIMEOUT);
370 if (bytesRead < 0) {
371 PrintPRError("PR_Recv failed 1");
372 return 1;
373 }
374 if (bytesRead == 0) {
375 fprintf(stderr, "PR_Recv eof 1\n");
376 return 1;
377 }
378 fprintf(stderr, "%s\n", buf);
379 return 0;
380 }
381
ConfigSecureServerWithNamedCert(PRFileDesc * fd,const char * certName,UniqueCERTCertificate * certOut,SSLKEAType * keaOut)382 SECStatus ConfigSecureServerWithNamedCert(
383 PRFileDesc* fd, const char* certName,
384 /*optional*/ UniqueCERTCertificate* certOut,
385 /*optional*/ SSLKEAType* keaOut) {
386 UniqueCERTCertificate cert(PK11_FindCertFromNickname(certName, nullptr));
387 if (!cert) {
388 PrintPRError("PK11_FindCertFromNickname failed");
389 return SECFailure;
390 }
391 // If an intermediate certificate issued the server certificate (rather than
392 // directly by a trust anchor), we want to send it along in the handshake so
393 // we don't encounter unknown issuer errors when that's not what we're
394 // testing.
395 UniqueCERTCertificateList certList;
396 UniqueCERTCertificate issuerCert(
397 CERT_FindCertByName(CERT_GetDefaultCertDB(), &cert->derIssuer));
398 // If we can't find the issuer cert, continue without it.
399 if (issuerCert) {
400 // Sadly, CERTCertificateList does not have a CERT_NewCertificateList
401 // utility function, so we must create it ourselves. This consists
402 // of creating an arena, allocating space for the CERTCertificateList,
403 // and then transferring ownership of the arena to that list.
404 UniquePLArenaPool arena(PORT_NewArena(DER_DEFAULT_CHUNKSIZE));
405 if (!arena) {
406 PrintPRError("PORT_NewArena failed");
407 return SECFailure;
408 }
409 certList.reset(static_cast<CERTCertificateList*>(
410 PORT_ArenaAlloc(arena.get(), sizeof(CERTCertificateList))));
411 if (!certList) {
412 PrintPRError("PORT_ArenaAlloc failed");
413 return SECFailure;
414 }
415 certList->arena = arena.release();
416 // We also have to manually copy the certificates we care about to the
417 // list, because there aren't any utility functions for that either.
418 certList->certs = static_cast<SECItem*>(
419 PORT_ArenaAlloc(certList->arena, 2 * sizeof(SECItem)));
420 if (SECITEM_CopyItem(certList->arena, certList->certs, &cert->derCert) !=
421 SECSuccess) {
422 PrintPRError("SECITEM_CopyItem failed");
423 return SECFailure;
424 }
425 if (SECITEM_CopyItem(certList->arena, certList->certs + 1,
426 &issuerCert->derCert) != SECSuccess) {
427 PrintPRError("SECITEM_CopyItem failed");
428 return SECFailure;
429 }
430 certList->len = 2;
431 }
432
433 UniquePK11SlotInfo slot(PK11_GetInternalKeySlot());
434 if (!slot) {
435 PrintPRError("PK11_GetInternalKeySlot failed");
436 return SECFailure;
437 }
438 UniqueSECKEYPrivateKey key(
439 PK11_FindKeyByDERCert(slot.get(), cert.get(), nullptr));
440 if (!key) {
441 PrintPRError("PK11_FindKeyByDERCert failed");
442 return SECFailure;
443 }
444
445 SSLKEAType certKEA = NSS_FindCertKEAType(cert.get());
446
447 if (SSL_ConfigSecureServerWithCertChain(fd, cert.get(), certList.get(),
448 key.get(), certKEA) != SECSuccess) {
449 PrintPRError("SSL_ConfigSecureServer failed");
450 return SECFailure;
451 }
452
453 if (certOut) {
454 *certOut = Move(cert);
455 }
456
457 if (keaOut) {
458 *keaOut = certKEA;
459 }
460
461 SSL_OptionSet(fd, SSL_NO_CACHE, false);
462 SSL_OptionSet(fd, SSL_ENABLE_SESSION_TICKETS, true);
463
464 return SECSuccess;
465 }
466
StartServer(const char * nssCertDBDir,SSLSNISocketConfig sniSocketConfig,void * sniSocketConfigArg)467 int StartServer(const char* nssCertDBDir, SSLSNISocketConfig sniSocketConfig,
468 void* sniSocketConfigArg) {
469 const char* debugLevel = PR_GetEnv("MOZ_TLS_SERVER_DEBUG_LEVEL");
470 if (debugLevel) {
471 int level = atoi(debugLevel);
472 switch (level) {
473 case DEBUG_ERRORS:
474 gDebugLevel = DEBUG_ERRORS;
475 break;
476 case DEBUG_WARNINGS:
477 gDebugLevel = DEBUG_WARNINGS;
478 break;
479 case DEBUG_VERBOSE:
480 gDebugLevel = DEBUG_VERBOSE;
481 break;
482 default:
483 PrintPRError("invalid MOZ_TLS_SERVER_DEBUG_LEVEL");
484 return 1;
485 }
486 }
487
488 const char* callbackPort = PR_GetEnv("MOZ_TLS_SERVER_CALLBACK_PORT");
489 if (callbackPort) {
490 gCallbackPort = atoi(callbackPort);
491 }
492
493 if (InitializeNSS(nssCertDBDir) != SECSuccess) {
494 PR_fprintf(PR_STDERR, "InitializeNSS failed");
495 return 1;
496 }
497
498 if (NSS_SetDomesticPolicy() != SECSuccess) {
499 PrintPRError("NSS_SetDomesticPolicy failed");
500 return 1;
501 }
502
503 if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) {
504 PrintPRError("SSL_ConfigServerSessionIDCache failed");
505 return 1;
506 }
507
508 UniquePRFileDesc serverSocket(PR_NewTCPSocket());
509 if (!serverSocket) {
510 PrintPRError("PR_NewTCPSocket failed");
511 return 1;
512 }
513
514 PRSocketOptionData socketOption;
515 socketOption.option = PR_SockOpt_Reuseaddr;
516 socketOption.value.reuse_addr = true;
517 PR_SetSocketOption(serverSocket.get(), &socketOption);
518
519 PRNetAddr serverAddr;
520 PR_InitializeNetAddr(PR_IpAddrLoopback, LISTEN_PORT, &serverAddr);
521 if (PR_Bind(serverSocket.get(), &serverAddr) != PR_SUCCESS) {
522 PrintPRError("PR_Bind failed");
523 return 1;
524 }
525
526 if (PR_Listen(serverSocket.get(), 1) != PR_SUCCESS) {
527 PrintPRError("PR_Listen failed");
528 return 1;
529 }
530
531 UniquePRFileDesc rawModelSocket(PR_NewTCPSocket());
532 if (!rawModelSocket) {
533 PrintPRError("PR_NewTCPSocket failed for rawModelSocket");
534 return 1;
535 }
536
537 UniquePRFileDesc modelSocket(SSL_ImportFD(nullptr, rawModelSocket.release()));
538 if (!modelSocket) {
539 PrintPRError("SSL_ImportFD of rawModelSocket failed");
540 return 1;
541 }
542
543 if (SSL_SNISocketConfigHook(modelSocket.get(), sniSocketConfig,
544 sniSocketConfigArg) != SECSuccess) {
545 PrintPRError("SSL_SNISocketConfigHook failed");
546 return 1;
547 }
548
549 // We have to configure the server with a certificate, but it's not one
550 // we're actually going to end up using. In the SNI callback, we pick
551 // the right certificate for the connection.
552 if (ConfigSecureServerWithNamedCert(modelSocket.get(), DEFAULT_CERT_NICKNAME,
553 nullptr, nullptr) != SECSuccess) {
554 return 1;
555 }
556
557 if (gCallbackPort != 0) {
558 if (DoCallback()) {
559 return 1;
560 }
561 }
562
563 while (true) {
564 PRNetAddr clientAddr;
565 PRFileDesc* clientSocket =
566 PR_Accept(serverSocket.get(), &clientAddr, PR_INTERVAL_NO_TIMEOUT);
567 HandleConnection(clientSocket, modelSocket);
568 }
569
570 return 0;
571 }
572
573 } // namespace test
574 } // namespace mozilla
575