1 //! SSL/TLS encryption support using Secure Transport.
2 //!
3 //! # Examples
4 //!
5 //! To connect as a client to a server with a certificate trusted by the system:
6 //!
7 //! ```rust
8 //! use std::io::prelude::*;
9 //! use std::net::TcpStream;
10 //! use security_framework::secure_transport::ClientBuilder;
11 //!
12 //! let stream = TcpStream::connect("google.com:443").unwrap();
13 //! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
14 //!
15 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
16 //! let mut page = vec![];
17 //! stream.read_to_end(&mut page).unwrap();
18 //! println!("{}", String::from_utf8_lossy(&page));
19 //! ```
20 //!
21 //! To connect to a server with a certificate that's *not* trusted by the
22 //! system, specify the root certificates for the server's chain to the
23 //! `ClientBuilder`:
24 //!
25 //! ```rust,no_run
26 //! use std::io::prelude::*;
27 //! use std::net::TcpStream;
28 //! use security_framework::secure_transport::ClientBuilder;
29 //!
30 //! # let root_cert = unsafe { std::mem::zeroed() };
31 //! let stream = TcpStream::connect("my_server.com:443").unwrap();
32 //! let mut stream = ClientBuilder::new()
33 //!                      .anchor_certificates(&[root_cert])
34 //!                      .handshake("my_server.com", stream)
35 //!                      .unwrap();
36 //!
37 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
38 //! let mut page = vec![];
39 //! stream.read_to_end(&mut page).unwrap();
40 //! println!("{}", String::from_utf8_lossy(&page));
41 //! ```
42 //!
43 //! For more advanced configuration, the `SslContext` type can be used directly.
44 //!
45 //! To run a server:
46 //!
47 //! ```rust,no_run
48 //! use std::net::TcpListener;
49 //! use std::thread;
50 //! use security_framework::secure_transport::{SslContext, SslProtocolSide, SslConnectionType};
51 //!
52 //! // Create a TCP listener and start accepting on it.
53 //! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
54 //!
55 //! for stream in listener.incoming() {
56 //!     let stream = stream.unwrap();
57 //!     thread::spawn(move || {
58 //!         // Create a new context configured to operate on the server side of
59 //!         // a traditional SSL/TLS session.
60 //!         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
61 //!                           .unwrap();
62 //!
63 //!         // Install the certificate chain that we will be using.
64 //!         # let identity = unsafe { std::mem::zeroed() };
65 //!         # let intermediate_cert = unsafe { std::mem::zeroed() };
66 //!         # let root_cert = unsafe { std::mem::zeroed() };
67 //!         ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
68 //!
69 //!         // Perform the SSL/TLS handshake and get our stream.
70 //!         let mut stream = ctx.handshake(stream).unwrap();
71 //!     });
72 //! }
73 //!
74 //! ```
75 #[allow(unused_imports)]
76 use core_foundation::array::{CFArray, CFArrayRef};
77 
78 use core_foundation::base::{Boolean, TCFType};
79 #[cfg(feature = "alpn")]
80 use core_foundation::string::CFString;
81 use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82 use std::os::raw::c_void;
83 
84 #[allow(unused_imports)]
85 use security_framework_sys::base::{
86     errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87     errSecUnimplemented,
88 };
89 
90 use security_framework_sys::secure_transport::*;
91 use std::any::Any;
92 use std::cmp;
93 use std::fmt;
94 use std::io;
95 use std::io::prelude::*;
96 use std::marker::PhantomData;
97 use std::panic::{self, AssertUnwindSafe};
98 use std::ptr;
99 use std::result;
100 use std::slice;
101 
102 use crate::base::{Error, Result};
103 use crate::certificate::SecCertificate;
104 use crate::cipher_suite::CipherSuite;
105 use crate::identity::SecIdentity;
106 use crate::import_export::Pkcs12ImportOptions;
107 use crate::policy::SecPolicy;
108 use crate::trust::SecTrust;
109 use crate::{cvt, AsInner};
110 use security_framework_sys::base::errSecParam;
111 
112 /// Specifies a side of a TLS session.
113 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
114 pub struct SslProtocolSide(SSLProtocolSide);
115 
116 impl SslProtocolSide {
117     /// The server side of the session.
118     pub const SERVER: Self = Self(kSSLServerSide);
119 
120     /// The client side of the session.
121     pub const CLIENT: Self = Self(kSSLClientSide);
122 }
123 
124 /// Specifies the type of TLS session.
125 #[derive(Debug, Copy, Clone)]
126 pub struct SslConnectionType(SSLConnectionType);
127 
128 impl SslConnectionType {
129     /// A traditional TLS stream.
130     pub const STREAM: Self = Self(kSSLStreamType);
131 
132     /// A DTLS session.
133     pub const DATAGRAM: Self = Self(kSSLDatagramType);
134 }
135 
136 /// An error or intermediate state after a TLS handshake attempt.
137 #[derive(Debug)]
138 pub enum HandshakeError<S> {
139     /// The handshake failed.
140     Failure(Error),
141     /// The handshake was interrupted midway through.
142     Interrupted(MidHandshakeSslStream<S>),
143 }
144 
145 impl<S> From<Error> for HandshakeError<S> {
146     #[inline(always)]
from(err: Error) -> Self147     fn from(err: Error) -> Self {
148         Self::Failure(err)
149     }
150 }
151 
152 /// An error or intermediate state after a TLS handshake attempt.
153 #[derive(Debug)]
154 pub enum ClientHandshakeError<S> {
155     /// The handshake failed.
156     Failure(Error),
157     /// The handshake was interrupted midway through.
158     Interrupted(MidHandshakeClientBuilder<S>),
159 }
160 
161 impl<S> From<Error> for ClientHandshakeError<S> {
162     #[inline(always)]
from(err: Error) -> Self163     fn from(err: Error) -> Self {
164         Self::Failure(err)
165     }
166 }
167 
168 /// An SSL stream midway through the handshake process.
169 #[derive(Debug)]
170 pub struct MidHandshakeSslStream<S> {
171     stream: SslStream<S>,
172     error: Error,
173 }
174 
175 impl<S> MidHandshakeSslStream<S> {
176     /// Returns a shared reference to the inner stream.
177     #[inline(always)]
get_ref(&self) -> &S178     pub fn get_ref(&self) -> &S {
179         self.stream.get_ref()
180     }
181 
182     /// Returns a mutable reference to the inner stream.
183     #[inline(always)]
get_mut(&mut self) -> &mut S184     pub fn get_mut(&mut self) -> &mut S {
185         self.stream.get_mut()
186     }
187 
188     /// Returns a shared reference to the `SslContext` of the stream.
189     #[inline(always)]
context(&self) -> &SslContext190     pub fn context(&self) -> &SslContext {
191         self.stream.context()
192     }
193 
194     /// Returns a mutable reference to the `SslContext` of the stream.
195     #[inline(always)]
context_mut(&mut self) -> &mut SslContext196     pub fn context_mut(&mut self) -> &mut SslContext {
197         self.stream.context_mut()
198     }
199 
200     /// Returns `true` iff `break_on_server_auth` was set and the handshake has
201     /// progressed to that point.
202     #[inline(always)]
server_auth_completed(&self) -> bool203     pub fn server_auth_completed(&self) -> bool {
204         self.error.code() == errSSLPeerAuthCompleted
205     }
206 
207     /// Returns `true` iff `break_on_cert_requested` was set and the handshake
208     /// has progressed to that point.
209     #[inline(always)]
client_cert_requested(&self) -> bool210     pub fn client_cert_requested(&self) -> bool {
211         self.error.code() == errSSLClientCertRequested
212     }
213 
214     /// Returns `true` iff the underlying stream returned an error with the
215     /// `WouldBlock` kind.
216     #[inline(always)]
would_block(&self) -> bool217     pub fn would_block(&self) -> bool {
218         self.error.code() == errSSLWouldBlock
219     }
220 
221     /// Returns the error which caused the handshake interruption.
222     #[inline(always)]
error(&self) -> &Error223     pub fn error(&self) -> &Error {
224         &self.error
225     }
226 
227     /// Restarts the handshake process.
228     #[inline(always)]
handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>>229     pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
230         self.stream.handshake()
231     }
232 }
233 
234 /// An SSL stream midway through the handshake process.
235 #[derive(Debug)]
236 pub struct MidHandshakeClientBuilder<S> {
237     stream: MidHandshakeSslStream<S>,
238     domain: Option<String>,
239     certs: Vec<SecCertificate>,
240     trust_certs_only: bool,
241     danger_accept_invalid_certs: bool,
242 }
243 
244 impl<S> MidHandshakeClientBuilder<S> {
245     /// Returns a shared reference to the inner stream.
246     #[inline(always)]
get_ref(&self) -> &S247     pub fn get_ref(&self) -> &S {
248         self.stream.get_ref()
249     }
250 
251     /// Returns a mutable reference to the inner stream.
252     #[inline(always)]
get_mut(&mut self) -> &mut S253     pub fn get_mut(&mut self) -> &mut S {
254         self.stream.get_mut()
255     }
256 
257     /// Returns the error which caused the handshake interruption.
258     #[inline(always)]
error(&self) -> &Error259     pub fn error(&self) -> &Error {
260         self.stream.error()
261     }
262 
263     /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>>264     pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
265         let MidHandshakeClientBuilder {
266             stream,
267             domain,
268             certs,
269             trust_certs_only,
270             danger_accept_invalid_certs,
271         } = self;
272 
273         let mut result = stream.handshake();
274         loop {
275             let stream = match result {
276                 Ok(stream) => return Ok(stream),
277                 Err(HandshakeError::Interrupted(stream)) => stream,
278                 Err(HandshakeError::Failure(err)) => {
279                     return Err(ClientHandshakeError::Failure(err))
280                 }
281             };
282 
283             if stream.would_block() {
284                 let ret = MidHandshakeClientBuilder {
285                     stream,
286                     domain,
287                     certs,
288                     trust_certs_only,
289                     danger_accept_invalid_certs,
290                 };
291                 return Err(ClientHandshakeError::Interrupted(ret));
292             }
293 
294             if stream.server_auth_completed() {
295                 if danger_accept_invalid_certs {
296                     result = stream.handshake();
297                     continue;
298                 }
299                 let mut trust = match stream.context().peer_trust2()? {
300                     Some(trust) => trust,
301                     None => {
302                         result = stream.handshake();
303                         continue;
304                     }
305                 };
306                 trust.set_anchor_certificates(&certs)?;
307                 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
308                 let policy =
309                     SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
310                 trust.set_policy(&policy)?;
311                 trust.evaluate_with_error().map_err(|error| {
312                     #[cfg(feature = "log")]
313                     log::warn!("SecTrustEvaluateWithError: {}", error.to_string());
314                     Error::from_code(error.code() as _)
315                 })?;
316                 result = stream.handshake();
317                 continue;
318             }
319 
320             let err = Error::from_code(stream.error().code());
321             return Err(ClientHandshakeError::Failure(err));
322         }
323     }
324 }
325 
326 /// Specifies the state of a TLS session.
327 #[derive(Debug, PartialEq, Eq)]
328 pub struct SessionState(SSLSessionState);
329 
330 impl SessionState {
331     /// The session has not yet started.
332     pub const IDLE: Self = Self(kSSLIdle);
333 
334     /// The session is in the handshake process.
335     pub const HANDSHAKE: Self = Self(kSSLHandshake);
336 
337     /// The session is connected.
338     pub const CONNECTED: Self = Self(kSSLConnected);
339 
340     /// The session has been terminated.
341     pub const CLOSED: Self = Self(kSSLClosed);
342 
343     /// The session has been aborted due to an error.
344     pub const ABORTED: Self = Self(kSSLAborted);
345 }
346 
347 /// Specifies a server's requirement for client certificates.
348 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
349 pub struct SslAuthenticate(SSLAuthenticate);
350 
351 impl SslAuthenticate {
352     /// Do not request a client certificate.
353     pub const NEVER: Self = Self(kNeverAuthenticate);
354 
355     /// Require a client certificate.
356     pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
357 
358     /// Request but do not require a client certificate.
359     pub const TRY: Self = Self(kTryAuthenticate);
360 }
361 
362 /// Specifies the state of client certificate processing.
363 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
364 pub struct SslClientCertificateState(SSLClientCertificateState);
365 
366 impl SslClientCertificateState {
367     /// A client certificate has not been requested or sent.
368     pub const NONE: Self = Self(kSSLClientCertNone);
369 
370     /// A client certificate has been requested but not recieved.
371     pub const REQUESTED: Self = Self(kSSLClientCertRequested);
372     /// A client certificate has been received and successfully validated.
373     pub const SENT: Self = Self(kSSLClientCertSent);
374 
375     /// A client certificate has been received but has failed to validate.
376     pub const REJECTED: Self = Self(kSSLClientCertRejected); }
377 
378 /// Specifies protocol versions.
379 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
380 pub struct SslProtocol(SSLProtocol);
381 
382 impl SslProtocol {
383     /// No protocol has been or should be negotiated or specified; use the default.
384     pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
385 
386     /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
387     /// SSL 3.0.
388     pub const SSL3: Self = Self(kSSLProtocol3);
389 
390     /// The TLS 1.0 protocol is preferred, though lower versions may be used
391     /// if the peer does not support TLS 1.0.
392     pub const TLS1: Self = Self(kTLSProtocol1);
393 
394     /// The TLS 1.1 protocol is preferred, though lower versions may be used
395     /// if the peer does not support TLS 1.1.
396     pub const TLS11: Self = Self(kTLSProtocol11);
397 
398     /// The TLS 1.2 protocol is preferred, though lower versions may be used
399     /// if the peer does not support TLS 1.2.
400     pub const TLS12: Self = Self(kTLSProtocol12);
401 
402     /// The TLS 1.3 protocol is preferred, though lower versions may be used
403     /// if the peer does not support TLS 1.3.
404     pub const TLS13: Self = Self(kTLSProtocol13);
405 
406     /// Only the SSL 2.0 protocol is accepted.
407     pub const SSL2: Self = Self(kSSLProtocol2);
408 
409     /// The DTLSv1 protocol is preferred.
410     pub const DTLS1: Self = Self(kDTLSProtocol1);
411 
412     /// Only the SSL 3.0 protocol is accepted.
413     pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
414 
415     /// Only the TLS 1.0 protocol is accepted.
416     pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
417 
418     /// All supported TLS/SSL versions are accepted.
419     pub const ALL: Self = Self(kSSLProtocolAll);
420 }
421 
422 declare_TCFType! {
423     /// A Secure Transport SSL/TLS context object.
424     SslContext, SSLContextRef
425 }
426 
427 impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
428 
429 impl fmt::Debug for SslContext {
430     #[cold]
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result431     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
432         let mut builder = fmt.debug_struct("SslContext");
433         if let Ok(state) = self.state() {
434             builder.field("state", &state);
435         }
436         builder.finish()
437     }
438 }
439 
440 unsafe impl Sync for SslContext {}
441 unsafe impl Send for SslContext {}
442 
443 impl AsInner for SslContext {
444     type Inner = SSLContextRef;
445 
446     #[inline(always)]
as_inner(&self) -> SSLContextRef447     fn as_inner(&self) -> SSLContextRef {
448         self.0
449     }
450 }
451 
452 macro_rules! impl_options {
453     ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
454         $(
455             $(#[$a])*
456             #[inline(always)]
457             pub fn $set(&mut self, value: bool) -> Result<()> {
458                 unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
459             }
460 
461             $(#[$a])*
462             #[inline]
463             pub fn $get(&self) -> Result<bool> {
464                 let mut value = 0;
465                 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
466                 Ok(value != 0)
467             }
468         )*
469     }
470 }
471 
472 impl SslContext {
473     /// Creates a new `SslContext` for the specified side and type of SSL
474     /// connection.
475     #[inline]
new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self>476     pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
477         unsafe {
478             let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
479             Ok(Self(ctx))
480         }
481     }
482 
483     /// Sets the fully qualified domain name of the peer.
484     ///
485     /// This will be used on the client side of a session to validate the
486     /// common name field of the server's certificate. It has no effect if
487     /// called on a server-side `SslContext`.
488     ///
489     /// It is *highly* recommended to call this method before starting the
490     /// handshake process.
491     #[inline]
set_peer_domain_name(&mut self, peer_name: &str) -> Result<()>492     pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
493         unsafe {
494             // SSLSetPeerDomainName doesn't need a null terminated string
495             cvt(SSLSetPeerDomainName(
496                 self.0,
497                 peer_name.as_ptr() as *const _,
498                 peer_name.len(),
499             ))
500         }
501     }
502 
503     /// Returns the peer domain name set by `set_peer_domain_name`.
peer_domain_name(&self) -> Result<String>504     pub fn peer_domain_name(&self) -> Result<String> {
505         unsafe {
506             let mut len = 0;
507             cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
508             let mut buf = vec![0; len];
509             cvt(SSLGetPeerDomainName(
510                 self.0,
511                 buf.as_mut_ptr() as *mut _,
512                 &mut len,
513             ))?;
514             Ok(String::from_utf8(buf).unwrap())
515         }
516     }
517 
518     /// Sets the certificate to be used by this side of the SSL session.
519     ///
520     /// This must be called before the handshake for server-side connections,
521     /// and can be used on the client-side to specify a client certificate.
522     ///
523     /// The `identity` corresponds to the leaf certificate and private
524     /// key, and the `certs` correspond to extra certificates in the chain.
set_certificate( &mut self, identity: &SecIdentity, certs: &[SecCertificate], ) -> Result<()>525     pub fn set_certificate(
526         &mut self,
527         identity: &SecIdentity,
528         certs: &[SecCertificate],
529     ) -> Result<()> {
530         let mut arr = vec![identity.as_CFType()];
531         arr.extend(certs.iter().map(|c| c.as_CFType()));
532         let certs = CFArray::from_CFTypes(&arr);
533 
534         unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
535     }
536 
537     /// Sets the peer ID of this session.
538     ///
539     /// A peer ID is an opaque sequence of bytes that will be used by Secure
540     /// Transport to identify the peer of an SSL session. If the peer ID of
541     /// this session matches that of a previously terminated session, the
542     /// previous session can be resumed without requiring a full handshake.
543     #[inline]
set_peer_id(&mut self, peer_id: &[u8]) -> Result<()>544     pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
545         unsafe {
546             cvt(SSLSetPeerID(
547                 self.0,
548                 peer_id.as_ptr() as *const _,
549                 peer_id.len(),
550             ))
551         }
552     }
553 
554     /// Returns the peer ID of this session.
peer_id(&self) -> Result<Option<&[u8]>>555     pub fn peer_id(&self) -> Result<Option<&[u8]>> {
556         unsafe {
557             let mut ptr = ptr::null();
558             let mut len = 0;
559             cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
560             if ptr.is_null() {
561                 Ok(None)
562             } else {
563                 Ok(Some(slice::from_raw_parts(ptr as *const _, len)))
564             }
565         }
566     }
567 
568     /// Returns the list of ciphers that are supported by Secure Transport.
supported_ciphers(&self) -> Result<Vec<CipherSuite>>569     pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
570         unsafe {
571             let mut num_ciphers = 0;
572             cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
573             let mut ciphers = vec![0; num_ciphers];
574             cvt(SSLGetSupportedCiphers(
575                 self.0,
576                 ciphers.as_mut_ptr(),
577                 &mut num_ciphers,
578             ))?;
579             Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
580         }
581     }
582 
583     /// Returns the list of ciphers that are eligible to be used for
584     /// negotiation.
enabled_ciphers(&self) -> Result<Vec<CipherSuite>>585     pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
586         unsafe {
587             let mut num_ciphers = 0;
588             cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
589             let mut ciphers = vec![0; num_ciphers];
590             cvt(SSLGetEnabledCiphers(
591                 self.0,
592                 ciphers.as_mut_ptr(),
593                 &mut num_ciphers,
594             ))?;
595             Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
596         }
597     }
598 
599     /// Sets the list of ciphers that are eligible to be used for negotiation.
set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()>600     pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
601         let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
602         unsafe {
603             cvt(SSLSetEnabledCiphers(
604                 self.0,
605                 ciphers.as_ptr(),
606                 ciphers.len(),
607             ))
608         }
609     }
610 
611     /// Returns the cipher being used by the session.
612     #[inline]
negotiated_cipher(&self) -> Result<CipherSuite>613     pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
614         unsafe {
615             let mut cipher = 0;
616             cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
617             Ok(CipherSuite::from_raw(cipher))
618         }
619     }
620 
621     /// Sets the requirements for client certificates.
622     ///
623     /// Should only be called on server-side sessions.
624     #[inline]
set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()>625     pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
626         unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
627     }
628 
629     /// Returns the state of client certificate processing.
630     #[inline]
client_certificate_state(&self) -> Result<SslClientCertificateState>631     pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
632         let mut state = 0;
633 
634         unsafe {
635             cvt(SSLGetClientCertificateState(self.0, &mut state))?;
636         }
637         Ok(SslClientCertificateState(state))
638     }
639 
640     /// Returns the `SecTrust` object corresponding to the peer.
641     ///
642     /// This can be used in conjunction with `set_break_on_server_auth` to
643     /// validate certificates which do not have roots in the default set.
peer_trust2(&self) -> Result<Option<SecTrust>>644     pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
645         // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
646         // so explicitly check for that
647         if self.state()? == SessionState::IDLE {
648             return Err(Error::from_code(errSecBadReq));
649         }
650 
651         unsafe {
652             let mut trust = ptr::null_mut();
653             cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
654             if trust.is_null() {
655                 Ok(None)
656             } else {
657                 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
658             }
659         }
660     }
661 
662     /// Returns the state of the session.
663     #[inline]
state(&self) -> Result<SessionState>664     pub fn state(&self) -> Result<SessionState> {
665         unsafe {
666             let mut state = 0;
667             cvt(SSLGetSessionState(self.0, &mut state))?;
668             Ok(SessionState(state))
669         }
670     }
671 
672     /// Returns the protocol version being used by the session.
673     #[inline]
negotiated_protocol_version(&self) -> Result<SslProtocol>674     pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
675         unsafe {
676             let mut version = 0;
677             cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
678             Ok(SslProtocol(version))
679         }
680     }
681 
682     /// Returns the maximum protocol version allowed by the session.
683     #[inline]
protocol_version_max(&self) -> Result<SslProtocol>684     pub fn protocol_version_max(&self) -> Result<SslProtocol> {
685         unsafe {
686             let mut version = 0;
687             cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
688             Ok(SslProtocol(version))
689         }
690     }
691 
692     /// Sets the maximum protocol version allowed by the session.
693     #[inline]
set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()>694     pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
695         unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
696     }
697 
698     /// Returns the minimum protocol version allowed by the session.
699     #[inline]
protocol_version_min(&self) -> Result<SslProtocol>700     pub fn protocol_version_min(&self) -> Result<SslProtocol> {
701         unsafe {
702             let mut version = 0;
703             cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
704             Ok(SslProtocol(version))
705         }
706     }
707 
708     /// Sets the minimum protocol version allowed by the session.
709     #[inline]
set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()>710     pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
711         unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
712     }
713 
714     /// Returns the set of protocols selected via ALPN if it succeeded.
715     #[cfg(feature = "alpn")]
alpn_protocols(&self) -> Result<Vec<String>>716     pub fn alpn_protocols(&self) -> Result<Vec<String>> {
717         let mut array: CFArrayRef = ptr::null();
718         unsafe {
719             #[cfg(feature = "OSX_10_13")]
720             {
721                 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
722             }
723 
724             #[cfg(not(feature = "OSX_10_13"))]
725             {
726                 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
727                 if let Some(f) = SSLCopyALPNProtocols.get() {
728                     cvt(f(self.0, &mut array))?;
729                 } else {
730                     return Err(Error::from_code(errSecUnimplemented));
731                 }
732             }
733 
734             if array.is_null() {
735                 return Ok(vec![]);
736             }
737 
738             let array = CFArray::<CFString>::wrap_under_create_rule(array);
739             Ok(array.into_iter().map(|p| p.to_string()).collect())
740         }
741     }
742 
743     /// Configures the set of protocols use for ALPN.
744     ///
745     /// This is only used for client-side connections.
746     #[cfg(feature = "alpn")]
set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()>747     pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
748         // When CFMutableArray is added to core-foundation and IntoIterator trait
749         // is implemented for CFMutableArray, the code below should directly collect
750         // into a CFMutableArray.
751         let protocols = CFArray::from_CFTypes(
752             &protocols
753                 .iter()
754                 .map(|proto| CFString::new(proto))
755                 .collect::<Vec<_>>(),
756         );
757 
758         #[cfg(feature = "OSX_10_13")]
759         {
760             unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
761         }
762         #[cfg(not(feature = "OSX_10_13"))]
763         {
764             dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
765             if let Some(f) = SSLSetALPNProtocols.get() {
766                 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
767             } else {
768                 Err(Error::from_code(errSecUnimplemented))
769             }
770         }
771     }
772 
773     /// Sets whether the client sends the `SessionTicket` extension in its `ClientHello`.
774     ///
775     /// On its own, this will just cause the client to send an empty `SessionTicket` extension on
776     /// every connection. [`SslContext::set_peer_id`] must also be used to key the session
777     /// ticket returned by the server.
778     ///
779     /// [`SslContext::set_peer_id`]: #method.set_peer_id
780     #[cfg(feature = "session-tickets")]
set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()>781     pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
782         #[cfg(feature = "OSX_10_13")]
783         {
784             unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, enabled as Boolean)) }
785         }
786         #[cfg(not(feature = "OSX_10_13"))]
787         {
788             dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
789             if let Some(f) = SSLSetSessionTicketsEnabled.get() {
790                 unsafe { cvt(f(self.0, enabled as Boolean)) }
791             } else {
792                 Err(Error::from_code(errSecUnimplemented))
793             }
794         }
795     }
796 
797     /// Sets whether a protocol is enabled or not.
798     ///
799     /// # Note
800     ///
801     /// On OSX this is a deprecated API in favor of `set_protocol_version_max` and
802     /// `set_protocol_version_min`, although if you're working with OSX 10.8 or before you may have
803     /// to use this API instead.
804     #[cfg(target_os = "macos")]
805     #[deprecated(note = "use `set_protocol_version_max`")]
set_protocol_version_enabled( &mut self, protocol: SslProtocol, enabled: bool, ) -> Result<()>806     pub fn set_protocol_version_enabled(
807         &mut self,
808         protocol: SslProtocol,
809         enabled: bool,
810     ) -> Result<()> {
811         unsafe {
812             cvt(SSLSetProtocolVersionEnabled(
813                 self.0,
814                 protocol.0,
815                 enabled as Boolean,
816             ))
817         }
818     }
819 
820     /// Returns the number of bytes which can be read without triggering a
821     /// `read` call in the underlying stream.
822     #[inline]
buffered_read_size(&self) -> Result<usize>823     pub fn buffered_read_size(&self) -> Result<usize> {
824         unsafe {
825             let mut size = 0;
826             cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
827             Ok(size)
828         }
829     }
830 
831     impl_options! {
832         /// If enabled, the handshake process will pause and return instead of
833         /// automatically validating a server's certificate.
834         const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
835         /// If enabled, the handshake process will pause and return after
836         /// the server requests a certificate from the client.
837         const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
838         /// If enabled, the handshake process will pause and return instead of
839         /// automatically validating a client's certificate.
840         const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
841         /// If enabled, TLS false start will be performed if an appropriate
842         /// cipher suite is negotiated.
843         ///
844         /// Requires the `OSX_10_9` (or greater) feature.
845         #[cfg(feature = "OSX_10_9")]
846         const kSSLSessionOptionFalseStart: false_start & set_false_start,
847         /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
848         /// connections using block ciphers to mitigate the BEAST attack.
849         ///
850         /// Requires the `OSX_10_9` (or greater) feature.
851         #[cfg(feature = "OSX_10_9")]
852         const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
853     }
854 
into_stream<S>(self, stream: S) -> Result<SslStream<S>> where S: Read + Write,855     fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
856     where
857         S: Read + Write,
858     {
859         unsafe {
860             let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
861             if ret != errSecSuccess {
862                 return Err(Error::from_code(ret));
863             }
864 
865             let stream = Connection {
866                 stream,
867                 err: None,
868                 panic: None,
869             };
870             let stream = Box::into_raw(Box::new(stream));
871             let ret = SSLSetConnection(self.0, stream as *mut _);
872             if ret != errSecSuccess {
873                 let _conn = Box::from_raw(stream);
874                 return Err(Error::from_code(ret));
875             }
876 
877             Ok(SslStream {
878                 ctx: self,
879                 _m: PhantomData,
880             })
881         }
882     }
883 
884     /// Performs the SSL/TLS handshake.
handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>> where S: Read + Write,885     pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
886     where
887         S: Read + Write,
888     {
889         self.into_stream(stream)
890             .map_err(HandshakeError::Failure)
891             .and_then(SslStream::handshake)
892     }
893 }
894 
895 struct Connection<S> {
896     stream: S,
897     err: Option<io::Error>,
898     panic: Option<Box<dyn Any + Send>>,
899 }
900 
901 // the logic here is based off of libcurl's
902 #[cold]
translate_err(e: &io::Error) -> OSStatus903 fn translate_err(e: &io::Error) -> OSStatus {
904     match e.kind() {
905         io::ErrorKind::NotFound => errSSLClosedGraceful,
906         io::ErrorKind::ConnectionReset => errSSLClosedAbort,
907         io::ErrorKind::WouldBlock |
908         io::ErrorKind::NotConnected => errSSLWouldBlock,
909         _ => errSecIO,
910     }
911 }
912 
read_func<S>( connection: SSLConnectionRef, data: *mut c_void, data_length: *mut usize, ) -> OSStatus where S: Read,913 unsafe extern "C" fn read_func<S>(
914     connection: SSLConnectionRef,
915     data: *mut c_void,
916     data_length: *mut usize,
917 ) -> OSStatus
918 where
919     S: Read,
920 {
921     let conn: &mut Connection<S> = &mut *(connection as *mut _);
922     let data = slice::from_raw_parts_mut(data as *mut u8, *data_length);
923     let mut start = 0;
924     let mut ret = errSecSuccess;
925 
926     while start < data.len() {
927         match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
928             Ok(Ok(0)) => {
929                 ret = errSSLClosedNoNotify;
930                 break;
931             }
932             Ok(Ok(len)) => start += len,
933             Ok(Err(e)) => {
934                 ret = translate_err(&e);
935                 conn.err = Some(e);
936                 break;
937             }
938             Err(e) => {
939                 ret = errSecIO;
940                 conn.panic = Some(e);
941                 break;
942             }
943         }
944     }
945 
946     *data_length = start;
947     ret
948 }
949 
write_func<S>( connection: SSLConnectionRef, data: *const c_void, data_length: *mut usize, ) -> OSStatus where S: Write,950 unsafe extern "C" fn write_func<S>(
951     connection: SSLConnectionRef,
952     data: *const c_void,
953     data_length: *mut usize,
954 ) -> OSStatus
955 where
956     S: Write,
957 {
958     let conn: &mut Connection<S> = &mut *(connection as *mut _);
959     let data = slice::from_raw_parts(data as *mut u8, *data_length);
960     let mut start = 0;
961     let mut ret = errSecSuccess;
962 
963     while start < data.len() {
964         match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
965             Ok(Ok(0)) => {
966                 ret = errSSLClosedNoNotify;
967                 break;
968             }
969             Ok(Ok(len)) => start += len,
970             Ok(Err(e)) => {
971                 ret = translate_err(&e);
972                 conn.err = Some(e);
973                 break;
974             }
975             Err(e) => {
976                 ret = errSecIO;
977                 conn.panic = Some(e);
978                 break;
979             }
980         }
981     }
982 
983     *data_length = start;
984     ret
985 }
986 
987 /// A type implementing SSL/TLS encryption over an underlying stream.
988 pub struct SslStream<S> {
989     ctx: SslContext,
990     _m: PhantomData<S>,
991 }
992 
993 impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
994     #[cold]
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result995     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
996         fmt.debug_struct("SslStream")
997             .field("context", &self.ctx)
998             .field("stream", self.get_ref())
999             .finish()
1000     }
1001 }
1002 
1003 impl<S> Drop for SslStream<S> {
drop(&mut self)1004     fn drop(&mut self) {
1005         unsafe {
1006             let mut conn = ptr::null();
1007             let ret = SSLGetConnection(self.ctx.0, &mut conn);
1008             assert!(ret == errSecSuccess);
1009             Box::<Connection<S>>::from_raw(conn as *mut _);
1010         }
1011     }
1012 }
1013 
1014 impl<S> SslStream<S> {
handshake(mut self) -> result::Result<Self, HandshakeError<S>>1015     fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
1016         match unsafe { SSLHandshake(self.ctx.0) } {
1017             errSecSuccess => Ok(self),
1018             reason @ errSSLPeerAuthCompleted
1019             | reason @ errSSLClientCertRequested
1020             | reason @ errSSLWouldBlock
1021             | reason @ errSSLClientHelloReceived => {
1022                 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
1023                     stream: self,
1024                     error: Error::from_code(reason),
1025                 }))
1026             }
1027             err => {
1028                 self.check_panic();
1029                 Err(HandshakeError::Failure(Error::from_code(err)))
1030             }
1031         }
1032     }
1033 
1034     /// Returns a shared reference to the inner stream.
1035     #[inline(always)]
get_ref(&self) -> &S1036     pub fn get_ref(&self) -> &S {
1037         &self.connection().stream
1038     }
1039 
1040     /// Returns a mutable reference to the underlying stream.
1041     #[inline(always)]
get_mut(&mut self) -> &mut S1042     pub fn get_mut(&mut self) -> &mut S {
1043         &mut self.connection_mut().stream
1044     }
1045 
1046     /// Returns a shared reference to the `SslContext` of the stream.
1047     #[inline(always)]
context(&self) -> &SslContext1048     pub fn context(&self) -> &SslContext {
1049         &self.ctx
1050     }
1051 
1052     /// Returns a mutable reference to the `SslContext` of the stream.
1053     #[inline(always)]
context_mut(&mut self) -> &mut SslContext1054     pub fn context_mut(&mut self) -> &mut SslContext {
1055         &mut self.ctx
1056     }
1057 
1058     /// Shuts down the connection.
close(&mut self) -> result::Result<(), io::Error>1059     pub fn close(&mut self) -> result::Result<(), io::Error> {
1060         unsafe {
1061             let ret = SSLClose(self.ctx.0);
1062             if ret == errSecSuccess {
1063                 Ok(())
1064             } else {
1065                 Err(self.get_error(ret))
1066             }
1067         }
1068     }
1069 
connection(&self) -> &Connection<S>1070     fn connection(&self) -> &Connection<S> {
1071         unsafe {
1072             let mut conn = ptr::null();
1073             let ret = SSLGetConnection(self.ctx.0, &mut conn);
1074             assert!(ret == errSecSuccess);
1075 
1076             &mut *(conn as *mut Connection<S>)
1077         }
1078     }
1079 
connection_mut(&mut self) -> &mut Connection<S>1080     fn connection_mut(&mut self) -> &mut Connection<S> {
1081         unsafe {
1082             let mut conn = ptr::null();
1083             let ret = SSLGetConnection(self.ctx.0, &mut conn);
1084             assert!(ret == errSecSuccess);
1085 
1086             &mut *(conn as *mut Connection<S>)
1087         }
1088     }
1089 
1090     #[cold]
check_panic(&mut self)1091     fn check_panic(&mut self) {
1092         let conn = self.connection_mut();
1093         if let Some(err) = conn.panic.take() {
1094             panic::resume_unwind(err);
1095         }
1096     }
1097 
1098     #[cold]
get_error(&mut self, ret: OSStatus) -> io::Error1099     fn get_error(&mut self, ret: OSStatus) -> io::Error {
1100         self.check_panic();
1101 
1102         if let Some(err) = self.connection_mut().err.take() {
1103             err
1104         } else {
1105             io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1106         }
1107     }
1108 }
1109 
1110 impl<S: Read + Write> Read for SslStream<S> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>1111     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1112         // Below we base our return value off the amount of data read, so a
1113         // zero-length buffer might cause us to erroneously interpret this
1114         // request as an error. Instead short-circuit that logic and return
1115         // `Ok(0)` instead.
1116         if buf.is_empty() {
1117             return Ok(0);
1118         }
1119 
1120         // If some data was buffered but not enough to fill `buf`, SSLRead
1121         // will try to read a new packet. This is bad because there may be
1122         // no more data but the socket is remaining open (e.g HTTPS with
1123         // Connection: keep-alive).
1124         let buffered = self.context().buffered_read_size().unwrap_or(0);
1125         let to_read = if buffered > 0 {
1126             cmp::min(buffered, buf.len())
1127         } else {
1128             buf.len()
1129         };
1130 
1131         unsafe {
1132             let mut nread = 0;
1133             let ret = SSLRead(self.ctx.0, buf.as_mut_ptr() as *mut _, to_read, &mut nread);
1134             // SSLRead can return an error at the same time it returns the last
1135             // chunk of data (!)
1136             if nread > 0 {
1137                 return Ok(nread as usize);
1138             }
1139 
1140             match ret {
1141                 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1142                 _ => Err(self.get_error(ret)),
1143             }
1144         }
1145     }
1146 }
1147 
1148 impl<S: Read + Write> Write for SslStream<S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>1149     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1150         // Like above in read, short circuit a 0-length write
1151         if buf.is_empty() {
1152             return Ok(0);
1153         }
1154         unsafe {
1155             let mut nwritten = 0;
1156             let ret = SSLWrite(
1157                 self.ctx.0,
1158                 buf.as_ptr() as *const _,
1159                 buf.len(),
1160                 &mut nwritten,
1161             );
1162             // just to be safe, base success off of nwritten rather than ret
1163             // for the same reason as in read
1164             if nwritten > 0 {
1165                 Ok(nwritten as usize)
1166             } else {
1167                 Err(self.get_error(ret))
1168             }
1169         }
1170     }
1171 
flush(&mut self) -> io::Result<()>1172     fn flush(&mut self) -> io::Result<()> {
1173         self.connection_mut().stream.flush()
1174     }
1175 }
1176 
1177 /// A builder type to simplify the creation of client side `SslStream`s.
1178 #[derive(Debug)]
1179 pub struct ClientBuilder {
1180     identity: Option<SecIdentity>,
1181     certs: Vec<SecCertificate>,
1182     chain: Vec<SecCertificate>,
1183     protocol_min: Option<SslProtocol>,
1184     protocol_max: Option<SslProtocol>,
1185     trust_certs_only: bool,
1186     use_sni: bool,
1187     danger_accept_invalid_certs: bool,
1188     danger_accept_invalid_hostnames: bool,
1189     whitelisted_ciphers: Vec<CipherSuite>,
1190     blacklisted_ciphers: Vec<CipherSuite>,
1191     alpn: Option<Vec<String>>,
1192     enable_session_tickets: bool,
1193 }
1194 
1195 impl Default for ClientBuilder {
1196     #[inline(always)]
default() -> Self1197     fn default() -> Self {
1198         Self::new()
1199     }
1200 }
1201 
1202 impl ClientBuilder {
1203     /// Creates a new builder with default options.
1204     #[inline]
new() -> Self1205     pub fn new() -> Self {
1206         Self {
1207             identity: None,
1208             certs: Vec::new(),
1209             chain: Vec::new(),
1210             protocol_min: None,
1211             protocol_max: None,
1212             trust_certs_only: false,
1213             use_sni: true,
1214             danger_accept_invalid_certs: false,
1215             danger_accept_invalid_hostnames: false,
1216             whitelisted_ciphers: Vec::new(),
1217             blacklisted_ciphers: Vec::new(),
1218             alpn: None,
1219             enable_session_tickets: false,
1220         }
1221     }
1222 
1223     /// Specifies the set of root certificates to trust when
1224     /// verifying the server's certificate.
1225     #[inline]
anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self1226     pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1227         self.certs = certs.to_owned();
1228         self
1229     }
1230 
1231     /// Add the certificate the set of root certificates to trust
1232     /// when verifying the server's certificate.
1233     #[inline]
add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self1234     pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
1235         self.certs.push(certs.to_owned());
1236         self
1237     }
1238 
1239     /// Specifies whether to trust the built-in certificates in addition
1240     /// to specified anchor certificates.
1241     #[inline(always)]
trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self1242     pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1243         self.trust_certs_only = only;
1244         self
1245     }
1246 
1247     /// Specifies whether to trust invalid certificates.
1248     ///
1249     /// # Warning
1250     ///
1251     /// You should think very carefully before using this method. If invalid
1252     /// certificates are trusted, *any* certificate for *any* site will be
1253     /// trusted for use. This includes expired certificates. This introduces
1254     /// significant vulnerabilities, and should only be used as a last resort.
1255     #[inline(always)]
danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self1256     pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1257         self.danger_accept_invalid_certs = noverify;
1258         self
1259     }
1260 
1261     /// Specifies whether to use Server Name Indication (SNI).
1262     #[inline(always)]
use_sni(&mut self, use_sni: bool) -> &mut Self1263     pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1264         self.use_sni = use_sni;
1265         self
1266     }
1267 
1268     /// Specifies whether to verify that the server's hostname matches its certificate.
1269     ///
1270     /// # Warning
1271     ///
1272     /// You should think very carefully before using this method. If hostnames are not verified,
1273     /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1274     /// vulnerabilities, and should only be used as a last resort.
1275     #[inline(always)]
danger_accept_invalid_hostnames( &mut self, danger_accept_invalid_hostnames: bool, ) -> &mut Self1276     pub fn danger_accept_invalid_hostnames(
1277         &mut self,
1278         danger_accept_invalid_hostnames: bool,
1279     ) -> &mut Self {
1280         self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1281         self
1282     }
1283 
1284     /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self1285     pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1286         self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
1287         self
1288     }
1289 
1290     /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self1291     pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1292         self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
1293         self
1294     }
1295 
1296     /// Use the specified identity as a SSL/TLS client certificate.
identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self1297     pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1298         self.identity = Some(identity.clone());
1299         self.chain = chain.to_owned();
1300         self
1301     }
1302 
1303     /// Configure the minimum protocol that this client will support.
1304     #[inline(always)]
protocol_min(&mut self, min: SslProtocol) -> &mut Self1305     pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1306         self.protocol_min = Some(min);
1307         self
1308     }
1309 
1310     /// Configure the minimum protocol that this client will support.
1311     #[inline(always)]
protocol_max(&mut self, max: SslProtocol) -> &mut Self1312     pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1313         self.protocol_max = Some(max);
1314         self
1315     }
1316 
1317     /// Configures the set of protocols used for ALPN.
1318     #[cfg(feature = "alpn")]
alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self1319     pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1320         self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
1321         self
1322     }
1323 
1324     /// Configures the use of the RFC 5077 `SessionTicket` extension.
1325     ///
1326     /// Defaults to `false`.
1327     #[cfg(feature = "session-tickets")]
1328     #[inline(always)]
enable_session_tickets(&mut self, enable: bool) -> &mut Self1329     pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1330         self.enable_session_tickets = enable;
1331         self
1332     }
1333 
1334     /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1335     ///
1336     /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
handshake<S>( &self, domain: &str, stream: S, ) -> result::Result<SslStream<S>, ClientHandshakeError<S>> where S: Read + Write,1337     pub fn handshake<S>(
1338         &self,
1339         domain: &str,
1340         stream: S,
1341     ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1342     where
1343         S: Read + Write,
1344     {
1345         // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1346         // of the handshake logic through that.
1347         let stream = MidHandshakeSslStream {
1348             stream: self.ctx_into_stream(domain, stream)?,
1349             error: Error::from(errSecSuccess),
1350         };
1351 
1352         let certs = self.certs.clone();
1353         let stream = MidHandshakeClientBuilder {
1354             stream,
1355             domain: if self.danger_accept_invalid_hostnames {
1356                 None
1357             } else {
1358                 Some(domain.to_string())
1359             },
1360             certs,
1361             trust_certs_only: self.trust_certs_only,
1362             danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1363         };
1364         stream.handshake()
1365     }
1366 
ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>> where S: Read + Write,1367     fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1368     where
1369         S: Read + Write,
1370     {
1371         let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1372 
1373         if self.use_sni {
1374             ctx.set_peer_domain_name(domain)?;
1375         }
1376         if let Some(ref identity) = self.identity {
1377             ctx.set_certificate(identity, &self.chain)?;
1378         }
1379         #[cfg(feature = "alpn")]
1380         {
1381             if let Some(ref alpn) = self.alpn {
1382                 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1383             }
1384         }
1385         #[cfg(feature = "session-tickets")]
1386         {
1387             if self.enable_session_tickets {
1388                 // We must use the domain here to ensure that we go through certificate validation
1389                 // again rather than resuming the session if the domain changes.
1390                 ctx.set_peer_id(domain.as_bytes())?;
1391                 ctx.set_session_tickets_enabled(true)?;
1392             }
1393         }
1394         ctx.set_break_on_server_auth(true)?;
1395         self.configure_protocols(&mut ctx)?;
1396         self.configure_ciphers(&mut ctx)?;
1397 
1398         ctx.into_stream(stream)
1399     }
1400 
configure_protocols(&self, ctx: &mut SslContext) -> Result<()>1401     fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1402         if let Some(min) = self.protocol_min {
1403             ctx.set_protocol_version_min(min)?;
1404         }
1405         if let Some(max) = self.protocol_max {
1406             ctx.set_protocol_version_max(max)?;
1407         }
1408         Ok(())
1409     }
1410 
configure_ciphers(&self, ctx: &mut SslContext) -> Result<()>1411     fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1412         let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1413             ctx.enabled_ciphers()?
1414         } else {
1415             self.whitelisted_ciphers.clone()
1416         };
1417 
1418         if !self.blacklisted_ciphers.is_empty() {
1419             ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1420         }
1421 
1422         ctx.set_enabled_ciphers(&ciphers)?;
1423         Ok(())
1424     }
1425 }
1426 
1427 /// A builder type to simplify the creation of server-side `SslStream`s.
1428 #[derive(Debug)]
1429 pub struct ServerBuilder {
1430     identity: SecIdentity,
1431     certs: Vec<SecCertificate>,
1432 }
1433 
1434 impl ServerBuilder {
1435     /// Creates a new `ServerBuilder` which will use the specified identity
1436     /// and certificate chain for handshakes.
new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self1437     pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1438         Self {
1439             identity: identity.clone(),
1440             certs: certs.to_owned(),
1441         }
1442     }
1443 
1444     /// Creates a new `ServerBuilder` which will use the identity
1445     /// from the given PKCS #12 data.
1446     ///
1447     /// This operation fails if PKCS #12 file contains zero or more than one identity.
1448     ///
1449     /// This is a shortcut for the most common operation.
from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self>1450     pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
1451         let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
1452             .passphrase(passphrase)
1453             .import(pkcs12_der)?
1454             .into_iter()
1455             .flat_map(|idendity| {
1456                 let certs = idendity.cert_chain.unwrap_or_default();
1457                 idendity.identity.map(|identity| (identity, certs))
1458             })
1459             .collect();
1460         if identities.len() == 1 {
1461             let (identity, certs) = identities.pop().unwrap();
1462             Ok(ServerBuilder::new(&identity, &certs))
1463         } else {
1464             // This error code is not really helpful
1465             Err(Error::from_code(errSecParam))
1466         }
1467     }
1468 
1469     /// Create a SSL context for lower-level stream initialization.
new_ssl_context(&self) -> Result<SslContext>1470     pub fn new_ssl_context(&self) -> Result<SslContext> {
1471         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1472         ctx.set_certificate(&self.identity, &self.certs)?;
1473         Ok(ctx)
1474     }
1475 
1476     /// Initiates a new SSL/TLS session over a stream.
handshake<S>(&self, stream: S) -> Result<SslStream<S>> where S: Read + Write,1477     pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1478     where
1479         S: Read + Write,
1480     {
1481         match self.new_ssl_context()?.handshake(stream) {
1482             Ok(stream) => Ok(stream),
1483             Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
1484             Err(HandshakeError::Failure(err)) => Err(err),
1485         }
1486     }
1487 }
1488 
1489 #[cfg(test)]
1490 mod test {
1491     use std::io;
1492     use std::io::prelude::*;
1493     use std::net::TcpStream;
1494 
1495     use super::*;
1496 
1497     #[test]
server_builder_from_pkcs12()1498     fn server_builder_from_pkcs12() {
1499         let pkcs12_der = include_bytes!("../test/server.p12");
1500         ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
1501     }
1502 
1503     #[test]
connect()1504     fn connect() {
1505         let mut ctx = p!(SslContext::new(
1506             SslProtocolSide::CLIENT,
1507             SslConnectionType::STREAM
1508         ));
1509         p!(ctx.set_peer_domain_name("google.com"));
1510         let stream = p!(TcpStream::connect("google.com:443"));
1511         p!(ctx.handshake(stream));
1512     }
1513 
1514     #[test]
connect_bad_domain()1515     fn connect_bad_domain() {
1516         let mut ctx = p!(SslContext::new(
1517             SslProtocolSide::CLIENT,
1518             SslConnectionType::STREAM
1519         ));
1520         p!(ctx.set_peer_domain_name("foobar.com"));
1521         let stream = p!(TcpStream::connect("google.com:443"));
1522         match ctx.handshake(stream) {
1523             Ok(_) => panic!("expected failure"),
1524             Err(_) => {}
1525         }
1526     }
1527 
1528     #[test]
load_page()1529     fn load_page() {
1530         let mut ctx = p!(SslContext::new(
1531             SslProtocolSide::CLIENT,
1532             SslConnectionType::STREAM
1533         ));
1534         p!(ctx.set_peer_domain_name("google.com"));
1535         let stream = p!(TcpStream::connect("google.com:443"));
1536         let mut stream = p!(ctx.handshake(stream));
1537         p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1538         p!(stream.flush());
1539         let mut buf = vec![];
1540         p!(stream.read_to_end(&mut buf));
1541         println!("{}", String::from_utf8_lossy(&buf));
1542     }
1543 
1544     #[test]
client_no_session_ticket_resumption()1545     fn client_no_session_ticket_resumption() {
1546         for _ in 0..2 {
1547             let stream = p!(TcpStream::connect("google.com:443"));
1548 
1549             // Manually handshake here.
1550             let stream = MidHandshakeSslStream {
1551                 stream: ClientBuilder::new()
1552                     .ctx_into_stream("google.com", stream)
1553                     .unwrap(),
1554                 error: Error::from(errSecSuccess),
1555             };
1556 
1557             let mut result = stream.handshake();
1558 
1559             if let Err(HandshakeError::Interrupted(stream)) = result {
1560                 assert!(stream.server_auth_completed());
1561                 result = stream.handshake();
1562             } else {
1563                 panic!("Unexpectedly skipped server auth");
1564             }
1565 
1566             assert!(result.is_ok());
1567         }
1568     }
1569 
1570     #[test]
1571     #[cfg(feature = "session-tickets")]
client_session_ticket_resumption()1572     fn client_session_ticket_resumption() {
1573         // The first time through this loop, we should do a full handshake. The second time, we
1574         // should immediately finish the handshake without breaking on server auth.
1575         for i in 0..2 {
1576             let stream = p!(TcpStream::connect("google.com:443"));
1577             let mut builder = ClientBuilder::new();
1578             builder.enable_session_tickets(true);
1579 
1580             // Manually handshake here.
1581             let stream = MidHandshakeSslStream {
1582                 stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1583                 error: Error::from(errSecSuccess),
1584             };
1585 
1586             let mut result = stream.handshake();
1587 
1588             if let Err(HandshakeError::Interrupted(stream)) = result {
1589                 assert!(stream.server_auth_completed());
1590                 assert_eq!(
1591                     i, 0,
1592                     "Session ticket resumption did not work, server auth was not skipped"
1593                 );
1594                 result = stream.handshake();
1595             } else {
1596                 assert_eq!(i, 1, "Unexpectedly skipped server auth");
1597             }
1598 
1599             assert!(result.is_ok());
1600         }
1601     }
1602 
1603     #[test]
1604     #[cfg(feature = "alpn")]
client_alpn_accept()1605     fn client_alpn_accept() {
1606         let mut ctx = p!(SslContext::new(
1607             SslProtocolSide::CLIENT,
1608             SslConnectionType::STREAM
1609         ));
1610         p!(ctx.set_peer_domain_name("google.com"));
1611         p!(ctx.set_alpn_protocols(&vec!["h2"]));
1612         let stream = p!(TcpStream::connect("google.com:443"));
1613         let stream = ctx.handshake(stream).unwrap();
1614         assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1615     }
1616 
1617     #[test]
1618     #[cfg(feature = "alpn")]
client_alpn_reject()1619     fn client_alpn_reject() {
1620         let mut ctx = p!(SslContext::new(
1621             SslProtocolSide::CLIENT,
1622             SslConnectionType::STREAM
1623         ));
1624         p!(ctx.set_peer_domain_name("google.com"));
1625         p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1626         let stream = p!(TcpStream::connect("google.com:443"));
1627         let stream = ctx.handshake(stream).unwrap();
1628         assert!(stream.context().alpn_protocols().is_err());
1629     }
1630 
1631     #[test]
client_no_anchor_certs()1632     fn client_no_anchor_certs() {
1633         let stream = p!(TcpStream::connect("google.com:443"));
1634         assert!(ClientBuilder::new()
1635             .trust_anchor_certificates_only(true)
1636             .handshake("google.com", stream)
1637             .is_err());
1638     }
1639 
1640     #[test]
client_bad_domain()1641     fn client_bad_domain() {
1642         let stream = p!(TcpStream::connect("google.com:443"));
1643         assert!(ClientBuilder::new()
1644             .handshake("foobar.com", stream)
1645             .is_err());
1646     }
1647 
1648     #[test]
client_bad_domain_ignored()1649     fn client_bad_domain_ignored() {
1650         let stream = p!(TcpStream::connect("google.com:443"));
1651         ClientBuilder::new()
1652             .danger_accept_invalid_hostnames(true)
1653             .handshake("foobar.com", stream)
1654             .unwrap();
1655     }
1656 
1657     #[test]
connect_no_verify_ssl()1658     fn connect_no_verify_ssl() {
1659         let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1660         let mut builder = ClientBuilder::new();
1661         builder.danger_accept_invalid_certs(true);
1662         builder.handshake("expired.badssl.com", stream).unwrap();
1663     }
1664 
1665     #[test]
load_page_client()1666     fn load_page_client() {
1667         let stream = p!(TcpStream::connect("google.com:443"));
1668         let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1669         p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1670         p!(stream.flush());
1671         let mut buf = vec![];
1672         p!(stream.read_to_end(&mut buf));
1673         println!("{}", String::from_utf8_lossy(&buf));
1674     }
1675 
1676     #[test]
1677     #[cfg_attr(target_os = "ios", ignore)] // FIXME what's going on with ios?
cipher_configuration()1678     fn cipher_configuration() {
1679         let mut ctx = p!(SslContext::new(
1680             SslProtocolSide::SERVER,
1681             SslConnectionType::STREAM
1682         ));
1683         let ciphers = p!(ctx.enabled_ciphers());
1684         let ciphers = ciphers
1685             .iter()
1686             .enumerate()
1687             .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1688             .collect::<Vec<_>>();
1689         p!(ctx.set_enabled_ciphers(&ciphers));
1690         assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1691     }
1692 
1693     #[test]
test_builder_whitelist_ciphers()1694     fn test_builder_whitelist_ciphers() {
1695         let stream = p!(TcpStream::connect("google.com:443"));
1696 
1697         let ctx = p!(SslContext::new(
1698             SslProtocolSide::CLIENT,
1699             SslConnectionType::STREAM
1700         ));
1701         assert!(p!(ctx.enabled_ciphers()).len() > 1);
1702 
1703         let ciphers = p!(ctx.enabled_ciphers());
1704         let cipher = ciphers.first().unwrap();
1705         let stream = p!(ClientBuilder::new()
1706             .whitelist_ciphers(&[*cipher])
1707             .ctx_into_stream("google.com", stream));
1708 
1709         assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1710     }
1711 
1712     #[test]
1713     #[cfg_attr(target_os = "ios", ignore)] // FIXME same issue as cipher_configuration
test_builder_blacklist_ciphers()1714     fn test_builder_blacklist_ciphers() {
1715         let stream = p!(TcpStream::connect("google.com:443"));
1716 
1717         let ctx = p!(SslContext::new(
1718             SslProtocolSide::CLIENT,
1719             SslConnectionType::STREAM
1720         ));
1721         let num = p!(ctx.enabled_ciphers()).len();
1722         assert!(num > 1);
1723 
1724         let ciphers = p!(ctx.enabled_ciphers());
1725         let cipher = ciphers.first().unwrap();
1726         let stream = p!(ClientBuilder::new()
1727             .blacklist_ciphers(&[*cipher])
1728             .ctx_into_stream("google.com", stream));
1729 
1730         assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1731     }
1732 
1733     #[test]
idle_context_peer_trust()1734     fn idle_context_peer_trust() {
1735         let ctx = p!(SslContext::new(
1736             SslProtocolSide::SERVER,
1737             SslConnectionType::STREAM
1738         ));
1739         assert!(ctx.peer_trust2().is_err());
1740     }
1741 
1742     #[test]
peer_id()1743     fn peer_id() {
1744         let mut ctx = p!(SslContext::new(
1745             SslProtocolSide::SERVER,
1746             SslConnectionType::STREAM
1747         ));
1748         assert!(p!(ctx.peer_id()).is_none());
1749         p!(ctx.set_peer_id(b"foobar"));
1750         assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1751     }
1752 
1753     #[test]
peer_domain_name()1754     fn peer_domain_name() {
1755         let mut ctx = p!(SslContext::new(
1756             SslProtocolSide::CLIENT,
1757             SslConnectionType::STREAM
1758         ));
1759         assert_eq!("", p!(ctx.peer_domain_name()));
1760         p!(ctx.set_peer_domain_name("foobar.com"));
1761         assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1762     }
1763 
1764     #[test]
1765     #[should_panic(expected = "blammo")]
write_panic()1766     fn write_panic() {
1767         struct ExplodingStream(TcpStream);
1768 
1769         impl Read for ExplodingStream {
1770             fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1771                 self.0.read(buf)
1772             }
1773         }
1774 
1775         impl Write for ExplodingStream {
1776             fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1777                 panic!("blammo");
1778             }
1779 
1780             fn flush(&mut self) -> io::Result<()> {
1781                 self.0.flush()
1782             }
1783         }
1784 
1785         let mut ctx = p!(SslContext::new(
1786             SslProtocolSide::CLIENT,
1787             SslConnectionType::STREAM
1788         ));
1789         p!(ctx.set_peer_domain_name("google.com"));
1790         let stream = p!(TcpStream::connect("google.com:443"));
1791         let _ = ctx.handshake(ExplodingStream(stream));
1792     }
1793 
1794     #[test]
1795     #[should_panic(expected = "blammo")]
read_panic()1796     fn read_panic() {
1797         struct ExplodingStream(TcpStream);
1798 
1799         impl Read for ExplodingStream {
1800             fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1801                 panic!("blammo");
1802             }
1803         }
1804 
1805         impl Write for ExplodingStream {
1806             fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1807                 self.0.write(buf)
1808             }
1809 
1810             fn flush(&mut self) -> io::Result<()> {
1811                 self.0.flush()
1812             }
1813         }
1814 
1815         let mut ctx = p!(SslContext::new(
1816             SslProtocolSide::CLIENT,
1817             SslConnectionType::STREAM
1818         ));
1819         p!(ctx.set_peer_domain_name("google.com"));
1820         let stream = p!(TcpStream::connect("google.com:443"));
1821         let _ = ctx.handshake(ExplodingStream(stream));
1822     }
1823 
1824     #[test]
zero_length_buffers()1825     fn zero_length_buffers() {
1826         let mut ctx = p!(SslContext::new(
1827             SslProtocolSide::CLIENT,
1828             SslConnectionType::STREAM
1829         ));
1830         p!(ctx.set_peer_domain_name("google.com"));
1831         let stream = p!(TcpStream::connect("google.com:443"));
1832         let mut stream = ctx.handshake(stream).unwrap();
1833         assert_eq!(stream.write(b"").unwrap(), 0);
1834         assert_eq!(stream.read(&mut []).unwrap(), 0);
1835     }
1836 }
1837