1 /*
2  * Copyright (c) 2017, 2019-2020 Paul Mattes.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the names of Paul Mattes, Don Russell, Jeff Sparkes, GTRC
13  *       nor their contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY PAUL MATTES, "AS IS" AND ANY EXPRESS OR
17  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
18  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
19  * IN NO EVENT SHALL PAUL MATTES BE LIABLE FOR ANY DIRECT, INDIRECT,
20  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
21  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
22  * USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
23  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  * This module borrows freely from "TLS with Schannel" posted on
28  * http://www.coastrd.com/tls-with-schannel.
29  */
30 
31 /*
32  *	sio_schannel.c
33  *		Secure I/O via the Windows schannel facility.
34  */
35 
36 #include "globals.h"
37 
38 #define SECURITY_WIN32
39 #include <wincrypt.h>
40 #include <wintrust.h>
41 #include <schannel.h>
42 #include <security.h>
43 #include <sspi.h>
44 
45 #include "tls_config.h"
46 
47 #include "indent_s.h"
48 #include "sio.h"
49 #include "sioc.h"
50 #include "tls_passwd_gui.h"
51 #include "trace.h"
52 #include "utils.h"
53 #include "varbuf.h"
54 #include "w3misc.h"
55 #include "winvers.h"
56 
57 #if !defined(SP_PROT_TLS1_1_CLIENT)
58 # define SP_PROT_TLS1_1_CLIENT 0x200
59 #endif
60 
61 #if !defined(SP_PROT_TLS1_2_CLIENT)
62 # define SP_PROT_TLS1_2_CLIENT 0x800
63 #endif
64 
65 /* TLS protocols to negotiate. */
66 #define TLS_PROTOCOLS	\
67     (SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT)
68 
69 /* #define VERBOSE		1 */	/* dump protocol packets in hex */
70 
71 #define MIN_READ	50		/* small amount to read from the
72 					   socket at a time, so we are in
73 					   no danger of reading more than one
74 					   record */
75 #define INBUF		(16 * 1024)	/* preliminary input buffer size */
76 
77 /* Globals */
78 
79 /* Statics */
80 typedef struct {
81     socket_t sock;			/* socket */
82     const char *hostname;		/* server name */
83     bool negotiate_pending;		/* true if negotiate pending */
84     bool secure_unverified;		/* true if server cert not verified */
85     bool negotiated;			/* true if session is negotiated */
86 
87     CredHandle client_creds;		/* client credentials */
88     bool client_creds_set;		/* true if client_creds is valid */
89     bool manual;			/* true if manual validation needed */
90 
91     CtxtHandle context;			/* security context */
92     bool context_set;			/* true if context is valid */
93 
94     SecPkgContext_StreamSizes sizes;	/* stream sizes */
95 
96     char *session_info;			/* session information */
97     char *server_cert_info;		/* server cert information */
98 
99     char *rcvbuf;			/* receive buffer */
100     size_t rcvbuf_len;			/* receive buffer length */
101 
102     char *prbuf;			/* pending record buffer */
103     size_t prbuf_len;			/* pending record buffer size */
104 
105     char *sendbuf;			/* send buffer */
106 } schannel_sio_t;
107 
108 static tls_config_t *config;
109 static HCERTSTORE my_cert_store;
110 
111 /* Display the certificate chain. */
112 static void
display_cert_chain(varbuf_t * v,PCCERT_CONTEXT cert)113 display_cert_chain(varbuf_t *v, PCCERT_CONTEXT cert)
114 {
115     CHAR name[1024];
116     PCCERT_CONTEXT current_cert, issuer_cert;
117     PCERT_EXTENSION ext;
118     WCHAR *wcbuf = NULL;
119     DWORD wcsize;
120     DWORD mbsize;
121     char *mbbuf = NULL;
122     int i;
123 
124     /* Display leaf name. */
125     if (!CertNameToStr(cert->dwCertEncodingType,
126 		&cert->pCertInfo->Subject,
127 		CERT_X500_NAME_STR | CERT_NAME_STR_NO_PLUS_FLAG,
128 		name, sizeof(name))) {
129 	int err = GetLastError();
130 	vtrace("CertNameToStr(subject): error 0x%x (%s)\n", err,
131 		win32_strerror(err));
132     } else {
133 	vb_appendf(v, "Subject: %s\n", name);
134     }
135 
136     if (!CertNameToStr(cert->dwCertEncodingType,
137 		&cert->pCertInfo->Issuer,
138 		CERT_X500_NAME_STR | CERT_NAME_STR_NO_PLUS_FLAG,
139 		name, sizeof(name))) {
140 	int err = GetLastError();
141 	vtrace("CertNameToStr(issuer): error 0x%x (%s)\n", err,
142 		win32_strerror(err));
143     } else {
144 	vb_appendf(v, "Issuer: %s\n", name);
145     }
146 
147     /* Display the alternate name. */
148     do {
149 	ext = CertFindExtension(szOID_SUBJECT_ALT_NAME2,
150 		cert->pCertInfo->cExtension,
151 		cert->pCertInfo->rgExtension);
152 	if (ext == NULL) {
153 	    break;
154 	}
155 	if (!CryptFormatObject(X509_ASN_ENCODING, 0, 0, NULL,
156 		    szOID_SUBJECT_ALT_NAME2, ext->Value.pbData,
157 		    ext->Value.cbData, NULL, &wcsize)) {
158 	    break;
159 	}
160 	wcsize *= 4;
161 	wcbuf = (WCHAR *)Malloc(wcsize);
162 	if (!CryptFormatObject(X509_ASN_ENCODING, 0, 0, NULL,
163 		    szOID_SUBJECT_ALT_NAME2, ext->Value.pbData,
164 		    ext->Value.cbData, wcbuf, &wcsize)) {
165 	    break;
166 	}
167 	mbsize = WideCharToMultiByte(CP_ACP, 0, wcbuf, -1, NULL, 0, NULL, NULL);
168 	mbbuf = Malloc(mbsize);
169 	if (WideCharToMultiByte(CP_ACP, 0, wcbuf, -1, mbbuf, mbsize, NULL,
170 		    NULL) != mbsize) {
171 	    break;
172 	}
173 	vb_appendf(v, "Alternate names: %s\n", mbbuf);
174     } while (false);
175     if (wcbuf != NULL) {
176 	Free(wcbuf);
177     }
178     if (mbbuf != NULL) {
179 	Free(mbbuf);
180     }
181 
182     /* Display certificate chain. */
183     current_cert = cert;
184     i = 0;
185     while (current_cert != NULL) {
186 	DWORD verification_flags = 0;
187 
188 	i++;
189 	issuer_cert = CertGetIssuerCertificateFromStore(cert->hCertStore,
190 		current_cert, NULL, &verification_flags);
191 	if (issuer_cert == NULL) {
192             if (current_cert != cert) {
193 		CertFreeCertificateContext(current_cert);
194 	    }
195 	    break;
196 	}
197 
198 	if (!CertNameToStr(issuer_cert->dwCertEncodingType,
199 		    &issuer_cert->pCertInfo->Subject,
200 		    CERT_X500_NAME_STR | CERT_NAME_STR_NO_PLUS_FLAG,
201 		    name, sizeof(name))) {
202 	    int err = GetLastError();
203 	    vtrace("CertNameToStr(subject): error 0x%x (%s)\n", err,
204 		    win32_strerror(err));
205 	} else {
206 	    vb_appendf(v, "CA %d Subject: %s\n", i, name);
207 	}
208 
209 	if (!CertNameToStr(issuer_cert->dwCertEncodingType,
210 		    &issuer_cert->pCertInfo->Issuer,
211 		    CERT_X500_NAME_STR | CERT_NAME_STR_NO_PLUS_FLAG,
212 		    name, sizeof(name))) {
213 	    int err = GetLastError();
214 	    vtrace("CertNameToStr(issuer): error 0x%x (%s)\n", err,
215 		    win32_strerror(err));
216 	} else {
217 	    vb_appendf(v, "CA %d Issuer: %s\n", i, name);
218 	}
219 
220 	if (current_cert != cert) {
221 	    CertFreeCertificateContext(current_cert);
222 	}
223 	current_cert = issuer_cert;
224 	issuer_cert = NULL;
225     }
226 }
227 
228 /* Create security credentials. */
229 static SECURITY_STATUS
create_credentials(LPSTR friendly_name,PCredHandle creds,bool * manual)230 create_credentials(LPSTR friendly_name, PCredHandle creds, bool *manual)
231 {
232     TimeStamp ts_expiry;
233     SECURITY_STATUS status;
234     PCCERT_CONTEXT cert_context = NULL;
235     SCHANNEL_CRED schannel_cred;
236     varbuf_t v;
237     char *s, *t;
238 
239     *manual = false;
240 
241     /* Open the "MY" certificate store, where IE stores client certificates. */
242     if (my_cert_store == NULL) {
243 	my_cert_store = CertOpenSystemStore(0, "MY");
244 	if (my_cert_store == NULL) {
245 	    int err = GetLastError();
246 	    sioc_set_error("CertOpenSystemStore: error 0x%x (%s)\n", err,
247 		    win32_strerror(err));
248 	    return err;
249 	}
250     }
251 
252     /*
253      * If a friendly name name is specified, then attempt to find a client
254      * certificate. Otherwise, just create a NULL credential.
255      */
256     if (friendly_name != NULL) {
257 	for (;;) {
258 	    DWORD nbytes;
259 	    LPTSTR cert_friendly_name;
260 
261 	    /* Find a client certificate with the given friendly name. */
262 	    cert_context = CertFindCertificateInStore(
263 		    my_cert_store,	/* hCertStore */
264 		    X509_ASN_ENCODING,	/* dwCertEncodingType */
265 		    0,			/* dwFindFlags */
266 		    CERT_FIND_ANY,	/* dwFindType */
267 		    NULL,		/* *pvFindPara */
268 		    cert_context);	/* pPrevCertContext */
269 
270 	    if (cert_context == NULL) {
271 		int err = GetLastError();
272 		sioc_set_error("CertFindCertificateInStore: error 0x%x (%s)\n", err,
273 			win32_strerror(err));
274 		return err;
275 	    }
276 
277 	    nbytes = CertGetNameString(cert_context,
278 		    CERT_NAME_FRIENDLY_DISPLAY_TYPE,
279 		    0,
280 		    NULL,
281 		    NULL,
282 		    0);
283 	    cert_friendly_name = Malloc(nbytes);
284 	    nbytes = CertGetNameString(cert_context,
285 		    CERT_NAME_FRIENDLY_DISPLAY_TYPE,
286 		    0,
287 		    NULL,
288 		    cert_friendly_name,
289 		    nbytes);
290 	    if (!strcasecmp(friendly_name, cert_friendly_name)) {
291 		Free(cert_friendly_name);
292 		break;
293 	    }
294 
295 	    Free(cert_friendly_name);
296 	}
297 
298 	/* Display it. */
299 	vtrace("Client certificate:\n");
300 	vb_init(&v);
301 	display_cert_chain(&v, cert_context);
302 	s = vb_consume(&v);
303 	t = indent_s(s);
304 	vtrace("%s", t);
305 	Free(t);
306 	Free(s);
307     }
308 
309     /* Build Schannel credential structure. */
310     memset(&schannel_cred, 0, sizeof(schannel_cred));
311     schannel_cred.dwVersion  = SCHANNEL_CRED_VERSION;
312     if (cert_context != NULL) {
313 	schannel_cred.cCreds = 1;
314 	schannel_cred.paCred = &cert_context;
315     }
316 
317     /* Before Windows 10, you need to specify the protocols explicitly. */
318     if (!IsWindowsVersionOrGreater(10, 0, 0)) {
319 	schannel_cred.grbitEnabledProtocols = TLS_PROTOCOLS;
320     }
321 
322     schannel_cred.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS;
323 
324     /*
325      * If they don't want the host certificate checked, specify manual
326      * validation here and then don't validate.
327      */
328     if (!config->verify_host_cert || is_wine()) {
329 	schannel_cred.dwFlags |= SCH_CRED_MANUAL_CRED_VALIDATION;
330 	*manual = true;
331     } else {
332 	schannel_cred.dwFlags |= SCH_CRED_AUTO_CRED_VALIDATION;
333     }
334 
335     /* Create an SSPI credential. */
336     status = AcquireCredentialsHandle(
337 	    NULL,			/* Name of principal */
338 	    UNISP_NAME,			/* Name of package */
339 	    SECPKG_CRED_OUTBOUND,	/* Flags indicating use */
340 	    NULL,			/* Pointer to logon ID */
341 	    &schannel_cred,		/* Package specific data */
342 	    NULL,			/* Pointer to GetKey() func */
343 	    NULL,			/* Value to pass to GetKey() */
344 	    creds,			/* (out) Cred Handle */
345 	    &ts_expiry);		/* (out) Lifetime (optional) */
346 
347     if (status != SEC_E_OK) {
348 	sioc_set_error("AcquireCredentialsHandle: error 0x%x (%s)\n", status,
349 		win32_strerror(status));
350     }
351 
352     /* Free the certificate context. Schannel has already made its own copy. */
353     if (cert_context != NULL) {
354 	CertFreeCertificateContext(cert_context);
355     }
356 
357     return status;
358 }
359 
360 /* Get new client credentials. */
361 static void
get_new_client_credentials(CredHandle * creds,CtxtHandle * context)362 get_new_client_credentials(CredHandle *creds, CtxtHandle *context)
363 {
364     CredHandle                        new_creds;
365     SecPkgContext_IssuerListInfoEx    issuer_list_info;
366     PCCERT_CHAIN_CONTEXT              chain_context;
367     CERT_CHAIN_FIND_BY_ISSUER_PARA    find_by_issuer_params;
368     PCCERT_CONTEXT                    cert_context;
369     TimeStamp                         expiry;
370     SECURITY_STATUS                   status;
371     SCHANNEL_CRED                     schannel_cred;
372 
373     /* Read the list of trusted issuers from schannel. */
374     status = QueryContextAttributes(context, SECPKG_ATTR_ISSUER_LIST_EX,
375 	    (PVOID)&issuer_list_info);
376     if (status != SEC_E_OK) {
377 	vtrace("QueryContextAttributes: error 0x%x (%s)\n",
378 		(unsigned)status, win32_strerror(status));
379 	return;
380     }
381 
382     /* Enumerate the client certificates. */
383     memset(&find_by_issuer_params, 0, sizeof(find_by_issuer_params));
384 
385     find_by_issuer_params.cbSize = sizeof(find_by_issuer_params);
386     find_by_issuer_params.pszUsageIdentifier = szOID_PKIX_KP_CLIENT_AUTH;
387     find_by_issuer_params.dwKeySpec = 0;
388     find_by_issuer_params.cIssuer   = issuer_list_info.cIssuers;
389     find_by_issuer_params.rgIssuer  = issuer_list_info.aIssuers;
390 
391     chain_context = NULL;
392 
393     while (true) {
394 	/* Find a certificate chain. */
395         chain_context = CertFindChainInStore(
396 		my_cert_store,
397 		X509_ASN_ENCODING,
398 		0,
399 		CERT_CHAIN_FIND_BY_ISSUER,
400 		&find_by_issuer_params,
401 		chain_context);
402 	if (chain_context == NULL) {
403 	    vtrace("CertFindChainInStore: error 0x%x (%s)\n",
404 		    (unsigned)GetLastError(),
405 		    win32_strerror(GetLastError()));
406 	    break;
407 	}
408 
409 	/* Get pointer to leaf certificate context. */
410 	cert_context = chain_context->rgpChain[0]->rgpElement[0]->pCertContext;
411 
412 	/* Create schannel credential. */
413 	schannel_cred.dwVersion = SCHANNEL_CRED_VERSION;
414 	schannel_cred.cCreds = 1;
415 	schannel_cred.paCred = &cert_context;
416 
417 	status = AcquireCredentialsHandle(
418 		NULL,                   /* Name of principal */
419 		UNISP_NAME_A,           /* Name of package */
420 		SECPKG_CRED_OUTBOUND,   /* Flags indicating use */
421 		NULL,                   /* Pointer to logon ID */
422 		&schannel_cred,         /* Package specific data */
423 		NULL,                   /* Pointer to GetKey() func */
424 		NULL,                   /* Value to pass to GetKey() */
425 		&new_creds,             /* (out) Cred Handle */
426 		&expiry);               /* (out) Lifetime (optional) */
427 	if (status != SEC_E_OK) {
428 	    vtrace("AcquireCredentialsHandle: error 0x%x (%s)\n",
429 		    (unsigned)status, win32_strerror(status));
430 	    continue;
431 	}
432 
433 	/* Destroy the old credentials. */
434 	FreeCredentialsHandle(creds);
435 	*creds = new_creds;
436     }
437 }
438 
439 #if defined(VERBOSE) /*[*/
440 /* Display a hex dump of a buffer. */
441 static void
print_hex_dump(const char * prefix,int length,unsigned char * buffer)442 print_hex_dump(const char *prefix, int length, unsigned char *buffer)
443 {
444     int i, count, index;
445     static char rgbDigits[] = "0123456789abcdef";
446     char rgbLine[100];
447     int cbLine;
448 
449     for (index = 0; length; length -= count, buffer += count, index += count) {
450 	count = (length > 16)? 16: length;
451 	sprintf(rgbLine, "%4.4x  ", index);
452 	cbLine = 6;
453 
454 	for (i = 0; i < count; i++) {
455 	    rgbLine[cbLine++] = rgbDigits[buffer[i] >> 4];
456 	    rgbLine[cbLine++] = rgbDigits[buffer[i] & 0x0f];
457 	    if (i == 7) {
458 		rgbLine[cbLine++] = ':';
459 	    } else {
460 		rgbLine[cbLine++] = ' ';
461 	    }
462 	}
463 	for (; i < 16; i++) {
464 	    rgbLine[cbLine++] = ' ';
465 	    rgbLine[cbLine++] = ' ';
466 	    rgbLine[cbLine++] = ' ';
467 	}
468 	rgbLine[cbLine++] = ' ';
469 
470 	for (i = 0; i < count; i++) {
471 	    if (buffer[i] < 32 || buffer[i] > 126 || buffer[i] == '%') {
472 		rgbLine[cbLine++] = '.';
473 	    } else {
474 		rgbLine[cbLine++] = buffer[i];
475 	    }
476 	}
477 	rgbLine[cbLine++] = 0;
478 	vtrace("%s %s\n", prefix, rgbLine);
479     }
480 }
481 #endif /*]*/
482 
483 /* Client handshake, second phase. */
484 static SECURITY_STATUS
client_handshake_loop(schannel_sio_t * s,bool do_initial_read)485 client_handshake_loop(
486     schannel_sio_t *s,			/* in, out */
487     bool            do_initial_read)	/* in */
488 {
489     SecBufferDesc   out_buffer, in_buffer;
490     SecBuffer       in_buffers[2], out_buffers[1];
491     DWORD           ssp_i_flags, ssp_o_flags;
492     int             nrw;
493     TimeStamp       expiry;
494     SECURITY_STATUS ret;
495     bool            do_read;
496     int             n2read = MIN_READ;
497 
498     ssp_i_flags =
499 	ISC_REQ_SEQUENCE_DETECT   |
500 	ISC_REQ_REPLAY_DETECT     |
501 	ISC_REQ_CONFIDENTIALITY   |
502 	ISC_RET_EXTENDED_ERROR    |
503 	ISC_REQ_ALLOCATE_MEMORY   |
504 	ISC_REQ_STREAM;
505 
506     do_read = do_initial_read;
507 
508     /* Loop until the handshake is finished or an error occurs. */
509     ret = SEC_I_CONTINUE_NEEDED;
510 
511     while (ret == SEC_I_CONTINUE_NEEDED        ||
512 	   ret == SEC_E_INCOMPLETE_MESSAGE     ||
513 	   ret == SEC_I_INCOMPLETE_CREDENTIALS) {
514 	if (s->rcvbuf_len == 0 || ret == SEC_E_INCOMPLETE_MESSAGE) {
515 	    /* Read data from server. */
516             if (do_read) {
517 
518 		/* Read it. */
519 		nrw = recv(s->sock, s->rcvbuf + s->rcvbuf_len, n2read, 0);
520 		vtrace("TLS: %d/%d bytes of handshake data received\n", nrw,
521 			n2read);
522 		if (nrw == SOCKET_ERROR) {
523 		    ret = WSAGetLastError();
524 		    if (ret != WSAEWOULDBLOCK) {
525 			sioc_set_error("recv: error %d (%s)\n", (int)ret,
526 				win32_strerror(ret));
527 		    }
528 		    break;
529 		} else if (nrw == 0) {
530 		    sioc_set_error("server disconnected during TLS "
531 			    "negotiation");
532 		    ret = WSAECONNABORTED; /* XXX: synthetic error */
533 		    break;
534 	    }
535 #if defined(VERBOSE) /*[*/
536 		print_hex_dump("<enc", nrw,
537 			(unsigned char *)s->rcvbuf + s->rcvbuf_len);
538 #endif /*]*/
539 		s->rcvbuf_len += nrw;
540 	    } else {
541 	      do_read = true;
542 	    }
543 	}
544 
545 	/*
546 	 * Set up the input buffers. Buffer 0 is used to pass in data
547 	 * received from the server. Schannel will consume some or all
548 	 * of this. Leftover data (if any) will be placed in buffer 1 and
549 	 * given a buffer type of SECBUFFER_EXTRA.
550 	 */
551 	in_buffers[0].pvBuffer   = s->rcvbuf;
552 	in_buffers[0].cbBuffer   = (DWORD)s->rcvbuf_len;
553 	in_buffers[0].BufferType = SECBUFFER_TOKEN;
554 
555 	in_buffers[1].pvBuffer   = NULL;
556 	in_buffers[1].cbBuffer   = 0;
557 	in_buffers[1].BufferType = SECBUFFER_EMPTY;
558 
559 	in_buffer.cBuffers       = 2;
560 	in_buffer.pBuffers       = in_buffers;
561 	in_buffer.ulVersion      = SECBUFFER_VERSION;
562 
563 	/*
564 	 * Set up the output buffers. These are initialized to NULL
565 	 * so as to make it less likely we'll attempt to free random
566 	 * garbage later.
567 	 */
568 	out_buffers[0].pvBuffer  = NULL;
569 	out_buffers[0].BufferType= SECBUFFER_TOKEN;
570 	out_buffers[0].cbBuffer  = 0;
571 
572 	out_buffer.cBuffers      = 1;
573 	out_buffer.pBuffers      = out_buffers;
574 	out_buffer.ulVersion     = SECBUFFER_VERSION;
575 
576 	/* Call InitializeSecurityContext. */
577 	ret = InitializeSecurityContext(
578 		&s->client_creds,
579 		&s->context,
580 		NULL,
581 		ssp_i_flags,
582 		0,
583 		0,
584 		&in_buffer,
585 		0,
586 		NULL,
587 		&out_buffer,
588 		&ssp_o_flags,
589 		&expiry);
590 
591 	vtrace("TLS: InitializeSecurityContext -> 0x%x (%s)\n", (unsigned)ret,
592 		win32_strerror(ret));
593 
594 	/*
595 	 * If InitializeSecurityContext was successful (or if the error was
596 	 * one of the special extended ones), send the contends of the output
597 	 * buffer to the server.
598 	 */
599 	if (ret == SEC_E_OK                ||
600 	    ret == SEC_I_CONTINUE_NEEDED   ||
601 	    (FAILED(ret) && (ssp_o_flags & ISC_RET_EXTENDED_ERROR))) {
602 	    if (out_buffers[0].cbBuffer != 0 &&
603 		    out_buffers[0].pvBuffer != NULL) {
604 		nrw = send(s->sock, out_buffers[0].pvBuffer,
605 			out_buffers[0].cbBuffer, 0);
606 		if (nrw == SOCKET_ERROR) {
607 		    ret = WSAGetLastError();
608 		    sioc_set_error("send: error %d (%s)\n", (int)ret,
609 			    win32_strerror(ret));
610 		    FreeContextBuffer(out_buffers[0].pvBuffer);
611 		    break;
612 		}
613 		vtrace("TLS: %d bytes of handshake data sent\n", nrw);
614 #if defined(VERBOSE) /*[*/
615 		print_hex_dump(">enc", nrw, out_buffers[0].pvBuffer);
616 #endif /*]*/
617 
618 		/* Free output buffer. */
619 		FreeContextBuffer(out_buffers[0].pvBuffer);
620 		out_buffers[0].pvBuffer = NULL;
621 	    }
622 	}
623 
624 	/*
625 	 * If InitializeSecurityContext returned SEC_E_INCOMPLETE_MESSAGE,
626 	 * then we need to read more data from the server and try again.
627 	 */
628 	if (ret == SEC_E_INCOMPLETE_MESSAGE) {
629 	    if (in_buffers[1].BufferType == SECBUFFER_MISSING) {
630 		n2read = in_buffers[1].cbBuffer;
631 	    } else {
632 		n2read = MIN_READ;
633 	    }
634 	    continue;
635 	} else {
636 	    n2read = MIN_READ;
637 	}
638 
639 	/*
640 	 * If InitializeSecurityContext returned SEC_E_OK, then the
641 	 * handshake completed successfully.
642 	 */
643 	if (ret == SEC_E_OK) {
644 	    /*
645 	     * If the "extra" buffer contains data, this is encrypted
646 	     * application protocol layer stuff. It needs to be saved. The
647 	     * application layer will later decrypt it with DecryptMessage.
648 	     */
649 	    vtrace("TLS: Handshake was successful\n");
650 
651 	    if (in_buffers[1].BufferType == SECBUFFER_EXTRA) {
652 		/* Interestingly, in_buffers[1].pvBuffer is NULL here. */
653 		vtrace("TLS: %d bytes of encrypted data saved\n",
654 			(int)in_buffers[1].cbBuffer);
655 		memmove(s->rcvbuf,
656 			s->rcvbuf + s->rcvbuf_len - in_buffers[1].cbBuffer,
657 			in_buffers[1].cbBuffer);
658 		s->rcvbuf_len = in_buffers[1].cbBuffer;
659 	    } else {
660 		s->rcvbuf_len = 0;
661 	    }
662 	    break;
663 	}
664 
665 	if (ret == SEC_E_UNSUPPORTED_FUNCTION) {
666 	    vtrace("TLS: SEC_E_UNSUPPORTED_FUNCTION from InitializeSecurityContext -- usually means requested TLS version not supported by server\n");
667 	}
668 
669 	if (ret == SEC_E_WRONG_PRINCIPAL) {
670 	    vtrace("TLS: SEC_E_WRONG_PRINCIPAL from InitializeSecurityContext -- bad server certificate\n");
671 	}
672 
673 	/* Check for fatal error. */
674 	if (FAILED(ret)) {
675 	    sioc_set_error("InitializeSecurityContext: error 0x%x (%s)\n", ret,
676 		    win32_strerror(ret));
677 	    break;
678 	}
679 
680 	/*
681 	 * If InitializeSecurityContext returned SEC_I_INCOMPLETE_CREDENTIALS,
682 	 * then the server just requested client authentication.
683 	 */
684 	if (ret == SEC_I_INCOMPLETE_CREDENTIALS) {
685 	    /*
686 	     * Busted. The server has requested client authentication and
687 	     * the credential we supplied didn't contain a client certificate.
688 	     * This function will read the list of trusted certificate
689 	     * authorities ("issuers") that was received from the server
690 	     * and attempt to find a suitable client certificate that
691 	     * was issued by one of these. If this function is successful,
692 	     * then we will connect using the new certificate. Otherwise,
693 	     * we will attempt to connect anonymously (using our current
694 	     * credentials).
695 	     */
696 	    get_new_client_credentials(&s->client_creds, &s->context);
697 
698 	    /* Go around again. */
699 	    do_read = false;
700 	    ret = SEC_I_CONTINUE_NEEDED;
701 	    continue;
702 	}
703 
704 	if (in_buffers[1].BufferType == SECBUFFER_EXTRA) {
705 	    /*
706 	     * Copy any leftover data from the "extra" buffer, and go around
707 	     * again.
708 	     */
709 	    vtrace("TLS: %lu bytes of extra data copied\n",
710 		    in_buffers[1].cbBuffer);
711 	    memmove(s->rcvbuf,
712 		    s->rcvbuf + s->rcvbuf_len - in_buffers[1].cbBuffer,
713 		    in_buffers[1].cbBuffer);
714 	    s->rcvbuf_len = in_buffers[1].cbBuffer;
715 	} else {
716 	    s->rcvbuf_len = 0;
717 	}
718     }
719 
720     /* Delete the security context in the case of a fatal error. */
721     if (ret != SEC_E_OK && ret != WSAEWOULDBLOCK) {
722 	DeleteSecurityContext(&s->context);
723     } else {
724 	s->context_set = true;
725     }
726 
727     return ret;
728 }
729 
730 /* Client handshake, first phase. */
731 static SECURITY_STATUS
perform_client_handshake(schannel_sio_t * s,LPSTR server_name)732 perform_client_handshake(
733 	schannel_sio_t *s,		/* in, out */
734 	LPSTR		server_name)	/* in */
735 {
736     SecBufferDesc   out_buffer;
737     SecBuffer       out_buffers[1];
738     DWORD           ssp_i_flags, ssp_o_flags;
739     int             data;
740     TimeStamp       expiry;
741     SECURITY_STATUS scRet;
742 
743     ssp_i_flags =
744 	ISC_REQ_SEQUENCE_DETECT   |
745 	ISC_REQ_REPLAY_DETECT     |
746 	ISC_REQ_CONFIDENTIALITY   |
747 	ISC_RET_EXTENDED_ERROR    |
748 	ISC_REQ_ALLOCATE_MEMORY   |
749 	ISC_REQ_STREAM;
750 
751     /* Initiate a ClientHello message and generate a token. */
752     out_buffers[0].pvBuffer   = NULL;
753     out_buffers[0].BufferType = SECBUFFER_TOKEN;
754     out_buffers[0].cbBuffer   = 0;
755 
756     out_buffer.cBuffers  = 1;
757     out_buffer.pBuffers  = out_buffers;
758     out_buffer.ulVersion = SECBUFFER_VERSION;
759 
760     scRet = InitializeSecurityContext(
761 	    &s->client_creds,
762 	    NULL,
763 	    server_name,
764 	    ssp_i_flags,
765 	    0,
766 	    0,
767 	    NULL,
768 	    0,
769 	    &s->context,
770 	    &out_buffer,
771 	    &ssp_o_flags,
772 	    &expiry);
773 
774     if (scRet != SEC_I_CONTINUE_NEEDED) {
775 	sioc_set_error("InitializeSecurityContext: error %d (%s)\n", scRet,
776 		win32_strerror(scRet));
777 	return scRet;
778     }
779 
780     /* Send response to server, if there is one. */
781     if (out_buffers[0].cbBuffer != 0 && out_buffers[0].pvBuffer != NULL) {
782 	data = send(s->sock, out_buffers[0].pvBuffer, out_buffers[0].cbBuffer,
783 		0);
784 	if (data == SOCKET_ERROR) {
785 	    int err = WSAGetLastError();
786 	    sioc_set_error("send: error %d (%s)\n", err, win32_strerror(err));
787 	    FreeContextBuffer(out_buffers[0].pvBuffer);
788 	    DeleteSecurityContext(&s->context);
789 	    return err;
790 	}
791 	vtrace("TLS: %d bytes of handshake data sent\n", data);
792 	FreeContextBuffer(out_buffers[0].pvBuffer);
793 	out_buffers[0].pvBuffer = NULL;
794     }
795 
796     return client_handshake_loop(s, true);
797 }
798 
799 /* Manually verify a server certificate. */
800 static DWORD
verify_server_certificate(PCCERT_CONTEXT server_cert,PSTR server_name,DWORD cert_flags)801 verify_server_certificate(
802 	PCCERT_CONTEXT server_cert,
803 	PSTR server_name,
804 	DWORD cert_flags)
805 {
806     HTTPSPolicyCallbackData  policy_https;
807     CERT_CHAIN_POLICY_PARA   policy_params;
808     CERT_CHAIN_POLICY_STATUS policy_status;
809     CERT_CHAIN_PARA          chain_params;
810     PCCERT_CHAIN_CONTEXT     chain_context = NULL;
811     DWORD                    server_name_size, status;
812     LPSTR rgszUsages[]     = { szOID_PKIX_KP_SERVER_AUTH,
813                                szOID_SERVER_GATED_CRYPTO,
814                                szOID_SGC_NETSCAPE };
815     DWORD usages_count     = sizeof(rgszUsages) / sizeof(LPSTR);
816     PWSTR server_name_wide = NULL;
817 
818     vtrace("TLS: Verifying server certificate manually\n");
819 
820     /* Convert server name to Unicode. */
821     server_name_size = MultiByteToWideChar(CP_ACP, 0, server_name, -1, NULL, 0);
822     server_name_wide = Malloc(server_name_size * sizeof(WCHAR));
823     MultiByteToWideChar(CP_ACP, 0, server_name, -1, server_name_wide,
824 	    server_name_size);
825 
826     /* Build certificate chain. */
827     memset(&chain_params, 0, sizeof(chain_params));
828     chain_params.cbSize = sizeof(chain_params);
829     chain_params.RequestedUsage.dwType = USAGE_MATCH_TYPE_OR;
830     chain_params.RequestedUsage.Usage.cUsageIdentifier = usages_count;
831     chain_params.RequestedUsage.Usage.rgpszUsageIdentifier = rgszUsages;
832 
833     if (!CertGetCertificateChain(
834 		NULL,
835 		server_cert,
836 		NULL,
837 		server_cert->hCertStore,
838 		&chain_params,
839 		0,
840 		NULL,
841 		&chain_context)) {
842 	status = GetLastError();
843 	sioc_set_error("CertGetCertificateChain: error 0x%x (%s)\n", status,
844 		win32_strerror(status));
845 	goto done;
846     }
847 
848     /* Validate certificate chain. */
849     ZeroMemory(&policy_https, sizeof(HTTPSPolicyCallbackData));
850     policy_https.cbStruct       = sizeof(HTTPSPolicyCallbackData);
851     policy_https.dwAuthType     = AUTHTYPE_SERVER;
852     policy_https.fdwChecks      = cert_flags;
853     policy_https.pwszServerName = server_name_wide;
854 
855     memset(&policy_params, 0, sizeof(policy_params));
856     policy_params.cbSize = sizeof(policy_params);
857     policy_params.pvExtraPolicyPara = &policy_https;
858 
859     memset(&policy_status, 0, sizeof(policy_status));
860     policy_status.cbSize = sizeof(policy_status);
861 
862     if (!CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL, chain_context,
863 		&policy_params, &policy_status)) {
864 	status = GetLastError();
865 	sioc_set_error("CertVerifyCertificateChainPolicy: error 0x%x (%s)\n",
866 		status, win32_strerror(status));
867 	goto done;
868     }
869 
870     if (policy_status.dwError) {
871 	status = policy_status.dwError;
872 	sioc_set_error("CertVerifyCertificateChainPolicy: error 0x%x (%s)\n",
873 		status, win32_strerror(status));
874 	goto done;
875     }
876 
877     status = SEC_E_OK;
878 
879 done:
880     if (chain_context != NULL) {
881 	CertFreeCertificateChain(chain_context);
882     }
883     if (server_name_wide != NULL) {
884 	Free(server_name_wide);
885     }
886 
887     return status;
888 }
889 
890 /* Display a connection. */
891 static void
display_connection_info(varbuf_t * v,CtxtHandle * context)892 display_connection_info(varbuf_t *v, CtxtHandle *context)
893 {
894     SECURITY_STATUS status;
895     SecPkgContext_ConnectionInfo connection_info;
896 
897     status = QueryContextAttributes(context, SECPKG_ATTR_CONNECTION_INFO,
898 	    (PVOID)&connection_info);
899     if (status != SEC_E_OK) {
900 	vtrace("QueryContextAttributes: error 0x%x (%s)\n", (unsigned)status,
901 		win32_strerror(status));
902 	return;
903     }
904 
905     vb_appendf(v, "Protocol: ");
906     switch (connection_info.dwProtocol) {
907     case SP_PROT_TLS1_CLIENT:
908 	vb_appendf(v, "TLS 1.0\n");
909 	break;
910     case SP_PROT_TLS1_1_CLIENT:
911 	vb_appendf(v, "TLS 1.1\n");
912 	break;
913     case SP_PROT_TLS1_2_CLIENT:
914 	vb_appendf(v, "TLS 1.2\n");
915 	break;
916     case SP_PROT_SSL3_CLIENT:
917 	vb_appendf(v, "SSL 3.0\n");
918 	break;
919     case SP_PROT_SSL2_CLIENT:
920 	vb_appendf(v, "SSL 2.0\n");
921 	break;
922     default:
923 	vb_appendf(v, "0x%x\n", (unsigned)connection_info.dwProtocol);
924 	break;
925     }
926 
927     vb_appendf(v, "Cipher: ");
928     switch (connection_info.aiCipher) {
929     case CALG_3DES:
930 	vb_appendf(v, "Triple DES\n");
931 	break;
932     case CALG_AES:
933 	vb_appendf(v, "AES\n");
934 	break;
935     case CALG_AES_128:
936 	vb_appendf(v, "AES 128\n");
937 	break;
938     case CALG_AES_256:
939 	vb_appendf(v, "AES 256\n");
940 	break;
941     case CALG_DES:
942 	vb_appendf(v, "DES\n");
943 	break;
944     case CALG_RC2:
945 	vb_appendf(v, "RC2\n");
946 	break;
947     case CALG_RC4:
948 	vb_appendf(v, "RC4\n");
949 	break;
950     default:
951 	vb_appendf(v, "0x%x\n", connection_info.aiCipher);
952 	break;
953     }
954 
955     vb_appendf(v, "Cipher strength: %d\n",
956 	    (int)connection_info.dwCipherStrength);
957 
958     vb_appendf(v, "Hash: ");
959     switch (connection_info.aiHash) {
960     case CALG_MD5:
961 	vb_appendf(v, "MD5\n");
962 	break;
963     case CALG_SHA:
964 	vb_appendf(v, "SHA\n");
965 	break;
966     default:
967 	vb_appendf(v, "0x%x\n", connection_info.aiHash);
968 	break;
969     }
970 
971     vb_appendf(v, "Hash strength: %d\n", (int)connection_info.dwHashStrength);
972 
973     vb_appendf(v, "Key exchange: ");
974     switch (connection_info.aiExch) {
975     case CALG_RSA_KEYX:
976     case CALG_RSA_SIGN:
977 	vb_appendf(v, "RSA\n");
978 	break;
979     case CALG_KEA_KEYX:
980 	vb_appendf(v, "KEA\n");
981 	break;
982     case CALG_DH_EPHEM:
983 	vb_appendf(v, "DH Ephemeral\n");
984 	break;
985     default:
986 	vb_appendf(v, "0x%x\n", connection_info.aiExch);
987 	break;
988     }
989 
990     vb_appendf(v, "Key exchange strength: %d\n",
991 	    (int)connection_info.dwExchStrength);
992 }
993 
994 /* Free an sio context. */
995 static void
sio_free(schannel_sio_t * s)996 sio_free(schannel_sio_t *s)
997 {
998     s->sock = INVALID_SOCKET;
999 
1000     /* Free the SSPI context handle. */
1001     if (s->context_set) {
1002         DeleteSecurityContext(&s->context);
1003 	memset(&s->context, 0, sizeof(s->context));
1004         s->context_set = false;
1005     }
1006 
1007     /* Free the SSPI credentials handle. */
1008     if (s->client_creds_set) {
1009         FreeCredentialsHandle(&s->client_creds);
1010 	memset(&s->client_creds, 0, sizeof(s->client_creds));
1011         s->client_creds_set = false;
1012     }
1013 
1014     /* Free the receive buffer. */
1015     if (s->rcvbuf != NULL) {
1016 	Free(s->rcvbuf);
1017 	s->rcvbuf = NULL;
1018     }
1019 
1020     /* Free the record buffer. */
1021     if (s->prbuf != NULL) {
1022 	Free(s->prbuf);
1023 	s->prbuf = NULL;
1024     }
1025 
1026     /* Free the send buffer. */
1027     if (s->sendbuf != NULL) {
1028 	Free(s->sendbuf);
1029 	s->sendbuf = NULL;
1030     }
1031 
1032     /* Free the session info. */
1033     if (s->session_info != NULL) {
1034 	Free(s->session_info);
1035 	s->session_info = NULL;
1036     }
1037 
1038     /* Free the server cert info. */
1039     if (s->server_cert_info != NULL) {
1040 	Free(s->server_cert_info);
1041 	s->server_cert_info = NULL;
1042     }
1043 
1044     Free(s);
1045 }
1046 
1047 /* Returns true if secure I/O is supported. */
1048 bool
sio_supported(void)1049 sio_supported(void)
1050 {
1051     return true;
1052 }
1053 
1054 /*
1055  * Create a new context.
1056  */
1057 sio_init_ret_t
sio_init(tls_config_t * c,const char * password,sio_t * sio_ret)1058 sio_init(tls_config_t *c, const char *password, sio_t *sio_ret)
1059 {
1060     schannel_sio_t *s;
1061 
1062     sioc_error_reset();
1063 
1064     config = c;
1065 
1066     s = (schannel_sio_t *)Malloc(sizeof(schannel_sio_t));
1067     memset(s, 0, sizeof(*s));
1068     s->sock = INVALID_SOCKET;
1069 
1070     /* Create credentials. */
1071     if (create_credentials(config->client_cert, &s->client_creds, &s->manual)) {
1072 	vtrace("TLS: Error creating credentials\n");
1073 	goto fail;
1074     }
1075     s->client_creds_set = true;
1076 
1077     *sio_ret = (sio_t)s;
1078     return SI_SUCCESS;
1079 
1080 fail:
1081     sio_free(s);
1082     return SI_FAILURE;
1083 }
1084 
1085 /*
1086  * Negotiate a TLS connection.
1087  * Returns true for success, false for failure.
1088  * If it returns false, the socket should be disconnected.
1089  *
1090  * Returns 'data' true if there is already protocol data pending.
1091  */
1092 sio_negotiate_ret_t
sio_negotiate(sio_t sio,socket_t sock,const char * hostname,bool * data)1093 sio_negotiate(sio_t sio, socket_t sock, const char *hostname, bool *data)
1094 {
1095     schannel_sio_t *s;
1096     const char *accept_hostname = hostname;
1097     SECURITY_STATUS status;
1098     PCCERT_CONTEXT remote_cert_context = NULL;
1099     size_t recsz;
1100     varbuf_t v;
1101     char *cert_desc = NULL;
1102     size_t sl;
1103 
1104     sioc_error_reset();
1105 
1106     *data = false;
1107     if (sio == NULL) {
1108 	sioc_set_error("NULL sio");
1109 	return SIG_FAILURE;
1110     }
1111     s = (schannel_sio_t *)sio;
1112     if (s->negotiate_pending) {
1113 	if (s->sock == INVALID_SOCKET) {
1114 	    sioc_set_error("Invalid sio (missing socket)");
1115 	    return SIG_FAILURE;
1116 	}
1117 
1118 	/* Continue handshake. */
1119 	status = client_handshake_loop(s, true);
1120     } else {
1121 	if (s->sock != INVALID_SOCKET) {
1122 	    sioc_set_error("Invalid sio (already negotiated)");
1123 	    return SIG_FAILURE;
1124 	}
1125 	s->sock = sock;
1126 	s->hostname = hostname;
1127 
1128 	/*
1129 	 * Allocate the initial receive buffer.
1130 	 * This is temporary, because we can't learn the receive stream sizes
1131 	 * until we have finished negotiating, but we need a receive buffer to
1132 	 * negotiate in the first place.
1133 	 */
1134 	s->rcvbuf = Malloc(INBUF);
1135 
1136 	if (config->accept_hostname != NULL) {
1137 	    if (!strncasecmp(accept_hostname, "DNS:", 4)) {
1138 		accept_hostname = config->accept_hostname + 4;
1139 		sioc_set_error("Empty acceptHostname");
1140 		goto fail;
1141 	    } else if (!strncasecmp(config->accept_hostname, "IP:", 3)) {
1142 		sioc_set_error("Cannot use 'IP:' acceptHostname");
1143 		goto fail;
1144 	    } else if (!strcasecmp(config->accept_hostname, "any")) {
1145 		sioc_set_error("Cannot use 'any' acceptHostname");
1146 		goto fail;
1147 	    } else {
1148 		accept_hostname = config->accept_hostname;
1149 	    }
1150 	}
1151 
1152 	/* Perform handshake. */
1153 	status = perform_client_handshake(s, (LPSTR)accept_hostname);
1154     }
1155 
1156     if (status == WSAEWOULDBLOCK) {
1157 	s->negotiate_pending = true;
1158 	return SIG_WANTMORE;
1159     } else if (status != 0) {
1160 	vtrace("TLS: Error performing handshake\n");
1161 	goto fail;
1162     }
1163 
1164     /* Get the server's certificate. */
1165     status = QueryContextAttributes(&s->context,
1166 	    SECPKG_ATTR_REMOTE_CERT_CONTEXT, (PVOID)&remote_cert_context);
1167     if (status != SEC_E_OK) {
1168 	sioc_set_error("QueryContextAttributes: error 0x%x (%s)",
1169 		(unsigned)status, win32_strerror(status));
1170 	goto fail;
1171     }
1172 
1173     /*
1174      * Get the description of the server certificate chain.
1175      */
1176     vb_init(&v);
1177     display_cert_chain(&v, remote_cert_context);
1178     cert_desc = vb_consume(&v);
1179 
1180     /* Attempt to validate the server certificate. */
1181     if (s->manual && config->verify_host_cert) {
1182 	status = verify_server_certificate(remote_cert_context,
1183 		(LPSTR)accept_hostname, 0);
1184 	if (status) {
1185 	    vtrace("TLS: Error 0x%x authenticating server credentials\n",
1186 		    (unsigned)status);
1187 	    goto fail;
1188 	}
1189     }
1190 
1191     /* Free the server certificate context. */
1192     CertFreeCertificateContext(remote_cert_context);
1193     remote_cert_context = NULL;
1194 
1195     /* Read stream encryption properties. */
1196     status = QueryContextAttributes(&s->context, SECPKG_ATTR_STREAM_SIZES,
1197 	    &s->sizes);
1198     if (status != SEC_E_OK) {
1199 	sioc_set_error("QueryContextAttributes: error 0x%x (%s)",
1200 		(unsigned)status, win32_strerror(status));
1201 	goto fail;
1202     }
1203 
1204     /* Display connection info. */
1205     vb_init(&v);
1206     display_connection_info(&v, &s->context);
1207     s->session_info = vb_consume(&v);
1208     sl = strlen(s->session_info);
1209     if (sl > 0 && s->session_info[sl - 1] == '\n') {
1210 	s->session_info[sl - 1] = '\0';
1211     }
1212 
1213     /* Display server_cert info. */
1214     s->server_cert_info = cert_desc;
1215     cert_desc = NULL;
1216     sl = strlen(s->server_cert_info);
1217     if (sl > 0 && s->server_cert_info[sl - 1] == '\n') {
1218 	s->server_cert_info[sl - 1] = '\0';
1219     }
1220 
1221     /* Account for any extra data. */
1222     if (s->rcvbuf_len > 0) {
1223 	*data = true;
1224     }
1225 
1226     /* Reallocate the receive buffer. */
1227     vtrace("TLS: Sizes: header %d, trailer %d, max message %d\n",
1228 	    (int)s->sizes.cbHeader, (int)s->sizes.cbTrailer,
1229 	    (int)s->sizes.cbMaximumMessage);
1230     recsz = s->sizes.cbHeader + s->sizes.cbTrailer + s->sizes.cbMaximumMessage;
1231     if (recsz > INBUF) {
1232 	s->rcvbuf = Realloc(s->rcvbuf, recsz);
1233     }
1234     s->prbuf = Malloc(s->sizes.cbMaximumMessage);
1235     s->sendbuf = Malloc(s->sizes.cbMaximumMessage);
1236 
1237     /* Success. */
1238     s->secure_unverified = !config->verify_host_cert;
1239     s->negotiated = true;
1240     return SIG_SUCCESS;
1241 
1242 fail:
1243     /* Free the server certificate context. */
1244     if (remote_cert_context != NULL) {
1245         CertFreeCertificateContext(remote_cert_context);
1246         remote_cert_context = NULL;
1247     }
1248 
1249     /* Free the SSPI context handle. */
1250     if (s->context_set) {
1251         DeleteSecurityContext(&s->context);
1252 	memset(&s->context, 0, sizeof(s->context));
1253         s->context_set = false;
1254     }
1255 
1256     /* Free the SSPI credentials handle. */
1257     if (s->client_creds_set) {
1258         FreeCredentialsHandle(&s->client_creds);
1259 	memset(&s->client_creds, 0, sizeof(s->client_creds));
1260         s->client_creds_set = false;
1261     }
1262 
1263     if (cert_desc != NULL) {
1264 	Free(cert_desc);
1265     }
1266 
1267     return SIG_FAILURE;
1268 }
1269 
1270 /*
1271  * Read and decrypt data.
1272  */
1273 static SECURITY_STATUS
read_decrypt(schannel_sio_t * s,CtxtHandle * context)1274 read_decrypt(
1275 	schannel_sio_t *s,	/* in */
1276 	CtxtHandle *context)	/* in */
1277 {
1278     SecBuffer          *data_buffer_ptr, *extra_buffer_ptr;
1279 
1280     SECURITY_STATUS    ret;
1281     SecBufferDesc      message;
1282     SecBuffer          buffers[4];
1283 
1284     int                nr;
1285     int                i;
1286     int                n2read = s->sizes.cbHeader;
1287 
1288     /* Read data from server until done. */
1289     ret = SEC_E_OK;
1290     while (true) {
1291 	data_buffer_ptr = NULL;
1292 	extra_buffer_ptr = NULL;
1293 
1294 	/* Read some data. */
1295 	if (s->rcvbuf_len == 0 || ret == SEC_E_INCOMPLETE_MESSAGE) {
1296 	    /* Get the data */
1297             nr = recv(s->sock, s->rcvbuf + s->rcvbuf_len, n2read, 0);
1298 	    vtrace("TLS: %d/%d bytes of encrypted application data received\n",
1299 		    nr, n2read);
1300             if (nr == SOCKET_ERROR) {
1301 		ret = WSAGetLastError();
1302 		sioc_set_error("recv: error %d (%s)", (int)ret,
1303 			win32_strerror(ret));
1304 		break;
1305             } else if (nr == 0) {
1306 		/* Server disconnected. */
1307 		vtrace("TLS: Server disconnected.\n");
1308 		s->negotiated = false;
1309 		ret = SEC_E_OK;
1310 		break;
1311             } else {
1312 		/* Success. */
1313 #if defined(VERBOSE) /*[*/
1314 		print_hex_dump("<enc", nr,
1315 			(unsigned char *)s->rcvbuf + s->rcvbuf_len);
1316 #endif /*]*/
1317 		s->rcvbuf_len += nr;
1318             }
1319         }
1320 
1321         /* Try to decrypt it. */
1322 	buffers[0].pvBuffer     = s->rcvbuf;
1323 	buffers[0].cbBuffer     = (DWORD)s->rcvbuf_len;
1324 	buffers[0].BufferType   = SECBUFFER_DATA;
1325 	buffers[1].BufferType   = SECBUFFER_EMPTY;
1326 	buffers[2].BufferType   = SECBUFFER_EMPTY;
1327 	buffers[3].BufferType   = SECBUFFER_EMPTY;
1328 
1329 	message.ulVersion       = SECBUFFER_VERSION;
1330 	message.cBuffers        = 4;
1331 	message.pBuffers        = buffers;
1332 	ret = DecryptMessage(context, &message, 0, NULL);
1333 	if (ret == SEC_I_CONTEXT_EXPIRED) {
1334 	    /* Server signalled end of session. Treat it like EOF. */
1335 	    vtrace("TLS: Server signaled end of session.\n");
1336 	    s->negotiated = false;
1337 	    ret = SEC_E_OK;
1338 	    break;
1339 	}
1340         if (ret != SEC_E_OK &&
1341 	    ret != SEC_I_RENEGOTIATE &&
1342 	    ret != SEC_I_CONTEXT_EXPIRED &&
1343 	    ret != SEC_E_INCOMPLETE_MESSAGE) {
1344 	    sioc_set_error("DecryptMessage: error 0x%x (%s)\n", (unsigned)ret,
1345 		    win32_strerror(ret));
1346 	    return ret;
1347 	}
1348 
1349 	if (ret == SEC_E_INCOMPLETE_MESSAGE) {
1350 	    /* Nibble some more. */
1351 	    if (buffers[0].BufferType == SECBUFFER_MISSING) {
1352 		n2read = buffers[0].cbBuffer;
1353 	    } else {
1354 		n2read = s->sizes.cbHeader;
1355 	    }
1356 	    continue;
1357 	} else {
1358 	    n2read = s->sizes.cbHeader;
1359 	}
1360 
1361 	/* Locate data and (optional) extra buffers. */
1362 	data_buffer_ptr  = NULL;
1363 	extra_buffer_ptr = NULL;
1364 	for (i = 1; i < 4; i++) {
1365 	    if (data_buffer_ptr == NULL &&
1366 		    buffers[i].BufferType == SECBUFFER_DATA) {
1367 		data_buffer_ptr  = &buffers[i];
1368 	    }
1369 	    if (extra_buffer_ptr == NULL &&
1370 		    buffers[i].BufferType == SECBUFFER_EXTRA) {
1371 		extra_buffer_ptr = &buffers[i];
1372 	    }
1373 	}
1374 
1375 	/* Check for completion. */
1376         if (data_buffer_ptr != NULL && data_buffer_ptr->cbBuffer) {
1377 	    /* Copy decrypted data to the record buffer. */
1378 	    memcpy(s->prbuf, data_buffer_ptr->pvBuffer,
1379 		    data_buffer_ptr->cbBuffer);
1380 	    s->prbuf_len = data_buffer_ptr->cbBuffer;
1381 	    s->rcvbuf_len = 0;
1382 	    vtrace("TLS: Got %lu decrypted bytes\n", data_buffer_ptr->cbBuffer);
1383 	}
1384 
1385 	/* Move any "extra" data to the receive buffer for next time. */
1386 	if (extra_buffer_ptr != NULL) {
1387 	    vtrace("TLS: %d bytes extra after decryption\n",
1388 		    (int)extra_buffer_ptr->cbBuffer);
1389 	    memmove(s->rcvbuf, extra_buffer_ptr->pvBuffer,
1390 		    extra_buffer_ptr->cbBuffer);
1391 	    s->rcvbuf_len = extra_buffer_ptr->cbBuffer;
1392 	}
1393 
1394 	/*
1395 	 * Check for renegotiation.
1396 	 * It's not clear to me if we can get data back *and* this return code,
1397 	 * of if it's one or the other.
1398 	 */
1399 	if (ret == SEC_I_RENEGOTIATE) {
1400 	    /* The server wants to perform another handshake sequence. */
1401 	    vtrace("TLS: Server requested renegotiate\n");
1402 	    ret = client_handshake_loop(s, false);
1403 	    if (ret != SEC_E_OK) {
1404 		s->negotiated = false;
1405 		return ret;
1406 	    }
1407 	    /* XXX: And if it succeeds? */
1408 	}
1409 
1410 	if (ret == SEC_E_OK) {
1411 	    break;
1412 	}
1413     }
1414 
1415     return ret;
1416 }
1417 
1418 /* Send an encrypted message. */
1419 static SECURITY_STATUS
encrypt_send(schannel_sio_t * s,const char * buf,size_t len)1420 encrypt_send(
1421 	schannel_sio_t *s,
1422 	const char *buf,
1423 	size_t len)
1424 {
1425     SECURITY_STATUS    ret;
1426     SecBufferDesc      message;
1427     SecBuffer          buffers[4];
1428     int                nw;
1429 
1430     /* Copy the data. */
1431     memcpy(s->sendbuf + s->sizes.cbHeader, buf, len);
1432 
1433     /* Encrypt the data. */
1434     buffers[0].pvBuffer     = s->sendbuf;
1435     buffers[0].cbBuffer     = s->sizes.cbHeader;
1436     buffers[0].BufferType   = SECBUFFER_STREAM_HEADER;
1437 
1438     buffers[1].pvBuffer     = s->sendbuf + s->sizes.cbHeader;
1439     buffers[1].cbBuffer     = (DWORD)len;
1440     buffers[1].BufferType   = SECBUFFER_DATA;
1441 
1442     buffers[2].pvBuffer     = s->sendbuf + s->sizes.cbHeader + len;
1443     buffers[2].cbBuffer     = s->sizes.cbTrailer;
1444     buffers[2].BufferType   = SECBUFFER_STREAM_TRAILER;
1445 
1446     buffers[3].pvBuffer     = SECBUFFER_EMPTY;
1447     buffers[3].cbBuffer     = SECBUFFER_EMPTY;
1448     buffers[3].BufferType   = SECBUFFER_EMPTY;
1449 
1450     message.ulVersion       = SECBUFFER_VERSION;
1451     message.cBuffers        = 4;
1452     message.pBuffers        = buffers;
1453     ret = EncryptMessage(&s->context, 0, &message, 0);
1454     if (FAILED(ret)) {
1455 	sioc_set_error("EncryptMessage: error 0x%x (%s)", (unsigned)ret,
1456 		win32_strerror(ret));
1457 	return ret;
1458     }
1459 
1460     /* Send the encrypted data to the server. */
1461     nw = send(s->sock, s->sendbuf,
1462 	    buffers[0].cbBuffer + buffers[1].cbBuffer + buffers[2].cbBuffer, 0);
1463 	vtrace("TLS: %d bytes of encrypted data sent\n", nw);
1464     if (nw < 0) {
1465 	ret = WSAGetLastError();
1466 	sioc_set_error("send: error %d (%s)", (int)ret, win32_strerror(ret));
1467     } else {
1468 #if defined(VERBOSE) /*[*/
1469 	print_hex_dump(">enc", nw, (PBYTE)s->sendbuf);
1470 #endif /*]*/
1471     }
1472 
1473     return ret;
1474 }
1475 
1476 /* Disconnect from the server. */
1477 static SECURITY_STATUS
disconnect_from_server(schannel_sio_t * s)1478 disconnect_from_server(schannel_sio_t *s)
1479 {
1480     PBYTE         outbuf;
1481     DWORD         type, flags, out_flags;
1482     int		  n2w;
1483     int           nw;
1484     SECURITY_STATUS status;
1485     SecBufferDesc out_buffer;
1486     SecBuffer     out_buffers[1];
1487     TimeStamp     expiry;
1488 
1489     /* Notify schannel that we are about to close the connection. */
1490     type = SCHANNEL_SHUTDOWN;
1491 
1492     out_buffers[0].pvBuffer   = &type;
1493     out_buffers[0].BufferType = SECBUFFER_TOKEN;
1494     out_buffers[0].cbBuffer   = sizeof(type);
1495 
1496     out_buffer.cBuffers  = 1;
1497     out_buffer.pBuffers  = out_buffers;
1498     out_buffer.ulVersion = SECBUFFER_VERSION;
1499 
1500     status = ApplyControlToken(&s->context, &out_buffer);
1501     if (FAILED(status)) {
1502 	vtrace("TLS: ApplyControlToken: error 0x%x (%s)\n", (unsigned)status,
1503 		win32_strerror(status));
1504 	return status;
1505     }
1506 
1507     /* Build a TLS close notify message. */
1508     flags = ISC_REQ_SEQUENCE_DETECT   |
1509 		  ISC_REQ_REPLAY_DETECT     |
1510 		  ISC_REQ_CONFIDENTIALITY   |
1511 		  ISC_RET_EXTENDED_ERROR    |
1512 		  ISC_REQ_ALLOCATE_MEMORY   |
1513 		  ISC_REQ_STREAM;
1514 
1515     out_buffers[0].pvBuffer   = NULL;
1516     out_buffers[0].BufferType = SECBUFFER_TOKEN;
1517     out_buffers[0].cbBuffer   = 0;
1518 
1519     out_buffer.cBuffers  = 1;
1520     out_buffer.pBuffers  = out_buffers;
1521     out_buffer.ulVersion = SECBUFFER_VERSION;
1522 
1523     status = InitializeSecurityContext(&s->client_creds,
1524 	    &s->context,
1525 	    NULL,
1526 	    flags,
1527 	    0,
1528 	    0,
1529 	    NULL,
1530 	    0,
1531 	    &s->context,
1532 	    &out_buffer,
1533 	    &out_flags,
1534 	    &expiry);
1535 
1536     if (FAILED(status)) {
1537 	vtrace("TLS: InitializeSecurityContext: error 0x%x (%s)\n",
1538 		(unsigned)status, win32_strerror(status));
1539 	return status;
1540     }
1541 
1542     outbuf = out_buffers[0].pvBuffer;
1543     n2w = out_buffers[0].cbBuffer;
1544 
1545     /* Send the close notify message to the server. */
1546     if (outbuf != NULL && n2w != 0) {
1547 	nw = send(s->sock, (char *)outbuf, n2w, 0);
1548 	if (nw == SOCKET_ERROR) {
1549 	    status = WSAGetLastError();
1550 	    vtrace("TLS: send: error %d (%s)\n", (int)status,
1551 		    win32_strerror(status));
1552 	} else {
1553 	    vtrace("TLS: %d bytes of handshake data sent\n", nw);
1554 #if defined(VERBOSE) /*[*/
1555 	    print_hex_dump(">enc", nw, outbuf);
1556 #endif /*]*/
1557 	}
1558 	FreeContextBuffer(outbuf);
1559     }
1560     vtrace("TLS: Sent TLS disconnect\n");
1561 
1562     return status;
1563 }
1564 
1565 /*
1566  * Read encrypted data from a socket.
1567  * Returns the data length, SIO_EOF for EOF, SIO_FATAL_ERROR for a fatal error,
1568  * SIO_EWOULDBLOCK for incomplete input.
1569  */
1570 int
sio_read(sio_t sio,char * buf,size_t buflen)1571 sio_read(sio_t sio, char *buf, size_t buflen)
1572 {
1573     schannel_sio_t *s;
1574     SECURITY_STATUS ret;
1575 
1576     sioc_error_reset();
1577 
1578     if (sio == NULL) {
1579 	sioc_set_error("NULL sio");
1580 	return SIO_FATAL_ERROR;
1581     }
1582     s = (schannel_sio_t *)sio;
1583     if (s->sock == INVALID_SOCKET) {
1584 	sioc_set_error("Invalid sio (not negotiated)");
1585 	return SIO_FATAL_ERROR;
1586     }
1587 
1588     if (!s->negotiated) {
1589 	return SIO_EOF;
1590     }
1591 
1592     if (s->prbuf_len > 0) {
1593 	size_t copy_len = s->prbuf_len;
1594 
1595 	/* Record already buffered. */
1596 	if (copy_len > buflen) {
1597 	    copy_len = buflen;
1598 	}
1599 	memcpy(buf, s->prbuf, copy_len);
1600 	s->prbuf_len -= copy_len;
1601 	return (int)copy_len;
1602     }
1603 
1604     ret = read_decrypt(s, &s->context);
1605     if (ret != SEC_E_OK) {
1606 	if (ret == WSAEWOULDBLOCK) {
1607 	    return SIO_EWOULDBLOCK;
1608 	}
1609 	s->negotiated = false;
1610 	vtrace("TLS: sio_read: fatal error, ret = 0x%x\n", (unsigned)ret);
1611 	return SIO_FATAL_ERROR;
1612     }
1613 
1614     if (s->prbuf_len == 0) {
1615 	/* End of file. */
1616 	s->negotiated = false;
1617 	return SIO_EOF;
1618     }
1619 
1620     /* Got a complete record. */
1621     return sio_read(sio, buf, buflen);
1622 }
1623 
1624 /*
1625  * Write encrypted data on the socket.
1626  * Returns the data length or SIO_FATAL_ERROR.
1627  */
1628 int
sio_write(sio_t sio,const char * buf,size_t buflen)1629 sio_write(sio_t sio, const char *buf, size_t buflen)
1630 {
1631     schannel_sio_t *s;
1632     size_t len_left = buflen;
1633 
1634     sioc_error_reset();
1635 
1636     if (sio == NULL) {
1637 	sioc_set_error("NULL sio");
1638 	return SIO_FATAL_ERROR;
1639     }
1640     s = (schannel_sio_t *)sio;
1641     if (s->sock == INVALID_SOCKET) {
1642 	sioc_set_error("Invalid sio (not negotiated)");
1643 	return SIO_FATAL_ERROR;
1644     }
1645 
1646     do {
1647 	size_t n2w = len_left;
1648 	SECURITY_STATUS ret;
1649 
1650 	if (n2w > s->sizes.cbMaximumMessage) {
1651 	    n2w = s->sizes.cbMaximumMessage;
1652 	}
1653 	ret = encrypt_send(s, buf, n2w);
1654 	if (ret != SEC_E_OK) {
1655 	    s->negotiated = false;
1656 	    return SIO_FATAL_ERROR;
1657 	}
1658 	len_left -= n2w;
1659 	buf += n2w;
1660     } while (len_left > 0);
1661 
1662     return (int)buflen;
1663 }
1664 
1665 /* Closes the TLS connection. */
1666 void
sio_close(sio_t sio)1667 sio_close(sio_t sio)
1668 {
1669     schannel_sio_t *s;
1670 
1671     if (sio == NULL) {
1672 	return;
1673     }
1674     s = (schannel_sio_t *)sio;
1675     if (s->sock == INVALID_SOCKET) {
1676 	return;
1677     }
1678 
1679     if (s->negotiated) {
1680 	disconnect_from_server(s);
1681     }
1682     sio_free(s);
1683 }
1684 
1685 /*
1686  * Returns true if the current connection is unverified.
1687  */
1688 bool
sio_secure_unverified(sio_t sio)1689 sio_secure_unverified(sio_t sio)
1690 {
1691     schannel_sio_t *s = (schannel_sio_t *)sio;
1692     return s? s->secure_unverified: false;
1693 }
1694 
1695 /*
1696  * Returns a bitmap of the supported options.
1697  */
1698 unsigned
sio_options_supported(void)1699 sio_options_supported(void)
1700 {
1701     return TLS_OPT_CLIENT_CERT;
1702 }
1703 
1704 /*
1705  * Returns session information.
1706  */
1707 const char *
sio_session_info(sio_t sio)1708 sio_session_info(sio_t sio)
1709 {
1710     schannel_sio_t *s = (schannel_sio_t *)sio;
1711     return s? s->session_info: NULL;
1712 }
1713 
1714 /*
1715  * Returns server cert information.
1716  */
1717 const char *
sio_server_cert_info(sio_t sio)1718 sio_server_cert_info(sio_t sio)
1719 {
1720     schannel_sio_t *s = (schannel_sio_t *)sio;
1721     return s? s->server_cert_info: NULL;
1722 }
1723 
1724 /*
1725  * Returns the provider name.
1726  */
1727 const char *
sio_provider(void)1728 sio_provider(void)
1729 {
1730     return "Windows Schannel";
1731 }
1732