1 #[cfg(feature = "logging")] 2 use crate::log::warn; 3 use crate::msgs::codec::Codec; 4 use crate::msgs::handshake::HandshakeMessagePayload; 5 use crate::msgs::message::{Message, MessagePayload}; 6 use ring::digest; 7 use std::mem; 8 9 /// This deals with keeping a running hash of the handshake 10 /// payloads. This is computed by buffering initially. Once 11 /// we know what hash function we need to use we switch to 12 /// incremental hashing. 13 /// 14 /// For client auth, we also need to buffer all the messages. 15 /// This is disabled in cases where client auth is not possible. 16 pub struct HandshakeHash { 17 /// None before we know what hash function we're using 18 alg: Option<&'static digest::Algorithm>, 19 20 /// None before we know what hash function we're using 21 ctx: Option<digest::Context>, 22 23 /// true if we need to keep all messages 24 client_auth_enabled: bool, 25 26 /// buffer for pre-hashing stage and client-auth. 27 buffer: Vec<u8>, 28 } 29 30 impl HandshakeHash { new() -> HandshakeHash31 pub fn new() -> HandshakeHash { 32 HandshakeHash { 33 alg: None, 34 ctx: None, 35 client_auth_enabled: false, 36 buffer: Vec::new(), 37 } 38 } 39 40 /// We might be doing client auth, so need to keep a full 41 /// log of the handshake. set_client_auth_enabled(&mut self)42 pub fn set_client_auth_enabled(&mut self) { 43 debug_assert!(self.ctx.is_none()); // or we might have already discarded messages 44 self.client_auth_enabled = true; 45 } 46 47 /// We decided not to do client auth after all, so discard 48 /// the transcript. abandon_client_auth(&mut self)49 pub fn abandon_client_auth(&mut self) { 50 self.client_auth_enabled = false; 51 self.buffer.drain(..); 52 } 53 54 /// We now know what hash function the verify_data will use. start_hash(&mut self, alg: &'static digest::Algorithm) -> bool55 pub fn start_hash(&mut self, alg: &'static digest::Algorithm) -> bool { 56 match self.alg { 57 None => {} 58 Some(started) => { 59 if started != alg { 60 // hash type is changing 61 warn!("altered hash to HandshakeHash::start_hash"); 62 return false; 63 } 64 65 return true; 66 } 67 } 68 self.alg = Some(alg); 69 debug_assert!(self.ctx.is_none()); 70 71 let mut ctx = digest::Context::new(alg); 72 ctx.update(&self.buffer); 73 self.ctx = Some(ctx); 74 75 // Discard buffer if we don't need it now. 76 if !self.client_auth_enabled { 77 self.buffer.drain(..); 78 } 79 true 80 } 81 82 /// Hash/buffer a handshake message. add_message(&mut self, m: &Message) -> &mut HandshakeHash83 pub fn add_message(&mut self, m: &Message) -> &mut HandshakeHash { 84 match m.payload { 85 MessagePayload::Handshake(ref hs) => { 86 let buf = hs.get_encoding(); 87 self.update_raw(&buf); 88 } 89 _ => {} 90 }; 91 self 92 } 93 94 /// Hash or buffer a byte slice. update_raw(&mut self, buf: &[u8]) -> &mut Self95 fn update_raw(&mut self, buf: &[u8]) -> &mut Self { 96 if self.ctx.is_some() { 97 self.ctx.as_mut().unwrap().update(buf); 98 } 99 100 if self.ctx.is_none() || self.client_auth_enabled { 101 self.buffer.extend_from_slice(buf); 102 } 103 104 self 105 } 106 107 /// Get the hash value if we were to hash `extra` too, 108 /// using hash function `hash`. get_hash_given(&self, hash: &'static digest::Algorithm, extra: &[u8]) -> Vec<u8>109 pub fn get_hash_given(&self, hash: &'static digest::Algorithm, extra: &[u8]) -> Vec<u8> { 110 let mut ctx = if self.ctx.is_none() { 111 let mut ctx = digest::Context::new(hash); 112 ctx.update(&self.buffer); 113 ctx 114 } else { 115 self.ctx.as_ref().unwrap().clone() 116 }; 117 118 ctx.update(extra); 119 let hash = ctx.finish(); 120 let mut ret = Vec::new(); 121 ret.extend_from_slice(hash.as_ref()); 122 ret 123 } 124 125 /// Take the current hash value, and encapsulate it in a 126 /// 'handshake_hash' handshake message. Start this hash 127 /// again, with that message at the front. rollup_for_hrr(&mut self)128 pub fn rollup_for_hrr(&mut self) { 129 let old_hash = self.ctx.take().unwrap().finish(); 130 let old_handshake_hash_msg = 131 HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref()); 132 133 self.ctx = Some(digest::Context::new(self.alg.unwrap())); 134 self.update_raw(&old_handshake_hash_msg.get_encoding()); 135 } 136 137 /// Get the current hash value. get_current_hash(&self) -> Vec<u8>138 pub fn get_current_hash(&self) -> Vec<u8> { 139 let hash = self 140 .ctx 141 .as_ref() 142 .unwrap() 143 .clone() 144 .finish(); 145 let mut ret = Vec::new(); 146 ret.extend_from_slice(hash.as_ref()); 147 ret 148 } 149 150 /// Takes this object's buffer containing all handshake messages 151 /// so far. This method only works once; it resets the buffer 152 /// to empty. take_handshake_buf(&mut self) -> Vec<u8>153 pub fn take_handshake_buf(&mut self) -> Vec<u8> { 154 debug_assert!(self.client_auth_enabled); 155 mem::replace(&mut self.buffer, Vec::new()) 156 } 157 } 158 159 #[cfg(test)] 160 mod test { 161 use super::HandshakeHash; 162 use ring::digest; 163 164 #[test] hashes_correctly()165 fn hashes_correctly() { 166 let mut hh = HandshakeHash::new(); 167 hh.update_raw(b"hello"); 168 assert_eq!(hh.buffer.len(), 5); 169 hh.start_hash(&digest::SHA256); 170 assert_eq!(hh.buffer.len(), 0); 171 hh.update_raw(b"world"); 172 let h = hh.get_current_hash(); 173 assert_eq!(h[0], 0x93); 174 assert_eq!(h[1], 0x6a); 175 assert_eq!(h[2], 0x18); 176 assert_eq!(h[3], 0x5c); 177 } 178 179 #[test] buffers_correctly()180 fn buffers_correctly() { 181 let mut hh = HandshakeHash::new(); 182 hh.set_client_auth_enabled(); 183 hh.update_raw(b"hello"); 184 assert_eq!(hh.buffer.len(), 5); 185 hh.start_hash(&digest::SHA256); 186 assert_eq!(hh.buffer.len(), 5); 187 hh.update_raw(b"world"); 188 assert_eq!(hh.buffer.len(), 10); 189 let h = hh.get_current_hash(); 190 assert_eq!(h[0], 0x93); 191 assert_eq!(h[1], 0x6a); 192 assert_eq!(h[2], 0x18); 193 assert_eq!(h[3], 0x5c); 194 let buf = hh.take_handshake_buf(); 195 assert_eq!(b"helloworld".to_vec(), buf); 196 } 197 198 #[test] abandon()199 fn abandon() { 200 let mut hh = HandshakeHash::new(); 201 hh.set_client_auth_enabled(); 202 hh.update_raw(b"hello"); 203 assert_eq!(hh.buffer.len(), 5); 204 hh.start_hash(&digest::SHA256); 205 assert_eq!(hh.buffer.len(), 5); 206 hh.abandon_client_auth(); 207 assert_eq!(hh.buffer.len(), 0); 208 hh.update_raw(b"world"); 209 assert_eq!(hh.buffer.len(), 0); 210 let h = hh.get_current_hash(); 211 assert_eq!(h[0], 0x93); 212 assert_eq!(h[1], 0x6a); 213 assert_eq!(h[2], 0x18); 214 assert_eq!(h[3], 0x5c); 215 } 216 } 217