1 use crate::msgs::base::Payload;
2 use crate::msgs::enums::{ContentType, ProtocolVersion};
3 use crate::msgs::message::{BorrowedPlainMessage, PlainMessage};
4 use crate::Error;
5 use std::collections::VecDeque;
6 
7 pub const MAX_FRAGMENT_LEN: usize = 16384;
8 pub const PACKET_OVERHEAD: usize = 1 + 2 + 2;
9 pub const MAX_FRAGMENT_SIZE: usize = MAX_FRAGMENT_LEN + PACKET_OVERHEAD;
10 
11 pub struct MessageFragmenter {
12     max_frag: usize,
13 }
14 
15 impl MessageFragmenter {
16     /// Make a new fragmenter.
17     ///
18     /// `max_fragment_size` is the maximum fragment size that will be produced --
19     /// this includes overhead. A `max_fragment_size` of 10 will produce TLS fragments
20     /// up to 10 bytes.
new(max_fragment_size: Option<usize>) -> Result<Self, Error>21     pub fn new(max_fragment_size: Option<usize>) -> Result<Self, Error> {
22         let mut new = Self { max_frag: 0 };
23         new.set_max_fragment_size(max_fragment_size)?;
24         Ok(new)
25     }
26 
27     /// Take the Message `msg` and re-fragment it into new
28     /// messages whose fragment is no more than max_frag.
29     /// The new messages are appended to the `out` deque.
30     /// Payloads are copied.
fragment(&self, msg: PlainMessage, out: &mut VecDeque<PlainMessage>)31     pub fn fragment(&self, msg: PlainMessage, out: &mut VecDeque<PlainMessage>) {
32         // Non-fragment path
33         if msg.payload.0.len() <= self.max_frag {
34             out.push_back(msg);
35             return;
36         }
37 
38         for chunk in msg.payload.0.chunks(self.max_frag) {
39             out.push_back(PlainMessage {
40                 typ: msg.typ,
41                 version: msg.version,
42                 payload: Payload(chunk.to_vec()),
43             });
44         }
45     }
46 
47     /// Enqueue borrowed fragments of (version, typ, payload) which
48     /// are no longer than max_frag onto the `out` deque.
fragment_borrow<'a>( &self, typ: ContentType, version: ProtocolVersion, payload: &'a [u8], out: &mut VecDeque<BorrowedPlainMessage<'a>>, )49     pub fn fragment_borrow<'a>(
50         &self,
51         typ: ContentType,
52         version: ProtocolVersion,
53         payload: &'a [u8],
54         out: &mut VecDeque<BorrowedPlainMessage<'a>>,
55     ) {
56         for chunk in payload.chunks(self.max_frag) {
57             let cm = BorrowedPlainMessage {
58                 typ,
59                 version,
60                 payload: chunk,
61             };
62             out.push_back(cm);
63         }
64     }
65 
set_max_fragment_size(&mut self, new: Option<usize>) -> Result<(), Error>66     pub fn set_max_fragment_size(&mut self, new: Option<usize>) -> Result<(), Error> {
67         self.max_frag = match new {
68             Some(sz @ 32..=MAX_FRAGMENT_SIZE) => sz - PACKET_OVERHEAD,
69             None => MAX_FRAGMENT_LEN,
70             _ => return Err(Error::BadMaxFragmentSize),
71         };
72         Ok(())
73     }
74 }
75 
76 #[cfg(test)]
77 mod tests {
78     use super::{MessageFragmenter, PACKET_OVERHEAD};
79     use crate::msgs::base::Payload;
80     use crate::msgs::enums::{ContentType, ProtocolVersion};
81     use crate::msgs::message::PlainMessage;
82     use std::collections::VecDeque;
83 
msg_eq( mm: Option<PlainMessage>, total_len: usize, typ: &ContentType, version: &ProtocolVersion, bytes: &[u8], )84     fn msg_eq(
85         mm: Option<PlainMessage>,
86         total_len: usize,
87         typ: &ContentType,
88         version: &ProtocolVersion,
89         bytes: &[u8],
90     ) {
91         let m = mm.unwrap();
92         let buf = m
93             .clone()
94             .into_unencrypted_opaque()
95             .encode();
96 
97         assert_eq!(&m.typ, typ);
98         assert_eq!(&m.version, version);
99         assert_eq!(m.payload.0, bytes.to_vec());
100 
101         assert_eq!(total_len, buf.len());
102     }
103 
104     #[test]
smoke()105     fn smoke() {
106         let typ = ContentType::Handshake;
107         let version = ProtocolVersion::TLSv1_2;
108         let data: Vec<u8> = (1..70u8).collect();
109         let m = PlainMessage {
110             typ,
111             version,
112             payload: Payload::new(data),
113         };
114 
115         let frag = MessageFragmenter::new(Some(32)).unwrap();
116         let mut q = VecDeque::new();
117         frag.fragment(m, &mut q);
118         msg_eq(
119             q.pop_front(),
120             32,
121             &typ,
122             &version,
123             &[
124                 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
125                 24, 25, 26, 27,
126             ],
127         );
128         msg_eq(
129             q.pop_front(),
130             32,
131             &typ,
132             &version,
133             &[
134                 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
135                 49, 50, 51, 52, 53, 54,
136             ],
137         );
138         msg_eq(
139             q.pop_front(),
140             20,
141             &typ,
142             &version,
143             &[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
144         );
145         assert_eq!(q.len(), 0);
146     }
147 
148     #[test]
non_fragment()149     fn non_fragment() {
150         let m = PlainMessage {
151             typ: ContentType::Handshake,
152             version: ProtocolVersion::TLSv1_2,
153             payload: Payload::new(b"\x01\x02\x03\x04\x05\x06\x07\x08".to_vec()),
154         };
155 
156         let frag = MessageFragmenter::new(Some(32)).unwrap();
157         let mut q = VecDeque::new();
158         frag.fragment(m, &mut q);
159         msg_eq(
160             q.pop_front(),
161             PACKET_OVERHEAD + 8,
162             &ContentType::Handshake,
163             &ProtocolVersion::TLSv1_2,
164             b"\x01\x02\x03\x04\x05\x06\x07\x08",
165         );
166         assert_eq!(q.len(), 0);
167     }
168 }
169