1 use std::collections::VecDeque;
2 use std::io;
3 
4 use crate::msgs::codec;
5 use crate::msgs::message::{Message, MessageError};
6 
7 /// This deframer works to reconstruct TLS messages
8 /// from arbitrary-sized reads, buffering as necessary.
9 /// The input is `read()`, the output is the `frames` deque.
10 pub struct MessageDeframer {
11     /// Completed frames for output.
12     pub frames: VecDeque<Message>,
13 
14     /// Set to true if the peer is not talking TLS, but some other
15     /// protocol.  The caller should abort the connection, because
16     /// the deframer cannot recover.
17     pub desynced: bool,
18 
19     /// A fixed-size buffer containing the currently-accumulating
20     /// TLS message.
21     buf: Box<[u8; Message::MAX_WIRE_SIZE]>,
22 
23     /// What size prefix of `buf` is used.
24     used: usize,
25 }
26 
27 enum BufferContents {
28     /// Contains an invalid message as a header.
29     Invalid,
30 
31     /// Might contain a valid message if we receive more.
32     /// Perhaps totally empty!
33     Partial,
34 
35     /// Contains a valid frame as a prefix.
36     Valid,
37 }
38 
39 impl Default for MessageDeframer {
default() -> Self40     fn default() -> Self {
41         Self::new()
42     }
43 }
44 
45 impl MessageDeframer {
new() -> MessageDeframer46     pub fn new() -> MessageDeframer {
47         MessageDeframer {
48             frames: VecDeque::new(),
49             desynced: false,
50             buf: Box::new([0u8; Message::MAX_WIRE_SIZE]),
51             used: 0,
52         }
53     }
54 
55     /// Read some bytes from `rd`, and add them to our internal
56     /// buffer.  If this means our internal buffer contains
57     /// full messages, decode them all.
read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize>58     pub fn read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
59         // Try to do the largest reads possible.  Note that if
60         // we get a message with a length field out of range here,
61         // we do a zero length read.  That looks like an EOF to
62         // the next layer up, which is fine.
63         debug_assert!(self.used <= Message::MAX_WIRE_SIZE);
64         let new_bytes = rd.read(&mut self.buf[self.used..])?;
65 
66         self.used += new_bytes;
67 
68         loop {
69             match self.try_deframe_one() {
70                 BufferContents::Invalid => {
71                     self.desynced = true;
72                     break;
73                 }
74                 BufferContents::Valid => continue,
75                 BufferContents::Partial => break,
76             }
77         }
78 
79         Ok(new_bytes)
80     }
81 
82     /// Returns true if we have messages for the caller
83     /// to process, either whole messages in our output
84     /// queue or partial messages in our buffer.
has_pending(&self) -> bool85     pub fn has_pending(&self) -> bool {
86         !self.frames.is_empty() || self.used > 0
87     }
88 
89     /// Does our `buf` contain a full message?  It does if it is big enough to
90     /// contain a header, and that header has a length which falls within `buf`.
91     /// If so, deframe it and place the message onto the frames output queue.
try_deframe_one(&mut self) -> BufferContents92     fn try_deframe_one(&mut self) -> BufferContents {
93         // Try to decode a message off the front of buf.
94         let mut rd = codec::Reader::init(&self.buf[..self.used]);
95 
96         match Message::read_with_detailed_error(&mut rd) {
97             Ok(m) => {
98                 let used = rd.used();
99                 self.frames.push_back(m);
100                 self.buf_consume(used);
101                 BufferContents::Valid
102             }
103             Err(MessageError::TooShortForHeader) | Err(MessageError::TooShortForLength) => {
104                 BufferContents::Partial
105             }
106             Err(_) => BufferContents::Invalid,
107         }
108     }
109 
buf_consume(&mut self, taken: usize)110     fn buf_consume(&mut self, taken: usize) {
111         if taken < self.used {
112             /* Before:
113              * +----------+----------+----------+
114              * | taken    | pending  |xxxxxxxxxx|
115              * +----------+----------+----------+
116              * 0          ^ taken    ^ self.used
117              *
118              * After:
119              * +----------+----------+----------+
120              * | pending  |xxxxxxxxxxxxxxxxxxxxx|
121              * +----------+----------+----------+
122              * 0          ^ self.used
123              */
124 
125             self.buf
126                 .copy_within(taken..self.used, 0);
127             self.used = self.used - taken;
128         } else if taken == self.used {
129             self.used = 0;
130         }
131     }
132 }
133 
134 #[cfg(test)]
135 mod tests {
136     use super::MessageDeframer;
137     use crate::msgs;
138     use std::io;
139 
140     const FIRST_MESSAGE: &'static [u8] = include_bytes!("../testdata/deframer-test.1.bin");
141     const SECOND_MESSAGE: &'static [u8] = include_bytes!("../testdata/deframer-test.2.bin");
142 
143     struct ByteRead<'a> {
144         buf: &'a [u8],
145         offs: usize,
146     }
147 
148     impl<'a> ByteRead<'a> {
new(bytes: &'a [u8]) -> ByteRead149         fn new(bytes: &'a [u8]) -> ByteRead {
150             ByteRead {
151                 buf: bytes,
152                 offs: 0,
153             }
154         }
155     }
156 
157     impl<'a> io::Read for ByteRead<'a> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>158         fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
159             let mut len = 0;
160 
161             while len < buf.len() && len < self.buf.len() - self.offs {
162                 buf[len] = self.buf[self.offs + len];
163                 len += 1;
164             }
165 
166             self.offs += len;
167 
168             Ok(len)
169         }
170     }
171 
input_bytes(d: &mut MessageDeframer, bytes: &[u8]) -> io::Result<usize>172     fn input_bytes(d: &mut MessageDeframer, bytes: &[u8]) -> io::Result<usize> {
173         let mut rd = ByteRead::new(bytes);
174         d.read(&mut rd)
175     }
176 
input_bytes_concat( d: &mut MessageDeframer, bytes1: &[u8], bytes2: &[u8], ) -> io::Result<usize>177     fn input_bytes_concat(
178         d: &mut MessageDeframer,
179         bytes1: &[u8],
180         bytes2: &[u8],
181     ) -> io::Result<usize> {
182         let mut bytes = vec![0u8; bytes1.len() + bytes2.len()];
183         bytes[..bytes1.len()].clone_from_slice(bytes1);
184         bytes[bytes1.len()..].clone_from_slice(bytes2);
185         let mut rd = ByteRead::new(&bytes);
186         d.read(&mut rd)
187     }
188 
189     struct ErrorRead {
190         error: Option<io::Error>,
191     }
192 
193     impl ErrorRead {
new(error: io::Error) -> ErrorRead194         fn new(error: io::Error) -> ErrorRead {
195             ErrorRead { error: Some(error) }
196         }
197     }
198 
199     impl io::Read for ErrorRead {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>200         fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
201             for (i, b) in buf.iter_mut().enumerate() {
202                 *b = i as u8;
203             }
204 
205             let error = self.error.take().unwrap();
206             Err(error)
207         }
208     }
209 
input_error(d: &mut MessageDeframer)210     fn input_error(d: &mut MessageDeframer) {
211         let error = io::Error::from(io::ErrorKind::TimedOut);
212         let mut rd = ErrorRead::new(error);
213         d.read(&mut rd)
214             .expect_err("error not propagated");
215     }
216 
input_whole_incremental(d: &mut MessageDeframer, bytes: &[u8])217     fn input_whole_incremental(d: &mut MessageDeframer, bytes: &[u8]) {
218         let frames_before = d.frames.len();
219 
220         for i in 0..bytes.len() {
221             assert_len(1, input_bytes(d, &bytes[i..i + 1]));
222             assert_eq!(d.has_pending(), true);
223 
224             if i < bytes.len() - 1 {
225                 assert_eq!(frames_before, d.frames.len());
226             }
227         }
228 
229         assert_eq!(frames_before + 1, d.frames.len());
230     }
231 
assert_len(want: usize, got: io::Result<usize>)232     fn assert_len(want: usize, got: io::Result<usize>) {
233         if let Ok(gotval) = got {
234             assert_eq!(gotval, want);
235         } else {
236             assert!(false, "read failed, expected {:?} bytes", want);
237         }
238     }
239 
pop_first(d: &mut MessageDeframer)240     fn pop_first(d: &mut MessageDeframer) {
241         let mut m = d.frames.pop_front().unwrap();
242         m.decode_payload();
243         assert_eq!(m.typ, msgs::enums::ContentType::Handshake);
244     }
245 
pop_second(d: &mut MessageDeframer)246     fn pop_second(d: &mut MessageDeframer) {
247         let mut m = d.frames.pop_front().unwrap();
248         m.decode_payload();
249         assert_eq!(m.typ, msgs::enums::ContentType::Alert);
250     }
251 
252     #[test]
check_incremental()253     fn check_incremental() {
254         let mut d = MessageDeframer::new();
255         assert_eq!(d.has_pending(), false);
256         input_whole_incremental(&mut d, FIRST_MESSAGE);
257         assert_eq!(d.has_pending(), true);
258         assert_eq!(1, d.frames.len());
259         pop_first(&mut d);
260         assert_eq!(d.has_pending(), false);
261     }
262 
263     #[test]
check_incremental_2()264     fn check_incremental_2() {
265         let mut d = MessageDeframer::new();
266         assert_eq!(d.has_pending(), false);
267         input_whole_incremental(&mut d, FIRST_MESSAGE);
268         assert_eq!(d.has_pending(), true);
269         input_whole_incremental(&mut d, SECOND_MESSAGE);
270         assert_eq!(d.has_pending(), true);
271         assert_eq!(2, d.frames.len());
272         pop_first(&mut d);
273         assert_eq!(d.has_pending(), true);
274         pop_second(&mut d);
275         assert_eq!(d.has_pending(), false);
276     }
277 
278     #[test]
check_whole()279     fn check_whole() {
280         let mut d = MessageDeframer::new();
281         assert_eq!(d.has_pending(), false);
282         assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
283         assert_eq!(d.has_pending(), true);
284         assert_eq!(d.frames.len(), 1);
285         pop_first(&mut d);
286         assert_eq!(d.has_pending(), false);
287     }
288 
289     #[test]
check_whole_2()290     fn check_whole_2() {
291         let mut d = MessageDeframer::new();
292         assert_eq!(d.has_pending(), false);
293         assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
294         assert_len(SECOND_MESSAGE.len(), input_bytes(&mut d, SECOND_MESSAGE));
295         assert_eq!(d.frames.len(), 2);
296         pop_first(&mut d);
297         pop_second(&mut d);
298         assert_eq!(d.has_pending(), false);
299     }
300 
301     #[test]
test_two_in_one_read()302     fn test_two_in_one_read() {
303         let mut d = MessageDeframer::new();
304         assert_eq!(d.has_pending(), false);
305         assert_len(
306             FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
307             input_bytes_concat(&mut d, FIRST_MESSAGE, SECOND_MESSAGE),
308         );
309         assert_eq!(d.frames.len(), 2);
310         pop_first(&mut d);
311         pop_second(&mut d);
312         assert_eq!(d.has_pending(), false);
313     }
314 
315     #[test]
test_two_in_one_read_shortest_first()316     fn test_two_in_one_read_shortest_first() {
317         let mut d = MessageDeframer::new();
318         assert_eq!(d.has_pending(), false);
319         assert_len(
320             FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
321             input_bytes_concat(&mut d, SECOND_MESSAGE, FIRST_MESSAGE),
322         );
323         assert_eq!(d.frames.len(), 2);
324         pop_second(&mut d);
325         pop_first(&mut d);
326         assert_eq!(d.has_pending(), false);
327     }
328 
329     #[test]
test_incremental_with_nonfatal_read_error()330     fn test_incremental_with_nonfatal_read_error() {
331         let mut d = MessageDeframer::new();
332         assert_len(3, input_bytes(&mut d, &FIRST_MESSAGE[..3]));
333         input_error(&mut d);
334         assert_len(
335             FIRST_MESSAGE.len() - 3,
336             input_bytes(&mut d, &FIRST_MESSAGE[3..]),
337         );
338         assert_eq!(d.frames.len(), 1);
339         pop_first(&mut d);
340         assert_eq!(d.has_pending(), false);
341     }
342 }
343