1 extern crate schannel;
2 
3 use self::schannel::cert_context::{CertContext, HashAlgorithm};
4 use self::schannel::cert_store::{CertAdd, CertStore, Memory, PfxImportOptions};
5 use self::schannel::schannel_cred::{Direction, Protocol, SchannelCred};
6 use self::schannel::tls_stream;
7 use std::error;
8 use std::fmt;
9 use std::io;
10 use std::str;
11 
12 use {TlsAcceptorBuilder, TlsConnectorBuilder};
13 
14 const SEC_E_NO_CREDENTIALS: u32 = 0x8009030E;
15 
16 static PROTOCOLS: &'static [Protocol] = &[
17     Protocol::Ssl3,
18     Protocol::Tls10,
19     Protocol::Tls11,
20     Protocol::Tls12,
21 ];
22 
convert_protocols(min: Option<::Protocol>, max: Option<::Protocol>) -> &'static [Protocol]23 fn convert_protocols(min: Option<::Protocol>, max: Option<::Protocol>) -> &'static [Protocol] {
24     let mut protocols = PROTOCOLS;
25     if let Some(p) = max.and_then(|max| protocols.get(..=max as usize)) {
26         protocols = p;
27     }
28     if let Some(p) = min.and_then(|min| protocols.get(min as usize..)) {
29         protocols = p;
30     }
31     protocols
32 }
33 
34 pub struct Error(io::Error);
35 
36 impl error::Error for Error {
source(&self) -> Option<&(dyn error::Error + 'static)>37     fn source(&self) -> Option<&(dyn error::Error + 'static)> {
38         error::Error::source(&self.0)
39     }
40 }
41 
42 impl fmt::Display for Error {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result43     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
44         fmt::Display::fmt(&self.0, fmt)
45     }
46 }
47 
48 impl fmt::Debug for Error {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result49     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
50         fmt::Debug::fmt(&self.0, fmt)
51     }
52 }
53 
54 impl From<io::Error> for Error {
from(error: io::Error) -> Error55     fn from(error: io::Error) -> Error {
56         Error(error)
57     }
58 }
59 
60 #[derive(Clone)]
61 pub struct Identity {
62     cert: CertContext,
63 }
64 
65 impl Identity {
from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error>66     pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
67         let store = PfxImportOptions::new().password(pass).import(buf)?;
68         let mut identity = None;
69 
70         for cert in store.certs() {
71             if cert
72                 .private_key()
73                 .silent(true)
74                 .compare_key(true)
75                 .acquire()
76                 .is_ok()
77             {
78                 identity = Some(cert);
79                 break;
80             }
81         }
82 
83         let identity = match identity {
84             Some(identity) => identity,
85             None => {
86                 return Err(io::Error::new(
87                     io::ErrorKind::InvalidInput,
88                     "No identity found in PKCS #12 archive",
89                 )
90                 .into());
91             }
92         };
93 
94         Ok(Identity { cert: identity })
95     }
96 }
97 
98 #[derive(Clone)]
99 pub struct Certificate(CertContext);
100 
101 impl Certificate {
from_der(buf: &[u8]) -> Result<Certificate, Error>102     pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
103         let cert = CertContext::new(buf)?;
104         Ok(Certificate(cert))
105     }
106 
from_pem(buf: &[u8]) -> Result<Certificate, Error>107     pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
108         match str::from_utf8(buf) {
109             Ok(s) => {
110                 let cert = CertContext::from_pem(s)?;
111                 Ok(Certificate(cert))
112             }
113             Err(_) => Err(io::Error::new(
114                 io::ErrorKind::InvalidInput,
115                 "PEM representation contains non-UTF-8 bytes",
116             )
117             .into()),
118         }
119     }
120 
to_der(&self) -> Result<Vec<u8>, Error>121     pub fn to_der(&self) -> Result<Vec<u8>, Error> {
122         Ok(self.0.to_der().to_vec())
123     }
124 }
125 
126 pub struct MidHandshakeTlsStream<S>(tls_stream::MidHandshakeTlsStream<S>);
127 
128 impl<S> fmt::Debug for MidHandshakeTlsStream<S>
129 where
130     S: fmt::Debug,
131 {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result132     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
133         fmt::Debug::fmt(&self.0, fmt)
134     }
135 }
136 
137 impl<S> MidHandshakeTlsStream<S> {
get_ref(&self) -> &S138     pub fn get_ref(&self) -> &S {
139         self.0.get_ref()
140     }
141 
get_mut(&mut self) -> &mut S142     pub fn get_mut(&mut self) -> &mut S {
143         self.0.get_mut()
144     }
145 }
146 
147 impl<S> MidHandshakeTlsStream<S>
148 where
149     S: io::Read + io::Write,
150 {
handshake(self) -> Result<TlsStream<S>, HandshakeError<S>>151     pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
152         match self.0.handshake() {
153             Ok(s) => Ok(TlsStream(s)),
154             Err(e) => Err(e.into()),
155         }
156     }
157 }
158 
159 pub enum HandshakeError<S> {
160     Failure(Error),
161     WouldBlock(MidHandshakeTlsStream<S>),
162 }
163 
164 impl<S> From<tls_stream::HandshakeError<S>> for HandshakeError<S> {
from(e: tls_stream::HandshakeError<S>) -> HandshakeError<S>165     fn from(e: tls_stream::HandshakeError<S>) -> HandshakeError<S> {
166         match e {
167             tls_stream::HandshakeError::Failure(e) => HandshakeError::Failure(e.into()),
168             tls_stream::HandshakeError::Interrupted(s) => {
169                 HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
170             }
171         }
172     }
173 }
174 
175 impl<S> From<io::Error> for HandshakeError<S> {
from(e: io::Error) -> HandshakeError<S>176     fn from(e: io::Error) -> HandshakeError<S> {
177         HandshakeError::Failure(e.into())
178     }
179 }
180 
181 #[derive(Clone, Debug)]
182 pub struct TlsConnector {
183     cert: Option<CertContext>,
184     roots: CertStore,
185     min_protocol: Option<::Protocol>,
186     max_protocol: Option<::Protocol>,
187     use_sni: bool,
188     accept_invalid_hostnames: bool,
189     accept_invalid_certs: bool,
190     disable_built_in_roots: bool,
191     #[cfg(feature = "alpn")]
192     alpn: Vec<String>,
193 }
194 
195 impl TlsConnector {
new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error>196     pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
197         let cert = builder.identity.as_ref().map(|i| i.0.cert.clone());
198         let mut roots = Memory::new()?.into_store();
199         for cert in &builder.root_certificates {
200             roots.add_cert(&(cert.0).0, CertAdd::ReplaceExisting)?;
201         }
202 
203         Ok(TlsConnector {
204             cert,
205             roots,
206             min_protocol: builder.min_protocol,
207             max_protocol: builder.max_protocol,
208             use_sni: builder.use_sni,
209             accept_invalid_hostnames: builder.accept_invalid_hostnames,
210             accept_invalid_certs: builder.accept_invalid_certs,
211             disable_built_in_roots: builder.disable_built_in_roots,
212             #[cfg(feature = "alpn")]
213             alpn: builder.alpn.clone(),
214         })
215     }
216 
connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: io::Read + io::Write,217     pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
218     where
219         S: io::Read + io::Write,
220     {
221         let mut builder = SchannelCred::builder();
222         builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol));
223         if let Some(cert) = self.cert.as_ref() {
224             builder.cert(cert.clone());
225         }
226         let cred = builder.acquire(Direction::Outbound)?;
227         let mut builder = tls_stream::Builder::new();
228         builder
229             .cert_store(self.roots.clone())
230             .domain(domain)
231             .use_sni(self.use_sni)
232             .accept_invalid_hostnames(self.accept_invalid_hostnames);
233         if self.accept_invalid_certs {
234             builder.verify_callback(|_| Ok(()));
235         } else if self.disable_built_in_roots {
236             let roots_copy = self.roots.clone();
237             builder.verify_callback(move |res| {
238                 if let Err(err) = res.result() {
239                     // Propagate previous error encountered during normal cert validation.
240                     return Err(err);
241                 }
242 
243                 if let Some(chain) = res.chain() {
244                     if chain
245                         .certificates()
246                         .any(|cert| roots_copy.certs().any(|root_cert| root_cert == cert))
247                     {
248                         return Ok(());
249                     }
250                 }
251 
252                 Err(io::Error::new(
253                     io::ErrorKind::Other,
254                     "unable to find any user-specified roots in the final cert chain",
255                 ))
256             });
257         }
258         #[cfg(feature = "alpn")]
259         {
260             if !self.alpn.is_empty() {
261                 builder.request_application_protocols(
262                     &self.alpn.iter().map(|s| s.as_bytes()).collect::<Vec<_>>(),
263                 );
264             }
265         }
266         match builder.connect(cred, stream) {
267             Ok(s) => Ok(TlsStream(s)),
268             Err(e) => Err(e.into()),
269         }
270     }
271 }
272 
273 #[derive(Clone)]
274 pub struct TlsAcceptor {
275     cert: CertContext,
276     min_protocol: Option<::Protocol>,
277     max_protocol: Option<::Protocol>,
278 }
279 
280 impl TlsAcceptor {
new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error>281     pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
282         Ok(TlsAcceptor {
283             cert: builder.identity.0.cert.clone(),
284             min_protocol: builder.min_protocol,
285             max_protocol: builder.max_protocol,
286         })
287     }
288 
accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: io::Read + io::Write,289     pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
290     where
291         S: io::Read + io::Write,
292     {
293         let mut builder = SchannelCred::builder();
294         builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol));
295         builder.cert(self.cert.clone());
296         // FIXME we're probably missing the certificate chain?
297         let cred = builder.acquire(Direction::Inbound)?;
298         match tls_stream::Builder::new().accept(cred, stream) {
299             Ok(s) => Ok(TlsStream(s)),
300             Err(e) => Err(e.into()),
301         }
302     }
303 }
304 
305 pub struct TlsStream<S>(tls_stream::TlsStream<S>);
306 
307 impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result308     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
309         fmt::Debug::fmt(&self.0, fmt)
310     }
311 }
312 
313 impl<S> TlsStream<S> {
get_ref(&self) -> &S314     pub fn get_ref(&self) -> &S {
315         self.0.get_ref()
316     }
317 
get_mut(&mut self) -> &mut S318     pub fn get_mut(&mut self) -> &mut S {
319         self.0.get_mut()
320     }
321 }
322 
323 impl<S: io::Read + io::Write> TlsStream<S> {
buffered_read_size(&self) -> Result<usize, Error>324     pub fn buffered_read_size(&self) -> Result<usize, Error> {
325         Ok(self.0.get_buf().len())
326     }
327 
peer_certificate(&self) -> Result<Option<Certificate>, Error>328     pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
329         match self.0.peer_certificate() {
330             Ok(cert) => Ok(Some(Certificate(cert))),
331             Err(ref e) if e.raw_os_error() == Some(SEC_E_NO_CREDENTIALS as i32) => Ok(None),
332             Err(e) => Err(Error(e)),
333         }
334     }
335 
336     #[cfg(feature = "alpn")]
negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error>337     pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
338         Ok(self.0.negotiated_application_protocol()?)
339     }
340 
tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error>341     pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
342         let cert = if self.0.is_server() {
343             self.0.certificate()
344         } else {
345             self.0.peer_certificate()
346         };
347 
348         let cert = match cert {
349             Ok(cert) => cert,
350             Err(ref e) if e.raw_os_error() == Some(SEC_E_NO_CREDENTIALS as i32) => return Ok(None),
351             Err(e) => return Err(Error(e)),
352         };
353 
354         let signature_algorithms = cert.sign_hash_algorithms()?;
355         let hash = match signature_algorithms.rsplit('/').next().unwrap() {
356             "MD5" | "SHA1" | "SHA256" => HashAlgorithm::sha256(),
357             "SHA384" => HashAlgorithm::sha384(),
358             "SHA512" => HashAlgorithm::sha512(),
359             _ => return Ok(None),
360         };
361 
362         let digest = cert.fingerprint(hash)?;
363         Ok(Some(digest))
364     }
365 
shutdown(&mut self) -> io::Result<()>366     pub fn shutdown(&mut self) -> io::Result<()> {
367         self.0.shutdown()?;
368         Ok(())
369     }
370 }
371 
372 impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>373     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
374         self.0.read(buf)
375     }
376 }
377 
378 impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>379     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
380         self.0.write(buf)
381     }
382 
flush(&mut self) -> io::Result<()>383     fn flush(&mut self) -> io::Result<()> {
384         self.0.flush()
385     }
386 }
387