1 //! Schannel TLS streams.
2 use std::any::Any;
3 use std::cmp;
4 use std::error::Error;
5 use std::fmt;
6 use std::io::{self, Read, BufRead, Write, Cursor};
7 use std::mem;
8 use std::ptr;
9 use std::slice;
10 use std::sync::Arc;
11 use winapi::shared::minwindef as winapi;
12 use winapi::shared::{ntdef, sspi, winerror};
13 use winapi::um::{self, schannel, wincrypt};
14
15 use crate::{INIT_REQUESTS, ACCEPT_REQUESTS, Inner, secbuf, secbuf_desc};
16 use crate::alpn_list::AlpnList;
17 use crate::cert_chain::{CertChain, CertChainContext};
18 use crate::cert_store::{CertAdd, CertStore};
19 use crate::cert_context::CertContext;
20 use crate::security_context::SecurityContext;
21 use crate::context_buffer::ContextBuffer;
22 use crate::schannel_cred::SchannelCred;
23
24 lazy_static! {
25 static ref szOID_PKIX_KP_SERVER_AUTH: Vec<u8> =
26 wincrypt::szOID_PKIX_KP_SERVER_AUTH.bytes().chain(Some(0)).collect();
27 static ref szOID_SERVER_GATED_CRYPTO: Vec<u8> =
28 wincrypt::szOID_SERVER_GATED_CRYPTO.bytes().chain(Some(0)).collect();
29 static ref szOID_SGC_NETSCAPE: Vec<u8> =
30 wincrypt::szOID_SGC_NETSCAPE.bytes().chain(Some(0)).collect();
31 }
32
33 /// A builder type for `TlsStream`s.
34 pub struct Builder {
35 domain: Option<Vec<u16>>,
36 use_sni: bool,
37 accept_invalid_hostnames: bool,
38 verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>,
39 cert_store: Option<CertStore>,
40 requested_application_protocols: Option<Vec<Vec<u8>>>,
41 }
42
43 impl Default for Builder {
default() -> Builder44 fn default() -> Builder {
45 Builder {
46 domain: None,
47 use_sni: true,
48 accept_invalid_hostnames: false,
49 verify_callback: None,
50 cert_store: None,
51 requested_application_protocols: None,
52 }
53 }
54 }
55
56 impl Builder {
57 /// Returns a new `Builder`.
new() -> Builder58 pub fn new() -> Builder {
59 Builder::default()
60 }
61
62 /// Sets the domain associated with connections created with this `Builder`.
63 ///
64 /// The domain will be used for Server Name Indication as well as
65 /// certificate validation.
domain(&mut self, domain: &str) -> &mut Builder66 pub fn domain(&mut self, domain: &str) -> &mut Builder {
67 self.domain = Some(domain.encode_utf16().chain(Some(0)).collect());
68 self
69 }
70
71 /// Determines if Server Name Indication (SNI) will be used.
72 ///
73 /// Defaults to `true`.
use_sni(&mut self, use_sni: bool) -> &mut Builder74 pub fn use_sni(&mut self, use_sni: bool) -> &mut Builder {
75 self.use_sni = use_sni;
76 self
77 }
78
79 /// Determines if the server's hostname will be checked during certificate verification.
80 ///
81 /// Defaults to `false`.
accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) -> &mut Builder82 pub fn accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) -> &mut Builder {
83 self.accept_invalid_hostnames = accept_invalid_hostnames;
84 self
85 }
86
87 /// Set a verification callback to be used for connections created with this `Builder`.
88 ///
89 /// The callback is provided with an io::Result indicating if the (pre)validation was
90 /// successful. The Ok() variant indicates a successful validation while the Err() variant
91 /// contains the errorcode returned from the internal verification process.
92 /// The validated certificate, is accessible through the second argument of the closure.
verify_callback<F>(&mut self, callback: F) -> &mut Builder where F: Fn(CertValidationResult) -> io::Result<()> + 'static + Sync + Send93 pub fn verify_callback<F>(&mut self, callback: F) -> &mut Builder
94 where F: Fn(CertValidationResult) -> io::Result<()> + 'static + Sync + Send
95 {
96 self.verify_callback = Some(Arc::new(callback));
97 self
98 }
99
100 /// Specifies a custom certificate store which is later used when validating
101 /// a server's certificate.
102 ///
103 /// This option is only used for client connections and is used to construct
104 /// the certificate chain which the server's certificate is validated
105 /// against.
106 ///
107 /// Note that adding certificates here means that they are
108 /// implicitly trusted.
cert_store(&mut self, cert_store: CertStore) -> &mut Builder109 pub fn cert_store(&mut self, cert_store: CertStore) -> &mut Builder {
110 self.cert_store = Some(cert_store);
111 self
112 }
113
114 /// Requests one of a set of application protocols using alpn
request_application_protocols(&mut self, alpns: &[&[u8]]) -> &mut Builder115 pub fn request_application_protocols(&mut self, alpns: &[&[u8]]) -> &mut Builder {
116 self.requested_application_protocols =
117 Some(alpns.iter().map(|bytes| bytes.to_vec()).collect::<Vec<_>>());
118 self
119 }
120
121 /// Initialize a new TLS session where the stream provided will be
122 /// connecting to a remote TLS server.
123 ///
124 /// If the stream provided is a blocking stream then the entire handshake
125 /// will be performed if possible, but if the stream is in nonblocking mode
126 /// then a `HandshakeError::Interrupted` variant may be returned. This
127 /// type can then be extracted to later call
128 /// `MidHandshakeTlsStream::handshake` when data becomes available.
connect<S>(&mut self, cred: SchannelCred, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: Read + Write129 pub fn connect<S>(&mut self,
130 cred: SchannelCred,
131 stream: S)
132 -> Result<TlsStream<S>, HandshakeError<S>>
133 where S: Read + Write
134 {
135 self.initialize(cred, false, stream)
136 }
137
138 /// Initialize a new TLS session where the stream provided will be
139 /// accepting a connection.
140 ///
141 /// This method will tweak the protocol for "who talks first" and also
142 /// currently disables validation of the client that's connecting to us.
143 ///
144 /// If the stream provided is a blocking stream then the entire handshake
145 /// will be performed if possible, but if the stream is in nonblocking mode
146 /// then a `HandshakeError::Interrupted` variant may be returned. This
147 /// type can then be extracted to later call
148 /// `MidHandshakeTlsStream::handshake` when data becomes available.
accept<S>(&mut self, cred: SchannelCred, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: Read + Write149 pub fn accept<S>(&mut self,
150 cred: SchannelCred,
151 stream: S)
152 -> Result<TlsStream<S>, HandshakeError<S>>
153 where S: Read + Write
154 {
155 self.initialize(cred, true, stream)
156 }
157
initialize<S>(&mut self, mut cred: SchannelCred, server: bool, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: Read + Write158 fn initialize<S>(&mut self,
159 mut cred: SchannelCred,
160 server: bool,
161 stream: S)
162 -> Result<TlsStream<S>, HandshakeError<S>>
163 where S: Read + Write
164 {
165 let domain = match self.domain {
166 Some(ref domain) if self.use_sni => Some(&domain[..]),
167 _ => None,
168 };
169 let (ctxt, buf) = match SecurityContext::initialize(&mut cred,
170 server,
171 domain,
172 &self.requested_application_protocols) {
173 Ok(pair) => pair,
174 Err(e) => return Err(HandshakeError::Failure(e)),
175 };
176
177 let stream = TlsStream {
178 cred: cred,
179 context: ctxt,
180 cert_store: self.cert_store.clone(),
181 domain: self.domain.clone(),
182 use_sni: self.use_sni,
183 accept_invalid_hostnames: self.accept_invalid_hostnames,
184 verify_callback: self.verify_callback.clone(),
185 stream: stream,
186 server: server,
187 accept_first: true,
188 state: State::Initializing {
189 needs_flush: false,
190 more_calls: true,
191 shutting_down: false,
192 validated: false,
193 },
194 needs_read: 1,
195 dec_in: Cursor::new(Vec::new()),
196 enc_in: Cursor::new(Vec::new()),
197 out_buf: Cursor::new(buf.map(|b| b.to_owned()).unwrap_or(Vec::new())),
198 last_write_len: 0,
199 requested_application_protocols: self.requested_application_protocols.clone(),
200 };
201
202 MidHandshakeTlsStream {
203 inner: stream,
204 }.handshake()
205 }
206 }
207
208 enum State {
209 Initializing {
210 needs_flush: bool,
211 more_calls: bool,
212 shutting_down: bool,
213 validated: bool,
214 },
215 Streaming { sizes: sspi::SecPkgContext_StreamSizes, },
216 Shutdown,
217 }
218
219 /// An Schannel TLS stream.
220 pub struct TlsStream<S> {
221 cred: SchannelCred,
222 context: SecurityContext,
223 cert_store: Option<CertStore>,
224 domain: Option<Vec<u16>>,
225 use_sni: bool,
226 accept_invalid_hostnames: bool,
227 verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>,
228 stream: S,
229 state: State,
230 server: bool,
231 accept_first: bool,
232 needs_read: usize,
233 // valid from position() to len()
234 dec_in: Cursor<Vec<u8>>,
235 // valid from 0 to position()
236 enc_in: Cursor<Vec<u8>>,
237 // valid from position() to len()
238 out_buf: Cursor<Vec<u8>>,
239 /// the (unencrypted) length of the last write call used to track writes
240 last_write_len: usize,
241 requested_application_protocols: Option<Vec<Vec<u8>>>,
242 }
243
244 /// ensures that a TlsStream is always Sync/Send
_is_sync()245 fn _is_sync() {
246 fn sync<T: Sync + Send>() {}
247 sync::<TlsStream<()>>();
248 }
249
250 /// A failure which can happen during the `Builder::initialize` phase, either an
251 /// I/O error or an intermediate stream which has not completed its handshake.
252 #[derive(Debug)]
253 pub enum HandshakeError<S> {
254 /// A fatal I/O error occurred
255 Failure(io::Error),
256 /// The stream connection is in progress, but the handshake is not completed
257 /// yet.
258 Interrupted(MidHandshakeTlsStream<S>),
259 }
260
261 /// A struct used to wrap various cert chain validation results for callback processing.
262 pub struct CertValidationResult {
263 chain: CertChainContext,
264 res: i32,
265 chain_index: i32,
266 element_index: i32,
267 }
268
269 impl CertValidationResult {
270 /// Returns the certificate that failed validation if applicable
failed_certificate(&self) -> Option<CertContext>271 pub fn failed_certificate(&self) -> Option<CertContext> {
272 if let Some(cert_chain) = self.chain.get_chain(self.chain_index as usize) {
273 return cert_chain.get(self.element_index as usize);
274 }
275 None
276 }
277
278 /// Returns the final certificate chain in the certificate context if applicable
chain(&self) -> Option<CertChain>279 pub fn chain(&self) -> Option<CertChain> {
280 self.chain.final_chain()
281 }
282
283 /// Returns the result of the built-in certificate verification process.
result(&self) -> io::Result<()>284 pub fn result(&self) -> io::Result<()> {
285 if self.res as u32 != winerror::ERROR_SUCCESS {
286 Err(io::Error::from_raw_os_error(self.res))
287 } else {
288 Ok(())
289 }
290 }
291 }
292
293 impl<S: fmt::Debug + Any> Error for HandshakeError<S> {
source(&self) -> Option<&(dyn Error + 'static)>294 fn source(&self) -> Option<&(dyn Error + 'static)> {
295 match *self {
296 HandshakeError::Failure(ref e) => Some(e),
297 HandshakeError::Interrupted(_) => None,
298 }
299 }
300 }
301
302 impl<S: fmt::Debug + Any> fmt::Display for HandshakeError<S> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result303 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
304 let desc = match *self {
305 HandshakeError::Failure(_) => "failed to perform handshake",
306 HandshakeError::Interrupted(_) => "interrupted performing handshake",
307 };
308 write!(f, "{}", desc)?;
309 if let Some(e) = self.source() {
310 write!(f, ": {}", e)?;
311 }
312 Ok(())
313 }
314 }
315
316 /// A stream which has not yet completed its handshake.
317 #[derive(Debug)]
318 pub struct MidHandshakeTlsStream<S> {
319 inner: TlsStream<S>,
320 }
321
322 impl<S> fmt::Debug for TlsStream<S>
323 where S: fmt::Debug
324 {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result325 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
326 fmt.debug_struct("TlsStream")
327 .field("stream", &self.stream)
328 .finish()
329 }
330 }
331
332 impl<S> TlsStream<S> {
333 /// Returns a reference to the wrapped stream.
get_ref(&self) -> &S334 pub fn get_ref(&self) -> &S {
335 &self.stream
336 }
337
338 /// Returns a mutable reference to the wrapped stream.
get_mut(&mut self) -> &mut S339 pub fn get_mut(&mut self) -> &mut S {
340 &mut self.stream
341 }
342
343 /// Indicates if this stream is the server- or client-side of a TLS session.
is_server(&self) -> bool344 pub fn is_server(&self) -> bool {
345 self.server
346 }
347 }
348
349 impl<S> TlsStream<S>
350 where S: Read + Write
351 {
352 /// Returns the certificate used to identify this side of the TLS session.
353 ///
354 /// Its associated cert store contains any intermediate certificates sent
355 /// along with the leaf.
certificate(&self) -> io::Result<CertContext>356 pub fn certificate(&self) -> io::Result<CertContext> {
357 self.context.local_cert()
358 }
359
360 /// Returns the peer's certificate, if available.
361 ///
362 /// Its associated cert store contains any intermediate certificates sent
363 /// by the server.
peer_certificate(&self) -> io::Result<CertContext>364 pub fn peer_certificate(&self) -> io::Result<CertContext> {
365 self.context.remote_cert()
366 }
367
368 /// Returns the negotiated application protocol for this tls stream, if one exists
negotiated_application_protocol(&self) -> io::Result<Option<Vec<u8>>>369 pub fn negotiated_application_protocol(&self) -> io::Result<Option<Vec<u8>>> {
370 let client_proto = self.context.application_protocol()?;
371 if client_proto.ProtoNegoStatus != sspi::SecApplicationProtocolNegotiationStatus_Success
372 || client_proto.ProtoNegoExt != sspi::SecApplicationProtocolNegotiationExt_ALPN
373 {
374 return Ok(None);
375 }
376 Ok(Some(client_proto.ProtocolId[..client_proto.ProtocolIdSize as usize].to_vec()))
377 }
378
379 /// Returns whether or not the session was resumed.
session_resumed(&self) -> io::Result<bool>380 pub fn session_resumed(&self) -> io::Result<bool> {
381 let session_info = self.context.session_info()?;
382 Ok(session_info.dwFlags & schannel::SSL_SESSION_RECONNECT > 0)
383 }
384
385 /// Returns a reference to the buffer of pending data.
386 ///
387 /// Like `BufRead::fill_buf` except that it will return an empty slice
388 /// rather than reading from the wrapped stream if there is no buffered
389 /// data.
get_buf(&self) -> &[u8]390 pub fn get_buf(&self) -> &[u8] {
391 &self.dec_in.get_ref()[self.dec_in.position() as usize..]
392 }
393
394 /// Shuts the TLS session down.
shutdown(&mut self) -> io::Result<()>395 pub fn shutdown(&mut self) -> io::Result<()> {
396 match self.state {
397 State::Shutdown => return Ok(()),
398 State::Initializing { shutting_down: true, .. } => {}
399 _ => {
400 unsafe {
401 let mut token = um::schannel::SCHANNEL_SHUTDOWN;
402 let ptr = &mut token as *mut _ as *mut u8;
403 let size = mem::size_of_val(&token);
404 let token = slice::from_raw_parts_mut(ptr, size);
405 let mut buf = [secbuf(sspi::SECBUFFER_TOKEN, Some(token))];
406 let mut desc = secbuf_desc(&mut buf);
407
408 match sspi::ApplyControlToken(self.context.get_mut(), &mut desc) {
409 winerror::SEC_E_OK => {}
410 err => return Err(io::Error::from_raw_os_error(err as i32)),
411 }
412 }
413
414 self.state = State::Initializing {
415 needs_flush: false,
416 more_calls: true,
417 shutting_down: true,
418 validated: false,
419 };
420 self.needs_read = 0;
421 }
422 }
423
424 self.initialize().map(|_| ())
425 }
426
step_initialize(&mut self) -> io::Result<()>427 fn step_initialize(&mut self) -> io::Result<()> {
428 unsafe {
429 let pos = self.enc_in.position() as usize;
430 let mut inbufs = vec![secbuf(sspi::SECBUFFER_TOKEN,
431 Some(&mut self.enc_in.get_mut()[..pos])),
432 secbuf(sspi::SECBUFFER_EMPTY, None)];
433 // Make sure `AlpnList` is kept alive for the duration of this function.
434 let mut alpns = self.requested_application_protocols.as_ref().map(|alpn| AlpnList::new(&alpn));
435 if let Some(ref mut alpns) = alpns {
436 inbufs.push(secbuf(sspi::SECBUFFER_APPLICATION_PROTOCOLS,
437 Some(&mut alpns[..])));
438 };
439 let mut inbuf_desc = secbuf_desc(&mut inbufs[..]);
440
441 let mut outbufs = [secbuf(sspi::SECBUFFER_TOKEN, None),
442 secbuf(sspi::SECBUFFER_ALERT, None),
443 secbuf(sspi::SECBUFFER_EMPTY, None)];
444 let mut outbuf_desc = secbuf_desc(&mut outbufs);
445
446 let mut attributes = 0;
447
448 let status = if self.server {
449 let ptr = if self.accept_first {
450 ptr::null_mut()
451 } else {
452 self.context.get_mut()
453 };
454 sspi::AcceptSecurityContext(&mut self.cred.as_inner(),
455 ptr,
456 &mut inbuf_desc,
457 ACCEPT_REQUESTS,
458 0,
459 self.context.get_mut(),
460 &mut outbuf_desc,
461 &mut attributes,
462 ptr::null_mut())
463 } else {
464 let domain = match self.domain {
465 Some(ref domain) if self.use_sni => domain.as_ptr() as *mut u16,
466 _ => ptr::null_mut(),
467 };
468
469 sspi::InitializeSecurityContextW(&mut self.cred.as_inner(),
470 self.context.get_mut(),
471 domain,
472 INIT_REQUESTS,
473 0,
474 0,
475 &mut inbuf_desc,
476 0,
477 ptr::null_mut(),
478 &mut outbuf_desc,
479 &mut attributes,
480 ptr::null_mut())
481 };
482
483 for buf in &outbufs[1..] {
484 if !buf.pvBuffer.is_null() {
485 sspi::FreeContextBuffer(buf.pvBuffer);
486 }
487 }
488
489 match status {
490 winerror::SEC_I_CONTINUE_NEEDED => {
491 // Windows apparently doesn't like AcceptSecurityContext
492 // being called as if it were the second time unless the
493 // first call to AcceptSecurityContext succeeded with
494 // CONTINUE_NEEDED.
495 //
496 // In other words, if we were to set `accept_first` to
497 // `false` after the literal first call to
498 // `AcceptSecurityContext` while the call returned
499 // INCOMPLETE_MESSAGE, the next call would return an error.
500 //
501 // For that reason we only set `accept_first` to false here
502 // once we've actually successfully received the full
503 // "token" from the client.
504 self.accept_first = false;
505 let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA {
506 self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
507 } else {
508 self.enc_in.position() as usize
509 };
510 let to_write = ContextBuffer(outbufs[0]);
511
512 self.consume_enc_in(nread);
513 self.needs_read = (self.enc_in.position() == 0) as usize;
514 self.out_buf.get_mut().extend_from_slice(&to_write);
515 }
516 winerror::SEC_E_INCOMPLETE_MESSAGE => {
517 self.needs_read = if inbufs[1].BufferType == sspi::SECBUFFER_MISSING {
518 inbufs[1].cbBuffer as usize
519 } else {
520 1
521 };
522 }
523 winerror::SEC_E_OK => {
524 let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA {
525 self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
526 } else {
527 self.enc_in.position() as usize
528 };
529 let to_write = if outbufs[0].pvBuffer.is_null() {
530 None
531 } else {
532 Some(ContextBuffer(outbufs[0]))
533 };
534
535 self.consume_enc_in(nread);
536 self.needs_read = (self.enc_in.position() == 0) as usize;
537 if let Some(to_write) = to_write {
538 self.out_buf.get_mut().extend_from_slice(&to_write);
539 }
540 if self.enc_in.position() != 0 {
541 self.decrypt()?;
542 }
543 if let State::Initializing { ref mut more_calls, .. } = self.state {
544 *more_calls = false;
545 }
546 }
547 _ => {
548 return Err(io::Error::from_raw_os_error(status as i32))
549 }
550 }
551 Ok(())
552 }
553 }
554
initialize(&mut self) -> io::Result<Option<sspi::SecPkgContext_StreamSizes>>555 fn initialize(&mut self) -> io::Result<Option<sspi::SecPkgContext_StreamSizes>> {
556 loop {
557 match self.state {
558 State::Initializing { mut needs_flush, more_calls, shutting_down, validated } => {
559 if self.write_out()? > 0 {
560 needs_flush = true;
561 if let State::Initializing { ref mut needs_flush, .. } = self.state {
562 *needs_flush = true;
563 }
564 }
565
566 if needs_flush {
567 self.stream.flush()?;
568 if let State::Initializing { ref mut needs_flush, .. } = self.state {
569 *needs_flush = false;
570 }
571 }
572
573 if !shutting_down && !validated {
574 // on the last call, we require a valid certificate
575 if self.validate(!more_calls)? {
576 if let State::Initializing { ref mut validated, .. } = self.state {
577 *validated = true;
578 }
579 }
580 }
581
582 if !more_calls {
583 self.state = if shutting_down {
584 State::Shutdown
585 } else {
586 State::Streaming { sizes: self.context.stream_sizes()? }
587 };
588 continue;
589 }
590
591 if self.needs_read > 0 {
592 if self.read_in()? == 0 {
593 return Err(io::Error::new(io::ErrorKind::UnexpectedEof,
594 "unexpected EOF during handshake"));
595 }
596 }
597
598 self.step_initialize()?;
599 }
600 State::Streaming { sizes } => return Ok(Some(sizes)),
601 State::Shutdown => return Ok(None),
602 }
603 }
604 }
605
606 /// Returns true when the certificate was succesfully verified
607 /// Returns false, when a verification isn't necessary (yet)
608 /// Returns an error when the verification failed
validate(&mut self, require_cert: bool) -> io::Result<bool>609 fn validate(&mut self, require_cert: bool) -> io::Result<bool> {
610 // If we're accepting connections then we don't perform any validation
611 // for the remote certificate, that's what they're doing!
612 if self.server {
613 return Ok(false);
614 }
615
616 let cert_context = match self.context.remote_cert() {
617 Err(_) if !require_cert => return Ok(false),
618 ret => ret?
619 };
620
621 let cert_chain = unsafe {
622 let cert_store = match (cert_context.cert_store(), &self.cert_store) {
623 (Some(ref mut chain_certs), &Some(ref extra_certs)) => {
624 for extra_cert in extra_certs.certs() {
625 chain_certs.add_cert(&extra_cert, CertAdd::ReplaceExisting)?;
626 }
627 chain_certs.as_inner()
628 },
629 (Some(chain_certs), &None) => chain_certs.as_inner(),
630 (None, &Some(ref extra_certs)) => extra_certs.as_inner(),
631 (None, &None) => ptr::null_mut()
632 };
633
634 let flags = wincrypt::CERT_CHAIN_CACHE_END_CERT |
635 wincrypt::CERT_CHAIN_REVOCATION_CHECK_CACHE_ONLY |
636 wincrypt::CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT;
637
638 let mut para: wincrypt::CERT_CHAIN_PARA = mem::zeroed();
639 para.cbSize = mem::size_of_val(¶) as winapi::DWORD;
640 para.RequestedUsage.dwType = wincrypt::USAGE_MATCH_TYPE_OR;
641
642 let mut identifiers = [szOID_PKIX_KP_SERVER_AUTH.as_ptr() as ntdef::LPSTR,
643 szOID_SERVER_GATED_CRYPTO.as_ptr() as ntdef::LPSTR,
644 szOID_SGC_NETSCAPE.as_ptr() as ntdef::LPSTR];
645 para.RequestedUsage.Usage.cUsageIdentifier = identifiers.len() as winapi::DWORD;
646 para.RequestedUsage.Usage.rgpszUsageIdentifier = identifiers.as_mut_ptr();
647
648 let mut cert_chain = mem::zeroed();
649
650 let res = wincrypt::CertGetCertificateChain(ptr::null_mut(),
651 cert_context.as_inner(),
652 ptr::null_mut(),
653 cert_store,
654 &mut para,
655 flags,
656 ptr::null_mut(),
657 &mut cert_chain);
658
659 if res == winapi::TRUE {
660 CertChainContext(cert_chain as *mut _)
661 } else {
662 return Err(io::Error::last_os_error())
663 }
664 };
665
666 unsafe {
667 // check if we trust the root-CA explicitly
668 let mut para_flags = wincrypt::CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS;
669 if let Some(ref mut store) = self.cert_store {
670 if let Some(chain) = cert_chain.final_chain() {
671 // check if any cert of the chain is in the passed store (and therefore trusted)
672 if chain.certificates().any(|cert| store.certs().any(|root_cert| root_cert == cert)) {
673 para_flags |= wincrypt::CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG;
674 }
675 }
676 }
677
678 let mut extra_para: wincrypt::SSL_EXTRA_CERT_CHAIN_POLICY_PARA = mem::zeroed();
679 *extra_para.u.cbSize_mut() = mem::size_of_val(&extra_para) as winapi::DWORD;
680 extra_para.dwAuthType = wincrypt::AUTHTYPE_SERVER;
681 match self.domain {
682 Some(ref mut domain) if !self.accept_invalid_hostnames => {
683 extra_para.pwszServerName = domain.as_mut_ptr();
684 }
685 _ => {}
686 }
687
688 let mut para: wincrypt::CERT_CHAIN_POLICY_PARA = mem::zeroed();
689 para.cbSize = mem::size_of_val(¶) as winapi::DWORD;
690 para.dwFlags = para_flags;
691 para.pvExtraPolicyPara = &mut extra_para as *mut _ as *mut _;
692
693 let mut status: wincrypt::CERT_CHAIN_POLICY_STATUS = mem::zeroed();
694 status.cbSize = mem::size_of_val(&status) as winapi::DWORD;
695
696 let verify_chain_policy_structure = wincrypt::CERT_CHAIN_POLICY_SSL as ntdef::LPCSTR;
697 let res = wincrypt::CertVerifyCertificateChainPolicy(verify_chain_policy_structure,
698 cert_chain.0,
699 &mut para,
700 &mut status);
701 if res == winapi::FALSE {
702 return Err(io::Error::last_os_error())
703 }
704
705 let mut verify_result = if status.dwError != winerror::ERROR_SUCCESS {
706 Err(io::Error::from_raw_os_error(status.dwError as i32))
707 } else {
708 Ok(())
709 };
710
711 // check if there's a user-specified verify callback
712 if let Some(ref callback) = self.verify_callback {
713 verify_result = callback(CertValidationResult{
714 chain: cert_chain,
715 res: status.dwError as i32,
716 chain_index: status.lChainIndex,
717 element_index: status.lElementIndex});
718 }
719 verify_result?;
720 }
721 Ok(true)
722 }
723
write_out(&mut self) -> io::Result<usize>724 fn write_out(&mut self) -> io::Result<usize> {
725 let mut out = 0;
726 while self.out_buf.position() as usize != self.out_buf.get_ref().len() {
727 let position = self.out_buf.position() as usize;
728 let nwritten = self.stream.write(&self.out_buf.get_ref()[position..])?;
729 out += nwritten;
730 self.out_buf.set_position((position + nwritten) as u64);
731 }
732
733 Ok(out)
734 }
735
read_in(&mut self) -> io::Result<usize>736 fn read_in(&mut self) -> io::Result<usize> {
737 let mut sum_nread = 0;
738
739 while self.needs_read > 0 {
740 let existing_len = self.enc_in.position() as usize;
741 let min_len = cmp::max(cmp::max(1024, 2 * existing_len), self.needs_read);
742 if self.enc_in.get_ref().len() < min_len {
743 self.enc_in.get_mut().resize(min_len, 0);
744 }
745 let nread = {
746 let buf = &mut self.enc_in.get_mut()[existing_len..];
747 self.stream.read(buf)?
748 };
749 self.enc_in.set_position((existing_len + nread) as u64);
750 self.needs_read = self.needs_read.saturating_sub(nread);
751 if nread == 0 {
752 break;
753 }
754 sum_nread += nread;
755 }
756
757 Ok(sum_nread)
758 }
759
consume_enc_in(&mut self, nread: usize)760 fn consume_enc_in(&mut self, nread: usize) {
761 let size = self.enc_in.position() as usize;
762 assert!(size >= nread);
763 let count = size - nread;
764
765 if count > 0 {
766 self.enc_in.get_mut().drain(..nread);
767 }
768
769 self.enc_in.set_position(count as u64);
770 }
771
decrypt(&mut self) -> io::Result<bool>772 fn decrypt(&mut self) -> io::Result<bool> {
773 unsafe {
774 let position = self.enc_in.position() as usize;
775 let mut bufs = [secbuf(sspi::SECBUFFER_DATA,
776 Some(&mut self.enc_in.get_mut()[..position])),
777 secbuf(sspi::SECBUFFER_EMPTY, None),
778 secbuf(sspi::SECBUFFER_EMPTY, None),
779 secbuf(sspi::SECBUFFER_EMPTY, None)];
780 let mut bufdesc = secbuf_desc(&mut bufs);
781
782 match sspi::DecryptMessage(self.context.get_mut(),
783 &mut bufdesc,
784 0,
785 ptr::null_mut()) {
786 winerror::SEC_E_OK => {
787 let start = bufs[1].pvBuffer as usize - self.enc_in.get_ref().as_ptr() as usize;
788 let end = start + bufs[1].cbBuffer as usize;
789 self.dec_in.get_mut().clear();
790 self.dec_in
791 .get_mut()
792 .extend_from_slice(&self.enc_in.get_ref()[start..end]);
793 self.dec_in.set_position(0);
794
795 let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA {
796 self.enc_in.position() as usize - bufs[3].cbBuffer as usize
797 } else {
798 self.enc_in.position() as usize
799 };
800 self.consume_enc_in(nread);
801 self.needs_read = (self.enc_in.position() == 0) as usize;
802 Ok(false)
803 }
804 winerror::SEC_E_INCOMPLETE_MESSAGE => {
805 self.needs_read = if bufs[1].BufferType == sspi::SECBUFFER_MISSING {
806 bufs[1].cbBuffer as usize
807 } else {
808 1
809 };
810 Ok(false)
811 }
812 winerror::SEC_I_CONTEXT_EXPIRED => Ok(true),
813 winerror::SEC_I_RENEGOTIATE => {
814 self.state = State::Initializing {
815 needs_flush: false,
816 more_calls: true,
817 shutting_down: false,
818 validated: false,
819 };
820
821 let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA {
822 self.enc_in.position() as usize - bufs[3].cbBuffer as usize
823 } else {
824 self.enc_in.position() as usize
825 };
826 self.consume_enc_in(nread);
827 self.needs_read = 0;
828 Ok(false)
829 }
830 e => Err(io::Error::from_raw_os_error(e as i32)),
831 }
832 }
833 }
834
encrypt(&mut self, buf: &[u8], sizes: &sspi::SecPkgContext_StreamSizes) -> io::Result<()>835 fn encrypt(&mut self, buf: &[u8], sizes: &sspi::SecPkgContext_StreamSizes) -> io::Result<()> {
836 assert!(buf.len() <= sizes.cbMaximumMessage as usize);
837
838 unsafe {
839 let len = sizes.cbHeader as usize + buf.len() + sizes.cbTrailer as usize;
840
841 if self.out_buf.get_ref().len() < len {
842 self.out_buf.get_mut().resize(len, 0);
843 }
844
845 let message_start = sizes.cbHeader as usize;
846 self.out_buf
847 .get_mut()[message_start..message_start + buf.len()]
848 .clone_from_slice(buf);
849
850 let mut bufs = {
851 let out_buf = self.out_buf.get_mut();
852 let size = sizes.cbHeader as usize;
853
854 let header = secbuf(sspi::SECBUFFER_STREAM_HEADER,
855 Some(&mut out_buf[..size]));
856 let data = secbuf(sspi::SECBUFFER_DATA,
857 Some(&mut out_buf[size..size + buf.len()]));
858 let trailer = secbuf(sspi::SECBUFFER_STREAM_TRAILER,
859 Some(&mut out_buf[size + buf.len()..]));
860 let empty = secbuf(sspi::SECBUFFER_EMPTY, None);
861 [header, data, trailer, empty]
862 };
863 let mut bufdesc = secbuf_desc(&mut bufs);
864
865 match sspi::EncryptMessage(self.context.get_mut(), 0, &mut bufdesc, 0) {
866 winerror::SEC_E_OK => {
867 let len = bufs[0].cbBuffer + bufs[1].cbBuffer + bufs[2].cbBuffer;
868 self.out_buf.get_mut().truncate(len as usize);
869 self.out_buf.set_position(0);
870 Ok(())
871 }
872 err => Err(io::Error::from_raw_os_error(err as i32)),
873 }
874 }
875 }
876 }
877
878 impl<S> MidHandshakeTlsStream<S> {
879 /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S880 pub fn get_ref(&self) -> &S {
881 self.inner.get_ref()
882 }
883
884 /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S885 pub fn get_mut(&mut self) -> &mut S {
886 self.inner.get_mut()
887 }
888 }
889
890 impl<S> MidHandshakeTlsStream<S>
891 where S: Read + Write,
892 {
893 /// Restarts the handshake process.
handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>>894 pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
895 match self.inner.initialize() {
896 Ok(_) => Ok(self.inner),
897 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
898 Err(HandshakeError::Interrupted(self))
899 }
900 Err(e) => Err(HandshakeError::Failure(e)),
901 }
902 }
903 }
904
905 impl<S> Write for TlsStream<S>
906 where S: Read + Write
907 {
908 /// In the case of a WouldBlock error, we expect another call
909 /// starting with the same input data
910 /// This is similar to the use of ACCEPT_MOVING_WRITE_BUFFER in openssl
write(&mut self, buf: &[u8]) -> io::Result<usize>911 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
912 let sizes = match self.initialize()? {
913 Some(sizes) => sizes,
914 None => return Err(io::Error::from_raw_os_error(winerror::SEC_E_CONTEXT_EXPIRED as i32)),
915 };
916
917 // if we have pending output data, it must have been because a previous
918 // attempt to send this part of the data ran into an error.
919 if self.out_buf.position() == self.out_buf.get_ref().len() as u64 {
920 let len = cmp::min(buf.len(), sizes.cbMaximumMessage as usize);
921 self.encrypt(&buf[..len], &sizes)?;
922 self.last_write_len = len;
923 }
924 self.write_out()?;
925
926 Ok(self.last_write_len)
927 }
928
flush(&mut self) -> io::Result<()>929 fn flush(&mut self) -> io::Result<()> {
930 // Make sure the write buffer is emptied
931 self.write_out()?;
932 self.stream.flush()
933 }
934 }
935
936 impl<S> Read for TlsStream<S>
937 where S: Read + Write
938 {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>939 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
940 let nread = {
941 let read_buf = self.fill_buf()?;
942 let nread = cmp::min(buf.len(), read_buf.len());
943 buf[..nread].copy_from_slice(&read_buf[..nread]);
944 nread
945 };
946 self.consume(nread);
947 Ok(nread)
948 }
949 }
950
951 impl<S> BufRead for TlsStream<S>
952 where S: Read + Write
953 {
fill_buf(&mut self) -> io::Result<&[u8]>954 fn fill_buf(&mut self) -> io::Result<&[u8]> {
955 while self.get_buf().is_empty() {
956 if let None = self.initialize()? {
957 break;
958 }
959
960 if self.needs_read > 0 {
961 if self.read_in()? == 0 {
962 break;
963 }
964 self.needs_read = 0;
965 }
966
967 let eof = self.decrypt()?;
968 if eof {
969 break;
970 }
971 }
972
973 Ok(self.get_buf())
974 }
975
consume(&mut self, amt: usize)976 fn consume(&mut self, amt: usize) {
977 let pos = self.dec_in.position() + amt as u64;
978 assert!(pos <= self.dec_in.get_ref().len() as u64);
979 self.dec_in.set_position(pos);
980 }
981 }
982