1 #![warn(rust_2018_idioms)]
2
3 use cfg_if::cfg_if;
4 use env_logger;
5 use futures::join;
6 use native_tls;
7 use native_tls::{Identity, TlsAcceptor, TlsConnector};
8 use std::io::Write;
9 use std::marker::Unpin;
10 use std::process::Command;
11 use std::ptr;
12 use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, Error, ErrorKind};
13 use tokio::net::{TcpListener, TcpStream};
14 use tokio::stream::StreamExt;
15 use tokio_tls;
16
17 macro_rules! t {
18 ($e:expr) => {
19 match $e {
20 Ok(e) => e,
21 Err(e) => panic!("{} failed with {:?}", stringify!($e), e),
22 }
23 };
24 }
25
26 #[allow(dead_code)]
27 struct Keys {
28 cert_der: Vec<u8>,
29 pkey_der: Vec<u8>,
30 pkcs12_der: Vec<u8>,
31 }
32
33 #[allow(dead_code)]
openssl_keys() -> &'static Keys34 fn openssl_keys() -> &'static Keys {
35 static INIT: Once = Once::new();
36 static mut KEYS: *mut Keys = ptr::null_mut();
37
38 INIT.call_once(|| {
39 let path = t!(env::current_exe());
40 let path = path.parent().unwrap();
41 let keyfile = path.join("test.key");
42 let certfile = path.join("test.crt");
43 let config = path.join("openssl.config");
44
45 File::create(&config)
46 .unwrap()
47 .write_all(
48 b"\
49 [req]\n\
50 distinguished_name=dn\n\
51 [ dn ]\n\
52 CN=localhost\n\
53 [ ext ]\n\
54 basicConstraints=CA:FALSE,pathlen:0\n\
55 subjectAltName = @alt_names
56 extendedKeyUsage=serverAuth,clientAuth
57 [alt_names]
58 DNS.1 = localhost
59 ",
60 )
61 .unwrap();
62
63 let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost";
64 let output = t!(Command::new("openssl")
65 .arg("req")
66 .arg("-nodes")
67 .arg("-x509")
68 .arg("-newkey")
69 .arg("rsa:2048")
70 .arg("-config")
71 .arg(&config)
72 .arg("-extensions")
73 .arg("ext")
74 .arg("-subj")
75 .arg(subj)
76 .arg("-keyout")
77 .arg(&keyfile)
78 .arg("-out")
79 .arg(&certfile)
80 .arg("-days")
81 .arg("1")
82 .output());
83 assert!(output.status.success());
84
85 let crtout = t!(Command::new("openssl")
86 .arg("x509")
87 .arg("-outform")
88 .arg("der")
89 .arg("-in")
90 .arg(&certfile)
91 .output());
92 assert!(crtout.status.success());
93 let keyout = t!(Command::new("openssl")
94 .arg("rsa")
95 .arg("-outform")
96 .arg("der")
97 .arg("-in")
98 .arg(&keyfile)
99 .output());
100 assert!(keyout.status.success());
101
102 let pkcs12out = t!(Command::new("openssl")
103 .arg("pkcs12")
104 .arg("-export")
105 .arg("-nodes")
106 .arg("-inkey")
107 .arg(&keyfile)
108 .arg("-in")
109 .arg(&certfile)
110 .arg("-password")
111 .arg("pass:foobar")
112 .output());
113 assert!(pkcs12out.status.success());
114
115 let keys = Box::new(Keys {
116 cert_der: crtout.stdout,
117 pkey_der: keyout.stdout,
118 pkcs12_der: pkcs12out.stdout,
119 });
120 unsafe {
121 KEYS = Box::into_raw(keys);
122 }
123 });
124 unsafe { &*KEYS }
125 }
126
127 cfg_if! {
128 if #[cfg(feature = "rustls")] {
129 use webpki;
130 use untrusted;
131 use std::env;
132 use std::fs::File;
133 use std::process::Command;
134 use std::sync::Once;
135
136 use untrusted::Input;
137 use webpki::trust_anchor_util;
138
139 fn server_cx() -> io::Result<ServerContext> {
140 let mut cx = ServerContext::new();
141
142 let (cert, key) = keys();
143 cx.config_mut()
144 .set_single_cert(vec![cert.to_vec()], key.to_vec());
145
146 Ok(cx)
147 }
148
149 fn configure_client(cx: &mut ClientContext) {
150 let (cert, _key) = keys();
151 let cert = Input::from(cert);
152 let anchor = trust_anchor_util::cert_der_as_trust_anchor(cert).unwrap();
153 cx.config_mut().root_store.add_trust_anchors(&[anchor]);
154 }
155
156 // Like OpenSSL we generate certificates on the fly, but for OSX we
157 // also have to put them into a specific keychain. We put both the
158 // certificates and the keychain next to our binary.
159 //
160 // Right now I don't know of a way to programmatically create a
161 // self-signed certificate, so we just fork out to the `openssl` binary.
162 fn keys() -> (&'static [u8], &'static [u8]) {
163 static INIT: Once = Once::new();
164 static mut KEYS: *mut (Vec<u8>, Vec<u8>) = ptr::null_mut();
165
166 INIT.call_once(|| {
167 let (key, cert) = openssl_keys();
168 let path = t!(env::current_exe());
169 let path = path.parent().unwrap();
170 let keyfile = path.join("test.key");
171 let certfile = path.join("test.crt");
172 let config = path.join("openssl.config");
173
174 File::create(&config).unwrap().write_all(b"\
175 [req]\n\
176 distinguished_name=dn\n\
177 [ dn ]\n\
178 CN=localhost\n\
179 [ ext ]\n\
180 basicConstraints=CA:FALSE,pathlen:0\n\
181 subjectAltName = @alt_names
182 [alt_names]
183 DNS.1 = localhost
184 ").unwrap();
185
186 let subj = "/C=US/ST=Denial/L=Sprintfield/O=Dis/CN=localhost";
187 let output = t!(Command::new("openssl")
188 .arg("req")
189 .arg("-nodes")
190 .arg("-x509")
191 .arg("-newkey").arg("rsa:2048")
192 .arg("-config").arg(&config)
193 .arg("-extensions").arg("ext")
194 .arg("-subj").arg(subj)
195 .arg("-keyout").arg(&keyfile)
196 .arg("-out").arg(&certfile)
197 .arg("-days").arg("1")
198 .output());
199 assert!(output.status.success());
200
201 let crtout = t!(Command::new("openssl")
202 .arg("x509")
203 .arg("-outform").arg("der")
204 .arg("-in").arg(&certfile)
205 .output());
206 assert!(crtout.status.success());
207 let keyout = t!(Command::new("openssl")
208 .arg("rsa")
209 .arg("-outform").arg("der")
210 .arg("-in").arg(&keyfile)
211 .output());
212 assert!(keyout.status.success());
213
214 let cert = crtout.stdout;
215 let key = keyout.stdout;
216 unsafe {
217 KEYS = Box::into_raw(Box::new((cert, key)));
218 }
219 });
220 unsafe {
221 (&(*KEYS).0, &(*KEYS).1)
222 }
223 }
224 } else if #[cfg(any(feature = "force-openssl",
225 all(not(target_os = "macos"),
226 not(target_os = "windows"),
227 not(target_os = "ios"))))] {
228 use std::fs::File;
229 use std::env;
230 use std::sync::Once;
231
232 fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
233 let keys = openssl_keys();
234
235 let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar"));
236 let srv = TlsAcceptor::builder(pkcs12);
237
238 let cert = t!(native_tls::Certificate::from_der(&keys.cert_der));
239
240 let mut client = TlsConnector::builder();
241 t!(client.add_root_certificate(cert).build());
242
243 (t!(srv.build()).into(), t!(client.build()).into())
244 }
245 } else if #[cfg(any(target_os = "macos", target_os = "ios"))] {
246 use std::env;
247 use std::fs::File;
248 use std::sync::Once;
249
250 fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
251 let keys = openssl_keys();
252
253 let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar"));
254 let srv = TlsAcceptor::builder(pkcs12);
255
256 let cert = native_tls::Certificate::from_der(&keys.cert_der).unwrap();
257 let mut client = TlsConnector::builder();
258 client.add_root_certificate(cert);
259
260 (t!(srv.build()).into(), t!(client.build()).into())
261 }
262 } else {
263 use schannel;
264 use winapi;
265
266 use std::env;
267 use std::fs::File;
268 use std::io;
269 use std::mem;
270 use std::sync::Once;
271
272 use schannel::cert_context::CertContext;
273 use schannel::cert_store::{CertStore, CertAdd, Memory};
274 use winapi::shared::basetsd::*;
275 use winapi::shared::lmcons::*;
276 use winapi::shared::minwindef::*;
277 use winapi::shared::ntdef::WCHAR;
278 use winapi::um::minwinbase::*;
279 use winapi::um::sysinfoapi::*;
280 use winapi::um::timezoneapi::*;
281 use winapi::um::wincrypt::*;
282
283 const FRIENDLY_NAME: &str = "tokio-tls localhost testing cert";
284
285 fn contexts() -> (tokio_tls::TlsAcceptor, tokio_tls::TlsConnector) {
286 let cert = localhost_cert();
287 let mut store = t!(Memory::new()).into_store();
288 t!(store.add_cert(&cert, CertAdd::Always));
289 let pkcs12_der = t!(store.export_pkcs12("foobar"));
290 let pkcs12 = t!(Identity::from_pkcs12(&pkcs12_der, "foobar"));
291
292 let srv = TlsAcceptor::builder(pkcs12);
293 let client = TlsConnector::builder();
294 (t!(srv.build()).into(), t!(client.build()).into())
295 }
296
297 // ====================================================================
298 // Magic!
299 //
300 // Lots of magic is happening here to wrangle certificates for running
301 // these tests on Windows. For more information see the test suite
302 // in the schannel-rs crate as this is just coyping that.
303 //
304 // The general gist of this though is that the only way to add custom
305 // trusted certificates is to add it to the system store of trust. To
306 // do that we go through the whole rigamarole here to generate a new
307 // self-signed certificate and then insert that into the system store.
308 //
309 // This generates some dialogs, so we print what we're doing sometimes,
310 // and otherwise we just manage the ephemeral certificates. Because
311 // they're in the system store we always ensure that they're only valid
312 // for a small period of time (e.g. 1 day).
313
314 fn localhost_cert() -> CertContext {
315 static INIT: Once = Once::new();
316 INIT.call_once(|| {
317 for cert in local_root_store().certs() {
318 let name = match cert.friendly_name() {
319 Ok(name) => name,
320 Err(_) => continue,
321 };
322 if name != FRIENDLY_NAME {
323 continue
324 }
325 if !cert.is_time_valid().unwrap() {
326 io::stdout().write_all(br#"
327
328 The tokio-tls test suite is about to delete an old copy of one of its
329 certificates from your root trust store. This certificate was only valid for one
330 day and it is no longer needed. The host should be "localhost" and the
331 description should mention "tokio-tls".
332
333 "#).unwrap();
334 cert.delete().unwrap();
335 } else {
336 return
337 }
338 }
339
340 install_certificate().unwrap();
341 });
342
343 for cert in local_root_store().certs() {
344 let name = match cert.friendly_name() {
345 Ok(name) => name,
346 Err(_) => continue,
347 };
348 if name == FRIENDLY_NAME {
349 return cert
350 }
351 }
352
353 panic!("couldn't find a cert");
354 }
355
356 fn local_root_store() -> CertStore {
357 if env::var("CI").is_ok() {
358 CertStore::open_local_machine("Root").unwrap()
359 } else {
360 CertStore::open_current_user("Root").unwrap()
361 }
362 }
363
364 fn install_certificate() -> io::Result<CertContext> {
365 unsafe {
366 let mut provider = 0;
367 let mut hkey = 0;
368
369 let mut buffer = "tokio-tls test suite".encode_utf16()
370 .chain(Some(0))
371 .collect::<Vec<_>>();
372 let res = CryptAcquireContextW(&mut provider,
373 buffer.as_ptr(),
374 ptr::null_mut(),
375 PROV_RSA_FULL,
376 CRYPT_MACHINE_KEYSET);
377 if res != TRUE {
378 // create a new key container (since it does not exist)
379 let res = CryptAcquireContextW(&mut provider,
380 buffer.as_ptr(),
381 ptr::null_mut(),
382 PROV_RSA_FULL,
383 CRYPT_NEWKEYSET | CRYPT_MACHINE_KEYSET);
384 if res != TRUE {
385 return Err(Error::last_os_error())
386 }
387 }
388
389 // create a new keypair (RSA-2048)
390 let res = CryptGenKey(provider,
391 AT_SIGNATURE,
392 0x0800<<16 | CRYPT_EXPORTABLE,
393 &mut hkey);
394 if res != TRUE {
395 return Err(Error::last_os_error());
396 }
397
398 // start creating the certificate
399 let name = "CN=localhost,O=tokio-tls,OU=tokio-tls,\
400 G=tokio_tls".encode_utf16()
401 .chain(Some(0))
402 .collect::<Vec<_>>();
403 let mut cname_buffer: [WCHAR; UNLEN as usize + 1] = mem::zeroed();
404 let mut cname_len = cname_buffer.len() as DWORD;
405 let res = CertStrToNameW(X509_ASN_ENCODING,
406 name.as_ptr(),
407 CERT_X500_NAME_STR,
408 ptr::null_mut(),
409 cname_buffer.as_mut_ptr() as *mut u8,
410 &mut cname_len,
411 ptr::null_mut());
412 if res != TRUE {
413 return Err(Error::last_os_error());
414 }
415
416 let mut subject_issuer = CERT_NAME_BLOB {
417 cbData: cname_len,
418 pbData: cname_buffer.as_ptr() as *mut u8,
419 };
420 let mut key_provider = CRYPT_KEY_PROV_INFO {
421 pwszContainerName: buffer.as_mut_ptr(),
422 pwszProvName: ptr::null_mut(),
423 dwProvType: PROV_RSA_FULL,
424 dwFlags: CRYPT_MACHINE_KEYSET,
425 cProvParam: 0,
426 rgProvParam: ptr::null_mut(),
427 dwKeySpec: AT_SIGNATURE,
428 };
429 let mut sig_algorithm = CRYPT_ALGORITHM_IDENTIFIER {
430 pszObjId: szOID_RSA_SHA256RSA.as_ptr() as *mut _,
431 Parameters: mem::zeroed(),
432 };
433 let mut expiration_date: SYSTEMTIME = mem::zeroed();
434 GetSystemTime(&mut expiration_date);
435 let mut file_time: FILETIME = mem::zeroed();
436 let res = SystemTimeToFileTime(&expiration_date,
437 &mut file_time);
438 if res != TRUE {
439 return Err(Error::last_os_error());
440 }
441 let mut timestamp: u64 = file_time.dwLowDateTime as u64 |
442 (file_time.dwHighDateTime as u64) << 32;
443 // one day, timestamp unit is in 100 nanosecond intervals
444 timestamp += (1E9 as u64) / 100 * (60 * 60 * 24);
445 file_time.dwLowDateTime = timestamp as u32;
446 file_time.dwHighDateTime = (timestamp >> 32) as u32;
447 let res = FileTimeToSystemTime(&file_time,
448 &mut expiration_date);
449 if res != TRUE {
450 return Err(Error::last_os_error());
451 }
452
453 // create a self signed certificate
454 let cert_context = CertCreateSelfSignCertificate(
455 0 as ULONG_PTR,
456 &mut subject_issuer,
457 0,
458 &mut key_provider,
459 &mut sig_algorithm,
460 ptr::null_mut(),
461 &mut expiration_date,
462 ptr::null_mut());
463 if cert_context.is_null() {
464 return Err(Error::last_os_error());
465 }
466
467 // TODO: this is.. a terrible hack. Right now `schannel`
468 // doesn't provide a public method to go from a raw
469 // cert context pointer to the `CertContext` structure it
470 // has, so we just fake it here with a transmute. This'll
471 // probably break at some point, but hopefully by then
472 // it'll have a method to do this!
473 struct MyCertContext<T>(T);
474 impl<T> Drop for MyCertContext<T> {
475 fn drop(&mut self) {}
476 }
477
478 let cert_context = MyCertContext(cert_context);
479 let cert_context: CertContext = mem::transmute(cert_context);
480
481 cert_context.set_friendly_name(FRIENDLY_NAME)?;
482
483 // install the certificate to the machine's local store
484 io::stdout().write_all(br#"
485
486 The tokio-tls test suite is about to add a certificate to your set of root
487 and trusted certificates. This certificate should be for the domain "localhost"
488 with the description related to "tokio-tls". This certificate is only valid
489 for one day and will be automatically deleted if you re-run the tokio-tls
490 test suite later.
491
492 "#).unwrap();
493 local_root_store().add_cert(&cert_context,
494 CertAdd::ReplaceExisting)?;
495 Ok(cert_context)
496 }
497 }
498 }
499 }
500
501 const AMT: usize = 128 * 1024;
502
copy_data<W: AsyncWrite + Unpin>(mut w: W) -> Result<usize, Error>503 async fn copy_data<W: AsyncWrite + Unpin>(mut w: W) -> Result<usize, Error> {
504 let mut data = vec![9; AMT as usize];
505 let mut amt = 0;
506 while !data.is_empty() {
507 let written = w.write(&data).await?;
508 if written <= data.len() {
509 amt += written;
510 data.resize(data.len() - written, 0);
511 } else {
512 w.write_all(&data).await?;
513 amt += data.len();
514 break;
515 }
516
517 println!("remaining: {}", data.len());
518 }
519 Ok(amt)
520 }
521
522 #[tokio::test]
client_to_server()523 async fn client_to_server() {
524 drop(env_logger::try_init());
525
526 // Create a server listening on a port, then figure out what that port is
527 let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
528 let addr = t!(srv.local_addr());
529
530 let (server_cx, client_cx) = contexts();
531
532 // Create a future to accept one socket, connect the ssl stream, and then
533 // read all the data from it.
534 let server = async move {
535 let mut incoming = srv.incoming();
536 let socket = t!(incoming.next().await.unwrap());
537 let mut socket = t!(server_cx.accept(socket).await);
538 let mut data = Vec::new();
539 t!(socket.read_to_end(&mut data).await);
540 data
541 };
542
543 // Create a future to connect to our server, connect the ssl stream, and
544 // then write a bunch of data to it.
545 let client = async move {
546 let socket = t!(TcpStream::connect(&addr).await);
547 let socket = t!(client_cx.connect("localhost", socket).await);
548 copy_data(socket).await
549 };
550
551 // Finally, run everything!
552 let (data, _) = join!(server, client);
553 // assert_eq!(amt, AMT);
554 assert!(data == vec![9; AMT]);
555 }
556
557 #[tokio::test]
server_to_client()558 async fn server_to_client() {
559 drop(env_logger::try_init());
560
561 // Create a server listening on a port, then figure out what that port is
562 let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
563 let addr = t!(srv.local_addr());
564
565 let (server_cx, client_cx) = contexts();
566
567 let server = async move {
568 let mut incoming = srv.incoming();
569 let socket = t!(incoming.next().await.unwrap());
570 let socket = t!(server_cx.accept(socket).await);
571 copy_data(socket).await
572 };
573
574 let client = async move {
575 let socket = t!(TcpStream::connect(&addr).await);
576 let mut socket = t!(client_cx.connect("localhost", socket).await);
577 let mut data = Vec::new();
578 t!(socket.read_to_end(&mut data).await);
579 data
580 };
581
582 // Finally, run everything!
583 let (_, data) = join!(server, client);
584 // assert_eq!(amt, AMT);
585 assert!(data == vec![9; AMT]);
586 }
587
588 #[tokio::test]
one_byte_at_a_time()589 async fn one_byte_at_a_time() {
590 const AMT: usize = 1024;
591 drop(env_logger::try_init());
592
593 let mut srv = t!(TcpListener::bind("127.0.0.1:0").await);
594 let addr = t!(srv.local_addr());
595
596 let (server_cx, client_cx) = contexts();
597
598 let server = async move {
599 let mut incoming = srv.incoming();
600 let socket = t!(incoming.next().await.unwrap());
601 let mut socket = t!(server_cx.accept(socket).await);
602 let mut amt = 0;
603 for b in std::iter::repeat(9).take(AMT) {
604 let data = [b as u8];
605 t!(socket.write_all(&data).await);
606 amt += 1;
607 }
608 amt
609 };
610
611 let client = async move {
612 let socket = t!(TcpStream::connect(&addr).await);
613 let mut socket = t!(client_cx.connect("localhost", socket).await);
614 let mut data = Vec::new();
615 loop {
616 let mut buf = [0; 1];
617 match socket.read_exact(&mut buf).await {
618 Ok(_) => data.extend_from_slice(&buf),
619 Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => break,
620 Err(err) => panic!(err),
621 }
622 }
623 data
624 };
625
626 let (amt, data) = join!(server, client);
627 assert_eq!(amt, AMT);
628 assert!(data == vec![9; AMT as usize]);
629 }
630