1 #![warn(rust_2018_idioms)]
2 
3 use cfg_if::cfg_if;
4 use env_logger;
5 use futures::join;
6 use native_tls;
7 use native_tls::{Identity, TlsAcceptor, TlsConnector};
8 use std::io::Write;
9 use std::marker::Unpin;
10 use std::process::Command;
11 use std::ptr;
12 use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, Error, ErrorKind};
13 use tokio::net::{TcpListener, TcpStream};
14 use tokio::stream::StreamExt;
15 use tokio_tls;
16 
17 macro_rules! t {
18     ($e:expr) => {
19         match $e {
20             Ok(e) => e,
21             Err(e) => panic!("{} failed with {:?}", stringify!($e), e),
22         }
23     };
24 }
25 
26 #[allow(dead_code)]
27 struct Keys {
28     cert_der: Vec<u8>,
29     pkey_der: Vec<u8>,
30     pkcs12_der: Vec<u8>,
31 }
32 
33 #[allow(dead_code)]
openssl_keys() -> &'static Keys34 fn openssl_keys() -> &'static Keys {
35     static INIT: Once = Once::new();
36     static mut KEYS: *mut Keys = ptr::null_mut();
37 
38     INIT.call_once(|| {
39         let path = t!(env::current_exe());
40         let path = path.parent().unwrap();
41         let keyfile = path.join("test.key");
42         let certfile = path.join("test.crt");
43         let config = path.join("openssl.config");
44 
45         File::create(&config)
46             .unwrap()
47             .write_all(
48                 b"\
49             [req]\n\
50             distinguished_name=dn\n\
51             [ dn ]\n\
52             CN=localhost\n\
53             [ ext ]\n\
54             basicConstraints=CA:FALSE,pathlen:0\n\
55             subjectAltName = @alt_names
56             extendedKeyUsage=serverAuth,clientAuth
57             [alt_names]
58             DNS.1 = localhost
59         ",
60             )
61             .unwrap();
62 
63         let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost";
64         let output = t!(Command::new("openssl")
65             .arg("req")
66             .arg("-nodes")
67             .arg("-x509")
68             .arg("-newkey")
69             .arg("rsa:2048")
70             .arg("-config")
71             .arg(&config)
72             .arg("-extensions")
73             .arg("ext")
74             .arg("-subj")
75             .arg(subj)
76             .arg("-keyout")
77             .arg(&keyfile)
78             .arg("-out")
79             .arg(&certfile)
80             .arg("-days")
81             .arg("1")
82             .output());
83         assert!(output.status.success());
84 
85         let crtout = t!(Command::new("openssl")
86             .arg("x509")
87             .arg("-outform")
88             .arg("der")
89             .arg("-in")
90             .arg(&certfile)
91             .output());
92         assert!(crtout.status.success());
93         let keyout = t!(Command::new("openssl")
94             .arg("rsa")
95             .arg("-outform")
96             .arg("der")
97             .arg("-in")
98             .arg(&keyfile)
99             .output());
100         assert!(keyout.status.success());
101 
102         let pkcs12out = t!(Command::new("openssl")
103             .arg("pkcs12")
104             .arg("-export")
105             .arg("-nodes")
106             .arg("-inkey")
107             .arg(&keyfile)
108             .arg("-in")
109             .arg(&certfile)
110             .arg("-password")
111             .arg("pass:foobar")
112             .output());
113         assert!(pkcs12out.status.success());
114 
115         let keys = Box::new(Keys {
116             cert_der: crtout.stdout,
117             pkey_der: keyout.stdout,
118             pkcs12_der: pkcs12out.stdout,
119         });
120         unsafe {
121             KEYS = Box::into_raw(keys);
122         }
123     });
124     unsafe { &*KEYS }
125 }
126 
127 cfg_if! {
128     if #[cfg(feature = "rustls")] {
129         use webpki;
130         use untrusted;
131         use std::env;
132         use std::fs::File;
133         use std::process::Command;
134         use std::sync::Once;
135 
136         use untrusted::Input;
137         use webpki::trust_anchor_util;
138 
139         fn server_cx() -> io::Result<ServerContext> {
140             let mut cx = ServerContext::new();
141 
142             let (cert, key) = keys();
143             cx.config_mut()
144               .set_single_cert(vec![cert.to_vec()], key.to_vec());
145 
146             Ok(cx)
147         }
148 
149         fn configure_client(cx: &mut ClientContext) {
150             let (cert, _key) = keys();
151             let cert = Input::from(cert);
152             let anchor = trust_anchor_util::cert_der_as_trust_anchor(cert).unwrap();
153             cx.config_mut().root_store.add_trust_anchors(&[anchor]);
154         }
155 
156         // Like OpenSSL we generate certificates on the fly, but for OSX we
157         // also have to put them into a specific keychain. We put both the
158         // certificates and the keychain next to our binary.
159         //
160         // Right now I don't know of a way to programmatically create a
161         // self-signed certificate, so we just fork out to the `openssl` binary.
162         fn keys() -> (&'static [u8], &'static [u8]) {
163             static INIT: Once = Once::new();
164             static mut KEYS: *mut (Vec<u8>, Vec<u8>) = ptr::null_mut();
165 
166             INIT.call_once(|| {
167                 let (key, cert) = openssl_keys();
168                 let path = t!(env::current_exe());
169                 let path = path.parent().unwrap();
170                 let keyfile = path.join("test.key");
171                 let certfile = path.join("test.crt");
172                 let config = path.join("openssl.config");
173 
174                 File::create(&config).unwrap().write_all(b"\
175                     [req]\n\
176                     distinguished_name=dn\n\
177                     [ dn ]\n\
178                     CN=localhost\n\
179                     [ ext ]\n\
180                     basicConstraints=CA:FALSE,pathlen:0\n\
181                     subjectAltName = @alt_names
182                     [alt_names]
183                     DNS.1 = localhost
184                 ").unwrap();
185 
186                 let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost";
187                 let output = t!(Command::new("openssl")
188                                         .arg("req")
189                                         .arg("-nodes")
190                                         .arg("-x509")
191                                         .arg("-newkey").arg("rsa:2048")
192                                         .arg("-config").arg(&config)
193                                         .arg("-extensions").arg("ext")
194                                         .arg("-subj").arg(subj)
195                                         .arg("-keyout").arg(&keyfile)
196                                         .arg("-out").arg(&certfile)
197                                         .arg("-days").arg("1")
198                                         .output());
199                 assert!(output.status.success());
200 
201                 let crtout = t!(Command::new("openssl")
202                                         .arg("x509")
203                                         .arg("-outform").arg("der")
204                                         .arg("-in").arg(&certfile)
205                                         .output());
206                 assert!(crtout.status.success());
207                 let keyout = t!(Command::new("openssl")
208                                         .arg("rsa")
209                                         .arg("-outform").arg("der")
210                                         .arg("-in").arg(&keyfile)
211                                         .output());
212                 assert!(keyout.status.success());
213 
214                 let cert = crtout.stdout;
215                 let key = keyout.stdout;
216                 unsafe {
217                     KEYS = Box::into_raw(Box::new((cert, key)));
218                 }
219             });
220             unsafe {
221                 (&(*KEYS).0, &(*KEYS).1)
222             }
223         }
224     } else if #[cfg(any(feature = "force-openssl",
225                         all(not(target_os = "macos"),
226                             not(target_os = "windows"),
227                             not(target_os = "ios"))))] {
228         use std::fs::File;
229         use std::env;
230         use std::sync::Once;
231 
232         fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
233             let keys = openssl_keys();
234 
235             let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar"));
236             let srv = TlsAcceptor::builder(pkcs12);
237 
238             let cert = t!(native_tls::Certificate::from_der(&keys.cert_der));
239 
240             let mut client = TlsConnector::builder();
241             t!(client.add_root_certificate(cert).build());
242 
243             (t!(srv.build()).into(), t!(client.build()).into())
244         }
245     } else if #[cfg(any(target_os = "macos", target_os = "ios"))] {
246         use std::env;
247         use std::fs::File;
248         use std::sync::Once;
249 
250         fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
251             let keys = openssl_keys();
252 
253             let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar"));
254             let srv = TlsAcceptor::builder(pkcs12);
255 
256             let cert = native_tls::Certificate::from_der(&keys.cert_der).unwrap();
257             let mut client = TlsConnector::builder();
258             client.add_root_certificate(cert);
259 
260             (t!(srv.build()).into(), t!(client.build()).into())
261         }
262     } else {
263         use schannel;
264         use winapi;
265 
266         use std::env;
267         use std::fs::File;
268         use std::io;
269         use std::mem;
270         use std::sync::Once;
271 
272         use schannel::cert_context::CertContext;
273         use schannel::cert_store::{CertStore, CertAdd, Memory};
274         use winapi::shared::basetsd::*;
275         use winapi::shared::lmcons::*;
276         use winapi::shared::minwindef::*;
277         use winapi::shared::ntdef::WCHAR;
278         use winapi::um::minwinbase::*;
279         use winapi::um::sysinfoapi::*;
280         use winapi::um::timezoneapi::*;
281         use winapi::um::wincrypt::*;
282 
283         const FRIENDLY_NAME: &str = "tokio-tls localhost testing cert";
284 
285         fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
286             let cert = localhost_cert();
287             let mut store = t!(Memory::new()).into_store();
288             t!(store.add_cert(&cert, CertAdd::Always));
289             let pkcs12_der = t!(store.export_pkcs12("foobar"));
290             let pkcs12 = t!(Identity::from_pkcs12(&pkcs12_der, "foobar"));
291 
292             let srv = TlsAcceptor::builder(pkcs12);
293             let client = TlsConnector::builder();
294             (t!(srv.build()).into(), t!(client.build()).into())
295         }
296 
297         // ====================================================================
298         // Magic!
299         //
300         // Lots of magic is happening here to wrangle certificates for running
301         // these tests on Windows. For more information see the test suite
302         // in the schannel-rs crate as this is just coyping that.
303         //
304         // The general gist of this though is that the only way to add custom
305         // trusted certificates is to add it to the system store of trust. To
306         // do that we go through the whole rigamarole here to generate a new
307         // self-signed certificate and then insert that into the system store.
308         //
309         // This generates some dialogs, so we print what we're doing sometimes,
310         // and otherwise we just manage the ephemeral certificates. Because
311         // they're in the system store we always ensure that they're only valid
312         // for a small period of time (e.g. 1 day).
313 
314         fn localhost_cert() -> CertContext {
315             static INIT: Once = Once::new();
316             INIT.call_once(|| {
317                 for cert in local_root_store().certs() {
318                     let name = match cert.friendly_name() {
319                         Ok(name) => name,
320                         Err(_) => continue,
321                     };
322                     if name != FRIENDLY_NAME {
323                         continue
324                     }
325                     if !cert.is_time_valid().unwrap() {
326                         io::stdout().write_all(br#"
327 
328 The tokio-tls test suite is about to delete an old copy of one of its
329 certificates from your root trust store. This certificate was only valid for one
330 day and it is no longer needed. The host should be "localhost" and the
331 description should mention "tokio-tls".
332 
333         "#).unwrap();
334                         cert.delete().unwrap();
335                     } else {
336                         return
337                     }
338                 }
339 
340                 install_certificate().unwrap();
341             });
342 
343             for cert in local_root_store().certs() {
344                 let name = match cert.friendly_name() {
345                     Ok(name) => name,
346                     Err(_) => continue,
347                 };
348                 if name == FRIENDLY_NAME {
349                     return cert
350                 }
351             }
352 
353             panic!("couldn't find a cert");
354         }
355 
356         fn local_root_store() -> CertStore {
357             if env::var("CI").is_ok() {
358                 CertStore::open_local_machine("Root").unwrap()
359             } else {
360                 CertStore::open_current_user("Root").unwrap()
361             }
362         }
363 
364         fn install_certificate() -> io::Result<CertContext> {
365             unsafe {
366                 let mut provider = 0;
367                 let mut hkey = 0;
368 
369                 let mut buffer = "tokio-tls test suite".encode_utf16()
370                                                          .chain(Some(0))
371                                                          .collect::<Vec<_>>();
372                 let res = CryptAcquireContextW(&mut provider,
373                                                buffer.as_ptr(),
374                                                ptr::null_mut(),
375                                                PROV_RSA_FULL,
376                                                CRYPT_MACHINE_KEYSET);
377                 if res != TRUE {
378                     // create a new key container (since it does not exist)
379                     let res = CryptAcquireContextW(&mut provider,
380                                                    buffer.as_ptr(),
381                                                    ptr::null_mut(),
382                                                    PROV_RSA_FULL,
383                                                    CRYPT_NEWKEYSET | CRYPT_MACHINE_KEYSET);
384                     if res != TRUE {
385                         return Err(Error::last_os_error())
386                     }
387                 }
388 
389                 // create a new keypair (RSA-2048)
390                 let res = CryptGenKey(provider,
391                                       AT_SIGNATURE,
392                                       0x0800<<16 | CRYPT_EXPORTABLE,
393                                       &mut hkey);
394                 if res != TRUE {
395                     return Err(Error::last_os_error());
396                 }
397 
398                 // start creating the certificate
399                 let name = "CN=localhost,O=tokio-tls,OU=tokio-tls,\
400                             G=tokio_tls".encode_utf16()
401                                           .chain(Some(0))
402                                           .collect::<Vec<_>>();
403                 let mut cname_buffer: [WCHAR; UNLEN as usize + 1] = mem::zeroed();
404                 let mut cname_len = cname_buffer.len() as DWORD;
405                 let res = CertStrToNameW(X509_ASN_ENCODING,
406                                          name.as_ptr(),
407                                          CERT_X500_NAME_STR,
408                                          ptr::null_mut(),
409                                          cname_buffer.as_mut_ptr() as *mut u8,
410                                          &mut cname_len,
411                                          ptr::null_mut());
412                 if res != TRUE {
413                     return Err(Error::last_os_error());
414                 }
415 
416                 let mut subject_issuer = CERT_NAME_BLOB {
417                     cbData: cname_len,
418                     pbData: cname_buffer.as_ptr() as *mut u8,
419                 };
420                 let mut key_provider = CRYPT_KEY_PROV_INFO {
421                     pwszContainerName: buffer.as_mut_ptr(),
422                     pwszProvName: ptr::null_mut(),
423                     dwProvType: PROV_RSA_FULL,
424                     dwFlags: CRYPT_MACHINE_KEYSET,
425                     cProvParam: 0,
426                     rgProvParam: ptr::null_mut(),
427                     dwKeySpec: AT_SIGNATURE,
428                 };
429                 let mut sig_algorithm = CRYPT_ALGORITHM_IDENTIFIER {
430                     pszObjId: szOID_RSA_SHA256RSA.as_ptr() as *mut _,
431                     Parameters: mem::zeroed(),
432                 };
433                 let mut expiration_date: SYSTEMTIME = mem::zeroed();
434                 GetSystemTime(&mut expiration_date);
435                 let mut file_time: FILETIME = mem::zeroed();
436                 let res = SystemTimeToFileTime(&expiration_date,
437                                                &mut file_time);
438                 if res != TRUE {
439                     return Err(Error::last_os_error());
440                 }
441                 let mut timestamp: u64 = file_time.dwLowDateTime as u64 |
442                                          (file_time.dwHighDateTime as u64) << 32;
443                 // one day, timestamp unit is in 100 nanosecond intervals
444                 timestamp += (1E9 as u64) / 100 * (60 * 60 * 24);
445                 file_time.dwLowDateTime = timestamp as u32;
446                 file_time.dwHighDateTime = (timestamp >> 32) as u32;
447                 let res = FileTimeToSystemTime(&file_time,
448                                                &mut expiration_date);
449                 if res != TRUE {
450                     return Err(Error::last_os_error());
451                 }
452 
453                 // create a self signed certificate
454                 let cert_context = CertCreateSelfSignCertificate(
455                         0 as ULONG_PTR,
456                         &mut subject_issuer,
457                         0,
458                         &mut key_provider,
459                         &mut sig_algorithm,
460                         ptr::null_mut(),
461                         &mut expiration_date,
462                         ptr::null_mut());
463                 if cert_context.is_null() {
464                     return Err(Error::last_os_error());
465                 }
466 
467                 // TODO: this is.. a terrible hack. Right now `schannel`
468                 //       doesn't provide a public method to go from a raw
469                 //       cert context pointer to the `CertContext` structure it
470                 //       has, so we just fake it here with a transmute. This'll
471                 //       probably break at some point, but hopefully by then
472                 //       it'll have a method to do this!
473                 struct MyCertContext<T>(T);
474                 impl<T> Drop for MyCertContext<T> {
475                     fn drop(&mut self) {}
476                 }
477 
478                 let cert_context = MyCertContext(cert_context);
479                 let cert_context: CertContext = mem::transmute(cert_context);
480 
481                 cert_context.set_friendly_name(FRIENDLY_NAME)?;
482 
483                 // install the certificate to the machine's local store
484                 io::stdout().write_all(br#"
485 
486 The tokio-tls test suite is about to add a certificate to your set of root
487 and trusted certificates. This certificate should be for the domain "localhost"
488 with the description related to "tokio-tls". This certificate is only valid
489 for one day and will be automatically deleted if you re-run the tokio-tls
490 test suite later.
491 
492         "#).unwrap();
493                 local_root_store().add_cert(&cert_context,
494                                                  CertAdd::ReplaceExisting)?;
495                 Ok(cert_context)
496             }
497         }
498     }
499 }
500 
501 const AMT: usize = 128 * 1024;
502 
copy_data<W: AsyncWrite + Unpin>(mut w: W) -> Result<usize, Error>503 async fn copy_data<W: AsyncWrite + Unpin>(mut w: W) -> Result<usize, Error> {
504     let mut data = vec![9; AMT as usize];
505     let mut amt = 0;
506     while !data.is_empty() {
507         let written = w.write(&data).await?;
508         if written <= data.len() {
509             amt += written;
510             data.resize(data.len() - written, 0);
511         } else {
512             w.write_all(&data).await?;
513             amt += data.len();
514             break;
515         }
516 
517         println!("remaining: {}", data.len());
518     }
519     Ok(amt)
520 }
521 
522 #[tokio::test]
client_to_server()523 async fn client_to_server() {
524     drop(env_logger::try_init());
525 
526     // Create a server listening on a port, then figure out what that port is
527     let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
528     let addr = t!(srv.local_addr());
529 
530     let (server_cx, client_cx) = contexts();
531 
532     // Create a future to accept one socket, connect the ssl stream, and then
533     // read all the data from it.
534     let server = async move {
535         let mut incoming = srv.incoming();
536         let socket = t!(incoming.next().await.unwrap());
537         let mut socket = t!(server_cx.accept(socket).await);
538         let mut data = Vec::new();
539         t!(socket.read_to_end(&mut data).await);
540         data
541     };
542 
543     // Create a future to connect to our server, connect the ssl stream, and
544     // then write a bunch of data to it.
545     let client = async move {
546         let socket = t!(TcpStream::connect(&addr).await);
547         let socket = t!(client_cx.connect("localhost", socket).await);
548         copy_data(socket).await
549     };
550 
551     // Finally, run everything!
552     let (data, _) = join!(server, client);
553     // assert_eq!(amt, AMT);
554     assert!(data == vec![9; AMT]);
555 }
556 
557 #[tokio::test]
server_to_client()558 async fn server_to_client() {
559     drop(env_logger::try_init());
560 
561     // Create a server listening on a port, then figure out what that port is
562     let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
563     let addr = t!(srv.local_addr());
564 
565     let (server_cx, client_cx) = contexts();
566 
567     let server = async move {
568         let mut incoming = srv.incoming();
569         let socket = t!(incoming.next().await.unwrap());
570         let socket = t!(server_cx.accept(socket).await);
571         copy_data(socket).await
572     };
573 
574     let client = async move {
575         let socket = t!(TcpStream::connect(&addr).await);
576         let mut socket = t!(client_cx.connect("localhost", socket).await);
577         let mut data = Vec::new();
578         t!(socket.read_to_end(&mut data).await);
579         data
580     };
581 
582     // Finally, run everything!
583     let (_, data) = join!(server, client);
584     // assert_eq!(amt, AMT);
585     assert!(data == vec![9; AMT]);
586 }
587 
588 #[tokio::test]
one_byte_at_a_time()589 async fn one_byte_at_a_time() {
590     const AMT: usize = 1024;
591     drop(env_logger::try_init());
592 
593     let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
594     let addr = t!(srv.local_addr());
595 
596     let (server_cx, client_cx) = contexts();
597 
598     let server = async move {
599         let mut incoming = srv.incoming();
600         let socket = t!(incoming.next().await.unwrap());
601         let mut socket = t!(server_cx.accept(socket).await);
602         let mut amt = 0;
603         for b in std::iter::repeat(9).take(AMT) {
604             let data = [b as u8];
605             t!(socket.write_all(&data).await);
606             amt += 1;
607         }
608         amt
609     };
610 
611     let client = async move {
612         let socket = t!(TcpStream::connect(&addr).await);
613         let mut socket = t!(client_cx.connect("localhost", socket).await);
614         let mut data = Vec::new();
615         loop {
616             let mut buf = [0; 1];
617             match socket.read_exact(&mut buf).await {
618                 Ok(_) => data.extend_from_slice(&buf),
619                 Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => break,
620                 Err(err) => panic!(err),
621             }
622         }
623         data
624     };
625 
626     let (amt, data) = join!(server, client);
627     assert_eq!(amt, AMT);
628     assert!(data == vec![9; AMT as usize]);
629 }
630