1 //! SSL/TLS encryption support using Secure Transport.
2 //!
3 //! # Examples
4 //!
5 //! To connect as a client to a server with a certificate trusted by the system:
6 //!
7 //! ```rust
8 //! use std::io::prelude::*;
9 //! use std::net::TcpStream;
10 //! use security_framework::secure_transport::ClientBuilder;
11 //!
12 //! let stream = TcpStream::connect("google.com:443").unwrap();
13 //! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
14 //!
15 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
16 //! let mut page = vec![];
17 //! stream.read_to_end(&mut page).unwrap();
18 //! println!("{}", String::from_utf8_lossy(&page));
19 //! ```
20 //!
21 //! To connect to a server with a certificate that's *not* trusted by the
22 //! system, specify the root certificates for the server's chain to the
23 //! `ClientBuilder`:
24 //!
25 //! ```rust,no_run
26 //! use std::io::prelude::*;
27 //! use std::net::TcpStream;
28 //! use security_framework::secure_transport::ClientBuilder;
29 //!
30 //! # let root_cert = unsafe { std::mem::zeroed() };
31 //! let stream = TcpStream::connect("my_server.com:443").unwrap();
32 //! let mut stream = ClientBuilder::new()
33 //! .anchor_certificates(&[root_cert])
34 //! .handshake("my_server.com", stream)
35 //! .unwrap();
36 //!
37 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
38 //! let mut page = vec![];
39 //! stream.read_to_end(&mut page).unwrap();
40 //! println!("{}", String::from_utf8_lossy(&page));
41 //! ```
42 //!
43 //! For more advanced configuration, the `SslContext` type can be used directly.
44 //!
45 //! To run a server:
46 //!
47 //! ```rust,no_run
48 //! use std::net::TcpListener;
49 //! use std::thread;
50 //! use security_framework::secure_transport::{SslContext, SslProtocolSide, SslConnectionType};
51 //!
52 //! // Create a TCP listener and start accepting on it.
53 //! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
54 //!
55 //! for stream in listener.incoming() {
56 //! let stream = stream.unwrap();
57 //! thread::spawn(move || {
58 //! // Create a new context configured to operate on the server side of
59 //! // a traditional SSL/TLS session.
60 //! let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
61 //! .unwrap();
62 //!
63 //! // Install the certificate chain that we will be using.
64 //! # let identity = unsafe { std::mem::zeroed() };
65 //! # let intermediate_cert = unsafe { std::mem::zeroed() };
66 //! # let root_cert = unsafe { std::mem::zeroed() };
67 //! ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
68 //!
69 //! // Perform the SSL/TLS handshake and get our stream.
70 //! let mut stream = ctx.handshake(stream).unwrap();
71 //! });
72 //! }
73 //!
74 //! ```
75 #[allow(unused_imports)]
76 use core_foundation::array::{CFArray, CFArrayRef};
77
78 use core_foundation::base::{Boolean, TCFType};
79 #[cfg(feature = "alpn")]
80 use core_foundation::string::CFString;
81 use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82 use std::os::raw::c_void;
83
84 #[allow(unused_imports)]
85 use security_framework_sys::base::{
86 errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87 errSecUnimplemented,
88 };
89
90 use security_framework_sys::secure_transport::*;
91 use std::any::Any;
92 use std::cmp;
93 use std::fmt;
94 use std::io;
95 use std::io::prelude::*;
96 use std::marker::PhantomData;
97 use std::panic::{self, AssertUnwindSafe};
98 use std::ptr;
99 use std::result;
100 use std::slice;
101
102 use crate::base::{Error, Result};
103 use crate::certificate::SecCertificate;
104 use crate::cipher_suite::CipherSuite;
105 use crate::identity::SecIdentity;
106 use crate::import_export::Pkcs12ImportOptions;
107 use crate::policy::SecPolicy;
108 use crate::trust::SecTrust;
109 use crate::{cvt, AsInner};
110 use security_framework_sys::base::errSecParam;
111
112 /// Specifies a side of a TLS session.
113 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
114 pub struct SslProtocolSide(SSLProtocolSide);
115
116 impl SslProtocolSide {
117 /// The server side of the session.
118 pub const SERVER: Self = Self(kSSLServerSide);
119
120 /// The client side of the session.
121 pub const CLIENT: Self = Self(kSSLClientSide);
122 }
123
124 /// Specifies the type of TLS session.
125 #[derive(Debug, Copy, Clone)]
126 pub struct SslConnectionType(SSLConnectionType);
127
128 impl SslConnectionType {
129 /// A traditional TLS stream.
130 pub const STREAM: Self = Self(kSSLStreamType);
131
132 /// A DTLS session.
133 pub const DATAGRAM: Self = Self(kSSLDatagramType);
134 }
135
136 /// An error or intermediate state after a TLS handshake attempt.
137 #[derive(Debug)]
138 pub enum HandshakeError<S> {
139 /// The handshake failed.
140 Failure(Error),
141 /// The handshake was interrupted midway through.
142 Interrupted(MidHandshakeSslStream<S>),
143 }
144
145 impl<S> From<Error> for HandshakeError<S> {
146 #[inline(always)]
from(err: Error) -> Self147 fn from(err: Error) -> Self {
148 Self::Failure(err)
149 }
150 }
151
152 /// An error or intermediate state after a TLS handshake attempt.
153 #[derive(Debug)]
154 pub enum ClientHandshakeError<S> {
155 /// The handshake failed.
156 Failure(Error),
157 /// The handshake was interrupted midway through.
158 Interrupted(MidHandshakeClientBuilder<S>),
159 }
160
161 impl<S> From<Error> for ClientHandshakeError<S> {
162 #[inline(always)]
from(err: Error) -> Self163 fn from(err: Error) -> Self {
164 Self::Failure(err)
165 }
166 }
167
168 /// An SSL stream midway through the handshake process.
169 #[derive(Debug)]
170 pub struct MidHandshakeSslStream<S> {
171 stream: SslStream<S>,
172 error: Error,
173 }
174
175 impl<S> MidHandshakeSslStream<S> {
176 /// Returns a shared reference to the inner stream.
177 #[inline(always)]
get_ref(&self) -> &S178 pub fn get_ref(&self) -> &S {
179 self.stream.get_ref()
180 }
181
182 /// Returns a mutable reference to the inner stream.
183 #[inline(always)]
get_mut(&mut self) -> &mut S184 pub fn get_mut(&mut self) -> &mut S {
185 self.stream.get_mut()
186 }
187
188 /// Returns a shared reference to the `SslContext` of the stream.
189 #[inline(always)]
context(&self) -> &SslContext190 pub fn context(&self) -> &SslContext {
191 self.stream.context()
192 }
193
194 /// Returns a mutable reference to the `SslContext` of the stream.
195 #[inline(always)]
context_mut(&mut self) -> &mut SslContext196 pub fn context_mut(&mut self) -> &mut SslContext {
197 self.stream.context_mut()
198 }
199
200 /// Returns `true` iff `break_on_server_auth` was set and the handshake has
201 /// progressed to that point.
202 #[inline(always)]
server_auth_completed(&self) -> bool203 pub fn server_auth_completed(&self) -> bool {
204 self.error.code() == errSSLPeerAuthCompleted
205 }
206
207 /// Returns `true` iff `break_on_cert_requested` was set and the handshake
208 /// has progressed to that point.
209 #[inline(always)]
client_cert_requested(&self) -> bool210 pub fn client_cert_requested(&self) -> bool {
211 self.error.code() == errSSLClientCertRequested
212 }
213
214 /// Returns `true` iff the underlying stream returned an error with the
215 /// `WouldBlock` kind.
216 #[inline(always)]
would_block(&self) -> bool217 pub fn would_block(&self) -> bool {
218 self.error.code() == errSSLWouldBlock
219 }
220
221 /// Returns the error which caused the handshake interruption.
222 #[inline(always)]
error(&self) -> &Error223 pub fn error(&self) -> &Error {
224 &self.error
225 }
226
227 /// Restarts the handshake process.
228 #[inline(always)]
handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>>229 pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
230 self.stream.handshake()
231 }
232 }
233
234 /// An SSL stream midway through the handshake process.
235 #[derive(Debug)]
236 pub struct MidHandshakeClientBuilder<S> {
237 stream: MidHandshakeSslStream<S>,
238 domain: Option<String>,
239 certs: Vec<SecCertificate>,
240 trust_certs_only: bool,
241 danger_accept_invalid_certs: bool,
242 }
243
244 impl<S> MidHandshakeClientBuilder<S> {
245 /// Returns a shared reference to the inner stream.
246 #[inline(always)]
get_ref(&self) -> &S247 pub fn get_ref(&self) -> &S {
248 self.stream.get_ref()
249 }
250
251 /// Returns a mutable reference to the inner stream.
252 #[inline(always)]
get_mut(&mut self) -> &mut S253 pub fn get_mut(&mut self) -> &mut S {
254 self.stream.get_mut()
255 }
256
257 /// Returns the error which caused the handshake interruption.
258 #[inline(always)]
error(&self) -> &Error259 pub fn error(&self) -> &Error {
260 self.stream.error()
261 }
262
263 /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>>264 pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
265 let MidHandshakeClientBuilder {
266 stream,
267 domain,
268 certs,
269 trust_certs_only,
270 danger_accept_invalid_certs,
271 } = self;
272
273 let mut result = stream.handshake();
274 loop {
275 let stream = match result {
276 Ok(stream) => return Ok(stream),
277 Err(HandshakeError::Interrupted(stream)) => stream,
278 Err(HandshakeError::Failure(err)) => {
279 return Err(ClientHandshakeError::Failure(err))
280 }
281 };
282
283 if stream.would_block() {
284 let ret = MidHandshakeClientBuilder {
285 stream,
286 domain,
287 certs,
288 trust_certs_only,
289 danger_accept_invalid_certs,
290 };
291 return Err(ClientHandshakeError::Interrupted(ret));
292 }
293
294 if stream.server_auth_completed() {
295 if danger_accept_invalid_certs {
296 result = stream.handshake();
297 continue;
298 }
299 let mut trust = match stream.context().peer_trust2()? {
300 Some(trust) => trust,
301 None => {
302 result = stream.handshake();
303 continue;
304 }
305 };
306 trust.set_anchor_certificates(&certs)?;
307 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
308 let policy =
309 SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
310 trust.set_policy(&policy)?;
311 trust.evaluate_with_error().map_err(|error| {
312 #[cfg(feature = "log")]
313 log::warn!("SecTrustEvaluateWithError: {}", error.to_string());
314 Error::from_code(error.code() as _)
315 })?;
316 result = stream.handshake();
317 continue;
318 }
319
320 let err = Error::from_code(stream.error().code());
321 return Err(ClientHandshakeError::Failure(err));
322 }
323 }
324 }
325
326 /// Specifies the state of a TLS session.
327 #[derive(Debug, PartialEq, Eq)]
328 pub struct SessionState(SSLSessionState);
329
330 impl SessionState {
331 /// The session has not yet started.
332 pub const IDLE: Self = Self(kSSLIdle);
333
334 /// The session is in the handshake process.
335 pub const HANDSHAKE: Self = Self(kSSLHandshake);
336
337 /// The session is connected.
338 pub const CONNECTED: Self = Self(kSSLConnected);
339
340 /// The session has been terminated.
341 pub const CLOSED: Self = Self(kSSLClosed);
342
343 /// The session has been aborted due to an error.
344 pub const ABORTED: Self = Self(kSSLAborted);
345 }
346
347 /// Specifies a server's requirement for client certificates.
348 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
349 pub struct SslAuthenticate(SSLAuthenticate);
350
351 impl SslAuthenticate {
352 /// Do not request a client certificate.
353 pub const NEVER: Self = Self(kNeverAuthenticate);
354
355 /// Require a client certificate.
356 pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
357
358 /// Request but do not require a client certificate.
359 pub const TRY: Self = Self(kTryAuthenticate);
360 }
361
362 /// Specifies the state of client certificate processing.
363 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
364 pub struct SslClientCertificateState(SSLClientCertificateState);
365
366 impl SslClientCertificateState {
367 /// A client certificate has not been requested or sent.
368 pub const NONE: Self = Self(kSSLClientCertNone);
369
370 /// A client certificate has been requested but not recieved.
371 pub const REQUESTED: Self = Self(kSSLClientCertRequested);
372 /// A client certificate has been received and successfully validated.
373 pub const SENT: Self = Self(kSSLClientCertSent);
374
375 /// A client certificate has been received but has failed to validate.
376 pub const REJECTED: Self = Self(kSSLClientCertRejected); }
377
378 /// Specifies protocol versions.
379 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
380 pub struct SslProtocol(SSLProtocol);
381
382 impl SslProtocol {
383 /// No protocol has been or should be negotiated or specified; use the default.
384 pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
385
386 /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
387 /// SSL 3.0.
388 pub const SSL3: Self = Self(kSSLProtocol3);
389
390 /// The TLS 1.0 protocol is preferred, though lower versions may be used
391 /// if the peer does not support TLS 1.0.
392 pub const TLS1: Self = Self(kTLSProtocol1);
393
394 /// The TLS 1.1 protocol is preferred, though lower versions may be used
395 /// if the peer does not support TLS 1.1.
396 pub const TLS11: Self = Self(kTLSProtocol11);
397
398 /// The TLS 1.2 protocol is preferred, though lower versions may be used
399 /// if the peer does not support TLS 1.2.
400 pub const TLS12: Self = Self(kTLSProtocol12);
401
402 /// The TLS 1.3 protocol is preferred, though lower versions may be used
403 /// if the peer does not support TLS 1.3.
404 pub const TLS13: Self = Self(kTLSProtocol13);
405
406 /// Only the SSL 2.0 protocol is accepted.
407 pub const SSL2: Self = Self(kSSLProtocol2);
408
409 /// The DTLSv1 protocol is preferred.
410 pub const DTLS1: Self = Self(kDTLSProtocol1);
411
412 /// Only the SSL 3.0 protocol is accepted.
413 pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
414
415 /// Only the TLS 1.0 protocol is accepted.
416 pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
417
418 /// All supported TLS/SSL versions are accepted.
419 pub const ALL: Self = Self(kSSLProtocolAll);
420 }
421
422 declare_TCFType! {
423 /// A Secure Transport SSL/TLS context object.
424 SslContext, SSLContextRef
425 }
426
427 impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
428
429 impl fmt::Debug for SslContext {
430 #[cold]
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result431 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
432 let mut builder = fmt.debug_struct("SslContext");
433 if let Ok(state) = self.state() {
434 builder.field("state", &state);
435 }
436 builder.finish()
437 }
438 }
439
440 unsafe impl Sync for SslContext {}
441 unsafe impl Send for SslContext {}
442
443 impl AsInner for SslContext {
444 type Inner = SSLContextRef;
445
446 #[inline(always)]
as_inner(&self) -> SSLContextRef447 fn as_inner(&self) -> SSLContextRef {
448 self.0
449 }
450 }
451
452 macro_rules! impl_options {
453 ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
454 $(
455 $(#[$a])*
456 #[inline(always)]
457 pub fn $set(&mut self, value: bool) -> Result<()> {
458 unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
459 }
460
461 $(#[$a])*
462 #[inline]
463 pub fn $get(&self) -> Result<bool> {
464 let mut value = 0;
465 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
466 Ok(value != 0)
467 }
468 )*
469 }
470 }
471
472 impl SslContext {
473 /// Creates a new `SslContext` for the specified side and type of SSL
474 /// connection.
475 #[inline]
new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self>476 pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
477 unsafe {
478 let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
479 Ok(Self(ctx))
480 }
481 }
482
483 /// Sets the fully qualified domain name of the peer.
484 ///
485 /// This will be used on the client side of a session to validate the
486 /// common name field of the server's certificate. It has no effect if
487 /// called on a server-side `SslContext`.
488 ///
489 /// It is *highly* recommended to call this method before starting the
490 /// handshake process.
491 #[inline]
set_peer_domain_name(&mut self, peer_name: &str) -> Result<()>492 pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
493 unsafe {
494 // SSLSetPeerDomainName doesn't need a null terminated string
495 cvt(SSLSetPeerDomainName(
496 self.0,
497 peer_name.as_ptr() as *const _,
498 peer_name.len(),
499 ))
500 }
501 }
502
503 /// Returns the peer domain name set by `set_peer_domain_name`.
peer_domain_name(&self) -> Result<String>504 pub fn peer_domain_name(&self) -> Result<String> {
505 unsafe {
506 let mut len = 0;
507 cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
508 let mut buf = vec![0; len];
509 cvt(SSLGetPeerDomainName(
510 self.0,
511 buf.as_mut_ptr() as *mut _,
512 &mut len,
513 ))?;
514 Ok(String::from_utf8(buf).unwrap())
515 }
516 }
517
518 /// Sets the certificate to be used by this side of the SSL session.
519 ///
520 /// This must be called before the handshake for server-side connections,
521 /// and can be used on the client-side to specify a client certificate.
522 ///
523 /// The `identity` corresponds to the leaf certificate and private
524 /// key, and the `certs` correspond to extra certificates in the chain.
set_certificate( &mut self, identity: &SecIdentity, certs: &[SecCertificate], ) -> Result<()>525 pub fn set_certificate(
526 &mut self,
527 identity: &SecIdentity,
528 certs: &[SecCertificate],
529 ) -> Result<()> {
530 let mut arr = vec![identity.as_CFType()];
531 arr.extend(certs.iter().map(|c| c.as_CFType()));
532 let certs = CFArray::from_CFTypes(&arr);
533
534 unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
535 }
536
537 /// Sets the peer ID of this session.
538 ///
539 /// A peer ID is an opaque sequence of bytes that will be used by Secure
540 /// Transport to identify the peer of an SSL session. If the peer ID of
541 /// this session matches that of a previously terminated session, the
542 /// previous session can be resumed without requiring a full handshake.
543 #[inline]
set_peer_id(&mut self, peer_id: &[u8]) -> Result<()>544 pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
545 unsafe {
546 cvt(SSLSetPeerID(
547 self.0,
548 peer_id.as_ptr() as *const _,
549 peer_id.len(),
550 ))
551 }
552 }
553
554 /// Returns the peer ID of this session.
peer_id(&self) -> Result<Option<&[u8]>>555 pub fn peer_id(&self) -> Result<Option<&[u8]>> {
556 unsafe {
557 let mut ptr = ptr::null();
558 let mut len = 0;
559 cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
560 if ptr.is_null() {
561 Ok(None)
562 } else {
563 Ok(Some(slice::from_raw_parts(ptr as *const _, len)))
564 }
565 }
566 }
567
568 /// Returns the list of ciphers that are supported by Secure Transport.
supported_ciphers(&self) -> Result<Vec<CipherSuite>>569 pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
570 unsafe {
571 let mut num_ciphers = 0;
572 cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
573 let mut ciphers = vec![0; num_ciphers];
574 cvt(SSLGetSupportedCiphers(
575 self.0,
576 ciphers.as_mut_ptr(),
577 &mut num_ciphers,
578 ))?;
579 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
580 }
581 }
582
583 /// Returns the list of ciphers that are eligible to be used for
584 /// negotiation.
enabled_ciphers(&self) -> Result<Vec<CipherSuite>>585 pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
586 unsafe {
587 let mut num_ciphers = 0;
588 cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
589 let mut ciphers = vec![0; num_ciphers];
590 cvt(SSLGetEnabledCiphers(
591 self.0,
592 ciphers.as_mut_ptr(),
593 &mut num_ciphers,
594 ))?;
595 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
596 }
597 }
598
599 /// Sets the list of ciphers that are eligible to be used for negotiation.
set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()>600 pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
601 let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
602 unsafe {
603 cvt(SSLSetEnabledCiphers(
604 self.0,
605 ciphers.as_ptr(),
606 ciphers.len(),
607 ))
608 }
609 }
610
611 /// Returns the cipher being used by the session.
612 #[inline]
negotiated_cipher(&self) -> Result<CipherSuite>613 pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
614 unsafe {
615 let mut cipher = 0;
616 cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
617 Ok(CipherSuite::from_raw(cipher))
618 }
619 }
620
621 /// Sets the requirements for client certificates.
622 ///
623 /// Should only be called on server-side sessions.
624 #[inline]
set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()>625 pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
626 unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
627 }
628
629 /// Returns the state of client certificate processing.
630 #[inline]
client_certificate_state(&self) -> Result<SslClientCertificateState>631 pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
632 let mut state = 0;
633
634 unsafe {
635 cvt(SSLGetClientCertificateState(self.0, &mut state))?;
636 }
637 Ok(SslClientCertificateState(state))
638 }
639
640 /// Returns the `SecTrust` object corresponding to the peer.
641 ///
642 /// This can be used in conjunction with `set_break_on_server_auth` to
643 /// validate certificates which do not have roots in the default set.
peer_trust2(&self) -> Result<Option<SecTrust>>644 pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
645 // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
646 // so explicitly check for that
647 if self.state()? == SessionState::IDLE {
648 return Err(Error::from_code(errSecBadReq));
649 }
650
651 unsafe {
652 let mut trust = ptr::null_mut();
653 cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
654 if trust.is_null() {
655 Ok(None)
656 } else {
657 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
658 }
659 }
660 }
661
662 /// Returns the state of the session.
663 #[inline]
state(&self) -> Result<SessionState>664 pub fn state(&self) -> Result<SessionState> {
665 unsafe {
666 let mut state = 0;
667 cvt(SSLGetSessionState(self.0, &mut state))?;
668 Ok(SessionState(state))
669 }
670 }
671
672 /// Returns the protocol version being used by the session.
673 #[inline]
negotiated_protocol_version(&self) -> Result<SslProtocol>674 pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
675 unsafe {
676 let mut version = 0;
677 cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
678 Ok(SslProtocol(version))
679 }
680 }
681
682 /// Returns the maximum protocol version allowed by the session.
683 #[inline]
protocol_version_max(&self) -> Result<SslProtocol>684 pub fn protocol_version_max(&self) -> Result<SslProtocol> {
685 unsafe {
686 let mut version = 0;
687 cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
688 Ok(SslProtocol(version))
689 }
690 }
691
692 /// Sets the maximum protocol version allowed by the session.
693 #[inline]
set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()>694 pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
695 unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
696 }
697
698 /// Returns the minimum protocol version allowed by the session.
699 #[inline]
protocol_version_min(&self) -> Result<SslProtocol>700 pub fn protocol_version_min(&self) -> Result<SslProtocol> {
701 unsafe {
702 let mut version = 0;
703 cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
704 Ok(SslProtocol(version))
705 }
706 }
707
708 /// Sets the minimum protocol version allowed by the session.
709 #[inline]
set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()>710 pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
711 unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
712 }
713
714 /// Returns the set of protocols selected via ALPN if it succeeded.
715 #[cfg(feature = "alpn")]
alpn_protocols(&self) -> Result<Vec<String>>716 pub fn alpn_protocols(&self) -> Result<Vec<String>> {
717 let mut array: CFArrayRef = ptr::null();
718 unsafe {
719 #[cfg(feature = "OSX_10_13")]
720 {
721 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
722 }
723
724 #[cfg(not(feature = "OSX_10_13"))]
725 {
726 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
727 if let Some(f) = SSLCopyALPNProtocols.get() {
728 cvt(f(self.0, &mut array))?;
729 } else {
730 return Err(Error::from_code(errSecUnimplemented));
731 }
732 }
733
734 if array.is_null() {
735 return Ok(vec![]);
736 }
737
738 let array = CFArray::<CFString>::wrap_under_create_rule(array);
739 Ok(array.into_iter().map(|p| p.to_string()).collect())
740 }
741 }
742
743 /// Configures the set of protocols use for ALPN.
744 ///
745 /// This is only used for client-side connections.
746 #[cfg(feature = "alpn")]
set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()>747 pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
748 // When CFMutableArray is added to core-foundation and IntoIterator trait
749 // is implemented for CFMutableArray, the code below should directly collect
750 // into a CFMutableArray.
751 let protocols = CFArray::from_CFTypes(
752 &protocols
753 .iter()
754 .map(|proto| CFString::new(proto))
755 .collect::<Vec<_>>(),
756 );
757
758 #[cfg(feature = "OSX_10_13")]
759 {
760 unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
761 }
762 #[cfg(not(feature = "OSX_10_13"))]
763 {
764 dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
765 if let Some(f) = SSLSetALPNProtocols.get() {
766 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
767 } else {
768 Err(Error::from_code(errSecUnimplemented))
769 }
770 }
771 }
772
773 /// Sets whether the client sends the `SessionTicket` extension in its `ClientHello`.
774 ///
775 /// On its own, this will just cause the client to send an empty `SessionTicket` extension on
776 /// every connection. [`SslContext::set_peer_id`] must also be used to key the session
777 /// ticket returned by the server.
778 ///
779 /// [`SslContext::set_peer_id`]: #method.set_peer_id
780 #[cfg(feature = "session-tickets")]
set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()>781 pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
782 #[cfg(feature = "OSX_10_13")]
783 {
784 unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, enabled as Boolean)) }
785 }
786 #[cfg(not(feature = "OSX_10_13"))]
787 {
788 dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
789 if let Some(f) = SSLSetSessionTicketsEnabled.get() {
790 unsafe { cvt(f(self.0, enabled as Boolean)) }
791 } else {
792 Err(Error::from_code(errSecUnimplemented))
793 }
794 }
795 }
796
797 /// Sets whether a protocol is enabled or not.
798 ///
799 /// # Note
800 ///
801 /// On OSX this is a deprecated API in favor of `set_protocol_version_max` and
802 /// `set_protocol_version_min`, although if you're working with OSX 10.8 or before you may have
803 /// to use this API instead.
804 #[cfg(target_os = "macos")]
805 #[deprecated(note = "use `set_protocol_version_max`")]
set_protocol_version_enabled( &mut self, protocol: SslProtocol, enabled: bool, ) -> Result<()>806 pub fn set_protocol_version_enabled(
807 &mut self,
808 protocol: SslProtocol,
809 enabled: bool,
810 ) -> Result<()> {
811 unsafe {
812 cvt(SSLSetProtocolVersionEnabled(
813 self.0,
814 protocol.0,
815 enabled as Boolean,
816 ))
817 }
818 }
819
820 /// Returns the number of bytes which can be read without triggering a
821 /// `read` call in the underlying stream.
822 #[inline]
buffered_read_size(&self) -> Result<usize>823 pub fn buffered_read_size(&self) -> Result<usize> {
824 unsafe {
825 let mut size = 0;
826 cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
827 Ok(size)
828 }
829 }
830
831 impl_options! {
832 /// If enabled, the handshake process will pause and return instead of
833 /// automatically validating a server's certificate.
834 const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
835 /// If enabled, the handshake process will pause and return after
836 /// the server requests a certificate from the client.
837 const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
838 /// If enabled, the handshake process will pause and return instead of
839 /// automatically validating a client's certificate.
840 const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
841 /// If enabled, TLS false start will be performed if an appropriate
842 /// cipher suite is negotiated.
843 ///
844 /// Requires the `OSX_10_9` (or greater) feature.
845 #[cfg(feature = "OSX_10_9")]
846 const kSSLSessionOptionFalseStart: false_start & set_false_start,
847 /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
848 /// connections using block ciphers to mitigate the BEAST attack.
849 ///
850 /// Requires the `OSX_10_9` (or greater) feature.
851 #[cfg(feature = "OSX_10_9")]
852 const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
853 }
854
into_stream<S>(self, stream: S) -> Result<SslStream<S>> where S: Read + Write,855 fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
856 where
857 S: Read + Write,
858 {
859 unsafe {
860 let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
861 if ret != errSecSuccess {
862 return Err(Error::from_code(ret));
863 }
864
865 let stream = Connection {
866 stream,
867 err: None,
868 panic: None,
869 };
870 let stream = Box::into_raw(Box::new(stream));
871 let ret = SSLSetConnection(self.0, stream as *mut _);
872 if ret != errSecSuccess {
873 let _conn = Box::from_raw(stream);
874 return Err(Error::from_code(ret));
875 }
876
877 Ok(SslStream {
878 ctx: self,
879 _m: PhantomData,
880 })
881 }
882 }
883
884 /// Performs the SSL/TLS handshake.
handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>> where S: Read + Write,885 pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
886 where
887 S: Read + Write,
888 {
889 self.into_stream(stream)
890 .map_err(HandshakeError::Failure)
891 .and_then(SslStream::handshake)
892 }
893 }
894
895 struct Connection<S> {
896 stream: S,
897 err: Option<io::Error>,
898 panic: Option<Box<dyn Any + Send>>,
899 }
900
901 // the logic here is based off of libcurl's
902 #[cold]
translate_err(e: &io::Error) -> OSStatus903 fn translate_err(e: &io::Error) -> OSStatus {
904 match e.kind() {
905 io::ErrorKind::NotFound => errSSLClosedGraceful,
906 io::ErrorKind::ConnectionReset => errSSLClosedAbort,
907 io::ErrorKind::WouldBlock |
908 io::ErrorKind::NotConnected => errSSLWouldBlock,
909 _ => errSecIO,
910 }
911 }
912
read_func<S>( connection: SSLConnectionRef, data: *mut c_void, data_length: *mut usize, ) -> OSStatus where S: Read,913 unsafe extern "C" fn read_func<S>(
914 connection: SSLConnectionRef,
915 data: *mut c_void,
916 data_length: *mut usize,
917 ) -> OSStatus
918 where
919 S: Read,
920 {
921 let conn: &mut Connection<S> = &mut *(connection as *mut _);
922 let data = slice::from_raw_parts_mut(data as *mut u8, *data_length);
923 let mut start = 0;
924 let mut ret = errSecSuccess;
925
926 while start < data.len() {
927 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
928 Ok(Ok(0)) => {
929 ret = errSSLClosedNoNotify;
930 break;
931 }
932 Ok(Ok(len)) => start += len,
933 Ok(Err(e)) => {
934 ret = translate_err(&e);
935 conn.err = Some(e);
936 break;
937 }
938 Err(e) => {
939 ret = errSecIO;
940 conn.panic = Some(e);
941 break;
942 }
943 }
944 }
945
946 *data_length = start;
947 ret
948 }
949
write_func<S>( connection: SSLConnectionRef, data: *const c_void, data_length: *mut usize, ) -> OSStatus where S: Write,950 unsafe extern "C" fn write_func<S>(
951 connection: SSLConnectionRef,
952 data: *const c_void,
953 data_length: *mut usize,
954 ) -> OSStatus
955 where
956 S: Write,
957 {
958 let conn: &mut Connection<S> = &mut *(connection as *mut _);
959 let data = slice::from_raw_parts(data as *mut u8, *data_length);
960 let mut start = 0;
961 let mut ret = errSecSuccess;
962
963 while start < data.len() {
964 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
965 Ok(Ok(0)) => {
966 ret = errSSLClosedNoNotify;
967 break;
968 }
969 Ok(Ok(len)) => start += len,
970 Ok(Err(e)) => {
971 ret = translate_err(&e);
972 conn.err = Some(e);
973 break;
974 }
975 Err(e) => {
976 ret = errSecIO;
977 conn.panic = Some(e);
978 break;
979 }
980 }
981 }
982
983 *data_length = start;
984 ret
985 }
986
987 /// A type implementing SSL/TLS encryption over an underlying stream.
988 pub struct SslStream<S> {
989 ctx: SslContext,
990 _m: PhantomData<S>,
991 }
992
993 impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
994 #[cold]
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result995 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
996 fmt.debug_struct("SslStream")
997 .field("context", &self.ctx)
998 .field("stream", self.get_ref())
999 .finish()
1000 }
1001 }
1002
1003 impl<S> Drop for SslStream<S> {
drop(&mut self)1004 fn drop(&mut self) {
1005 unsafe {
1006 let mut conn = ptr::null();
1007 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1008 assert!(ret == errSecSuccess);
1009 Box::<Connection<S>>::from_raw(conn as *mut _);
1010 }
1011 }
1012 }
1013
1014 impl<S> SslStream<S> {
handshake(mut self) -> result::Result<Self, HandshakeError<S>>1015 fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
1016 match unsafe { SSLHandshake(self.ctx.0) } {
1017 errSecSuccess => Ok(self),
1018 reason @ errSSLPeerAuthCompleted
1019 | reason @ errSSLClientCertRequested
1020 | reason @ errSSLWouldBlock
1021 | reason @ errSSLClientHelloReceived => {
1022 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
1023 stream: self,
1024 error: Error::from_code(reason),
1025 }))
1026 }
1027 err => {
1028 self.check_panic();
1029 Err(HandshakeError::Failure(Error::from_code(err)))
1030 }
1031 }
1032 }
1033
1034 /// Returns a shared reference to the inner stream.
1035 #[inline(always)]
get_ref(&self) -> &S1036 pub fn get_ref(&self) -> &S {
1037 &self.connection().stream
1038 }
1039
1040 /// Returns a mutable reference to the underlying stream.
1041 #[inline(always)]
get_mut(&mut self) -> &mut S1042 pub fn get_mut(&mut self) -> &mut S {
1043 &mut self.connection_mut().stream
1044 }
1045
1046 /// Returns a shared reference to the `SslContext` of the stream.
1047 #[inline(always)]
context(&self) -> &SslContext1048 pub fn context(&self) -> &SslContext {
1049 &self.ctx
1050 }
1051
1052 /// Returns a mutable reference to the `SslContext` of the stream.
1053 #[inline(always)]
context_mut(&mut self) -> &mut SslContext1054 pub fn context_mut(&mut self) -> &mut SslContext {
1055 &mut self.ctx
1056 }
1057
1058 /// Shuts down the connection.
close(&mut self) -> result::Result<(), io::Error>1059 pub fn close(&mut self) -> result::Result<(), io::Error> {
1060 unsafe {
1061 let ret = SSLClose(self.ctx.0);
1062 if ret == errSecSuccess {
1063 Ok(())
1064 } else {
1065 Err(self.get_error(ret))
1066 }
1067 }
1068 }
1069
connection(&self) -> &Connection<S>1070 fn connection(&self) -> &Connection<S> {
1071 unsafe {
1072 let mut conn = ptr::null();
1073 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1074 assert!(ret == errSecSuccess);
1075
1076 &mut *(conn as *mut Connection<S>)
1077 }
1078 }
1079
connection_mut(&mut self) -> &mut Connection<S>1080 fn connection_mut(&mut self) -> &mut Connection<S> {
1081 unsafe {
1082 let mut conn = ptr::null();
1083 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1084 assert!(ret == errSecSuccess);
1085
1086 &mut *(conn as *mut Connection<S>)
1087 }
1088 }
1089
1090 #[cold]
check_panic(&mut self)1091 fn check_panic(&mut self) {
1092 let conn = self.connection_mut();
1093 if let Some(err) = conn.panic.take() {
1094 panic::resume_unwind(err);
1095 }
1096 }
1097
1098 #[cold]
get_error(&mut self, ret: OSStatus) -> io::Error1099 fn get_error(&mut self, ret: OSStatus) -> io::Error {
1100 self.check_panic();
1101
1102 if let Some(err) = self.connection_mut().err.take() {
1103 err
1104 } else {
1105 io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1106 }
1107 }
1108 }
1109
1110 impl<S: Read + Write> Read for SslStream<S> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>1111 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1112 // Below we base our return value off the amount of data read, so a
1113 // zero-length buffer might cause us to erroneously interpret this
1114 // request as an error. Instead short-circuit that logic and return
1115 // `Ok(0)` instead.
1116 if buf.is_empty() {
1117 return Ok(0);
1118 }
1119
1120 // If some data was buffered but not enough to fill `buf`, SSLRead
1121 // will try to read a new packet. This is bad because there may be
1122 // no more data but the socket is remaining open (e.g HTTPS with
1123 // Connection: keep-alive).
1124 let buffered = self.context().buffered_read_size().unwrap_or(0);
1125 let to_read = if buffered > 0 {
1126 cmp::min(buffered, buf.len())
1127 } else {
1128 buf.len()
1129 };
1130
1131 unsafe {
1132 let mut nread = 0;
1133 let ret = SSLRead(self.ctx.0, buf.as_mut_ptr() as *mut _, to_read, &mut nread);
1134 // SSLRead can return an error at the same time it returns the last
1135 // chunk of data (!)
1136 if nread > 0 {
1137 return Ok(nread as usize);
1138 }
1139
1140 match ret {
1141 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1142 _ => Err(self.get_error(ret)),
1143 }
1144 }
1145 }
1146 }
1147
1148 impl<S: Read + Write> Write for SslStream<S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>1149 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1150 // Like above in read, short circuit a 0-length write
1151 if buf.is_empty() {
1152 return Ok(0);
1153 }
1154 unsafe {
1155 let mut nwritten = 0;
1156 let ret = SSLWrite(
1157 self.ctx.0,
1158 buf.as_ptr() as *const _,
1159 buf.len(),
1160 &mut nwritten,
1161 );
1162 // just to be safe, base success off of nwritten rather than ret
1163 // for the same reason as in read
1164 if nwritten > 0 {
1165 Ok(nwritten as usize)
1166 } else {
1167 Err(self.get_error(ret))
1168 }
1169 }
1170 }
1171
flush(&mut self) -> io::Result<()>1172 fn flush(&mut self) -> io::Result<()> {
1173 self.connection_mut().stream.flush()
1174 }
1175 }
1176
1177 /// A builder type to simplify the creation of client side `SslStream`s.
1178 #[derive(Debug)]
1179 pub struct ClientBuilder {
1180 identity: Option<SecIdentity>,
1181 certs: Vec<SecCertificate>,
1182 chain: Vec<SecCertificate>,
1183 protocol_min: Option<SslProtocol>,
1184 protocol_max: Option<SslProtocol>,
1185 trust_certs_only: bool,
1186 use_sni: bool,
1187 danger_accept_invalid_certs: bool,
1188 danger_accept_invalid_hostnames: bool,
1189 whitelisted_ciphers: Vec<CipherSuite>,
1190 blacklisted_ciphers: Vec<CipherSuite>,
1191 alpn: Option<Vec<String>>,
1192 enable_session_tickets: bool,
1193 }
1194
1195 impl Default for ClientBuilder {
1196 #[inline(always)]
default() -> Self1197 fn default() -> Self {
1198 Self::new()
1199 }
1200 }
1201
1202 impl ClientBuilder {
1203 /// Creates a new builder with default options.
1204 #[inline]
new() -> Self1205 pub fn new() -> Self {
1206 Self {
1207 identity: None,
1208 certs: Vec::new(),
1209 chain: Vec::new(),
1210 protocol_min: None,
1211 protocol_max: None,
1212 trust_certs_only: false,
1213 use_sni: true,
1214 danger_accept_invalid_certs: false,
1215 danger_accept_invalid_hostnames: false,
1216 whitelisted_ciphers: Vec::new(),
1217 blacklisted_ciphers: Vec::new(),
1218 alpn: None,
1219 enable_session_tickets: false,
1220 }
1221 }
1222
1223 /// Specifies the set of root certificates to trust when
1224 /// verifying the server's certificate.
1225 #[inline]
anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self1226 pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1227 self.certs = certs.to_owned();
1228 self
1229 }
1230
1231 /// Add the certificate the set of root certificates to trust
1232 /// when verifying the server's certificate.
1233 #[inline]
add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self1234 pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
1235 self.certs.push(certs.to_owned());
1236 self
1237 }
1238
1239 /// Specifies whether to trust the built-in certificates in addition
1240 /// to specified anchor certificates.
1241 #[inline(always)]
trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self1242 pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1243 self.trust_certs_only = only;
1244 self
1245 }
1246
1247 /// Specifies whether to trust invalid certificates.
1248 ///
1249 /// # Warning
1250 ///
1251 /// You should think very carefully before using this method. If invalid
1252 /// certificates are trusted, *any* certificate for *any* site will be
1253 /// trusted for use. This includes expired certificates. This introduces
1254 /// significant vulnerabilities, and should only be used as a last resort.
1255 #[inline(always)]
danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self1256 pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1257 self.danger_accept_invalid_certs = noverify;
1258 self
1259 }
1260
1261 /// Specifies whether to use Server Name Indication (SNI).
1262 #[inline(always)]
use_sni(&mut self, use_sni: bool) -> &mut Self1263 pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1264 self.use_sni = use_sni;
1265 self
1266 }
1267
1268 /// Specifies whether to verify that the server's hostname matches its certificate.
1269 ///
1270 /// # Warning
1271 ///
1272 /// You should think very carefully before using this method. If hostnames are not verified,
1273 /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1274 /// vulnerabilities, and should only be used as a last resort.
1275 #[inline(always)]
danger_accept_invalid_hostnames( &mut self, danger_accept_invalid_hostnames: bool, ) -> &mut Self1276 pub fn danger_accept_invalid_hostnames(
1277 &mut self,
1278 danger_accept_invalid_hostnames: bool,
1279 ) -> &mut Self {
1280 self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1281 self
1282 }
1283
1284 /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self1285 pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1286 self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
1287 self
1288 }
1289
1290 /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self1291 pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1292 self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
1293 self
1294 }
1295
1296 /// Use the specified identity as a SSL/TLS client certificate.
identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self1297 pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1298 self.identity = Some(identity.clone());
1299 self.chain = chain.to_owned();
1300 self
1301 }
1302
1303 /// Configure the minimum protocol that this client will support.
1304 #[inline(always)]
protocol_min(&mut self, min: SslProtocol) -> &mut Self1305 pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1306 self.protocol_min = Some(min);
1307 self
1308 }
1309
1310 /// Configure the minimum protocol that this client will support.
1311 #[inline(always)]
protocol_max(&mut self, max: SslProtocol) -> &mut Self1312 pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1313 self.protocol_max = Some(max);
1314 self
1315 }
1316
1317 /// Configures the set of protocols used for ALPN.
1318 #[cfg(feature = "alpn")]
alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self1319 pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1320 self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
1321 self
1322 }
1323
1324 /// Configures the use of the RFC 5077 `SessionTicket` extension.
1325 ///
1326 /// Defaults to `false`.
1327 #[cfg(feature = "session-tickets")]
1328 #[inline(always)]
enable_session_tickets(&mut self, enable: bool) -> &mut Self1329 pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1330 self.enable_session_tickets = enable;
1331 self
1332 }
1333
1334 /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1335 ///
1336 /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
handshake<S>( &self, domain: &str, stream: S, ) -> result::Result<SslStream<S>, ClientHandshakeError<S>> where S: Read + Write,1337 pub fn handshake<S>(
1338 &self,
1339 domain: &str,
1340 stream: S,
1341 ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1342 where
1343 S: Read + Write,
1344 {
1345 // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1346 // of the handshake logic through that.
1347 let stream = MidHandshakeSslStream {
1348 stream: self.ctx_into_stream(domain, stream)?,
1349 error: Error::from(errSecSuccess),
1350 };
1351
1352 let certs = self.certs.clone();
1353 let stream = MidHandshakeClientBuilder {
1354 stream,
1355 domain: if self.danger_accept_invalid_hostnames {
1356 None
1357 } else {
1358 Some(domain.to_string())
1359 },
1360 certs,
1361 trust_certs_only: self.trust_certs_only,
1362 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1363 };
1364 stream.handshake()
1365 }
1366
ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>> where S: Read + Write,1367 fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1368 where
1369 S: Read + Write,
1370 {
1371 let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1372
1373 if self.use_sni {
1374 ctx.set_peer_domain_name(domain)?;
1375 }
1376 if let Some(ref identity) = self.identity {
1377 ctx.set_certificate(identity, &self.chain)?;
1378 }
1379 #[cfg(feature = "alpn")]
1380 {
1381 if let Some(ref alpn) = self.alpn {
1382 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1383 }
1384 }
1385 #[cfg(feature = "session-tickets")]
1386 {
1387 if self.enable_session_tickets {
1388 // We must use the domain here to ensure that we go through certificate validation
1389 // again rather than resuming the session if the domain changes.
1390 ctx.set_peer_id(domain.as_bytes())?;
1391 ctx.set_session_tickets_enabled(true)?;
1392 }
1393 }
1394 ctx.set_break_on_server_auth(true)?;
1395 self.configure_protocols(&mut ctx)?;
1396 self.configure_ciphers(&mut ctx)?;
1397
1398 ctx.into_stream(stream)
1399 }
1400
configure_protocols(&self, ctx: &mut SslContext) -> Result<()>1401 fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1402 if let Some(min) = self.protocol_min {
1403 ctx.set_protocol_version_min(min)?;
1404 }
1405 if let Some(max) = self.protocol_max {
1406 ctx.set_protocol_version_max(max)?;
1407 }
1408 Ok(())
1409 }
1410
configure_ciphers(&self, ctx: &mut SslContext) -> Result<()>1411 fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1412 let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1413 ctx.enabled_ciphers()?
1414 } else {
1415 self.whitelisted_ciphers.clone()
1416 };
1417
1418 if !self.blacklisted_ciphers.is_empty() {
1419 ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1420 }
1421
1422 ctx.set_enabled_ciphers(&ciphers)?;
1423 Ok(())
1424 }
1425 }
1426
1427 /// A builder type to simplify the creation of server-side `SslStream`s.
1428 #[derive(Debug)]
1429 pub struct ServerBuilder {
1430 identity: SecIdentity,
1431 certs: Vec<SecCertificate>,
1432 }
1433
1434 impl ServerBuilder {
1435 /// Creates a new `ServerBuilder` which will use the specified identity
1436 /// and certificate chain for handshakes.
new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self1437 pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1438 Self {
1439 identity: identity.clone(),
1440 certs: certs.to_owned(),
1441 }
1442 }
1443
1444 /// Creates a new `ServerBuilder` which will use the identity
1445 /// from the given PKCS #12 data.
1446 ///
1447 /// This operation fails if PKCS #12 file contains zero or more than one identity.
1448 ///
1449 /// This is a shortcut for the most common operation.
from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self>1450 pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
1451 let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
1452 .passphrase(passphrase)
1453 .import(pkcs12_der)?
1454 .into_iter()
1455 .flat_map(|idendity| {
1456 let certs = idendity.cert_chain.unwrap_or_default();
1457 idendity.identity.map(|identity| (identity, certs))
1458 })
1459 .collect();
1460 if identities.len() == 1 {
1461 let (identity, certs) = identities.pop().unwrap();
1462 Ok(ServerBuilder::new(&identity, &certs))
1463 } else {
1464 // This error code is not really helpful
1465 Err(Error::from_code(errSecParam))
1466 }
1467 }
1468
1469 /// Create a SSL context for lower-level stream initialization.
new_ssl_context(&self) -> Result<SslContext>1470 pub fn new_ssl_context(&self) -> Result<SslContext> {
1471 let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1472 ctx.set_certificate(&self.identity, &self.certs)?;
1473 Ok(ctx)
1474 }
1475
1476 /// Initiates a new SSL/TLS session over a stream.
handshake<S>(&self, stream: S) -> Result<SslStream<S>> where S: Read + Write,1477 pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1478 where
1479 S: Read + Write,
1480 {
1481 match self.new_ssl_context()?.handshake(stream) {
1482 Ok(stream) => Ok(stream),
1483 Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
1484 Err(HandshakeError::Failure(err)) => Err(err),
1485 }
1486 }
1487 }
1488
1489 #[cfg(test)]
1490 mod test {
1491 use std::io;
1492 use std::io::prelude::*;
1493 use std::net::TcpStream;
1494
1495 use super::*;
1496
1497 #[test]
server_builder_from_pkcs12()1498 fn server_builder_from_pkcs12() {
1499 let pkcs12_der = include_bytes!("../test/server.p12");
1500 ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
1501 }
1502
1503 #[test]
connect()1504 fn connect() {
1505 let mut ctx = p!(SslContext::new(
1506 SslProtocolSide::CLIENT,
1507 SslConnectionType::STREAM
1508 ));
1509 p!(ctx.set_peer_domain_name("google.com"));
1510 let stream = p!(TcpStream::connect("google.com:443"));
1511 p!(ctx.handshake(stream));
1512 }
1513
1514 #[test]
connect_bad_domain()1515 fn connect_bad_domain() {
1516 let mut ctx = p!(SslContext::new(
1517 SslProtocolSide::CLIENT,
1518 SslConnectionType::STREAM
1519 ));
1520 p!(ctx.set_peer_domain_name("foobar.com"));
1521 let stream = p!(TcpStream::connect("google.com:443"));
1522 match ctx.handshake(stream) {
1523 Ok(_) => panic!("expected failure"),
1524 Err(_) => {}
1525 }
1526 }
1527
1528 #[test]
load_page()1529 fn load_page() {
1530 let mut ctx = p!(SslContext::new(
1531 SslProtocolSide::CLIENT,
1532 SslConnectionType::STREAM
1533 ));
1534 p!(ctx.set_peer_domain_name("google.com"));
1535 let stream = p!(TcpStream::connect("google.com:443"));
1536 let mut stream = p!(ctx.handshake(stream));
1537 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1538 p!(stream.flush());
1539 let mut buf = vec![];
1540 p!(stream.read_to_end(&mut buf));
1541 println!("{}", String::from_utf8_lossy(&buf));
1542 }
1543
1544 #[test]
client_no_session_ticket_resumption()1545 fn client_no_session_ticket_resumption() {
1546 for _ in 0..2 {
1547 let stream = p!(TcpStream::connect("google.com:443"));
1548
1549 // Manually handshake here.
1550 let stream = MidHandshakeSslStream {
1551 stream: ClientBuilder::new()
1552 .ctx_into_stream("google.com", stream)
1553 .unwrap(),
1554 error: Error::from(errSecSuccess),
1555 };
1556
1557 let mut result = stream.handshake();
1558
1559 if let Err(HandshakeError::Interrupted(stream)) = result {
1560 assert!(stream.server_auth_completed());
1561 result = stream.handshake();
1562 } else {
1563 panic!("Unexpectedly skipped server auth");
1564 }
1565
1566 assert!(result.is_ok());
1567 }
1568 }
1569
1570 #[test]
1571 #[cfg(feature = "session-tickets")]
client_session_ticket_resumption()1572 fn client_session_ticket_resumption() {
1573 // The first time through this loop, we should do a full handshake. The second time, we
1574 // should immediately finish the handshake without breaking on server auth.
1575 for i in 0..2 {
1576 let stream = p!(TcpStream::connect("google.com:443"));
1577 let mut builder = ClientBuilder::new();
1578 builder.enable_session_tickets(true);
1579
1580 // Manually handshake here.
1581 let stream = MidHandshakeSslStream {
1582 stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1583 error: Error::from(errSecSuccess),
1584 };
1585
1586 let mut result = stream.handshake();
1587
1588 if let Err(HandshakeError::Interrupted(stream)) = result {
1589 assert!(stream.server_auth_completed());
1590 assert_eq!(
1591 i, 0,
1592 "Session ticket resumption did not work, server auth was not skipped"
1593 );
1594 result = stream.handshake();
1595 } else {
1596 assert_eq!(i, 1, "Unexpectedly skipped server auth");
1597 }
1598
1599 assert!(result.is_ok());
1600 }
1601 }
1602
1603 #[test]
1604 #[cfg(feature = "alpn")]
client_alpn_accept()1605 fn client_alpn_accept() {
1606 let mut ctx = p!(SslContext::new(
1607 SslProtocolSide::CLIENT,
1608 SslConnectionType::STREAM
1609 ));
1610 p!(ctx.set_peer_domain_name("google.com"));
1611 p!(ctx.set_alpn_protocols(&vec!["h2"]));
1612 let stream = p!(TcpStream::connect("google.com:443"));
1613 let stream = ctx.handshake(stream).unwrap();
1614 assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1615 }
1616
1617 #[test]
1618 #[cfg(feature = "alpn")]
client_alpn_reject()1619 fn client_alpn_reject() {
1620 let mut ctx = p!(SslContext::new(
1621 SslProtocolSide::CLIENT,
1622 SslConnectionType::STREAM
1623 ));
1624 p!(ctx.set_peer_domain_name("google.com"));
1625 p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1626 let stream = p!(TcpStream::connect("google.com:443"));
1627 let stream = ctx.handshake(stream).unwrap();
1628 assert!(stream.context().alpn_protocols().is_err());
1629 }
1630
1631 #[test]
client_no_anchor_certs()1632 fn client_no_anchor_certs() {
1633 let stream = p!(TcpStream::connect("google.com:443"));
1634 assert!(ClientBuilder::new()
1635 .trust_anchor_certificates_only(true)
1636 .handshake("google.com", stream)
1637 .is_err());
1638 }
1639
1640 #[test]
client_bad_domain()1641 fn client_bad_domain() {
1642 let stream = p!(TcpStream::connect("google.com:443"));
1643 assert!(ClientBuilder::new()
1644 .handshake("foobar.com", stream)
1645 .is_err());
1646 }
1647
1648 #[test]
client_bad_domain_ignored()1649 fn client_bad_domain_ignored() {
1650 let stream = p!(TcpStream::connect("google.com:443"));
1651 ClientBuilder::new()
1652 .danger_accept_invalid_hostnames(true)
1653 .handshake("foobar.com", stream)
1654 .unwrap();
1655 }
1656
1657 #[test]
connect_no_verify_ssl()1658 fn connect_no_verify_ssl() {
1659 let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1660 let mut builder = ClientBuilder::new();
1661 builder.danger_accept_invalid_certs(true);
1662 builder.handshake("expired.badssl.com", stream).unwrap();
1663 }
1664
1665 #[test]
load_page_client()1666 fn load_page_client() {
1667 let stream = p!(TcpStream::connect("google.com:443"));
1668 let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1669 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1670 p!(stream.flush());
1671 let mut buf = vec![];
1672 p!(stream.read_to_end(&mut buf));
1673 println!("{}", String::from_utf8_lossy(&buf));
1674 }
1675
1676 #[test]
1677 #[cfg_attr(target_os = "ios", ignore)] // FIXME what's going on with ios?
cipher_configuration()1678 fn cipher_configuration() {
1679 let mut ctx = p!(SslContext::new(
1680 SslProtocolSide::SERVER,
1681 SslConnectionType::STREAM
1682 ));
1683 let ciphers = p!(ctx.enabled_ciphers());
1684 let ciphers = ciphers
1685 .iter()
1686 .enumerate()
1687 .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1688 .collect::<Vec<_>>();
1689 p!(ctx.set_enabled_ciphers(&ciphers));
1690 assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1691 }
1692
1693 #[test]
test_builder_whitelist_ciphers()1694 fn test_builder_whitelist_ciphers() {
1695 let stream = p!(TcpStream::connect("google.com:443"));
1696
1697 let ctx = p!(SslContext::new(
1698 SslProtocolSide::CLIENT,
1699 SslConnectionType::STREAM
1700 ));
1701 assert!(p!(ctx.enabled_ciphers()).len() > 1);
1702
1703 let ciphers = p!(ctx.enabled_ciphers());
1704 let cipher = ciphers.first().unwrap();
1705 let stream = p!(ClientBuilder::new()
1706 .whitelist_ciphers(&[*cipher])
1707 .ctx_into_stream("google.com", stream));
1708
1709 assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1710 }
1711
1712 #[test]
1713 #[cfg_attr(target_os = "ios", ignore)] // FIXME same issue as cipher_configuration
test_builder_blacklist_ciphers()1714 fn test_builder_blacklist_ciphers() {
1715 let stream = p!(TcpStream::connect("google.com:443"));
1716
1717 let ctx = p!(SslContext::new(
1718 SslProtocolSide::CLIENT,
1719 SslConnectionType::STREAM
1720 ));
1721 let num = p!(ctx.enabled_ciphers()).len();
1722 assert!(num > 1);
1723
1724 let ciphers = p!(ctx.enabled_ciphers());
1725 let cipher = ciphers.first().unwrap();
1726 let stream = p!(ClientBuilder::new()
1727 .blacklist_ciphers(&[*cipher])
1728 .ctx_into_stream("google.com", stream));
1729
1730 assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1731 }
1732
1733 #[test]
idle_context_peer_trust()1734 fn idle_context_peer_trust() {
1735 let ctx = p!(SslContext::new(
1736 SslProtocolSide::SERVER,
1737 SslConnectionType::STREAM
1738 ));
1739 assert!(ctx.peer_trust2().is_err());
1740 }
1741
1742 #[test]
peer_id()1743 fn peer_id() {
1744 let mut ctx = p!(SslContext::new(
1745 SslProtocolSide::SERVER,
1746 SslConnectionType::STREAM
1747 ));
1748 assert!(p!(ctx.peer_id()).is_none());
1749 p!(ctx.set_peer_id(b"foobar"));
1750 assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1751 }
1752
1753 #[test]
peer_domain_name()1754 fn peer_domain_name() {
1755 let mut ctx = p!(SslContext::new(
1756 SslProtocolSide::CLIENT,
1757 SslConnectionType::STREAM
1758 ));
1759 assert_eq!("", p!(ctx.peer_domain_name()));
1760 p!(ctx.set_peer_domain_name("foobar.com"));
1761 assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1762 }
1763
1764 #[test]
1765 #[should_panic(expected = "blammo")]
write_panic()1766 fn write_panic() {
1767 struct ExplodingStream(TcpStream);
1768
1769 impl Read for ExplodingStream {
1770 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1771 self.0.read(buf)
1772 }
1773 }
1774
1775 impl Write for ExplodingStream {
1776 fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1777 panic!("blammo");
1778 }
1779
1780 fn flush(&mut self) -> io::Result<()> {
1781 self.0.flush()
1782 }
1783 }
1784
1785 let mut ctx = p!(SslContext::new(
1786 SslProtocolSide::CLIENT,
1787 SslConnectionType::STREAM
1788 ));
1789 p!(ctx.set_peer_domain_name("google.com"));
1790 let stream = p!(TcpStream::connect("google.com:443"));
1791 let _ = ctx.handshake(ExplodingStream(stream));
1792 }
1793
1794 #[test]
1795 #[should_panic(expected = "blammo")]
read_panic()1796 fn read_panic() {
1797 struct ExplodingStream(TcpStream);
1798
1799 impl Read for ExplodingStream {
1800 fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1801 panic!("blammo");
1802 }
1803 }
1804
1805 impl Write for ExplodingStream {
1806 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1807 self.0.write(buf)
1808 }
1809
1810 fn flush(&mut self) -> io::Result<()> {
1811 self.0.flush()
1812 }
1813 }
1814
1815 let mut ctx = p!(SslContext::new(
1816 SslProtocolSide::CLIENT,
1817 SslConnectionType::STREAM
1818 ));
1819 p!(ctx.set_peer_domain_name("google.com"));
1820 let stream = p!(TcpStream::connect("google.com:443"));
1821 let _ = ctx.handshake(ExplodingStream(stream));
1822 }
1823
1824 #[test]
zero_length_buffers()1825 fn zero_length_buffers() {
1826 let mut ctx = p!(SslContext::new(
1827 SslProtocolSide::CLIENT,
1828 SslConnectionType::STREAM
1829 ));
1830 p!(ctx.set_peer_domain_name("google.com"));
1831 let stream = p!(TcpStream::connect("google.com:443"));
1832 let mut stream = ctx.handshake(stream).unwrap();
1833 assert_eq!(stream.write(b"").unwrap(), 0);
1834 assert_eq!(stream.read(&mut []).unwrap(), 0);
1835 }
1836 }
1837