1 extern crate schannel;
2
3 use self::schannel::cert_context::{CertContext, HashAlgorithm};
4 use self::schannel::cert_store::{CertAdd, CertStore, Memory, PfxImportOptions};
5 use self::schannel::schannel_cred::{Direction, Protocol, SchannelCred};
6 use self::schannel::tls_stream;
7 use std::error;
8 use std::fmt;
9 use std::io;
10 use std::str;
11
12 use {TlsAcceptorBuilder, TlsConnectorBuilder};
13
14 const SEC_E_NO_CREDENTIALS: u32 = 0x8009030E;
15
16 static PROTOCOLS: &'static [Protocol] = &[
17 Protocol::Ssl3,
18 Protocol::Tls10,
19 Protocol::Tls11,
20 Protocol::Tls12,
21 ];
22
convert_protocols(min: Option<::Protocol>, max: Option<::Protocol>) -> &'static [Protocol]23 fn convert_protocols(min: Option<::Protocol>, max: Option<::Protocol>) -> &'static [Protocol] {
24 let mut protocols = PROTOCOLS;
25 if let Some(p) = max.and_then(|max| protocols.get(..max as usize)) {
26 protocols = p;
27 }
28 if let Some(p) = min.and_then(|min| protocols.get(min as usize..)) {
29 protocols = p;
30 }
31 protocols
32 }
33
34 pub struct Error(io::Error);
35
36 impl error::Error for Error {
source(&self) -> Option<&(dyn error::Error + 'static)>37 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
38 error::Error::source(&self.0)
39 }
40 }
41
42 impl fmt::Display for Error {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result43 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
44 fmt::Display::fmt(&self.0, fmt)
45 }
46 }
47
48 impl fmt::Debug for Error {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result49 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
50 fmt::Debug::fmt(&self.0, fmt)
51 }
52 }
53
54 impl From<io::Error> for Error {
from(error: io::Error) -> Error55 fn from(error: io::Error) -> Error {
56 Error(error)
57 }
58 }
59
60 #[derive(Clone)]
61 pub struct Identity {
62 cert: CertContext,
63 }
64
65 impl Identity {
from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error>66 pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
67 let store = PfxImportOptions::new().password(pass).import(buf)?;
68 let mut identity = None;
69
70 for cert in store.certs() {
71 if cert
72 .private_key()
73 .silent(true)
74 .compare_key(true)
75 .acquire()
76 .is_ok()
77 {
78 identity = Some(cert);
79 break;
80 }
81 }
82
83 let identity = match identity {
84 Some(identity) => identity,
85 None => {
86 return Err(io::Error::new(
87 io::ErrorKind::InvalidInput,
88 "No identity found in PKCS #12 archive",
89 )
90 .into());
91 }
92 };
93
94 Ok(Identity { cert: identity })
95 }
96 }
97
98 #[derive(Clone)]
99 pub struct Certificate(CertContext);
100
101 impl Certificate {
from_der(buf: &[u8]) -> Result<Certificate, Error>102 pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
103 let cert = CertContext::new(buf)?;
104 Ok(Certificate(cert))
105 }
106
from_pem(buf: &[u8]) -> Result<Certificate, Error>107 pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
108 match str::from_utf8(buf) {
109 Ok(s) => {
110 let cert = CertContext::from_pem(s)?;
111 Ok(Certificate(cert))
112 }
113 Err(_) => Err(io::Error::new(
114 io::ErrorKind::InvalidInput,
115 "PEM representation contains non-UTF-8 bytes",
116 )
117 .into()),
118 }
119 }
120
to_der(&self) -> Result<Vec<u8>, Error>121 pub fn to_der(&self) -> Result<Vec<u8>, Error> {
122 Ok(self.0.to_der().to_vec())
123 }
124 }
125
126 pub struct MidHandshakeTlsStream<S>(tls_stream::MidHandshakeTlsStream<S>);
127
128 impl<S> fmt::Debug for MidHandshakeTlsStream<S>
129 where
130 S: fmt::Debug,
131 {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result132 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
133 fmt::Debug::fmt(&self.0, fmt)
134 }
135 }
136
137 impl<S> MidHandshakeTlsStream<S> {
get_ref(&self) -> &S138 pub fn get_ref(&self) -> &S {
139 self.0.get_ref()
140 }
141
get_mut(&mut self) -> &mut S142 pub fn get_mut(&mut self) -> &mut S {
143 self.0.get_mut()
144 }
145 }
146
147 impl<S> MidHandshakeTlsStream<S>
148 where
149 S: io::Read + io::Write,
150 {
handshake(self) -> Result<TlsStream<S>, HandshakeError<S>>151 pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
152 match self.0.handshake() {
153 Ok(s) => Ok(TlsStream(s)),
154 Err(e) => Err(e.into()),
155 }
156 }
157 }
158
159 pub enum HandshakeError<S> {
160 Failure(Error),
161 WouldBlock(MidHandshakeTlsStream<S>),
162 }
163
164 impl<S> From<tls_stream::HandshakeError<S>> for HandshakeError<S> {
from(e: tls_stream::HandshakeError<S>) -> HandshakeError<S>165 fn from(e: tls_stream::HandshakeError<S>) -> HandshakeError<S> {
166 match e {
167 tls_stream::HandshakeError::Failure(e) => HandshakeError::Failure(e.into()),
168 tls_stream::HandshakeError::Interrupted(s) => {
169 HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
170 }
171 }
172 }
173 }
174
175 impl<S> From<io::Error> for HandshakeError<S> {
from(e: io::Error) -> HandshakeError<S>176 fn from(e: io::Error) -> HandshakeError<S> {
177 HandshakeError::Failure(e.into())
178 }
179 }
180
181 #[derive(Clone, Debug)]
182 pub struct TlsConnector {
183 cert: Option<CertContext>,
184 roots: CertStore,
185 min_protocol: Option<::Protocol>,
186 max_protocol: Option<::Protocol>,
187 use_sni: bool,
188 accept_invalid_hostnames: bool,
189 accept_invalid_certs: bool,
190 disable_built_in_roots: bool,
191 #[cfg(feature = "alpn")]
192 alpn: Vec<String>,
193 }
194
195 impl TlsConnector {
new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error>196 pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
197 let cert = builder.identity.as_ref().map(|i| i.0.cert.clone());
198 let mut roots = Memory::new()?.into_store();
199 for cert in &builder.root_certificates {
200 roots.add_cert(&(cert.0).0, CertAdd::ReplaceExisting)?;
201 }
202
203 Ok(TlsConnector {
204 cert,
205 roots,
206 min_protocol: builder.min_protocol,
207 max_protocol: builder.max_protocol,
208 use_sni: builder.use_sni,
209 accept_invalid_hostnames: builder.accept_invalid_hostnames,
210 accept_invalid_certs: builder.accept_invalid_certs,
211 disable_built_in_roots: builder.disable_built_in_roots,
212 #[cfg(feature = "alpn")]
213 alpn: builder.alpn.clone(),
214 })
215 }
216
connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: io::Read + io::Write,217 pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
218 where
219 S: io::Read + io::Write,
220 {
221 let mut builder = SchannelCred::builder();
222 builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol));
223 if let Some(cert) = self.cert.as_ref() {
224 builder.cert(cert.clone());
225 }
226 let cred = builder.acquire(Direction::Outbound)?;
227 let mut builder = tls_stream::Builder::new();
228 builder
229 .cert_store(self.roots.clone())
230 .domain(domain)
231 .use_sni(self.use_sni)
232 .accept_invalid_hostnames(self.accept_invalid_hostnames);
233 if self.accept_invalid_certs {
234 builder.verify_callback(|_| Ok(()));
235 } else if self.disable_built_in_roots {
236 let roots_copy = self.roots.clone();
237 builder.verify_callback(move |res| {
238 if let Err(err) = res.result() {
239 // Propagate previous error encountered during normal cert validation.
240 return Err(err);
241 }
242
243 if let Some(chain) = res.chain() {
244 if chain
245 .certificates()
246 .any(|cert| roots_copy.certs().any(|root_cert| root_cert == cert))
247 {
248 return Ok(());
249 }
250 }
251
252 Err(io::Error::new(
253 io::ErrorKind::Other,
254 "unable to find any user-specified roots in the final cert chain",
255 ))
256 });
257 }
258 #[cfg(feature = "alpn")]
259 {
260 if !self.alpn.is_empty() {
261 builder.request_application_protocols(
262 &self.alpn.iter().map(|s| s.as_bytes()).collect::<Vec<_>>(),
263 );
264 }
265 }
266 match builder.connect(cred, stream) {
267 Ok(s) => Ok(TlsStream(s)),
268 Err(e) => Err(e.into()),
269 }
270 }
271 }
272
273 #[derive(Clone)]
274 pub struct TlsAcceptor {
275 cert: CertContext,
276 min_protocol: Option<::Protocol>,
277 max_protocol: Option<::Protocol>,
278 }
279
280 impl TlsAcceptor {
new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error>281 pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
282 Ok(TlsAcceptor {
283 cert: builder.identity.0.cert.clone(),
284 min_protocol: builder.min_protocol,
285 max_protocol: builder.max_protocol,
286 })
287 }
288
accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>> where S: io::Read + io::Write,289 pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
290 where
291 S: io::Read + io::Write,
292 {
293 let mut builder = SchannelCred::builder();
294 builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol));
295 builder.cert(self.cert.clone());
296 // FIXME we're probably missing the certificate chain?
297 let cred = builder.acquire(Direction::Inbound)?;
298 match tls_stream::Builder::new().accept(cred, stream) {
299 Ok(s) => Ok(TlsStream(s)),
300 Err(e) => Err(e.into()),
301 }
302 }
303 }
304
305 pub struct TlsStream<S>(tls_stream::TlsStream<S>);
306
307 impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result308 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
309 fmt::Debug::fmt(&self.0, fmt)
310 }
311 }
312
313 impl<S> TlsStream<S> {
get_ref(&self) -> &S314 pub fn get_ref(&self) -> &S {
315 self.0.get_ref()
316 }
317
get_mut(&mut self) -> &mut S318 pub fn get_mut(&mut self) -> &mut S {
319 self.0.get_mut()
320 }
321 }
322
323 impl<S: io::Read + io::Write> TlsStream<S> {
buffered_read_size(&self) -> Result<usize, Error>324 pub fn buffered_read_size(&self) -> Result<usize, Error> {
325 Ok(self.0.get_buf().len())
326 }
327
peer_certificate(&self) -> Result<Option<Certificate>, Error>328 pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
329 match self.0.peer_certificate() {
330 Ok(cert) => Ok(Some(Certificate(cert))),
331 Err(ref e) if e.raw_os_error() == Some(SEC_E_NO_CREDENTIALS as i32) => Ok(None),
332 Err(e) => Err(Error(e)),
333 }
334 }
335
336 #[cfg(feature = "alpn")]
negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error>337 pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
338 Ok(self.0.negotiated_application_protocol()?)
339 }
340
tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error>341 pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
342 let cert = if self.0.is_server() {
343 self.0.certificate()
344 } else {
345 self.0.peer_certificate()
346 };
347
348 let cert = match cert {
349 Ok(cert) => cert,
350 Err(ref e) if e.raw_os_error() == Some(SEC_E_NO_CREDENTIALS as i32) => return Ok(None),
351 Err(e) => return Err(Error(e)),
352 };
353
354 let signature_algorithms = cert.sign_hash_algorithms()?;
355 let hash = match signature_algorithms.rsplit('/').next().unwrap() {
356 "MD5" | "SHA1" | "SHA256" => HashAlgorithm::sha256(),
357 "SHA384" => HashAlgorithm::sha384(),
358 "SHA512" => HashAlgorithm::sha512(),
359 _ => return Ok(None),
360 };
361
362 let digest = cert.fingerprint(hash)?;
363 Ok(Some(digest))
364 }
365
shutdown(&mut self) -> io::Result<()>366 pub fn shutdown(&mut self) -> io::Result<()> {
367 self.0.shutdown()?;
368 Ok(())
369 }
370 }
371
372 impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>373 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
374 self.0.read(buf)
375 }
376 }
377
378 impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
write(&mut self, buf: &[u8]) -> io::Result<usize>379 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
380 self.0.write(buf)
381 }
382
flush(&mut self) -> io::Result<()>383 fn flush(&mut self) -> io::Result<()> {
384 self.0.flush()
385 }
386 }
387