1 //! Schannel TLS streams. 2 use std::any::Any; 3 use std::cmp; 4 use std::error::Error; 5 use std::fmt; 6 use std::io::{self, Read, BufRead, Write, Cursor}; 7 use std::mem; 8 use std::ptr; 9 use std::slice; 10 use std::sync::Arc; 11 use winapi::shared::minwindef as winapi; 12 use winapi::shared::{ntdef, sspi, winerror}; 13 use winapi::um::{self, wincrypt}; 14 15 use crate::{INIT_REQUESTS, ACCEPT_REQUESTS, Inner, secbuf, secbuf_desc}; 16 use crate::cert_chain::{CertChain, CertChainContext}; 17 use crate::cert_store::{CertAdd, CertStore}; 18 use crate::cert_context::CertContext; 19 use crate::security_context::SecurityContext; 20 use crate::context_buffer::ContextBuffer; 21 use crate::schannel_cred::SchannelCred; 22 23 lazy_static! { 24 static ref szOID_PKIX_KP_SERVER_AUTH: Vec<u8> = 25 wincrypt::szOID_PKIX_KP_SERVER_AUTH.bytes().chain(Some(0)).collect(); 26 static ref szOID_SERVER_GATED_CRYPTO: Vec<u8> = 27 wincrypt::szOID_SERVER_GATED_CRYPTO.bytes().chain(Some(0)).collect(); 28 static ref szOID_SGC_NETSCAPE: Vec<u8> = 29 wincrypt::szOID_SGC_NETSCAPE.bytes().chain(Some(0)).collect(); 30 } 31 32 /// A builder type for `TlsStream`s. 33 pub struct Builder { 34 domain: Option<Vec<u16>>, 35 use_sni: bool, 36 accept_invalid_hostnames: bool, 37 verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>, 38 cert_store: Option<CertStore>, 39 } 40 41 impl Default for Builder { 42 fn default() -> Builder { 43 Builder { 44 domain: None, 45 use_sni: true, 46 accept_invalid_hostnames: false, 47 verify_callback: None, 48 cert_store: None, 49 } 50 } 51 } 52 53 impl Builder { 54 /// Returns a new `Builder`. 55 pub fn new() -> Builder { 56 Builder::default() 57 } 58 59 /// Sets the domain associated with connections created with this `Builder`. 60 /// 61 /// The domain will be used for Server Name Indication as well as 62 /// certificate validation. 63 pub fn domain(&mut self, domain: &str) -> &mut Builder { 64 self.domain = Some(domain.encode_utf16().chain(Some(0)).collect()); 65 self 66 } 67 68 /// Determines if Server Name Indication (SNI) will be used. 69 /// 70 /// Defaults to `true`. 71 pub fn use_sni(&mut self, use_sni: bool) -> &mut Builder { 72 self.use_sni = use_sni; 73 self 74 } 75 76 /// Determines if the server's hostname will be checked during certificate verification. 77 /// 78 /// Defaults to `false`. 79 pub fn accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) -> &mut Builder { 80 self.accept_invalid_hostnames = accept_invalid_hostnames; 81 self 82 } 83 84 /// Set a verification callback to be used for connections created with this `Builder`. 85 /// 86 /// The callback is provided with an io::Result indicating if the (pre)validation was 87 /// successful. The Ok() variant indicates a successful validation while the Err() variant 88 /// contains the errorcode returned from the internal verification process. 89 /// The validated certificate, is accessible through the second argument of the closure. 90 pub fn verify_callback<F>(&mut self, callback: F) -> &mut Builder 91 where F: Fn(CertValidationResult) -> io::Result<()> + 'static + Sync + Send 92 { 93 self.verify_callback = Some(Arc::new(callback)); 94 self 95 } 96 97 /// Specifies a custom certificate store which is later used when validating 98 /// a server's certificate. 99 /// 100 /// This option is only used for client connections and is used to construct 101 /// the certificate chain which the server's certificate is validated 102 /// against. 103 /// 104 /// Note that adding certificates here means that they are 105 /// implicitly trusted. 106 pub fn cert_store(&mut self, cert_store: CertStore) -> &mut Builder { 107 self.cert_store = Some(cert_store); 108 self 109 } 110 111 /// Initialize a new TLS session where the stream provided will be 112 /// connecting to a remote TLS server. 113 /// 114 /// If the stream provided is a blocking stream then the entire handshake 115 /// will be performed if possible, but if the stream is in nonblocking mode 116 /// then a `HandshakeError::Interrupted` variant may be returned. This 117 /// type can then be extracted to later call 118 /// `MidHandshakeTlsStream::handshake` when data becomes available. 119 pub fn connect<S>(&mut self, 120 cred: SchannelCred, 121 stream: S) 122 -> Result<TlsStream<S>, HandshakeError<S>> 123 where S: Read + Write 124 { 125 self.initialize(cred, false, stream) 126 } 127 128 /// Initialize a new TLS session where the stream provided will be 129 /// accepting a connection. 130 /// 131 /// This method will tweak the protocol for "who talks first" and also 132 /// currently disables validation of the client that's connecting to us. 133 /// 134 /// If the stream provided is a blocking stream then the entire handshake 135 /// will be performed if possible, but if the stream is in nonblocking mode 136 /// then a `HandshakeError::Interrupted` variant may be returned. This 137 /// type can then be extracted to later call 138 /// `MidHandshakeTlsStream::handshake` when data becomes available. 139 pub fn accept<S>(&mut self, 140 cred: SchannelCred, 141 stream: S) 142 -> Result<TlsStream<S>, HandshakeError<S>> 143 where S: Read + Write 144 { 145 self.initialize(cred, true, stream) 146 } 147 148 fn initialize<S>(&mut self, 149 mut cred: SchannelCred, 150 server: bool, 151 stream: S) 152 -> Result<TlsStream<S>, HandshakeError<S>> 153 where S: Read + Write 154 { 155 let domain = match self.domain { 156 Some(ref domain) if self.use_sni => Some(&domain[..]), 157 _ => None, 158 }; 159 let (ctxt, buf) = match SecurityContext::initialize(&mut cred, 160 server, 161 domain) { 162 Ok(pair) => pair, 163 Err(e) => return Err(HandshakeError::Failure(e)), 164 }; 165 166 let stream = TlsStream { 167 cred: cred, 168 context: ctxt, 169 cert_store: self.cert_store.clone(), 170 domain: self.domain.clone(), 171 use_sni: self.use_sni, 172 accept_invalid_hostnames: self.accept_invalid_hostnames, 173 verify_callback: self.verify_callback.clone(), 174 stream: stream, 175 server: server, 176 accept_first: true, 177 state: State::Initializing { 178 needs_flush: false, 179 more_calls: true, 180 shutting_down: false, 181 validated: false, 182 }, 183 needs_read: 1, 184 dec_in: Cursor::new(Vec::new()), 185 enc_in: Cursor::new(Vec::new()), 186 out_buf: Cursor::new(buf.map(|b| b.to_owned()).unwrap_or(Vec::new())), 187 last_write_len: 0, 188 }; 189 190 MidHandshakeTlsStream { 191 inner: stream, 192 }.handshake() 193 } 194 } 195 196 enum State { 197 Initializing { 198 needs_flush: bool, 199 more_calls: bool, 200 shutting_down: bool, 201 validated: bool, 202 }, 203 Streaming { sizes: sspi::SecPkgContext_StreamSizes, }, 204 Shutdown, 205 } 206 207 /// An Schannel TLS stream. 208 pub struct TlsStream<S> { 209 cred: SchannelCred, 210 context: SecurityContext, 211 cert_store: Option<CertStore>, 212 domain: Option<Vec<u16>>, 213 use_sni: bool, 214 accept_invalid_hostnames: bool, 215 verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>, 216 stream: S, 217 state: State, 218 server: bool, 219 accept_first: bool, 220 needs_read: usize, 221 // valid from position() to len() 222 dec_in: Cursor<Vec<u8>>, 223 // valid from 0 to position() 224 enc_in: Cursor<Vec<u8>>, 225 // valid from position() to len() 226 out_buf: Cursor<Vec<u8>>, 227 /// the (unencrypted) length of the last write call used to track writes 228 last_write_len: usize, 229 } 230 231 /// ensures that a TlsStream is always Sync/Send 232 fn _is_sync() { 233 fn sync<T: Sync + Send>() {} 234 sync::<TlsStream<()>>(); 235 } 236 237 /// A failure which can happen during the `Builder::initialize` phase, either an 238 /// I/O error or an intermediate stream which has not completed its handshake. 239 #[derive(Debug)] 240 pub enum HandshakeError<S> { 241 /// A fatal I/O error occurred 242 Failure(io::Error), 243 /// The stream connection is in progress, but the handshake is not completed 244 /// yet. 245 Interrupted(MidHandshakeTlsStream<S>), 246 } 247 248 /// A struct used to wrap various cert chain validation results for callback processing. 249 pub struct CertValidationResult { 250 chain: CertChainContext, 251 res: i32, 252 chain_index: i32, 253 element_index: i32, 254 } 255 256 impl CertValidationResult { 257 /// Returns the certificate that failed validation if applicable 258 pub fn failed_certificate(&self) -> Option<CertContext> { 259 if let Some(cert_chain) = self.chain.get_chain(self.chain_index as usize) { 260 return cert_chain.get(self.element_index as usize); 261 } 262 None 263 } 264 265 /// Returns the final certificate chain in the certificate context if applicable 266 pub fn chain(&self) -> Option<CertChain> { 267 self.chain.final_chain() 268 } 269 270 /// Returns the result of the built-in certificate verification process. 271 pub fn result(&self) -> io::Result<()> { 272 if self.res as u32 != winerror::ERROR_SUCCESS { 273 Err(io::Error::from_raw_os_error(self.res)) 274 } else { 275 Ok(()) 276 } 277 } 278 } 279 280 impl<S: fmt::Debug + Any> Error for HandshakeError<S> { 281 fn description(&self) -> &str { 282 match *self { 283 HandshakeError::Failure(_) => "failed to perform handshake", 284 HandshakeError::Interrupted(_) => "interrupted performing handshake", 285 } 286 } 287 288 fn cause(&self) -> Option<&dyn Error> { 289 match *self { 290 HandshakeError::Failure(ref e) => Some(e), 291 HandshakeError::Interrupted(_) => None, 292 } 293 } 294 } 295 296 impl<S: fmt::Debug + Any> fmt::Display for HandshakeError<S> { 297 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 298 f.write_str(self.description())?; 299 if let Some(e) = self.source() { 300 write!(f, ": {}", e)?; 301 } 302 Ok(()) 303 } 304 } 305 306 /// A stream which has not yet completed its handshake. 307 #[derive(Debug)] 308 pub struct MidHandshakeTlsStream<S> { 309 inner: TlsStream<S>, 310 } 311 312 impl<S> fmt::Debug for TlsStream<S> 313 where S: fmt::Debug 314 { 315 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { 316 fmt.debug_struct("TlsStream") 317 .field("stream", &self.stream) 318 .finish() 319 } 320 } 321 322 impl<S> TlsStream<S> { 323 /// Returns a reference to the wrapped stream. 324 pub fn get_ref(&self) -> &S { 325 &self.stream 326 } 327 328 /// Returns a mutable reference to the wrapped stream. 329 pub fn get_mut(&mut self) -> &mut S { 330 &mut self.stream 331 } 332 333 /// Indicates if this stream is the server- or client-side of a TLS session. 334 pub fn is_server(&self) -> bool { 335 self.server 336 } 337 } 338 339 impl<S> TlsStream<S> 340 where S: Read + Write 341 { 342 /// Returns the certificate used to identify this side of the TLS session. 343 /// 344 /// Its associated cert store contains any intermediate certificates sent 345 /// along with the leaf. 346 pub fn certificate(&self) -> io::Result<CertContext> { 347 self.context.local_cert() 348 } 349 350 /// Returns the peer's certificate, if available. 351 /// 352 /// Its associated cert store contains any intermediate certificates sent 353 /// by the server. 354 pub fn peer_certificate(&self) -> io::Result<CertContext> { 355 self.context.remote_cert() 356 } 357 358 /// Returns a reference to the buffer of pending data. 359 /// 360 /// Like `BufRead::fill_buf` except that it will return an empty slice 361 /// rather than reading from the wrapped stream if there is no buffered 362 /// data. 363 pub fn get_buf(&self) -> &[u8] { 364 &self.dec_in.get_ref()[self.dec_in.position() as usize..] 365 } 366 367 /// Shuts the TLS session down. 368 pub fn shutdown(&mut self) -> io::Result<()> { 369 match self.state { 370 State::Shutdown => return Ok(()), 371 State::Initializing { shutting_down: true, .. } => {} 372 _ => { 373 unsafe { 374 let mut token = um::schannel::SCHANNEL_SHUTDOWN; 375 let ptr = &mut token as *mut _ as *mut u8; 376 let size = mem::size_of_val(&token); 377 let token = slice::from_raw_parts_mut(ptr, size); 378 let mut buf = [secbuf(sspi::SECBUFFER_TOKEN, Some(token))]; 379 let mut desc = secbuf_desc(&mut buf); 380 381 match sspi::ApplyControlToken(self.context.get_mut(), &mut desc) { 382 winerror::SEC_E_OK => {} 383 err => return Err(io::Error::from_raw_os_error(err as i32)), 384 } 385 } 386 387 self.state = State::Initializing { 388 needs_flush: false, 389 more_calls: true, 390 shutting_down: true, 391 validated: false, 392 }; 393 self.needs_read = 0; 394 } 395 } 396 397 self.initialize().map(|_| ()) 398 } 399 400 fn step_initialize(&mut self) -> io::Result<()> { 401 unsafe { 402 let pos = self.enc_in.position() as usize; 403 let mut inbufs = [secbuf(sspi::SECBUFFER_TOKEN, 404 Some(&mut self.enc_in.get_mut()[..pos])), 405 secbuf(sspi::SECBUFFER_EMPTY, None)]; 406 let mut inbuf_desc = secbuf_desc(&mut inbufs); 407 408 let mut outbufs = [secbuf(sspi::SECBUFFER_TOKEN, None), 409 secbuf(sspi::SECBUFFER_ALERT, None), 410 secbuf(sspi::SECBUFFER_EMPTY, None)]; 411 let mut outbuf_desc = secbuf_desc(&mut outbufs); 412 413 let mut attributes = 0; 414 415 let status = if self.server { 416 let ptr = if self.accept_first { 417 ptr::null_mut() 418 } else { 419 self.context.get_mut() 420 }; 421 sspi::AcceptSecurityContext(self.cred.get_mut(), 422 ptr, 423 &mut inbuf_desc, 424 ACCEPT_REQUESTS, 425 0, 426 self.context.get_mut(), 427 &mut outbuf_desc, 428 &mut attributes, 429 ptr::null_mut()) 430 } else { 431 let domain = match self.domain { 432 Some(ref domain) if self.use_sni => domain.as_ptr() as *mut u16, 433 _ => ptr::null_mut(), 434 }; 435 436 sspi::InitializeSecurityContextW(self.cred.get_mut(), 437 self.context.get_mut(), 438 domain, 439 INIT_REQUESTS, 440 0, 441 0, 442 &mut inbuf_desc, 443 0, 444 ptr::null_mut(), 445 &mut outbuf_desc, 446 &mut attributes, 447 ptr::null_mut()) 448 }; 449 450 for buf in &outbufs[1..] { 451 if !buf.pvBuffer.is_null() { 452 sspi::FreeContextBuffer(buf.pvBuffer); 453 } 454 } 455 456 match status { 457 winerror::SEC_I_CONTINUE_NEEDED => { 458 // Windows apparently doesn't like AcceptSecurityContext 459 // being called as if it were the second time unless the 460 // first call to AcceptSecurityContext succeeded with 461 // CONTINUE_NEEDED. 462 // 463 // In other words, if we were to set `accept_first` to 464 // `false` after the literal first call to 465 // `AcceptSecurityContext` while the call returned 466 // INCOMPLETE_MESSAGE, the next call would return an error. 467 // 468 // For that reason we only set `accept_first` to false here 469 // once we've actually successfully received the full 470 // "token" from the client. 471 self.accept_first = false; 472 let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA { 473 self.enc_in.position() as usize - inbufs[1].cbBuffer as usize 474 } else { 475 self.enc_in.position() as usize 476 }; 477 let to_write = ContextBuffer(outbufs[0]); 478 479 self.consume_enc_in(nread); 480 self.needs_read = (self.enc_in.position() == 0) as usize; 481 self.out_buf.get_mut().extend_from_slice(&to_write); 482 } 483 winerror::SEC_E_INCOMPLETE_MESSAGE => { 484 self.needs_read = if inbufs[1].BufferType == sspi::SECBUFFER_MISSING { 485 inbufs[1].cbBuffer as usize 486 } else { 487 1 488 }; 489 } 490 winerror::SEC_E_OK => { 491 let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA { 492 self.enc_in.position() as usize - inbufs[1].cbBuffer as usize 493 } else { 494 self.enc_in.position() as usize 495 }; 496 let to_write = if outbufs[0].pvBuffer.is_null() { 497 None 498 } else { 499 Some(ContextBuffer(outbufs[0])) 500 }; 501 502 self.consume_enc_in(nread); 503 self.needs_read = (self.enc_in.position() == 0) as usize; 504 if let Some(to_write) = to_write { 505 self.out_buf.get_mut().extend_from_slice(&to_write); 506 } 507 if self.enc_in.position() != 0 { 508 self.decrypt()?; 509 } 510 if let State::Initializing { ref mut more_calls, .. } = self.state { 511 *more_calls = false; 512 } 513 } 514 _ => { 515 return Err(io::Error::from_raw_os_error(status as i32)) 516 } 517 } 518 Ok(()) 519 } 520 } 521 522 fn initialize(&mut self) -> io::Result<Option<sspi::SecPkgContext_StreamSizes>> { 523 loop { 524 match self.state { 525 State::Initializing { mut needs_flush, more_calls, shutting_down, validated } => { 526 if self.write_out()? > 0 { 527 needs_flush = true; 528 if let State::Initializing { ref mut needs_flush, .. } = self.state { 529 *needs_flush = true; 530 } 531 } 532 533 if needs_flush { 534 self.stream.flush()?; 535 if let State::Initializing { ref mut needs_flush, .. } = self.state { 536 *needs_flush = false; 537 } 538 } 539 540 if !shutting_down && !validated { 541 // on the last call, we require a valid certificate 542 if self.validate(!more_calls)? { 543 if let State::Initializing { ref mut validated, .. } = self.state { 544 *validated = true; 545 } 546 } 547 } 548 549 if !more_calls { 550 self.state = if shutting_down { 551 State::Shutdown 552 } else { 553 State::Streaming { sizes: self.context.stream_sizes()? } 554 }; 555 continue; 556 } 557 558 if self.needs_read > 0 { 559 if self.read_in()? == 0 { 560 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, 561 "unexpected EOF during handshake")); 562 } 563 } 564 565 self.step_initialize()?; 566 } 567 State::Streaming { sizes } => return Ok(Some(sizes)), 568 State::Shutdown => return Ok(None), 569 } 570 } 571 } 572 573 /// Returns true when the certificate was succesfully verified 574 /// Returns false, when a verification isn't necessary (yet) 575 /// Returns an error when the verification failed 576 fn validate(&mut self, require_cert: bool) -> io::Result<bool> { 577 // If we're accepting connections then we don't perform any validation 578 // for the remote certificate, that's what they're doing! 579 if self.server { 580 return Ok(false); 581 } 582 583 let cert_context = match self.context.remote_cert() { 584 Err(_) if !require_cert => return Ok(false), 585 ret => ret? 586 }; 587 588 let cert_chain = unsafe { 589 let cert_store = match (cert_context.cert_store(), &self.cert_store) { 590 (Some(ref mut chain_certs), &Some(ref extra_certs)) => { 591 for extra_cert in extra_certs.certs() { 592 chain_certs.add_cert(&extra_cert, CertAdd::ReplaceExisting)?; 593 } 594 chain_certs.as_inner() 595 }, 596 (Some(chain_certs), &None) => chain_certs.as_inner(), 597 (None, &Some(ref extra_certs)) => extra_certs.as_inner(), 598 (None, &None) => ptr::null_mut() 599 }; 600 601 let flags = wincrypt::CERT_CHAIN_CACHE_END_CERT | 602 wincrypt::CERT_CHAIN_REVOCATION_CHECK_CACHE_ONLY | 603 wincrypt::CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT; 604 605 let mut para: wincrypt::CERT_CHAIN_PARA = mem::zeroed(); 606 para.cbSize = mem::size_of_val(¶) as winapi::DWORD; 607 para.RequestedUsage.dwType = wincrypt::USAGE_MATCH_TYPE_OR; 608 609 let mut identifiers = [szOID_PKIX_KP_SERVER_AUTH.as_ptr() as ntdef::LPSTR, 610 szOID_SERVER_GATED_CRYPTO.as_ptr() as ntdef::LPSTR, 611 szOID_SGC_NETSCAPE.as_ptr() as ntdef::LPSTR]; 612 para.RequestedUsage.Usage.cUsageIdentifier = identifiers.len() as winapi::DWORD; 613 para.RequestedUsage.Usage.rgpszUsageIdentifier = identifiers.as_mut_ptr(); 614 615 let mut cert_chain = mem::zeroed(); 616 617 let res = wincrypt::CertGetCertificateChain(ptr::null_mut(), 618 cert_context.as_inner(), 619 ptr::null_mut(), 620 cert_store, 621 &mut para, 622 flags, 623 ptr::null_mut(), 624 &mut cert_chain); 625 626 if res == winapi::TRUE { 627 CertChainContext(cert_chain as *mut _) 628 } else { 629 return Err(io::Error::last_os_error()) 630 } 631 }; 632 633 unsafe { 634 // check if we trust the root-CA explicitly 635 let mut para_flags = wincrypt::CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS; 636 if let Some(ref mut store) = self.cert_store { 637 if let Some(chain) = cert_chain.final_chain() { 638 // check if any cert of the chain is in the passed store (and therefore trusted) 639 if chain.certificates().any(|cert| store.certs().any(|root_cert| root_cert == cert)) { 640 para_flags |= wincrypt::CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG; 641 } 642 } 643 } 644 645 let mut extra_para: wincrypt::SSL_EXTRA_CERT_CHAIN_POLICY_PARA = mem::zeroed(); 646 *extra_para.u.cbSize_mut() = mem::size_of_val(&extra_para) as winapi::DWORD; 647 extra_para.dwAuthType = wincrypt::AUTHTYPE_SERVER; 648 match self.domain { 649 Some(ref mut domain) if !self.accept_invalid_hostnames => { 650 extra_para.pwszServerName = domain.as_mut_ptr(); 651 } 652 _ => {} 653 } 654 655 let mut para: wincrypt::CERT_CHAIN_POLICY_PARA = mem::zeroed(); 656 para.cbSize = mem::size_of_val(¶) as winapi::DWORD; 657 para.dwFlags = para_flags; 658 para.pvExtraPolicyPara = &mut extra_para as *mut _ as *mut _; 659 660 let mut status: wincrypt::CERT_CHAIN_POLICY_STATUS = mem::zeroed(); 661 status.cbSize = mem::size_of_val(&status) as winapi::DWORD; 662 663 let verify_chain_policy_structure = wincrypt::CERT_CHAIN_POLICY_SSL as ntdef::LPCSTR; 664 let res = wincrypt::CertVerifyCertificateChainPolicy(verify_chain_policy_structure, 665 cert_chain.0, 666 &mut para, 667 &mut status); 668 if res == winapi::FALSE { 669 return Err(io::Error::last_os_error()) 670 } 671 672 let mut verify_result = if status.dwError != winerror::ERROR_SUCCESS { 673 Err(io::Error::from_raw_os_error(status.dwError as i32)) 674 } else { 675 Ok(()) 676 }; 677 678 // check if there's a user-specified verify callback 679 if let Some(ref callback) = self.verify_callback { 680 verify_result = callback(CertValidationResult{ 681 chain: cert_chain, 682 res: status.dwError as i32, 683 chain_index: status.lChainIndex, 684 element_index: status.lElementIndex}); 685 } 686 verify_result?; 687 } 688 Ok(true) 689 } 690 691 fn write_out(&mut self) -> io::Result<usize> { 692 let mut out = 0; 693 while self.out_buf.position() as usize != self.out_buf.get_ref().len() { 694 let position = self.out_buf.position() as usize; 695 let nwritten = self.stream.write(&self.out_buf.get_ref()[position..])?; 696 out += nwritten; 697 self.out_buf.set_position((position + nwritten) as u64); 698 } 699 700 Ok(out) 701 } 702 703 fn read_in(&mut self) -> io::Result<usize> { 704 let mut sum_nread = 0; 705 706 while self.needs_read > 0 { 707 let existing_len = self.enc_in.position() as usize; 708 let min_len = cmp::max(cmp::max(1024, 2 * existing_len), self.needs_read); 709 if self.enc_in.get_ref().len() < min_len { 710 self.enc_in.get_mut().resize(min_len, 0); 711 } 712 let nread = { 713 let buf = &mut self.enc_in.get_mut()[existing_len..]; 714 self.stream.read(buf)? 715 }; 716 self.enc_in.set_position((existing_len + nread) as u64); 717 self.needs_read = self.needs_read.saturating_sub(nread); 718 if nread == 0 { 719 break; 720 } 721 sum_nread += nread; 722 } 723 724 Ok(sum_nread) 725 } 726 727 fn consume_enc_in(&mut self, nread: usize) { 728 let size = self.enc_in.position() as usize; 729 assert!(size >= nread); 730 let count = size - nread; 731 732 if count > 0 { 733 self.enc_in.get_mut().drain(..nread); 734 } 735 736 self.enc_in.set_position(count as u64); 737 } 738 739 fn decrypt(&mut self) -> io::Result<bool> { 740 unsafe { 741 let position = self.enc_in.position() as usize; 742 let mut bufs = [secbuf(sspi::SECBUFFER_DATA, 743 Some(&mut self.enc_in.get_mut()[..position])), 744 secbuf(sspi::SECBUFFER_EMPTY, None), 745 secbuf(sspi::SECBUFFER_EMPTY, None), 746 secbuf(sspi::SECBUFFER_EMPTY, None)]; 747 let mut bufdesc = secbuf_desc(&mut bufs); 748 749 match sspi::DecryptMessage(self.context.get_mut(), 750 &mut bufdesc, 751 0, 752 ptr::null_mut()) { 753 winerror::SEC_E_OK => { 754 let start = bufs[1].pvBuffer as usize - self.enc_in.get_ref().as_ptr() as usize; 755 let end = start + bufs[1].cbBuffer as usize; 756 self.dec_in.get_mut().clear(); 757 self.dec_in 758 .get_mut() 759 .extend_from_slice(&self.enc_in.get_ref()[start..end]); 760 self.dec_in.set_position(0); 761 762 let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA { 763 self.enc_in.position() as usize - bufs[3].cbBuffer as usize 764 } else { 765 self.enc_in.position() as usize 766 }; 767 self.consume_enc_in(nread); 768 self.needs_read = (self.enc_in.position() == 0) as usize; 769 Ok(false) 770 } 771 winerror::SEC_E_INCOMPLETE_MESSAGE => { 772 self.needs_read = if bufs[1].BufferType == sspi::SECBUFFER_MISSING { 773 bufs[1].cbBuffer as usize 774 } else { 775 1 776 }; 777 Ok(false) 778 } 779 winerror::SEC_I_CONTEXT_EXPIRED => Ok(true), 780 winerror::SEC_I_RENEGOTIATE => { 781 self.state = State::Initializing { 782 needs_flush: false, 783 more_calls: true, 784 shutting_down: false, 785 validated: false, 786 }; 787 788 let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA { 789 self.enc_in.position() as usize - bufs[3].cbBuffer as usize 790 } else { 791 self.enc_in.position() as usize 792 }; 793 self.consume_enc_in(nread); 794 self.needs_read = 0; 795 Ok(false) 796 } 797 e => Err(io::Error::from_raw_os_error(e as i32)), 798 } 799 } 800 } 801 802 fn encrypt(&mut self, buf: &[u8], sizes: &sspi::SecPkgContext_StreamSizes) -> io::Result<()> { 803 assert!(buf.len() <= sizes.cbMaximumMessage as usize); 804 805 unsafe { 806 let len = sizes.cbHeader as usize + buf.len() + sizes.cbTrailer as usize; 807 808 if self.out_buf.get_ref().len() < len { 809 self.out_buf.get_mut().resize(len, 0); 810 } 811 812 let message_start = sizes.cbHeader as usize; 813 self.out_buf 814 .get_mut()[message_start..message_start + buf.len()] 815 .clone_from_slice(buf); 816 817 let mut bufs = { 818 let out_buf = self.out_buf.get_mut(); 819 let size = sizes.cbHeader as usize; 820 821 let header = secbuf(sspi::SECBUFFER_STREAM_HEADER, 822 Some(&mut out_buf[..size])); 823 let data = secbuf(sspi::SECBUFFER_DATA, 824 Some(&mut out_buf[size..size + buf.len()])); 825 let trailer = secbuf(sspi::SECBUFFER_STREAM_TRAILER, 826 Some(&mut out_buf[size + buf.len()..])); 827 let empty = secbuf(sspi::SECBUFFER_EMPTY, None); 828 [header, data, trailer, empty] 829 }; 830 let mut bufdesc = secbuf_desc(&mut bufs); 831 832 match sspi::EncryptMessage(self.context.get_mut(), 0, &mut bufdesc, 0) { 833 winerror::SEC_E_OK => { 834 let len = bufs[0].cbBuffer + bufs[1].cbBuffer + bufs[2].cbBuffer; 835 self.out_buf.get_mut().truncate(len as usize); 836 self.out_buf.set_position(0); 837 Ok(()) 838 } 839 err => Err(io::Error::from_raw_os_error(err as i32)), 840 } 841 } 842 } 843 } 844 845 impl<S> MidHandshakeTlsStream<S> { 846 /// Returns a shared reference to the inner stream. 847 pub fn get_ref(&self) -> &S { 848 self.inner.get_ref() 849 } 850 851 /// Returns a mutable reference to the inner stream. 852 pub fn get_mut(&mut self) -> &mut S { 853 self.inner.get_mut() 854 } 855 } 856 857 impl<S> MidHandshakeTlsStream<S> 858 where S: Read + Write, 859 { 860 /// Restarts the handshake process. 861 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> { 862 match self.inner.initialize() { 863 Ok(_) => Ok(self.inner), 864 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { 865 Err(HandshakeError::Interrupted(self)) 866 } 867 Err(e) => Err(HandshakeError::Failure(e)), 868 } 869 } 870 } 871 872 impl<S> Write for TlsStream<S> 873 where S: Read + Write 874 { 875 /// In the case of a WouldBlock error, we expect another call 876 /// starting with the same input data 877 /// This is similar to the use of ACCEPT_MOVING_WRITE_BUFFER in openssl 878 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 879 let sizes = match self.initialize()? { 880 Some(sizes) => sizes, 881 None => return Err(io::Error::from_raw_os_error(winerror::SEC_E_CONTEXT_EXPIRED as i32)), 882 }; 883 884 // if we have pending output data, it must have been because a previous 885 // attempt to send this part of the data ran into an error. 886 if self.out_buf.position() == self.out_buf.get_ref().len() as u64 { 887 let len = cmp::min(buf.len(), sizes.cbMaximumMessage as usize); 888 self.encrypt(&buf[..len], &sizes)?; 889 self.last_write_len = len; 890 } 891 self.write_out()?; 892 893 Ok(self.last_write_len) 894 } 895 896 fn flush(&mut self) -> io::Result<()> { 897 // Make sure the write buffer is emptied 898 self.write_out()?; 899 self.stream.flush() 900 } 901 } 902 903 impl<S> Read for TlsStream<S> 904 where S: Read + Write 905 { 906 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 907 let nread = { 908 let read_buf = self.fill_buf()?; 909 let nread = cmp::min(buf.len(), read_buf.len()); 910 buf[..nread].copy_from_slice(&read_buf[..nread]); 911 nread 912 }; 913 self.consume(nread); 914 Ok(nread) 915 } 916 } 917 918 impl<S> BufRead for TlsStream<S> 919 where S: Read + Write 920 { 921 fn fill_buf(&mut self) -> io::Result<&[u8]> { 922 while self.get_buf().is_empty() { 923 if let None = self.initialize()? { 924 break; 925 } 926 927 if self.needs_read > 0 { 928 if self.read_in()? == 0 { 929 break; 930 } 931 self.needs_read = 0; 932 } 933 934 let eof = self.decrypt()?; 935 if eof { 936 break; 937 } 938 } 939 940 Ok(self.get_buf()) 941 } 942 943 fn consume(&mut self, amt: usize) { 944 let pos = self.dec_in.position() + amt as u64; 945 assert!(pos <= self.dec_in.get_ref().len() as u64); 946 self.dec_in.set_position(pos); 947 } 948 } 949