1 // Assorted public API tests.
2 use std::sync::Arc;
3 use std::sync::Mutex;
4 use std::sync::atomic::{AtomicUsize, Ordering};
5 use std::mem;
6 use std::fmt;
7 use std::env;
8 use std::error::Error;
9 use std::io::{self, Write, Read};
10 
11 use rustls;
12 
13 use rustls::{ClientConfig, ClientSession, ResolvesClientCert};
14 use rustls::{ServerConfig, ServerSession, ResolvesServerCert};
15 use rustls::Session;
16 use rustls::{Stream, StreamOwned};
17 use rustls::{ProtocolVersion, SignatureScheme, CipherSuite};
18 use rustls::TLSError;
19 use rustls::sign;
20 use rustls::{ALL_CIPHERSUITES, SupportedCipherSuite};
21 use rustls::KeyLog;
22 use rustls::ClientHello;
23 #[cfg(feature = "quic")]
24 use rustls::quic::{self, QuicExt, ClientQuicExt, ServerQuicExt};
25 #[cfg(feature = "quic")]
26 use ring::hkdf;
27 
28 #[cfg(feature = "dangerous_configuration")]
29 use rustls::ClientCertVerified;
30 
31 use webpki;
32 
33 #[allow(dead_code)]
34 mod common;
35 use crate::common::*;
36 
alpn_test(server_protos: Vec<Vec<u8>>, client_protos: Vec<Vec<u8>>, agreed: Option<&[u8]>)37 fn alpn_test(server_protos: Vec<Vec<u8>>, client_protos: Vec<Vec<u8>>, agreed: Option<&[u8]>) {
38     let mut client_config = make_client_config(KeyType::RSA);
39     let mut server_config = make_server_config(KeyType::RSA);
40 
41     client_config.alpn_protocols = client_protos;
42     server_config.alpn_protocols = server_protos;
43 
44     let server_config = Arc::new(server_config);
45 
46     for client_config in AllClientVersions::new(client_config) {
47         let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config),
48                                                                  &server_config);
49 
50         assert_eq!(client.get_alpn_protocol(), None);
51         assert_eq!(server.get_alpn_protocol(), None);
52         do_handshake(&mut client, &mut server);
53         assert_eq!(client.get_alpn_protocol(), agreed);
54         assert_eq!(server.get_alpn_protocol(), agreed);
55     }
56 }
57 
58 #[test]
alpn()59 fn alpn() {
60     // no support
61     alpn_test(vec![], vec![], None);
62 
63     // server support
64     alpn_test(vec![b"server-proto".to_vec()], vec![], None);
65 
66     // client support
67     alpn_test(vec![], vec![b"client-proto".to_vec()], None);
68 
69     // no overlap
70     alpn_test(vec![b"server-proto".to_vec()],
71               vec![b"client-proto".to_vec()],
72               None);
73 
74     // server chooses preference
75     alpn_test(vec![b"server-proto".to_vec(), b"client-proto".to_vec()],
76               vec![b"client-proto".to_vec(), b"server-proto".to_vec()],
77               Some(b"server-proto"));
78 
79     // case sensitive
80     alpn_test(vec![b"PROTO".to_vec()], vec![b"proto".to_vec()], None);
81 }
82 
version_test(client_versions: Vec<ProtocolVersion>, server_versions: Vec<ProtocolVersion>, result: Option<ProtocolVersion>)83 fn version_test(client_versions: Vec<ProtocolVersion>,
84                 server_versions: Vec<ProtocolVersion>,
85                 result: Option<ProtocolVersion>) {
86     let mut client_config = make_client_config(KeyType::RSA);
87     let mut server_config = make_server_config(KeyType::RSA);
88 
89     println!("version {:?} {:?} -> {:?}",
90              client_versions,
91              server_versions,
92              result);
93 
94     if !client_versions.is_empty() {
95         client_config.versions = client_versions;
96     }
97 
98     if !server_versions.is_empty() {
99         server_config.versions = server_versions;
100     }
101 
102     let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
103 
104     assert_eq!(client.get_protocol_version(), None);
105     assert_eq!(server.get_protocol_version(), None);
106     if result.is_none() {
107         let err = do_handshake_until_error(&mut client, &mut server);
108         assert_eq!(err.is_err(), true);
109     } else {
110         do_handshake(&mut client, &mut server);
111         assert_eq!(client.get_protocol_version(), result);
112         assert_eq!(server.get_protocol_version(), result);
113     }
114 }
115 
116 #[test]
versions()117 fn versions() {
118     // default -> 1.3
119     version_test(vec![], vec![], Some(ProtocolVersion::TLSv1_3));
120 
121     // client default, server 1.2 -> 1.2
122     version_test(vec![],
123                  vec![ProtocolVersion::TLSv1_2],
124                  Some(ProtocolVersion::TLSv1_2));
125 
126     // client 1.2, server default -> 1.2
127     version_test(vec![ProtocolVersion::TLSv1_2],
128                  vec![],
129                  Some(ProtocolVersion::TLSv1_2));
130 
131     // client 1.2, server 1.3 -> fail
132     version_test(vec![ProtocolVersion::TLSv1_2],
133                  vec![ProtocolVersion::TLSv1_3],
134                  None);
135 
136     // client 1.3, server 1.2 -> fail
137     version_test(vec![ProtocolVersion::TLSv1_3],
138                  vec![ProtocolVersion::TLSv1_2],
139                  None);
140 
141     // client 1.3, server 1.2+1.3 -> 1.3
142     version_test(vec![ProtocolVersion::TLSv1_3],
143                  vec![ProtocolVersion::TLSv1_2, ProtocolVersion::TLSv1_3],
144                  Some(ProtocolVersion::TLSv1_3));
145 
146     // client 1.2+1.3, server 1.2 -> 1.2
147     version_test(vec![ProtocolVersion::TLSv1_3, ProtocolVersion::TLSv1_2],
148                  vec![ProtocolVersion::TLSv1_2],
149                  Some(ProtocolVersion::TLSv1_2));
150 }
151 
check_read(reader: &mut dyn io::Read, bytes: &[u8])152 fn check_read(reader: &mut dyn io::Read, bytes: &[u8]) {
153     let mut buf = Vec::new();
154     assert_eq!(bytes.len(), reader.read_to_end(&mut buf).unwrap());
155     assert_eq!(bytes.to_vec(), buf);
156 }
157 
158 #[test]
buffered_client_data_sent()159 fn buffered_client_data_sent() {
160     let server_config = Arc::new(make_server_config(KeyType::RSA));
161 
162     for client_config in AllClientVersions::new(make_client_config(KeyType::RSA)) {
163         let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config),
164                                                                  &server_config);
165 
166         assert_eq!(5, client.write(b"hello").unwrap());
167 
168         do_handshake(&mut client, &mut server);
169         transfer(&mut client, &mut server);
170         server.process_new_packets().unwrap();
171 
172         check_read(&mut server, b"hello");
173     }
174 }
175 
176 #[test]
buffered_server_data_sent()177 fn buffered_server_data_sent() {
178     let server_config = Arc::new(make_server_config(KeyType::RSA));
179 
180     for client_config in AllClientVersions::new(make_client_config(KeyType::RSA)) {
181         let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config),
182                                                                  &server_config);
183 
184         assert_eq!(5, server.write(b"hello").unwrap());
185 
186         do_handshake(&mut client, &mut server);
187         transfer(&mut server, &mut client);
188         client.process_new_packets().unwrap();
189 
190         check_read(&mut client, b"hello");
191     }
192 }
193 
194 #[test]
buffered_both_data_sent()195 fn buffered_both_data_sent() {
196     let server_config = Arc::new(make_server_config(KeyType::RSA));
197 
198     for client_config in AllClientVersions::new(make_client_config(KeyType::RSA)) {
199         let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config),
200                                                                  &server_config);
201 
202         assert_eq!(12, server.write(b"from-server!").unwrap());
203         assert_eq!(12, client.write(b"from-client!").unwrap());
204 
205         do_handshake(&mut client, &mut server);
206 
207         transfer(&mut server, &mut client);
208         client.process_new_packets().unwrap();
209         transfer(&mut client, &mut server);
210         server.process_new_packets().unwrap();
211 
212         check_read(&mut client, b"from-server!");
213         check_read(&mut server, b"from-client!");
214     }
215 }
216 
217 #[test]
client_can_get_server_cert()218 fn client_can_get_server_cert() {
219     for kt in ALL_KEY_TYPES.iter() {
220         for client_config in AllClientVersions::new(make_client_config(*kt)) {
221             let (mut client, mut server) = make_pair_for_configs(client_config,
222                                                                  make_server_config(*kt));
223             do_handshake(&mut client, &mut server);
224 
225             let certs = client.get_peer_certificates();
226             assert_eq!(certs, Some(kt.get_chain()));
227         }
228     }
229 }
230 
231 #[test]
server_can_get_client_cert()232 fn server_can_get_client_cert() {
233     for kt in ALL_KEY_TYPES.iter() {
234         let mut client_config = make_client_config(*kt);
235         client_config.set_single_client_cert(kt.get_chain(), kt.get_key())
236             .unwrap();
237 
238         let server_config = Arc::new(make_server_config_with_mandatory_client_auth(*kt));
239 
240         for client_config in AllClientVersions::new(client_config) {
241             let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config),
242                                                                      &server_config);
243             do_handshake(&mut client, &mut server);
244 
245             let certs = server.get_peer_certificates();
246             assert_eq!(certs, Some(kt.get_chain()));
247         }
248     }
249 }
250 
check_read_and_close(reader: &mut dyn io::Read, expect: &[u8])251 fn check_read_and_close(reader: &mut dyn io::Read, expect: &[u8]) {
252     let mut buf = Vec::new();
253     buf.resize(expect.len(), 0u8);
254     assert_eq!(expect.len(), reader.read(&mut buf).unwrap());
255     assert_eq!(expect.to_vec(), buf);
256 
257     let err = reader.read(&mut buf);
258     assert!(err.is_err());
259     assert_eq!(err.err().unwrap().kind(), io::ErrorKind::ConnectionAborted);
260 }
261 
262 #[test]
server_close_notify()263 fn server_close_notify() {
264     let kt = KeyType::RSA;
265     let mut client_config = make_client_config(kt);
266     client_config.set_single_client_cert(kt.get_chain(), kt.get_key())
267         .unwrap();
268 
269     let server_config = Arc::new(make_server_config_with_mandatory_client_auth(kt));
270 
271     for client_config in AllClientVersions::new(client_config) {
272         let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config),
273                                                                  &server_config);
274         do_handshake(&mut client, &mut server);
275 
276         // check that alerts don't overtake appdata
277         assert_eq!(12, server.write(b"from-server!").unwrap());
278         assert_eq!(12, client.write(b"from-client!").unwrap());
279         server.send_close_notify();
280 
281         transfer(&mut server, &mut client);
282         client.process_new_packets().unwrap();
283         check_read_and_close(&mut client, b"from-server!");
284 
285         transfer(&mut client, &mut server);
286         server.process_new_packets().unwrap();
287         check_read(&mut server, b"from-client!");
288     }
289 }
290 
291 #[test]
client_close_notify()292 fn client_close_notify() {
293     let kt = KeyType::RSA;
294     let mut client_config = make_client_config(kt);
295     client_config.set_single_client_cert(kt.get_chain(), kt.get_key())
296         .unwrap();
297 
298     let server_config = Arc::new(make_server_config_with_mandatory_client_auth(kt));
299 
300     for client_config in AllClientVersions::new(client_config) {
301         let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config),
302                                                                  &server_config);
303         do_handshake(&mut client, &mut server);
304 
305         // check that alerts don't overtake appdata
306         assert_eq!(12, server.write(b"from-server!").unwrap());
307         assert_eq!(12, client.write(b"from-client!").unwrap());
308         client.send_close_notify();
309 
310         transfer(&mut client, &mut server);
311         server.process_new_packets().unwrap();
312         check_read_and_close(&mut server, b"from-client!");
313 
314         transfer(&mut server, &mut client);
315         client.process_new_packets().unwrap();
316         check_read(&mut client, b"from-server!");
317     }
318 }
319 
320 #[derive(Default)]
321 struct ServerCheckCertResolve {
322     expected_sni: Option<String>,
323     expected_sigalgs: Option<Vec<SignatureScheme>>,
324     expected_alpn: Option<Vec<Vec<u8>>>,
325 }
326 
327 impl ResolvesServerCert for ServerCheckCertResolve {
resolve(&self, client_hello: ClientHello) -> Option<sign::CertifiedKey>328     fn resolve(&self, client_hello: ClientHello) -> Option<sign::CertifiedKey> {
329         if client_hello.sigschemes().len() == 0 {
330             panic!("no signature schemes shared by client");
331         }
332 
333         if let Some(expected_sni) = &self.expected_sni {
334             let sni: &str = client_hello.server_name().expect("sni unexpectedly absent").into();
335             assert_eq!(expected_sni, sni);
336         }
337 
338         if let Some(expected_sigalgs) = &self.expected_sigalgs {
339             if expected_sigalgs != &client_hello.sigschemes() {
340                 panic!("unexpected signature schemes (wanted {:?} got {:?})",
341                        self.expected_sigalgs, client_hello.sigschemes());
342             }
343         }
344 
345         if let Some(expected_alpn) = &self.expected_alpn {
346             let alpn = client_hello.alpn().expect("alpn unexpectedly absent");
347             assert_eq!(alpn.len(), expected_alpn.len());
348 
349             for (got, wanted) in alpn.iter().zip(expected_alpn.iter()) {
350                 assert_eq!(got, &wanted.as_slice());
351             }
352         }
353 
354         None
355     }
356 }
357 
358 #[test]
server_cert_resolve_with_sni()359 fn server_cert_resolve_with_sni() {
360     for kt in ALL_KEY_TYPES.iter() {
361         let client_config = make_client_config(*kt);
362         let mut server_config = make_server_config(*kt);
363 
364         server_config.cert_resolver = Arc::new(ServerCheckCertResolve {
365             expected_sni: Some("the-value-from-sni".into()),
366             ..Default::default()
367         });
368 
369         let mut client = ClientSession::new(&Arc::new(client_config), dns_name("the-value-from-sni"));
370         let mut server = ServerSession::new(&Arc::new(server_config));
371 
372         let err = do_handshake_until_error(&mut client, &mut server);
373         assert_eq!(err.is_err(), true);
374     }
375 }
376 
377 #[test]
server_cert_resolve_with_alpn()378 fn server_cert_resolve_with_alpn() {
379     for kt in ALL_KEY_TYPES.iter() {
380         let mut client_config = make_client_config(*kt);
381         client_config.alpn_protocols = vec!["foo".into(), "bar".into()];
382 
383         let mut server_config = make_server_config(*kt);
384         server_config.cert_resolver = Arc::new(ServerCheckCertResolve {
385             expected_alpn: Some(vec![ b"foo".to_vec(), b"bar".to_vec() ]),
386             ..Default::default()
387         });
388 
389         let mut client = ClientSession::new(&Arc::new(client_config), dns_name("sni-value"));
390         let mut server = ServerSession::new(&Arc::new(server_config));
391 
392         let err = do_handshake_until_error(&mut client, &mut server);
393         assert_eq!(err.is_err(), true);
394     }
395 }
396 
397 
check_sigalgs_reduced_by_ciphersuite(kt: KeyType, suite: CipherSuite, expected_sigalgs: Vec<SignatureScheme>)398 fn check_sigalgs_reduced_by_ciphersuite(kt: KeyType, suite: CipherSuite,
399                                         expected_sigalgs: Vec<SignatureScheme>) {
400     let mut client_config = make_client_config(kt);
401     client_config.ciphersuites = vec![ find_suite(suite) ];
402 
403     let mut server_config = make_server_config(kt);
404 
405     server_config.cert_resolver = Arc::new(ServerCheckCertResolve {
406         expected_sigalgs: Some(expected_sigalgs),
407         ..Default::default()
408     });
409 
410     let mut client = ClientSession::new(&Arc::new(client_config), dns_name("localhost"));
411     let mut server = ServerSession::new(&Arc::new(server_config));
412 
413     let err = do_handshake_until_error(&mut client, &mut server);
414     assert_eq!(err.is_err(), true);
415 }
416 
417 #[test]
server_cert_resolve_reduces_sigalgs_for_rsa_ciphersuite()418 fn server_cert_resolve_reduces_sigalgs_for_rsa_ciphersuite() {
419     check_sigalgs_reduced_by_ciphersuite(
420         KeyType::RSA,
421         CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
422         vec![
423             SignatureScheme::RSA_PSS_SHA512,
424             SignatureScheme::RSA_PSS_SHA384,
425             SignatureScheme::RSA_PSS_SHA256,
426             SignatureScheme::RSA_PKCS1_SHA512,
427             SignatureScheme::RSA_PKCS1_SHA384,
428             SignatureScheme::RSA_PKCS1_SHA256,
429         ]
430     );
431 }
432 
433 #[test]
server_cert_resolve_reduces_sigalgs_for_ecdsa_ciphersuite()434 fn server_cert_resolve_reduces_sigalgs_for_ecdsa_ciphersuite() {
435     check_sigalgs_reduced_by_ciphersuite(
436         KeyType::ECDSA,
437         CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
438         vec![
439             SignatureScheme::ECDSA_NISTP384_SHA384,
440             SignatureScheme::ECDSA_NISTP256_SHA256,
441         ]
442     );
443 }
444 
445 struct ServerCheckNoSNI {}
446 
447 impl ResolvesServerCert for ServerCheckNoSNI {
resolve(&self, client_hello: ClientHello) -> Option<sign::CertifiedKey>448     fn resolve(&self, client_hello: ClientHello) -> Option<sign::CertifiedKey> {
449         assert!(client_hello.server_name().is_none());
450 
451         None
452     }
453 }
454 
455 #[test]
client_with_sni_disabled_does_not_send_sni()456 fn client_with_sni_disabled_does_not_send_sni() {
457     for kt in ALL_KEY_TYPES.iter() {
458         let mut client_config = make_client_config(*kt);
459         client_config.enable_sni = false;
460 
461         let mut server_config = make_server_config(*kt);
462         server_config.cert_resolver = Arc::new(ServerCheckNoSNI {});
463         let server_config = Arc::new(server_config);
464 
465         for client_config in AllClientVersions::new(client_config) {
466             let mut client = ClientSession::new(&Arc::new(client_config), dns_name("value-not-sent"));
467             let mut server = ServerSession::new(&server_config);
468 
469             let err = do_handshake_until_error(&mut client, &mut server);
470             assert_eq!(err.is_err(), true);
471         }
472     }
473 }
474 
475 #[test]
client_checks_server_certificate_with_given_name()476 fn client_checks_server_certificate_with_given_name() {
477     for kt in ALL_KEY_TYPES.iter() {
478         let client_config = make_client_config(*kt);
479         let server_config = Arc::new(make_server_config(*kt));
480 
481         for client_config in AllClientVersions::new(client_config) {
482             let mut client = ClientSession::new(&Arc::new(client_config),
483                                                 dns_name("not-the-right-hostname.com"));
484             let mut server = ServerSession::new(&server_config);
485 
486             let err = do_handshake_until_error(&mut client, &mut server);
487             assert_eq!(err,
488                        Err(TLSErrorFromPeer::Client(
489                                TLSError::WebPKIError(webpki::Error::CertNotValidForName))
490                            )
491                        );
492         }
493     }
494 }
495 
496 struct ClientCheckCertResolve {
497     query_count: AtomicUsize,
498     expect_queries: usize
499 }
500 
501 impl ClientCheckCertResolve {
new(expect_queries: usize) -> ClientCheckCertResolve502     fn new(expect_queries: usize) -> ClientCheckCertResolve {
503         ClientCheckCertResolve {
504             query_count: AtomicUsize::new(0),
505             expect_queries: expect_queries
506         }
507     }
508 }
509 
510 impl Drop for ClientCheckCertResolve {
drop(&mut self)511     fn drop(&mut self) {
512         let count = self.query_count.load(Ordering::SeqCst);
513         assert_eq!(count, self.expect_queries);
514     }
515 }
516 
517 impl ResolvesClientCert for ClientCheckCertResolve {
resolve(&self, acceptable_issuers: &[&[u8]], sigschemes: &[SignatureScheme]) -> Option<sign::CertifiedKey>518     fn resolve(&self,
519                acceptable_issuers: &[&[u8]],
520                sigschemes: &[SignatureScheme])
521         -> Option<sign::CertifiedKey> {
522         self.query_count.fetch_add(1, Ordering::SeqCst);
523 
524         if acceptable_issuers.len() == 0 {
525             panic!("no issuers offered by server");
526         }
527 
528         if sigschemes.len() == 0 {
529             panic!("no signature schemes shared by server");
530         }
531 
532         None
533     }
534 
has_certs(&self) -> bool535     fn has_certs(&self) -> bool {
536         true
537     }
538 }
539 
540 #[test]
client_cert_resolve()541 fn client_cert_resolve() {
542     for kt in ALL_KEY_TYPES.iter() {
543         let mut client_config = make_client_config(*kt);
544         client_config.client_auth_cert_resolver = Arc::new(ClientCheckCertResolve::new(2));
545 
546         let server_config = Arc::new(make_server_config_with_mandatory_client_auth(*kt));
547 
548         for client_config in AllClientVersions::new(client_config) {
549             let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config),
550                                                                      &server_config);
551 
552             assert_eq!(
553                 do_handshake_until_error(&mut client, &mut server),
554                 Err(TLSErrorFromPeer::Server(TLSError::NoCertificatesPresented)));
555         }
556     }
557 }
558 
559 #[test]
client_auth_works()560 fn client_auth_works() {
561     for kt in ALL_KEY_TYPES.iter() {
562         let client_config = make_client_config_with_auth(*kt);
563         let server_config = Arc::new(make_server_config_with_mandatory_client_auth(*kt));
564 
565         for client_config in AllClientVersions::new(client_config) {
566             let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config),
567                                                                      &server_config);
568             do_handshake(&mut client, &mut server);
569         }
570     }
571 }
572 
573 #[cfg(feature = "dangerous_configuration")]
574 mod test_verifier {
575     use super::*;
576     use crate::common::MockClientVerifier;
577     use rustls::internal::msgs::enums::AlertDescription;
578 
579     // Client is authorized!
ver_ok() -> Result<ClientCertVerified, TLSError>580     fn ver_ok() -> Result<ClientCertVerified, TLSError> {
581         Ok(rustls::ClientCertVerified::assertion())
582     }
583 
584     // Use when we shouldn't even attempt verification
ver_unreachable() -> Result<ClientCertVerified, TLSError>585     fn ver_unreachable() -> Result<ClientCertVerified, TLSError> {
586         unreachable!()
587     }
588 
589     // Verifier that returns an error that we can expect
ver_err() -> Result<ClientCertVerified, TLSError>590     fn ver_err() -> Result<ClientCertVerified, TLSError> {
591         Err(TLSError::General("test err".to_string()))
592     }
593 
594     #[test]
595     // Happy path, we resolve to a root, it is verified OK, should be able to connect
client_verifier_works()596     fn client_verifier_works() {
597         for kt in ALL_KEY_TYPES.iter() {
598             let client_verifier = MockClientVerifier {
599                 verified: ver_ok,
600                 subjects: Some(get_client_root_store(*kt).get_subjects()),
601                 mandatory: Some(true),
602             };
603 
604             let mut server_config = ServerConfig::new(Arc::new(client_verifier));
605             server_config.set_single_cert(kt.get_chain(), kt.get_key()).unwrap();
606 
607             let server_config = Arc::new(server_config);
608             let client_config = make_client_config_with_auth(*kt);
609 
610             for client_config in AllClientVersions::new(client_config) {
611                 let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config.clone()),
612                                                                      &server_config);
613                 let err = do_handshake_until_error(&mut client, &mut server);
614                 assert_eq!(err, Ok(()));
615             }
616         }
617     }
618 
619     // Common case, we do not find a root store to resolve to
620     #[test]
client_verifier_no_root()621     fn client_verifier_no_root() {
622         for kt in ALL_KEY_TYPES.iter() {
623             let client_verifier = MockClientVerifier {
624                 verified: ver_ok,
625                 subjects: None,
626                 mandatory: Some(true),
627             };
628 
629             let mut server_config = ServerConfig::new(Arc::new(client_verifier));
630             server_config.set_single_cert(kt.get_chain(), kt.get_key()).unwrap();
631 
632             let server_config = Arc::new(server_config);
633             let client_config = make_client_config_with_auth(*kt);
634 
635             for client_config in AllClientVersions::new(client_config) {
636                 let mut server = ServerSession::new(&server_config);
637                 let mut client = ClientSession::new(&Arc::new(client_config), dns_name("notlocalhost"));
638                 let errs = do_handshake_until_both_error(&mut client, &mut server);
639                 assert_eq!(errs,
640                            Err(vec![
641                               TLSErrorFromPeer::Server(TLSError::General("client rejected by client_auth_root_subjects".into())),
642                               TLSErrorFromPeer::Client(TLSError::AlertReceived(AlertDescription::AccessDenied))
643                            ]));
644             }
645         }
646     }
647 
648     // If we cannot resolve a root, we cannot decide if auth is mandatory
649     #[test]
client_verifier_no_auth_no_root()650     fn client_verifier_no_auth_no_root() {
651         for kt in ALL_KEY_TYPES.iter() {
652             let client_verifier = MockClientVerifier {
653                 verified: ver_unreachable,
654                 subjects: None,
655                 mandatory: Some(true),
656             };
657 
658             let mut server_config = ServerConfig::new(Arc::new(client_verifier));
659             server_config.set_single_cert(kt.get_chain(), kt.get_key()).unwrap();
660 
661             let server_config = Arc::new(server_config);
662             let client_config = make_client_config(*kt);
663 
664             for client_config in AllClientVersions::new(client_config) {
665                 let mut server = ServerSession::new(&server_config);
666                 let mut client = ClientSession::new(&Arc::new(client_config), dns_name("notlocalhost"));
667                 let errs = do_handshake_until_both_error(&mut client, &mut server);
668                 assert_eq!(errs,
669                            Err(vec![
670                                TLSErrorFromPeer::Server(TLSError::General("client rejected by client_auth_root_subjects".into())),
671                                TLSErrorFromPeer::Client(TLSError::AlertReceived(AlertDescription::AccessDenied))
672                             ]));
673             }
674         }
675     }
676 
677         // If we do have a root, we must do auth
678     #[test]
client_verifier_no_auth_yes_root()679     fn client_verifier_no_auth_yes_root() {
680         for kt in ALL_KEY_TYPES.iter() {
681             let client_verifier = MockClientVerifier {
682                 verified: ver_unreachable,
683                 subjects: Some(get_client_root_store(*kt).get_subjects()),
684                 mandatory: Some(true),
685             };
686 
687             let mut server_config = ServerConfig::new(Arc::new(client_verifier));
688             server_config.set_single_cert(kt.get_chain(), kt.get_key()).unwrap();
689 
690             let server_config = Arc::new(server_config);
691             let client_config = make_client_config(*kt);
692 
693             for client_config in AllClientVersions::new(client_config) {
694                 println!("Failing: {:?}", client_config.versions);
695                 let mut server = ServerSession::new(&server_config);
696                 let mut client = ClientSession::new(&Arc::new(client_config), dns_name("localhost"));
697                 let errs = do_handshake_until_both_error(&mut client, &mut server);
698                 assert_eq!(errs,
699                            Err(vec![
700                                TLSErrorFromPeer::Server(TLSError::NoCertificatesPresented),
701                                TLSErrorFromPeer::Client(TLSError::AlertReceived(AlertDescription::CertificateRequired))
702                            ]));
703             }
704         }
705     }
706 
707     #[test]
708     // Triple checks we propagate the TLSError through
client_verifier_fails_properly()709     fn client_verifier_fails_properly() {
710         for kt in ALL_KEY_TYPES.iter() {
711             let client_verifier = MockClientVerifier {
712                 verified: ver_err,
713                 subjects: Some(get_client_root_store(*kt).get_subjects()),
714                 mandatory: Some(true),
715             };
716 
717             let mut server_config = ServerConfig::new(Arc::new(client_verifier));
718             server_config.set_single_cert(kt.get_chain(), kt.get_key()).unwrap();
719 
720             let server_config = Arc::new(server_config);
721             let client_config = make_client_config_with_auth(*kt);
722 
723             for client_config in AllClientVersions::new(client_config) {
724                 let mut server = ServerSession::new(&server_config);
725                 let mut client = ClientSession::new(&Arc::new(client_config), dns_name("localhost"));
726                 let err = do_handshake_until_error(&mut client, &mut server);
727                 assert_eq!(err, Err(TLSErrorFromPeer::Server(
728                             TLSError::General("test err".into()))));
729             }
730         }
731     }
732 
733 
734     #[test]
735     // If a verifier returns a None on Mandatory-ness, then we error out
client_verifier_must_determine_client_auth_requirement_to_continue()736     fn client_verifier_must_determine_client_auth_requirement_to_continue() {
737         for kt in ALL_KEY_TYPES.iter() {
738             let client_verifier = MockClientVerifier {
739                 verified: ver_ok,
740                 subjects: Some(get_client_root_store(*kt).get_subjects()),
741                 mandatory: None,
742             };
743 
744             let mut server_config = ServerConfig::new(Arc::new(client_verifier));
745             server_config.set_single_cert(kt.get_chain(), kt.get_key()).unwrap();
746 
747             let server_config = Arc::new(server_config);
748             let client_config = make_client_config_with_auth(*kt);
749 
750             for client_config in AllClientVersions::new(client_config) {
751                 let mut server = ServerSession::new(&server_config);
752                 let mut client = ClientSession::new(&Arc::new(client_config), dns_name("localhost"));
753                 let errs = do_handshake_until_both_error(&mut client, &mut server);
754                 assert_eq!(errs,
755                            Err(vec![
756                                TLSErrorFromPeer::Server(TLSError::General("client rejected by client_auth_mandatory".into())),
757                                TLSErrorFromPeer::Client(TLSError::AlertReceived(AlertDescription::AccessDenied))
758                            ]));
759             }
760         }
761     }
762 } // mod test_verifier
763 
764 #[test]
client_error_is_sticky()765 fn client_error_is_sticky() {
766     let (mut client, _) = make_pair(KeyType::RSA);
767     client.read_tls(&mut b"\x16\x03\x03\x00\x08\x0f\x00\x00\x04junk".as_ref()).unwrap();
768     let mut err = client.process_new_packets();
769     assert_eq!(err.is_err(), true);
770     err = client.process_new_packets();
771     assert_eq!(err.is_err(), true);
772 }
773 
774 #[test]
server_error_is_sticky()775 fn server_error_is_sticky() {
776     let (_, mut server) = make_pair(KeyType::RSA);
777     server.read_tls(&mut b"\x16\x03\x03\x00\x08\x0f\x00\x00\x04junk".as_ref()).unwrap();
778     let mut err = server.process_new_packets();
779     assert_eq!(err.is_err(), true);
780     err = server.process_new_packets();
781     assert_eq!(err.is_err(), true);
782 }
783 
784 #[test]
server_is_send_and_sync()785 fn server_is_send_and_sync() {
786     let (_, server) = make_pair(KeyType::RSA);
787     &server as &dyn Send;
788     &server as &dyn Sync;
789 }
790 
791 #[test]
client_is_send_and_sync()792 fn client_is_send_and_sync() {
793     let (client, _) = make_pair(KeyType::RSA);
794     &client as &dyn Send;
795     &client as &dyn Sync;
796 }
797 
798 #[test]
server_respects_buffer_limit_pre_handshake()799 fn server_respects_buffer_limit_pre_handshake() {
800     let (mut client, mut server) = make_pair(KeyType::RSA);
801 
802     server.set_buffer_limit(32);
803 
804     assert_eq!(server.write(b"01234567890123456789").unwrap(), 20);
805     assert_eq!(server.write(b"01234567890123456789").unwrap(), 12);
806 
807     do_handshake(&mut client, &mut server);
808     transfer(&mut server, &mut client);
809     client.process_new_packets().unwrap();
810 
811     check_read(&mut client, b"01234567890123456789012345678901");
812 }
813 
814 #[test]
server_respects_buffer_limit_post_handshake()815 fn server_respects_buffer_limit_post_handshake() {
816     let (mut client, mut server) = make_pair(KeyType::RSA);
817 
818     // this test will vary in behaviour depending on the default suites
819     do_handshake(&mut client, &mut server);
820     server.set_buffer_limit(48);
821 
822     assert_eq!(server.write(b"01234567890123456789").unwrap(), 20);
823     assert_eq!(server.write(b"01234567890123456789").unwrap(), 6);
824 
825     transfer(&mut server, &mut client);
826     client.process_new_packets().unwrap();
827 
828     check_read(&mut client, b"01234567890123456789012345");
829 }
830 
831 #[test]
client_respects_buffer_limit_pre_handshake()832 fn client_respects_buffer_limit_pre_handshake() {
833     let (mut client, mut server) = make_pair(KeyType::RSA);
834 
835     client.set_buffer_limit(32);
836 
837     assert_eq!(client.write(b"01234567890123456789").unwrap(), 20);
838     assert_eq!(client.write(b"01234567890123456789").unwrap(), 12);
839 
840     do_handshake(&mut client, &mut server);
841     transfer(&mut client, &mut server);
842     server.process_new_packets().unwrap();
843 
844     check_read(&mut server, b"01234567890123456789012345678901");
845 }
846 
847 #[test]
client_respects_buffer_limit_post_handshake()848 fn client_respects_buffer_limit_post_handshake() {
849     let (mut client, mut server) = make_pair(KeyType::RSA);
850 
851     do_handshake(&mut client, &mut server);
852     client.set_buffer_limit(48);
853 
854     assert_eq!(client.write(b"01234567890123456789").unwrap(), 20);
855     assert_eq!(client.write(b"01234567890123456789").unwrap(), 6);
856 
857     transfer(&mut client, &mut server);
858     server.process_new_packets().unwrap();
859 
860     check_read(&mut server, b"01234567890123456789012345");
861 }
862 
863 struct OtherSession<'a> {
864     sess: &'a mut dyn Session,
865     pub reads: usize,
866     pub writes: usize,
867     pub writevs: Vec<Vec<usize>>,
868     fail_ok: bool,
869     pub short_writes: bool,
870     pub last_error: Option<rustls::TLSError>,
871 }
872 
873 impl<'a> OtherSession<'a> {
new(sess: &'a mut dyn Session) -> OtherSession<'a>874     fn new(sess: &'a mut dyn Session) -> OtherSession<'a> {
875         OtherSession {
876             sess,
877             reads: 0,
878             writes: 0,
879             writevs: vec![],
880             fail_ok: false,
881             short_writes: false,
882             last_error: None,
883         }
884     }
885 
new_fails(sess: &'a mut dyn Session) -> OtherSession<'a>886     fn new_fails(sess: &'a mut dyn Session) -> OtherSession<'a> {
887         let mut os = OtherSession::new(sess);
888         os.fail_ok = true;
889         os
890     }
891 }
892 
893 impl<'a> io::Read for OtherSession<'a> {
read(&mut self, mut b: &mut [u8]) -> io::Result<usize>894     fn read(&mut self, mut b: &mut [u8]) -> io::Result<usize> {
895         self.reads += 1;
896         self.sess.write_tls(b.by_ref())
897     }
898 }
899 
900 impl<'a> io::Write for OtherSession<'a> {
write(&mut self, mut b: &[u8]) -> io::Result<usize>901     fn write(&mut self, mut b: &[u8]) -> io::Result<usize> {
902         self.writes += 1;
903         let l = self.sess.read_tls(b.by_ref())?;
904         let rc = self.sess.process_new_packets();
905 
906         if !self.fail_ok {
907             rc.unwrap();
908         } else if rc.is_err() {
909             self.last_error = rc.err();
910         }
911 
912         Ok(l)
913     }
914 
flush(&mut self) -> io::Result<()>915     fn flush(&mut self) -> io::Result<()> {
916         Ok(())
917     }
918 }
919 
920 impl<'a> rustls::WriteV for OtherSession<'a> {
writev(&mut self, b: &[&[u8]]) -> io::Result<usize>921     fn writev(&mut self, b: &[&[u8]]) -> io::Result<usize> {
922         let mut total = 0;
923         let mut lengths = vec![];
924         for bytes in b {
925             let write_len = if self.short_writes {
926                 if bytes.len() > 5 { bytes.len() / 2 } else { bytes.len() }
927             } else {
928                 bytes.len()
929             };
930 
931             let l = self.sess.read_tls(&mut io::Cursor::new(&bytes[..write_len]))?;
932             lengths.push(l);
933             total += l;
934             if bytes.len() != l {
935                 break;
936             }
937         }
938 
939         let rc = self.sess.process_new_packets();
940         if !self.fail_ok {
941             rc.unwrap();
942         } else if rc.is_err() {
943             self.last_error = rc.err();
944         }
945 
946         self.writevs.push(lengths);
947         Ok(total)
948     }
949 }
950 
951 #[test]
client_complete_io_for_handshake()952 fn client_complete_io_for_handshake() {
953     let (mut client, mut server) = make_pair(KeyType::RSA);
954 
955     assert_eq!(true, client.is_handshaking());
956     let (rdlen, wrlen) = client.complete_io(&mut OtherSession::new(&mut server)).unwrap();
957     assert!(rdlen > 0 && wrlen > 0);
958     assert_eq!(false, client.is_handshaking());
959 }
960 
961 #[test]
client_complete_io_for_handshake_eof()962 fn client_complete_io_for_handshake_eof() {
963     let (mut client, _) = make_pair(KeyType::RSA);
964     let mut input = io::Cursor::new(Vec::new());
965 
966     assert_eq!(true, client.is_handshaking());
967     let err = client.complete_io(&mut input).unwrap_err();
968     assert_eq!(io::ErrorKind::UnexpectedEof, err.kind());
969 }
970 
971 #[test]
client_complete_io_for_write()972 fn client_complete_io_for_write() {
973     for kt in ALL_KEY_TYPES.iter() {
974         let (mut client, mut server) = make_pair(*kt);
975 
976         do_handshake(&mut client, &mut server);
977 
978         client.write(b"01234567890123456789").unwrap();
979         client.write(b"01234567890123456789").unwrap();
980         {
981             let mut pipe = OtherSession::new(&mut server);
982             let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap();
983             assert!(rdlen == 0 && wrlen > 0);
984             assert_eq!(pipe.writes, 2);
985         }
986         check_read(&mut server, b"0123456789012345678901234567890123456789");
987     }
988 }
989 
990 #[test]
client_complete_io_for_read()991 fn client_complete_io_for_read() {
992     for kt in ALL_KEY_TYPES.iter() {
993         let (mut client, mut server) = make_pair(*kt);
994 
995         do_handshake(&mut client, &mut server);
996 
997         server.write(b"01234567890123456789").unwrap();
998         {
999             let mut pipe = OtherSession::new(&mut server);
1000             let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap();
1001             assert!(rdlen > 0 && wrlen == 0);
1002             assert_eq!(pipe.reads, 1);
1003         }
1004         check_read(&mut client, b"01234567890123456789");
1005     }
1006 }
1007 
1008 #[test]
server_complete_io_for_handshake()1009 fn server_complete_io_for_handshake() {
1010     for kt in ALL_KEY_TYPES.iter() {
1011         let (mut client, mut server) = make_pair(*kt);
1012 
1013         assert_eq!(true, server.is_handshaking());
1014         let (rdlen, wrlen) = server.complete_io(&mut OtherSession::new(&mut client)).unwrap();
1015         assert!(rdlen > 0 && wrlen > 0);
1016         assert_eq!(false, server.is_handshaking());
1017     }
1018 }
1019 
1020 #[test]
server_complete_io_for_handshake_eof()1021 fn server_complete_io_for_handshake_eof() {
1022     let (_, mut server) = make_pair(KeyType::RSA);
1023     let mut input = io::Cursor::new(Vec::new());
1024 
1025     assert_eq!(true, server.is_handshaking());
1026     let err = server.complete_io(&mut input).unwrap_err();
1027     assert_eq!(io::ErrorKind::UnexpectedEof, err.kind());
1028 }
1029 
1030 #[test]
server_complete_io_for_write()1031 fn server_complete_io_for_write() {
1032     for kt in ALL_KEY_TYPES.iter() {
1033         let (mut client, mut server) = make_pair(*kt);
1034 
1035         do_handshake(&mut client, &mut server);
1036 
1037         server.write(b"01234567890123456789").unwrap();
1038         server.write(b"01234567890123456789").unwrap();
1039         {
1040             let mut pipe = OtherSession::new(&mut client);
1041             let (rdlen, wrlen) = server.complete_io(&mut pipe).unwrap();
1042             assert!(rdlen == 0 && wrlen > 0);
1043             assert_eq!(pipe.writes, 2);
1044         }
1045         check_read(&mut client, b"0123456789012345678901234567890123456789");
1046     }
1047 }
1048 
1049 #[test]
server_complete_io_for_read()1050 fn server_complete_io_for_read() {
1051     for kt in ALL_KEY_TYPES.iter() {
1052         let (mut client, mut server) = make_pair(*kt);
1053 
1054         do_handshake(&mut client, &mut server);
1055 
1056         client.write(b"01234567890123456789").unwrap();
1057         {
1058             let mut pipe = OtherSession::new(&mut client);
1059             let (rdlen, wrlen) = server.complete_io(&mut pipe).unwrap();
1060             assert!(rdlen > 0 && wrlen == 0);
1061             assert_eq!(pipe.reads, 1);
1062         }
1063         check_read(&mut server, b"01234567890123456789");
1064     }
1065 }
1066 
1067 #[test]
client_stream_write()1068 fn client_stream_write() {
1069     for kt in ALL_KEY_TYPES.iter() {
1070         let (mut client, mut server) = make_pair(*kt);
1071 
1072         {
1073             let mut pipe = OtherSession::new(&mut server);
1074             let mut stream = Stream::new(&mut client, &mut pipe);
1075             assert_eq!(stream.write(b"hello").unwrap(), 5);
1076         }
1077         check_read(&mut server, b"hello");
1078     }
1079 }
1080 
1081 #[test]
client_streamowned_write()1082 fn client_streamowned_write() {
1083     for kt in ALL_KEY_TYPES.iter() {
1084         let (client, mut server) = make_pair(*kt);
1085 
1086         {
1087             let pipe = OtherSession::new(&mut server);
1088             let mut stream = StreamOwned::new(client, pipe);
1089             assert_eq!(stream.write(b"hello").unwrap(), 5);
1090         }
1091         check_read(&mut server, b"hello");
1092     }
1093 }
1094 
1095 #[test]
client_stream_read()1096 fn client_stream_read() {
1097     for kt in ALL_KEY_TYPES.iter() {
1098         let (mut client, mut server) = make_pair(*kt);
1099 
1100         server.write(b"world").unwrap();
1101 
1102         {
1103             let mut pipe = OtherSession::new(&mut server);
1104             let mut stream = Stream::new(&mut client, &mut pipe);
1105             check_read(&mut stream, b"world");
1106         }
1107     }
1108 }
1109 
1110 #[test]
client_streamowned_read()1111 fn client_streamowned_read() {
1112     for kt in ALL_KEY_TYPES.iter() {
1113         let (client, mut server) = make_pair(*kt);
1114 
1115         server.write(b"world").unwrap();
1116 
1117         {
1118             let pipe = OtherSession::new(&mut server);
1119             let mut stream = StreamOwned::new(client, pipe);
1120             check_read(&mut stream, b"world");
1121         }
1122     }
1123 }
1124 
1125 #[test]
server_stream_write()1126 fn server_stream_write() {
1127     for kt in ALL_KEY_TYPES.iter() {
1128         let (mut client, mut server) = make_pair(*kt);
1129 
1130         {
1131             let mut pipe = OtherSession::new(&mut client);
1132             let mut stream = Stream::new(&mut server, &mut pipe);
1133             assert_eq!(stream.write(b"hello").unwrap(), 5);
1134         }
1135         check_read(&mut client, b"hello");
1136     }
1137 }
1138 
1139 #[test]
server_streamowned_write()1140 fn server_streamowned_write() {
1141     for kt in ALL_KEY_TYPES.iter() {
1142         let (mut client, server) = make_pair(*kt);
1143 
1144         {
1145             let pipe = OtherSession::new(&mut client);
1146             let mut stream = StreamOwned::new(server, pipe);
1147             assert_eq!(stream.write(b"hello").unwrap(), 5);
1148         }
1149         check_read(&mut client, b"hello");
1150     }
1151 }
1152 
1153 #[test]
server_stream_read()1154 fn server_stream_read() {
1155     for kt in ALL_KEY_TYPES.iter() {
1156         let (mut client, mut server) = make_pair(*kt);
1157 
1158         client.write(b"world").unwrap();
1159 
1160         {
1161             let mut pipe = OtherSession::new(&mut client);
1162             let mut stream = Stream::new(&mut server, &mut pipe);
1163             check_read(&mut stream, b"world");
1164         }
1165     }
1166 }
1167 
1168 #[test]
server_streamowned_read()1169 fn server_streamowned_read() {
1170     for kt in ALL_KEY_TYPES.iter() {
1171         let (mut client, server) = make_pair(*kt);
1172 
1173         client.write(b"world").unwrap();
1174 
1175         {
1176             let pipe = OtherSession::new(&mut client);
1177             let mut stream = StreamOwned::new(server, pipe);
1178             check_read(&mut stream, b"world");
1179         }
1180     }
1181 }
1182 
1183 struct FailsWrites {
1184     errkind: io::ErrorKind,
1185     after: usize,
1186 }
1187 
1188 impl io::Read for FailsWrites {
read(&mut self, _b: &mut [u8]) -> io::Result<usize>1189     fn read(&mut self, _b: &mut [u8]) -> io::Result<usize> {
1190         Ok(0)
1191     }
1192 }
1193 
1194 impl io::Write for FailsWrites {
write(&mut self, b: &[u8]) -> io::Result<usize>1195     fn write(&mut self, b: &[u8]) -> io::Result<usize> {
1196         if self.after > 0 {
1197             self.after -= 1;
1198             Ok(b.len())
1199         } else {
1200             Err(io::Error::new(self.errkind, "oops"))
1201         }
1202     }
1203 
flush(&mut self) -> io::Result<()>1204     fn flush(&mut self) -> io::Result<()> {
1205         Ok(())
1206     }
1207 }
1208 
1209 #[test]
stream_write_reports_underlying_io_error_before_plaintext_processed()1210 fn stream_write_reports_underlying_io_error_before_plaintext_processed() {
1211     let (mut client, mut server) = make_pair(KeyType::RSA);
1212     do_handshake(&mut client, &mut server);
1213 
1214     let mut pipe = FailsWrites {
1215         errkind: io::ErrorKind::WouldBlock,
1216         after: 0,
1217     };
1218     client.write(b"hello").unwrap();
1219     let mut client_stream = Stream::new(&mut client, &mut pipe);
1220     let rc = client_stream.write(b"world");
1221     assert!(rc.is_err());
1222     let err = rc.err().unwrap();
1223     assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
1224     assert_eq!(err.description(), "oops");
1225 }
1226 
1227 #[test]
stream_write_swallows_underlying_io_error_after_plaintext_processed()1228 fn stream_write_swallows_underlying_io_error_after_plaintext_processed() {
1229     let (mut client, mut server) = make_pair(KeyType::RSA);
1230     do_handshake(&mut client, &mut server);
1231 
1232     let mut pipe = FailsWrites {
1233         errkind: io::ErrorKind::WouldBlock,
1234         after: 1,
1235     };
1236     client.write(b"hello").unwrap();
1237     let mut client_stream = Stream::new(&mut client, &mut pipe);
1238     let rc = client_stream.write(b"world");
1239     assert_eq!(format!("{:?}", rc), "Ok(5)");
1240 }
1241 
make_disjoint_suite_configs() -> (ClientConfig, ServerConfig)1242 fn make_disjoint_suite_configs() -> (ClientConfig, ServerConfig) {
1243     let kt = KeyType::RSA;
1244     let mut server_config = make_server_config(kt);
1245     server_config.ciphersuites = vec![find_suite(CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)];
1246 
1247     let mut client_config = make_client_config(kt);
1248     client_config.ciphersuites = vec![find_suite(CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384)];
1249 
1250     (client_config, server_config)
1251 }
1252 
1253 #[test]
client_stream_handshake_error()1254 fn client_stream_handshake_error() {
1255     let (client_config, server_config) = make_disjoint_suite_configs();
1256     let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
1257 
1258     {
1259         let mut pipe = OtherSession::new_fails(&mut server);
1260         let mut client_stream = Stream::new(&mut client, &mut pipe);
1261         let rc = client_stream.write(b"hello");
1262         assert!(rc.is_err());
1263         assert_eq!(format!("{:?}", rc),
1264                    "Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })");
1265         let rc = client_stream.write(b"hello");
1266         assert!(rc.is_err());
1267         assert_eq!(format!("{:?}", rc),
1268                    "Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })");
1269 
1270     }
1271 }
1272 
1273 #[test]
client_streamowned_handshake_error()1274 fn client_streamowned_handshake_error() {
1275     let (client_config, server_config) = make_disjoint_suite_configs();
1276     let (client, mut server) = make_pair_for_configs(client_config, server_config);
1277 
1278     let pipe = OtherSession::new_fails(&mut server);
1279     let mut client_stream = StreamOwned::new(client, pipe);
1280     let rc = client_stream.write(b"hello");
1281     assert!(rc.is_err());
1282     assert_eq!(format!("{:?}", rc),
1283                "Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })");
1284     let rc = client_stream.write(b"hello");
1285     assert!(rc.is_err());
1286     assert_eq!(format!("{:?}", rc),
1287                "Err(Custom { kind: InvalidData, error: AlertReceived(HandshakeFailure) })");
1288 }
1289 
1290 #[test]
server_stream_handshake_error()1291 fn server_stream_handshake_error() {
1292     let (client_config, server_config) = make_disjoint_suite_configs();
1293     let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
1294 
1295     client.write(b"world").unwrap();
1296 
1297     {
1298         let mut pipe = OtherSession::new_fails(&mut client);
1299         let mut server_stream = Stream::new(&mut server, &mut pipe);
1300         let mut bytes = [0u8; 5];
1301         let rc = server_stream.read(&mut bytes);
1302         assert!(rc.is_err());
1303         assert_eq!(format!("{:?}", rc),
1304                    "Err(Custom { kind: InvalidData, error: PeerIncompatibleError(\"no ciphersuites in common\") })");
1305     }
1306 }
1307 
1308 #[test]
server_streamowned_handshake_error()1309 fn server_streamowned_handshake_error() {
1310     let (client_config, server_config) = make_disjoint_suite_configs();
1311     let (mut client, server) = make_pair_for_configs(client_config, server_config);
1312 
1313     client.write(b"world").unwrap();
1314 
1315     let pipe = OtherSession::new_fails(&mut client);
1316     let mut server_stream = StreamOwned::new(server, pipe);
1317     let mut bytes = [0u8; 5];
1318     let rc = server_stream.read(&mut bytes);
1319     assert!(rc.is_err());
1320     assert_eq!(format!("{:?}", rc),
1321                "Err(Custom { kind: InvalidData, error: PeerIncompatibleError(\"no ciphersuites in common\") })");
1322 }
1323 
1324 #[test]
server_config_is_clone()1325 fn server_config_is_clone() {
1326     let _ = make_server_config(KeyType::RSA).clone();
1327 }
1328 
1329 #[test]
client_config_is_clone()1330 fn client_config_is_clone() {
1331     let _ = make_client_config(KeyType::RSA).clone();
1332 }
1333 
1334 #[test]
client_session_is_debug()1335 fn client_session_is_debug() {
1336     let (client, _) = make_pair(KeyType::RSA);
1337     println!("{:?}", client);
1338 }
1339 
1340 #[test]
server_session_is_debug()1341 fn server_session_is_debug() {
1342     let (_, server) = make_pair(KeyType::RSA);
1343     println!("{:?}", server);
1344 }
1345 
1346 #[test]
server_complete_io_for_handshake_ending_with_alert()1347 fn server_complete_io_for_handshake_ending_with_alert() {
1348     let (client_config, server_config) = make_disjoint_suite_configs();
1349     let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
1350 
1351     assert_eq!(true, server.is_handshaking());
1352 
1353     let mut pipe = OtherSession::new_fails(&mut client);
1354     let rc = server.complete_io(&mut pipe);
1355     assert!(rc.is_err(),
1356             "server io failed due to handshake failure");
1357     assert!(!server.wants_write(),
1358             "but server did send its alert");
1359     assert_eq!(format!("{:?}", pipe.last_error),
1360                "Some(AlertReceived(HandshakeFailure))",
1361                "which was received by client");
1362 }
1363 
1364 #[test]
server_exposes_offered_sni()1365 fn server_exposes_offered_sni() {
1366     let kt = KeyType::RSA;
1367     let mut client = ClientSession::new(&Arc::new(make_client_config(kt)),
1368                                         dns_name("second.testserver.com"));
1369     let mut server = ServerSession::new(&Arc::new(make_server_config(kt)));
1370 
1371     assert_eq!(None, server.get_sni_hostname());
1372     do_handshake(&mut client, &mut server);
1373     assert_eq!(Some("second.testserver.com"), server.get_sni_hostname());
1374 }
1375 
1376 #[test]
sni_resolver_works()1377 fn sni_resolver_works() {
1378     let kt = KeyType::RSA;
1379     let mut resolver = rustls::ResolvesServerCertUsingSNI::new();
1380     let signing_key = sign::RSASigningKey::new(&kt.get_key())
1381         .unwrap();
1382     let signing_key: Arc<Box<dyn sign::SigningKey>> = Arc::new(Box::new(signing_key));
1383     resolver.add("localhost",
1384                  sign::CertifiedKey::new(kt.get_chain(), signing_key.clone()))
1385         .unwrap();
1386 
1387     let mut server_config = make_server_config(kt);
1388     server_config.cert_resolver = Arc::new(resolver);
1389     let server_config = Arc::new(server_config);
1390 
1391     let mut server1 = ServerSession::new(&server_config);
1392     let mut client1 = ClientSession::new(&Arc::new(make_client_config(kt)), dns_name("localhost"));
1393     let err = do_handshake_until_error(&mut client1, &mut server1);
1394     assert_eq!(err, Ok(()));
1395 
1396     let mut server2 = ServerSession::new(&server_config);
1397     let mut client2 = ClientSession::new(&Arc::new(make_client_config(kt)), dns_name("notlocalhost"));
1398     let err = do_handshake_until_error(&mut client2, &mut server2);
1399     assert_eq!(err,
1400                Err(TLSErrorFromPeer::Server(
1401                        TLSError::General("no server certificate chain resolved".into()))));
1402 }
1403 
1404 #[test]
sni_resolver_rejects_wrong_names()1405 fn sni_resolver_rejects_wrong_names() {
1406     let kt = KeyType::RSA;
1407     let mut resolver = rustls::ResolvesServerCertUsingSNI::new();
1408     let signing_key = sign::RSASigningKey::new(&kt.get_key())
1409         .unwrap();
1410     let signing_key: Arc<Box<dyn sign::SigningKey>> = Arc::new(Box::new(signing_key));
1411 
1412     assert_eq!(Ok(()),
1413                resolver.add("localhost",
1414                             sign::CertifiedKey::new(kt.get_chain(), signing_key.clone())));
1415     assert_eq!(Err(TLSError::General("The server certificate is not valid for the given name".into())),
1416                resolver.add("not-localhost",
1417                             sign::CertifiedKey::new(kt.get_chain(), signing_key.clone())));
1418     assert_eq!(Err(TLSError::General("Bad DNS name".into())),
1419                resolver.add("not ascii ��",
1420                             sign::CertifiedKey::new(kt.get_chain(), signing_key.clone())));
1421 }
1422 
1423 #[test]
sni_resolver_rejects_bad_certs()1424 fn sni_resolver_rejects_bad_certs() {
1425     let kt = KeyType::RSA;
1426     let mut resolver = rustls::ResolvesServerCertUsingSNI::new();
1427     let signing_key = sign::RSASigningKey::new(&kt.get_key())
1428         .unwrap();
1429     let signing_key: Arc<Box<dyn sign::SigningKey>> = Arc::new(Box::new(signing_key));
1430 
1431     assert_eq!(Err(TLSError::General("No end-entity certificate in certificate chain".into())),
1432                resolver.add("localhost",
1433                             sign::CertifiedKey::new(vec![], signing_key.clone())));
1434 
1435     let bad_chain = vec![ rustls::Certificate(vec![ 0xa0 ]) ];
1436     assert_eq!(Err(TLSError::General("End-entity certificate in certificate chain is syntactically invalid".into())),
1437                resolver.add("localhost",
1438                             sign::CertifiedKey::new(bad_chain, signing_key.clone())));
1439 }
1440 
do_exporter_test(client_config: ClientConfig, server_config: ServerConfig)1441 fn do_exporter_test(client_config: ClientConfig, server_config: ServerConfig) {
1442     let mut client_secret = [0u8; 64];
1443     let mut server_secret = [0u8; 64];
1444 
1445     let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
1446 
1447     assert_eq!(Err(TLSError::HandshakeNotComplete),
1448                client.export_keying_material(&mut client_secret, b"label", Some(b"context")));
1449     assert_eq!(Err(TLSError::HandshakeNotComplete),
1450                server.export_keying_material(&mut server_secret, b"label", Some(b"context")));
1451     do_handshake(&mut client, &mut server);
1452 
1453     assert_eq!(Ok(()),
1454                client.export_keying_material(&mut client_secret, b"label", Some(b"context")));
1455     assert_eq!(Ok(()),
1456                server.export_keying_material(&mut server_secret, b"label", Some(b"context")));
1457     assert_eq!(client_secret.to_vec(), server_secret.to_vec());
1458 
1459     assert_eq!(Ok(()),
1460                client.export_keying_material(&mut client_secret, b"label", None));
1461     assert_ne!(client_secret.to_vec(), server_secret.to_vec());
1462     assert_eq!(Ok(()),
1463                server.export_keying_material(&mut server_secret, b"label", None));
1464     assert_eq!(client_secret.to_vec(), server_secret.to_vec());
1465 }
1466 
1467 #[test]
test_tls12_exporter()1468 fn test_tls12_exporter() {
1469     for kt in ALL_KEY_TYPES.iter() {
1470         let mut client_config = make_client_config(*kt);
1471         let server_config = make_server_config(*kt);
1472         client_config.versions = vec![ ProtocolVersion::TLSv1_2 ];
1473 
1474         do_exporter_test(client_config, server_config);
1475     }
1476 }
1477 
1478 #[test]
test_tls13_exporter()1479 fn test_tls13_exporter() {
1480     for kt in ALL_KEY_TYPES.iter() {
1481         let mut client_config = make_client_config(*kt);
1482         let server_config = make_server_config(*kt);
1483         client_config.versions = vec![ ProtocolVersion::TLSv1_3 ];
1484 
1485         do_exporter_test(client_config, server_config);
1486     }
1487 }
1488 
do_suite_test(client_config: ClientConfig, server_config: ServerConfig, expect_suite: &'static SupportedCipherSuite, expect_version: ProtocolVersion)1489 fn do_suite_test(client_config: ClientConfig,
1490                  server_config: ServerConfig,
1491                  expect_suite: &'static SupportedCipherSuite,
1492                  expect_version: ProtocolVersion) {
1493     println!("do_suite_test {:?} {:?}", expect_version, expect_suite.suite);
1494     let (mut client, mut server) = make_pair_for_configs(client_config, server_config);
1495 
1496     assert_eq!(None, client.get_negotiated_ciphersuite());
1497     assert_eq!(None, server.get_negotiated_ciphersuite());
1498     assert_eq!(None, client.get_protocol_version());
1499     assert_eq!(None, server.get_protocol_version());
1500     assert_eq!(true, client.is_handshaking());
1501     assert_eq!(true, server.is_handshaking());
1502 
1503     transfer(&mut client, &mut server);
1504     server.process_new_packets().unwrap();
1505 
1506     assert_eq!(true, client.is_handshaking());
1507     assert_eq!(true, server.is_handshaking());
1508     assert_eq!(None, client.get_protocol_version());
1509     assert_eq!(Some(expect_version), server.get_protocol_version());
1510     assert_eq!(None, client.get_negotiated_ciphersuite());
1511     assert_eq!(Some(expect_suite), server.get_negotiated_ciphersuite());
1512 
1513     transfer(&mut server, &mut client);
1514     client.process_new_packets().unwrap();
1515 
1516     assert_eq!(Some(expect_suite), client.get_negotiated_ciphersuite());
1517     assert_eq!(Some(expect_suite), server.get_negotiated_ciphersuite());
1518 
1519     transfer(&mut client, &mut server);
1520     server.process_new_packets().unwrap();
1521     transfer(&mut server, &mut client);
1522     client.process_new_packets().unwrap();
1523 
1524     assert_eq!(false, client.is_handshaking());
1525     assert_eq!(false, server.is_handshaking());
1526     assert_eq!(Some(expect_version), client.get_protocol_version());
1527     assert_eq!(Some(expect_version), server.get_protocol_version());
1528     assert_eq!(Some(expect_suite), client.get_negotiated_ciphersuite());
1529     assert_eq!(Some(expect_suite), server.get_negotiated_ciphersuite());
1530 }
1531 
find_suite(suite: CipherSuite) -> &'static SupportedCipherSuite1532 fn find_suite(suite: CipherSuite) -> &'static SupportedCipherSuite {
1533     for scs in ALL_CIPHERSUITES.iter() {
1534         if scs.suite == suite {
1535             return scs;
1536         }
1537     }
1538 
1539     panic!("find_suite given unsuppported suite");
1540 }
1541 
1542 static TEST_CIPHERSUITES: [(ProtocolVersion, KeyType, CipherSuite); 9] = [
1543     (ProtocolVersion::TLSv1_3, KeyType::RSA, CipherSuite::TLS13_CHACHA20_POLY1305_SHA256),
1544     (ProtocolVersion::TLSv1_3, KeyType::RSA, CipherSuite::TLS13_AES_256_GCM_SHA384),
1545     (ProtocolVersion::TLSv1_3, KeyType::RSA, CipherSuite::TLS13_AES_128_GCM_SHA256),
1546     (ProtocolVersion::TLSv1_2, KeyType::ECDSA, CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256),
1547     (ProtocolVersion::TLSv1_2, KeyType::RSA,   CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256),
1548     (ProtocolVersion::TLSv1_2, KeyType::ECDSA, CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384),
1549     (ProtocolVersion::TLSv1_2, KeyType::ECDSA, CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256),
1550     (ProtocolVersion::TLSv1_2, KeyType::RSA,   CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384),
1551     (ProtocolVersion::TLSv1_2, KeyType::RSA,   CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256)
1552 ];
1553 
1554 #[test]
negotiated_ciphersuite_default()1555 fn negotiated_ciphersuite_default() {
1556     for kt in ALL_KEY_TYPES.iter() {
1557         do_suite_test(make_client_config(*kt),
1558                       make_server_config(*kt),
1559                       find_suite(CipherSuite::TLS13_CHACHA20_POLY1305_SHA256),
1560                       ProtocolVersion::TLSv1_3);
1561     }
1562 }
1563 
1564 #[test]
all_suites_covered()1565 fn all_suites_covered() {
1566     assert_eq!(ALL_CIPHERSUITES.len(), TEST_CIPHERSUITES.len());
1567 }
1568 
1569 #[test]
negotiated_ciphersuite_client()1570 fn negotiated_ciphersuite_client() {
1571     for item in TEST_CIPHERSUITES.iter() {
1572         let (version, kt, suite) = *item;
1573         let scs = find_suite(suite);
1574         let mut client_config = make_client_config(kt);
1575         client_config.ciphersuites = vec![scs];
1576         client_config.versions = vec![version];
1577 
1578         do_suite_test(client_config,
1579                       make_server_config(kt),
1580                       scs,
1581                       version);
1582     }
1583 }
1584 
1585 #[test]
negotiated_ciphersuite_server()1586 fn negotiated_ciphersuite_server() {
1587     for item in TEST_CIPHERSUITES.iter() {
1588         let (version, kt, suite) = *item;
1589         let scs = find_suite(suite);
1590         let mut server_config = make_server_config(kt);
1591         server_config.ciphersuites = vec![scs];
1592         server_config.versions = vec![version];
1593 
1594         do_suite_test(make_client_config(kt),
1595                       server_config,
1596                       scs,
1597                       version);
1598     }
1599 }
1600 
1601 #[derive(Debug, PartialEq)]
1602 struct KeyLogItem {
1603     label: String,
1604     client_random: Vec<u8>,
1605     secret: Vec<u8>,
1606 }
1607 
1608 struct KeyLogToVec {
1609     label: &'static str,
1610     items: Mutex<Vec<KeyLogItem>>,
1611 }
1612 
1613 impl KeyLogToVec {
new(who: &'static str) -> Self1614     fn new(who: &'static str) -> Self {
1615         KeyLogToVec {
1616             label: who,
1617             items: Mutex::new(vec![]),
1618         }
1619     }
1620 
take(&self) -> Vec<KeyLogItem>1621     fn take(&self) -> Vec<KeyLogItem> {
1622         mem::replace(&mut self.items.lock()
1623                          .unwrap(),
1624                      vec![])
1625     }
1626 }
1627 
1628 impl KeyLog for KeyLogToVec {
log(&self, label: &str, client: &[u8], secret: &[u8])1629     fn log(&self, label: &str, client: &[u8], secret: &[u8]) {
1630         let value = KeyLogItem {
1631             label: label.into(),
1632             client_random: client.into(),
1633             secret: secret.into()
1634         };
1635 
1636         println!("key log {:?}: {:?}", self.label, value);
1637 
1638         self.items.lock()
1639             .unwrap()
1640             .push(value);
1641     }
1642 }
1643 
1644 #[test]
key_log_for_tls12()1645 fn key_log_for_tls12() {
1646     let client_key_log = Arc::new(KeyLogToVec::new("client"));
1647     let server_key_log = Arc::new(KeyLogToVec::new("server"));
1648 
1649     let kt = KeyType::RSA;
1650     let mut client_config = make_client_config(kt);
1651     client_config.versions = vec![ ProtocolVersion::TLSv1_2 ];
1652     client_config.key_log = client_key_log.clone();
1653     let client_config = Arc::new(client_config);
1654 
1655     let mut server_config = make_server_config(kt);
1656     server_config.key_log = server_key_log.clone();
1657     let server_config = Arc::new(server_config);
1658 
1659     // full handshake
1660     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
1661     do_handshake(&mut client, &mut server);
1662 
1663     let client_full_log = client_key_log.take();
1664     let server_full_log = server_key_log.take();
1665     assert_eq!(client_full_log, server_full_log);
1666     assert_eq!(1, client_full_log.len());
1667     assert_eq!("CLIENT_RANDOM", client_full_log[0].label);
1668 
1669     // resumed
1670     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
1671     do_handshake(&mut client, &mut server);
1672 
1673     let client_resume_log = client_key_log.take();
1674     let server_resume_log = server_key_log.take();
1675     assert_eq!(client_resume_log, server_resume_log);
1676     assert_eq!(1, client_resume_log.len());
1677     assert_eq!("CLIENT_RANDOM", client_resume_log[0].label);
1678     assert_eq!(client_full_log[0].secret, client_resume_log[0].secret);
1679 }
1680 
1681 #[test]
key_log_for_tls13()1682 fn key_log_for_tls13() {
1683     let client_key_log = Arc::new(KeyLogToVec::new("client"));
1684     let server_key_log = Arc::new(KeyLogToVec::new("server"));
1685 
1686     let kt = KeyType::RSA;
1687     let mut client_config = make_client_config(kt);
1688     client_config.versions = vec![ ProtocolVersion::TLSv1_3 ];
1689     client_config.key_log = client_key_log.clone();
1690     let client_config = Arc::new(client_config);
1691 
1692     let mut server_config = make_server_config(kt);
1693     server_config.key_log = server_key_log.clone();
1694     let server_config = Arc::new(server_config);
1695 
1696     // full handshake
1697     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
1698     do_handshake(&mut client, &mut server);
1699 
1700     let client_full_log = client_key_log.take();
1701     let server_full_log = server_key_log.take();
1702 
1703     assert_eq!(5, client_full_log.len());
1704     assert_eq!("CLIENT_HANDSHAKE_TRAFFIC_SECRET", client_full_log[0].label);
1705     assert_eq!("SERVER_HANDSHAKE_TRAFFIC_SECRET", client_full_log[1].label);
1706     assert_eq!("SERVER_TRAFFIC_SECRET_0", client_full_log[2].label);
1707     assert_eq!("EXPORTER_SECRET", client_full_log[3].label);
1708     assert_eq!("CLIENT_TRAFFIC_SECRET_0", client_full_log[4].label);
1709 
1710     assert_eq!(client_full_log[0], server_full_log[1]);
1711     assert_eq!(client_full_log[1], server_full_log[0]);
1712     assert_eq!(client_full_log[2], server_full_log[2]);
1713     assert_eq!(client_full_log[3], server_full_log[3]);
1714     assert_eq!(client_full_log[4], server_full_log[4]);
1715 
1716     // resumed
1717     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
1718     do_handshake(&mut client, &mut server);
1719 
1720     let client_resume_log = client_key_log.take();
1721     let server_resume_log = server_key_log.take();
1722 
1723     assert_eq!(5, client_resume_log.len());
1724     assert_eq!("CLIENT_HANDSHAKE_TRAFFIC_SECRET", client_resume_log[0].label);
1725     assert_eq!("SERVER_HANDSHAKE_TRAFFIC_SECRET", client_resume_log[1].label);
1726     assert_eq!("SERVER_TRAFFIC_SECRET_0", client_resume_log[2].label);
1727     assert_eq!("EXPORTER_SECRET", client_resume_log[3].label);
1728     assert_eq!("CLIENT_TRAFFIC_SECRET_0", client_resume_log[4].label);
1729 
1730     assert_eq!(client_resume_log[0], server_resume_log[1]);
1731     assert_eq!(client_resume_log[1], server_resume_log[0]);
1732     assert_eq!(client_resume_log[2], server_resume_log[2]);
1733     assert_eq!(client_resume_log[3], server_resume_log[3]);
1734     assert_eq!(client_resume_log[4], server_resume_log[4]);
1735 }
1736 
1737 #[test]
vectored_write_for_server_appdata()1738 fn vectored_write_for_server_appdata() {
1739     let (mut client, mut server) = make_pair(KeyType::RSA);
1740     do_handshake(&mut client, &mut server);
1741 
1742     server.write(b"01234567890123456789").unwrap();
1743     server.write(b"01234567890123456789").unwrap();
1744     {
1745         let mut pipe = OtherSession::new(&mut client);
1746         let wrlen = server.writev_tls(&mut pipe).unwrap();
1747         assert_eq!(84, wrlen);
1748         assert_eq!(pipe.writevs, vec![vec![42, 42]]);
1749     }
1750     check_read(&mut client, b"0123456789012345678901234567890123456789");
1751 }
1752 
1753 #[test]
vectored_write_for_client_appdata()1754 fn vectored_write_for_client_appdata() {
1755     let (mut client, mut server) = make_pair(KeyType::RSA);
1756     do_handshake(&mut client, &mut server);
1757 
1758     client.write(b"01234567890123456789").unwrap();
1759     client.write(b"01234567890123456789").unwrap();
1760     {
1761         let mut pipe = OtherSession::new(&mut server);
1762         let wrlen = client.writev_tls(&mut pipe).unwrap();
1763         assert_eq!(84, wrlen);
1764         assert_eq!(pipe.writevs, vec![vec![42, 42]]);
1765     }
1766     check_read(&mut server, b"0123456789012345678901234567890123456789");
1767 }
1768 
1769 #[test]
vectored_write_for_server_handshake()1770 fn vectored_write_for_server_handshake() {
1771     let (mut client, mut server) = make_pair(KeyType::RSA);
1772 
1773     server.write(b"01234567890123456789").unwrap();
1774     server.write(b"0123456789").unwrap();
1775 
1776     transfer(&mut client, &mut server);
1777     server.process_new_packets().unwrap();
1778     {
1779         let mut pipe = OtherSession::new(&mut client);
1780         let wrlen = server.writev_tls(&mut pipe).unwrap();
1781         // don't assert exact sizes here, to avoid a brittle test
1782         assert!(wrlen > 4000); // its pretty big (contains cert chain)
1783         assert_eq!(pipe.writevs.len(), 1); // only one writev
1784         assert!(pipe.writevs[0].len() > 3); // at least a server hello/cert/serverkx
1785     }
1786 
1787     client.process_new_packets().unwrap();
1788     transfer(&mut client, &mut server);
1789     server.process_new_packets().unwrap();
1790     {
1791         let mut pipe = OtherSession::new(&mut client);
1792         let wrlen = server.writev_tls(&mut pipe).unwrap();
1793         assert_eq!(wrlen, 177);
1794         assert_eq!(pipe.writevs, vec![vec![103, 42, 32]]);
1795     }
1796 
1797     assert_eq!(server.is_handshaking(), false);
1798     assert_eq!(client.is_handshaking(), false);
1799     check_read(&mut client, b"012345678901234567890123456789");
1800 }
1801 
1802 #[test]
vectored_write_for_client_handshake()1803 fn vectored_write_for_client_handshake() {
1804     let (mut client, mut server) = make_pair(KeyType::RSA);
1805 
1806     client.write(b"01234567890123456789").unwrap();
1807     client.write(b"0123456789").unwrap();
1808     {
1809         let mut pipe = OtherSession::new(&mut server);
1810         let wrlen = client.writev_tls(&mut pipe).unwrap();
1811         // don't assert exact sizes here, to avoid a brittle test
1812         assert!(wrlen > 200); // just the client hello
1813         assert_eq!(pipe.writevs.len(), 1); // only one writev
1814         assert!(pipe.writevs[0].len() == 1); // only a client hello
1815     }
1816 
1817     transfer(&mut server, &mut client);
1818     client.process_new_packets().unwrap();
1819 
1820     {
1821         let mut pipe = OtherSession::new(&mut server);
1822         let wrlen = client.writev_tls(&mut pipe).unwrap();
1823         assert_eq!(wrlen, 138);
1824         // CCS, finished, then two application datas
1825         assert_eq!(pipe.writevs, vec![vec![6, 58, 42, 32]]);
1826     }
1827 
1828     assert_eq!(server.is_handshaking(), false);
1829     assert_eq!(client.is_handshaking(), false);
1830     check_read(&mut server, b"012345678901234567890123456789");
1831 }
1832 
1833 #[test]
vectored_write_with_slow_client()1834 fn vectored_write_with_slow_client() {
1835     let (mut client, mut server) = make_pair(KeyType::RSA);
1836 
1837     client.set_buffer_limit(32);
1838 
1839     do_handshake(&mut client, &mut server);
1840     server.write(b"01234567890123456789").unwrap();
1841 
1842     {
1843         let mut pipe = OtherSession::new(&mut client);
1844         pipe.short_writes = true;
1845         let wrlen = server.writev_tls(&mut pipe).unwrap() +
1846             server.writev_tls(&mut pipe).unwrap() +
1847             server.writev_tls(&mut pipe).unwrap() +
1848             server.writev_tls(&mut pipe).unwrap() +
1849             server.writev_tls(&mut pipe).unwrap() +
1850             server.writev_tls(&mut pipe).unwrap();
1851         assert_eq!(42, wrlen);
1852         assert_eq!(pipe.writevs, vec![vec![21], vec![10], vec![5], vec![3], vec![3]]);
1853     }
1854     check_read(&mut client, b"01234567890123456789");
1855 }
1856 
1857 struct ServerStorage {
1858     storage: Arc<dyn rustls::StoresServerSessions>,
1859     put_count: AtomicUsize,
1860     get_count: AtomicUsize,
1861     take_count: AtomicUsize,
1862 }
1863 
1864 impl ServerStorage {
new() -> ServerStorage1865     fn new() -> ServerStorage {
1866         ServerStorage {
1867             storage: rustls::ServerSessionMemoryCache::new(1024),
1868             put_count: AtomicUsize::new(0),
1869             get_count: AtomicUsize::new(0),
1870             take_count: AtomicUsize::new(0),
1871         }
1872     }
1873 
puts(&self) -> usize1874     fn puts(&self) -> usize { self.put_count.load(Ordering::SeqCst) }
gets(&self) -> usize1875     fn gets(&self) -> usize { self.get_count.load(Ordering::SeqCst) }
takes(&self) -> usize1876     fn takes(&self) -> usize { self.take_count.load(Ordering::SeqCst) }
1877 }
1878 
1879 impl fmt::Debug for ServerStorage {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result1880     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1881         write!(f, "(put: {:?}, get: {:?}, take: {:?})",
1882                self.put_count, self.get_count, self.take_count)
1883     }
1884 }
1885 
1886 impl rustls::StoresServerSessions for ServerStorage {
put(&self, key: Vec<u8>, value: Vec<u8>) -> bool1887     fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
1888         self.put_count.fetch_add(1, Ordering::SeqCst);
1889         self.storage.put(key, value)
1890     }
1891 
get(&self, key: &[u8]) -> Option<Vec<u8>>1892     fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
1893         self.get_count.fetch_add(1, Ordering::SeqCst);
1894         self.storage.get(key)
1895     }
1896 
take(&self, key: &[u8]) -> Option<Vec<u8>>1897     fn take(&self, key: &[u8]) -> Option<Vec<u8>> {
1898         self.take_count.fetch_add(1, Ordering::SeqCst);
1899         self.storage.take(key)
1900     }
1901 }
1902 
1903 #[test]
tls13_stateful_resumption()1904 fn tls13_stateful_resumption() {
1905     let kt = KeyType::RSA;
1906     let mut client_config = make_client_config(kt);
1907     client_config.versions = vec![ ProtocolVersion::TLSv1_3 ];
1908     let client_config = Arc::new(client_config);
1909 
1910     let mut server_config = make_server_config(kt);
1911     let storage = Arc::new(ServerStorage::new());
1912     server_config.session_storage = storage.clone();
1913     let server_config = Arc::new(server_config);
1914 
1915     // full handshake
1916     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
1917     let (full_c2s, full_s2c) = do_handshake(&mut client, &mut server);
1918     assert_eq!(storage.puts(), 1);
1919     assert_eq!(storage.gets(), 0);
1920     assert_eq!(storage.takes(), 0);
1921 
1922     // resumed
1923     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
1924     let (resume_c2s, resume_s2c) = do_handshake(&mut client, &mut server);
1925     assert!(resume_c2s > full_c2s);
1926     assert!(resume_s2c < full_s2c);
1927     assert_eq!(storage.puts(), 2);
1928     assert_eq!(storage.gets(), 0);
1929     assert_eq!(storage.takes(), 1);
1930 
1931     // resumed again
1932     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
1933     let (resume2_c2s, resume2_s2c) = do_handshake(&mut client, &mut server);
1934     assert_eq!(resume_s2c, resume2_s2c);
1935     assert_eq!(resume_c2s, resume2_c2s);
1936     assert_eq!(storage.puts(), 3);
1937     assert_eq!(storage.gets(), 0);
1938     assert_eq!(storage.takes(), 2);
1939 }
1940 
1941 #[test]
tls13_stateless_resumption()1942 fn tls13_stateless_resumption() {
1943     let kt = KeyType::RSA;
1944     let mut client_config = make_client_config(kt);
1945     client_config.versions = vec![ ProtocolVersion::TLSv1_3 ];
1946     let client_config = Arc::new(client_config);
1947 
1948     let mut server_config = make_server_config(kt);
1949     server_config.ticketer = rustls::Ticketer::new();
1950     let storage = Arc::new(ServerStorage::new());
1951     server_config.session_storage = storage.clone();
1952     let server_config = Arc::new(server_config);
1953 
1954     // full handshake
1955     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
1956     let (full_c2s, full_s2c) = do_handshake(&mut client, &mut server);
1957     assert_eq!(storage.puts(), 0);
1958     assert_eq!(storage.gets(), 0);
1959     assert_eq!(storage.takes(), 0);
1960 
1961     // resumed
1962     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
1963     let (resume_c2s, resume_s2c) = do_handshake(&mut client, &mut server);
1964     assert!(resume_c2s > full_c2s);
1965     assert!(resume_s2c < full_s2c);
1966     assert_eq!(storage.puts(), 0);
1967     assert_eq!(storage.gets(), 0);
1968     assert_eq!(storage.takes(), 0);
1969 
1970     // resumed again
1971     let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config);
1972     let (resume2_c2s, resume2_s2c) = do_handshake(&mut client, &mut server);
1973     assert_eq!(resume_s2c, resume2_s2c);
1974     assert_eq!(resume_c2s, resume2_c2s);
1975     assert_eq!(storage.puts(), 0);
1976     assert_eq!(storage.gets(), 0);
1977     assert_eq!(storage.takes(), 0);
1978 }
1979 
1980 #[cfg(feature = "quic")]
1981 mod test_quic {
1982     use super::*;
1983 
1984     // Returns the sender's next secrets to use, or the receiver's error.
step(send: &mut dyn Session, recv: &mut dyn Session) -> Result<Option<quic::Secrets>, TLSError>1985     fn step(send: &mut dyn Session, recv: &mut dyn Session) -> Result<Option<quic::Secrets>, TLSError> {
1986         let mut buf = Vec::new();
1987         let secrets = loop {
1988             let prev = buf.len();
1989             if let Some(x) = send.write_hs(&mut buf) {
1990                 break Some(x);
1991             }
1992             if prev == buf.len() {
1993                 break None;
1994             }
1995         };
1996         if let Err(e) = recv.read_hs(&buf) {
1997             return Err(e);
1998         } else {
1999             assert_eq!(recv.get_alert(), None);
2000         }
2001         Ok(secrets)
2002     }
2003 
2004 #[test]
test_quic_handshake()2005     fn test_quic_handshake() {
2006         fn equal_prk(x: &ring::hkdf::Prk, y: &ring::hkdf::Prk) -> bool {
2007             let mut x_data = [0; 16];
2008             let mut y_data = [0; 16];
2009             let x_okm = x.expand(&[b"info"], &ring::aead::quic::AES_128).unwrap();
2010             x_okm.fill(&mut x_data[..]).unwrap();
2011             let y_okm = y.expand(&[b"info"], &ring::aead::quic::AES_128).unwrap();
2012             y_okm.fill(&mut y_data[..]).unwrap();
2013             x_data == y_data
2014         }
2015 
2016         fn equal_secrets(x: &quic::Secrets, y: &quic::Secrets) -> bool {
2017             equal_prk(&x.client, &y.client) && equal_prk(&x.server, &y.server)
2018         }
2019 
2020         let kt = KeyType::RSA;
2021         let mut client_config = make_client_config(kt);
2022         client_config.versions = vec![ProtocolVersion::TLSv1_3];
2023         client_config.enable_early_data = true;
2024         let client_config = Arc::new(client_config);
2025         let mut server_config = make_server_config(kt);
2026         server_config.versions = vec![ProtocolVersion::TLSv1_3];
2027         server_config.max_early_data_size = 0xffffffff;
2028         server_config.alpn_protocols = vec!["foo".into()];
2029         let server_config = Arc::new(server_config);
2030         let client_params = &b"client params"[..];
2031         let server_params = &b"server params"[..];
2032 
2033         // full handshake
2034         let mut client =
2035             ClientSession::new_quic(&client_config, dns_name("localhost"), client_params.into());
2036         let mut server = ServerSession::new_quic(&server_config, server_params.into());
2037         let client_initial = step(&mut client, &mut server).unwrap();
2038         assert!(client_initial.is_none());
2039         assert!(client.get_early_secret().is_none());
2040         assert_eq!(server.get_quic_transport_parameters(), Some(client_params));
2041         let server_hs = step(&mut server, &mut client).unwrap().unwrap();
2042         assert!(server.get_early_secret().is_none());
2043         let client_hs = step(&mut client, &mut server).unwrap().unwrap();
2044         assert!(equal_secrets(&server_hs, &client_hs));
2045         assert!(client.is_handshaking());
2046         let server_1rtt = step(&mut server, &mut client).unwrap().unwrap();
2047         assert!(!client.is_handshaking());
2048         assert_eq!(client.get_quic_transport_parameters(), Some(server_params));
2049         assert!(server.is_handshaking());
2050         let client_1rtt = step(&mut client, &mut server).unwrap().unwrap();
2051         assert!(!server.is_handshaking());
2052         assert!(equal_secrets(&server_1rtt, &client_1rtt));
2053         assert!(!equal_secrets(&server_hs, &server_1rtt));
2054         assert!(step(&mut client, &mut server).unwrap().is_none());
2055         assert!(step(&mut server, &mut client).unwrap().is_none());
2056 
2057         // key update
2058         let initial = quic::Secrets {
2059             // Constant dummy values for reproducibility
2060             client: hkdf::Prk::new_less_safe(hkdf::HKDF_SHA256, &[
2061                 0xb8, 0x76, 0x77, 0x08, 0xf8, 0x77, 0x23, 0x58, 0xa6, 0xea, 0x9f, 0xc4, 0x3e, 0x4a,
2062                 0xdd, 0x2c, 0x96, 0x1b, 0x3f, 0x52, 0x87, 0xa6, 0xd1, 0x46, 0x7e, 0xe0, 0xae, 0xab,
2063                 0x33, 0x72, 0x4d, 0xbf,
2064             ]),
2065             server: hkdf::Prk::new_less_safe(hkdf::HKDF_SHA256, &[
2066                 0x42, 0xdc, 0x97, 0x21, 0x40, 0xe0, 0xf2, 0xe3, 0x98, 0x45, 0xb7, 0x67, 0x61, 0x34,
2067                 0x39, 0xdc, 0x67, 0x58, 0xca, 0x43, 0x25, 0x9b, 0x87, 0x85, 0x06, 0x82, 0x4e, 0xb1,
2068                 0xe4, 0x38, 0xd8, 0x55,
2069             ]),
2070         };
2071         let updated = client.update_secrets(&initial.client, &initial.server);
2072         // The expected values will need to be updated if the negotiated hash function changes. Pull the
2073         // values from ring's `hmac::Key::construct` with a debugger.
2074         assert!(equal_prk(
2075             &updated.client,
2076             &hkdf::Prk::new_less_safe(hkdf::HKDF_SHA256, &[
2077                 0x42, 0xca, 0xc8, 0xc9, 0x1c, 0xd5, 0xeb, 0x40, 0x68, 0x2e, 0x43,
2078                 0x2e, 0xdf, 0x2d, 0x2b, 0xe9, 0xf4, 0x1a, 0x52, 0xca, 0x6b, 0x22, 0xd8, 0xe6, 0xcd, 0xb1,
2079                 0xe8, 0xac, 0xa9, 0x6, 0x1f, 0xce
2080             ]))
2081         );
2082         assert!(equal_prk(
2083             &updated.server,
2084             &hkdf::Prk::new_less_safe(hkdf::HKDF_SHA256, &[
2085                 0xeb, 0x7f, 0x5e, 0x2a, 0x12, 0x3f, 0x40, 0x7d, 0xb4, 0x99, 0xe3,
2086                 0x61, 0xca, 0xe5, 0x90, 0xd4, 0xd9, 0x92, 0xe1, 0x4b, 0x7a, 0xce, 0x3, 0xc2, 0x44, 0xe0,
2087                 0x42, 0x21, 0x15, 0xb6, 0xd3, 0x8a
2088             ]))
2089         );
2090 
2091         // 0-RTT handshake
2092         let mut client =
2093             ClientSession::new_quic(&client_config, dns_name("localhost"), client_params.into());
2094         assert!(client.get_negotiated_ciphersuite().is_some());
2095         let mut server = ServerSession::new_quic(&server_config, server_params.into());
2096         step(&mut client, &mut server).unwrap();
2097         assert_eq!(client.get_quic_transport_parameters(), Some(server_params));
2098         {
2099             let client_early = client.get_early_secret().unwrap();
2100             let server_early = server.get_early_secret().unwrap();
2101             assert!(equal_prk(client_early, server_early));
2102         }
2103         step(&mut server, &mut client).unwrap().unwrap();
2104         step(&mut client, &mut server).unwrap().unwrap();
2105         step(&mut server, &mut client).unwrap().unwrap();
2106         assert!(client.is_early_data_accepted());
2107 
2108         // 0-RTT rejection
2109         {
2110             let mut client_config = (*client_config).clone();
2111             client_config.alpn_protocols = vec!["foo".into()];
2112             let mut client =
2113                 ClientSession::new_quic(&Arc::new(client_config), dns_name("localhost"), client_params.into());
2114             let mut server = ServerSession::new_quic(&server_config, server_params.into());
2115             step(&mut client, &mut server).unwrap();
2116             assert_eq!(client.get_quic_transport_parameters(), Some(server_params));
2117             assert!(client.get_early_secret().is_some());
2118             assert!(server.get_early_secret().is_none());
2119             step(&mut server, &mut client).unwrap().unwrap();
2120             step(&mut client, &mut server).unwrap().unwrap();
2121             step(&mut server, &mut client).unwrap().unwrap();
2122             assert!(!client.is_early_data_accepted());
2123         }
2124 
2125         // failed handshake
2126         let mut client = ClientSession::new_quic(
2127             &client_config,
2128             dns_name("example.com"),
2129             client_params.into(),
2130         );
2131         let mut server = ServerSession::new_quic(&server_config, server_params.into());
2132         step(&mut client, &mut server).unwrap();
2133         step(&mut server, &mut client).unwrap().unwrap();
2134         step(&mut server, &mut client).unwrap_err();
2135         assert_eq!(
2136             client.get_alert(),
2137             Some(rustls::internal::msgs::enums::AlertDescription::BadCertificate)
2138         );
2139     }
2140 
2141 #[test]
test_quic_rejects_missing_alpn()2142     fn test_quic_rejects_missing_alpn() {
2143         let client_params = &b"client params"[..];
2144         let server_params = &b"server params"[..];
2145 
2146         for &kt in ALL_KEY_TYPES.iter() {
2147             let mut client_config = make_client_config(kt);
2148             client_config.versions = vec![ProtocolVersion::TLSv1_3];
2149             client_config.alpn_protocols = vec!["bar".into()];
2150             let client_config = Arc::new(client_config);
2151 
2152             let mut server_config = make_server_config(kt);
2153             server_config.versions = vec![ProtocolVersion::TLSv1_3];
2154             server_config.alpn_protocols = vec!["foo".into()];
2155             let server_config = Arc::new(server_config);
2156 
2157             let mut client = ClientSession::new_quic(&client_config,
2158                                                      dns_name("localhost"),
2159                                                      client_params.into());
2160             let mut server = ServerSession::new_quic(&server_config,
2161                                                      server_params.into());
2162 
2163             assert_eq!(step(&mut client, &mut server).unwrap_err(),
2164                        TLSError::NoApplicationProtocol);
2165 
2166             assert_eq!(server.get_alert(),
2167                        Some(rustls::internal::msgs::enums::AlertDescription::NoApplicationProtocol));
2168         }
2169     }
2170 } // mod test_quic
2171 
2172 #[test]
test_client_does_not_offer_sha1()2173 fn test_client_does_not_offer_sha1() {
2174     use rustls::internal::msgs::{message::Message, message::MessagePayload,
2175         handshake::HandshakePayload, enums::HandshakeType, codec::Codec};
2176 
2177     for kt in ALL_KEY_TYPES.iter() {
2178         for client_config in AllClientVersions::new(make_client_config(*kt)) {
2179             let (mut client, _) = make_pair_for_configs(client_config,
2180                                                         make_server_config(*kt));
2181 
2182             assert!(client.wants_write());
2183             let mut buf = [0u8; 262144];
2184             let sz = client.write_tls(&mut buf.as_mut())
2185                 .unwrap();
2186             let mut msg = Message::read_bytes(&buf[..sz])
2187                 .unwrap();
2188             assert!(msg.decode_payload());
2189             assert!(msg.is_handshake_type(HandshakeType::ClientHello));
2190 
2191             let client_hello = match msg.payload {
2192                 MessagePayload::Handshake(hs) => match hs.payload {
2193                     HandshakePayload::ClientHello(ch) => ch,
2194                     _ => unreachable!()
2195                 }
2196                 _ => unreachable!()
2197             };
2198 
2199             let sigalgs = client_hello.get_sigalgs_extension()
2200                 .unwrap();
2201             assert_eq!(sigalgs.contains(&SignatureScheme::RSA_PKCS1_SHA1), false,
2202                        "sha1 unexpectedly offered");
2203         }
2204     }
2205 }
2206 
2207 #[test]
test_client_mtu_reduction()2208 fn test_client_mtu_reduction() {
2209     fn collect_write_lengths(client: &mut ClientSession) -> Vec<usize> {
2210         let mut r = Vec::new();
2211         let mut buf = [0u8; 128];
2212 
2213         loop {
2214             let sz = client.write_tls(&mut buf.as_mut())
2215                 .unwrap();
2216             r.push(sz);
2217             assert!(sz <= 64);
2218             if sz < 64 {
2219                 break;
2220             }
2221         }
2222 
2223         r
2224     }
2225 
2226     for kt in ALL_KEY_TYPES.iter() {
2227         let mut client_config = make_client_config(*kt);
2228         client_config.set_mtu(&Some(64));
2229 
2230         let mut client = ClientSession::new(&Arc::new(client_config), dns_name("localhost"));
2231         let writes = collect_write_lengths(&mut client);
2232         assert!(writes.iter().all(|x| *x <= 64));
2233         assert!(writes.len() > 1);
2234     }
2235 }
2236 
2237 #[test]
exercise_key_log_file_for_client()2238 fn exercise_key_log_file_for_client() {
2239     let server_config = Arc::new(make_server_config(KeyType::RSA));
2240     let mut client_config = make_client_config(KeyType::RSA);
2241     env::set_var("SSLKEYLOGFILE", "./sslkeylogfile.txt");
2242     client_config.key_log = Arc::new(rustls::KeyLogFile::new());
2243 
2244     for client_config in AllClientVersions::new(client_config) {
2245         let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config),
2246                                                                  &server_config);
2247 
2248         assert_eq!(5, client.write(b"hello").unwrap());
2249 
2250         do_handshake(&mut client, &mut server);
2251         transfer(&mut client, &mut server);
2252         server.process_new_packets().unwrap();
2253     }
2254 }
2255 
2256 #[test]
exercise_key_log_file_for_server()2257 fn exercise_key_log_file_for_server() {
2258     let mut server_config = make_server_config(KeyType::RSA);
2259 
2260     env::set_var("SSLKEYLOGFILE", "./sslkeylogfile.txt");
2261     server_config.key_log = Arc::new(rustls::KeyLogFile::new());
2262 
2263     let server_config = Arc::new(server_config);
2264 
2265     for client_config in AllClientVersions::new(make_client_config(KeyType::RSA)) {
2266         let (mut client, mut server) = make_pair_for_arc_configs(&Arc::new(client_config),
2267                                                                  &server_config);
2268 
2269         assert_eq!(5, client.write(b"hello").unwrap());
2270 
2271         do_handshake(&mut client, &mut server);
2272         transfer(&mut client, &mut server);
2273         server.process_new_packets().unwrap();
2274     }
2275 }
2276 
assert_lt(left: usize, right: usize)2277 fn assert_lt(left: usize, right: usize) {
2278     if left >= right {
2279         panic!("expected {} < {}", left, right);
2280     }
2281 }
2282 
2283 #[test]
session_types_are_not_huge()2284 fn session_types_are_not_huge() {
2285     // Arbitrary sizes
2286     assert_lt(mem::size_of::<ServerSession>(), 1600);
2287     assert_lt(mem::size_of::<ClientSession>(), 1600);
2288 }
2289 
2290 use rustls::internal::msgs::{message::Message, message::MessagePayload,
2291     handshake::HandshakePayload, handshake::ClientExtension};
2292 
2293 #[test]
test_server_rejects_duplicate_sni_names()2294 fn test_server_rejects_duplicate_sni_names() {
2295     fn duplicate_sni_payload(msg: &mut Message) {
2296         if let MessagePayload::Handshake(hs) = &mut msg.payload {
2297             if let HandshakePayload::ClientHello(ch) = &mut hs.payload {
2298                 for mut ext in ch.extensions.iter_mut() {
2299                     if let ClientExtension::ServerName(snr) = &mut ext {
2300                         snr.push(snr[0].clone());
2301                     }
2302                 }
2303             }
2304         }
2305     }
2306 
2307     let (mut client, mut server) = make_pair(KeyType::RSA);
2308     transfer_altered(&mut client, duplicate_sni_payload, &mut server);
2309     assert_eq!(server.process_new_packets(),
2310                Err(TLSError::PeerMisbehavedError("ClientHello SNI contains duplicate name types".into())));
2311 }
2312 
2313 #[test]
test_server_rejects_empty_sni_extension()2314 fn test_server_rejects_empty_sni_extension() {
2315     fn empty_sni_payload(msg: &mut Message) {
2316         if let MessagePayload::Handshake(hs) = &mut msg.payload {
2317             if let HandshakePayload::ClientHello(ch) = &mut hs.payload {
2318                 for mut ext in ch.extensions.iter_mut() {
2319                     if let ClientExtension::ServerName(snr) = &mut ext {
2320                         snr.clear();
2321                     }
2322                 }
2323             }
2324         }
2325     }
2326 
2327     let (mut client, mut server) = make_pair(KeyType::RSA);
2328     transfer_altered(&mut client, empty_sni_payload, &mut server);
2329     assert_eq!(server.process_new_packets(),
2330                Err(TLSError::PeerMisbehavedError("ClientHello SNI did not contain a hostname".into())));
2331 }
2332 
2333 #[test]
test_server_rejects_clients_without_any_kx_group_overlap()2334 fn test_server_rejects_clients_without_any_kx_group_overlap() {
2335     fn different_kx_group(msg: &mut Message) {
2336         if let MessagePayload::Handshake(hs) = &mut msg.payload {
2337             if let HandshakePayload::ClientHello(ch) = &mut hs.payload {
2338                 for mut ext in ch.extensions.iter_mut() {
2339                     if let ClientExtension::NamedGroups(ngs) = &mut ext {
2340                         ngs.clear();
2341                     }
2342                     if let ClientExtension::KeyShare(ks) = &mut ext {
2343                         ks.clear();
2344                     }
2345                 }
2346             }
2347         }
2348     }
2349 
2350     let (mut client, mut server) = make_pair(KeyType::RSA);
2351     transfer_altered(&mut client, different_kx_group, &mut server);
2352     assert_eq!(server.process_new_packets(),
2353                Err(TLSError::PeerIncompatibleError("no kx group overlap with client".into())));
2354 }
2355