1 #![allow(dead_code)]
2 
3 use neqo_common::qinfo;
4 use neqo_crypto::{
5     AntiReplay, AuthenticationStatus, Client, HandshakeState, RecordList, Res, ResumptionToken,
6     SecretAgent, Server, ZeroRttCheckResult, ZeroRttChecker,
7 };
8 use std::mem;
9 use std::time::Instant;
10 use test_fixture::{anti_replay, fixture_init, now};
11 
12 /// Consume records until the handshake state changes.
forward_records( now: Instant, agent: &mut SecretAgent, records_in: RecordList, ) -> Res<RecordList>13 pub fn forward_records(
14     now: Instant,
15     agent: &mut SecretAgent,
16     records_in: RecordList,
17 ) -> Res<RecordList> {
18     let mut expected_state = match agent.state() {
19         HandshakeState::New => HandshakeState::New,
20         _ => HandshakeState::InProgress,
21     };
22     let mut records_out = RecordList::default();
23     for record in records_in {
24         assert_eq!(records_out.len(), 0);
25         assert_eq!(*agent.state(), expected_state);
26 
27         records_out = agent.handshake_raw(now, Some(record))?;
28         expected_state = HandshakeState::InProgress;
29     }
30     Ok(records_out)
31 }
32 
handshake(now: Instant, client: &mut SecretAgent, server: &mut SecretAgent)33 fn handshake(now: Instant, client: &mut SecretAgent, server: &mut SecretAgent) {
34     let mut a = client;
35     let mut b = server;
36     let mut records = a.handshake_raw(now, None).unwrap();
37     let is_done = |agent: &mut SecretAgent| agent.state().is_final();
38     while !is_done(b) {
39         records = if let Ok(r) = forward_records(now, &mut b, records) {
40             r
41         } else {
42             // TODO(mt) take the alert generated by the failed handshake
43             // and allow it to be sent to the peer.
44             return;
45         };
46 
47         if *b.state() == HandshakeState::AuthenticationPending {
48             b.authenticated(AuthenticationStatus::Ok);
49             records = if let Ok(r) = b.handshake_raw(now, None) {
50                 r
51             } else {
52                 // TODO(mt) - as above.
53                 return;
54             }
55         }
56         mem::swap(&mut a, &mut b);
57     }
58 }
59 
connect_at(now: Instant, client: &mut SecretAgent, server: &mut SecretAgent)60 pub fn connect_at(now: Instant, client: &mut SecretAgent, server: &mut SecretAgent) {
61     handshake(now, client, server);
62     qinfo!("client: {:?}", client.state());
63     qinfo!("server: {:?}", server.state());
64     assert!(client.state().is_connected());
65     assert!(server.state().is_connected());
66 }
67 
connect(client: &mut SecretAgent, server: &mut SecretAgent)68 pub fn connect(client: &mut SecretAgent, server: &mut SecretAgent) {
69     connect_at(now(), client, server);
70 }
71 
connect_fail(client: &mut SecretAgent, server: &mut SecretAgent)72 pub fn connect_fail(client: &mut SecretAgent, server: &mut SecretAgent) {
73     handshake(now(), client, server);
74     assert!(!client.state().is_connected());
75     assert!(!server.state().is_connected());
76 }
77 
78 #[derive(Clone, Copy, Debug)]
79 pub enum Resumption {
80     WithoutZeroRtt,
81     WithZeroRtt,
82 }
83 
84 pub const ZERO_RTT_TOKEN_DATA: &[u8] = b"zero-rtt-token";
85 
86 #[derive(Debug)]
87 pub struct PermissiveZeroRttChecker {
88     resuming: bool,
89 }
90 
91 impl Default for PermissiveZeroRttChecker {
default() -> Self92     fn default() -> Self {
93         Self { resuming: true }
94     }
95 }
96 
97 impl ZeroRttChecker for PermissiveZeroRttChecker {
check(&self, token: &[u8]) -> ZeroRttCheckResult98     fn check(&self, token: &[u8]) -> ZeroRttCheckResult {
99         if self.resuming {
100             assert_eq!(ZERO_RTT_TOKEN_DATA, token);
101         } else {
102             assert!(token.is_empty());
103         }
104         ZeroRttCheckResult::Accept
105     }
106 }
107 
zero_rtt_setup( mode: Resumption, client: &mut Client, server: &mut Server, ) -> Option<AntiReplay>108 fn zero_rtt_setup(
109     mode: Resumption,
110     client: &mut Client,
111     server: &mut Server,
112 ) -> Option<AntiReplay> {
113     if let Resumption::WithZeroRtt = mode {
114         client.enable_0rtt().expect("should enable 0-RTT on client");
115 
116         let anti_replay = anti_replay();
117         server
118             .enable_0rtt(
119                 &anti_replay,
120                 0xffff_ffff,
121                 Box::new(PermissiveZeroRttChecker { resuming: false }),
122             )
123             .expect("should enable 0-RTT on server");
124         Some(anti_replay)
125     } else {
126         None
127     }
128 }
129 
resumption_setup(mode: Resumption) -> (Option<AntiReplay>, ResumptionToken)130 pub fn resumption_setup(mode: Resumption) -> (Option<AntiReplay>, ResumptionToken) {
131     fixture_init();
132 
133     let mut client = Client::new("server.example").expect("should create client");
134     let mut server = Server::new(&["key"]).expect("should create server");
135     let anti_replay = zero_rtt_setup(mode, &mut client, &mut server);
136 
137     connect(&mut client, &mut server);
138 
139     assert!(!client.info().unwrap().resumed());
140     assert!(!server.info().unwrap().resumed());
141     assert!(!client.info().unwrap().early_data_accepted());
142     assert!(!server.info().unwrap().early_data_accepted());
143 
144     let server_records = server
145         .send_ticket(now(), ZERO_RTT_TOKEN_DATA)
146         .expect("ticket sent");
147     assert_eq!(server_records.len(), 1);
148     let client_records = client
149         .handshake_raw(now(), server_records.into_iter().next())
150         .expect("records ingested");
151     assert_eq!(client_records.len(), 0);
152 
153     // `client` is about to go out of scope,
154     // but we only need to keep the resumption token, so clone it.
155     let token = client.resumption_token().expect("token is present");
156     (anti_replay, token)
157 }
158