1 //! SSL/TLS encryption support using Secure Transport.
2 //!
3 //! # Examples
4 //!
5 //! To connect as a client to a server with a certificate trusted by the system:
6 //!
7 //! ```rust
8 //! use std::io::prelude::*;
9 //! use std::net::TcpStream;
10 //! use security_framework::secure_transport::ClientBuilder;
11 //!
12 //! let stream = TcpStream::connect("google.com:443").unwrap();
13 //! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
14 //!
15 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
16 //! let mut page = vec![];
17 //! stream.read_to_end(&mut page).unwrap();
18 //! println!("{}", String::from_utf8_lossy(&page));
19 //! ```
20 //!
21 //! To connect to a server with a certificate that's *not* trusted by the
22 //! system, specify the root certificates for the server's chain to the
23 //! `ClientBuilder`:
24 //!
25 //! ```rust,no_run
26 //! use std::io::prelude::*;
27 //! use std::net::TcpStream;
28 //! use security_framework::secure_transport::ClientBuilder;
29 //!
30 //! # let root_cert = unsafe { std::mem::zeroed() };
31 //! let stream = TcpStream::connect("my_server.com:443").unwrap();
32 //! let mut stream = ClientBuilder::new()
33 //! .anchor_certificates(&[root_cert])
34 //! .handshake("my_server.com", stream)
35 //! .unwrap();
36 //!
37 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
38 //! let mut page = vec![];
39 //! stream.read_to_end(&mut page).unwrap();
40 //! println!("{}", String::from_utf8_lossy(&page));
41 //! ```
42 //!
43 //! For more advanced configuration, the `SslContext` type can be used directly.
44 //!
45 //! To run a server:
46 //!
47 //! ```rust,no_run
48 //! use std::net::TcpListener;
49 //! use std::thread;
50 //! use security_framework::secure_transport::{SslContext, SslProtocolSide, SslConnectionType};
51 //!
52 //! // Create a TCP listener and start accepting on it.
53 //! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
54 //!
55 //! for stream in listener.incoming() {
56 //! let stream = stream.unwrap();
57 //! thread::spawn(move || {
58 //! // Create a new context configured to operate on the server side of
59 //! // a traditional SSL/TLS session.
60 //! let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
61 //! .unwrap();
62 //!
63 //! // Install the certificate chain that we will be using.
64 //! # let identity = unsafe { std::mem::zeroed() };
65 //! # let intermediate_cert = unsafe { std::mem::zeroed() };
66 //! # let root_cert = unsafe { std::mem::zeroed() };
67 //! ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
68 //!
69 //! // Perform the SSL/TLS handshake and get our stream.
70 //! let mut stream = ctx.handshake(stream).unwrap();
71 //! });
72 //! }
73 //!
74 //! ```
75 #[allow(unused_imports)]
76 use core_foundation::array::{CFArray, CFArrayRef};
77
78 use core_foundation::base::{Boolean, TCFType};
79 #[cfg(feature = "alpn")]
80 use core_foundation::string::CFString;
81 use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82 use std::os::raw::c_void;
83
84 #[allow(unused_imports)]
85 use security_framework_sys::base::{
86 errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87 errSecUnimplemented,
88 };
89
90 use security_framework_sys::secure_transport::*;
91 use std::any::Any;
92 use std::cmp;
93 use std::fmt;
94 use std::io;
95 use std::io::prelude::*;
96 use std::marker::PhantomData;
97 use std::mem;
98 use std::panic::{self, AssertUnwindSafe};
99 use std::ptr;
100 use std::result;
101 use std::slice;
102
103 use crate::base::{Error, Result};
104 use crate::certificate::SecCertificate;
105 use crate::cipher_suite::CipherSuite;
106 use crate::identity::SecIdentity;
107 use crate::policy::SecPolicy;
108 use crate::trust::{SecTrust, TrustResult};
109 use crate::{cvt, AsInner};
110
111 /// Specifies a side of a TLS session.
112 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
113 pub struct SslProtocolSide(SSLProtocolSide);
114
115 impl SslProtocolSide {
116 /// The server side of the session.
117 pub const SERVER: Self = Self(kSSLServerSide);
118
119 /// The client side of the session.
120 pub const CLIENT: Self = Self(kSSLClientSide);
121 }
122
123 /// Specifies the type of TLS session.
124 #[derive(Debug, Copy, Clone)]
125 pub struct SslConnectionType(SSLConnectionType);
126
127 impl SslConnectionType {
128 /// A traditional TLS stream.
129 pub const STREAM: Self = Self(kSSLStreamType);
130
131 /// A DTLS session.
132 pub const DATAGRAM: Self = Self(kSSLDatagramType);
133 }
134
135 /// An error or intermediate state after a TLS handshake attempt.
136 #[derive(Debug)]
137 pub enum HandshakeError<S> {
138 /// The handshake failed.
139 Failure(Error),
140 /// The handshake was interrupted midway through.
141 Interrupted(MidHandshakeSslStream<S>),
142 }
143
144 impl<S> From<Error> for HandshakeError<S> {
from(err: Error) -> Self145 fn from(err: Error) -> Self {
146 Self::Failure(err)
147 }
148 }
149
150 /// An error or intermediate state after a TLS handshake attempt.
151 #[derive(Debug)]
152 pub enum ClientHandshakeError<S> {
153 /// The handshake failed.
154 Failure(Error),
155 /// The handshake was interrupted midway through.
156 Interrupted(MidHandshakeClientBuilder<S>),
157 }
158
159 impl<S> From<Error> for ClientHandshakeError<S> {
from(err: Error) -> Self160 fn from(err: Error) -> Self {
161 Self::Failure(err)
162 }
163 }
164
165 /// An SSL stream midway through the handshake process.
166 #[derive(Debug)]
167 pub struct MidHandshakeSslStream<S> {
168 stream: SslStream<S>,
169 error: Error,
170 }
171
172 impl<S> MidHandshakeSslStream<S> {
173 /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S174 pub fn get_ref(&self) -> &S {
175 self.stream.get_ref()
176 }
177
178 /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S179 pub fn get_mut(&mut self) -> &mut S {
180 self.stream.get_mut()
181 }
182
183 /// Returns a shared reference to the `SslContext` of the stream.
context(&self) -> &SslContext184 pub fn context(&self) -> &SslContext {
185 self.stream.context()
186 }
187
188 /// Returns a mutable reference to the `SslContext` of the stream.
context_mut(&mut self) -> &mut SslContext189 pub fn context_mut(&mut self) -> &mut SslContext {
190 self.stream.context_mut()
191 }
192
193 /// Returns `true` iff `break_on_server_auth` was set and the handshake has
194 /// progressed to that point.
server_auth_completed(&self) -> bool195 pub fn server_auth_completed(&self) -> bool {
196 self.error.code() == errSSLPeerAuthCompleted
197 }
198
199 /// Returns `true` iff `break_on_cert_requested` was set and the handshake
200 /// has progressed to that point.
client_cert_requested(&self) -> bool201 pub fn client_cert_requested(&self) -> bool {
202 self.error.code() == errSSLClientCertRequested
203 }
204
205 /// Returns `true` iff the underlying stream returned an error with the
206 /// `WouldBlock` kind.
would_block(&self) -> bool207 pub fn would_block(&self) -> bool {
208 self.error.code() == errSSLWouldBlock
209 }
210
211 /// Deprecated
reason(&self) -> OSStatus212 pub fn reason(&self) -> OSStatus {
213 self.error.code()
214 }
215
216 /// Returns the error which caused the handshake interruption.
error(&self) -> &Error217 pub fn error(&self) -> &Error {
218 &self.error
219 }
220
221 /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>>222 pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
223 self.stream.handshake()
224 }
225 }
226
227 /// An SSL stream midway through the handshake process.
228 #[derive(Debug)]
229 pub struct MidHandshakeClientBuilder<S> {
230 stream: MidHandshakeSslStream<S>,
231 domain: Option<String>,
232 certs: Vec<SecCertificate>,
233 trust_certs_only: bool,
234 danger_accept_invalid_certs: bool,
235 }
236
237 impl<S> MidHandshakeClientBuilder<S> {
238 /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S239 pub fn get_ref(&self) -> &S {
240 self.stream.get_ref()
241 }
242
243 /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S244 pub fn get_mut(&mut self) -> &mut S {
245 self.stream.get_mut()
246 }
247
248 /// Returns the error which caused the handshake interruption.
error(&self) -> &Error249 pub fn error(&self) -> &Error {
250 self.stream.error()
251 }
252
253 /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>>254 pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
255 let MidHandshakeClientBuilder {
256 stream,
257 domain,
258 certs,
259 trust_certs_only,
260 danger_accept_invalid_certs,
261 } = self;
262
263 let mut result = stream.handshake();
264 loop {
265 let stream = match result {
266 Ok(stream) => return Ok(stream),
267 Err(HandshakeError::Interrupted(stream)) => stream,
268 Err(HandshakeError::Failure(err)) => {
269 return Err(ClientHandshakeError::Failure(err))
270 }
271 };
272
273 if stream.would_block() {
274 let ret = MidHandshakeClientBuilder {
275 stream,
276 domain,
277 certs,
278 trust_certs_only,
279 danger_accept_invalid_certs,
280 };
281 return Err(ClientHandshakeError::Interrupted(ret));
282 }
283
284 if stream.server_auth_completed() {
285 if danger_accept_invalid_certs {
286 result = stream.handshake();
287 continue;
288 }
289 let mut trust = match stream.context().peer_trust2()? {
290 Some(trust) => trust,
291 None => {
292 result = stream.handshake();
293 continue;
294 }
295 };
296 trust.set_anchor_certificates(&certs)?;
297 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
298 let policy =
299 SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_ref().map(|s| &**s));
300 trust.set_policy(&policy)?;
301 let trusted = trust.evaluate()?;
302 match trusted {
303 TrustResult::PROCEED | TrustResult::UNSPECIFIED => {
304 result = stream.handshake();
305 continue;
306 }
307 TrustResult::DENY => {
308 let err = Error::from_code(errSecTrustSettingDeny);
309 return Err(ClientHandshakeError::Failure(err));
310 }
311 TrustResult::RECOVERABLE_TRUST_FAILURE | TrustResult::FATAL_TRUST_FAILURE => {
312 let err = Error::from_code(errSecNotTrusted);
313 return Err(ClientHandshakeError::Failure(err));
314 }
315 TrustResult::INVALID | TrustResult::OTHER_ERROR | _ => {
316 let err = Error::from_code(errSecBadReq);
317 return Err(ClientHandshakeError::Failure(err));
318 }
319 }
320 }
321
322 let err = Error::from_code(stream.reason());
323 return Err(ClientHandshakeError::Failure(err));
324 }
325 }
326 }
327
328 /// Specifies the state of a TLS session.
329 #[derive(Debug, PartialEq, Eq)]
330 pub struct SessionState(SSLSessionState);
331
332 impl SessionState {
333 /// The session has not yet started.
334 pub const IDLE: Self = Self(kSSLIdle);
335
336 /// The session is in the handshake process.
337 pub const HANDSHAKE: Self = Self(kSSLHandshake);
338
339 /// The session is connected.
340 pub const CONNECTED: Self = Self(kSSLConnected);
341
342 /// The session has been terminated.
343 pub const CLOSED: Self = Self(kSSLClosed);
344
345 /// The session has been aborted due to an error.
346 pub const ABORTED: Self = Self(kSSLAborted);
347 }
348
349 /// Specifies a server's requirement for client certificates.
350 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
351 pub struct SslAuthenticate(SSLAuthenticate);
352
353 impl SslAuthenticate {
354 /// Do not request a client certificate.
355 pub const NEVER: Self = Self(kNeverAuthenticate);
356
357 /// Require a client certificate.
358 pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
359
360 /// Request but do not require a client certificate.
361 pub const TRY: Self = Self(kTryAuthenticate);
362 }
363
364 /// Specifies the state of client certificate processing.
365 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
366 pub struct SslClientCertificateState(SSLClientCertificateState);
367
368 impl SslClientCertificateState {
369 /// A client certificate has not been requested or sent.
370 pub const NONE: Self = Self(kSSLClientCertNone);
371
372 /// A client certificate has been requested but not recieved.
373 pub const REQUESTED: Self = Self(kSSLClientCertRequested);
374 /// A client certificate has been received and successfully validated.
375 pub const SENT: Self = Self(kSSLClientCertSent);
376
377 /// A client certificate has been received but has failed to validate.
378 pub const REJECTED: Self = Self(kSSLClientCertRejected); }
379
380 /// Specifies protocol versions.
381 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
382 pub struct SslProtocol(SSLProtocol);
383
384 impl SslProtocol {
385 /// No protocol has been or should be negotiated or specified; use the default.
386 pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
387
388 /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
389 /// SSL 3.0.
390 pub const SSL3: Self = Self(kSSLProtocol3);
391
392 /// The TLS 1.0 protocol is preferred, though lower versions may be used
393 /// if the peer does not support TLS 1.0.
394 pub const TLS1: Self = Self(kTLSProtocol1);
395
396 /// The TLS 1.1 protocol is preferred, though lower versions may be used
397 /// if the peer does not support TLS 1.1.
398 pub const TLS11: Self = Self(kTLSProtocol11);
399
400 /// The TLS 1.2 protocol is preferred, though lower versions may be used
401 /// if the peer does not support TLS 1.2.
402 pub const TLS12: Self = Self(kTLSProtocol12);
403
404 /// The TLS 1.3 protocol is preferred, though lower versions may be used
405 /// if the peer does not support TLS 1.3.
406 pub const TLS13: Self = Self(kTLSProtocol13);
407
408 /// Only the SSL 2.0 protocol is accepted.
409 pub const SSL2: Self = Self(kSSLProtocol2);
410
411 /// The DTLSv1 protocol is preferred.
412 pub const DTLS1: Self = Self(kDTLSProtocol1);
413
414 /// Only the SSL 3.0 protocol is accepted.
415 pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
416
417 /// Only the TLS 1.0 protocol is accepted.
418 pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
419
420 /// All supported TLS/SSL versions are accepted.
421 pub const ALL: Self = Self(kSSLProtocolAll);
422 }
423
424 declare_TCFType! {
425 /// A Secure Transport SSL/TLS context object.
426 SslContext, SSLContextRef
427 }
428
429 impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
430
431 impl fmt::Debug for SslContext {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result432 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
433 let mut builder = fmt.debug_struct("SslContext");
434 if let Ok(state) = self.state() {
435 builder.field("state", &state);
436 }
437 builder.finish()
438 }
439 }
440
441 unsafe impl Sync for SslContext {}
442 unsafe impl Send for SslContext {}
443
444 impl AsInner for SslContext {
445 type Inner = SSLContextRef;
446
as_inner(&self) -> SSLContextRef447 fn as_inner(&self) -> SSLContextRef {
448 self.0
449 }
450 }
451
452 macro_rules! impl_options {
453 ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
454 $(
455 $(#[$a])*
456 pub fn $set(&mut self, value: bool) -> Result<()> {
457 unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
458 }
459
460 $(#[$a])*
461 pub fn $get(&self) -> Result<bool> {
462 let mut value = 0;
463 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
464 Ok(value != 0)
465 }
466 )*
467 }
468 }
469
470 impl SslContext {
471 /// Creates a new `SslContext` for the specified side and type of SSL
472 /// connection.
new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self>473 pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
474 unsafe {
475 let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
476 Ok(Self(ctx))
477 }
478 }
479
480 /// Sets the fully qualified domain name of the peer.
481 ///
482 /// This will be used on the client side of a session to validate the
483 /// common name field of the server's certificate. It has no effect if
484 /// called on a server-side `SslContext`.
485 ///
486 /// It is *highly* recommended to call this method before starting the
487 /// handshake process.
set_peer_domain_name(&mut self, peer_name: &str) -> Result<()>488 pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
489 unsafe {
490 // SSLSetPeerDomainName doesn't need a null terminated string
491 cvt(SSLSetPeerDomainName(
492 self.0,
493 peer_name.as_ptr() as *const _,
494 peer_name.len(),
495 ))
496 }
497 }
498
499 /// Returns the peer domain name set by `set_peer_domain_name`.
peer_domain_name(&self) -> Result<String>500 pub fn peer_domain_name(&self) -> Result<String> {
501 unsafe {
502 let mut len = 0;
503 cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
504 let mut buf = vec![0; len];
505 cvt(SSLGetPeerDomainName(
506 self.0,
507 buf.as_mut_ptr() as *mut _,
508 &mut len,
509 ))?;
510 Ok(String::from_utf8(buf).unwrap())
511 }
512 }
513
514 /// Sets the certificate to be used by this side of the SSL session.
515 ///
516 /// This must be called before the handshake for server-side connections,
517 /// and can be used on the client-side to specify a client certificate.
518 ///
519 /// The `identity` corresponds to the leaf certificate and private
520 /// key, and the `certs` correspond to extra certificates in the chain.
set_certificate( &mut self, identity: &SecIdentity, certs: &[SecCertificate], ) -> Result<()>521 pub fn set_certificate(
522 &mut self,
523 identity: &SecIdentity,
524 certs: &[SecCertificate],
525 ) -> Result<()> {
526 let mut arr = vec![identity.as_CFType()];
527 arr.extend(certs.iter().map(|c| c.as_CFType()));
528 let certs = CFArray::from_CFTypes(&arr);
529
530 unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
531 }
532
533 /// Sets the peer ID of this session.
534 ///
535 /// A peer ID is an opaque sequence of bytes that will be used by Secure
536 /// Transport to identify the peer of an SSL session. If the peer ID of
537 /// this session matches that of a previously terminated session, the
538 /// previous session can be resumed without requiring a full handshake.
set_peer_id(&mut self, peer_id: &[u8]) -> Result<()>539 pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
540 unsafe {
541 cvt(SSLSetPeerID(
542 self.0,
543 peer_id.as_ptr() as *const _,
544 peer_id.len(),
545 ))
546 }
547 }
548
549 /// Returns the peer ID of this session.
peer_id(&self) -> Result<Option<&[u8]>>550 pub fn peer_id(&self) -> Result<Option<&[u8]>> {
551 unsafe {
552 let mut ptr = ptr::null();
553 let mut len = 0;
554 cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
555 if ptr.is_null() {
556 Ok(None)
557 } else {
558 Ok(Some(slice::from_raw_parts(ptr as *const _, len)))
559 }
560 }
561 }
562
563 /// Returns the list of ciphers that are supported by Secure Transport.
supported_ciphers(&self) -> Result<Vec<CipherSuite>>564 pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
565 unsafe {
566 let mut num_ciphers = 0;
567 cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
568 let mut ciphers = vec![0; num_ciphers];
569 cvt(SSLGetSupportedCiphers(
570 self.0,
571 ciphers.as_mut_ptr(),
572 &mut num_ciphers,
573 ))?;
574 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
575 }
576 }
577
578 /// Returns the list of ciphers that are eligible to be used for
579 /// negotiation.
enabled_ciphers(&self) -> Result<Vec<CipherSuite>>580 pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
581 unsafe {
582 let mut num_ciphers = 0;
583 cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
584 let mut ciphers = vec![0; num_ciphers];
585 cvt(SSLGetEnabledCiphers(
586 self.0,
587 ciphers.as_mut_ptr(),
588 &mut num_ciphers,
589 ))?;
590 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
591 }
592 }
593
594 /// Sets the list of ciphers that are eligible to be used for negotiation.
set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()>595 pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
596 let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
597 unsafe {
598 cvt(SSLSetEnabledCiphers(
599 self.0,
600 ciphers.as_ptr(),
601 ciphers.len(),
602 ))
603 }
604 }
605
606 /// Returns the cipher being used by the session.
negotiated_cipher(&self) -> Result<CipherSuite>607 pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
608 unsafe {
609 let mut cipher = 0;
610 cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
611 Ok(CipherSuite::from_raw(cipher))
612 }
613 }
614
615 /// Sets the requirements for client certificates.
616 ///
617 /// Should only be called on server-side sessions.
set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()>618 pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
619 unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
620 }
621
622 /// Returns the state of client certificate processing.
client_certificate_state(&self) -> Result<SslClientCertificateState>623 pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
624 let mut state = 0;
625
626 unsafe {
627 cvt(SSLGetClientCertificateState(self.0, &mut state))?;
628 }
629 Ok(SslClientCertificateState(state))
630 }
631
632 /// Returns the `SecTrust` object corresponding to the peer.
633 ///
634 /// This can be used in conjunction with `set_break_on_server_auth` to
635 /// validate certificates which do not have roots in the default set.
peer_trust2(&self) -> Result<Option<SecTrust>>636 pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
637 // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
638 // so explicitly check for that
639 if self.state()? == SessionState::IDLE {
640 return Err(Error::from_code(errSecBadReq));
641 }
642
643 unsafe {
644 let mut trust = ptr::null_mut();
645 cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
646 if trust.is_null() {
647 Ok(None)
648 } else {
649 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
650 }
651 }
652 }
653
654 #[allow(missing_docs)]
655 #[deprecated(since = "0.2.1", note = "use peer_trust2 instead")]
peer_trust(&self) -> Result<SecTrust>656 pub fn peer_trust(&self) -> Result<SecTrust> {
657 match self.peer_trust2() {
658 Ok(Some(trust)) => Ok(trust),
659 Ok(None) => panic!("no trust available"),
660 Err(e) => Err(e),
661 }
662 }
663
664 /// Returns the state of the session.
state(&self) -> Result<SessionState>665 pub fn state(&self) -> Result<SessionState> {
666 unsafe {
667 let mut state = 0;
668 cvt(SSLGetSessionState(self.0, &mut state))?;
669 Ok(SessionState(state))
670 }
671 }
672
673 /// Returns the protocol version being used by the session.
negotiated_protocol_version(&self) -> Result<SslProtocol>674 pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
675 unsafe {
676 let mut version = 0;
677 cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
678 Ok(SslProtocol(version))
679 }
680 }
681
682 /// Returns the maximum protocol version allowed by the session.
protocol_version_max(&self) -> Result<SslProtocol>683 pub fn protocol_version_max(&self) -> Result<SslProtocol> {
684 unsafe {
685 let mut version = 0;
686 cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
687 Ok(SslProtocol(version))
688 }
689 }
690
691 /// Sets the maximum protocol version allowed by the session.
set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()>692 pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
693 unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
694 }
695
696 /// Returns the minimum protocol version allowed by the session.
protocol_version_min(&self) -> Result<SslProtocol>697 pub fn protocol_version_min(&self) -> Result<SslProtocol> {
698 unsafe {
699 let mut version = 0;
700 cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
701 Ok(SslProtocol(version))
702 }
703 }
704
705 /// Sets the minimum protocol version allowed by the session.
set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()>706 pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
707 unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
708 }
709
710 /// Returns the set of protocols selected via ALPN if it succeeded.
711 #[cfg(feature = "alpn")]
alpn_protocols(&self) -> Result<Vec<String>>712 pub fn alpn_protocols(&self) -> Result<Vec<String>> {
713 let mut array: CFArrayRef = ptr::null();
714 unsafe {
715 #[cfg(feature = "OSX_10_13")]
716 {
717 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
718 }
719
720 #[cfg(not(feature = "OSX_10_13"))]
721 {
722 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
723 if let Some(f) = SSLCopyALPNProtocols.get() {
724 cvt(f(self.0, &mut array))?;
725 } else {
726 return Err(Error::from_code(errSecUnimplemented));
727 }
728 }
729
730 if array.is_null() {
731 return Ok(vec![]);
732 }
733
734 let array = CFArray::<CFString>::wrap_under_create_rule(array);
735 Ok(array.into_iter().map(|p| p.to_string()).collect())
736 }
737 }
738
739 /// Configures the set of protocols use for ALPN.
740 ///
741 /// This is only used for client-side connections.
742 #[cfg(feature = "alpn")]
set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()>743 pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
744 // When CFMutableArray is added to core-foundation and IntoIterator trait
745 // is implemented for CFMutableArray, the code below should directly collect
746 // into a CFMutableArray.
747 let protocols = CFArray::from_CFTypes(
748 &protocols
749 .iter()
750 .map(|proto| CFString::new(proto))
751 .collect::<Vec<_>>(),
752 );
753
754 #[cfg(feature = "OSX_10_13")]
755 {
756 unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
757 }
758 #[cfg(not(feature = "OSX_10_13"))]
759 {
760 dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
761 if let Some(f) = SSLSetALPNProtocols.get() {
762 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
763 } else {
764 Err(Error::from_code(errSecUnimplemented))
765 }
766 }
767 }
768
769 /// Sets whether the client sends the `SessionTicket` extension in its `ClientHello`.
770 ///
771 /// On its own, this will just cause the client to send an empty `SessionTicket` extension on
772 /// every connection. [`SslContext::set_peer_id`] must also be used to key the session
773 /// ticket returned by the server.
774 ///
775 /// [`SslContext::set_peer_id`]: #method.set_peer_id
776 #[cfg(feature = "session-tickets")]
set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()>777 pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
778 #[cfg(feature = "OSX_10_13")]
779 {
780 unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, enabled as Boolean)) }
781 }
782 #[cfg(not(feature = "OSX_10_13"))]
783 {
784 dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
785 if let Some(f) = SSLSetSessionTicketsEnabled.get() {
786 unsafe { cvt(f(self.0, enabled as Boolean)) }
787 } else {
788 Err(Error::from_code(errSecUnimplemented))
789 }
790 }
791 }
792
793 /// Sets whether a protocol is enabled or not.
794 ///
795 /// # Note
796 ///
797 /// On OSX this is a deprecated API in favor of `set_protocol_version_max` and
798 /// `set_protocol_version_min`, although if you're working with OSX 10.8 or before you may have
799 /// to use this API instead.
800 #[cfg(target_os = "macos")]
set_protocol_version_enabled( &mut self, protocol: SslProtocol, enabled: bool, ) -> Result<()>801 pub fn set_protocol_version_enabled(
802 &mut self,
803 protocol: SslProtocol,
804 enabled: bool,
805 ) -> Result<()> {
806 unsafe {
807 cvt(SSLSetProtocolVersionEnabled(
808 self.0,
809 protocol.0,
810 enabled as Boolean,
811 ))
812 }
813 }
814
815 /// Returns the number of bytes which can be read without triggering a
816 /// `read` call in the underlying stream.
buffered_read_size(&self) -> Result<usize>817 pub fn buffered_read_size(&self) -> Result<usize> {
818 unsafe {
819 let mut size = 0;
820 cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
821 Ok(size)
822 }
823 }
824
825 impl_options! {
826 /// If enabled, the handshake process will pause and return instead of
827 /// automatically validating a server's certificate.
828 const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
829 /// If enabled, the handshake process will pause and return after
830 /// the server requests a certificate from the client.
831 const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
832 /// If enabled, the handshake process will pause and return instead of
833 /// automatically validating a client's certificate.
834 const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
835 /// If enabled, TLS false start will be performed if an appropriate
836 /// cipher suite is negotiated.
837 ///
838 /// Requires the `OSX_10_9` (or greater) feature.
839 #[cfg(feature = "OSX_10_9")]
840 const kSSLSessionOptionFalseStart: false_start & set_false_start,
841 /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
842 /// connections using block ciphers to mitigate the BEAST attack.
843 ///
844 /// Requires the `OSX_10_9` (or greater) feature.
845 #[cfg(feature = "OSX_10_9")]
846 const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
847 }
848
into_stream<S>(self, stream: S) -> Result<SslStream<S>> where S: Read + Write,849 fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
850 where
851 S: Read + Write,
852 {
853 unsafe {
854 let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
855 if ret != errSecSuccess {
856 return Err(Error::from_code(ret));
857 }
858
859 let stream = Connection {
860 stream,
861 err: None,
862 panic: None,
863 };
864 let stream = Box::into_raw(Box::new(stream));
865 let ret = SSLSetConnection(self.0, stream as *mut _);
866 if ret != errSecSuccess {
867 let _conn = Box::from_raw(stream);
868 return Err(Error::from_code(ret));
869 }
870
871 Ok(SslStream {
872 ctx: self,
873 _m: PhantomData,
874 })
875 }
876 }
877
878 /// Performs the SSL/TLS handshake.
handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>> where S: Read + Write,879 pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
880 where
881 S: Read + Write,
882 {
883 self.into_stream(stream)
884 .map_err(HandshakeError::Failure)
885 .and_then(SslStream::handshake)
886 }
887 }
888
889 struct Connection<S> {
890 stream: S,
891 err: Option<io::Error>,
892 panic: Option<Box<dyn Any + Send>>,
893 }
894
895 // the logic here is based off of libcurl's
896
translate_err(e: &io::Error) -> OSStatus897 fn translate_err(e: &io::Error) -> OSStatus {
898 match e.kind() {
899 io::ErrorKind::NotFound => errSSLClosedGraceful,
900 io::ErrorKind::ConnectionReset => errSSLClosedAbort,
901 io::ErrorKind::WouldBlock |
902 io::ErrorKind::NotConnected => errSSLWouldBlock,
903 _ => errSecIO,
904 }
905 }
906
read_func<S>( connection: SSLConnectionRef, data: *mut c_void, data_length: *mut usize, ) -> OSStatus where S: Read,907 unsafe extern "C" fn read_func<S>(
908 connection: SSLConnectionRef,
909 data: *mut c_void,
910 data_length: *mut usize,
911 ) -> OSStatus
912 where
913 S: Read,
914 {
915 let conn: &mut Connection<S> = &mut *(connection as *mut _);
916 let data = slice::from_raw_parts_mut(data as *mut u8, *data_length);
917 let mut start = 0;
918 let mut ret = errSecSuccess;
919
920 while start < data.len() {
921 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
922 Ok(Ok(0)) => {
923 ret = errSSLClosedNoNotify;
924 break;
925 }
926 Ok(Ok(len)) => start += len,
927 Ok(Err(e)) => {
928 ret = translate_err(&e);
929 conn.err = Some(e);
930 break;
931 }
932 Err(e) => {
933 ret = errSecIO;
934 conn.panic = Some(e);
935 break;
936 }
937 }
938 }
939
940 *data_length = start;
941 ret
942 }
943
write_func<S>( connection: SSLConnectionRef, data: *const c_void, data_length: *mut usize, ) -> OSStatus where S: Write,944 unsafe extern "C" fn write_func<S>(
945 connection: SSLConnectionRef,
946 data: *const c_void,
947 data_length: *mut usize,
948 ) -> OSStatus
949 where
950 S: Write,
951 {
952 let conn: &mut Connection<S> = &mut *(connection as *mut _);
953 let data = slice::from_raw_parts(data as *mut u8, *data_length);
954 let mut start = 0;
955 let mut ret = errSecSuccess;
956
957 while start < data.len() {
958 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
959 Ok(Ok(0)) => {
960 ret = errSSLClosedNoNotify;
961 break;
962 }
963 Ok(Ok(len)) => start += len,
964 Ok(Err(e)) => {
965 ret = translate_err(&e);
966 conn.err = Some(e);
967 break;
968 }
969 Err(e) => {
970 ret = errSecIO;
971 conn.panic = Some(e);
972 break;
973 }
974 }
975 }
976
977 *data_length = start;
978 ret
979 }
980
981 /// A type implementing SSL/TLS encryption over an underlying stream.
982 pub struct SslStream<S> {
983 ctx: SslContext,
984 _m: PhantomData<S>,
985 }
986
987 impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result988 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
989 fmt.debug_struct("SslStream")
990 .field("context", &self.ctx)
991 .field("stream", self.get_ref())
992 .finish()
993 }
994 }
995
996 impl<S> Drop for SslStream<S> {
drop(&mut self)997 fn drop(&mut self) {
998 unsafe {
999 let mut conn = ptr::null();
1000 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1001 assert!(ret == errSecSuccess);
1002 Box::<Connection<S>>::from_raw(conn as *mut _);
1003 }
1004 }
1005 }
1006
1007 impl<S> SslStream<S> {
handshake(mut self) -> result::Result<Self, HandshakeError<S>>1008 fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
1009 match unsafe { SSLHandshake(self.ctx.0) } {
1010 errSecSuccess => Ok(self),
1011 reason @ errSSLPeerAuthCompleted
1012 | reason @ errSSLClientCertRequested
1013 | reason @ errSSLWouldBlock
1014 | reason @ errSSLClientHelloReceived => {
1015 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
1016 stream: self,
1017 error: Error::from_code(reason),
1018 }))
1019 }
1020 err => {
1021 self.check_panic();
1022 Err(HandshakeError::Failure(Error::from_code(err)))
1023 }
1024 }
1025 }
1026
1027 /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S1028 pub fn get_ref(&self) -> &S {
1029 &self.connection().stream
1030 }
1031
1032 /// Returns a mutable reference to the underlying stream.
get_mut(&mut self) -> &mut S1033 pub fn get_mut(&mut self) -> &mut S {
1034 &mut self.connection_mut().stream
1035 }
1036
1037 /// Returns a shared reference to the `SslContext` of the stream.
context(&self) -> &SslContext1038 pub fn context(&self) -> &SslContext {
1039 &self.ctx
1040 }
1041
1042 /// Returns a mutable reference to the `SslContext` of the stream.
context_mut(&mut self) -> &mut SslContext1043 pub fn context_mut(&mut self) -> &mut SslContext {
1044 &mut self.ctx
1045 }
1046
1047 /// Shuts down the connection.
close(&mut self) -> result::Result<(), io::Error>1048 pub fn close(&mut self) -> result::Result<(), io::Error> {
1049 unsafe {
1050 let ret = SSLClose(self.ctx.0);
1051 if ret == errSecSuccess {
1052 Ok(())
1053 } else {
1054 Err(self.get_error(ret))
1055 }
1056 }
1057 }
1058
connection(&self) -> &Connection<S>1059 fn connection(&self) -> &Connection<S> {
1060 unsafe {
1061 let mut conn = ptr::null();
1062 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1063 assert!(ret == errSecSuccess);
1064
1065 mem::transmute(conn)
1066 }
1067 }
1068
connection_mut(&mut self) -> &mut Connection<S>1069 fn connection_mut(&mut self) -> &mut Connection<S> {
1070 unsafe {
1071 let mut conn = ptr::null();
1072 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1073 assert!(ret == errSecSuccess);
1074
1075 mem::transmute(conn)
1076 }
1077 }
1078
check_panic(&mut self)1079 fn check_panic(&mut self) {
1080 let conn = self.connection_mut();
1081 if let Some(err) = conn.panic.take() {
1082 panic::resume_unwind(err);
1083 }
1084 }
1085
get_error(&mut self, ret: OSStatus) -> io::Error1086 fn get_error(&mut self, ret: OSStatus) -> io::Error {
1087 self.check_panic();
1088
1089 if let Some(err) = self.connection_mut().err.take() {
1090 err
1091 } else {
1092 io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1093 }
1094 }
1095 }
1096
1097 impl<S: Read + Write> Read for SslStream<S> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>1098 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1099 // Below we base our return value off the amount of data read, so a
1100 // zero-length buffer might cause us to erroneously interpret this
1101 // request as an error. Instead short-circuit that logic and return
1102 // `Ok(0)` instead.
1103 if buf.is_empty() {
1104 return Ok(0);
1105 }
1106
1107 // If some data was buffered but not enough to fill `buf`, SSLRead
1108 // will try to read a new packet. This is bad because there may be
1109 // no more data but the socket is remaining open (e.g HTTPS with
1110 // Connection: keep-alive).
1111 let buffered = self.context().buffered_read_size().unwrap_or(0);
1112 let to_read = if buffered > 0 {
1113 cmp::min(buffered, buf.len())
1114 } else {
1115 buf.len()
1116 };
1117
1118 unsafe {
1119 let mut nread = 0;
1120 let ret = SSLRead(self.ctx.0, buf.as_mut_ptr() as *mut _, to_read, &mut nread);
1121 // SSLRead can return an error at the same time it returns the last
1122 // chunk of data (!)
1123 if nread > 0 {
1124 return Ok(nread as usize);
1125 }
1126
1127 match ret {
1128 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1129 _ => Err(self.get_error(ret)),
1130 }
1131 }
1132 }
1133 }
1134
1135 impl<S: Read + Write> Write for SslStream<S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>1136 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1137 // Like above in read, short circuit a 0-length write
1138 if buf.is_empty() {
1139 return Ok(0);
1140 }
1141 unsafe {
1142 let mut nwritten = 0;
1143 let ret = SSLWrite(
1144 self.ctx.0,
1145 buf.as_ptr() as *const _,
1146 buf.len(),
1147 &mut nwritten,
1148 );
1149 // just to be safe, base success off of nwritten rather than ret
1150 // for the same reason as in read
1151 if nwritten > 0 {
1152 Ok(nwritten as usize)
1153 } else {
1154 Err(self.get_error(ret))
1155 }
1156 }
1157 }
1158
flush(&mut self) -> io::Result<()>1159 fn flush(&mut self) -> io::Result<()> {
1160 self.connection_mut().stream.flush()
1161 }
1162 }
1163
1164 /// A builder type to simplify the creation of client side `SslStream`s.
1165 #[derive(Debug)]
1166 pub struct ClientBuilder {
1167 identity: Option<SecIdentity>,
1168 certs: Vec<SecCertificate>,
1169 chain: Vec<SecCertificate>,
1170 protocol_min: Option<SslProtocol>,
1171 protocol_max: Option<SslProtocol>,
1172 trust_certs_only: bool,
1173 use_sni: bool,
1174 danger_accept_invalid_certs: bool,
1175 danger_accept_invalid_hostnames: bool,
1176 whitelisted_ciphers: Vec<CipherSuite>,
1177 blacklisted_ciphers: Vec<CipherSuite>,
1178 alpn: Option<Vec<String>>,
1179 enable_session_tickets: bool,
1180 }
1181
1182 impl Default for ClientBuilder {
default() -> Self1183 fn default() -> Self {
1184 Self::new()
1185 }
1186 }
1187
1188 impl ClientBuilder {
1189 /// Creates a new builder with default options.
new() -> Self1190 pub fn new() -> Self {
1191 Self {
1192 identity: None,
1193 certs: Vec::new(),
1194 chain: Vec::new(),
1195 protocol_min: None,
1196 protocol_max: None,
1197 trust_certs_only: false,
1198 use_sni: true,
1199 danger_accept_invalid_certs: false,
1200 danger_accept_invalid_hostnames: false,
1201 whitelisted_ciphers: Vec::new(),
1202 blacklisted_ciphers: Vec::new(),
1203 alpn: None,
1204 enable_session_tickets: false,
1205 }
1206 }
1207
1208 /// Specifies the set of root certificates to trust when
1209 /// verifying the server's certificate.
anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self1210 pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1211 self.certs = certs.to_owned();
1212 self
1213 }
1214
1215 /// Specifies whether to trust the built-in certificates in addition
1216 /// to specified anchor certificates.
trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self1217 pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1218 self.trust_certs_only = only;
1219 self
1220 }
1221
1222 /// Specifies whether to trust invalid certificates.
1223 ///
1224 /// # Warning
1225 ///
1226 /// You should think very carefully before using this method. If invalid
1227 /// certificates are trusted, *any* certificate for *any* site will be
1228 /// trusted for use. This includes expired certificates. This introduces
1229 /// significant vulnerabilities, and should only be used as a last resort.
danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self1230 pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1231 self.danger_accept_invalid_certs = noverify;
1232 self
1233 }
1234
1235 /// Specifies whether to use Server Name Indication (SNI).
use_sni(&mut self, use_sni: bool) -> &mut Self1236 pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1237 self.use_sni = use_sni;
1238 self
1239 }
1240
1241 /// Specifies whether to verify that the server's hostname matches its certificate.
1242 ///
1243 /// # Warning
1244 ///
1245 /// You should think very carefully before using this method. If hostnames are not verified,
1246 /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1247 /// vulnerabilities, and should only be used as a last resort.
danger_accept_invalid_hostnames( &mut self, danger_accept_invalid_hostnames: bool, ) -> &mut Self1248 pub fn danger_accept_invalid_hostnames(
1249 &mut self,
1250 danger_accept_invalid_hostnames: bool,
1251 ) -> &mut Self {
1252 self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1253 self
1254 }
1255
1256 /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self1257 pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1258 self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
1259 self
1260 }
1261
1262 /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self1263 pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1264 self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
1265 self
1266 }
1267
1268 /// Use the specified identity as a SSL/TLS client certificate.
identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self1269 pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1270 self.identity = Some(identity.clone());
1271 self.chain = chain.to_owned();
1272 self
1273 }
1274
1275 /// Configure the minimum protocol that this client will support.
protocol_min(&mut self, min: SslProtocol) -> &mut Self1276 pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1277 self.protocol_min = Some(min);
1278 self
1279 }
1280
1281 /// Configure the minimum protocol that this client will support.
protocol_max(&mut self, max: SslProtocol) -> &mut Self1282 pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1283 self.protocol_max = Some(max);
1284 self
1285 }
1286
1287 /// Configures the set of protocols used for ALPN.
1288 #[cfg(feature = "alpn")]
alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self1289 pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1290 self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
1291 self
1292 }
1293
1294 /// Configures the use of the RFC 5077 `SessionTicket` extension.
1295 ///
1296 /// Defaults to `false`.
1297 #[cfg(feature = "session-tickets")]
enable_session_tickets(&mut self, enable: bool) -> &mut Self1298 pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
1299 self.enable_session_tickets = enable;
1300 self
1301 }
1302
1303 /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1304 ///
1305 /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
handshake<S>( &self, domain: &str, stream: S, ) -> result::Result<SslStream<S>, ClientHandshakeError<S>> where S: Read + Write,1306 pub fn handshake<S>(
1307 &self,
1308 domain: &str,
1309 stream: S,
1310 ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1311 where
1312 S: Read + Write,
1313 {
1314 // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1315 // of the handshake logic through that.
1316 let stream = MidHandshakeSslStream {
1317 stream: self.ctx_into_stream(domain, stream)?,
1318 error: Error::from(errSecSuccess),
1319 };
1320
1321 let certs = self.certs.clone();
1322 let stream = MidHandshakeClientBuilder {
1323 stream,
1324 domain: if self.danger_accept_invalid_hostnames {
1325 None
1326 } else {
1327 Some(domain.to_string())
1328 },
1329 certs,
1330 trust_certs_only: self.trust_certs_only,
1331 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1332 };
1333 stream.handshake()
1334 }
1335
ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>> where S: Read + Write,1336 fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1337 where
1338 S: Read + Write,
1339 {
1340 let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1341
1342 if self.use_sni {
1343 ctx.set_peer_domain_name(domain)?;
1344 }
1345 if let Some(ref identity) = self.identity {
1346 ctx.set_certificate(identity, &self.chain)?;
1347 }
1348 #[cfg(feature = "alpn")]
1349 {
1350 if let Some(ref alpn) = self.alpn {
1351 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1352 }
1353 }
1354 #[cfg(feature = "session-tickets")]
1355 {
1356 if self.enable_session_tickets {
1357 // We must use the domain here to ensure that we go through certificate validation
1358 // again rather than resuming the session if the domain changes.
1359 ctx.set_peer_id(domain.as_bytes())?;
1360 ctx.set_session_tickets_enabled(true)?;
1361 }
1362 }
1363 ctx.set_break_on_server_auth(true)?;
1364 self.configure_protocols(&mut ctx)?;
1365 self.configure_ciphers(&mut ctx)?;
1366
1367 ctx.into_stream(stream)
1368 }
1369
configure_protocols(&self, ctx: &mut SslContext) -> Result<()>1370 fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1371 if let Some(min) = self.protocol_min {
1372 ctx.set_protocol_version_min(min)?;
1373 }
1374 if let Some(max) = self.protocol_max {
1375 ctx.set_protocol_version_max(max)?;
1376 }
1377 Ok(())
1378 }
1379
configure_ciphers(&self, ctx: &mut SslContext) -> Result<()>1380 fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1381 let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1382 ctx.enabled_ciphers()?
1383 } else {
1384 self.whitelisted_ciphers.clone()
1385 };
1386
1387 if !self.blacklisted_ciphers.is_empty() {
1388 ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1389 }
1390
1391 ctx.set_enabled_ciphers(&ciphers)?;
1392 Ok(())
1393 }
1394 }
1395
1396 /// A builder type to simplify the creation of server-side `SslStream`s.
1397 #[derive(Debug)]
1398 pub struct ServerBuilder {
1399 identity: SecIdentity,
1400 certs: Vec<SecCertificate>,
1401 }
1402
1403 impl ServerBuilder {
1404 /// Creates a new `ServerBuilder` which will use the specified identity
1405 /// and certificate chain for handshakes.
new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self1406 pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1407 Self {
1408 identity: identity.clone(),
1409 certs: certs.to_owned(),
1410 }
1411 }
1412
1413 /// Initiates a new SSL/TLS session over a stream.
handshake<S>(&self, stream: S) -> Result<SslStream<S>> where S: Read + Write,1414 pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1415 where
1416 S: Read + Write,
1417 {
1418 let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1419 ctx.set_certificate(&self.identity, &self.certs)?;
1420 match ctx.handshake(stream) {
1421 Ok(stream) => Ok(stream),
1422 Err(HandshakeError::Interrupted(stream)) => Err(Error::from_code(stream.reason())),
1423 Err(HandshakeError::Failure(err)) => Err(err),
1424 }
1425 }
1426 }
1427
1428 #[cfg(test)]
1429 mod test {
1430 use std::io;
1431 use std::io::prelude::*;
1432 use std::net::TcpStream;
1433
1434 use super::*;
1435
1436 #[test]
connect()1437 fn connect() {
1438 let mut ctx = p!(SslContext::new(
1439 SslProtocolSide::CLIENT,
1440 SslConnectionType::STREAM
1441 ));
1442 p!(ctx.set_peer_domain_name("google.com"));
1443 let stream = p!(TcpStream::connect("google.com:443"));
1444 p!(ctx.handshake(stream));
1445 }
1446
1447 #[test]
connect_bad_domain()1448 fn connect_bad_domain() {
1449 let mut ctx = p!(SslContext::new(
1450 SslProtocolSide::CLIENT,
1451 SslConnectionType::STREAM
1452 ));
1453 p!(ctx.set_peer_domain_name("foobar.com"));
1454 let stream = p!(TcpStream::connect("google.com:443"));
1455 match ctx.handshake(stream) {
1456 Ok(_) => panic!("expected failure"),
1457 Err(_) => {}
1458 }
1459 }
1460
1461 #[test]
load_page()1462 fn load_page() {
1463 let mut ctx = p!(SslContext::new(
1464 SslProtocolSide::CLIENT,
1465 SslConnectionType::STREAM
1466 ));
1467 p!(ctx.set_peer_domain_name("google.com"));
1468 let stream = p!(TcpStream::connect("google.com:443"));
1469 let mut stream = p!(ctx.handshake(stream));
1470 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1471 p!(stream.flush());
1472 let mut buf = vec![];
1473 p!(stream.read_to_end(&mut buf));
1474 println!("{}", String::from_utf8_lossy(&buf));
1475 }
1476
1477 #[test]
client_no_session_ticket_resumption()1478 fn client_no_session_ticket_resumption() {
1479 for _ in 0..2 {
1480 let stream = p!(TcpStream::connect("google.com:443"));
1481
1482 // Manually handshake here.
1483 let stream = MidHandshakeSslStream {
1484 stream: ClientBuilder::new()
1485 .ctx_into_stream("google.com", stream)
1486 .unwrap(),
1487 error: Error::from(errSecSuccess),
1488 };
1489
1490 let mut result = stream.handshake();
1491
1492 if let Err(HandshakeError::Interrupted(stream)) = result {
1493 assert!(stream.server_auth_completed());
1494 result = stream.handshake();
1495 } else {
1496 panic!("Unexpectedly skipped server auth");
1497 }
1498
1499 assert!(result.is_ok());
1500 }
1501 }
1502
1503 #[test]
1504 #[cfg(feature = "session-tickets")]
client_session_ticket_resumption()1505 fn client_session_ticket_resumption() {
1506 // The first time through this loop, we should do a full handshake. The second time, we
1507 // should immediately finish the handshake without breaking on server auth.
1508 for i in 0..2 {
1509 let stream = p!(TcpStream::connect("google.com:443"));
1510 let mut builder = ClientBuilder::new();
1511 builder.enable_session_tickets(true);
1512
1513 // Manually handshake here.
1514 let stream = MidHandshakeSslStream {
1515 stream: builder.ctx_into_stream("google.com", stream).unwrap(),
1516 error: Error::from(errSecSuccess),
1517 };
1518
1519 let mut result = stream.handshake();
1520
1521 if let Err(HandshakeError::Interrupted(stream)) = result {
1522 assert!(stream.server_auth_completed());
1523 assert_eq!(
1524 i, 0,
1525 "Session ticket resumption did not work, server auth was not skipped"
1526 );
1527 result = stream.handshake();
1528 } else {
1529 assert_eq!(i, 1, "Unexpectedly skipped server auth");
1530 }
1531
1532 assert!(result.is_ok());
1533 }
1534 }
1535
1536 #[test]
1537 #[cfg(feature = "alpn")]
client_alpn_accept()1538 fn client_alpn_accept() {
1539 let mut ctx = p!(SslContext::new(
1540 SslProtocolSide::CLIENT,
1541 SslConnectionType::STREAM
1542 ));
1543 p!(ctx.set_peer_domain_name("google.com"));
1544 p!(ctx.set_alpn_protocols(&vec!["h2"]));
1545 let stream = p!(TcpStream::connect("google.com:443"));
1546 let stream = ctx.handshake(stream).unwrap();
1547 assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1548 }
1549
1550 #[test]
1551 #[cfg(feature = "alpn")]
client_alpn_reject()1552 fn client_alpn_reject() {
1553 let mut ctx = p!(SslContext::new(
1554 SslProtocolSide::CLIENT,
1555 SslConnectionType::STREAM
1556 ));
1557 p!(ctx.set_peer_domain_name("google.com"));
1558 p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1559 let stream = p!(TcpStream::connect("google.com:443"));
1560 let stream = ctx.handshake(stream).unwrap();
1561 assert!(stream.context().alpn_protocols().is_err());
1562 }
1563
1564 #[test]
client_no_anchor_certs()1565 fn client_no_anchor_certs() {
1566 let stream = p!(TcpStream::connect("google.com:443"));
1567 assert!(ClientBuilder::new()
1568 .trust_anchor_certificates_only(true)
1569 .handshake("google.com", stream)
1570 .is_err());
1571 }
1572
1573 #[test]
client_bad_domain()1574 fn client_bad_domain() {
1575 let stream = p!(TcpStream::connect("google.com:443"));
1576 assert!(ClientBuilder::new()
1577 .handshake("foobar.com", stream)
1578 .is_err());
1579 }
1580
1581 #[test]
client_bad_domain_ignored()1582 fn client_bad_domain_ignored() {
1583 let stream = p!(TcpStream::connect("google.com:443"));
1584 ClientBuilder::new()
1585 .danger_accept_invalid_hostnames(true)
1586 .handshake("foobar.com", stream)
1587 .unwrap();
1588 }
1589
1590 #[test]
connect_no_verify_ssl()1591 fn connect_no_verify_ssl() {
1592 let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1593 let mut builder = ClientBuilder::new();
1594 builder.danger_accept_invalid_certs(true);
1595 builder.handshake("expired.badssl.com", stream).unwrap();
1596 }
1597
1598 #[test]
load_page_client()1599 fn load_page_client() {
1600 let stream = p!(TcpStream::connect("google.com:443"));
1601 let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1602 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1603 p!(stream.flush());
1604 let mut buf = vec![];
1605 p!(stream.read_to_end(&mut buf));
1606 println!("{}", String::from_utf8_lossy(&buf));
1607 }
1608
1609 #[test]
1610 #[cfg_attr(target_os = "ios", ignore)] // FIXME what's going on with ios?
cipher_configuration()1611 fn cipher_configuration() {
1612 let mut ctx = p!(SslContext::new(
1613 SslProtocolSide::SERVER,
1614 SslConnectionType::STREAM
1615 ));
1616 let ciphers = p!(ctx.enabled_ciphers());
1617 let ciphers = ciphers
1618 .iter()
1619 .enumerate()
1620 .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1621 .collect::<Vec<_>>();
1622 p!(ctx.set_enabled_ciphers(&ciphers));
1623 assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1624 }
1625
1626 #[test]
test_builder_whitelist_ciphers()1627 fn test_builder_whitelist_ciphers() {
1628 let stream = p!(TcpStream::connect("google.com:443"));
1629
1630 let ctx = p!(SslContext::new(
1631 SslProtocolSide::CLIENT,
1632 SslConnectionType::STREAM
1633 ));
1634 assert!(p!(ctx.enabled_ciphers()).len() > 1);
1635
1636 let ciphers = p!(ctx.enabled_ciphers());
1637 let cipher = ciphers.first().unwrap();
1638 let stream = p!(ClientBuilder::new()
1639 .whitelist_ciphers(&[*cipher])
1640 .ctx_into_stream("google.com", stream));
1641
1642 assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1643 }
1644
1645 #[test]
1646 #[cfg_attr(target_os = "ios", ignore)] // FIXME same issue as cipher_configuration
test_builder_blacklist_ciphers()1647 fn test_builder_blacklist_ciphers() {
1648 let stream = p!(TcpStream::connect("google.com:443"));
1649
1650 let ctx = p!(SslContext::new(
1651 SslProtocolSide::CLIENT,
1652 SslConnectionType::STREAM
1653 ));
1654 let num = p!(ctx.enabled_ciphers()).len();
1655 assert!(num > 1);
1656
1657 let ciphers = p!(ctx.enabled_ciphers());
1658 let cipher = ciphers.first().unwrap();
1659 let stream = p!(ClientBuilder::new()
1660 .blacklist_ciphers(&[*cipher])
1661 .ctx_into_stream("google.com", stream));
1662
1663 assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1664 }
1665
1666 #[test]
idle_context_peer_trust()1667 fn idle_context_peer_trust() {
1668 let ctx = p!(SslContext::new(
1669 SslProtocolSide::SERVER,
1670 SslConnectionType::STREAM
1671 ));
1672 assert!(ctx.peer_trust2().is_err());
1673 }
1674
1675 #[test]
peer_id()1676 fn peer_id() {
1677 let mut ctx = p!(SslContext::new(
1678 SslProtocolSide::SERVER,
1679 SslConnectionType::STREAM
1680 ));
1681 assert!(p!(ctx.peer_id()).is_none());
1682 p!(ctx.set_peer_id(b"foobar"));
1683 assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1684 }
1685
1686 #[test]
peer_domain_name()1687 fn peer_domain_name() {
1688 let mut ctx = p!(SslContext::new(
1689 SslProtocolSide::CLIENT,
1690 SslConnectionType::STREAM
1691 ));
1692 assert_eq!("", p!(ctx.peer_domain_name()));
1693 p!(ctx.set_peer_domain_name("foobar.com"));
1694 assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1695 }
1696
1697 #[test]
1698 #[should_panic(expected = "blammo")]
write_panic()1699 fn write_panic() {
1700 struct ExplodingStream(TcpStream);
1701
1702 impl Read for ExplodingStream {
1703 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1704 self.0.read(buf)
1705 }
1706 }
1707
1708 impl Write for ExplodingStream {
1709 fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1710 panic!("blammo");
1711 }
1712
1713 fn flush(&mut self) -> io::Result<()> {
1714 self.0.flush()
1715 }
1716 }
1717
1718 let mut ctx = p!(SslContext::new(
1719 SslProtocolSide::CLIENT,
1720 SslConnectionType::STREAM
1721 ));
1722 p!(ctx.set_peer_domain_name("google.com"));
1723 let stream = p!(TcpStream::connect("google.com:443"));
1724 let _ = ctx.handshake(ExplodingStream(stream));
1725 }
1726
1727 #[test]
1728 #[should_panic(expected = "blammo")]
read_panic()1729 fn read_panic() {
1730 struct ExplodingStream(TcpStream);
1731
1732 impl Read for ExplodingStream {
1733 fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1734 panic!("blammo");
1735 }
1736 }
1737
1738 impl Write for ExplodingStream {
1739 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1740 self.0.write(buf)
1741 }
1742
1743 fn flush(&mut self) -> io::Result<()> {
1744 self.0.flush()
1745 }
1746 }
1747
1748 let mut ctx = p!(SslContext::new(
1749 SslProtocolSide::CLIENT,
1750 SslConnectionType::STREAM
1751 ));
1752 p!(ctx.set_peer_domain_name("google.com"));
1753 let stream = p!(TcpStream::connect("google.com:443"));
1754 let _ = ctx.handshake(ExplodingStream(stream));
1755 }
1756
1757 #[test]
zero_length_buffers()1758 fn zero_length_buffers() {
1759 let mut ctx = p!(SslContext::new(
1760 SslProtocolSide::CLIENT,
1761 SslConnectionType::STREAM
1762 ));
1763 p!(ctx.set_peer_domain_name("google.com"));
1764 let stream = p!(TcpStream::connect("google.com:443"));
1765 let mut stream = ctx.handshake(stream).unwrap();
1766 assert_eq!(stream.write(b"").unwrap(), 0);
1767 assert_eq!(stream.read(&mut []).unwrap(), 0);
1768 }
1769 }
1770