1 //! SSL/TLS encryption support using Secure Transport.
2 //!
3 //! # Examples
4 //!
5 //! To connect as a client to a server with a certificate trusted by the system:
6 //!
7 //! ```rust
8 //! use std::io::prelude::*;
9 //! use std::net::TcpStream;
10 //! use security_framework::secure_transport::ClientBuilder;
11 //!
12 //! let stream = TcpStream::connect("google.com:443").unwrap();
13 //! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
14 //!
15 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
16 //! let mut page = vec![];
17 //! stream.read_to_end(&mut page).unwrap();
18 //! println!("{}", String::from_utf8_lossy(&page));
19 //! ```
20 //!
21 //! To connect to a server with a certificate that's *not* trusted by the
22 //! system, specify the root certificates for the server's chain to the
23 //! `ClientBuilder`:
24 //!
25 //! ```rust,no_run
26 //! use std::io::prelude::*;
27 //! use std::net::TcpStream;
28 //! use security_framework::secure_transport::ClientBuilder;
29 //!
30 //! # let root_cert = unsafe { std::mem::zeroed() };
31 //! let stream = TcpStream::connect("my_server.com:443").unwrap();
32 //! let mut stream = ClientBuilder::new()
33 //!                      .anchor_certificates(&[root_cert])
34 //!                      .handshake("my_server.com", stream)
35 //!                      .unwrap();
36 //!
37 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
38 //! let mut page = vec![];
39 //! stream.read_to_end(&mut page).unwrap();
40 //! println!("{}", String::from_utf8_lossy(&page));
41 //! ```
42 //!
43 //! For more advanced configuration, the `SslContext` type can be used directly.
44 //!
45 //! To run a server:
46 //!
47 //! ```rust,no_run
48 //! use std::net::TcpListener;
49 //! use std::thread;
50 //! use security_framework::secure_transport::{SslContext, SslProtocolSide, SslConnectionType};
51 //!
52 //! // Create a TCP listener and start accepting on it.
53 //! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
54 //!
55 //! for stream in listener.incoming() {
56 //!     let stream = stream.unwrap();
57 //!     thread::spawn(move || {
58 //!         // Create a new context configured to operate on the server side of
59 //!         // a traditional SSL/TLS session.
60 //!         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
61 //!                           .unwrap();
62 //!
63 //!         // Install the certificate chain that we will be using.
64 //!         # let identity = unsafe { std::mem::zeroed() };
65 //!         # let intermediate_cert = unsafe { std::mem::zeroed() };
66 //!         # let root_cert = unsafe { std::mem::zeroed() };
67 //!         ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
68 //!
69 //!         // Perform the SSL/TLS handshake and get our stream.
70 //!         let mut stream = ctx.handshake(stream).unwrap();
71 //!     });
72 //! }
73 //!
74 //! ```
75 #[allow(unused_imports)]
76 use core_foundation::array::{CFArray, CFArrayRef};
77 
78 use core_foundation::base::{Boolean, TCFType};
79 #[cfg(feature = "alpn")]
80 use core_foundation::string::CFString;
81 use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82 use std::os::raw::c_void;
83 
84 #[allow(unused_imports)]
85 use security_framework_sys::base::{
86     errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87     errSecUnimplemented,
88 };
89 
90 use security_framework_sys::secure_transport::*;
91 use std::any::Any;
92 use std::cmp;
93 use std::fmt;
94 use std::io;
95 use std::io::prelude::*;
96 use std::marker::PhantomData;
97 use std::mem;
98 use std::panic::{self, AssertUnwindSafe};
99 use std::ptr;
100 use std::result;
101 use std::slice;
102 
103 use crate::base::{Error, Result};
104 use crate::certificate::SecCertificate;
105 use crate::cipher_suite::CipherSuite;
106 use crate::identity::SecIdentity;
107 use crate::policy::SecPolicy;
108 use crate::trust::{SecTrust, TrustResult};
109 use crate::{cvt, AsInner};
110 
111 /// Specifies a side of a TLS session.
112 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
113 pub struct SslProtocolSide(SSLProtocolSide);
114 
115 impl SslProtocolSide {
116     /// The server side of the session.
117     pub const SERVER: Self = Self(kSSLServerSide);
118 
119     /// The client side of the session.
120     pub const CLIENT: Self = Self(kSSLClientSide);
121 }
122 
123 /// Specifies the type of TLS session.
124 #[derive(Debug, Copy, Clone)]
125 pub struct SslConnectionType(SSLConnectionType);
126 
127 impl SslConnectionType {
128     /// A traditional TLS stream.
129     pub const STREAM: Self = Self(kSSLStreamType);
130 
131     /// A DTLS session.
132     pub const DATAGRAM: Self = Self(kSSLDatagramType);
133 }
134 
135 /// An error or intermediate state after a TLS handshake attempt.
136 #[derive(Debug)]
137 pub enum HandshakeError<S> {
138     /// The handshake failed.
139     Failure(Error),
140     /// The handshake was interrupted midway through.
141     Interrupted(MidHandshakeSslStream<S>),
142 }
143 
144 impl<S> From<Error> for HandshakeError<S> {
from(err: Error) -> Self145     fn from(err: Error) -> Self {
146         Self::Failure(err)
147     }
148 }
149 
150 /// An error or intermediate state after a TLS handshake attempt.
151 #[derive(Debug)]
152 pub enum ClientHandshakeError<S> {
153     /// The handshake failed.
154     Failure(Error),
155     /// The handshake was interrupted midway through.
156     Interrupted(MidHandshakeClientBuilder<S>),
157 }
158 
159 impl<S> From<Error> for ClientHandshakeError<S> {
from(err: Error) -> Self160     fn from(err: Error) -> Self {
161         Self::Failure(err)
162     }
163 }
164 
165 /// An SSL stream midway through the handshake process.
166 #[derive(Debug)]
167 pub struct MidHandshakeSslStream<S> {
168     stream: SslStream<S>,
169     error: Error,
170 }
171 
172 impl<S> MidHandshakeSslStream<S> {
173     /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S174     pub fn get_ref(&self) -> &S {
175         self.stream.get_ref()
176     }
177 
178     /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S179     pub fn get_mut(&mut self) -> &mut S {
180         self.stream.get_mut()
181     }
182 
183     /// Returns a shared reference to the `SslContext` of the stream.
context(&self) -> &SslContext184     pub fn context(&self) -> &SslContext {
185         self.stream.context()
186     }
187 
188     /// Returns a mutable reference to the `SslContext` of the stream.
context_mut(&mut self) -> &mut SslContext189     pub fn context_mut(&mut self) -> &mut SslContext {
190         self.stream.context_mut()
191     }
192 
193     /// Returns `true` iff `break_on_server_auth` was set and the handshake has
194     /// progressed to that point.
server_auth_completed(&self) -> bool195     pub fn server_auth_completed(&self) -> bool {
196         self.error.code() == errSSLPeerAuthCompleted
197     }
198 
199     /// Returns `true` iff `break_on_cert_requested` was set and the handshake
200     /// has progressed to that point.
client_cert_requested(&self) -> bool201     pub fn client_cert_requested(&self) -> bool {
202         self.error.code() == errSSLClientCertRequested
203     }
204 
205     /// Returns `true` iff the underlying stream returned an error with the
206     /// `WouldBlock` kind.
would_block(&self) -> bool207     pub fn would_block(&self) -> bool {
208         self.error.code() == errSSLWouldBlock
209     }
210 
211     /// Returns the error which caused the handshake interruption.
error(&self) -> &Error212     pub fn error(&self) -> &Error {
213         &self.error
214     }
215 
216     /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>>217     pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
218         self.stream.handshake()
219     }
220 }
221 
222 /// An SSL stream midway through the handshake process.
223 #[derive(Debug)]
224 pub struct MidHandshakeClientBuilder<S> {
225     stream: MidHandshakeSslStream<S>,
226     domain: Option<String>,
227     certs: Vec<SecCertificate>,
228     trust_certs_only: bool,
229     danger_accept_invalid_certs: bool,
230 }
231 
232 impl<S> MidHandshakeClientBuilder<S> {
233     /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S234     pub fn get_ref(&self) -> &S {
235         self.stream.get_ref()
236     }
237 
238     /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S239     pub fn get_mut(&mut self) -> &mut S {
240         self.stream.get_mut()
241     }
242 
243     /// Returns the error which caused the handshake interruption.
error(&self) -> &Error244     pub fn error(&self) -> &Error {
245         self.stream.error()
246     }
247 
248     /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>>249     pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
250         let MidHandshakeClientBuilder {
251             stream,
252             domain,
253             certs,
254             trust_certs_only,
255             danger_accept_invalid_certs,
256         } = self;
257 
258         let mut result = stream.handshake();
259         loop {
260             let stream = match result {
261                 Ok(stream) => return Ok(stream),
262                 Err(HandshakeError::Interrupted(stream)) => stream,
263                 Err(HandshakeError::Failure(err)) => {
264                     return Err(ClientHandshakeError::Failure(err))
265                 }
266             };
267 
268             if stream.would_block() {
269                 let ret = MidHandshakeClientBuilder {
270                     stream,
271                     domain,
272                     certs,
273                     trust_certs_only,
274                     danger_accept_invalid_certs,
275                 };
276                 return Err(ClientHandshakeError::Interrupted(ret));
277             }
278 
279             if stream.server_auth_completed() {
280                 if danger_accept_invalid_certs {
281                     result = stream.handshake();
282                     continue;
283                 }
284                 let mut trust = match stream.context().peer_trust2()? {
285                     Some(trust) => trust,
286                     None => {
287                         result = stream.handshake();
288                         continue;
289                     }
290                 };
291                 trust.set_anchor_certificates(&certs)?;
292                 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
293                 let policy =
294                     SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_ref().map(|s| &**s));
295                 trust.set_policy(&policy)?;
296                 let trusted = trust.evaluate()?;
297                 match trusted {
298                     TrustResult::PROCEED | TrustResult::UNSPECIFIED => {
299                         result = stream.handshake();
300                         continue;
301                     }
302                     TrustResult::DENY => {
303                         let err = Error::from_code(errSecTrustSettingDeny);
304                         return Err(ClientHandshakeError::Failure(err));
305                     }
306                     TrustResult::RECOVERABLE_TRUST_FAILURE | TrustResult::FATAL_TRUST_FAILURE => {
307                         let err = Error::from_code(errSecNotTrusted);
308                         return Err(ClientHandshakeError::Failure(err));
309                     }
310                     TrustResult::INVALID | TrustResult::OTHER_ERROR | _ => {
311                         let err = Error::from_code(errSecBadReq);
312                         return Err(ClientHandshakeError::Failure(err));
313                     }
314                 }
315             }
316 
317             let err = Error::from_code(stream.error().code());
318             return Err(ClientHandshakeError::Failure(err));
319         }
320     }
321 }
322 
323 /// Specifies the state of a TLS session.
324 #[derive(Debug, PartialEq, Eq)]
325 pub struct SessionState(SSLSessionState);
326 
327 impl SessionState {
328     /// The session has not yet started.
329     pub const IDLE: Self = Self(kSSLIdle);
330 
331     /// The session is in the handshake process.
332     pub const HANDSHAKE: Self = Self(kSSLHandshake);
333 
334     /// The session is connected.
335     pub const CONNECTED: Self = Self(kSSLConnected);
336 
337     /// The session has been terminated.
338     pub const CLOSED: Self = Self(kSSLClosed);
339 
340     /// The session has been aborted due to an error.
341     pub const ABORTED: Self = Self(kSSLAborted);
342 }
343 
344 /// Specifies a server's requirement for client certificates.
345 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
346 pub struct SslAuthenticate(SSLAuthenticate);
347 
348 impl SslAuthenticate {
349     /// Do not request a client certificate.
350     pub const NEVER: Self = Self(kNeverAuthenticate);
351 
352     /// Require a client certificate.
353     pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
354 
355     /// Request but do not require a client certificate.
356     pub const TRY: Self = Self(kTryAuthenticate);
357 }
358 
359 /// Specifies the state of client certificate processing.
360 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
361 pub struct SslClientCertificateState(SSLClientCertificateState);
362 
363 impl SslClientCertificateState {
364     /// A client certificate has not been requested or sent.
365     pub const NONE: Self = Self(kSSLClientCertNone);
366 
367     /// A client certificate has been requested but not recieved.
368     pub const REQUESTED: Self = Self(kSSLClientCertRequested);
369     /// A client certificate has been received and successfully validated.
370     pub const SENT: Self = Self(kSSLClientCertSent);
371 
372     /// A client certificate has been received but has failed to validate.
373     pub const REJECTED: Self = Self(kSSLClientCertRejected); }
374 
375 /// Specifies protocol versions.
376 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
377 pub struct SslProtocol(SSLProtocol);
378 
379 impl SslProtocol {
380     /// No protocol has been or should be negotiated or specified; use the default.
381     pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
382 
383     /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
384     /// SSL 3.0.
385     pub const SSL3: Self = Self(kSSLProtocol3);
386 
387     /// The TLS 1.0 protocol is preferred, though lower versions may be used
388     /// if the peer does not support TLS 1.0.
389     pub const TLS1: Self = Self(kTLSProtocol1);
390 
391     /// The TLS 1.1 protocol is preferred, though lower versions may be used
392     /// if the peer does not support TLS 1.1.
393     pub const TLS11: Self = Self(kTLSProtocol11);
394 
395     /// The TLS 1.2 protocol is preferred, though lower versions may be used
396     /// if the peer does not support TLS 1.2.
397     pub const TLS12: Self = Self(kTLSProtocol12);
398 
399     /// The TLS 1.3 protocol is preferred, though lower versions may be used
400     /// if the peer does not support TLS 1.3.
401     pub const TLS13: Self = Self(kTLSProtocol13);
402 
403     /// Only the SSL 2.0 protocol is accepted.
404     pub const SSL2: Self = Self(kSSLProtocol2);
405 
406     /// The DTLSv1 protocol is preferred.
407     pub const DTLS1: Self = Self(kDTLSProtocol1);
408 
409     /// Only the SSL 3.0 protocol is accepted.
410     pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
411 
412     /// Only the TLS 1.0 protocol is accepted.
413     pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
414 
415     /// All supported TLS/SSL versions are accepted.
416     pub const ALL: Self = Self(kSSLProtocolAll);
417 }
418 
419 declare_TCFType! {
420     /// A Secure Transport SSL/TLS context object.
421     SslContext, SSLContextRef
422 }
423 
424 impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
425 
426 impl fmt::Debug for SslContext {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result427     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
428         let mut builder = fmt.debug_struct("SslContext");
429         if let Ok(state) = self.state() {
430             builder.field("state", &state);
431         }
432         builder.finish()
433     }
434 }
435 
436 unsafe impl Sync for SslContext {}
437 unsafe impl Send for SslContext {}
438 
439 impl AsInner for SslContext {
440     type Inner = SSLContextRef;
441 
as_inner(&self) -> SSLContextRef442     fn as_inner(&self) -> SSLContextRef {
443         self.0
444     }
445 }
446 
447 macro_rules! impl_options {
448     ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
449         $(
450             $(#[$a])*
451             pub fn $set(&mut self, value: bool) -> Result<()> {
452                 unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
453             }
454 
455             $(#[$a])*
456             pub fn $get(&self) -> Result<bool> {
457                 let mut value = 0;
458                 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
459                 Ok(value != 0)
460             }
461         )*
462     }
463 }
464 
465 impl SslContext {
466     /// Creates a new `SslContext` for the specified side and type of SSL
467     /// connection.
new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self>468     pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
469         unsafe {
470             let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
471             Ok(Self(ctx))
472         }
473     }
474 
475     /// Sets the fully qualified domain name of the peer.
476     ///
477     /// This will be used on the client side of a session to validate the
478     /// common name field of the server's certificate. It has no effect if
479     /// called on a server-side `SslContext`.
480     ///
481     /// It is *highly* recommended to call this method before starting the
482     /// handshake process.
set_peer_domain_name(&mut self, peer_name: &str) -> Result<()>483     pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
484         unsafe {
485             // SSLSetPeerDomainName doesn't need a null terminated string
486             cvt(SSLSetPeerDomainName(
487                 self.0,
488                 peer_name.as_ptr() as *const _,
489                 peer_name.len(),
490             ))
491         }
492     }
493 
494     /// Returns the peer domain name set by `set_peer_domain_name`.
peer_domain_name(&self) -> Result<String>495     pub fn peer_domain_name(&self) -> Result<String> {
496         unsafe {
497             let mut len = 0;
498             cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
499             let mut buf = vec![0; len];
500             cvt(SSLGetPeerDomainName(
501                 self.0,
502                 buf.as_mut_ptr() as *mut _,
503                 &mut len,
504             ))?;
505             Ok(String::from_utf8(buf).unwrap())
506         }
507     }
508 
509     /// Sets the certificate to be used by this side of the SSL session.
510     ///
511     /// This must be called before the handshake for server-side connections,
512     /// and can be used on the client-side to specify a client certificate.
513     ///
514     /// The `identity` corresponds to the leaf certificate and private
515     /// key, and the `certs` correspond to extra certificates in the chain.
set_certificate( &mut self, identity: &SecIdentity, certs: &[SecCertificate], ) -> Result<()>516     pub fn set_certificate(
517         &mut self,
518         identity: &SecIdentity,
519         certs: &[SecCertificate],
520     ) -> Result<()> {
521         let mut arr = vec![identity.as_CFType()];
522         arr.extend(certs.iter().map(|c| c.as_CFType()));
523         let certs = CFArray::from_CFTypes(&arr);
524 
525         unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
526     }
527 
528     /// Sets the peer ID of this session.
529     ///
530     /// A peer ID is an opaque sequence of bytes that will be used by Secure
531     /// Transport to identify the peer of an SSL session. If the peer ID of
532     /// this session matches that of a previously terminated session, the
533     /// previous session can be resumed without requiring a full handshake.
set_peer_id(&mut self, peer_id: &[u8]) -> Result<()>534     pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
535         unsafe {
536             cvt(SSLSetPeerID(
537                 self.0,
538                 peer_id.as_ptr() as *const _,
539                 peer_id.len(),
540             ))
541         }
542     }
543 
544     /// Returns the peer ID of this session.
peer_id(&self) -> Result<Option<&[u8]>>545     pub fn peer_id(&self) -> Result<Option<&[u8]>> {
546         unsafe {
547             let mut ptr = ptr::null();
548             let mut len = 0;
549             cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
550             if ptr.is_null() {
551                 Ok(None)
552             } else {
553                 Ok(Some(slice::from_raw_parts(ptr as *const _, len)))
554             }
555         }
556     }
557 
558     /// Returns the list of ciphers that are supported by Secure Transport.
supported_ciphers(&self) -> Result<Vec<CipherSuite>>559     pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
560         unsafe {
561             let mut num_ciphers = 0;
562             cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
563             let mut ciphers = vec![0; num_ciphers];
564             cvt(SSLGetSupportedCiphers(
565                 self.0,
566                 ciphers.as_mut_ptr(),
567                 &mut num_ciphers,
568             ))?;
569             Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
570         }
571     }
572 
573     /// Returns the list of ciphers that are eligible to be used for
574     /// negotiation.
enabled_ciphers(&self) -> Result<Vec<CipherSuite>>575     pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
576         unsafe {
577             let mut num_ciphers = 0;
578             cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
579             let mut ciphers = vec![0; num_ciphers];
580             cvt(SSLGetEnabledCiphers(
581                 self.0,
582                 ciphers.as_mut_ptr(),
583                 &mut num_ciphers,
584             ))?;
585             Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
586         }
587     }
588 
589     /// Sets the list of ciphers that are eligible to be used for negotiation.
set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()>590     pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
591         let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
592         unsafe {
593             cvt(SSLSetEnabledCiphers(
594                 self.0,
595                 ciphers.as_ptr(),
596                 ciphers.len(),
597             ))
598         }
599     }
600 
601     /// Returns the cipher being used by the session.
negotiated_cipher(&self) -> Result<CipherSuite>602     pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
603         unsafe {
604             let mut cipher = 0;
605             cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
606             Ok(CipherSuite::from_raw(cipher))
607         }
608     }
609 
610     /// Sets the requirements for client certificates.
611     ///
612     /// Should only be called on server-side sessions.
set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()>613     pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
614         unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
615     }
616 
617     /// Returns the state of client certificate processing.
client_certificate_state(&self) -> Result<SslClientCertificateState>618     pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
619         let mut state = 0;
620 
621         unsafe {
622             cvt(SSLGetClientCertificateState(self.0, &mut state))?;
623         }
624         Ok(SslClientCertificateState(state))
625     }
626 
627     /// Returns the `SecTrust` object corresponding to the peer.
628     ///
629     /// This can be used in conjunction with `set_break_on_server_auth` to
630     /// validate certificates which do not have roots in the default set.
peer_trust2(&self) -> Result<Option<SecTrust>>631     pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
632         // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
633         // so explicitly check for that
634         if self.state()? == SessionState::IDLE {
635             return Err(Error::from_code(errSecBadReq));
636         }
637 
638         unsafe {
639             let mut trust = ptr::null_mut();
640             cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
641             if trust.is_null() {
642                 Ok(None)
643             } else {
644                 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
645             }
646         }
647     }
648 
649     /// Returns the state of the session.
state(&self) -> Result<SessionState>650     pub fn state(&self) -> Result<SessionState> {
651         unsafe {
652             let mut state = 0;
653             cvt(SSLGetSessionState(self.0, &mut state))?;
654             Ok(SessionState(state))
655         }
656     }
657 
658     /// Returns the protocol version being used by the session.
negotiated_protocol_version(&self) -> Result<SslProtocol>659     pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
660         unsafe {
661             let mut version = 0;
662             cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
663             Ok(SslProtocol(version))
664         }
665     }
666 
667     /// Returns the maximum protocol version allowed by the session.
protocol_version_max(&self) -> Result<SslProtocol>668     pub fn protocol_version_max(&self) -> Result<SslProtocol> {
669         unsafe {
670             let mut version = 0;
671             cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
672             Ok(SslProtocol(version))
673         }
674     }
675 
676     /// Sets the maximum protocol version allowed by the session.
set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()>677     pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
678         unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
679     }
680 
681     /// Returns the minimum protocol version allowed by the session.
protocol_version_min(&self) -> Result<SslProtocol>682     pub fn protocol_version_min(&self) -> Result<SslProtocol> {
683         unsafe {
684             let mut version = 0;
685             cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
686             Ok(SslProtocol(version))
687         }
688     }
689 
690     /// Sets the minimum protocol version allowed by the session.
set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()>691     pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
692         unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
693     }
694 
695     /// Returns the set of protocols selected via ALPN if it succeeded.
696     #[cfg(feature = "alpn")]
alpn_protocols(&self) -> Result<Vec<String>>697     pub fn alpn_protocols(&self) -> Result<Vec<String>> {
698         let mut array: CFArrayRef = ptr::null();
699         unsafe {
700             #[cfg(feature = "OSX_10_13")]
701             {
702                 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
703             }
704 
705             #[cfg(not(feature = "OSX_10_13"))]
706             {
707                 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
708                 if let Some(f) = SSLCopyALPNProtocols.get() {
709                     cvt(f(self.0, &mut array))?;
710                 } else {
711                     return Err(Error::from_code(errSecUnimplemented));
712                 }
713             }
714 
715             if array.is_null() {
716                 return Ok(vec![]);
717             }
718 
719             let array = CFArray::<CFString>::wrap_under_create_rule(array);
720             Ok(array.into_iter().map(|p| p.to_string()).collect())
721         }
722     }
723 
724     /// Configures the set of protocols use for ALPN.
725     ///
726     /// This is only used for client-side connections.
727     #[cfg(feature = "alpn")]
set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()>728     pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
729         // When CFMutableArray is added to core-foundation and IntoIterator trait
730         // is implemented for CFMutableArray, the code below should directly collect
731         // into a CFMutableArray.
732         let protocols = CFArray::from_CFTypes(
733             &protocols
734                 .iter()
735                 .map(|proto| CFString::new(proto))
736                 .collect::<Vec<_>>(),
737         );
738 
739         #[cfg(feature = "OSX_10_13")]
740         {
741             unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
742         }
743         #[cfg(not(feature = "OSX_10_13"))]
744         {
745             dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
746             if let Some(f) = SSLSetALPNProtocols.get() {
747                 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
748             } else {
749                 Err(Error::from_code(errSecUnimplemented))
750             }
751         }
752     }
753 
754     /// Sets whether the client sends the `SessionTicket` extension in its `ClientHello`.
755     ///
756     /// On its own, this will just cause the client to send an empty `SessionTicket` extension on
757     /// every connection. [`SslContext::set_peer_id`] must also be used to key the session
758     /// ticket returned by the server.
759     ///
760     /// [`SslContext::set_peer_id`]: #method.set_peer_id
761     #[cfg(feature = "session-tickets")]
set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()>762     pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
763         #[cfg(feature = "OSX_10_13")]
764         {
765             unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, enabled as Boolean)) }
766         }
767         #[cfg(not(feature = "OSX_10_13"))]
768         {
769             dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
770             if let Some(f) = SSLSetSessionTicketsEnabled.get() {
771                 unsafe { cvt(f(self.0, enabled as Boolean)) }
772             } else {
773                 Err(Error::from_code(errSecUnimplemented))
774             }
775         }
776     }
777 
778     /// Sets whether a protocol is enabled or not.
779     ///
780     /// # Note
781     ///
782     /// On OSX this is a deprecated API in favor of `set_protocol_version_max` and
783     /// `set_protocol_version_min`, although if you're working with OSX 10.8 or before you may have
784     /// to use this API instead.
785     #[cfg(target_os = "macos")]
786     #[deprecated(note = "use `set_protocol_version_max`")]
set_protocol_version_enabled( &mut self, protocol: SslProtocol, enabled: bool, ) -> Result<()>787     pub fn set_protocol_version_enabled(
788         &mut self,
789         protocol: SslProtocol,
790         enabled: bool,
791     ) -> Result<()> {
792         unsafe {
793             cvt(SSLSetProtocolVersionEnabled(
794                 self.0,
795                 protocol.0,
796                 enabled as Boolean,
797             ))
798         }
799     }
800 
801     /// Returns the number of bytes which can be read without triggering a
802     /// `read` call in the underlying stream.
buffered_read_size(&self) -> Result<usize>803     pub fn buffered_read_size(&self) -> Result<usize> {
804         unsafe {
805             let mut size = 0;
806             cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
807             Ok(size)
808         }
809     }
810 
811     impl_options! {
812         /// If enabled, the handshake process will pause and return instead of
813         /// automatically validating a server's certificate.
814         const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
815         /// If enabled, the handshake process will pause and return after
816         /// the server requests a certificate from the client.
817         const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
818         /// If enabled, the handshake process will pause and return instead of
819         /// automatically validating a client's certificate.
820         const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
821         /// If enabled, TLS false start will be performed if an appropriate
822         /// cipher suite is negotiated.
823         ///
824         /// Requires the `OSX_10_9` (or greater) feature.
825         #[cfg(feature = "OSX_10_9")]
826         const kSSLSessionOptionFalseStart: false_start & set_false_start,
827         /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
828         /// connections using block ciphers to mitigate the BEAST attack.
829         ///
830         /// Requires the `OSX_10_9` (or greater) feature.
831         #[cfg(feature = "OSX_10_9")]
832         const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
833     }
834 
into_stream<S>(self, stream: S) -> Result<SslStream<S>> where S: Read + Write,835     fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
836     where
837         S: Read + Write,
838     {
839         unsafe {
840             let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
841             if ret != errSecSuccess {
842                 return Err(Error::from_code(ret));
843             }
844 
845             let stream = Connection {
846                 stream,
847                 err: None,
848                 panic: None,
849             };
850             let stream = Box::into_raw(Box::new(stream));
851             let ret = SSLSetConnection(self.0, stream as *mut _);
852             if ret != errSecSuccess {
853                 let _conn = Box::from_raw(stream);
854                 return Err(Error::from_code(ret));
855             }
856 
857             Ok(SslStream {
858                 ctx: self,
859                 _m: PhantomData,
860             })
861         }
862     }
863 
864     /// Performs the SSL/TLS handshake.
handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>> where S: Read + Write,865     pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
866     where
867         S: Read + Write,
868     {
869         self.into_stream(stream)
870             .map_err(HandshakeError::Failure)
871             .and_then(SslStream::handshake)
872     }
873 }
874 
875 struct Connection<S> {
876     stream: S,
877     err: Option<io::Error>,
878     panic: Option<Box<dyn Any + Send>>,
879 }
880 
881 // the logic here is based off of libcurl's
882 
translate_err(e: &io::Error) -> OSStatus883 fn translate_err(e: &io::Error) -> OSStatus {
884     match e.kind() {
885         io::ErrorKind::NotFound => errSSLClosedGraceful,
886         io::ErrorKind::ConnectionReset => errSSLClosedAbort,
887         io::ErrorKind::WouldBlock |
888         io::ErrorKind::NotConnected => errSSLWouldBlock,
889         _ => errSecIO,
890     }
891 }
892 
read_func<S>( connection: SSLConnectionRef, data: *mut c_void, data_length: *mut usize, ) -> OSStatus where S: Read,893 unsafe extern "C" fn read_func<S>(
894     connection: SSLConnectionRef,
895     data: *mut c_void,
896     data_length: *mut usize,
897 ) -> OSStatus
898 where
899     S: Read,
900 {
901     let conn: &mut Connection<S> = &mut *(connection as *mut _);
902     let data = slice::from_raw_parts_mut(data as *mut u8, *data_length);
903     let mut start = 0;
904     let mut ret = errSecSuccess;
905 
906     while start < data.len() {
907         match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
908             Ok(Ok(0)) => {
909                 ret = errSSLClosedNoNotify;
910                 break;
911             }
912             Ok(Ok(len)) => start += len,
913             Ok(Err(e)) => {
914                 ret = translate_err(&e);
915                 conn.err = Some(e);
916                 break;
917             }
918             Err(e) => {
919                 ret = errSecIO;
920                 conn.panic = Some(e);
921                 break;
922             }
923         }
924     }
925 
926     *data_length = start;
927     ret
928 }
929 
write_func<S>( connection: SSLConnectionRef, data: *const c_void, data_length: *mut usize, ) -> OSStatus where S: Write,930 unsafe extern "C" fn write_func<S>(
931     connection: SSLConnectionRef,
932     data: *const c_void,
933     data_length: *mut usize,
934 ) -> OSStatus
935 where
936     S: Write,
937 {
938     let conn: &mut Connection<S> = &mut *(connection as *mut _);
939     let data = slice::from_raw_parts(data as *mut u8, *data_length);
940     let mut start = 0;
941     let mut ret = errSecSuccess;
942 
943     while start < data.len() {
944         match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
945             Ok(Ok(0)) => {
946                 ret = errSSLClosedNoNotify;
947                 break;
948             }
949             Ok(Ok(len)) => start += len,
950             Ok(Err(e)) => {
951                 ret = translate_err(&e);
952                 conn.err = Some(e);
953                 break;
954             }
955             Err(e) => {
956                 ret = errSecIO;
957                 conn.panic = Some(e);
958                 break;
959             }
960         }
961     }
962 
963     *data_length = start;
964     ret
965 }
966 
967 /// A type implementing SSL/TLS encryption over an underlying stream.
968 pub struct SslStream<S> {
969     ctx: SslContext,
970     _m: PhantomData<S>,
971 }
972 
973 impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result974     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
975         fmt.debug_struct("SslStream")
976             .field("context", &self.ctx)
977             .field("stream", self.get_ref())
978             .finish()
979     }
980 }
981 
982 impl<S> Drop for SslStream<S> {
drop(&mut self)983     fn drop(&mut self) {
984         unsafe {
985             let mut conn = ptr::null();
986             let ret = SSLGetConnection(self.ctx.0, &mut conn);
987             assert!(ret == errSecSuccess);
988             Box::<Connection<S>>::from_raw(conn as *mut _);
989         }
990     }
991 }
992 
993 impl<S> SslStream<S> {
handshake(mut self) -> result::Result<Self, HandshakeError<S>>994     fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
995         match unsafe { SSLHandshake(self.ctx.0) } {
996             errSecSuccess => Ok(self),
997             reason @ errSSLPeerAuthCompleted
998             | reason @ errSSLClientCertRequested
999             | reason @ errSSLWouldBlock
1000             | reason @ errSSLClientHelloReceived => {
1001                 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
1002                     stream: self,
1003                     error: Error::from_code(reason),
1004                 }))
1005             }
1006             err => {
1007                 self.check_panic();
1008                 Err(HandshakeError::Failure(Error::from_code(err)))
1009             }
1010         }
1011     }
1012 
1013     /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S1014     pub fn get_ref(&self) -> &S {
1015         &self.connection().stream
1016     }
1017 
1018     /// Returns a mutable reference to the underlying stream.
get_mut(&mut self) -> &mut S1019     pub fn get_mut(&mut self) -> &mut S {
1020         &mut self.connection_mut().stream
1021     }
1022 
1023     /// Returns a shared reference to the `SslContext` of the stream.
context(&self) -> &SslContext1024     pub fn context(&self) -> &SslContext {
1025         &self.ctx
1026     }
1027 
1028     /// Returns a mutable reference to the `SslContext` of the stream.
context_mut(&mut self) -> &mut SslContext1029     pub fn context_mut(&mut self) -> &mut SslContext {
1030         &mut self.ctx
1031     }
1032 
1033     /// Shuts down the connection.
close(&mut self) -> result::Result<(), io::Error>1034     pub fn close(&mut self) -> result::Result<(), io::Error> {
1035         unsafe {
1036             let ret = SSLClose(self.ctx.0);
1037             if ret == errSecSuccess {
1038                 Ok(())
1039             } else {
1040                 Err(self.get_error(ret))
1041             }
1042         }
1043     }
1044 
connection(&self) -> &Connection<S>1045     fn connection(&self) -> &Connection<S> {
1046         unsafe {
1047             let mut conn = ptr::null();
1048             let ret = SSLGetConnection(self.ctx.0, &mut conn);
1049             assert!(ret == errSecSuccess);
1050 
1051             mem::transmute(conn)
1052         }
1053     }
1054 
connection_mut(&mut self) -> &mut Connection<S>1055     fn connection_mut(&mut self) -> &mut Connection<S> {
1056         unsafe {
1057             let mut conn = ptr::null();
1058             let ret = SSLGetConnection(self.ctx.0, &mut conn);
1059             assert!(ret == errSecSuccess);
1060 
1061             mem::transmute(conn)
1062         }
1063     }
1064 
check_panic(&mut self)1065     fn check_panic(&mut self) {
1066         let conn = self.connection_mut();
1067         if let Some(err) = conn.panic.take() {
1068             panic::resume_unwind(err);
1069         }
1070     }
1071 
get_error(&mut self, ret: OSStatus) -> io::Error1072     fn get_error(&mut self, ret: OSStatus) -> io::Error {
1073         self.check_panic();
1074 
1075         if let Some(err) = self.connection_mut().err.take() {
1076             err
1077         } else {
1078             io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1079         }
1080     }
1081 }
1082 
1083 impl<S: Read + Write> Read for SslStream<S> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>1084     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1085         // Below we base our return value off the amount of data read, so a
1086         // zero-length buffer might cause us to erroneously interpret this
1087         // request as an error. Instead short-circuit that logic and return
1088         // `Ok(0)` instead.
1089         if buf.is_empty() {
1090             return Ok(0);
1091         }
1092 
1093         // If some data was buffered but not enough to fill `buf`, SSLRead
1094         // will try to read a new packet. This is bad because there may be
1095         // no more data but the socket is remaining open (e.g HTTPS with
1096         // Connection: keep-alive).
1097         let buffered = self.context().buffered_read_size().unwrap_or(0);
1098         let to_read = if buffered > 0 {
1099             cmp::min(buffered, buf.len())
1100         } else {
1101             buf.len()
1102         };
1103 
1104         unsafe {
1105             let mut nread = 0;
1106             let ret = SSLRead(self.ctx.0, buf.as_mut_ptr() as *mut _, to_read, &mut nread);
1107             // SSLRead can return an error at the same time it returns the last
1108             // chunk of data (!)
1109             if nread > 0 {
1110                 return Ok(nread as usize);
1111             }
1112 
1113             match ret {
1114                 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1115                 _ => Err(self.get_error(ret)),
1116             }
1117         }
1118     }
1119 }
1120 
1121 impl<S: Read + Write> Write for SslStream<S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>1122     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1123         // Like above in read, short circuit a 0-length write
1124         if buf.is_empty() {
1125             return Ok(0);
1126         }
1127         unsafe {
1128             let mut nwritten = 0;
1129             let ret = SSLWrite(
1130                 self.ctx.0,
1131                 buf.as_ptr() as *const _,
1132                 buf.len(),
1133                 &mut nwritten,
1134             );
1135             // just to be safe, base success off of nwritten rather than ret
1136             // for the same reason as in read
1137             if nwritten > 0 {
1138                 Ok(nwritten as usize)
1139             } else {
1140                 Err(self.get_error(ret))
1141             }
1142         }
1143     }
1144 
flush(&mut self) -> io::Result<()>1145     fn flush(&mut self) -> io::Result<()> {
1146         self.connection_mut().stream.flush()
1147     }
1148 }
1149 
1150 /// A builder type to simplify the creation of client side `SslStream`s.
1151 #[derive(Debug)]
1152 pub struct ClientBuilder {
1153     identity: Option<SecIdentity>,
1154     certs: Vec<SecCertificate>,
1155     chain: Vec<SecCertificate>,
1156     protocol_min: Option<SslProtocol>,
1157     protocol_max: Option<SslProtocol>,
1158     trust_certs_only: bool,
1159     use_sni: bool,
1160     danger_accept_invalid_certs: bool,
1161     danger_accept_invalid_hostnames: bool,
1162     whitelisted_ciphers: Vec<CipherSuite>,
1163     blacklisted_ciphers: Vec<CipherSuite>,
1164     alpn: Option<Vec<String>>,
1165     enable_session_tickets: bool,
1166 }
1167 
1168 impl Default for ClientBuilder {
default() -> Self1169     fn default() -> Self {
1170         Self::new()
1171     }
1172 }
1173 
1174 impl ClientBuilder {
1175     /// Creates a new builder with default options.
new() -> Self1176     pub fn new() -> Self {
1177         Self {
1178             identity: None,
1179             certs: Vec::new(),
1180             chain: Vec::new(),
1181             protocol_min: None,
1182             protocol_max: None,
1183             trust_certs_only: false,
1184             use_sni: true,
1185             danger_accept_invalid_certs: false,
1186             danger_accept_invalid_hostnames: false,
1187             whitelisted_ciphers: Vec::new(),
1188             blacklisted_ciphers: Vec::new(),
1189             alpn: None,
1190             enable_session_tickets: false,
1191         }
1192     }
1193 
1194     /// Specifies the set of root certificates to trust when
1195     /// verifying the server's certificate.
anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self1196     pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1197         self.certs = certs.to_owned();
1198         self
1199     }
1200 
1201     /// Specifies whether to trust the built-in certificates in addition
1202     /// to specified anchor certificates.
trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self1203     pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1204         self.trust_certs_only = only;
1205         self
1206     }
1207 
1208     /// Specifies whether to trust invalid certificates.
1209     ///
1210     /// # Warning
1211     ///
1212     /// You should think very carefully before using this method. If invalid
1213     /// certificates are trusted, *any* certificate for *any* site will be
1214     /// trusted for use. This includes expired certificates. This introduces
1215     /// significant vulnerabilities, and should only be used as a last resort.
danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self1216     pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1217         self.danger_accept_invalid_certs = noverify;
1218         self
1219     }
1220 
1221     /// Specifies whether to use Server Name Indication (SNI).
use_sni(&mut self, use_sni: bool) -> &mut Self1222     pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1223         self.use_sni = use_sni;
1224         self
1225     }
1226 
1227     /// Specifies whether to verify that the server's hostname matches its certificate.
1228     ///
1229     /// # Warning
1230     ///
1231     /// You should think very carefully before using this method. If hostnames are not verified,
1232     /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1233     /// vulnerabilities, and should only be used as a last resort.
danger_accept_invalid_hostnames( &mut self, danger_accept_invalid_hostnames: bool, ) -> &mut Self1234     pub fn danger_accept_invalid_hostnames(
1235         &mut self,
1236         danger_accept_invalid_hostnames: bool,
1237     ) -> &mut Self {
1238         self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1239         self
1240     }
1241 
1242     /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self1243     pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1244         self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
1245         self
1246     }
1247 
1248     /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self1249     pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1250         self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
1251         self
1252     }
1253 
1254     /// Use the specified identity as a SSL/TLS client certificate.
identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self1255     pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1256         self.identity = Some(identity.clone());
1257         self.chain = chain.to_owned();
1258         self
1259     }
1260 
1261     /// Configure the minimum protocol that this client will support.
protocol_min(&mut self, min: SslProtocol) -> &mut Self1262     pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1263         self.protocol_min = Some(min);
1264         self
1265     }
1266 
1267     /// Configure the minimum protocol that this client will support.
protocol_max(&mut self, max: SslProtocol) -> &mut Self1268     pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1269         self.protocol_max = Some(max);
1270         self
1271     }
1272 
1273     /// Configures the set of protocols used for ALPN.
1274     #[cfg(feature = "alpn")]
alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self1275     pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1276         self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
1277         self
1278     }
1279 
1280     /// Configures the use of the RFC 5077 `SessionTicket` extension.
1281     ///
1282     /// Defaults to `false`.
1283     #[cfg(feature = "session-tickets")]
enable_session_tickets(&mut self, enable: bool) -> &mut Self1284     pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1285         self.enable_session_tickets = enable;
1286         self
1287     }
1288 
1289     /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1290     ///
1291     /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
handshake<S>( &self, domain: &str, stream: S, ) -> result::Result<SslStream<S>, ClientHandshakeError<S>> where S: Read + Write,1292     pub fn handshake<S>(
1293         &self,
1294         domain: &str,
1295         stream: S,
1296     ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1297     where
1298         S: Read + Write,
1299     {
1300         // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1301         // of the handshake logic through that.
1302         let stream = MidHandshakeSslStream {
1303             stream: self.ctx_into_stream(domain, stream)?,
1304             error: Error::from(errSecSuccess),
1305         };
1306 
1307         let certs = self.certs.clone();
1308         let stream = MidHandshakeClientBuilder {
1309             stream,
1310             domain: if self.danger_accept_invalid_hostnames {
1311                 None
1312             } else {
1313                 Some(domain.to_string())
1314             },
1315             certs,
1316             trust_certs_only: self.trust_certs_only,
1317             danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1318         };
1319         stream.handshake()
1320     }
1321 
ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>> where S: Read + Write,1322     fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1323     where
1324         S: Read + Write,
1325     {
1326         let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1327 
1328         if self.use_sni {
1329             ctx.set_peer_domain_name(domain)?;
1330         }
1331         if let Some(ref identity) = self.identity {
1332             ctx.set_certificate(identity, &self.chain)?;
1333         }
1334         #[cfg(feature = "alpn")]
1335         {
1336             if let Some(ref alpn) = self.alpn {
1337                 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1338             }
1339         }
1340         #[cfg(feature = "session-tickets")]
1341         {
1342             if self.enable_session_tickets {
1343                 // We must use the domain here to ensure that we go through certificate validation
1344                 // again rather than resuming the session if the domain changes.
1345                 ctx.set_peer_id(domain.as_bytes())?;
1346                 ctx.set_session_tickets_enabled(true)?;
1347             }
1348         }
1349         ctx.set_break_on_server_auth(true)?;
1350         self.configure_protocols(&mut ctx)?;
1351         self.configure_ciphers(&mut ctx)?;
1352 
1353         ctx.into_stream(stream)
1354     }
1355 
configure_protocols(&self, ctx: &mut SslContext) -> Result<()>1356     fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1357         if let Some(min) = self.protocol_min {
1358             ctx.set_protocol_version_min(min)?;
1359         }
1360         if let Some(max) = self.protocol_max {
1361             ctx.set_protocol_version_max(max)?;
1362         }
1363         Ok(())
1364     }
1365 
configure_ciphers(&self, ctx: &mut SslContext) -> Result<()>1366     fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1367         let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1368             ctx.enabled_ciphers()?
1369         } else {
1370             self.whitelisted_ciphers.clone()
1371         };
1372 
1373         if !self.blacklisted_ciphers.is_empty() {
1374             ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1375         }
1376 
1377         ctx.set_enabled_ciphers(&ciphers)?;
1378         Ok(())
1379     }
1380 }
1381 
1382 /// A builder type to simplify the creation of server-side `SslStream`s.
1383 #[derive(Debug)]
1384 pub struct ServerBuilder {
1385     identity: SecIdentity,
1386     certs: Vec<SecCertificate>,
1387 }
1388 
1389 impl ServerBuilder {
1390     /// Creates a new `ServerBuilder` which will use the specified identity
1391     /// and certificate chain for handshakes.
new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self1392     pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1393         Self {
1394             identity: identity.clone(),
1395             certs: certs.to_owned(),
1396         }
1397     }
1398 
1399     /// Initiates a new SSL/TLS session over a stream.
handshake<S>(&self, stream: S) -> Result<SslStream<S>> where S: Read + Write,1400     pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1401     where
1402         S: Read + Write,
1403     {
1404         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1405         ctx.set_certificate(&self.identity, &self.certs)?;
1406         match ctx.handshake(stream) {
1407             Ok(stream) => Ok(stream),
1408             Err(HandshakeError::Interrupted(stream)) => Err(stream.error().clone()),
1409             Err(HandshakeError::Failure(err)) => Err(err),
1410         }
1411     }
1412 }
1413 
1414 #[cfg(test)]
1415 mod test {
1416     use std::io;
1417     use std::io::prelude::*;
1418     use std::net::TcpStream;
1419 
1420     use super::*;
1421 
1422     #[test]
connect()1423     fn connect() {
1424         let mut ctx = p!(SslContext::new(
1425             SslProtocolSide::CLIENT,
1426             SslConnectionType::STREAM
1427         ));
1428         p!(ctx.set_peer_domain_name("google.com"));
1429         let stream = p!(TcpStream::connect("google.com:443"));
1430         p!(ctx.handshake(stream));
1431     }
1432 
1433     #[test]
connect_bad_domain()1434     fn connect_bad_domain() {
1435         let mut ctx = p!(SslContext::new(
1436             SslProtocolSide::CLIENT,
1437             SslConnectionType::STREAM
1438         ));
1439         p!(ctx.set_peer_domain_name("foobar.com"));
1440         let stream = p!(TcpStream::connect("google.com:443"));
1441         match ctx.handshake(stream) {
1442             Ok(_) => panic!("expected failure"),
1443             Err(_) => {}
1444         }
1445     }
1446 
1447     #[test]
load_page()1448     fn load_page() {
1449         let mut ctx = p!(SslContext::new(
1450             SslProtocolSide::CLIENT,
1451             SslConnectionType::STREAM
1452         ));
1453         p!(ctx.set_peer_domain_name("google.com"));
1454         let stream = p!(TcpStream::connect("google.com:443"));
1455         let mut stream = p!(ctx.handshake(stream));
1456         p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1457         p!(stream.flush());
1458         let mut buf = vec![];
1459         p!(stream.read_to_end(&mut buf));
1460         println!("{}", String::from_utf8_lossy(&buf));
1461     }
1462 
1463     #[test]
client_no_session_ticket_resumption()1464     fn client_no_session_ticket_resumption() {
1465         for _ in 0..2 {
1466             let stream = p!(TcpStream::connect("google.com:443"));
1467 
1468             // Manually handshake here.
1469             let stream = MidHandshakeSslStream {
1470                 stream: ClientBuilder::new()
1471                     .ctx_into_stream("google.com", stream)
1472                     .unwrap(),
1473                 error: Error::from(errSecSuccess),
1474             };
1475 
1476             let mut result = stream.handshake();
1477 
1478             if let Err(HandshakeError::Interrupted(stream)) = result {
1479                 assert!(stream.server_auth_completed());
1480                 result = stream.handshake();
1481             } else {
1482                 panic!("Unexpectedly skipped server auth");
1483             }
1484 
1485             assert!(result.is_ok());
1486         }
1487     }
1488 
1489     #[test]
1490     #[cfg(feature = "session-tickets")]
client_session_ticket_resumption()1491     fn client_session_ticket_resumption() {
1492         // The first time through this loop, we should do a full handshake. The second time, we
1493         // should immediately finish the handshake without breaking on server auth.
1494         for i in 0..2 {
1495             let stream = p!(TcpStream::connect("google.com:443"));
1496             let mut builder = ClientBuilder::new();
1497             builder.enable_session_tickets(true);
1498 
1499             // Manually handshake here.
1500             let stream = MidHandshakeSslStream {
1501                 stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1502                 error: Error::from(errSecSuccess),
1503             };
1504 
1505             let mut result = stream.handshake();
1506 
1507             if let Err(HandshakeError::Interrupted(stream)) = result {
1508                 assert!(stream.server_auth_completed());
1509                 assert_eq!(
1510                     i, 0,
1511                     "Session ticket resumption did not work, server auth was not skipped"
1512                 );
1513                 result = stream.handshake();
1514             } else {
1515                 assert_eq!(i, 1, "Unexpectedly skipped server auth");
1516             }
1517 
1518             assert!(result.is_ok());
1519         }
1520     }
1521 
1522     #[test]
1523     #[cfg(feature = "alpn")]
client_alpn_accept()1524     fn client_alpn_accept() {
1525         let mut ctx = p!(SslContext::new(
1526             SslProtocolSide::CLIENT,
1527             SslConnectionType::STREAM
1528         ));
1529         p!(ctx.set_peer_domain_name("google.com"));
1530         p!(ctx.set_alpn_protocols(&vec!["h2"]));
1531         let stream = p!(TcpStream::connect("google.com:443"));
1532         let stream = ctx.handshake(stream).unwrap();
1533         assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1534     }
1535 
1536     #[test]
1537     #[cfg(feature = "alpn")]
client_alpn_reject()1538     fn client_alpn_reject() {
1539         let mut ctx = p!(SslContext::new(
1540             SslProtocolSide::CLIENT,
1541             SslConnectionType::STREAM
1542         ));
1543         p!(ctx.set_peer_domain_name("google.com"));
1544         p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1545         let stream = p!(TcpStream::connect("google.com:443"));
1546         let stream = ctx.handshake(stream).unwrap();
1547         assert!(stream.context().alpn_protocols().is_err());
1548     }
1549 
1550     #[test]
client_no_anchor_certs()1551     fn client_no_anchor_certs() {
1552         let stream = p!(TcpStream::connect("google.com:443"));
1553         assert!(ClientBuilder::new()
1554             .trust_anchor_certificates_only(true)
1555             .handshake("google.com", stream)
1556             .is_err());
1557     }
1558 
1559     #[test]
client_bad_domain()1560     fn client_bad_domain() {
1561         let stream = p!(TcpStream::connect("google.com:443"));
1562         assert!(ClientBuilder::new()
1563             .handshake("foobar.com", stream)
1564             .is_err());
1565     }
1566 
1567     #[test]
client_bad_domain_ignored()1568     fn client_bad_domain_ignored() {
1569         let stream = p!(TcpStream::connect("google.com:443"));
1570         ClientBuilder::new()
1571             .danger_accept_invalid_hostnames(true)
1572             .handshake("foobar.com", stream)
1573             .unwrap();
1574     }
1575 
1576     #[test]
connect_no_verify_ssl()1577     fn connect_no_verify_ssl() {
1578         let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1579         let mut builder = ClientBuilder::new();
1580         builder.danger_accept_invalid_certs(true);
1581         builder.handshake("expired.badssl.com", stream).unwrap();
1582     }
1583 
1584     #[test]
load_page_client()1585     fn load_page_client() {
1586         let stream = p!(TcpStream::connect("google.com:443"));
1587         let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1588         p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1589         p!(stream.flush());
1590         let mut buf = vec![];
1591         p!(stream.read_to_end(&mut buf));
1592         println!("{}", String::from_utf8_lossy(&buf));
1593     }
1594 
1595     #[test]
1596     #[cfg_attr(target_os = "ios", ignore)] // FIXME what's going on with ios?
cipher_configuration()1597     fn cipher_configuration() {
1598         let mut ctx = p!(SslContext::new(
1599             SslProtocolSide::SERVER,
1600             SslConnectionType::STREAM
1601         ));
1602         let ciphers = p!(ctx.enabled_ciphers());
1603         let ciphers = ciphers
1604             .iter()
1605             .enumerate()
1606             .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1607             .collect::<Vec<_>>();
1608         p!(ctx.set_enabled_ciphers(&ciphers));
1609         assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1610     }
1611 
1612     #[test]
test_builder_whitelist_ciphers()1613     fn test_builder_whitelist_ciphers() {
1614         let stream = p!(TcpStream::connect("google.com:443"));
1615 
1616         let ctx = p!(SslContext::new(
1617             SslProtocolSide::CLIENT,
1618             SslConnectionType::STREAM
1619         ));
1620         assert!(p!(ctx.enabled_ciphers()).len() > 1);
1621 
1622         let ciphers = p!(ctx.enabled_ciphers());
1623         let cipher = ciphers.first().unwrap();
1624         let stream = p!(ClientBuilder::new()
1625             .whitelist_ciphers(&[*cipher])
1626             .ctx_into_stream("google.com", stream));
1627 
1628         assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1629     }
1630 
1631     #[test]
1632     #[cfg_attr(target_os = "ios", ignore)] // FIXME same issue as cipher_configuration
test_builder_blacklist_ciphers()1633     fn test_builder_blacklist_ciphers() {
1634         let stream = p!(TcpStream::connect("google.com:443"));
1635 
1636         let ctx = p!(SslContext::new(
1637             SslProtocolSide::CLIENT,
1638             SslConnectionType::STREAM
1639         ));
1640         let num = p!(ctx.enabled_ciphers()).len();
1641         assert!(num > 1);
1642 
1643         let ciphers = p!(ctx.enabled_ciphers());
1644         let cipher = ciphers.first().unwrap();
1645         let stream = p!(ClientBuilder::new()
1646             .blacklist_ciphers(&[*cipher])
1647             .ctx_into_stream("google.com", stream));
1648 
1649         assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1650     }
1651 
1652     #[test]
idle_context_peer_trust()1653     fn idle_context_peer_trust() {
1654         let ctx = p!(SslContext::new(
1655             SslProtocolSide::SERVER,
1656             SslConnectionType::STREAM
1657         ));
1658         assert!(ctx.peer_trust2().is_err());
1659     }
1660 
1661     #[test]
peer_id()1662     fn peer_id() {
1663         let mut ctx = p!(SslContext::new(
1664             SslProtocolSide::SERVER,
1665             SslConnectionType::STREAM
1666         ));
1667         assert!(p!(ctx.peer_id()).is_none());
1668         p!(ctx.set_peer_id(b"foobar"));
1669         assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1670     }
1671 
1672     #[test]
peer_domain_name()1673     fn peer_domain_name() {
1674         let mut ctx = p!(SslContext::new(
1675             SslProtocolSide::CLIENT,
1676             SslConnectionType::STREAM
1677         ));
1678         assert_eq!("", p!(ctx.peer_domain_name()));
1679         p!(ctx.set_peer_domain_name("foobar.com"));
1680         assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1681     }
1682 
1683     #[test]
1684     #[should_panic(expected = "blammo")]
write_panic()1685     fn write_panic() {
1686         struct ExplodingStream(TcpStream);
1687 
1688         impl Read for ExplodingStream {
1689             fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1690                 self.0.read(buf)
1691             }
1692         }
1693 
1694         impl Write for ExplodingStream {
1695             fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1696                 panic!("blammo");
1697             }
1698 
1699             fn flush(&mut self) -> io::Result<()> {
1700                 self.0.flush()
1701             }
1702         }
1703 
1704         let mut ctx = p!(SslContext::new(
1705             SslProtocolSide::CLIENT,
1706             SslConnectionType::STREAM
1707         ));
1708         p!(ctx.set_peer_domain_name("google.com"));
1709         let stream = p!(TcpStream::connect("google.com:443"));
1710         let _ = ctx.handshake(ExplodingStream(stream));
1711     }
1712 
1713     #[test]
1714     #[should_panic(expected = "blammo")]
read_panic()1715     fn read_panic() {
1716         struct ExplodingStream(TcpStream);
1717 
1718         impl Read for ExplodingStream {
1719             fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1720                 panic!("blammo");
1721             }
1722         }
1723 
1724         impl Write for ExplodingStream {
1725             fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1726                 self.0.write(buf)
1727             }
1728 
1729             fn flush(&mut self) -> io::Result<()> {
1730                 self.0.flush()
1731             }
1732         }
1733 
1734         let mut ctx = p!(SslContext::new(
1735             SslProtocolSide::CLIENT,
1736             SslConnectionType::STREAM
1737         ));
1738         p!(ctx.set_peer_domain_name("google.com"));
1739         let stream = p!(TcpStream::connect("google.com:443"));
1740         let _ = ctx.handshake(ExplodingStream(stream));
1741     }
1742 
1743     #[test]
zero_length_buffers()1744     fn zero_length_buffers() {
1745         let mut ctx = p!(SslContext::new(
1746             SslProtocolSide::CLIENT,
1747             SslConnectionType::STREAM
1748         ));
1749         p!(ctx.set_peer_domain_name("google.com"));
1750         let stream = p!(TcpStream::connect("google.com:443"));
1751         let mut stream = ctx.handshake(stream).unwrap();
1752         assert_eq!(stream.write(b"").unwrap(), 0);
1753         assert_eq!(stream.read(&mut []).unwrap(), 0);
1754     }
1755 }
1756