1 extern crate libc;
2 extern crate security_framework;
3 extern crate security_framework_sys;
4 extern crate tempfile;
5 
6 use self::security_framework::base;
7 use self::security_framework::certificate::SecCertificate;
8 use self::security_framework::identity::SecIdentity;
9 use self::security_framework::import_export::{ImportedIdentity, Pkcs12ImportOptions};
10 use self::security_framework::secure_transport::{
11     self, ClientBuilder, SslConnectionType, SslContext, SslProtocol, SslProtocolSide,
12 };
13 use self::security_framework_sys::base::errSecIO;
14 use self::tempfile::TempDir;
15 use std::error;
16 use std::fmt;
17 use std::io;
18 use std::sync::Mutex;
19 use std::sync::Once;
20 
21 #[cfg(not(target_os = "ios"))]
22 use self::security_framework::os::macos::certificate::{PropertyType, SecCertificateExt};
23 #[cfg(not(target_os = "ios"))]
24 use self::security_framework::os::macos::certificate_oids::CertificateOid;
25 #[cfg(not(target_os = "ios"))]
26 use self::security_framework::os::macos::import_export::{ImportOptions, SecItems};
27 #[cfg(not(target_os = "ios"))]
28 use self::security_framework::os::macos::keychain::{self, KeychainSettings, SecKeychain};
29 #[cfg(not(target_os = "ios"))]
30 use self::security_framework_sys::base::errSecParam;
31 
32 use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
33 
34 static SET_AT_EXIT: Once = Once::new();
35 
36 #[cfg(not(target_os = "ios"))]
37 lazy_static! {
38     static ref TEMP_KEYCHAIN: Mutex<Option<(SecKeychain, TempDir)>> = Mutex::new(None);
39 }
40 
convert_protocol(protocol: Protocol) -> SslProtocol41 fn convert_protocol(protocol: Protocol) -> SslProtocol {
42     match protocol {
43         Protocol::Sslv3 => SslProtocol::SSL3,
44         Protocol::Tlsv10 => SslProtocol::TLS1,
45         Protocol::Tlsv11 => SslProtocol::TLS11,
46         Protocol::Tlsv12 => SslProtocol::TLS12,
47         Protocol::__NonExhaustive => unreachable!(),
48     }
49 }
50 
51 pub struct Error(base::Error);
52 
53 impl error::Error for Error {
source(&self) -> Option<&(dyn error::Error + 'static)>54     fn source(&self) -> Option<&(dyn error::Error + 'static)> {
55         error::Error::source(&self.0)
56     }
57 }
58 
59 impl fmt::Display for Error {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result60     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
61         fmt::Display::fmt(&self.0, fmt)
62     }
63 }
64 
65 impl fmt::Debug for Error {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result66     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
67         fmt::Debug::fmt(&self.0, fmt)
68     }
69 }
70 
71 impl From<base::Error> for Error {
from(error: base::Error) -> Error72     fn from(error: base::Error) -> Error {
73         Error(error)
74     }
75 }
76 
77 #[derive(Clone)]
78 pub struct Identity {
79     identity: SecIdentity,
80     chain: Vec<SecCertificate>,
81 }
82 
83 impl Identity {
from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error>84     pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
85         let mut imports = Identity::import_options(buf, pass)?;
86         let import = imports.pop().unwrap();
87 
88         let identity = import
89             .identity
90             .expect("Pkcs12 files must include an identity");
91 
92         // FIXME: Compare the certificates for equality using CFEqual
93         let identity_cert = identity.certificate()?.to_der();
94 
95         Ok(Identity {
96             identity,
97             chain: import
98                 .cert_chain
99                 .unwrap_or(vec![])
100                 .into_iter()
101                 .filter(|c| c.to_der() != identity_cert)
102                 .collect(),
103         })
104     }
105 
106     #[cfg(not(target_os = "ios"))]
import_options(buf: &[u8], pass: &str) -> Result<Vec<ImportedIdentity>, Error>107     fn import_options(buf: &[u8], pass: &str) -> Result<Vec<ImportedIdentity>, Error> {
108         SET_AT_EXIT.call_once(|| {
109             extern "C" fn atexit() {
110                 *TEMP_KEYCHAIN.lock().unwrap() = None;
111             }
112             unsafe {
113                 libc::atexit(atexit);
114             }
115         });
116 
117         let keychain = match *TEMP_KEYCHAIN.lock().unwrap() {
118             Some((ref keychain, _)) => keychain.clone(),
119             ref mut lock @ None => {
120                 let dir = TempDir::new().map_err(|_| Error(base::Error::from(errSecIO)))?;
121 
122                 let mut keychain = keychain::CreateOptions::new()
123                     .password(pass)
124                     .create(dir.path().join("tmp.keychain"))?;
125                 keychain.set_settings(&KeychainSettings::new())?;
126 
127                 *lock = Some((keychain.clone(), dir));
128                 keychain
129             }
130         };
131         let imports = Pkcs12ImportOptions::new()
132             .passphrase(pass)
133             .keychain(keychain)
134             .import(buf)?;
135         Ok(imports)
136     }
137 
138     #[cfg(target_os = "ios")]
import_options(buf: &[u8], pass: &str) -> Result<Vec<ImportedIdentity>, Error>139     fn import_options(buf: &[u8], pass: &str) -> Result<Vec<ImportedIdentity>, Error> {
140         let imports = Pkcs12ImportOptions::new().passphrase(pass).import(buf)?;
141         Ok(imports)
142     }
143 }
144 
145 #[derive(Clone)]
146 pub struct Certificate(SecCertificate);
147 
148 impl Certificate {
from_der(buf: &[u8]) -> Result<Certificate, Error>149     pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
150         let cert = SecCertificate::from_der(buf)?;
151         Ok(Certificate(cert))
152     }
153 
154     #[cfg(not(target_os = "ios"))]
from_pem(buf: &[u8]) -> Result<Certificate, Error>155     pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
156         let mut items = SecItems::default();
157         ImportOptions::new().items(&mut items).import(buf)?;
158         if items.certificates.len() == 1 && items.identities.is_empty() && items.keys.is_empty() {
159             Ok(Certificate(items.certificates.pop().unwrap()))
160         } else {
161             Err(Error(base::Error::from(errSecParam)))
162         }
163     }
164 
165     #[cfg(target_os = "ios")]
from_pem(buf: &[u8]) -> Result<Certificate, Error>166     pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
167         panic!("Not implemented on iOS");
168     }
169 
to_der(&self) -> Result<Vec<u8>, Error>170     pub fn to_der(&self) -> Result<Vec<u8>, Error> {
171         Ok(self.0.to_der())
172     }
173 }
174 
175 pub enum HandshakeError<S> {
176     WouldBlock(MidHandshakeTlsStream<S>),
177     Failure(Error),
178 }
179 
180 impl<S> From<secure_transport::ClientHandshakeError<S>> for HandshakeError<S> {
from(e: secure_transport::ClientHandshakeError<S>) -> HandshakeError<S>181     fn from(e: secure_transport::ClientHandshakeError<S>) -> HandshakeError<S> {
182         match e {
183             secure_transport::ClientHandshakeError::Failure(e) => HandshakeError::Failure(e.into()),
184             secure_transport::ClientHandshakeError::Interrupted(s) => {
185                 HandshakeError::WouldBlock(MidHandshakeTlsStream::Client(s))
186             }
187         }
188     }
189 }
190 
191 impl<S> From<base::Error> for HandshakeError<S> {
from(e: base::Error) -> HandshakeError<S>192     fn from(e: base::Error) -> HandshakeError<S> {
193         HandshakeError::Failure(e.into())
194     }
195 }
196 
197 pub enum MidHandshakeTlsStream<S> {
198     Server(
199         secure_transport::MidHandshakeSslStream<S>,
200         Option<SecCertificate>,
201     ),
202     Client(secure_transport::MidHandshakeClientBuilder<S>),
203 }
204 
205 impl<S> fmt::Debug for MidHandshakeTlsStream<S>
206 where
207     S: fmt::Debug,
208 {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result209     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
210         match *self {
211             MidHandshakeTlsStream::Server(ref s, _) => s.fmt(fmt),
212             MidHandshakeTlsStream::Client(ref s) => s.fmt(fmt),
213         }
214     }
215 }
216 
217 impl<S> MidHandshakeTlsStream<S> {
get_ref(&self) -> &S218     pub fn get_ref(&self) -> &S {
219         match *self {
220             MidHandshakeTlsStream::Server(ref s, _) => s.get_ref(),
221             MidHandshakeTlsStream::Client(ref s) => s.get_ref(),
222         }
223     }
224 
get_mut(&mut self) -> &mut S225     pub fn get_mut(&mut self) -> &mut S {
226         match *self {
227             MidHandshakeTlsStream::Server(ref mut s, _) => s.get_mut(),
228             MidHandshakeTlsStream::Client(ref mut s) => s.get_mut(),
229         }
230     }
231 }
232 
233 impl<S> MidHandshakeTlsStream<S>
234 where
235     S: io::Read + io::Write,
236 {
handshake(self) -> Result<TlsStream<S>, HandshakeError<S>>237     pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
238         match self {
239             MidHandshakeTlsStream::Server(s, cert) => match s.handshake() {
240                 Ok(stream) => Ok(TlsStream { stream, cert }),
241                 Err(secure_transport::HandshakeError::Failure(e)) => {
242                     Err(HandshakeError::Failure(Error(e)))
243                 }
244                 Err(secure_transport::HandshakeError::Interrupted(s)) => Err(
245                     HandshakeError::WouldBlock(MidHandshakeTlsStream::Server(s, cert)),
246                 ),
247             },
248             MidHandshakeTlsStream::Client(s) => match s.handshake() {
249                 Ok(stream) => Ok(TlsStream { stream, cert: None }),
250                 Err(e) => Err(e.into()),
251             },
252         }
253     }
254 }
255 
256 #[derive(Clone)]
257 pub struct TlsConnector {
258     identity: Option<Identity>,
259     min_protocol: Option<Protocol>,
260     max_protocol: Option<Protocol>,
261     roots: Vec<SecCertificate>,
262     use_sni: bool,
263     danger_accept_invalid_hostnames: bool,
264     danger_accept_invalid_certs: bool,
265 }
266 
267 impl TlsConnector {
new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error>268     pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
269         Ok(TlsConnector {
270             identity: builder.identity.as_ref().map(|i| i.0.clone()),
271             min_protocol: builder.min_protocol,
272             max_protocol: builder.max_protocol,
273             roots: builder
274                 .root_certificates
275                 .iter()
276                 .map(|c| (c.0).0.clone())
277                 .collect(),
278             use_sni: builder.use_sni,
279             danger_accept_invalid_hostnames: builder.accept_invalid_hostnames,
280             danger_accept_invalid_certs: builder.accept_invalid_certs,
281         })
282     }
283 
connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: io::Read + io::Write,284     pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
285     where
286         S: io::Read + io::Write,
287     {
288         let mut builder = ClientBuilder::new();
289         if let Some(min) = self.min_protocol {
290             builder.protocol_min(convert_protocol(min));
291         }
292         if let Some(max) = self.max_protocol {
293             builder.protocol_max(convert_protocol(max));
294         }
295         if let Some(identity) = self.identity.as_ref() {
296             builder.identity(&identity.identity, &identity.chain);
297         }
298         builder.anchor_certificates(&self.roots);
299         builder.use_sni(self.use_sni);
300         builder.danger_accept_invalid_hostnames(self.danger_accept_invalid_hostnames);
301         builder.danger_accept_invalid_certs(self.danger_accept_invalid_certs);
302 
303         match builder.handshake(domain, stream) {
304             Ok(stream) => Ok(TlsStream { stream, cert: None }),
305             Err(e) => Err(e.into()),
306         }
307     }
308 }
309 
310 #[derive(Clone)]
311 pub struct TlsAcceptor {
312     identity: Identity,
313     min_protocol: Option<Protocol>,
314     max_protocol: Option<Protocol>,
315 }
316 
317 impl TlsAcceptor {
new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error>318     pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
319         Ok(TlsAcceptor {
320             identity: builder.identity.0.clone(),
321             min_protocol: builder.min_protocol,
322             max_protocol: builder.max_protocol,
323         })
324     }
325 
accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: io::Read + io::Write,326     pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
327     where
328         S: io::Read + io::Write,
329     {
330         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
331 
332         if let Some(min) = self.min_protocol {
333             ctx.set_protocol_version_min(convert_protocol(min))?;
334         }
335         if let Some(max) = self.max_protocol {
336             ctx.set_protocol_version_max(convert_protocol(max))?;
337         }
338         ctx.set_certificate(&self.identity.identity, &self.identity.chain)?;
339         let cert = Some(self.identity.identity.certificate()?);
340         match ctx.handshake(stream) {
341             Ok(stream) => Ok(TlsStream { stream, cert }),
342             Err(secure_transport::HandshakeError::Failure(e)) => {
343                 Err(HandshakeError::Failure(Error(e)))
344             }
345             Err(secure_transport::HandshakeError::Interrupted(s)) => Err(
346                 HandshakeError::WouldBlock(MidHandshakeTlsStream::Server(s, cert)),
347             ),
348         }
349     }
350 }
351 
352 pub struct TlsStream<S> {
353     stream: secure_transport::SslStream<S>,
354     cert: Option<SecCertificate>,
355 }
356 
357 impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result358     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
359         fmt::Debug::fmt(&self.stream, fmt)
360     }
361 }
362 
363 impl<S> TlsStream<S> {
get_ref(&self) -> &S364     pub fn get_ref(&self) -> &S {
365         self.stream.get_ref()
366     }
367 
get_mut(&mut self) -> &mut S368     pub fn get_mut(&mut self) -> &mut S {
369         self.stream.get_mut()
370     }
371 }
372 
373 impl<S: io::Read + io::Write> TlsStream<S> {
buffered_read_size(&self) -> Result<usize, Error>374     pub fn buffered_read_size(&self) -> Result<usize, Error> {
375         Ok(self.stream.context().buffered_read_size()?)
376     }
377 
peer_certificate(&self) -> Result<Option<Certificate>, Error>378     pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
379         let trust = match self.stream.context().peer_trust2()? {
380             Some(trust) => trust,
381             None => return Ok(None),
382         };
383         trust.evaluate()?;
384 
385         Ok(trust.certificate_at_index(0).map(Certificate))
386     }
387 
388     #[cfg(target_os = "ios")]
tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error>389     pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
390         Ok(None)
391     }
392 
393     #[cfg(not(target_os = "ios"))]
tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error>394     pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
395         let cert = match self.cert {
396             Some(ref cert) => cert.clone(),
397             None => match self.peer_certificate()? {
398                 Some(cert) => cert.0,
399                 None => return Ok(None),
400             },
401         };
402 
403         let property = match cert
404             .properties(Some(&[CertificateOid::x509_v1_signature_algorithm()]))
405             .ok()
406             .and_then(|p| p.get(CertificateOid::x509_v1_signature_algorithm()))
407         {
408             Some(property) => property,
409             None => return Ok(None),
410         };
411 
412         let section = match property.get() {
413             PropertyType::Section(section) => section,
414             _ => return Ok(None),
415         };
416 
417         let algorithm = match section
418             .iter()
419             .filter(|p| p.label().to_string() == "Algorithm")
420             .next()
421         {
422             Some(property) => property,
423             None => return Ok(None),
424         };
425 
426         let algorithm = match algorithm.get() {
427             PropertyType::String(algorithm) => algorithm,
428             _ => return Ok(None),
429         };
430 
431         let digest = match &*algorithm.to_string() {
432             // MD5
433             "1.2.840.113549.2.5" | "1.2.840.113549.1.1.4" | "1.3.14.3.2.3" => Digest::Sha256,
434             // SHA-1
435             "1.3.14.3.2.26"
436             | "1.3.14.3.2.15"
437             | "1.2.840.113549.1.1.5"
438             | "1.3.14.3.2.29"
439             | "1.2.840.10040.4.3"
440             | "1.3.14.3.2.13"
441             | "1.2.840.10045.4.1" => Digest::Sha256,
442             // SHA-224
443             "2.16.840.1.101.3.4.2.4"
444             | "1.2.840.113549.1.1.14"
445             | "2.16.840.1.101.3.4.3.1"
446             | "1.2.840.10045.4.3.1" => Digest::Sha224,
447             // SHA-256
448             "2.16.840.1.101.3.4.2.1" | "1.2.840.113549.1.1.11" | "1.2.840.10045.4.3.2" => {
449                 Digest::Sha256
450             }
451             // SHA-384
452             "2.16.840.1.101.3.4.2.2" | "1.2.840.113549.1.1.12" | "1.2.840.10045.4.3.3" => {
453                 Digest::Sha384
454             }
455             // SHA-512
456             "2.16.840.1.101.3.4.2.3" | "1.2.840.113549.1.1.13" | "1.2.840.10045.4.3.4" => {
457                 Digest::Sha512
458             }
459             _ => return Ok(None),
460         };
461 
462         let der = cert.to_der();
463         Ok(Some(digest.hash(&der)))
464     }
465 
shutdown(&mut self) -> io::Result<()>466     pub fn shutdown(&mut self) -> io::Result<()> {
467         self.stream.close()?;
468         Ok(())
469     }
470 }
471 
472 impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>473     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
474         self.stream.read(buf)
475     }
476 }
477 
478 impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>479     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
480         self.stream.write(buf)
481     }
482 
flush(&mut self) -> io::Result<()>483     fn flush(&mut self) -> io::Result<()> {
484         self.stream.flush()
485     }
486 }
487 
488 enum Digest {
489     Sha224,
490     Sha256,
491     Sha384,
492     Sha512,
493 }
494 
495 impl Digest {
hash(&self, data: &[u8]) -> Vec<u8>496     fn hash(&self, data: &[u8]) -> Vec<u8> {
497         unsafe {
498             assert!(data.len() <= CC_LONG::max_value() as usize);
499             match *self {
500                 Digest::Sha224 => {
501                     let mut buf = [0; CC_SHA224_DIGEST_LENGTH];
502                     CC_SHA224(data.as_ptr(), data.len() as CC_LONG, buf.as_mut_ptr());
503                     buf.to_vec()
504                 }
505                 Digest::Sha256 => {
506                     let mut buf = [0; CC_SHA256_DIGEST_LENGTH];
507                     CC_SHA256(data.as_ptr(), data.len() as CC_LONG, buf.as_mut_ptr());
508                     buf.to_vec()
509                 }
510                 Digest::Sha384 => {
511                     let mut buf = [0; CC_SHA384_DIGEST_LENGTH];
512                     CC_SHA384(data.as_ptr(), data.len() as CC_LONG, buf.as_mut_ptr());
513                     buf.to_vec()
514                 }
515                 Digest::Sha512 => {
516                     let mut buf = [0; CC_SHA512_DIGEST_LENGTH];
517                     CC_SHA512(data.as_ptr(), data.len() as CC_LONG, buf.as_mut_ptr());
518                     buf.to_vec()
519                 }
520             }
521         }
522     }
523 }
524 
525 // FIXME ideally we'd pull these in from elsewhere
526 const CC_SHA224_DIGEST_LENGTH: usize = 28;
527 const CC_SHA256_DIGEST_LENGTH: usize = 32;
528 const CC_SHA384_DIGEST_LENGTH: usize = 48;
529 const CC_SHA512_DIGEST_LENGTH: usize = 64;
530 #[allow(non_camel_case_types)]
531 type CC_LONG = u32;
532 
533 extern "C" {
CC_SHA224(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8534     fn CC_SHA224(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8;
CC_SHA256(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8535     fn CC_SHA256(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8;
CC_SHA384(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8536     fn CC_SHA384(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8;
CC_SHA512(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8537     fn CC_SHA512(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8;
538 }
539