1 use crate::{tables, Config};
2 
3 #[cfg(any(feature = "alloc", feature = "std", test))]
4 use crate::STANDARD;
5 #[cfg(any(feature = "alloc", feature = "std", test))]
6 use alloc::vec::Vec;
7 #[cfg(any(feature = "std", test))]
8 use std::error;
9 use core::fmt;
10 
11 // decode logic operates on chunks of 8 input bytes without padding
12 const INPUT_CHUNK_LEN: usize = 8;
13 const DECODED_CHUNK_LEN: usize = 6;
14 // we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last
15 // 2 bytes of any output u64 should not be counted as written to (but must be available in a
16 // slice).
17 const DECODED_CHUNK_SUFFIX: usize = 2;
18 
19 // how many u64's of input to handle at a time
20 const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
21 const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
22 // includes the trailing 2 bytes for the final u64 write
23 const DECODED_BLOCK_LEN: usize =
24     CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
25 
26 /// Errors that can occur while decoding.
27 #[derive(Clone, Debug, PartialEq, Eq)]
28 pub enum DecodeError {
29     /// An invalid byte was found in the input. The offset and offending byte are provided.
30     InvalidByte(usize, u8),
31     /// The length of the input is invalid.
32     InvalidLength,
33     /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
34     /// This is indicative of corrupted or truncated Base64.
35     /// Unlike InvalidByte, which reports symbols that aren't in the alphabet, this error is for
36     /// symbols that are in the alphabet but represent nonsensical encodings.
37     InvalidLastSymbol(usize, u8),
38 }
39 
40 impl fmt::Display for DecodeError {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result41     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42         match *self {
43             DecodeError::InvalidByte(index, byte) => {
44                 write!(f, "Invalid byte {}, offset {}.", byte, index)
45             }
46             DecodeError::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."),
47             DecodeError::InvalidLastSymbol(index, byte) => {
48                 write!(f, "Invalid last symbol {}, offset {}.", byte, index)
49             }
50         }
51     }
52 }
53 
54 #[cfg(any(feature = "std", test))]
55 impl error::Error for DecodeError {
description(&self) -> &str56     fn description(&self) -> &str {
57         match *self {
58             DecodeError::InvalidByte(_, _) => "invalid byte",
59             DecodeError::InvalidLength => "invalid length",
60             DecodeError::InvalidLastSymbol(_, _) => "invalid last symbol",
61         }
62     }
63 
cause(&self) -> Option<&dyn error::Error>64     fn cause(&self) -> Option<&dyn error::Error> {
65         None
66     }
67 }
68 
69 ///Decode from string reference as octets.
70 ///Returns a Result containing a Vec<u8>.
71 ///Convenience `decode_config(input, base64::STANDARD);`.
72 ///
73 ///# Example
74 ///
75 ///```rust
76 ///extern crate base64;
77 ///
78 ///fn main() {
79 ///    let bytes = base64::decode("aGVsbG8gd29ybGQ=").unwrap();
80 ///    println!("{:?}", bytes);
81 ///}
82 ///```
83 #[cfg(any(feature = "alloc", feature = "std", test))]
decode<T: ?Sized + AsRef<[u8]>>(input: &T) -> Result<Vec<u8>, DecodeError>84 pub fn decode<T: ?Sized + AsRef<[u8]>>(input: &T) -> Result<Vec<u8>, DecodeError> {
85     decode_config(input, STANDARD)
86 }
87 
88 ///Decode from string reference as octets.
89 ///Returns a Result containing a Vec<u8>.
90 ///
91 ///# Example
92 ///
93 ///```rust
94 ///extern crate base64;
95 ///
96 ///fn main() {
97 ///    let bytes = base64::decode_config("aGVsbG8gd29ybGR+Cg==", base64::STANDARD).unwrap();
98 ///    println!("{:?}", bytes);
99 ///
100 ///    let bytes_url = base64::decode_config("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE).unwrap();
101 ///    println!("{:?}", bytes_url);
102 ///}
103 ///```
104 #[cfg(any(feature = "alloc", feature = "std", test))]
decode_config<T: ?Sized + AsRef<[u8]>>( input: &T, config: Config, ) -> Result<Vec<u8>, DecodeError>105 pub fn decode_config<T: ?Sized + AsRef<[u8]>>(
106     input: &T,
107     config: Config,
108 ) -> Result<Vec<u8>, DecodeError> {
109     let mut buffer = Vec::<u8>::with_capacity(input.as_ref().len() * 4 / 3);
110 
111     decode_config_buf(input, config, &mut buffer).map(|_| buffer)
112 }
113 
114 ///Decode from string reference as octets.
115 ///Writes into the supplied buffer to avoid allocation.
116 ///Returns a Result containing an empty tuple, aka ().
117 ///
118 ///# Example
119 ///
120 ///```rust
121 ///extern crate base64;
122 ///
123 ///fn main() {
124 ///    let mut buffer = Vec::<u8>::new();
125 ///    base64::decode_config_buf("aGVsbG8gd29ybGR+Cg==", base64::STANDARD, &mut buffer).unwrap();
126 ///    println!("{:?}", buffer);
127 ///
128 ///    buffer.clear();
129 ///
130 ///    base64::decode_config_buf("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE, &mut buffer)
131 ///        .unwrap();
132 ///    println!("{:?}", buffer);
133 ///}
134 ///```
135 #[cfg(any(feature = "alloc", feature = "std", test))]
decode_config_buf<T: ?Sized + AsRef<[u8]>>( input: &T, config: Config, buffer: &mut Vec<u8>, ) -> Result<(), DecodeError>136 pub fn decode_config_buf<T: ?Sized + AsRef<[u8]>>(
137     input: &T,
138     config: Config,
139     buffer: &mut Vec<u8>,
140 ) -> Result<(), DecodeError> {
141     let input_bytes = input.as_ref();
142 
143     let starting_output_len = buffer.len();
144 
145     let num_chunks = num_chunks(input_bytes);
146     let decoded_len_estimate = num_chunks
147         .checked_mul(DECODED_CHUNK_LEN)
148         .and_then(|p| p.checked_add(starting_output_len))
149         .expect("Overflow when calculating output buffer length");
150     buffer.resize(decoded_len_estimate, 0);
151 
152     let bytes_written;
153     {
154         let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..];
155         bytes_written = decode_helper(input_bytes, num_chunks, config, buffer_slice)?;
156     }
157 
158     buffer.truncate(starting_output_len + bytes_written);
159 
160     Ok(())
161 }
162 
163 /// Decode the input into the provided output slice.
164 ///
165 /// This will not write any bytes past exactly what is decoded (no stray garbage bytes at the end).
166 ///
167 /// If you don't know ahead of time what the decoded length should be, size your buffer with a
168 /// conservative estimate for the decoded length of an input: 3 bytes of output for every 4 bytes of
169 /// input, rounded up, or in other words `(input_len + 3) / 4 * 3`.
170 ///
171 /// If the slice is not large enough, this will panic.
decode_config_slice<T: ?Sized + AsRef<[u8]>>( input: &T, config: Config, output: &mut [u8], ) -> Result<usize, DecodeError>172 pub fn decode_config_slice<T: ?Sized + AsRef<[u8]>>(
173     input: &T,
174     config: Config,
175     output: &mut [u8],
176 ) -> Result<usize, DecodeError> {
177     let input_bytes = input.as_ref();
178 
179     decode_helper(input_bytes, num_chunks(input_bytes), config, output)
180 }
181 
182 /// Return the number of input chunks (including a possibly partial final chunk) in the input
num_chunks(input: &[u8]) -> usize183 fn num_chunks(input: &[u8]) -> usize {
184     input
185         .len()
186         .checked_add(INPUT_CHUNK_LEN - 1)
187         .expect("Overflow when calculating number of chunks in input")
188         / INPUT_CHUNK_LEN
189 }
190 
191 /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
192 /// Returns the number of bytes written, or an error.
193 // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
194 // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment,
195 // but this is fragile and the best setting changes with only minor code modifications.
196 #[inline]
decode_helper( input: &[u8], num_chunks: usize, config: Config, output: &mut [u8], ) -> Result<usize, DecodeError>197 fn decode_helper(
198     input: &[u8],
199     num_chunks: usize,
200     config: Config,
201     output: &mut [u8],
202 ) -> Result<usize, DecodeError> {
203     let char_set = config.char_set;
204     let decode_table = char_set.decode_table();
205 
206     let remainder_len = input.len() % INPUT_CHUNK_LEN;
207 
208     // Because the fast decode loop writes in groups of 8 bytes (unrolled to
209     // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of
210     // which only 6 are valid data), we need to be sure that we stop using the fast decode loop
211     // soon enough that there will always be 2 more bytes of valid data written after that loop.
212     let trailing_bytes_to_skip = match remainder_len {
213         // if input is a multiple of the chunk size, ignore the last chunk as it may have padding,
214         // and the fast decode logic cannot handle padding
215         0 => INPUT_CHUNK_LEN,
216         // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte
217         1 | 5 => return Err(DecodeError::InvalidLength),
218         // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes
219         // written by the fast decode loop. So, we have to ignore both these 2 bytes and the
220         // previous chunk.
221         2 => INPUT_CHUNK_LEN + 2,
222         // If this is 3 unpadded chars, then it would actually decode to 2 bytes. However, if this
223         // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail
224         // with an error, not panic from going past the bounds of the output slice, so we let it
225         // use stage 3 + 4.
226         3 => INPUT_CHUNK_LEN + 3,
227         // This can also decode to one output byte because it may be 2 input chars + 2 padding
228         // chars, which would decode to 1 byte.
229         4 => INPUT_CHUNK_LEN + 4,
230         // Everything else is a legal decode len (given that we don't require padding), and will
231         // decode to at least 2 bytes of output.
232         _ => remainder_len,
233     };
234 
235     // rounded up to include partial chunks
236     let mut remaining_chunks = num_chunks;
237 
238     let mut input_index = 0;
239     let mut output_index = 0;
240 
241     {
242         let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
243 
244         // Fast loop, stage 1
245         // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks
246         if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
247             while input_index <= max_start_index {
248                 let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
249                 let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
250 
251                 decode_chunk(
252                     &input_slice[0..],
253                     input_index,
254                     decode_table,
255                     &mut output_slice[0..],
256                 )?;
257                 decode_chunk(
258                     &input_slice[8..],
259                     input_index + 8,
260                     decode_table,
261                     &mut output_slice[6..],
262                 )?;
263                 decode_chunk(
264                     &input_slice[16..],
265                     input_index + 16,
266                     decode_table,
267                     &mut output_slice[12..],
268                 )?;
269                 decode_chunk(
270                     &input_slice[24..],
271                     input_index + 24,
272                     decode_table,
273                     &mut output_slice[18..],
274                 )?;
275 
276                 input_index += INPUT_BLOCK_LEN;
277                 output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
278                 remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
279             }
280         }
281 
282         // Fast loop, stage 2 (aka still pretty fast loop)
283         // 8 bytes at a time for whatever we didn't do in stage 1.
284         if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
285             while input_index < max_start_index {
286                 decode_chunk(
287                     &input[input_index..(input_index + INPUT_CHUNK_LEN)],
288                     input_index,
289                     decode_table,
290                     &mut output
291                         [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
292                 )?;
293 
294                 output_index += DECODED_CHUNK_LEN;
295                 input_index += INPUT_CHUNK_LEN;
296                 remaining_chunks -= 1;
297             }
298         }
299     }
300 
301     // Stage 3
302     // If input length was such that a chunk had to be deferred until after the fast loop
303     // because decoding it would have produced 2 trailing bytes that wouldn't then be
304     // overwritten, we decode that chunk here. This way is slower but doesn't write the 2
305     // trailing bytes.
306     // However, we still need to avoid the last chunk (partial or complete) because it could
307     // have padding, so we always do 1 fewer to avoid the last chunk.
308     for _ in 1..remaining_chunks {
309         decode_chunk_precise(
310             &input[input_index..],
311             input_index,
312             decode_table,
313             &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
314         )?;
315 
316         input_index += INPUT_CHUNK_LEN;
317         output_index += DECODED_CHUNK_LEN;
318     }
319 
320     // always have one more (possibly partial) block of 8 input
321     debug_assert!(input.len() - input_index > 1 || input.is_empty());
322     debug_assert!(input.len() - input_index <= 8);
323 
324     // Stage 4
325     // Finally, decode any leftovers that aren't a complete input block of 8 bytes.
326     // Use a u64 as a stack-resident 8 byte buffer.
327     let mut leftover_bits: u64 = 0;
328     let mut morsels_in_leftover = 0;
329     let mut padding_bytes = 0;
330     let mut first_padding_index: usize = 0;
331     let mut last_symbol = 0_u8;
332     let start_of_leftovers = input_index;
333     for (i, b) in input[start_of_leftovers..].iter().enumerate() {
334         // '=' padding
335         if *b == 0x3D {
336             // There can be bad padding in a few ways:
337             // 1 - Padding with non-padding characters after it
338             // 2 - Padding after zero or one non-padding characters before it
339             //     in the current quad.
340             // 3 - More than two characters of padding. If 3 or 4 padding chars
341             //     are in the same quad, that implies it will be caught by #2.
342             //     If it spreads from one quad to another, it will be caught by
343             //     #2 in the second quad.
344 
345             if i % 4 < 2 {
346                 // Check for case #2.
347                 let bad_padding_index = start_of_leftovers
348                     + if padding_bytes > 0 {
349                         // If we've already seen padding, report the first padding index.
350                         // This is to be consistent with the faster logic above: it will report an
351                         // error on the first padding character (since it doesn't expect to see
352                         // anything but actual encoded data).
353                         first_padding_index
354                     } else {
355                         // haven't seen padding before, just use where we are now
356                         i
357                     };
358                 return Err(DecodeError::InvalidByte(bad_padding_index, *b));
359             }
360 
361             if padding_bytes == 0 {
362                 first_padding_index = i;
363             }
364 
365             padding_bytes += 1;
366             continue;
367         }
368 
369         // Check for case #1.
370         // To make '=' handling consistent with the main loop, don't allow
371         // non-suffix '=' in trailing chunk either. Report error as first
372         // erroneous padding.
373         if padding_bytes > 0 {
374             return Err(DecodeError::InvalidByte(
375                 start_of_leftovers + first_padding_index,
376                 0x3D,
377             ));
378         }
379         last_symbol = *b;
380 
381         // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
382         // To minimize shifts, pack the leftovers from left to right.
383         let shift = 64 - (morsels_in_leftover + 1) * 6;
384         // tables are all 256 elements, lookup with a u8 index always succeeds
385         let morsel = decode_table[*b as usize];
386         if morsel == tables::INVALID_VALUE {
387             return Err(DecodeError::InvalidByte(start_of_leftovers + i, *b));
388         }
389 
390         leftover_bits |= (morsel as u64) << shift;
391         morsels_in_leftover += 1;
392     }
393 
394     let leftover_bits_ready_to_append = match morsels_in_leftover {
395         0 => 0,
396         2 => 8,
397         3 => 16,
398         4 => 24,
399         6 => 32,
400         7 => 40,
401         8 => 48,
402         _ => unreachable!(
403             "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths"
404         ),
405     };
406 
407     // if there are bits set outside the bits we care about, last symbol encodes trailing bits that
408     // will not be included in the output
409     let mask = !0 >> leftover_bits_ready_to_append;
410     if !config.decode_allow_trailing_bits && (leftover_bits & mask) != 0 {
411         // last morsel is at `morsels_in_leftover` - 1
412         return Err(DecodeError::InvalidLastSymbol(
413             start_of_leftovers + morsels_in_leftover - 1,
414             last_symbol,
415         ));
416     }
417 
418     let mut leftover_bits_appended_to_buf = 0;
419     while leftover_bits_appended_to_buf < leftover_bits_ready_to_append {
420         // `as` simply truncates the higher bits, which is what we want here
421         let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8;
422         output[output_index] = selected_bits;
423         output_index += 1;
424 
425         leftover_bits_appended_to_buf += 8;
426     }
427 
428     Ok(output_index)
429 }
430 
431 #[inline]
write_u64(output: &mut [u8], value: u64)432 fn write_u64(output: &mut [u8], value: u64) {
433     output[..8].copy_from_slice(&value.to_be_bytes());
434 }
435 
436 /// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the
437 /// first 6 of those contain meaningful data.
438 ///
439 /// `input` is the bytes to decode, of which the first 8 bytes will be processed.
440 /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
441 /// accurately)
442 /// `decode_table` is the lookup table for the particular base64 alphabet.
443 /// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded
444 /// data.
445 // yes, really inline (worth 30-50% speedup)
446 #[inline(always)]
decode_chunk( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>447 fn decode_chunk(
448     input: &[u8],
449     index_at_start_of_input: usize,
450     decode_table: &[u8; 256],
451     output: &mut [u8],
452 ) -> Result<(), DecodeError> {
453     let mut accum: u64;
454 
455     let morsel = decode_table[input[0] as usize];
456     if morsel == tables::INVALID_VALUE {
457         return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
458     }
459     accum = (morsel as u64) << 58;
460 
461     let morsel = decode_table[input[1] as usize];
462     if morsel == tables::INVALID_VALUE {
463         return Err(DecodeError::InvalidByte(
464             index_at_start_of_input + 1,
465             input[1],
466         ));
467     }
468     accum |= (morsel as u64) << 52;
469 
470     let morsel = decode_table[input[2] as usize];
471     if morsel == tables::INVALID_VALUE {
472         return Err(DecodeError::InvalidByte(
473             index_at_start_of_input + 2,
474             input[2],
475         ));
476     }
477     accum |= (morsel as u64) << 46;
478 
479     let morsel = decode_table[input[3] as usize];
480     if morsel == tables::INVALID_VALUE {
481         return Err(DecodeError::InvalidByte(
482             index_at_start_of_input + 3,
483             input[3],
484         ));
485     }
486     accum |= (morsel as u64) << 40;
487 
488     let morsel = decode_table[input[4] as usize];
489     if morsel == tables::INVALID_VALUE {
490         return Err(DecodeError::InvalidByte(
491             index_at_start_of_input + 4,
492             input[4],
493         ));
494     }
495     accum |= (morsel as u64) << 34;
496 
497     let morsel = decode_table[input[5] as usize];
498     if morsel == tables::INVALID_VALUE {
499         return Err(DecodeError::InvalidByte(
500             index_at_start_of_input + 5,
501             input[5],
502         ));
503     }
504     accum |= (morsel as u64) << 28;
505 
506     let morsel = decode_table[input[6] as usize];
507     if morsel == tables::INVALID_VALUE {
508         return Err(DecodeError::InvalidByte(
509             index_at_start_of_input + 6,
510             input[6],
511         ));
512     }
513     accum |= (morsel as u64) << 22;
514 
515     let morsel = decode_table[input[7] as usize];
516     if morsel == tables::INVALID_VALUE {
517         return Err(DecodeError::InvalidByte(
518             index_at_start_of_input + 7,
519             input[7],
520         ));
521     }
522     accum |= (morsel as u64) << 16;
523 
524     write_u64(output, accum);
525 
526     Ok(())
527 }
528 
529 /// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2
530 /// trailing garbage bytes.
531 #[inline]
decode_chunk_precise( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>532 fn decode_chunk_precise(
533     input: &[u8],
534     index_at_start_of_input: usize,
535     decode_table: &[u8; 256],
536     output: &mut [u8],
537 ) -> Result<(), DecodeError> {
538     let mut tmp_buf = [0_u8; 8];
539 
540     decode_chunk(
541         input,
542         index_at_start_of_input,
543         decode_table,
544         &mut tmp_buf[..],
545     )?;
546 
547     output[0..6].copy_from_slice(&tmp_buf[0..6]);
548 
549     Ok(())
550 }
551 
552 #[cfg(test)]
553 mod tests {
554     use super::*;
555     use crate::{
556         encode::encode_config_buf,
557         encode::encode_config_slice,
558         tests::{assert_encode_sanity, random_config},
559     };
560 
561     use rand::{
562         distributions::{Distribution, Uniform},
563         FromEntropy, Rng,
564     };
565 
566     #[test]
decode_chunk_precise_writes_only_6_bytes()567     fn decode_chunk_precise_writes_only_6_bytes() {
568         let input = b"Zm9vYmFy"; // "foobar"
569         let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
570         decode_chunk_precise(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
571         assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
572     }
573 
574     #[test]
decode_chunk_writes_8_bytes()575     fn decode_chunk_writes_8_bytes() {
576         let input = b"Zm9vYmFy"; // "foobar"
577         let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
578         decode_chunk(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
579         assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
580     }
581 
582     #[test]
decode_into_nonempty_vec_doesnt_clobber_existing_prefix()583     fn decode_into_nonempty_vec_doesnt_clobber_existing_prefix() {
584         let mut orig_data = Vec::new();
585         let mut encoded_data = String::new();
586         let mut decoded_with_prefix = Vec::new();
587         let mut decoded_without_prefix = Vec::new();
588         let mut prefix = Vec::new();
589 
590         let prefix_len_range = Uniform::new(0, 1000);
591         let input_len_range = Uniform::new(0, 1000);
592 
593         let mut rng = rand::rngs::SmallRng::from_entropy();
594 
595         for _ in 0..10_000 {
596             orig_data.clear();
597             encoded_data.clear();
598             decoded_with_prefix.clear();
599             decoded_without_prefix.clear();
600             prefix.clear();
601 
602             let input_len = input_len_range.sample(&mut rng);
603 
604             for _ in 0..input_len {
605                 orig_data.push(rng.gen());
606             }
607 
608             let config = random_config(&mut rng);
609             encode_config_buf(&orig_data, config, &mut encoded_data);
610             assert_encode_sanity(&encoded_data, config, input_len);
611 
612             let prefix_len = prefix_len_range.sample(&mut rng);
613 
614             // fill the buf with a prefix
615             for _ in 0..prefix_len {
616                 prefix.push(rng.gen());
617             }
618 
619             decoded_with_prefix.resize(prefix_len, 0);
620             decoded_with_prefix.copy_from_slice(&prefix);
621 
622             // decode into the non-empty buf
623             decode_config_buf(&encoded_data, config, &mut decoded_with_prefix).unwrap();
624             // also decode into the empty buf
625             decode_config_buf(&encoded_data, config, &mut decoded_without_prefix).unwrap();
626 
627             assert_eq!(
628                 prefix_len + decoded_without_prefix.len(),
629                 decoded_with_prefix.len()
630             );
631             assert_eq!(orig_data, decoded_without_prefix);
632 
633             // append plain decode onto prefix
634             prefix.append(&mut decoded_without_prefix);
635 
636             assert_eq!(prefix, decoded_with_prefix);
637         }
638     }
639 
640     #[test]
decode_into_slice_doesnt_clobber_existing_prefix_or_suffix()641     fn decode_into_slice_doesnt_clobber_existing_prefix_or_suffix() {
642         let mut orig_data = Vec::new();
643         let mut encoded_data = String::new();
644         let mut decode_buf = Vec::new();
645         let mut decode_buf_copy: Vec<u8> = Vec::new();
646 
647         let input_len_range = Uniform::new(0, 1000);
648 
649         let mut rng = rand::rngs::SmallRng::from_entropy();
650 
651         for _ in 0..10_000 {
652             orig_data.clear();
653             encoded_data.clear();
654             decode_buf.clear();
655             decode_buf_copy.clear();
656 
657             let input_len = input_len_range.sample(&mut rng);
658 
659             for _ in 0..input_len {
660                 orig_data.push(rng.gen());
661             }
662 
663             let config = random_config(&mut rng);
664             encode_config_buf(&orig_data, config, &mut encoded_data);
665             assert_encode_sanity(&encoded_data, config, input_len);
666 
667             // fill the buffer with random garbage, long enough to have some room before and after
668             for _ in 0..5000 {
669                 decode_buf.push(rng.gen());
670             }
671 
672             // keep a copy for later comparison
673             decode_buf_copy.extend(decode_buf.iter());
674 
675             let offset = 1000;
676 
677             // decode into the non-empty buf
678             let decode_bytes_written =
679                 decode_config_slice(&encoded_data, config, &mut decode_buf[offset..]).unwrap();
680 
681             assert_eq!(orig_data.len(), decode_bytes_written);
682             assert_eq!(
683                 orig_data,
684                 &decode_buf[offset..(offset + decode_bytes_written)]
685             );
686             assert_eq!(&decode_buf_copy[0..offset], &decode_buf[0..offset]);
687             assert_eq!(
688                 &decode_buf_copy[offset + decode_bytes_written..],
689                 &decode_buf[offset + decode_bytes_written..]
690             );
691         }
692     }
693 
694     #[test]
decode_into_slice_fits_in_precisely_sized_slice()695     fn decode_into_slice_fits_in_precisely_sized_slice() {
696         let mut orig_data = Vec::new();
697         let mut encoded_data = String::new();
698         let mut decode_buf = Vec::new();
699 
700         let input_len_range = Uniform::new(0, 1000);
701 
702         let mut rng = rand::rngs::SmallRng::from_entropy();
703 
704         for _ in 0..10_000 {
705             orig_data.clear();
706             encoded_data.clear();
707             decode_buf.clear();
708 
709             let input_len = input_len_range.sample(&mut rng);
710 
711             for _ in 0..input_len {
712                 orig_data.push(rng.gen());
713             }
714 
715             let config = random_config(&mut rng);
716             encode_config_buf(&orig_data, config, &mut encoded_data);
717             assert_encode_sanity(&encoded_data, config, input_len);
718 
719             decode_buf.resize(input_len, 0);
720 
721             // decode into the non-empty buf
722             let decode_bytes_written =
723                 decode_config_slice(&encoded_data, config, &mut decode_buf[..]).unwrap();
724 
725             assert_eq!(orig_data.len(), decode_bytes_written);
726             assert_eq!(orig_data, decode_buf);
727         }
728     }
729 
730     #[test]
detect_invalid_last_symbol_two_bytes()731     fn detect_invalid_last_symbol_two_bytes() {
732         let decode =
733             |input, forgiving| decode_config(input, STANDARD.decode_allow_trailing_bits(forgiving));
734 
735         // example from https://github.com/marshallpierce/rust-base64/issues/75
736         assert!(decode("iYU=", false).is_ok());
737         // trailing 01
738         assert_eq!(
739             Err(DecodeError::InvalidLastSymbol(2, b'V')),
740             decode("iYV=", false)
741         );
742         assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
743         // trailing 10
744         assert_eq!(
745             Err(DecodeError::InvalidLastSymbol(2, b'W')),
746             decode("iYW=", false)
747         );
748         assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
749         // trailing 11
750         assert_eq!(
751             Err(DecodeError::InvalidLastSymbol(2, b'X')),
752             decode("iYX=", false)
753         );
754         assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
755 
756         // also works when there are 2 quads in the last block
757         assert_eq!(
758             Err(DecodeError::InvalidLastSymbol(6, b'X')),
759             decode("AAAAiYX=", false)
760         );
761         assert_eq!(Ok(vec![0, 0, 0, 137, 133]), decode("AAAAiYX=", true));
762     }
763 
764     #[test]
detect_invalid_last_symbol_one_byte()765     fn detect_invalid_last_symbol_one_byte() {
766         // 0xFF -> "/w==", so all letters > w, 0-9, and '+', '/' should get InvalidLastSymbol
767 
768         assert!(decode("/w==").is_ok());
769         // trailing 01
770         assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'x')), decode("/x=="));
771         assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'z')), decode("/z=="));
772         assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'0')), decode("/0=="));
773         assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'9')), decode("/9=="));
774         assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'+')), decode("/+=="));
775         assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'/')), decode("//=="));
776 
777         // also works when there are 2 quads in the last block
778         assert_eq!(
779             Err(DecodeError::InvalidLastSymbol(5, b'x')),
780             decode("AAAA/x==")
781         );
782     }
783 
784     #[test]
detect_invalid_last_symbol_every_possible_three_symbols()785     fn detect_invalid_last_symbol_every_possible_three_symbols() {
786         let mut base64_to_bytes = ::std::collections::HashMap::new();
787 
788         let mut bytes = [0_u8; 2];
789         for b1 in 0_u16..256 {
790             bytes[0] = b1 as u8;
791             for b2 in 0_u16..256 {
792                 bytes[1] = b2 as u8;
793                 let mut b64 = vec![0_u8; 4];
794                 assert_eq!(4, encode_config_slice(&bytes, STANDARD, &mut b64[..]));
795                 let mut v = ::std::vec::Vec::with_capacity(2);
796                 v.extend_from_slice(&bytes[..]);
797 
798                 assert!(base64_to_bytes.insert(b64, v).is_none());
799             }
800         }
801 
802         // every possible combination of symbols must either decode to 2 bytes or get InvalidLastSymbol
803 
804         let mut symbols = [0_u8; 4];
805         for &s1 in STANDARD.char_set.encode_table().iter() {
806             symbols[0] = s1;
807             for &s2 in STANDARD.char_set.encode_table().iter() {
808                 symbols[1] = s2;
809                 for &s3 in STANDARD.char_set.encode_table().iter() {
810                     symbols[2] = s3;
811                     symbols[3] = b'=';
812 
813                     match base64_to_bytes.get(&symbols[..]) {
814                         Some(bytes) => {
815                             assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
816                         }
817                         None => assert_eq!(
818                             Err(DecodeError::InvalidLastSymbol(2, s3)),
819                             decode_config(&symbols[..], STANDARD)
820                         ),
821                     }
822                 }
823             }
824         }
825     }
826 
827     #[test]
detect_invalid_last_symbol_every_possible_two_symbols()828     fn detect_invalid_last_symbol_every_possible_two_symbols() {
829         let mut base64_to_bytes = ::std::collections::HashMap::new();
830 
831         for b in 0_u16..256 {
832             let mut b64 = vec![0_u8; 4];
833             assert_eq!(4, encode_config_slice(&[b as u8], STANDARD, &mut b64[..]));
834             let mut v = ::std::vec::Vec::with_capacity(1);
835             v.push(b as u8);
836 
837             assert!(base64_to_bytes.insert(b64, v).is_none());
838         }
839 
840         // every possible combination of symbols must either decode to 1 byte or get InvalidLastSymbol
841 
842         let mut symbols = [0_u8; 4];
843         for &s1 in STANDARD.char_set.encode_table().iter() {
844             symbols[0] = s1;
845             for &s2 in STANDARD.char_set.encode_table().iter() {
846                 symbols[1] = s2;
847                 symbols[2] = b'=';
848                 symbols[3] = b'=';
849 
850                 match base64_to_bytes.get(&symbols[..]) {
851                     Some(bytes) => {
852                         assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
853                     }
854                     None => assert_eq!(
855                         Err(DecodeError::InvalidLastSymbol(1, s2)),
856                         decode_config(&symbols[..], STANDARD)
857                     ),
858                 }
859             }
860         }
861     }
862 }
863