1 //! SSL/TLS encryption support using Secure Transport. 2 //! 3 //! # Examples 4 //! 5 //! To connect as a client to a server with a certificate trusted by the system: 6 //! 7 //! ```rust 8 //! use std::io::prelude::*; 9 //! use std::net::TcpStream; 10 //! use security_framework::secure_transport::ClientBuilder; 11 //! 12 //! let stream = TcpStream::connect("google.com:443").unwrap(); 13 //! let mut stream = ClientBuilder::new().handshake("google.com", stream).unwrap(); 14 //! 15 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); 16 //! let mut page = vec![]; 17 //! stream.read_to_end(&mut page).unwrap(); 18 //! println!("{}", String::from_utf8_lossy(&page)); 19 //! ``` 20 //! 21 //! To connect to a server with a certificate that's *not* trusted by the 22 //! system, specify the root certificates for the server's chain to the 23 //! `ClientBuilder`: 24 //! 25 //! ```rust,no_run 26 //! use std::io::prelude::*; 27 //! use std::net::TcpStream; 28 //! use security_framework::secure_transport::ClientBuilder; 29 //! 30 //! # let root_cert = unsafe { std::mem::zeroed() }; 31 //! let stream = TcpStream::connect("my_server.com:443").unwrap(); 32 //! let mut stream = ClientBuilder::new() 33 //! .anchor_certificates(&[root_cert]) 34 //! .handshake("my_server.com", stream) 35 //! .unwrap(); 36 //! 37 //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); 38 //! let mut page = vec![]; 39 //! stream.read_to_end(&mut page).unwrap(); 40 //! println!("{}", String::from_utf8_lossy(&page)); 41 //! ``` 42 //! 43 //! For more advanced configuration, the `SslContext` type can be used directly. 44 //! 45 //! To run a server: 46 //! 47 //! ```rust,no_run 48 //! use std::net::TcpListener; 49 //! use std::thread; 50 //! use security_framework::secure_transport::{SslContext, SslProtocolSide, SslConnectionType}; 51 //! 52 //! // Create a TCP listener and start accepting on it. 53 //! let mut listener = TcpListener::bind("0.0.0.0:443").unwrap(); 54 //! 55 //! for stream in listener.incoming() { 56 //! let stream = stream.unwrap(); 57 //! thread::spawn(move || { 58 //! // Create a new context configured to operate on the server side of 59 //! // a traditional SSL/TLS session. 60 //! let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM) 61 //! .unwrap(); 62 //! 63 //! // Install the certificate chain that we will be using. 64 //! # let identity = unsafe { std::mem::zeroed() }; 65 //! # let intermediate_cert = unsafe { std::mem::zeroed() }; 66 //! # let root_cert = unsafe { std::mem::zeroed() }; 67 //! ctx.set_certificate(identity, &[intermediate_cert, root_cert]).unwrap(); 68 //! 69 //! // Perform the SSL/TLS handshake and get our stream. 70 //! let mut stream = ctx.handshake(stream).unwrap(); 71 //! }); 72 //! } 73 //! 74 //! ``` 75 #[allow(unused_imports)] 76 use core_foundation::array::{CFArray, CFArrayRef}; 77 78 use core_foundation::base::{Boolean, TCFType}; 79 #[cfg(feature = "alpn")] 80 use core_foundation::string::CFString; 81 use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus}; 82 use std::os::raw::c_void; 83 84 #[allow(unused_imports)] 85 use security_framework_sys::base::{ 86 errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny, 87 errSecUnimplemented, 88 }; 89 90 use security_framework_sys::secure_transport::*; 91 use std::any::Any; 92 use std::cmp; 93 use std::fmt; 94 use std::io; 95 use std::io::prelude::*; 96 use std::marker::PhantomData; 97 use std::mem; 98 use std::panic::{self, AssertUnwindSafe}; 99 use std::ptr; 100 use std::result; 101 use std::slice; 102 103 use crate::base::{Error, Result}; 104 use crate::certificate::SecCertificate; 105 use crate::cipher_suite::CipherSuite; 106 use crate::identity::SecIdentity; 107 use crate::policy::SecPolicy; 108 use crate::trust::{SecTrust, TrustResult}; 109 use crate::{cvt, AsInner}; 110 111 /// Specifies a side of a TLS session. 112 #[derive(Debug, Copy, Clone, PartialEq, Eq)] 113 pub struct SslProtocolSide(SSLProtocolSide); 114 115 impl SslProtocolSide { 116 /// The server side of the session. 117 pub const SERVER: SslProtocolSide = SslProtocolSide(kSSLServerSide); 118 119 /// The client side of the session. 120 pub const CLIENT: SslProtocolSide = SslProtocolSide(kSSLClientSide); 121 } 122 123 /// Specifies the type of TLS session. 124 #[derive(Debug, Copy, Clone)] 125 pub struct SslConnectionType(SSLConnectionType); 126 127 impl SslConnectionType { 128 /// A traditional TLS stream. 129 pub const STREAM: SslConnectionType = SslConnectionType(kSSLStreamType); 130 131 /// A DTLS session. 132 pub const DATAGRAM: SslConnectionType = SslConnectionType(kSSLDatagramType); 133 } 134 135 /// An error or intermediate state after a TLS handshake attempt. 136 #[derive(Debug)] 137 pub enum HandshakeError<S> { 138 /// The handshake failed. 139 Failure(Error), 140 /// The handshake was interrupted midway through. 141 Interrupted(MidHandshakeSslStream<S>), 142 } 143 144 impl<S> From<Error> for HandshakeError<S> { 145 fn from(err: Error) -> HandshakeError<S> { 146 HandshakeError::Failure(err) 147 } 148 } 149 150 /// An error or intermediate state after a TLS handshake attempt. 151 #[derive(Debug)] 152 pub enum ClientHandshakeError<S> { 153 /// The handshake failed. 154 Failure(Error), 155 /// The handshake was interrupted midway through. 156 Interrupted(MidHandshakeClientBuilder<S>), 157 } 158 159 impl<S> From<Error> for ClientHandshakeError<S> { 160 fn from(err: Error) -> ClientHandshakeError<S> { 161 ClientHandshakeError::Failure(err) 162 } 163 } 164 165 /// An SSL stream midway through the handshake process. 166 #[derive(Debug)] 167 pub struct MidHandshakeSslStream<S> { 168 stream: SslStream<S>, 169 error: Error, 170 } 171 172 impl<S> MidHandshakeSslStream<S> { 173 /// Returns a shared reference to the inner stream. 174 pub fn get_ref(&self) -> &S { 175 self.stream.get_ref() 176 } 177 178 /// Returns a mutable reference to the inner stream. 179 pub fn get_mut(&mut self) -> &mut S { 180 self.stream.get_mut() 181 } 182 183 /// Returns a shared reference to the `SslContext` of the stream. 184 pub fn context(&self) -> &SslContext { 185 self.stream.context() 186 } 187 188 /// Returns a mutable reference to the `SslContext` of the stream. 189 pub fn context_mut(&mut self) -> &mut SslContext { 190 self.stream.context_mut() 191 } 192 193 /// Returns `true` iff `break_on_server_auth` was set and the handshake has 194 /// progressed to that point. 195 pub fn server_auth_completed(&self) -> bool { 196 self.error.code() == errSSLPeerAuthCompleted 197 } 198 199 /// Returns `true` iff `break_on_cert_requested` was set and the handshake 200 /// has progressed to that point. 201 pub fn client_cert_requested(&self) -> bool { 202 self.error.code() == errSSLClientCertRequested 203 } 204 205 /// Returns `true` iff the underlying stream returned an error with the 206 /// `WouldBlock` kind. 207 pub fn would_block(&self) -> bool { 208 self.error.code() == errSSLWouldBlock 209 } 210 211 /// Deprecated 212 pub fn reason(&self) -> OSStatus { 213 self.error.code() 214 } 215 216 /// Returns the error which caused the handshake interruption. 217 pub fn error(&self) -> &Error { 218 &self.error 219 } 220 221 /// Restarts the handshake process. 222 pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> { 223 self.stream.handshake() 224 } 225 } 226 227 /// An SSL stream midway through the handshake process. 228 #[derive(Debug)] 229 pub struct MidHandshakeClientBuilder<S> { 230 stream: MidHandshakeSslStream<S>, 231 domain: Option<String>, 232 certs: Vec<SecCertificate>, 233 trust_certs_only: bool, 234 danger_accept_invalid_certs: bool, 235 } 236 237 impl<S> MidHandshakeClientBuilder<S> { 238 /// Returns a shared reference to the inner stream. 239 pub fn get_ref(&self) -> &S { 240 self.stream.get_ref() 241 } 242 243 /// Returns a mutable reference to the inner stream. 244 pub fn get_mut(&mut self) -> &mut S { 245 self.stream.get_mut() 246 } 247 248 /// Returns the error which caused the handshake interruption. 249 pub fn error(&self) -> &Error { 250 self.stream.error() 251 } 252 253 /// Restarts the handshake process. 254 pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> { 255 let MidHandshakeClientBuilder { 256 stream, 257 domain, 258 certs, 259 trust_certs_only, 260 danger_accept_invalid_certs, 261 } = self; 262 263 let mut result = stream.handshake(); 264 loop { 265 let stream = match result { 266 Ok(stream) => return Ok(stream), 267 Err(HandshakeError::Interrupted(stream)) => stream, 268 Err(HandshakeError::Failure(err)) => { 269 return Err(ClientHandshakeError::Failure(err)) 270 } 271 }; 272 273 if stream.would_block() { 274 let ret = MidHandshakeClientBuilder { 275 stream, 276 domain, 277 certs, 278 trust_certs_only, 279 danger_accept_invalid_certs, 280 }; 281 return Err(ClientHandshakeError::Interrupted(ret)); 282 } 283 284 if stream.server_auth_completed() { 285 if danger_accept_invalid_certs { 286 result = stream.handshake(); 287 continue; 288 } 289 let mut trust = match stream.context().peer_trust2()? { 290 Some(trust) => trust, 291 None => { 292 result = stream.handshake(); 293 continue; 294 } 295 }; 296 trust.set_anchor_certificates(&certs)?; 297 trust.set_trust_anchor_certificates_only(self.trust_certs_only)?; 298 let policy = 299 SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_ref().map(|s| &**s)); 300 trust.set_policy(&policy)?; 301 let trusted = trust.evaluate()?; 302 match trusted { 303 TrustResult::PROCEED | TrustResult::UNSPECIFIED => { 304 result = stream.handshake(); 305 continue; 306 } 307 TrustResult::DENY => { 308 let err = Error::from_code(errSecTrustSettingDeny); 309 return Err(ClientHandshakeError::Failure(err)); 310 } 311 TrustResult::RECOVERABLE_TRUST_FAILURE | TrustResult::FATAL_TRUST_FAILURE => { 312 let err = Error::from_code(errSecNotTrusted); 313 return Err(ClientHandshakeError::Failure(err)); 314 } 315 TrustResult::INVALID | TrustResult::OTHER_ERROR | _ => { 316 let err = Error::from_code(errSecBadReq); 317 return Err(ClientHandshakeError::Failure(err)); 318 } 319 } 320 } 321 322 let err = Error::from_code(stream.reason()); 323 return Err(ClientHandshakeError::Failure(err)); 324 } 325 } 326 } 327 328 /// Specifies the state of a TLS session. 329 #[derive(Debug, PartialEq, Eq)] 330 pub struct SessionState(SSLSessionState); 331 332 impl SessionState { 333 /// The session has not yet started. 334 pub const IDLE: SessionState = SessionState(kSSLIdle); 335 336 /// The session is in the handshake process. 337 pub const HANDSHAKE: SessionState = SessionState(kSSLHandshake); 338 339 /// The session is connected. 340 pub const CONNECTED: SessionState = SessionState(kSSLConnected); 341 342 /// The session has been terminated. 343 pub const CLOSED: SessionState = SessionState(kSSLClosed); 344 345 /// The session has been aborted due to an error. 346 pub const ABORTED: SessionState = SessionState(kSSLAborted); 347 } 348 349 /// Specifies a server's requirement for client certificates. 350 #[derive(Debug, Copy, Clone, PartialEq, Eq)] 351 pub struct SslAuthenticate(SSLAuthenticate); 352 353 impl SslAuthenticate { 354 /// Do not request a client certificate. 355 pub const NEVER: SslAuthenticate = SslAuthenticate(kNeverAuthenticate); 356 357 /// Require a client certificate. 358 pub const ALWAYS: SslAuthenticate = SslAuthenticate(kAlwaysAuthenticate); 359 360 /// Request but do not require a client certificate. 361 pub const TRY: SslAuthenticate = SslAuthenticate(kTryAuthenticate); 362 } 363 364 /// Specifies the state of client certificate processing. 365 #[derive(Debug, Copy, Clone, PartialEq, Eq)] 366 pub struct SslClientCertificateState(SSLClientCertificateState); 367 368 impl SslClientCertificateState { 369 /// A client certificate has not been requested or sent. 370 pub const NONE: SslClientCertificateState = SslClientCertificateState(kSSLClientCertNone); 371 372 /// A client certificate has been requested but not recieved. 373 pub const REQUESTED: SslClientCertificateState = 374 SslClientCertificateState(kSSLClientCertRequested); 375 376 /// A client certificate has been received and successfully validated. 377 pub const SENT: SslClientCertificateState = SslClientCertificateState(kSSLClientCertSent); 378 379 /// A client certificate has been received but has failed to validate. 380 pub const REJECTED: SslClientCertificateState = 381 SslClientCertificateState(kSSLClientCertRejected); 382 } 383 384 /// Specifies protocol versions. 385 #[derive(Debug, Copy, Clone, PartialEq, Eq)] 386 pub struct SslProtocol(SSLProtocol); 387 388 impl SslProtocol { 389 /// No protocol has been or should be negotiated or specified; use the default. 390 pub const UNKNOWN: SslProtocol = SslProtocol(kSSLProtocolUnknown); 391 392 /// The SSL 3.0 protocol is preferred, though SSL 2.0 may be used if the peer does not support 393 /// SSL 3.0. 394 pub const SSL3: SslProtocol = SslProtocol(kSSLProtocol3); 395 396 /// The TLS 1.0 protocol is preferred, though lower versions may be used 397 /// if the peer does not support TLS 1.0. 398 pub const TLS1: SslProtocol = SslProtocol(kTLSProtocol1); 399 400 /// The TLS 1.1 protocol is preferred, though lower versions may be used 401 /// if the peer does not support TLS 1.1. 402 pub const TLS11: SslProtocol = SslProtocol(kTLSProtocol11); 403 404 /// The TLS 1.2 protocol is preferred, though lower versions may be used 405 /// if the peer does not support TLS 1.2. 406 pub const TLS12: SslProtocol = SslProtocol(kTLSProtocol12); 407 408 /// Only the SSL 2.0 protocol is accepted. 409 pub const SSL2: SslProtocol = SslProtocol(kSSLProtocol2); 410 411 /// The DTLSv1 protocol is preferred. 412 pub const DTLS1: SslProtocol = SslProtocol(kDTLSProtocol1); 413 414 /// Only the SSL 3.0 protocol is accepted. 415 pub const SSL3_ONLY: SslProtocol = SslProtocol(kSSLProtocol3Only); 416 417 /// Only the TLS 1.0 protocol is accepted. 418 pub const TLS1_ONLY: SslProtocol = SslProtocol(kTLSProtocol1Only); 419 420 /// All supported TLS/SSL versions are accepted. 421 pub const ALL: SslProtocol = SslProtocol(kSSLProtocolAll); 422 } 423 424 declare_TCFType! { 425 /// A Secure Transport SSL/TLS context object. 426 SslContext, SSLContextRef 427 } 428 429 impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID); 430 431 impl fmt::Debug for SslContext { 432 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { 433 let mut builder = fmt.debug_struct("SslContext"); 434 if let Ok(state) = self.state() { 435 builder.field("state", &state); 436 } 437 builder.finish() 438 } 439 } 440 441 unsafe impl Sync for SslContext {} 442 unsafe impl Send for SslContext {} 443 444 impl AsInner for SslContext { 445 type Inner = SSLContextRef; 446 447 fn as_inner(&self) -> SSLContextRef { 448 self.0 449 } 450 } 451 452 macro_rules! impl_options { 453 ($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => { 454 $( 455 $(#[$a])* 456 pub fn $set(&mut self, value: bool) -> Result<()> { 457 unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) } 458 } 459 460 $(#[$a])* 461 pub fn $get(&self) -> Result<bool> { 462 let mut value = 0; 463 unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; } 464 Ok(value != 0) 465 } 466 )* 467 } 468 } 469 470 impl SslContext { 471 /// Creates a new `SslContext` for the specified side and type of SSL 472 /// connection. 473 pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<SslContext> { 474 unsafe { 475 let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0); 476 Ok(SslContext(ctx)) 477 } 478 } 479 480 /// Sets the fully qualified domain name of the peer. 481 /// 482 /// This will be used on the client side of a session to validate the 483 /// common name field of the server's certificate. It has no effect if 484 /// called on a server-side `SslContext`. 485 /// 486 /// It is *highly* recommended to call this method before starting the 487 /// handshake process. 488 pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> { 489 unsafe { 490 // SSLSetPeerDomainName doesn't need a null terminated string 491 cvt(SSLSetPeerDomainName( 492 self.0, 493 peer_name.as_ptr() as *const _, 494 peer_name.len(), 495 )) 496 } 497 } 498 499 /// Returns the peer domain name set by `set_peer_domain_name`. 500 pub fn peer_domain_name(&self) -> Result<String> { 501 unsafe { 502 let mut len = 0; 503 cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?; 504 let mut buf = vec![0; len]; 505 cvt(SSLGetPeerDomainName( 506 self.0, 507 buf.as_mut_ptr() as *mut _, 508 &mut len, 509 ))?; 510 Ok(String::from_utf8(buf).unwrap()) 511 } 512 } 513 514 /// Sets the certificate to be used by this side of the SSL session. 515 /// 516 /// This must be called before the handshake for server-side connections, 517 /// and can be used on the client-side to specify a client certificate. 518 /// 519 /// The `identity` corresponds to the leaf certificate and private 520 /// key, and the `certs` correspond to extra certificates in the chain. 521 pub fn set_certificate( 522 &mut self, 523 identity: &SecIdentity, 524 certs: &[SecCertificate], 525 ) -> Result<()> { 526 let mut arr = vec![identity.as_CFType()]; 527 arr.extend(certs.iter().map(|c| c.as_CFType())); 528 let certs = CFArray::from_CFTypes(&arr); 529 530 unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) } 531 } 532 533 /// Sets the peer ID of this session. 534 /// 535 /// A peer ID is an opaque sequence of bytes that will be used by Secure 536 /// Transport to identify the peer of an SSL session. If the peer ID of 537 /// this session matches that of a previously terminated session, the 538 /// previous session can be resumed without requiring a full handshake. 539 pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> { 540 unsafe { 541 cvt(SSLSetPeerID( 542 self.0, 543 peer_id.as_ptr() as *const _, 544 peer_id.len(), 545 )) 546 } 547 } 548 549 /// Returns the peer ID of this session. 550 pub fn peer_id(&self) -> Result<Option<&[u8]>> { 551 unsafe { 552 let mut ptr = ptr::null(); 553 let mut len = 0; 554 cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?; 555 if ptr.is_null() { 556 Ok(None) 557 } else { 558 Ok(Some(slice::from_raw_parts(ptr as *const _, len))) 559 } 560 } 561 } 562 563 /// Returns the list of ciphers that are supported by Secure Transport. 564 pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> { 565 unsafe { 566 let mut num_ciphers = 0; 567 cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?; 568 let mut ciphers = vec![0; num_ciphers]; 569 cvt(SSLGetSupportedCiphers( 570 self.0, 571 ciphers.as_mut_ptr(), 572 &mut num_ciphers, 573 ))?; 574 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect()) 575 } 576 } 577 578 /// Returns the list of ciphers that are eligible to be used for 579 /// negotiation. 580 pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> { 581 unsafe { 582 let mut num_ciphers = 0; 583 cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?; 584 let mut ciphers = vec![0; num_ciphers]; 585 cvt(SSLGetEnabledCiphers( 586 self.0, 587 ciphers.as_mut_ptr(), 588 &mut num_ciphers, 589 ))?; 590 Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect()) 591 } 592 } 593 594 /// Sets the list of ciphers that are eligible to be used for negotiation. 595 pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> { 596 let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>(); 597 unsafe { 598 cvt(SSLSetEnabledCiphers( 599 self.0, 600 ciphers.as_ptr(), 601 ciphers.len(), 602 )) 603 } 604 } 605 606 /// Returns the cipher being used by the session. 607 pub fn negotiated_cipher(&self) -> Result<CipherSuite> { 608 unsafe { 609 let mut cipher = 0; 610 cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?; 611 Ok(CipherSuite::from_raw(cipher)) 612 } 613 } 614 615 /// Sets the requirements for client certificates. 616 /// 617 /// Should only be called on server-side sessions. 618 pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> { 619 unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) } 620 } 621 622 /// Returns the state of client certificate processing. 623 pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> { 624 let mut state = 0; 625 626 unsafe { 627 cvt(SSLGetClientCertificateState(self.0, &mut state))?; 628 } 629 Ok(SslClientCertificateState(state)) 630 } 631 632 /// Returns the `SecTrust` object corresponding to the peer. 633 /// 634 /// This can be used in conjunction with `set_break_on_server_auth` to 635 /// validate certificates which do not have roots in the default set. 636 pub fn peer_trust2(&self) -> Result<Option<SecTrust>> { 637 // Calling SSLCopyPeerTrust on an idle connection does not seem to be well defined, 638 // so explicitly check for that 639 if self.state()? == SessionState::IDLE { 640 return Err(Error::from_code(errSecBadReq)); 641 } 642 643 unsafe { 644 let mut trust = ptr::null_mut(); 645 cvt(SSLCopyPeerTrust(self.0, &mut trust))?; 646 if trust.is_null() { 647 Ok(None) 648 } else { 649 Ok(Some(SecTrust::wrap_under_create_rule(trust))) 650 } 651 } 652 } 653 654 #[allow(missing_docs)] 655 #[deprecated(since = "0.2.1", note = "use peer_trust2 instead")] 656 pub fn peer_trust(&self) -> Result<SecTrust> { 657 match self.peer_trust2() { 658 Ok(Some(trust)) => Ok(trust), 659 Ok(None) => panic!("no trust available"), 660 Err(e) => Err(e), 661 } 662 } 663 664 /// Returns the state of the session. 665 pub fn state(&self) -> Result<SessionState> { 666 unsafe { 667 let mut state = 0; 668 cvt(SSLGetSessionState(self.0, &mut state))?; 669 Ok(SessionState(state)) 670 } 671 } 672 673 /// Returns the protocol version being used by the session. 674 pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> { 675 unsafe { 676 let mut version = 0; 677 cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?; 678 Ok(SslProtocol(version)) 679 } 680 } 681 682 /// Returns the maximum protocol version allowed by the session. 683 pub fn protocol_version_max(&self) -> Result<SslProtocol> { 684 unsafe { 685 let mut version = 0; 686 cvt(SSLGetProtocolVersionMax(self.0, &mut version))?; 687 Ok(SslProtocol(version)) 688 } 689 } 690 691 /// Sets the maximum protocol version allowed by the session. 692 pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> { 693 unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) } 694 } 695 696 /// Returns the minimum protocol version allowed by the session. 697 pub fn protocol_version_min(&self) -> Result<SslProtocol> { 698 unsafe { 699 let mut version = 0; 700 cvt(SSLGetProtocolVersionMin(self.0, &mut version))?; 701 Ok(SslProtocol(version)) 702 } 703 } 704 705 /// Sets the minimum protocol version allowed by the session. 706 pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> { 707 unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) } 708 } 709 710 /// Returns the set of protocols selected via ALPN if it succeeded. 711 #[cfg(feature = "alpn")] 712 pub fn alpn_protocols(&self) -> Result<Vec<String>> { 713 let mut array = ptr::null(); 714 unsafe { 715 #[cfg(feature = "OSX_10_13")] 716 { 717 cvt(SSLCopyALPNProtocols(self.0, &mut array))?; 718 } 719 720 #[cfg(not(feature = "OSX_10_13"))] 721 { 722 dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus } 723 if let Some(f) = SSLCopyALPNProtocols.get() { 724 cvt(f(self.0, &mut array))?; 725 } else { 726 return Err(Error::from_code(errSecUnimplemented)); 727 } 728 } 729 730 if array.is_null() { 731 return Ok(vec![]); 732 } 733 734 let array = CFArray::<CFString>::wrap_under_create_rule(array); 735 Ok(array.into_iter().map(|p| p.to_string()).collect()) 736 } 737 } 738 739 /// Configures the set of protocols use for ALPN. 740 /// 741 /// This is only used for client-side connections. 742 #[cfg(feature = "alpn")] 743 pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> { 744 // When CFMutableArray is added to core-foundation and IntoIterator trait 745 // is implemented for CFMutableArray, the code below should directly collect 746 // into a CFMutableArray. 747 let protocols = CFArray::from_CFTypes( 748 &protocols 749 .iter() 750 .map(|proto| CFString::new(proto)) 751 .collect::<Vec<_>>(), 752 ); 753 754 #[cfg(feature = "OSX_10_13")] 755 { 756 unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) } 757 } 758 #[cfg(not(feature = "OSX_10_13"))] 759 { 760 dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus } 761 if let Some(f) = SSLSetALPNProtocols.get() { 762 unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) } 763 } else { 764 Err(Error::from_code(errSecUnimplemented)) 765 } 766 } 767 } 768 769 /// Sets whether a protocol is enabled or not. 770 /// 771 /// # Note 772 /// 773 /// On OSX this is a deprecated API in favor of `set_protocol_version_max` and 774 /// `set_protocol_version_min`, although if you're working with OSX 10.8 or before you may have 775 /// to use this API instead. 776 #[cfg(target_os = "macos")] 777 pub fn set_protocol_version_enabled( 778 &mut self, 779 protocol: SslProtocol, 780 enabled: bool, 781 ) -> Result<()> { 782 unsafe { 783 cvt(SSLSetProtocolVersionEnabled( 784 self.0, 785 protocol.0, 786 enabled as Boolean, 787 )) 788 } 789 } 790 791 /// Returns the number of bytes which can be read without triggering a 792 /// `read` call in the underlying stream. 793 pub fn buffered_read_size(&self) -> Result<usize> { 794 unsafe { 795 let mut size = 0; 796 cvt(SSLGetBufferedReadSize(self.0, &mut size))?; 797 Ok(size) 798 } 799 } 800 801 impl_options! { 802 /// If enabled, the handshake process will pause and return instead of 803 /// automatically validating a server's certificate. 804 const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth, 805 /// If enabled, the handshake process will pause and return after 806 /// the server requests a certificate from the client. 807 const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested, 808 /// If enabled, the handshake process will pause and return instead of 809 /// automatically validating a client's certificate. 810 const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth, 811 /// If enabled, TLS false start will be performed if an appropriate 812 /// cipher suite is negotiated. 813 /// 814 /// Requires the `OSX_10_9` (or greater) feature. 815 #[cfg(feature = "OSX_10_9")] 816 const kSSLSessionOptionFalseStart: false_start & set_false_start, 817 /// If enabled, 1/n-1 record splitting will be enabled for TLS 1.0 818 /// connections using block ciphers to mitigate the BEAST attack. 819 /// 820 /// Requires the `OSX_10_9` (or greater) feature. 821 #[cfg(feature = "OSX_10_9")] 822 const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record, 823 } 824 825 fn into_stream<S>(self, stream: S) -> Result<SslStream<S>> 826 where 827 S: Read + Write, 828 { 829 unsafe { 830 let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>); 831 if ret != errSecSuccess { 832 return Err(Error::from_code(ret)); 833 } 834 835 let stream = Connection { 836 stream, 837 err: None, 838 panic: None, 839 }; 840 let stream = Box::into_raw(Box::new(stream)); 841 let ret = SSLSetConnection(self.0, stream as *mut _); 842 if ret != errSecSuccess { 843 let _conn = Box::from_raw(stream); 844 return Err(Error::from_code(ret)); 845 } 846 847 Ok(SslStream { 848 ctx: self, 849 _m: PhantomData, 850 }) 851 } 852 } 853 854 /// Performs the SSL/TLS handshake. 855 pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>> 856 where 857 S: Read + Write, 858 { 859 self.into_stream(stream) 860 .map_err(HandshakeError::Failure) 861 .and_then(SslStream::handshake) 862 } 863 } 864 865 struct Connection<S> { 866 stream: S, 867 err: Option<io::Error>, 868 panic: Option<Box<dyn Any + Send>>, 869 } 870 871 // the logic here is based off of libcurl's 872 873 fn translate_err(e: &io::Error) -> OSStatus { 874 match e.kind() { 875 io::ErrorKind::NotFound => errSSLClosedGraceful, 876 io::ErrorKind::ConnectionReset => errSSLClosedAbort, 877 io::ErrorKind::WouldBlock => errSSLWouldBlock, 878 io::ErrorKind::NotConnected => errSSLWouldBlock, 879 _ => errSecIO, 880 } 881 } 882 883 unsafe extern "C" fn read_func<S>( 884 connection: SSLConnectionRef, 885 data: *mut c_void, 886 data_length: *mut usize, 887 ) -> OSStatus 888 where 889 S: Read, 890 { 891 let conn: &mut Connection<S> = &mut *(connection as *mut _); 892 let data = slice::from_raw_parts_mut(data as *mut u8, *data_length); 893 let mut start = 0; 894 let mut ret = errSecSuccess; 895 896 while start < data.len() { 897 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) { 898 Ok(Ok(0)) => { 899 ret = errSSLClosedNoNotify; 900 break; 901 } 902 Ok(Ok(len)) => start += len, 903 Ok(Err(e)) => { 904 ret = translate_err(&e); 905 conn.err = Some(e); 906 break; 907 } 908 Err(e) => { 909 ret = errSecIO; 910 conn.panic = Some(e); 911 break; 912 } 913 } 914 } 915 916 *data_length = start; 917 ret 918 } 919 920 unsafe extern "C" fn write_func<S>( 921 connection: SSLConnectionRef, 922 data: *const c_void, 923 data_length: *mut usize, 924 ) -> OSStatus 925 where 926 S: Write, 927 { 928 let conn: &mut Connection<S> = &mut *(connection as *mut _); 929 let data = slice::from_raw_parts(data as *mut u8, *data_length); 930 let mut start = 0; 931 let mut ret = errSecSuccess; 932 933 while start < data.len() { 934 match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) { 935 Ok(Ok(0)) => { 936 ret = errSSLClosedNoNotify; 937 break; 938 } 939 Ok(Ok(len)) => start += len, 940 Ok(Err(e)) => { 941 ret = translate_err(&e); 942 conn.err = Some(e); 943 break; 944 } 945 Err(e) => { 946 ret = errSecIO; 947 conn.panic = Some(e); 948 break; 949 } 950 } 951 } 952 953 *data_length = start; 954 ret 955 } 956 957 /// A type implementing SSL/TLS encryption over an underlying stream. 958 pub struct SslStream<S> { 959 ctx: SslContext, 960 _m: PhantomData<S>, 961 } 962 963 impl<S: fmt::Debug> fmt::Debug for SslStream<S> { 964 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { 965 fmt.debug_struct("SslStream") 966 .field("context", &self.ctx) 967 .field("stream", self.get_ref()) 968 .finish() 969 } 970 } 971 972 impl<S> Drop for SslStream<S> { 973 fn drop(&mut self) { 974 unsafe { 975 let mut conn = ptr::null(); 976 let ret = SSLGetConnection(self.ctx.0, &mut conn); 977 assert!(ret == errSecSuccess); 978 Box::<Connection<S>>::from_raw(conn as *mut _); 979 } 980 } 981 } 982 983 impl<S> SslStream<S> { 984 fn handshake(mut self) -> result::Result<SslStream<S>, HandshakeError<S>> { 985 match unsafe { SSLHandshake(self.ctx.0) } { 986 errSecSuccess => Ok(self), 987 reason @ errSSLPeerAuthCompleted 988 | reason @ errSSLClientCertRequested 989 | reason @ errSSLWouldBlock 990 | reason @ errSSLClientHelloReceived => { 991 Err(HandshakeError::Interrupted(MidHandshakeSslStream { 992 stream: self, 993 error: Error::from_code(reason), 994 })) 995 } 996 err => { 997 self.check_panic(); 998 Err(HandshakeError::Failure(Error::from_code(err))) 999 } 1000 } 1001 } 1002 1003 /// Returns a shared reference to the inner stream. 1004 pub fn get_ref(&self) -> &S { 1005 &self.connection().stream 1006 } 1007 1008 /// Returns a mutable reference to the underlying stream. 1009 pub fn get_mut(&mut self) -> &mut S { 1010 &mut self.connection_mut().stream 1011 } 1012 1013 /// Returns a shared reference to the `SslContext` of the stream. 1014 pub fn context(&self) -> &SslContext { 1015 &self.ctx 1016 } 1017 1018 /// Returns a mutable reference to the `SslContext` of the stream. 1019 pub fn context_mut(&mut self) -> &mut SslContext { 1020 &mut self.ctx 1021 } 1022 1023 /// Shuts down the connection. 1024 pub fn close(&mut self) -> result::Result<(), io::Error> { 1025 unsafe { 1026 let ret = SSLClose(self.ctx.0); 1027 if ret == errSecSuccess { 1028 Ok(()) 1029 } else { 1030 Err(self.get_error(ret)) 1031 } 1032 } 1033 } 1034 1035 fn connection(&self) -> &Connection<S> { 1036 unsafe { 1037 let mut conn = ptr::null(); 1038 let ret = SSLGetConnection(self.ctx.0, &mut conn); 1039 assert!(ret == errSecSuccess); 1040 1041 mem::transmute(conn) 1042 } 1043 } 1044 1045 fn connection_mut(&mut self) -> &mut Connection<S> { 1046 unsafe { 1047 let mut conn = ptr::null(); 1048 let ret = SSLGetConnection(self.ctx.0, &mut conn); 1049 assert!(ret == errSecSuccess); 1050 1051 mem::transmute(conn) 1052 } 1053 } 1054 1055 fn check_panic(&mut self) { 1056 let conn = self.connection_mut(); 1057 if let Some(err) = conn.panic.take() { 1058 panic::resume_unwind(err); 1059 } 1060 } 1061 1062 fn get_error(&mut self, ret: OSStatus) -> io::Error { 1063 self.check_panic(); 1064 1065 if let Some(err) = self.connection_mut().err.take() { 1066 err 1067 } else { 1068 io::Error::new(io::ErrorKind::Other, Error::from_code(ret)) 1069 } 1070 } 1071 } 1072 1073 impl<S: Read + Write> Read for SslStream<S> { 1074 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 1075 // Below we base our return value off the amount of data read, so a 1076 // zero-length buffer might cause us to erroneously interpret this 1077 // request as an error. Instead short-circuit that logic and return 1078 // `Ok(0)` instead. 1079 if buf.is_empty() { 1080 return Ok(0); 1081 } 1082 1083 // If some data was buffered but not enough to fill `buf`, SSLRead 1084 // will try to read a new packet. This is bad because there may be 1085 // no more data but the socket is remaining open (e.g HTTPS with 1086 // Connection: keep-alive). 1087 let buffered = self.context().buffered_read_size().unwrap_or(0); 1088 let mut to_read = buf.len(); 1089 if buffered > 0 { 1090 to_read = cmp::min(buffered, buf.len()); 1091 } 1092 1093 unsafe { 1094 let mut nread = 0; 1095 let ret = SSLRead(self.ctx.0, buf.as_mut_ptr() as *mut _, to_read, &mut nread); 1096 // SSLRead can return an error at the same time it returns the last 1097 // chunk of data (!) 1098 if nread > 0 { 1099 return Ok(nread as usize); 1100 } 1101 1102 match ret { 1103 errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0), 1104 _ => Err(self.get_error(ret)), 1105 } 1106 } 1107 } 1108 } 1109 1110 impl<S: Read + Write> Write for SslStream<S> { 1111 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 1112 // Like above in read, short circuit a 0-length write 1113 if buf.is_empty() { 1114 return Ok(0); 1115 } 1116 unsafe { 1117 let mut nwritten = 0; 1118 let ret = SSLWrite( 1119 self.ctx.0, 1120 buf.as_ptr() as *const _, 1121 buf.len(), 1122 &mut nwritten, 1123 ); 1124 // just to be safe, base success off of nwritten rather than ret 1125 // for the same reason as in read 1126 if nwritten > 0 { 1127 Ok(nwritten as usize) 1128 } else { 1129 Err(self.get_error(ret)) 1130 } 1131 } 1132 } 1133 1134 fn flush(&mut self) -> io::Result<()> { 1135 self.connection_mut().stream.flush() 1136 } 1137 } 1138 1139 /// A builder type to simplify the creation of client side `SslStream`s. 1140 #[derive(Debug)] 1141 pub struct ClientBuilder { 1142 identity: Option<SecIdentity>, 1143 certs: Vec<SecCertificate>, 1144 chain: Vec<SecCertificate>, 1145 protocol_min: Option<SslProtocol>, 1146 protocol_max: Option<SslProtocol>, 1147 trust_certs_only: bool, 1148 use_sni: bool, 1149 danger_accept_invalid_certs: bool, 1150 danger_accept_invalid_hostnames: bool, 1151 whitelisted_ciphers: Vec<CipherSuite>, 1152 blacklisted_ciphers: Vec<CipherSuite>, 1153 alpn: Option<Vec<String>>, 1154 } 1155 1156 impl Default for ClientBuilder { 1157 fn default() -> ClientBuilder { 1158 ClientBuilder::new() 1159 } 1160 } 1161 1162 impl ClientBuilder { 1163 /// Creates a new builder with default options. 1164 pub fn new() -> Self { 1165 ClientBuilder { 1166 identity: None, 1167 certs: Vec::new(), 1168 chain: Vec::new(), 1169 protocol_min: None, 1170 protocol_max: None, 1171 trust_certs_only: false, 1172 use_sni: true, 1173 danger_accept_invalid_certs: false, 1174 danger_accept_invalid_hostnames: false, 1175 whitelisted_ciphers: Vec::new(), 1176 blacklisted_ciphers: Vec::new(), 1177 alpn: None, 1178 } 1179 } 1180 1181 /// Specifies the set of root certificates to trust when 1182 /// verifying the server's certificate. 1183 pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self { 1184 self.certs = certs.to_owned(); 1185 self 1186 } 1187 1188 /// Specifies whether to trust the built-in certificates in addition 1189 /// to specified anchor certificates. 1190 pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self { 1191 self.trust_certs_only = only; 1192 self 1193 } 1194 1195 /// Specifies whether to trust invalid certificates. 1196 /// 1197 /// # Warning 1198 /// 1199 /// You should think very carefully before using this method. If invalid 1200 /// certificates are trusted, *any* certificate for *any* site will be 1201 /// trusted for use. This includes expired certificates. This introduces 1202 /// significant vulnerabilities, and should only be used as a last resort. 1203 pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self { 1204 self.danger_accept_invalid_certs = noverify; 1205 self 1206 } 1207 1208 /// Specifies whether to use Server Name Indication (SNI). 1209 pub fn use_sni(&mut self, use_sni: bool) -> &mut Self { 1210 self.use_sni = use_sni; 1211 self 1212 } 1213 1214 /// Specifies whether to verify that the server's hostname matches its certificate. 1215 /// 1216 /// # Warning 1217 /// 1218 /// You should think very carefully before using this method. If hostnames are not verified, 1219 /// *any* valid certificate for *any* site will be trusted for use. This introduces significant 1220 /// vulnerabilities, and should only be used as a last resort. 1221 pub fn danger_accept_invalid_hostnames( 1222 &mut self, 1223 danger_accept_invalid_hostnames: bool, 1224 ) -> &mut Self { 1225 self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames; 1226 self 1227 } 1228 1229 /// Set a whitelist of enabled ciphers. Any ciphers not whitelisted will be disabled. 1230 pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self { 1231 self.whitelisted_ciphers = whitelisted_ciphers.to_owned(); 1232 self 1233 } 1234 1235 /// Set a blacklist of disabled ciphers. Blacklisted ciphers will be disabled. 1236 pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self { 1237 self.blacklisted_ciphers = blacklisted_ciphers.to_owned(); 1238 self 1239 } 1240 1241 /// Use the specified identity as a SSL/TLS client certificate. 1242 pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self { 1243 self.identity = Some(identity.clone()); 1244 self.chain = chain.to_owned(); 1245 self 1246 } 1247 1248 /// Configure the minimum protocol that this client will support. 1249 pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self { 1250 self.protocol_min = Some(min); 1251 self 1252 } 1253 1254 /// Configure the minimum protocol that this client will support. 1255 pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self { 1256 self.protocol_max = Some(max); 1257 self 1258 } 1259 1260 /// Configures the set of protocols used for ALPN. 1261 #[cfg(feature = "alpn")] 1262 pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self { 1263 self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect()); 1264 self 1265 } 1266 1267 /// Initiates a new SSL/TLS session over a stream connected to the specified domain. 1268 /// 1269 /// If both SNI and hostname verification are disabled, the value of `domain` will be ignored. 1270 pub fn handshake<S>( 1271 &self, 1272 domain: &str, 1273 stream: S, 1274 ) -> result::Result<SslStream<S>, ClientHandshakeError<S>> 1275 where 1276 S: Read + Write, 1277 { 1278 // the logic for trust validation is in MidHandshakeClientBuilder::connect, so run all 1279 // of the handshake logic through that. 1280 let stream = MidHandshakeSslStream { 1281 stream: self.ctx_into_stream(domain, stream)?, 1282 error: Error::from(errSecSuccess), 1283 }; 1284 1285 let certs = self.certs.clone(); 1286 let stream = MidHandshakeClientBuilder { 1287 stream, 1288 domain: if self.danger_accept_invalid_hostnames { 1289 None 1290 } else { 1291 Some(domain.to_string()) 1292 }, 1293 certs, 1294 trust_certs_only: self.trust_certs_only, 1295 danger_accept_invalid_certs: self.danger_accept_invalid_certs, 1296 }; 1297 stream.handshake() 1298 } 1299 1300 fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>> 1301 where 1302 S: Read + Write, 1303 { 1304 let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?; 1305 1306 if self.use_sni { 1307 ctx.set_peer_domain_name(domain)?; 1308 } 1309 if let Some(ref identity) = self.identity { 1310 ctx.set_certificate(identity, &self.chain)?; 1311 } 1312 #[cfg(feature = "alpn")] 1313 { 1314 if let Some(ref alpn) = self.alpn { 1315 ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?; 1316 } 1317 } 1318 ctx.set_break_on_server_auth(true)?; 1319 self.configure_protocols(&mut ctx)?; 1320 self.configure_ciphers(&mut ctx)?; 1321 1322 ctx.into_stream(stream) 1323 } 1324 1325 fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> { 1326 if let Some(min) = self.protocol_min { 1327 ctx.set_protocol_version_min(min)?; 1328 } 1329 if let Some(max) = self.protocol_max { 1330 ctx.set_protocol_version_max(max)?; 1331 } 1332 Ok(()) 1333 } 1334 1335 fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> { 1336 let mut ciphers = if self.whitelisted_ciphers.is_empty() { 1337 ctx.enabled_ciphers()? 1338 } else { 1339 self.whitelisted_ciphers.clone() 1340 }; 1341 1342 if !self.blacklisted_ciphers.is_empty() { 1343 ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher)); 1344 } 1345 1346 ctx.set_enabled_ciphers(&ciphers)?; 1347 Ok(()) 1348 } 1349 } 1350 1351 /// A builder type to simplify the creation of server-side `SslStream`s. 1352 #[derive(Debug)] 1353 pub struct ServerBuilder { 1354 identity: SecIdentity, 1355 certs: Vec<SecCertificate>, 1356 } 1357 1358 impl ServerBuilder { 1359 /// Creates a new `ServerBuilder` which will use the specified identity 1360 /// and certificate chain for handshakes. 1361 pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> ServerBuilder { 1362 ServerBuilder { 1363 identity: identity.clone(), 1364 certs: certs.to_owned(), 1365 } 1366 } 1367 1368 /// Initiates a new SSL/TLS session over a stream. 1369 pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>> 1370 where 1371 S: Read + Write, 1372 { 1373 let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?; 1374 ctx.set_certificate(&self.identity, &self.certs)?; 1375 match ctx.handshake(stream) { 1376 Ok(stream) => Ok(stream), 1377 Err(HandshakeError::Interrupted(stream)) => Err(Error::from_code(stream.reason())), 1378 Err(HandshakeError::Failure(err)) => Err(err), 1379 } 1380 } 1381 } 1382 1383 #[cfg(test)] 1384 mod test { 1385 use std::io; 1386 use std::io::prelude::*; 1387 use std::net::TcpStream; 1388 1389 use super::*; 1390 1391 #[test] 1392 fn connect() { 1393 let mut ctx = p!(SslContext::new( 1394 SslProtocolSide::CLIENT, 1395 SslConnectionType::STREAM 1396 )); 1397 p!(ctx.set_peer_domain_name("google.com")); 1398 let stream = p!(TcpStream::connect("google.com:443")); 1399 p!(ctx.handshake(stream)); 1400 } 1401 1402 #[test] 1403 fn connect_bad_domain() { 1404 let mut ctx = p!(SslContext::new( 1405 SslProtocolSide::CLIENT, 1406 SslConnectionType::STREAM 1407 )); 1408 p!(ctx.set_peer_domain_name("foobar.com")); 1409 let stream = p!(TcpStream::connect("google.com:443")); 1410 match ctx.handshake(stream) { 1411 Ok(_) => panic!("expected failure"), 1412 Err(_) => {} 1413 } 1414 } 1415 1416 #[test] 1417 fn load_page() { 1418 let mut ctx = p!(SslContext::new( 1419 SslProtocolSide::CLIENT, 1420 SslConnectionType::STREAM 1421 )); 1422 p!(ctx.set_peer_domain_name("google.com")); 1423 let stream = p!(TcpStream::connect("google.com:443")); 1424 let mut stream = p!(ctx.handshake(stream)); 1425 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n")); 1426 p!(stream.flush()); 1427 let mut buf = vec![]; 1428 p!(stream.read_to_end(&mut buf)); 1429 println!("{}", String::from_utf8_lossy(&buf)); 1430 } 1431 1432 #[test] 1433 #[cfg(feature = "alpn")] 1434 fn client_alpn_accept() { 1435 let mut ctx = p!(SslContext::new( 1436 SslProtocolSide::CLIENT, 1437 SslConnectionType::STREAM 1438 )); 1439 p!(ctx.set_peer_domain_name("google.com")); 1440 p!(ctx.set_alpn_protocols(&vec!["h2"])); 1441 let stream = p!(TcpStream::connect("google.com:443")); 1442 let stream = ctx.handshake(stream).unwrap(); 1443 assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap()); 1444 } 1445 1446 #[test] 1447 #[cfg(feature = "alpn")] 1448 fn client_alpn_reject() { 1449 let mut ctx = p!(SslContext::new( 1450 SslProtocolSide::CLIENT, 1451 SslConnectionType::STREAM 1452 )); 1453 p!(ctx.set_peer_domain_name("google.com")); 1454 p!(ctx.set_alpn_protocols(&vec!["h2c"])); 1455 let stream = p!(TcpStream::connect("google.com:443")); 1456 let stream = ctx.handshake(stream).unwrap(); 1457 assert!(stream.context().alpn_protocols().is_err()); 1458 } 1459 1460 #[test] 1461 fn client_no_anchor_certs() { 1462 let stream = p!(TcpStream::connect("google.com:443")); 1463 assert!(ClientBuilder::new() 1464 .trust_anchor_certificates_only(true) 1465 .handshake("google.com", stream) 1466 .is_err()); 1467 } 1468 1469 #[test] 1470 fn client_bad_domain() { 1471 let stream = p!(TcpStream::connect("google.com:443")); 1472 assert!(ClientBuilder::new() 1473 .handshake("foobar.com", stream) 1474 .is_err()); 1475 } 1476 1477 #[test] 1478 fn client_bad_domain_ignored() { 1479 let stream = p!(TcpStream::connect("google.com:443")); 1480 ClientBuilder::new() 1481 .danger_accept_invalid_hostnames(true) 1482 .handshake("foobar.com", stream) 1483 .unwrap(); 1484 } 1485 1486 #[test] 1487 fn connect_no_verify_ssl() { 1488 let stream = p!(TcpStream::connect("expired.badssl.com:443")); 1489 let mut builder = ClientBuilder::new(); 1490 builder.danger_accept_invalid_certs(true); 1491 builder.handshake("expired.badssl.com", stream).unwrap(); 1492 } 1493 1494 #[test] 1495 fn load_page_client() { 1496 let stream = p!(TcpStream::connect("google.com:443")); 1497 let mut stream = p!(ClientBuilder::new().handshake("google.com", stream)); 1498 p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n")); 1499 p!(stream.flush()); 1500 let mut buf = vec![]; 1501 p!(stream.read_to_end(&mut buf)); 1502 println!("{}", String::from_utf8_lossy(&buf)); 1503 } 1504 1505 #[test] 1506 #[cfg_attr(target_os = "ios", ignore)] // FIXME what's going on with ios? 1507 fn cipher_configuration() { 1508 let mut ctx = p!(SslContext::new( 1509 SslProtocolSide::SERVER, 1510 SslConnectionType::STREAM 1511 )); 1512 let ciphers = p!(ctx.enabled_ciphers()); 1513 let ciphers = ciphers 1514 .iter() 1515 .enumerate() 1516 .filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None }) 1517 .collect::<Vec<_>>(); 1518 p!(ctx.set_enabled_ciphers(&ciphers)); 1519 assert_eq!(ciphers, p!(ctx.enabled_ciphers())); 1520 } 1521 1522 #[test] 1523 fn test_builder_whitelist_ciphers() { 1524 let stream = p!(TcpStream::connect("google.com:443")); 1525 1526 let ctx = p!(SslContext::new( 1527 SslProtocolSide::CLIENT, 1528 SslConnectionType::STREAM 1529 )); 1530 assert!(p!(ctx.enabled_ciphers()).len() > 1); 1531 1532 let ciphers = p!(ctx.enabled_ciphers()); 1533 let cipher = ciphers.first().unwrap(); 1534 let stream = p!(ClientBuilder::new() 1535 .whitelist_ciphers(&[*cipher]) 1536 .ctx_into_stream("google.com", stream)); 1537 1538 assert_eq!(1, p!(stream.context().enabled_ciphers()).len()); 1539 } 1540 1541 #[test] 1542 #[cfg_attr(target_os = "ios", ignore)] // FIXME same issue as cipher_configuration 1543 fn test_builder_blacklist_ciphers() { 1544 let stream = p!(TcpStream::connect("google.com:443")); 1545 1546 let ctx = p!(SslContext::new( 1547 SslProtocolSide::CLIENT, 1548 SslConnectionType::STREAM 1549 )); 1550 let num = p!(ctx.enabled_ciphers()).len(); 1551 assert!(num > 1); 1552 1553 let ciphers = p!(ctx.enabled_ciphers()); 1554 let cipher = ciphers.first().unwrap(); 1555 let stream = p!(ClientBuilder::new() 1556 .blacklist_ciphers(&[*cipher]) 1557 .ctx_into_stream("google.com", stream)); 1558 1559 assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len()); 1560 } 1561 1562 #[test] 1563 fn idle_context_peer_trust() { 1564 let ctx = p!(SslContext::new( 1565 SslProtocolSide::SERVER, 1566 SslConnectionType::STREAM 1567 )); 1568 assert!(ctx.peer_trust2().is_err()); 1569 } 1570 1571 #[test] 1572 fn peer_id() { 1573 let mut ctx = p!(SslContext::new( 1574 SslProtocolSide::SERVER, 1575 SslConnectionType::STREAM 1576 )); 1577 assert!(p!(ctx.peer_id()).is_none()); 1578 p!(ctx.set_peer_id(b"foobar")); 1579 assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..])); 1580 } 1581 1582 #[test] 1583 fn peer_domain_name() { 1584 let mut ctx = p!(SslContext::new( 1585 SslProtocolSide::CLIENT, 1586 SslConnectionType::STREAM 1587 )); 1588 assert_eq!("", p!(ctx.peer_domain_name())); 1589 p!(ctx.set_peer_domain_name("foobar.com")); 1590 assert_eq!("foobar.com", p!(ctx.peer_domain_name())); 1591 } 1592 1593 #[test] 1594 #[should_panic(expected = "blammo")] 1595 fn write_panic() { 1596 struct ExplodingStream(TcpStream); 1597 1598 impl Read for ExplodingStream { 1599 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 1600 self.0.read(buf) 1601 } 1602 } 1603 1604 impl Write for ExplodingStream { 1605 fn write(&mut self, _: &[u8]) -> io::Result<usize> { 1606 panic!("blammo"); 1607 } 1608 1609 fn flush(&mut self) -> io::Result<()> { 1610 self.0.flush() 1611 } 1612 } 1613 1614 let mut ctx = p!(SslContext::new( 1615 SslProtocolSide::CLIENT, 1616 SslConnectionType::STREAM 1617 )); 1618 p!(ctx.set_peer_domain_name("google.com")); 1619 let stream = p!(TcpStream::connect("google.com:443")); 1620 let _ = ctx.handshake(ExplodingStream(stream)); 1621 } 1622 1623 #[test] 1624 #[should_panic(expected = "blammo")] 1625 fn read_panic() { 1626 struct ExplodingStream(TcpStream); 1627 1628 impl Read for ExplodingStream { 1629 fn read(&mut self, _: &mut [u8]) -> io::Result<usize> { 1630 panic!("blammo"); 1631 } 1632 } 1633 1634 impl Write for ExplodingStream { 1635 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 1636 self.0.write(buf) 1637 } 1638 1639 fn flush(&mut self) -> io::Result<()> { 1640 self.0.flush() 1641 } 1642 } 1643 1644 let mut ctx = p!(SslContext::new( 1645 SslProtocolSide::CLIENT, 1646 SslConnectionType::STREAM 1647 )); 1648 p!(ctx.set_peer_domain_name("google.com")); 1649 let stream = p!(TcpStream::connect("google.com:443")); 1650 let _ = ctx.handshake(ExplodingStream(stream)); 1651 } 1652 1653 #[test] 1654 fn zero_length_buffers() { 1655 let mut ctx = p!(SslContext::new( 1656 SslProtocolSide::CLIENT, 1657 SslConnectionType::STREAM 1658 )); 1659 p!(ctx.set_peer_domain_name("google.com")); 1660 let stream = p!(TcpStream::connect("google.com:443")); 1661 let mut stream = ctx.handshake(stream).unwrap(); 1662 assert_eq!(stream.write(b"").unwrap(), 0); 1663 assert_eq!(stream.read(&mut []).unwrap(), 0); 1664 } 1665 } 1666