xref: /reactos/dll/win32/wininet/netconnection.c (revision c2c66aff)
1 /*
2  * Wininet - networking layer
3  *
4  * Copyright 2002 TransGaming Technologies Inc.
5  * Copyright 2013 Jacek Caban for CodeWeavers
6  *
7  * David Hammerton
8  *
9  * This library is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public
11  * License as published by the Free Software Foundation; either
12  * version 2.1 of the License, or (at your option) any later version.
13  *
14  * This library is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with this library; if not, write to the Free Software
21  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
22  */
23 
24 #include "internet.h"
25 
26 #include <sys/types.h>
27 #ifdef HAVE_POLL_H
28 #include <poll.h>
29 #endif
30 #ifdef HAVE_SYS_FILIO_H
31 # include <sys/filio.h>
32 #endif
33 #ifdef HAVE_NETINET_TCP_H
34 # include <netinet/tcp.h>
35 #endif
36 
37 #include <errno.h>
38 
39 static DWORD netconn_verify_cert(netconn_t *conn, PCCERT_CONTEXT cert, HCERTSTORE store)
40 {
41     BOOL ret;
42     CERT_CHAIN_PARA chainPara = { sizeof(chainPara), { 0 } };
43     PCCERT_CHAIN_CONTEXT chain;
44     char oid_server_auth[] = szOID_PKIX_KP_SERVER_AUTH;
45     char *server_auth[] = { oid_server_auth };
46     DWORD err = ERROR_SUCCESS, errors;
47 
48     static const DWORD supportedErrors =
49         CERT_TRUST_IS_NOT_TIME_VALID |
50         CERT_TRUST_IS_UNTRUSTED_ROOT |
51         CERT_TRUST_IS_PARTIAL_CHAIN |
52         CERT_TRUST_IS_NOT_SIGNATURE_VALID |
53         CERT_TRUST_IS_NOT_VALID_FOR_USAGE;
54 
55     TRACE("verifying %s\n", debugstr_w(conn->server->name));
56 
57     chainPara.RequestedUsage.Usage.cUsageIdentifier = 1;
58     chainPara.RequestedUsage.Usage.rgpszUsageIdentifier = server_auth;
59     if (!(ret = CertGetCertificateChain(NULL, cert, NULL, store, &chainPara, 0, NULL, &chain))) {
60         TRACE("failed\n");
61         return GetLastError();
62     }
63 
64     errors = chain->TrustStatus.dwErrorStatus;
65 
66     do {
67         /* This seems strange, but that's what tests show */
68         if(errors & CERT_TRUST_IS_PARTIAL_CHAIN) {
69             WARN("ERROR_INTERNET_SEC_CERT_REV_FAILED\n");
70             err = ERROR_INTERNET_SEC_CERT_REV_FAILED;
71             if(conn->mask_errors)
72                 conn->security_flags |= _SECURITY_FLAG_CERT_REV_FAILED;
73             if(!(conn->security_flags & SECURITY_FLAG_IGNORE_REVOCATION))
74                 break;
75         }
76 
77         if (chain->TrustStatus.dwErrorStatus & ~supportedErrors) {
78             WARN("error status %x\n", chain->TrustStatus.dwErrorStatus & ~supportedErrors);
79             err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_INVALID_CERT;
80             errors &= supportedErrors;
81             if(!conn->mask_errors)
82                 break;
83             WARN("unknown error flags\n");
84         }
85 
86         if(errors & CERT_TRUST_IS_NOT_TIME_VALID) {
87             WARN("CERT_TRUST_IS_NOT_TIME_VALID\n");
88             if(!(conn->security_flags & SECURITY_FLAG_IGNORE_CERT_DATE_INVALID)) {
89                 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_CERT_DATE_INVALID;
90                 if(!conn->mask_errors)
91                     break;
92                 conn->security_flags |= _SECURITY_FLAG_CERT_INVALID_DATE;
93             }
94             errors &= ~CERT_TRUST_IS_NOT_TIME_VALID;
95         }
96 
97         if(errors & CERT_TRUST_IS_UNTRUSTED_ROOT) {
98             WARN("CERT_TRUST_IS_UNTRUSTED_ROOT\n");
99             if(!(conn->security_flags & SECURITY_FLAG_IGNORE_UNKNOWN_CA)) {
100                 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_INVALID_CA;
101                 if(!conn->mask_errors)
102                     break;
103                 conn->security_flags |= _SECURITY_FLAG_CERT_INVALID_CA;
104             }
105             errors &= ~CERT_TRUST_IS_UNTRUSTED_ROOT;
106         }
107 
108         if(errors & CERT_TRUST_IS_PARTIAL_CHAIN) {
109             WARN("CERT_TRUST_IS_PARTIAL_CHAIN\n");
110             if(!(conn->security_flags & SECURITY_FLAG_IGNORE_UNKNOWN_CA)) {
111                 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_INVALID_CA;
112                 if(!conn->mask_errors)
113                     break;
114                 conn->security_flags |= _SECURITY_FLAG_CERT_INVALID_CA;
115             }
116             errors &= ~CERT_TRUST_IS_PARTIAL_CHAIN;
117         }
118 
119         if(errors & CERT_TRUST_IS_NOT_SIGNATURE_VALID) {
120             WARN("CERT_TRUST_IS_NOT_SIGNATURE_VALID\n");
121             if(!(conn->security_flags & SECURITY_FLAG_IGNORE_UNKNOWN_CA)) {
122                 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_INVALID_CA;
123                 if(!conn->mask_errors)
124                     break;
125                 conn->security_flags |= _SECURITY_FLAG_CERT_INVALID_CA;
126             }
127             errors &= ~CERT_TRUST_IS_NOT_SIGNATURE_VALID;
128         }
129 
130         if(errors & CERT_TRUST_IS_NOT_VALID_FOR_USAGE) {
131             WARN("CERT_TRUST_IS_NOT_VALID_FOR_USAGE\n");
132             if(!(conn->security_flags & SECURITY_FLAG_IGNORE_WRONG_USAGE)) {
133                 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_INVALID_CERT;
134                 if(!conn->mask_errors)
135                     break;
136                 WARN("CERT_TRUST_IS_NOT_VALID_FOR_USAGE, unknown error flags\n");
137             }
138             errors &= ~CERT_TRUST_IS_NOT_VALID_FOR_USAGE;
139         }
140 
141         if(err == ERROR_INTERNET_SEC_CERT_REV_FAILED) {
142             assert(conn->security_flags & SECURITY_FLAG_IGNORE_REVOCATION);
143             err = ERROR_SUCCESS;
144         }
145     }while(0);
146 
147     if(!err || conn->mask_errors) {
148         CERT_CHAIN_POLICY_PARA policyPara;
149         SSL_EXTRA_CERT_CHAIN_POLICY_PARA sslExtraPolicyPara;
150         CERT_CHAIN_POLICY_STATUS policyStatus;
151         CERT_CHAIN_CONTEXT chainCopy;
152 
153         /* Clear chain->TrustStatus.dwErrorStatus so
154          * CertVerifyCertificateChainPolicy will verify additional checks
155          * rather than stopping with an existing, ignored error.
156          */
157         memcpy(&chainCopy, chain, sizeof(chainCopy));
158         chainCopy.TrustStatus.dwErrorStatus = 0;
159         sslExtraPolicyPara.u.cbSize = sizeof(sslExtraPolicyPara);
160         sslExtraPolicyPara.dwAuthType = AUTHTYPE_SERVER;
161         sslExtraPolicyPara.pwszServerName = conn->server->name;
162         sslExtraPolicyPara.fdwChecks = conn->security_flags;
163         policyPara.cbSize = sizeof(policyPara);
164         policyPara.dwFlags = 0;
165         policyPara.pvExtraPolicyPara = &sslExtraPolicyPara;
166         ret = CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL,
167                 &chainCopy, &policyPara, &policyStatus);
168         /* Any error in the policy status indicates that the
169          * policy couldn't be verified.
170          */
171         if(ret) {
172             if(policyStatus.dwError == CERT_E_CN_NO_MATCH) {
173                 WARN("CERT_E_CN_NO_MATCH\n");
174                 if(conn->mask_errors)
175                     conn->security_flags |= _SECURITY_FLAG_CERT_INVALID_CN;
176                 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_CERT_CN_INVALID;
177             }else if(policyStatus.dwError) {
178                 WARN("policyStatus.dwError %x\n", policyStatus.dwError);
179                 if(conn->mask_errors)
180                     WARN("unknown error flags for policy status %x\n", policyStatus.dwError);
181                 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_INVALID_CERT;
182             }
183         }else {
184             err = GetLastError();
185         }
186     }
187 
188     if(err) {
189         WARN("failed %u\n", err);
190         CertFreeCertificateChain(chain);
191         if(conn->server->cert_chain) {
192             CertFreeCertificateChain(conn->server->cert_chain);
193             conn->server->cert_chain = NULL;
194         }
195         if(conn->mask_errors)
196             conn->server->security_flags |= conn->security_flags & _SECURITY_ERROR_FLAGS_MASK;
197         return err;
198     }
199 
200     /* FIXME: Reuse cached chain */
201     if(conn->server->cert_chain)
202         CertFreeCertificateChain(chain);
203     else
204         conn->server->cert_chain = chain;
205     return ERROR_SUCCESS;
206 }
207 
208 static SecHandle cred_handle, compat_cred_handle;
209 static BOOL cred_handle_initialized, have_compat_cred_handle;
210 
211 static CRITICAL_SECTION init_sechandle_cs;
212 static CRITICAL_SECTION_DEBUG init_sechandle_cs_debug = {
213     0, 0, &init_sechandle_cs,
214     { &init_sechandle_cs_debug.ProcessLocksList,
215       &init_sechandle_cs_debug.ProcessLocksList },
216     0, 0, { (DWORD_PTR)(__FILE__ ": init_sechandle_cs") }
217 };
218 static CRITICAL_SECTION init_sechandle_cs = { &init_sechandle_cs_debug, -1, 0, 0, 0, 0 };
219 
220 static BOOL ensure_cred_handle(void)
221 {
222     SECURITY_STATUS res = SEC_E_OK;
223 
224     EnterCriticalSection(&init_sechandle_cs);
225 
226     if(!cred_handle_initialized) {
227         SCHANNEL_CRED cred = {SCHANNEL_CRED_VERSION};
228         SecPkgCred_SupportedProtocols prots;
229 
230         res = AcquireCredentialsHandleW(NULL, (WCHAR*)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, NULL, &cred,
231                 NULL, NULL, &cred_handle, NULL);
232         if(res == SEC_E_OK) {
233             res = QueryCredentialsAttributesA(&cred_handle, SECPKG_ATTR_SUPPORTED_PROTOCOLS, &prots);
234             if(res != SEC_E_OK || (prots.grbitProtocol & SP_PROT_TLS1_1PLUS_CLIENT)) {
235                 cred.grbitEnabledProtocols = prots.grbitProtocol & ~SP_PROT_TLS1_1PLUS_CLIENT;
236                 res = AcquireCredentialsHandleW(NULL, (WCHAR*)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, NULL, &cred,
237                        NULL, NULL, &compat_cred_handle, NULL);
238                 have_compat_cred_handle = res == SEC_E_OK;
239             }
240         }
241 
242         cred_handle_initialized = res == SEC_E_OK;
243     }
244 
245     LeaveCriticalSection(&init_sechandle_cs);
246 
247     if(res != SEC_E_OK) {
248         WARN("Failed: %08x\n", res);
249         return FALSE;
250     }
251 
252     return TRUE;
253 }
254 
255 static BOOL winsock_loaded = FALSE;
256 
257 static BOOL WINAPI winsock_startup(INIT_ONCE *once, void *param, void **context)
258 {
259     WSADATA wsa_data;
260     DWORD res;
261 
262     res = WSAStartup(MAKEWORD(2,2), &wsa_data);
263     if(res == ERROR_SUCCESS)
264         winsock_loaded = TRUE;
265     else
266         ERR("WSAStartup failed: %u\n", res);
267     return TRUE;
268 }
269 
270 void init_winsock(void)
271 {
272     static INIT_ONCE init_once = INIT_ONCE_STATIC_INIT;
273     InitOnceExecuteOnce(&init_once, winsock_startup, NULL, NULL);
274 }
275 
276 static void set_socket_blocking(netconn_t *conn, BOOL is_blocking)
277 {
278     if(conn->is_blocking != is_blocking) {
279         ULONG arg = !is_blocking;
280         ioctlsocket(conn->socket, FIONBIO, &arg);
281     }
282     conn->is_blocking = is_blocking;
283 }
284 
285 static DWORD create_netconn_socket(server_t *server, netconn_t *netconn, DWORD timeout)
286 {
287     int result;
288     ULONG flag;
289     DWORD res;
290 
291     init_winsock();
292 
293     assert(server->addr_len);
294     result = netconn->socket = socket(server->addr.ss_family, SOCK_STREAM, 0);
295     if(result != -1) {
296         set_socket_blocking(netconn, FALSE);
297         result = connect(netconn->socket, (struct sockaddr*)&server->addr, server->addr_len);
298         if(result == -1)
299         {
300             res = WSAGetLastError();
301             if (res == WSAEINPROGRESS || res == WSAEWOULDBLOCK) {
302                 FD_SET set;
303                 int res;
304                 socklen_t len = sizeof(res);
305                 TIMEVAL timeout_timeval = {0, timeout*1000};
306 
307                 FD_ZERO(&set);
308                 FD_SET(netconn->socket, &set);
309                 res = select(netconn->socket+1, NULL, &set, NULL, &timeout_timeval);
310                 if(!res || res == SOCKET_ERROR) {
311                     closesocket(netconn->socket);
312                     netconn->socket = -1;
313                     return ERROR_INTERNET_CANNOT_CONNECT;
314                 }
315                 if (!getsockopt(netconn->socket, SOL_SOCKET, SO_ERROR, (void *)&res, &len) && !res)
316                     result = 0;
317             }
318         }
319         if(result == -1)
320         {
321             closesocket(netconn->socket);
322             netconn->socket = -1;
323         }
324     }
325     if(result == -1)
326         return ERROR_INTERNET_CANNOT_CONNECT;
327 
328     flag = 1;
329     result = setsockopt(netconn->socket, IPPROTO_TCP, TCP_NODELAY, (void*)&flag, sizeof(flag));
330     if(result < 0)
331         WARN("setsockopt(TCP_NODELAY) failed\n");
332 
333     return ERROR_SUCCESS;
334 }
335 
336 DWORD create_netconn(BOOL useSSL, server_t *server, DWORD security_flags, BOOL mask_errors, DWORD timeout, netconn_t **ret)
337 {
338     netconn_t *netconn;
339     int result;
340 
341     netconn = heap_alloc_zero(sizeof(*netconn));
342     if(!netconn)
343         return ERROR_OUTOFMEMORY;
344 
345     netconn->socket = -1;
346     netconn->security_flags = security_flags | server->security_flags;
347     netconn->mask_errors = mask_errors;
348     list_init(&netconn->pool_entry);
349     SecInvalidateHandle(&netconn->ssl_ctx);
350 
351     result = create_netconn_socket(server, netconn, timeout);
352     if (result != ERROR_SUCCESS) {
353         heap_free(netconn);
354         return result;
355     }
356 
357     server_addref(server);
358     netconn->server = server;
359     *ret = netconn;
360     return result;
361 }
362 
363 BOOL is_valid_netconn(netconn_t *netconn)
364 {
365     return netconn && netconn->socket != -1;
366 }
367 
368 void close_netconn(netconn_t *netconn)
369 {
370     closesocket(netconn->socket);
371     netconn->socket = -1;
372 }
373 
374 void free_netconn(netconn_t *netconn)
375 {
376     server_release(netconn->server);
377 
378     if (netconn->secure) {
379         heap_free(netconn->peek_msg_mem);
380         netconn->peek_msg_mem = NULL;
381         netconn->peek_msg = NULL;
382         netconn->peek_len = 0;
383         heap_free(netconn->ssl_buf);
384         netconn->ssl_buf = NULL;
385         heap_free(netconn->extra_buf);
386         netconn->extra_buf = NULL;
387         netconn->extra_len = 0;
388         if (SecIsValidHandle(&netconn->ssl_ctx))
389             DeleteSecurityContext(&netconn->ssl_ctx);
390     }
391 
392     close_netconn(netconn);
393     heap_free(netconn);
394 }
395 
396 void NETCON_unload(void)
397 {
398     if(cred_handle_initialized)
399         FreeCredentialsHandle(&cred_handle);
400     if(have_compat_cred_handle)
401         FreeCredentialsHandle(&compat_cred_handle);
402     DeleteCriticalSection(&init_sechandle_cs);
403     if(winsock_loaded)
404         WSACleanup();
405 }
406 
407 int sock_send(int fd, const void *msg, size_t len, int flags)
408 {
409     int ret;
410     do
411     {
412         ret = send(fd, msg, len, flags);
413     }
414     while(ret == -1 && WSAGetLastError() == WSAEINTR);
415     return ret;
416 }
417 
418 int sock_recv(int fd, void *msg, size_t len, int flags)
419 {
420     int ret;
421     do
422     {
423         ret = recv(fd, msg, len, flags);
424     }
425     while(ret == -1 && WSAGetLastError() == WSAEINTR);
426     return ret;
427 }
428 
429 static DWORD netcon_secure_connect_setup(netconn_t *connection, BOOL compat_mode)
430 {
431     SecBuffer out_buf = {0, SECBUFFER_TOKEN, NULL}, in_bufs[2] = {{0, SECBUFFER_TOKEN}, {0, SECBUFFER_EMPTY}};
432     SecBufferDesc out_desc = {SECBUFFER_VERSION, 1, &out_buf}, in_desc = {SECBUFFER_VERSION, 2, in_bufs};
433     SecHandle *cred = &cred_handle;
434     BYTE *read_buf;
435     SIZE_T read_buf_size = 2048;
436     ULONG attrs = 0;
437     CtxtHandle ctx;
438     SSIZE_T size;
439     int bits;
440     const CERT_CONTEXT *cert;
441     SECURITY_STATUS status;
442     DWORD res = ERROR_SUCCESS;
443 
444     const DWORD isc_req_flags = ISC_REQ_ALLOCATE_MEMORY|ISC_REQ_USE_SESSION_KEY|ISC_REQ_CONFIDENTIALITY
445         |ISC_REQ_SEQUENCE_DETECT|ISC_REQ_REPLAY_DETECT|ISC_REQ_MANUAL_CRED_VALIDATION;
446 
447     if(!ensure_cred_handle())
448         return ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
449 
450     if(compat_mode) {
451         if(!have_compat_cred_handle)
452             return ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
453         cred = &compat_cred_handle;
454     }
455 
456     read_buf = heap_alloc(read_buf_size);
457     if(!read_buf)
458         return ERROR_OUTOFMEMORY;
459 
460     status = InitializeSecurityContextW(cred, NULL, connection->server->name, isc_req_flags, 0, 0, NULL, 0,
461             &ctx, &out_desc, &attrs, NULL);
462 
463     assert(status != SEC_E_OK);
464 
465     set_socket_blocking(connection, TRUE);
466 
467     while(status == SEC_I_CONTINUE_NEEDED || status == SEC_E_INCOMPLETE_MESSAGE) {
468         if(out_buf.cbBuffer) {
469             assert(status == SEC_I_CONTINUE_NEEDED);
470 
471             TRACE("sending %u bytes\n", out_buf.cbBuffer);
472 
473             size = sock_send(connection->socket, out_buf.pvBuffer, out_buf.cbBuffer, 0);
474             if(size != out_buf.cbBuffer) {
475                 ERR("send failed\n");
476                 status = ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
477                 break;
478             }
479 
480             FreeContextBuffer(out_buf.pvBuffer);
481             out_buf.pvBuffer = NULL;
482             out_buf.cbBuffer = 0;
483         }
484 
485         if(status == SEC_I_CONTINUE_NEEDED) {
486             assert(in_bufs[1].cbBuffer < read_buf_size);
487 
488             memmove(read_buf, (BYTE*)in_bufs[0].pvBuffer+in_bufs[0].cbBuffer-in_bufs[1].cbBuffer, in_bufs[1].cbBuffer);
489             in_bufs[0].cbBuffer = in_bufs[1].cbBuffer;
490 
491             in_bufs[1].BufferType = SECBUFFER_EMPTY;
492             in_bufs[1].cbBuffer = 0;
493             in_bufs[1].pvBuffer = NULL;
494         }
495 
496         assert(in_bufs[0].BufferType == SECBUFFER_TOKEN);
497         assert(in_bufs[1].BufferType == SECBUFFER_EMPTY);
498 
499         if(in_bufs[0].cbBuffer + 1024 > read_buf_size) {
500             BYTE *new_read_buf;
501 
502             new_read_buf = heap_realloc(read_buf, read_buf_size + 1024);
503             if(!new_read_buf) {
504                 status = E_OUTOFMEMORY;
505                 break;
506             }
507 
508             in_bufs[0].pvBuffer = read_buf = new_read_buf;
509             read_buf_size += 1024;
510         }
511 
512         size = sock_recv(connection->socket, read_buf+in_bufs[0].cbBuffer, read_buf_size-in_bufs[0].cbBuffer, 0);
513         if(size < 1) {
514             WARN("recv error\n");
515             res = ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
516             break;
517         }
518 
519         TRACE("recv %lu bytes\n", size);
520 
521         in_bufs[0].cbBuffer += size;
522         in_bufs[0].pvBuffer = read_buf;
523         status = InitializeSecurityContextW(cred, &ctx, connection->server->name,  isc_req_flags, 0, 0, &in_desc,
524                 0, NULL, &out_desc, &attrs, NULL);
525         TRACE("InitializeSecurityContext ret %08x\n", status);
526 
527         if(status == SEC_E_OK) {
528             if(SecIsValidHandle(&connection->ssl_ctx))
529                 DeleteSecurityContext(&connection->ssl_ctx);
530             connection->ssl_ctx = ctx;
531 
532             if(in_bufs[1].BufferType == SECBUFFER_EXTRA)
533                 FIXME("SECBUFFER_EXTRA not supported\n");
534 
535             status = QueryContextAttributesW(&ctx, SECPKG_ATTR_STREAM_SIZES, &connection->ssl_sizes);
536             if(status != SEC_E_OK) {
537                 WARN("Could not get sizes\n");
538                 break;
539             }
540 
541             status = QueryContextAttributesW(&ctx, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (void*)&cert);
542             if(status == SEC_E_OK) {
543                 res = netconn_verify_cert(connection, cert, cert->hCertStore);
544                 CertFreeCertificateContext(cert);
545                 if(res != ERROR_SUCCESS) {
546                     WARN("cert verify failed: %u\n", res);
547                     break;
548                 }
549             }else {
550                 WARN("Could not get cert\n");
551                 break;
552             }
553 
554             connection->ssl_buf = heap_alloc(connection->ssl_sizes.cbHeader + connection->ssl_sizes.cbMaximumMessage
555                     + connection->ssl_sizes.cbTrailer);
556             if(!connection->ssl_buf) {
557                 res = GetLastError();
558                 break;
559             }
560         }
561     }
562 
563     heap_free(read_buf);
564 
565     if(status != SEC_E_OK || res != ERROR_SUCCESS) {
566         WARN("Failed to establish SSL connection: %08x (%u)\n", status, res);
567         heap_free(connection->ssl_buf);
568         connection->ssl_buf = NULL;
569         return res ? res : ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
570     }
571 
572     TRACE("established SSL connection\n");
573     connection->secure = TRUE;
574     connection->security_flags |= SECURITY_FLAG_SECURE;
575 
576     bits = NETCON_GetCipherStrength(connection);
577     if (bits >= 128)
578         connection->security_flags |= SECURITY_FLAG_STRENGTH_STRONG;
579     else if (bits >= 56)
580         connection->security_flags |= SECURITY_FLAG_STRENGTH_MEDIUM;
581     else
582         connection->security_flags |= SECURITY_FLAG_STRENGTH_WEAK;
583 
584     if(connection->mask_errors)
585         connection->server->security_flags = connection->security_flags;
586     return ERROR_SUCCESS;
587 }
588 
589 /******************************************************************************
590  * NETCON_secure_connect
591  * Initiates a secure connection over an existing plaintext connection.
592  */
593 DWORD NETCON_secure_connect(netconn_t *connection, server_t *server)
594 {
595     DWORD res;
596 
597     /* can't connect if we are already connected */
598     if(connection->secure) {
599         ERR("already connected\n");
600         return ERROR_INTERNET_CANNOT_CONNECT;
601     }
602 
603     if(server != connection->server) {
604         server_release(connection->server);
605         server_addref(server);
606         connection->server = server;
607     }
608 
609     /* connect with given TLS options */
610     res = netcon_secure_connect_setup(connection, FALSE);
611     if (res == ERROR_SUCCESS)
612         return res;
613 
614     /* FIXME: when got version alert and FIN from server */
615     /* fallback to connect without TLSv1.1/TLSv1.2        */
616     if (res == ERROR_INTERNET_SECURITY_CHANNEL_ERROR && have_compat_cred_handle)
617     {
618         closesocket(connection->socket);
619         res = create_netconn_socket(connection->server, connection, 500);
620         if (res != ERROR_SUCCESS)
621             return res;
622         res = netcon_secure_connect_setup(connection, TRUE);
623     }
624     return res;
625 }
626 
627 static BOOL send_ssl_chunk(netconn_t *conn, const void *msg, size_t size)
628 {
629     SecBuffer bufs[4] = {
630         {conn->ssl_sizes.cbHeader, SECBUFFER_STREAM_HEADER, conn->ssl_buf},
631         {size,  SECBUFFER_DATA, conn->ssl_buf+conn->ssl_sizes.cbHeader},
632         {conn->ssl_sizes.cbTrailer, SECBUFFER_STREAM_TRAILER, conn->ssl_buf+conn->ssl_sizes.cbHeader+size},
633         {0, SECBUFFER_EMPTY, NULL}
634     };
635     SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
636     SECURITY_STATUS res;
637 
638     memcpy(bufs[1].pvBuffer, msg, size);
639     res = EncryptMessage(&conn->ssl_ctx, 0, &buf_desc, 0);
640     if(res != SEC_E_OK) {
641         WARN("EncryptMessage failed\n");
642         return FALSE;
643     }
644 
645     if(sock_send(conn->socket, conn->ssl_buf, bufs[0].cbBuffer+bufs[1].cbBuffer+bufs[2].cbBuffer, 0) < 1) {
646         WARN("send failed\n");
647         return FALSE;
648     }
649 
650     return TRUE;
651 }
652 
653 /******************************************************************************
654  * NETCON_send
655  * Basically calls 'send()' unless we should use SSL
656  * number of chars send is put in *sent
657  */
658 DWORD NETCON_send(netconn_t *connection, const void *msg, size_t len, int flags,
659 		int *sent /* out */)
660 {
661     /* send is always blocking. */
662     set_socket_blocking(connection, TRUE);
663 
664     if(!connection->secure)
665     {
666 	*sent = sock_send(connection->socket, msg, len, flags);
667         return *sent == -1 ? WSAGetLastError() : ERROR_SUCCESS;
668     }
669     else
670     {
671         const BYTE *ptr = msg;
672         size_t chunk_size;
673 
674         *sent = 0;
675 
676         while(len) {
677             chunk_size = min(len, connection->ssl_sizes.cbMaximumMessage);
678             if(!send_ssl_chunk(connection, ptr, chunk_size))
679                 return ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
680 
681             *sent += chunk_size;
682             ptr += chunk_size;
683             len -= chunk_size;
684         }
685 
686         return ERROR_SUCCESS;
687     }
688 }
689 
690 static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, BOOL blocking, SIZE_T *ret_size, BOOL *eof)
691 {
692     const SIZE_T ssl_buf_size = conn->ssl_sizes.cbHeader+conn->ssl_sizes.cbMaximumMessage+conn->ssl_sizes.cbTrailer;
693     SecBuffer bufs[4];
694     SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
695     SSIZE_T size, buf_len = 0;
696     int i;
697     SECURITY_STATUS res;
698 
699     assert(conn->extra_len < ssl_buf_size);
700 
701     if(conn->extra_len) {
702         memcpy(conn->ssl_buf, conn->extra_buf, conn->extra_len);
703         buf_len = conn->extra_len;
704         conn->extra_len = 0;
705         heap_free(conn->extra_buf);
706         conn->extra_buf = NULL;
707     }
708 
709     set_socket_blocking(conn, blocking && !buf_len);
710     size = sock_recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, 0);
711     if(size < 0) {
712         if(!buf_len) {
713             if(WSAGetLastError() == WSAEWOULDBLOCK) {
714                 TRACE("would block\n");
715                 return WSAEWOULDBLOCK;
716             }
717             WARN("recv failed\n");
718             return ERROR_INTERNET_CONNECTION_ABORTED;
719         }
720     }else {
721         buf_len += size;
722     }
723 
724     if(!buf_len) {
725         TRACE("EOF\n");
726         *eof = TRUE;
727         *ret_size = 0;
728         return ERROR_SUCCESS;
729     }
730 
731     *eof = FALSE;
732 
733     do {
734         memset(bufs, 0, sizeof(bufs));
735         bufs[0].BufferType = SECBUFFER_DATA;
736         bufs[0].cbBuffer = buf_len;
737         bufs[0].pvBuffer = conn->ssl_buf;
738 
739         res = DecryptMessage(&conn->ssl_ctx, &buf_desc, 0, NULL);
740         switch(res) {
741         case SEC_E_OK:
742             break;
743         case SEC_I_CONTEXT_EXPIRED:
744             TRACE("context expired\n");
745             *eof = TRUE;
746             return ERROR_SUCCESS;
747         case SEC_E_INCOMPLETE_MESSAGE:
748             assert(buf_len < ssl_buf_size);
749 
750             set_socket_blocking(conn, blocking);
751             size = sock_recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, 0);
752             if(size < 1) {
753                 if(size < 0 && WSAGetLastError() == WSAEWOULDBLOCK) {
754                     TRACE("would block\n");
755 
756                     /* FIXME: Optimize extra_buf usage. */
757                     conn->extra_buf = heap_alloc(buf_len);
758                     if(!conn->extra_buf)
759                         return ERROR_NOT_ENOUGH_MEMORY;
760 
761                     conn->extra_len = buf_len;
762                     memcpy(conn->extra_buf, conn->ssl_buf, conn->extra_len);
763                     return WSAEWOULDBLOCK;
764                 }
765 
766                 return ERROR_INTERNET_CONNECTION_ABORTED;
767             }
768 
769             buf_len += size;
770             continue;
771         default:
772             WARN("failed: %08x\n", res);
773             return ERROR_INTERNET_CONNECTION_ABORTED;
774         }
775     } while(res != SEC_E_OK);
776 
777     for(i=0; i < sizeof(bufs)/sizeof(*bufs); i++) {
778         if(bufs[i].BufferType == SECBUFFER_DATA) {
779             size = min(buf_size, bufs[i].cbBuffer);
780             memcpy(buf, bufs[i].pvBuffer, size);
781             if(size < bufs[i].cbBuffer) {
782                 assert(!conn->peek_len);
783                 conn->peek_msg_mem = conn->peek_msg = heap_alloc(bufs[i].cbBuffer - size);
784                 if(!conn->peek_msg)
785                     return ERROR_NOT_ENOUGH_MEMORY;
786                 conn->peek_len = bufs[i].cbBuffer-size;
787                 memcpy(conn->peek_msg, (char*)bufs[i].pvBuffer+size, conn->peek_len);
788             }
789 
790             *ret_size = size;
791         }
792     }
793 
794     for(i=0; i < sizeof(bufs)/sizeof(*bufs); i++) {
795         if(bufs[i].BufferType == SECBUFFER_EXTRA) {
796             conn->extra_buf = heap_alloc(bufs[i].cbBuffer);
797             if(!conn->extra_buf)
798                 return ERROR_NOT_ENOUGH_MEMORY;
799 
800             conn->extra_len = bufs[i].cbBuffer;
801             memcpy(conn->extra_buf, bufs[i].pvBuffer, conn->extra_len);
802         }
803     }
804 
805     return ERROR_SUCCESS;
806 }
807 
808 /******************************************************************************
809  * NETCON_recv
810  * Basically calls 'recv()' unless we should use SSL
811  * number of chars received is put in *recvd
812  */
813 DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, BOOL blocking, int *recvd)
814 {
815     *recvd = 0;
816     if (!len)
817         return ERROR_SUCCESS;
818 
819     if (!connection->secure)
820     {
821         set_socket_blocking(connection, blocking);
822         *recvd = sock_recv(connection->socket, buf, len, 0);
823         return *recvd == -1 ? WSAGetLastError() :  ERROR_SUCCESS;
824     }
825     else
826     {
827         SIZE_T size = 0;
828         BOOL eof;
829         DWORD res;
830 
831         if(connection->peek_msg) {
832             size = min(len, connection->peek_len);
833             memcpy(buf, connection->peek_msg, size);
834             connection->peek_len -= size;
835             connection->peek_msg += size;
836 
837             if(!connection->peek_len) {
838                 heap_free(connection->peek_msg_mem);
839                 connection->peek_msg_mem = connection->peek_msg = NULL;
840             }
841 
842             *recvd = size;
843             return ERROR_SUCCESS;
844         }
845 
846         do {
847             res = read_ssl_chunk(connection, (BYTE*)buf, len, blocking, &size, &eof);
848             if(res != ERROR_SUCCESS) {
849                 if(res == WSAEWOULDBLOCK) {
850                     if(size)
851                         res = ERROR_SUCCESS;
852                 }else {
853                     WARN("read_ssl_chunk failed\n");
854                 }
855                 break;
856             }
857         }while(!size && !eof);
858 
859         TRACE("received %ld bytes\n", size);
860         *recvd = size;
861         return res;
862     }
863 }
864 
865 BOOL NETCON_is_alive(netconn_t *netconn)
866 {
867     int len;
868     char b;
869 
870     set_socket_blocking(netconn, FALSE);
871     len = sock_recv(netconn->socket, &b, 1, MSG_PEEK);
872 
873     return len == 1 || (len == -1 && WSAGetLastError() == WSAEWOULDBLOCK);
874 }
875 
876 LPCVOID NETCON_GetCert(netconn_t *connection)
877 {
878     const CERT_CONTEXT *ret;
879     SECURITY_STATUS res;
880 
881     res = QueryContextAttributesW(&connection->ssl_ctx, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (void*)&ret);
882     return res == SEC_E_OK ? ret : NULL;
883 }
884 
885 int NETCON_GetCipherStrength(netconn_t *connection)
886 {
887     SecPkgContext_ConnectionInfo conn_info;
888     SECURITY_STATUS res;
889 
890     if (!connection->secure)
891         return 0;
892 
893     res = QueryContextAttributesW(&connection->ssl_ctx, SECPKG_ATTR_CONNECTION_INFO, (void*)&conn_info);
894     if(res != SEC_E_OK)
895         WARN("QueryContextAttributesW failed: %08x\n", res);
896     return res == SEC_E_OK ? conn_info.dwCipherStrength : 0;
897 }
898 
899 DWORD NETCON_set_timeout(netconn_t *connection, BOOL send, DWORD value)
900 {
901     int result;
902 
903     result = setsockopt(connection->socket, SOL_SOCKET,
904                         send ? SO_SNDTIMEO : SO_RCVTIMEO, (void*)&value,
905                         sizeof(value));
906     if (result == -1)
907     {
908         WARN("setsockopt failed\n");
909         return WSAGetLastError();
910     }
911     return ERROR_SUCCESS;
912 }
913