1 use std::io::{self, Read};
2 
3 use rand::{Rng, RngCore};
4 use std::{cmp, iter};
5 
6 use super::decoder::{DecoderReader, BUF_SIZE};
7 use crate::encode::encode_config_buf;
8 use crate::tests::random_config;
9 use crate::{decode_config_buf, DecodeError, STANDARD};
10 
11 #[test]
simple()12 fn simple() {
13     let tests: &[(&[u8], &[u8])] = &[
14         (&b"0"[..], &b"MA=="[..]),
15         (b"01", b"MDE="),
16         (b"012", b"MDEy"),
17         (b"0123", b"MDEyMw=="),
18         (b"01234", b"MDEyMzQ="),
19         (b"012345", b"MDEyMzQ1"),
20         (b"0123456", b"MDEyMzQ1Ng=="),
21         (b"01234567", b"MDEyMzQ1Njc="),
22         (b"012345678", b"MDEyMzQ1Njc4"),
23         (b"0123456789", b"MDEyMzQ1Njc4OQ=="),
24     ][..];
25 
26     for (text_expected, base64data) in tests.iter() {
27         // Read n bytes at a time.
28         for n in 1..base64data.len() + 1 {
29             let mut wrapped_reader = io::Cursor::new(base64data);
30             let mut decoder = DecoderReader::new(&mut wrapped_reader, STANDARD);
31 
32             // handle errors as you normally would
33             let mut text_got = Vec::new();
34             let mut buffer = vec![0u8; n];
35             while let Ok(read) = decoder.read(&mut buffer[..]) {
36                 if read == 0 {
37                     break;
38                 }
39                 text_got.extend_from_slice(&buffer[..read]);
40             }
41 
42             assert_eq!(
43                 text_got,
44                 *text_expected,
45                 "\nGot: {}\nExpected: {}",
46                 String::from_utf8_lossy(&text_got[..]),
47                 String::from_utf8_lossy(text_expected)
48             );
49         }
50     }
51 }
52 
53 // Make sure we error out on trailing junk.
54 #[test]
trailing_junk()55 fn trailing_junk() {
56     let tests: &[&[u8]] = &[&b"MDEyMzQ1Njc4*!@#$%^&"[..], b"MDEyMzQ1Njc4OQ== "][..];
57 
58     for base64data in tests.iter() {
59         // Read n bytes at a time.
60         for n in 1..base64data.len() + 1 {
61             let mut wrapped_reader = io::Cursor::new(base64data);
62             let mut decoder = DecoderReader::new(&mut wrapped_reader, STANDARD);
63 
64             // handle errors as you normally would
65             let mut buffer = vec![0u8; n];
66             let mut saw_error = false;
67             loop {
68                 match decoder.read(&mut buffer[..]) {
69                     Err(_) => {
70                         saw_error = true;
71                         break;
72                     }
73                     Ok(read) if read == 0 => break,
74                     Ok(_) => (),
75                 }
76             }
77 
78             assert!(saw_error);
79         }
80     }
81 }
82 
83 #[test]
handles_short_read_from_delegate()84 fn handles_short_read_from_delegate() {
85     let mut rng = rand::thread_rng();
86     let mut bytes = Vec::new();
87     let mut b64 = String::new();
88     let mut decoded = Vec::new();
89 
90     for _ in 0..10_000 {
91         bytes.clear();
92         b64.clear();
93         decoded.clear();
94 
95         let size = rng.gen_range(0, 10 * BUF_SIZE);
96         bytes.extend(iter::repeat(0).take(size));
97         bytes.truncate(size);
98         rng.fill_bytes(&mut bytes[..size]);
99         assert_eq!(size, bytes.len());
100 
101         let config = random_config(&mut rng);
102         encode_config_buf(&bytes[..], config, &mut b64);
103 
104         let mut wrapped_reader = io::Cursor::new(b64.as_bytes());
105         let mut short_reader = RandomShortRead {
106             delegate: &mut wrapped_reader,
107             rng: &mut rng,
108         };
109 
110         let mut decoder = DecoderReader::new(&mut short_reader, config);
111 
112         let decoded_len = decoder.read_to_end(&mut decoded).unwrap();
113         assert_eq!(size, decoded_len);
114         assert_eq!(&bytes[..], &decoded[..]);
115     }
116 }
117 
118 #[test]
read_in_short_increments()119 fn read_in_short_increments() {
120     let mut rng = rand::thread_rng();
121     let mut bytes = Vec::new();
122     let mut b64 = String::new();
123     let mut decoded = Vec::new();
124 
125     for _ in 0..10_000 {
126         bytes.clear();
127         b64.clear();
128         decoded.clear();
129 
130         let size = rng.gen_range(0, 10 * BUF_SIZE);
131         bytes.extend(iter::repeat(0).take(size));
132         // leave room to play around with larger buffers
133         decoded.extend(iter::repeat(0).take(size * 3));
134 
135         rng.fill_bytes(&mut bytes[..]);
136         assert_eq!(size, bytes.len());
137 
138         let config = random_config(&mut rng);
139 
140         encode_config_buf(&bytes[..], config, &mut b64);
141 
142         let mut wrapped_reader = io::Cursor::new(&b64[..]);
143         let mut decoder = DecoderReader::new(&mut wrapped_reader, config);
144 
145         consume_with_short_reads_and_validate(&mut rng, &bytes[..], &mut decoded, &mut decoder);
146     }
147 }
148 
149 #[test]
read_in_short_increments_with_short_delegate_reads()150 fn read_in_short_increments_with_short_delegate_reads() {
151     let mut rng = rand::thread_rng();
152     let mut bytes = Vec::new();
153     let mut b64 = String::new();
154     let mut decoded = Vec::new();
155 
156     for _ in 0..10_000 {
157         bytes.clear();
158         b64.clear();
159         decoded.clear();
160 
161         let size = rng.gen_range(0, 10 * BUF_SIZE);
162         bytes.extend(iter::repeat(0).take(size));
163         // leave room to play around with larger buffers
164         decoded.extend(iter::repeat(0).take(size * 3));
165 
166         rng.fill_bytes(&mut bytes[..]);
167         assert_eq!(size, bytes.len());
168 
169         let config = random_config(&mut rng);
170 
171         encode_config_buf(&bytes[..], config, &mut b64);
172 
173         let mut base_reader = io::Cursor::new(&b64[..]);
174         let mut decoder = DecoderReader::new(&mut base_reader, config);
175         let mut short_reader = RandomShortRead {
176             delegate: &mut decoder,
177             rng: &mut rand::thread_rng(),
178         };
179 
180         consume_with_short_reads_and_validate(&mut rng, &bytes[..], &mut decoded, &mut short_reader)
181     }
182 }
183 
184 #[test]
reports_invalid_last_symbol_correctly()185 fn reports_invalid_last_symbol_correctly() {
186     let mut rng = rand::thread_rng();
187     let mut bytes = Vec::new();
188     let mut b64 = String::new();
189     let mut b64_bytes = Vec::new();
190     let mut decoded = Vec::new();
191     let mut bulk_decoded = Vec::new();
192 
193     for _ in 0..1_000 {
194         bytes.clear();
195         b64.clear();
196         b64_bytes.clear();
197 
198         let size = rng.gen_range(1, 10 * BUF_SIZE);
199         bytes.extend(iter::repeat(0).take(size));
200         decoded.extend(iter::repeat(0).take(size));
201         rng.fill_bytes(&mut bytes[..]);
202         assert_eq!(size, bytes.len());
203 
204         let mut config = random_config(&mut rng);
205         // changing padding will cause invalid padding errors when we twiddle the last byte
206         config.pad = false;
207 
208         encode_config_buf(&bytes[..], config, &mut b64);
209         b64_bytes.extend(b64.bytes());
210         assert_eq!(b64_bytes.len(), b64.len());
211 
212         // change the last character to every possible symbol. Should behave the same as bulk
213         // decoding whether invalid or valid.
214         for &s1 in config.char_set.encode_table().iter() {
215             decoded.clear();
216             bulk_decoded.clear();
217 
218             // replace the last
219             *b64_bytes.last_mut().unwrap() = s1;
220             let bulk_res = decode_config_buf(&b64_bytes[..], config, &mut bulk_decoded);
221 
222             let mut wrapped_reader = io::Cursor::new(&b64_bytes[..]);
223             let mut decoder = DecoderReader::new(&mut wrapped_reader, config);
224 
225             let stream_res = decoder.read_to_end(&mut decoded).map(|_| ()).map_err(|e| {
226                 e.into_inner()
227                     .and_then(|e| e.downcast::<DecodeError>().ok())
228             });
229 
230             assert_eq!(bulk_res.map_err(|e| Some(Box::new(e))), stream_res);
231         }
232     }
233 }
234 
235 #[test]
reports_invalid_byte_correctly()236 fn reports_invalid_byte_correctly() {
237     let mut rng = rand::thread_rng();
238     let mut bytes = Vec::new();
239     let mut b64 = String::new();
240     let mut decoded = Vec::new();
241 
242     for _ in 0..10_000 {
243         bytes.clear();
244         b64.clear();
245         decoded.clear();
246 
247         let size = rng.gen_range(1, 10 * BUF_SIZE);
248         bytes.extend(iter::repeat(0).take(size));
249         rng.fill_bytes(&mut bytes[..size]);
250         assert_eq!(size, bytes.len());
251 
252         let config = random_config(&mut rng);
253         encode_config_buf(&bytes[..], config, &mut b64);
254         // replace one byte, somewhere, with '*', which is invalid
255         let bad_byte_pos = rng.gen_range(0, &b64.len());
256         let mut b64_bytes = b64.bytes().collect::<Vec<u8>>();
257         b64_bytes[bad_byte_pos] = b'*';
258 
259         let mut wrapped_reader = io::Cursor::new(b64_bytes.clone());
260         let mut decoder = DecoderReader::new(&mut wrapped_reader, config);
261 
262         // some gymnastics to avoid double-moving the io::Error, which is not Copy
263         let read_decode_err = decoder
264             .read_to_end(&mut decoded)
265             .map_err(|e| {
266                 let kind = e.kind();
267                 let inner = e
268                     .into_inner()
269                     .and_then(|e| e.downcast::<DecodeError>().ok());
270                 inner.map(|i| (*i, kind))
271             })
272             .err()
273             .and_then(|o| o);
274 
275         let mut bulk_buf = Vec::new();
276         let bulk_decode_err = decode_config_buf(&b64_bytes[..], config, &mut bulk_buf).err();
277 
278         // it's tricky to predict where the invalid data's offset will be since if it's in the last
279         // chunk it will be reported at the first padding location because it's treated as invalid
280         // padding. So, we just check that it's the same as it is for decoding all at once.
281         assert_eq!(
282             bulk_decode_err.map(|e| (e, io::ErrorKind::InvalidData)),
283             read_decode_err
284         );
285     }
286 }
287 
consume_with_short_reads_and_validate<R: Read>( rng: &mut rand::rngs::ThreadRng, expected_bytes: &[u8], decoded: &mut Vec<u8>, short_reader: &mut R, ) -> ()288 fn consume_with_short_reads_and_validate<R: Read>(
289     rng: &mut rand::rngs::ThreadRng,
290     expected_bytes: &[u8],
291     decoded: &mut Vec<u8>,
292     short_reader: &mut R,
293 ) -> () {
294     let mut total_read = 0_usize;
295     loop {
296         assert!(
297             total_read <= expected_bytes.len(),
298             "tr {} size {}",
299             total_read,
300             expected_bytes.len()
301         );
302         if total_read == expected_bytes.len() {
303             assert_eq!(expected_bytes, &decoded[..total_read]);
304             // should be done
305             assert_eq!(0, short_reader.read(&mut decoded[..]).unwrap());
306             // didn't write anything
307             assert_eq!(expected_bytes, &decoded[..total_read]);
308 
309             break;
310         }
311         let decode_len = rng.gen_range(1, cmp::max(2, expected_bytes.len() * 2));
312 
313         let read = short_reader
314             .read(&mut decoded[total_read..total_read + decode_len])
315             .unwrap();
316         total_read += read;
317     }
318 }
319 
320 /// Limits how many bytes a reader will provide in each read call.
321 /// Useful for shaking out code that may work fine only with typical input sources that always fill
322 /// the buffer.
323 struct RandomShortRead<'a, 'b, R: io::Read, N: rand::Rng> {
324     delegate: &'b mut R,
325     rng: &'a mut N,
326 }
327 
328 impl<'a, 'b, R: io::Read, N: rand::Rng> io::Read for RandomShortRead<'a, 'b, R, N> {
read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error>329     fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
330         // avoid 0 since it means EOF for non-empty buffers
331         let effective_len = cmp::min(self.rng.gen_range(1, 20), buf.len());
332 
333         self.delegate.read(&mut buf[..effective_len])
334     }
335 }
336