1 //! Schannel TLS streams.
2 use std::any::Any;
3 use std::cmp;
4 use std::error::Error;
5 use std::fmt;
6 use std::io::{self, Read, BufRead, Write, Cursor};
7 use std::mem;
8 use std::ptr;
9 use std::slice;
10 use std::sync::Arc;
11 use winapi::shared::minwindef as winapi;
12 use winapi::shared::{ntdef, sspi, winerror};
13 use winapi::um::{self, wincrypt};
14 
15 use crate::{INIT_REQUESTS, ACCEPT_REQUESTS, Inner, secbuf, secbuf_desc};
16 use crate::cert_chain::{CertChain, CertChainContext};
17 use crate::cert_store::{CertAdd, CertStore};
18 use crate::cert_context::CertContext;
19 use crate::security_context::SecurityContext;
20 use crate::context_buffer::ContextBuffer;
21 use crate::schannel_cred::SchannelCred;
22 
23 lazy_static! {
24     static ref szOID_PKIX_KP_SERVER_AUTH: Vec<u8> =
25         wincrypt::szOID_PKIX_KP_SERVER_AUTH.bytes().chain(Some(0)).collect();
26     static ref szOID_SERVER_GATED_CRYPTO: Vec<u8> =
27         wincrypt::szOID_SERVER_GATED_CRYPTO.bytes().chain(Some(0)).collect();
28     static ref szOID_SGC_NETSCAPE: Vec<u8> =
29         wincrypt::szOID_SGC_NETSCAPE.bytes().chain(Some(0)).collect();
30 }
31 
32 /// A builder type for `TlsStream`s.
33 pub struct Builder {
34     domain: Option<Vec<u16>>,
35     use_sni: bool,
36     accept_invalid_hostnames: bool,
37     verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>,
38     cert_store: Option<CertStore>,
39 }
40 
41 impl Default for Builder {
default() -> Builder42     fn default() -> Builder {
43         Builder {
44             domain: None,
45             use_sni: true,
46             accept_invalid_hostnames: false,
47             verify_callback: None,
48             cert_store: None,
49         }
50     }
51 }
52 
53 impl Builder {
54     /// Returns a new `Builder`.
new() -> Builder55     pub fn new() -> Builder {
56         Builder::default()
57     }
58 
59     /// Sets the domain associated with connections created with this `Builder`.
60     ///
61     /// The domain will be used for Server Name Indication as well as
62     /// certificate validation.
domain(&mut self, domain: &str) -> &mut Builder63     pub fn domain(&mut self, domain: &str) -> &mut Builder {
64         self.domain = Some(domain.encode_utf16().chain(Some(0)).collect());
65         self
66     }
67 
68     /// Determines if Server Name Indication (SNI) will be used.
69     ///
70     /// Defaults to `true`.
use_sni(&mut self, use_sni: bool) -> &mut Builder71     pub fn use_sni(&mut self, use_sni: bool) -> &mut Builder {
72         self.use_sni = use_sni;
73         self
74     }
75 
76     /// Determines if the server's hostname will be checked during certificate verification.
77     ///
78     /// Defaults to `false`.
accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) -> &mut Builder79     pub fn accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) -> &mut Builder {
80         self.accept_invalid_hostnames = accept_invalid_hostnames;
81         self
82     }
83 
84     /// Set a verification callback to be used for connections created with this `Builder`.
85     ///
86     /// The callback is provided with an io::Result indicating if the (pre)validation was
87     /// successful. The Ok() variant indicates a successful validation while the Err() variant
88     /// contains the errorcode returned from the internal verification process.
89     /// The validated certificate, is accessible through the second argument of the closure.
verify_callback<F>(&mut self, callback: F) -> &mut Builder where F: Fn(CertValidationResult) -> io::Result<()> + 'static + Sync + Send90     pub fn verify_callback<F>(&mut self, callback: F) -> &mut Builder
91         where F: Fn(CertValidationResult) -> io::Result<()> + 'static + Sync + Send
92     {
93         self.verify_callback = Some(Arc::new(callback));
94         self
95     }
96 
97     /// Specifies a custom certificate store which is later used when validating
98     /// a server's certificate.
99     ///
100     /// This option is only used for client connections and is used to construct
101     /// the certificate chain which the server's certificate is validated
102     /// against.
103     ///
104     /// Note that adding certificates here means that they are
105     /// implicitly trusted.
cert_store(&mut self, cert_store: CertStore) -> &mut Builder106     pub fn cert_store(&mut self, cert_store: CertStore) -> &mut Builder {
107         self.cert_store = Some(cert_store);
108         self
109     }
110 
111     /// Initialize a new TLS session where the stream provided will be
112     /// connecting to a remote TLS server.
113     ///
114     /// If the stream provided is a blocking stream then the entire handshake
115     /// will be performed if possible, but if the stream is in nonblocking mode
116     /// then a `HandshakeError::Interrupted` variant may be returned. This
117     /// type can then be extracted to later call
118     /// `MidHandshakeTlsStream::handshake` when data becomes available.
connect<S>(&mut self, cred: SchannelCred, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: Read + Write119     pub fn connect<S>(&mut self,
120                       cred: SchannelCred,
121                       stream: S)
122                       -> Result<TlsStream<S>, HandshakeError<S>>
123         where S: Read + Write
124     {
125         self.initialize(cred, false, stream)
126     }
127 
128     /// Initialize a new TLS session where the stream provided will be
129     /// accepting a connection.
130     ///
131     /// This method will tweak the protocol for "who talks first" and also
132     /// currently disables validation of the client that's connecting to us.
133     ///
134     /// If the stream provided is a blocking stream then the entire handshake
135     /// will be performed if possible, but if the stream is in nonblocking mode
136     /// then a `HandshakeError::Interrupted` variant may be returned. This
137     /// type can then be extracted to later call
138     /// `MidHandshakeTlsStream::handshake` when data becomes available.
accept<S>(&mut self, cred: SchannelCred, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: Read + Write139     pub fn accept<S>(&mut self,
140                      cred: SchannelCred,
141                      stream: S)
142                      -> Result<TlsStream<S>, HandshakeError<S>>
143         where S: Read + Write
144     {
145         self.initialize(cred, true, stream)
146     }
147 
initialize<S>(&mut self, mut cred: SchannelCred, server: bool, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: Read + Write148     fn initialize<S>(&mut self,
149                      mut cred: SchannelCred,
150                      server: bool,
151                      stream: S)
152                          -> Result<TlsStream<S>, HandshakeError<S>>
153         where S: Read + Write
154     {
155         let domain = match self.domain {
156             Some(ref domain) if self.use_sni => Some(&domain[..]),
157             _ => None,
158         };
159         let (ctxt, buf) = match SecurityContext::initialize(&mut cred,
160                                                             server,
161                                                             domain) {
162             Ok(pair) => pair,
163             Err(e) => return Err(HandshakeError::Failure(e)),
164         };
165 
166         let stream = TlsStream {
167             cred: cred,
168             context: ctxt,
169             cert_store: self.cert_store.clone(),
170             domain: self.domain.clone(),
171             use_sni: self.use_sni,
172             accept_invalid_hostnames: self.accept_invalid_hostnames,
173             verify_callback: self.verify_callback.clone(),
174             stream: stream,
175             server: server,
176             accept_first: true,
177             state: State::Initializing {
178                 needs_flush: false,
179                 more_calls: true,
180                 shutting_down: false,
181                 validated: false,
182             },
183             needs_read: 1,
184             dec_in: Cursor::new(Vec::new()),
185             enc_in: Cursor::new(Vec::new()),
186             out_buf: Cursor::new(buf.map(|b| b.to_owned()).unwrap_or(Vec::new())),
187             last_write_len: 0,
188         };
189 
190         MidHandshakeTlsStream {
191             inner: stream,
192         }.handshake()
193     }
194 }
195 
196 enum State {
197     Initializing {
198         needs_flush: bool,
199         more_calls: bool,
200         shutting_down: bool,
201         validated: bool,
202     },
203     Streaming { sizes: sspi::SecPkgContext_StreamSizes, },
204     Shutdown,
205 }
206 
207 /// An Schannel TLS stream.
208 pub struct TlsStream<S> {
209     cred: SchannelCred,
210     context: SecurityContext,
211     cert_store: Option<CertStore>,
212     domain: Option<Vec<u16>>,
213     use_sni: bool,
214     accept_invalid_hostnames: bool,
215     verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>,
216     stream: S,
217     state: State,
218     server: bool,
219     accept_first: bool,
220     needs_read: usize,
221     // valid from position() to len()
222     dec_in: Cursor<Vec<u8>>,
223     // valid from 0 to position()
224     enc_in: Cursor<Vec<u8>>,
225     // valid from position() to len()
226     out_buf: Cursor<Vec<u8>>,
227     /// the (unencrypted) length of the last write call used to track writes
228     last_write_len: usize,
229 }
230 
231 /// ensures that a TlsStream is always Sync/Send
_is_sync()232 fn _is_sync() {
233     fn sync<T: Sync + Send>() {}
234     sync::<TlsStream<()>>();
235 }
236 
237 /// A failure which can happen during the `Builder::initialize` phase, either an
238 /// I/O error or an intermediate stream which has not completed its handshake.
239 #[derive(Debug)]
240 pub enum HandshakeError<S> {
241     /// A fatal I/O error occurred
242     Failure(io::Error),
243     /// The stream connection is in progress, but the handshake is not completed
244     /// yet.
245     Interrupted(MidHandshakeTlsStream<S>),
246 }
247 
248 /// A struct used to wrap various cert chain validation results for callback processing.
249 pub struct CertValidationResult {
250     chain: CertChainContext,
251     res: i32,
252     chain_index: i32,
253     element_index: i32,
254 }
255 
256 impl CertValidationResult {
257     /// Returns the certificate that failed validation if applicable
failed_certificate(&self) -> Option<CertContext>258     pub fn failed_certificate(&self) -> Option<CertContext> {
259         if let Some(cert_chain) = self.chain.get_chain(self.chain_index as usize) {
260             return cert_chain.get(self.element_index as usize);
261         }
262         None
263     }
264 
265     /// Returns the final certificate chain in the certificate context if applicable
chain(&self) -> Option<CertChain>266     pub fn chain(&self) -> Option<CertChain> {
267         self.chain.final_chain()
268     }
269 
270     /// Returns the result of the built-in certificate verification process.
result(&self) -> io::Result<()>271     pub fn result(&self) -> io::Result<()> {
272         if self.res as u32 != winerror::ERROR_SUCCESS {
273                 Err(io::Error::from_raw_os_error(self.res))
274         } else {
275                 Ok(())
276         }
277     }
278 }
279 
280 impl<S: fmt::Debug + Any> Error for HandshakeError<S> {
description(&self) -> &str281     fn description(&self) -> &str {
282         match *self {
283             HandshakeError::Failure(_) => "failed to perform handshake",
284             HandshakeError::Interrupted(_) => "interrupted performing handshake",
285         }
286     }
287 
cause(&self) -> Option<&dyn Error>288     fn cause(&self) -> Option<&dyn Error> {
289         match *self {
290             HandshakeError::Failure(ref e) => Some(e),
291             HandshakeError::Interrupted(_) => None,
292         }
293     }
294 }
295 
296 impl<S: fmt::Debug + Any> fmt::Display for HandshakeError<S> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result297     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
298         f.write_str(self.description())?;
299         if let Some(e) = self.source() {
300            write!(f, ": {}", e)?;
301         }
302         Ok(())
303     }
304 }
305 
306 /// A stream which has not yet completed its handshake.
307 #[derive(Debug)]
308 pub struct MidHandshakeTlsStream<S> {
309     inner: TlsStream<S>,
310 }
311 
312 impl<S> fmt::Debug for TlsStream<S>
313     where S: fmt::Debug
314 {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result315     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
316         fmt.debug_struct("TlsStream")
317             .field("stream", &self.stream)
318             .finish()
319     }
320 }
321 
322 impl<S> TlsStream<S> {
323     /// Returns a reference to the wrapped stream.
get_ref(&self) -> &S324     pub fn get_ref(&self) -> &S {
325         &self.stream
326     }
327 
328     /// Returns a mutable reference to the wrapped stream.
get_mut(&mut self) -> &mut S329     pub fn get_mut(&mut self) -> &mut S {
330         &mut self.stream
331     }
332 
333     /// Indicates if this stream is the server- or client-side of a TLS session.
is_server(&self) -> bool334     pub fn is_server(&self) -> bool {
335         self.server
336     }
337 }
338 
339 impl<S> TlsStream<S>
340     where S: Read + Write
341 {
342     /// Returns the certificate used to identify this side of the TLS session.
343     ///
344     /// Its associated cert store contains any intermediate certificates sent
345     /// along with the leaf.
certificate(&self) -> io::Result<CertContext>346     pub fn certificate(&self) -> io::Result<CertContext> {
347         self.context.local_cert()
348     }
349 
350     /// Returns the peer's certificate, if available.
351     ///
352     /// Its associated cert store contains any intermediate certificates sent
353     /// by the server.
peer_certificate(&self) -> io::Result<CertContext>354     pub fn peer_certificate(&self) -> io::Result<CertContext> {
355         self.context.remote_cert()
356     }
357 
358     /// Returns a reference to the buffer of pending data.
359     ///
360     /// Like `BufRead::fill_buf` except that it will return an empty slice
361     /// rather than reading from the wrapped stream if there is no buffered
362     /// data.
get_buf(&self) -> &[u8]363     pub fn get_buf(&self) -> &[u8] {
364         &self.dec_in.get_ref()[self.dec_in.position() as usize..]
365     }
366 
367     /// Shuts the TLS session down.
shutdown(&mut self) -> io::Result<()>368     pub fn shutdown(&mut self) -> io::Result<()> {
369         match self.state {
370             State::Shutdown => return Ok(()),
371             State::Initializing { shutting_down: true, .. } => {}
372             _ => {
373                 unsafe {
374                     let mut token = um::schannel::SCHANNEL_SHUTDOWN;
375                     let ptr = &mut token as *mut _ as *mut u8;
376                     let size = mem::size_of_val(&token);
377                     let token = slice::from_raw_parts_mut(ptr, size);
378                     let mut buf = [secbuf(sspi::SECBUFFER_TOKEN, Some(token))];
379                     let mut desc = secbuf_desc(&mut buf);
380 
381                     match sspi::ApplyControlToken(self.context.get_mut(), &mut desc) {
382                         winerror::SEC_E_OK => {}
383                         err => return Err(io::Error::from_raw_os_error(err as i32)),
384                     }
385                 }
386 
387                 self.state = State::Initializing {
388                     needs_flush: false,
389                     more_calls: true,
390                     shutting_down: true,
391                     validated: false,
392                 };
393                 self.needs_read = 0;
394             }
395         }
396 
397         self.initialize().map(|_| ())
398     }
399 
step_initialize(&mut self) -> io::Result<()>400     fn step_initialize(&mut self) -> io::Result<()> {
401         unsafe {
402             let pos = self.enc_in.position() as usize;
403             let mut inbufs = [secbuf(sspi::SECBUFFER_TOKEN,
404                                      Some(&mut self.enc_in.get_mut()[..pos])),
405                               secbuf(sspi::SECBUFFER_EMPTY, None)];
406             let mut inbuf_desc = secbuf_desc(&mut inbufs);
407 
408             let mut outbufs = [secbuf(sspi::SECBUFFER_TOKEN, None),
409                                secbuf(sspi::SECBUFFER_ALERT, None),
410                                secbuf(sspi::SECBUFFER_EMPTY, None)];
411             let mut outbuf_desc = secbuf_desc(&mut outbufs);
412 
413             let mut attributes = 0;
414 
415             let status = if self.server {
416                 let ptr = if self.accept_first {
417                     ptr::null_mut()
418                 } else {
419                     self.context.get_mut()
420                 };
421                 sspi::AcceptSecurityContext(self.cred.get_mut(),
422                                             ptr,
423                                             &mut inbuf_desc,
424                                             ACCEPT_REQUESTS,
425                                             0,
426                                             self.context.get_mut(),
427                                             &mut outbuf_desc,
428                                             &mut attributes,
429                                             ptr::null_mut())
430             } else {
431                 let domain = match self.domain {
432                     Some(ref domain) if self.use_sni => domain.as_ptr() as *mut u16,
433                     _ => ptr::null_mut(),
434                 };
435 
436                 sspi::InitializeSecurityContextW(self.cred.get_mut(),
437                                                  self.context.get_mut(),
438                                                  domain,
439                                                  INIT_REQUESTS,
440                                                  0,
441                                                  0,
442                                                  &mut inbuf_desc,
443                                                  0,
444                                                  ptr::null_mut(),
445                                                  &mut outbuf_desc,
446                                                  &mut attributes,
447                                                  ptr::null_mut())
448             };
449 
450             for buf in &outbufs[1..] {
451                 if !buf.pvBuffer.is_null() {
452                     sspi::FreeContextBuffer(buf.pvBuffer);
453                 }
454             }
455 
456             match status {
457                 winerror::SEC_I_CONTINUE_NEEDED => {
458                     // Windows apparently doesn't like AcceptSecurityContext
459                     // being called as if it were the second time unless the
460                     // first call to AcceptSecurityContext succeeded with
461                     // CONTINUE_NEEDED.
462                     //
463                     // In other words, if we were to set `accept_first` to
464                     // `false` after the literal first call to
465                     // `AcceptSecurityContext` while the call returned
466                     // INCOMPLETE_MESSAGE, the next call would return an error.
467                     //
468                     // For that reason we only set `accept_first` to false here
469                     // once we've actually successfully received the full
470                     // "token" from the client.
471                     self.accept_first = false;
472                     let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA {
473                         self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
474                     } else {
475                         self.enc_in.position() as usize
476                     };
477                     let to_write = ContextBuffer(outbufs[0]);
478 
479                     self.consume_enc_in(nread);
480                     self.needs_read = (self.enc_in.position() == 0) as usize;
481                     self.out_buf.get_mut().extend_from_slice(&to_write);
482                 }
483                 winerror::SEC_E_INCOMPLETE_MESSAGE => {
484                     self.needs_read = if inbufs[1].BufferType == sspi::SECBUFFER_MISSING {
485                         inbufs[1].cbBuffer as usize
486                     } else {
487                         1
488                     };
489                 }
490                 winerror::SEC_E_OK => {
491                     let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA {
492                         self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
493                     } else {
494                         self.enc_in.position() as usize
495                     };
496                     let to_write = if outbufs[0].pvBuffer.is_null() {
497                         None
498                     } else {
499                         Some(ContextBuffer(outbufs[0]))
500                     };
501 
502                     self.consume_enc_in(nread);
503                     self.needs_read = (self.enc_in.position() == 0) as usize;
504                     if let Some(to_write) = to_write {
505                         self.out_buf.get_mut().extend_from_slice(&to_write);
506                     }
507                     if self.enc_in.position() != 0 {
508                         self.decrypt()?;
509                     }
510                     if let State::Initializing { ref mut more_calls, .. } = self.state {
511                         *more_calls = false;
512                     }
513                 }
514                 _ => {
515                     return Err(io::Error::from_raw_os_error(status as i32))
516                 }
517             }
518             Ok(())
519         }
520     }
521 
initialize(&mut self) -> io::Result<Option<sspi::SecPkgContext_StreamSizes>>522     fn initialize(&mut self) -> io::Result<Option<sspi::SecPkgContext_StreamSizes>> {
523         loop {
524             match self.state {
525                 State::Initializing { mut needs_flush, more_calls, shutting_down, validated } => {
526                     if self.write_out()? > 0 {
527                         needs_flush = true;
528                         if let State::Initializing { ref mut needs_flush, .. } = self.state {
529                             *needs_flush = true;
530                         }
531                     }
532 
533                     if needs_flush {
534                         self.stream.flush()?;
535                         if let State::Initializing { ref mut needs_flush, .. } = self.state {
536                             *needs_flush = false;
537                         }
538                     }
539 
540                     if !shutting_down && !validated {
541                         // on the last call, we require a valid certificate
542                         if self.validate(!more_calls)? {
543                             if let State::Initializing { ref mut validated, .. } = self.state {
544                                 *validated = true;
545                             }
546                         }
547                     }
548 
549                     if !more_calls {
550                         self.state = if shutting_down {
551                             State::Shutdown
552                         } else {
553                             State::Streaming { sizes: self.context.stream_sizes()? }
554                         };
555                         continue;
556                     }
557 
558                     if self.needs_read > 0 {
559                         if self.read_in()? == 0 {
560                             return Err(io::Error::new(io::ErrorKind::UnexpectedEof,
561                                                       "unexpected EOF during handshake"));
562                         }
563                     }
564 
565                     self.step_initialize()?;
566                 }
567                 State::Streaming { sizes } => return Ok(Some(sizes)),
568                 State::Shutdown => return Ok(None),
569             }
570         }
571     }
572 
573     /// Returns true when the certificate was succesfully verified
574     /// Returns false, when a verification isn't necessary (yet)
575     /// Returns an error when the verification failed
validate(&mut self, require_cert: bool) -> io::Result<bool>576     fn validate(&mut self, require_cert: bool) -> io::Result<bool> {
577         // If we're accepting connections then we don't perform any validation
578         // for the remote certificate, that's what they're doing!
579         if self.server {
580             return Ok(false);
581         }
582 
583         let cert_context = match self.context.remote_cert() {
584             Err(_) if !require_cert => return Ok(false),
585             ret => ret?
586         };
587 
588         let cert_chain = unsafe {
589             let cert_store = match (cert_context.cert_store(), &self.cert_store) {
590                 (Some(ref mut chain_certs), &Some(ref extra_certs)) => {
591                     for extra_cert in extra_certs.certs() {
592                         chain_certs.add_cert(&extra_cert, CertAdd::ReplaceExisting)?;
593                     }
594                     chain_certs.as_inner()
595                 },
596                 (Some(chain_certs), &None) => chain_certs.as_inner(),
597                 (None, &Some(ref extra_certs)) => extra_certs.as_inner(),
598                 (None, &None) => ptr::null_mut()
599             };
600 
601             let flags = wincrypt::CERT_CHAIN_CACHE_END_CERT |
602                         wincrypt::CERT_CHAIN_REVOCATION_CHECK_CACHE_ONLY |
603                         wincrypt::CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT;
604 
605             let mut para: wincrypt::CERT_CHAIN_PARA = mem::zeroed();
606             para.cbSize = mem::size_of_val(&para) as winapi::DWORD;
607             para.RequestedUsage.dwType = wincrypt::USAGE_MATCH_TYPE_OR;
608 
609             let mut identifiers = [szOID_PKIX_KP_SERVER_AUTH.as_ptr() as ntdef::LPSTR,
610                                    szOID_SERVER_GATED_CRYPTO.as_ptr() as ntdef::LPSTR,
611                                    szOID_SGC_NETSCAPE.as_ptr() as ntdef::LPSTR];
612             para.RequestedUsage.Usage.cUsageIdentifier = identifiers.len() as winapi::DWORD;
613             para.RequestedUsage.Usage.rgpszUsageIdentifier = identifiers.as_mut_ptr();
614 
615             let mut cert_chain = mem::zeroed();
616 
617             let res = wincrypt::CertGetCertificateChain(ptr::null_mut(),
618                                                         cert_context.as_inner(),
619                                                         ptr::null_mut(),
620                                                         cert_store,
621                                                         &mut para,
622                                                         flags,
623                                                         ptr::null_mut(),
624                                                         &mut cert_chain);
625 
626             if res == winapi::TRUE {
627                 CertChainContext(cert_chain as *mut _)
628             } else {
629                 return Err(io::Error::last_os_error())
630             }
631         };
632 
633         unsafe {
634             // check if we trust the root-CA explicitly
635             let mut para_flags = wincrypt::CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS;
636             if let Some(ref mut store) = self.cert_store {
637                 if let Some(chain) = cert_chain.final_chain() {
638                     // check if any cert of the chain is in the passed store (and therefore trusted)
639                     if chain.certificates().any(|cert| store.certs().any(|root_cert| root_cert == cert)) {
640                         para_flags |= wincrypt::CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG;
641                     }
642                 }
643             }
644 
645             let mut extra_para: wincrypt::SSL_EXTRA_CERT_CHAIN_POLICY_PARA = mem::zeroed();
646             *extra_para.u.cbSize_mut() = mem::size_of_val(&extra_para) as winapi::DWORD;
647             extra_para.dwAuthType = wincrypt::AUTHTYPE_SERVER;
648             match self.domain {
649                 Some(ref mut domain) if !self.accept_invalid_hostnames => {
650                     extra_para.pwszServerName = domain.as_mut_ptr();
651                 }
652                 _ => {}
653             }
654 
655             let mut para: wincrypt::CERT_CHAIN_POLICY_PARA = mem::zeroed();
656             para.cbSize = mem::size_of_val(&para) as winapi::DWORD;
657             para.dwFlags = para_flags;
658             para.pvExtraPolicyPara = &mut extra_para as *mut _ as *mut _;
659 
660             let mut status: wincrypt::CERT_CHAIN_POLICY_STATUS = mem::zeroed();
661             status.cbSize = mem::size_of_val(&status) as winapi::DWORD;
662 
663             let verify_chain_policy_structure = wincrypt::CERT_CHAIN_POLICY_SSL as ntdef::LPCSTR;
664             let res = wincrypt::CertVerifyCertificateChainPolicy(verify_chain_policy_structure,
665                                                                 cert_chain.0,
666                                                                 &mut para,
667                                                                 &mut status);
668             if res == winapi::FALSE {
669                 return Err(io::Error::last_os_error())
670             }
671 
672             let mut verify_result = if status.dwError != winerror::ERROR_SUCCESS {
673                 Err(io::Error::from_raw_os_error(status.dwError as i32))
674             } else {
675                 Ok(())
676             };
677 
678             // check if there's a user-specified verify callback
679             if let Some(ref callback) = self.verify_callback {
680                 verify_result = callback(CertValidationResult{
681                     chain: cert_chain,
682                     res: status.dwError as i32,
683                     chain_index: status.lChainIndex,
684                     element_index: status.lElementIndex});
685             }
686             verify_result?;
687         }
688         Ok(true)
689     }
690 
write_out(&mut self) -> io::Result<usize>691     fn write_out(&mut self) -> io::Result<usize> {
692         let mut out = 0;
693         while self.out_buf.position() as usize != self.out_buf.get_ref().len() {
694             let position = self.out_buf.position() as usize;
695             let nwritten = self.stream.write(&self.out_buf.get_ref()[position..])?;
696             out += nwritten;
697             self.out_buf.set_position((position + nwritten) as u64);
698         }
699 
700         Ok(out)
701     }
702 
read_in(&mut self) -> io::Result<usize>703     fn read_in(&mut self) -> io::Result<usize> {
704         let mut sum_nread = 0;
705 
706         while self.needs_read > 0 {
707             let existing_len = self.enc_in.position() as usize;
708             let min_len = cmp::max(cmp::max(1024, 2 * existing_len), self.needs_read);
709             if self.enc_in.get_ref().len() < min_len {
710                 self.enc_in.get_mut().resize(min_len, 0);
711             }
712             let nread = {
713                 let buf = &mut self.enc_in.get_mut()[existing_len..];
714                 self.stream.read(buf)?
715             };
716             self.enc_in.set_position((existing_len + nread) as u64);
717             self.needs_read = self.needs_read.saturating_sub(nread);
718             if nread == 0 {
719                 break;
720             }
721             sum_nread += nread;
722         }
723 
724         Ok(sum_nread)
725     }
726 
consume_enc_in(&mut self, nread: usize)727     fn consume_enc_in(&mut self, nread: usize) {
728         let size = self.enc_in.position() as usize;
729         assert!(size >= nread);
730         let count = size - nread;
731 
732         if count > 0 {
733             self.enc_in.get_mut().drain(..nread);
734         }
735 
736         self.enc_in.set_position(count as u64);
737     }
738 
decrypt(&mut self) -> io::Result<bool>739     fn decrypt(&mut self) -> io::Result<bool> {
740         unsafe {
741             let position = self.enc_in.position() as usize;
742             let mut bufs = [secbuf(sspi::SECBUFFER_DATA,
743                                    Some(&mut self.enc_in.get_mut()[..position])),
744                             secbuf(sspi::SECBUFFER_EMPTY, None),
745                             secbuf(sspi::SECBUFFER_EMPTY, None),
746                             secbuf(sspi::SECBUFFER_EMPTY, None)];
747             let mut bufdesc = secbuf_desc(&mut bufs);
748 
749             match sspi::DecryptMessage(self.context.get_mut(),
750                                           &mut bufdesc,
751                                           0,
752                                           ptr::null_mut()) {
753                 winerror::SEC_E_OK => {
754                     let start = bufs[1].pvBuffer as usize - self.enc_in.get_ref().as_ptr() as usize;
755                     let end = start + bufs[1].cbBuffer as usize;
756                     self.dec_in.get_mut().clear();
757                     self.dec_in
758                         .get_mut()
759                         .extend_from_slice(&self.enc_in.get_ref()[start..end]);
760                     self.dec_in.set_position(0);
761 
762                     let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA {
763                         self.enc_in.position() as usize - bufs[3].cbBuffer as usize
764                     } else {
765                         self.enc_in.position() as usize
766                     };
767                     self.consume_enc_in(nread);
768                     self.needs_read = (self.enc_in.position() == 0) as usize;
769                     Ok(false)
770                 }
771                 winerror::SEC_E_INCOMPLETE_MESSAGE => {
772                     self.needs_read = if bufs[1].BufferType == sspi::SECBUFFER_MISSING {
773                         bufs[1].cbBuffer as usize
774                     } else {
775                         1
776                     };
777                     Ok(false)
778                 }
779                 winerror::SEC_I_CONTEXT_EXPIRED => Ok(true),
780                 winerror::SEC_I_RENEGOTIATE => {
781                     self.state = State::Initializing {
782                         needs_flush: false,
783                         more_calls: true,
784                         shutting_down: false,
785                         validated: false,
786                     };
787 
788                     let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA {
789                         self.enc_in.position() as usize - bufs[3].cbBuffer as usize
790                     } else {
791                         self.enc_in.position() as usize
792                     };
793                     self.consume_enc_in(nread);
794                     self.needs_read = 0;
795                     Ok(false)
796                 }
797                 e => Err(io::Error::from_raw_os_error(e as i32)),
798             }
799         }
800     }
801 
encrypt(&mut self, buf: &[u8], sizes: &sspi::SecPkgContext_StreamSizes) -> io::Result<()>802     fn encrypt(&mut self, buf: &[u8], sizes: &sspi::SecPkgContext_StreamSizes) -> io::Result<()> {
803         assert!(buf.len() <= sizes.cbMaximumMessage as usize);
804 
805         unsafe {
806             let len = sizes.cbHeader as usize + buf.len() + sizes.cbTrailer as usize;
807 
808             if self.out_buf.get_ref().len() < len {
809                 self.out_buf.get_mut().resize(len, 0);
810             }
811 
812             let message_start = sizes.cbHeader as usize;
813             self.out_buf
814                 .get_mut()[message_start..message_start + buf.len()]
815                 .clone_from_slice(buf);
816 
817             let mut bufs = {
818                 let out_buf = self.out_buf.get_mut();
819                 let size = sizes.cbHeader as usize;
820 
821                 let header = secbuf(sspi::SECBUFFER_STREAM_HEADER,
822                                     Some(&mut out_buf[..size]));
823                 let data = secbuf(sspi::SECBUFFER_DATA,
824                                   Some(&mut out_buf[size..size + buf.len()]));
825                 let trailer = secbuf(sspi::SECBUFFER_STREAM_TRAILER,
826                                      Some(&mut out_buf[size + buf.len()..]));
827                 let empty = secbuf(sspi::SECBUFFER_EMPTY, None);
828                 [header, data, trailer, empty]
829             };
830             let mut bufdesc = secbuf_desc(&mut bufs);
831 
832             match sspi::EncryptMessage(self.context.get_mut(), 0, &mut bufdesc, 0) {
833                 winerror::SEC_E_OK => {
834                     let len = bufs[0].cbBuffer + bufs[1].cbBuffer + bufs[2].cbBuffer;
835                     self.out_buf.get_mut().truncate(len as usize);
836                     self.out_buf.set_position(0);
837                     Ok(())
838                 }
839                 err => Err(io::Error::from_raw_os_error(err as i32)),
840             }
841         }
842     }
843 }
844 
845 impl<S> MidHandshakeTlsStream<S> {
846     /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S847     pub fn get_ref(&self) -> &S {
848         self.inner.get_ref()
849     }
850 
851     /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S852     pub fn get_mut(&mut self) -> &mut S {
853         self.inner.get_mut()
854     }
855 }
856 
857 impl<S> MidHandshakeTlsStream<S>
858     where S: Read + Write,
859 {
860     /// Restarts the handshake process.
handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>>861     pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
862         match self.inner.initialize() {
863             Ok(_) => Ok(self.inner),
864             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
865                 Err(HandshakeError::Interrupted(self))
866             }
867             Err(e) => Err(HandshakeError::Failure(e)),
868         }
869     }
870 }
871 
872 impl<S> Write for TlsStream<S>
873     where S: Read + Write
874 {
875     /// In the case of a WouldBlock error, we expect another call
876     /// starting with the same input data
877     /// This is similar to the use of ACCEPT_MOVING_WRITE_BUFFER in openssl
write(&mut self, buf: &[u8]) -> io::Result<usize>878     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
879         let sizes = match self.initialize()? {
880             Some(sizes) => sizes,
881             None => return Err(io::Error::from_raw_os_error(winerror::SEC_E_CONTEXT_EXPIRED as i32)),
882         };
883 
884         // if we have pending output data, it must have been because a previous
885         // attempt to send this part of the data ran into an error.
886         if self.out_buf.position() == self.out_buf.get_ref().len() as u64 {
887             let len = cmp::min(buf.len(), sizes.cbMaximumMessage as usize);
888             self.encrypt(&buf[..len], &sizes)?;
889             self.last_write_len = len;
890         }
891         self.write_out()?;
892 
893         Ok(self.last_write_len)
894     }
895 
flush(&mut self) -> io::Result<()>896     fn flush(&mut self) -> io::Result<()> {
897         // Make sure the write buffer is emptied
898         self.write_out()?;
899         self.stream.flush()
900     }
901 }
902 
903 impl<S> Read for TlsStream<S>
904     where S: Read + Write
905 {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>906     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
907         let nread = {
908             let read_buf = self.fill_buf()?;
909             let nread = cmp::min(buf.len(), read_buf.len());
910             buf[..nread].copy_from_slice(&read_buf[..nread]);
911             nread
912         };
913         self.consume(nread);
914         Ok(nread)
915     }
916 }
917 
918 impl<S> BufRead for TlsStream<S>
919     where S: Read + Write
920 {
fill_buf(&mut self) -> io::Result<&[u8]>921     fn fill_buf(&mut self) -> io::Result<&[u8]> {
922         while self.get_buf().is_empty() {
923             if let None = self.initialize()? {
924                 break;
925             }
926 
927             if self.needs_read > 0 {
928                 if self.read_in()? == 0 {
929                     break;
930                 }
931                 self.needs_read = 0;
932             }
933 
934             let eof = self.decrypt()?;
935             if eof {
936                 break;
937             }
938         }
939 
940         Ok(self.get_buf())
941     }
942 
consume(&mut self, amt: usize)943     fn consume(&mut self, amt: usize) {
944         let pos = self.dec_in.position() + amt as u64;
945         assert!(pos <= self.dec_in.get_ref().len() as u64);
946         self.dec_in.set_position(pos);
947     }
948 }
949