1 //! Schannel TLS streams.
2 use std::any::Any;
3 use std::cmp;
4 use std::error::Error;
5 use std::fmt;
6 use std::io::{self, Read, BufRead, Write, Cursor};
7 use std::mem;
8 use std::ptr;
9 use std::slice;
10 use std::sync::Arc;
11 use winapi::shared::minwindef as winapi;
12 use winapi::shared::{ntdef, sspi, winerror};
13 use winapi::um::{self, wincrypt};
14
15 use crate::{INIT_REQUESTS, ACCEPT_REQUESTS, Inner, secbuf, secbuf_desc};
16 use crate::cert_chain::{CertChain, CertChainContext};
17 use crate::cert_store::{CertAdd, CertStore};
18 use crate::cert_context::CertContext;
19 use crate::security_context::SecurityContext;
20 use crate::context_buffer::ContextBuffer;
21 use crate::schannel_cred::SchannelCred;
22
23 lazy_static! {
24 static ref szOID_PKIX_KP_SERVER_AUTH: Vec<u8> =
25 wincrypt::szOID_PKIX_KP_SERVER_AUTH.bytes().chain(Some(0)).collect();
26 static ref szOID_SERVER_GATED_CRYPTO: Vec<u8> =
27 wincrypt::szOID_SERVER_GATED_CRYPTO.bytes().chain(Some(0)).collect();
28 static ref szOID_SGC_NETSCAPE: Vec<u8> =
29 wincrypt::szOID_SGC_NETSCAPE.bytes().chain(Some(0)).collect();
30 }
31
32 /// A builder type for `TlsStream`s.
33 pub struct Builder {
34 domain: Option<Vec<u16>>,
35 use_sni: bool,
36 accept_invalid_hostnames: bool,
37 verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>,
38 cert_store: Option<CertStore>,
39 }
40
41 impl Default for Builder {
default() -> Builder42 fn default() -> Builder {
43 Builder {
44 domain: None,
45 use_sni: true,
46 accept_invalid_hostnames: false,
47 verify_callback: None,
48 cert_store: None,
49 }
50 }
51 }
52
53 impl Builder {
54 /// Returns a new `Builder`.
new() -> Builder55 pub fn new() -> Builder {
56 Builder::default()
57 }
58
59 /// Sets the domain associated with connections created with this `Builder`.
60 ///
61 /// The domain will be used for Server Name Indication as well as
62 /// certificate validation.
domain(&mut self, domain: &str) -> &mut Builder63 pub fn domain(&mut self, domain: &str) -> &mut Builder {
64 self.domain = Some(domain.encode_utf16().chain(Some(0)).collect());
65 self
66 }
67
68 /// Determines if Server Name Indication (SNI) will be used.
69 ///
70 /// Defaults to `true`.
use_sni(&mut self, use_sni: bool) -> &mut Builder71 pub fn use_sni(&mut self, use_sni: bool) -> &mut Builder {
72 self.use_sni = use_sni;
73 self
74 }
75
76 /// Determines if the server's hostname will be checked during certificate verification.
77 ///
78 /// Defaults to `false`.
accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) -> &mut Builder79 pub fn accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) -> &mut Builder {
80 self.accept_invalid_hostnames = accept_invalid_hostnames;
81 self
82 }
83
84 /// Set a verification callback to be used for connections created with this `Builder`.
85 ///
86 /// The callback is provided with an io::Result indicating if the (pre)validation was
87 /// successful. The Ok() variant indicates a successful validation while the Err() variant
88 /// contains the errorcode returned from the internal verification process.
89 /// The validated certificate, is accessible through the second argument of the closure.
verify_callback<F>(&mut self, callback: F) -> &mut Builder where F: Fn(CertValidationResult) -> io::Result<()> + 'static + Sync + Send90 pub fn verify_callback<F>(&mut self, callback: F) -> &mut Builder
91 where F: Fn(CertValidationResult) -> io::Result<()> + 'static + Sync + Send
92 {
93 self.verify_callback = Some(Arc::new(callback));
94 self
95 }
96
97 /// Specifies a custom certificate store which is later used when validating
98 /// a server's certificate.
99 ///
100 /// This option is only used for client connections and is used to construct
101 /// the certificate chain which the server's certificate is validated
102 /// against.
103 ///
104 /// Note that adding certificates here means that they are
105 /// implicitly trusted.
cert_store(&mut self, cert_store: CertStore) -> &mut Builder106 pub fn cert_store(&mut self, cert_store: CertStore) -> &mut Builder {
107 self.cert_store = Some(cert_store);
108 self
109 }
110
111 /// Initialize a new TLS session where the stream provided will be
112 /// connecting to a remote TLS server.
113 ///
114 /// If the stream provided is a blocking stream then the entire handshake
115 /// will be performed if possible, but if the stream is in nonblocking mode
116 /// then a `HandshakeError::Interrupted` variant may be returned. This
117 /// type can then be extracted to later call
118 /// `MidHandshakeTlsStream::handshake` when data becomes available.
connect<S>(&mut self, cred: SchannelCred, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: Read + Write119 pub fn connect<S>(&mut self,
120 cred: SchannelCred,
121 stream: S)
122 -> Result<TlsStream<S>, HandshakeError<S>>
123 where S: Read + Write
124 {
125 self.initialize(cred, false, stream)
126 }
127
128 /// Initialize a new TLS session where the stream provided will be
129 /// accepting a connection.
130 ///
131 /// This method will tweak the protocol for "who talks first" and also
132 /// currently disables validation of the client that's connecting to us.
133 ///
134 /// If the stream provided is a blocking stream then the entire handshake
135 /// will be performed if possible, but if the stream is in nonblocking mode
136 /// then a `HandshakeError::Interrupted` variant may be returned. This
137 /// type can then be extracted to later call
138 /// `MidHandshakeTlsStream::handshake` when data becomes available.
accept<S>(&mut self, cred: SchannelCred, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: Read + Write139 pub fn accept<S>(&mut self,
140 cred: SchannelCred,
141 stream: S)
142 -> Result<TlsStream<S>, HandshakeError<S>>
143 where S: Read + Write
144 {
145 self.initialize(cred, true, stream)
146 }
147
initialize<S>(&mut self, mut cred: SchannelCred, server: bool, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: Read + Write148 fn initialize<S>(&mut self,
149 mut cred: SchannelCred,
150 server: bool,
151 stream: S)
152 -> Result<TlsStream<S>, HandshakeError<S>>
153 where S: Read + Write
154 {
155 let domain = match self.domain {
156 Some(ref domain) if self.use_sni => Some(&domain[..]),
157 _ => None,
158 };
159 let (ctxt, buf) = match SecurityContext::initialize(&mut cred,
160 server,
161 domain) {
162 Ok(pair) => pair,
163 Err(e) => return Err(HandshakeError::Failure(e)),
164 };
165
166 let stream = TlsStream {
167 cred: cred,
168 context: ctxt,
169 cert_store: self.cert_store.clone(),
170 domain: self.domain.clone(),
171 use_sni: self.use_sni,
172 accept_invalid_hostnames: self.accept_invalid_hostnames,
173 verify_callback: self.verify_callback.clone(),
174 stream: stream,
175 server: server,
176 accept_first: true,
177 state: State::Initializing {
178 needs_flush: false,
179 more_calls: true,
180 shutting_down: false,
181 validated: false,
182 },
183 needs_read: 1,
184 dec_in: Cursor::new(Vec::new()),
185 enc_in: Cursor::new(Vec::new()),
186 out_buf: Cursor::new(buf.map(|b| b.to_owned()).unwrap_or(Vec::new())),
187 last_write_len: 0,
188 };
189
190 MidHandshakeTlsStream {
191 inner: stream,
192 }.handshake()
193 }
194 }
195
196 enum State {
197 Initializing {
198 needs_flush: bool,
199 more_calls: bool,
200 shutting_down: bool,
201 validated: bool,
202 },
203 Streaming { sizes: sspi::SecPkgContext_StreamSizes, },
204 Shutdown,
205 }
206
207 /// An Schannel TLS stream.
208 pub struct TlsStream<S> {
209 cred: SchannelCred,
210 context: SecurityContext,
211 cert_store: Option<CertStore>,
212 domain: Option<Vec<u16>>,
213 use_sni: bool,
214 accept_invalid_hostnames: bool,
215 verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>,
216 stream: S,
217 state: State,
218 server: bool,
219 accept_first: bool,
220 needs_read: usize,
221 // valid from position() to len()
222 dec_in: Cursor<Vec<u8>>,
223 // valid from 0 to position()
224 enc_in: Cursor<Vec<u8>>,
225 // valid from position() to len()
226 out_buf: Cursor<Vec<u8>>,
227 /// the (unencrypted) length of the last write call used to track writes
228 last_write_len: usize,
229 }
230
231 /// ensures that a TlsStream is always Sync/Send
_is_sync()232 fn _is_sync() {
233 fn sync<T: Sync + Send>() {}
234 sync::<TlsStream<()>>();
235 }
236
237 /// A failure which can happen during the `Builder::initialize` phase, either an
238 /// I/O error or an intermediate stream which has not completed its handshake.
239 #[derive(Debug)]
240 pub enum HandshakeError<S> {
241 /// A fatal I/O error occurred
242 Failure(io::Error),
243 /// The stream connection is in progress, but the handshake is not completed
244 /// yet.
245 Interrupted(MidHandshakeTlsStream<S>),
246 }
247
248 /// A struct used to wrap various cert chain validation results for callback processing.
249 pub struct CertValidationResult {
250 chain: CertChainContext,
251 res: i32,
252 chain_index: i32,
253 element_index: i32,
254 }
255
256 impl CertValidationResult {
257 /// Returns the certificate that failed validation if applicable
failed_certificate(&self) -> Option<CertContext>258 pub fn failed_certificate(&self) -> Option<CertContext> {
259 if let Some(cert_chain) = self.chain.get_chain(self.chain_index as usize) {
260 return cert_chain.get(self.element_index as usize);
261 }
262 None
263 }
264
265 /// Returns the final certificate chain in the certificate context if applicable
chain(&self) -> Option<CertChain>266 pub fn chain(&self) -> Option<CertChain> {
267 self.chain.final_chain()
268 }
269
270 /// Returns the result of the built-in certificate verification process.
result(&self) -> io::Result<()>271 pub fn result(&self) -> io::Result<()> {
272 if self.res as u32 != winerror::ERROR_SUCCESS {
273 Err(io::Error::from_raw_os_error(self.res))
274 } else {
275 Ok(())
276 }
277 }
278 }
279
280 impl<S: fmt::Debug + Any> Error for HandshakeError<S> {
description(&self) -> &str281 fn description(&self) -> &str {
282 match *self {
283 HandshakeError::Failure(_) => "failed to perform handshake",
284 HandshakeError::Interrupted(_) => "interrupted performing handshake",
285 }
286 }
287
cause(&self) -> Option<&dyn Error>288 fn cause(&self) -> Option<&dyn Error> {
289 match *self {
290 HandshakeError::Failure(ref e) => Some(e),
291 HandshakeError::Interrupted(_) => None,
292 }
293 }
294 }
295
296 impl<S: fmt::Debug + Any> fmt::Display for HandshakeError<S> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result297 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
298 f.write_str(self.description())?;
299 if let Some(e) = self.source() {
300 write!(f, ": {}", e)?;
301 }
302 Ok(())
303 }
304 }
305
306 /// A stream which has not yet completed its handshake.
307 #[derive(Debug)]
308 pub struct MidHandshakeTlsStream<S> {
309 inner: TlsStream<S>,
310 }
311
312 impl<S> fmt::Debug for TlsStream<S>
313 where S: fmt::Debug
314 {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result315 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
316 fmt.debug_struct("TlsStream")
317 .field("stream", &self.stream)
318 .finish()
319 }
320 }
321
322 impl<S> TlsStream<S> {
323 /// Returns a reference to the wrapped stream.
get_ref(&self) -> &S324 pub fn get_ref(&self) -> &S {
325 &self.stream
326 }
327
328 /// Returns a mutable reference to the wrapped stream.
get_mut(&mut self) -> &mut S329 pub fn get_mut(&mut self) -> &mut S {
330 &mut self.stream
331 }
332
333 /// Indicates if this stream is the server- or client-side of a TLS session.
is_server(&self) -> bool334 pub fn is_server(&self) -> bool {
335 self.server
336 }
337 }
338
339 impl<S> TlsStream<S>
340 where S: Read + Write
341 {
342 /// Returns the certificate used to identify this side of the TLS session.
343 ///
344 /// Its associated cert store contains any intermediate certificates sent
345 /// along with the leaf.
certificate(&self) -> io::Result<CertContext>346 pub fn certificate(&self) -> io::Result<CertContext> {
347 self.context.local_cert()
348 }
349
350 /// Returns the peer's certificate, if available.
351 ///
352 /// Its associated cert store contains any intermediate certificates sent
353 /// by the server.
peer_certificate(&self) -> io::Result<CertContext>354 pub fn peer_certificate(&self) -> io::Result<CertContext> {
355 self.context.remote_cert()
356 }
357
358 /// Returns a reference to the buffer of pending data.
359 ///
360 /// Like `BufRead::fill_buf` except that it will return an empty slice
361 /// rather than reading from the wrapped stream if there is no buffered
362 /// data.
get_buf(&self) -> &[u8]363 pub fn get_buf(&self) -> &[u8] {
364 &self.dec_in.get_ref()[self.dec_in.position() as usize..]
365 }
366
367 /// Shuts the TLS session down.
shutdown(&mut self) -> io::Result<()>368 pub fn shutdown(&mut self) -> io::Result<()> {
369 match self.state {
370 State::Shutdown => return Ok(()),
371 State::Initializing { shutting_down: true, .. } => {}
372 _ => {
373 unsafe {
374 let mut token = um::schannel::SCHANNEL_SHUTDOWN;
375 let ptr = &mut token as *mut _ as *mut u8;
376 let size = mem::size_of_val(&token);
377 let token = slice::from_raw_parts_mut(ptr, size);
378 let mut buf = [secbuf(sspi::SECBUFFER_TOKEN, Some(token))];
379 let mut desc = secbuf_desc(&mut buf);
380
381 match sspi::ApplyControlToken(self.context.get_mut(), &mut desc) {
382 winerror::SEC_E_OK => {}
383 err => return Err(io::Error::from_raw_os_error(err as i32)),
384 }
385 }
386
387 self.state = State::Initializing {
388 needs_flush: false,
389 more_calls: true,
390 shutting_down: true,
391 validated: false,
392 };
393 self.needs_read = 0;
394 }
395 }
396
397 self.initialize().map(|_| ())
398 }
399
step_initialize(&mut self) -> io::Result<()>400 fn step_initialize(&mut self) -> io::Result<()> {
401 unsafe {
402 let pos = self.enc_in.position() as usize;
403 let mut inbufs = [secbuf(sspi::SECBUFFER_TOKEN,
404 Some(&mut self.enc_in.get_mut()[..pos])),
405 secbuf(sspi::SECBUFFER_EMPTY, None)];
406 let mut inbuf_desc = secbuf_desc(&mut inbufs);
407
408 let mut outbufs = [secbuf(sspi::SECBUFFER_TOKEN, None),
409 secbuf(sspi::SECBUFFER_ALERT, None),
410 secbuf(sspi::SECBUFFER_EMPTY, None)];
411 let mut outbuf_desc = secbuf_desc(&mut outbufs);
412
413 let mut attributes = 0;
414
415 let status = if self.server {
416 let ptr = if self.accept_first {
417 ptr::null_mut()
418 } else {
419 self.context.get_mut()
420 };
421 sspi::AcceptSecurityContext(self.cred.get_mut(),
422 ptr,
423 &mut inbuf_desc,
424 ACCEPT_REQUESTS,
425 0,
426 self.context.get_mut(),
427 &mut outbuf_desc,
428 &mut attributes,
429 ptr::null_mut())
430 } else {
431 let domain = match self.domain {
432 Some(ref domain) if self.use_sni => domain.as_ptr() as *mut u16,
433 _ => ptr::null_mut(),
434 };
435
436 sspi::InitializeSecurityContextW(self.cred.get_mut(),
437 self.context.get_mut(),
438 domain,
439 INIT_REQUESTS,
440 0,
441 0,
442 &mut inbuf_desc,
443 0,
444 ptr::null_mut(),
445 &mut outbuf_desc,
446 &mut attributes,
447 ptr::null_mut())
448 };
449
450 for buf in &outbufs[1..] {
451 if !buf.pvBuffer.is_null() {
452 sspi::FreeContextBuffer(buf.pvBuffer);
453 }
454 }
455
456 match status {
457 winerror::SEC_I_CONTINUE_NEEDED => {
458 // Windows apparently doesn't like AcceptSecurityContext
459 // being called as if it were the second time unless the
460 // first call to AcceptSecurityContext succeeded with
461 // CONTINUE_NEEDED.
462 //
463 // In other words, if we were to set `accept_first` to
464 // `false` after the literal first call to
465 // `AcceptSecurityContext` while the call returned
466 // INCOMPLETE_MESSAGE, the next call would return an error.
467 //
468 // For that reason we only set `accept_first` to false here
469 // once we've actually successfully received the full
470 // "token" from the client.
471 self.accept_first = false;
472 let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA {
473 self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
474 } else {
475 self.enc_in.position() as usize
476 };
477 let to_write = ContextBuffer(outbufs[0]);
478
479 self.consume_enc_in(nread);
480 self.needs_read = (self.enc_in.position() == 0) as usize;
481 self.out_buf.get_mut().extend_from_slice(&to_write);
482 }
483 winerror::SEC_E_INCOMPLETE_MESSAGE => {
484 self.needs_read = if inbufs[1].BufferType == sspi::SECBUFFER_MISSING {
485 inbufs[1].cbBuffer as usize
486 } else {
487 1
488 };
489 }
490 winerror::SEC_E_OK => {
491 let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA {
492 self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
493 } else {
494 self.enc_in.position() as usize
495 };
496 let to_write = if outbufs[0].pvBuffer.is_null() {
497 None
498 } else {
499 Some(ContextBuffer(outbufs[0]))
500 };
501
502 self.consume_enc_in(nread);
503 self.needs_read = (self.enc_in.position() == 0) as usize;
504 if let Some(to_write) = to_write {
505 self.out_buf.get_mut().extend_from_slice(&to_write);
506 }
507 if self.enc_in.position() != 0 {
508 self.decrypt()?;
509 }
510 if let State::Initializing { ref mut more_calls, .. } = self.state {
511 *more_calls = false;
512 }
513 }
514 _ => {
515 return Err(io::Error::from_raw_os_error(status as i32))
516 }
517 }
518 Ok(())
519 }
520 }
521
initialize(&mut self) -> io::Result<Option<sspi::SecPkgContext_StreamSizes>>522 fn initialize(&mut self) -> io::Result<Option<sspi::SecPkgContext_StreamSizes>> {
523 loop {
524 match self.state {
525 State::Initializing { mut needs_flush, more_calls, shutting_down, validated } => {
526 if self.write_out()? > 0 {
527 needs_flush = true;
528 if let State::Initializing { ref mut needs_flush, .. } = self.state {
529 *needs_flush = true;
530 }
531 }
532
533 if needs_flush {
534 self.stream.flush()?;
535 if let State::Initializing { ref mut needs_flush, .. } = self.state {
536 *needs_flush = false;
537 }
538 }
539
540 if !shutting_down && !validated {
541 // on the last call, we require a valid certificate
542 if self.validate(!more_calls)? {
543 if let State::Initializing { ref mut validated, .. } = self.state {
544 *validated = true;
545 }
546 }
547 }
548
549 if !more_calls {
550 self.state = if shutting_down {
551 State::Shutdown
552 } else {
553 State::Streaming { sizes: self.context.stream_sizes()? }
554 };
555 continue;
556 }
557
558 if self.needs_read > 0 {
559 if self.read_in()? == 0 {
560 return Err(io::Error::new(io::ErrorKind::UnexpectedEof,
561 "unexpected EOF during handshake"));
562 }
563 }
564
565 self.step_initialize()?;
566 }
567 State::Streaming { sizes } => return Ok(Some(sizes)),
568 State::Shutdown => return Ok(None),
569 }
570 }
571 }
572
573 /// Returns true when the certificate was succesfully verified
574 /// Returns false, when a verification isn't necessary (yet)
575 /// Returns an error when the verification failed
validate(&mut self, require_cert: bool) -> io::Result<bool>576 fn validate(&mut self, require_cert: bool) -> io::Result<bool> {
577 // If we're accepting connections then we don't perform any validation
578 // for the remote certificate, that's what they're doing!
579 if self.server {
580 return Ok(false);
581 }
582
583 let cert_context = match self.context.remote_cert() {
584 Err(_) if !require_cert => return Ok(false),
585 ret => ret?
586 };
587
588 let cert_chain = unsafe {
589 let cert_store = match (cert_context.cert_store(), &self.cert_store) {
590 (Some(ref mut chain_certs), &Some(ref extra_certs)) => {
591 for extra_cert in extra_certs.certs() {
592 chain_certs.add_cert(&extra_cert, CertAdd::ReplaceExisting)?;
593 }
594 chain_certs.as_inner()
595 },
596 (Some(chain_certs), &None) => chain_certs.as_inner(),
597 (None, &Some(ref extra_certs)) => extra_certs.as_inner(),
598 (None, &None) => ptr::null_mut()
599 };
600
601 let flags = wincrypt::CERT_CHAIN_CACHE_END_CERT |
602 wincrypt::CERT_CHAIN_REVOCATION_CHECK_CACHE_ONLY |
603 wincrypt::CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT;
604
605 let mut para: wincrypt::CERT_CHAIN_PARA = mem::zeroed();
606 para.cbSize = mem::size_of_val(¶) as winapi::DWORD;
607 para.RequestedUsage.dwType = wincrypt::USAGE_MATCH_TYPE_OR;
608
609 let mut identifiers = [szOID_PKIX_KP_SERVER_AUTH.as_ptr() as ntdef::LPSTR,
610 szOID_SERVER_GATED_CRYPTO.as_ptr() as ntdef::LPSTR,
611 szOID_SGC_NETSCAPE.as_ptr() as ntdef::LPSTR];
612 para.RequestedUsage.Usage.cUsageIdentifier = identifiers.len() as winapi::DWORD;
613 para.RequestedUsage.Usage.rgpszUsageIdentifier = identifiers.as_mut_ptr();
614
615 let mut cert_chain = mem::zeroed();
616
617 let res = wincrypt::CertGetCertificateChain(ptr::null_mut(),
618 cert_context.as_inner(),
619 ptr::null_mut(),
620 cert_store,
621 &mut para,
622 flags,
623 ptr::null_mut(),
624 &mut cert_chain);
625
626 if res == winapi::TRUE {
627 CertChainContext(cert_chain as *mut _)
628 } else {
629 return Err(io::Error::last_os_error())
630 }
631 };
632
633 unsafe {
634 // check if we trust the root-CA explicitly
635 let mut para_flags = wincrypt::CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS;
636 if let Some(ref mut store) = self.cert_store {
637 if let Some(chain) = cert_chain.final_chain() {
638 // check if any cert of the chain is in the passed store (and therefore trusted)
639 if chain.certificates().any(|cert| store.certs().any(|root_cert| root_cert == cert)) {
640 para_flags |= wincrypt::CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG;
641 }
642 }
643 }
644
645 let mut extra_para: wincrypt::SSL_EXTRA_CERT_CHAIN_POLICY_PARA = mem::zeroed();
646 *extra_para.u.cbSize_mut() = mem::size_of_val(&extra_para) as winapi::DWORD;
647 extra_para.dwAuthType = wincrypt::AUTHTYPE_SERVER;
648 match self.domain {
649 Some(ref mut domain) if !self.accept_invalid_hostnames => {
650 extra_para.pwszServerName = domain.as_mut_ptr();
651 }
652 _ => {}
653 }
654
655 let mut para: wincrypt::CERT_CHAIN_POLICY_PARA = mem::zeroed();
656 para.cbSize = mem::size_of_val(¶) as winapi::DWORD;
657 para.dwFlags = para_flags;
658 para.pvExtraPolicyPara = &mut extra_para as *mut _ as *mut _;
659
660 let mut status: wincrypt::CERT_CHAIN_POLICY_STATUS = mem::zeroed();
661 status.cbSize = mem::size_of_val(&status) as winapi::DWORD;
662
663 let verify_chain_policy_structure = wincrypt::CERT_CHAIN_POLICY_SSL as ntdef::LPCSTR;
664 let res = wincrypt::CertVerifyCertificateChainPolicy(verify_chain_policy_structure,
665 cert_chain.0,
666 &mut para,
667 &mut status);
668 if res == winapi::FALSE {
669 return Err(io::Error::last_os_error())
670 }
671
672 let mut verify_result = if status.dwError != winerror::ERROR_SUCCESS {
673 Err(io::Error::from_raw_os_error(status.dwError as i32))
674 } else {
675 Ok(())
676 };
677
678 // check if there's a user-specified verify callback
679 if let Some(ref callback) = self.verify_callback {
680 verify_result = callback(CertValidationResult{
681 chain: cert_chain,
682 res: status.dwError as i32,
683 chain_index: status.lChainIndex,
684 element_index: status.lElementIndex});
685 }
686 verify_result?;
687 }
688 Ok(true)
689 }
690
write_out(&mut self) -> io::Result<usize>691 fn write_out(&mut self) -> io::Result<usize> {
692 let mut out = 0;
693 while self.out_buf.position() as usize != self.out_buf.get_ref().len() {
694 let position = self.out_buf.position() as usize;
695 let nwritten = self.stream.write(&self.out_buf.get_ref()[position..])?;
696 out += nwritten;
697 self.out_buf.set_position((position + nwritten) as u64);
698 }
699
700 Ok(out)
701 }
702
read_in(&mut self) -> io::Result<usize>703 fn read_in(&mut self) -> io::Result<usize> {
704 let mut sum_nread = 0;
705
706 while self.needs_read > 0 {
707 let existing_len = self.enc_in.position() as usize;
708 let min_len = cmp::max(cmp::max(1024, 2 * existing_len), self.needs_read);
709 if self.enc_in.get_ref().len() < min_len {
710 self.enc_in.get_mut().resize(min_len, 0);
711 }
712 let nread = {
713 let buf = &mut self.enc_in.get_mut()[existing_len..];
714 self.stream.read(buf)?
715 };
716 self.enc_in.set_position((existing_len + nread) as u64);
717 self.needs_read = self.needs_read.saturating_sub(nread);
718 if nread == 0 {
719 break;
720 }
721 sum_nread += nread;
722 }
723
724 Ok(sum_nread)
725 }
726
consume_enc_in(&mut self, nread: usize)727 fn consume_enc_in(&mut self, nread: usize) {
728 let size = self.enc_in.position() as usize;
729 assert!(size >= nread);
730 let count = size - nread;
731
732 if count > 0 {
733 self.enc_in.get_mut().drain(..nread);
734 }
735
736 self.enc_in.set_position(count as u64);
737 }
738
decrypt(&mut self) -> io::Result<bool>739 fn decrypt(&mut self) -> io::Result<bool> {
740 unsafe {
741 let position = self.enc_in.position() as usize;
742 let mut bufs = [secbuf(sspi::SECBUFFER_DATA,
743 Some(&mut self.enc_in.get_mut()[..position])),
744 secbuf(sspi::SECBUFFER_EMPTY, None),
745 secbuf(sspi::SECBUFFER_EMPTY, None),
746 secbuf(sspi::SECBUFFER_EMPTY, None)];
747 let mut bufdesc = secbuf_desc(&mut bufs);
748
749 match sspi::DecryptMessage(self.context.get_mut(),
750 &mut bufdesc,
751 0,
752 ptr::null_mut()) {
753 winerror::SEC_E_OK => {
754 let start = bufs[1].pvBuffer as usize - self.enc_in.get_ref().as_ptr() as usize;
755 let end = start + bufs[1].cbBuffer as usize;
756 self.dec_in.get_mut().clear();
757 self.dec_in
758 .get_mut()
759 .extend_from_slice(&self.enc_in.get_ref()[start..end]);
760 self.dec_in.set_position(0);
761
762 let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA {
763 self.enc_in.position() as usize - bufs[3].cbBuffer as usize
764 } else {
765 self.enc_in.position() as usize
766 };
767 self.consume_enc_in(nread);
768 self.needs_read = (self.enc_in.position() == 0) as usize;
769 Ok(false)
770 }
771 winerror::SEC_E_INCOMPLETE_MESSAGE => {
772 self.needs_read = if bufs[1].BufferType == sspi::SECBUFFER_MISSING {
773 bufs[1].cbBuffer as usize
774 } else {
775 1
776 };
777 Ok(false)
778 }
779 winerror::SEC_I_CONTEXT_EXPIRED => Ok(true),
780 winerror::SEC_I_RENEGOTIATE => {
781 self.state = State::Initializing {
782 needs_flush: false,
783 more_calls: true,
784 shutting_down: false,
785 validated: false,
786 };
787
788 let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA {
789 self.enc_in.position() as usize - bufs[3].cbBuffer as usize
790 } else {
791 self.enc_in.position() as usize
792 };
793 self.consume_enc_in(nread);
794 self.needs_read = 0;
795 Ok(false)
796 }
797 e => Err(io::Error::from_raw_os_error(e as i32)),
798 }
799 }
800 }
801
encrypt(&mut self, buf: &[u8], sizes: &sspi::SecPkgContext_StreamSizes) -> io::Result<()>802 fn encrypt(&mut self, buf: &[u8], sizes: &sspi::SecPkgContext_StreamSizes) -> io::Result<()> {
803 assert!(buf.len() <= sizes.cbMaximumMessage as usize);
804
805 unsafe {
806 let len = sizes.cbHeader as usize + buf.len() + sizes.cbTrailer as usize;
807
808 if self.out_buf.get_ref().len() < len {
809 self.out_buf.get_mut().resize(len, 0);
810 }
811
812 let message_start = sizes.cbHeader as usize;
813 self.out_buf
814 .get_mut()[message_start..message_start + buf.len()]
815 .clone_from_slice(buf);
816
817 let mut bufs = {
818 let out_buf = self.out_buf.get_mut();
819 let size = sizes.cbHeader as usize;
820
821 let header = secbuf(sspi::SECBUFFER_STREAM_HEADER,
822 Some(&mut out_buf[..size]));
823 let data = secbuf(sspi::SECBUFFER_DATA,
824 Some(&mut out_buf[size..size + buf.len()]));
825 let trailer = secbuf(sspi::SECBUFFER_STREAM_TRAILER,
826 Some(&mut out_buf[size + buf.len()..]));
827 let empty = secbuf(sspi::SECBUFFER_EMPTY, None);
828 [header, data, trailer, empty]
829 };
830 let mut bufdesc = secbuf_desc(&mut bufs);
831
832 match sspi::EncryptMessage(self.context.get_mut(), 0, &mut bufdesc, 0) {
833 winerror::SEC_E_OK => {
834 let len = bufs[0].cbBuffer + bufs[1].cbBuffer + bufs[2].cbBuffer;
835 self.out_buf.get_mut().truncate(len as usize);
836 self.out_buf.set_position(0);
837 Ok(())
838 }
839 err => Err(io::Error::from_raw_os_error(err as i32)),
840 }
841 }
842 }
843 }
844
845 impl<S> MidHandshakeTlsStream<S> {
846 /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S847 pub fn get_ref(&self) -> &S {
848 self.inner.get_ref()
849 }
850
851 /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S852 pub fn get_mut(&mut self) -> &mut S {
853 self.inner.get_mut()
854 }
855 }
856
857 impl<S> MidHandshakeTlsStream<S>
858 where S: Read + Write,
859 {
860 /// Restarts the handshake process.
handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>>861 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
862 match self.inner.initialize() {
863 Ok(_) => Ok(self.inner),
864 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
865 Err(HandshakeError::Interrupted(self))
866 }
867 Err(e) => Err(HandshakeError::Failure(e)),
868 }
869 }
870 }
871
872 impl<S> Write for TlsStream<S>
873 where S: Read + Write
874 {
875 /// In the case of a WouldBlock error, we expect another call
876 /// starting with the same input data
877 /// This is similar to the use of ACCEPT_MOVING_WRITE_BUFFER in openssl
write(&mut self, buf: &[u8]) -> io::Result<usize>878 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
879 let sizes = match self.initialize()? {
880 Some(sizes) => sizes,
881 None => return Err(io::Error::from_raw_os_error(winerror::SEC_E_CONTEXT_EXPIRED as i32)),
882 };
883
884 // if we have pending output data, it must have been because a previous
885 // attempt to send this part of the data ran into an error.
886 if self.out_buf.position() == self.out_buf.get_ref().len() as u64 {
887 let len = cmp::min(buf.len(), sizes.cbMaximumMessage as usize);
888 self.encrypt(&buf[..len], &sizes)?;
889 self.last_write_len = len;
890 }
891 self.write_out()?;
892
893 Ok(self.last_write_len)
894 }
895
flush(&mut self) -> io::Result<()>896 fn flush(&mut self) -> io::Result<()> {
897 // Make sure the write buffer is emptied
898 self.write_out()?;
899 self.stream.flush()
900 }
901 }
902
903 impl<S> Read for TlsStream<S>
904 where S: Read + Write
905 {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>906 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
907 let nread = {
908 let read_buf = self.fill_buf()?;
909 let nread = cmp::min(buf.len(), read_buf.len());
910 buf[..nread].copy_from_slice(&read_buf[..nread]);
911 nread
912 };
913 self.consume(nread);
914 Ok(nread)
915 }
916 }
917
918 impl<S> BufRead for TlsStream<S>
919 where S: Read + Write
920 {
fill_buf(&mut self) -> io::Result<&[u8]>921 fn fill_buf(&mut self) -> io::Result<&[u8]> {
922 while self.get_buf().is_empty() {
923 if let None = self.initialize()? {
924 break;
925 }
926
927 if self.needs_read > 0 {
928 if self.read_in()? == 0 {
929 break;
930 }
931 self.needs_read = 0;
932 }
933
934 let eof = self.decrypt()?;
935 if eof {
936 break;
937 }
938 }
939
940 Ok(self.get_buf())
941 }
942
consume(&mut self, amt: usize)943 fn consume(&mut self, amt: usize) {
944 let pos = self.dec_in.position() + amt as u64;
945 assert!(pos <= self.dec_in.get_ref().len() as u64);
946 self.dec_in.set_position(pos);
947 }
948 }
949