1 // Assorted public API tests.
2 use std::env;
3 use std::fmt;
4 use std::io::{self, IoSlice, Read, Write};
5 use std::mem;
6 use std::sync::atomic::{AtomicUsize, Ordering};
7 use std::sync::Arc;
8 use std::sync::Mutex;
9 
10 use rustls;
11 
12 #[cfg(feature = "quic")]
13 use rustls::quic::{self, ClientQuicExt, QuicExt, ServerQuicExt};
14 use rustls::sign;
15 use rustls::ClientHello;
16 use rustls::KeyLog;
17 use rustls::Session;
18 use rustls::TLSError;
19 use rustls::{CipherSuite, ProtocolVersion, SignatureScheme};
20 use rustls::{ClientConfig, ClientSession, ResolvesClientCert};
21 use rustls::{ResolvesServerCert, ServerConfig, ServerSession};
22 use rustls::{Stream, StreamOwned};
23 use rustls::{SupportedCipherSuite, ALL_CIPHERSUITES};
24 
25 #[cfg(feature = "dangerous_configuration")]
26 use rustls::ClientCertVerified;
27 
28 use webpki;
29 
30 #[allow(dead_code)]
31 mod common;
32 use crate::common::*;
33 
alpn_test(server_protos: Vec<Vec<u8>>, client_protos: Vec<Vec<u8>>, agreed: Option<&[u8]>)34 fn alpn_test(server_protos: Vec<Vec<u8>>, client_protos: Vec<Vec<u8>>, agreed: Option<&[u8]>) {
35     let mut client_config = make_client_config(KeyType::RSA);
36     let mut server_config = make_server_config(KeyType::RSA);
37 
38     client_config.alpn_protocols = client_protos;
39     server_config.alpn_protocols = server_protos;
40 
41     let server_config = Arc::new(server_config);
42 
43     for client_config in AllClientVersions::new(client_config) {
44         let (mut client, mut server) =
45             make_pair_for_arc_configs(&Arc::new(client_config), &server_config);
46 
47         assert_eq!(client.get_alpn_protocol(), None);
48         assert_eq!(server.get_alpn_protocol(), None);
49         do_handshake(&mut client, &mut server);
50         assert_eq!(client.get_alpn_protocol(), agreed);
51         assert_eq!(server.get_alpn_protocol(), agreed);
52     }
53 }
54 
55 #[test]
alpn()56 fn alpn() {
57     // no support
58     alpn_test(vec![], vec![], None);
59 
60     // server support
61     alpn_test(vec![b"server-proto".to_vec()], vec![], None);
62 
63     // client support
64     alpn_test(vec![], vec![b"client-proto".to_vec()], None);
65 
66     // no overlap
67     alpn_test(
68         vec![b"server-proto".to_vec()],
69         vec![b"client-proto".to_vec()],
70         None,
71     );
72 
73     // server chooses preference
74     alpn_test(
75         vec![b"server-proto".to_vec(), b"client-proto".to_vec()],
76         vec![b"client-proto".to_vec(), b"server-proto".to_vec()],
77         Some(b"server-proto"),
78     );
79 
80     // case sensitive
81     alpn_test(vec![b"PROTO".to_vec()], vec![b"proto".to_vec()], None);
82 }
83 
version_test( client_versions: Vec<ProtocolVersion>, server_versions: Vec<ProtocolVersion>, result: Option<ProtocolVersion>, )84 fn version_test(
85     client_versions: Vec<ProtocolVersion>,
86     server_versions: Vec<ProtocolVersion>,
87     result: Option<ProtocolVersion>,
88 ) {
89     let mut client_config = make_client_config(KeyType::RSA);
90     let mut server_config = make_server_config(KeyType::RSA);
91 
92     println!(
93         "version {:?} {:?} -> {:?}",
94         client_versions, server_versions, result
95     );
96 
97     if !client_versions.is_empty() {
98         client_config.versions = client_versions;
99     }
100 
101     if !server_versions.is_empty() {
102         server_config.versions = server_versions;
103     }
104 
105     let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
106 
107     assert_eq!(client.get_protocol_version(), None);
108     assert_eq!(server.get_protocol_version(), None);
109     if result.is_none() {
110         let err = do_handshake_until_error(&mut client, &mut server);
111         assert_eq!(err.is_err(), true);
112     } else {
113         do_handshake(&mut client, &mut server);
114         assert_eq!(client.get_protocol_version(), result);
115         assert_eq!(server.get_protocol_version(), result);
116     }
117 }
118 
119 #[test]
versions()120 fn versions() {
121     // default -> 1.3
122     version_test(vec![], vec![], Some(ProtocolVersion::TLSv1_3));
123 
124     // client default, server 1.2 -> 1.2
125     version_test(
126         vec![],
127         vec![ProtocolVersion::TLSv1_2],
128         Some(ProtocolVersion::TLSv1_2),
129     );
130 
131     // client 1.2, server default -> 1.2
132     version_test(
133         vec![ProtocolVersion::TLSv1_2],
134         vec![],
135         Some(ProtocolVersion::TLSv1_2),
136     );
137 
138     // client 1.2, server 1.3 -> fail
139     version_test(
140         vec![ProtocolVersion::TLSv1_2],
141         vec![ProtocolVersion::TLSv1_3],
142         None,
143     );
144 
145     // client 1.3, server 1.2 -> fail
146     version_test(
147         vec![ProtocolVersion::TLSv1_3],
148         vec![ProtocolVersion::TLSv1_2],
149         None,
150     );
151 
152     // client 1.3, server 1.2+1.3 -> 1.3
153     version_test(
154         vec![ProtocolVersion::TLSv1_3],
155         vec![ProtocolVersion::TLSv1_2, ProtocolVersion::TLSv1_3],
156         Some(ProtocolVersion::TLSv1_3),
157     );
158 
159     // client 1.2+1.3, server 1.2 -> 1.2
160     version_test(
161         vec![ProtocolVersion::TLSv1_3, ProtocolVersion::TLSv1_2],
162         vec![ProtocolVersion::TLSv1_2],
163         Some(ProtocolVersion::TLSv1_2),
164     );
165 }
166 
check_read(reader: &mut dyn io::Read, bytes: &[u8])167 fn check_read(reader: &mut dyn io::Read, bytes: &[u8]) {
168     let mut buf = Vec::new();
169     assert_eq!(bytes.len(), reader.read_to_end(&mut buf).unwrap());
170     assert_eq!(bytes.to_vec(), buf);
171 }
172 
173 #[test]
buffered_client_data_sent()174 fn buffered_client_data_sent() {
175     let server_config = Arc::new(make_server_config(KeyType::RSA));
176 
177     for client_config in AllClientVersions::new(make_client_config(KeyType::RSA)) {
178         let (mut client, mut server) =
179             make_pair_for_arc_configs(&Arc::new(client_config), &server_config);
180 
181         assert_eq!(5, client.write(b"hello").unwrap());
182 
183         do_handshake(&mut client, &mut server);
184         transfer(&mut client, &mut server);
185         server.process_new_packets().unwrap();
186 
187         check_read(&mut server, b"hello");
188     }
189 }
190 
191 #[test]
buffered_server_data_sent()192 fn buffered_server_data_sent() {
193     let server_config = Arc::new(make_server_config(KeyType::RSA));
194 
195     for client_config in AllClientVersions::new(make_client_config(KeyType::RSA)) {
196         let (mut client, mut server) =
197             make_pair_for_arc_configs(&Arc::new(client_config), &server_config);
198 
199         assert_eq!(5, server.write(b"hello").unwrap());
200 
201         do_handshake(&mut client, &mut server);
202         transfer(&mut server, &mut client);
203         client.process_new_packets().unwrap();
204 
205         check_read(&mut client, b"hello");
206     }
207 }
208 
209 #[test]
buffered_both_data_sent()210 fn buffered_both_data_sent() {
211     let server_config = Arc::new(make_server_config(KeyType::RSA));
212 
213     for client_config in AllClientVersions::new(make_client_config(KeyType::RSA)) {
214         let (mut client, mut server) =
215             make_pair_for_arc_configs(&Arc::new(client_config), &server_config);
216 
217         assert_eq!(12, server.write(b"from-server!").unwrap());
218         assert_eq!(12, client.write(b"from-client!").unwrap());
219 
220         do_handshake(&mut client, &mut server);
221 
222         transfer(&mut server, &mut client);
223         client.process_new_packets().unwrap();
224         transfer(&mut client, &mut server);
225         server.process_new_packets().unwrap();
226 
227         check_read(&mut client, b"from-server!");
228         check_read(&mut server, b"from-client!");
229     }
230 }
231 
232 #[test]
client_can_get_server_cert()233 fn client_can_get_server_cert() {
234     for kt in ALL_KEY_TYPES.iter() {
235         for client_config in AllClientVersions::new(make_client_config(*kt)) {
236             let (mut client, mut server) =
237                 make_pair_for_configs(client_config, make_server_config(*kt));
238             do_handshake(&mut client, &mut server);
239 
240             let certs = client.get_peer_certificates();
241             assert_eq!(certs, Some(kt.get_chain()));
242         }
243     }
244 }
245 
246 #[test]
client_can_get_server_cert_after_resumption()247 fn client_can_get_server_cert_after_resumption() {
248     for kt in ALL_KEY_TYPES.iter() {
249         let server_config = make_server_config(*kt);
250         for client_config in AllClientVersions::new(make_client_config(*kt)) {
251             let (mut client, mut server) =
252                 make_pair_for_configs(client_config.clone(), server_config.clone());
253             do_handshake(&mut client, &mut server);
254 
255             let original_certs = client.get_peer_certificates();
256 
257             let (mut client, mut server) =
258                 make_pair_for_configs(client_config.clone(), server_config.clone());
259             do_handshake(&mut client, &mut server);
260 
261             let resumed_certs = client.get_peer_certificates();
262 
263             assert_eq!(original_certs, resumed_certs);
264         }
265     }
266 }
267 
268 #[test]
server_can_get_client_cert()269 fn server_can_get_client_cert() {
270     for kt in ALL_KEY_TYPES.iter() {
271         let mut client_config = make_client_config(*kt);
272         client_config
273             .set_single_client_cert(kt.get_chain(), kt.get_key())
274             .unwrap();
275 
276         let server_config = Arc::new(make_server_config_with_mandatory_client_auth(*kt));
277 
278         for client_config in AllClientVersions::new(client_config) {
279             let (mut client, mut server) =
280                 make_pair_for_arc_configs(&Arc::new(client_config), &server_config);
281             do_handshake(&mut client, &mut server);
282 
283             let certs = server.get_peer_certificates();
284             assert_eq!(certs, Some(kt.get_chain()));
285         }
286     }
287 }
288 
289 #[test]
server_can_get_client_cert_after_resumption()290 fn server_can_get_client_cert_after_resumption() {
291     for kt in ALL_KEY_TYPES.iter() {
292         let mut client_config = make_client_config(*kt);
293         client_config
294             .set_single_client_cert(kt.get_chain(), kt.get_key())
295             .unwrap();
296 
297         let server_config = Arc::new(make_server_config_with_mandatory_client_auth(*kt));
298 
299         for client_config in AllClientVersions::new(client_config) {
300             let client_config = Arc::new(client_config);
301             let (mut client, mut server) =
302                 make_pair_for_arc_configs(&client_config, &server_config);
303             do_handshake(&mut client, &mut server);
304             let original_certs = server.get_peer_certificates();
305 
306             let (mut client, mut server) =
307                 make_pair_for_arc_configs(&client_config, &server_config);
308             do_handshake(&mut client, &mut server);
309             let resumed_certs = server.get_peer_certificates();
310             assert_eq!(original_certs, resumed_certs);
311         }
312     }
313 }
314 
check_read_and_close(reader: &mut dyn io::Read, expect: &[u8])315 fn check_read_and_close(reader: &mut dyn io::Read, expect: &[u8]) {
316     let mut buf = Vec::new();
317     buf.resize(expect.len(), 0u8);
318     assert_eq!(expect.len(), reader.read(&mut buf).unwrap());
319     assert_eq!(expect.to_vec(), buf);
320 
321     let err = reader.read(&mut buf);
322     assert!(err.is_err());
323     assert_eq!(err.err().unwrap().kind(), io::ErrorKind::ConnectionAborted);
324 }
325 
326 #[test]
server_close_notify()327 fn server_close_notify() {
328     let kt = KeyType::RSA;
329     let mut client_config = make_client_config(kt);
330     client_config
331         .set_single_client_cert(kt.get_chain(), kt.get_key())
332         .unwrap();
333 
334     let server_config = Arc::new(make_server_config_with_mandatory_client_auth(kt));
335 
336     for client_config in AllClientVersions::new(client_config) {
337         let (mut client, mut server) =
338             make_pair_for_arc_configs(&Arc::new(client_config), &server_config);
339         do_handshake(&mut client, &mut server);
340 
341         // check that alerts don't overtake appdata
342         assert_eq!(12, server.write(b"from-server!").unwrap());
343         assert_eq!(12, client.write(b"from-client!").unwrap());
344         server.send_close_notify();
345 
346         transfer(&mut server, &mut client);
347         client.process_new_packets().unwrap();
348         check_read_and_close(&mut client, b"from-server!");
349 
350         transfer(&mut client, &mut server);
351         server.process_new_packets().unwrap();
352         check_read(&mut server, b"from-client!");
353     }
354 }
355 
356 #[test]
client_close_notify()357 fn client_close_notify() {
358     let kt = KeyType::RSA;
359     let mut client_config = make_client_config(kt);
360     client_config
361         .set_single_client_cert(kt.get_chain(), kt.get_key())
362         .unwrap();
363 
364     let server_config = Arc::new(make_server_config_with_mandatory_client_auth(kt));
365 
366     for client_config in AllClientVersions::new(client_config) {
367         let (mut client, mut server) =
368             make_pair_for_arc_configs(&Arc::new(client_config), &server_config);
369         do_handshake(&mut client, &mut server);
370 
371         // check that alerts don't overtake appdata
372         assert_eq!(12, server.write(b"from-server!").unwrap());
373         assert_eq!(12, client.write(b"from-client!").unwrap());
374         client.send_close_notify();
375 
376         transfer(&mut client, &mut server);
377         server.process_new_packets().unwrap();
378         check_read_and_close(&mut server, b"from-client!");
379 
380         transfer(&mut server, &mut client);
381         client.process_new_packets().unwrap();
382         check_read(&mut client, b"from-server!");
383     }
384 }
385 
386 #[derive(Default)]
387 struct ServerCheckCertResolve {
388     expected_sni: Option<String>,
389     expected_sigalgs: Option<Vec<SignatureScheme>>,
390     expected_alpn: Option<Vec<Vec<u8>>>,
391 }
392 
393 impl ResolvesServerCert for ServerCheckCertResolve {
resolve(&self, client_hello: ClientHello) -> Option<sign::CertifiedKey>394     fn resolve(&self, client_hello: ClientHello) -> Option<sign::CertifiedKey> {
395         if client_hello.sigschemes().len() == 0 {
396             panic!("no signature schemes shared by client");
397         }
398 
399         if let Some(expected_sni) = &self.expected_sni {
400             let sni: &str = client_hello
401                 .server_name()
402                 .expect("sni unexpectedly absent")
403                 .into();
404             assert_eq!(expected_sni, sni);
405         }
406 
407         if let Some(expected_sigalgs) = &self.expected_sigalgs {
408             if expected_sigalgs != &client_hello.sigschemes() {
409                 panic!(
410                     "unexpected signature schemes (wanted {:?} got {:?})",
411                     self.expected_sigalgs,
412                     client_hello.sigschemes()
413                 );
414             }
415         }
416 
417         if let Some(expected_alpn) = &self.expected_alpn {
418             let alpn = client_hello
419                 .alpn()
420                 .expect("alpn unexpectedly absent");
421             assert_eq!(alpn.len(), expected_alpn.len());
422 
423             for (got, wanted) in alpn.iter().zip(expected_alpn.iter()) {
424                 assert_eq!(got, &wanted.as_slice());
425             }
426         }
427 
428         None
429     }
430 }
431 
432 #[test]
server_cert_resolve_with_sni()433 fn server_cert_resolve_with_sni() {
434     for kt in ALL_KEY_TYPES.iter() {
435         let client_config = make_client_config(*kt);
436         let mut server_config = make_server_config(*kt);
437 
438         server_config.cert_resolver = Arc::new(ServerCheckCertResolve {
439             expected_sni: Some("the-value-from-sni".into()),
440             ..Default::default()
441         });
442 
443         let mut client =
444             ClientSession::new(&Arc::new(client_config), dns_name("the-value-from-sni"));
445         let mut server = ServerSession::new(&Arc::new(server_config));
446 
447         let err = do_handshake_until_error(&mut client, &mut server);
448         assert_eq!(err.is_err(), true);
449     }
450 }
451 
452 #[test]
server_cert_resolve_with_alpn()453 fn server_cert_resolve_with_alpn() {
454     for kt in ALL_KEY_TYPES.iter() {
455         let mut client_config = make_client_config(*kt);
456         client_config.alpn_protocols = vec!["foo".into(), "bar".into()];
457 
458         let mut server_config = make_server_config(*kt);
459         server_config.cert_resolver = Arc::new(ServerCheckCertResolve {
460             expected_alpn: Some(vec![b"foo".to_vec(), b"bar".to_vec()]),
461             ..Default::default()
462         });
463 
464         let mut client = ClientSession::new(&Arc::new(client_config), dns_name("sni-value"));
465         let mut server = ServerSession::new(&Arc::new(server_config));
466 
467         let err = do_handshake_until_error(&mut client, &mut server);
468         assert_eq!(err.is_err(), true);
469     }
470 }
471 
472 #[test]
client_trims_terminating_dot()473 fn client_trims_terminating_dot() {
474     for kt in ALL_KEY_TYPES.iter() {
475         let client_config = make_client_config(*kt);
476         let mut server_config = make_server_config(*kt);
477 
478         server_config.cert_resolver = Arc::new(ServerCheckCertResolve {
479             expected_sni: Some("some-host.com".into()),
480             ..Default::default()
481         });
482 
483         let mut client = ClientSession::new(&Arc::new(client_config), dns_name("some-host.com."));
484         let mut server = ServerSession::new(&Arc::new(server_config));
485 
486         let err = do_handshake_until_error(&mut client, &mut server);
487         assert_eq!(err.is_err(), true);
488     }
489 }
490 
check_sigalgs_reduced_by_ciphersuite( kt: KeyType, suite: CipherSuite, expected_sigalgs: Vec<SignatureScheme>, )491 fn check_sigalgs_reduced_by_ciphersuite(
492     kt: KeyType,
493     suite: CipherSuite,
494     expected_sigalgs: Vec<SignatureScheme>,
495 ) {
496     let mut client_config = make_client_config(kt);
497     client_config.ciphersuites = vec![find_suite(suite)];
498 
499     let mut server_config = make_server_config(kt);
500 
501     server_config.cert_resolver = Arc::new(ServerCheckCertResolve {
502         expected_sigalgs: Some(expected_sigalgs),
503         ..Default::default()
504     });
505 
506     let mut client = ClientSession::new(&Arc::new(client_config), dns_name("localhost"));
507     let mut server = ServerSession::new(&Arc::new(server_config));
508 
509     let err = do_handshake_until_error(&mut client, &mut server);
510     assert_eq!(err.is_err(), true);
511 }
512 
513 #[test]
server_cert_resolve_reduces_sigalgs_for_rsa_ciphersuite()514 fn server_cert_resolve_reduces_sigalgs_for_rsa_ciphersuite() {
515     check_sigalgs_reduced_by_ciphersuite(
516         KeyType::RSA,
517         CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
518         vec![
519             SignatureScheme::RSA_PSS_SHA512,
520             SignatureScheme::RSA_PSS_SHA384,
521             SignatureScheme::RSA_PSS_SHA256,
522             SignatureScheme::RSA_PKCS1_SHA512,
523             SignatureScheme::RSA_PKCS1_SHA384,
524             SignatureScheme::RSA_PKCS1_SHA256,
525         ],
526     );
527 }
528 
529 #[test]
server_cert_resolve_reduces_sigalgs_for_ecdsa_ciphersuite()530 fn server_cert_resolve_reduces_sigalgs_for_ecdsa_ciphersuite() {
531     check_sigalgs_reduced_by_ciphersuite(
532         KeyType::ECDSA,
533         CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
534         vec![
535             SignatureScheme::ECDSA_NISTP384_SHA384,
536             SignatureScheme::ECDSA_NISTP256_SHA256,
537             SignatureScheme::ED25519,
538         ],
539     );
540 }
541 
542 struct ServerCheckNoSNI {}
543 
544 impl ResolvesServerCert for ServerCheckNoSNI {
resolve(&self, client_hello: ClientHello) -> Option<sign::CertifiedKey>545     fn resolve(&self, client_hello: ClientHello) -> Option<sign::CertifiedKey> {
546         assert!(client_hello.server_name().is_none());
547 
548         None
549     }
550 }
551 
552 #[test]
client_with_sni_disabled_does_not_send_sni()553 fn client_with_sni_disabled_does_not_send_sni() {
554     for kt in ALL_KEY_TYPES.iter() {
555         let mut client_config = make_client_config(*kt);
556         client_config.enable_sni = false;
557 
558         let mut server_config = make_server_config(*kt);
559         server_config.cert_resolver = Arc::new(ServerCheckNoSNI {});
560         let server_config = Arc::new(server_config);
561 
562         for client_config in AllClientVersions::new(client_config) {
563             let mut client =
564                 ClientSession::new(&Arc::new(client_config), dns_name("value-not-sent"));
565             let mut server = ServerSession::new(&server_config);
566 
567             let err = do_handshake_until_error(&mut client, &mut server);
568             assert_eq!(err.is_err(), true);
569         }
570     }
571 }
572 
573 #[test]
client_checks_server_certificate_with_given_name()574 fn client_checks_server_certificate_with_given_name() {
575     for kt in ALL_KEY_TYPES.iter() {
576         let client_config = make_client_config(*kt);
577         let server_config = Arc::new(make_server_config(*kt));
578 
579         for client_config in AllClientVersions::new(client_config) {
580             let mut client = ClientSession::new(
581                 &Arc::new(client_config),
582                 dns_name("not-the-right-hostname.com"),
583             );
584             let mut server = ServerSession::new(&server_config);
585 
586             let err = do_handshake_until_error(&mut client, &mut server);
587             assert_eq!(
588                 err,
589                 Err(TLSErrorFromPeer::Client(TLSError::WebPKIError(
590                     webpki::Error::CertNotValidForName
591                 )))
592             );
593         }
594     }
595 }
596 
597 struct ClientCheckCertResolve {
598     query_count: AtomicUsize,
599     expect_queries: usize,
600 }
601 
602 impl ClientCheckCertResolve {
new(expect_queries: usize) -> ClientCheckCertResolve603     fn new(expect_queries: usize) -> ClientCheckCertResolve {
604         ClientCheckCertResolve {
605             query_count: AtomicUsize::new(0),
606             expect_queries: expect_queries,
607         }
608     }
609 }
610 
611 impl Drop for ClientCheckCertResolve {
drop(&mut self)612     fn drop(&mut self) {
613         let count = self.query_count.load(Ordering::SeqCst);
614         assert_eq!(count, self.expect_queries);
615     }
616 }
617 
618 impl ResolvesClientCert for ClientCheckCertResolve {
resolve( &self, acceptable_issuers: &[&[u8]], sigschemes: &[SignatureScheme], ) -> Option<sign::CertifiedKey>619     fn resolve(
620         &self,
621         acceptable_issuers: &[&[u8]],
622         sigschemes: &[SignatureScheme],
623     ) -> Option<sign::CertifiedKey> {
624         self.query_count
625             .fetch_add(1, Ordering::SeqCst);
626 
627         if acceptable_issuers.len() == 0 {
628             panic!("no issuers offered by server");
629         }
630 
631         if sigschemes.len() == 0 {
632             panic!("no signature schemes shared by server");
633         }
634 
635         None
636     }
637 
has_certs(&self) -> bool638     fn has_certs(&self) -> bool {
639         true
640     }
641 }
642 
643 #[test]
client_cert_resolve()644 fn client_cert_resolve() {
645     for kt in ALL_KEY_TYPES.iter() {
646         let mut client_config = make_client_config(*kt);
647         client_config.client_auth_cert_resolver = Arc::new(ClientCheckCertResolve::new(2));
648 
649         let server_config = Arc::new(make_server_config_with_mandatory_client_auth(*kt));
650 
651         for client_config in AllClientVersions::new(client_config) {
652             let (mut client, mut server) =
653                 make_pair_for_arc_configs(&Arc::new(client_config), &server_config);
654 
655             assert_eq!(
656                 do_handshake_until_error(&mut client, &mut server),
657                 Err(TLSErrorFromPeer::Server(TLSError::NoCertificatesPresented))
658             );
659         }
660     }
661 }
662 
663 #[test]
client_auth_works()664 fn client_auth_works() {
665     for kt in ALL_KEY_TYPES.iter() {
666         let client_config = make_client_config_with_auth(*kt);
667         let server_config = Arc::new(make_server_config_with_mandatory_client_auth(*kt));
668 
669         for client_config in AllClientVersions::new(client_config) {
670             let (mut client, mut server) =
671                 make_pair_for_arc_configs(&Arc::new(client_config), &server_config);
672             do_handshake(&mut client, &mut server);
673         }
674     }
675 }
676 
677 #[cfg(feature = "dangerous_configuration")]
678 mod test_clientverifier {
679     use super::*;
680     use crate::common::MockClientVerifier;
681     use rustls::internal::msgs::enums::AlertDescription;
682     use rustls::internal::msgs::enums::ContentType;
683 
684     // Client is authorized!
ver_ok() -> Result<ClientCertVerified, TLSError>685     fn ver_ok() -> Result<ClientCertVerified, TLSError> {
686         Ok(rustls::ClientCertVerified::assertion())
687     }
688 
689     // Use when we shouldn't even attempt verification
ver_unreachable() -> Result<ClientCertVerified, TLSError>690     fn ver_unreachable() -> Result<ClientCertVerified, TLSError> {
691         unreachable!()
692     }
693 
694     // Verifier that returns an error that we can expect
ver_err() -> Result<ClientCertVerified, TLSError>695     fn ver_err() -> Result<ClientCertVerified, TLSError> {
696         Err(TLSError::General("test err".to_string()))
697     }
698 
699     #[test]
700     // Happy path, we resolve to a root, it is verified OK, should be able to connect
client_verifier_works()701     fn client_verifier_works() {
702         for kt in ALL_KEY_TYPES.iter() {
703             let client_verifier = MockClientVerifier {
704                 verified: ver_ok,
705                 subjects: Some(get_client_root_store(*kt).get_subjects()),
706                 mandatory: Some(true),
707                 offered_schemes: None,
708             };
709 
710             let mut server_config = ServerConfig::new(Arc::new(client_verifier));
711             server_config
712                 .set_single_cert(kt.get_chain(), kt.get_key())
713                 .unwrap();
714 
715             let server_config = Arc::new(server_config);
716             let client_config = make_client_config_with_auth(*kt);
717 
718             for client_config in AllClientVersions::new(client_config) {
719                 let (mut client, mut server) =
720                     make_pair_for_arc_configs(&Arc::new(client_config.clone()), &server_config);
721                 let err = do_handshake_until_error(&mut client, &mut server);
722                 assert_eq!(err, Ok(()));
723             }
724         }
725     }
726 
727     // Server offers no verification schemes
728     #[test]
client_verifier_no_schemes()729     fn client_verifier_no_schemes() {
730         for kt in ALL_KEY_TYPES.iter() {
731             let client_verifier = MockClientVerifier {
732                 verified: ver_ok,
733                 subjects: Some(get_client_root_store(*kt).get_subjects()),
734                 mandatory: Some(true),
735                 offered_schemes: Some(vec![]),
736             };
737 
738             let mut server_config = ServerConfig::new(Arc::new(client_verifier));
739             server_config
740                 .set_single_cert(kt.get_chain(), kt.get_key())
741                 .unwrap();
742 
743             let server_config = Arc::new(server_config);
744             let client_config = make_client_config_with_auth(*kt);
745 
746             for client_config in AllClientVersions::new(client_config) {
747                 let (mut client, mut server) =
748                     make_pair_for_arc_configs(&Arc::new(client_config.clone()), &server_config);
749                 let err = do_handshake_until_error(&mut client, &mut server);
750                 assert_eq!(
751                     err,
752                     Err(TLSErrorFromPeer::Client(TLSError::CorruptMessagePayload(
753                         ContentType::Handshake
754                     )))
755                 );
756             }
757         }
758     }
759 
760     // Common case, we do not find a root store to resolve to
761     #[test]
client_verifier_no_root()762     fn client_verifier_no_root() {
763         for kt in ALL_KEY_TYPES.iter() {
764             let client_verifier = MockClientVerifier {
765                 verified: ver_ok,
766                 subjects: None,
767                 mandatory: Some(true),
768                 offered_schemes: None,
769             };
770 
771             let mut server_config = ServerConfig::new(Arc::new(client_verifier));
772             server_config
773                 .set_single_cert(kt.get_chain(), kt.get_key())
774                 .unwrap();
775 
776             let server_config = Arc::new(server_config);
777             let client_config = make_client_config_with_auth(*kt);
778 
779             for client_config in AllClientVersions::new(client_config) {
780                 let mut server = ServerSession::new(&server_config);
781                 let mut client =
782                     ClientSession::new(&Arc::new(client_config), dns_name("notlocalhost"));
783                 let errs = do_handshake_until_both_error(&mut client, &mut server);
784                 assert_eq!(
785                     errs,
786                     Err(vec![
787                         TLSErrorFromPeer::Server(TLSError::General(
788                             "client rejected by client_auth_root_subjects".into()
789                         )),
790                         TLSErrorFromPeer::Client(TLSError::AlertReceived(
791                             AlertDescription::AccessDenied
792                         ))
793                     ])
794                 );
795             }
796         }
797     }
798 
799     // If we cannot resolve a root, we cannot decide if auth is mandatory
800     #[test]
client_verifier_no_auth_no_root()801     fn client_verifier_no_auth_no_root() {
802         for kt in ALL_KEY_TYPES.iter() {
803             let client_verifier = MockClientVerifier {
804                 verified: ver_unreachable,
805                 subjects: None,
806                 mandatory: Some(true),
807                 offered_schemes: None,
808             };
809 
810             let mut server_config = ServerConfig::new(Arc::new(client_verifier));
811             server_config
812                 .set_single_cert(kt.get_chain(), kt.get_key())
813                 .unwrap();
814 
815             let server_config = Arc::new(server_config);
816             let client_config = make_client_config(*kt);
817 
818             for client_config in AllClientVersions::new(client_config) {
819                 let mut server = ServerSession::new(&server_config);
820                 let mut client =
821                     ClientSession::new(&Arc::new(client_config), dns_name("notlocalhost"));
822                 let errs = do_handshake_until_both_error(&mut client, &mut server);
823                 assert_eq!(
824                     errs,
825                     Err(vec![
826                         TLSErrorFromPeer::Server(TLSError::General(
827                             "client rejected by client_auth_root_subjects".into()
828                         )),
829                         TLSErrorFromPeer::Client(TLSError::AlertReceived(
830                             AlertDescription::AccessDenied
831                         ))
832                     ])
833                 );
834             }
835         }
836     }
837 
838     // If we do have a root, we must do auth
839     #[test]
client_verifier_no_auth_yes_root()840     fn client_verifier_no_auth_yes_root() {
841         for kt in ALL_KEY_TYPES.iter() {
842             let client_verifier = MockClientVerifier {
843                 verified: ver_unreachable,
844                 subjects: Some(get_client_root_store(*kt).get_subjects()),
845                 mandatory: Some(true),
846                 offered_schemes: None,
847             };
848 
849             let mut server_config = ServerConfig::new(Arc::new(client_verifier));
850             server_config
851                 .set_single_cert(kt.get_chain(), kt.get_key())
852                 .unwrap();
853 
854             let server_config = Arc::new(server_config);
855             let client_config = make_client_config(*kt);
856 
857             for client_config in AllClientVersions::new(client_config) {
858                 println!("Failing: {:?}", client_config.versions);
859                 let mut server = ServerSession::new(&server_config);
860                 let mut client =
861                     ClientSession::new(&Arc::new(client_config), dns_name("localhost"));
862                 let errs = do_handshake_until_both_error(&mut client, &mut server);
863                 assert_eq!(
864                     errs,
865                     Err(vec![
866                         TLSErrorFromPeer::Server(TLSError::NoCertificatesPresented),
867                         TLSErrorFromPeer::Client(TLSError::AlertReceived(
868                             AlertDescription::CertificateRequired
869                         ))
870                     ])
871                 );
872             }
873         }
874     }
875 
876     #[test]
877     // Triple checks we propagate the TLSError through
client_verifier_fails_properly()878     fn client_verifier_fails_properly() {
879         for kt in ALL_KEY_TYPES.iter() {
880             let client_verifier = MockClientVerifier {
881                 verified: ver_err,
882                 subjects: Some(get_client_root_store(*kt).get_subjects()),
883                 mandatory: Some(true),
884                 offered_schemes: None,
885             };
886 
887             let mut server_config = ServerConfig::new(Arc::new(client_verifier));
888             server_config
889                 .set_single_cert(kt.get_chain(), kt.get_key())
890                 .unwrap();
891 
892             let server_config = Arc::new(server_config);
893             let client_config = make_client_config_with_auth(*kt);
894 
895             for client_config in AllClientVersions::new(client_config) {
896                 let mut server = ServerSession::new(&server_config);
897                 let mut client =
898                     ClientSession::new(&Arc::new(client_config), dns_name("localhost"));
899                 let err = do_handshake_until_error(&mut client, &mut server);
900                 assert_eq!(
901                     err,
902                     Err(TLSErrorFromPeer::Server(TLSError::General(
903                         "test err".into()
904                     )))
905                 );
906             }
907         }
908     }
909 
910     #[test]
911     // If a verifier returns a None on Mandatory-ness, then we error out
client_verifier_must_determine_client_auth_requirement_to_continue()912     fn client_verifier_must_determine_client_auth_requirement_to_continue() {
913         for kt in ALL_KEY_TYPES.iter() {
914             let client_verifier = MockClientVerifier {
915                 verified: ver_ok,
916                 subjects: Some(get_client_root_store(*kt).get_subjects()),
917                 mandatory: None,
918                 offered_schemes: None,
919             };
920 
921             let mut server_config = ServerConfig::new(Arc::new(client_verifier));
922             server_config
923                 .set_single_cert(kt.get_chain(), kt.get_key())
924                 .unwrap();
925 
926             let server_config = Arc::new(server_config);
927             let client_config = make_client_config_with_auth(*kt);
928 
929             for client_config in AllClientVersions::new(client_config) {
930                 let mut server = ServerSession::new(&server_config);
931                 let mut client =
932                     ClientSession::new(&Arc::new(client_config), dns_name("localhost"));
933                 let errs = do_handshake_until_both_error(&mut client, &mut server);
934                 assert_eq!(
935                     errs,
936                     Err(vec![
937                         TLSErrorFromPeer::Server(TLSError::General(
938                             "client rejected by client_auth_mandatory".into()
939                         )),
940                         TLSErrorFromPeer::Client(TLSError::AlertReceived(
941                             AlertDescription::AccessDenied
942                         ))
943                     ])
944                 );
945             }
946         }
947     }
948 } // mod test_clientverifier
949 
950 #[test]
client_error_is_sticky()951 fn client_error_is_sticky() {
952     let (mut client, _) = make_pair(KeyType::RSA);
953     client
954         .read_tls(&mut b"\x16\x03\x03\x00\x08\x0f\x00\x00\x04junk".as_ref())
955         .unwrap();
956     let mut err = client.process_new_packets();
957     assert_eq!(err.is_err(), true);
958     err = client.process_new_packets();
959     assert_eq!(err.is_err(), true);
960 }
961 
962 #[test]
server_error_is_sticky()963 fn server_error_is_sticky() {
964     let (_, mut server) = make_pair(KeyType::RSA);
965     server
966         .read_tls(&mut b"\x16\x03\x03\x00\x08\x0f\x00\x00\x04junk".as_ref())
967         .unwrap();
968     let mut err = server.process_new_packets();
969     assert_eq!(err.is_err(), true);
970     err = server.process_new_packets();
971     assert_eq!(err.is_err(), true);
972 }
973 
974 #[test]
server_is_send_and_sync()975 fn server_is_send_and_sync() {
976     let (_, server) = make_pair(KeyType::RSA);
977     &server as &dyn Send;
978     &server as &dyn Sync;
979 }
980 
981 #[test]
client_is_send_and_sync()982 fn client_is_send_and_sync() {
983     let (client, _) = make_pair(KeyType::RSA);
984     &client as &dyn Send;
985     &client as &dyn Sync;
986 }
987 
988 #[test]
server_respects_buffer_limit_pre_handshake()989 fn server_respects_buffer_limit_pre_handshake() {
990     let (mut client, mut server) = make_pair(KeyType::RSA);
991 
992     server.set_buffer_limit(32);
993 
994     assert_eq!(
995         server
996             .write(b"01234567890123456789")
997             .unwrap(),
998         20
999     );
1000     assert_eq!(
1001         server
1002             .write(b"01234567890123456789")
1003             .unwrap(),
1004         12
1005     );
1006 
1007     do_handshake(&mut client, &mut server);
1008     transfer(&mut server, &mut client);
1009     client.process_new_packets().unwrap();
1010 
1011     check_read(&mut client, b"01234567890123456789012345678901");
1012 }
1013 
1014 #[test]
server_respects_buffer_limit_pre_handshake_with_vectored_write()1015 fn server_respects_buffer_limit_pre_handshake_with_vectored_write() {
1016     let (mut client, mut server) = make_pair(KeyType::RSA);
1017 
1018     server.set_buffer_limit(32);
1019 
1020     assert_eq!(
1021         server
1022             .write_vectored(&[
1023                 IoSlice::new(b"01234567890123456789"),
1024                 IoSlice::new(b"01234567890123456789")
1025             ])
1026             .unwrap(),
1027         32
1028     );
1029 
1030     do_handshake(&mut client, &mut server);
1031     transfer(&mut server, &mut client);
1032     client.process_new_packets().unwrap();
1033 
1034     check_read(&mut client, b"01234567890123456789012345678901");
1035 }
1036 
1037 #[test]
server_respects_buffer_limit_post_handshake()1038 fn server_respects_buffer_limit_post_handshake() {
1039     let (mut client, mut server) = make_pair(KeyType::RSA);
1040 
1041     // this test will vary in behaviour depending on the default suites
1042     do_handshake(&mut client, &mut server);
1043     server.set_buffer_limit(48);
1044 
1045     assert_eq!(
1046         server
1047             .write(b"01234567890123456789")
1048             .unwrap(),
1049         20
1050     );
1051     assert_eq!(
1052         server
1053             .write(b"01234567890123456789")
1054             .unwrap(),
1055         6
1056     );
1057 
1058     transfer(&mut server, &mut client);
1059     client.process_new_packets().unwrap();
1060 
1061     check_read(&mut client, b"01234567890123456789012345");
1062 }
1063 
1064 #[test]
client_respects_buffer_limit_pre_handshake()1065 fn client_respects_buffer_limit_pre_handshake() {
1066     let (mut client, mut server) = make_pair(KeyType::RSA);
1067 
1068     client.set_buffer_limit(32);
1069 
1070     assert_eq!(
1071         client
1072             .write(b"01234567890123456789")
1073             .unwrap(),
1074         20
1075     );
1076     assert_eq!(
1077         client
1078             .write(b"01234567890123456789")
1079             .unwrap(),
1080         12
1081     );
1082 
1083     do_handshake(&mut client, &mut server);
1084     transfer(&mut client, &mut server);
1085     server.process_new_packets().unwrap();
1086 
1087     check_read(&mut server, b"01234567890123456789012345678901");
1088 }
1089 
1090 #[test]
client_respects_buffer_limit_pre_handshake_with_vectored_write()1091 fn client_respects_buffer_limit_pre_handshake_with_vectored_write() {
1092     let (mut client, mut server) = make_pair(KeyType::RSA);
1093 
1094     client.set_buffer_limit(32);
1095 
1096     assert_eq!(
1097         client
1098             .write_vectored(&[
1099                 IoSlice::new(b"01234567890123456789"),
1100                 IoSlice::new(b"01234567890123456789")
1101             ])
1102             .unwrap(),
1103         32
1104     );
1105 
1106     do_handshake(&mut client, &mut server);
1107     transfer(&mut client, &mut server);
1108     server.process_new_packets().unwrap();
1109 
1110     check_read(&mut server, b"01234567890123456789012345678901");
1111 }
1112 
1113 #[test]
client_respects_buffer_limit_post_handshake()1114 fn client_respects_buffer_limit_post_handshake() {
1115     let (mut client, mut server) = make_pair(KeyType::RSA);
1116 
1117     do_handshake(&mut client, &mut server);
1118     client.set_buffer_limit(48);
1119 
1120     assert_eq!(
1121         client
1122             .write(b"01234567890123456789")
1123             .unwrap(),
1124         20
1125     );
1126     assert_eq!(
1127         client
1128             .write(b"01234567890123456789")
1129             .unwrap(),
1130         6
1131     );
1132 
1133     transfer(&mut client, &mut server);
1134     server.process_new_packets().unwrap();
1135 
1136     check_read(&mut server, b"01234567890123456789012345");
1137 }
1138 
1139 struct OtherSession<'a> {
1140     sess: &'a mut dyn Session,
1141     pub reads: usize,
1142     pub writevs: Vec<Vec<usize>>,
1143     fail_ok: bool,
1144     pub short_writes: bool,
1145     pub last_error: Option<rustls::TLSError>,
1146 }
1147 
1148 impl<'a> OtherSession<'a> {
new(sess: &'a mut dyn Session) -> OtherSession<'a>1149     fn new(sess: &'a mut dyn Session) -> OtherSession<'a> {
1150         OtherSession {
1151             sess,
1152             reads: 0,
1153             writevs: vec![],
1154             fail_ok: false,
1155             short_writes: false,
1156             last_error: None,
1157         }
1158     }
1159 
new_fails(sess: &'a mut dyn Session) -> OtherSession<'a>1160     fn new_fails(sess: &'a mut dyn Session) -> OtherSession<'a> {
1161         let mut os = OtherSession::new(sess);
1162         os.fail_ok = true;
1163         os
1164     }
1165 }
1166 
1167 impl<'a> io::Read for OtherSession<'a> {
read(&mut self, mut b: &mut [u8]) -> io::Result<usize>1168     fn read(&mut self, mut b: &mut [u8]) -> io::Result<usize> {
1169         self.reads += 1;
1170         self.sess.write_tls(b.by_ref())
1171     }
1172 }
1173 
1174 impl<'a> io::Write for OtherSession<'a> {
write(&mut self, _: &[u8]) -> io::Result<usize>1175     fn write(&mut self, _: &[u8]) -> io::Result<usize> {
1176         unreachable!()
1177     }
1178 
flush(&mut self) -> io::Result<()>1179     fn flush(&mut self) -> io::Result<()> {
1180         Ok(())
1181     }
1182 
write_vectored<'b>(&mut self, b: &[io::IoSlice<'b>]) -> io::Result<usize>1183     fn write_vectored<'b>(&mut self, b: &[io::IoSlice<'b>]) -> io::Result<usize> {
1184         let mut total = 0;
1185         let mut lengths = vec![];
1186         for bytes in b {
1187             let write_len = if self.short_writes {
1188                 if bytes.len() > 5 {
1189                     bytes.len() / 2
1190                 } else {
1191                     bytes.len()
1192                 }
1193             } else {
1194                 bytes.len()
1195             };
1196 
1197             let l = self
1198                 .sess
1199                 .read_tls(&mut io::Cursor::new(&bytes[..write_len]))?;
1200             lengths.push(l);
1201             total += l;
1202             if bytes.len() != l {
1203                 break;
1204             }
1205         }
1206 
1207         let rc = self.sess.process_new_packets();
1208         if !self.fail_ok {
1209             rc.unwrap();
1210         } else if rc.is_err() {
1211             self.last_error = rc.err();
1212         }
1213 
1214         self.writevs.push(lengths);
1215         Ok(total)
1216     }
1217 }
1218 
1219 #[test]
client_complete_io_for_handshake()1220 fn client_complete_io_for_handshake() {
1221     let (mut client, mut server) = make_pair(KeyType::RSA);
1222 
1223     assert_eq!(true, client.is_handshaking());
1224     let (rdlen, wrlen) = client
1225         .complete_io(&mut OtherSession::new(&mut server))
1226         .unwrap();
1227     assert!(rdlen > 0 && wrlen > 0);
1228     assert_eq!(false, client.is_handshaking());
1229 }
1230 
1231 #[test]
client_complete_io_for_handshake_eof()1232 fn client_complete_io_for_handshake_eof() {
1233     let (mut client, _) = make_pair(KeyType::RSA);
1234     let mut input = io::Cursor::new(Vec::new());
1235 
1236     assert_eq!(true, client.is_handshaking());
1237     let err = client
1238         .complete_io(&mut input)
1239         .unwrap_err();
1240     assert_eq!(io::ErrorKind::UnexpectedEof, err.kind());
1241 }
1242 
1243 #[test]
client_complete_io_for_write()1244 fn client_complete_io_for_write() {
1245     for kt in ALL_KEY_TYPES.iter() {
1246         let (mut client, mut server) = make_pair(*kt);
1247 
1248         do_handshake(&mut client, &mut server);
1249 
1250         client
1251             .write(b"01234567890123456789")
1252             .unwrap();
1253         client
1254             .write(b"01234567890123456789")
1255             .unwrap();
1256         {
1257             let mut pipe = OtherSession::new(&mut server);
1258             let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap();
1259             assert!(rdlen == 0 && wrlen > 0);
1260             println!("{:?}", pipe.writevs);
1261             assert_eq!(pipe.writevs, vec![vec![42, 42]]);
1262         }
1263         check_read(&mut server, b"0123456789012345678901234567890123456789");
1264     }
1265 }
1266 
1267 #[test]
client_complete_io_for_read()1268 fn client_complete_io_for_read() {
1269     for kt in ALL_KEY_TYPES.iter() {
1270         let (mut client, mut server) = make_pair(*kt);
1271 
1272         do_handshake(&mut client, &mut server);
1273 
1274         server
1275             .write(b"01234567890123456789")
1276             .unwrap();
1277         {
1278             let mut pipe = OtherSession::new(&mut server);
1279             let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap();
1280             assert!(rdlen > 0 && wrlen == 0);
1281             assert_eq!(pipe.reads, 1);
1282         }
1283         check_read(&mut client, b"01234567890123456789");
1284     }
1285 }
1286 
1287 #[test]
server_complete_io_for_handshake()1288 fn server_complete_io_for_handshake() {
1289     for kt in ALL_KEY_TYPES.iter() {
1290         let (mut client, mut server) = make_pair(*kt);
1291 
1292         assert_eq!(true, server.is_handshaking());
1293         let (rdlen, wrlen) = server
1294             .complete_io(&mut OtherSession::new(&mut client))
1295             .unwrap();
1296         assert!(rdlen > 0 && wrlen > 0);
1297         assert_eq!(false, server.is_handshaking());
1298     }
1299 }
1300 
1301 #[test]
server_complete_io_for_handshake_eof()1302 fn server_complete_io_for_handshake_eof() {
1303     let (_, mut server) = make_pair(KeyType::RSA);
1304     let mut input = io::Cursor::new(Vec::new());
1305 
1306     assert_eq!(true, server.is_handshaking());
1307     let err = server
1308         .complete_io(&mut input)
1309         .unwrap_err();
1310     assert_eq!(io::ErrorKind::UnexpectedEof, err.kind());
1311 }
1312 
1313 #[test]
server_complete_io_for_write()1314 fn server_complete_io_for_write() {
1315     for kt in ALL_KEY_TYPES.iter() {
1316         let (mut client, mut server) = make_pair(*kt);
1317 
1318         do_handshake(&mut client, &mut server);
1319 
1320         server
1321             .write(b"01234567890123456789")
1322             .unwrap();
1323         server
1324             .write(b"01234567890123456789")
1325             .unwrap();
1326         {
1327             let mut pipe = OtherSession::new(&mut client);
1328             let (rdlen, wrlen) = server.complete_io(&mut pipe).unwrap();
1329             assert!(rdlen == 0 && wrlen > 0);
1330             assert_eq!(pipe.writevs, vec![vec![42, 42]]);
1331         }
1332         check_read(&mut client, b"0123456789012345678901234567890123456789");
1333     }
1334 }
1335 
1336 #[test]
server_complete_io_for_read()1337 fn server_complete_io_for_read() {
1338     for kt in ALL_KEY_TYPES.iter() {
1339         let (mut client, mut server) = make_pair(*kt);
1340 
1341         do_handshake(&mut client, &mut server);
1342 
1343         client
1344             .write(b"01234567890123456789")
1345             .unwrap();
1346         {
1347             let mut pipe = OtherSession::new(&mut client);
1348             let (rdlen, wrlen) = server.complete_io(&mut pipe).unwrap();
1349             assert!(rdlen > 0 && wrlen == 0);
1350             assert_eq!(pipe.reads, 1);
1351         }
1352         check_read(&mut server, b"01234567890123456789");
1353     }
1354 }
1355 
1356 #[test]
client_stream_write()1357 fn client_stream_write() {
1358     for kt in ALL_KEY_TYPES.iter() {
1359         let (mut client, mut server) = make_pair(*kt);
1360 
1361         {
1362             let mut pipe = OtherSession::new(&mut server);
1363             let mut stream = Stream::new(&mut client, &mut pipe);
1364             assert_eq!(stream.write(b"hello").unwrap(), 5);
1365         }
1366         check_read(&mut server, b"hello");
1367     }
1368 }
1369 
1370 #[test]
client_streamowned_write()1371 fn client_streamowned_write() {
1372     for kt in ALL_KEY_TYPES.iter() {
1373         let (client, mut server) = make_pair(*kt);
1374 
1375         {
1376             let pipe = OtherSession::new(&mut server);
1377             let mut stream = StreamOwned::new(client, pipe);
1378             assert_eq!(stream.write(b"hello").unwrap(), 5);
1379         }
1380         check_read(&mut server, b"hello");
1381     }
1382 }
1383 
1384 #[test]
client_stream_read()1385 fn client_stream_read() {
1386     for kt in ALL_KEY_TYPES.iter() {
1387         let (mut client, mut server) = make_pair(*kt);
1388 
1389         server.write(b"world").unwrap();
1390 
1391         {
1392             let mut pipe = OtherSession::new(&mut server);
1393             let mut stream = Stream::new(&mut client, &mut pipe);
1394             check_read(&mut stream, b"world");
1395         }
1396     }
1397 }
1398 
1399 #[test]
client_streamowned_read()1400 fn client_streamowned_read() {
1401     for kt in ALL_KEY_TYPES.iter() {
1402         let (client, mut server) = make_pair(*kt);
1403 
1404         server.write(b"world").unwrap();
1405 
1406         {
1407             let pipe = OtherSession::new(&mut server);
1408             let mut stream = StreamOwned::new(client, pipe);
1409             check_read(&mut stream, b"world");
1410         }
1411     }
1412 }
1413 
1414 #[test]
server_stream_write()1415 fn server_stream_write() {
1416     for kt in ALL_KEY_TYPES.iter() {
1417         let (mut client, mut server) = make_pair(*kt);
1418 
1419         {
1420             let mut pipe = OtherSession::new(&mut client);
1421             let mut stream = Stream::new(&mut server, &mut pipe);
1422             assert_eq!(stream.write(b"hello").unwrap(), 5);
1423         }
1424         check_read(&mut client, b"hello");
1425     }
1426 }
1427 
1428 #[test]
server_streamowned_write()1429 fn server_streamowned_write() {
1430     for kt in ALL_KEY_TYPES.iter() {
1431         let (mut client, server) = make_pair(*kt);
1432 
1433         {
1434             let pipe = OtherSession::new(&mut client);
1435             let mut stream = StreamOwned::new(server, pipe);
1436             assert_eq!(stream.write(b"hello").unwrap(), 5);
1437         }
1438         check_read(&mut client, b"hello");
1439     }
1440 }
1441 
1442 #[test]
server_stream_read()1443 fn server_stream_read() {
1444     for kt in ALL_KEY_TYPES.iter() {
1445         let (mut client, mut server) = make_pair(*kt);
1446 
1447         client.write(b"world").unwrap();
1448 
1449         {
1450             let mut pipe = OtherSession::new(&mut client);
1451             let mut stream = Stream::new(&mut server, &mut pipe);
1452             check_read(&mut stream, b"world");
1453         }
1454     }
1455 }
1456 
1457 #[test]
server_streamowned_read()1458 fn server_streamowned_read() {
1459     for kt in ALL_KEY_TYPES.iter() {
1460         let (mut client, server) = make_pair(*kt);
1461 
1462         client.write(b"world").unwrap();
1463 
1464         {
1465             let pipe = OtherSession::new(&mut client);
1466             let mut stream = StreamOwned::new(server, pipe);
1467             check_read(&mut stream, b"world");
1468         }
1469     }
1470 }
1471 
1472 struct FailsWrites {
1473     errkind: io::ErrorKind,
1474     after: usize,
1475 }
1476 
1477 impl io::Read for FailsWrites {
read(&mut self, _b: &mut [u8]) -> io::Result<usize>1478     fn read(&mut self, _b: &mut [u8]) -> io::Result<usize> {
1479         Ok(0)
1480     }
1481 }
1482 
1483 impl io::Write for FailsWrites {
write(&mut self, b: &[u8]) -> io::Result<usize>1484     fn write(&mut self, b: &[u8]) -> io::Result<usize> {
1485         if self.after > 0 {
1486             self.after -= 1;
1487             Ok(b.len())
1488         } else {
1489             Err(io::Error::new(self.errkind, "oops"))
1490         }
1491     }
1492 
flush(&mut self) -> io::Result<()>1493     fn flush(&mut self) -> io::Result<()> {
1494         Ok(())
1495     }
1496 }
1497 
1498 #[test]
stream_write_reports_underlying_io_error_before_plaintext_processed()1499 fn stream_write_reports_underlying_io_error_before_plaintext_processed() {
1500     let (mut client, mut server) = make_pair(KeyType::RSA);
1501     do_handshake(&mut client, &mut server);
1502 
1503     let mut pipe = FailsWrites {
1504         errkind: io::ErrorKind::WouldBlock,
1505         after: 0,
1506     };
1507     client.write(b"hello").unwrap();
1508     let mut client_stream = Stream::new(&mut client, &mut pipe);
1509     let rc = client_stream.write(b"world");
1510     assert!(rc.is_err());
1511     let err = rc.err().unwrap();
1512     assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
1513 }
1514 
1515 #[test]
stream_write_swallows_underlying_io_error_after_plaintext_processed()1516 fn stream_write_swallows_underlying_io_error_after_plaintext_processed() {
1517     let (mut client, mut server) = make_pair(KeyType::RSA);
1518     do_handshake(&mut client, &mut server);
1519 
1520     let mut pipe = FailsWrites {
1521         errkind: io::ErrorKind::WouldBlock,
1522         after: 1,
1523     };
1524     client.write(b"hello").unwrap();
1525     let mut client_stream = Stream::new(&mut client, &mut pipe);
1526     let rc = client_stream.write(b"world");
1527     assert_eq!(format!("{:?}", rc), "Ok(5)");
1528 }
1529 
make_disjoint_suite_configs() -> (ClientConfig, ServerConfig)1530 fn make_disjoint_suite_configs() -> (ClientConfig, ServerConfig) {
1531     let kt = KeyType::RSA;
1532     let mut server_config = make_server_config(kt);
1533     server_config.ciphersuites = vec![find_suite(
1534         CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
1535     )];
1536 
1537     let mut client_config = make_client_config(kt);
1538     client_config.ciphersuites = vec![find_suite(
1539         CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
1540     )];
1541 
1542     (client_config, server_config)
1543 }
1544 
1545 #[test]
client_stream_handshake_error()1546 fn client_stream_handshake_error() {
1547     let (client_config, server_config) = make_disjoint_suite_configs();
1548     let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
1549 
1550     {
1551         let mut pipe = OtherSession::new_fails(&mut server);
1552         let mut client_stream = Stream::new(&mut client, &mut pipe);
1553         let rc = client_stream.write(b"hello");
1554         assert!(rc.is_err());
1555         assert_eq!(
1556             format!("{:?}", rc),
1557             "Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })"
1558         );
1559         let rc = client_stream.write(b"hello");
1560         assert!(rc.is_err());
1561         assert_eq!(
1562             format!("{:?}", rc),
1563             "Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })"
1564         );
1565     }
1566 }
1567 
1568 #[test]
client_streamowned_handshake_error()1569 fn client_streamowned_handshake_error() {
1570     let (client_config, server_config) = make_disjoint_suite_configs();
1571     let (client, mut server) = make_pair_for_configs(client_config, server_config);
1572 
1573     let pipe = OtherSession::new_fails(&mut server);
1574     let mut client_stream = StreamOwned::new(client, pipe);
1575     let rc = client_stream.write(b"hello");
1576     assert!(rc.is_err());
1577     assert_eq!(
1578         format!("{:?}", rc),
1579         "Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })"
1580     );
1581     let rc = client_stream.write(b"hello");
1582     assert!(rc.is_err());
1583     assert_eq!(
1584         format!("{:?}", rc),
1585         "Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })"
1586     );
1587 }
1588 
1589 #[test]
server_stream_handshake_error()1590 fn server_stream_handshake_error() {
1591     let (client_config, server_config) = make_disjoint_suite_configs();
1592     let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
1593 
1594     client.write(b"world").unwrap();
1595 
1596     {
1597         let mut pipe = OtherSession::new_fails(&mut client);
1598         let mut server_stream = Stream::new(&mut server, &mut pipe);
1599         let mut bytes = [0u8; 5];
1600         let rc = server_stream.read(&mut bytes);
1601         assert!(rc.is_err());
1602         assert_eq!(
1603             format!("{:?}", rc),
1604             "Err(Custom { kind: InvalidData, error: PeerIncompatibleError(\"no ciphersuites in common\") })"
1605         );
1606     }
1607 }
1608 
1609 #[test]
server_streamowned_handshake_error()1610 fn server_streamowned_handshake_error() {
1611     let (client_config, server_config) = make_disjoint_suite_configs();
1612     let (mut client, server) = make_pair_for_configs(client_config, server_config);
1613 
1614     client.write(b"world").unwrap();
1615 
1616     let pipe = OtherSession::new_fails(&mut client);
1617     let mut server_stream = StreamOwned::new(server, pipe);
1618     let mut bytes = [0u8; 5];
1619     let rc = server_stream.read(&mut bytes);
1620     assert!(rc.is_err());
1621     assert_eq!(
1622         format!("{:?}", rc),
1623         "Err(Custom { kind: InvalidData, error: PeerIncompatibleError(\"no ciphersuites in common\") })"
1624     );
1625 }
1626 
1627 #[test]
server_config_is_clone()1628 fn server_config_is_clone() {
1629     let _ = make_server_config(KeyType::RSA).clone();
1630 }
1631 
1632 #[test]
client_config_is_clone()1633 fn client_config_is_clone() {
1634     let _ = make_client_config(KeyType::RSA).clone();
1635 }
1636 
1637 #[test]
client_session_is_debug()1638 fn client_session_is_debug() {
1639     let (client, _) = make_pair(KeyType::RSA);
1640     println!("{:?}", client);
1641 }
1642 
1643 #[test]
server_session_is_debug()1644 fn server_session_is_debug() {
1645     let (_, server) = make_pair(KeyType::RSA);
1646     println!("{:?}", server);
1647 }
1648 
1649 #[test]
server_complete_io_for_handshake_ending_with_alert()1650 fn server_complete_io_for_handshake_ending_with_alert() {
1651     let (client_config, server_config) = make_disjoint_suite_configs();
1652     let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
1653 
1654     assert_eq!(true, server.is_handshaking());
1655 
1656     let mut pipe = OtherSession::new_fails(&mut client);
1657     let rc = server.complete_io(&mut pipe);
1658     assert!(rc.is_err(), "server io failed due to handshake failure");
1659     assert!(!server.wants_write(), "but server did send its alert");
1660     assert_eq!(
1661         format!("{:?}", pipe.last_error),
1662         "Some(AlertReceived(HandshakeFailure))",
1663         "which was received by client"
1664     );
1665 }
1666 
1667 #[test]
server_exposes_offered_sni()1668 fn server_exposes_offered_sni() {
1669     let kt = KeyType::RSA;
1670     for client_config in AllClientVersions::new(make_client_config(kt)) {
1671         let mut client =
1672             ClientSession::new(&Arc::new(client_config), dns_name("second.testserver.com"));
1673         let mut server = ServerSession::new(&Arc::new(make_server_config(kt)));
1674 
1675         assert_eq!(None, server.get_sni_hostname());
1676         do_handshake(&mut client, &mut server);
1677         assert_eq!(Some("second.testserver.com"), server.get_sni_hostname());
1678     }
1679 }
1680 
1681 #[test]
server_exposes_offered_sni_smashed_to_lowercase()1682 fn server_exposes_offered_sni_smashed_to_lowercase() {
1683     // webpki actually does this for us in its DNSName type
1684     let kt = KeyType::RSA;
1685     for client_config in AllClientVersions::new(make_client_config(kt)) {
1686         let mut client =
1687             ClientSession::new(&Arc::new(client_config), dns_name("SECOND.TESTServer.com"));
1688         let mut server = ServerSession::new(&Arc::new(make_server_config(kt)));
1689 
1690         assert_eq!(None, server.get_sni_hostname());
1691         do_handshake(&mut client, &mut server);
1692         assert_eq!(Some("second.testserver.com"), server.get_sni_hostname());
1693     }
1694 }
1695 
1696 #[test]
server_exposes_offered_sni_even_if_resolver_fails()1697 fn server_exposes_offered_sni_even_if_resolver_fails() {
1698     let kt = KeyType::RSA;
1699     let resolver = rustls::ResolvesServerCertUsingSNI::new();
1700 
1701     let mut server_config = make_server_config(kt);
1702     server_config.cert_resolver = Arc::new(resolver);
1703     let server_config = Arc::new(server_config);
1704 
1705     for client_config in AllClientVersions::new(make_client_config(kt)) {
1706         let mut server = ServerSession::new(&server_config);
1707         let mut client =
1708             ClientSession::new(&Arc::new(client_config), dns_name("thisdoesNOTexist.com"));
1709 
1710         assert_eq!(None, server.get_sni_hostname());
1711         transfer(&mut client, &mut server);
1712         assert_eq!(
1713             server.process_new_packets(),
1714             Err(TLSError::General(
1715                 "no server certificate chain resolved".to_string()
1716             ))
1717         );
1718         assert_eq!(Some("thisdoesnotexist.com"), server.get_sni_hostname());
1719     }
1720 }
1721 
1722 #[test]
sni_resolver_works()1723 fn sni_resolver_works() {
1724     let kt = KeyType::RSA;
1725     let mut resolver = rustls::ResolvesServerCertUsingSNI::new();
1726     let signing_key = sign::RSASigningKey::new(&kt.get_key()).unwrap();
1727     let signing_key: Arc<Box<dyn sign::SigningKey>> = Arc::new(Box::new(signing_key));
1728     resolver
1729         .add(
1730             "localhost",
1731             sign::CertifiedKey::new(kt.get_chain(), signing_key.clone()),
1732         )
1733         .unwrap();
1734 
1735     let mut server_config = make_server_config(kt);
1736     server_config.cert_resolver = Arc::new(resolver);
1737     let server_config = Arc::new(server_config);
1738 
1739     let mut server1 = ServerSession::new(&server_config);
1740     let mut client1 = ClientSession::new(&Arc::new(make_client_config(kt)), dns_name("localhost"));
1741     let err = do_handshake_until_error(&mut client1, &mut server1);
1742     assert_eq!(err, Ok(()));
1743 
1744     let mut server2 = ServerSession::new(&server_config);
1745     let mut client2 =
1746         ClientSession::new(&Arc::new(make_client_config(kt)), dns_name("notlocalhost"));
1747     let err = do_handshake_until_error(&mut client2, &mut server2);
1748     assert_eq!(
1749         err,
1750         Err(TLSErrorFromPeer::Server(TLSError::General(
1751             "no server certificate chain resolved".into()
1752         )))
1753     );
1754 }
1755 
1756 #[test]
sni_resolver_rejects_wrong_names()1757 fn sni_resolver_rejects_wrong_names() {
1758     let kt = KeyType::RSA;
1759     let mut resolver = rustls::ResolvesServerCertUsingSNI::new();
1760     let signing_key = sign::RSASigningKey::new(&kt.get_key()).unwrap();
1761     let signing_key: Arc<Box<dyn sign::SigningKey>> = Arc::new(Box::new(signing_key));
1762 
1763     assert_eq!(
1764         Ok(()),
1765         resolver.add(
1766             "localhost",
1767             sign::CertifiedKey::new(kt.get_chain(), signing_key.clone())
1768         )
1769     );
1770     assert_eq!(
1771         Err(TLSError::General(
1772             "The server certificate is not valid for the given name".into()
1773         )),
1774         resolver.add(
1775             "not-localhost",
1776             sign::CertifiedKey::new(kt.get_chain(), signing_key.clone())
1777         )
1778     );
1779     assert_eq!(
1780         Err(TLSError::General("Bad DNS name".into())),
1781         resolver.add(
1782             "not ascii ��",
1783             sign::CertifiedKey::new(kt.get_chain(), signing_key.clone())
1784         )
1785     );
1786 }
1787 
1788 #[test]
sni_resolver_rejects_bad_certs()1789 fn sni_resolver_rejects_bad_certs() {
1790     let kt = KeyType::RSA;
1791     let mut resolver = rustls::ResolvesServerCertUsingSNI::new();
1792     let signing_key = sign::RSASigningKey::new(&kt.get_key()).unwrap();
1793     let signing_key: Arc<Box<dyn sign::SigningKey>> = Arc::new(Box::new(signing_key));
1794 
1795     assert_eq!(
1796         Err(TLSError::General(
1797             "No end-entity certificate in certificate chain".into()
1798         )),
1799         resolver.add(
1800             "localhost",
1801             sign::CertifiedKey::new(vec![], signing_key.clone())
1802         )
1803     );
1804 
1805     let bad_chain = vec![rustls::Certificate(vec![0xa0])];
1806     assert_eq!(
1807         Err(TLSError::General(
1808             "End-entity certificate in certificate chain is syntactically invalid".into()
1809         )),
1810         resolver.add(
1811             "localhost",
1812             sign::CertifiedKey::new(bad_chain, signing_key.clone())
1813         )
1814     );
1815 }
1816 
do_exporter_test(client_config: ClientConfig, server_config: ServerConfig)1817 fn do_exporter_test(client_config: ClientConfig, server_config: ServerConfig) {
1818     let mut client_secret = [0u8; 64];
1819     let mut server_secret = [0u8; 64];
1820 
1821     let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
1822 
1823     assert_eq!(
1824         Err(TLSError::HandshakeNotComplete),
1825         client.export_keying_material(&mut client_secret, b"label", Some(b"context"))
1826     );
1827     assert_eq!(
1828         Err(TLSError::HandshakeNotComplete),
1829         server.export_keying_material(&mut server_secret, b"label", Some(b"context"))
1830     );
1831     do_handshake(&mut client, &mut server);
1832 
1833     assert_eq!(
1834         Ok(()),
1835         client.export_keying_material(&mut client_secret, b"label", Some(b"context"))
1836     );
1837     assert_eq!(
1838         Ok(()),
1839         server.export_keying_material(&mut server_secret, b"label", Some(b"context"))
1840     );
1841     assert_eq!(client_secret.to_vec(), server_secret.to_vec());
1842 
1843     assert_eq!(
1844         Ok(()),
1845         client.export_keying_material(&mut client_secret, b"label", None)
1846     );
1847     assert_ne!(client_secret.to_vec(), server_secret.to_vec());
1848     assert_eq!(
1849         Ok(()),
1850         server.export_keying_material(&mut server_secret, b"label", None)
1851     );
1852     assert_eq!(client_secret.to_vec(), server_secret.to_vec());
1853 }
1854 
1855 #[test]
test_tls12_exporter()1856 fn test_tls12_exporter() {
1857     for kt in ALL_KEY_TYPES.iter() {
1858         let mut client_config = make_client_config(*kt);
1859         let server_config = make_server_config(*kt);
1860         client_config.versions = vec![ProtocolVersion::TLSv1_2];
1861 
1862         do_exporter_test(client_config, server_config);
1863     }
1864 }
1865 
1866 #[test]
test_tls13_exporter()1867 fn test_tls13_exporter() {
1868     for kt in ALL_KEY_TYPES.iter() {
1869         let mut client_config = make_client_config(*kt);
1870         let server_config = make_server_config(*kt);
1871         client_config.versions = vec![ProtocolVersion::TLSv1_3];
1872 
1873         do_exporter_test(client_config, server_config);
1874     }
1875 }
1876 
do_suite_test( client_config: ClientConfig, server_config: ServerConfig, expect_suite: &'static SupportedCipherSuite, expect_version: ProtocolVersion, )1877 fn do_suite_test(
1878     client_config: ClientConfig,
1879     server_config: ServerConfig,
1880     expect_suite: &'static SupportedCipherSuite,
1881     expect_version: ProtocolVersion,
1882 ) {
1883     println!(
1884         "do_suite_test {:?} {:?}",
1885         expect_version, expect_suite.suite
1886     );
1887     let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
1888 
1889     assert_eq!(None, client.get_negotiated_ciphersuite());
1890     assert_eq!(None, server.get_negotiated_ciphersuite());
1891     assert_eq!(None, client.get_protocol_version());
1892     assert_eq!(None, server.get_protocol_version());
1893     assert_eq!(true, client.is_handshaking());
1894     assert_eq!(true, server.is_handshaking());
1895 
1896     transfer(&mut client, &mut server);
1897     server.process_new_packets().unwrap();
1898 
1899     assert_eq!(true, client.is_handshaking());
1900     assert_eq!(true, server.is_handshaking());
1901     assert_eq!(None, client.get_protocol_version());
1902     assert_eq!(Some(expect_version), server.get_protocol_version());
1903     assert_eq!(None, client.get_negotiated_ciphersuite());
1904     assert_eq!(Some(expect_suite), server.get_negotiated_ciphersuite());
1905 
1906     transfer(&mut server, &mut client);
1907     client.process_new_packets().unwrap();
1908 
1909     assert_eq!(Some(expect_suite), client.get_negotiated_ciphersuite());
1910     assert_eq!(Some(expect_suite), server.get_negotiated_ciphersuite());
1911 
1912     transfer(&mut client, &mut server);
1913     server.process_new_packets().unwrap();
1914     transfer(&mut server, &mut client);
1915     client.process_new_packets().unwrap();
1916 
1917     assert_eq!(false, client.is_handshaking());
1918     assert_eq!(false, server.is_handshaking());
1919     assert_eq!(Some(expect_version), client.get_protocol_version());
1920     assert_eq!(Some(expect_version), server.get_protocol_version());
1921     assert_eq!(Some(expect_suite), client.get_negotiated_ciphersuite());
1922     assert_eq!(Some(expect_suite), server.get_negotiated_ciphersuite());
1923 }
1924 
find_suite(suite: CipherSuite) -> &'static SupportedCipherSuite1925 fn find_suite(suite: CipherSuite) -> &'static SupportedCipherSuite {
1926     for scs in ALL_CIPHERSUITES.iter() {
1927         if scs.suite == suite {
1928             return scs;
1929         }
1930     }
1931 
1932     panic!("find_suite given unsuppported suite");
1933 }
1934 
1935 static TEST_CIPHERSUITES: [(ProtocolVersion, KeyType, CipherSuite); 9] = [
1936     (
1937         ProtocolVersion::TLSv1_3,
1938         KeyType::RSA,
1939         CipherSuite::TLS13_CHACHA20_POLY1305_SHA256,
1940     ),
1941     (
1942         ProtocolVersion::TLSv1_3,
1943         KeyType::RSA,
1944         CipherSuite::TLS13_AES_256_GCM_SHA384,
1945     ),
1946     (
1947         ProtocolVersion::TLSv1_3,
1948         KeyType::RSA,
1949         CipherSuite::TLS13_AES_128_GCM_SHA256,
1950     ),
1951     (
1952         ProtocolVersion::TLSv1_2,
1953         KeyType::ECDSA,
1954         CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
1955     ),
1956     (
1957         ProtocolVersion::TLSv1_2,
1958         KeyType::RSA,
1959         CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
1960     ),
1961     (
1962         ProtocolVersion::TLSv1_2,
1963         KeyType::ECDSA,
1964         CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
1965     ),
1966     (
1967         ProtocolVersion::TLSv1_2,
1968         KeyType::ECDSA,
1969         CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
1970     ),
1971     (
1972         ProtocolVersion::TLSv1_2,
1973         KeyType::RSA,
1974         CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
1975     ),
1976     (
1977         ProtocolVersion::TLSv1_2,
1978         KeyType::RSA,
1979         CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
1980     ),
1981 ];
1982 
1983 #[test]
negotiated_ciphersuite_default()1984 fn negotiated_ciphersuite_default() {
1985     for kt in ALL_KEY_TYPES.iter() {
1986         do_suite_test(
1987             make_client_config(*kt),
1988             make_server_config(*kt),
1989             find_suite(CipherSuite::TLS13_CHACHA20_POLY1305_SHA256),
1990             ProtocolVersion::TLSv1_3,
1991         );
1992     }
1993 }
1994 
1995 #[test]
all_suites_covered()1996 fn all_suites_covered() {
1997     assert_eq!(ALL_CIPHERSUITES.len(), TEST_CIPHERSUITES.len());
1998 }
1999 
2000 #[test]
negotiated_ciphersuite_client()2001 fn negotiated_ciphersuite_client() {
2002     for item in TEST_CIPHERSUITES.iter() {
2003         let (version, kt, suite) = *item;
2004         let scs = find_suite(suite);
2005         let mut client_config = make_client_config(kt);
2006         client_config.ciphersuites = vec![scs];
2007         client_config.versions = vec![version];
2008 
2009         do_suite_test(client_config, make_server_config(kt), scs, version);
2010     }
2011 }
2012 
2013 #[test]
negotiated_ciphersuite_server()2014 fn negotiated_ciphersuite_server() {
2015     for item in TEST_CIPHERSUITES.iter() {
2016         let (version, kt, suite) = *item;
2017         let scs = find_suite(suite);
2018         let mut server_config = make_server_config(kt);
2019         server_config.ciphersuites = vec![scs];
2020         server_config.versions = vec![version];
2021 
2022         do_suite_test(make_client_config(kt), server_config, scs, version);
2023     }
2024 }
2025 
2026 #[derive(Debug, PartialEq)]
2027 struct KeyLogItem {
2028     label: String,
2029     client_random: Vec<u8>,
2030     secret: Vec<u8>,
2031 }
2032 
2033 struct KeyLogToVec {
2034     label: &'static str,
2035     items: Mutex<Vec<KeyLogItem>>,
2036 }
2037 
2038 impl KeyLogToVec {
new(who: &'static str) -> Self2039     fn new(who: &'static str) -> Self {
2040         KeyLogToVec {
2041             label: who,
2042             items: Mutex::new(vec![]),
2043         }
2044     }
2045 
take(&self) -> Vec<KeyLogItem>2046     fn take(&self) -> Vec<KeyLogItem> {
2047         mem::replace(&mut self.items.lock().unwrap(), vec![])
2048     }
2049 }
2050 
2051 impl KeyLog for KeyLogToVec {
log(&self, label: &str, client: &[u8], secret: &[u8])2052     fn log(&self, label: &str, client: &[u8], secret: &[u8]) {
2053         let value = KeyLogItem {
2054             label: label.into(),
2055             client_random: client.into(),
2056             secret: secret.into(),
2057         };
2058 
2059         println!("key log {:?}: {:?}", self.label, value);
2060 
2061         self.items.lock().unwrap().push(value);
2062     }
2063 }
2064 
2065 #[test]
key_log_for_tls12()2066 fn key_log_for_tls12() {
2067     let client_key_log = Arc::new(KeyLogToVec::new("client"));
2068     let server_key_log = Arc::new(KeyLogToVec::new("server"));
2069 
2070     let kt = KeyType::RSA;
2071     let mut client_config = make_client_config(kt);
2072     client_config.versions = vec![ProtocolVersion::TLSv1_2];
2073     client_config.key_log = client_key_log.clone();
2074     let client_config = Arc::new(client_config);
2075 
2076     let mut server_config = make_server_config(kt);
2077     server_config.key_log = server_key_log.clone();
2078     let server_config = Arc::new(server_config);
2079 
2080     // full handshake
2081     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
2082     do_handshake(&mut client, &mut server);
2083 
2084     let client_full_log = client_key_log.take();
2085     let server_full_log = server_key_log.take();
2086     assert_eq!(client_full_log, server_full_log);
2087     assert_eq!(1, client_full_log.len());
2088     assert_eq!("CLIENT_RANDOM", client_full_log[0].label);
2089 
2090     // resumed
2091     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
2092     do_handshake(&mut client, &mut server);
2093 
2094     let client_resume_log = client_key_log.take();
2095     let server_resume_log = server_key_log.take();
2096     assert_eq!(client_resume_log, server_resume_log);
2097     assert_eq!(1, client_resume_log.len());
2098     assert_eq!("CLIENT_RANDOM", client_resume_log[0].label);
2099     assert_eq!(client_full_log[0].secret, client_resume_log[0].secret);
2100 }
2101 
2102 #[test]
key_log_for_tls13()2103 fn key_log_for_tls13() {
2104     let client_key_log = Arc::new(KeyLogToVec::new("client"));
2105     let server_key_log = Arc::new(KeyLogToVec::new("server"));
2106 
2107     let kt = KeyType::RSA;
2108     let mut client_config = make_client_config(kt);
2109     client_config.versions = vec![ProtocolVersion::TLSv1_3];
2110     client_config.key_log = client_key_log.clone();
2111     let client_config = Arc::new(client_config);
2112 
2113     let mut server_config = make_server_config(kt);
2114     server_config.key_log = server_key_log.clone();
2115     let server_config = Arc::new(server_config);
2116 
2117     // full handshake
2118     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
2119     do_handshake(&mut client, &mut server);
2120 
2121     let client_full_log = client_key_log.take();
2122     let server_full_log = server_key_log.take();
2123 
2124     assert_eq!(5, client_full_log.len());
2125     assert_eq!("CLIENT_HANDSHAKE_TRAFFIC_SECRET", client_full_log[0].label);
2126     assert_eq!("SERVER_HANDSHAKE_TRAFFIC_SECRET", client_full_log[1].label);
2127     assert_eq!("SERVER_TRAFFIC_SECRET_0", client_full_log[2].label);
2128     assert_eq!("EXPORTER_SECRET", client_full_log[3].label);
2129     assert_eq!("CLIENT_TRAFFIC_SECRET_0", client_full_log[4].label);
2130 
2131     assert_eq!(client_full_log[0], server_full_log[1]);
2132     assert_eq!(client_full_log[1], server_full_log[0]);
2133     assert_eq!(client_full_log[2], server_full_log[2]);
2134     assert_eq!(client_full_log[3], server_full_log[3]);
2135     assert_eq!(client_full_log[4], server_full_log[4]);
2136 
2137     // resumed
2138     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
2139     do_handshake(&mut client, &mut server);
2140 
2141     let client_resume_log = client_key_log.take();
2142     let server_resume_log = server_key_log.take();
2143 
2144     assert_eq!(5, client_resume_log.len());
2145     assert_eq!(
2146         "CLIENT_HANDSHAKE_TRAFFIC_SECRET",
2147         client_resume_log[0].label
2148     );
2149     assert_eq!(
2150         "SERVER_HANDSHAKE_TRAFFIC_SECRET",
2151         client_resume_log[1].label
2152     );
2153     assert_eq!("SERVER_TRAFFIC_SECRET_0", client_resume_log[2].label);
2154     assert_eq!("EXPORTER_SECRET", client_resume_log[3].label);
2155     assert_eq!("CLIENT_TRAFFIC_SECRET_0", client_resume_log[4].label);
2156 
2157     assert_eq!(client_resume_log[0], server_resume_log[1]);
2158     assert_eq!(client_resume_log[1], server_resume_log[0]);
2159     assert_eq!(client_resume_log[2], server_resume_log[2]);
2160     assert_eq!(client_resume_log[3], server_resume_log[3]);
2161     assert_eq!(client_resume_log[4], server_resume_log[4]);
2162 }
2163 
2164 #[test]
vectored_write_for_server_appdata()2165 fn vectored_write_for_server_appdata() {
2166     let (mut client, mut server) = make_pair(KeyType::RSA);
2167     do_handshake(&mut client, &mut server);
2168 
2169     server
2170         .write(b"01234567890123456789")
2171         .unwrap();
2172     server
2173         .write(b"01234567890123456789")
2174         .unwrap();
2175     {
2176         let mut pipe = OtherSession::new(&mut client);
2177         let wrlen = server.write_tls(&mut pipe).unwrap();
2178         assert_eq!(84, wrlen);
2179         assert_eq!(pipe.writevs, vec![vec![42, 42]]);
2180     }
2181     check_read(&mut client, b"0123456789012345678901234567890123456789");
2182 }
2183 
2184 #[test]
vectored_write_for_client_appdata()2185 fn vectored_write_for_client_appdata() {
2186     let (mut client, mut server) = make_pair(KeyType::RSA);
2187     do_handshake(&mut client, &mut server);
2188 
2189     client
2190         .write(b"01234567890123456789")
2191         .unwrap();
2192     client
2193         .write(b"01234567890123456789")
2194         .unwrap();
2195     {
2196         let mut pipe = OtherSession::new(&mut server);
2197         let wrlen = client.write_tls(&mut pipe).unwrap();
2198         assert_eq!(84, wrlen);
2199         assert_eq!(pipe.writevs, vec![vec![42, 42]]);
2200     }
2201     check_read(&mut server, b"0123456789012345678901234567890123456789");
2202 }
2203 
2204 #[test]
vectored_write_for_server_handshake()2205 fn vectored_write_for_server_handshake() {
2206     let (mut client, mut server) = make_pair(KeyType::RSA);
2207 
2208     server
2209         .write(b"01234567890123456789")
2210         .unwrap();
2211     server.write(b"0123456789").unwrap();
2212 
2213     transfer(&mut client, &mut server);
2214     server.process_new_packets().unwrap();
2215     {
2216         let mut pipe = OtherSession::new(&mut client);
2217         let wrlen = server.write_tls(&mut pipe).unwrap();
2218         // don't assert exact sizes here, to avoid a brittle test
2219         assert!(wrlen > 4000); // its pretty big (contains cert chain)
2220         assert_eq!(pipe.writevs.len(), 1); // only one writev
2221         assert!(pipe.writevs[0].len() > 3); // at least a server hello/cert/serverkx
2222     }
2223 
2224     client.process_new_packets().unwrap();
2225     transfer(&mut client, &mut server);
2226     server.process_new_packets().unwrap();
2227     {
2228         let mut pipe = OtherSession::new(&mut client);
2229         let wrlen = server.write_tls(&mut pipe).unwrap();
2230         assert_eq!(wrlen, 177);
2231         assert_eq!(pipe.writevs, vec![vec![103, 42, 32]]);
2232     }
2233 
2234     assert_eq!(server.is_handshaking(), false);
2235     assert_eq!(client.is_handshaking(), false);
2236     check_read(&mut client, b"012345678901234567890123456789");
2237 }
2238 
2239 #[test]
vectored_write_for_client_handshake()2240 fn vectored_write_for_client_handshake() {
2241     let (mut client, mut server) = make_pair(KeyType::RSA);
2242 
2243     client
2244         .write(b"01234567890123456789")
2245         .unwrap();
2246     client.write(b"0123456789").unwrap();
2247     {
2248         let mut pipe = OtherSession::new(&mut server);
2249         let wrlen = client.write_tls(&mut pipe).unwrap();
2250         // don't assert exact sizes here, to avoid a brittle test
2251         assert!(wrlen > 200); // just the client hello
2252         assert_eq!(pipe.writevs.len(), 1); // only one writev
2253         assert!(pipe.writevs[0].len() == 1); // only a client hello
2254     }
2255 
2256     transfer(&mut server, &mut client);
2257     client.process_new_packets().unwrap();
2258 
2259     {
2260         let mut pipe = OtherSession::new(&mut server);
2261         let wrlen = client.write_tls(&mut pipe).unwrap();
2262         assert_eq!(wrlen, 138);
2263         // CCS, finished, then two application datas
2264         assert_eq!(pipe.writevs, vec![vec![6, 58, 42, 32]]);
2265     }
2266 
2267     assert_eq!(server.is_handshaking(), false);
2268     assert_eq!(client.is_handshaking(), false);
2269     check_read(&mut server, b"012345678901234567890123456789");
2270 }
2271 
2272 #[test]
vectored_write_with_slow_client()2273 fn vectored_write_with_slow_client() {
2274     let (mut client, mut server) = make_pair(KeyType::RSA);
2275 
2276     client.set_buffer_limit(32);
2277 
2278     do_handshake(&mut client, &mut server);
2279     server
2280         .write(b"01234567890123456789")
2281         .unwrap();
2282 
2283     {
2284         let mut pipe = OtherSession::new(&mut client);
2285         pipe.short_writes = true;
2286         let wrlen = server.write_tls(&mut pipe).unwrap()
2287             + server.write_tls(&mut pipe).unwrap()
2288             + server.write_tls(&mut pipe).unwrap()
2289             + server.write_tls(&mut pipe).unwrap()
2290             + server.write_tls(&mut pipe).unwrap()
2291             + server.write_tls(&mut pipe).unwrap();
2292         assert_eq!(42, wrlen);
2293         assert_eq!(
2294             pipe.writevs,
2295             vec![vec![21], vec![10], vec![5], vec![3], vec![3]]
2296         );
2297     }
2298     check_read(&mut client, b"01234567890123456789");
2299 }
2300 
2301 struct ServerStorage {
2302     storage: Arc<dyn rustls::StoresServerSessions>,
2303     put_count: AtomicUsize,
2304     get_count: AtomicUsize,
2305     take_count: AtomicUsize,
2306 }
2307 
2308 impl ServerStorage {
new() -> ServerStorage2309     fn new() -> ServerStorage {
2310         ServerStorage {
2311             storage: rustls::ServerSessionMemoryCache::new(1024),
2312             put_count: AtomicUsize::new(0),
2313             get_count: AtomicUsize::new(0),
2314             take_count: AtomicUsize::new(0),
2315         }
2316     }
2317 
puts(&self) -> usize2318     fn puts(&self) -> usize {
2319         self.put_count.load(Ordering::SeqCst)
2320     }
gets(&self) -> usize2321     fn gets(&self) -> usize {
2322         self.get_count.load(Ordering::SeqCst)
2323     }
takes(&self) -> usize2324     fn takes(&self) -> usize {
2325         self.take_count.load(Ordering::SeqCst)
2326     }
2327 }
2328 
2329 impl fmt::Debug for ServerStorage {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result2330     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2331         write!(
2332             f,
2333             "(put: {:?}, get: {:?}, take: {:?})",
2334             self.put_count, self.get_count, self.take_count
2335         )
2336     }
2337 }
2338 
2339 impl rustls::StoresServerSessions for ServerStorage {
put(&self, key: Vec<u8>, value: Vec<u8>) -> bool2340     fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
2341         self.put_count
2342             .fetch_add(1, Ordering::SeqCst);
2343         self.storage.put(key, value)
2344     }
2345 
get(&self, key: &[u8]) -> Option<Vec<u8>>2346     fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
2347         self.get_count
2348             .fetch_add(1, Ordering::SeqCst);
2349         self.storage.get(key)
2350     }
2351 
take(&self, key: &[u8]) -> Option<Vec<u8>>2352     fn take(&self, key: &[u8]) -> Option<Vec<u8>> {
2353         self.take_count
2354             .fetch_add(1, Ordering::SeqCst);
2355         self.storage.take(key)
2356     }
2357 }
2358 
2359 #[test]
tls13_stateful_resumption()2360 fn tls13_stateful_resumption() {
2361     let kt = KeyType::RSA;
2362     let mut client_config = make_client_config(kt);
2363     client_config.versions = vec![ProtocolVersion::TLSv1_3];
2364     let client_config = Arc::new(client_config);
2365 
2366     let mut server_config = make_server_config(kt);
2367     let storage = Arc::new(ServerStorage::new());
2368     server_config.session_storage = storage.clone();
2369     let server_config = Arc::new(server_config);
2370 
2371     // full handshake
2372     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
2373     let (full_c2s, full_s2c) = do_handshake(&mut client, &mut server);
2374     assert_eq!(storage.puts(), 1);
2375     assert_eq!(storage.gets(), 0);
2376     assert_eq!(storage.takes(), 0);
2377     assert_eq!(
2378         client
2379             .get_peer_certificates()
2380             .map(|certs| certs.len()),
2381         Some(3)
2382     );
2383 
2384     // resumed
2385     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
2386     let (resume_c2s, resume_s2c) = do_handshake(&mut client, &mut server);
2387     assert!(resume_c2s > full_c2s);
2388     assert!(resume_s2c < full_s2c);
2389     assert_eq!(storage.puts(), 2);
2390     assert_eq!(storage.gets(), 0);
2391     assert_eq!(storage.takes(), 1);
2392     assert_eq!(
2393         client
2394             .get_peer_certificates()
2395             .map(|certs| certs.len()),
2396         Some(3)
2397     );
2398 
2399     // resumed again
2400     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
2401     let (resume2_c2s, resume2_s2c) = do_handshake(&mut client, &mut server);
2402     assert_eq!(resume_s2c, resume2_s2c);
2403     assert_eq!(resume_c2s, resume2_c2s);
2404     assert_eq!(storage.puts(), 3);
2405     assert_eq!(storage.gets(), 0);
2406     assert_eq!(storage.takes(), 2);
2407     assert_eq!(
2408         client
2409             .get_peer_certificates()
2410             .map(|certs| certs.len()),
2411         Some(3)
2412     );
2413 }
2414 
2415 #[test]
tls13_stateless_resumption()2416 fn tls13_stateless_resumption() {
2417     let kt = KeyType::RSA;
2418     let mut client_config = make_client_config(kt);
2419     client_config.versions = vec![ProtocolVersion::TLSv1_3];
2420     let client_config = Arc::new(client_config);
2421 
2422     let mut server_config = make_server_config(kt);
2423     server_config.ticketer = rustls::Ticketer::new();
2424     let storage = Arc::new(ServerStorage::new());
2425     server_config.session_storage = storage.clone();
2426     let server_config = Arc::new(server_config);
2427 
2428     // full handshake
2429     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
2430     let (full_c2s, full_s2c) = do_handshake(&mut client, &mut server);
2431     assert_eq!(storage.puts(), 0);
2432     assert_eq!(storage.gets(), 0);
2433     assert_eq!(storage.takes(), 0);
2434     assert_eq!(
2435         client
2436             .get_peer_certificates()
2437             .map(|certs| certs.len()),
2438         Some(3)
2439     );
2440 
2441     // resumed
2442     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
2443     let (resume_c2s, resume_s2c) = do_handshake(&mut client, &mut server);
2444     assert!(resume_c2s > full_c2s);
2445     assert!(resume_s2c < full_s2c);
2446     assert_eq!(storage.puts(), 0);
2447     assert_eq!(storage.gets(), 0);
2448     assert_eq!(storage.takes(), 0);
2449     assert_eq!(
2450         client
2451             .get_peer_certificates()
2452             .map(|certs| certs.len()),
2453         Some(3)
2454     );
2455 
2456     // resumed again
2457     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
2458     let (resume2_c2s, resume2_s2c) = do_handshake(&mut client, &mut server);
2459     assert_eq!(resume_s2c, resume2_s2c);
2460     assert_eq!(resume_c2s, resume2_c2s);
2461     assert_eq!(storage.puts(), 0);
2462     assert_eq!(storage.gets(), 0);
2463     assert_eq!(storage.takes(), 0);
2464     assert_eq!(
2465         client
2466             .get_peer_certificates()
2467             .map(|certs| certs.len()),
2468         Some(3)
2469     );
2470 }
2471 
2472 #[cfg(feature = "quic")]
2473 mod test_quic {
2474     use super::*;
2475 
2476     // Returns the sender's next secrets to use, or the receiver's error.
step( send: &mut dyn Session, recv: &mut dyn Session, ) -> Result<Option<quic::Keys>, TLSError>2477     fn step(
2478         send: &mut dyn Session,
2479         recv: &mut dyn Session,
2480     ) -> Result<Option<quic::Keys>, TLSError> {
2481         let mut buf = Vec::new();
2482         let secrets = loop {
2483             let prev = buf.len();
2484             if let Some(x) = send.write_hs(&mut buf) {
2485                 break Some(x);
2486             }
2487             if prev == buf.len() {
2488                 break None;
2489             }
2490         };
2491         if let Err(e) = recv.read_hs(&buf) {
2492             return Err(e);
2493         } else {
2494             assert_eq!(recv.get_alert(), None);
2495         }
2496         Ok(secrets)
2497     }
2498 
2499     #[test]
test_quic_handshake()2500     fn test_quic_handshake() {
2501         fn equal_dir_keys(x: &quic::DirectionalKeys, y: &quic::DirectionalKeys) -> bool {
2502             // Check that these two sets of keys are equal. The quic module's unit tests validate
2503             // that the IV and the keys are consistent, so we can just check the IV here.
2504             x.packet.iv.nonce_for(42).as_ref() == y.packet.iv.nonce_for(42).as_ref()
2505         }
2506         fn compatible_keys(x: &quic::Keys, y: &quic::Keys) -> bool {
2507             equal_dir_keys(&x.local, &y.remote) && equal_dir_keys(&x.remote, &y.local)
2508         }
2509 
2510         let kt = KeyType::RSA;
2511         let mut client_config = make_client_config(kt);
2512         client_config.versions = vec![ProtocolVersion::TLSv1_3];
2513         client_config.enable_early_data = true;
2514         let client_config = Arc::new(client_config);
2515         let mut server_config = make_server_config(kt);
2516         server_config.versions = vec![ProtocolVersion::TLSv1_3];
2517         server_config.max_early_data_size = 0xffffffff;
2518         server_config.alpn_protocols = vec!["foo".into()];
2519         let server_config = Arc::new(server_config);
2520         let client_params = &b"client params"[..];
2521         let server_params = &b"server params"[..];
2522 
2523         // full handshake
2524         let mut client =
2525             ClientSession::new_quic(&client_config, dns_name("localhost"), client_params.into());
2526         let mut server = ServerSession::new_quic(&server_config, server_params.into());
2527         let client_initial = step(&mut client, &mut server).unwrap();
2528         assert!(client_initial.is_none());
2529         assert!(client.get_0rtt_keys().is_none());
2530         assert_eq!(server.get_quic_transport_parameters(), Some(client_params));
2531         let server_hs = step(&mut server, &mut client)
2532             .unwrap()
2533             .unwrap();
2534         assert!(server.get_0rtt_keys().is_none());
2535         let client_hs = step(&mut client, &mut server)
2536             .unwrap()
2537             .unwrap();
2538         assert!(compatible_keys(&server_hs, &client_hs));
2539         assert!(client.is_handshaking());
2540         let server_1rtt = step(&mut server, &mut client)
2541             .unwrap()
2542             .unwrap();
2543         assert!(!client.is_handshaking());
2544         assert_eq!(client.get_quic_transport_parameters(), Some(server_params));
2545         assert!(server.is_handshaking());
2546         let client_1rtt = step(&mut client, &mut server)
2547             .unwrap()
2548             .unwrap();
2549         assert!(!server.is_handshaking());
2550         assert!(compatible_keys(&server_1rtt, &client_1rtt));
2551         assert!(!compatible_keys(&server_hs, &server_1rtt));
2552         assert!(
2553             step(&mut client, &mut server)
2554                 .unwrap()
2555                 .is_none()
2556         );
2557         assert!(
2558             step(&mut server, &mut client)
2559                 .unwrap()
2560                 .is_none()
2561         );
2562 
2563         // 0-RTT handshake
2564         let mut client =
2565             ClientSession::new_quic(&client_config, dns_name("localhost"), client_params.into());
2566         assert!(
2567             client
2568                 .get_negotiated_ciphersuite()
2569                 .is_some()
2570         );
2571         let mut server = ServerSession::new_quic(&server_config, server_params.into());
2572         step(&mut client, &mut server).unwrap();
2573         assert_eq!(client.get_quic_transport_parameters(), Some(server_params));
2574         {
2575             let client_early = client.get_0rtt_keys().unwrap();
2576             let server_early = server.get_0rtt_keys().unwrap();
2577             assert!(equal_dir_keys(&client_early, &server_early));
2578         }
2579         step(&mut server, &mut client)
2580             .unwrap()
2581             .unwrap();
2582         step(&mut client, &mut server)
2583             .unwrap()
2584             .unwrap();
2585         step(&mut server, &mut client)
2586             .unwrap()
2587             .unwrap();
2588         assert!(client.is_early_data_accepted());
2589 
2590         // 0-RTT rejection
2591         {
2592             let mut client_config = (*client_config).clone();
2593             client_config.alpn_protocols = vec!["foo".into()];
2594             let mut client = ClientSession::new_quic(
2595                 &Arc::new(client_config),
2596                 dns_name("localhost"),
2597                 client_params.into(),
2598             );
2599             let mut server = ServerSession::new_quic(&server_config, server_params.into());
2600             step(&mut client, &mut server).unwrap();
2601             assert_eq!(client.get_quic_transport_parameters(), Some(server_params));
2602             assert!(client.get_0rtt_keys().is_some());
2603             assert!(server.get_0rtt_keys().is_none());
2604             step(&mut server, &mut client)
2605                 .unwrap()
2606                 .unwrap();
2607             step(&mut client, &mut server)
2608                 .unwrap()
2609                 .unwrap();
2610             step(&mut server, &mut client)
2611                 .unwrap()
2612                 .unwrap();
2613             assert!(!client.is_early_data_accepted());
2614         }
2615 
2616         // failed handshake
2617         let mut client = ClientSession::new_quic(
2618             &client_config,
2619             dns_name("example.com"),
2620             client_params.into(),
2621         );
2622         let mut server = ServerSession::new_quic(&server_config, server_params.into());
2623         step(&mut client, &mut server).unwrap();
2624         step(&mut server, &mut client)
2625             .unwrap()
2626             .unwrap();
2627         assert!(step(&mut server, &mut client).is_err());
2628         assert_eq!(
2629             client.get_alert(),
2630             Some(rustls::internal::msgs::enums::AlertDescription::BadCertificate)
2631         );
2632     }
2633 
2634     #[test]
test_quic_rejects_missing_alpn()2635     fn test_quic_rejects_missing_alpn() {
2636         let client_params = &b"client params"[..];
2637         let server_params = &b"server params"[..];
2638 
2639         for &kt in ALL_KEY_TYPES.iter() {
2640             let mut client_config = make_client_config(kt);
2641             client_config.versions = vec![ProtocolVersion::TLSv1_3];
2642             client_config.alpn_protocols = vec!["bar".into()];
2643             let client_config = Arc::new(client_config);
2644 
2645             let mut server_config = make_server_config(kt);
2646             server_config.versions = vec![ProtocolVersion::TLSv1_3];
2647             server_config.alpn_protocols = vec!["foo".into()];
2648             let server_config = Arc::new(server_config);
2649 
2650             let mut client = ClientSession::new_quic(
2651                 &client_config,
2652                 dns_name("localhost"),
2653                 client_params.into(),
2654             );
2655             let mut server = ServerSession::new_quic(&server_config, server_params.into());
2656 
2657             assert_eq!(
2658                 step(&mut client, &mut server)
2659                     .err()
2660                     .unwrap(),
2661                 TLSError::NoApplicationProtocol
2662             );
2663 
2664             assert_eq!(
2665                 server.get_alert(),
2666                 Some(rustls::internal::msgs::enums::AlertDescription::NoApplicationProtocol)
2667             );
2668         }
2669     }
2670 
2671     #[test]
test_quic_exporter()2672     fn test_quic_exporter() {
2673         for &kt in ALL_KEY_TYPES.iter() {
2674             let mut client_config = make_client_config(kt);
2675             client_config.versions = vec![ProtocolVersion::TLSv1_3];
2676             client_config.alpn_protocols = vec!["bar".into()];
2677 
2678             let mut server_config = make_server_config(kt);
2679             server_config.versions = vec![ProtocolVersion::TLSv1_3];
2680             server_config.alpn_protocols = vec!["foo".into()];
2681 
2682             do_exporter_test(client_config, server_config);
2683         }
2684     }
2685 } // mod test_quic
2686 
2687 #[test]
test_client_does_not_offer_sha1()2688 fn test_client_does_not_offer_sha1() {
2689     use rustls::internal::msgs::{
2690         codec::Codec, enums::HandshakeType, handshake::HandshakePayload, message::Message,
2691         message::MessagePayload,
2692     };
2693 
2694     for kt in ALL_KEY_TYPES.iter() {
2695         for client_config in AllClientVersions::new(make_client_config(*kt)) {
2696             let (mut client, _) = make_pair_for_configs(client_config, make_server_config(*kt));
2697 
2698             assert!(client.wants_write());
2699             let mut buf = [0u8; 262144];
2700             let sz = client
2701                 .write_tls(&mut buf.as_mut())
2702                 .unwrap();
2703             let mut msg = Message::read_bytes(&buf[..sz]).unwrap();
2704             assert!(msg.decode_payload());
2705             assert!(msg.is_handshake_type(HandshakeType::ClientHello));
2706 
2707             let client_hello = match msg.payload {
2708                 MessagePayload::Handshake(hs) => match hs.payload {
2709                     HandshakePayload::ClientHello(ch) => ch,
2710                     _ => unreachable!(),
2711                 },
2712                 _ => unreachable!(),
2713             };
2714 
2715             let sigalgs = client_hello
2716                 .get_sigalgs_extension()
2717                 .unwrap();
2718             assert_eq!(
2719                 sigalgs.contains(&SignatureScheme::RSA_PKCS1_SHA1),
2720                 false,
2721                 "sha1 unexpectedly offered"
2722             );
2723         }
2724     }
2725 }
2726 
2727 #[test]
test_client_mtu_reduction()2728 fn test_client_mtu_reduction() {
2729     struct CollectWrites {
2730         writevs: Vec<Vec<usize>>,
2731     }
2732 
2733     impl io::Write for CollectWrites {
2734         fn write(&mut self, _: &[u8]) -> io::Result<usize> {
2735             panic!()
2736         }
2737         fn flush(&mut self) -> io::Result<()> {
2738             panic!()
2739         }
2740         fn write_vectored<'b>(&mut self, b: &[io::IoSlice<'b>]) -> io::Result<usize> {
2741             let writes = b
2742                 .iter()
2743                 .map(|slice| slice.len())
2744                 .collect::<Vec<usize>>();
2745             let len = writes.iter().sum();
2746             self.writevs.push(writes);
2747             Ok(len)
2748         }
2749     }
2750 
2751     fn collect_write_lengths(client: &mut ClientSession) -> Vec<usize> {
2752         let mut collector = CollectWrites { writevs: vec![] };
2753 
2754         client
2755             .write_tls(&mut collector)
2756             .unwrap();
2757         assert_eq!(collector.writevs.len(), 1);
2758         collector.writevs[0].clone()
2759     }
2760 
2761     for kt in ALL_KEY_TYPES.iter() {
2762         let mut client_config = make_client_config(*kt);
2763         client_config.set_mtu(&Some(64));
2764 
2765         let mut client = ClientSession::new(&Arc::new(client_config), dns_name("localhost"));
2766         let writes = collect_write_lengths(&mut client);
2767         println!("writes at mtu=64: {:?}", writes);
2768         assert!(writes.iter().all(|x| *x <= 64));
2769         assert!(writes.len() > 1);
2770     }
2771 }
2772 
2773 #[test]
exercise_key_log_file_for_client()2774 fn exercise_key_log_file_for_client() {
2775     let server_config = Arc::new(make_server_config(KeyType::RSA));
2776     let mut client_config = make_client_config(KeyType::RSA);
2777     env::set_var("SSLKEYLOGFILE", "./sslkeylogfile.txt");
2778     client_config.key_log = Arc::new(rustls::KeyLogFile::new());
2779 
2780     for client_config in AllClientVersions::new(client_config) {
2781         let (mut client, mut server) =
2782             make_pair_for_arc_configs(&Arc::new(client_config), &server_config);
2783 
2784         assert_eq!(5, client.write(b"hello").unwrap());
2785 
2786         do_handshake(&mut client, &mut server);
2787         transfer(&mut client, &mut server);
2788         server.process_new_packets().unwrap();
2789     }
2790 }
2791 
2792 #[test]
exercise_key_log_file_for_server()2793 fn exercise_key_log_file_for_server() {
2794     let mut server_config = make_server_config(KeyType::RSA);
2795 
2796     env::set_var("SSLKEYLOGFILE", "./sslkeylogfile.txt");
2797     server_config.key_log = Arc::new(rustls::KeyLogFile::new());
2798 
2799     let server_config = Arc::new(server_config);
2800 
2801     for client_config in AllClientVersions::new(make_client_config(KeyType::RSA)) {
2802         let (mut client, mut server) =
2803             make_pair_for_arc_configs(&Arc::new(client_config), &server_config);
2804 
2805         assert_eq!(5, client.write(b"hello").unwrap());
2806 
2807         do_handshake(&mut client, &mut server);
2808         transfer(&mut client, &mut server);
2809         server.process_new_packets().unwrap();
2810     }
2811 }
2812 
assert_lt(left: usize, right: usize)2813 fn assert_lt(left: usize, right: usize) {
2814     if left >= right {
2815         panic!("expected {} < {}", left, right);
2816     }
2817 }
2818 
2819 #[test]
session_types_are_not_huge()2820 fn session_types_are_not_huge() {
2821     // Arbitrary sizes
2822     assert_lt(mem::size_of::<ServerSession>(), 1600);
2823     assert_lt(mem::size_of::<ClientSession>(), 1600);
2824 }
2825 
2826 use rustls::internal::msgs::{
2827     handshake::ClientExtension, handshake::HandshakePayload, message::Message,
2828     message::MessagePayload,
2829 };
2830 
2831 #[test]
test_server_rejects_duplicate_sni_names()2832 fn test_server_rejects_duplicate_sni_names() {
2833     fn duplicate_sni_payload(msg: &mut Message) {
2834         if let MessagePayload::Handshake(hs) = &mut msg.payload {
2835             if let HandshakePayload::ClientHello(ch) = &mut hs.payload {
2836                 for mut ext in ch.extensions.iter_mut() {
2837                     if let ClientExtension::ServerName(snr) = &mut ext {
2838                         snr.push(snr[0].clone());
2839                     }
2840                 }
2841             }
2842         }
2843     }
2844 
2845     let (mut client, mut server) = make_pair(KeyType::RSA);
2846     transfer_altered(&mut client, duplicate_sni_payload, &mut server);
2847     assert_eq!(
2848         server.process_new_packets(),
2849         Err(TLSError::PeerMisbehavedError(
2850             "ClientHello SNI contains duplicate name types".into()
2851         ))
2852     );
2853 }
2854 
2855 #[test]
test_server_rejects_empty_sni_extension()2856 fn test_server_rejects_empty_sni_extension() {
2857     fn empty_sni_payload(msg: &mut Message) {
2858         if let MessagePayload::Handshake(hs) = &mut msg.payload {
2859             if let HandshakePayload::ClientHello(ch) = &mut hs.payload {
2860                 for mut ext in ch.extensions.iter_mut() {
2861                     if let ClientExtension::ServerName(snr) = &mut ext {
2862                         snr.clear();
2863                     }
2864                 }
2865             }
2866         }
2867     }
2868 
2869     let (mut client, mut server) = make_pair(KeyType::RSA);
2870     transfer_altered(&mut client, empty_sni_payload, &mut server);
2871     assert_eq!(
2872         server.process_new_packets(),
2873         Err(TLSError::PeerMisbehavedError(
2874             "ClientHello SNI did not contain a hostname".into()
2875         ))
2876     );
2877 }
2878 
2879 #[test]
test_server_rejects_clients_without_any_kx_group_overlap()2880 fn test_server_rejects_clients_without_any_kx_group_overlap() {
2881     fn different_kx_group(msg: &mut Message) {
2882         if let MessagePayload::Handshake(hs) = &mut msg.payload {
2883             if let HandshakePayload::ClientHello(ch) = &mut hs.payload {
2884                 for mut ext in ch.extensions.iter_mut() {
2885                     if let ClientExtension::NamedGroups(ngs) = &mut ext {
2886                         ngs.clear();
2887                     }
2888                     if let ClientExtension::KeyShare(ks) = &mut ext {
2889                         ks.clear();
2890                     }
2891                 }
2892             }
2893         }
2894     }
2895 
2896     let (mut client, mut server) = make_pair(KeyType::RSA);
2897     transfer_altered(&mut client, different_kx_group, &mut server);
2898     assert_eq!(
2899         server.process_new_packets(),
2900         Err(TLSError::PeerIncompatibleError(
2901             "no kx group overlap with client".into()
2902         ))
2903     );
2904 }
2905 
2906 #[test]
test_ownedtrustanchor_to_trust_anchor_is_public()2907 fn test_ownedtrustanchor_to_trust_anchor_is_public() {
2908     let client_config = make_client_config(KeyType::RSA);
2909     let _anchor: webpki::TrustAnchor = client_config.root_store.roots[0].to_trust_anchor();
2910 }
2911