1 /* This Source Code Form is subject to the terms of the Mozilla Public
2  * License, v. 2.0. If a copy of the MPL was not distributed with this
3  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4 
5 /****************************************************************************
6  *  SSL client program that tests  a server for proper operation of SSL2,   *
7  *  SSL3, and TLS. Test propder certificate installation.                   *
8  *                                                                          *
9  *  This code was modified from the SSLSample code also kept in the NSS     *
10  *  directory.                                                              *
11  ****************************************************************************/
12 
13 #include <stdio.h>
14 #include <string.h>
15 
16 #if defined(XP_UNIX)
17 #include <unistd.h>
18 #endif
19 
20 #include "prerror.h"
21 
22 #include "pk11func.h"
23 #include "secmod.h"
24 #include "secitem.h"
25 
26 #include <stdlib.h>
27 #include <errno.h>
28 #include <fcntl.h>
29 #include <stdarg.h>
30 
31 #include "nspr.h"
32 #include "plgetopt.h"
33 #include "prio.h"
34 #include "prnetdb.h"
35 #include "nss.h"
36 #include "secutil.h"
37 #include "ocsp.h"
38 
39 #include "vfyserv.h"
40 
41 #define RD_BUF_SIZE (60 * 1024)
42 
43 extern int ssl3CipherSuites[];
44 extern int numSSL3CipherSuites;
45 
46 GlobalThreadMgr threadMGR;
47 char *certNickname = NULL;
48 char *hostName = NULL;
49 secuPWData pwdata = { PW_NONE, 0 };
50 unsigned short port = 0;
51 PRBool dumpChain;
52 
53 static void
Usage(const char * progName)54 Usage(const char *progName)
55 {
56     PRFileDesc *pr_stderr;
57 
58     pr_stderr = PR_STDERR;
59 
60     PR_fprintf(pr_stderr, "Usage:\n"
61                           "   %s  [-c ] [-o] [-p port] [-d dbdir] [-w password] [-f pwfile]\n"
62                           "   \t\t[-C cipher(s)]  [-l <url> -t <nickname> ] hostname",
63                progName);
64     PR_fprintf(pr_stderr, "\nWhere:\n");
65     PR_fprintf(pr_stderr,
66                "  %-13s dump server cert chain into files\n",
67                "-c");
68     PR_fprintf(pr_stderr,
69                "  %-13s perform server cert OCSP check\n",
70                "-o");
71     PR_fprintf(pr_stderr,
72                "  %-13s server port to be used\n",
73                "-p");
74     PR_fprintf(pr_stderr,
75                "  %-13s use security databases in \"dbdir\"\n",
76                "-d dbdir");
77     PR_fprintf(pr_stderr,
78                "  %-13s key database password\n",
79                "-w password");
80     PR_fprintf(pr_stderr,
81                "  %-13s token password file\n",
82                "-f pwfile");
83     PR_fprintf(pr_stderr,
84                "  %-13s communication cipher list\n",
85                "-C cipher(s)");
86     PR_fprintf(pr_stderr,
87                "  %-13s OCSP responder location. This location is used to\n"
88                "  %-13s check  status  of a server  certificate.  If  not \n"
89                "  %-13s specified, location  will  be taken  from the AIA\n"
90                "  %-13s server certificate extension.\n",
91                "-l url", "", "", "");
92     PR_fprintf(pr_stderr,
93                "  %-13s OCSP Trusted Responder Cert nickname\n\n",
94                "-t nickname");
95 
96     exit(1);
97 }
98 
99 PRFileDesc *
setupSSLSocket(PRNetAddr * addr)100 setupSSLSocket(PRNetAddr *addr)
101 {
102     PRFileDesc *tcpSocket;
103     PRFileDesc *sslSocket;
104     PRSocketOptionData socketOption;
105     PRStatus prStatus;
106     SECStatus secStatus;
107 
108     tcpSocket = PR_NewTCPSocket();
109     if (tcpSocket == NULL) {
110         errWarn("PR_NewTCPSocket");
111     }
112 
113     /* Make the socket blocking. */
114     socketOption.option = PR_SockOpt_Nonblocking;
115     socketOption.value.non_blocking = PR_FALSE;
116 
117     prStatus = PR_SetSocketOption(tcpSocket, &socketOption);
118     if (prStatus != PR_SUCCESS) {
119         errWarn("PR_SetSocketOption");
120         goto loser;
121     }
122 
123     /* Import the socket into the SSL layer. */
124     sslSocket = SSL_ImportFD(NULL, tcpSocket);
125     if (!sslSocket) {
126         errWarn("SSL_ImportFD");
127         goto loser;
128     }
129 
130     /* Set configuration options. */
131     secStatus = SSL_OptionSet(sslSocket, SSL_SECURITY, PR_TRUE);
132     if (secStatus != SECSuccess) {
133         errWarn("SSL_OptionSet:SSL_SECURITY");
134         goto loser;
135     }
136 
137     secStatus = SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, PR_TRUE);
138     if (secStatus != SECSuccess) {
139         errWarn("SSL_OptionSet:SSL_HANDSHAKE_AS_CLIENT");
140         goto loser;
141     }
142 
143     /* Set SSL callback routines. */
144     secStatus = SSL_GetClientAuthDataHook(sslSocket,
145                                           (SSLGetClientAuthData)myGetClientAuthData,
146                                           (void *)certNickname);
147     if (secStatus != SECSuccess) {
148         errWarn("SSL_GetClientAuthDataHook");
149         goto loser;
150     }
151 
152     secStatus = SSL_AuthCertificateHook(sslSocket,
153                                         (SSLAuthCertificate)myAuthCertificate,
154                                         (void *)CERT_GetDefaultCertDB());
155     if (secStatus != SECSuccess) {
156         errWarn("SSL_AuthCertificateHook");
157         goto loser;
158     }
159 
160     secStatus = SSL_BadCertHook(sslSocket,
161                                 (SSLBadCertHandler)myBadCertHandler, NULL);
162     if (secStatus != SECSuccess) {
163         errWarn("SSL_BadCertHook");
164         goto loser;
165     }
166 
167     secStatus = SSL_HandshakeCallback(sslSocket,
168                                       myHandshakeCallback,
169                                       NULL);
170     if (secStatus != SECSuccess) {
171         errWarn("SSL_HandshakeCallback");
172         goto loser;
173     }
174 
175     return sslSocket;
176 
177 loser:
178 
179     PR_Close(tcpSocket);
180     return NULL;
181 }
182 
183 const char requestString[] = { "GET /testfile HTTP/1.0\r\n\r\n" };
184 
185 SECStatus
handle_connection(PRFileDesc * sslSocket,int connection)186 handle_connection(PRFileDesc *sslSocket, int connection)
187 {
188     int countRead = 0;
189     PRInt32 numBytes;
190     char *readBuffer;
191 
192     readBuffer = PORT_Alloc(RD_BUF_SIZE);
193     if (!readBuffer) {
194         exitErr("PORT_Alloc");
195     }
196 
197     /* compose the http request here. */
198 
199     numBytes = PR_Write(sslSocket, requestString, strlen(requestString));
200     if (numBytes <= 0) {
201         errWarn("PR_Write");
202         PR_Free(readBuffer);
203         readBuffer = NULL;
204         return SECFailure;
205     }
206 
207     /* read until EOF */
208     while (PR_TRUE) {
209         numBytes = PR_Read(sslSocket, readBuffer, RD_BUF_SIZE);
210         if (numBytes == 0) {
211             break; /* EOF */
212         }
213         if (numBytes < 0) {
214             errWarn("PR_Read");
215             break;
216         }
217         countRead += numBytes;
218     }
219 
220     printSecurityInfo(stderr, sslSocket);
221 
222     PR_Free(readBuffer);
223     readBuffer = NULL;
224 
225     /* Caller closes the socket. */
226 
227     fprintf(stderr,
228             "***** Connection %d read %d bytes total.\n",
229             connection, countRead);
230 
231     return SECSuccess; /* success */
232 }
233 
234 #define BYTE(n, i) (((i) >> ((n)*8)) & 0xff)
235 
236 /* one copy of this function is launched in a separate thread for each
237 ** connection to be made.
238 */
239 SECStatus
do_connects(void * a,int connection)240 do_connects(void *a, int connection)
241 {
242     PRNetAddr *addr = (PRNetAddr *)a;
243     PRFileDesc *sslSocket;
244     PRHostEnt hostEntry;
245     char buffer[PR_NETDB_BUF_SIZE];
246     PRStatus prStatus;
247     PRIntn hostenum;
248     PRInt32 ip;
249     SECStatus secStatus;
250 
251     /* Set up SSL secure socket. */
252     sslSocket = setupSSLSocket(addr);
253     if (sslSocket == NULL) {
254         errWarn("setupSSLSocket");
255         return SECFailure;
256     }
257 
258     secStatus = SSL_SetPKCS11PinArg(sslSocket, &pwdata);
259     if (secStatus != SECSuccess) {
260         errWarn("SSL_SetPKCS11PinArg");
261         return secStatus;
262     }
263 
264     secStatus = SSL_SetURL(sslSocket, hostName);
265     if (secStatus != SECSuccess) {
266         errWarn("SSL_SetURL");
267         return secStatus;
268     }
269 
270     /* Prepare and setup network connection. */
271     prStatus = PR_GetHostByName(hostName, buffer, sizeof(buffer), &hostEntry);
272     if (prStatus != PR_SUCCESS) {
273         errWarn("PR_GetHostByName");
274         return SECFailure;
275     }
276 
277     hostenum = PR_EnumerateHostEnt(0, &hostEntry, port, addr);
278     if (hostenum == -1) {
279         errWarn("PR_EnumerateHostEnt");
280         return SECFailure;
281     }
282 
283     ip = PR_ntohl(addr->inet.ip);
284     fprintf(stderr,
285             "Connecting to host %s (addr %d.%d.%d.%d) on port %d\n",
286             hostName, BYTE(3, ip), BYTE(2, ip), BYTE(1, ip),
287             BYTE(0, ip), PR_ntohs(addr->inet.port));
288 
289     prStatus = PR_Connect(sslSocket, addr, PR_INTERVAL_NO_TIMEOUT);
290     if (prStatus != PR_SUCCESS) {
291         errWarn("PR_Connect");
292         return SECFailure;
293     }
294 
295 /* Established SSL connection, ready to send data. */
296 #if 0
297     secStatus = SSL_ForceHandshake(sslSocket);
298     if (secStatus != SECSuccess) {
299         errWarn("SSL_ForceHandshake");
300         return secStatus;
301     }
302 #endif
303 
304     secStatus = SSL_ResetHandshake(sslSocket, /* asServer */ PR_FALSE);
305     if (secStatus != SECSuccess) {
306         errWarn("SSL_ResetHandshake");
307         prStatus = PR_Close(sslSocket);
308         if (prStatus != PR_SUCCESS) {
309             errWarn("PR_Close");
310         }
311         return secStatus;
312     }
313 
314     secStatus = handle_connection(sslSocket, connection);
315     if (secStatus != SECSuccess) {
316         /* error already printed out in handle_connection */
317         /* errWarn("handle_connection"); */
318         prStatus = PR_Close(sslSocket);
319         if (prStatus != PR_SUCCESS) {
320             errWarn("PR_Close");
321         }
322         return secStatus;
323     }
324 
325     PR_Close(sslSocket);
326     return SECSuccess;
327 }
328 
329 void
client_main(unsigned short port,int connections,const char * hostName)330 client_main(unsigned short port,
331             int connections,
332             const char *hostName)
333 {
334     int i;
335     SECStatus secStatus;
336     PRStatus prStatus;
337     PRInt32 rv;
338     PRNetAddr addr;
339     PRHostEnt hostEntry;
340     char buffer[PR_NETDB_BUF_SIZE];
341 
342     /* Setup network connection. */
343     prStatus = PR_GetHostByName(hostName, buffer, sizeof(buffer), &hostEntry);
344     if (prStatus != PR_SUCCESS) {
345         exitErr("PR_GetHostByName");
346     }
347 
348     rv = PR_EnumerateHostEnt(0, &hostEntry, port, &addr);
349     if (rv < 0) {
350         exitErr("PR_EnumerateHostEnt");
351     }
352 
353     secStatus = launch_thread(&threadMGR, do_connects, &addr, 1);
354     if (secStatus != SECSuccess) {
355         exitErr("launch_thread");
356     }
357 
358     if (connections > 1) {
359         /* wait for the first connection to terminate, then launch the rest. */
360         reap_threads(&threadMGR);
361         /* Start up the connections */
362         for (i = 2; i <= connections; ++i) {
363             secStatus = launch_thread(&threadMGR, do_connects, &addr, i);
364             if (secStatus != SECSuccess) {
365                 errWarn("launch_thread");
366             }
367         }
368     }
369 
370     reap_threads(&threadMGR);
371     destroy_thread_data(&threadMGR);
372 }
373 
374 #define HEXCHAR_TO_INT(c, i)                   \
375     if (((c) >= '0') && ((c) <= '9')) {        \
376         i = (c) - '0';                         \
377     } else if (((c) >= 'a') && ((c) <= 'f')) { \
378         i = (c) - 'a' + 10;                    \
379     } else if (((c) >= 'A') && ((c) <= 'F')) { \
380         i = (c) - 'A' + 10;                    \
381     } else {                                   \
382         Usage(progName);                       \
383     }
384 
385 int
main(int argc,char ** argv)386 main(int argc, char **argv)
387 {
388     char *certDir = NULL;
389     char *progName = NULL;
390     int connections = 1;
391     char *cipherString = NULL;
392     char *respUrl = NULL;
393     char *respCertName = NULL;
394     SECStatus secStatus;
395     PLOptState *optstate;
396     PLOptStatus status;
397     PRBool doOcspCheck = PR_FALSE;
398 
399     /* Call the NSPR initialization routines */
400     PR_Init(PR_SYSTEM_THREAD, PR_PRIORITY_NORMAL, 1);
401 
402     progName = PORT_Strdup(argv[0]);
403 
404     hostName = NULL;
405     optstate = PL_CreateOptState(argc, argv, "C:cd:f:l:n:p:ot:w:");
406     while ((status = PL_GetNextOpt(optstate)) == PL_OPT_OK) {
407         switch (optstate->option) {
408             case 'C':
409                 cipherString = PL_strdup(optstate->value);
410                 break;
411             case 'c':
412                 dumpChain = PR_TRUE;
413                 break;
414             case 'd':
415                 certDir = PL_strdup(optstate->value);
416                 break;
417             case 'l':
418                 respUrl = PL_strdup(optstate->value);
419                 break;
420             case 'p':
421                 port = PORT_Atoi(optstate->value);
422                 break;
423             case 'o':
424                 doOcspCheck = PR_TRUE;
425                 break;
426             case 't':
427                 respCertName = PL_strdup(optstate->value);
428                 break;
429             case 'w':
430                 pwdata.source = PW_PLAINTEXT;
431                 pwdata.data = PORT_Strdup(optstate->value);
432                 break;
433 
434             case 'f':
435                 pwdata.source = PW_FROMFILE;
436                 pwdata.data = PORT_Strdup(optstate->value);
437                 break;
438             case '\0':
439                 hostName = PL_strdup(optstate->value);
440                 break;
441             default:
442                 Usage(progName);
443         }
444     }
445 
446     if (port == 0) {
447         port = 443;
448     }
449 
450     if (port == 0 || hostName == NULL)
451         Usage(progName);
452 
453     if (doOcspCheck &&
454         ((respCertName != NULL && respUrl == NULL) ||
455          (respUrl != NULL && respCertName == NULL))) {
456         SECU_PrintError(progName, "options -l <url> and -t "
457                                   "<responder> must be used together");
458         Usage(progName);
459     }
460 
461     PK11_SetPasswordFunc(SECU_GetModulePassword);
462 
463     /* Initialize the NSS libraries. */
464     if (certDir) {
465         secStatus = NSS_Init(certDir);
466     } else {
467         secStatus = NSS_NoDB_Init(NULL);
468 
469         /* load the builtins */
470         SECMOD_AddNewModule("Builtins",
471                             DLL_PREFIX "nssckbi." DLL_SUFFIX, 0, 0);
472     }
473     if (secStatus != SECSuccess) {
474         exitErr("NSS_Init");
475     }
476     SECU_RegisterDynamicOids();
477 
478     if (doOcspCheck == PR_TRUE) {
479         SECStatus rv;
480         CERTCertDBHandle *handle = CERT_GetDefaultCertDB();
481         if (handle == NULL) {
482             SECU_PrintError(progName, "problem getting certdb handle");
483             goto cleanup;
484         }
485 
486         rv = CERT_EnableOCSPChecking(handle);
487         if (rv != SECSuccess) {
488             SECU_PrintError(progName, "error enabling OCSP checking");
489             goto cleanup;
490         }
491 
492         if (respUrl != NULL) {
493             rv = CERT_SetOCSPDefaultResponder(handle, respUrl,
494                                               respCertName);
495             if (rv != SECSuccess) {
496                 SECU_PrintError(progName,
497                                 "error setting default responder");
498                 goto cleanup;
499             }
500 
501             rv = CERT_EnableOCSPDefaultResponder(handle);
502             if (rv != SECSuccess) {
503                 SECU_PrintError(progName,
504                                 "error enabling default responder");
505                 goto cleanup;
506             }
507         }
508     }
509 
510     /* All cipher suites except RSA_NULL_MD5 are enabled by
511      * Domestic Policy. */
512     NSS_SetDomesticPolicy();
513     SSL_CipherPrefSetDefault(TLS_RSA_WITH_NULL_MD5, PR_TRUE);
514 
515     /* all the SSL2 and SSL3 cipher suites are enabled by default. */
516     if (cipherString) {
517         int ndx;
518 
519         /* disable all the ciphers, then enable the ones we want. */
520         disableAllSSLCiphers();
521 
522         while (0 != (ndx = *cipherString++)) {
523             int cipher = 0;
524 
525             if (ndx == ':') {
526                 int ctmp = 0;
527 
528                 HEXCHAR_TO_INT(*cipherString, ctmp)
529                 cipher |= (ctmp << 12);
530                 cipherString++;
531                 HEXCHAR_TO_INT(*cipherString, ctmp)
532                 cipher |= (ctmp << 8);
533                 cipherString++;
534                 HEXCHAR_TO_INT(*cipherString, ctmp)
535                 cipher |= (ctmp << 4);
536                 cipherString++;
537                 HEXCHAR_TO_INT(*cipherString, ctmp)
538                 cipher |= ctmp;
539                 cipherString++;
540             } else {
541                 if (!isalpha(ndx))
542                     Usage(progName);
543                 ndx = tolower(ndx) - 'a';
544                 if (ndx < numSSL3CipherSuites) {
545                     cipher = ssl3CipherSuites[ndx];
546                 }
547             }
548             if (cipher > 0) {
549                 SSL_CipherPrefSetDefault(cipher, PR_TRUE);
550             } else {
551                 Usage(progName);
552             }
553         }
554     }
555 
556     client_main(port, connections, hostName);
557 
558 cleanup:
559     if (doOcspCheck) {
560         CERTCertDBHandle *handle = CERT_GetDefaultCertDB();
561         CERT_DisableOCSPDefaultResponder(handle);
562         CERT_DisableOCSPChecking(handle);
563     }
564 
565     if (NSS_Shutdown() != SECSuccess) {
566         exit(1);
567     }
568 
569     PR_Cleanup();
570     PORT_Free(progName);
571     return 0;
572 }
573