1 extern crate libc;
2 extern crate security_framework;
3 extern crate security_framework_sys;
4 extern crate tempfile;
5 
6 use self::security_framework::base;
7 use self::security_framework::certificate::SecCertificate;
8 use self::security_framework::identity::SecIdentity;
9 use self::security_framework::import_export::{ImportedIdentity, Pkcs12ImportOptions};
10 use self::security_framework::secure_transport::{
11     self, ClientBuilder, SslConnectionType, SslContext, SslProtocol, SslProtocolSide,
12 };
13 use self::security_framework_sys::base::{errSecIO, errSecParam};
14 use self::tempfile::TempDir;
15 use std::error;
16 use std::fmt;
17 use std::io;
18 use std::str;
19 use std::sync::Mutex;
20 use std::sync::Once;
21 
22 #[cfg(not(target_os = "ios"))]
23 use self::security_framework::os::macos::certificate::{PropertyType, SecCertificateExt};
24 #[cfg(not(target_os = "ios"))]
25 use self::security_framework::os::macos::certificate_oids::CertificateOid;
26 #[cfg(not(target_os = "ios"))]
27 use self::security_framework::os::macos::import_export::{
28     ImportOptions, Pkcs12ImportOptionsExt, SecItems,
29 };
30 #[cfg(not(target_os = "ios"))]
31 use self::security_framework::os::macos::keychain::{self, KeychainSettings, SecKeychain};
32 
33 use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
34 
35 static SET_AT_EXIT: Once = Once::new();
36 
37 #[cfg(not(target_os = "ios"))]
38 lazy_static! {
39     static ref TEMP_KEYCHAIN: Mutex<Option<(SecKeychain, TempDir)>> = Mutex::new(None);
40 }
41 
convert_protocol(protocol: Protocol) -> SslProtocol42 fn convert_protocol(protocol: Protocol) -> SslProtocol {
43     match protocol {
44         Protocol::Sslv3 => SslProtocol::SSL3,
45         Protocol::Tlsv10 => SslProtocol::TLS1,
46         Protocol::Tlsv11 => SslProtocol::TLS11,
47         Protocol::Tlsv12 => SslProtocol::TLS12,
48         Protocol::__NonExhaustive => unreachable!(),
49     }
50 }
51 
52 pub struct Error(base::Error);
53 
54 impl error::Error for Error {
source(&self) -> Option<&(dyn error::Error + 'static)>55     fn source(&self) -> Option<&(dyn error::Error + 'static)> {
56         error::Error::source(&self.0)
57     }
58 }
59 
60 impl fmt::Display for Error {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result61     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
62         fmt::Display::fmt(&self.0, fmt)
63     }
64 }
65 
66 impl fmt::Debug for Error {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result67     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
68         fmt::Debug::fmt(&self.0, fmt)
69     }
70 }
71 
72 impl From<base::Error> for Error {
from(error: base::Error) -> Error73     fn from(error: base::Error) -> Error {
74         Error(error)
75     }
76 }
77 
78 #[derive(Clone, Debug)]
79 pub struct Identity {
80     identity: SecIdentity,
81     chain: Vec<SecCertificate>,
82 }
83 
84 impl Identity {
from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error>85     pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
86         let mut imports = Identity::import_options(buf, pass)?;
87         let import = imports.pop().unwrap();
88 
89         let identity = import
90             .identity
91             .expect("Pkcs12 files must include an identity");
92 
93         // FIXME: Compare the certificates for equality using CFEqual
94         let identity_cert = identity.certificate()?.to_der();
95 
96         Ok(Identity {
97             identity,
98             chain: import
99                 .cert_chain
100                 .unwrap_or(vec![])
101                 .into_iter()
102                 .filter(|c| c.to_der() != identity_cert)
103                 .collect(),
104         })
105     }
106 
107     #[cfg(not(target_os = "ios"))]
import_options(buf: &[u8], pass: &str) -> Result<Vec<ImportedIdentity>, Error>108     fn import_options(buf: &[u8], pass: &str) -> Result<Vec<ImportedIdentity>, Error> {
109         SET_AT_EXIT.call_once(|| {
110             extern "C" fn atexit() {
111                 *TEMP_KEYCHAIN.lock().unwrap() = None;
112             }
113             unsafe {
114                 libc::atexit(atexit);
115             }
116         });
117 
118         let keychain = match *TEMP_KEYCHAIN.lock().unwrap() {
119             Some((ref keychain, _)) => keychain.clone(),
120             ref mut lock @ None => {
121                 let dir = TempDir::new().map_err(|_| Error(base::Error::from(errSecIO)))?;
122 
123                 let mut keychain = keychain::CreateOptions::new()
124                     .password(pass)
125                     .create(dir.path().join("tmp.keychain"))?;
126                 keychain.set_settings(&KeychainSettings::new())?;
127 
128                 *lock = Some((keychain.clone(), dir));
129                 keychain
130             }
131         };
132         let mut import_opts = Pkcs12ImportOptions::new();
133         // Method shadowed by deprecated method.
134         <Pkcs12ImportOptions as Pkcs12ImportOptionsExt>::keychain(&mut import_opts, keychain);
135         let imports = import_opts.passphrase(pass).import(buf)?;
136         Ok(imports)
137     }
138 
139     #[cfg(target_os = "ios")]
import_options(buf: &[u8], pass: &str) -> Result<Vec<ImportedIdentity>, Error>140     fn import_options(buf: &[u8], pass: &str) -> Result<Vec<ImportedIdentity>, Error> {
141         let imports = Pkcs12ImportOptions::new().passphrase(pass).import(buf)?;
142         Ok(imports)
143     }
144 }
145 
146 #[derive(Clone)]
147 pub struct Certificate(SecCertificate);
148 
149 impl Certificate {
from_der(buf: &[u8]) -> Result<Certificate, Error>150     pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
151         let cert = SecCertificate::from_der(buf)?;
152         Ok(Certificate(cert))
153     }
154 
155     #[cfg(not(target_os = "ios"))]
from_pem(buf: &[u8]) -> Result<Certificate, Error>156     pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
157         let mut items = SecItems::default();
158         ImportOptions::new().items(&mut items).import(buf)?;
159         if items.certificates.len() == 1 && items.identities.is_empty() && items.keys.is_empty() {
160             Ok(Certificate(items.certificates.pop().unwrap()))
161         } else {
162             Err(Error(base::Error::from(errSecParam)))
163         }
164     }
165 
166     #[cfg(target_os = "ios")]
from_pem(buf: &[u8]) -> Result<Certificate, Error>167     pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
168         panic!("Not implemented on iOS");
169     }
170 
to_der(&self) -> Result<Vec<u8>, Error>171     pub fn to_der(&self) -> Result<Vec<u8>, Error> {
172         Ok(self.0.to_der())
173     }
174 }
175 
176 pub enum HandshakeError<S> {
177     WouldBlock(MidHandshakeTlsStream<S>),
178     Failure(Error),
179 }
180 
181 impl<S> From<secure_transport::ClientHandshakeError<S>> for HandshakeError<S> {
from(e: secure_transport::ClientHandshakeError<S>) -> HandshakeError<S>182     fn from(e: secure_transport::ClientHandshakeError<S>) -> HandshakeError<S> {
183         match e {
184             secure_transport::ClientHandshakeError::Failure(e) => HandshakeError::Failure(e.into()),
185             secure_transport::ClientHandshakeError::Interrupted(s) => {
186                 HandshakeError::WouldBlock(MidHandshakeTlsStream::Client(s))
187             }
188         }
189     }
190 }
191 
192 impl<S> From<base::Error> for HandshakeError<S> {
from(e: base::Error) -> HandshakeError<S>193     fn from(e: base::Error) -> HandshakeError<S> {
194         HandshakeError::Failure(e.into())
195     }
196 }
197 
198 pub enum MidHandshakeTlsStream<S> {
199     Server(
200         secure_transport::MidHandshakeSslStream<S>,
201         Option<SecCertificate>,
202     ),
203     Client(secure_transport::MidHandshakeClientBuilder<S>),
204 }
205 
206 impl<S> fmt::Debug for MidHandshakeTlsStream<S>
207 where
208     S: fmt::Debug,
209 {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result210     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
211         match *self {
212             MidHandshakeTlsStream::Server(ref s, _) => s.fmt(fmt),
213             MidHandshakeTlsStream::Client(ref s) => s.fmt(fmt),
214         }
215     }
216 }
217 
218 impl<S> MidHandshakeTlsStream<S> {
get_ref(&self) -> &S219     pub fn get_ref(&self) -> &S {
220         match *self {
221             MidHandshakeTlsStream::Server(ref s, _) => s.get_ref(),
222             MidHandshakeTlsStream::Client(ref s) => s.get_ref(),
223         }
224     }
225 
get_mut(&mut self) -> &mut S226     pub fn get_mut(&mut self) -> &mut S {
227         match *self {
228             MidHandshakeTlsStream::Server(ref mut s, _) => s.get_mut(),
229             MidHandshakeTlsStream::Client(ref mut s) => s.get_mut(),
230         }
231     }
232 }
233 
234 impl<S> MidHandshakeTlsStream<S>
235 where
236     S: io::Read + io::Write,
237 {
handshake(self) -> Result<TlsStream<S>, HandshakeError<S>>238     pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
239         match self {
240             MidHandshakeTlsStream::Server(s, cert) => match s.handshake() {
241                 Ok(stream) => Ok(TlsStream { stream, cert }),
242                 Err(secure_transport::HandshakeError::Failure(e)) => {
243                     Err(HandshakeError::Failure(Error(e)))
244                 }
245                 Err(secure_transport::HandshakeError::Interrupted(s)) => Err(
246                     HandshakeError::WouldBlock(MidHandshakeTlsStream::Server(s, cert)),
247                 ),
248             },
249             MidHandshakeTlsStream::Client(s) => match s.handshake() {
250                 Ok(stream) => Ok(TlsStream { stream, cert: None }),
251                 Err(e) => Err(e.into()),
252             },
253         }
254     }
255 }
256 
257 #[derive(Clone, Debug)]
258 pub struct TlsConnector {
259     identity: Option<Identity>,
260     min_protocol: Option<Protocol>,
261     max_protocol: Option<Protocol>,
262     roots: Vec<SecCertificate>,
263     use_sni: bool,
264     danger_accept_invalid_hostnames: bool,
265     danger_accept_invalid_certs: bool,
266     disable_built_in_roots: bool,
267     #[cfg(feature = "alpn")]
268     alpn: Vec<String>,
269 }
270 
271 impl TlsConnector {
new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error>272     pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
273         Ok(TlsConnector {
274             identity: builder.identity.as_ref().map(|i| i.0.clone()),
275             min_protocol: builder.min_protocol,
276             max_protocol: builder.max_protocol,
277             roots: builder
278                 .root_certificates
279                 .iter()
280                 .map(|c| (c.0).0.clone())
281                 .collect(),
282             use_sni: builder.use_sni,
283             danger_accept_invalid_hostnames: builder.accept_invalid_hostnames,
284             danger_accept_invalid_certs: builder.accept_invalid_certs,
285             disable_built_in_roots: builder.disable_built_in_roots,
286             #[cfg(feature = "alpn")]
287             alpn: builder.alpn.clone(),
288         })
289     }
290 
connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: io::Read + io::Write,291     pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
292     where
293         S: io::Read + io::Write,
294     {
295         let mut builder = ClientBuilder::new();
296         if let Some(min) = self.min_protocol {
297             builder.protocol_min(convert_protocol(min));
298         }
299         if let Some(max) = self.max_protocol {
300             builder.protocol_max(convert_protocol(max));
301         }
302         if let Some(identity) = self.identity.as_ref() {
303             builder.identity(&identity.identity, &identity.chain);
304         }
305         builder.anchor_certificates(&self.roots);
306         builder.use_sni(self.use_sni);
307         builder.danger_accept_invalid_hostnames(self.danger_accept_invalid_hostnames);
308         builder.danger_accept_invalid_certs(self.danger_accept_invalid_certs);
309         builder.trust_anchor_certificates_only(self.disable_built_in_roots);
310 
311         #[cfg(feature = "alpn")]
312         {
313             if !self.alpn.is_empty() {
314                 builder.alpn_protocols(&self.alpn.iter().map(String::as_str).collect::<Vec<_>>());
315             }
316         }
317 
318         match builder.handshake(domain, stream) {
319             Ok(stream) => Ok(TlsStream { stream, cert: None }),
320             Err(e) => Err(e.into()),
321         }
322     }
323 }
324 
325 #[derive(Clone)]
326 pub struct TlsAcceptor {
327     identity: Identity,
328     min_protocol: Option<Protocol>,
329     max_protocol: Option<Protocol>,
330 }
331 
332 impl TlsAcceptor {
new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error>333     pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
334         Ok(TlsAcceptor {
335             identity: builder.identity.0.clone(),
336             min_protocol: builder.min_protocol,
337             max_protocol: builder.max_protocol,
338         })
339     }
340 
accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: io::Read + io::Write,341     pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
342     where
343         S: io::Read + io::Write,
344     {
345         let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
346 
347         if let Some(min) = self.min_protocol {
348             ctx.set_protocol_version_min(convert_protocol(min))?;
349         }
350         if let Some(max) = self.max_protocol {
351             ctx.set_protocol_version_max(convert_protocol(max))?;
352         }
353         ctx.set_certificate(&self.identity.identity, &self.identity.chain)?;
354         let cert = Some(self.identity.identity.certificate()?);
355         match ctx.handshake(stream) {
356             Ok(stream) => Ok(TlsStream { stream, cert }),
357             Err(secure_transport::HandshakeError::Failure(e)) => {
358                 Err(HandshakeError::Failure(Error(e)))
359             }
360             Err(secure_transport::HandshakeError::Interrupted(s)) => Err(
361                 HandshakeError::WouldBlock(MidHandshakeTlsStream::Server(s, cert)),
362             ),
363         }
364     }
365 }
366 
367 pub struct TlsStream<S> {
368     stream: secure_transport::SslStream<S>,
369     cert: Option<SecCertificate>,
370 }
371 
372 impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result373     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
374         fmt::Debug::fmt(&self.stream, fmt)
375     }
376 }
377 
378 impl<S> TlsStream<S> {
get_ref(&self) -> &S379     pub fn get_ref(&self) -> &S {
380         self.stream.get_ref()
381     }
382 
get_mut(&mut self) -> &mut S383     pub fn get_mut(&mut self) -> &mut S {
384         self.stream.get_mut()
385     }
386 }
387 
388 impl<S: io::Read + io::Write> TlsStream<S> {
buffered_read_size(&self) -> Result<usize, Error>389     pub fn buffered_read_size(&self) -> Result<usize, Error> {
390         Ok(self.stream.context().buffered_read_size()?)
391     }
392 
peer_certificate(&self) -> Result<Option<Certificate>, Error>393     pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
394         let trust = match self.stream.context().peer_trust2()? {
395             Some(trust) => trust,
396             None => return Ok(None),
397         };
398         trust.evaluate()?;
399 
400         Ok(trust.certificate_at_index(0).map(Certificate))
401     }
402 
403     #[cfg(feature = "alpn")]
negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error>404     pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
405         match self.stream.context().alpn_protocols() {
406             Ok(protocols) => {
407                 // Per RFC7301, "ProtocolNameList" MUST contain exactly one "ProtocolName".
408                 assert!(protocols.len() < 2);
409 
410                 if protocols.is_empty() {
411                     // Not sure this is actually possible.
412                     Ok(None)
413                 } else {
414                     Ok(Some(protocols.into_iter().next().unwrap().into_bytes()))
415                 }
416             }
417             // The macOS API appears to return `errSecParam` whenever no ALPN was negotiated, both
418             // when it isn't attempted and when it isn't successful.
419             Err(e) if e.code() == errSecParam => Ok(None),
420             Err(other) => Err(Error::from(other)),
421         }
422     }
423 
424     #[cfg(target_os = "ios")]
tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error>425     pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
426         Ok(None)
427     }
428 
429     #[cfg(not(target_os = "ios"))]
tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error>430     pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
431         let cert = match self.cert {
432             Some(ref cert) => cert.clone(),
433             None => match self.peer_certificate()? {
434                 Some(cert) => cert.0,
435                 None => return Ok(None),
436             },
437         };
438 
439         let property = match cert
440             .properties(Some(&[CertificateOid::x509_v1_signature_algorithm()]))
441             .ok()
442             .and_then(|p| p.get(CertificateOid::x509_v1_signature_algorithm()))
443         {
444             Some(property) => property,
445             None => return Ok(None),
446         };
447 
448         let section = match property.get() {
449             PropertyType::Section(section) => section,
450             _ => return Ok(None),
451         };
452 
453         let algorithm = match section
454             .iter()
455             .filter(|p| p.label().to_string() == "Algorithm")
456             .next()
457         {
458             Some(property) => property,
459             None => return Ok(None),
460         };
461 
462         let algorithm = match algorithm.get() {
463             PropertyType::String(algorithm) => algorithm,
464             _ => return Ok(None),
465         };
466 
467         let digest = match &*algorithm.to_string() {
468             // MD5
469             "1.2.840.113549.2.5" | "1.2.840.113549.1.1.4" | "1.3.14.3.2.3" => Digest::Sha256,
470             // SHA-1
471             "1.3.14.3.2.26"
472             | "1.3.14.3.2.15"
473             | "1.2.840.113549.1.1.5"
474             | "1.3.14.3.2.29"
475             | "1.2.840.10040.4.3"
476             | "1.3.14.3.2.13"
477             | "1.2.840.10045.4.1" => Digest::Sha256,
478             // SHA-224
479             "2.16.840.1.101.3.4.2.4"
480             | "1.2.840.113549.1.1.14"
481             | "2.16.840.1.101.3.4.3.1"
482             | "1.2.840.10045.4.3.1" => Digest::Sha224,
483             // SHA-256
484             "2.16.840.1.101.3.4.2.1" | "1.2.840.113549.1.1.11" | "1.2.840.10045.4.3.2" => {
485                 Digest::Sha256
486             }
487             // SHA-384
488             "2.16.840.1.101.3.4.2.2" | "1.2.840.113549.1.1.12" | "1.2.840.10045.4.3.3" => {
489                 Digest::Sha384
490             }
491             // SHA-512
492             "2.16.840.1.101.3.4.2.3" | "1.2.840.113549.1.1.13" | "1.2.840.10045.4.3.4" => {
493                 Digest::Sha512
494             }
495             _ => return Ok(None),
496         };
497 
498         let der = cert.to_der();
499         Ok(Some(digest.hash(&der)))
500     }
501 
shutdown(&mut self) -> io::Result<()>502     pub fn shutdown(&mut self) -> io::Result<()> {
503         self.stream.close()?;
504         Ok(())
505     }
506 }
507 
508 impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>509     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
510         self.stream.read(buf)
511     }
512 }
513 
514 impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>515     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
516         self.stream.write(buf)
517     }
518 
flush(&mut self) -> io::Result<()>519     fn flush(&mut self) -> io::Result<()> {
520         self.stream.flush()
521     }
522 }
523 
524 enum Digest {
525     Sha224,
526     Sha256,
527     Sha384,
528     Sha512,
529 }
530 
531 impl Digest {
hash(&self, data: &[u8]) -> Vec<u8>532     fn hash(&self, data: &[u8]) -> Vec<u8> {
533         unsafe {
534             assert!(data.len() <= CC_LONG::max_value() as usize);
535             match *self {
536                 Digest::Sha224 => {
537                     let mut buf = [0; CC_SHA224_DIGEST_LENGTH];
538                     CC_SHA224(data.as_ptr(), data.len() as CC_LONG, buf.as_mut_ptr());
539                     buf.to_vec()
540                 }
541                 Digest::Sha256 => {
542                     let mut buf = [0; CC_SHA256_DIGEST_LENGTH];
543                     CC_SHA256(data.as_ptr(), data.len() as CC_LONG, buf.as_mut_ptr());
544                     buf.to_vec()
545                 }
546                 Digest::Sha384 => {
547                     let mut buf = [0; CC_SHA384_DIGEST_LENGTH];
548                     CC_SHA384(data.as_ptr(), data.len() as CC_LONG, buf.as_mut_ptr());
549                     buf.to_vec()
550                 }
551                 Digest::Sha512 => {
552                     let mut buf = [0; CC_SHA512_DIGEST_LENGTH];
553                     CC_SHA512(data.as_ptr(), data.len() as CC_LONG, buf.as_mut_ptr());
554                     buf.to_vec()
555                 }
556             }
557         }
558     }
559 }
560 
561 // FIXME ideally we'd pull these in from elsewhere
562 const CC_SHA224_DIGEST_LENGTH: usize = 28;
563 const CC_SHA256_DIGEST_LENGTH: usize = 32;
564 const CC_SHA384_DIGEST_LENGTH: usize = 48;
565 const CC_SHA512_DIGEST_LENGTH: usize = 64;
566 #[allow(non_camel_case_types)]
567 type CC_LONG = u32;
568 
569 extern "C" {
CC_SHA224(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8570     fn CC_SHA224(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8;
CC_SHA256(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8571     fn CC_SHA256(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8;
CC_SHA384(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8572     fn CC_SHA384(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8;
CC_SHA512(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8573     fn CC_SHA512(data: *const u8, len: CC_LONG, md: *mut u8) -> *mut u8;
574 }
575