1 //! SSL/TLS encryption support using Secure Transport.
2 //!
3 //! # Examples
4 //!
5 //! To connect as a client to a server with a certificate trusted by the system:
6 //!
7 //! ```rust
8 //! use std::io::prelude::*;
9 //! use std::net::TcpStream;
10 //! use security_framework::secure_transport::ClientBuilder;
11 //!
12 //! let stream = TcpStream::connect("google.com:443").unwrap();
13 //! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
14 //!
15 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
16 //! let mut page = vec![];
17 //! stream.read_to_end(&mut page).unwrap();
18 //! println!("{}", String::from_utf8_lossy(&page));
19 //! ```
20 //!
21 //! To connect to a server with a certificate that's *not* trusted by the
22 //! system, specify the root certificates for the server's chain to the
23 //! `ClientBuilder`:
24 //!
25 //! ```rust,no_run
26 //! use std::io::prelude::*;
27 //! use std::net::TcpStream;
28 //! use security_framework::secure_transport::ClientBuilder;
29 //!
30 //! # let root_cert = unsafe { std::mem::zeroed() };
31 //! let stream = TcpStream::connect("my_server.com:443").unwrap();
32 //! let mut stream = ClientBuilder::new()
33 //! .anchor_certificates(&[root_cert])
34 //! .handshake("my_server.com", stream)
35 //! .unwrap();
36 //!
37 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
38 //! let mut page = vec![];
39 //! stream.read_to_end(&mut page).unwrap();
40 //! println!("{}", String::from_utf8_lossy(&page));
41 //! ```
42 //!
43 //! For more advanced configuration, the `SslContext` type can be used directly.
44 //!
45 //! To run a server:
46 //!
47 //! ```rust,no_run
48 //! use std::net::TcpListener;
49 //! use std::thread;
50 //! use security_framework::secure_transport::{SslContext, SslProtocolSide, SslConnectionType};
51 //!
52 //! // Create a TCP listener and start accepting on it.
53 //! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
54 //!
55 //! for stream in listener.incoming() {
56 //! let stream = stream.unwrap();
57 //! thread::spawn(move || {
58 //! // Create a new context configured to operate on the server side of
59 //! // a traditional SSL/TLS session.
60 //! let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
61 //! .unwrap();
62 //!
63 //! // Install the certificate chain that we will be using.
64 //! # let identity = unsafe { std::mem::zeroed() };
65 //! # let intermediate_cert = unsafe { std::mem::zeroed() };
66 //! # let root_cert = unsafe { std::mem::zeroed() };
67 //! ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
68 //!
69 //! // Perform the SSL/TLS handshake and get our stream.
70 //! let mut stream = ctx.handshake(stream).unwrap();
71 //! });
72 //! }
73 //!
74 //! ```
75 #[allow(unused_imports)]
76 use core_foundation::array::{CFArray, CFArrayRef};
77
78 use core_foundation::base::{Boolean, TCFType};
79 #[cfg(feature = "alpn")]
80 use core_foundation::string::CFString;
81 use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82 use std::os::raw::c_void;
83
84 #[allow(unused_imports)]
85 use security_framework_sys::base::{
86 errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87 errSecUnimplemented,
88 };
89
90 use security_framework_sys::secure_transport::*;
91 use std::any::Any;
92 use std::cmp;
93 use std::fmt;
94 use std::io;
95 use std::io::prelude::*;
96 use std::marker::PhantomData;
97 use std::mem;
98 use std::panic::{self, AssertUnwindSafe};
99 use std::ptr;
100 use std::result;
101 use std::slice;
102
103 use crate::base::{Error, Result};
104 use crate::certificate::SecCertificate;
105 use crate::cipher_suite::CipherSuite;
106 use crate::identity::SecIdentity;
107 use crate::policy::SecPolicy;
108 use crate::trust::{SecTrust, TrustResult};
109 use crate::{cvt, AsInner};
110
111 /// Specifies a side of a TLS session.
112 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
113 pub struct SslProtocolSide(SSLProtocolSide);
114
115 impl SslProtocolSide {
116 /// The server side of the session.
117 pub const SERVER: Self = Self(kSSLServerSide);
118
119 /// The client side of the session.
120 pub const CLIENT: Self = Self(kSSLClientSide);
121 }
122
123 /// Specifies the type of TLS session.
124 #[derive(Debug, Copy, Clone)]
125 pub struct SslConnectionType(SSLConnectionType);
126
127 impl SslConnectionType {
128 /// A traditional TLS stream.
129 pub const STREAM: Self = Self(kSSLStreamType);
130
131 /// A DTLS session.
132 pub const DATAGRAM: Self = Self(kSSLDatagramType);
133 }
134
135 /// An error or intermediate state after a TLS handshake attempt.
136 #[derive(Debug)]
137 pub enum HandshakeError<S> {
138 /// The handshake failed.
139 Failure(Error),
140 /// The handshake was interrupted midway through.
141 Interrupted(MidHandshakeSslStream<S>),
142 }
143
144 impl<S> From<Error> for HandshakeError<S> {
from(err: Error) -> Self145 fn from(err: Error) -> Self {
146 Self::Failure(err)
147 }
148 }
149
150 /// An error or intermediate state after a TLS handshake attempt.
151 #[derive(Debug)]
152 pub enum ClientHandshakeError<S> {
153 /// The handshake failed.
154 Failure(Error),
155 /// The handshake was interrupted midway through.
156 Interrupted(MidHandshakeClientBuilder<S>),
157 }
158
159 impl<S> From<Error> for ClientHandshakeError<S> {
from(err: Error) -> Self160 fn from(err: Error) -> Self {
161 Self::Failure(err)
162 }
163 }
164
165 /// An SSL stream midway through the handshake process.
166 #[derive(Debug)]
167 pub struct MidHandshakeSslStream<S> {
168 stream: SslStream<S>,
169 error: Error,
170 }
171
172 impl<S> MidHandshakeSslStream<S> {
173 /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S174 pub fn get_ref(&self) -> &S {
175 self.stream.get_ref()
176 }
177
178 /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S179 pub fn get_mut(&mut self) -> &mut S {
180 self.stream.get_mut()
181 }
182
183 /// Returns a shared reference to the `SslContext` of the stream.
context(&self) -> &SslContext184 pub fn context(&self) -> &SslContext {
185 self.stream.context()
186 }
187
188 /// Returns a mutable reference to the `SslContext` of the stream.
context_mut(&mut self) -> &mut SslContext189 pub fn context_mut(&mut self) -> &mut SslContext {
190 self.stream.context_mut()
191 }
192
193 /// Returns `true` iff `break_on_server_auth` was set and the handshake has
194 /// progressed to that point.
server_auth_completed(&self) -> bool195 pub fn server_auth_completed(&self) -> bool {
196 self.error.code() == errSSLPeerAuthCompleted
197 }
198
199 /// Returns `true` iff `break_on_cert_requested` was set and the handshake
200 /// has progressed to that point.
client_cert_requested(&self) -> bool201 pub fn client_cert_requested(&self) -> bool {
202 self.error.code() == errSSLClientCertRequested
203 }
204
205 /// Returns `true` iff the underlying stream returned an error with the
206 /// `WouldBlock` kind.
would_block(&self) -> bool207 pub fn would_block(&self) -> bool {
208 self.error.code() == errSSLWouldBlock
209 }
210
211 /// Returns the error which caused the handshake interruption.
error(&self) -> &Error212 pub fn error(&self) -> &Error {
213 &self.error
214 }
215
216 /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>>217 pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
218 self.stream.handshake()
219 }
220 }
221
222 /// An SSL stream midway through the handshake process.
223 #[derive(Debug)]
224 pub struct MidHandshakeClientBuilder<S> {
225 stream: MidHandshakeSslStream<S>,
226 domain: Option<String>,
227 certs: Vec<SecCertificate>,
228 trust_certs_only: bool,
229 danger_accept_invalid_certs: bool,
230 }
231
232 impl<S> MidHandshakeClientBuilder<S> {
233 /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S234 pub fn get_ref(&self) -> &S {
235 self.stream.get_ref()
236 }
237
238 /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S239 pub fn get_mut(&mut self) -> &mut S {
240 self.stream.get_mut()
241 }
242
243 /// Returns the error which caused the handshake interruption.
error(&self) -> &Error244 pub fn error(&self) -> &Error {
245 self.stream.error()
246 }
247
248 /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>>249 pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
250 let MidHandshakeClientBuilder {
251 stream,
252 domain,
253 certs,
254 trust_certs_only,
255 danger_accept_invalid_certs,
256 } = self;
257
258 let mut result = stream.handshake();
259 loop {
260 let stream = match result {
261 Ok(stream) => return Ok(stream),
262 Err(HandshakeError::Interrupted(stream)) => stream,
263 Err(HandshakeError::Failure(err)) => {
264 return Err(ClientHandshakeError::Failure(err))
265 }
266 };
267
268 if stream.would_block() {
269 let ret = MidHandshakeClientBuilder {
270 stream,
271 domain,
272 certs,
273 trust_certs_only,
274 danger_accept_invalid_certs,
275 };
276 return Err(ClientHandshakeError::Interrupted(ret));
277 }
278
279 if stream.server_auth_completed() {
280 if danger_accept_invalid_certs {
281 result = stream.handshake();
282 continue;
283 }
284 let mut trust = match stream.context().peer_trust2()? {
285 Some(trust) => trust,
286 None => {
287 result = stream.handshake();
288 continue;
289 }
290 };
291 trust.set_anchor_certificates(&certs)?;
292 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
293 let policy =
294 SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_ref().map(|s| &**s));
295 trust.set_policy(&policy)?;
296 let trusted = trust.evaluate()?;
297 match trusted {
298 TrustResult::PROCEED | TrustResult::UNSPECIFIED => {
299 result = stream.handshake();
300 continue;
301 }
302 TrustResult::DENY => {
303 let err = Error::from_code(errSecTrustSettingDeny);
304 return Err(ClientHandshakeError::Failure(err));
305 }
306 TrustResult::RECOVERABLE_TRUST_FAILURE | TrustResult::FATAL_TRUST_FAILURE => {
307 let err = Error::from_code(errSecNotTrusted);
308 return Err(ClientHandshakeError::Failure(err));
309 }
310 TrustResult::INVALID | TrustResult::OTHER_ERROR | _ => {
311 let err = Error::from_code(errSecBadReq);
312 return Err(ClientHandshakeError::Failure(err));
313 }
314 }
315 }
316
317 let err = Error::from_code(stream.error().code());
318 return Err(ClientHandshakeError::Failure(err));
319 }
320 }
321 }
322
323 /// Specifies the state of a TLS session.
324 #[derive(Debug, PartialEq, Eq)]
325 pub struct SessionState(SSLSessionState);
326
327 impl SessionState {
328 /// The session has not yet started.
329 pub const IDLE: Self = Self(kSSLIdle);
330
331 /// The session is in the handshake process.
332 pub const HANDSHAKE: Self = Self(kSSLHandshake);
333
334 /// The session is connected.
335 pub const CONNECTED: Self = Self(kSSLConnected);
336
337 /// The session has been terminated.
338 pub const CLOSED: Self = Self(kSSLClosed);
339
340 /// The session has been aborted due to an error.
341 pub const ABORTED: Self = Self(kSSLAborted);
342 }
343
344 /// Specifies a server's requirement for client certificates.
345 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
346 pub struct SslAuthenticate(SSLAuthenticate);
347
348 impl SslAuthenticate {
349 /// Do not request a client certificate.
350 pub const NEVER: Self = Self(kNeverAuthenticate);
351
352 /// Require a client certificate.
353 pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
354
355 /// Request but do not require a client certificate.
356 pub const TRY: Self = Self(kTryAuthenticate);
357 }
358
359 /// Specifies the state of client certificate processing.
360 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
361 pub struct SslClientCertificateState(SSLClientCertificateState);
362
363 impl SslClientCertificateState {
364 /// A client certificate has not been requested or sent.
365 pub const NONE: Self = Self(kSSLClientCertNone);
366
367 /// A client certificate has been requested but not recieved.
368 pub const REQUESTED: Self = Self(kSSLClientCertRequested);
369 /// A client certificate has been received and successfully validated.
370 pub const SENT: Self = Self(kSSLClientCertSent);
371
372 /// A client certificate has been received but has failed to validate.
373 pub const REJECTED: Self = Self(kSSLClientCertRejected); }
374
375 /// Specifies protocol versions.
376 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
377 pub struct SslProtocol(SSLProtocol);
378
379 impl SslProtocol {
380 /// No protocol has been or should be negotiated or specified; use the default.
381 pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
382
383 /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
384 /// SSL 3.0.
385 pub const SSL3: Self = Self(kSSLProtocol3);
386
387 /// The TLS 1.0 protocol is preferred, though lower versions may be used
388 /// if the peer does not support TLS 1.0.
389 pub const TLS1: Self = Self(kTLSProtocol1);
390
391 /// The TLS 1.1 protocol is preferred, though lower versions may be used
392 /// if the peer does not support TLS 1.1.
393 pub const TLS11: Self = Self(kTLSProtocol11);
394
395 /// The TLS 1.2 protocol is preferred, though lower versions may be used
396 /// if the peer does not support TLS 1.2.
397 pub const TLS12: Self = Self(kTLSProtocol12);
398
399 /// The TLS 1.3 protocol is preferred, though lower versions may be used
400 /// if the peer does not support TLS 1.3.
401 pub const TLS13: Self = Self(kTLSProtocol13);
402
403 /// Only the SSL 2.0 protocol is accepted.
404 pub const SSL2: Self = Self(kSSLProtocol2);
405
406 /// The DTLSv1 protocol is preferred.
407 pub const DTLS1: Self = Self(kDTLSProtocol1);
408
409 /// Only the SSL 3.0 protocol is accepted.
410 pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
411
412 /// Only the TLS 1.0 protocol is accepted.
413 pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
414
415 /// All supported TLS/SSL versions are accepted.
416 pub const ALL: Self = Self(kSSLProtocolAll);
417 }
418
419 declare_TCFType! {
420 /// A Secure Transport SSL/TLS context object.
421 SslContext, SSLContextRef
422 }
423
424 impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
425
426 impl fmt::Debug for SslContext {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result427 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
428 let mut builder = fmt.debug_struct("SslContext");
429 if let Ok(state) = self.state() {
430 builder.field("state", &state);
431 }
432 builder.finish()
433 }
434 }
435
436 unsafe impl Sync for SslContext {}
437 unsafe impl Send for SslContext {}
438
439 impl AsInner for SslContext {
440 type Inner = SSLContextRef;
441
as_inner(&self) -> SSLContextRef442 fn as_inner(&self) -> SSLContextRef {
443 self.0
444 }
445 }
446
447 macro_rules! impl_options {
448 ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
449 $(
450 $(#[$a])*
451 pub fn $set(&mut self, value: bool) -> Result<()> {
452 unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
453 }
454
455 $(#[$a])*
456 pub fn $get(&self) -> Result<bool> {
457 let mut value = 0;
458 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
459 Ok(value != 0)
460 }
461 )*
462 }
463 }
464
465 impl SslContext {
466 /// Creates a new `SslContext` for the specified side and type of SSL
467 /// connection.
new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self>468 pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
469 unsafe {
470 let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
471 Ok(Self(ctx))
472 }
473 }
474
475 /// Sets the fully qualified domain name of the peer.
476 ///
477 /// This will be used on the client side of a session to validate the
478 /// common name field of the server's certificate. It has no effect if
479 /// called on a server-side `SslContext`.
480 ///
481 /// It is *highly* recommended to call this method before starting the
482 /// handshake process.
set_peer_domain_name(&mut self, peer_name: &str) -> Result<()>483 pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
484 unsafe {
485 // SSLSetPeerDomainName doesn't need a null terminated string
486 cvt(SSLSetPeerDomainName(
487 self.0,
488 peer_name.as_ptr() as *const _,
489 peer_name.len(),
490 ))
491 }
492 }
493
494 /// Returns the peer domain name set by `set_peer_domain_name`.
peer_domain_name(&self) -> Result<String>495 pub fn peer_domain_name(&self) -> Result<String> {
496 unsafe {
497 let mut len = 0;
498 cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
499 let mut buf = vec![0; len];
500 cvt(SSLGetPeerDomainName(
501 self.0,
502 buf.as_mut_ptr() as *mut _,
503 &mut len,
504 ))?;
505 Ok(String::from_utf8(buf).unwrap())
506 }
507 }
508
509 /// Sets the certificate to be used by this side of the SSL session.
510 ///
511 /// This must be called before the handshake for server-side connections,
512 /// and can be used on the client-side to specify a client certificate.
513 ///
514 /// The `identity` corresponds to the leaf certificate and private
515 /// key, and the `certs` correspond to extra certificates in the chain.
set_certificate( &mut self, identity: &SecIdentity, certs: &[SecCertificate], ) -> Result<()>516 pub fn set_certificate(
517 &mut self,
518 identity: &SecIdentity,
519 certs: &[SecCertificate],
520 ) -> Result<()> {
521 let mut arr = vec![identity.as_CFType()];
522 arr.extend(certs.iter().map(|c| c.as_CFType()));
523 let certs = CFArray::from_CFTypes(&arr);
524
525 unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
526 }
527
528 /// Sets the peer ID of this session.
529 ///
530 /// A peer ID is an opaque sequence of bytes that will be used by Secure
531 /// Transport to identify the peer of an SSL session. If the peer ID of
532 /// this session matches that of a previously terminated session, the
533 /// previous session can be resumed without requiring a full handshake.
set_peer_id(&mut self, peer_id: &[u8]) -> Result<()>534 pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
535 unsafe {
536 cvt(SSLSetPeerID(
537 self.0,
538 peer_id.as_ptr() as *const _,
539 peer_id.len(),
540 ))
541 }
542 }
543
544 /// Returns the peer ID of this session.
peer_id(&self) -> Result<Option<&[u8]>>545 pub fn peer_id(&self) -> Result<Option<&[u8]>> {
546 unsafe {
547 let mut ptr = ptr::null();
548 let mut len = 0;
549 cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
550 if ptr.is_null() {
551 Ok(None)
552 } else {
553 Ok(Some(slice::from_raw_parts(ptr as *const _, len)))
554 }
555 }
556 }
557
558 /// Returns the list of ciphers that are supported by Secure Transport.
supported_ciphers(&self) -> Result<Vec<CipherSuite>>559 pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
560 unsafe {
561 let mut num_ciphers = 0;
562 cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
563 let mut ciphers = vec![0; num_ciphers];
564 cvt(SSLGetSupportedCiphers(
565 self.0,
566 ciphers.as_mut_ptr(),
567 &mut num_ciphers,
568 ))?;
569 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
570 }
571 }
572
573 /// Returns the list of ciphers that are eligible to be used for
574 /// negotiation.
enabled_ciphers(&self) -> Result<Vec<CipherSuite>>575 pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
576 unsafe {
577 let mut num_ciphers = 0;
578 cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
579 let mut ciphers = vec![0; num_ciphers];
580 cvt(SSLGetEnabledCiphers(
581 self.0,
582 ciphers.as_mut_ptr(),
583 &mut num_ciphers,
584 ))?;
585 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
586 }
587 }
588
589 /// Sets the list of ciphers that are eligible to be used for negotiation.
set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()>590 pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
591 let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
592 unsafe {
593 cvt(SSLSetEnabledCiphers(
594 self.0,
595 ciphers.as_ptr(),
596 ciphers.len(),
597 ))
598 }
599 }
600
601 /// Returns the cipher being used by the session.
negotiated_cipher(&self) -> Result<CipherSuite>602 pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
603 unsafe {
604 let mut cipher = 0;
605 cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
606 Ok(CipherSuite::from_raw(cipher))
607 }
608 }
609
610 /// Sets the requirements for client certificates.
611 ///
612 /// Should only be called on server-side sessions.
set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()>613 pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
614 unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
615 }
616
617 /// Returns the state of client certificate processing.
client_certificate_state(&self) -> Result<SslClientCertificateState>618 pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
619 let mut state = 0;
620
621 unsafe {
622 cvt(SSLGetClientCertificateState(self.0, &mut state))?;
623 }
624 Ok(SslClientCertificateState(state))
625 }
626
627 /// Returns the `SecTrust` object corresponding to the peer.
628 ///
629 /// This can be used in conjunction with `set_break_on_server_auth` to
630 /// validate certificates which do not have roots in the default set.
peer_trust2(&self) -> Result<Option<SecTrust>>631 pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
632 // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
633 // so explicitly check for that
634 if self.state()? == SessionState::IDLE {
635 return Err(Error::from_code(errSecBadReq));
636 }
637
638 unsafe {
639 let mut trust = ptr::null_mut();
640 cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
641 if trust.is_null() {
642 Ok(None)
643 } else {
644 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
645 }
646 }
647 }
648
649 /// Returns the state of the session.
state(&self) -> Result<SessionState>650 pub fn state(&self) -> Result<SessionState> {
651 unsafe {
652 let mut state = 0;
653 cvt(SSLGetSessionState(self.0, &mut state))?;
654 Ok(SessionState(state))
655 }
656 }
657
658 /// Returns the protocol version being used by the session.
negotiated_protocol_version(&self) -> Result<SslProtocol>659 pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
660 unsafe {
661 let mut version = 0;
662 cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
663 Ok(SslProtocol(version))
664 }
665 }
666
667 /// Returns the maximum protocol version allowed by the session.
protocol_version_max(&self) -> Result<SslProtocol>668 pub fn protocol_version_max(&self) -> Result<SslProtocol> {
669 unsafe {
670 let mut version = 0;
671 cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
672 Ok(SslProtocol(version))
673 }
674 }
675
676 /// Sets the maximum protocol version allowed by the session.
set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()>677 pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
678 unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
679 }
680
681 /// Returns the minimum protocol version allowed by the session.
protocol_version_min(&self) -> Result<SslProtocol>682 pub fn protocol_version_min(&self) -> Result<SslProtocol> {
683 unsafe {
684 let mut version = 0;
685 cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
686 Ok(SslProtocol(version))
687 }
688 }
689
690 /// Sets the minimum protocol version allowed by the session.
set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()>691 pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
692 unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
693 }
694
695 /// Returns the set of protocols selected via ALPN if it succeeded.
696 #[cfg(feature = "alpn")]
alpn_protocols(&self) -> Result<Vec<String>>697 pub fn alpn_protocols(&self) -> Result<Vec<String>> {
698 let mut array: CFArrayRef = ptr::null();
699 unsafe {
700 #[cfg(feature = "OSX_10_13")]
701 {
702 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
703 }
704
705 #[cfg(not(feature = "OSX_10_13"))]
706 {
707 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
708 if let Some(f) = SSLCopyALPNProtocols.get() {
709 cvt(f(self.0, &mut array))?;
710 } else {
711 return Err(Error::from_code(errSecUnimplemented));
712 }
713 }
714
715 if array.is_null() {
716 return Ok(vec![]);
717 }
718
719 let array = CFArray::<CFString>::wrap_under_create_rule(array);
720 Ok(array.into_iter().map(|p| p.to_string()).collect())
721 }
722 }
723
724 /// Configures the set of protocols use for ALPN.
725 ///
726 /// This is only used for client-side connections.
727 #[cfg(feature = "alpn")]
set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()>728 pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
729 // When CFMutableArray is added to core-foundation and IntoIterator trait
730 // is implemented for CFMutableArray, the code below should directly collect
731 // into a CFMutableArray.
732 let protocols = CFArray::from_CFTypes(
733 &protocols
734 .iter()
735 .map(|proto| CFString::new(proto))
736 .collect::<Vec<_>>(),
737 );
738
739 #[cfg(feature = "OSX_10_13")]
740 {
741 unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
742 }
743 #[cfg(not(feature = "OSX_10_13"))]
744 {
745 dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
746 if let Some(f) = SSLSetALPNProtocols.get() {
747 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
748 } else {
749 Err(Error::from_code(errSecUnimplemented))
750 }
751 }
752 }
753
754 /// Sets whether the client sends the `SessionTicket` extension in its `ClientHello`.
755 ///
756 /// On its own, this will just cause the client to send an empty `SessionTicket` extension on
757 /// every connection. [`SslContext::set_peer_id`] must also be used to key the session
758 /// ticket returned by the server.
759 ///
760 /// [`SslContext::set_peer_id`]: #method.set_peer_id
761 #[cfg(feature = "session-tickets")]
set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()>762 pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
763 #[cfg(feature = "OSX_10_13")]
764 {
765 unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, enabled as Boolean)) }
766 }
767 #[cfg(not(feature = "OSX_10_13"))]
768 {
769 dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
770 if let Some(f) = SSLSetSessionTicketsEnabled.get() {
771 unsafe { cvt(f(self.0, enabled as Boolean)) }
772 } else {
773 Err(Error::from_code(errSecUnimplemented))
774 }
775 }
776 }
777
778 /// Sets whether a protocol is enabled or not.
779 ///
780 /// # Note
781 ///
782 /// On OSX this is a deprecated API in favor of `set_protocol_version_max` and
783 /// `set_protocol_version_min`, although if you're working with OSX 10.8 or before you may have
784 /// to use this API instead.
785 #[cfg(target_os = "macos")]
786 #[deprecated(note = "use `set_protocol_version_max`")]
set_protocol_version_enabled( &mut self, protocol: SslProtocol, enabled: bool, ) -> Result<()>787 pub fn set_protocol_version_enabled(
788 &mut self,
789 protocol: SslProtocol,
790 enabled: bool,
791 ) -> Result<()> {
792 unsafe {
793 cvt(SSLSetProtocolVersionEnabled(
794 self.0,
795 protocol.0,
796 enabled as Boolean,
797 ))
798 }
799 }
800
801 /// Returns the number of bytes which can be read without triggering a
802 /// `read` call in the underlying stream.
buffered_read_size(&self) -> Result<usize>803 pub fn buffered_read_size(&self) -> Result<usize> {
804 unsafe {
805 let mut size = 0;
806 cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
807 Ok(size)
808 }
809 }
810
811 impl_options! {
812 /// If enabled, the handshake process will pause and return instead of
813 /// automatically validating a server's certificate.
814 const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
815 /// If enabled, the handshake process will pause and return after
816 /// the server requests a certificate from the client.
817 const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
818 /// If enabled, the handshake process will pause and return instead of
819 /// automatically validating a client's certificate.
820 const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
821 /// If enabled, TLS false start will be performed if an appropriate
822 /// cipher suite is negotiated.
823 ///
824 /// Requires the `OSX_10_9` (or greater) feature.
825 #[cfg(feature = "OSX_10_9")]
826 const kSSLSessionOptionFalseStart: false_start & set_false_start,
827 /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
828 /// connections using block ciphers to mitigate the BEAST attack.
829 ///
830 /// Requires the `OSX_10_9` (or greater) feature.
831 #[cfg(feature = "OSX_10_9")]
832 const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
833 }
834
into_stream<S>(self, stream: S) -> Result<SslStream<S>> where S: Read + Write,835 fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
836 where
837 S: Read + Write,
838 {
839 unsafe {
840 let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
841 if ret != errSecSuccess {
842 return Err(Error::from_code(ret));
843 }
844
845 let stream = Connection {
846 stream,
847 err: None,
848 panic: None,
849 };
850 let stream = Box::into_raw(Box::new(stream));
851 let ret = SSLSetConnection(self.0, stream as *mut _);
852 if ret != errSecSuccess {
853 let _conn = Box::from_raw(stream);
854 return Err(Error::from_code(ret));
855 }
856
857 Ok(SslStream {
858 ctx: self,
859 _m: PhantomData,
860 })
861 }
862 }
863
864 /// Performs the SSL/TLS handshake.
handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>> where S: Read + Write,865 pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
866 where
867 S: Read + Write,
868 {
869 self.into_stream(stream)
870 .map_err(HandshakeError::Failure)
871 .and_then(SslStream::handshake)
872 }
873 }
874
875 struct Connection<S> {
876 stream: S,
877 err: Option<io::Error>,
878 panic: Option<Box<dyn Any + Send>>,
879 }
880
881 // the logic here is based off of libcurl's
882
translate_err(e: &io::Error) -> OSStatus883 fn translate_err(e: &io::Error) -> OSStatus {
884 match e.kind() {
885 io::ErrorKind::NotFound => errSSLClosedGraceful,
886 io::ErrorKind::ConnectionReset => errSSLClosedAbort,
887 io::ErrorKind::WouldBlock |
888 io::ErrorKind::NotConnected => errSSLWouldBlock,
889 _ => errSecIO,
890 }
891 }
892
read_func<S>( connection: SSLConnectionRef, data: *mut c_void, data_length: *mut usize, ) -> OSStatus where S: Read,893 unsafe extern "C" fn read_func<S>(
894 connection: SSLConnectionRef,
895 data: *mut c_void,
896 data_length: *mut usize,
897 ) -> OSStatus
898 where
899 S: Read,
900 {
901 let conn: &mut Connection<S> = &mut *(connection as *mut _);
902 let data = slice::from_raw_parts_mut(data as *mut u8, *data_length);
903 let mut start = 0;
904 let mut ret = errSecSuccess;
905
906 while start < data.len() {
907 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
908 Ok(Ok(0)) => {
909 ret = errSSLClosedNoNotify;
910 break;
911 }
912 Ok(Ok(len)) => start += len,
913 Ok(Err(e)) => {
914 ret = translate_err(&e);
915 conn.err = Some(e);
916 break;
917 }
918 Err(e) => {
919 ret = errSecIO;
920 conn.panic = Some(e);
921 break;
922 }
923 }
924 }
925
926 *data_length = start;
927 ret
928 }
929
write_func<S>( connection: SSLConnectionRef, data: *const c_void, data_length: *mut usize, ) -> OSStatus where S: Write,930 unsafe extern "C" fn write_func<S>(
931 connection: SSLConnectionRef,
932 data: *const c_void,
933 data_length: *mut usize,
934 ) -> OSStatus
935 where
936 S: Write,
937 {
938 let conn: &mut Connection<S> = &mut *(connection as *mut _);
939 let data = slice::from_raw_parts(data as *mut u8, *data_length);
940 let mut start = 0;
941 let mut ret = errSecSuccess;
942
943 while start < data.len() {
944 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
945 Ok(Ok(0)) => {
946 ret = errSSLClosedNoNotify;
947 break;
948 }
949 Ok(Ok(len)) => start += len,
950 Ok(Err(e)) => {
951 ret = translate_err(&e);
952 conn.err = Some(e);
953 break;
954 }
955 Err(e) => {
956 ret = errSecIO;
957 conn.panic = Some(e);
958 break;
959 }
960 }
961 }
962
963 *data_length = start;
964 ret
965 }
966
967 /// A type implementing SSL/TLS encryption over an underlying stream.
968 pub struct SslStream<S> {
969 ctx: SslContext,
970 _m: PhantomData<S>,
971 }
972
973 impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result974 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
975 fmt.debug_struct("SslStream")
976 .field("context", &self.ctx)
977 .field("stream", self.get_ref())
978 .finish()
979 }
980 }
981
982 impl<S> Drop for SslStream<S> {
drop(&mut self)983 fn drop(&mut self) {
984 unsafe {
985 let mut conn = ptr::null();
986 let ret = SSLGetConnection(self.ctx.0, &mut conn);
987 assert!(ret == errSecSuccess);
988 Box::<Connection<S>>::from_raw(conn as *mut _);
989 }
990 }
991 }
992
993 impl<S> SslStream<S> {
handshake(mut self) -> result::Result<Self, HandshakeError<S>>994 fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
995 match unsafe { SSLHandshake(self.ctx.0) } {
996 errSecSuccess => Ok(self),
997 reason @ errSSLPeerAuthCompleted
998 | reason @ errSSLClientCertRequested
999 | reason @ errSSLWouldBlock
1000 | reason @ errSSLClientHelloReceived => {
1001 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
1002 stream: self,
1003 error: Error::from_code(reason),
1004 }))
1005 }
1006 err => {
1007 self.check_panic();
1008 Err(HandshakeError::Failure(Error::from_code(err)))
1009 }
1010 }
1011 }
1012
1013 /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S1014 pub fn get_ref(&self) -> &S {
1015 &self.connection().stream
1016 }
1017
1018 /// Returns a mutable reference to the underlying stream.
get_mut(&mut self) -> &mut S1019 pub fn get_mut(&mut self) -> &mut S {
1020 &mut self.connection_mut().stream
1021 }
1022
1023 /// Returns a shared reference to the `SslContext` of the stream.
context(&self) -> &SslContext1024 pub fn context(&self) -> &SslContext {
1025 &self.ctx
1026 }
1027
1028 /// Returns a mutable reference to the `SslContext` of the stream.
context_mut(&mut self) -> &mut SslContext1029 pub fn context_mut(&mut self) -> &mut SslContext {
1030 &mut self.ctx
1031 }
1032
1033 /// Shuts down the connection.
close(&mut self) -> result::Result<(), io::Error>1034 pub fn close(&mut self) -> result::Result<(), io::Error> {
1035 unsafe {
1036 let ret = SSLClose(self.ctx.0);
1037 if ret == errSecSuccess {
1038 Ok(())
1039 } else {
1040 Err(self.get_error(ret))
1041 }
1042 }
1043 }
1044
connection(&self) -> &Connection<S>1045 fn connection(&self) -> &Connection<S> {
1046 unsafe {
1047 let mut conn = ptr::null();
1048 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1049 assert!(ret == errSecSuccess);
1050
1051 mem::transmute(conn)
1052 }
1053 }
1054
connection_mut(&mut self) -> &mut Connection<S>1055 fn connection_mut(&mut self) -> &mut Connection<S> {
1056 unsafe {
1057 let mut conn = ptr::null();
1058 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1059 assert!(ret == errSecSuccess);
1060
1061 mem::transmute(conn)
1062 }
1063 }
1064
check_panic(&mut self)1065 fn check_panic(&mut self) {
1066 let conn = self.connection_mut();
1067 if let Some(err) = conn.panic.take() {
1068 panic::resume_unwind(err);
1069 }
1070 }
1071
get_error(&mut self, ret: OSStatus) -> io::Error1072 fn get_error(&mut self, ret: OSStatus) -> io::Error {
1073 self.check_panic();
1074
1075 if let Some(err) = self.connection_mut().err.take() {
1076 err
1077 } else {
1078 io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1079 }
1080 }
1081 }
1082
1083 impl<S: Read + Write> Read for SslStream<S> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>1084 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1085 // Below we base our return value off the amount of data read, so a
1086 // zero-length buffer might cause us to erroneously interpret this
1087 // request as an error. Instead short-circuit that logic and return
1088 // `Ok(0)` instead.
1089 if buf.is_empty() {
1090 return Ok(0);
1091 }
1092
1093 // If some data was buffered but not enough to fill `buf`, SSLRead
1094 // will try to read a new packet. This is bad because there may be
1095 // no more data but the socket is remaining open (e.g HTTPS with
1096 // Connection: keep-alive).
1097 let buffered = self.context().buffered_read_size().unwrap_or(0);
1098 let to_read = if buffered > 0 {
1099 cmp::min(buffered, buf.len())
1100 } else {
1101 buf.len()
1102 };
1103
1104 unsafe {
1105 let mut nread = 0;
1106 let ret = SSLRead(self.ctx.0, buf.as_mut_ptr() as *mut _, to_read, &mut nread);
1107 // SSLRead can return an error at the same time it returns the last
1108 // chunk of data (!)
1109 if nread > 0 {
1110 return Ok(nread as usize);
1111 }
1112
1113 match ret {
1114 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1115 _ => Err(self.get_error(ret)),
1116 }
1117 }
1118 }
1119 }
1120
1121 impl<S: Read + Write> Write for SslStream<S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>1122 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1123 // Like above in read, short circuit a 0-length write
1124 if buf.is_empty() {
1125 return Ok(0);
1126 }
1127 unsafe {
1128 let mut nwritten = 0;
1129 let ret = SSLWrite(
1130 self.ctx.0,
1131 buf.as_ptr() as *const _,
1132 buf.len(),
1133 &mut nwritten,
1134 );
1135 // just to be safe, base success off of nwritten rather than ret
1136 // for the same reason as in read
1137 if nwritten > 0 {
1138 Ok(nwritten as usize)
1139 } else {
1140 Err(self.get_error(ret))
1141 }
1142 }
1143 }
1144
flush(&mut self) -> io::Result<()>1145 fn flush(&mut self) -> io::Result<()> {
1146 self.connection_mut().stream.flush()
1147 }
1148 }
1149
1150 /// A builder type to simplify the creation of client side `SslStream`s.
1151 #[derive(Debug)]
1152 pub struct ClientBuilder {
1153 identity: Option<SecIdentity>,
1154 certs: Vec<SecCertificate>,
1155 chain: Vec<SecCertificate>,
1156 protocol_min: Option<SslProtocol>,
1157 protocol_max: Option<SslProtocol>,
1158 trust_certs_only: bool,
1159 use_sni: bool,
1160 danger_accept_invalid_certs: bool,
1161 danger_accept_invalid_hostnames: bool,
1162 whitelisted_ciphers: Vec<CipherSuite>,
1163 blacklisted_ciphers: Vec<CipherSuite>,
1164 alpn: Option<Vec<String>>,
1165 enable_session_tickets: bool,
1166 }
1167
1168 impl Default for ClientBuilder {
default() -> Self1169 fn default() -> Self {
1170 Self::new()
1171 }
1172 }
1173
1174 impl ClientBuilder {
1175 /// Creates a new builder with default options.
new() -> Self1176 pub fn new() -> Self {
1177 Self {
1178 identity: None,
1179 certs: Vec::new(),
1180 chain: Vec::new(),
1181 protocol_min: None,
1182 protocol_max: None,
1183 trust_certs_only: false,
1184 use_sni: true,
1185 danger_accept_invalid_certs: false,
1186 danger_accept_invalid_hostnames: false,
1187 whitelisted_ciphers: Vec::new(),
1188 blacklisted_ciphers: Vec::new(),
1189 alpn: None,
1190 enable_session_tickets: false,
1191 }
1192 }
1193
1194 /// Specifies the set of root certificates to trust when
1195 /// verifying the server's certificate.
anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self1196 pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1197 self.certs = certs.to_owned();
1198 self
1199 }
1200
1201 /// Specifies whether to trust the built-in certificates in addition
1202 /// to specified anchor certificates.
trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self1203 pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1204 self.trust_certs_only = only;
1205 self
1206 }
1207
1208 /// Specifies whether to trust invalid certificates.
1209 ///
1210 /// # Warning
1211 ///
1212 /// You should think very carefully before using this method. If invalid
1213 /// certificates are trusted, *any* certificate for *any* site will be
1214 /// trusted for use. This includes expired certificates. This introduces
1215 /// significant vulnerabilities, and should only be used as a last resort.
danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self1216 pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1217 self.danger_accept_invalid_certs = noverify;
1218 self
1219 }
1220
1221 /// Specifies whether to use Server Name Indication (SNI).
use_sni(&mut self, use_sni: bool) -> &mut Self1222 pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1223 self.use_sni = use_sni;
1224 self
1225 }
1226
1227 /// Specifies whether to verify that the server's hostname matches its certificate.
1228 ///
1229 /// # Warning
1230 ///
1231 /// You should think very carefully before using this method. If hostnames are not verified,
1232 /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1233 /// vulnerabilities, and should only be used as a last resort.
danger_accept_invalid_hostnames( &mut self, danger_accept_invalid_hostnames: bool, ) -> &mut Self1234 pub fn danger_accept_invalid_hostnames(
1235 &mut self,
1236 danger_accept_invalid_hostnames: bool,
1237 ) -> &mut Self {
1238 self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1239 self
1240 }
1241
1242 /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self1243 pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1244 self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
1245 self
1246 }
1247
1248 /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self1249 pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1250 self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
1251 self
1252 }
1253
1254 /// Use the specified identity as a SSL/TLS client certificate.
identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self1255 pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1256 self.identity = Some(identity.clone());
1257 self.chain = chain.to_owned();
1258 self
1259 }
1260
1261 /// Configure the minimum protocol that this client will support.
protocol_min(&mut self, min: SslProtocol) -> &mut Self1262 pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1263 self.protocol_min = Some(min);
1264 self
1265 }
1266
1267 /// Configure the minimum protocol that this client will support.
protocol_max(&mut self, max: SslProtocol) -> &mut Self1268 pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1269 self.protocol_max = Some(max);
1270 self
1271 }
1272
1273 /// Configures the set of protocols used for ALPN.
1274 #[cfg(feature = "alpn")]
alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self1275 pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1276 self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
1277 self
1278 }
1279
1280 /// Configures the use of the RFC 5077 `SessionTicket` extension.
1281 ///
1282 /// Defaults to `false`.
1283 #[cfg(feature = "session-tickets")]
enable_session_tickets(&mut self, enable: bool) -> &mut Self1284 pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1285 self.enable_session_tickets = enable;
1286 self
1287 }
1288
1289 /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1290 ///
1291 /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
handshake<S>( &self, domain: &str, stream: S, ) -> result::Result<SslStream<S>, ClientHandshakeError<S>> where S: Read + Write,1292 pub fn handshake<S>(
1293 &self,
1294 domain: &str,
1295 stream: S,
1296 ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1297 where
1298 S: Read + Write,
1299 {
1300 // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1301 // of the handshake logic through that.
1302 let stream = MidHandshakeSslStream {
1303 stream: self.ctx_into_stream(domain, stream)?,
1304 error: Error::from(errSecSuccess),
1305 };
1306
1307 let certs = self.certs.clone();
1308 let stream = MidHandshakeClientBuilder {
1309 stream,
1310 domain: if self.danger_accept_invalid_hostnames {
1311 None
1312 } else {
1313 Some(domain.to_string())
1314 },
1315 certs,
1316 trust_certs_only: self.trust_certs_only,
1317 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1318 };
1319 stream.handshake()
1320 }
1321
ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>> where S: Read + Write,1322 fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1323 where
1324 S: Read + Write,
1325 {
1326 let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1327
1328 if self.use_sni {
1329 ctx.set_peer_domain_name(domain)?;
1330 }
1331 if let Some(ref identity) = self.identity {
1332 ctx.set_certificate(identity, &self.chain)?;
1333 }
1334 #[cfg(feature = "alpn")]
1335 {
1336 if let Some(ref alpn) = self.alpn {
1337 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1338 }
1339 }
1340 #[cfg(feature = "session-tickets")]
1341 {
1342 if self.enable_session_tickets {
1343 // We must use the domain here to ensure that we go through certificate validation
1344 // again rather than resuming the session if the domain changes.
1345 ctx.set_peer_id(domain.as_bytes())?;
1346 ctx.set_session_tickets_enabled(true)?;
1347 }
1348 }
1349 ctx.set_break_on_server_auth(true)?;
1350 self.configure_protocols(&mut ctx)?;
1351 self.configure_ciphers(&mut ctx)?;
1352
1353 ctx.into_stream(stream)
1354 }
1355
configure_protocols(&self, ctx: &mut SslContext) -> Result<()>1356 fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1357 if let Some(min) = self.protocol_min {
1358 ctx.set_protocol_version_min(min)?;
1359 }
1360 if let Some(max) = self.protocol_max {
1361 ctx.set_protocol_version_max(max)?;
1362 }
1363 Ok(())
1364 }
1365
configure_ciphers(&self, ctx: &mut SslContext) -> Result<()>1366 fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1367 let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1368 ctx.enabled_ciphers()?
1369 } else {
1370 self.whitelisted_ciphers.clone()
1371 };
1372
1373 if !self.blacklisted_ciphers.is_empty() {
1374 ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1375 }
1376
1377 ctx.set_enabled_ciphers(&ciphers)?;
1378 Ok(())
1379 }
1380 }
1381
1382 /// A builder type to simplify the creation of server-side `SslStream`s.
1383 #[derive(Debug)]
1384 pub struct ServerBuilder {
1385 identity: SecIdentity,
1386 certs: Vec<SecCertificate>,
1387 }
1388
1389 impl ServerBuilder {
1390 /// Creates a new `ServerBuilder` which will use the specified identity
1391 /// and certificate chain for handshakes.
new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self1392 pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1393 Self {
1394 identity: identity.clone(),
1395 certs: certs.to_owned(),
1396 }
1397 }
1398
1399 /// Initiates a new SSL/TLS session over a stream.
handshake<S>(&self, stream: S) -> Result<SslStream<S>> where S: Read + Write,1400 pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1401 where
1402 S: Read + Write,
1403 {
1404 let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1405 ctx.set_certificate(&self.identity, &self.certs)?;
1406 match ctx.handshake(stream) {
1407 Ok(stream) => Ok(stream),
1408 Err(HandshakeError::Interrupted(stream)) => Err(stream.error().clone()),
1409 Err(HandshakeError::Failure(err)) => Err(err),
1410 }
1411 }
1412 }
1413
1414 #[cfg(test)]
1415 mod test {
1416 use std::io;
1417 use std::io::prelude::*;
1418 use std::net::TcpStream;
1419
1420 use super::*;
1421
1422 #[test]
connect()1423 fn connect() {
1424 let mut ctx = p!(SslContext::new(
1425 SslProtocolSide::CLIENT,
1426 SslConnectionType::STREAM
1427 ));
1428 p!(ctx.set_peer_domain_name("google.com"));
1429 let stream = p!(TcpStream::connect("google.com:443"));
1430 p!(ctx.handshake(stream));
1431 }
1432
1433 #[test]
connect_bad_domain()1434 fn connect_bad_domain() {
1435 let mut ctx = p!(SslContext::new(
1436 SslProtocolSide::CLIENT,
1437 SslConnectionType::STREAM
1438 ));
1439 p!(ctx.set_peer_domain_name("foobar.com"));
1440 let stream = p!(TcpStream::connect("google.com:443"));
1441 match ctx.handshake(stream) {
1442 Ok(_) => panic!("expected failure"),
1443 Err(_) => {}
1444 }
1445 }
1446
1447 #[test]
load_page()1448 fn load_page() {
1449 let mut ctx = p!(SslContext::new(
1450 SslProtocolSide::CLIENT,
1451 SslConnectionType::STREAM
1452 ));
1453 p!(ctx.set_peer_domain_name("google.com"));
1454 let stream = p!(TcpStream::connect("google.com:443"));
1455 let mut stream = p!(ctx.handshake(stream));
1456 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1457 p!(stream.flush());
1458 let mut buf = vec![];
1459 p!(stream.read_to_end(&mut buf));
1460 println!("{}", String::from_utf8_lossy(&buf));
1461 }
1462
1463 #[test]
client_no_session_ticket_resumption()1464 fn client_no_session_ticket_resumption() {
1465 for _ in 0..2 {
1466 let stream = p!(TcpStream::connect("google.com:443"));
1467
1468 // Manually handshake here.
1469 let stream = MidHandshakeSslStream {
1470 stream: ClientBuilder::new()
1471 .ctx_into_stream("google.com", stream)
1472 .unwrap(),
1473 error: Error::from(errSecSuccess),
1474 };
1475
1476 let mut result = stream.handshake();
1477
1478 if let Err(HandshakeError::Interrupted(stream)) = result {
1479 assert!(stream.server_auth_completed());
1480 result = stream.handshake();
1481 } else {
1482 panic!("Unexpectedly skipped server auth");
1483 }
1484
1485 assert!(result.is_ok());
1486 }
1487 }
1488
1489 #[test]
1490 #[cfg(feature = "session-tickets")]
client_session_ticket_resumption()1491 fn client_session_ticket_resumption() {
1492 // The first time through this loop, we should do a full handshake. The second time, we
1493 // should immediately finish the handshake without breaking on server auth.
1494 for i in 0..2 {
1495 let stream = p!(TcpStream::connect("google.com:443"));
1496 let mut builder = ClientBuilder::new();
1497 builder.enable_session_tickets(true);
1498
1499 // Manually handshake here.
1500 let stream = MidHandshakeSslStream {
1501 stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1502 error: Error::from(errSecSuccess),
1503 };
1504
1505 let mut result = stream.handshake();
1506
1507 if let Err(HandshakeError::Interrupted(stream)) = result {
1508 assert!(stream.server_auth_completed());
1509 assert_eq!(
1510 i, 0,
1511 "Session ticket resumption did not work, server auth was not skipped"
1512 );
1513 result = stream.handshake();
1514 } else {
1515 assert_eq!(i, 1, "Unexpectedly skipped server auth");
1516 }
1517
1518 assert!(result.is_ok());
1519 }
1520 }
1521
1522 #[test]
1523 #[cfg(feature = "alpn")]
client_alpn_accept()1524 fn client_alpn_accept() {
1525 let mut ctx = p!(SslContext::new(
1526 SslProtocolSide::CLIENT,
1527 SslConnectionType::STREAM
1528 ));
1529 p!(ctx.set_peer_domain_name("google.com"));
1530 p!(ctx.set_alpn_protocols(&vec!["h2"]));
1531 let stream = p!(TcpStream::connect("google.com:443"));
1532 let stream = ctx.handshake(stream).unwrap();
1533 assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1534 }
1535
1536 #[test]
1537 #[cfg(feature = "alpn")]
client_alpn_reject()1538 fn client_alpn_reject() {
1539 let mut ctx = p!(SslContext::new(
1540 SslProtocolSide::CLIENT,
1541 SslConnectionType::STREAM
1542 ));
1543 p!(ctx.set_peer_domain_name("google.com"));
1544 p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1545 let stream = p!(TcpStream::connect("google.com:443"));
1546 let stream = ctx.handshake(stream).unwrap();
1547 assert!(stream.context().alpn_protocols().is_err());
1548 }
1549
1550 #[test]
client_no_anchor_certs()1551 fn client_no_anchor_certs() {
1552 let stream = p!(TcpStream::connect("google.com:443"));
1553 assert!(ClientBuilder::new()
1554 .trust_anchor_certificates_only(true)
1555 .handshake("google.com", stream)
1556 .is_err());
1557 }
1558
1559 #[test]
client_bad_domain()1560 fn client_bad_domain() {
1561 let stream = p!(TcpStream::connect("google.com:443"));
1562 assert!(ClientBuilder::new()
1563 .handshake("foobar.com", stream)
1564 .is_err());
1565 }
1566
1567 #[test]
client_bad_domain_ignored()1568 fn client_bad_domain_ignored() {
1569 let stream = p!(TcpStream::connect("google.com:443"));
1570 ClientBuilder::new()
1571 .danger_accept_invalid_hostnames(true)
1572 .handshake("foobar.com", stream)
1573 .unwrap();
1574 }
1575
1576 #[test]
connect_no_verify_ssl()1577 fn connect_no_verify_ssl() {
1578 let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1579 let mut builder = ClientBuilder::new();
1580 builder.danger_accept_invalid_certs(true);
1581 builder.handshake("expired.badssl.com", stream).unwrap();
1582 }
1583
1584 #[test]
load_page_client()1585 fn load_page_client() {
1586 let stream = p!(TcpStream::connect("google.com:443"));
1587 let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1588 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1589 p!(stream.flush());
1590 let mut buf = vec![];
1591 p!(stream.read_to_end(&mut buf));
1592 println!("{}", String::from_utf8_lossy(&buf));
1593 }
1594
1595 #[test]
1596 #[cfg_attr(target_os = "ios", ignore)] // FIXME what's going on with ios?
cipher_configuration()1597 fn cipher_configuration() {
1598 let mut ctx = p!(SslContext::new(
1599 SslProtocolSide::SERVER,
1600 SslConnectionType::STREAM
1601 ));
1602 let ciphers = p!(ctx.enabled_ciphers());
1603 let ciphers = ciphers
1604 .iter()
1605 .enumerate()
1606 .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1607 .collect::<Vec<_>>();
1608 p!(ctx.set_enabled_ciphers(&ciphers));
1609 assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1610 }
1611
1612 #[test]
test_builder_whitelist_ciphers()1613 fn test_builder_whitelist_ciphers() {
1614 let stream = p!(TcpStream::connect("google.com:443"));
1615
1616 let ctx = p!(SslContext::new(
1617 SslProtocolSide::CLIENT,
1618 SslConnectionType::STREAM
1619 ));
1620 assert!(p!(ctx.enabled_ciphers()).len() > 1);
1621
1622 let ciphers = p!(ctx.enabled_ciphers());
1623 let cipher = ciphers.first().unwrap();
1624 let stream = p!(ClientBuilder::new()
1625 .whitelist_ciphers(&[*cipher])
1626 .ctx_into_stream("google.com", stream));
1627
1628 assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1629 }
1630
1631 #[test]
1632 #[cfg_attr(target_os = "ios", ignore)] // FIXME same issue as cipher_configuration
test_builder_blacklist_ciphers()1633 fn test_builder_blacklist_ciphers() {
1634 let stream = p!(TcpStream::connect("google.com:443"));
1635
1636 let ctx = p!(SslContext::new(
1637 SslProtocolSide::CLIENT,
1638 SslConnectionType::STREAM
1639 ));
1640 let num = p!(ctx.enabled_ciphers()).len();
1641 assert!(num > 1);
1642
1643 let ciphers = p!(ctx.enabled_ciphers());
1644 let cipher = ciphers.first().unwrap();
1645 let stream = p!(ClientBuilder::new()
1646 .blacklist_ciphers(&[*cipher])
1647 .ctx_into_stream("google.com", stream));
1648
1649 assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1650 }
1651
1652 #[test]
idle_context_peer_trust()1653 fn idle_context_peer_trust() {
1654 let ctx = p!(SslContext::new(
1655 SslProtocolSide::SERVER,
1656 SslConnectionType::STREAM
1657 ));
1658 assert!(ctx.peer_trust2().is_err());
1659 }
1660
1661 #[test]
peer_id()1662 fn peer_id() {
1663 let mut ctx = p!(SslContext::new(
1664 SslProtocolSide::SERVER,
1665 SslConnectionType::STREAM
1666 ));
1667 assert!(p!(ctx.peer_id()).is_none());
1668 p!(ctx.set_peer_id(b"foobar"));
1669 assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1670 }
1671
1672 #[test]
peer_domain_name()1673 fn peer_domain_name() {
1674 let mut ctx = p!(SslContext::new(
1675 SslProtocolSide::CLIENT,
1676 SslConnectionType::STREAM
1677 ));
1678 assert_eq!("", p!(ctx.peer_domain_name()));
1679 p!(ctx.set_peer_domain_name("foobar.com"));
1680 assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1681 }
1682
1683 #[test]
1684 #[should_panic(expected = "blammo")]
write_panic()1685 fn write_panic() {
1686 struct ExplodingStream(TcpStream);
1687
1688 impl Read for ExplodingStream {
1689 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1690 self.0.read(buf)
1691 }
1692 }
1693
1694 impl Write for ExplodingStream {
1695 fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1696 panic!("blammo");
1697 }
1698
1699 fn flush(&mut self) -> io::Result<()> {
1700 self.0.flush()
1701 }
1702 }
1703
1704 let mut ctx = p!(SslContext::new(
1705 SslProtocolSide::CLIENT,
1706 SslConnectionType::STREAM
1707 ));
1708 p!(ctx.set_peer_domain_name("google.com"));
1709 let stream = p!(TcpStream::connect("google.com:443"));
1710 let _ = ctx.handshake(ExplodingStream(stream));
1711 }
1712
1713 #[test]
1714 #[should_panic(expected = "blammo")]
read_panic()1715 fn read_panic() {
1716 struct ExplodingStream(TcpStream);
1717
1718 impl Read for ExplodingStream {
1719 fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1720 panic!("blammo");
1721 }
1722 }
1723
1724 impl Write for ExplodingStream {
1725 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1726 self.0.write(buf)
1727 }
1728
1729 fn flush(&mut self) -> io::Result<()> {
1730 self.0.flush()
1731 }
1732 }
1733
1734 let mut ctx = p!(SslContext::new(
1735 SslProtocolSide::CLIENT,
1736 SslConnectionType::STREAM
1737 ));
1738 p!(ctx.set_peer_domain_name("google.com"));
1739 let stream = p!(TcpStream::connect("google.com:443"));
1740 let _ = ctx.handshake(ExplodingStream(stream));
1741 }
1742
1743 #[test]
zero_length_buffers()1744 fn zero_length_buffers() {
1745 let mut ctx = p!(SslContext::new(
1746 SslProtocolSide::CLIENT,
1747 SslConnectionType::STREAM
1748 ));
1749 p!(ctx.set_peer_domain_name("google.com"));
1750 let stream = p!(TcpStream::connect("google.com:443"));
1751 let mut stream = ctx.handshake(stream).unwrap();
1752 assert_eq!(stream.write(b"").unwrap(), 0);
1753 assert_eq!(stream.read(&mut []).unwrap(), 0);
1754 }
1755 }
1756