1 #[cfg(feature = "logging")]
2 use crate::log::warn;
3 use crate::msgs::codec::Codec;
4 use crate::msgs::handshake::HandshakeMessagePayload;
5 use crate::msgs::message::{Message, MessagePayload};
6 use ring::digest;
7 use std::mem;
8 
9 /// This deals with keeping a running hash of the handshake
10 /// payloads.  This is computed by buffering initially.  Once
11 /// we know what hash function we need to use we switch to
12 /// incremental hashing.
13 ///
14 /// For client auth, we also need to buffer all the messages.
15 /// This is disabled in cases where client auth is not possible.
16 pub struct HandshakeHash {
17     /// None before we know what hash function we're using
18     alg: Option<&'static digest::Algorithm>,
19 
20     /// None before we know what hash function we're using
21     ctx: Option<digest::Context>,
22 
23     /// true if we need to keep all messages
24     client_auth_enabled: bool,
25 
26     /// buffer for pre-hashing stage and client-auth.
27     buffer: Vec<u8>,
28 }
29 
30 impl HandshakeHash {
new() -> HandshakeHash31     pub fn new() -> HandshakeHash {
32         HandshakeHash {
33             alg: None,
34             ctx: None,
35             client_auth_enabled: false,
36             buffer: Vec::new(),
37         }
38     }
39 
40     /// We might be doing client auth, so need to keep a full
41     /// log of the handshake.
set_client_auth_enabled(&mut self)42     pub fn set_client_auth_enabled(&mut self) {
43         debug_assert!(self.ctx.is_none()); // or we might have already discarded messages
44         self.client_auth_enabled = true;
45     }
46 
47     /// We decided not to do client auth after all, so discard
48     /// the transcript.
abandon_client_auth(&mut self)49     pub fn abandon_client_auth(&mut self) {
50         self.client_auth_enabled = false;
51         self.buffer.drain(..);
52     }
53 
54     /// We now know what hash function the verify_data will use.
start_hash(&mut self, alg: &'static digest::Algorithm) -> bool55     pub fn start_hash(&mut self, alg: &'static digest::Algorithm) -> bool {
56         match self.alg {
57             None => {}
58             Some(started) => {
59                 if started != alg {
60                     // hash type is changing
61                     warn!("altered hash to HandshakeHash::start_hash");
62                     return false;
63                 }
64 
65                 return true;
66             }
67         }
68         self.alg = Some(alg);
69         debug_assert!(self.ctx.is_none());
70 
71         let mut ctx = digest::Context::new(alg);
72         ctx.update(&self.buffer);
73         self.ctx = Some(ctx);
74 
75         // Discard buffer if we don't need it now.
76         if !self.client_auth_enabled {
77             self.buffer.drain(..);
78         }
79         true
80     }
81 
82     /// Hash/buffer a handshake message.
add_message(&mut self, m: &Message) -> &mut HandshakeHash83     pub fn add_message(&mut self, m: &Message) -> &mut HandshakeHash {
84         match m.payload {
85             MessagePayload::Handshake(ref hs) => {
86                 let buf = hs.get_encoding();
87                 self.update_raw(&buf);
88             }
89             _ => {}
90         };
91         self
92     }
93 
94     /// Hash or buffer a byte slice.
update_raw(&mut self, buf: &[u8]) -> &mut Self95     fn update_raw(&mut self, buf: &[u8]) -> &mut Self {
96         if self.ctx.is_some() {
97             self.ctx.as_mut().unwrap().update(buf);
98         }
99 
100         if self.ctx.is_none() || self.client_auth_enabled {
101             self.buffer.extend_from_slice(buf);
102         }
103 
104         self
105     }
106 
107     /// Get the hash value if we were to hash `extra` too,
108     /// using hash function `hash`.
get_hash_given(&self, hash: &'static digest::Algorithm, extra: &[u8]) -> Vec<u8>109     pub fn get_hash_given(&self, hash: &'static digest::Algorithm, extra: &[u8]) -> Vec<u8> {
110         let mut ctx = if self.ctx.is_none() {
111             let mut ctx = digest::Context::new(hash);
112             ctx.update(&self.buffer);
113             ctx
114         } else {
115             self.ctx.as_ref().unwrap().clone()
116         };
117 
118         ctx.update(extra);
119         let hash = ctx.finish();
120         let mut ret = Vec::new();
121         ret.extend_from_slice(hash.as_ref());
122         ret
123     }
124 
125     /// Take the current hash value, and encapsulate it in a
126     /// 'handshake_hash' handshake message.  Start this hash
127     /// again, with that message at the front.
rollup_for_hrr(&mut self)128     pub fn rollup_for_hrr(&mut self) {
129         let old_hash = self.ctx.take().unwrap().finish();
130         let old_handshake_hash_msg =
131             HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
132 
133         self.ctx = Some(digest::Context::new(self.alg.unwrap()));
134         self.update_raw(&old_handshake_hash_msg.get_encoding());
135     }
136 
137     /// Get the current hash value.
get_current_hash(&self) -> Vec<u8>138     pub fn get_current_hash(&self) -> Vec<u8> {
139         let hash = self
140             .ctx
141             .as_ref()
142             .unwrap()
143             .clone()
144             .finish();
145         let mut ret = Vec::new();
146         ret.extend_from_slice(hash.as_ref());
147         ret
148     }
149 
150     /// Takes this object's buffer containing all handshake messages
151     /// so far.  This method only works once; it resets the buffer
152     /// to empty.
take_handshake_buf(&mut self) -> Vec<u8>153     pub fn take_handshake_buf(&mut self) -> Vec<u8> {
154         debug_assert!(self.client_auth_enabled);
155         mem::replace(&mut self.buffer, Vec::new())
156     }
157 }
158 
159 #[cfg(test)]
160 mod test {
161     use super::HandshakeHash;
162     use ring::digest;
163 
164     #[test]
hashes_correctly()165     fn hashes_correctly() {
166         let mut hh = HandshakeHash::new();
167         hh.update_raw(b"hello");
168         assert_eq!(hh.buffer.len(), 5);
169         hh.start_hash(&digest::SHA256);
170         assert_eq!(hh.buffer.len(), 0);
171         hh.update_raw(b"world");
172         let h = hh.get_current_hash();
173         assert_eq!(h[0], 0x93);
174         assert_eq!(h[1], 0x6a);
175         assert_eq!(h[2], 0x18);
176         assert_eq!(h[3], 0x5c);
177     }
178 
179     #[test]
buffers_correctly()180     fn buffers_correctly() {
181         let mut hh = HandshakeHash::new();
182         hh.set_client_auth_enabled();
183         hh.update_raw(b"hello");
184         assert_eq!(hh.buffer.len(), 5);
185         hh.start_hash(&digest::SHA256);
186         assert_eq!(hh.buffer.len(), 5);
187         hh.update_raw(b"world");
188         assert_eq!(hh.buffer.len(), 10);
189         let h = hh.get_current_hash();
190         assert_eq!(h[0], 0x93);
191         assert_eq!(h[1], 0x6a);
192         assert_eq!(h[2], 0x18);
193         assert_eq!(h[3], 0x5c);
194         let buf = hh.take_handshake_buf();
195         assert_eq!(b"helloworld".to_vec(), buf);
196     }
197 
198     #[test]
abandon()199     fn abandon() {
200         let mut hh = HandshakeHash::new();
201         hh.set_client_auth_enabled();
202         hh.update_raw(b"hello");
203         assert_eq!(hh.buffer.len(), 5);
204         hh.start_hash(&digest::SHA256);
205         assert_eq!(hh.buffer.len(), 5);
206         hh.abandon_client_auth();
207         assert_eq!(hh.buffer.len(), 0);
208         hh.update_raw(b"world");
209         assert_eq!(hh.buffer.len(), 0);
210         let h = hh.get_current_hash();
211         assert_eq!(h[0], 0x93);
212         assert_eq!(h[1], 0x6a);
213         assert_eq!(h[2], 0x18);
214         assert_eq!(h[3], 0x5c);
215     }
216 }
217