1 use std::collections::VecDeque; 2 use std::io; 3 4 use crate::msgs::codec; 5 use crate::msgs::message::{Message, MessageError}; 6 7 /// This deframer works to reconstruct TLS messages 8 /// from arbitrary-sized reads, buffering as necessary. 9 /// The input is `read()`, the output is the `frames` deque. 10 pub struct MessageDeframer { 11 /// Completed frames for output. 12 pub frames: VecDeque<Message>, 13 14 /// Set to true if the peer is not talking TLS, but some other 15 /// protocol. The caller should abort the connection, because 16 /// the deframer cannot recover. 17 pub desynced: bool, 18 19 /// A fixed-size buffer containing the currently-accumulating 20 /// TLS message. 21 buf: Box<[u8; Message::MAX_WIRE_SIZE]>, 22 23 /// What size prefix of `buf` is used. 24 used: usize, 25 } 26 27 enum BufferContents { 28 /// Contains an invalid message as a header. 29 Invalid, 30 31 /// Might contain a valid message if we receive more. 32 /// Perhaps totally empty! 33 Partial, 34 35 /// Contains a valid frame as a prefix. 36 Valid, 37 } 38 39 impl Default for MessageDeframer { default() -> Self40 fn default() -> Self { 41 Self::new() 42 } 43 } 44 45 impl MessageDeframer { new() -> MessageDeframer46 pub fn new() -> MessageDeframer { 47 MessageDeframer { 48 frames: VecDeque::new(), 49 desynced: false, 50 buf: Box::new([0u8; Message::MAX_WIRE_SIZE]), 51 used: 0, 52 } 53 } 54 55 /// Read some bytes from `rd`, and add them to our internal 56 /// buffer. If this means our internal buffer contains 57 /// full messages, decode them all. read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize>58 pub fn read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> { 59 // Try to do the largest reads possible. Note that if 60 // we get a message with a length field out of range here, 61 // we do a zero length read. That looks like an EOF to 62 // the next layer up, which is fine. 63 debug_assert!(self.used <= Message::MAX_WIRE_SIZE); 64 let new_bytes = rd.read(&mut self.buf[self.used..])?; 65 66 self.used += new_bytes; 67 68 loop { 69 match self.try_deframe_one() { 70 BufferContents::Invalid => { 71 self.desynced = true; 72 break; 73 } 74 BufferContents::Valid => continue, 75 BufferContents::Partial => break, 76 } 77 } 78 79 Ok(new_bytes) 80 } 81 82 /// Returns true if we have messages for the caller 83 /// to process, either whole messages in our output 84 /// queue or partial messages in our buffer. has_pending(&self) -> bool85 pub fn has_pending(&self) -> bool { 86 !self.frames.is_empty() || self.used > 0 87 } 88 89 /// Does our `buf` contain a full message? It does if it is big enough to 90 /// contain a header, and that header has a length which falls within `buf`. 91 /// If so, deframe it and place the message onto the frames output queue. try_deframe_one(&mut self) -> BufferContents92 fn try_deframe_one(&mut self) -> BufferContents { 93 // Try to decode a message off the front of buf. 94 let mut rd = codec::Reader::init(&self.buf[..self.used]); 95 96 match Message::read_with_detailed_error(&mut rd) { 97 Ok(m) => { 98 let used = rd.used(); 99 self.frames.push_back(m); 100 self.buf_consume(used); 101 BufferContents::Valid 102 } 103 Err(MessageError::TooShortForHeader) | Err(MessageError::TooShortForLength) => { 104 BufferContents::Partial 105 } 106 Err(_) => BufferContents::Invalid, 107 } 108 } 109 buf_consume(&mut self, taken: usize)110 fn buf_consume(&mut self, taken: usize) { 111 if taken < self.used { 112 /* Before: 113 * +----------+----------+----------+ 114 * | taken | pending |xxxxxxxxxx| 115 * +----------+----------+----------+ 116 * 0 ^ taken ^ self.used 117 * 118 * After: 119 * +----------+----------+----------+ 120 * | pending |xxxxxxxxxxxxxxxxxxxxx| 121 * +----------+----------+----------+ 122 * 0 ^ self.used 123 */ 124 125 self.buf 126 .copy_within(taken..self.used, 0); 127 self.used = self.used - taken; 128 } else if taken == self.used { 129 self.used = 0; 130 } 131 } 132 } 133 134 #[cfg(test)] 135 mod tests { 136 use super::MessageDeframer; 137 use crate::msgs; 138 use std::io; 139 140 const FIRST_MESSAGE: &'static [u8] = include_bytes!("../testdata/deframer-test.1.bin"); 141 const SECOND_MESSAGE: &'static [u8] = include_bytes!("../testdata/deframer-test.2.bin"); 142 143 struct ByteRead<'a> { 144 buf: &'a [u8], 145 offs: usize, 146 } 147 148 impl<'a> ByteRead<'a> { new(bytes: &'a [u8]) -> ByteRead149 fn new(bytes: &'a [u8]) -> ByteRead { 150 ByteRead { 151 buf: bytes, 152 offs: 0, 153 } 154 } 155 } 156 157 impl<'a> io::Read for ByteRead<'a> { read(&mut self, buf: &mut [u8]) -> io::Result<usize>158 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 159 let mut len = 0; 160 161 while len < buf.len() && len < self.buf.len() - self.offs { 162 buf[len] = self.buf[self.offs + len]; 163 len += 1; 164 } 165 166 self.offs += len; 167 168 Ok(len) 169 } 170 } 171 input_bytes(d: &mut MessageDeframer, bytes: &[u8]) -> io::Result<usize>172 fn input_bytes(d: &mut MessageDeframer, bytes: &[u8]) -> io::Result<usize> { 173 let mut rd = ByteRead::new(bytes); 174 d.read(&mut rd) 175 } 176 input_bytes_concat( d: &mut MessageDeframer, bytes1: &[u8], bytes2: &[u8], ) -> io::Result<usize>177 fn input_bytes_concat( 178 d: &mut MessageDeframer, 179 bytes1: &[u8], 180 bytes2: &[u8], 181 ) -> io::Result<usize> { 182 let mut bytes = vec![0u8; bytes1.len() + bytes2.len()]; 183 bytes[..bytes1.len()].clone_from_slice(bytes1); 184 bytes[bytes1.len()..].clone_from_slice(bytes2); 185 let mut rd = ByteRead::new(&bytes); 186 d.read(&mut rd) 187 } 188 189 struct ErrorRead { 190 error: Option<io::Error>, 191 } 192 193 impl ErrorRead { new(error: io::Error) -> ErrorRead194 fn new(error: io::Error) -> ErrorRead { 195 ErrorRead { error: Some(error) } 196 } 197 } 198 199 impl io::Read for ErrorRead { read(&mut self, buf: &mut [u8]) -> io::Result<usize>200 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 201 for (i, b) in buf.iter_mut().enumerate() { 202 *b = i as u8; 203 } 204 205 let error = self.error.take().unwrap(); 206 Err(error) 207 } 208 } 209 input_error(d: &mut MessageDeframer)210 fn input_error(d: &mut MessageDeframer) { 211 let error = io::Error::from(io::ErrorKind::TimedOut); 212 let mut rd = ErrorRead::new(error); 213 d.read(&mut rd) 214 .expect_err("error not propagated"); 215 } 216 input_whole_incremental(d: &mut MessageDeframer, bytes: &[u8])217 fn input_whole_incremental(d: &mut MessageDeframer, bytes: &[u8]) { 218 let frames_before = d.frames.len(); 219 220 for i in 0..bytes.len() { 221 assert_len(1, input_bytes(d, &bytes[i..i + 1])); 222 assert_eq!(d.has_pending(), true); 223 224 if i < bytes.len() - 1 { 225 assert_eq!(frames_before, d.frames.len()); 226 } 227 } 228 229 assert_eq!(frames_before + 1, d.frames.len()); 230 } 231 assert_len(want: usize, got: io::Result<usize>)232 fn assert_len(want: usize, got: io::Result<usize>) { 233 if let Ok(gotval) = got { 234 assert_eq!(gotval, want); 235 } else { 236 assert!(false, "read failed, expected {:?} bytes", want); 237 } 238 } 239 pop_first(d: &mut MessageDeframer)240 fn pop_first(d: &mut MessageDeframer) { 241 let mut m = d.frames.pop_front().unwrap(); 242 m.decode_payload(); 243 assert_eq!(m.typ, msgs::enums::ContentType::Handshake); 244 } 245 pop_second(d: &mut MessageDeframer)246 fn pop_second(d: &mut MessageDeframer) { 247 let mut m = d.frames.pop_front().unwrap(); 248 m.decode_payload(); 249 assert_eq!(m.typ, msgs::enums::ContentType::Alert); 250 } 251 252 #[test] check_incremental()253 fn check_incremental() { 254 let mut d = MessageDeframer::new(); 255 assert_eq!(d.has_pending(), false); 256 input_whole_incremental(&mut d, FIRST_MESSAGE); 257 assert_eq!(d.has_pending(), true); 258 assert_eq!(1, d.frames.len()); 259 pop_first(&mut d); 260 assert_eq!(d.has_pending(), false); 261 } 262 263 #[test] check_incremental_2()264 fn check_incremental_2() { 265 let mut d = MessageDeframer::new(); 266 assert_eq!(d.has_pending(), false); 267 input_whole_incremental(&mut d, FIRST_MESSAGE); 268 assert_eq!(d.has_pending(), true); 269 input_whole_incremental(&mut d, SECOND_MESSAGE); 270 assert_eq!(d.has_pending(), true); 271 assert_eq!(2, d.frames.len()); 272 pop_first(&mut d); 273 assert_eq!(d.has_pending(), true); 274 pop_second(&mut d); 275 assert_eq!(d.has_pending(), false); 276 } 277 278 #[test] check_whole()279 fn check_whole() { 280 let mut d = MessageDeframer::new(); 281 assert_eq!(d.has_pending(), false); 282 assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE)); 283 assert_eq!(d.has_pending(), true); 284 assert_eq!(d.frames.len(), 1); 285 pop_first(&mut d); 286 assert_eq!(d.has_pending(), false); 287 } 288 289 #[test] check_whole_2()290 fn check_whole_2() { 291 let mut d = MessageDeframer::new(); 292 assert_eq!(d.has_pending(), false); 293 assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE)); 294 assert_len(SECOND_MESSAGE.len(), input_bytes(&mut d, SECOND_MESSAGE)); 295 assert_eq!(d.frames.len(), 2); 296 pop_first(&mut d); 297 pop_second(&mut d); 298 assert_eq!(d.has_pending(), false); 299 } 300 301 #[test] test_two_in_one_read()302 fn test_two_in_one_read() { 303 let mut d = MessageDeframer::new(); 304 assert_eq!(d.has_pending(), false); 305 assert_len( 306 FIRST_MESSAGE.len() + SECOND_MESSAGE.len(), 307 input_bytes_concat(&mut d, FIRST_MESSAGE, SECOND_MESSAGE), 308 ); 309 assert_eq!(d.frames.len(), 2); 310 pop_first(&mut d); 311 pop_second(&mut d); 312 assert_eq!(d.has_pending(), false); 313 } 314 315 #[test] test_two_in_one_read_shortest_first()316 fn test_two_in_one_read_shortest_first() { 317 let mut d = MessageDeframer::new(); 318 assert_eq!(d.has_pending(), false); 319 assert_len( 320 FIRST_MESSAGE.len() + SECOND_MESSAGE.len(), 321 input_bytes_concat(&mut d, SECOND_MESSAGE, FIRST_MESSAGE), 322 ); 323 assert_eq!(d.frames.len(), 2); 324 pop_second(&mut d); 325 pop_first(&mut d); 326 assert_eq!(d.has_pending(), false); 327 } 328 329 #[test] test_incremental_with_nonfatal_read_error()330 fn test_incremental_with_nonfatal_read_error() { 331 let mut d = MessageDeframer::new(); 332 assert_len(3, input_bytes(&mut d, &FIRST_MESSAGE[..3])); 333 input_error(&mut d); 334 assert_len( 335 FIRST_MESSAGE.len() - 3, 336 input_bytes(&mut d, &FIRST_MESSAGE[3..]), 337 ); 338 assert_eq!(d.frames.len(), 1); 339 pop_first(&mut d); 340 assert_eq!(d.has_pending(), false); 341 } 342 } 343