xref: /reactos/base/applications/mstsc/tcp.c (revision 84ccccab)
1 /* -*- c-basic-offset: 8 -*-
2    rdesktop: A Remote Desktop Protocol client.
3    Protocol services - TCP layer
4    Copyright (C) Matthew Chapman <matthewc.unsw.edu.au> 1999-2008
5    Copyright 2005-2011 Peter Astrand <astrand@cendio.se> for Cendio AB
6    Copyright 2012-2013 Henrik Andersson <hean01@cendio.se> for Cendio AB
7 
8    This program is free software: you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation, either version 3 of the License, or
11    (at your option) any later version.
12 
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17 
18    You should have received a copy of the GNU General Public License
19    along with this program.  If not, see <http://www.gnu.org/licenses/>.
20 */
21 
22 #include "precomp.h"
23 #ifdef WITH_SSL
24 #include <sspi.h>
25 #include <schannel.h>
26 #endif
27 
28 #ifdef _WIN32
29 #define socklen_t int
30 #define TCP_CLOSE(_sck) closesocket(_sck)
31 #define TCP_STRERROR "tcp error"
32 #define TCP_SLEEP(_n) Sleep(_n)
33 #define TCP_BLOCKS (WSAGetLastError() == WSAEWOULDBLOCK)
34 #else /* _WIN32 */
35 #define TCP_CLOSE(_sck) close(_sck)
36 #define TCP_STRERROR strerror(errno)
37 #define TCP_SLEEP(_n) sleep(_n)
38 #define TCP_BLOCKS (errno == EWOULDBLOCK)
39 #endif /* _WIN32 */
40 
41 #ifndef INADDR_NONE
42 #define INADDR_NONE ((unsigned long) -1)
43 #endif
44 
45 #ifdef WITH_SCARD
46 #define STREAM_COUNT 8
47 #else
48 #define STREAM_COUNT 1
49 #endif
50 
51 #ifdef WITH_SSL
52 typedef struct
53 {
54 	CtxtHandle ssl_ctx;
55 	SecPkgContext_StreamSizes ssl_sizes;
56 	char *ssl_buf;
57 	char *extra_buf;
58 	size_t extra_len;
59 	char *peek_msg;
60 	char *peek_msg_mem;
61 	size_t peek_len;
62 	DWORD security_flags;
63 } netconn_t;
64 static char * g_ssl_server = NULL;
65 static RD_BOOL g_ssl_initialized = False;
66 static RD_BOOL cred_handle_initialized = False;
67 static RD_BOOL have_compat_cred_handle = False;
68 static SecHandle cred_handle, compat_cred_handle;
69 static netconn_t g_ssl1;
70 static netconn_t *g_ssl = NULL;
71 #endif /* WITH_SSL */
72 static int g_sock;
73 static struct stream g_in;
74 static struct stream g_out[STREAM_COUNT];
75 int g_tcp_port_rdp = TCP_PORT_RDP;
76 extern RD_BOOL g_user_quit;
77 extern RD_BOOL g_network_error;
78 extern RD_BOOL g_reconnect_loop;
79 
80 /* Initialise TCP transport data packet */
81 STREAM
82 tcp_init(uint32 maxlen)
83 {
84 	static int cur_stream_id = 0;
85 	STREAM result = NULL;
86 
87 #ifdef WITH_SCARD
88 	scard_lock(SCARD_LOCK_TCP);
89 #endif
90 	result = &g_out[cur_stream_id];
91 	cur_stream_id = (cur_stream_id + 1) % STREAM_COUNT;
92 
93 	if (maxlen > result->size)
94 	{
95 		result->data = (uint8 *) xrealloc(result->data, maxlen);
96 		result->size = maxlen;
97 	}
98 
99 	result->p = result->data;
100 	result->end = result->data + result->size;
101 #ifdef WITH_SCARD
102 	scard_unlock(SCARD_LOCK_TCP);
103 #endif
104 	return result;
105 }
106 
107 #ifdef WITH_SSL
108 
109 RD_BOOL send_ssl_chunk(const void *msg, size_t size)
110 {
111 	SecBuffer bufs[4] = {
112 		{g_ssl->ssl_sizes.cbHeader, SECBUFFER_STREAM_HEADER, g_ssl->ssl_buf},
113 		{size,  SECBUFFER_DATA, g_ssl->ssl_buf+g_ssl->ssl_sizes.cbHeader},
114 		{g_ssl->ssl_sizes.cbTrailer, SECBUFFER_STREAM_TRAILER, g_ssl->ssl_buf+g_ssl->ssl_sizes.cbHeader+size},
115 		{0, SECBUFFER_EMPTY, NULL}
116 	};
117 	SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
118 	SECURITY_STATUS res;
119 	int tcp_res;
120 
121 	memcpy(bufs[1].pvBuffer, msg, size);
122 	res = EncryptMessage(&g_ssl->ssl_ctx, 0, &buf_desc, 0);
123 	if (res != SEC_E_OK)
124 	{
125 		error("EncryptMessage failed: %d\n", res);
126 		return False;
127 	}
128 
129 	tcp_res = send(g_sock, g_ssl->ssl_buf, bufs[0].cbBuffer+bufs[1].cbBuffer+bufs[2].cbBuffer, 0);
130 	if (tcp_res < 1)
131 	{
132 		error("send failed: %d (%s)\n", tcp_res, TCP_STRERROR);
133 		return False;
134 	}
135 
136 	return True;
137 }
138 
139 DWORD read_ssl_chunk(void *buf, SIZE_T buf_size, BOOL blocking, SIZE_T *ret_size, BOOL *eof)
140 {
141 	const SIZE_T ssl_buf_size = g_ssl->ssl_sizes.cbHeader+g_ssl->ssl_sizes.cbMaximumMessage+g_ssl->ssl_sizes.cbTrailer;
142 	SecBuffer bufs[4];
143 	SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
144 	SSIZE_T size, buf_len = 0;
145 	int i;
146 	SECURITY_STATUS res;
147 
148 	//assert(conn->extra_len < ssl_buf_size);
149 
150 	if (g_ssl->extra_len)
151 	{
152 		memcpy(g_ssl->ssl_buf, g_ssl->extra_buf, g_ssl->extra_len);
153 		buf_len = g_ssl->extra_len;
154 		g_ssl->extra_len = 0;
155 		xfree(g_ssl->extra_buf);
156 		g_ssl->extra_buf = NULL;
157 	}
158 
159 	size = recv(g_sock, g_ssl->ssl_buf+buf_len, ssl_buf_size-buf_len, 0);
160 	if (size < 0)
161 	{
162 		if (!buf_len)
163 		{
164 			if (size == -1 && TCP_BLOCKS)
165 			{
166 				return WSAEWOULDBLOCK;
167 			}
168 			error("recv failed: %d (%s)\n", size, TCP_STRERROR);
169 			return -1;//ERROR_INTERNET_CONNECTION_ABORTED;
170 		}
171 	}
172 	else
173 	{
174 		buf_len += size;
175 	}
176 
177 	if (!buf_len)
178 	{
179 		*eof = TRUE;
180 		*ret_size = 0;
181 		return ERROR_SUCCESS;
182 	}
183 
184 	*eof = FALSE;
185 
186 	do
187 	{
188 		memset(bufs, 0, sizeof(bufs));
189 		bufs[0].BufferType = SECBUFFER_DATA;
190 		bufs[0].cbBuffer = buf_len;
191 		bufs[0].pvBuffer = g_ssl->ssl_buf;
192 
193 		res = DecryptMessage(&g_ssl->ssl_ctx, &buf_desc, 0, NULL);
194 		switch (res)
195 		{
196 		case SEC_E_OK:
197 			break;
198 		case SEC_I_CONTEXT_EXPIRED:
199 			*eof = TRUE;
200 			return ERROR_SUCCESS;
201 		case SEC_E_INCOMPLETE_MESSAGE:
202 			//assert(buf_len < ssl_buf_size);
203 
204 			size = recv(g_sock, g_ssl->ssl_buf+buf_len, ssl_buf_size-buf_len, 0);
205 			if (size < 1)
206 			{
207 				if (size == -1 && TCP_BLOCKS)
208 				{
209 					/* FIXME: Optimize extra_buf usage. */
210 					g_ssl->extra_buf = xmalloc(buf_len);
211 					if (!g_ssl->extra_buf)
212 						return ERROR_NOT_ENOUGH_MEMORY;
213 
214 					g_ssl->extra_len = buf_len;
215 					memcpy(g_ssl->extra_buf, g_ssl->ssl_buf, g_ssl->extra_len);
216 					return WSAEWOULDBLOCK;
217 				}
218 
219 				error("recv failed: %d (%s)\n", size, TCP_STRERROR);
220 				return -1;//ERROR_INTERNET_CONNECTION_ABORTED;
221 			}
222 
223 			buf_len += size;
224 			continue;
225 		default:
226 			error("DecryptMessage failed: %d\n", res);
227 			return -1;//ERROR_INTERNET_CONNECTION_ABORTED;
228 		}
229 	}
230 	while (res != SEC_E_OK);
231 
232 	for (i=0; i < sizeof(bufs)/sizeof(*bufs); i++)
233 	{
234 		if (bufs[i].BufferType == SECBUFFER_DATA)
235 		{
236 			size = min(buf_size, bufs[i].cbBuffer);
237 			memcpy(buf, bufs[i].pvBuffer, size);
238 			if (size < bufs[i].cbBuffer)
239 			{
240 				//assert(!conn->peek_len);
241 				g_ssl->peek_msg_mem = g_ssl->peek_msg = xmalloc(bufs[i].cbBuffer - size);
242 				if (!g_ssl->peek_msg)
243 					return ERROR_NOT_ENOUGH_MEMORY;
244 				g_ssl->peek_len = bufs[i].cbBuffer-size;
245 				memcpy(g_ssl->peek_msg, (char*)bufs[i].pvBuffer+size, g_ssl->peek_len);
246 			}
247 
248 			*ret_size = size;
249 		}
250 	}
251 
252 	for (i=0; i < sizeof(bufs)/sizeof(*bufs); i++)
253 	{
254 		if (bufs[i].BufferType == SECBUFFER_EXTRA)
255 		{
256 			g_ssl->extra_buf = xmalloc(bufs[i].cbBuffer);
257 			if (!g_ssl->extra_buf)
258 				return ERROR_NOT_ENOUGH_MEMORY;
259 
260 			g_ssl->extra_len = bufs[i].cbBuffer;
261 			memcpy(g_ssl->extra_buf, bufs[i].pvBuffer, g_ssl->extra_len);
262 		}
263 	}
264 
265 	return ERROR_SUCCESS;
266 }
267 #endif /* WITH_SSL */
268 /* Send TCP transport data packet */
269 void
270 tcp_send(STREAM s)
271 {
272 	int length = s->end - s->data;
273 	int sent, total = 0;
274 
275 	if (g_network_error == True)
276 		return;
277 
278 #ifdef WITH_SCARD
279 	scard_lock(SCARD_LOCK_TCP);
280 #endif
281 	while (total < length)
282 	{
283 #ifdef WITH_SSL
284 		if (g_ssl)
285 		{
286 			const BYTE *ptr = s->data + total;
287 			size_t chunk_size;
288 
289 			sent = 0;
290 
291 			while (length - total)
292 			{
293 				chunk_size = min(length - total, g_ssl->ssl_sizes.cbMaximumMessage);
294 				if (!send_ssl_chunk(ptr, chunk_size))
295 				{
296 #ifdef WITH_SCARD
297 					scard_unlock(SCARD_LOCK_TCP);
298 #endif
299 
300 					//error("send_ssl_chunk: %d (%s)\n", sent, TCP_STRERROR);
301 					g_network_error = True;
302 					return;
303 				}
304 
305 				sent += chunk_size;
306 				ptr += chunk_size;
307 				length -= chunk_size;
308 			}
309 		}
310 		else
311 		{
312 #endif /* WITH_SSL */
313 			sent = send(g_sock, (const char *)s->data + total, length - total, 0);
314 			if (sent <= 0)
315 			{
316 				if (sent == -1 && TCP_BLOCKS)
317 				{
318 					TCP_SLEEP(0);
319 					sent = 0;
320 				}
321 				else
322 				{
323 #ifdef WITH_SCARD
324 					scard_unlock(SCARD_LOCK_TCP);
325 #endif
326 
327 					error("send: %d (%s)\n", sent, TCP_STRERROR);
328 					g_network_error = True;
329 					return;
330 				}
331 			}
332 #ifdef WITH_SSL
333 		}
334 #endif /* WITH_SSL */
335 		total += sent;
336 	}
337 #ifdef WITH_SCARD
338 	scard_unlock(SCARD_LOCK_TCP);
339 #endif
340 }
341 
342 /* Receive a message on the TCP layer */
343 STREAM
344 tcp_recv(STREAM s, uint32 length)
345 {
346 	uint32 new_length, end_offset, p_offset;
347 	int rcvd = 0;
348 
349 	if (g_network_error == True)
350 		return NULL;
351 
352 	if (s == NULL)
353 	{
354 		/* read into "new" stream */
355 		if (length > g_in.size)
356 		{
357 			g_in.data = (uint8 *) xrealloc(g_in.data, length);
358 			g_in.size = length;
359 		}
360 		g_in.end = g_in.p = g_in.data;
361 		s = &g_in;
362 	}
363 	else
364 	{
365 		/* append to existing stream */
366 		new_length = (s->end - s->data) + length;
367 		if (new_length > s->size)
368 		{
369 			p_offset = s->p - s->data;
370 			end_offset = s->end - s->data;
371 			s->data = (uint8 *) xrealloc(s->data, new_length);
372 			s->size = new_length;
373 			s->p = s->data + p_offset;
374 			s->end = s->data + end_offset;
375 		}
376 	}
377 
378 	while (length > 0)
379 	{
380 #ifdef WITH_SSL
381 		if (!g_ssl)
382 #endif /* WITH_SSL */
383 		{
384 			if (!ui_select(g_sock))
385 			{
386 				/* User quit */
387 				g_user_quit = True;
388 				return NULL;
389 			}
390 		}
391 
392 #ifdef WITH_SSL
393 		if (g_ssl)
394 		{
395 			SIZE_T size = 0;
396 			BOOL eof;
397 			DWORD res;
398 
399 			if (g_ssl->peek_msg)
400 			{
401 				size = min(length, g_ssl->peek_len);
402 				memcpy(s->end, g_ssl->peek_msg, size);
403 				g_ssl->peek_len -= size;
404 				g_ssl->peek_msg += size;
405 				s->end += size;
406 
407 				if (!g_ssl->peek_len)
408 				{
409 					xfree(g_ssl->peek_msg_mem);
410 					g_ssl->peek_msg_mem = g_ssl->peek_msg = NULL;
411 				}
412 
413 				return s;
414 			}
415 
416 			do
417 			{
418 				res = read_ssl_chunk((BYTE*)s->end, length, TRUE, &size, &eof);
419 				if (res != ERROR_SUCCESS)
420 				{
421 					if (res == WSAEWOULDBLOCK)
422 					{
423 						if (size)
424 						{
425  							res = ERROR_SUCCESS;
426 						}
427 					}
428 					else
429 					{
430 						error("read_ssl_chunk: %d (%s)\n", res, TCP_STRERROR);
431 						g_network_error = True;
432 						return NULL;
433 					}
434 					break;
435 				}
436 			}
437 			while (!size && !eof);
438 			rcvd = size;
439 		}
440 		else
441 		{
442 #endif /* WITH_SSL */
443 			rcvd = recv(g_sock, (char *)s->end, length, 0);
444 			if (rcvd < 0)
445 			{
446 				if (rcvd == -1 && TCP_BLOCKS)
447 				{
448 					rcvd = 0;
449 				}
450 				else
451 				{
452 					error("recv: %d (%s)\n", rcvd, TCP_STRERROR);
453 					g_network_error = True;
454 					return NULL;
455 				}
456 			}
457 			else if (rcvd == 0)
458 			{
459 				error("Connection closed\n");
460 				return NULL;
461 			}
462 #ifdef WITH_SSL
463 		}
464 #endif /* WITH_SSL */
465 
466 		s->end += rcvd;
467 		length -= rcvd;
468 	}
469 
470 	return s;
471 }
472 
473 #ifdef WITH_SSL
474 
475 RD_BOOL
476 ensure_cred_handle(void)
477 {
478 	SECURITY_STATUS res = SEC_E_OK;
479 
480 	if (!cred_handle_initialized)
481 	{
482 		SCHANNEL_CRED cred = {SCHANNEL_CRED_VERSION};
483 		SecPkgCred_SupportedProtocols prots;
484 
485 		res = AcquireCredentialsHandleW(NULL, (WCHAR*)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, NULL, &cred,
486 				NULL, NULL, &cred_handle, NULL);
487 		if (res == SEC_E_OK)
488 		{
489 			res = QueryCredentialsAttributesA(&cred_handle, SECPKG_ATTR_SUPPORTED_PROTOCOLS, &prots);
490 			if (res != SEC_E_OK || (prots.grbitProtocol & SP_PROT_TLS1_1PLUS_CLIENT))
491 			{
492 				cred.grbitEnabledProtocols = prots.grbitProtocol & ~SP_PROT_TLS1_1PLUS_CLIENT;
493 				res = AcquireCredentialsHandleW(NULL, (WCHAR*)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, NULL, &cred,
494 						NULL, NULL, &compat_cred_handle, NULL);
495 				have_compat_cred_handle = res == SEC_E_OK;
496 			}
497 		}
498 
499 		cred_handle_initialized = res == SEC_E_OK;
500 	}
501 
502 	if (res != SEC_E_OK)
503 	{
504 		error("ensure_cred_handle failed: %ld\n", res);
505 		return False;
506 	}
507 
508 	return True;
509 }
510 
511 DWORD
512 ssl_handshake(RD_BOOL compat_mode)
513 {
514 	SecBuffer out_buf = {0, SECBUFFER_TOKEN, NULL}, in_bufs[2] = {{0, SECBUFFER_TOKEN}, {0, SECBUFFER_EMPTY}};
515 	SecBufferDesc out_desc = {SECBUFFER_VERSION, 1, &out_buf}, in_desc = {SECBUFFER_VERSION, 2, in_bufs};
516 	SecHandle *cred = &cred_handle;
517 	BYTE *read_buf;
518 	SIZE_T read_buf_size = 2048;
519 	ULONG attrs = 0;
520 	CtxtHandle ctx;
521 	SSIZE_T size;
522 	SECURITY_STATUS status;
523 	DWORD res = ERROR_SUCCESS;
524 
525 	const DWORD isc_req_flags = ISC_REQ_ALLOCATE_MEMORY|ISC_REQ_USE_SESSION_KEY|ISC_REQ_CONFIDENTIALITY
526 		|ISC_REQ_SEQUENCE_DETECT|ISC_REQ_REPLAY_DETECT|ISC_REQ_MANUAL_CRED_VALIDATION;
527 
528 	if (!ensure_cred_handle())
529 		return -1;
530 
531 	if (compat_mode) {
532 		if (!have_compat_cred_handle)
533 			return -1;
534 		cred = &compat_cred_handle;
535 	}
536 
537 	read_buf = xmalloc(read_buf_size);
538 	if (!read_buf)
539 		return ERROR_OUTOFMEMORY;
540 
541 	if (!g_ssl_server)
542 		return -1;
543 
544 	status = InitializeSecurityContextA(cred, NULL, g_ssl_server, isc_req_flags, 0, 0, NULL, 0,
545 			&ctx, &out_desc, &attrs, NULL);
546 
547 	//assert(status != SEC_E_OK);
548 
549 	while (status == SEC_I_CONTINUE_NEEDED || status == SEC_E_INCOMPLETE_MESSAGE)
550 	{
551 		if (out_buf.cbBuffer)
552 		{
553 			//assert(status == SEC_I_CONTINUE_NEEDED);
554 
555 			size = send(g_sock, out_buf.pvBuffer, out_buf.cbBuffer, 0);
556 			if (size != out_buf.cbBuffer)
557 			{
558 				error("send failed: %d (%s)\n", size, TCP_STRERROR);
559 				status = -1;
560 				break;
561 			}
562 
563 			FreeContextBuffer(out_buf.pvBuffer);
564 			out_buf.pvBuffer = NULL;
565 			out_buf.cbBuffer = 0;
566 		}
567 
568 		if (status == SEC_I_CONTINUE_NEEDED)
569 		{
570 			//assert(in_bufs[1].cbBuffer < read_buf_size);
571 
572 			memmove(read_buf, (BYTE*)in_bufs[0].pvBuffer+in_bufs[0].cbBuffer-in_bufs[1].cbBuffer, in_bufs[1].cbBuffer);
573 			in_bufs[0].cbBuffer = in_bufs[1].cbBuffer;
574 
575 			in_bufs[1].BufferType = SECBUFFER_EMPTY;
576 			in_bufs[1].cbBuffer = 0;
577 			in_bufs[1].pvBuffer = NULL;
578 		}
579 
580 		//assert(in_bufs[0].BufferType == SECBUFFER_TOKEN);
581 		//assert(in_bufs[1].BufferType == SECBUFFER_EMPTY);
582 
583 		if (in_bufs[0].cbBuffer + 1024 > read_buf_size)
584 		{
585 			BYTE *new_read_buf = xrealloc(read_buf, read_buf_size + 1024);
586 			if (!new_read_buf)
587 			{
588 				status = E_OUTOFMEMORY;
589 				break;
590 			}
591 
592 			in_bufs[0].pvBuffer = read_buf = new_read_buf;
593 			read_buf_size += 1024;
594 		}
595 
596 		size = recv(g_sock, (char *)read_buf + in_bufs[0].cbBuffer, read_buf_size - in_bufs[0].cbBuffer, 0);
597 		if (size < 1)
598 		{
599 			error("recv failed: %d (%s)\n", size, TCP_STRERROR);
600 			res = -1;
601 			break;
602 		}
603 
604 		in_bufs[0].cbBuffer += size;
605 		in_bufs[0].pvBuffer = read_buf;
606 		status = InitializeSecurityContextA(cred, &ctx, g_ssl_server,  isc_req_flags, 0, 0, &in_desc,
607 				0, NULL, &out_desc, &attrs, NULL);
608 
609 		if (status == SEC_E_OK) {
610 			if (SecIsValidHandle(&g_ssl->ssl_ctx))
611 				DeleteSecurityContext(&g_ssl->ssl_ctx);
612 			g_ssl->ssl_ctx = ctx;
613 
614 			if (in_bufs[1].BufferType == SECBUFFER_EXTRA)
615 			{
616 				//FIXME("SECBUFFER_EXTRA not supported\n");
617 			}
618 
619 			status = QueryContextAttributesW(&ctx, SECPKG_ATTR_STREAM_SIZES, &g_ssl->ssl_sizes);
620 			if (status != SEC_E_OK)
621 			{
622 				//error("Can't determine ssl buffer sizes: %ld\n", status);
623 				break;
624 			}
625 
626 			g_ssl->ssl_buf = xmalloc(g_ssl->ssl_sizes.cbHeader + g_ssl->ssl_sizes.cbMaximumMessage
627 					+ g_ssl->ssl_sizes.cbTrailer);
628 			if (!g_ssl->ssl_buf)
629 			{
630 				res = GetLastError();
631 				break;
632 			}
633 		}
634 	}
635 
636 	xfree(read_buf);
637 
638 	if (status != SEC_E_OK || res != ERROR_SUCCESS)
639 	{
640 		error("Failed to establish SSL connection: %08x (%u)\n", status, res);
641 		xfree(g_ssl->ssl_buf);
642 		g_ssl->ssl_buf = NULL;
643 		return res ? res : -1;
644 	}
645 
646 	return ERROR_SUCCESS;
647 }
648 
649 /* Establish a SSL/TLS 1.0 connection */
650 RD_BOOL
651 tcp_tls_connect(void)
652 {
653 	int err;
654 	char tcp_port_rdp_s[10];
655 
656 	if (!g_ssl_initialized)
657 	{
658 		g_ssl = &g_ssl1;
659 		SecInvalidateHandle(&g_ssl->ssl_ctx);
660 
661 		g_ssl_initialized = True;
662 	}
663 
664 	snprintf(tcp_port_rdp_s, 10, "%d", g_tcp_port_rdp);
665 	if ((err = ssl_handshake(FALSE)) != 0)
666 	{
667 		goto fail;
668 	}
669 
670 	return True;
671 
672       fail:
673 	g_ssl = NULL;
674 	return False;
675 }
676 
677 /* Get public key from server of TLS 1.0 connection */
678 RD_BOOL
679 tcp_tls_get_server_pubkey(STREAM s)
680 {
681 	const CERT_CONTEXT *cert = NULL;
682 	SECURITY_STATUS status;
683 
684 	s->data = s->p = NULL;
685 	s->size = 0;
686 
687 	if (g_ssl == NULL)
688 		goto out;
689 	status = QueryContextAttributesW(&g_ssl->ssl_ctx, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (void*)&cert);
690 	if (status != SEC_E_OK)
691 	{
692 		error("tcp_tls_get_server_pubkey: QueryContextAttributesW() failed %ld\n", status);
693 		goto out;
694 	}
695 
696 	s->size = cert->cbCertEncoded;
697 	if (s->size < 1)
698 	{
699 		error("tcp_tls_get_server_pubkey: cert->cbCertEncoded = %ld\n", cert->cbCertEncoded);
700 		goto out;
701 	}
702 
703 	s->data = s->p = (unsigned char *)xmalloc(s->size);
704 	memcpy(cert->pbCertEncoded, &s->p, s->size);
705 	s->p = s->data;
706 	s->end = s->p + s->size;
707 
708       out:
709 	if (cert)
710 		CertFreeCertificateContext(cert);
711 	return (s->size != 0);
712 }
713 #endif /* WITH_SSL */
714 
715 /* Establish a connection on the TCP layer */
716 RD_BOOL
717 tcp_connect(char *server)
718 {
719 	socklen_t option_len;
720 	uint32 option_value;
721 	int i;
722 
723 #ifdef IPv6
724 
725 	int n;
726 	struct addrinfo hints, *res, *ressave;
727 	char tcp_port_rdp_s[10];
728 
729 	snprintf(tcp_port_rdp_s, 10, "%d", g_tcp_port_rdp);
730 
731 	memset(&hints, 0, sizeof(struct addrinfo));
732 	hints.ai_family = AF_UNSPEC;
733 	hints.ai_socktype = SOCK_STREAM;
734 
735 	if ((n = getaddrinfo(server, tcp_port_rdp_s, &hints, &res)))
736 	{
737 		error("getaddrinfo: %s\n", gai_strerror(n));
738 		return False;
739 	}
740 
741 	ressave = res;
742 	g_sock = -1;
743 	while (res)
744 	{
745 		g_sock = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
746 		if (!(g_sock < 0))
747 		{
748 			if (connect(g_sock, res->ai_addr, res->ai_addrlen) == 0)
749 				break;
750 			TCP_CLOSE(g_sock);
751 			g_sock = -1;
752 		}
753 		res = res->ai_next;
754 	}
755 	freeaddrinfo(ressave);
756 
757 	if (g_sock == -1)
758 	{
759 		error("%s: unable to connect\n", server);
760 		return False;
761 	}
762 
763 #else /* no IPv6 support */
764 
765 	struct hostent *nslookup;
766 	struct sockaddr_in servaddr;
767 
768 	if ((nslookup = gethostbyname(server)) != NULL)
769 	{
770 		memcpy(&servaddr.sin_addr, nslookup->h_addr, sizeof(servaddr.sin_addr));
771 	}
772 	else if ((servaddr.sin_addr.s_addr = inet_addr(server)) == INADDR_NONE)
773 	{
774 		error("%s: unable to resolve host\n", server);
775 		return False;
776 	}
777 
778 	if ((g_sock = socket(AF_INET, SOCK_STREAM, 0)) < 0)
779 	{
780 		error("socket: %s\n", TCP_STRERROR);
781 		return False;
782 	}
783 
784 	servaddr.sin_family = AF_INET;
785 	servaddr.sin_port = htons((uint16) g_tcp_port_rdp);
786 
787 	if (connect(g_sock, (struct sockaddr *) &servaddr, sizeof(struct sockaddr)) < 0)
788 	{
789 		if (!g_reconnect_loop)
790 			error("connect: %s\n", TCP_STRERROR);
791 
792 		TCP_CLOSE(g_sock);
793 		g_sock = -1;
794 		return False;
795 	}
796 
797 #endif /* IPv6 */
798 
799 	option_value = 1;
800 	option_len = sizeof(option_value);
801 	setsockopt(g_sock, IPPROTO_TCP, TCP_NODELAY, (void *) &option_value, option_len);
802 	/* receive buffer must be a least 16 K */
803 	if (getsockopt(g_sock, SOL_SOCKET, SO_RCVBUF, (void *) &option_value, &option_len) == 0)
804 	{
805 		if (option_value < (1024 * 16))
806 		{
807 			option_value = 1024 * 16;
808 			option_len = sizeof(option_value);
809 			setsockopt(g_sock, SOL_SOCKET, SO_RCVBUF, (void *) &option_value,
810 				   option_len);
811 		}
812 	}
813 
814 	g_in.size = 4096;
815 	g_in.data = (uint8 *) xmalloc(g_in.size);
816 
817 	for (i = 0; i < STREAM_COUNT; i++)
818 	{
819 		g_out[i].size = 4096;
820 		g_out[i].data = (uint8 *) xmalloc(g_out[i].size);
821 	}
822 
823 #ifdef WITH_SSL
824 	g_ssl_server = xmalloc(strlen(server)+1);
825 #endif /* WITH_SSL */
826 
827 	return True;
828 }
829 
830 /* Disconnect on the TCP layer */
831 void
832 tcp_disconnect(void)
833 {
834 #ifdef WITH_SSL
835 	if (g_ssl)
836 	{
837 		xfree(g_ssl->peek_msg_mem);
838 		g_ssl->peek_msg_mem = NULL;
839 		g_ssl->peek_msg = NULL;
840 		g_ssl->peek_len = 0;
841 		xfree(g_ssl->ssl_buf);
842 		g_ssl->ssl_buf = NULL;
843 		xfree(g_ssl->extra_buf);
844 		g_ssl->extra_buf = NULL;
845 		g_ssl->extra_len = 0;
846 		if (SecIsValidHandle(&g_ssl->ssl_ctx))
847 			DeleteSecurityContext(&g_ssl->ssl_ctx);
848 		if (cred_handle_initialized)
849 			FreeCredentialsHandle(&cred_handle);
850 		if (have_compat_cred_handle)
851 			FreeCredentialsHandle(&compat_cred_handle);
852 		if (g_ssl_server)
853 		{
854 			xfree(g_ssl_server);
855 			g_ssl_server = NULL;
856 		}
857 		g_ssl = NULL;
858 		g_ssl_initialized = False;
859 	}
860 #endif /* WITH_SSL */
861 	TCP_CLOSE(g_sock);
862 	g_sock = -1;
863 }
864 
865 char *
866 tcp_get_address()
867 {
868 	static char ipaddr[32];
869 	struct sockaddr_in sockaddr;
870 	socklen_t len = sizeof(sockaddr);
871 	if (getsockname(g_sock, (struct sockaddr *) &sockaddr, &len) == 0)
872 	{
873 		uint8 *ip = (uint8 *) & sockaddr.sin_addr;
874 		sprintf(ipaddr, "%d.%d.%d.%d", ip[0], ip[1], ip[2], ip[3]);
875 	}
876 	else
877 		strcpy(ipaddr, "127.0.0.1");
878 	return ipaddr;
879 }
880 
881 RD_BOOL
882 tcp_is_connected()
883 {
884 	struct sockaddr_in sockaddr;
885 	socklen_t len = sizeof(sockaddr);
886 	if (getpeername(g_sock, (struct sockaddr *) &sockaddr, &len))
887 		return True;
888 	return False;
889 }
890 
891 /* reset the state of the tcp layer */
892 /* Support for Session Directory */
893 void
894 tcp_reset_state(void)
895 {
896 	int i;
897 
898 	/* Clear the incoming stream */
899 	if (g_in.data != NULL)
900 		xfree(g_in.data);
901 	g_in.p = NULL;
902 	g_in.end = NULL;
903 	g_in.data = NULL;
904 	g_in.size = 0;
905 	g_in.iso_hdr = NULL;
906 	g_in.mcs_hdr = NULL;
907 	g_in.sec_hdr = NULL;
908 	g_in.rdp_hdr = NULL;
909 	g_in.channel_hdr = NULL;
910 
911 	/* Clear the outgoing stream(s) */
912 	for (i = 0; i < STREAM_COUNT; i++)
913 	{
914 		if (g_out[i].data != NULL)
915 			xfree(g_out[i].data);
916 		g_out[i].p = NULL;
917 		g_out[i].end = NULL;
918 		g_out[i].data = NULL;
919 		g_out[i].size = 0;
920 		g_out[i].iso_hdr = NULL;
921 		g_out[i].mcs_hdr = NULL;
922 		g_out[i].sec_hdr = NULL;
923 		g_out[i].rdp_hdr = NULL;
924 		g_out[i].channel_hdr = NULL;
925 	}
926 }
927