1 //! Schannel TLS streams.
2 use std::any::Any;
3 use std::cmp;
4 use std::error::Error;
5 use std::fmt;
6 use std::io::{self, Read, BufRead, Write, Cursor};
7 use std::mem;
8 use std::ptr;
9 use std::slice;
10 use std::sync::Arc;
11 use winapi::shared::minwindef as winapi;
12 use winapi::shared::{ntdef, sspi, winerror};
13 use winapi::um::{self, schannel, wincrypt};
14 
15 use crate::{INIT_REQUESTS, ACCEPT_REQUESTS, Inner, secbuf, secbuf_desc};
16 use crate::alpn_list::AlpnList;
17 use crate::cert_chain::{CertChain, CertChainContext};
18 use crate::cert_store::{CertAdd, CertStore};
19 use crate::cert_context::CertContext;
20 use crate::security_context::SecurityContext;
21 use crate::context_buffer::ContextBuffer;
22 use crate::schannel_cred::SchannelCred;
23 
24 lazy_static! {
25     static ref szOID_PKIX_KP_SERVER_AUTH: Vec<u8> =
26         wincrypt::szOID_PKIX_KP_SERVER_AUTH.bytes().chain(Some(0)).collect();
27     static ref szOID_SERVER_GATED_CRYPTO: Vec<u8> =
28         wincrypt::szOID_SERVER_GATED_CRYPTO.bytes().chain(Some(0)).collect();
29     static ref szOID_SGC_NETSCAPE: Vec<u8> =
30         wincrypt::szOID_SGC_NETSCAPE.bytes().chain(Some(0)).collect();
31 }
32 
33 /// A builder type for `TlsStream`s.
34 pub struct Builder {
35     domain: Option<Vec<u16>>,
36     use_sni: bool,
37     accept_invalid_hostnames: bool,
38     verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>,
39     cert_store: Option<CertStore>,
40     requested_application_protocols: Option<Vec<Vec<u8>>>,
41 }
42 
43 impl Default for Builder {
44     fn default() -> Builder {
45         Builder {
46             domain: None,
47             use_sni: true,
48             accept_invalid_hostnames: false,
49             verify_callback: None,
50             cert_store: None,
51             requested_application_protocols: None,
52         }
53     }
54 }
55 
56 impl Builder {
57     /// Returns a new `Builder`.
58     pub fn new() -> Builder {
59         Builder::default()
60     }
61 
62     /// Sets the domain associated with connections created with this `Builder`.
63     ///
64     /// The domain will be used for Server Name Indication as well as
65     /// certificate validation.
66     pub fn domain(&mut self, domain: &str) -> &mut Builder {
67         self.domain = Some(domain.encode_utf16().chain(Some(0)).collect());
68         self
69     }
70 
71     /// Determines if Server Name Indication (SNI) will be used.
72     ///
73     /// Defaults to `true`.
74     pub fn use_sni(&mut self, use_sni: bool) -> &mut Builder {
75         self.use_sni = use_sni;
76         self
77     }
78 
79     /// Determines if the server's hostname will be checked during certificate verification.
80     ///
81     /// Defaults to `false`.
82     pub fn accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) -> &mut Builder {
83         self.accept_invalid_hostnames = accept_invalid_hostnames;
84         self
85     }
86 
87     /// Set a verification callback to be used for connections created with this `Builder`.
88     ///
89     /// The callback is provided with an io::Result indicating if the (pre)validation was
90     /// successful. The Ok() variant indicates a successful validation while the Err() variant
91     /// contains the errorcode returned from the internal verification process.
92     /// The validated certificate, is accessible through the second argument of the closure.
93     pub fn verify_callback<F>(&mut self, callback: F) -> &mut Builder
94         where F: Fn(CertValidationResult) -> io::Result<()> + 'static + Sync + Send
95     {
96         self.verify_callback = Some(Arc::new(callback));
97         self
98     }
99 
100     /// Specifies a custom certificate store which is later used when validating
101     /// a server's certificate.
102     ///
103     /// This option is only used for client connections and is used to construct
104     /// the certificate chain which the server's certificate is validated
105     /// against.
106     ///
107     /// Note that adding certificates here means that they are
108     /// implicitly trusted.
109     pub fn cert_store(&mut self, cert_store: CertStore) -> &mut Builder {
110         self.cert_store = Some(cert_store);
111         self
112     }
113 
114     /// Requests one of a set of application protocols using alpn
115     pub fn request_application_protocols(&mut self, alpns: &[&[u8]]) -> &mut Builder {
116         self.requested_application_protocols =
117             Some(alpns.iter().map(|bytes| bytes.to_vec()).collect::<Vec<_>>());
118         self
119     }
120 
121     /// Initialize a new TLS session where the stream provided will be
122     /// connecting to a remote TLS server.
123     ///
124     /// If the stream provided is a blocking stream then the entire handshake
125     /// will be performed if possible, but if the stream is in nonblocking mode
126     /// then a `HandshakeError::Interrupted` variant may be returned. This
127     /// type can then be extracted to later call
128     /// `MidHandshakeTlsStream::handshake` when data becomes available.
129     pub fn connect<S>(&mut self,
130                       cred: SchannelCred,
131                       stream: S)
132                       -> Result<TlsStream<S>, HandshakeError<S>>
133         where S: Read + Write
134     {
135         self.initialize(cred, false, stream)
136     }
137 
138     /// Initialize a new TLS session where the stream provided will be
139     /// accepting a connection.
140     ///
141     /// This method will tweak the protocol for "who talks first" and also
142     /// currently disables validation of the client that's connecting to us.
143     ///
144     /// If the stream provided is a blocking stream then the entire handshake
145     /// will be performed if possible, but if the stream is in nonblocking mode
146     /// then a `HandshakeError::Interrupted` variant may be returned. This
147     /// type can then be extracted to later call
148     /// `MidHandshakeTlsStream::handshake` when data becomes available.
149     pub fn accept<S>(&mut self,
150                      cred: SchannelCred,
151                      stream: S)
152                      -> Result<TlsStream<S>, HandshakeError<S>>
153         where S: Read + Write
154     {
155         self.initialize(cred, true, stream)
156     }
157 
func_unsetnull158     fn initialize<S>(&mut self,
159                      mut cred: SchannelCred,
160                      server: bool,
161                      stream: S)
162                          -> Result<TlsStream<S>, HandshakeError<S>>
163         where S: Read + Write
164     {
165         let domain = match self.domain {
166             Some(ref domain) if self.use_sni => Some(&domain[..]),
167             _ => None,
168         };
169         let (ctxt, buf) = match SecurityContext::initialize(&mut cred,
170                                                             server,
171                                                             domain,
172                                                             &self.requested_application_protocols) {
173             Ok(pair) => pair,
174             Err(e) => return Err(HandshakeError::Failure(e)),
175         };
176 
177         let stream = TlsStream {
178             cred: cred,
179             context: ctxt,
180             cert_store: self.cert_store.clone(),
181             domain: self.domain.clone(),
182             use_sni: self.use_sni,
183             accept_invalid_hostnames: self.accept_invalid_hostnames,
184             verify_callback: self.verify_callback.clone(),
185             stream: stream,
186             server: server,
187             accept_first: true,
188             state: State::Initializing {
189                 needs_flush: false,
190                 more_calls: true,
191                 shutting_down: false,
192                 validated: false,
193             },
194             needs_read: 1,
195             dec_in: Cursor::new(Vec::new()),
196             enc_in: Cursor::new(Vec::new()),
197             out_buf: Cursor::new(buf.map(|b| b.to_owned()).unwrap_or(Vec::new())),
198             last_write_len: 0,
199             requested_application_protocols: self.requested_application_protocols.clone(),
200         };
201 
202         MidHandshakeTlsStream {
203             inner: stream,
204         }.handshake()
205     }
206 }
207 
208 enum State {
209     Initializing {
210         needs_flush: bool,
211         more_calls: bool,
212         shutting_down: bool,
213         validated: bool,
214     },
215     Streaming { sizes: sspi::SecPkgContext_StreamSizes, },
216     Shutdown,
217 }
218 
219 /// An Schannel TLS stream.
220 pub struct TlsStream<S> {
221     cred: SchannelCred,
222     context: SecurityContext,
223     cert_store: Option<CertStore>,
224     domain: Option<Vec<u16>>,
225     use_sni: bool,
226     accept_invalid_hostnames: bool,
227     verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>,
228     stream: S,
229     state: State,
230     server: bool,
231     accept_first: bool,
232     needs_read: usize,
233     // valid from position() to len()
234     dec_in: Cursor<Vec<u8>>,
235     // valid from 0 to position()
236     enc_in: Cursor<Vec<u8>>,
237     // valid from position() to len()
238     out_buf: Cursor<Vec<u8>>,
239     /// the (unencrypted) length of the last write call used to track writes
240     last_write_len: usize,
241     requested_application_protocols: Option<Vec<Vec<u8>>>,
242 }
243 
func_check_prog_sednull244 /// ensures that a TlsStream is always Sync/Send
245 fn _is_sync() {
246     fn sync<T: Sync + Send>() {}
247     sync::<TlsStream<()>>();
248 }
249 
250 /// A failure which can happen during the `Builder::initialize` phase, either an
251 /// I/O error or an intermediate stream which has not completed its handshake.
252 #[derive(Debug)]
253 pub enum HandshakeError<S> {
254     /// A fatal I/O error occurred
255     Failure(io::Error),
256     /// The stream connection is in progress, but the handshake is not completed
257     /// yet.
258     Interrupted(MidHandshakeTlsStream<S>),
259 }
260 
261 /// A struct used to wrap various cert chain validation results for callback processing.
262 pub struct CertValidationResult {
263     chain: CertChainContext,
264     res: i32,
265     chain_index: i32,
266     element_index: i32,
267 }
268 
269 impl CertValidationResult {
270     /// Returns the certificate that failed validation if applicable
271     pub fn failed_certificate(&self) -> Option<CertContext> {
272         if let Some(cert_chain) = self.chain.get_chain(self.chain_index as usize) {
273             return cert_chain.get(self.element_index as usize);
274         }
275         None
276     }
277 
278     /// Returns the final certificate chain in the certificate context if applicable
func_check_prog_grepnull279     pub fn chain(&self) -> Option<CertChain> {
280         self.chain.final_chain()
281     }
282 
283     /// Returns the result of the built-in certificate verification process.
284     pub fn result(&self) -> io::Result<()> {
285         if self.res as u32 != winerror::ERROR_SUCCESS {
286                 Err(io::Error::from_raw_os_error(self.res))
287         } else {
288                 Ok(())
289         }
290     }
291 }
292 
293 impl<S: fmt::Debug + Any> Error for HandshakeError<S> {
294     fn source(&self) -> Option<&(dyn Error + 'static)> {
295         match *self {
296             HandshakeError::Failure(ref e) => Some(e),
297             HandshakeError::Interrupted(_) => None,
298         }
299     }
300 }
301 
302 impl<S: fmt::Debug + Any> fmt::Display for HandshakeError<S> {
303     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
304         let desc = match *self {
305             HandshakeError::Failure(_) => "failed to perform handshake",
306             HandshakeError::Interrupted(_) => "interrupted performing handshake",
307         };
308         write!(f, "{}", desc)?;
309         if let Some(e) = self.source() {
310            write!(f, ": {}", e)?;
311         }
312         Ok(())
313     }
314 }
315 
316 /// A stream which has not yet completed its handshake.
317 #[derive(Debug)]
318 pub struct MidHandshakeTlsStream<S> {
319     inner: TlsStream<S>,
320 }
321 
322 impl<S> fmt::Debug for TlsStream<S>
323     where S: fmt::Debug
324 {
325     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
326         fmt.debug_struct("TlsStream")
327             .field("stream", &self.stream)
328             .finish()
329     }
330 }
331 
332 impl<S> TlsStream<S> {
333     /// Returns a reference to the wrapped stream.
334     pub fn get_ref(&self) -> &S {
335         &self.stream
336     }
337 
338     /// Returns a mutable reference to the wrapped stream.
339     pub fn get_mut(&mut self) -> &mut S {
340         &mut self.stream
341     }
342 
343     /// Indicates if this stream is the server- or client-side of a TLS session.
344     pub fn is_server(&self) -> bool {
345         self.server
346     }
347 }
348 
349 impl<S> TlsStream<S>
350     where S: Read + Write
351 {
352     /// Returns the certificate used to identify this side of the TLS session.
353     ///
354     /// Its associated cert store contains any intermediate certificates sent
355     /// along with the leaf.
356     pub fn certificate(&self) -> io::Result<CertContext> {
357         self.context.local_cert()
358     }
359 
360     /// Returns the peer's certificate, if available.
361     ///
362     /// Its associated cert store contains any intermediate certificates sent
363     /// by the server.
364     pub fn peer_certificate(&self) -> io::Result<CertContext> {
365         self.context.remote_cert()
366     }
367 
368     /// Returns the negotiated application protocol for this tls stream, if one exists
369     pub fn negotiated_application_protocol(&self) -> io::Result<Option<Vec<u8>>> {
370         let client_proto = self.context.application_protocol()?;
371         if client_proto.ProtoNegoStatus != sspi::SecApplicationProtocolNegotiationStatus_Success
372             || client_proto.ProtoNegoExt != sspi::SecApplicationProtocolNegotiationExt_ALPN
373         {
374             return Ok(None);
375         }
376         Ok(Some(client_proto.ProtocolId[..client_proto.ProtocolIdSize as usize].to_vec()))
377     }
378 
379     /// Returns whether or not the session was resumed.
380     pub fn session_resumed(&self) -> io::Result<bool> {
381         let session_info = self.context.session_info()?;
382         Ok(session_info.dwFlags & schannel::SSL_SESSION_RECONNECT > 0)
383     }
384 
385     /// Returns a reference to the buffer of pending data.
386     ///
387     /// Like `BufRead::fill_buf` except that it will return an empty slice
388     /// rather than reading from the wrapped stream if there is no buffered
389     /// data.
390     pub fn get_buf(&self) -> &[u8] {
391         &self.dec_in.get_ref()[self.dec_in.position() as usize..]
392     }
393 
394     /// Shuts the TLS session down.
395     pub fn shutdown(&mut self) -> io::Result<()> {
396         match self.state {
397             State::Shutdown => return Ok(()),
398             State::Initializing { shutting_down: true, .. } => {}
399             _ => {
400                 unsafe {
401                     let mut token = um::schannel::SCHANNEL_SHUTDOWN;
402                     let ptr = &mut token as *mut _ as *mut u8;
403                     let size = mem::size_of_val(&token);
404                     let token = slice::from_raw_parts_mut(ptr, size);
405                     let mut buf = [secbuf(sspi::SECBUFFER_TOKEN, Some(token))];
406                     let mut desc = secbuf_desc(&mut buf);
407 
408                     match sspi::ApplyControlToken(self.context.get_mut(), &mut desc) {
409                         winerror::SEC_E_OK => {}
410                         err => return Err(io::Error::from_raw_os_error(err as i32)),
411                     }
412                 }
413 
414                 self.state = State::Initializing {
415                     needs_flush: false,
416                     more_calls: true,
417                     shutting_down: true,
418                     validated: false,
419                 };
420                 self.needs_read = 0;
421             }
422         }
423 
424         self.initialize().map(|_| ())
425     }
426 
427     fn step_initialize(&mut self) -> io::Result<()> {
428         unsafe {
429             let pos = self.enc_in.position() as usize;
430             let mut inbufs = vec![secbuf(sspi::SECBUFFER_TOKEN,
431                                          Some(&mut self.enc_in.get_mut()[..pos])),
432                                   secbuf(sspi::SECBUFFER_EMPTY, None)];
433             // Make sure `AlpnList` is kept alive for the duration of this function.
434             let mut alpns = self.requested_application_protocols.as_ref().map(|alpn| AlpnList::new(&alpn));
435             if let Some(ref mut alpns) = alpns {
436                 inbufs.push(secbuf(sspi::SECBUFFER_APPLICATION_PROTOCOLS,
437                                    Some(&mut alpns[..])));
438             };
439             let mut inbuf_desc = secbuf_desc(&mut inbufs[..]);
440 
441             let mut outbufs = [secbuf(sspi::SECBUFFER_TOKEN, None),
442                                secbuf(sspi::SECBUFFER_ALERT, None),
443                                secbuf(sspi::SECBUFFER_EMPTY, None)];
444             let mut outbuf_desc = secbuf_desc(&mut outbufs);
445 
446             let mut attributes = 0;
447 
448             let status = if self.server {
449                 let ptr = if self.accept_first {
450                     ptr::null_mut()
451                 } else {
452                     self.context.get_mut()
453                 };
454                 sspi::AcceptSecurityContext(&mut self.cred.as_inner(),
455                                             ptr,
456                                             &mut inbuf_desc,
457                                             ACCEPT_REQUESTS,
458                                             0,
459                                             self.context.get_mut(),
460                                             &mut outbuf_desc,
461                                             &mut attributes,
462                                             ptr::null_mut())
463             } else {
464                 let domain = match self.domain {
465                     Some(ref domain) if self.use_sni => domain.as_ptr() as *mut u16,
466                     _ => ptr::null_mut(),
467                 };
468 
469                 sspi::InitializeSecurityContextW(&mut self.cred.as_inner(),
470                                                  self.context.get_mut(),
471                                                  domain,
472                                                  INIT_REQUESTS,
473                                                  0,
474                                                  0,
475                                                  &mut inbuf_desc,
476                                                  0,
477                                                  ptr::null_mut(),
478                                                  &mut outbuf_desc,
479                                                  &mut attributes,
480                                                  ptr::null_mut())
481             };
482 
483             for buf in &outbufs[1..] {
484                 if !buf.pvBuffer.is_null() {
485                     sspi::FreeContextBuffer(buf.pvBuffer);
486                 }
487             }
488 
489             match status {
490                 winerror::SEC_I_CONTINUE_NEEDED => {
func_require_term_colorsnull491                     // Windows apparently doesn't like AcceptSecurityContext
492                     // being called as if it were the second time unless the
493                     // first call to AcceptSecurityContext succeeded with
494                     // CONTINUE_NEEDED.
495                     //
496                     // In other words, if we were to set `accept_first` to
497                     // `false` after the literal first call to
498                     // `AcceptSecurityContext` while the call returned
499                     // INCOMPLETE_MESSAGE, the next call would return an error.
500                     //
501                     // For that reason we only set `accept_first` to false here
502                     // once we've actually successfully received the full
503                     // "token" from the client.
504                     self.accept_first = false;
505                     let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA {
506                         self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
507                     } else {
508                         self.enc_in.position() as usize
509                     };
510                     let to_write = ContextBuffer(outbufs[0]);
511 
512                     self.consume_enc_in(nread);
513                     self.needs_read = (self.enc_in.position() == 0) as usize;
514                     self.out_buf.get_mut().extend_from_slice(&to_write);
515                 }
516                 winerror::SEC_E_INCOMPLETE_MESSAGE => {
517                     self.needs_read = if inbufs[1].BufferType == sspi::SECBUFFER_MISSING {
518                         inbufs[1].cbBuffer as usize
519                     } else {
520                         1
521                     };
522                 }
523                 winerror::SEC_E_OK => {
524                     let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA {
525                         self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
526                     } else {
527                         self.enc_in.position() as usize
528                     };
529                     let to_write = if outbufs[0].pvBuffer.is_null() {
530                         None
531                     } else {
532                         Some(ContextBuffer(outbufs[0]))
533                     };
534 
535                     self.consume_enc_in(nread);
536                     self.needs_read = (self.enc_in.position() == 0) as usize;
537                     if let Some(to_write) = to_write {
538                         self.out_buf.get_mut().extend_from_slice(&to_write);
539                     }
540                     if self.enc_in.position() != 0 {
541                         self.decrypt()?;
542                     }
543                     if let State::Initializing { ref mut more_calls, .. } = self.state {
544                         *more_calls = false;
545                     }
546                 }
547                 _ => {
548                     return Err(io::Error::from_raw_os_error(status as i32))
549                 }
550             }
551             Ok(())
552         }
553     }
554 
555     fn initialize(&mut self) -> io::Result<Option<sspi::SecPkgContext_StreamSizes>> {
556         loop {
557             match self.state {
558                 State::Initializing { mut needs_flush, more_calls, shutting_down, validated } => {
559                     if self.write_out()? > 0 {
560                         needs_flush = true;
561                         if let State::Initializing { ref mut needs_flush, .. } = self.state {
562                             *needs_flush = true;
563                         }
564                     }
565 
566                     if needs_flush {
567                         self.stream.flush()?;
568                         if let State::Initializing { ref mut needs_flush, .. } = self.state {
569                             *needs_flush = false;
570                         }
571                     }
572 
func_appendnull573                     if !shutting_down && !validated {
574                         // on the last call, we require a valid certificate
575                         if self.validate(!more_calls)? {
576                             if let State::Initializing { ref mut validated, .. } = self.state {
577                                 *validated = true;
578                             }
579                         }
580                     }
581 
582                     if !more_calls {
583                         self.state = if shutting_down {
584                             State::Shutdown
585                         } else {
586                             State::Streaming { sizes: self.context.stream_sizes()? }
587                         };
588                         continue;
589                     }
590 
591                     if self.needs_read > 0 {
592                         if self.read_in()? == 0 {
593                             return Err(io::Error::new(io::ErrorKind::UnexpectedEof,
594                                                       "unexpected EOF during handshake"));
func_append_quotednull595                         }
596                     }
597 
598                     self.step_initialize()?;
599                 }
600                 State::Streaming { sizes } => return Ok(Some(sizes)),
601                 State::Shutdown => return Ok(None),
602             }
603         }
604     }
605 
606     /// Returns true when the certificate was succesfully verified
607     /// Returns false, when a verification isn't necessary (yet)
608     /// Returns an error when the verification failed
609     fn validate(&mut self, require_cert: bool) -> io::Result<bool> {
610         // If we're accepting connections then we don't perform any validation
611         // for the remote certificate, that's what they're doing!
612         if self.server {
613             return Ok(false);
614         }
func_append_uniqnull615 
616         let cert_context = match self.context.remote_cert() {
617             Err(_) if !require_cert => return Ok(false),
618             ret => ret?
619         };
620 
621         let cert_chain = unsafe {
622             let cert_store = match (cert_context.cert_store(), &self.cert_store) {
623                 (Some(ref mut chain_certs), &Some(ref extra_certs)) => {
624                     for extra_cert in extra_certs.certs() {
625                         chain_certs.add_cert(&extra_cert, CertAdd::ReplaceExisting)?;
626                     }
627                     chain_certs.as_inner()
628                 },
629                 (Some(chain_certs), &None) => chain_certs.as_inner(),
630                 (None, &Some(ref extra_certs)) => extra_certs.as_inner(),
631                 (None, &None) => ptr::null_mut()
632             };
633 
634             let flags = wincrypt::CERT_CHAIN_CACHE_END_CERT |
635                         wincrypt::CERT_CHAIN_REVOCATION_CHECK_CACHE_ONLY |
636                         wincrypt::CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT;
637 
638             let mut para: wincrypt::CERT_CHAIN_PARA = mem::zeroed();
639             para.cbSize = mem::size_of_val(&para) as winapi::DWORD;
640             para.RequestedUsage.dwType = wincrypt::USAGE_MATCH_TYPE_OR;
641 
642             let mut identifiers = [szOID_PKIX_KP_SERVER_AUTH.as_ptr() as ntdef::LPSTR,
643                                    szOID_SERVER_GATED_CRYPTO.as_ptr() as ntdef::LPSTR,
func_arithnull644                                    szOID_SGC_NETSCAPE.as_ptr() as ntdef::LPSTR];
645             para.RequestedUsage.Usage.cUsageIdentifier = identifiers.len() as winapi::DWORD;
646             para.RequestedUsage.Usage.rgpszUsageIdentifier = identifiers.as_mut_ptr();
647 
648             let mut cert_chain = mem::zeroed();
649 
650             let res = wincrypt::CertGetCertificateChain(ptr::null_mut(),
651                                                         cert_context.as_inner(),
652                                                         ptr::null_mut(),
653                                                         cert_store,
654                                                         &mut para,
655                                                         flags,
656                                                         ptr::null_mut(),
657                                                         &mut cert_chain);
658 
659             if res == winapi::TRUE {
660                 CertChainContext(cert_chain as *mut _)
661             } else {
662                 return Err(io::Error::last_os_error())
663             }
664         };
665 
666         unsafe {
667             // check if we trust the root-CA explicitly
668             let mut para_flags = wincrypt::CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS;
669             if let Some(ref mut store) = self.cert_store {
670                 if let Some(chain) = cert_chain.final_chain() {
671                     // check if any cert of the chain is in the passed store (and therefore trusted)
672                     if chain.certificates().any(|cert| store.certs().any(|root_cert| root_cert == cert)) {
673                         para_flags |= wincrypt::CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG;
674                     }
675                 }
676             }
677 
678             let mut extra_para: wincrypt::SSL_EXTRA_CERT_CHAIN_POLICY_PARA = mem::zeroed();
679             *extra_para.u.cbSize_mut() = mem::size_of_val(&extra_para) as winapi::DWORD;
680             extra_para.dwAuthType = wincrypt::AUTHTYPE_SERVER;
681             match self.domain {
682                 Some(ref mut domain) if !self.accept_invalid_hostnames => {
683                     extra_para.pwszServerName = domain.as_mut_ptr();
684                 }
685                 _ => {}
686             }
687 
688             let mut para: wincrypt::CERT_CHAIN_POLICY_PARA = mem::zeroed();
689             para.cbSize = mem::size_of_val(&para) as winapi::DWORD;
690             para.dwFlags = para_flags;
691             para.pvExtraPolicyPara = &mut extra_para as *mut _ as *mut _;
692 
693             let mut status: wincrypt::CERT_CHAIN_POLICY_STATUS = mem::zeroed();
694             status.cbSize = mem::size_of_val(&status) as winapi::DWORD;
695 
696             let verify_chain_policy_structure = wincrypt::CERT_CHAIN_POLICY_SSL as ntdef::LPCSTR;
697             let res = wincrypt::CertVerifyCertificateChainPolicy(verify_chain_policy_structure,
698                                                                 cert_chain.0,
699                                                                 &mut para,
700                                                                 &mut status);
701             if res == winapi::FALSE {
702                 return Err(io::Error::last_os_error())
703             }
704 
705             let mut verify_result = if status.dwError != winerror::ERROR_SUCCESS {
706                 Err(io::Error::from_raw_os_error(status.dwError as i32))
707             } else {
708                 Ok(())
709             };
710 
711             // check if there's a user-specified verify callback
712             if let Some(ref callback) = self.verify_callback {
713                 verify_result = callback(CertValidationResult{
714                     chain: cert_chain,
715                     res: status.dwError as i32,
716                     chain_index: status.lChainIndex,
717                     element_index: status.lElementIndex});
718             }
719             verify_result?;
720         }
721         Ok(true)
func_echonull722     }
723 
724     fn write_out(&mut self) -> io::Result<usize> {
725         let mut out = 0;
726         while self.out_buf.position() as usize != self.out_buf.get_ref().len() {
727             let position = self.out_buf.position() as usize;
728             let nwritten = self.stream.write(&self.out_buf.get_ref()[position..])?;
729             out += nwritten;
730             self.out_buf.set_position((position + nwritten) as u64);
731         }
732 
733         Ok(out)
734     }
735 
736     fn read_in(&mut self) -> io::Result<usize> {
737         let mut sum_nread = 0;
738 
739         while self.needs_read > 0 {
740             let existing_len = self.enc_in.position() as usize;
func_echo_allnull741             let min_len = cmp::max(cmp::max(1024, 2 * existing_len), self.needs_read);
742             if self.enc_in.get_ref().len() < min_len {
743                 self.enc_in.get_mut().resize(min_len, 0);
744             }
745             let nread = {
746                 let buf = &mut self.enc_in.get_mut()[existing_len..];
747                 self.stream.read(buf)?
748             };
749             self.enc_in.set_position((existing_len + nread) as u64);
750             self.needs_read = self.needs_read.saturating_sub(nread);
func_echo_infix_1null751             if nread == 0 {
752                 break;
753             }
754             sum_nread += nread;
755         }
756 
757         Ok(sum_nread)
758     }
759 
760     fn consume_enc_in(&mut self, nread: usize) {
761         let size = self.enc_in.position() as usize;
762         assert!(size >= nread);
763         let count = size - nread;
764 
765         if count > 0 {
766             self.enc_in.get_mut().drain(..nread);
767         }
768 
769         self.enc_in.set_position(count as u64);
770     }
771 
772     fn decrypt(&mut self) -> io::Result<bool> {
773         unsafe {
774             let position = self.enc_in.position() as usize;
775             let mut bufs = [secbuf(sspi::SECBUFFER_DATA,
776                                    Some(&mut self.enc_in.get_mut()[..position])),
777                             secbuf(sspi::SECBUFFER_EMPTY, None),
778                             secbuf(sspi::SECBUFFER_EMPTY, None),
779                             secbuf(sspi::SECBUFFER_EMPTY, None)];
780             let mut bufdesc = secbuf_desc(&mut bufs);
781 
782             match sspi::DecryptMessage(self.context.get_mut(),
783                                           &mut bufdesc,
784                                           0,
785                                           ptr::null_mut()) {
func_errornull786                 winerror::SEC_E_OK => {
787                     let start = bufs[1].pvBuffer as usize - self.enc_in.get_ref().as_ptr() as usize;
788                     let end = start + bufs[1].cbBuffer as usize;
789                     self.dec_in.get_mut().clear();
790                     self.dec_in
791                         .get_mut()
792                         .extend_from_slice(&self.enc_in.get_ref()[start..end]);
793                     self.dec_in.set_position(0);
794 
795                     let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA {
796                         self.enc_in.position() as usize - bufs[3].cbBuffer as usize
797                     } else {
798                         self.enc_in.position() as usize
799                     };
800                     self.consume_enc_in(nread);
801                     self.needs_read = (self.enc_in.position() == 0) as usize;
802                     Ok(false)
803                 }
804                 winerror::SEC_E_INCOMPLETE_MESSAGE => {
805                     self.needs_read = if bufs[1].BufferType == sspi::SECBUFFER_MISSING {
806                         bufs[1].cbBuffer as usize
807                     } else {
808                         1
809                     };
810                     Ok(false)
func_grepnull811                 }
812                 winerror::SEC_I_CONTEXT_EXPIRED => Ok(true),
813                 winerror::SEC_I_RENEGOTIATE => {
814                     self.state = State::Initializing {
815                         needs_flush: false,
816                         more_calls: true,
817                         shutting_down: false,
818                         validated: false,
819                     };
820 
821                     let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA {
822                         self.enc_in.position() as usize - bufs[3].cbBuffer as usize
823                     } else {
824                         self.enc_in.position() as usize
825                     };
826                     self.consume_enc_in(nread);
827                     self.needs_read = 0;
828                     Ok(false)
829                 }
830                 e => Err(io::Error::from_raw_os_error(e as i32)),
831             }
832         }
833     }
834 
835     fn encrypt(&mut self, buf: &[u8], sizes: &sspi::SecPkgContext_StreamSizes) -> io::Result<()> {
func_lennull836         assert!(buf.len() <= sizes.cbMaximumMessage as usize);
837 
838         unsafe {
839             let len = sizes.cbHeader as usize + buf.len() + sizes.cbTrailer as usize;
840 
841             if self.out_buf.get_ref().len() < len {
842                 self.out_buf.get_mut().resize(len, 0);
843             }
844 
845             let message_start = sizes.cbHeader as usize;
846             self.out_buf
847                 .get_mut()[message_start..message_start + buf.len()]
func_mkdir_pnull848                 .clone_from_slice(buf);
849 
850             let mut bufs = {
851                 let out_buf = self.out_buf.get_mut();
852                 let size = sizes.cbHeader as usize;
853 
854                 let header = secbuf(sspi::SECBUFFER_STREAM_HEADER,
855                                     Some(&mut out_buf[..size]));
856                 let data = secbuf(sspi::SECBUFFER_DATA,
857                                   Some(&mut out_buf[size..size + buf.len()]));
858                 let trailer = secbuf(sspi::SECBUFFER_STREAM_TRAILER,
859                                      Some(&mut out_buf[size + buf.len()..]));
860                 let empty = secbuf(sspi::SECBUFFER_EMPTY, None);
861                 [header, data, trailer, empty]
862             };
863             let mut bufdesc = secbuf_desc(&mut bufs);
864 
865             match sspi::EncryptMessage(self.context.get_mut(), 0, &mut bufdesc, 0) {
866                 winerror::SEC_E_OK => {
867                     let len = bufs[0].cbBuffer + bufs[1].cbBuffer + bufs[2].cbBuffer;
868                     self.out_buf.get_mut().truncate(len as usize);
869                     self.out_buf.set_position(0);
870                     Ok(())
871                 }
872                 err => Err(io::Error::from_raw_os_error(err as i32)),
873             }
874         }
875     }
876 }
877 
878 impl<S> MidHandshakeTlsStream<S> {
879     /// Returns a shared reference to the inner stream.
880     pub fn get_ref(&self) -> &S {
881         self.inner.get_ref()
882     }
883 
884     /// Returns a mutable reference to the inner stream.
885     pub fn get_mut(&mut self) -> &mut S {
886         self.inner.get_mut()
887     }
888 }
889 
890 impl<S> MidHandshakeTlsStream<S>
891     where S: Read + Write,
892 {
893     /// Restarts the handshake process.
894     pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
895         match self.inner.initialize() {
896             Ok(_) => Ok(self.inner),
897             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
func_mktempdirnull898                 Err(HandshakeError::Interrupted(self))
899             }
900             Err(e) => Err(HandshakeError::Failure(e)),
901         }
902     }
903 }
904 
905 impl<S> Write for TlsStream<S>
906     where S: Read + Write
907 {
908     /// In the case of a WouldBlock error, we expect another call
909     /// starting with the same input data
910     /// This is similar to the use of ACCEPT_MOVING_WRITE_BUFFER in openssl
911     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
912         let sizes = match self.initialize()? {
913             Some(sizes) => sizes,
914             None => return Err(io::Error::from_raw_os_error(winerror::SEC_E_CONTEXT_EXPIRED as i32)),
915         };
916 
917         // if we have pending output data, it must have been because a previous
918         // attempt to send this part of the data ran into an error.
919         if self.out_buf.position() == self.out_buf.get_ref().len() as u64 {
920             let len = cmp::min(buf.len(), sizes.cbMaximumMessage as usize);
921             self.encrypt(&buf[..len], &sizes)?;
922             self.last_write_len = len;
923         }
924         self.write_out()?;
925 
926         Ok(self.last_write_len)
927     }
928 
929     fn flush(&mut self) -> io::Result<()> {
930         // Make sure the write buffer is emptied
931         self.write_out()?;
932         self.stream.flush()
933     }
934 }
935 
func_normal_abspathnull936 impl<S> Read for TlsStream<S>
937     where S: Read + Write
938 {
939     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
940         let nread = {
941             let read_buf = self.fill_buf()?;
942             let nread = cmp::min(buf.len(), read_buf.len());
943             buf[..nread].copy_from_slice(&read_buf[..nread]);
944             nread
945         };
946         self.consume(nread);
947         Ok(nread)
948     }
949 }
950 
951 impl<S> BufRead for TlsStream<S>
952     where S: Read + Write
953 {
954     fn fill_buf(&mut self) -> io::Result<&[u8]> {
955         while self.get_buf().is_empty() {
956             if let None = self.initialize()? {
957                 break;
958             }
959 
960             if self.needs_read > 0 {
961                 if self.read_in()? == 0 {
962                     break;
963                 }
964                 self.needs_read = 0;
965             }
966 
967             let eof = self.decrypt()?;
968             if eof {
969                 break;
970             }
971         }
972 
973         Ok(self.get_buf())
974     }
975 
976     fn consume(&mut self, amt: usize) {
977         let pos = self.dec_in.position() + amt as u64;
978         assert!(pos <= self.dec_in.get_ref().len() as u64);
979         self.dec_in.set_position(pos);
980     }
981 }
982