1 //! SSL/TLS encryption support using Secure Transport.
2 //!
3 //! # Examples
4 //!
5 //! To connect as a client to a server with a certificate trusted by the system:
6 //!
7 //! ```rust
8 //! use std::io::prelude::*;
9 //! use std::net::TcpStream;
10 //! use security_framework::secure_transport::ClientBuilder;
11 //!
12 //! let stream = TcpStream::connect("google.com:443").unwrap();
13 //! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
14 //!
15 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
16 //! let mut page = vec![];
17 //! stream.read_to_end(&mut page).unwrap();
18 //! println!("{}", String::from_utf8_lossy(&page));
19 //! ```
20 //!
21 //! To connect to a server with a certificate that's *not* trusted by the
22 //! system, specify the root certificates for the server's chain to the
23 //! `ClientBuilder`:
24 //!
25 //! ```rust,no_run
26 //! use std::io::prelude::*;
27 //! use std::net::TcpStream;
28 //! use security_framework::secure_transport::ClientBuilder;
29 //!
30 //! # let root_cert = unsafe { std::mem::zeroed() };
31 //! let stream = TcpStream::connect("my_server.com:443").unwrap();
32 //! let mut stream = ClientBuilder::new()
33 //!                      .anchor_certificates(&[root_cert])
34 //!                      .handshake("my_server.com", stream)
35 //!                      .unwrap();
36 //!
37 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
38 //! let mut page = vec![];
39 //! stream.read_to_end(&mut page).unwrap();
40 //! println!("{}", String::from_utf8_lossy(&page));
41 //! ```
42 //!
43 //! For more advanced configuration, the `SslContext` type can be used directly.
44 //!
45 //! To run a server:
46 //!
47 //! ```rust,no_run
48 //! use std::net::TcpListener;
49 //! use std::thread;
50 //! use security_framework::secure_transport::{SslContext, SslProtocolSide, SslConnectionType};
51 //!
52 //! // Create a TCP listener and start accepting on it.
53 //! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
54 //!
55 //! for stream in listener.incoming() {
56 //!     let stream = stream.unwrap();
57 //!     thread::spawn(move || {
58 //!         // Create a new context configured to operate on the server side of
59 //!         // a traditional SSL/TLS session.
60 //!         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
61 //!                           .unwrap();
62 //!
63 //!         // Install the certificate chain that we will be using.
64 //!         # let identity = unsafe { std::mem::zeroed() };
65 //!         # let intermediate_cert = unsafe { std::mem::zeroed() };
66 //!         # let root_cert = unsafe { std::mem::zeroed() };
67 //!         ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
68 //!
69 //!         // Perform the SSL/TLS handshake and get our stream.
70 //!         let mut stream = ctx.handshake(stream).unwrap();
71 //!     });
72 //! }
73 //!
74 //! ```
75 #[allow(unused_imports)]
76 use core_foundation::array::{CFArray, CFArrayRef};
77 
78 use core_foundation::base::{Boolean, TCFType};
79 #[cfg(feature = "alpn")]
80 use core_foundation::string::CFString;
81 use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82 use std::os::raw::c_void;
83 
84 #[allow(unused_imports)]
85 use security_framework_sys::base::{
86     errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87     errSecUnimplemented,
88 };
89 
90 use security_framework_sys::secure_transport::*;
91 use std::any::Any;
92 use std::cmp;
93 use std::fmt;
94 use std::io;
95 use std::io::prelude::*;
96 use std::marker::PhantomData;
97 use std::mem;
98 use std::panic::{self, AssertUnwindSafe};
99 use std::ptr;
100 use std::result;
101 use std::slice;
102 
103 use crate::base::{Error, Result};
104 use crate::certificate::SecCertificate;
105 use crate::cipher_suite::CipherSuite;
106 use crate::identity::SecIdentity;
107 use crate::policy::SecPolicy;
108 use crate::trust::{SecTrust, TrustResult};
109 use crate::{cvt, AsInner};
110 
111 /// Specifies a side of a TLS session.
112 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
113 pub struct SslProtocolSide(SSLProtocolSide);
114 
115 impl SslProtocolSide {
116     /// The server side of the session.
117     pub const SERVER: Self = Self(kSSLServerSide);
118 
119     /// The client side of the session.
120     pub const CLIENT: Self = Self(kSSLClientSide);
121 }
122 
123 /// Specifies the type of TLS session.
124 #[derive(Debug, Copy, Clone)]
125 pub struct SslConnectionType(SSLConnectionType);
126 
127 impl SslConnectionType {
128     /// A traditional TLS stream.
129     pub const STREAM: Self = Self(kSSLStreamType);
130 
131     /// A DTLS session.
132     pub const DATAGRAM: Self = Self(kSSLDatagramType);
133 }
134 
135 /// An error or intermediate state after a TLS handshake attempt.
136 #[derive(Debug)]
137 pub enum HandshakeError<S> {
138     /// The handshake failed.
139     Failure(Error),
140     /// The handshake was interrupted midway through.
141     Interrupted(MidHandshakeSslStream<S>),
142 }
143 
144 impl<S> From<Error> for HandshakeError<S> {
from(err: Error) -> Self145     fn from(err: Error) -> Self {
146         Self::Failure(err)
147     }
148 }
149 
150 /// An error or intermediate state after a TLS handshake attempt.
151 #[derive(Debug)]
152 pub enum ClientHandshakeError<S> {
153     /// The handshake failed.
154     Failure(Error),
155     /// The handshake was interrupted midway through.
156     Interrupted(MidHandshakeClientBuilder<S>),
157 }
158 
159 impl<S> From<Error> for ClientHandshakeError<S> {
from(err: Error) -> Self160     fn from(err: Error) -> Self {
161         Self::Failure(err)
162     }
163 }
164 
165 /// An SSL stream midway through the handshake process.
166 #[derive(Debug)]
167 pub struct MidHandshakeSslStream<S> {
168     stream: SslStream<S>,
169     error: Error,
170 }
171 
172 impl<S> MidHandshakeSslStream<S> {
173     /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S174     pub fn get_ref(&self) -> &S {
175         self.stream.get_ref()
176     }
177 
178     /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S179     pub fn get_mut(&mut self) -> &mut S {
180         self.stream.get_mut()
181     }
182 
183     /// Returns a shared reference to the `SslContext` of the stream.
context(&self) -> &SslContext184     pub fn context(&self) -> &SslContext {
185         self.stream.context()
186     }
187 
188     /// Returns a mutable reference to the `SslContext` of the stream.
context_mut(&mut self) -> &mut SslContext189     pub fn context_mut(&mut self) -> &mut SslContext {
190         self.stream.context_mut()
191     }
192 
193     /// Returns `true` iff `break_on_server_auth` was set and the handshake has
194     /// progressed to that point.
server_auth_completed(&self) -> bool195     pub fn server_auth_completed(&self) -> bool {
196         self.error.code() == errSSLPeerAuthCompleted
197     }
198 
199     /// Returns `true` iff `break_on_cert_requested` was set and the handshake
200     /// has progressed to that point.
client_cert_requested(&self) -> bool201     pub fn client_cert_requested(&self) -> bool {
202         self.error.code() == errSSLClientCertRequested
203     }
204 
205     /// Returns `true` iff the underlying stream returned an error with the
206     /// `WouldBlock` kind.
would_block(&self) -> bool207     pub fn would_block(&self) -> bool {
208         self.error.code() == errSSLWouldBlock
209     }
210 
211     /// Deprecated
reason(&self) -> OSStatus212     pub fn reason(&self) -> OSStatus {
213         self.error.code()
214     }
215 
216     /// Returns the error which caused the handshake interruption.
error(&self) -> &Error217     pub fn error(&self) -> &Error {
218         &self.error
219     }
220 
221     /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>>222     pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
223         self.stream.handshake()
224     }
225 }
226 
227 /// An SSL stream midway through the handshake process.
228 #[derive(Debug)]
229 pub struct MidHandshakeClientBuilder<S> {
230     stream: MidHandshakeSslStream<S>,
231     domain: Option<String>,
232     certs: Vec<SecCertificate>,
233     trust_certs_only: bool,
234     danger_accept_invalid_certs: bool,
235 }
236 
237 impl<S> MidHandshakeClientBuilder<S> {
238     /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S239     pub fn get_ref(&self) -> &S {
240         self.stream.get_ref()
241     }
242 
243     /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S244     pub fn get_mut(&mut self) -> &mut S {
245         self.stream.get_mut()
246     }
247 
248     /// Returns the error which caused the handshake interruption.
error(&self) -> &Error249     pub fn error(&self) -> &Error {
250         self.stream.error()
251     }
252 
253     /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>>254     pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
255         let MidHandshakeClientBuilder {
256             stream,
257             domain,
258             certs,
259             trust_certs_only,
260             danger_accept_invalid_certs,
261         } = self;
262 
263         let mut result = stream.handshake();
264         loop {
265             let stream = match result {
266                 Ok(stream) => return Ok(stream),
267                 Err(HandshakeError::Interrupted(stream)) => stream,
268                 Err(HandshakeError::Failure(err)) => {
269                     return Err(ClientHandshakeError::Failure(err))
270                 }
271             };
272 
273             if stream.would_block() {
274                 let ret = MidHandshakeClientBuilder {
275                     stream,
276                     domain,
277                     certs,
278                     trust_certs_only,
279                     danger_accept_invalid_certs,
280                 };
281                 return Err(ClientHandshakeError::Interrupted(ret));
282             }
283 
284             if stream.server_auth_completed() {
285                 if danger_accept_invalid_certs {
286                     result = stream.handshake();
287                     continue;
288                 }
289                 let mut trust = match stream.context().peer_trust2()? {
290                     Some(trust) => trust,
291                     None => {
292                         result = stream.handshake();
293                         continue;
294                     }
295                 };
296                 trust.set_anchor_certificates(&certs)?;
297                 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
298                 let policy =
299                     SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_ref().map(|s| &**s));
300                 trust.set_policy(&policy)?;
301                 let trusted = trust.evaluate()?;
302                 match trusted {
303                     TrustResult::PROCEED | TrustResult::UNSPECIFIED => {
304                         result = stream.handshake();
305                         continue;
306                     }
307                     TrustResult::DENY => {
308                         let err = Error::from_code(errSecTrustSettingDeny);
309                         return Err(ClientHandshakeError::Failure(err));
310                     }
311                     TrustResult::RECOVERABLE_TRUST_FAILURE | TrustResult::FATAL_TRUST_FAILURE => {
312                         let err = Error::from_code(errSecNotTrusted);
313                         return Err(ClientHandshakeError::Failure(err));
314                     }
315                     TrustResult::INVALID | TrustResult::OTHER_ERROR | _ => {
316                         let err = Error::from_code(errSecBadReq);
317                         return Err(ClientHandshakeError::Failure(err));
318                     }
319                 }
320             }
321 
322             let err = Error::from_code(stream.reason());
323             return Err(ClientHandshakeError::Failure(err));
324         }
325     }
326 }
327 
328 /// Specifies the state of a TLS session.
329 #[derive(Debug, PartialEq, Eq)]
330 pub struct SessionState(SSLSessionState);
331 
332 impl SessionState {
333     /// The session has not yet started.
334     pub const IDLE: Self = Self(kSSLIdle);
335 
336     /// The session is in the handshake process.
337     pub const HANDSHAKE: Self = Self(kSSLHandshake);
338 
339     /// The session is connected.
340     pub const CONNECTED: Self = Self(kSSLConnected);
341 
342     /// The session has been terminated.
343     pub const CLOSED: Self = Self(kSSLClosed);
344 
345     /// The session has been aborted due to an error.
346     pub const ABORTED: Self = Self(kSSLAborted);
347 }
348 
349 /// Specifies a server's requirement for client certificates.
350 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
351 pub struct SslAuthenticate(SSLAuthenticate);
352 
353 impl SslAuthenticate {
354     /// Do not request a client certificate.
355     pub const NEVER: Self = Self(kNeverAuthenticate);
356 
357     /// Require a client certificate.
358     pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
359 
360     /// Request but do not require a client certificate.
361     pub const TRY: Self = Self(kTryAuthenticate);
362 }
363 
364 /// Specifies the state of client certificate processing.
365 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
366 pub struct SslClientCertificateState(SSLClientCertificateState);
367 
368 impl SslClientCertificateState {
369     /// A client certificate has not been requested or sent.
370     pub const NONE: Self = Self(kSSLClientCertNone);
371 
372     /// A client certificate has been requested but not recieved.
373     pub const REQUESTED: Self = Self(kSSLClientCertRequested);
374     /// A client certificate has been received and successfully validated.
375     pub const SENT: Self = Self(kSSLClientCertSent);
376 
377     /// A client certificate has been received but has failed to validate.
378     pub const REJECTED: Self = Self(kSSLClientCertRejected); }
379 
380 /// Specifies protocol versions.
381 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
382 pub struct SslProtocol(SSLProtocol);
383 
384 impl SslProtocol {
385     /// No protocol has been or should be negotiated or specified; use the default.
386     pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
387 
388     /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
389     /// SSL 3.0.
390     pub const SSL3: Self = Self(kSSLProtocol3);
391 
392     /// The TLS 1.0 protocol is preferred, though lower versions may be used
393     /// if the peer does not support TLS 1.0.
394     pub const TLS1: Self = Self(kTLSProtocol1);
395 
396     /// The TLS 1.1 protocol is preferred, though lower versions may be used
397     /// if the peer does not support TLS 1.1.
398     pub const TLS11: Self = Self(kTLSProtocol11);
399 
400     /// The TLS 1.2 protocol is preferred, though lower versions may be used
401     /// if the peer does not support TLS 1.2.
402     pub const TLS12: Self = Self(kTLSProtocol12);
403 
404     /// The TLS 1.3 protocol is preferred, though lower versions may be used
405     /// if the peer does not support TLS 1.3.
406     pub const TLS13: Self = Self(kTLSProtocol13);
407 
408     /// Only the SSL 2.0 protocol is accepted.
409     pub const SSL2: Self = Self(kSSLProtocol2);
410 
411     /// The DTLSv1 protocol is preferred.
412     pub const DTLS1: Self = Self(kDTLSProtocol1);
413 
414     /// Only the SSL 3.0 protocol is accepted.
415     pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
416 
417     /// Only the TLS 1.0 protocol is accepted.
418     pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
419 
420     /// All supported TLS/SSL versions are accepted.
421     pub const ALL: Self = Self(kSSLProtocolAll);
422 }
423 
424 declare_TCFType! {
425     /// A Secure Transport SSL/TLS context object.
426     SslContext, SSLContextRef
427 }
428 
429 impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
430 
431 impl fmt::Debug for SslContext {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result432     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
433         let mut builder = fmt.debug_struct("SslContext");
434         if let Ok(state) = self.state() {
435             builder.field("state", &state);
436         }
437         builder.finish()
438     }
439 }
440 
441 unsafe impl Sync for SslContext {}
442 unsafe impl Send for SslContext {}
443 
444 impl AsInner for SslContext {
445     type Inner = SSLContextRef;
446 
as_inner(&self) -> SSLContextRef447     fn as_inner(&self) -> SSLContextRef {
448         self.0
449     }
450 }
451 
452 macro_rules! impl_options {
453     ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
454         $(
455             $(#[$a])*
456             pub fn $set(&mut self, value: bool) -> Result<()> {
457                 unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
458             }
459 
460             $(#[$a])*
461             pub fn $get(&self) -> Result<bool> {
462                 let mut value = 0;
463                 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
464                 Ok(value != 0)
465             }
466         )*
467     }
468 }
469 
470 impl SslContext {
471     /// Creates a new `SslContext` for the specified side and type of SSL
472     /// connection.
new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self>473     pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
474         unsafe {
475             let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
476             Ok(Self(ctx))
477         }
478     }
479 
480     /// Sets the fully qualified domain name of the peer.
481     ///
482     /// This will be used on the client side of a session to validate the
483     /// common name field of the server's certificate. It has no effect if
484     /// called on a server-side `SslContext`.
485     ///
486     /// It is *highly* recommended to call this method before starting the
487     /// handshake process.
set_peer_domain_name(&mut self, peer_name: &str) -> Result<()>488     pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
489         unsafe {
490             // SSLSetPeerDomainName doesn't need a null terminated string
491             cvt(SSLSetPeerDomainName(
492                 self.0,
493                 peer_name.as_ptr() as *const _,
494                 peer_name.len(),
495             ))
496         }
497     }
498 
499     /// Returns the peer domain name set by `set_peer_domain_name`.
peer_domain_name(&self) -> Result<String>500     pub fn peer_domain_name(&self) -> Result<String> {
501         unsafe {
502             let mut len = 0;
503             cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
504             let mut buf = vec![0; len];
505             cvt(SSLGetPeerDomainName(
506                 self.0,
507                 buf.as_mut_ptr() as *mut _,
508                 &mut len,
509             ))?;
510             Ok(String::from_utf8(buf).unwrap())
511         }
512     }
513 
514     /// Sets the certificate to be used by this side of the SSL session.
515     ///
516     /// This must be called before the handshake for server-side connections,
517     /// and can be used on the client-side to specify a client certificate.
518     ///
519     /// The `identity` corresponds to the leaf certificate and private
520     /// key, and the `certs` correspond to extra certificates in the chain.
set_certificate( &mut self, identity: &SecIdentity, certs: &[SecCertificate], ) -> Result<()>521     pub fn set_certificate(
522         &mut self,
523         identity: &SecIdentity,
524         certs: &[SecCertificate],
525     ) -> Result<()> {
526         let mut arr = vec![identity.as_CFType()];
527         arr.extend(certs.iter().map(|c| c.as_CFType()));
528         let certs = CFArray::from_CFTypes(&arr);
529 
530         unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
531     }
532 
533     /// Sets the peer ID of this session.
534     ///
535     /// A peer ID is an opaque sequence of bytes that will be used by Secure
536     /// Transport to identify the peer of an SSL session. If the peer ID of
537     /// this session matches that of a previously terminated session, the
538     /// previous session can be resumed without requiring a full handshake.
set_peer_id(&mut self, peer_id: &[u8]) -> Result<()>539     pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
540         unsafe {
541             cvt(SSLSetPeerID(
542                 self.0,
543                 peer_id.as_ptr() as *const _,
544                 peer_id.len(),
545             ))
546         }
547     }
548 
549     /// Returns the peer ID of this session.
peer_id(&self) -> Result<Option<&[u8]>>550     pub fn peer_id(&self) -> Result<Option<&[u8]>> {
551         unsafe {
552             let mut ptr = ptr::null();
553             let mut len = 0;
554             cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
555             if ptr.is_null() {
556                 Ok(None)
557             } else {
558                 Ok(Some(slice::from_raw_parts(ptr as *const _, len)))
559             }
560         }
561     }
562 
563     /// Returns the list of ciphers that are supported by Secure Transport.
supported_ciphers(&self) -> Result<Vec<CipherSuite>>564     pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
565         unsafe {
566             let mut num_ciphers = 0;
567             cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
568             let mut ciphers = vec![0; num_ciphers];
569             cvt(SSLGetSupportedCiphers(
570                 self.0,
571                 ciphers.as_mut_ptr(),
572                 &mut num_ciphers,
573             ))?;
574             Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
575         }
576     }
577 
578     /// Returns the list of ciphers that are eligible to be used for
579     /// negotiation.
enabled_ciphers(&self) -> Result<Vec<CipherSuite>>580     pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
581         unsafe {
582             let mut num_ciphers = 0;
583             cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
584             let mut ciphers = vec![0; num_ciphers];
585             cvt(SSLGetEnabledCiphers(
586                 self.0,
587                 ciphers.as_mut_ptr(),
588                 &mut num_ciphers,
589             ))?;
590             Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
591         }
592     }
593 
594     /// Sets the list of ciphers that are eligible to be used for negotiation.
set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()>595     pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
596         let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
597         unsafe {
598             cvt(SSLSetEnabledCiphers(
599                 self.0,
600                 ciphers.as_ptr(),
601                 ciphers.len(),
602             ))
603         }
604     }
605 
606     /// Returns the cipher being used by the session.
negotiated_cipher(&self) -> Result<CipherSuite>607     pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
608         unsafe {
609             let mut cipher = 0;
610             cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
611             Ok(CipherSuite::from_raw(cipher))
612         }
613     }
614 
615     /// Sets the requirements for client certificates.
616     ///
617     /// Should only be called on server-side sessions.
set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()>618     pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
619         unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
620     }
621 
622     /// Returns the state of client certificate processing.
client_certificate_state(&self) -> Result<SslClientCertificateState>623     pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
624         let mut state = 0;
625 
626         unsafe {
627             cvt(SSLGetClientCertificateState(self.0, &mut state))?;
628         }
629         Ok(SslClientCertificateState(state))
630     }
631 
632     /// Returns the `SecTrust` object corresponding to the peer.
633     ///
634     /// This can be used in conjunction with `set_break_on_server_auth` to
635     /// validate certificates which do not have roots in the default set.
peer_trust2(&self) -> Result<Option<SecTrust>>636     pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
637         // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
638         // so explicitly check for that
639         if self.state()? == SessionState::IDLE {
640             return Err(Error::from_code(errSecBadReq));
641         }
642 
643         unsafe {
644             let mut trust = ptr::null_mut();
645             cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
646             if trust.is_null() {
647                 Ok(None)
648             } else {
649                 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
650             }
651         }
652     }
653 
654     #[allow(missing_docs)]
655     #[deprecated(since = "0.2.1", note = "use peer_trust2 instead")]
peer_trust(&self) -> Result<SecTrust>656     pub fn peer_trust(&self) -> Result<SecTrust> {
657         match self.peer_trust2() {
658             Ok(Some(trust)) => Ok(trust),
659             Ok(None) => panic!("no trust available"),
660             Err(e) => Err(e),
661         }
662     }
663 
664     /// Returns the state of the session.
state(&self) -> Result<SessionState>665     pub fn state(&self) -> Result<SessionState> {
666         unsafe {
667             let mut state = 0;
668             cvt(SSLGetSessionState(self.0, &mut state))?;
669             Ok(SessionState(state))
670         }
671     }
672 
673     /// Returns the protocol version being used by the session.
negotiated_protocol_version(&self) -> Result<SslProtocol>674     pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
675         unsafe {
676             let mut version = 0;
677             cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
678             Ok(SslProtocol(version))
679         }
680     }
681 
682     /// Returns the maximum protocol version allowed by the session.
protocol_version_max(&self) -> Result<SslProtocol>683     pub fn protocol_version_max(&self) -> Result<SslProtocol> {
684         unsafe {
685             let mut version = 0;
686             cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
687             Ok(SslProtocol(version))
688         }
689     }
690 
691     /// Sets the maximum protocol version allowed by the session.
set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()>692     pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
693         unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
694     }
695 
696     /// Returns the minimum protocol version allowed by the session.
protocol_version_min(&self) -> Result<SslProtocol>697     pub fn protocol_version_min(&self) -> Result<SslProtocol> {
698         unsafe {
699             let mut version = 0;
700             cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
701             Ok(SslProtocol(version))
702         }
703     }
704 
705     /// Sets the minimum protocol version allowed by the session.
set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()>706     pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
707         unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
708     }
709 
710     /// Returns the set of protocols selected via ALPN if it succeeded.
711     #[cfg(feature = "alpn")]
alpn_protocols(&self) -> Result<Vec<String>>712     pub fn alpn_protocols(&self) -> Result<Vec<String>> {
713         let mut array: CFArrayRef = ptr::null();
714         unsafe {
715             #[cfg(feature = "OSX_10_13")]
716             {
717                 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
718             }
719 
720             #[cfg(not(feature = "OSX_10_13"))]
721             {
722                 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
723                 if let Some(f) = SSLCopyALPNProtocols.get() {
724                     cvt(f(self.0, &mut array))?;
725                 } else {
726                     return Err(Error::from_code(errSecUnimplemented));
727                 }
728             }
729 
730             if array.is_null() {
731                 return Ok(vec![]);
732             }
733 
734             let array = CFArray::<CFString>::wrap_under_create_rule(array);
735             Ok(array.into_iter().map(|p| p.to_string()).collect())
736         }
737     }
738 
739     /// Configures the set of protocols use for ALPN.
740     ///
741     /// This is only used for client-side connections.
742     #[cfg(feature = "alpn")]
set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()>743     pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
744         // When CFMutableArray is added to core-foundation and IntoIterator trait
745         // is implemented for CFMutableArray, the code below should directly collect
746         // into a CFMutableArray.
747         let protocols = CFArray::from_CFTypes(
748             &protocols
749                 .iter()
750                 .map(|proto| CFString::new(proto))
751                 .collect::<Vec<_>>(),
752         );
753 
754         #[cfg(feature = "OSX_10_13")]
755         {
756             unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
757         }
758         #[cfg(not(feature = "OSX_10_13"))]
759         {
760             dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
761             if let Some(f) = SSLSetALPNProtocols.get() {
762                 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
763             } else {
764                 Err(Error::from_code(errSecUnimplemented))
765             }
766         }
767     }
768 
769     /// Sets whether the client sends the `SessionTicket` extension in its `ClientHello`.
770     ///
771     /// On its own, this will just cause the client to send an empty `SessionTicket` extension on
772     /// every connection. [`SslContext::set_peer_id`] must also be used to key the session
773     /// ticket returned by the server.
774     ///
775     /// [`SslContext::set_peer_id`]: #method.set_peer_id
776     #[cfg(feature = "session-tickets")]
set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()>777     pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
778         #[cfg(feature = "OSX_10_13")]
779         {
780             unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, enabled as Boolean)) }
781         }
782         #[cfg(not(feature = "OSX_10_13"))]
783         {
784             dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
785             if let Some(f) = SSLSetSessionTicketsEnabled.get() {
786                 unsafe { cvt(f(self.0, enabled as Boolean)) }
787             } else {
788                 Err(Error::from_code(errSecUnimplemented))
789             }
790         }
791     }
792 
793     /// Sets whether a protocol is enabled or not.
794     ///
795     /// # Note
796     ///
797     /// On OSX this is a deprecated API in favor of `set_protocol_version_max` and
798     /// `set_protocol_version_min`, although if you're working with OSX 10.8 or before you may have
799     /// to use this API instead.
800     #[cfg(target_os = "macos")]
set_protocol_version_enabled( &mut self, protocol: SslProtocol, enabled: bool, ) -> Result<()>801     pub fn set_protocol_version_enabled(
802         &mut self,
803         protocol: SslProtocol,
804         enabled: bool,
805     ) -> Result<()> {
806         unsafe {
807             cvt(SSLSetProtocolVersionEnabled(
808                 self.0,
809                 protocol.0,
810                 enabled as Boolean,
811             ))
812         }
813     }
814 
815     /// Returns the number of bytes which can be read without triggering a
816     /// `read` call in the underlying stream.
buffered_read_size(&self) -> Result<usize>817     pub fn buffered_read_size(&self) -> Result<usize> {
818         unsafe {
819             let mut size = 0;
820             cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
821             Ok(size)
822         }
823     }
824 
825     impl_options! {
826         /// If enabled, the handshake process will pause and return instead of
827         /// automatically validating a server's certificate.
828         const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
829         /// If enabled, the handshake process will pause and return after
830         /// the server requests a certificate from the client.
831         const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
832         /// If enabled, the handshake process will pause and return instead of
833         /// automatically validating a client's certificate.
834         const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
835         /// If enabled, TLS false start will be performed if an appropriate
836         /// cipher suite is negotiated.
837         ///
838         /// Requires the `OSX_10_9` (or greater) feature.
839         #[cfg(feature = "OSX_10_9")]
840         const kSSLSessionOptionFalseStart: false_start & set_false_start,
841         /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
842         /// connections using block ciphers to mitigate the BEAST attack.
843         ///
844         /// Requires the `OSX_10_9` (or greater) feature.
845         #[cfg(feature = "OSX_10_9")]
846         const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
847     }
848 
into_stream<S>(self, stream: S) -> Result<SslStream<S>> where S: Read + Write,849     fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
850     where
851         S: Read + Write,
852     {
853         unsafe {
854             let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
855             if ret != errSecSuccess {
856                 return Err(Error::from_code(ret));
857             }
858 
859             let stream = Connection {
860                 stream,
861                 err: None,
862                 panic: None,
863             };
864             let stream = Box::into_raw(Box::new(stream));
865             let ret = SSLSetConnection(self.0, stream as *mut _);
866             if ret != errSecSuccess {
867                 let _conn = Box::from_raw(stream);
868                 return Err(Error::from_code(ret));
869             }
870 
871             Ok(SslStream {
872                 ctx: self,
873                 _m: PhantomData,
874             })
875         }
876     }
877 
878     /// Performs the SSL/TLS handshake.
handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>> where S: Read + Write,879     pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
880     where
881         S: Read + Write,
882     {
883         self.into_stream(stream)
884             .map_err(HandshakeError::Failure)
885             .and_then(SslStream::handshake)
886     }
887 }
888 
889 struct Connection<S> {
890     stream: S,
891     err: Option<io::Error>,
892     panic: Option<Box<dyn Any + Send>>,
893 }
894 
895 // the logic here is based off of libcurl's
896 
translate_err(e: &io::Error) -> OSStatus897 fn translate_err(e: &io::Error) -> OSStatus {
898     match e.kind() {
899         io::ErrorKind::NotFound => errSSLClosedGraceful,
900         io::ErrorKind::ConnectionReset => errSSLClosedAbort,
901         io::ErrorKind::WouldBlock |
902         io::ErrorKind::NotConnected => errSSLWouldBlock,
903         _ => errSecIO,
904     }
905 }
906 
read_func<S>( connection: SSLConnectionRef, data: *mut c_void, data_length: *mut usize, ) -> OSStatus where S: Read,907 unsafe extern "C" fn read_func<S>(
908     connection: SSLConnectionRef,
909     data: *mut c_void,
910     data_length: *mut usize,
911 ) -> OSStatus
912 where
913     S: Read,
914 {
915     let conn: &mut Connection<S> = &mut *(connection as *mut _);
916     let data = slice::from_raw_parts_mut(data as *mut u8, *data_length);
917     let mut start = 0;
918     let mut ret = errSecSuccess;
919 
920     while start < data.len() {
921         match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
922             Ok(Ok(0)) => {
923                 ret = errSSLClosedNoNotify;
924                 break;
925             }
926             Ok(Ok(len)) => start += len,
927             Ok(Err(e)) => {
928                 ret = translate_err(&e);
929                 conn.err = Some(e);
930                 break;
931             }
932             Err(e) => {
933                 ret = errSecIO;
934                 conn.panic = Some(e);
935                 break;
936             }
937         }
938     }
939 
940     *data_length = start;
941     ret
942 }
943 
write_func<S>( connection: SSLConnectionRef, data: *const c_void, data_length: *mut usize, ) -> OSStatus where S: Write,944 unsafe extern "C" fn write_func<S>(
945     connection: SSLConnectionRef,
946     data: *const c_void,
947     data_length: *mut usize,
948 ) -> OSStatus
949 where
950     S: Write,
951 {
952     let conn: &mut Connection<S> = &mut *(connection as *mut _);
953     let data = slice::from_raw_parts(data as *mut u8, *data_length);
954     let mut start = 0;
955     let mut ret = errSecSuccess;
956 
957     while start < data.len() {
958         match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
959             Ok(Ok(0)) => {
960                 ret = errSSLClosedNoNotify;
961                 break;
962             }
963             Ok(Ok(len)) => start += len,
964             Ok(Err(e)) => {
965                 ret = translate_err(&e);
966                 conn.err = Some(e);
967                 break;
968             }
969             Err(e) => {
970                 ret = errSecIO;
971                 conn.panic = Some(e);
972                 break;
973             }
974         }
975     }
976 
977     *data_length = start;
978     ret
979 }
980 
981 /// A type implementing SSL/TLS encryption over an underlying stream.
982 pub struct SslStream<S> {
983     ctx: SslContext,
984     _m: PhantomData<S>,
985 }
986 
987 impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result988     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
989         fmt.debug_struct("SslStream")
990             .field("context", &self.ctx)
991             .field("stream", self.get_ref())
992             .finish()
993     }
994 }
995 
996 impl<S> Drop for SslStream<S> {
drop(&mut self)997     fn drop(&mut self) {
998         unsafe {
999             let mut conn = ptr::null();
1000             let ret = SSLGetConnection(self.ctx.0, &mut conn);
1001             assert!(ret == errSecSuccess);
1002             Box::<Connection<S>>::from_raw(conn as *mut _);
1003         }
1004     }
1005 }
1006 
1007 impl<S> SslStream<S> {
handshake(mut self) -> result::Result<Self, HandshakeError<S>>1008     fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
1009         match unsafe { SSLHandshake(self.ctx.0) } {
1010             errSecSuccess => Ok(self),
1011             reason @ errSSLPeerAuthCompleted
1012             | reason @ errSSLClientCertRequested
1013             | reason @ errSSLWouldBlock
1014             | reason @ errSSLClientHelloReceived => {
1015                 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
1016                     stream: self,
1017                     error: Error::from_code(reason),
1018                 }))
1019             }
1020             err => {
1021                 self.check_panic();
1022                 Err(HandshakeError::Failure(Error::from_code(err)))
1023             }
1024         }
1025     }
1026 
1027     /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S1028     pub fn get_ref(&self) -> &S {
1029         &self.connection().stream
1030     }
1031 
1032     /// Returns a mutable reference to the underlying stream.
get_mut(&mut self) -> &mut S1033     pub fn get_mut(&mut self) -> &mut S {
1034         &mut self.connection_mut().stream
1035     }
1036 
1037     /// Returns a shared reference to the `SslContext` of the stream.
context(&self) -> &SslContext1038     pub fn context(&self) -> &SslContext {
1039         &self.ctx
1040     }
1041 
1042     /// Returns a mutable reference to the `SslContext` of the stream.
context_mut(&mut self) -> &mut SslContext1043     pub fn context_mut(&mut self) -> &mut SslContext {
1044         &mut self.ctx
1045     }
1046 
1047     /// Shuts down the connection.
close(&mut self) -> result::Result<(), io::Error>1048     pub fn close(&mut self) -> result::Result<(), io::Error> {
1049         unsafe {
1050             let ret = SSLClose(self.ctx.0);
1051             if ret == errSecSuccess {
1052                 Ok(())
1053             } else {
1054                 Err(self.get_error(ret))
1055             }
1056         }
1057     }
1058 
connection(&self) -> &Connection<S>1059     fn connection(&self) -> &Connection<S> {
1060         unsafe {
1061             let mut conn = ptr::null();
1062             let ret = SSLGetConnection(self.ctx.0, &mut conn);
1063             assert!(ret == errSecSuccess);
1064 
1065             mem::transmute(conn)
1066         }
1067     }
1068 
connection_mut(&mut self) -> &mut Connection<S>1069     fn connection_mut(&mut self) -> &mut Connection<S> {
1070         unsafe {
1071             let mut conn = ptr::null();
1072             let ret = SSLGetConnection(self.ctx.0, &mut conn);
1073             assert!(ret == errSecSuccess);
1074 
1075             mem::transmute(conn)
1076         }
1077     }
1078 
check_panic(&mut self)1079     fn check_panic(&mut self) {
1080         let conn = self.connection_mut();
1081         if let Some(err) = conn.panic.take() {
1082             panic::resume_unwind(err);
1083         }
1084     }
1085 
get_error(&mut self, ret: OSStatus) -> io::Error1086     fn get_error(&mut self, ret: OSStatus) -> io::Error {
1087         self.check_panic();
1088 
1089         if let Some(err) = self.connection_mut().err.take() {
1090             err
1091         } else {
1092             io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1093         }
1094     }
1095 }
1096 
1097 impl<S: Read + Write> Read for SslStream<S> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>1098     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1099         // Below we base our return value off the amount of data read, so a
1100         // zero-length buffer might cause us to erroneously interpret this
1101         // request as an error. Instead short-circuit that logic and return
1102         // `Ok(0)` instead.
1103         if buf.is_empty() {
1104             return Ok(0);
1105         }
1106 
1107         // If some data was buffered but not enough to fill `buf`, SSLRead
1108         // will try to read a new packet. This is bad because there may be
1109         // no more data but the socket is remaining open (e.g HTTPS with
1110         // Connection: keep-alive).
1111         let buffered = self.context().buffered_read_size().unwrap_or(0);
1112         let to_read = if buffered > 0 {
1113             cmp::min(buffered, buf.len())
1114         } else {
1115             buf.len()
1116         };
1117 
1118         unsafe {
1119             let mut nread = 0;
1120             let ret = SSLRead(self.ctx.0, buf.as_mut_ptr() as *mut _, to_read, &mut nread);
1121             // SSLRead can return an error at the same time it returns the last
1122             // chunk of data (!)
1123             if nread > 0 {
1124                 return Ok(nread as usize);
1125             }
1126 
1127             match ret {
1128                 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1129                 _ => Err(self.get_error(ret)),
1130             }
1131         }
1132     }
1133 }
1134 
1135 impl<S: Read + Write> Write for SslStream<S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>1136     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1137         // Like above in read, short circuit a 0-length write
1138         if buf.is_empty() {
1139             return Ok(0);
1140         }
1141         unsafe {
1142             let mut nwritten = 0;
1143             let ret = SSLWrite(
1144                 self.ctx.0,
1145                 buf.as_ptr() as *const _,
1146                 buf.len(),
1147                 &mut nwritten,
1148             );
1149             // just to be safe, base success off of nwritten rather than ret
1150             // for the same reason as in read
1151             if nwritten > 0 {
1152                 Ok(nwritten as usize)
1153             } else {
1154                 Err(self.get_error(ret))
1155             }
1156         }
1157     }
1158 
flush(&mut self) -> io::Result<()>1159     fn flush(&mut self) -> io::Result<()> {
1160         self.connection_mut().stream.flush()
1161     }
1162 }
1163 
1164 /// A builder type to simplify the creation of client side `SslStream`s.
1165 #[derive(Debug)]
1166 pub struct ClientBuilder {
1167     identity: Option<SecIdentity>,
1168     certs: Vec<SecCertificate>,
1169     chain: Vec<SecCertificate>,
1170     protocol_min: Option<SslProtocol>,
1171     protocol_max: Option<SslProtocol>,
1172     trust_certs_only: bool,
1173     use_sni: bool,
1174     danger_accept_invalid_certs: bool,
1175     danger_accept_invalid_hostnames: bool,
1176     whitelisted_ciphers: Vec<CipherSuite>,
1177     blacklisted_ciphers: Vec<CipherSuite>,
1178     alpn: Option<Vec<String>>,
1179     enable_session_tickets: bool,
1180 }
1181 
1182 impl Default for ClientBuilder {
default() -> Self1183     fn default() -> Self {
1184         Self::new()
1185     }
1186 }
1187 
1188 impl ClientBuilder {
1189     /// Creates a new builder with default options.
new() -> Self1190     pub fn new() -> Self {
1191         Self {
1192             identity: None,
1193             certs: Vec::new(),
1194             chain: Vec::new(),
1195             protocol_min: None,
1196             protocol_max: None,
1197             trust_certs_only: false,
1198             use_sni: true,
1199             danger_accept_invalid_certs: false,
1200             danger_accept_invalid_hostnames: false,
1201             whitelisted_ciphers: Vec::new(),
1202             blacklisted_ciphers: Vec::new(),
1203             alpn: None,
1204             enable_session_tickets: false,
1205         }
1206     }
1207 
1208     /// Specifies the set of root certificates to trust when
1209     /// verifying the server's certificate.
anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self1210     pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1211         self.certs = certs.to_owned();
1212         self
1213     }
1214 
1215     /// Specifies whether to trust the built-in certificates in addition
1216     /// to specified anchor certificates.
trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self1217     pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1218         self.trust_certs_only = only;
1219         self
1220     }
1221 
1222     /// Specifies whether to trust invalid certificates.
1223     ///
1224     /// # Warning
1225     ///
1226     /// You should think very carefully before using this method. If invalid
1227     /// certificates are trusted, *any* certificate for *any* site will be
1228     /// trusted for use. This includes expired certificates. This introduces
1229     /// significant vulnerabilities, and should only be used as a last resort.
danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self1230     pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1231         self.danger_accept_invalid_certs = noverify;
1232         self
1233     }
1234 
1235     /// Specifies whether to use Server Name Indication (SNI).
use_sni(&mut self, use_sni: bool) -> &mut Self1236     pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1237         self.use_sni = use_sni;
1238         self
1239     }
1240 
1241     /// Specifies whether to verify that the server's hostname matches its certificate.
1242     ///
1243     /// # Warning
1244     ///
1245     /// You should think very carefully before using this method. If hostnames are not verified,
1246     /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1247     /// vulnerabilities, and should only be used as a last resort.
danger_accept_invalid_hostnames( &mut self, danger_accept_invalid_hostnames: bool, ) -> &mut Self1248     pub fn danger_accept_invalid_hostnames(
1249         &mut self,
1250         danger_accept_invalid_hostnames: bool,
1251     ) -> &mut Self {
1252         self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1253         self
1254     }
1255 
1256     /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self1257     pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1258         self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
1259         self
1260     }
1261 
1262     /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self1263     pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1264         self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
1265         self
1266     }
1267 
1268     /// Use the specified identity as a SSL/TLS client certificate.
identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self1269     pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1270         self.identity = Some(identity.clone());
1271         self.chain = chain.to_owned();
1272         self
1273     }
1274 
1275     /// Configure the minimum protocol that this client will support.
protocol_min(&mut self, min: SslProtocol) -> &mut Self1276     pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1277         self.protocol_min = Some(min);
1278         self
1279     }
1280 
1281     /// Configure the minimum protocol that this client will support.
protocol_max(&mut self, max: SslProtocol) -> &mut Self1282     pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1283         self.protocol_max = Some(max);
1284         self
1285     }
1286 
1287     /// Configures the set of protocols used for ALPN.
1288     #[cfg(feature = "alpn")]
alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self1289     pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1290         self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
1291         self
1292     }
1293 
1294     /// Configures the use of the RFC 5077 `SessionTicket` extension.
1295     ///
1296     /// Defaults to `false`.
1297     #[cfg(feature = "session-tickets")]
enable_session_tickets(&mut self, enable: bool) -> &mut Self1298     pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1299         self.enable_session_tickets = enable;
1300         self
1301     }
1302 
1303     /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1304     ///
1305     /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
handshake<S>( &self, domain: &str, stream: S, ) -> result::Result<SslStream<S>, ClientHandshakeError<S>> where S: Read + Write,1306     pub fn handshake<S>(
1307         &self,
1308         domain: &str,
1309         stream: S,
1310     ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1311     where
1312         S: Read + Write,
1313     {
1314         // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1315         // of the handshake logic through that.
1316         let stream = MidHandshakeSslStream {
1317             stream: self.ctx_into_stream(domain, stream)?,
1318             error: Error::from(errSecSuccess),
1319         };
1320 
1321         let certs = self.certs.clone();
1322         let stream = MidHandshakeClientBuilder {
1323             stream,
1324             domain: if self.danger_accept_invalid_hostnames {
1325                 None
1326             } else {
1327                 Some(domain.to_string())
1328             },
1329             certs,
1330             trust_certs_only: self.trust_certs_only,
1331             danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1332         };
1333         stream.handshake()
1334     }
1335 
ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>> where S: Read + Write,1336     fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1337     where
1338         S: Read + Write,
1339     {
1340         let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1341 
1342         if self.use_sni {
1343             ctx.set_peer_domain_name(domain)?;
1344         }
1345         if let Some(ref identity) = self.identity {
1346             ctx.set_certificate(identity, &self.chain)?;
1347         }
1348         #[cfg(feature = "alpn")]
1349         {
1350             if let Some(ref alpn) = self.alpn {
1351                 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1352             }
1353         }
1354         #[cfg(feature = "session-tickets")]
1355         {
1356             if self.enable_session_tickets {
1357                 // We must use the domain here to ensure that we go through certificate validation
1358                 // again rather than resuming the session if the domain changes.
1359                 ctx.set_peer_id(domain.as_bytes())?;
1360                 ctx.set_session_tickets_enabled(true)?;
1361             }
1362         }
1363         ctx.set_break_on_server_auth(true)?;
1364         self.configure_protocols(&mut ctx)?;
1365         self.configure_ciphers(&mut ctx)?;
1366 
1367         ctx.into_stream(stream)
1368     }
1369 
configure_protocols(&self, ctx: &mut SslContext) -> Result<()>1370     fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1371         if let Some(min) = self.protocol_min {
1372             ctx.set_protocol_version_min(min)?;
1373         }
1374         if let Some(max) = self.protocol_max {
1375             ctx.set_protocol_version_max(max)?;
1376         }
1377         Ok(())
1378     }
1379 
configure_ciphers(&self, ctx: &mut SslContext) -> Result<()>1380     fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1381         let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1382             ctx.enabled_ciphers()?
1383         } else {
1384             self.whitelisted_ciphers.clone()
1385         };
1386 
1387         if !self.blacklisted_ciphers.is_empty() {
1388             ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1389         }
1390 
1391         ctx.set_enabled_ciphers(&ciphers)?;
1392         Ok(())
1393     }
1394 }
1395 
1396 /// A builder type to simplify the creation of server-side `SslStream`s.
1397 #[derive(Debug)]
1398 pub struct ServerBuilder {
1399     identity: SecIdentity,
1400     certs: Vec<SecCertificate>,
1401 }
1402 
1403 impl ServerBuilder {
1404     /// Creates a new `ServerBuilder` which will use the specified identity
1405     /// and certificate chain for handshakes.
new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self1406     pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1407         Self {
1408             identity: identity.clone(),
1409             certs: certs.to_owned(),
1410         }
1411     }
1412 
1413     /// Initiates a new SSL/TLS session over a stream.
handshake<S>(&self, stream: S) -> Result<SslStream<S>> where S: Read + Write,1414     pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1415     where
1416         S: Read + Write,
1417     {
1418         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1419         ctx.set_certificate(&self.identity, &self.certs)?;
1420         match ctx.handshake(stream) {
1421             Ok(stream) => Ok(stream),
1422             Err(HandshakeError::Interrupted(stream)) => Err(Error::from_code(stream.reason())),
1423             Err(HandshakeError::Failure(err)) => Err(err),
1424         }
1425     }
1426 }
1427 
1428 #[cfg(test)]
1429 mod test {
1430     use std::io;
1431     use std::io::prelude::*;
1432     use std::net::TcpStream;
1433 
1434     use super::*;
1435 
1436     #[test]
connect()1437     fn connect() {
1438         let mut ctx = p!(SslContext::new(
1439             SslProtocolSide::CLIENT,
1440             SslConnectionType::STREAM
1441         ));
1442         p!(ctx.set_peer_domain_name("google.com"));
1443         let stream = p!(TcpStream::connect("google.com:443"));
1444         p!(ctx.handshake(stream));
1445     }
1446 
1447     #[test]
connect_bad_domain()1448     fn connect_bad_domain() {
1449         let mut ctx = p!(SslContext::new(
1450             SslProtocolSide::CLIENT,
1451             SslConnectionType::STREAM
1452         ));
1453         p!(ctx.set_peer_domain_name("foobar.com"));
1454         let stream = p!(TcpStream::connect("google.com:443"));
1455         match ctx.handshake(stream) {
1456             Ok(_) => panic!("expected failure"),
1457             Err(_) => {}
1458         }
1459     }
1460 
1461     #[test]
load_page()1462     fn load_page() {
1463         let mut ctx = p!(SslContext::new(
1464             SslProtocolSide::CLIENT,
1465             SslConnectionType::STREAM
1466         ));
1467         p!(ctx.set_peer_domain_name("google.com"));
1468         let stream = p!(TcpStream::connect("google.com:443"));
1469         let mut stream = p!(ctx.handshake(stream));
1470         p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1471         p!(stream.flush());
1472         let mut buf = vec![];
1473         p!(stream.read_to_end(&mut buf));
1474         println!("{}", String::from_utf8_lossy(&buf));
1475     }
1476 
1477     #[test]
client_no_session_ticket_resumption()1478     fn client_no_session_ticket_resumption() {
1479         for _ in 0..2 {
1480             let stream = p!(TcpStream::connect("google.com:443"));
1481 
1482             // Manually handshake here.
1483             let stream = MidHandshakeSslStream {
1484                 stream: ClientBuilder::new()
1485                     .ctx_into_stream("google.com", stream)
1486                     .unwrap(),
1487                 error: Error::from(errSecSuccess),
1488             };
1489 
1490             let mut result = stream.handshake();
1491 
1492             if let Err(HandshakeError::Interrupted(stream)) = result {
1493                 assert!(stream.server_auth_completed());
1494                 result = stream.handshake();
1495             } else {
1496                 panic!("Unexpectedly skipped server auth");
1497             }
1498 
1499             assert!(result.is_ok());
1500         }
1501     }
1502 
1503     #[test]
1504     #[cfg(feature = "session-tickets")]
client_session_ticket_resumption()1505     fn client_session_ticket_resumption() {
1506         // The first time through this loop, we should do a full handshake. The second time, we
1507         // should immediately finish the handshake without breaking on server auth.
1508         for i in 0..2 {
1509             let stream = p!(TcpStream::connect("google.com:443"));
1510             let mut builder = ClientBuilder::new();
1511             builder.enable_session_tickets(true);
1512 
1513             // Manually handshake here.
1514             let stream = MidHandshakeSslStream {
1515                 stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1516                 error: Error::from(errSecSuccess),
1517             };
1518 
1519             let mut result = stream.handshake();
1520 
1521             if let Err(HandshakeError::Interrupted(stream)) = result {
1522                 assert!(stream.server_auth_completed());
1523                 assert_eq!(
1524                     i, 0,
1525                     "Session ticket resumption did not work, server auth was not skipped"
1526                 );
1527                 result = stream.handshake();
1528             } else {
1529                 assert_eq!(i, 1, "Unexpectedly skipped server auth");
1530             }
1531 
1532             assert!(result.is_ok());
1533         }
1534     }
1535 
1536     #[test]
1537     #[cfg(feature = "alpn")]
client_alpn_accept()1538     fn client_alpn_accept() {
1539         let mut ctx = p!(SslContext::new(
1540             SslProtocolSide::CLIENT,
1541             SslConnectionType::STREAM
1542         ));
1543         p!(ctx.set_peer_domain_name("google.com"));
1544         p!(ctx.set_alpn_protocols(&vec!["h2"]));
1545         let stream = p!(TcpStream::connect("google.com:443"));
1546         let stream = ctx.handshake(stream).unwrap();
1547         assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1548     }
1549 
1550     #[test]
1551     #[cfg(feature = "alpn")]
client_alpn_reject()1552     fn client_alpn_reject() {
1553         let mut ctx = p!(SslContext::new(
1554             SslProtocolSide::CLIENT,
1555             SslConnectionType::STREAM
1556         ));
1557         p!(ctx.set_peer_domain_name("google.com"));
1558         p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1559         let stream = p!(TcpStream::connect("google.com:443"));
1560         let stream = ctx.handshake(stream).unwrap();
1561         assert!(stream.context().alpn_protocols().is_err());
1562     }
1563 
1564     #[test]
client_no_anchor_certs()1565     fn client_no_anchor_certs() {
1566         let stream = p!(TcpStream::connect("google.com:443"));
1567         assert!(ClientBuilder::new()
1568             .trust_anchor_certificates_only(true)
1569             .handshake("google.com", stream)
1570             .is_err());
1571     }
1572 
1573     #[test]
client_bad_domain()1574     fn client_bad_domain() {
1575         let stream = p!(TcpStream::connect("google.com:443"));
1576         assert!(ClientBuilder::new()
1577             .handshake("foobar.com", stream)
1578             .is_err());
1579     }
1580 
1581     #[test]
client_bad_domain_ignored()1582     fn client_bad_domain_ignored() {
1583         let stream = p!(TcpStream::connect("google.com:443"));
1584         ClientBuilder::new()
1585             .danger_accept_invalid_hostnames(true)
1586             .handshake("foobar.com", stream)
1587             .unwrap();
1588     }
1589 
1590     #[test]
connect_no_verify_ssl()1591     fn connect_no_verify_ssl() {
1592         let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1593         let mut builder = ClientBuilder::new();
1594         builder.danger_accept_invalid_certs(true);
1595         builder.handshake("expired.badssl.com", stream).unwrap();
1596     }
1597 
1598     #[test]
load_page_client()1599     fn load_page_client() {
1600         let stream = p!(TcpStream::connect("google.com:443"));
1601         let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1602         p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1603         p!(stream.flush());
1604         let mut buf = vec![];
1605         p!(stream.read_to_end(&mut buf));
1606         println!("{}", String::from_utf8_lossy(&buf));
1607     }
1608 
1609     #[test]
1610     #[cfg_attr(target_os = "ios", ignore)] // FIXME what's going on with ios?
cipher_configuration()1611     fn cipher_configuration() {
1612         let mut ctx = p!(SslContext::new(
1613             SslProtocolSide::SERVER,
1614             SslConnectionType::STREAM
1615         ));
1616         let ciphers = p!(ctx.enabled_ciphers());
1617         let ciphers = ciphers
1618             .iter()
1619             .enumerate()
1620             .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1621             .collect::<Vec<_>>();
1622         p!(ctx.set_enabled_ciphers(&ciphers));
1623         assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1624     }
1625 
1626     #[test]
test_builder_whitelist_ciphers()1627     fn test_builder_whitelist_ciphers() {
1628         let stream = p!(TcpStream::connect("google.com:443"));
1629 
1630         let ctx = p!(SslContext::new(
1631             SslProtocolSide::CLIENT,
1632             SslConnectionType::STREAM
1633         ));
1634         assert!(p!(ctx.enabled_ciphers()).len() > 1);
1635 
1636         let ciphers = p!(ctx.enabled_ciphers());
1637         let cipher = ciphers.first().unwrap();
1638         let stream = p!(ClientBuilder::new()
1639             .whitelist_ciphers(&[*cipher])
1640             .ctx_into_stream("google.com", stream));
1641 
1642         assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1643     }
1644 
1645     #[test]
1646     #[cfg_attr(target_os = "ios", ignore)] // FIXME same issue as cipher_configuration
test_builder_blacklist_ciphers()1647     fn test_builder_blacklist_ciphers() {
1648         let stream = p!(TcpStream::connect("google.com:443"));
1649 
1650         let ctx = p!(SslContext::new(
1651             SslProtocolSide::CLIENT,
1652             SslConnectionType::STREAM
1653         ));
1654         let num = p!(ctx.enabled_ciphers()).len();
1655         assert!(num > 1);
1656 
1657         let ciphers = p!(ctx.enabled_ciphers());
1658         let cipher = ciphers.first().unwrap();
1659         let stream = p!(ClientBuilder::new()
1660             .blacklist_ciphers(&[*cipher])
1661             .ctx_into_stream("google.com", stream));
1662 
1663         assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1664     }
1665 
1666     #[test]
idle_context_peer_trust()1667     fn idle_context_peer_trust() {
1668         let ctx = p!(SslContext::new(
1669             SslProtocolSide::SERVER,
1670             SslConnectionType::STREAM
1671         ));
1672         assert!(ctx.peer_trust2().is_err());
1673     }
1674 
1675     #[test]
peer_id()1676     fn peer_id() {
1677         let mut ctx = p!(SslContext::new(
1678             SslProtocolSide::SERVER,
1679             SslConnectionType::STREAM
1680         ));
1681         assert!(p!(ctx.peer_id()).is_none());
1682         p!(ctx.set_peer_id(b"foobar"));
1683         assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1684     }
1685 
1686     #[test]
peer_domain_name()1687     fn peer_domain_name() {
1688         let mut ctx = p!(SslContext::new(
1689             SslProtocolSide::CLIENT,
1690             SslConnectionType::STREAM
1691         ));
1692         assert_eq!("", p!(ctx.peer_domain_name()));
1693         p!(ctx.set_peer_domain_name("foobar.com"));
1694         assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1695     }
1696 
1697     #[test]
1698     #[should_panic(expected = "blammo")]
write_panic()1699     fn write_panic() {
1700         struct ExplodingStream(TcpStream);
1701 
1702         impl Read for ExplodingStream {
1703             fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1704                 self.0.read(buf)
1705             }
1706         }
1707 
1708         impl Write for ExplodingStream {
1709             fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1710                 panic!("blammo");
1711             }
1712 
1713             fn flush(&mut self) -> io::Result<()> {
1714                 self.0.flush()
1715             }
1716         }
1717 
1718         let mut ctx = p!(SslContext::new(
1719             SslProtocolSide::CLIENT,
1720             SslConnectionType::STREAM
1721         ));
1722         p!(ctx.set_peer_domain_name("google.com"));
1723         let stream = p!(TcpStream::connect("google.com:443"));
1724         let _ = ctx.handshake(ExplodingStream(stream));
1725     }
1726 
1727     #[test]
1728     #[should_panic(expected = "blammo")]
read_panic()1729     fn read_panic() {
1730         struct ExplodingStream(TcpStream);
1731 
1732         impl Read for ExplodingStream {
1733             fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1734                 panic!("blammo");
1735             }
1736         }
1737 
1738         impl Write for ExplodingStream {
1739             fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1740                 self.0.write(buf)
1741             }
1742 
1743             fn flush(&mut self) -> io::Result<()> {
1744                 self.0.flush()
1745             }
1746         }
1747 
1748         let mut ctx = p!(SslContext::new(
1749             SslProtocolSide::CLIENT,
1750             SslConnectionType::STREAM
1751         ));
1752         p!(ctx.set_peer_domain_name("google.com"));
1753         let stream = p!(TcpStream::connect("google.com:443"));
1754         let _ = ctx.handshake(ExplodingStream(stream));
1755     }
1756 
1757     #[test]
zero_length_buffers()1758     fn zero_length_buffers() {
1759         let mut ctx = p!(SslContext::new(
1760             SslProtocolSide::CLIENT,
1761             SslConnectionType::STREAM
1762         ));
1763         p!(ctx.set_peer_domain_name("google.com"));
1764         let stream = p!(TcpStream::connect("google.com:443"));
1765         let mut stream = ctx.handshake(stream).unwrap();
1766         assert_eq!(stream.write(b"").unwrap(), 0);
1767         assert_eq!(stream.read(&mut []).unwrap(), 0);
1768     }
1769 }
1770