1 //! SSL/TLS encryption support using Secure Transport.
2 //!
3 //! # Examples
4 //!
5 //! To connect as a client to a server with a certificate trusted by the system:
6 //!
7 //! ```rust
8 //! use std::io::prelude::*;
9 //! use std::net::TcpStream;
10 //! use security_framework::secure_transport::ClientBuilder;
11 //!
12 //! let stream = TcpStream::connect("google.com:443").unwrap();
13 //! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap();
14 //!
15 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
16 //! let mut page = vec![];
17 //! stream.read_to_end(&mut page).unwrap();
18 //! println!("{}", String::from_utf8_lossy(&page));
19 //! ```
20 //!
21 //! To connect to a server with a certificate that's *not* trusted by the
22 //! system, specify the root certificates for the server's chain to the
23 //! `ClientBuilder`:
24 //!
25 //! ```rust,no_run
26 //! use std::io::prelude::*;
27 //! use std::net::TcpStream;
28 //! use security_framework::secure_transport::ClientBuilder;
29 //!
30 //! # let root_cert = unsafe { std::mem::zeroed() };
31 //! let stream = TcpStream::connect("my_server.com:443").unwrap();
32 //! let mut stream = ClientBuilder::new()
33 //! .anchor_certificates(&[root_cert])
34 //! .handshake("my_server.com", stream)
35 //! .unwrap();
36 //!
37 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
38 //! let mut page = vec![];
39 //! stream.read_to_end(&mut page).unwrap();
40 //! println!("{}", String::from_utf8_lossy(&page));
41 //! ```
42 //!
43 //! For more advanced configuration, the `SslContext` type can be used directly.
44 //!
45 //! To run a server:
46 //!
47 //! ```rust,no_run
48 //! use std::net::TcpListener;
49 //! use std::thread;
50 //! use security_framework::secure_transport::{SslContext, SslProtocolSide, SslConnectionType};
51 //!
52 //! // Create a TCP listener and start accepting on it.
53 //! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap();
54 //!
55 //! for stream in listener.incoming() {
56 //! let stream = stream.unwrap();
57 //! thread::spawn(move || {
58 //! // Create a new context configured to operate on the server side of
59 //! // a traditional SSL/TLS session.
60 //! let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)
61 //! .unwrap();
62 //!
63 //! // Install the certificate chain that we will be using.
64 //! # let identity = unsafe { std::mem::zeroed() };
65 //! # let intermediate_cert = unsafe { std::mem::zeroed() };
66 //! # let root_cert = unsafe { std::mem::zeroed() };
67 //! ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap();
68 //!
69 //! // Perform the SSL/TLS handshake and get our stream.
70 //! let mut stream = ctx.handshake(stream).unwrap();
71 //! });
72 //! }
73 //!
74 //! ```
75 #[allow(unused_imports)]
76 use core_foundation::array::{CFArray, CFArrayRef};
77
78 use core_foundation::base::{Boolean, TCFType};
79 #[cfg(feature = "alpn")]
80 use core_foundation::string::CFString;
81 use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
82 use std::os::raw::c_void;
83
84 #[allow(unused_imports)]
85 use security_framework_sys::base::{
86 errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
87 errSecUnimplemented,
88 };
89
90 use security_framework_sys::secure_transport::*;
91 use std::any::Any;
92 use std::cmp;
93 use std::fmt;
94 use std::io;
95 use std::io::prelude::*;
96 use std::marker::PhantomData;
97 use std::mem;
98 use std::panic::{self, AssertUnwindSafe};
99 use std::ptr;
100 use std::result;
101 use std::slice;
102
103 use crate::base::{Error, Result};
104 use crate::certificate::SecCertificate;
105 use crate::cipher_suite::CipherSuite;
106 use crate::identity::SecIdentity;
107 use crate::policy::SecPolicy;
108 use crate::trust::{SecTrust, TrustResult};
109 use crate::{cvt, AsInner};
110
111 /// Specifies a side of a TLS session.
112 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
113 pub struct SslProtocolSide(SSLProtocolSide);
114
115 impl SslProtocolSide {
116 /// The server side of the session.
117 pub const SERVER: Self = Self(kSSLServerSide);
118
119 /// The client side of the session.
120 pub const CLIENT: Self = Self(kSSLClientSide);
121 }
122
123 /// Specifies the type of TLS session.
124 #[derive(Debug, Copy, Clone)]
125 pub struct SslConnectionType(SSLConnectionType);
126
127 impl SslConnectionType {
128 /// A traditional TLS stream.
129 pub const STREAM: Self = Self(kSSLStreamType);
130
131 /// A DTLS session.
132 pub const DATAGRAM: Self = Self(kSSLDatagramType);
133 }
134
135 /// An error or intermediate state after a TLS handshake attempt.
136 #[derive(Debug)]
137 pub enum HandshakeError<S> {
138 /// The handshake failed.
139 Failure(Error),
140 /// The handshake was interrupted midway through.
141 Interrupted(MidHandshakeSslStream<S>),
142 }
143
144 impl<S> From<Error> for HandshakeError<S> {
from(err: Error) -> Self145 fn from(err: Error) -> Self {
146 Self::Failure(err)
147 }
148 }
149
150 /// An error or intermediate state after a TLS handshake attempt.
151 #[derive(Debug)]
152 pub enum ClientHandshakeError<S> {
153 /// The handshake failed.
154 Failure(Error),
155 /// The handshake was interrupted midway through.
156 Interrupted(MidHandshakeClientBuilder<S>),
157 }
158
159 impl<S> From<Error> for ClientHandshakeError<S> {
from(err: Error) -> Self160 fn from(err: Error) -> Self {
161 Self::Failure(err)
162 }
163 }
164
165 /// An SSL stream midway through the handshake process.
166 #[derive(Debug)]
167 pub struct MidHandshakeSslStream<S> {
168 stream: SslStream<S>,
169 error: Error,
170 }
171
172 impl<S> MidHandshakeSslStream<S> {
173 /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S174 pub fn get_ref(&self) -> &S {
175 self.stream.get_ref()
176 }
177
178 /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S179 pub fn get_mut(&mut self) -> &mut S {
180 self.stream.get_mut()
181 }
182
183 /// Returns a shared reference to the `SslContext` of the stream.
context(&self) -> &SslContext184 pub fn context(&self) -> &SslContext {
185 self.stream.context()
186 }
187
188 /// Returns a mutable reference to the `SslContext` of the stream.
context_mut(&mut self) -> &mut SslContext189 pub fn context_mut(&mut self) -> &mut SslContext {
190 self.stream.context_mut()
191 }
192
193 /// Returns `true` iff `break_on_server_auth` was set and the handshake has
194 /// progressed to that point.
server_auth_completed(&self) -> bool195 pub fn server_auth_completed(&self) -> bool {
196 self.error.code() == errSSLPeerAuthCompleted
197 }
198
199 /// Returns `true` iff `break_on_cert_requested` was set and the handshake
200 /// has progressed to that point.
client_cert_requested(&self) -> bool201 pub fn client_cert_requested(&self) -> bool {
202 self.error.code() == errSSLClientCertRequested
203 }
204
205 /// Returns `true` iff the underlying stream returned an error with the
206 /// `WouldBlock` kind.
would_block(&self) -> bool207 pub fn would_block(&self) -> bool {
208 self.error.code() == errSSLWouldBlock
209 }
210
211 /// Deprecated
reason(&self) -> OSStatus212 pub fn reason(&self) -> OSStatus {
213 self.error.code()
214 }
215
216 /// Returns the error which caused the handshake interruption.
error(&self) -> &Error217 pub fn error(&self) -> &Error {
218 &self.error
219 }
220
221 /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>>222 pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
223 self.stream.handshake()
224 }
225 }
226
227 /// An SSL stream midway through the handshake process.
228 #[derive(Debug)]
229 pub struct MidHandshakeClientBuilder<S> {
230 stream: MidHandshakeSslStream<S>,
231 domain: Option<String>,
232 certs: Vec<SecCertificate>,
233 trust_certs_only: bool,
234 danger_accept_invalid_certs: bool,
235 }
236
237 impl<S> MidHandshakeClientBuilder<S> {
238 /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S239 pub fn get_ref(&self) -> &S {
240 self.stream.get_ref()
241 }
242
243 /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S244 pub fn get_mut(&mut self) -> &mut S {
245 self.stream.get_mut()
246 }
247
248 /// Returns the error which caused the handshake interruption.
error(&self) -> &Error249 pub fn error(&self) -> &Error {
250 self.stream.error()
251 }
252
253 /// Restarts the handshake process.
handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>>254 pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
255 let MidHandshakeClientBuilder {
256 stream,
257 domain,
258 certs,
259 trust_certs_only,
260 danger_accept_invalid_certs,
261 } = self;
262
263 let mut result = stream.handshake();
264 loop {
265 let stream = match result {
266 Ok(stream) => return Ok(stream),
267 Err(HandshakeError::Interrupted(stream)) => stream,
268 Err(HandshakeError::Failure(err)) => {
269 return Err(ClientHandshakeError::Failure(err))
270 }
271 };
272
273 if stream.would_block() {
274 let ret = MidHandshakeClientBuilder {
275 stream,
276 domain,
277 certs,
278 trust_certs_only,
279 danger_accept_invalid_certs,
280 };
281 return Err(ClientHandshakeError::Interrupted(ret));
282 }
283
284 if stream.server_auth_completed() {
285 if danger_accept_invalid_certs {
286 result = stream.handshake();
287 continue;
288 }
289 let mut trust = match stream.context().peer_trust2()? {
290 Some(trust) => trust,
291 None => {
292 result = stream.handshake();
293 continue;
294 }
295 };
296 trust.set_anchor_certificates(&certs)?;
297 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
298 let policy =
299 SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_ref().map(|s| &**s));
300 trust.set_policy(&policy)?;
301 let trusted = trust.evaluate()?;
302 match trusted {
303 TrustResult::PROCEED | TrustResult::UNSPECIFIED => {
304 result = stream.handshake();
305 continue;
306 }
307 TrustResult::DENY => {
308 let err = Error::from_code(errSecTrustSettingDeny);
309 return Err(ClientHandshakeError::Failure(err));
310 }
311 TrustResult::RECOVERABLE_TRUST_FAILURE | TrustResult::FATAL_TRUST_FAILURE => {
312 let err = Error::from_code(errSecNotTrusted);
313 return Err(ClientHandshakeError::Failure(err));
314 }
315 TrustResult::INVALID | TrustResult::OTHER_ERROR | _ => {
316 let err = Error::from_code(errSecBadReq);
317 return Err(ClientHandshakeError::Failure(err));
318 }
319 }
320 }
321
322 let err = Error::from_code(stream.reason());
323 return Err(ClientHandshakeError::Failure(err));
324 }
325 }
326 }
327
328 /// Specifies the state of a TLS session.
329 #[derive(Debug, PartialEq, Eq)]
330 pub struct SessionState(SSLSessionState);
331
332 impl SessionState {
333 /// The session has not yet started.
334 pub const IDLE: Self = Self(kSSLIdle);
335
336 /// The session is in the handshake process.
337 pub const HANDSHAKE: Self = Self(kSSLHandshake);
338
339 /// The session is connected.
340 pub const CONNECTED: Self = Self(kSSLConnected);
341
342 /// The session has been terminated.
343 pub const CLOSED: Self = Self(kSSLClosed);
344
345 /// The session has been aborted due to an error.
346 pub const ABORTED: Self = Self(kSSLAborted);
347 }
348
349 /// Specifies a server's requirement for client certificates.
350 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
351 pub struct SslAuthenticate(SSLAuthenticate);
352
353 impl SslAuthenticate {
354 /// Do not request a client certificate.
355 pub const NEVER: Self = Self(kNeverAuthenticate);
356
357 /// Require a client certificate.
358 pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
359
360 /// Request but do not require a client certificate.
361 pub const TRY: Self = Self(kTryAuthenticate);
362 }
363
364 /// Specifies the state of client certificate processing.
365 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
366 pub struct SslClientCertificateState(SSLClientCertificateState);
367
368 impl SslClientCertificateState {
369 /// A client certificate has not been requested or sent.
370 pub const NONE: Self = Self(kSSLClientCertNone);
371
372 /// A client certificate has been requested but not recieved.
373 pub const REQUESTED: Self = Self(kSSLClientCertRequested);
374 /// A client certificate has been received and successfully validated.
375 pub const SENT: Self = Self(kSSLClientCertSent);
376
377 /// A client certificate has been received but has failed to validate.
378 pub const REJECTED: Self = Self(kSSLClientCertRejected); }
379
380 /// Specifies protocol versions.
381 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
382 pub struct SslProtocol(SSLProtocol);
383
384 impl SslProtocol {
385 /// No protocol has been or should be negotiated or specified; use the default.
386 pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
387
388 /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support
389 /// SSL 3.0.
390 pub const SSL3: Self = Self(kSSLProtocol3);
391
392 /// The TLS 1.0 protocol is preferred, though lower versions may be used
393 /// if the peer does not support TLS 1.0.
394 pub const TLS1: Self = Self(kTLSProtocol1);
395
396 /// The TLS 1.1 protocol is preferred, though lower versions may be used
397 /// if the peer does not support TLS 1.1.
398 pub const TLS11: Self = Self(kTLSProtocol11);
399
400 /// The TLS 1.2 protocol is preferred, though lower versions may be used
401 /// if the peer does not support TLS 1.2.
402 pub const TLS12: Self = Self(kTLSProtocol12);
403
404 /// The TLS 1.3 protocol is preferred, though lower versions may be used
405 /// if the peer does not support TLS 1.3.
406 pub const TLS13: Self = Self(kTLSProtocol13);
407
408 /// Only the SSL 2.0 protocol is accepted.
409 pub const SSL2: Self = Self(kSSLProtocol2);
410
411 /// The DTLSv1 protocol is preferred.
412 pub const DTLS1: Self = Self(kDTLSProtocol1);
413
414 /// Only the SSL 3.0 protocol is accepted.
415 pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
416
417 /// Only the TLS 1.0 protocol is accepted.
418 pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
419
420 /// All supported TLS/SSL versions are accepted.
421 pub const ALL: Self = Self(kSSLProtocolAll);
422 }
423
424 declare_TCFType! {
425 /// A Secure Transport SSL/TLS context object.
426 SslContext, SSLContextRef
427 }
428
429 impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
430
431 impl fmt::Debug for SslContext {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result432 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
433 let mut builder = fmt.debug_struct("SslContext");
434 if let Ok(state) = self.state() {
435 builder.field("state", &state);
436 }
437 builder.finish()
438 }
439 }
440
441 unsafe impl Sync for SslContext {}
442 unsafe impl Send for SslContext {}
443
444 impl AsInner for SslContext {
445 type Inner = SSLContextRef;
446
as_inner(&self) -> SSLContextRef447 fn as_inner(&self) -> SSLContextRef {
448 self.0
449 }
450 }
451
452 macro_rules! impl_options {
453 ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
454 $(
455 $(#[$a])*
456 pub fn $set(&mut self, value: bool) -> Result<()> {
457 unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
458 }
459
460 $(#[$a])*
461 pub fn $get(&self) -> Result<bool> {
462 let mut value = 0;
463 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
464 Ok(value != 0)
465 }
466 )*
467 }
468 }
469
470 impl SslContext {
471 /// Creates a new `SslContext` for the specified side and type of SSL
472 /// connection.
new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self>473 pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
474 unsafe {
475 let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
476 Ok(Self(ctx))
477 }
478 }
479
480 /// Sets the fully qualified domain name of the peer.
481 ///
482 /// This will be used on the client side of a session to validate the
483 /// common name field of the server's certificate. It has no effect if
484 /// called on a server-side `SslContext`.
485 ///
486 /// It is *highly* recommended to call this method before starting the
487 /// handshake process.
set_peer_domain_name(&mut self, peer_name: &str) -> Result<()>488 pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
489 unsafe {
490 // SSLSetPeerDomainName doesn't need a null terminated string
491 cvt(SSLSetPeerDomainName(
492 self.0,
493 peer_name.as_ptr() as *const _,
494 peer_name.len(),
495 ))
496 }
497 }
498
499 /// Returns the peer domain name set by `set_peer_domain_name`.
peer_domain_name(&self) -> Result<String>500 pub fn peer_domain_name(&self) -> Result<String> {
501 unsafe {
502 let mut len = 0;
503 cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
504 let mut buf = vec![0; len];
505 cvt(SSLGetPeerDomainName(
506 self.0,
507 buf.as_mut_ptr() as *mut _,
508 &mut len,
509 ))?;
510 Ok(String::from_utf8(buf).unwrap())
511 }
512 }
513
514 /// Sets the certificate to be used by this side of the SSL session.
515 ///
516 /// This must be called before the handshake for server-side connections,
517 /// and can be used on the client-side to specify a client certificate.
518 ///
519 /// The `identity` corresponds to the leaf certificate and private
520 /// key, and the `certs` correspond to extra certificates in the chain.
set_certificate( &mut self, identity: &SecIdentity, certs: &[SecCertificate], ) -> Result<()>521 pub fn set_certificate(
522 &mut self,
523 identity: &SecIdentity,
524 certs: &[SecCertificate],
525 ) -> Result<()> {
526 let mut arr = vec![identity.as_CFType()];
527 arr.extend(certs.iter().map(|c| c.as_CFType()));
528 let certs = CFArray::from_CFTypes(&arr);
529
530 unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
531 }
532
533 /// Sets the peer ID of this session.
534 ///
535 /// A peer ID is an opaque sequence of bytes that will be used by Secure
536 /// Transport to identify the peer of an SSL session. If the peer ID of
537 /// this session matches that of a previously terminated session, the
538 /// previous session can be resumed without requiring a full handshake.
set_peer_id(&mut self, peer_id: &[u8]) -> Result<()>539 pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
540 unsafe {
541 cvt(SSLSetPeerID(
542 self.0,
543 peer_id.as_ptr() as *const _,
544 peer_id.len(),
545 ))
546 }
547 }
548
549 /// Returns the peer ID of this session.
peer_id(&self) -> Result<Option<&[u8]>>550 pub fn peer_id(&self) -> Result<Option<&[u8]>> {
551 unsafe {
552 let mut ptr = ptr::null();
553 let mut len = 0;
554 cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
555 if ptr.is_null() {
556 Ok(None)
557 } else {
558 Ok(Some(slice::from_raw_parts(ptr as *const _, len)))
559 }
560 }
561 }
562
563 /// Returns the list of ciphers that are supported by Secure Transport.
supported_ciphers(&self) -> Result<Vec<CipherSuite>>564 pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
565 unsafe {
566 let mut num_ciphers = 0;
567 cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
568 let mut ciphers = vec![0; num_ciphers];
569 cvt(SSLGetSupportedCiphers(
570 self.0,
571 ciphers.as_mut_ptr(),
572 &mut num_ciphers,
573 ))?;
574 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
575 }
576 }
577
578 /// Returns the list of ciphers that are eligible to be used for
579 /// negotiation.
enabled_ciphers(&self) -> Result<Vec<CipherSuite>>580 pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
581 unsafe {
582 let mut num_ciphers = 0;
583 cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
584 let mut ciphers = vec![0; num_ciphers];
585 cvt(SSLGetEnabledCiphers(
586 self.0,
587 ciphers.as_mut_ptr(),
588 &mut num_ciphers,
589 ))?;
590 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
591 }
592 }
593
594 /// Sets the list of ciphers that are eligible to be used for negotiation.
set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()>595 pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
596 let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
597 unsafe {
598 cvt(SSLSetEnabledCiphers(
599 self.0,
600 ciphers.as_ptr(),
601 ciphers.len(),
602 ))
603 }
604 }
605
606 /// Returns the cipher being used by the session.
negotiated_cipher(&self) -> Result<CipherSuite>607 pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
608 unsafe {
609 let mut cipher = 0;
610 cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
611 Ok(CipherSuite::from_raw(cipher))
612 }
613 }
614
615 /// Sets the requirements for client certificates.
616 ///
617 /// Should only be called on server-side sessions.
set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()>618 pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
619 unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
620 }
621
622 /// Returns the state of client certificate processing.
client_certificate_state(&self) -> Result<SslClientCertificateState>623 pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
624 let mut state = 0;
625
626 unsafe {
627 cvt(SSLGetClientCertificateState(self.0, &mut state))?;
628 }
629 Ok(SslClientCertificateState(state))
630 }
631
632 /// Returns the `SecTrust` object corresponding to the peer.
633 ///
634 /// This can be used in conjunction with `set_break_on_server_auth` to
635 /// validate certificates which do not have roots in the default set.
peer_trust2(&self) -> Result<Option<SecTrust>>636 pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
637 // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined,
638 // so explicitly check for that
639 if self.state()? == SessionState::IDLE {
640 return Err(Error::from_code(errSecBadReq));
641 }
642
643 unsafe {
644 let mut trust = ptr::null_mut();
645 cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
646 if trust.is_null() {
647 Ok(None)
648 } else {
649 Ok(Some(SecTrust::wrap_under_create_rule(trust)))
650 }
651 }
652 }
653
654 #[allow(missing_docs)]
655 #[deprecated(since = "0.2.1", note = "use peer_trust2 instead")]
peer_trust(&self) -> Result<SecTrust>656 pub fn peer_trust(&self) -> Result<SecTrust> {
657 match self.peer_trust2() {
658 Ok(Some(trust)) => Ok(trust),
659 Ok(None) => panic!("no trust available"),
660 Err(e) => Err(e),
661 }
662 }
663
664 /// Returns the state of the session.
state(&self) -> Result<SessionState>665 pub fn state(&self) -> Result<SessionState> {
666 unsafe {
667 let mut state = 0;
668 cvt(SSLGetSessionState(self.0, &mut state))?;
669 Ok(SessionState(state))
670 }
671 }
672
673 /// Returns the protocol version being used by the session.
negotiated_protocol_version(&self) -> Result<SslProtocol>674 pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
675 unsafe {
676 let mut version = 0;
677 cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
678 Ok(SslProtocol(version))
679 }
680 }
681
682 /// Returns the maximum protocol version allowed by the session.
protocol_version_max(&self) -> Result<SslProtocol>683 pub fn protocol_version_max(&self) -> Result<SslProtocol> {
684 unsafe {
685 let mut version = 0;
686 cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
687 Ok(SslProtocol(version))
688 }
689 }
690
691 /// Sets the maximum protocol version allowed by the session.
set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()>692 pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
693 unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
694 }
695
696 /// Returns the minimum protocol version allowed by the session.
protocol_version_min(&self) -> Result<SslProtocol>697 pub fn protocol_version_min(&self) -> Result<SslProtocol> {
698 unsafe {
699 let mut version = 0;
700 cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
701 Ok(SslProtocol(version))
702 }
703 }
704
705 /// Sets the minimum protocol version allowed by the session.
set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()>706 pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
707 unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
708 }
709
710 /// Returns the set of protocols selected via ALPN if it succeeded.
711 #[cfg(feature = "alpn")]
alpn_protocols(&self) -> Result<Vec<String>>712 pub fn alpn_protocols(&self) -> Result<Vec<String>> {
713 let mut array = ptr::null();
714 unsafe {
715 #[cfg(feature = "OSX_10_13")]
716 {
717 cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
718 }
719
720 #[cfg(not(feature = "OSX_10_13"))]
721 {
722 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
723 if let Some(f) = SSLCopyALPNProtocols.get() {
724 cvt(f(self.0, &mut array))?;
725 } else {
726 return Err(Error::from_code(errSecUnimplemented));
727 }
728 }
729
730 if array.is_null() {
731 return Ok(vec![]);
732 }
733
734 let array = CFArray::<CFString>::wrap_under_create_rule(array);
735 Ok(array.into_iter().map(|p| p.to_string()).collect())
736 }
737 }
738
739 /// Configures the set of protocols use for ALPN.
740 ///
741 /// This is only used for client-side connections.
742 #[cfg(feature = "alpn")]
set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()>743 pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
744 // When CFMutableArray is added to core-foundation and IntoIterator trait
745 // is implemented for CFMutableArray, the code below should directly collect
746 // into a CFMutableArray.
747 let protocols = CFArray::from_CFTypes(
748 &protocols
749 .iter()
750 .map(|proto| CFString::new(proto))
751 .collect::<Vec<_>>(),
752 );
753
754 #[cfg(feature = "OSX_10_13")]
755 {
756 unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
757 }
758 #[cfg(not(feature = "OSX_10_13"))]
759 {
760 dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
761 if let Some(f) = SSLSetALPNProtocols.get() {
762 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
763 } else {
764 Err(Error::from_code(errSecUnimplemented))
765 }
766 }
767 }
768
769 /// Sets whether a protocol is enabled or not.
770 ///
771 /// # Note
772 ///
773 /// On OSX this is a deprecated API in favor of `set_protocol_version_max` and
774 /// `set_protocol_version_min`, although if you're working with OSX 10.8 or before you may have
775 /// to use this API instead.
776 #[cfg(target_os = "macos")]
set_protocol_version_enabled( &mut self, protocol: SslProtocol, enabled: bool, ) -> Result<()>777 pub fn set_protocol_version_enabled(
778 &mut self,
779 protocol: SslProtocol,
780 enabled: bool,
781 ) -> Result<()> {
782 unsafe {
783 cvt(SSLSetProtocolVersionEnabled(
784 self.0,
785 protocol.0,
786 enabled as Boolean,
787 ))
788 }
789 }
790
791 /// Returns the number of bytes which can be read without triggering a
792 /// `read` call in the underlying stream.
buffered_read_size(&self) -> Result<usize>793 pub fn buffered_read_size(&self) -> Result<usize> {
794 unsafe {
795 let mut size = 0;
796 cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
797 Ok(size)
798 }
799 }
800
801 impl_options! {
802 /// If enabled, the handshake process will pause and return instead of
803 /// automatically validating a server's certificate.
804 const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
805 /// If enabled, the handshake process will pause and return after
806 /// the server requests a certificate from the client.
807 const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
808 /// If enabled, the handshake process will pause and return instead of
809 /// automatically validating a client's certificate.
810 const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
811 /// If enabled, TLS false start will be performed if an appropriate
812 /// cipher suite is negotiated.
813 ///
814 /// Requires the `OSX_10_9` (or greater) feature.
815 #[cfg(feature = "OSX_10_9")]
816 const kSSLSessionOptionFalseStart: false_start & set_false_start,
817 /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0
818 /// connections using block ciphers to mitigate the BEAST attack.
819 ///
820 /// Requires the `OSX_10_9` (or greater) feature.
821 #[cfg(feature = "OSX_10_9")]
822 const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
823 }
824
into_stream<S>(self, stream: S) -> Result<SslStream<S>> where S: Read + Write,825 fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
826 where
827 S: Read + Write,
828 {
829 unsafe {
830 let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
831 if ret != errSecSuccess {
832 return Err(Error::from_code(ret));
833 }
834
835 let stream = Connection {
836 stream,
837 err: None,
838 panic: None,
839 };
840 let stream = Box::into_raw(Box::new(stream));
841 let ret = SSLSetConnection(self.0, stream as *mut _);
842 if ret != errSecSuccess {
843 let _conn = Box::from_raw(stream);
844 return Err(Error::from_code(ret));
845 }
846
847 Ok(SslStream {
848 ctx: self,
849 _m: PhantomData,
850 })
851 }
852 }
853
854 /// Performs the SSL/TLS handshake.
handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>> where S: Read + Write,855 pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
856 where
857 S: Read + Write,
858 {
859 self.into_stream(stream)
860 .map_err(HandshakeError::Failure)
861 .and_then(SslStream::handshake)
862 }
863 }
864
865 struct Connection<S> {
866 stream: S,
867 err: Option<io::Error>,
868 panic: Option<Box<dyn Any + Send>>,
869 }
870
871 // the logic here is based off of libcurl's
872
translate_err(e: &io::Error) -> OSStatus873 fn translate_err(e: &io::Error) -> OSStatus {
874 match e.kind() {
875 io::ErrorKind::NotFound => errSSLClosedGraceful,
876 io::ErrorKind::ConnectionReset => errSSLClosedAbort,
877 io::ErrorKind::WouldBlock |
878 io::ErrorKind::NotConnected => errSSLWouldBlock,
879 _ => errSecIO,
880 }
881 }
882
read_func<S>( connection: SSLConnectionRef, data: *mut c_void, data_length: *mut usize, ) -> OSStatus where S: Read,883 unsafe extern "C" fn read_func<S>(
884 connection: SSLConnectionRef,
885 data: *mut c_void,
886 data_length: *mut usize,
887 ) -> OSStatus
888 where
889 S: Read,
890 {
891 let conn: &mut Connection<S> = &mut *(connection as *mut _);
892 let data = slice::from_raw_parts_mut(data as *mut u8, *data_length);
893 let mut start = 0;
894 let mut ret = errSecSuccess;
895
896 while start < data.len() {
897 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
898 Ok(Ok(0)) => {
899 ret = errSSLClosedNoNotify;
900 break;
901 }
902 Ok(Ok(len)) => start += len,
903 Ok(Err(e)) => {
904 ret = translate_err(&e);
905 conn.err = Some(e);
906 break;
907 }
908 Err(e) => {
909 ret = errSecIO;
910 conn.panic = Some(e);
911 break;
912 }
913 }
914 }
915
916 *data_length = start;
917 ret
918 }
919
write_func<S>( connection: SSLConnectionRef, data: *const c_void, data_length: *mut usize, ) -> OSStatus where S: Write,920 unsafe extern "C" fn write_func<S>(
921 connection: SSLConnectionRef,
922 data: *const c_void,
923 data_length: *mut usize,
924 ) -> OSStatus
925 where
926 S: Write,
927 {
928 let conn: &mut Connection<S> = &mut *(connection as *mut _);
929 let data = slice::from_raw_parts(data as *mut u8, *data_length);
930 let mut start = 0;
931 let mut ret = errSecSuccess;
932
933 while start < data.len() {
934 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
935 Ok(Ok(0)) => {
936 ret = errSSLClosedNoNotify;
937 break;
938 }
939 Ok(Ok(len)) => start += len,
940 Ok(Err(e)) => {
941 ret = translate_err(&e);
942 conn.err = Some(e);
943 break;
944 }
945 Err(e) => {
946 ret = errSecIO;
947 conn.panic = Some(e);
948 break;
949 }
950 }
951 }
952
953 *data_length = start;
954 ret
955 }
956
957 /// A type implementing SSL/TLS encryption over an underlying stream.
958 pub struct SslStream<S> {
959 ctx: SslContext,
960 _m: PhantomData<S>,
961 }
962
963 impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result964 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
965 fmt.debug_struct("SslStream")
966 .field("context", &self.ctx)
967 .field("stream", self.get_ref())
968 .finish()
969 }
970 }
971
972 impl<S> Drop for SslStream<S> {
drop(&mut self)973 fn drop(&mut self) {
974 unsafe {
975 let mut conn = ptr::null();
976 let ret = SSLGetConnection(self.ctx.0, &mut conn);
977 assert!(ret == errSecSuccess);
978 Box::<Connection<S>>::from_raw(conn as *mut _);
979 }
980 }
981 }
982
983 impl<S> SslStream<S> {
handshake(mut self) -> result::Result<Self, HandshakeError<S>>984 fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
985 match unsafe { SSLHandshake(self.ctx.0) } {
986 errSecSuccess => Ok(self),
987 reason @ errSSLPeerAuthCompleted
988 | reason @ errSSLClientCertRequested
989 | reason @ errSSLWouldBlock
990 | reason @ errSSLClientHelloReceived => {
991 Err(HandshakeError::Interrupted(MidHandshakeSslStream {
992 stream: self,
993 error: Error::from_code(reason),
994 }))
995 }
996 err => {
997 self.check_panic();
998 Err(HandshakeError::Failure(Error::from_code(err)))
999 }
1000 }
1001 }
1002
1003 /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S1004 pub fn get_ref(&self) -> &S {
1005 &self.connection().stream
1006 }
1007
1008 /// Returns a mutable reference to the underlying stream.
get_mut(&mut self) -> &mut S1009 pub fn get_mut(&mut self) -> &mut S {
1010 &mut self.connection_mut().stream
1011 }
1012
1013 /// Returns a shared reference to the `SslContext` of the stream.
context(&self) -> &SslContext1014 pub fn context(&self) -> &SslContext {
1015 &self.ctx
1016 }
1017
1018 /// Returns a mutable reference to the `SslContext` of the stream.
context_mut(&mut self) -> &mut SslContext1019 pub fn context_mut(&mut self) -> &mut SslContext {
1020 &mut self.ctx
1021 }
1022
1023 /// Shuts down the connection.
close(&mut self) -> result::Result<(), io::Error>1024 pub fn close(&mut self) -> result::Result<(), io::Error> {
1025 unsafe {
1026 let ret = SSLClose(self.ctx.0);
1027 if ret == errSecSuccess {
1028 Ok(())
1029 } else {
1030 Err(self.get_error(ret))
1031 }
1032 }
1033 }
1034
connection(&self) -> &Connection<S>1035 fn connection(&self) -> &Connection<S> {
1036 unsafe {
1037 let mut conn = ptr::null();
1038 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1039 assert!(ret == errSecSuccess);
1040
1041 mem::transmute(conn)
1042 }
1043 }
1044
connection_mut(&mut self) -> &mut Connection<S>1045 fn connection_mut(&mut self) -> &mut Connection<S> {
1046 unsafe {
1047 let mut conn = ptr::null();
1048 let ret = SSLGetConnection(self.ctx.0, &mut conn);
1049 assert!(ret == errSecSuccess);
1050
1051 mem::transmute(conn)
1052 }
1053 }
1054
check_panic(&mut self)1055 fn check_panic(&mut self) {
1056 let conn = self.connection_mut();
1057 if let Some(err) = conn.panic.take() {
1058 panic::resume_unwind(err);
1059 }
1060 }
1061
get_error(&mut self, ret: OSStatus) -> io::Error1062 fn get_error(&mut self, ret: OSStatus) -> io::Error {
1063 self.check_panic();
1064
1065 if let Some(err) = self.connection_mut().err.take() {
1066 err
1067 } else {
1068 io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
1069 }
1070 }
1071 }
1072
1073 impl<S: Read + Write> Read for SslStream<S> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>1074 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1075 // Below we base our return value off the amount of data read, so a
1076 // zero-length buffer might cause us to erroneously interpret this
1077 // request as an error. Instead short-circuit that logic and return
1078 // `Ok(0)` instead.
1079 if buf.is_empty() {
1080 return Ok(0);
1081 }
1082
1083 // If some data was buffered but not enough to fill `buf`, SSLRead
1084 // will try to read a new packet. This is bad because there may be
1085 // no more data but the socket is remaining open (e.g HTTPS with
1086 // Connection: keep-alive).
1087 let buffered = self.context().buffered_read_size().unwrap_or(0);
1088 let to_read = if buffered > 0 {
1089 cmp::min(buffered, buf.len())
1090 } else {
1091 buf.len()
1092 };
1093
1094 unsafe {
1095 let mut nread = 0;
1096 let ret = SSLRead(self.ctx.0, buf.as_mut_ptr() as *mut _, to_read, &mut nread);
1097 // SSLRead can return an error at the same time it returns the last
1098 // chunk of data (!)
1099 if nread > 0 {
1100 return Ok(nread as usize);
1101 }
1102
1103 match ret {
1104 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
1105 _ => Err(self.get_error(ret)),
1106 }
1107 }
1108 }
1109 }
1110
1111 impl<S: Read + Write> Write for SslStream<S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>1112 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1113 // Like above in read, short circuit a 0-length write
1114 if buf.is_empty() {
1115 return Ok(0);
1116 }
1117 unsafe {
1118 let mut nwritten = 0;
1119 let ret = SSLWrite(
1120 self.ctx.0,
1121 buf.as_ptr() as *const _,
1122 buf.len(),
1123 &mut nwritten,
1124 );
1125 // just to be safe, base success off of nwritten rather than ret
1126 // for the same reason as in read
1127 if nwritten > 0 {
1128 Ok(nwritten as usize)
1129 } else {
1130 Err(self.get_error(ret))
1131 }
1132 }
1133 }
1134
flush(&mut self) -> io::Result<()>1135 fn flush(&mut self) -> io::Result<()> {
1136 self.connection_mut().stream.flush()
1137 }
1138 }
1139
1140 /// A builder type to simplify the creation of client side `SslStream`s.
1141 #[derive(Debug)]
1142 pub struct ClientBuilder {
1143 identity: Option<SecIdentity>,
1144 certs: Vec<SecCertificate>,
1145 chain: Vec<SecCertificate>,
1146 protocol_min: Option<SslProtocol>,
1147 protocol_max: Option<SslProtocol>,
1148 trust_certs_only: bool,
1149 use_sni: bool,
1150 danger_accept_invalid_certs: bool,
1151 danger_accept_invalid_hostnames: bool,
1152 whitelisted_ciphers: Vec<CipherSuite>,
1153 blacklisted_ciphers: Vec<CipherSuite>,
1154 alpn: Option<Vec<String>>,
1155 }
1156
1157 impl Default for ClientBuilder {
default() -> Self1158 fn default() -> Self {
1159 Self::new()
1160 }
1161 }
1162
1163 impl ClientBuilder {
1164 /// Creates a new builder with default options.
new() -> Self1165 pub fn new() -> Self {
1166 Self {
1167 identity: None,
1168 certs: Vec::new(),
1169 chain: Vec::new(),
1170 protocol_min: None,
1171 protocol_max: None,
1172 trust_certs_only: false,
1173 use_sni: true,
1174 danger_accept_invalid_certs: false,
1175 danger_accept_invalid_hostnames: false,
1176 whitelisted_ciphers: Vec::new(),
1177 blacklisted_ciphers: Vec::new(),
1178 alpn: None,
1179 }
1180 }
1181
1182 /// Specifies the set of root certificates to trust when
1183 /// verifying the server's certificate.
anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self1184 pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
1185 self.certs = certs.to_owned();
1186 self
1187 }
1188
1189 /// Specifies whether to trust the built-in certificates in addition
1190 /// to specified anchor certificates.
trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self1191 pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
1192 self.trust_certs_only = only;
1193 self
1194 }
1195
1196 /// Specifies whether to trust invalid certificates.
1197 ///
1198 /// # Warning
1199 ///
1200 /// You should think very carefully before using this method. If invalid
1201 /// certificates are trusted, *any* certificate for *any* site will be
1202 /// trusted for use. This includes expired certificates. This introduces
1203 /// significant vulnerabilities, and should only be used as a last resort.
danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self1204 pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
1205 self.danger_accept_invalid_certs = noverify;
1206 self
1207 }
1208
1209 /// Specifies whether to use Server Name Indication (SNI).
use_sni(&mut self, use_sni: bool) -> &mut Self1210 pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
1211 self.use_sni = use_sni;
1212 self
1213 }
1214
1215 /// Specifies whether to verify that the server's hostname matches its certificate.
1216 ///
1217 /// # Warning
1218 ///
1219 /// You should think very carefully before using this method. If hostnames are not verified,
1220 /// *any* valid certificate for *any* site will be trusted for use. This introduces significant
1221 /// vulnerabilities, and should only be used as a last resort.
danger_accept_invalid_hostnames( &mut self, danger_accept_invalid_hostnames: bool, ) -> &mut Self1222 pub fn danger_accept_invalid_hostnames(
1223 &mut self,
1224 danger_accept_invalid_hostnames: bool,
1225 ) -> &mut Self {
1226 self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
1227 self
1228 }
1229
1230 /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled.
whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self1231 pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
1232 self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
1233 self
1234 }
1235
1236 /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled.
blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self1237 pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
1238 self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
1239 self
1240 }
1241
1242 /// Use the specified identity as a SSL/TLS client certificate.
identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self1243 pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
1244 self.identity = Some(identity.clone());
1245 self.chain = chain.to_owned();
1246 self
1247 }
1248
1249 /// Configure the minimum protocol that this client will support.
protocol_min(&mut self, min: SslProtocol) -> &mut Self1250 pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
1251 self.protocol_min = Some(min);
1252 self
1253 }
1254
1255 /// Configure the minimum protocol that this client will support.
protocol_max(&mut self, max: SslProtocol) -> &mut Self1256 pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
1257 self.protocol_max = Some(max);
1258 self
1259 }
1260
1261 /// Configures the set of protocols used for ALPN.
1262 #[cfg(feature = "alpn")]
alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self1263 pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
1264 self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
1265 self
1266 }
1267
1268 /// Initiates a new SSL/TLS session over a stream connected to the specified domain.
1269 ///
1270 /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored.
handshake<S>( &self, domain: &str, stream: S, ) -> result::Result<SslStream<S>, ClientHandshakeError<S>> where S: Read + Write,1271 pub fn handshake<S>(
1272 &self,
1273 domain: &str,
1274 stream: S,
1275 ) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
1276 where
1277 S: Read + Write,
1278 {
1279 // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all
1280 // of the handshake logic through that.
1281 let stream = MidHandshakeSslStream {
1282 stream: self.ctx_into_stream(domain, stream)?,
1283 error: Error::from(errSecSuccess),
1284 };
1285
1286 let certs = self.certs.clone();
1287 let stream = MidHandshakeClientBuilder {
1288 stream,
1289 domain: if self.danger_accept_invalid_hostnames {
1290 None
1291 } else {
1292 Some(domain.to_string())
1293 },
1294 certs,
1295 trust_certs_only: self.trust_certs_only,
1296 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1297 };
1298 stream.handshake()
1299 }
1300
ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>> where S: Read + Write,1301 fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
1302 where
1303 S: Read + Write,
1304 {
1305 let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
1306
1307 if self.use_sni {
1308 ctx.set_peer_domain_name(domain)?;
1309 }
1310 if let Some(ref identity) = self.identity {
1311 ctx.set_certificate(identity, &self.chain)?;
1312 }
1313 #[cfg(feature = "alpn")]
1314 {
1315 if let Some(ref alpn) = self.alpn {
1316 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
1317 }
1318 }
1319 ctx.set_break_on_server_auth(true)?;
1320 self.configure_protocols(&mut ctx)?;
1321 self.configure_ciphers(&mut ctx)?;
1322
1323 ctx.into_stream(stream)
1324 }
1325
configure_protocols(&self, ctx: &mut SslContext) -> Result<()>1326 fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
1327 if let Some(min) = self.protocol_min {
1328 ctx.set_protocol_version_min(min)?;
1329 }
1330 if let Some(max) = self.protocol_max {
1331 ctx.set_protocol_version_max(max)?;
1332 }
1333 Ok(())
1334 }
1335
configure_ciphers(&self, ctx: &mut SslContext) -> Result<()>1336 fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
1337 let mut ciphers = if self.whitelisted_ciphers.is_empty() {
1338 ctx.enabled_ciphers()?
1339 } else {
1340 self.whitelisted_ciphers.clone()
1341 };
1342
1343 if !self.blacklisted_ciphers.is_empty() {
1344 ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
1345 }
1346
1347 ctx.set_enabled_ciphers(&ciphers)?;
1348 Ok(())
1349 }
1350 }
1351
1352 /// A builder type to simplify the creation of server-side `SslStream`s.
1353 #[derive(Debug)]
1354 pub struct ServerBuilder {
1355 identity: SecIdentity,
1356 certs: Vec<SecCertificate>,
1357 }
1358
1359 impl ServerBuilder {
1360 /// Creates a new `ServerBuilder` which will use the specified identity
1361 /// and certificate chain for handshakes.
new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self1362 pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
1363 Self {
1364 identity: identity.clone(),
1365 certs: certs.to_owned(),
1366 }
1367 }
1368
1369 /// Initiates a new SSL/TLS session over a stream.
handshake<S>(&self, stream: S) -> Result<SslStream<S>> where S: Read + Write,1370 pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
1371 where
1372 S: Read + Write,
1373 {
1374 let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
1375 ctx.set_certificate(&self.identity, &self.certs)?;
1376 match ctx.handshake(stream) {
1377 Ok(stream) => Ok(stream),
1378 Err(HandshakeError::Interrupted(stream)) => Err(Error::from_code(stream.reason())),
1379 Err(HandshakeError::Failure(err)) => Err(err),
1380 }
1381 }
1382 }
1383
1384 #[cfg(test)]
1385 mod test {
1386 use std::io;
1387 use std::io::prelude::*;
1388 use std::net::TcpStream;
1389
1390 use super::*;
1391
1392 #[test]
connect()1393 fn connect() {
1394 let mut ctx = p!(SslContext::new(
1395 SslProtocolSide::CLIENT,
1396 SslConnectionType::STREAM
1397 ));
1398 p!(ctx.set_peer_domain_name("google.com"));
1399 let stream = p!(TcpStream::connect("google.com:443"));
1400 p!(ctx.handshake(stream));
1401 }
1402
1403 #[test]
connect_bad_domain()1404 fn connect_bad_domain() {
1405 let mut ctx = p!(SslContext::new(
1406 SslProtocolSide::CLIENT,
1407 SslConnectionType::STREAM
1408 ));
1409 p!(ctx.set_peer_domain_name("foobar.com"));
1410 let stream = p!(TcpStream::connect("google.com:443"));
1411 match ctx.handshake(stream) {
1412 Ok(_) => panic!("expected failure"),
1413 Err(_) => {}
1414 }
1415 }
1416
1417 #[test]
load_page()1418 fn load_page() {
1419 let mut ctx = p!(SslContext::new(
1420 SslProtocolSide::CLIENT,
1421 SslConnectionType::STREAM
1422 ));
1423 p!(ctx.set_peer_domain_name("google.com"));
1424 let stream = p!(TcpStream::connect("google.com:443"));
1425 let mut stream = p!(ctx.handshake(stream));
1426 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1427 p!(stream.flush());
1428 let mut buf = vec![];
1429 p!(stream.read_to_end(&mut buf));
1430 println!("{}", String::from_utf8_lossy(&buf));
1431 }
1432
1433 #[test]
1434 #[cfg(feature = "alpn")]
client_alpn_accept()1435 fn client_alpn_accept() {
1436 let mut ctx = p!(SslContext::new(
1437 SslProtocolSide::CLIENT,
1438 SslConnectionType::STREAM
1439 ));
1440 p!(ctx.set_peer_domain_name("google.com"));
1441 p!(ctx.set_alpn_protocols(&vec!["h2"]));
1442 let stream = p!(TcpStream::connect("google.com:443"));
1443 let stream = ctx.handshake(stream).unwrap();
1444 assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
1445 }
1446
1447 #[test]
1448 #[cfg(feature = "alpn")]
client_alpn_reject()1449 fn client_alpn_reject() {
1450 let mut ctx = p!(SslContext::new(
1451 SslProtocolSide::CLIENT,
1452 SslConnectionType::STREAM
1453 ));
1454 p!(ctx.set_peer_domain_name("google.com"));
1455 p!(ctx.set_alpn_protocols(&vec!["h2c"]));
1456 let stream = p!(TcpStream::connect("google.com:443"));
1457 let stream = ctx.handshake(stream).unwrap();
1458 assert!(stream.context().alpn_protocols().is_err());
1459 }
1460
1461 #[test]
client_no_anchor_certs()1462 fn client_no_anchor_certs() {
1463 let stream = p!(TcpStream::connect("google.com:443"));
1464 assert!(ClientBuilder::new()
1465 .trust_anchor_certificates_only(true)
1466 .handshake("google.com", stream)
1467 .is_err());
1468 }
1469
1470 #[test]
client_bad_domain()1471 fn client_bad_domain() {
1472 let stream = p!(TcpStream::connect("google.com:443"));
1473 assert!(ClientBuilder::new()
1474 .handshake("foobar.com", stream)
1475 .is_err());
1476 }
1477
1478 #[test]
client_bad_domain_ignored()1479 fn client_bad_domain_ignored() {
1480 let stream = p!(TcpStream::connect("google.com:443"));
1481 ClientBuilder::new()
1482 .danger_accept_invalid_hostnames(true)
1483 .handshake("foobar.com", stream)
1484 .unwrap();
1485 }
1486
1487 #[test]
connect_no_verify_ssl()1488 fn connect_no_verify_ssl() {
1489 let stream = p!(TcpStream::connect("expired.badssl.com:443"));
1490 let mut builder = ClientBuilder::new();
1491 builder.danger_accept_invalid_certs(true);
1492 builder.handshake("expired.badssl.com", stream).unwrap();
1493 }
1494
1495 #[test]
load_page_client()1496 fn load_page_client() {
1497 let stream = p!(TcpStream::connect("google.com:443"));
1498 let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
1499 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
1500 p!(stream.flush());
1501 let mut buf = vec![];
1502 p!(stream.read_to_end(&mut buf));
1503 println!("{}", String::from_utf8_lossy(&buf));
1504 }
1505
1506 #[test]
1507 #[cfg_attr(target_os = "ios", ignore)] // FIXME what's going on with ios?
cipher_configuration()1508 fn cipher_configuration() {
1509 let mut ctx = p!(SslContext::new(
1510 SslProtocolSide::SERVER,
1511 SslConnectionType::STREAM
1512 ));
1513 let ciphers = p!(ctx.enabled_ciphers());
1514 let ciphers = ciphers
1515 .iter()
1516 .enumerate()
1517 .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
1518 .collect::<Vec<_>>();
1519 p!(ctx.set_enabled_ciphers(&ciphers));
1520 assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
1521 }
1522
1523 #[test]
test_builder_whitelist_ciphers()1524 fn test_builder_whitelist_ciphers() {
1525 let stream = p!(TcpStream::connect("google.com:443"));
1526
1527 let ctx = p!(SslContext::new(
1528 SslProtocolSide::CLIENT,
1529 SslConnectionType::STREAM
1530 ));
1531 assert!(p!(ctx.enabled_ciphers()).len() > 1);
1532
1533 let ciphers = p!(ctx.enabled_ciphers());
1534 let cipher = ciphers.first().unwrap();
1535 let stream = p!(ClientBuilder::new()
1536 .whitelist_ciphers(&[*cipher])
1537 .ctx_into_stream("google.com", stream));
1538
1539 assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
1540 }
1541
1542 #[test]
1543 #[cfg_attr(target_os = "ios", ignore)] // FIXME same issue as cipher_configuration
test_builder_blacklist_ciphers()1544 fn test_builder_blacklist_ciphers() {
1545 let stream = p!(TcpStream::connect("google.com:443"));
1546
1547 let ctx = p!(SslContext::new(
1548 SslProtocolSide::CLIENT,
1549 SslConnectionType::STREAM
1550 ));
1551 let num = p!(ctx.enabled_ciphers()).len();
1552 assert!(num > 1);
1553
1554 let ciphers = p!(ctx.enabled_ciphers());
1555 let cipher = ciphers.first().unwrap();
1556 let stream = p!(ClientBuilder::new()
1557 .blacklist_ciphers(&[*cipher])
1558 .ctx_into_stream("google.com", stream));
1559
1560 assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
1561 }
1562
1563 #[test]
idle_context_peer_trust()1564 fn idle_context_peer_trust() {
1565 let ctx = p!(SslContext::new(
1566 SslProtocolSide::SERVER,
1567 SslConnectionType::STREAM
1568 ));
1569 assert!(ctx.peer_trust2().is_err());
1570 }
1571
1572 #[test]
peer_id()1573 fn peer_id() {
1574 let mut ctx = p!(SslContext::new(
1575 SslProtocolSide::SERVER,
1576 SslConnectionType::STREAM
1577 ));
1578 assert!(p!(ctx.peer_id()).is_none());
1579 p!(ctx.set_peer_id(b"foobar"));
1580 assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
1581 }
1582
1583 #[test]
peer_domain_name()1584 fn peer_domain_name() {
1585 let mut ctx = p!(SslContext::new(
1586 SslProtocolSide::CLIENT,
1587 SslConnectionType::STREAM
1588 ));
1589 assert_eq!("", p!(ctx.peer_domain_name()));
1590 p!(ctx.set_peer_domain_name("foobar.com"));
1591 assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
1592 }
1593
1594 #[test]
1595 #[should_panic(expected = "blammo")]
write_panic()1596 fn write_panic() {
1597 struct ExplodingStream(TcpStream);
1598
1599 impl Read for ExplodingStream {
1600 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
1601 self.0.read(buf)
1602 }
1603 }
1604
1605 impl Write for ExplodingStream {
1606 fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1607 panic!("blammo");
1608 }
1609
1610 fn flush(&mut self) -> io::Result<()> {
1611 self.0.flush()
1612 }
1613 }
1614
1615 let mut ctx = p!(SslContext::new(
1616 SslProtocolSide::CLIENT,
1617 SslConnectionType::STREAM
1618 ));
1619 p!(ctx.set_peer_domain_name("google.com"));
1620 let stream = p!(TcpStream::connect("google.com:443"));
1621 let _ = ctx.handshake(ExplodingStream(stream));
1622 }
1623
1624 #[test]
1625 #[should_panic(expected = "blammo")]
read_panic()1626 fn read_panic() {
1627 struct ExplodingStream(TcpStream);
1628
1629 impl Read for ExplodingStream {
1630 fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
1631 panic!("blammo");
1632 }
1633 }
1634
1635 impl Write for ExplodingStream {
1636 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
1637 self.0.write(buf)
1638 }
1639
1640 fn flush(&mut self) -> io::Result<()> {
1641 self.0.flush()
1642 }
1643 }
1644
1645 let mut ctx = p!(SslContext::new(
1646 SslProtocolSide::CLIENT,
1647 SslConnectionType::STREAM
1648 ));
1649 p!(ctx.set_peer_domain_name("google.com"));
1650 let stream = p!(TcpStream::connect("google.com:443"));
1651 let _ = ctx.handshake(ExplodingStream(stream));
1652 }
1653
1654 #[test]
zero_length_buffers()1655 fn zero_length_buffers() {
1656 let mut ctx = p!(SslContext::new(
1657 SslProtocolSide::CLIENT,
1658 SslConnectionType::STREAM
1659 ));
1660 p!(ctx.set_peer_domain_name("google.com"));
1661 let stream = p!(TcpStream::connect("google.com:443"));
1662 let mut stream = ctx.handshake(stream).unwrap();
1663 assert_eq!(stream.write(b"").unwrap(), 0);
1664 assert_eq!(stream.read(&mut []).unwrap(), 0);
1665 }
1666 }
1667