1 //! Schannel TLS streams. 2 use std::any::Any; 3 use std::cmp; 4 use std::error::Error; 5 use std::fmt; 6 use std::io::{self, Read, BufRead, Write, Cursor}; 7 use std::mem; 8 use std::ptr; 9 use std::slice; 10 use std::sync::Arc; 11 use winapi::shared::minwindef as winapi; 12 use winapi::shared::{ntdef, sspi, winerror}; 13 use winapi::um::{self, schannel, wincrypt}; 14 15 use crate::{INIT_REQUESTS, ACCEPT_REQUESTS, Inner, secbuf, secbuf_desc}; 16 use crate::alpn_list::AlpnList; 17 use crate::cert_chain::{CertChain, CertChainContext}; 18 use crate::cert_store::{CertAdd, CertStore}; 19 use crate::cert_context::CertContext; 20 use crate::security_context::SecurityContext; 21 use crate::context_buffer::ContextBuffer; 22 use crate::schannel_cred::SchannelCred; 23 24 lazy_static! { 25 static ref szOID_PKIX_KP_SERVER_AUTH: Vec<u8> = 26 wincrypt::szOID_PKIX_KP_SERVER_AUTH.bytes().chain(Some(0)).collect(); 27 static ref szOID_SERVER_GATED_CRYPTO: Vec<u8> = 28 wincrypt::szOID_SERVER_GATED_CRYPTO.bytes().chain(Some(0)).collect(); 29 static ref szOID_SGC_NETSCAPE: Vec<u8> = 30 wincrypt::szOID_SGC_NETSCAPE.bytes().chain(Some(0)).collect(); 31 } 32 33 /// A builder type for `TlsStream`s. 34 pub struct Builder { 35 domain: Option<Vec<u16>>, 36 use_sni: bool, 37 accept_invalid_hostnames: bool, 38 verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>, 39 cert_store: Option<CertStore>, 40 requested_application_protocols: Option<Vec<Vec<u8>>>, 41 } 42 43 impl Default for Builder { 44 fn default() -> Builder { 45 Builder { 46 domain: None, 47 use_sni: true, 48 accept_invalid_hostnames: false, 49 verify_callback: None, 50 cert_store: None, 51 requested_application_protocols: None, 52 } 53 } 54 } 55 56 impl Builder { 57 /// Returns a new `Builder`. 58 pub fn new() -> Builder { 59 Builder::default() 60 } 61 62 /// Sets the domain associated with connections created with this `Builder`. 63 /// 64 /// The domain will be used for Server Name Indication as well as 65 /// certificate validation. 66 pub fn domain(&mut self, domain: &str) -> &mut Builder { 67 self.domain = Some(domain.encode_utf16().chain(Some(0)).collect()); 68 self 69 } 70 71 /// Determines if Server Name Indication (SNI) will be used. 72 /// 73 /// Defaults to `true`. 74 pub fn use_sni(&mut self, use_sni: bool) -> &mut Builder { 75 self.use_sni = use_sni; 76 self 77 } 78 79 /// Determines if the server's hostname will be checked during certificate verification. 80 /// 81 /// Defaults to `false`. 82 pub fn accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) -> &mut Builder { 83 self.accept_invalid_hostnames = accept_invalid_hostnames; 84 self 85 } 86 87 /// Set a verification callback to be used for connections created with this `Builder`. 88 /// 89 /// The callback is provided with an io::Result indicating if the (pre)validation was 90 /// successful. The Ok() variant indicates a successful validation while the Err() variant 91 /// contains the errorcode returned from the internal verification process. 92 /// The validated certificate, is accessible through the second argument of the closure. 93 pub fn verify_callback<F>(&mut self, callback: F) -> &mut Builder 94 where F: Fn(CertValidationResult) -> io::Result<()> + 'static + Sync + Send 95 { 96 self.verify_callback = Some(Arc::new(callback)); 97 self 98 } 99 100 /// Specifies a custom certificate store which is later used when validating 101 /// a server's certificate. 102 /// 103 /// This option is only used for client connections and is used to construct 104 /// the certificate chain which the server's certificate is validated 105 /// against. 106 /// 107 /// Note that adding certificates here means that they are 108 /// implicitly trusted. 109 pub fn cert_store(&mut self, cert_store: CertStore) -> &mut Builder { 110 self.cert_store = Some(cert_store); 111 self 112 } 113 114 /// Requests one of a set of application protocols using alpn 115 pub fn request_application_protocols(&mut self, alpns: &[&[u8]]) -> &mut Builder { 116 self.requested_application_protocols = 117 Some(alpns.iter().map(|bytes| bytes.to_vec()).collect::<Vec<_>>()); 118 self 119 } 120 121 /// Initialize a new TLS session where the stream provided will be 122 /// connecting to a remote TLS server. 123 /// 124 /// If the stream provided is a blocking stream then the entire handshake 125 /// will be performed if possible, but if the stream is in nonblocking mode 126 /// then a `HandshakeError::Interrupted` variant may be returned. This 127 /// type can then be extracted to later call 128 /// `MidHandshakeTlsStream::handshake` when data becomes available. 129 pub fn connect<S>(&mut self, 130 cred: SchannelCred, 131 stream: S) 132 -> Result<TlsStream<S>, HandshakeError<S>> 133 where S: Read + Write 134 { 135 self.initialize(cred, false, stream) 136 } 137 138 /// Initialize a new TLS session where the stream provided will be 139 /// accepting a connection. 140 /// 141 /// This method will tweak the protocol for "who talks first" and also 142 /// currently disables validation of the client that's connecting to us. 143 /// 144 /// If the stream provided is a blocking stream then the entire handshake 145 /// will be performed if possible, but if the stream is in nonblocking mode 146 /// then a `HandshakeError::Interrupted` variant may be returned. This 147 /// type can then be extracted to later call 148 /// `MidHandshakeTlsStream::handshake` when data becomes available. 149 pub fn accept<S>(&mut self, 150 cred: SchannelCred, 151 stream: S) 152 -> Result<TlsStream<S>, HandshakeError<S>> 153 where S: Read + Write 154 { 155 self.initialize(cred, true, stream) 156 } 157 158 fn initialize<S>(&mut self, 159 mut cred: SchannelCred, 160 server: bool, 161 stream: S) 162 -> Result<TlsStream<S>, HandshakeError<S>> 163 where S: Read + Write 164 { 165 let domain = match self.domain { 166 Some(ref domain) if self.use_sni => Some(&domain[..]), 167 _ => None, 168 }; 169 let (ctxt, buf) = match SecurityContext::initialize(&mut cred, 170 server, 171 domain, 172 &self.requested_application_protocols) { 173 Ok(pair) => pair, 174 Err(e) => return Err(HandshakeError::Failure(e)), 175 }; 176 177 let stream = TlsStream { 178 cred: cred, 179 context: ctxt, 180 cert_store: self.cert_store.clone(), 181 domain: self.domain.clone(), 182 use_sni: self.use_sni, 183 accept_invalid_hostnames: self.accept_invalid_hostnames, 184 verify_callback: self.verify_callback.clone(), 185 stream: stream, 186 server: server, 187 accept_first: true, 188 state: State::Initializing { 189 needs_flush: false, 190 more_calls: true, 191 shutting_down: false, 192 validated: false, 193 }, 194 needs_read: 1, 195 dec_in: Cursor::new(Vec::new()), 196 enc_in: Cursor::new(Vec::new()), 197 out_buf: Cursor::new(buf.map(|b| b.to_owned()).unwrap_or(Vec::new())), 198 last_write_len: 0, 199 requested_application_protocols: self.requested_application_protocols.clone(), 200 }; 201 202 MidHandshakeTlsStream { 203 inner: stream, 204 }.handshake() 205 } 206 } 207 208 enum State { 209 Initializing { 210 needs_flush: bool, 211 more_calls: bool, 212 shutting_down: bool, 213 validated: bool, 214 }, 215 Streaming { sizes: sspi::SecPkgContext_StreamSizes, }, 216 Shutdown, 217 } 218 219 /// An Schannel TLS stream. 220 pub struct TlsStream<S> { 221 cred: SchannelCred, 222 context: SecurityContext, 223 cert_store: Option<CertStore>, 224 domain: Option<Vec<u16>>, 225 use_sni: bool, 226 accept_invalid_hostnames: bool, 227 verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>, 228 stream: S, 229 state: State, 230 server: bool, 231 accept_first: bool, 232 needs_read: usize, 233 // valid from position() to len() 234 dec_in: Cursor<Vec<u8>>, 235 // valid from 0 to position() 236 enc_in: Cursor<Vec<u8>>, 237 // valid from position() to len() 238 out_buf: Cursor<Vec<u8>>, 239 /// the (unencrypted) length of the last write call used to track writes 240 last_write_len: usize, 241 requested_application_protocols: Option<Vec<Vec<u8>>>, 242 } 243 244 /// ensures that a TlsStream is always Sync/Send 245 fn _is_sync() { 246 fn sync<T: Sync + Send>() {} 247 sync::<TlsStream<()>>(); 248 } 249 250 /// A failure which can happen during the `Builder::initialize` phase, either an 251 /// I/O error or an intermediate stream which has not completed its handshake. 252 #[derive(Debug)] 253 pub enum HandshakeError<S> { 254 /// A fatal I/O error occurred 255 Failure(io::Error), 256 /// The stream connection is in progress, but the handshake is not completed 257 /// yet. 258 Interrupted(MidHandshakeTlsStream<S>), 259 } 260 261 /// A struct used to wrap various cert chain validation results for callback processing. 262 pub struct CertValidationResult { 263 chain: CertChainContext, 264 res: i32, 265 chain_index: i32, 266 element_index: i32, 267 } 268 269 impl CertValidationResult { 270 /// Returns the certificate that failed validation if applicable 271 pub fn failed_certificate(&self) -> Option<CertContext> { 272 if let Some(cert_chain) = self.chain.get_chain(self.chain_index as usize) { 273 return cert_chain.get(self.element_index as usize); 274 } 275 None 276 } 277 278 /// Returns the final certificate chain in the certificate context if applicable 279 pub fn chain(&self) -> Option<CertChain> { 280 self.chain.final_chain() 281 } 282 283 /// Returns the result of the built-in certificate verification process. 284 pub fn result(&self) -> io::Result<()> { 285 if self.res as u32 != winerror::ERROR_SUCCESS { 286 Err(io::Error::from_raw_os_error(self.res)) 287 } else { 288 Ok(()) 289 } 290 } 291 } 292 293 impl<S: fmt::Debug + Any> Error for HandshakeError<S> { 294 fn source(&self) -> Option<&(dyn Error + 'static)> { 295 match *self { 296 HandshakeError::Failure(ref e) => Some(e), 297 HandshakeError::Interrupted(_) => None, 298 } 299 } 300 } 301 302 impl<S: fmt::Debug + Any> fmt::Display for HandshakeError<S> { 303 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 304 let desc = match *self { 305 HandshakeError::Failure(_) => "failed to perform handshake", 306 HandshakeError::Interrupted(_) => "interrupted performing handshake", 307 }; 308 write!(f, "{}", desc)?; 309 if let Some(e) = self.source() { 310 write!(f, ": {}", e)?; 311 } 312 Ok(()) 313 } 314 } 315 316 /// A stream which has not yet completed its handshake. 317 #[derive(Debug)] 318 pub struct MidHandshakeTlsStream<S> { 319 inner: TlsStream<S>, 320 } 321 322 impl<S> fmt::Debug for TlsStream<S> 323 where S: fmt::Debug 324 { 325 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { 326 fmt.debug_struct("TlsStream") 327 .field("stream", &self.stream) 328 .finish() 329 } 330 } 331 332 impl<S> TlsStream<S> { 333 /// Returns a reference to the wrapped stream. 334 pub fn get_ref(&self) -> &S { 335 &self.stream 336 } 337 338 /// Returns a mutable reference to the wrapped stream. 339 pub fn get_mut(&mut self) -> &mut S { 340 &mut self.stream 341 } 342 343 /// Indicates if this stream is the server- or client-side of a TLS session. 344 pub fn is_server(&self) -> bool { 345 self.server 346 } 347 } 348 349 impl<S> TlsStream<S> 350 where S: Read + Write 351 { 352 /// Returns the certificate used to identify this side of the TLS session. 353 /// 354 /// Its associated cert store contains any intermediate certificates sent 355 /// along with the leaf. 356 pub fn certificate(&self) -> io::Result<CertContext> { 357 self.context.local_cert() 358 } 359 360 /// Returns the peer's certificate, if available. 361 /// 362 /// Its associated cert store contains any intermediate certificates sent 363 /// by the server. 364 pub fn peer_certificate(&self) -> io::Result<CertContext> { 365 self.context.remote_cert() 366 } 367 368 /// Returns the negotiated application protocol for this tls stream, if one exists 369 pub fn negotiated_application_protocol(&self) -> io::Result<Option<Vec<u8>>> { 370 let client_proto = self.context.application_protocol()?; 371 if client_proto.ProtoNegoStatus != sspi::SecApplicationProtocolNegotiationStatus_Success 372 || client_proto.ProtoNegoExt != sspi::SecApplicationProtocolNegotiationExt_ALPN 373 { 374 return Ok(None); 375 } 376 Ok(Some(client_proto.ProtocolId[..client_proto.ProtocolIdSize as usize].to_vec())) 377 } 378 379 /// Returns whether or not the session was resumed. 380 pub fn session_resumed(&self) -> io::Result<bool> { 381 let session_info = self.context.session_info()?; 382 Ok(session_info.dwFlags & schannel::SSL_SESSION_RECONNECT > 0) 383 } 384 385 /// Returns a reference to the buffer of pending data. 386 /// 387 /// Like `BufRead::fill_buf` except that it will return an empty slice 388 /// rather than reading from the wrapped stream if there is no buffered 389 /// data. 390 pub fn get_buf(&self) -> &[u8] { 391 &self.dec_in.get_ref()[self.dec_in.position() as usize..] 392 } 393 394 /// Shuts the TLS session down. 395 pub fn shutdown(&mut self) -> io::Result<()> { 396 match self.state { 397 State::Shutdown => return Ok(()), 398 State::Initializing { shutting_down: true, .. } => {} 399 _ => { 400 unsafe { 401 let mut token = um::schannel::SCHANNEL_SHUTDOWN; 402 let ptr = &mut token as *mut _ as *mut u8; 403 let size = mem::size_of_val(&token); 404 let token = slice::from_raw_parts_mut(ptr, size); 405 let mut buf = [secbuf(sspi::SECBUFFER_TOKEN, Some(token))]; 406 let mut desc = secbuf_desc(&mut buf); 407 408 match sspi::ApplyControlToken(self.context.get_mut(), &mut desc) { 409 winerror::SEC_E_OK => {} 410 err => return Err(io::Error::from_raw_os_error(err as i32)), 411 } 412 } 413 414 self.state = State::Initializing { 415 needs_flush: false, 416 more_calls: true, 417 shutting_down: true, 418 validated: false, 419 }; 420 self.needs_read = 0; 421 } 422 } 423 424 self.initialize().map(|_| ()) 425 } 426 427 fn step_initialize(&mut self) -> io::Result<()> { 428 unsafe { 429 let pos = self.enc_in.position() as usize; 430 let mut inbufs = vec![secbuf(sspi::SECBUFFER_TOKEN, 431 Some(&mut self.enc_in.get_mut()[..pos])), 432 secbuf(sspi::SECBUFFER_EMPTY, None)]; 433 // Make sure `AlpnList` is kept alive for the duration of this function. 434 let mut alpns = self.requested_application_protocols.as_ref().map(|alpn| AlpnList::new(&alpn)); 435 if let Some(ref mut alpns) = alpns { 436 inbufs.push(secbuf(sspi::SECBUFFER_APPLICATION_PROTOCOLS, 437 Some(&mut alpns[..]))); 438 }; 439 let mut inbuf_desc = secbuf_desc(&mut inbufs[..]); 440 441 let mut outbufs = [secbuf(sspi::SECBUFFER_TOKEN, None), 442 secbuf(sspi::SECBUFFER_ALERT, None), 443 secbuf(sspi::SECBUFFER_EMPTY, None)]; 444 let mut outbuf_desc = secbuf_desc(&mut outbufs); 445 446 let mut attributes = 0; 447 448 let status = if self.server { 449 let ptr = if self.accept_first { 450 ptr::null_mut() 451 } else { 452 self.context.get_mut() 453 }; 454 sspi::AcceptSecurityContext(&mut self.cred.as_inner(), 455 ptr, 456 &mut inbuf_desc, 457 ACCEPT_REQUESTS, 458 0, 459 self.context.get_mut(), 460 &mut outbuf_desc, 461 &mut attributes, 462 ptr::null_mut()) 463 } else { 464 let domain = match self.domain { 465 Some(ref domain) if self.use_sni => domain.as_ptr() as *mut u16, 466 _ => ptr::null_mut(), 467 }; 468 469 sspi::InitializeSecurityContextW(&mut self.cred.as_inner(), 470 self.context.get_mut(), 471 domain, 472 INIT_REQUESTS, 473 0, 474 0, 475 &mut inbuf_desc, 476 0, 477 ptr::null_mut(), 478 &mut outbuf_desc, 479 &mut attributes, 480 ptr::null_mut()) 481 }; 482 483 for buf in &outbufs[1..] { 484 if !buf.pvBuffer.is_null() { 485 sspi::FreeContextBuffer(buf.pvBuffer); 486 } 487 } 488 489 match status { 490 winerror::SEC_I_CONTINUE_NEEDED => { 491 // Windows apparently doesn't like AcceptSecurityContext 492 // being called as if it were the second time unless the 493 // first call to AcceptSecurityContext succeeded with 494 // CONTINUE_NEEDED. 495 // 496 // In other words, if we were to set `accept_first` to 497 // `false` after the literal first call to 498 // `AcceptSecurityContext` while the call returned 499 // INCOMPLETE_MESSAGE, the next call would return an error. 500 // 501 // For that reason we only set `accept_first` to false here 502 // once we've actually successfully received the full 503 // "token" from the client. 504 self.accept_first = false; 505 let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA { 506 self.enc_in.position() as usize - inbufs[1].cbBuffer as usize 507 } else { 508 self.enc_in.position() as usize 509 }; 510 let to_write = ContextBuffer(outbufs[0]); 511 512 self.consume_enc_in(nread); 513 self.needs_read = (self.enc_in.position() == 0) as usize; 514 self.out_buf.get_mut().extend_from_slice(&to_write); 515 } 516 winerror::SEC_E_INCOMPLETE_MESSAGE => { 517 self.needs_read = if inbufs[1].BufferType == sspi::SECBUFFER_MISSING { 518 inbufs[1].cbBuffer as usize 519 } else { 520 1 521 }; 522 } 523 winerror::SEC_E_OK => { 524 let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA { 525 self.enc_in.position() as usize - inbufs[1].cbBuffer as usize 526 } else { 527 self.enc_in.position() as usize 528 }; 529 let to_write = if outbufs[0].pvBuffer.is_null() { 530 None 531 } else { 532 Some(ContextBuffer(outbufs[0])) 533 }; 534 535 self.consume_enc_in(nread); 536 self.needs_read = (self.enc_in.position() == 0) as usize; 537 if let Some(to_write) = to_write { 538 self.out_buf.get_mut().extend_from_slice(&to_write); 539 } 540 if self.enc_in.position() != 0 { 541 self.decrypt()?; 542 } 543 if let State::Initializing { ref mut more_calls, .. } = self.state { 544 *more_calls = false; 545 } 546 } 547 _ => { 548 return Err(io::Error::from_raw_os_error(status as i32)) 549 } 550 } 551 Ok(()) 552 } 553 } 554 555 fn initialize(&mut self) -> io::Result<Option<sspi::SecPkgContext_StreamSizes>> { 556 loop { 557 match self.state { 558 State::Initializing { mut needs_flush, more_calls, shutting_down, validated } => { 559 if self.write_out()? > 0 { 560 needs_flush = true; 561 if let State::Initializing { ref mut needs_flush, .. } = self.state { 562 *needs_flush = true; 563 } 564 } 565 566 if needs_flush { 567 self.stream.flush()?; 568 if let State::Initializing { ref mut needs_flush, .. } = self.state { 569 *needs_flush = false; 570 } 571 } 572 573 if !shutting_down && !validated { 574 // on the last call, we require a valid certificate 575 if self.validate(!more_calls)? { 576 if let State::Initializing { ref mut validated, .. } = self.state { 577 *validated = true; 578 } 579 } 580 } 581 582 if !more_calls { 583 self.state = if shutting_down { 584 State::Shutdown 585 } else { 586 State::Streaming { sizes: self.context.stream_sizes()? } 587 }; 588 continue; 589 } 590 591 if self.needs_read > 0 { 592 if self.read_in()? == 0 { 593 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, 594 "unexpected EOF during handshake")); 595 } 596 } 597 598 self.step_initialize()?; 599 } 600 State::Streaming { sizes } => return Ok(Some(sizes)), 601 State::Shutdown => return Ok(None), 602 } 603 } 604 } 605 606 /// Returns true when the certificate was succesfully verified 607 /// Returns false, when a verification isn't necessary (yet) 608 /// Returns an error when the verification failed 609 fn validate(&mut self, require_cert: bool) -> io::Result<bool> { 610 // If we're accepting connections then we don't perform any validation 611 // for the remote certificate, that's what they're doing! 612 if self.server { 613 return Ok(false); 614 } 615 616 let cert_context = match self.context.remote_cert() { 617 Err(_) if !require_cert => return Ok(false), 618 ret => ret? 619 }; 620 621 let cert_chain = unsafe { 622 let cert_store = match (cert_context.cert_store(), &self.cert_store) { 623 (Some(ref mut chain_certs), &Some(ref extra_certs)) => { 624 for extra_cert in extra_certs.certs() { 625 chain_certs.add_cert(&extra_cert, CertAdd::ReplaceExisting)?; 626 } 627 chain_certs.as_inner() 628 }, 629 (Some(chain_certs), &None) => chain_certs.as_inner(), 630 (None, &Some(ref extra_certs)) => extra_certs.as_inner(), 631 (None, &None) => ptr::null_mut() 632 }; 633 634 let flags = wincrypt::CERT_CHAIN_CACHE_END_CERT | 635 wincrypt::CERT_CHAIN_REVOCATION_CHECK_CACHE_ONLY | 636 wincrypt::CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT; 637 638 let mut para: wincrypt::CERT_CHAIN_PARA = mem::zeroed(); 639 para.cbSize = mem::size_of_val(¶) as winapi::DWORD; 640 para.RequestedUsage.dwType = wincrypt::USAGE_MATCH_TYPE_OR; 641 642 let mut identifiers = [szOID_PKIX_KP_SERVER_AUTH.as_ptr() as ntdef::LPSTR, 643 szOID_SERVER_GATED_CRYPTO.as_ptr() as ntdef::LPSTR, 644 szOID_SGC_NETSCAPE.as_ptr() as ntdef::LPSTR]; 645 para.RequestedUsage.Usage.cUsageIdentifier = identifiers.len() as winapi::DWORD; 646 para.RequestedUsage.Usage.rgpszUsageIdentifier = identifiers.as_mut_ptr(); 647 648 let mut cert_chain = mem::zeroed(); 649 650 let res = wincrypt::CertGetCertificateChain(ptr::null_mut(), 651 cert_context.as_inner(), 652 ptr::null_mut(), 653 cert_store, 654 &mut para, 655 flags, 656 ptr::null_mut(), 657 &mut cert_chain); 658 659 if res == winapi::TRUE { 660 CertChainContext(cert_chain as *mut _) 661 } else { 662 return Err(io::Error::last_os_error()) 663 } 664 }; 665 666 unsafe { 667 // check if we trust the root-CA explicitly 668 let mut para_flags = wincrypt::CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS; 669 if let Some(ref mut store) = self.cert_store { 670 if let Some(chain) = cert_chain.final_chain() { 671 // check if any cert of the chain is in the passed store (and therefore trusted) 672 if chain.certificates().any(|cert| store.certs().any(|root_cert| root_cert == cert)) { 673 para_flags |= wincrypt::CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG; 674 } 675 } 676 } 677 678 let mut extra_para: wincrypt::SSL_EXTRA_CERT_CHAIN_POLICY_PARA = mem::zeroed(); 679 *extra_para.u.cbSize_mut() = mem::size_of_val(&extra_para) as winapi::DWORD; 680 extra_para.dwAuthType = wincrypt::AUTHTYPE_SERVER; 681 match self.domain { 682 Some(ref mut domain) if !self.accept_invalid_hostnames => { 683 extra_para.pwszServerName = domain.as_mut_ptr(); 684 } 685 _ => {} 686 } 687 688 let mut para: wincrypt::CERT_CHAIN_POLICY_PARA = mem::zeroed(); 689 para.cbSize = mem::size_of_val(¶) as winapi::DWORD; 690 para.dwFlags = para_flags; 691 para.pvExtraPolicyPara = &mut extra_para as *mut _ as *mut _; 692 693 let mut status: wincrypt::CERT_CHAIN_POLICY_STATUS = mem::zeroed(); 694 status.cbSize = mem::size_of_val(&status) as winapi::DWORD; 695 696 let verify_chain_policy_structure = wincrypt::CERT_CHAIN_POLICY_SSL as ntdef::LPCSTR; 697 let res = wincrypt::CertVerifyCertificateChainPolicy(verify_chain_policy_structure, 698 cert_chain.0, 699 &mut para, 700 &mut status); 701 if res == winapi::FALSE { 702 return Err(io::Error::last_os_error()) 703 } 704 705 let mut verify_result = if status.dwError != winerror::ERROR_SUCCESS { 706 Err(io::Error::from_raw_os_error(status.dwError as i32)) 707 } else { 708 Ok(()) 709 }; 710 711 // check if there's a user-specified verify callback 712 if let Some(ref callback) = self.verify_callback { 713 verify_result = callback(CertValidationResult{ 714 chain: cert_chain, 715 res: status.dwError as i32, 716 chain_index: status.lChainIndex, 717 element_index: status.lElementIndex}); 718 } 719 verify_result?; 720 } 721 Ok(true) 722 } 723 724 fn write_out(&mut self) -> io::Result<usize> { 725 let mut out = 0; 726 while self.out_buf.position() as usize != self.out_buf.get_ref().len() { 727 let position = self.out_buf.position() as usize; 728 let nwritten = self.stream.write(&self.out_buf.get_ref()[position..])?; 729 out += nwritten; 730 self.out_buf.set_position((position + nwritten) as u64); 731 } 732 733 Ok(out) 734 } 735 736 fn read_in(&mut self) -> io::Result<usize> { 737 let mut sum_nread = 0; 738 739 while self.needs_read > 0 { 740 let existing_len = self.enc_in.position() as usize; 741 let min_len = cmp::max(cmp::max(1024, 2 * existing_len), self.needs_read); 742 if self.enc_in.get_ref().len() < min_len { 743 self.enc_in.get_mut().resize(min_len, 0); 744 } 745 let nread = { 746 let buf = &mut self.enc_in.get_mut()[existing_len..]; 747 self.stream.read(buf)? 748 }; 749 self.enc_in.set_position((existing_len + nread) as u64); 750 self.needs_read = self.needs_read.saturating_sub(nread); 751 if nread == 0 { 752 break; 753 } 754 sum_nread += nread; 755 } 756 757 Ok(sum_nread) 758 } 759 760 fn consume_enc_in(&mut self, nread: usize) { 761 let size = self.enc_in.position() as usize; 762 assert!(size >= nread); 763 let count = size - nread; 764 765 if count > 0 { 766 self.enc_in.get_mut().drain(..nread); 767 } 768 769 self.enc_in.set_position(count as u64); 770 } 771 772 fn decrypt(&mut self) -> io::Result<bool> { 773 unsafe { 774 let position = self.enc_in.position() as usize; 775 let mut bufs = [secbuf(sspi::SECBUFFER_DATA, 776 Some(&mut self.enc_in.get_mut()[..position])), 777 secbuf(sspi::SECBUFFER_EMPTY, None), 778 secbuf(sspi::SECBUFFER_EMPTY, None), 779 secbuf(sspi::SECBUFFER_EMPTY, None)]; 780 let mut bufdesc = secbuf_desc(&mut bufs); 781 782 match sspi::DecryptMessage(self.context.get_mut(), 783 &mut bufdesc, 784 0, 785 ptr::null_mut()) { 786 winerror::SEC_E_OK => { 787 let start = bufs[1].pvBuffer as usize - self.enc_in.get_ref().as_ptr() as usize; 788 let end = start + bufs[1].cbBuffer as usize; 789 self.dec_in.get_mut().clear(); 790 self.dec_in 791 .get_mut() 792 .extend_from_slice(&self.enc_in.get_ref()[start..end]); 793 self.dec_in.set_position(0); 794 795 let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA { 796 self.enc_in.position() as usize - bufs[3].cbBuffer as usize 797 } else { 798 self.enc_in.position() as usize 799 }; 800 self.consume_enc_in(nread); 801 self.needs_read = (self.enc_in.position() == 0) as usize; 802 Ok(false) 803 } 804 winerror::SEC_E_INCOMPLETE_MESSAGE => { 805 self.needs_read = if bufs[1].BufferType == sspi::SECBUFFER_MISSING { 806 bufs[1].cbBuffer as usize 807 } else { 808 1 809 }; 810 Ok(false) 811 } 812 winerror::SEC_I_CONTEXT_EXPIRED => Ok(true), 813 winerror::SEC_I_RENEGOTIATE => { 814 self.state = State::Initializing { 815 needs_flush: false, 816 more_calls: true, 817 shutting_down: false, 818 validated: false, 819 }; 820 821 let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA { 822 self.enc_in.position() as usize - bufs[3].cbBuffer as usize 823 } else { 824 self.enc_in.position() as usize 825 }; 826 self.consume_enc_in(nread); 827 self.needs_read = 0; 828 Ok(false) 829 } 830 e => Err(io::Error::from_raw_os_error(e as i32)), 831 } 832 } 833 } 834 835 fn encrypt(&mut self, buf: &[u8], sizes: &sspi::SecPkgContext_StreamSizes) -> io::Result<()> { 836 assert!(buf.len() <= sizes.cbMaximumMessage as usize); 837 838 unsafe { 839 let len = sizes.cbHeader as usize + buf.len() + sizes.cbTrailer as usize; 840 841 if self.out_buf.get_ref().len() < len { 842 self.out_buf.get_mut().resize(len, 0); 843 } 844 845 let message_start = sizes.cbHeader as usize; 846 self.out_buf 847 .get_mut()[message_start..message_start + buf.len()] 848 .clone_from_slice(buf); 849 850 let mut bufs = { 851 let out_buf = self.out_buf.get_mut(); 852 let size = sizes.cbHeader as usize; 853 854 let header = secbuf(sspi::SECBUFFER_STREAM_HEADER, 855 Some(&mut out_buf[..size])); 856 let data = secbuf(sspi::SECBUFFER_DATA, 857 Some(&mut out_buf[size..size + buf.len()])); 858 let trailer = secbuf(sspi::SECBUFFER_STREAM_TRAILER, 859 Some(&mut out_buf[size + buf.len()..])); 860 let empty = secbuf(sspi::SECBUFFER_EMPTY, None); 861 [header, data, trailer, empty] 862 }; 863 let mut bufdesc = secbuf_desc(&mut bufs); 864 865 match sspi::EncryptMessage(self.context.get_mut(), 0, &mut bufdesc, 0) { 866 winerror::SEC_E_OK => { 867 let len = bufs[0].cbBuffer + bufs[1].cbBuffer + bufs[2].cbBuffer; 868 self.out_buf.get_mut().truncate(len as usize); 869 self.out_buf.set_position(0); 870 Ok(()) 871 } 872 err => Err(io::Error::from_raw_os_error(err as i32)), 873 } 874 } 875 } 876 } 877 878 impl<S> MidHandshakeTlsStream<S> { 879 /// Returns a shared reference to the inner stream. 880 pub fn get_ref(&self) -> &S { 881 self.inner.get_ref() 882 } 883 884 /// Returns a mutable reference to the inner stream. 885 pub fn get_mut(&mut self) -> &mut S { 886 self.inner.get_mut() 887 } 888 } 889 890 impl<S> MidHandshakeTlsStream<S> 891 where S: Read + Write, 892 { 893 /// Restarts the handshake process. 894 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> { 895 match self.inner.initialize() { 896 Ok(_) => Ok(self.inner), 897 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { 898 Err(HandshakeError::Interrupted(self)) 899 } 900 Err(e) => Err(HandshakeError::Failure(e)), 901 } 902 } 903 } 904 905 impl<S> Write for TlsStream<S> 906 where S: Read + Write 907 { 908 /// In the case of a WouldBlock error, we expect another call 909 /// starting with the same input data 910 /// This is similar to the use of ACCEPT_MOVING_WRITE_BUFFER in openssl 911 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 912 let sizes = match self.initialize()? { 913 Some(sizes) => sizes, 914 None => return Err(io::Error::from_raw_os_error(winerror::SEC_E_CONTEXT_EXPIRED as i32)), 915 }; 916 917 // if we have pending output data, it must have been because a previous 918 // attempt to send this part of the data ran into an error. 919 if self.out_buf.position() == self.out_buf.get_ref().len() as u64 { 920 let len = cmp::min(buf.len(), sizes.cbMaximumMessage as usize); 921 self.encrypt(&buf[..len], &sizes)?; 922 self.last_write_len = len; 923 } 924 self.write_out()?; 925 926 Ok(self.last_write_len) 927 } 928 929 fn flush(&mut self) -> io::Result<()> { 930 // Make sure the write buffer is emptied 931 self.write_out()?; 932 self.stream.flush() 933 } 934 } 935 936 impl<S> Read for TlsStream<S> 937 where S: Read + Write 938 { 939 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 940 let nread = { 941 let read_buf = self.fill_buf()?; 942 let nread = cmp::min(buf.len(), read_buf.len()); 943 buf[..nread].copy_from_slice(&read_buf[..nread]); 944 nread 945 }; 946 self.consume(nread); 947 Ok(nread) 948 } 949 } 950 951 impl<S> BufRead for TlsStream<S> 952 where S: Read + Write 953 { 954 fn fill_buf(&mut self) -> io::Result<&[u8]> { 955 while self.get_buf().is_empty() { 956 if let None = self.initialize()? { 957 break; 958 } 959 960 if self.needs_read > 0 { 961 if self.read_in()? == 0 { 962 break; 963 } 964 self.needs_read = 0; 965 } 966 967 let eof = self.decrypt()?; 968 if eof { 969 break; 970 } 971 } 972 973 Ok(self.get_buf()) 974 } 975 976 fn consume(&mut self, amt: usize) { 977 let pos = self.dec_in.position() + amt as u64; 978 assert!(pos <= self.dec_in.get_ref().len() as u64); 979 self.dec_in.set_position(pos); 980 } 981 } 982