1 //! SSL/TLS encryption support using Secure Transport.
2 //!
3 //! # Examples
4 //!
5 //! To connect as a client to a server with a certificate trusted by the system:
6 //!
7 //! ```rust
8 //! use std::io::prelude::*;
9 //! use std::net::TcpStream;
10 //! use security_framework::secure_transport::ClientBuilder;
11 //!
12 //! let stream = TcpStream::connect("google.com:443").unwrap();
13 //! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
14 //!
15 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
16 //! let mut page = vec![];
17 //! stream.read_to_end(&mut page).unwrap();
18 //! println!("{}", String::from_utf8_lossy(&page));
19 //! ```
20 //!
21 //! To connect to a server with a certificate that's *not* trusted by the
22 //! system, specify the root certificates for the server's chain to the
23 //! `ClientBuilder`:
24 //!
25 //! ```rust,no_run
26 //! use std::io::prelude::*;
27 //! use std::net::TcpStream;
28 //! use security_framework::secure_transport::ClientBuilder;
29 //!
30 //! # let root_cert = unsafe { std::mem::zeroed() };
31 //! let stream = TcpStream::connect("my_server.com:443").unwrap();
32 //! let mut stream = ClientBuilder::new()
33 //!                      .anchor_certificates(&[root_cert])
34 //!                      .handshake("my_server.com", stream)
35 //!                      .unwrap();
36 //!
37 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
38 //! let mut page = vec![];
39 //! stream.read_to_end(&mut page).unwrap();
40 //! println!("{}", String::from_utf8_lossy(&page));
41 //! ```
42 //!
43 //! For more advanced configuration, the `SslContext` type can be used directly.
44 //!
45 //! To run a server:
46 //!
47 //! ```rust,no_run
48 //! use std::net::TcpListener;
49 //! use std::thread;
50 //! use security_framework::secure_transport::{SslContext, SslProtocolSide, SslConnectionType};
51 //!
52 //! // Create a TCP listener and start accepting on it.
53 //! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
54 //!
55 //! for stream in listener.incoming() {
56 //!     let stream = stream.unwrap();
57 //!     thread::spawn(move || {
58 //!         // Create a new context configured to operate on the server side of
59 //!         // a traditional SSL/TLS session.
60 //!         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
61 //!                           .unwrap();
62 //!
63 //!         // Install the certificate chain that we will be using.
64 //!         # let identity = unsafe { std::mem::zeroed() };
65 //!         # let intermediate_cert = unsafe { std::mem::zeroed() };
66 //!         # let root_cert = unsafe { std::mem::zeroed() };
67 //!         ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
68 //!
69 //!         // Perform the SSL/TLS handshake and get our stream.
70 //!         let mut stream = ctx.handshake(stream).unwrap();
71 //!     });
72 //! }
73 //!
74 //! ```
75 #[allow(unused_imports)]
76 use core_foundation::array::{CFArray, CFArrayRef};
77 
78 use core_foundation::base::{Boolean, TCFType};
79 #[cfg(feature = "alpn")]
80 use core_foundation::string::CFString;
81 use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82 use std::os::raw::c_void;
83 
84 #[allow(unused_imports)]
85 use security_framework_sys::base::{
86     errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87     errSecUnimplemented,
88 };
89 
90 use security_framework_sys::secure_transport::*;
91 use std::any::Any;
92 use std::cmp;
93 use std::fmt;
94 use std::io;
95 use std::io::prelude::*;
96 use std::marker::PhantomData;
97 use std::mem;
98 use std::panic::{self, AssertUnwindSafe};
99 use std::ptr;
100 use std::result;
101 use std::slice;
102 
103 use crate::base::{Error, Result};
104 use crate::certificate::SecCertificate;
105 use crate::cipher_suite::CipherSuite;
106 use crate::identity::SecIdentity;
107 use crate::policy::SecPolicy;
108 use crate::trust::{SecTrust, TrustResult};
109 use crate::{cvt, AsInner};
110 
111 /// Specifies a side of a TLS session.
112 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
113 pub struct SslProtocolSide(SSLProtocolSide);
114 
115 impl SslProtocolSide {
116     /// The server side of the session.
117     pub const SERVER: Self = Self(kSSLServerSide);
118 
119     /// The client side of the session.
120     pub const CLIENT: Self = Self(kSSLClientSide);
121 }
122 
123 /// Specifies the type of TLS session.
124 #[derive(Debug, Copy, Clone)]
125 pub struct SslConnectionType(SSLConnectionType);
126 
127 impl SslConnectionType {
128     /// A traditional TLS stream.
129     pub const STREAM: Self = Self(kSSLStreamType);
130 
131     /// A DTLS session.
132     pub const DATAGRAM: Self = Self(kSSLDatagramType);
133 }
134 
135 /// An error or intermediate state after a TLS handshake attempt.
136 #[derive(Debug)]
137 pub enum HandshakeError<S> {
138     /// The handshake failed.
139     Failure(Error),
140     /// The handshake was interrupted midway through.
141     Interrupted(MidHandshakeSslStream<S>),
142 }
143 
144 impl<S> From<Error> for HandshakeError<S> {
from(err: Error) -> Self145     fn from(err: Error) -> Self {
146         Self::Failure(err)
147     }
148 }
149 
150 /// An error or intermediate state after a TLS handshake attempt.
151 #[derive(Debug)]
152 pub enum ClientHandshakeError<S> {
153     /// The handshake failed.
154     Failure(Error),
155     /// The handshake was interrupted midway through.
156     Interrupted(MidHandshakeClientBuilder<S>),
157 }
158 
159 impl<S> From<Error> for ClientHandshakeError<S> {
from(err: Error) -> Self160     fn from(err: Error) -> Self {
161         Self::Failure(err)
162     }
163 }
164 
165 /// An SSL stream midway through the handshake process.
166 #[derive(Debug)]
167 pub struct MidHandshakeSslStream<S> {
168     stream: SslStream<S>,
169     error: Error,
170 }
171 
172 impl<S> MidHandshakeSslStream<S> {
173     /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S174     pub fn get_ref(&self) -> &S {
175         self.stream.get_ref()
176     }
177 
178     /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S179     pub fn get_mut(&mut self) -> &mut S {
180         self.stream.get_mut()
181     }
182 
183     /// Returns a shared reference to the `SslContext` of the stream.
context(&self) -> &SslContext184     pub fn context(&self) -> &SslContext {
185         self.stream.context()
186     }
187 
188     /// Returns a mutable reference to the `SslContext` of the stream.
context_mut(&mut self) -> &mut SslContext189     pub fn context_mut(&mut self) -> &mut SslContext {
190         self.stream.context_mut()
191     }
192 
193     /// Returns `true` iff `break_on_server_auth` was set and the handshake has
194     /// progressed to that point.
server_auth_completed(&self) -> bool195     pub fn server_auth_completed(&self) -> bool {
196         self.error.code() == errSSLPeerAuthCompleted
197     }
198 
199     /// Returns `true` iff `break_on_cert_requested` was set and the handshake
200     /// has progressed to that point.
client_cert_requested(&self) -> bool201     pub fn client_cert_requested(&self) -> bool {
202         self.error.code() == errSSLClientCertRequested
203     }
204 
205     /// Returns `true` iff the underlying stream returned an error with the
206     /// `WouldBlock` kind.
would_block(&self) -> bool207     pub fn would_block(&self) -> bool {
208         self.error.code() == errSSLWouldBlock
209     }
210 
211     /// Deprecated
reason(&self) -> OSStatus212     pub fn reason(&self) -> OSStatus {
213         self.error.code()
214     }
215 
216     /// Returns the error which caused the handshake interruption.
error(&self) -> &Error217     pub fn error(&self) -> &Error {
218         &self.error
219     }
220 
221     /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>>222     pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
223         self.stream.handshake()
224     }
225 }
226 
227 /// An SSL stream midway through the handshake process.
228 #[derive(Debug)]
229 pub struct MidHandshakeClientBuilder<S> {
230     stream: MidHandshakeSslStream<S>,
231     domain: Option<String>,
232     certs: Vec<SecCertificate>,
233     trust_certs_only: bool,
234     danger_accept_invalid_certs: bool,
235 }
236 
237 impl<S> MidHandshakeClientBuilder<S> {
238     /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S239     pub fn get_ref(&self) -> &S {
240         self.stream.get_ref()
241     }
242 
243     /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S244     pub fn get_mut(&mut self) -> &mut S {
245         self.stream.get_mut()
246     }
247 
248     /// Returns the error which caused the handshake interruption.
error(&self) -> &Error249     pub fn error(&self) -> &Error {
250         self.stream.error()
251     }
252 
253     /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>>254     pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
255         let MidHandshakeClientBuilder {
256             stream,
257             domain,
258             certs,
259             trust_certs_only,
260             danger_accept_invalid_certs,
261         } = self;
262 
263         let mut result = stream.handshake();
264         loop {
265             let stream = match result {
266                 Ok(stream) => return Ok(stream),
267                 Err(HandshakeError::Interrupted(stream)) => stream,
268                 Err(HandshakeError::Failure(err)) => {
269                     return Err(ClientHandshakeError::Failure(err))
270                 }
271             };
272 
273             if stream.would_block() {
274                 let ret = MidHandshakeClientBuilder {
275                     stream,
276                     domain,
277                     certs,
278                     trust_certs_only,
279                     danger_accept_invalid_certs,
280                 };
281                 return Err(ClientHandshakeError::Interrupted(ret));
282             }
283 
284             if stream.server_auth_completed() {
285                 if danger_accept_invalid_certs {
286                     result = stream.handshake();
287                     continue;
288                 }
289                 let mut trust = match stream.context().peer_trust2()? {
290                     Some(trust) => trust,
291                     None => {
292                         result = stream.handshake();
293                         continue;
294                     }
295                 };
296                 trust.set_anchor_certificates(&certs)?;
297                 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
298                 let policy =
299                     SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_ref().map(|s| &**s));
300                 trust.set_policy(&policy)?;
301                 let trusted = trust.evaluate()?;
302                 match trusted {
303                     TrustResult::PROCEED | TrustResult::UNSPECIFIED => {
304                         result = stream.handshake();
305                         continue;
306                     }
307                     TrustResult::DENY => {
308                         let err = Error::from_code(errSecTrustSettingDeny);
309                         return Err(ClientHandshakeError::Failure(err));
310                     }
311                     TrustResult::RECOVERABLE_TRUST_FAILURE | TrustResult::FATAL_TRUST_FAILURE => {
312                         let err = Error::from_code(errSecNotTrusted);
313                         return Err(ClientHandshakeError::Failure(err));
314                     }
315                     TrustResult::INVALID | TrustResult::OTHER_ERROR | _ => {
316                         let err = Error::from_code(errSecBadReq);
317                         return Err(ClientHandshakeError::Failure(err));
318                     }
319                 }
320             }
321 
322             let err = Error::from_code(stream.reason());
323             return Err(ClientHandshakeError::Failure(err));
324         }
325     }
326 }
327 
328 /// Specifies the state of a TLS session.
329 #[derive(Debug, PartialEq, Eq)]
330 pub struct SessionState(SSLSessionState);
331 
332 impl SessionState {
333     /// The session has not yet started.
334     pub const IDLE: Self = Self(kSSLIdle);
335 
336     /// The session is in the handshake process.
337     pub const HANDSHAKE: Self = Self(kSSLHandshake);
338 
339     /// The session is connected.
340     pub const CONNECTED: Self = Self(kSSLConnected);
341 
342     /// The session has been terminated.
343     pub const CLOSED: Self = Self(kSSLClosed);
344 
345     /// The session has been aborted due to an error.
346     pub const ABORTED: Self = Self(kSSLAborted);
347 }
348 
349 /// Specifies a server's requirement for client certificates.
350 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
351 pub struct SslAuthenticate(SSLAuthenticate);
352 
353 impl SslAuthenticate {
354     /// Do not request a client certificate.
355     pub const NEVER: Self = Self(kNeverAuthenticate);
356 
357     /// Require a client certificate.
358     pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
359 
360     /// Request but do not require a client certificate.
361     pub const TRY: Self = Self(kTryAuthenticate);
362 }
363 
364 /// Specifies the state of client certificate processing.
365 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
366 pub struct SslClientCertificateState(SSLClientCertificateState);
367 
368 impl SslClientCertificateState {
369     /// A client certificate has not been requested or sent.
370     pub const NONE: Self = Self(kSSLClientCertNone);
371 
372     /// A client certificate has been requested but not recieved.
373     pub const REQUESTED: Self = Self(kSSLClientCertRequested);
374     /// A client certificate has been received and successfully validated.
375     pub const SENT: Self = Self(kSSLClientCertSent);
376 
377     /// A client certificate has been received but has failed to validate.
378     pub const REJECTED: Self = Self(kSSLClientCertRejected); }
379 
380 /// Specifies protocol versions.
381 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
382 pub struct SslProtocol(SSLProtocol);
383 
384 impl SslProtocol {
385     /// No protocol has been or should be negotiated or specified; use the default.
386     pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
387 
388     /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
389     /// SSL 3.0.
390     pub const SSL3: Self = Self(kSSLProtocol3);
391 
392     /// The TLS 1.0 protocol is preferred, though lower versions may be used
393     /// if the peer does not support TLS 1.0.
394     pub const TLS1: Self = Self(kTLSProtocol1);
395 
396     /// The TLS 1.1 protocol is preferred, though lower versions may be used
397     /// if the peer does not support TLS 1.1.
398     pub const TLS11: Self = Self(kTLSProtocol11);
399 
400     /// The TLS 1.2 protocol is preferred, though lower versions may be used
401     /// if the peer does not support TLS 1.2.
402     pub const TLS12: Self = Self(kTLSProtocol12);
403 
404     /// The TLS 1.3 protocol is preferred, though lower versions may be used
405     /// if the peer does not support TLS 1.3.
406     pub const TLS13: Self = Self(kTLSProtocol13);
407 
408     /// Only the SSL 2.0 protocol is accepted.
409     pub const SSL2: Self = Self(kSSLProtocol2);
410 
411     /// The DTLSv1 protocol is preferred.
412     pub const DTLS1: Self = Self(kDTLSProtocol1);
413 
414     /// Only the SSL 3.0 protocol is accepted.
415     pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
416 
417     /// Only the TLS 1.0 protocol is accepted.
418     pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
419 
420     /// All supported TLS/SSL versions are accepted.
421     pub const ALL: Self = Self(kSSLProtocolAll);
422 }
423 
424 declare_TCFType! {
425     /// A Secure Transport SSL/TLS context object.
426     SslContext, SSLContextRef
427 }
428 
429 impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
430 
431 impl fmt::Debug for SslContext {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result432     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
433         let mut builder = fmt.debug_struct("SslContext");
434         if let Ok(state) = self.state() {
435             builder.field("state", &state);
436         }
437         builder.finish()
438     }
439 }
440 
441 unsafe impl Sync for SslContext {}
442 unsafe impl Send for SslContext {}
443 
444 impl AsInner for SslContext {
445     type Inner = SSLContextRef;
446 
as_inner(&self) -> SSLContextRef447     fn as_inner(&self) -> SSLContextRef {
448         self.0
449     }
450 }
451 
452 macro_rules! impl_options {
453     ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
454         $(
455             $(#[$a])*
456             pub fn $set(&mut self, value: bool) -> Result<()> {
457                 unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
458             }
459 
460             $(#[$a])*
461             pub fn $get(&self) -> Result<bool> {
462                 let mut value = 0;
463                 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
464                 Ok(value != 0)
465             }
466         )*
467     }
468 }
469 
470 impl SslContext {
471     /// Creates a new `SslContext` for the specified side and type of SSL
472     /// connection.
new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self>473     pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
474         unsafe {
475             let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
476             Ok(Self(ctx))
477         }
478     }
479 
480     /// Sets the fully qualified domain name of the peer.
481     ///
482     /// This will be used on the client side of a session to validate the
483     /// common name field of the server's certificate. It has no effect if
484     /// called on a server-side `SslContext`.
485     ///
486     /// It is *highly* recommended to call this method before starting the
487     /// handshake process.
set_peer_domain_name(&mut self, peer_name: &str) -> Result<()>488     pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
489         unsafe {
490             // SSLSetPeerDomainName doesn't need a null terminated string
491             cvt(SSLSetPeerDomainName(
492                 self.0,
493                 peer_name.as_ptr() as *const _,
494                 peer_name.len(),
495             ))
496         }
497     }
498 
499     /// Returns the peer domain name set by `set_peer_domain_name`.
peer_domain_name(&self) -> Result<String>500     pub fn peer_domain_name(&self) -> Result<String> {
501         unsafe {
502             let mut len = 0;
503             cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
504             let mut buf = vec![0; len];
505             cvt(SSLGetPeerDomainName(
506                 self.0,
507                 buf.as_mut_ptr() as *mut _,
508                 &mut len,
509             ))?;
510             Ok(String::from_utf8(buf).unwrap())
511         }
512     }
513 
514     /// Sets the certificate to be used by this side of the SSL session.
515     ///
516     /// This must be called before the handshake for server-side connections,
517     /// and can be used on the client-side to specify a client certificate.
518     ///
519     /// The `identity` corresponds to the leaf certificate and private
520     /// key, and the `certs` correspond to extra certificates in the chain.
set_certificate( &mut self, identity: &SecIdentity, certs: &[SecCertificate], ) -> Result<()>521     pub fn set_certificate(
522         &mut self,
523         identity: &SecIdentity,
524         certs: &[SecCertificate],
525     ) -> Result<()> {
526         let mut arr = vec![identity.as_CFType()];
527         arr.extend(certs.iter().map(|c| c.as_CFType()));
528         let certs = CFArray::from_CFTypes(&arr);
529 
530         unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
531     }
532 
533     /// Sets the peer ID of this session.
534     ///
535     /// A peer ID is an opaque sequence of bytes that will be used by Secure
536     /// Transport to identify the peer of an SSL session. If the peer ID of
537     /// this session matches that of a previously terminated session, the
538     /// previous session can be resumed without requiring a full handshake.
set_peer_id(&mut self, peer_id: &[u8]) -> Result<()>539     pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
540         unsafe {
541             cvt(SSLSetPeerID(
542                 self.0,
543                 peer_id.as_ptr() as *const _,
544                 peer_id.len(),
545             ))
546         }
547     }
548 
549     /// Returns the peer ID of this session.
peer_id(&self) -> Result<Option<&[u8]>>550     pub fn peer_id(&self) -> Result<Option<&[u8]>> {
551         unsafe {
552             let mut ptr = ptr::null();
553             let mut len = 0;
554             cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
555             if ptr.is_null() {
556                 Ok(None)
557             } else {
558                 Ok(Some(slice::from_raw_parts(ptr as *const _, len)))
559             }
560         }
561     }
562 
563     /// Returns the list of ciphers that are supported by Secure Transport.
supported_ciphers(&self) -> Result<Vec<CipherSuite>>564     pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
565         unsafe {
566             let mut num_ciphers = 0;
567             cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
568             let mut ciphers = vec![0; num_ciphers];
569             cvt(SSLGetSupportedCiphers(
570                 self.0,
571                 ciphers.as_mut_ptr(),
572                 &mut num_ciphers,
573             ))?;
574             Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
575         }
576     }
577 
578     /// Returns the list of ciphers that are eligible to be used for
579     /// negotiation.
enabled_ciphers(&self) -> Result<Vec<CipherSuite>>580     pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
581         unsafe {
582             let mut num_ciphers = 0;
583             cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
584             let mut ciphers = vec![0; num_ciphers];
585             cvt(SSLGetEnabledCiphers(
586                 self.0,
587                 ciphers.as_mut_ptr(),
588                 &mut num_ciphers,
589             ))?;
590             Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
591         }
592     }
593 
594     /// Sets the list of ciphers that are eligible to be used for negotiation.
set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()>595     pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
596         let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
597         unsafe {
598             cvt(SSLSetEnabledCiphers(
599                 self.0,
600                 ciphers.as_ptr(),
601                 ciphers.len(),
602             ))
603         }
604     }
605 
606     /// Returns the cipher being used by the session.
negotiated_cipher(&self) -> Result<CipherSuite>607     pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
608         unsafe {
609             let mut cipher = 0;
610             cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
611             Ok(CipherSuite::from_raw(cipher))
612         }
613     }
614 
615     /// Sets the requirements for client certificates.
616     ///
617     /// Should only be called on server-side sessions.
set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()>618     pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
619         unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
620     }
621 
622     /// Returns the state of client certificate processing.
client_certificate_state(&self) -> Result<SslClientCertificateState>623     pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
624         let mut state = 0;
625 
626         unsafe {
627             cvt(SSLGetClientCertificateState(self.0, &mut state))?;
628         }
629         Ok(SslClientCertificateState(state))
630     }
631 
632     /// Returns the `SecTrust` object corresponding to the peer.
633     ///
634     /// This can be used in conjunction with `set_break_on_server_auth` to
635     /// validate certificates which do not have roots in the default set.
peer_trust2(&self) -> Result<Option<SecTrust>>636     pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
637         // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
638         // so explicitly check for that
639         if self.state()? == SessionState::IDLE {
640             return Err(Error::from_code(errSecBadReq));
641         }
642 
643         unsafe {
644             let mut trust = ptr::null_mut();
645             cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
646             if trust.is_null() {
647                 Ok(None)
648             } else {
649                 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
650             }
651         }
652     }
653 
654     #[allow(missing_docs)]
655     #[deprecated(since = "0.2.1", note = "use peer_trust2 instead")]
peer_trust(&self) -> Result<SecTrust>656     pub fn peer_trust(&self) -> Result<SecTrust> {
657         match self.peer_trust2() {
658             Ok(Some(trust)) => Ok(trust),
659             Ok(None) => panic!("no trust available"),
660             Err(e) => Err(e),
661         }
662     }
663 
664     /// Returns the state of the session.
state(&self) -> Result<SessionState>665     pub fn state(&self) -> Result<SessionState> {
666         unsafe {
667             let mut state = 0;
668             cvt(SSLGetSessionState(self.0, &mut state))?;
669             Ok(SessionState(state))
670         }
671     }
672 
673     /// Returns the protocol version being used by the session.
negotiated_protocol_version(&self) -> Result<SslProtocol>674     pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
675         unsafe {
676             let mut version = 0;
677             cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
678             Ok(SslProtocol(version))
679         }
680     }
681 
682     /// Returns the maximum protocol version allowed by the session.
protocol_version_max(&self) -> Result<SslProtocol>683     pub fn protocol_version_max(&self) -> Result<SslProtocol> {
684         unsafe {
685             let mut version = 0;
686             cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
687             Ok(SslProtocol(version))
688         }
689     }
690 
691     /// Sets the maximum protocol version allowed by the session.
set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()>692     pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
693         unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
694     }
695 
696     /// Returns the minimum protocol version allowed by the session.
protocol_version_min(&self) -> Result<SslProtocol>697     pub fn protocol_version_min(&self) -> Result<SslProtocol> {
698         unsafe {
699             let mut version = 0;
700             cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
701             Ok(SslProtocol(version))
702         }
703     }
704 
705     /// Sets the minimum protocol version allowed by the session.
set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()>706     pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
707         unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
708     }
709 
710     /// Returns the set of protocols selected via ALPN if it succeeded.
711     #[cfg(feature = "alpn")]
alpn_protocols(&self) -> Result<Vec<String>>712     pub fn alpn_protocols(&self) -> Result<Vec<String>> {
713         let mut array = ptr::null();
714         unsafe {
715             #[cfg(feature = "OSX_10_13")]
716             {
717                 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
718             }
719 
720             #[cfg(not(feature = "OSX_10_13"))]
721             {
722                 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
723                 if let Some(f) = SSLCopyALPNProtocols.get() {
724                     cvt(f(self.0, &mut array))?;
725                 } else {
726                     return Err(Error::from_code(errSecUnimplemented));
727                 }
728             }
729 
730             if array.is_null() {
731                 return Ok(vec![]);
732             }
733 
734             let array = CFArray::<CFString>::wrap_under_create_rule(array);
735             Ok(array.into_iter().map(|p| p.to_string()).collect())
736         }
737     }
738 
739     /// Configures the set of protocols use for ALPN.
740     ///
741     /// This is only used for client-side connections.
742     #[cfg(feature = "alpn")]
set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()>743     pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
744         // When CFMutableArray is added to core-foundation and IntoIterator trait
745         // is implemented for CFMutableArray, the code below should directly collect
746         // into a CFMutableArray.
747         let protocols = CFArray::from_CFTypes(
748             &protocols
749                 .iter()
750                 .map(|proto| CFString::new(proto))
751                 .collect::<Vec<_>>(),
752         );
753 
754         #[cfg(feature = "OSX_10_13")]
755         {
756             unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
757         }
758         #[cfg(not(feature = "OSX_10_13"))]
759         {
760             dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
761             if let Some(f) = SSLSetALPNProtocols.get() {
762                 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
763             } else {
764                 Err(Error::from_code(errSecUnimplemented))
765             }
766         }
767     }
768 
769     /// Sets whether a protocol is enabled or not.
770     ///
771     /// # Note
772     ///
773     /// On OSX this is a deprecated API in favor of `set_protocol_version_max` and
774     /// `set_protocol_version_min`, although if you're working with OSX 10.8 or before you may have
775     /// to use this API instead.
776     #[cfg(target_os = "macos")]
set_protocol_version_enabled( &mut self, protocol: SslProtocol, enabled: bool, ) -> Result<()>777     pub fn set_protocol_version_enabled(
778         &mut self,
779         protocol: SslProtocol,
780         enabled: bool,
781     ) -> Result<()> {
782         unsafe {
783             cvt(SSLSetProtocolVersionEnabled(
784                 self.0,
785                 protocol.0,
786                 enabled as Boolean,
787             ))
788         }
789     }
790 
791     /// Returns the number of bytes which can be read without triggering a
792     /// `read` call in the underlying stream.
buffered_read_size(&self) -> Result<usize>793     pub fn buffered_read_size(&self) -> Result<usize> {
794         unsafe {
795             let mut size = 0;
796             cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
797             Ok(size)
798         }
799     }
800 
801     impl_options! {
802         /// If enabled, the handshake process will pause and return instead of
803         /// automatically validating a server's certificate.
804         const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
805         /// If enabled, the handshake process will pause and return after
806         /// the server requests a certificate from the client.
807         const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
808         /// If enabled, the handshake process will pause and return instead of
809         /// automatically validating a client's certificate.
810         const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
811         /// If enabled, TLS false start will be performed if an appropriate
812         /// cipher suite is negotiated.
813         ///
814         /// Requires the `OSX_10_9` (or greater) feature.
815         #[cfg(feature = "OSX_10_9")]
816         const kSSLSessionOptionFalseStart: false_start & set_false_start,
817         /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
818         /// connections using block ciphers to mitigate the BEAST attack.
819         ///
820         /// Requires the `OSX_10_9` (or greater) feature.
821         #[cfg(feature = "OSX_10_9")]
822         const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
823     }
824 
into_stream<S>(self, stream: S) -> Result<SslStream<S>> where S: Read + Write,825     fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
826     where
827         S: Read + Write,
828     {
829         unsafe {
830             let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
831             if ret != errSecSuccess {
832                 return Err(Error::from_code(ret));
833             }
834 
835             let stream = Connection {
836                 stream,
837                 err: None,
838                 panic: None,
839             };
840             let stream = Box::into_raw(Box::new(stream));
841             let ret = SSLSetConnection(self.0, stream as *mut _);
842             if ret != errSecSuccess {
843                 let _conn = Box::from_raw(stream);
844                 return Err(Error::from_code(ret));
845             }
846 
847             Ok(SslStream {
848                 ctx: self,
849                 _m: PhantomData,
850             })
851         }
852     }
853 
854     /// Performs the SSL/TLS handshake.
handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>> where S: Read + Write,855     pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
856     where
857         S: Read + Write,
858     {
859         self.into_stream(stream)
860             .map_err(HandshakeError::Failure)
861             .and_then(SslStream::handshake)
862     }
863 }
864 
865 struct Connection<S> {
866     stream: S,
867     err: Option<io::Error>,
868     panic: Option<Box<dyn Any + Send>>,
869 }
870 
871 // the logic here is based off of libcurl's
872 
translate_err(e: &io::Error) -> OSStatus873 fn translate_err(e: &io::Error) -> OSStatus {
874     match e.kind() {
875         io::ErrorKind::NotFound => errSSLClosedGraceful,
876         io::ErrorKind::ConnectionReset => errSSLClosedAbort,
877         io::ErrorKind::WouldBlock |
878         io::ErrorKind::NotConnected => errSSLWouldBlock,
879         _ => errSecIO,
880     }
881 }
882 
read_func<S>( connection: SSLConnectionRef, data: *mut c_void, data_length: *mut usize, ) -> OSStatus where S: Read,883 unsafe extern "C" fn read_func<S>(
884     connection: SSLConnectionRef,
885     data: *mut c_void,
886     data_length: *mut usize,
887 ) -> OSStatus
888 where
889     S: Read,
890 {
891     let conn: &mut Connection<S> = &mut *(connection as *mut _);
892     let data = slice::from_raw_parts_mut(data as *mut u8, *data_length);
893     let mut start = 0;
894     let mut ret = errSecSuccess;
895 
896     while start < data.len() {
897         match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
898             Ok(Ok(0)) => {
899                 ret = errSSLClosedNoNotify;
900                 break;
901             }
902             Ok(Ok(len)) => start += len,
903             Ok(Err(e)) => {
904                 ret = translate_err(&e);
905                 conn.err = Some(e);
906                 break;
907             }
908             Err(e) => {
909                 ret = errSecIO;
910                 conn.panic = Some(e);
911                 break;
912             }
913         }
914     }
915 
916     *data_length = start;
917     ret
918 }
919 
write_func<S>( connection: SSLConnectionRef, data: *const c_void, data_length: *mut usize, ) -> OSStatus where S: Write,920 unsafe extern "C" fn write_func<S>(
921     connection: SSLConnectionRef,
922     data: *const c_void,
923     data_length: *mut usize,
924 ) -> OSStatus
925 where
926     S: Write,
927 {
928     let conn: &mut Connection<S> = &mut *(connection as *mut _);
929     let data = slice::from_raw_parts(data as *mut u8, *data_length);
930     let mut start = 0;
931     let mut ret = errSecSuccess;
932 
933     while start < data.len() {
934         match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
935             Ok(Ok(0)) => {
936                 ret = errSSLClosedNoNotify;
937                 break;
938             }
939             Ok(Ok(len)) => start += len,
940             Ok(Err(e)) => {
941                 ret = translate_err(&e);
942                 conn.err = Some(e);
943                 break;
944             }
945             Err(e) => {
946                 ret = errSecIO;
947                 conn.panic = Some(e);
948                 break;
949             }
950         }
951     }
952 
953     *data_length = start;
954     ret
955 }
956 
957 /// A type implementing SSL/TLS encryption over an underlying stream.
958 pub struct SslStream<S> {
959     ctx: SslContext,
960     _m: PhantomData<S>,
961 }
962 
963 impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result964     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
965         fmt.debug_struct("SslStream")
966             .field("context", &self.ctx)
967             .field("stream", self.get_ref())
968             .finish()
969     }
970 }
971 
972 impl<S> Drop for SslStream<S> {
drop(&mut self)973     fn drop(&mut self) {
974         unsafe {
975             let mut conn = ptr::null();
976             let ret = SSLGetConnection(self.ctx.0, &mut conn);
977             assert!(ret == errSecSuccess);
978             Box::<Connection<S>>::from_raw(conn as *mut _);
979         }
980     }
981 }
982 
983 impl<S> SslStream<S> {
handshake(mut self) -> result::Result<Self, HandshakeError<S>>984     fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
985         match unsafe { SSLHandshake(self.ctx.0) } {
986             errSecSuccess => Ok(self),
987             reason @ errSSLPeerAuthCompleted
988             | reason @ errSSLClientCertRequested
989             | reason @ errSSLWouldBlock
990             | reason @ errSSLClientHelloReceived => {
991                 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
992                     stream: self,
993                     error: Error::from_code(reason),
994                 }))
995             }
996             err => {
997                 self.check_panic();
998                 Err(HandshakeError::Failure(Error::from_code(err)))
999             }
1000         }
1001     }
1002 
1003     /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S1004     pub fn get_ref(&self) -> &S {
1005         &self.connection().stream
1006     }
1007 
1008     /// Returns a mutable reference to the underlying stream.
get_mut(&mut self) -> &mut S1009     pub fn get_mut(&mut self) -> &mut S {
1010         &mut self.connection_mut().stream
1011     }
1012 
1013     /// Returns a shared reference to the `SslContext` of the stream.
context(&self) -> &SslContext1014     pub fn context(&self) -> &SslContext {
1015         &self.ctx
1016     }
1017 
1018     /// Returns a mutable reference to the `SslContext` of the stream.
context_mut(&mut self) -> &mut SslContext1019     pub fn context_mut(&mut self) -> &mut SslContext {
1020         &mut self.ctx
1021     }
1022 
1023     /// Shuts down the connection.
close(&mut self) -> result::Result<(), io::Error>1024     pub fn close(&mut self) -> result::Result<(), io::Error> {
1025         unsafe {
1026             let ret = SSLClose(self.ctx.0);
1027             if ret == errSecSuccess {
1028                 Ok(())
1029             } else {
1030                 Err(self.get_error(ret))
1031             }
1032         }
1033     }
1034 
connection(&self) -> &Connection<S>1035     fn connection(&self) -> &Connection<S> {
1036         unsafe {
1037             let mut conn = ptr::null();
1038             let ret = SSLGetConnection(self.ctx.0, &mut conn);
1039             assert!(ret == errSecSuccess);
1040 
1041             mem::transmute(conn)
1042         }
1043     }
1044 
connection_mut(&mut self) -> &mut Connection<S>1045     fn connection_mut(&mut self) -> &mut Connection<S> {
1046         unsafe {
1047             let mut conn = ptr::null();
1048             let ret = SSLGetConnection(self.ctx.0, &mut conn);
1049             assert!(ret == errSecSuccess);
1050 
1051             mem::transmute(conn)
1052         }
1053     }
1054 
check_panic(&mut self)1055     fn check_panic(&mut self) {
1056         let conn = self.connection_mut();
1057         if let Some(err) = conn.panic.take() {
1058             panic::resume_unwind(err);
1059         }
1060     }
1061 
get_error(&mut self, ret: OSStatus) -> io::Error1062     fn get_error(&mut self, ret: OSStatus) -> io::Error {
1063         self.check_panic();
1064 
1065         if let Some(err) = self.connection_mut().err.take() {
1066             err
1067         } else {
1068             io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1069         }
1070     }
1071 }
1072 
1073 impl<S: Read + Write> Read for SslStream<S> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>1074     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1075         // Below we base our return value off the amount of data read, so a
1076         // zero-length buffer might cause us to erroneously interpret this
1077         // request as an error. Instead short-circuit that logic and return
1078         // `Ok(0)` instead.
1079         if buf.is_empty() {
1080             return Ok(0);
1081         }
1082 
1083         // If some data was buffered but not enough to fill `buf`, SSLRead
1084         // will try to read a new packet. This is bad because there may be
1085         // no more data but the socket is remaining open (e.g HTTPS with
1086         // Connection: keep-alive).
1087         let buffered = self.context().buffered_read_size().unwrap_or(0);
1088         let to_read = if buffered > 0 {
1089             cmp::min(buffered, buf.len())
1090         } else {
1091             buf.len()
1092         };
1093 
1094         unsafe {
1095             let mut nread = 0;
1096             let ret = SSLRead(self.ctx.0, buf.as_mut_ptr() as *mut _, to_read, &mut nread);
1097             // SSLRead can return an error at the same time it returns the last
1098             // chunk of data (!)
1099             if nread > 0 {
1100                 return Ok(nread as usize);
1101             }
1102 
1103             match ret {
1104                 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1105                 _ => Err(self.get_error(ret)),
1106             }
1107         }
1108     }
1109 }
1110 
1111 impl<S: Read + Write> Write for SslStream<S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>1112     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1113         // Like above in read, short circuit a 0-length write
1114         if buf.is_empty() {
1115             return Ok(0);
1116         }
1117         unsafe {
1118             let mut nwritten = 0;
1119             let ret = SSLWrite(
1120                 self.ctx.0,
1121                 buf.as_ptr() as *const _,
1122                 buf.len(),
1123                 &mut nwritten,
1124             );
1125             // just to be safe, base success off of nwritten rather than ret
1126             // for the same reason as in read
1127             if nwritten > 0 {
1128                 Ok(nwritten as usize)
1129             } else {
1130                 Err(self.get_error(ret))
1131             }
1132         }
1133     }
1134 
flush(&mut self) -> io::Result<()>1135     fn flush(&mut self) -> io::Result<()> {
1136         self.connection_mut().stream.flush()
1137     }
1138 }
1139 
1140 /// A builder type to simplify the creation of client side `SslStream`s.
1141 #[derive(Debug)]
1142 pub struct ClientBuilder {
1143     identity: Option<SecIdentity>,
1144     certs: Vec<SecCertificate>,
1145     chain: Vec<SecCertificate>,
1146     protocol_min: Option<SslProtocol>,
1147     protocol_max: Option<SslProtocol>,
1148     trust_certs_only: bool,
1149     use_sni: bool,
1150     danger_accept_invalid_certs: bool,
1151     danger_accept_invalid_hostnames: bool,
1152     whitelisted_ciphers: Vec<CipherSuite>,
1153     blacklisted_ciphers: Vec<CipherSuite>,
1154     alpn: Option<Vec<String>>,
1155 }
1156 
1157 impl Default for ClientBuilder {
default() -> Self1158     fn default() -> Self {
1159         Self::new()
1160     }
1161 }
1162 
1163 impl ClientBuilder {
1164     /// Creates a new builder with default options.
new() -> Self1165     pub fn new() -> Self {
1166         Self {
1167             identity: None,
1168             certs: Vec::new(),
1169             chain: Vec::new(),
1170             protocol_min: None,
1171             protocol_max: None,
1172             trust_certs_only: false,
1173             use_sni: true,
1174             danger_accept_invalid_certs: false,
1175             danger_accept_invalid_hostnames: false,
1176             whitelisted_ciphers: Vec::new(),
1177             blacklisted_ciphers: Vec::new(),
1178             alpn: None,
1179         }
1180     }
1181 
1182     /// Specifies the set of root certificates to trust when
1183     /// verifying the server's certificate.
anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self1184     pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1185         self.certs = certs.to_owned();
1186         self
1187     }
1188 
1189     /// Specifies whether to trust the built-in certificates in addition
1190     /// to specified anchor certificates.
trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self1191     pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1192         self.trust_certs_only = only;
1193         self
1194     }
1195 
1196     /// Specifies whether to trust invalid certificates.
1197     ///
1198     /// # Warning
1199     ///
1200     /// You should think very carefully before using this method. If invalid
1201     /// certificates are trusted, *any* certificate for *any* site will be
1202     /// trusted for use. This includes expired certificates. This introduces
1203     /// significant vulnerabilities, and should only be used as a last resort.
danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self1204     pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1205         self.danger_accept_invalid_certs = noverify;
1206         self
1207     }
1208 
1209     /// Specifies whether to use Server Name Indication (SNI).
use_sni(&mut self, use_sni: bool) -> &mut Self1210     pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1211         self.use_sni = use_sni;
1212         self
1213     }
1214 
1215     /// Specifies whether to verify that the server's hostname matches its certificate.
1216     ///
1217     /// # Warning
1218     ///
1219     /// You should think very carefully before using this method. If hostnames are not verified,
1220     /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1221     /// vulnerabilities, and should only be used as a last resort.
danger_accept_invalid_hostnames( &mut self, danger_accept_invalid_hostnames: bool, ) -> &mut Self1222     pub fn danger_accept_invalid_hostnames(
1223         &mut self,
1224         danger_accept_invalid_hostnames: bool,
1225     ) -> &mut Self {
1226         self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1227         self
1228     }
1229 
1230     /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self1231     pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1232         self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
1233         self
1234     }
1235 
1236     /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self1237     pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1238         self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
1239         self
1240     }
1241 
1242     /// Use the specified identity as a SSL/TLS client certificate.
identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self1243     pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1244         self.identity = Some(identity.clone());
1245         self.chain = chain.to_owned();
1246         self
1247     }
1248 
1249     /// Configure the minimum protocol that this client will support.
protocol_min(&mut self, min: SslProtocol) -> &mut Self1250     pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1251         self.protocol_min = Some(min);
1252         self
1253     }
1254 
1255     /// Configure the minimum protocol that this client will support.
protocol_max(&mut self, max: SslProtocol) -> &mut Self1256     pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1257         self.protocol_max = Some(max);
1258         self
1259     }
1260 
1261     /// Configures the set of protocols used for ALPN.
1262     #[cfg(feature = "alpn")]
alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self1263     pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1264         self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
1265         self
1266     }
1267 
1268     /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1269     ///
1270     /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
handshake<S>( &self, domain: &str, stream: S, ) -> result::Result<SslStream<S>, ClientHandshakeError<S>> where S: Read + Write,1271     pub fn handshake<S>(
1272         &self,
1273         domain: &str,
1274         stream: S,
1275     ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1276     where
1277         S: Read + Write,
1278     {
1279         // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1280         // of the handshake logic through that.
1281         let stream = MidHandshakeSslStream {
1282             stream: self.ctx_into_stream(domain, stream)?,
1283             error: Error::from(errSecSuccess),
1284         };
1285 
1286         let certs = self.certs.clone();
1287         let stream = MidHandshakeClientBuilder {
1288             stream,
1289             domain: if self.danger_accept_invalid_hostnames {
1290                 None
1291             } else {
1292                 Some(domain.to_string())
1293             },
1294             certs,
1295             trust_certs_only: self.trust_certs_only,
1296             danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1297         };
1298         stream.handshake()
1299     }
1300 
ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>> where S: Read + Write,1301     fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1302     where
1303         S: Read + Write,
1304     {
1305         let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1306 
1307         if self.use_sni {
1308             ctx.set_peer_domain_name(domain)?;
1309         }
1310         if let Some(ref identity) = self.identity {
1311             ctx.set_certificate(identity, &self.chain)?;
1312         }
1313         #[cfg(feature = "alpn")]
1314         {
1315             if let Some(ref alpn) = self.alpn {
1316                 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1317             }
1318         }
1319         ctx.set_break_on_server_auth(true)?;
1320         self.configure_protocols(&mut ctx)?;
1321         self.configure_ciphers(&mut ctx)?;
1322 
1323         ctx.into_stream(stream)
1324     }
1325 
configure_protocols(&self, ctx: &mut SslContext) -> Result<()>1326     fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1327         if let Some(min) = self.protocol_min {
1328             ctx.set_protocol_version_min(min)?;
1329         }
1330         if let Some(max) = self.protocol_max {
1331             ctx.set_protocol_version_max(max)?;
1332         }
1333         Ok(())
1334     }
1335 
configure_ciphers(&self, ctx: &mut SslContext) -> Result<()>1336     fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1337         let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1338             ctx.enabled_ciphers()?
1339         } else {
1340             self.whitelisted_ciphers.clone()
1341         };
1342 
1343         if !self.blacklisted_ciphers.is_empty() {
1344             ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1345         }
1346 
1347         ctx.set_enabled_ciphers(&ciphers)?;
1348         Ok(())
1349     }
1350 }
1351 
1352 /// A builder type to simplify the creation of server-side `SslStream`s.
1353 #[derive(Debug)]
1354 pub struct ServerBuilder {
1355     identity: SecIdentity,
1356     certs: Vec<SecCertificate>,
1357 }
1358 
1359 impl ServerBuilder {
1360     /// Creates a new `ServerBuilder` which will use the specified identity
1361     /// and certificate chain for handshakes.
new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self1362     pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1363         Self {
1364             identity: identity.clone(),
1365             certs: certs.to_owned(),
1366         }
1367     }
1368 
1369     /// Initiates a new SSL/TLS session over a stream.
handshake<S>(&self, stream: S) -> Result<SslStream<S>> where S: Read + Write,1370     pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1371     where
1372         S: Read + Write,
1373     {
1374         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1375         ctx.set_certificate(&self.identity, &self.certs)?;
1376         match ctx.handshake(stream) {
1377             Ok(stream) => Ok(stream),
1378             Err(HandshakeError::Interrupted(stream)) => Err(Error::from_code(stream.reason())),
1379             Err(HandshakeError::Failure(err)) => Err(err),
1380         }
1381     }
1382 }
1383 
1384 #[cfg(test)]
1385 mod test {
1386     use std::io;
1387     use std::io::prelude::*;
1388     use std::net::TcpStream;
1389 
1390     use super::*;
1391 
1392     #[test]
connect()1393     fn connect() {
1394         let mut ctx = p!(SslContext::new(
1395             SslProtocolSide::CLIENT,
1396             SslConnectionType::STREAM
1397         ));
1398         p!(ctx.set_peer_domain_name("google.com"));
1399         let stream = p!(TcpStream::connect("google.com:443"));
1400         p!(ctx.handshake(stream));
1401     }
1402 
1403     #[test]
connect_bad_domain()1404     fn connect_bad_domain() {
1405         let mut ctx = p!(SslContext::new(
1406             SslProtocolSide::CLIENT,
1407             SslConnectionType::STREAM
1408         ));
1409         p!(ctx.set_peer_domain_name("foobar.com"));
1410         let stream = p!(TcpStream::connect("google.com:443"));
1411         match ctx.handshake(stream) {
1412             Ok(_) => panic!("expected failure"),
1413             Err(_) => {}
1414         }
1415     }
1416 
1417     #[test]
load_page()1418     fn load_page() {
1419         let mut ctx = p!(SslContext::new(
1420             SslProtocolSide::CLIENT,
1421             SslConnectionType::STREAM
1422         ));
1423         p!(ctx.set_peer_domain_name("google.com"));
1424         let stream = p!(TcpStream::connect("google.com:443"));
1425         let mut stream = p!(ctx.handshake(stream));
1426         p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1427         p!(stream.flush());
1428         let mut buf = vec![];
1429         p!(stream.read_to_end(&mut buf));
1430         println!("{}", String::from_utf8_lossy(&buf));
1431     }
1432 
1433     #[test]
1434     #[cfg(feature = "alpn")]
client_alpn_accept()1435     fn client_alpn_accept() {
1436         let mut ctx = p!(SslContext::new(
1437             SslProtocolSide::CLIENT,
1438             SslConnectionType::STREAM
1439         ));
1440         p!(ctx.set_peer_domain_name("google.com"));
1441         p!(ctx.set_alpn_protocols(&vec!["h2"]));
1442         let stream = p!(TcpStream::connect("google.com:443"));
1443         let stream = ctx.handshake(stream).unwrap();
1444         assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1445     }
1446 
1447     #[test]
1448     #[cfg(feature = "alpn")]
client_alpn_reject()1449     fn client_alpn_reject() {
1450         let mut ctx = p!(SslContext::new(
1451             SslProtocolSide::CLIENT,
1452             SslConnectionType::STREAM
1453         ));
1454         p!(ctx.set_peer_domain_name("google.com"));
1455         p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1456         let stream = p!(TcpStream::connect("google.com:443"));
1457         let stream = ctx.handshake(stream).unwrap();
1458         assert!(stream.context().alpn_protocols().is_err());
1459     }
1460 
1461     #[test]
client_no_anchor_certs()1462     fn client_no_anchor_certs() {
1463         let stream = p!(TcpStream::connect("google.com:443"));
1464         assert!(ClientBuilder::new()
1465             .trust_anchor_certificates_only(true)
1466             .handshake("google.com", stream)
1467             .is_err());
1468     }
1469 
1470     #[test]
client_bad_domain()1471     fn client_bad_domain() {
1472         let stream = p!(TcpStream::connect("google.com:443"));
1473         assert!(ClientBuilder::new()
1474             .handshake("foobar.com", stream)
1475             .is_err());
1476     }
1477 
1478     #[test]
client_bad_domain_ignored()1479     fn client_bad_domain_ignored() {
1480         let stream = p!(TcpStream::connect("google.com:443"));
1481         ClientBuilder::new()
1482             .danger_accept_invalid_hostnames(true)
1483             .handshake("foobar.com", stream)
1484             .unwrap();
1485     }
1486 
1487     #[test]
connect_no_verify_ssl()1488     fn connect_no_verify_ssl() {
1489         let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1490         let mut builder = ClientBuilder::new();
1491         builder.danger_accept_invalid_certs(true);
1492         builder.handshake("expired.badssl.com", stream).unwrap();
1493     }
1494 
1495     #[test]
load_page_client()1496     fn load_page_client() {
1497         let stream = p!(TcpStream::connect("google.com:443"));
1498         let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1499         p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1500         p!(stream.flush());
1501         let mut buf = vec![];
1502         p!(stream.read_to_end(&mut buf));
1503         println!("{}", String::from_utf8_lossy(&buf));
1504     }
1505 
1506     #[test]
1507     #[cfg_attr(target_os = "ios", ignore)] // FIXME what's going on with ios?
cipher_configuration()1508     fn cipher_configuration() {
1509         let mut ctx = p!(SslContext::new(
1510             SslProtocolSide::SERVER,
1511             SslConnectionType::STREAM
1512         ));
1513         let ciphers = p!(ctx.enabled_ciphers());
1514         let ciphers = ciphers
1515             .iter()
1516             .enumerate()
1517             .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1518             .collect::<Vec<_>>();
1519         p!(ctx.set_enabled_ciphers(&ciphers));
1520         assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1521     }
1522 
1523     #[test]
test_builder_whitelist_ciphers()1524     fn test_builder_whitelist_ciphers() {
1525         let stream = p!(TcpStream::connect("google.com:443"));
1526 
1527         let ctx = p!(SslContext::new(
1528             SslProtocolSide::CLIENT,
1529             SslConnectionType::STREAM
1530         ));
1531         assert!(p!(ctx.enabled_ciphers()).len() > 1);
1532 
1533         let ciphers = p!(ctx.enabled_ciphers());
1534         let cipher = ciphers.first().unwrap();
1535         let stream = p!(ClientBuilder::new()
1536             .whitelist_ciphers(&[*cipher])
1537             .ctx_into_stream("google.com", stream));
1538 
1539         assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1540     }
1541 
1542     #[test]
1543     #[cfg_attr(target_os = "ios", ignore)] // FIXME same issue as cipher_configuration
test_builder_blacklist_ciphers()1544     fn test_builder_blacklist_ciphers() {
1545         let stream = p!(TcpStream::connect("google.com:443"));
1546 
1547         let ctx = p!(SslContext::new(
1548             SslProtocolSide::CLIENT,
1549             SslConnectionType::STREAM
1550         ));
1551         let num = p!(ctx.enabled_ciphers()).len();
1552         assert!(num > 1);
1553 
1554         let ciphers = p!(ctx.enabled_ciphers());
1555         let cipher = ciphers.first().unwrap();
1556         let stream = p!(ClientBuilder::new()
1557             .blacklist_ciphers(&[*cipher])
1558             .ctx_into_stream("google.com", stream));
1559 
1560         assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1561     }
1562 
1563     #[test]
idle_context_peer_trust()1564     fn idle_context_peer_trust() {
1565         let ctx = p!(SslContext::new(
1566             SslProtocolSide::SERVER,
1567             SslConnectionType::STREAM
1568         ));
1569         assert!(ctx.peer_trust2().is_err());
1570     }
1571 
1572     #[test]
peer_id()1573     fn peer_id() {
1574         let mut ctx = p!(SslContext::new(
1575             SslProtocolSide::SERVER,
1576             SslConnectionType::STREAM
1577         ));
1578         assert!(p!(ctx.peer_id()).is_none());
1579         p!(ctx.set_peer_id(b"foobar"));
1580         assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1581     }
1582 
1583     #[test]
peer_domain_name()1584     fn peer_domain_name() {
1585         let mut ctx = p!(SslContext::new(
1586             SslProtocolSide::CLIENT,
1587             SslConnectionType::STREAM
1588         ));
1589         assert_eq!("", p!(ctx.peer_domain_name()));
1590         p!(ctx.set_peer_domain_name("foobar.com"));
1591         assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1592     }
1593 
1594     #[test]
1595     #[should_panic(expected = "blammo")]
write_panic()1596     fn write_panic() {
1597         struct ExplodingStream(TcpStream);
1598 
1599         impl Read for ExplodingStream {
1600             fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1601                 self.0.read(buf)
1602             }
1603         }
1604 
1605         impl Write for ExplodingStream {
1606             fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1607                 panic!("blammo");
1608             }
1609 
1610             fn flush(&mut self) -> io::Result<()> {
1611                 self.0.flush()
1612             }
1613         }
1614 
1615         let mut ctx = p!(SslContext::new(
1616             SslProtocolSide::CLIENT,
1617             SslConnectionType::STREAM
1618         ));
1619         p!(ctx.set_peer_domain_name("google.com"));
1620         let stream = p!(TcpStream::connect("google.com:443"));
1621         let _ = ctx.handshake(ExplodingStream(stream));
1622     }
1623 
1624     #[test]
1625     #[should_panic(expected = "blammo")]
read_panic()1626     fn read_panic() {
1627         struct ExplodingStream(TcpStream);
1628 
1629         impl Read for ExplodingStream {
1630             fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1631                 panic!("blammo");
1632             }
1633         }
1634 
1635         impl Write for ExplodingStream {
1636             fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1637                 self.0.write(buf)
1638             }
1639 
1640             fn flush(&mut self) -> io::Result<()> {
1641                 self.0.flush()
1642             }
1643         }
1644 
1645         let mut ctx = p!(SslContext::new(
1646             SslProtocolSide::CLIENT,
1647             SslConnectionType::STREAM
1648         ));
1649         p!(ctx.set_peer_domain_name("google.com"));
1650         let stream = p!(TcpStream::connect("google.com:443"));
1651         let _ = ctx.handshake(ExplodingStream(stream));
1652     }
1653 
1654     #[test]
zero_length_buffers()1655     fn zero_length_buffers() {
1656         let mut ctx = p!(SslContext::new(
1657             SslProtocolSide::CLIENT,
1658             SslConnectionType::STREAM
1659         ));
1660         p!(ctx.set_peer_domain_name("google.com"));
1661         let stream = p!(TcpStream::connect("google.com:443"));
1662         let mut stream = ctx.handshake(stream).unwrap();
1663         assert_eq!(stream.write(b"").unwrap(), 0);
1664         assert_eq!(stream.read(&mut []).unwrap(), 0);
1665     }
1666 }
1667