1 use crate::msgs::base::Payload; 2 use crate::msgs::enums::{ContentType, ProtocolVersion}; 3 use crate::msgs::message::{BorrowedPlainMessage, PlainMessage}; 4 use crate::Error; 5 use std::collections::VecDeque; 6 7 pub const MAX_FRAGMENT_LEN: usize = 16384; 8 pub const PACKET_OVERHEAD: usize = 1 + 2 + 2; 9 pub const MAX_FRAGMENT_SIZE: usize = MAX_FRAGMENT_LEN + PACKET_OVERHEAD; 10 11 pub struct MessageFragmenter { 12 max_frag: usize, 13 } 14 15 impl MessageFragmenter { 16 /// Make a new fragmenter. 17 /// 18 /// `max_fragment_size` is the maximum fragment size that will be produced -- 19 /// this includes overhead. A `max_fragment_size` of 10 will produce TLS fragments 20 /// up to 10 bytes. new(max_fragment_size: Option<usize>) -> Result<Self, Error>21 pub fn new(max_fragment_size: Option<usize>) -> Result<Self, Error> { 22 let mut new = Self { max_frag: 0 }; 23 new.set_max_fragment_size(max_fragment_size)?; 24 Ok(new) 25 } 26 27 /// Take the Message `msg` and re-fragment it into new 28 /// messages whose fragment is no more than max_frag. 29 /// The new messages are appended to the `out` deque. 30 /// Payloads are copied. fragment(&self, msg: PlainMessage, out: &mut VecDeque<PlainMessage>)31 pub fn fragment(&self, msg: PlainMessage, out: &mut VecDeque<PlainMessage>) { 32 // Non-fragment path 33 if msg.payload.0.len() <= self.max_frag { 34 out.push_back(msg); 35 return; 36 } 37 38 for chunk in msg.payload.0.chunks(self.max_frag) { 39 out.push_back(PlainMessage { 40 typ: msg.typ, 41 version: msg.version, 42 payload: Payload(chunk.to_vec()), 43 }); 44 } 45 } 46 47 /// Enqueue borrowed fragments of (version, typ, payload) which 48 /// are no longer than max_frag onto the `out` deque. fragment_borrow<'a>( &self, typ: ContentType, version: ProtocolVersion, payload: &'a [u8], out: &mut VecDeque<BorrowedPlainMessage<'a>>, )49 pub fn fragment_borrow<'a>( 50 &self, 51 typ: ContentType, 52 version: ProtocolVersion, 53 payload: &'a [u8], 54 out: &mut VecDeque<BorrowedPlainMessage<'a>>, 55 ) { 56 for chunk in payload.chunks(self.max_frag) { 57 let cm = BorrowedPlainMessage { 58 typ, 59 version, 60 payload: chunk, 61 }; 62 out.push_back(cm); 63 } 64 } 65 set_max_fragment_size(&mut self, new: Option<usize>) -> Result<(), Error>66 pub fn set_max_fragment_size(&mut self, new: Option<usize>) -> Result<(), Error> { 67 self.max_frag = match new { 68 Some(sz @ 32..=MAX_FRAGMENT_SIZE) => sz - PACKET_OVERHEAD, 69 None => MAX_FRAGMENT_LEN, 70 _ => return Err(Error::BadMaxFragmentSize), 71 }; 72 Ok(()) 73 } 74 } 75 76 #[cfg(test)] 77 mod tests { 78 use super::{MessageFragmenter, PACKET_OVERHEAD}; 79 use crate::msgs::base::Payload; 80 use crate::msgs::enums::{ContentType, ProtocolVersion}; 81 use crate::msgs::message::PlainMessage; 82 use std::collections::VecDeque; 83 msg_eq( mm: Option<PlainMessage>, total_len: usize, typ: &ContentType, version: &ProtocolVersion, bytes: &[u8], )84 fn msg_eq( 85 mm: Option<PlainMessage>, 86 total_len: usize, 87 typ: &ContentType, 88 version: &ProtocolVersion, 89 bytes: &[u8], 90 ) { 91 let m = mm.unwrap(); 92 let buf = m 93 .clone() 94 .into_unencrypted_opaque() 95 .encode(); 96 97 assert_eq!(&m.typ, typ); 98 assert_eq!(&m.version, version); 99 assert_eq!(m.payload.0, bytes.to_vec()); 100 101 assert_eq!(total_len, buf.len()); 102 } 103 104 #[test] smoke()105 fn smoke() { 106 let typ = ContentType::Handshake; 107 let version = ProtocolVersion::TLSv1_2; 108 let data: Vec<u8> = (1..70u8).collect(); 109 let m = PlainMessage { 110 typ, 111 version, 112 payload: Payload::new(data), 113 }; 114 115 let frag = MessageFragmenter::new(Some(32)).unwrap(); 116 let mut q = VecDeque::new(); 117 frag.fragment(m, &mut q); 118 msg_eq( 119 q.pop_front(), 120 32, 121 &typ, 122 &version, 123 &[ 124 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 125 24, 25, 26, 27, 126 ], 127 ); 128 msg_eq( 129 q.pop_front(), 130 32, 131 &typ, 132 &version, 133 &[ 134 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 135 49, 50, 51, 52, 53, 54, 136 ], 137 ); 138 msg_eq( 139 q.pop_front(), 140 20, 141 &typ, 142 &version, 143 &[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69], 144 ); 145 assert_eq!(q.len(), 0); 146 } 147 148 #[test] non_fragment()149 fn non_fragment() { 150 let m = PlainMessage { 151 typ: ContentType::Handshake, 152 version: ProtocolVersion::TLSv1_2, 153 payload: Payload::new(b"\x01\x02\x03\x04\x05\x06\x07\x08".to_vec()), 154 }; 155 156 let frag = MessageFragmenter::new(Some(32)).unwrap(); 157 let mut q = VecDeque::new(); 158 frag.fragment(m, &mut q); 159 msg_eq( 160 q.pop_front(), 161 PACKET_OVERHEAD + 8, 162 &ContentType::Handshake, 163 &ProtocolVersion::TLSv1_2, 164 b"\x01\x02\x03\x04\x05\x06\x07\x08", 165 ); 166 assert_eq!(q.len(), 0); 167 } 168 } 169