1 //! SSL/TLS encryption support using Secure Transport.
2 //!
3 //! # Examples
4 //!
5 //! To connect as a client to a server with a certificate trusted by the system:
6 //!
7 //! ```rust
8 //! use std::io::prelude::*;
9 //! use std::net::TcpStream;
10 //! use security_framework::secure_transport::ClientBuilder;
11 //!
12 //! let stream = TcpStream::connect("google.com:443").unwrap();
13 //! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
14 //!
15 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
16 //! let mut page = vec![];
17 //! stream.read_to_end(&mut page).unwrap();
18 //! println!("{}", String::from_utf8_lossy(&page));
19 //! ```
20 //!
21 //! To connect to a server with a certificate that's *not* trusted by the
22 //! system, specify the root certificates for the server's chain to the
23 //! `ClientBuilder`:
24 //!
25 //! ```rust,no_run
26 //! use std::io::prelude::*;
27 //! use std::net::TcpStream;
28 //! use security_framework::secure_transport::ClientBuilder;
29 //!
30 //! # let root_cert = unsafe { std::mem::zeroed() };
31 //! let stream = TcpStream::connect("my_server.com:443").unwrap();
32 //! let mut stream = ClientBuilder::new()
33 //!                      .anchor_certificates(&[root_cert])
34 //!                      .handshake("my_server.com", stream)
35 //!                      .unwrap();
36 //!
37 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
38 //! let mut page = vec![];
39 //! stream.read_to_end(&mut page).unwrap();
40 //! println!("{}", String::from_utf8_lossy(&page));
41 //! ```
42 //!
43 //! For more advanced configuration, the `SslContext` type can be used directly.
44 //!
45 //! To run a server:
46 //!
47 //! ```rust,no_run
48 //! use std::net::TcpListener;
49 //! use std::thread;
50 //! use security_framework::secure_transport::{SslContext, SslProtocolSide, SslConnectionType};
51 //!
52 //! // Create a TCP listener and start accepting on it.
53 //! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
54 //!
55 //! for stream in listener.incoming() {
56 //!     let stream = stream.unwrap();
57 //!     thread::spawn(move || {
58 //!         // Create a new context configured to operate on the server side of
59 //!         // a traditional SSL/TLS session.
60 //!         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
61 //!                           .unwrap();
62 //!
63 //!         // Install the certificate chain that we will be using.
64 //!         # let identity = unsafe { std::mem::zeroed() };
65 //!         # let intermediate_cert = unsafe { std::mem::zeroed() };
66 //!         # let root_cert = unsafe { std::mem::zeroed() };
67 //!         ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
68 //!
69 //!         // Perform the SSL/TLS handshake and get our stream.
70 //!         let mut stream = ctx.handshake(stream).unwrap();
71 //!     });
72 //! }
73 //!
74 //! ```
75 #[allow(unused_imports)]
76 use core_foundation::array::{CFArray, CFArrayRef};
77 
78 use core_foundation::base::{Boolean, TCFType};
79 #[cfg(feature = "alpn")]
80 use core_foundation::string::CFString;
81 use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82 use std::os::raw::c_void;
83 
84 #[allow(unused_imports)]
85 use security_framework_sys::base::{
86     errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87     errSecUnimplemented,
88 };
89 
90 use security_framework_sys::secure_transport::*;
91 use std::any::Any;
92 use std::cmp;
93 use std::fmt;
94 use std::io;
95 use std::io::prelude::*;
96 use std::marker::PhantomData;
97 use std::mem;
98 use std::panic::{self, AssertUnwindSafe};
99 use std::ptr;
100 use std::result;
101 use std::slice;
102 
103 use crate::base::{Error, Result};
104 use crate::certificate::SecCertificate;
105 use crate::cipher_suite::CipherSuite;
106 use crate::identity::SecIdentity;
107 use crate::policy::SecPolicy;
108 use crate::trust::{SecTrust, TrustResult};
109 use crate::{cvt, AsInner};
110 
111 /// Specifies a side of a TLS session.
112 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
113 pub struct SslProtocolSide(SSLProtocolSide);
114 
115 impl SslProtocolSide {
116     /// The server side of the session.
117     pub const SERVER: SslProtocolSide = SslProtocolSide(kSSLServerSide);
118 
119     /// The client side of the session.
120     pub const CLIENT: SslProtocolSide = SslProtocolSide(kSSLClientSide);
121 }
122 
123 /// Specifies the type of TLS session.
124 #[derive(Debug, Copy, Clone)]
125 pub struct SslConnectionType(SSLConnectionType);
126 
127 impl SslConnectionType {
128     /// A traditional TLS stream.
129     pub const STREAM: SslConnectionType = SslConnectionType(kSSLStreamType);
130 
131     /// A DTLS session.
132     pub const DATAGRAM: SslConnectionType = SslConnectionType(kSSLDatagramType);
133 }
134 
135 /// An error or intermediate state after a TLS handshake attempt.
136 #[derive(Debug)]
137 pub enum HandshakeError<S> {
138     /// The handshake failed.
139     Failure(Error),
140     /// The handshake was interrupted midway through.
141     Interrupted(MidHandshakeSslStream<S>),
142 }
143 
144 impl<S> From<Error> for HandshakeError<S> {
145     fn from(err: Error) -> HandshakeError<S> {
146         HandshakeError::Failure(err)
147     }
148 }
149 
150 /// An error or intermediate state after a TLS handshake attempt.
151 #[derive(Debug)]
152 pub enum ClientHandshakeError<S> {
153     /// The handshake failed.
154     Failure(Error),
155     /// The handshake was interrupted midway through.
156     Interrupted(MidHandshakeClientBuilder<S>),
157 }
158 
159 impl<S> From<Error> for ClientHandshakeError<S> {
160     fn from(err: Error) -> ClientHandshakeError<S> {
161         ClientHandshakeError::Failure(err)
162     }
163 }
164 
165 /// An SSL stream midway through the handshake process.
166 #[derive(Debug)]
167 pub struct MidHandshakeSslStream<S> {
168     stream: SslStream<S>,
169     error: Error,
170 }
171 
172 impl<S> MidHandshakeSslStream<S> {
173     /// Returns a shared reference to the inner stream.
174     pub fn get_ref(&self) -> &S {
175         self.stream.get_ref()
176     }
177 
178     /// Returns a mutable reference to the inner stream.
179     pub fn get_mut(&mut self) -> &mut S {
180         self.stream.get_mut()
181     }
182 
183     /// Returns a shared reference to the `SslContext` of the stream.
184     pub fn context(&self) -> &SslContext {
185         self.stream.context()
186     }
187 
188     /// Returns a mutable reference to the `SslContext` of the stream.
189     pub fn context_mut(&mut self) -> &mut SslContext {
190         self.stream.context_mut()
191     }
192 
193     /// Returns `true` iff `break_on_server_auth` was set and the handshake has
194     /// progressed to that point.
195     pub fn server_auth_completed(&self) -> bool {
196         self.error.code() == errSSLPeerAuthCompleted
197     }
198 
199     /// Returns `true` iff `break_on_cert_requested` was set and the handshake
200     /// has progressed to that point.
201     pub fn client_cert_requested(&self) -> bool {
202         self.error.code() == errSSLClientCertRequested
203     }
204 
205     /// Returns `true` iff the underlying stream returned an error with the
206     /// `WouldBlock` kind.
207     pub fn would_block(&self) -> bool {
208         self.error.code() == errSSLWouldBlock
209     }
210 
211     /// Deprecated
212     pub fn reason(&self) -> OSStatus {
213         self.error.code()
214     }
215 
216     /// Returns the error which caused the handshake interruption.
217     pub fn error(&self) -> &Error {
218         &self.error
219     }
220 
221     /// Restarts the handshake process.
222     pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
223         self.stream.handshake()
224     }
225 }
226 
227 /// An SSL stream midway through the handshake process.
228 #[derive(Debug)]
229 pub struct MidHandshakeClientBuilder<S> {
230     stream: MidHandshakeSslStream<S>,
231     domain: Option<String>,
232     certs: Vec<SecCertificate>,
233     trust_certs_only: bool,
234     danger_accept_invalid_certs: bool,
235 }
236 
237 impl<S> MidHandshakeClientBuilder<S> {
238     /// Returns a shared reference to the inner stream.
239     pub fn get_ref(&self) -> &S {
240         self.stream.get_ref()
241     }
242 
243     /// Returns a mutable reference to the inner stream.
244     pub fn get_mut(&mut self) -> &mut S {
245         self.stream.get_mut()
246     }
247 
248     /// Returns the error which caused the handshake interruption.
249     pub fn error(&self) -> &Error {
250         self.stream.error()
251     }
252 
253     /// Restarts the handshake process.
254     pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
255         let MidHandshakeClientBuilder {
256             stream,
257             domain,
258             certs,
259             trust_certs_only,
260             danger_accept_invalid_certs,
261         } = self;
262 
263         let mut result = stream.handshake();
264         loop {
265             let stream = match result {
266                 Ok(stream) => return Ok(stream),
267                 Err(HandshakeError::Interrupted(stream)) => stream,
268                 Err(HandshakeError::Failure(err)) => {
269                     return Err(ClientHandshakeError::Failure(err))
270                 }
271             };
272 
273             if stream.would_block() {
274                 let ret = MidHandshakeClientBuilder {
275                     stream,
276                     domain,
277                     certs,
278                     trust_certs_only,
279                     danger_accept_invalid_certs,
280                 };
281                 return Err(ClientHandshakeError::Interrupted(ret));
282             }
283 
284             if stream.server_auth_completed() {
285                 if danger_accept_invalid_certs {
286                     result = stream.handshake();
287                     continue;
288                 }
289                 let mut trust = match stream.context().peer_trust2()? {
290                     Some(trust) => trust,
291                     None => {
292                         result = stream.handshake();
293                         continue;
294                     }
295                 };
296                 trust.set_anchor_certificates(&certs)?;
297                 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
298                 let policy =
299                     SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_ref().map(|s| &**s));
300                 trust.set_policy(&policy)?;
301                 let trusted = trust.evaluate()?;
302                 match trusted {
303                     TrustResult::PROCEED | TrustResult::UNSPECIFIED => {
304                         result = stream.handshake();
305                         continue;
306                     }
307                     TrustResult::DENY => {
308                         let err = Error::from_code(errSecTrustSettingDeny);
309                         return Err(ClientHandshakeError::Failure(err));
310                     }
311                     TrustResult::RECOVERABLE_TRUST_FAILURE | TrustResult::FATAL_TRUST_FAILURE => {
312                         let err = Error::from_code(errSecNotTrusted);
313                         return Err(ClientHandshakeError::Failure(err));
314                     }
315                     TrustResult::INVALID | TrustResult::OTHER_ERROR | _ => {
316                         let err = Error::from_code(errSecBadReq);
317                         return Err(ClientHandshakeError::Failure(err));
318                     }
319                 }
320             }
321 
322             let err = Error::from_code(stream.reason());
323             return Err(ClientHandshakeError::Failure(err));
324         }
325     }
326 }
327 
328 /// Specifies the state of a TLS session.
329 #[derive(Debug, PartialEq, Eq)]
330 pub struct SessionState(SSLSessionState);
331 
332 impl SessionState {
333     /// The session has not yet started.
334     pub const IDLE: SessionState = SessionState(kSSLIdle);
335 
336     /// The session is in the handshake process.
337     pub const HANDSHAKE: SessionState = SessionState(kSSLHandshake);
338 
339     /// The session is connected.
340     pub const CONNECTED: SessionState = SessionState(kSSLConnected);
341 
342     /// The session has been terminated.
343     pub const CLOSED: SessionState = SessionState(kSSLClosed);
344 
345     /// The session has been aborted due to an error.
346     pub const ABORTED: SessionState = SessionState(kSSLAborted);
347 }
348 
349 /// Specifies a server's requirement for client certificates.
350 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
351 pub struct SslAuthenticate(SSLAuthenticate);
352 
353 impl SslAuthenticate {
354     /// Do not request a client certificate.
355     pub const NEVER: SslAuthenticate = SslAuthenticate(kNeverAuthenticate);
356 
357     /// Require a client certificate.
358     pub const ALWAYS: SslAuthenticate = SslAuthenticate(kAlwaysAuthenticate);
359 
360     /// Request but do not require a client certificate.
361     pub const TRY: SslAuthenticate = SslAuthenticate(kTryAuthenticate);
362 }
363 
364 /// Specifies the state of client certificate processing.
365 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
366 pub struct SslClientCertificateState(SSLClientCertificateState);
367 
368 impl SslClientCertificateState {
369     /// A client certificate has not been requested or sent.
370     pub const NONE: SslClientCertificateState = SslClientCertificateState(kSSLClientCertNone);
371 
372     /// A client certificate has been requested but not recieved.
373     pub const REQUESTED: SslClientCertificateState =
374         SslClientCertificateState(kSSLClientCertRequested);
375 
376     /// A client certificate has been received and successfully validated.
377     pub const SENT: SslClientCertificateState = SslClientCertificateState(kSSLClientCertSent);
378 
379     /// A client certificate has been received but has failed to validate.
380     pub const REJECTED: SslClientCertificateState =
381         SslClientCertificateState(kSSLClientCertRejected);
382 }
383 
384 /// Specifies protocol versions.
385 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
386 pub struct SslProtocol(SSLProtocol);
387 
388 impl SslProtocol {
389     /// No protocol has been or should be negotiated or specified; use the default.
390     pub const UNKNOWN: SslProtocol = SslProtocol(kSSLProtocolUnknown);
391 
392     /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
393     /// SSL 3.0.
394     pub const SSL3: SslProtocol = SslProtocol(kSSLProtocol3);
395 
396     /// The TLS 1.0 protocol is preferred, though lower versions may be used
397     /// if the peer does not support TLS 1.0.
398     pub const TLS1: SslProtocol = SslProtocol(kTLSProtocol1);
399 
400     /// The TLS 1.1 protocol is preferred, though lower versions may be used
401     /// if the peer does not support TLS 1.1.
402     pub const TLS11: SslProtocol = SslProtocol(kTLSProtocol11);
403 
404     /// The TLS 1.2 protocol is preferred, though lower versions may be used
405     /// if the peer does not support TLS 1.2.
406     pub const TLS12: SslProtocol = SslProtocol(kTLSProtocol12);
407 
408     /// Only the SSL 2.0 protocol is accepted.
409     pub const SSL2: SslProtocol = SslProtocol(kSSLProtocol2);
410 
411     /// The DTLSv1 protocol is preferred.
412     pub const DTLS1: SslProtocol = SslProtocol(kDTLSProtocol1);
413 
414     /// Only the SSL 3.0 protocol is accepted.
415     pub const SSL3_ONLY: SslProtocol = SslProtocol(kSSLProtocol3Only);
416 
417     /// Only the TLS 1.0 protocol is accepted.
418     pub const TLS1_ONLY: SslProtocol = SslProtocol(kTLSProtocol1Only);
419 
420     /// All supported TLS/SSL versions are accepted.
421     pub const ALL: SslProtocol = SslProtocol(kSSLProtocolAll);
422 }
423 
424 declare_TCFType! {
425     /// A Secure Transport SSL/TLS context object.
426     SslContext, SSLContextRef
427 }
428 
429 impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
430 
431 impl fmt::Debug for SslContext {
432     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
433         let mut builder = fmt.debug_struct("SslContext");
434         if let Ok(state) = self.state() {
435             builder.field("state", &state);
436         }
437         builder.finish()
438     }
439 }
440 
441 unsafe impl Sync for SslContext {}
442 unsafe impl Send for SslContext {}
443 
444 impl AsInner for SslContext {
445     type Inner = SSLContextRef;
446 
447     fn as_inner(&self) -> SSLContextRef {
448         self.0
449     }
450 }
451 
452 macro_rules! impl_options {
453     ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
454         $(
455             $(#[$a])*
456             pub fn $set(&mut self, value: bool) -> Result<()> {
457                 unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
458             }
459 
460             $(#[$a])*
461             pub fn $get(&self) -> Result<bool> {
462                 let mut value = 0;
463                 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
464                 Ok(value != 0)
465             }
466         )*
467     }
468 }
469 
470 impl SslContext {
471     /// Creates a new `SslContext` for the specified side and type of SSL
472     /// connection.
473     pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<SslContext> {
474         unsafe {
475             let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
476             Ok(SslContext(ctx))
477         }
478     }
479 
480     /// Sets the fully qualified domain name of the peer.
481     ///
482     /// This will be used on the client side of a session to validate the
483     /// common name field of the server's certificate. It has no effect if
484     /// called on a server-side `SslContext`.
485     ///
486     /// It is *highly* recommended to call this method before starting the
487     /// handshake process.
488     pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
489         unsafe {
490             // SSLSetPeerDomainName doesn't need a null terminated string
491             cvt(SSLSetPeerDomainName(
492                 self.0,
493                 peer_name.as_ptr() as *const _,
494                 peer_name.len(),
495             ))
496         }
497     }
498 
499     /// Returns the peer domain name set by `set_peer_domain_name`.
500     pub fn peer_domain_name(&self) -> Result<String> {
501         unsafe {
502             let mut len = 0;
503             cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
504             let mut buf = vec![0; len];
505             cvt(SSLGetPeerDomainName(
506                 self.0,
507                 buf.as_mut_ptr() as *mut _,
508                 &mut len,
509             ))?;
510             Ok(String::from_utf8(buf).unwrap())
511         }
512     }
513 
514     /// Sets the certificate to be used by this side of the SSL session.
515     ///
516     /// This must be called before the handshake for server-side connections,
517     /// and can be used on the client-side to specify a client certificate.
518     ///
519     /// The `identity` corresponds to the leaf certificate and private
520     /// key, and the `certs` correspond to extra certificates in the chain.
521     pub fn set_certificate(
522         &mut self,
523         identity: &SecIdentity,
524         certs: &[SecCertificate],
525     ) -> Result<()> {
526         let mut arr = vec![identity.as_CFType()];
527         arr.extend(certs.iter().map(|c| c.as_CFType()));
528         let certs = CFArray::from_CFTypes(&arr);
529 
530         unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
531     }
532 
533     /// Sets the peer ID of this session.
534     ///
535     /// A peer ID is an opaque sequence of bytes that will be used by Secure
536     /// Transport to identify the peer of an SSL session. If the peer ID of
537     /// this session matches that of a previously terminated session, the
538     /// previous session can be resumed without requiring a full handshake.
539     pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
540         unsafe {
541             cvt(SSLSetPeerID(
542                 self.0,
543                 peer_id.as_ptr() as *const _,
544                 peer_id.len(),
545             ))
546         }
547     }
548 
549     /// Returns the peer ID of this session.
550     pub fn peer_id(&self) -> Result<Option<&[u8]>> {
551         unsafe {
552             let mut ptr = ptr::null();
553             let mut len = 0;
554             cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
555             if ptr.is_null() {
556                 Ok(None)
557             } else {
558                 Ok(Some(slice::from_raw_parts(ptr as *const _, len)))
559             }
560         }
561     }
562 
563     /// Returns the list of ciphers that are supported by Secure Transport.
564     pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
565         unsafe {
566             let mut num_ciphers = 0;
567             cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
568             let mut ciphers = vec![0; num_ciphers];
569             cvt(SSLGetSupportedCiphers(
570                 self.0,
571                 ciphers.as_mut_ptr(),
572                 &mut num_ciphers,
573             ))?;
574             Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
575         }
576     }
577 
578     /// Returns the list of ciphers that are eligible to be used for
579     /// negotiation.
580     pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
581         unsafe {
582             let mut num_ciphers = 0;
583             cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
584             let mut ciphers = vec![0; num_ciphers];
585             cvt(SSLGetEnabledCiphers(
586                 self.0,
587                 ciphers.as_mut_ptr(),
588                 &mut num_ciphers,
589             ))?;
590             Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
591         }
592     }
593 
594     /// Sets the list of ciphers that are eligible to be used for negotiation.
595     pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
596         let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
597         unsafe {
598             cvt(SSLSetEnabledCiphers(
599                 self.0,
600                 ciphers.as_ptr(),
601                 ciphers.len(),
602             ))
603         }
604     }
605 
606     /// Returns the cipher being used by the session.
607     pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
608         unsafe {
609             let mut cipher = 0;
610             cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
611             Ok(CipherSuite::from_raw(cipher))
612         }
613     }
614 
615     /// Sets the requirements for client certificates.
616     ///
617     /// Should only be called on server-side sessions.
618     pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
619         unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
620     }
621 
622     /// Returns the state of client certificate processing.
623     pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
624         let mut state = 0;
625 
626         unsafe {
627             cvt(SSLGetClientCertificateState(self.0, &mut state))?;
628         }
629         Ok(SslClientCertificateState(state))
630     }
631 
632     /// Returns the `SecTrust` object corresponding to the peer.
633     ///
634     /// This can be used in conjunction with `set_break_on_server_auth` to
635     /// validate certificates which do not have roots in the default set.
636     pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
637         // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
638         // so explicitly check for that
639         if self.state()? == SessionState::IDLE {
640             return Err(Error::from_code(errSecBadReq));
641         }
642 
643         unsafe {
644             let mut trust = ptr::null_mut();
645             cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
646             if trust.is_null() {
647                 Ok(None)
648             } else {
649                 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
650             }
651         }
652     }
653 
654     #[allow(missing_docs)]
655     #[deprecated(since = "0.2.1", note = "use peer_trust2 instead")]
656     pub fn peer_trust(&self) -> Result<SecTrust> {
657         match self.peer_trust2() {
658             Ok(Some(trust)) => Ok(trust),
659             Ok(None) => panic!("no trust available"),
660             Err(e) => Err(e),
661         }
662     }
663 
664     /// Returns the state of the session.
665     pub fn state(&self) -> Result<SessionState> {
666         unsafe {
667             let mut state = 0;
668             cvt(SSLGetSessionState(self.0, &mut state))?;
669             Ok(SessionState(state))
670         }
671     }
672 
673     /// Returns the protocol version being used by the session.
674     pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
675         unsafe {
676             let mut version = 0;
677             cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
678             Ok(SslProtocol(version))
679         }
680     }
681 
682     /// Returns the maximum protocol version allowed by the session.
683     pub fn protocol_version_max(&self) -> Result<SslProtocol> {
684         unsafe {
685             let mut version = 0;
686             cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
687             Ok(SslProtocol(version))
688         }
689     }
690 
691     /// Sets the maximum protocol version allowed by the session.
692     pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
693         unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
694     }
695 
696     /// Returns the minimum protocol version allowed by the session.
697     pub fn protocol_version_min(&self) -> Result<SslProtocol> {
698         unsafe {
699             let mut version = 0;
700             cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
701             Ok(SslProtocol(version))
702         }
703     }
704 
705     /// Sets the minimum protocol version allowed by the session.
706     pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
707         unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
708     }
709 
710     /// Returns the set of protocols selected via ALPN if it succeeded.
711     #[cfg(feature = "alpn")]
712     pub fn alpn_protocols(&self) -> Result<Vec<String>> {
713         let mut array = ptr::null();
714         unsafe {
715             #[cfg(feature = "OSX_10_13")]
716             {
717                 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
718             }
719 
720             #[cfg(not(feature = "OSX_10_13"))]
721             {
722                 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
723                 if let Some(f) = SSLCopyALPNProtocols.get() {
724                     cvt(f(self.0, &mut array))?;
725                 } else {
726                     return Err(Error::from_code(errSecUnimplemented));
727                 }
728             }
729 
730             if array.is_null() {
731                 return Ok(vec![]);
732             }
733 
734             let array = CFArray::<CFString>::wrap_under_create_rule(array);
735             Ok(array.into_iter().map(|p| p.to_string()).collect())
736         }
737     }
738 
739     /// Configures the set of protocols use for ALPN.
740     ///
741     /// This is only used for client-side connections.
742     #[cfg(feature = "alpn")]
743     pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
744         // When CFMutableArray is added to core-foundation and IntoIterator trait
745         // is implemented for CFMutableArray, the code below should directly collect
746         // into a CFMutableArray.
747         let protocols = CFArray::from_CFTypes(
748             &protocols
749                 .iter()
750                 .map(|proto| CFString::new(proto))
751                 .collect::<Vec<_>>(),
752         );
753 
754         #[cfg(feature = "OSX_10_13")]
755         {
756             unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
757         }
758         #[cfg(not(feature = "OSX_10_13"))]
759         {
760             dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
761             if let Some(f) = SSLSetALPNProtocols.get() {
762                 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
763             } else {
764                 Err(Error::from_code(errSecUnimplemented))
765             }
766         }
767     }
768 
769     /// Sets whether a protocol is enabled or not.
770     ///
771     /// # Note
772     ///
773     /// On OSX this is a deprecated API in favor of `set_protocol_version_max` and
774     /// `set_protocol_version_min`, although if you're working with OSX 10.8 or before you may have
775     /// to use this API instead.
776     #[cfg(target_os = "macos")]
777     pub fn set_protocol_version_enabled(
778         &mut self,
779         protocol: SslProtocol,
780         enabled: bool,
781     ) -> Result<()> {
782         unsafe {
783             cvt(SSLSetProtocolVersionEnabled(
784                 self.0,
785                 protocol.0,
786                 enabled as Boolean,
787             ))
788         }
789     }
790 
791     /// Returns the number of bytes which can be read without triggering a
792     /// `read` call in the underlying stream.
793     pub fn buffered_read_size(&self) -> Result<usize> {
794         unsafe {
795             let mut size = 0;
796             cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
797             Ok(size)
798         }
799     }
800 
801     impl_options! {
802         /// If enabled, the handshake process will pause and return instead of
803         /// automatically validating a server's certificate.
804         const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
805         /// If enabled, the handshake process will pause and return after
806         /// the server requests a certificate from the client.
807         const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
808         /// If enabled, the handshake process will pause and return instead of
809         /// automatically validating a client's certificate.
810         const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
811         /// If enabled, TLS false start will be performed if an appropriate
812         /// cipher suite is negotiated.
813         ///
814         /// Requires the `OSX_10_9` (or greater) feature.
815         #[cfg(feature = "OSX_10_9")]
816         const kSSLSessionOptionFalseStart: false_start & set_false_start,
817         /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
818         /// connections using block ciphers to mitigate the BEAST attack.
819         ///
820         /// Requires the `OSX_10_9` (or greater) feature.
821         #[cfg(feature = "OSX_10_9")]
822         const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
823     }
824 
825     fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
826     where
827         S: Read + Write,
828     {
829         unsafe {
830             let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
831             if ret != errSecSuccess {
832                 return Err(Error::from_code(ret));
833             }
834 
835             let stream = Connection {
836                 stream,
837                 err: None,
838                 panic: None,
839             };
840             let stream = Box::into_raw(Box::new(stream));
841             let ret = SSLSetConnection(self.0, stream as *mut _);
842             if ret != errSecSuccess {
843                 let _conn = Box::from_raw(stream);
844                 return Err(Error::from_code(ret));
845             }
846 
847             Ok(SslStream {
848                 ctx: self,
849                 _m: PhantomData,
850             })
851         }
852     }
853 
854     /// Performs the SSL/TLS handshake.
855     pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
856     where
857         S: Read + Write,
858     {
859         self.into_stream(stream)
860             .map_err(HandshakeError::Failure)
861             .and_then(SslStream::handshake)
862     }
863 }
864 
865 struct Connection<S> {
866     stream: S,
867     err: Option<io::Error>,
868     panic: Option<Box<dyn Any + Send>>,
869 }
870 
871 // the logic here is based off of libcurl's
872 
873 fn translate_err(e: &io::Error) -> OSStatus {
874     match e.kind() {
875         io::ErrorKind::NotFound => errSSLClosedGraceful,
876         io::ErrorKind::ConnectionReset => errSSLClosedAbort,
877         io::ErrorKind::WouldBlock => errSSLWouldBlock,
878         io::ErrorKind::NotConnected => errSSLWouldBlock,
879         _ => errSecIO,
880     }
881 }
882 
883 unsafe extern "C" fn read_func<S>(
884     connection: SSLConnectionRef,
885     data: *mut c_void,
886     data_length: *mut usize,
887 ) -> OSStatus
888 where
889     S: Read,
890 {
891     let conn: &mut Connection<S> = &mut *(connection as *mut _);
892     let data = slice::from_raw_parts_mut(data as *mut u8, *data_length);
893     let mut start = 0;
894     let mut ret = errSecSuccess;
895 
896     while start < data.len() {
897         match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
898             Ok(Ok(0)) => {
899                 ret = errSSLClosedNoNotify;
900                 break;
901             }
902             Ok(Ok(len)) => start += len,
903             Ok(Err(e)) => {
904                 ret = translate_err(&e);
905                 conn.err = Some(e);
906                 break;
907             }
908             Err(e) => {
909                 ret = errSecIO;
910                 conn.panic = Some(e);
911                 break;
912             }
913         }
914     }
915 
916     *data_length = start;
917     ret
918 }
919 
920 unsafe extern "C" fn write_func<S>(
921     connection: SSLConnectionRef,
922     data: *const c_void,
923     data_length: *mut usize,
924 ) -> OSStatus
925 where
926     S: Write,
927 {
928     let conn: &mut Connection<S> = &mut *(connection as *mut _);
929     let data = slice::from_raw_parts(data as *mut u8, *data_length);
930     let mut start = 0;
931     let mut ret = errSecSuccess;
932 
933     while start < data.len() {
934         match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
935             Ok(Ok(0)) => {
936                 ret = errSSLClosedNoNotify;
937                 break;
938             }
939             Ok(Ok(len)) => start += len,
940             Ok(Err(e)) => {
941                 ret = translate_err(&e);
942                 conn.err = Some(e);
943                 break;
944             }
945             Err(e) => {
946                 ret = errSecIO;
947                 conn.panic = Some(e);
948                 break;
949             }
950         }
951     }
952 
953     *data_length = start;
954     ret
955 }
956 
957 /// A type implementing SSL/TLS encryption over an underlying stream.
958 pub struct SslStream<S> {
959     ctx: SslContext,
960     _m: PhantomData<S>,
961 }
962 
963 impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
964     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
965         fmt.debug_struct("SslStream")
966             .field("context", &self.ctx)
967             .field("stream", self.get_ref())
968             .finish()
969     }
970 }
971 
972 impl<S> Drop for SslStream<S> {
973     fn drop(&mut self) {
974         unsafe {
975             let mut conn = ptr::null();
976             let ret = SSLGetConnection(self.ctx.0, &mut conn);
977             assert!(ret == errSecSuccess);
978             Box::<Connection<S>>::from_raw(conn as *mut _);
979         }
980     }
981 }
982 
983 impl<S> SslStream<S> {
984     fn handshake(mut self) -> result::Result<SslStream<S>, HandshakeError<S>> {
985         match unsafe { SSLHandshake(self.ctx.0) } {
986             errSecSuccess => Ok(self),
987             reason @ errSSLPeerAuthCompleted
988             | reason @ errSSLClientCertRequested
989             | reason @ errSSLWouldBlock
990             | reason @ errSSLClientHelloReceived => {
991                 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
992                     stream: self,
993                     error: Error::from_code(reason),
994                 }))
995             }
996             err => {
997                 self.check_panic();
998                 Err(HandshakeError::Failure(Error::from_code(err)))
999             }
1000         }
1001     }
1002 
1003     /// Returns a shared reference to the inner stream.
1004     pub fn get_ref(&self) -> &S {
1005         &self.connection().stream
1006     }
1007 
1008     /// Returns a mutable reference to the underlying stream.
1009     pub fn get_mut(&mut self) -> &mut S {
1010         &mut self.connection_mut().stream
1011     }
1012 
1013     /// Returns a shared reference to the `SslContext` of the stream.
1014     pub fn context(&self) -> &SslContext {
1015         &self.ctx
1016     }
1017 
1018     /// Returns a mutable reference to the `SslContext` of the stream.
1019     pub fn context_mut(&mut self) -> &mut SslContext {
1020         &mut self.ctx
1021     }
1022 
1023     /// Shuts down the connection.
1024     pub fn close(&mut self) -> result::Result<(), io::Error> {
1025         unsafe {
1026             let ret = SSLClose(self.ctx.0);
1027             if ret == errSecSuccess {
1028                 Ok(())
1029             } else {
1030                 Err(self.get_error(ret))
1031             }
1032         }
1033     }
1034 
1035     fn connection(&self) -> &Connection<S> {
1036         unsafe {
1037             let mut conn = ptr::null();
1038             let ret = SSLGetConnection(self.ctx.0, &mut conn);
1039             assert!(ret == errSecSuccess);
1040 
1041             mem::transmute(conn)
1042         }
1043     }
1044 
1045     fn connection_mut(&mut self) -> &mut Connection<S> {
1046         unsafe {
1047             let mut conn = ptr::null();
1048             let ret = SSLGetConnection(self.ctx.0, &mut conn);
1049             assert!(ret == errSecSuccess);
1050 
1051             mem::transmute(conn)
1052         }
1053     }
1054 
1055     fn check_panic(&mut self) {
1056         let conn = self.connection_mut();
1057         if let Some(err) = conn.panic.take() {
1058             panic::resume_unwind(err);
1059         }
1060     }
1061 
1062     fn get_error(&mut self, ret: OSStatus) -> io::Error {
1063         self.check_panic();
1064 
1065         if let Some(err) = self.connection_mut().err.take() {
1066             err
1067         } else {
1068             io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1069         }
1070     }
1071 }
1072 
1073 impl<S: Read + Write> Read for SslStream<S> {
1074     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1075         // Below we base our return value off the amount of data read, so a
1076         // zero-length buffer might cause us to erroneously interpret this
1077         // request as an error. Instead short-circuit that logic and return
1078         // `Ok(0)` instead.
1079         if buf.is_empty() {
1080             return Ok(0);
1081         }
1082 
1083         // If some data was buffered but not enough to fill `buf`, SSLRead
1084         // will try to read a new packet. This is bad because there may be
1085         // no more data but the socket is remaining open (e.g HTTPS with
1086         // Connection: keep-alive).
1087         let buffered = self.context().buffered_read_size().unwrap_or(0);
1088         let mut to_read = buf.len();
1089         if buffered > 0 {
1090             to_read = cmp::min(buffered, buf.len());
1091         }
1092 
1093         unsafe {
1094             let mut nread = 0;
1095             let ret = SSLRead(self.ctx.0, buf.as_mut_ptr() as *mut _, to_read, &mut nread);
1096             // SSLRead can return an error at the same time it returns the last
1097             // chunk of data (!)
1098             if nread > 0 {
1099                 return Ok(nread as usize);
1100             }
1101 
1102             match ret {
1103                 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1104                 _ => Err(self.get_error(ret)),
1105             }
1106         }
1107     }
1108 }
1109 
1110 impl<S: Read + Write> Write for SslStream<S> {
1111     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1112         // Like above in read, short circuit a 0-length write
1113         if buf.is_empty() {
1114             return Ok(0);
1115         }
1116         unsafe {
1117             let mut nwritten = 0;
1118             let ret = SSLWrite(
1119                 self.ctx.0,
1120                 buf.as_ptr() as *const _,
1121                 buf.len(),
1122                 &mut nwritten,
1123             );
1124             // just to be safe, base success off of nwritten rather than ret
1125             // for the same reason as in read
1126             if nwritten > 0 {
1127                 Ok(nwritten as usize)
1128             } else {
1129                 Err(self.get_error(ret))
1130             }
1131         }
1132     }
1133 
1134     fn flush(&mut self) -> io::Result<()> {
1135         self.connection_mut().stream.flush()
1136     }
1137 }
1138 
1139 /// A builder type to simplify the creation of client side `SslStream`s.
1140 #[derive(Debug)]
1141 pub struct ClientBuilder {
1142     identity: Option<SecIdentity>,
1143     certs: Vec<SecCertificate>,
1144     chain: Vec<SecCertificate>,
1145     protocol_min: Option<SslProtocol>,
1146     protocol_max: Option<SslProtocol>,
1147     trust_certs_only: bool,
1148     use_sni: bool,
1149     danger_accept_invalid_certs: bool,
1150     danger_accept_invalid_hostnames: bool,
1151     whitelisted_ciphers: Vec<CipherSuite>,
1152     blacklisted_ciphers: Vec<CipherSuite>,
1153     alpn: Option<Vec<String>>,
1154 }
1155 
1156 impl Default for ClientBuilder {
1157     fn default() -> ClientBuilder {
1158         ClientBuilder::new()
1159     }
1160 }
1161 
1162 impl ClientBuilder {
1163     /// Creates a new builder with default options.
1164     pub fn new() -> Self {
1165         ClientBuilder {
1166             identity: None,
1167             certs: Vec::new(),
1168             chain: Vec::new(),
1169             protocol_min: None,
1170             protocol_max: None,
1171             trust_certs_only: false,
1172             use_sni: true,
1173             danger_accept_invalid_certs: false,
1174             danger_accept_invalid_hostnames: false,
1175             whitelisted_ciphers: Vec::new(),
1176             blacklisted_ciphers: Vec::new(),
1177             alpn: None,
1178         }
1179     }
1180 
1181     /// Specifies the set of root certificates to trust when
1182     /// verifying the server's certificate.
1183     pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1184         self.certs = certs.to_owned();
1185         self
1186     }
1187 
1188     /// Specifies whether to trust the built-in certificates in addition
1189     /// to specified anchor certificates.
1190     pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1191         self.trust_certs_only = only;
1192         self
1193     }
1194 
1195     /// Specifies whether to trust invalid certificates.
1196     ///
1197     /// # Warning
1198     ///
1199     /// You should think very carefully before using this method. If invalid
1200     /// certificates are trusted, *any* certificate for *any* site will be
1201     /// trusted for use. This includes expired certificates. This introduces
1202     /// significant vulnerabilities, and should only be used as a last resort.
1203     pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1204         self.danger_accept_invalid_certs = noverify;
1205         self
1206     }
1207 
1208     /// Specifies whether to use Server Name Indication (SNI).
1209     pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1210         self.use_sni = use_sni;
1211         self
1212     }
1213 
1214     /// Specifies whether to verify that the server's hostname matches its certificate.
1215     ///
1216     /// # Warning
1217     ///
1218     /// You should think very carefully before using this method. If hostnames are not verified,
1219     /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1220     /// vulnerabilities, and should only be used as a last resort.
1221     pub fn danger_accept_invalid_hostnames(
1222         &mut self,
1223         danger_accept_invalid_hostnames: bool,
1224     ) -> &mut Self {
1225         self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1226         self
1227     }
1228 
1229     /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
1230     pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1231         self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
1232         self
1233     }
1234 
1235     /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
1236     pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1237         self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
1238         self
1239     }
1240 
1241     /// Use the specified identity as a SSL/TLS client certificate.
1242     pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1243         self.identity = Some(identity.clone());
1244         self.chain = chain.to_owned();
1245         self
1246     }
1247 
1248     /// Configure the minimum protocol that this client will support.
1249     pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1250         self.protocol_min = Some(min);
1251         self
1252     }
1253 
1254     /// Configure the minimum protocol that this client will support.
1255     pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1256         self.protocol_max = Some(max);
1257         self
1258     }
1259 
1260     /// Configures the set of protocols used for ALPN.
1261     #[cfg(feature = "alpn")]
1262     pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1263         self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
1264         self
1265     }
1266 
1267     /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1268     ///
1269     /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
1270     pub fn handshake<S>(
1271         &self,
1272         domain: &str,
1273         stream: S,
1274     ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1275     where
1276         S: Read + Write,
1277     {
1278         // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1279         // of the handshake logic through that.
1280         let stream = MidHandshakeSslStream {
1281             stream: self.ctx_into_stream(domain, stream)?,
1282             error: Error::from(errSecSuccess),
1283         };
1284 
1285         let certs = self.certs.clone();
1286         let stream = MidHandshakeClientBuilder {
1287             stream,
1288             domain: if self.danger_accept_invalid_hostnames {
1289                 None
1290             } else {
1291                 Some(domain.to_string())
1292             },
1293             certs,
1294             trust_certs_only: self.trust_certs_only,
1295             danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1296         };
1297         stream.handshake()
1298     }
1299 
1300     fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1301     where
1302         S: Read + Write,
1303     {
1304         let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1305 
1306         if self.use_sni {
1307             ctx.set_peer_domain_name(domain)?;
1308         }
1309         if let Some(ref identity) = self.identity {
1310             ctx.set_certificate(identity, &self.chain)?;
1311         }
1312         #[cfg(feature = "alpn")]
1313         {
1314             if let Some(ref alpn) = self.alpn {
1315                 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1316             }
1317         }
1318         ctx.set_break_on_server_auth(true)?;
1319         self.configure_protocols(&mut ctx)?;
1320         self.configure_ciphers(&mut ctx)?;
1321 
1322         ctx.into_stream(stream)
1323     }
1324 
1325     fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1326         if let Some(min) = self.protocol_min {
1327             ctx.set_protocol_version_min(min)?;
1328         }
1329         if let Some(max) = self.protocol_max {
1330             ctx.set_protocol_version_max(max)?;
1331         }
1332         Ok(())
1333     }
1334 
1335     fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1336         let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1337             ctx.enabled_ciphers()?
1338         } else {
1339             self.whitelisted_ciphers.clone()
1340         };
1341 
1342         if !self.blacklisted_ciphers.is_empty() {
1343             ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1344         }
1345 
1346         ctx.set_enabled_ciphers(&ciphers)?;
1347         Ok(())
1348     }
1349 }
1350 
1351 /// A builder type to simplify the creation of server-side `SslStream`s.
1352 #[derive(Debug)]
1353 pub struct ServerBuilder {
1354     identity: SecIdentity,
1355     certs: Vec<SecCertificate>,
1356 }
1357 
1358 impl ServerBuilder {
1359     /// Creates a new `ServerBuilder` which will use the specified identity
1360     /// and certificate chain for handshakes.
1361     pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> ServerBuilder {
1362         ServerBuilder {
1363             identity: identity.clone(),
1364             certs: certs.to_owned(),
1365         }
1366     }
1367 
1368     /// Initiates a new SSL/TLS session over a stream.
1369     pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1370     where
1371         S: Read + Write,
1372     {
1373         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1374         ctx.set_certificate(&self.identity, &self.certs)?;
1375         match ctx.handshake(stream) {
1376             Ok(stream) => Ok(stream),
1377             Err(HandshakeError::Interrupted(stream)) => Err(Error::from_code(stream.reason())),
1378             Err(HandshakeError::Failure(err)) => Err(err),
1379         }
1380     }
1381 }
1382 
1383 #[cfg(test)]
1384 mod test {
1385     use std::io;
1386     use std::io::prelude::*;
1387     use std::net::TcpStream;
1388 
1389     use super::*;
1390 
1391     #[test]
1392     fn connect() {
1393         let mut ctx = p!(SslContext::new(
1394             SslProtocolSide::CLIENT,
1395             SslConnectionType::STREAM
1396         ));
1397         p!(ctx.set_peer_domain_name("google.com"));
1398         let stream = p!(TcpStream::connect("google.com:443"));
1399         p!(ctx.handshake(stream));
1400     }
1401 
1402     #[test]
1403     fn connect_bad_domain() {
1404         let mut ctx = p!(SslContext::new(
1405             SslProtocolSide::CLIENT,
1406             SslConnectionType::STREAM
1407         ));
1408         p!(ctx.set_peer_domain_name("foobar.com"));
1409         let stream = p!(TcpStream::connect("google.com:443"));
1410         match ctx.handshake(stream) {
1411             Ok(_) => panic!("expected failure"),
1412             Err(_) => {}
1413         }
1414     }
1415 
1416     #[test]
1417     fn load_page() {
1418         let mut ctx = p!(SslContext::new(
1419             SslProtocolSide::CLIENT,
1420             SslConnectionType::STREAM
1421         ));
1422         p!(ctx.set_peer_domain_name("google.com"));
1423         let stream = p!(TcpStream::connect("google.com:443"));
1424         let mut stream = p!(ctx.handshake(stream));
1425         p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1426         p!(stream.flush());
1427         let mut buf = vec![];
1428         p!(stream.read_to_end(&mut buf));
1429         println!("{}", String::from_utf8_lossy(&buf));
1430     }
1431 
1432     #[test]
1433     #[cfg(feature = "alpn")]
1434     fn client_alpn_accept() {
1435         let mut ctx = p!(SslContext::new(
1436             SslProtocolSide::CLIENT,
1437             SslConnectionType::STREAM
1438         ));
1439         p!(ctx.set_peer_domain_name("google.com"));
1440         p!(ctx.set_alpn_protocols(&vec!["h2"]));
1441         let stream = p!(TcpStream::connect("google.com:443"));
1442         let stream = ctx.handshake(stream).unwrap();
1443         assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1444     }
1445 
1446     #[test]
1447     #[cfg(feature = "alpn")]
1448     fn client_alpn_reject() {
1449         let mut ctx = p!(SslContext::new(
1450             SslProtocolSide::CLIENT,
1451             SslConnectionType::STREAM
1452         ));
1453         p!(ctx.set_peer_domain_name("google.com"));
1454         p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1455         let stream = p!(TcpStream::connect("google.com:443"));
1456         let stream = ctx.handshake(stream).unwrap();
1457         assert!(stream.context().alpn_protocols().is_err());
1458     }
1459 
1460     #[test]
1461     fn client_no_anchor_certs() {
1462         let stream = p!(TcpStream::connect("google.com:443"));
1463         assert!(ClientBuilder::new()
1464             .trust_anchor_certificates_only(true)
1465             .handshake("google.com", stream)
1466             .is_err());
1467     }
1468 
1469     #[test]
1470     fn client_bad_domain() {
1471         let stream = p!(TcpStream::connect("google.com:443"));
1472         assert!(ClientBuilder::new()
1473             .handshake("foobar.com", stream)
1474             .is_err());
1475     }
1476 
1477     #[test]
1478     fn client_bad_domain_ignored() {
1479         let stream = p!(TcpStream::connect("google.com:443"));
1480         ClientBuilder::new()
1481             .danger_accept_invalid_hostnames(true)
1482             .handshake("foobar.com", stream)
1483             .unwrap();
1484     }
1485 
1486     #[test]
1487     fn connect_no_verify_ssl() {
1488         let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1489         let mut builder = ClientBuilder::new();
1490         builder.danger_accept_invalid_certs(true);
1491         builder.handshake("expired.badssl.com", stream).unwrap();
1492     }
1493 
1494     #[test]
1495     fn load_page_client() {
1496         let stream = p!(TcpStream::connect("google.com:443"));
1497         let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1498         p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1499         p!(stream.flush());
1500         let mut buf = vec![];
1501         p!(stream.read_to_end(&mut buf));
1502         println!("{}", String::from_utf8_lossy(&buf));
1503     }
1504 
1505     #[test]
1506     #[cfg_attr(target_os = "ios", ignore)] // FIXME what's going on with ios?
1507     fn cipher_configuration() {
1508         let mut ctx = p!(SslContext::new(
1509             SslProtocolSide::SERVER,
1510             SslConnectionType::STREAM
1511         ));
1512         let ciphers = p!(ctx.enabled_ciphers());
1513         let ciphers = ciphers
1514             .iter()
1515             .enumerate()
1516             .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1517             .collect::<Vec<_>>();
1518         p!(ctx.set_enabled_ciphers(&ciphers));
1519         assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1520     }
1521 
1522     #[test]
1523     fn test_builder_whitelist_ciphers() {
1524         let stream = p!(TcpStream::connect("google.com:443"));
1525 
1526         let ctx = p!(SslContext::new(
1527             SslProtocolSide::CLIENT,
1528             SslConnectionType::STREAM
1529         ));
1530         assert!(p!(ctx.enabled_ciphers()).len() > 1);
1531 
1532         let ciphers = p!(ctx.enabled_ciphers());
1533         let cipher = ciphers.first().unwrap();
1534         let stream = p!(ClientBuilder::new()
1535             .whitelist_ciphers(&[*cipher])
1536             .ctx_into_stream("google.com", stream));
1537 
1538         assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1539     }
1540 
1541     #[test]
1542     #[cfg_attr(target_os = "ios", ignore)] // FIXME same issue as cipher_configuration
1543     fn test_builder_blacklist_ciphers() {
1544         let stream = p!(TcpStream::connect("google.com:443"));
1545 
1546         let ctx = p!(SslContext::new(
1547             SslProtocolSide::CLIENT,
1548             SslConnectionType::STREAM
1549         ));
1550         let num = p!(ctx.enabled_ciphers()).len();
1551         assert!(num > 1);
1552 
1553         let ciphers = p!(ctx.enabled_ciphers());
1554         let cipher = ciphers.first().unwrap();
1555         let stream = p!(ClientBuilder::new()
1556             .blacklist_ciphers(&[*cipher])
1557             .ctx_into_stream("google.com", stream));
1558 
1559         assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1560     }
1561 
1562     #[test]
1563     fn idle_context_peer_trust() {
1564         let ctx = p!(SslContext::new(
1565             SslProtocolSide::SERVER,
1566             SslConnectionType::STREAM
1567         ));
1568         assert!(ctx.peer_trust2().is_err());
1569     }
1570 
1571     #[test]
1572     fn peer_id() {
1573         let mut ctx = p!(SslContext::new(
1574             SslProtocolSide::SERVER,
1575             SslConnectionType::STREAM
1576         ));
1577         assert!(p!(ctx.peer_id()).is_none());
1578         p!(ctx.set_peer_id(b"foobar"));
1579         assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1580     }
1581 
1582     #[test]
1583     fn peer_domain_name() {
1584         let mut ctx = p!(SslContext::new(
1585             SslProtocolSide::CLIENT,
1586             SslConnectionType::STREAM
1587         ));
1588         assert_eq!("", p!(ctx.peer_domain_name()));
1589         p!(ctx.set_peer_domain_name("foobar.com"));
1590         assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1591     }
1592 
1593     #[test]
1594     #[should_panic(expected = "blammo")]
1595     fn write_panic() {
1596         struct ExplodingStream(TcpStream);
1597 
1598         impl Read for ExplodingStream {
1599             fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1600                 self.0.read(buf)
1601             }
1602         }
1603 
1604         impl Write for ExplodingStream {
1605             fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1606                 panic!("blammo");
1607             }
1608 
1609             fn flush(&mut self) -> io::Result<()> {
1610                 self.0.flush()
1611             }
1612         }
1613 
1614         let mut ctx = p!(SslContext::new(
1615             SslProtocolSide::CLIENT,
1616             SslConnectionType::STREAM
1617         ));
1618         p!(ctx.set_peer_domain_name("google.com"));
1619         let stream = p!(TcpStream::connect("google.com:443"));
1620         let _ = ctx.handshake(ExplodingStream(stream));
1621     }
1622 
1623     #[test]
1624     #[should_panic(expected = "blammo")]
1625     fn read_panic() {
1626         struct ExplodingStream(TcpStream);
1627 
1628         impl Read for ExplodingStream {
1629             fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1630                 panic!("blammo");
1631             }
1632         }
1633 
1634         impl Write for ExplodingStream {
1635             fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1636                 self.0.write(buf)
1637             }
1638 
1639             fn flush(&mut self) -> io::Result<()> {
1640                 self.0.flush()
1641             }
1642         }
1643 
1644         let mut ctx = p!(SslContext::new(
1645             SslProtocolSide::CLIENT,
1646             SslConnectionType::STREAM
1647         ));
1648         p!(ctx.set_peer_domain_name("google.com"));
1649         let stream = p!(TcpStream::connect("google.com:443"));
1650         let _ = ctx.handshake(ExplodingStream(stream));
1651     }
1652 
1653     #[test]
1654     fn zero_length_buffers() {
1655         let mut ctx = p!(SslContext::new(
1656             SslProtocolSide::CLIENT,
1657             SslConnectionType::STREAM
1658         ));
1659         p!(ctx.set_peer_domain_name("google.com"));
1660         let stream = p!(TcpStream::connect("google.com:443"));
1661         let mut stream = ctx.handshake(stream).unwrap();
1662         assert_eq!(stream.write(b"").unwrap(), 0);
1663         assert_eq!(stream.read(&mut []).unwrap(), 0);
1664     }
1665 }
1666