1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6 
7 use std::cmp::min;
8 use std::mem;
9 
10 use crate::codec::Decoder;
11 
12 #[derive(Clone, Debug, Default)]
13 pub struct IncrementalDecoderUint {
14     v: u64,
15     remaining: Option<usize>,
16 }
17 
18 impl IncrementalDecoderUint {
19     #[must_use]
min_remaining(&self) -> usize20     pub fn min_remaining(&self) -> usize {
21         self.remaining.unwrap_or(1)
22     }
23 
24     /// Consume some data.
25     #[allow(clippy::missing_panics_doc)] // See https://github.com/rust-lang/rust-clippy/issues/6699
consume(&mut self, dv: &mut Decoder) -> Option<u64>26     pub fn consume(&mut self, dv: &mut Decoder) -> Option<u64> {
27         if let Some(r) = &mut self.remaining {
28             let amount = min(*r, dv.remaining());
29             if amount < 8 {
30                 self.v <<= amount * 8;
31             }
32             self.v |= dv.decode_uint(amount).unwrap();
33             *r -= amount;
34             if *r == 0 {
35                 Some(self.v)
36             } else {
37                 None
38             }
39         } else {
40             let (v, remaining) = match dv.decode_byte() {
41                 Some(b) => (
42                     u64::from(b & 0x3f),
43                     match b >> 6 {
44                         0 => 0,
45                         1 => 1,
46                         2 => 3,
47                         3 => 7,
48                         _ => unreachable!(),
49                     },
50                 ),
51                 None => unreachable!(),
52             };
53             self.remaining = Some(remaining);
54             self.v = v;
55             if remaining == 0 {
56                 Some(v)
57             } else {
58                 None
59             }
60         }
61     }
62 
63     #[must_use]
decoding_in_progress(&self) -> bool64     pub fn decoding_in_progress(&self) -> bool {
65         self.remaining.is_some()
66     }
67 }
68 
69 #[derive(Clone, Debug)]
70 pub struct IncrementalDecoderBuffer {
71     v: Vec<u8>,
72     remaining: usize,
73 }
74 
75 impl IncrementalDecoderBuffer {
76     #[must_use]
new(n: usize) -> Self77     pub fn new(n: usize) -> Self {
78         Self {
79             v: Vec::new(),
80             remaining: n,
81         }
82     }
83 
84     #[must_use]
min_remaining(&self) -> usize85     pub fn min_remaining(&self) -> usize {
86         self.remaining
87     }
88 
89     /// Consume some bytes from the decoder.
90     /// # Panics
91     /// Never; but rust doesn't know that.
consume(&mut self, dv: &mut Decoder) -> Option<Vec<u8>>92     pub fn consume(&mut self, dv: &mut Decoder) -> Option<Vec<u8>> {
93         let amount = min(self.remaining, dv.remaining());
94         let b = dv.decode(amount).unwrap();
95         self.v.extend_from_slice(b);
96         self.remaining -= amount;
97         if self.remaining == 0 {
98             Some(mem::take(&mut self.v))
99         } else {
100             None
101         }
102     }
103 }
104 
105 #[derive(Clone, Debug)]
106 pub struct IncrementalDecoderIgnore {
107     remaining: usize,
108 }
109 
110 impl IncrementalDecoderIgnore {
111     /// Make a new ignoring decoder.
112     /// # Panics
113     /// If the amount to ignore is zero.
114     #[must_use]
new(n: usize) -> Self115     pub fn new(n: usize) -> Self {
116         assert_ne!(n, 0);
117         Self { remaining: n }
118     }
119 
120     #[must_use]
min_remaining(&self) -> usize121     pub fn min_remaining(&self) -> usize {
122         self.remaining
123     }
124 
consume(&mut self, dv: &mut Decoder) -> bool125     pub fn consume(&mut self, dv: &mut Decoder) -> bool {
126         let amount = min(self.remaining, dv.remaining());
127         let _ = dv.decode(amount);
128         self.remaining -= amount;
129         self.remaining == 0
130     }
131 }
132 
133 #[cfg(test)]
134 mod tests {
135     use super::{
136         Decoder, IncrementalDecoderBuffer, IncrementalDecoderIgnore, IncrementalDecoderUint,
137     };
138     use crate::codec::Encoder;
139 
140     #[test]
buffer_incremental()141     fn buffer_incremental() {
142         let b = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
143         let mut dec = IncrementalDecoderBuffer::new(b.len());
144         let mut i = 0;
145         while i < b.len() {
146             // Feed in b in increasing-sized chunks.
147             let incr = if i < b.len() / 2 { i + 1 } else { b.len() - i };
148             let mut dv = Decoder::from(&b[i..i + incr]);
149             i += incr;
150             match dec.consume(&mut dv) {
151                 None => {
152                     assert!(i < b.len());
153                 }
154                 Some(res) => {
155                     assert_eq!(i, b.len());
156                     assert_eq!(res, b);
157                 }
158             }
159         }
160     }
161 
162     struct UintTestCase {
163         b: String,
164         v: u64,
165     }
166 
167     impl UintTestCase {
run(&self)168         pub fn run(&self) {
169             eprintln!(
170                 "IncrementalDecoderUint decoder with {:?} ; expect {:?}",
171                 self.b, self.v
172             );
173 
174             let decoder = IncrementalDecoderUint::default();
175             let mut db = Encoder::from_hex(&self.b);
176             // Add padding so that we can verify that the reader doesn't over-consume.
177             db.encode_byte(0xff);
178 
179             for tail in 1..db.len() {
180                 let split = db.len() - tail;
181                 let mut dv = Decoder::from(&db[0..split]);
182                 eprintln!("  split at {}: {:?}", split, dv);
183 
184                 // Clone the basic decoder for each iteration of the loop.
185                 let mut dec = decoder.clone();
186                 let mut res = None;
187                 while dv.remaining() > 0 {
188                     res = dec.consume(&mut dv);
189                 }
190                 assert!(dec.min_remaining() < tail);
191 
192                 if tail > 1 {
193                     assert_eq!(res, None);
194                     assert!(dec.min_remaining() > 0);
195                     let mut dv = Decoder::from(&db[split..]);
196                     eprintln!("  split remainder {}: {:?}", split, dv);
197                     res = dec.consume(&mut dv);
198                     assert_eq!(dv.remaining(), 1);
199                 }
200 
201                 assert_eq!(dec.min_remaining(), 0);
202                 assert_eq!(res.unwrap(), self.v);
203             }
204         }
205     }
206 
207     macro_rules! uint_tc {
208         [$( $b:expr => $v:expr ),+ $(,)?] => {
209             vec![ $( UintTestCase { b: String::from($b), v: $v, } ),+]
210         };
211     }
212 
213     #[test]
varint()214     fn varint() {
215         for c in uint_tc![
216             "00" => 0,
217             "01" => 1,
218             "3f" => 63,
219             "4040" => 64,
220             "7fff" => 16383,
221             "80004000" => 16384,
222             "bfffffff" => (1 << 30) - 1,
223             "c000000040000000" => 1 << 30,
224             "ffffffffffffffff" => (1 << 62) - 1,
225         ] {
226             c.run();
227         }
228     }
229 
230     #[test]
zero_len()231     fn zero_len() {
232         let enc = Encoder::from_hex("ff");
233         let mut dec = Decoder::new(&enc);
234         let mut incr = IncrementalDecoderBuffer::new(0);
235         assert_eq!(incr.consume(&mut dec), Some(Vec::new()));
236         assert_eq!(dec.remaining(), enc.len());
237     }
238 
239     #[test]
ignore()240     fn ignore() {
241         let db = Encoder::from_hex("12345678ff");
242 
243         let decoder = IncrementalDecoderIgnore::new(4);
244 
245         for tail in 1..db.len() {
246             let split = db.len() - tail;
247             let mut dv = Decoder::from(&db[0..split]);
248             eprintln!("  split at {}: {:?}", split, dv);
249 
250             // Clone the basic decoder for each iteration of the loop.
251             let mut dec = decoder.clone();
252             let mut res = dec.consume(&mut dv);
253             assert_eq!(dv.remaining(), 0);
254             assert!(dec.min_remaining() < tail);
255 
256             if tail > 1 {
257                 assert!(!res);
258                 assert!(dec.min_remaining() > 0);
259                 let mut dv = Decoder::from(&db[split..]);
260                 eprintln!("  split remainder {}: {:?}", split, dv);
261                 res = dec.consume(&mut dv);
262                 assert_eq!(dv.remaining(), 1);
263             }
264 
265             assert_eq!(dec.min_remaining(), 0);
266             assert!(res);
267         }
268     }
269 }
270