1 use crate::{tables, Config, PAD_BYTE};
2 
3 #[cfg(any(feature = "alloc", feature = "std", test))]
4 use crate::STANDARD;
5 #[cfg(any(feature = "alloc", feature = "std", test))]
6 use alloc::vec::Vec;
7 use core::fmt;
8 #[cfg(any(feature = "std", test))]
9 use std::error;
10 
11 // decode logic operates on chunks of 8 input bytes without padding
12 const INPUT_CHUNK_LEN: usize = 8;
13 const DECODED_CHUNK_LEN: usize = 6;
14 // we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last
15 // 2 bytes of any output u64 should not be counted as written to (but must be available in a
16 // slice).
17 const DECODED_CHUNK_SUFFIX: usize = 2;
18 
19 // how many u64's of input to handle at a time
20 const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
21 const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
22 // includes the trailing 2 bytes for the final u64 write
23 const DECODED_BLOCK_LEN: usize =
24     CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
25 
26 /// Errors that can occur while decoding.
27 #[derive(Clone, Debug, PartialEq, Eq)]
28 pub enum DecodeError {
29     /// An invalid byte was found in the input. The offset and offending byte are provided.
30     InvalidByte(usize, u8),
31     /// The length of the input is invalid.
32     /// A typical cause of this is stray trailing whitespace or other separator bytes.
33     /// In the case where excess trailing bytes have produced an invalid length *and* the last byte
34     /// is also an invalid base64 symbol (as would be the case for whitespace, etc), `InvalidByte`
35     /// will be emitted instead of `InvalidLength` to make the issue easier to debug.
36     InvalidLength,
37     /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
38     /// This is indicative of corrupted or truncated Base64.
39     /// Unlike InvalidByte, which reports symbols that aren't in the alphabet, this error is for
40     /// symbols that are in the alphabet but represent nonsensical encodings.
41     InvalidLastSymbol(usize, u8),
42 }
43 
44 impl fmt::Display for DecodeError {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result45     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
46         match *self {
47             DecodeError::InvalidByte(index, byte) => {
48                 write!(f, "Invalid byte {}, offset {}.", byte, index)
49             }
50             DecodeError::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."),
51             DecodeError::InvalidLastSymbol(index, byte) => {
52                 write!(f, "Invalid last symbol {}, offset {}.", byte, index)
53             }
54         }
55     }
56 }
57 
58 #[cfg(any(feature = "std", test))]
59 impl error::Error for DecodeError {
description(&self) -> &str60     fn description(&self) -> &str {
61         match *self {
62             DecodeError::InvalidByte(_, _) => "invalid byte",
63             DecodeError::InvalidLength => "invalid length",
64             DecodeError::InvalidLastSymbol(_, _) => "invalid last symbol",
65         }
66     }
67 
cause(&self) -> Option<&dyn error::Error>68     fn cause(&self) -> Option<&dyn error::Error> {
69         None
70     }
71 }
72 
73 ///Decode from string reference as octets.
74 ///Returns a Result containing a Vec<u8>.
75 ///Convenience `decode_config(input, base64::STANDARD);`.
76 ///
77 ///# Example
78 ///
79 ///```rust
80 ///extern crate base64;
81 ///
82 ///fn main() {
83 ///    let bytes = base64::decode("aGVsbG8gd29ybGQ=").unwrap();
84 ///    println!("{:?}", bytes);
85 ///}
86 ///```
87 #[cfg(any(feature = "alloc", feature = "std", test))]
decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, DecodeError>88 pub fn decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, DecodeError> {
89     decode_config(input, STANDARD)
90 }
91 
92 ///Decode from string reference as octets.
93 ///Returns a Result containing a Vec<u8>.
94 ///
95 ///# Example
96 ///
97 ///```rust
98 ///extern crate base64;
99 ///
100 ///fn main() {
101 ///    let bytes = base64::decode_config("aGVsbG8gd29ybGR+Cg==", base64::STANDARD).unwrap();
102 ///    println!("{:?}", bytes);
103 ///
104 ///    let bytes_url = base64::decode_config("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE).unwrap();
105 ///    println!("{:?}", bytes_url);
106 ///}
107 ///```
108 #[cfg(any(feature = "alloc", feature = "std", test))]
decode_config<T: AsRef<[u8]>>(input: T, config: Config) -> Result<Vec<u8>, DecodeError>109 pub fn decode_config<T: AsRef<[u8]>>(input: T, config: Config) -> Result<Vec<u8>, DecodeError> {
110     let mut buffer = Vec::<u8>::with_capacity(input.as_ref().len() * 4 / 3);
111 
112     decode_config_buf(input, config, &mut buffer).map(|_| buffer)
113 }
114 
115 ///Decode from string reference as octets.
116 ///Writes into the supplied buffer to avoid allocation.
117 ///Returns a Result containing an empty tuple, aka ().
118 ///
119 ///# Example
120 ///
121 ///```rust
122 ///extern crate base64;
123 ///
124 ///fn main() {
125 ///    let mut buffer = Vec::<u8>::new();
126 ///    base64::decode_config_buf("aGVsbG8gd29ybGR+Cg==", base64::STANDARD, &mut buffer).unwrap();
127 ///    println!("{:?}", buffer);
128 ///
129 ///    buffer.clear();
130 ///
131 ///    base64::decode_config_buf("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE, &mut buffer)
132 ///        .unwrap();
133 ///    println!("{:?}", buffer);
134 ///}
135 ///```
136 #[cfg(any(feature = "alloc", feature = "std", test))]
decode_config_buf<T: AsRef<[u8]>>( input: T, config: Config, buffer: &mut Vec<u8>, ) -> Result<(), DecodeError>137 pub fn decode_config_buf<T: AsRef<[u8]>>(
138     input: T,
139     config: Config,
140     buffer: &mut Vec<u8>,
141 ) -> Result<(), DecodeError> {
142     let input_bytes = input.as_ref();
143 
144     let starting_output_len = buffer.len();
145 
146     let num_chunks = num_chunks(input_bytes);
147     let decoded_len_estimate = num_chunks
148         .checked_mul(DECODED_CHUNK_LEN)
149         .and_then(|p| p.checked_add(starting_output_len))
150         .expect("Overflow when calculating output buffer length");
151     buffer.resize(decoded_len_estimate, 0);
152 
153     let bytes_written;
154     {
155         let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..];
156         bytes_written = decode_helper(input_bytes, num_chunks, config, buffer_slice)?;
157     }
158 
159     buffer.truncate(starting_output_len + bytes_written);
160 
161     Ok(())
162 }
163 
164 /// Decode the input into the provided output slice.
165 ///
166 /// This will not write any bytes past exactly what is decoded (no stray garbage bytes at the end).
167 ///
168 /// If you don't know ahead of time what the decoded length should be, size your buffer with a
169 /// conservative estimate for the decoded length of an input: 3 bytes of output for every 4 bytes of
170 /// input, rounded up, or in other words `(input_len + 3) / 4 * 3`.
171 ///
172 /// If the slice is not large enough, this will panic.
decode_config_slice<T: AsRef<[u8]>>( input: T, config: Config, output: &mut [u8], ) -> Result<usize, DecodeError>173 pub fn decode_config_slice<T: AsRef<[u8]>>(
174     input: T,
175     config: Config,
176     output: &mut [u8],
177 ) -> Result<usize, DecodeError> {
178     let input_bytes = input.as_ref();
179 
180     decode_helper(input_bytes, num_chunks(input_bytes), config, output)
181 }
182 
183 /// Return the number of input chunks (including a possibly partial final chunk) in the input
num_chunks(input: &[u8]) -> usize184 fn num_chunks(input: &[u8]) -> usize {
185     input
186         .len()
187         .checked_add(INPUT_CHUNK_LEN - 1)
188         .expect("Overflow when calculating number of chunks in input")
189         / INPUT_CHUNK_LEN
190 }
191 
192 /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
193 /// Returns the number of bytes written, or an error.
194 // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
195 // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment,
196 // but this is fragile and the best setting changes with only minor code modifications.
197 #[inline]
decode_helper( input: &[u8], num_chunks: usize, config: Config, output: &mut [u8], ) -> Result<usize, DecodeError>198 fn decode_helper(
199     input: &[u8],
200     num_chunks: usize,
201     config: Config,
202     output: &mut [u8],
203 ) -> Result<usize, DecodeError> {
204     let char_set = config.char_set;
205     let decode_table = char_set.decode_table();
206 
207     let remainder_len = input.len() % INPUT_CHUNK_LEN;
208 
209     // Because the fast decode loop writes in groups of 8 bytes (unrolled to
210     // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of
211     // which only 6 are valid data), we need to be sure that we stop using the fast decode loop
212     // soon enough that there will always be 2 more bytes of valid data written after that loop.
213     let trailing_bytes_to_skip = match remainder_len {
214         // if input is a multiple of the chunk size, ignore the last chunk as it may have padding,
215         // and the fast decode logic cannot handle padding
216         0 => INPUT_CHUNK_LEN,
217         // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte
218         1 | 5 => {
219             // trailing whitespace is so common that it's worth it to check the last byte to
220             // possibly return a better error message
221             if let Some(b) = input.last() {
222                 if *b != PAD_BYTE && decode_table[*b as usize] == tables::INVALID_VALUE {
223                     return Err(DecodeError::InvalidByte(input.len() - 1, *b));
224                 }
225             }
226 
227             return Err(DecodeError::InvalidLength);
228         }
229         // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes
230         // written by the fast decode loop. So, we have to ignore both these 2 bytes and the
231         // previous chunk.
232         2 => INPUT_CHUNK_LEN + 2,
233         // If this is 3 unpadded chars, then it would actually decode to 2 bytes. However, if this
234         // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail
235         // with an error, not panic from going past the bounds of the output slice, so we let it
236         // use stage 3 + 4.
237         3 => INPUT_CHUNK_LEN + 3,
238         // This can also decode to one output byte because it may be 2 input chars + 2 padding
239         // chars, which would decode to 1 byte.
240         4 => INPUT_CHUNK_LEN + 4,
241         // Everything else is a legal decode len (given that we don't require padding), and will
242         // decode to at least 2 bytes of output.
243         _ => remainder_len,
244     };
245 
246     // rounded up to include partial chunks
247     let mut remaining_chunks = num_chunks;
248 
249     let mut input_index = 0;
250     let mut output_index = 0;
251 
252     {
253         let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
254 
255         // Fast loop, stage 1
256         // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks
257         if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
258             while input_index <= max_start_index {
259                 let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
260                 let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
261 
262                 decode_chunk(
263                     &input_slice[0..],
264                     input_index,
265                     decode_table,
266                     &mut output_slice[0..],
267                 )?;
268                 decode_chunk(
269                     &input_slice[8..],
270                     input_index + 8,
271                     decode_table,
272                     &mut output_slice[6..],
273                 )?;
274                 decode_chunk(
275                     &input_slice[16..],
276                     input_index + 16,
277                     decode_table,
278                     &mut output_slice[12..],
279                 )?;
280                 decode_chunk(
281                     &input_slice[24..],
282                     input_index + 24,
283                     decode_table,
284                     &mut output_slice[18..],
285                 )?;
286 
287                 input_index += INPUT_BLOCK_LEN;
288                 output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
289                 remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
290             }
291         }
292 
293         // Fast loop, stage 2 (aka still pretty fast loop)
294         // 8 bytes at a time for whatever we didn't do in stage 1.
295         if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
296             while input_index < max_start_index {
297                 decode_chunk(
298                     &input[input_index..(input_index + INPUT_CHUNK_LEN)],
299                     input_index,
300                     decode_table,
301                     &mut output
302                         [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
303                 )?;
304 
305                 output_index += DECODED_CHUNK_LEN;
306                 input_index += INPUT_CHUNK_LEN;
307                 remaining_chunks -= 1;
308             }
309         }
310     }
311 
312     // Stage 3
313     // If input length was such that a chunk had to be deferred until after the fast loop
314     // because decoding it would have produced 2 trailing bytes that wouldn't then be
315     // overwritten, we decode that chunk here. This way is slower but doesn't write the 2
316     // trailing bytes.
317     // However, we still need to avoid the last chunk (partial or complete) because it could
318     // have padding, so we always do 1 fewer to avoid the last chunk.
319     for _ in 1..remaining_chunks {
320         decode_chunk_precise(
321             &input[input_index..],
322             input_index,
323             decode_table,
324             &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
325         )?;
326 
327         input_index += INPUT_CHUNK_LEN;
328         output_index += DECODED_CHUNK_LEN;
329     }
330 
331     // always have one more (possibly partial) block of 8 input
332     debug_assert!(input.len() - input_index > 1 || input.is_empty());
333     debug_assert!(input.len() - input_index <= 8);
334 
335     // Stage 4
336     // Finally, decode any leftovers that aren't a complete input block of 8 bytes.
337     // Use a u64 as a stack-resident 8 byte buffer.
338     let mut leftover_bits: u64 = 0;
339     let mut morsels_in_leftover = 0;
340     let mut padding_bytes = 0;
341     let mut first_padding_index: usize = 0;
342     let mut last_symbol = 0_u8;
343     let start_of_leftovers = input_index;
344     for (i, b) in input[start_of_leftovers..].iter().enumerate() {
345         // '=' padding
346         if *b == PAD_BYTE {
347             // There can be bad padding in a few ways:
348             // 1 - Padding with non-padding characters after it
349             // 2 - Padding after zero or one non-padding characters before it
350             //     in the current quad.
351             // 3 - More than two characters of padding. If 3 or 4 padding chars
352             //     are in the same quad, that implies it will be caught by #2.
353             //     If it spreads from one quad to another, it will be caught by
354             //     #2 in the second quad.
355 
356             if i % 4 < 2 {
357                 // Check for case #2.
358                 let bad_padding_index = start_of_leftovers
359                     + if padding_bytes > 0 {
360                         // If we've already seen padding, report the first padding index.
361                         // This is to be consistent with the faster logic above: it will report an
362                         // error on the first padding character (since it doesn't expect to see
363                         // anything but actual encoded data).
364                         first_padding_index
365                     } else {
366                         // haven't seen padding before, just use where we are now
367                         i
368                     };
369                 return Err(DecodeError::InvalidByte(bad_padding_index, *b));
370             }
371 
372             if padding_bytes == 0 {
373                 first_padding_index = i;
374             }
375 
376             padding_bytes += 1;
377             continue;
378         }
379 
380         // Check for case #1.
381         // To make '=' handling consistent with the main loop, don't allow
382         // non-suffix '=' in trailing chunk either. Report error as first
383         // erroneous padding.
384         if padding_bytes > 0 {
385             return Err(DecodeError::InvalidByte(
386                 start_of_leftovers + first_padding_index,
387                 PAD_BYTE,
388             ));
389         }
390         last_symbol = *b;
391 
392         // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
393         // To minimize shifts, pack the leftovers from left to right.
394         let shift = 64 - (morsels_in_leftover + 1) * 6;
395         // tables are all 256 elements, lookup with a u8 index always succeeds
396         let morsel = decode_table[*b as usize];
397         if morsel == tables::INVALID_VALUE {
398             return Err(DecodeError::InvalidByte(start_of_leftovers + i, *b));
399         }
400 
401         leftover_bits |= (morsel as u64) << shift;
402         morsels_in_leftover += 1;
403     }
404 
405     let leftover_bits_ready_to_append = match morsels_in_leftover {
406         0 => 0,
407         2 => 8,
408         3 => 16,
409         4 => 24,
410         6 => 32,
411         7 => 40,
412         8 => 48,
413         _ => unreachable!(
414             "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths"
415         ),
416     };
417 
418     // if there are bits set outside the bits we care about, last symbol encodes trailing bits that
419     // will not be included in the output
420     let mask = !0 >> leftover_bits_ready_to_append;
421     if !config.decode_allow_trailing_bits && (leftover_bits & mask) != 0 {
422         // last morsel is at `morsels_in_leftover` - 1
423         return Err(DecodeError::InvalidLastSymbol(
424             start_of_leftovers + morsels_in_leftover - 1,
425             last_symbol,
426         ));
427     }
428 
429     let mut leftover_bits_appended_to_buf = 0;
430     while leftover_bits_appended_to_buf < leftover_bits_ready_to_append {
431         // `as` simply truncates the higher bits, which is what we want here
432         let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8;
433         output[output_index] = selected_bits;
434         output_index += 1;
435 
436         leftover_bits_appended_to_buf += 8;
437     }
438 
439     Ok(output_index)
440 }
441 
442 #[inline]
write_u64(output: &mut [u8], value: u64)443 fn write_u64(output: &mut [u8], value: u64) {
444     output[..8].copy_from_slice(&value.to_be_bytes());
445 }
446 
447 /// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the
448 /// first 6 of those contain meaningful data.
449 ///
450 /// `input` is the bytes to decode, of which the first 8 bytes will be processed.
451 /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
452 /// accurately)
453 /// `decode_table` is the lookup table for the particular base64 alphabet.
454 /// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded
455 /// data.
456 // yes, really inline (worth 30-50% speedup)
457 #[inline(always)]
decode_chunk( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>458 fn decode_chunk(
459     input: &[u8],
460     index_at_start_of_input: usize,
461     decode_table: &[u8; 256],
462     output: &mut [u8],
463 ) -> Result<(), DecodeError> {
464     let mut accum: u64;
465 
466     let morsel = decode_table[input[0] as usize];
467     if morsel == tables::INVALID_VALUE {
468         return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
469     }
470     accum = (morsel as u64) << 58;
471 
472     let morsel = decode_table[input[1] as usize];
473     if morsel == tables::INVALID_VALUE {
474         return Err(DecodeError::InvalidByte(
475             index_at_start_of_input + 1,
476             input[1],
477         ));
478     }
479     accum |= (morsel as u64) << 52;
480 
481     let morsel = decode_table[input[2] as usize];
482     if morsel == tables::INVALID_VALUE {
483         return Err(DecodeError::InvalidByte(
484             index_at_start_of_input + 2,
485             input[2],
486         ));
487     }
488     accum |= (morsel as u64) << 46;
489 
490     let morsel = decode_table[input[3] as usize];
491     if morsel == tables::INVALID_VALUE {
492         return Err(DecodeError::InvalidByte(
493             index_at_start_of_input + 3,
494             input[3],
495         ));
496     }
497     accum |= (morsel as u64) << 40;
498 
499     let morsel = decode_table[input[4] as usize];
500     if morsel == tables::INVALID_VALUE {
501         return Err(DecodeError::InvalidByte(
502             index_at_start_of_input + 4,
503             input[4],
504         ));
505     }
506     accum |= (morsel as u64) << 34;
507 
508     let morsel = decode_table[input[5] as usize];
509     if morsel == tables::INVALID_VALUE {
510         return Err(DecodeError::InvalidByte(
511             index_at_start_of_input + 5,
512             input[5],
513         ));
514     }
515     accum |= (morsel as u64) << 28;
516 
517     let morsel = decode_table[input[6] as usize];
518     if morsel == tables::INVALID_VALUE {
519         return Err(DecodeError::InvalidByte(
520             index_at_start_of_input + 6,
521             input[6],
522         ));
523     }
524     accum |= (morsel as u64) << 22;
525 
526     let morsel = decode_table[input[7] as usize];
527     if morsel == tables::INVALID_VALUE {
528         return Err(DecodeError::InvalidByte(
529             index_at_start_of_input + 7,
530             input[7],
531         ));
532     }
533     accum |= (morsel as u64) << 16;
534 
535     write_u64(output, accum);
536 
537     Ok(())
538 }
539 
540 /// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2
541 /// trailing garbage bytes.
542 #[inline]
decode_chunk_precise( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>543 fn decode_chunk_precise(
544     input: &[u8],
545     index_at_start_of_input: usize,
546     decode_table: &[u8; 256],
547     output: &mut [u8],
548 ) -> Result<(), DecodeError> {
549     let mut tmp_buf = [0_u8; 8];
550 
551     decode_chunk(
552         input,
553         index_at_start_of_input,
554         decode_table,
555         &mut tmp_buf[..],
556     )?;
557 
558     output[0..6].copy_from_slice(&tmp_buf[0..6]);
559 
560     Ok(())
561 }
562 
563 #[cfg(test)]
564 mod tests {
565     use super::*;
566     use crate::{
567         encode::encode_config_buf,
568         encode::encode_config_slice,
569         tests::{assert_encode_sanity, random_config},
570     };
571 
572     use rand::{
573         distributions::{Distribution, Uniform},
574         FromEntropy, Rng,
575     };
576 
577     #[test]
decode_chunk_precise_writes_only_6_bytes()578     fn decode_chunk_precise_writes_only_6_bytes() {
579         let input = b"Zm9vYmFy"; // "foobar"
580         let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
581         decode_chunk_precise(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
582         assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
583     }
584 
585     #[test]
decode_chunk_writes_8_bytes()586     fn decode_chunk_writes_8_bytes() {
587         let input = b"Zm9vYmFy"; // "foobar"
588         let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
589         decode_chunk(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
590         assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
591     }
592 
593     #[test]
decode_into_nonempty_vec_doesnt_clobber_existing_prefix()594     fn decode_into_nonempty_vec_doesnt_clobber_existing_prefix() {
595         let mut orig_data = Vec::new();
596         let mut encoded_data = String::new();
597         let mut decoded_with_prefix = Vec::new();
598         let mut decoded_without_prefix = Vec::new();
599         let mut prefix = Vec::new();
600 
601         let prefix_len_range = Uniform::new(0, 1000);
602         let input_len_range = Uniform::new(0, 1000);
603 
604         let mut rng = rand::rngs::SmallRng::from_entropy();
605 
606         for _ in 0..10_000 {
607             orig_data.clear();
608             encoded_data.clear();
609             decoded_with_prefix.clear();
610             decoded_without_prefix.clear();
611             prefix.clear();
612 
613             let input_len = input_len_range.sample(&mut rng);
614 
615             for _ in 0..input_len {
616                 orig_data.push(rng.gen());
617             }
618 
619             let config = random_config(&mut rng);
620             encode_config_buf(&orig_data, config, &mut encoded_data);
621             assert_encode_sanity(&encoded_data, config, input_len);
622 
623             let prefix_len = prefix_len_range.sample(&mut rng);
624 
625             // fill the buf with a prefix
626             for _ in 0..prefix_len {
627                 prefix.push(rng.gen());
628             }
629 
630             decoded_with_prefix.resize(prefix_len, 0);
631             decoded_with_prefix.copy_from_slice(&prefix);
632 
633             // decode into the non-empty buf
634             decode_config_buf(&encoded_data, config, &mut decoded_with_prefix).unwrap();
635             // also decode into the empty buf
636             decode_config_buf(&encoded_data, config, &mut decoded_without_prefix).unwrap();
637 
638             assert_eq!(
639                 prefix_len + decoded_without_prefix.len(),
640                 decoded_with_prefix.len()
641             );
642             assert_eq!(orig_data, decoded_without_prefix);
643 
644             // append plain decode onto prefix
645             prefix.append(&mut decoded_without_prefix);
646 
647             assert_eq!(prefix, decoded_with_prefix);
648         }
649     }
650 
651     #[test]
decode_into_slice_doesnt_clobber_existing_prefix_or_suffix()652     fn decode_into_slice_doesnt_clobber_existing_prefix_or_suffix() {
653         let mut orig_data = Vec::new();
654         let mut encoded_data = String::new();
655         let mut decode_buf = Vec::new();
656         let mut decode_buf_copy: Vec<u8> = Vec::new();
657 
658         let input_len_range = Uniform::new(0, 1000);
659 
660         let mut rng = rand::rngs::SmallRng::from_entropy();
661 
662         for _ in 0..10_000 {
663             orig_data.clear();
664             encoded_data.clear();
665             decode_buf.clear();
666             decode_buf_copy.clear();
667 
668             let input_len = input_len_range.sample(&mut rng);
669 
670             for _ in 0..input_len {
671                 orig_data.push(rng.gen());
672             }
673 
674             let config = random_config(&mut rng);
675             encode_config_buf(&orig_data, config, &mut encoded_data);
676             assert_encode_sanity(&encoded_data, config, input_len);
677 
678             // fill the buffer with random garbage, long enough to have some room before and after
679             for _ in 0..5000 {
680                 decode_buf.push(rng.gen());
681             }
682 
683             // keep a copy for later comparison
684             decode_buf_copy.extend(decode_buf.iter());
685 
686             let offset = 1000;
687 
688             // decode into the non-empty buf
689             let decode_bytes_written =
690                 decode_config_slice(&encoded_data, config, &mut decode_buf[offset..]).unwrap();
691 
692             assert_eq!(orig_data.len(), decode_bytes_written);
693             assert_eq!(
694                 orig_data,
695                 &decode_buf[offset..(offset + decode_bytes_written)]
696             );
697             assert_eq!(&decode_buf_copy[0..offset], &decode_buf[0..offset]);
698             assert_eq!(
699                 &decode_buf_copy[offset + decode_bytes_written..],
700                 &decode_buf[offset + decode_bytes_written..]
701             );
702         }
703     }
704 
705     #[test]
decode_into_slice_fits_in_precisely_sized_slice()706     fn decode_into_slice_fits_in_precisely_sized_slice() {
707         let mut orig_data = Vec::new();
708         let mut encoded_data = String::new();
709         let mut decode_buf = Vec::new();
710 
711         let input_len_range = Uniform::new(0, 1000);
712 
713         let mut rng = rand::rngs::SmallRng::from_entropy();
714 
715         for _ in 0..10_000 {
716             orig_data.clear();
717             encoded_data.clear();
718             decode_buf.clear();
719 
720             let input_len = input_len_range.sample(&mut rng);
721 
722             for _ in 0..input_len {
723                 orig_data.push(rng.gen());
724             }
725 
726             let config = random_config(&mut rng);
727             encode_config_buf(&orig_data, config, &mut encoded_data);
728             assert_encode_sanity(&encoded_data, config, input_len);
729 
730             decode_buf.resize(input_len, 0);
731 
732             // decode into the non-empty buf
733             let decode_bytes_written =
734                 decode_config_slice(&encoded_data, config, &mut decode_buf[..]).unwrap();
735 
736             assert_eq!(orig_data.len(), decode_bytes_written);
737             assert_eq!(orig_data, decode_buf);
738         }
739     }
740 
741     #[test]
detect_invalid_last_symbol_two_bytes()742     fn detect_invalid_last_symbol_two_bytes() {
743         let decode =
744             |input, forgiving| decode_config(input, STANDARD.decode_allow_trailing_bits(forgiving));
745 
746         // example from https://github.com/marshallpierce/rust-base64/issues/75
747         assert!(decode("iYU=", false).is_ok());
748         // trailing 01
749         assert_eq!(
750             Err(DecodeError::InvalidLastSymbol(2, b'V')),
751             decode("iYV=", false)
752         );
753         assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
754         // trailing 10
755         assert_eq!(
756             Err(DecodeError::InvalidLastSymbol(2, b'W')),
757             decode("iYW=", false)
758         );
759         assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
760         // trailing 11
761         assert_eq!(
762             Err(DecodeError::InvalidLastSymbol(2, b'X')),
763             decode("iYX=", false)
764         );
765         assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
766 
767         // also works when there are 2 quads in the last block
768         assert_eq!(
769             Err(DecodeError::InvalidLastSymbol(6, b'X')),
770             decode("AAAAiYX=", false)
771         );
772         assert_eq!(Ok(vec![0, 0, 0, 137, 133]), decode("AAAAiYX=", true));
773     }
774 
775     #[test]
detect_invalid_last_symbol_one_byte()776     fn detect_invalid_last_symbol_one_byte() {
777         // 0xFF -> "/w==", so all letters > w, 0-9, and '+', '/' should get InvalidLastSymbol
778 
779         assert!(decode("/w==").is_ok());
780         // trailing 01
781         assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'x')), decode("/x=="));
782         assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'z')), decode("/z=="));
783         assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'0')), decode("/0=="));
784         assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'9')), decode("/9=="));
785         assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'+')), decode("/+=="));
786         assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'/')), decode("//=="));
787 
788         // also works when there are 2 quads in the last block
789         assert_eq!(
790             Err(DecodeError::InvalidLastSymbol(5, b'x')),
791             decode("AAAA/x==")
792         );
793     }
794 
795     #[test]
detect_invalid_last_symbol_every_possible_three_symbols()796     fn detect_invalid_last_symbol_every_possible_three_symbols() {
797         let mut base64_to_bytes = ::std::collections::HashMap::new();
798 
799         let mut bytes = [0_u8; 2];
800         for b1 in 0_u16..256 {
801             bytes[0] = b1 as u8;
802             for b2 in 0_u16..256 {
803                 bytes[1] = b2 as u8;
804                 let mut b64 = vec![0_u8; 4];
805                 assert_eq!(4, encode_config_slice(&bytes, STANDARD, &mut b64[..]));
806                 let mut v = ::std::vec::Vec::with_capacity(2);
807                 v.extend_from_slice(&bytes[..]);
808 
809                 assert!(base64_to_bytes.insert(b64, v).is_none());
810             }
811         }
812 
813         // every possible combination of symbols must either decode to 2 bytes or get InvalidLastSymbol
814 
815         let mut symbols = [0_u8; 4];
816         for &s1 in STANDARD.char_set.encode_table().iter() {
817             symbols[0] = s1;
818             for &s2 in STANDARD.char_set.encode_table().iter() {
819                 symbols[1] = s2;
820                 for &s3 in STANDARD.char_set.encode_table().iter() {
821                     symbols[2] = s3;
822                     symbols[3] = PAD_BYTE;
823 
824                     match base64_to_bytes.get(&symbols[..]) {
825                         Some(bytes) => {
826                             assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
827                         }
828                         None => assert_eq!(
829                             Err(DecodeError::InvalidLastSymbol(2, s3)),
830                             decode_config(&symbols[..], STANDARD)
831                         ),
832                     }
833                 }
834             }
835         }
836     }
837 
838     #[test]
detect_invalid_last_symbol_every_possible_two_symbols()839     fn detect_invalid_last_symbol_every_possible_two_symbols() {
840         let mut base64_to_bytes = ::std::collections::HashMap::new();
841 
842         for b in 0_u16..256 {
843             let mut b64 = vec![0_u8; 4];
844             assert_eq!(4, encode_config_slice(&[b as u8], STANDARD, &mut b64[..]));
845             let mut v = ::std::vec::Vec::with_capacity(1);
846             v.push(b as u8);
847 
848             assert!(base64_to_bytes.insert(b64, v).is_none());
849         }
850 
851         // every possible combination of symbols must either decode to 1 byte or get InvalidLastSymbol
852 
853         let mut symbols = [0_u8; 4];
854         for &s1 in STANDARD.char_set.encode_table().iter() {
855             symbols[0] = s1;
856             for &s2 in STANDARD.char_set.encode_table().iter() {
857                 symbols[1] = s2;
858                 symbols[2] = PAD_BYTE;
859                 symbols[3] = PAD_BYTE;
860 
861                 match base64_to_bytes.get(&symbols[..]) {
862                     Some(bytes) => {
863                         assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
864                     }
865                     None => assert_eq!(
866                         Err(DecodeError::InvalidLastSymbol(1, s2)),
867                         decode_config(&symbols[..], STANDARD)
868                     ),
869                 }
870             }
871         }
872     }
873 }
874