1 use crate::{tables, Config};
2
3 #[cfg(any(feature = "alloc", feature = "std", test))]
4 use crate::STANDARD;
5 #[cfg(any(feature = "alloc", feature = "std", test))]
6 use alloc::vec::Vec;
7 #[cfg(any(feature = "std", test))]
8 use std::error;
9 use core::fmt;
10
11 // decode logic operates on chunks of 8 input bytes without padding
12 const INPUT_CHUNK_LEN: usize = 8;
13 const DECODED_CHUNK_LEN: usize = 6;
14 // we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last
15 // 2 bytes of any output u64 should not be counted as written to (but must be available in a
16 // slice).
17 const DECODED_CHUNK_SUFFIX: usize = 2;
18
19 // how many u64's of input to handle at a time
20 const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
21 const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
22 // includes the trailing 2 bytes for the final u64 write
23 const DECODED_BLOCK_LEN: usize =
24 CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
25
26 /// Errors that can occur while decoding.
27 #[derive(Clone, Debug, PartialEq, Eq)]
28 pub enum DecodeError {
29 /// An invalid byte was found in the input. The offset and offending byte are provided.
30 InvalidByte(usize, u8),
31 /// The length of the input is invalid.
32 InvalidLength,
33 /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
34 /// This is indicative of corrupted or truncated Base64.
35 /// Unlike InvalidByte, which reports symbols that aren't in the alphabet, this error is for
36 /// symbols that are in the alphabet but represent nonsensical encodings.
37 InvalidLastSymbol(usize, u8),
38 }
39
40 impl fmt::Display for DecodeError {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result41 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42 match *self {
43 DecodeError::InvalidByte(index, byte) => {
44 write!(f, "Invalid byte {}, offset {}.", byte, index)
45 }
46 DecodeError::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."),
47 DecodeError::InvalidLastSymbol(index, byte) => {
48 write!(f, "Invalid last symbol {}, offset {}.", byte, index)
49 }
50 }
51 }
52 }
53
54 #[cfg(any(feature = "std", test))]
55 impl error::Error for DecodeError {
description(&self) -> &str56 fn description(&self) -> &str {
57 match *self {
58 DecodeError::InvalidByte(_, _) => "invalid byte",
59 DecodeError::InvalidLength => "invalid length",
60 DecodeError::InvalidLastSymbol(_, _) => "invalid last symbol",
61 }
62 }
63
cause(&self) -> Option<&dyn error::Error>64 fn cause(&self) -> Option<&dyn error::Error> {
65 None
66 }
67 }
68
69 ///Decode from string reference as octets.
70 ///Returns a Result containing a Vec<u8>.
71 ///Convenience `decode_config(input, base64::STANDARD);`.
72 ///
73 ///# Example
74 ///
75 ///```rust
76 ///extern crate base64;
77 ///
78 ///fn main() {
79 /// let bytes = base64::decode("aGVsbG8gd29ybGQ=").unwrap();
80 /// println!("{:?}", bytes);
81 ///}
82 ///```
83 #[cfg(any(feature = "alloc", feature = "std", test))]
decode<T: ?Sized + AsRef<[u8]>>(input: &T) -> Result<Vec<u8>, DecodeError>84 pub fn decode<T: ?Sized + AsRef<[u8]>>(input: &T) -> Result<Vec<u8>, DecodeError> {
85 decode_config(input, STANDARD)
86 }
87
88 ///Decode from string reference as octets.
89 ///Returns a Result containing a Vec<u8>.
90 ///
91 ///# Example
92 ///
93 ///```rust
94 ///extern crate base64;
95 ///
96 ///fn main() {
97 /// let bytes = base64::decode_config("aGVsbG8gd29ybGR+Cg==", base64::STANDARD).unwrap();
98 /// println!("{:?}", bytes);
99 ///
100 /// let bytes_url = base64::decode_config("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE).unwrap();
101 /// println!("{:?}", bytes_url);
102 ///}
103 ///```
104 #[cfg(any(feature = "alloc", feature = "std", test))]
decode_config<T: ?Sized + AsRef<[u8]>>( input: &T, config: Config, ) -> Result<Vec<u8>, DecodeError>105 pub fn decode_config<T: ?Sized + AsRef<[u8]>>(
106 input: &T,
107 config: Config,
108 ) -> Result<Vec<u8>, DecodeError> {
109 let mut buffer = Vec::<u8>::with_capacity(input.as_ref().len() * 4 / 3);
110
111 decode_config_buf(input, config, &mut buffer).map(|_| buffer)
112 }
113
114 ///Decode from string reference as octets.
115 ///Writes into the supplied buffer to avoid allocation.
116 ///Returns a Result containing an empty tuple, aka ().
117 ///
118 ///# Example
119 ///
120 ///```rust
121 ///extern crate base64;
122 ///
123 ///fn main() {
124 /// let mut buffer = Vec::<u8>::new();
125 /// base64::decode_config_buf("aGVsbG8gd29ybGR+Cg==", base64::STANDARD, &mut buffer).unwrap();
126 /// println!("{:?}", buffer);
127 ///
128 /// buffer.clear();
129 ///
130 /// base64::decode_config_buf("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE, &mut buffer)
131 /// .unwrap();
132 /// println!("{:?}", buffer);
133 ///}
134 ///```
135 #[cfg(any(feature = "alloc", feature = "std", test))]
decode_config_buf<T: ?Sized + AsRef<[u8]>>( input: &T, config: Config, buffer: &mut Vec<u8>, ) -> Result<(), DecodeError>136 pub fn decode_config_buf<T: ?Sized + AsRef<[u8]>>(
137 input: &T,
138 config: Config,
139 buffer: &mut Vec<u8>,
140 ) -> Result<(), DecodeError> {
141 let input_bytes = input.as_ref();
142
143 let starting_output_len = buffer.len();
144
145 let num_chunks = num_chunks(input_bytes);
146 let decoded_len_estimate = num_chunks
147 .checked_mul(DECODED_CHUNK_LEN)
148 .and_then(|p| p.checked_add(starting_output_len))
149 .expect("Overflow when calculating output buffer length");
150 buffer.resize(decoded_len_estimate, 0);
151
152 let bytes_written;
153 {
154 let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..];
155 bytes_written = decode_helper(input_bytes, num_chunks, config, buffer_slice)?;
156 }
157
158 buffer.truncate(starting_output_len + bytes_written);
159
160 Ok(())
161 }
162
163 /// Decode the input into the provided output slice.
164 ///
165 /// This will not write any bytes past exactly what is decoded (no stray garbage bytes at the end).
166 ///
167 /// If you don't know ahead of time what the decoded length should be, size your buffer with a
168 /// conservative estimate for the decoded length of an input: 3 bytes of output for every 4 bytes of
169 /// input, rounded up, or in other words `(input_len + 3) / 4 * 3`.
170 ///
171 /// If the slice is not large enough, this will panic.
decode_config_slice<T: ?Sized + AsRef<[u8]>>( input: &T, config: Config, output: &mut [u8], ) -> Result<usize, DecodeError>172 pub fn decode_config_slice<T: ?Sized + AsRef<[u8]>>(
173 input: &T,
174 config: Config,
175 output: &mut [u8],
176 ) -> Result<usize, DecodeError> {
177 let input_bytes = input.as_ref();
178
179 decode_helper(input_bytes, num_chunks(input_bytes), config, output)
180 }
181
182 /// Return the number of input chunks (including a possibly partial final chunk) in the input
num_chunks(input: &[u8]) -> usize183 fn num_chunks(input: &[u8]) -> usize {
184 input
185 .len()
186 .checked_add(INPUT_CHUNK_LEN - 1)
187 .expect("Overflow when calculating number of chunks in input")
188 / INPUT_CHUNK_LEN
189 }
190
191 /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
192 /// Returns the number of bytes written, or an error.
193 // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
194 // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment,
195 // but this is fragile and the best setting changes with only minor code modifications.
196 #[inline]
decode_helper( input: &[u8], num_chunks: usize, config: Config, output: &mut [u8], ) -> Result<usize, DecodeError>197 fn decode_helper(
198 input: &[u8],
199 num_chunks: usize,
200 config: Config,
201 output: &mut [u8],
202 ) -> Result<usize, DecodeError> {
203 let char_set = config.char_set;
204 let decode_table = char_set.decode_table();
205
206 let remainder_len = input.len() % INPUT_CHUNK_LEN;
207
208 // Because the fast decode loop writes in groups of 8 bytes (unrolled to
209 // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of
210 // which only 6 are valid data), we need to be sure that we stop using the fast decode loop
211 // soon enough that there will always be 2 more bytes of valid data written after that loop.
212 let trailing_bytes_to_skip = match remainder_len {
213 // if input is a multiple of the chunk size, ignore the last chunk as it may have padding,
214 // and the fast decode logic cannot handle padding
215 0 => INPUT_CHUNK_LEN,
216 // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte
217 1 | 5 => return Err(DecodeError::InvalidLength),
218 // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes
219 // written by the fast decode loop. So, we have to ignore both these 2 bytes and the
220 // previous chunk.
221 2 => INPUT_CHUNK_LEN + 2,
222 // If this is 3 unpadded chars, then it would actually decode to 2 bytes. However, if this
223 // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail
224 // with an error, not panic from going past the bounds of the output slice, so we let it
225 // use stage 3 + 4.
226 3 => INPUT_CHUNK_LEN + 3,
227 // This can also decode to one output byte because it may be 2 input chars + 2 padding
228 // chars, which would decode to 1 byte.
229 4 => INPUT_CHUNK_LEN + 4,
230 // Everything else is a legal decode len (given that we don't require padding), and will
231 // decode to at least 2 bytes of output.
232 _ => remainder_len,
233 };
234
235 // rounded up to include partial chunks
236 let mut remaining_chunks = num_chunks;
237
238 let mut input_index = 0;
239 let mut output_index = 0;
240
241 {
242 let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
243
244 // Fast loop, stage 1
245 // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks
246 if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
247 while input_index <= max_start_index {
248 let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
249 let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
250
251 decode_chunk(
252 &input_slice[0..],
253 input_index,
254 decode_table,
255 &mut output_slice[0..],
256 )?;
257 decode_chunk(
258 &input_slice[8..],
259 input_index + 8,
260 decode_table,
261 &mut output_slice[6..],
262 )?;
263 decode_chunk(
264 &input_slice[16..],
265 input_index + 16,
266 decode_table,
267 &mut output_slice[12..],
268 )?;
269 decode_chunk(
270 &input_slice[24..],
271 input_index + 24,
272 decode_table,
273 &mut output_slice[18..],
274 )?;
275
276 input_index += INPUT_BLOCK_LEN;
277 output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
278 remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
279 }
280 }
281
282 // Fast loop, stage 2 (aka still pretty fast loop)
283 // 8 bytes at a time for whatever we didn't do in stage 1.
284 if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
285 while input_index < max_start_index {
286 decode_chunk(
287 &input[input_index..(input_index + INPUT_CHUNK_LEN)],
288 input_index,
289 decode_table,
290 &mut output
291 [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
292 )?;
293
294 output_index += DECODED_CHUNK_LEN;
295 input_index += INPUT_CHUNK_LEN;
296 remaining_chunks -= 1;
297 }
298 }
299 }
300
301 // Stage 3
302 // If input length was such that a chunk had to be deferred until after the fast loop
303 // because decoding it would have produced 2 trailing bytes that wouldn't then be
304 // overwritten, we decode that chunk here. This way is slower but doesn't write the 2
305 // trailing bytes.
306 // However, we still need to avoid the last chunk (partial or complete) because it could
307 // have padding, so we always do 1 fewer to avoid the last chunk.
308 for _ in 1..remaining_chunks {
309 decode_chunk_precise(
310 &input[input_index..],
311 input_index,
312 decode_table,
313 &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
314 )?;
315
316 input_index += INPUT_CHUNK_LEN;
317 output_index += DECODED_CHUNK_LEN;
318 }
319
320 // always have one more (possibly partial) block of 8 input
321 debug_assert!(input.len() - input_index > 1 || input.is_empty());
322 debug_assert!(input.len() - input_index <= 8);
323
324 // Stage 4
325 // Finally, decode any leftovers that aren't a complete input block of 8 bytes.
326 // Use a u64 as a stack-resident 8 byte buffer.
327 let mut leftover_bits: u64 = 0;
328 let mut morsels_in_leftover = 0;
329 let mut padding_bytes = 0;
330 let mut first_padding_index: usize = 0;
331 let mut last_symbol = 0_u8;
332 let start_of_leftovers = input_index;
333 for (i, b) in input[start_of_leftovers..].iter().enumerate() {
334 // '=' padding
335 if *b == 0x3D {
336 // There can be bad padding in a few ways:
337 // 1 - Padding with non-padding characters after it
338 // 2 - Padding after zero or one non-padding characters before it
339 // in the current quad.
340 // 3 - More than two characters of padding. If 3 or 4 padding chars
341 // are in the same quad, that implies it will be caught by #2.
342 // If it spreads from one quad to another, it will be caught by
343 // #2 in the second quad.
344
345 if i % 4 < 2 {
346 // Check for case #2.
347 let bad_padding_index = start_of_leftovers
348 + if padding_bytes > 0 {
349 // If we've already seen padding, report the first padding index.
350 // This is to be consistent with the faster logic above: it will report an
351 // error on the first padding character (since it doesn't expect to see
352 // anything but actual encoded data).
353 first_padding_index
354 } else {
355 // haven't seen padding before, just use where we are now
356 i
357 };
358 return Err(DecodeError::InvalidByte(bad_padding_index, *b));
359 }
360
361 if padding_bytes == 0 {
362 first_padding_index = i;
363 }
364
365 padding_bytes += 1;
366 continue;
367 }
368
369 // Check for case #1.
370 // To make '=' handling consistent with the main loop, don't allow
371 // non-suffix '=' in trailing chunk either. Report error as first
372 // erroneous padding.
373 if padding_bytes > 0 {
374 return Err(DecodeError::InvalidByte(
375 start_of_leftovers + first_padding_index,
376 0x3D,
377 ));
378 }
379 last_symbol = *b;
380
381 // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
382 // To minimize shifts, pack the leftovers from left to right.
383 let shift = 64 - (morsels_in_leftover + 1) * 6;
384 // tables are all 256 elements, lookup with a u8 index always succeeds
385 let morsel = decode_table[*b as usize];
386 if morsel == tables::INVALID_VALUE {
387 return Err(DecodeError::InvalidByte(start_of_leftovers + i, *b));
388 }
389
390 leftover_bits |= (morsel as u64) << shift;
391 morsels_in_leftover += 1;
392 }
393
394 let leftover_bits_ready_to_append = match morsels_in_leftover {
395 0 => 0,
396 2 => 8,
397 3 => 16,
398 4 => 24,
399 6 => 32,
400 7 => 40,
401 8 => 48,
402 _ => unreachable!(
403 "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths"
404 ),
405 };
406
407 // if there are bits set outside the bits we care about, last symbol encodes trailing bits that
408 // will not be included in the output
409 let mask = !0 >> leftover_bits_ready_to_append;
410 if !config.decode_allow_trailing_bits && (leftover_bits & mask) != 0 {
411 // last morsel is at `morsels_in_leftover` - 1
412 return Err(DecodeError::InvalidLastSymbol(
413 start_of_leftovers + morsels_in_leftover - 1,
414 last_symbol,
415 ));
416 }
417
418 let mut leftover_bits_appended_to_buf = 0;
419 while leftover_bits_appended_to_buf < leftover_bits_ready_to_append {
420 // `as` simply truncates the higher bits, which is what we want here
421 let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8;
422 output[output_index] = selected_bits;
423 output_index += 1;
424
425 leftover_bits_appended_to_buf += 8;
426 }
427
428 Ok(output_index)
429 }
430
431 #[inline]
write_u64(output: &mut [u8], value: u64)432 fn write_u64(output: &mut [u8], value: u64) {
433 output[..8].copy_from_slice(&value.to_be_bytes());
434 }
435
436 /// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the
437 /// first 6 of those contain meaningful data.
438 ///
439 /// `input` is the bytes to decode, of which the first 8 bytes will be processed.
440 /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
441 /// accurately)
442 /// `decode_table` is the lookup table for the particular base64 alphabet.
443 /// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded
444 /// data.
445 // yes, really inline (worth 30-50% speedup)
446 #[inline(always)]
decode_chunk( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>447 fn decode_chunk(
448 input: &[u8],
449 index_at_start_of_input: usize,
450 decode_table: &[u8; 256],
451 output: &mut [u8],
452 ) -> Result<(), DecodeError> {
453 let mut accum: u64;
454
455 let morsel = decode_table[input[0] as usize];
456 if morsel == tables::INVALID_VALUE {
457 return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
458 }
459 accum = (morsel as u64) << 58;
460
461 let morsel = decode_table[input[1] as usize];
462 if morsel == tables::INVALID_VALUE {
463 return Err(DecodeError::InvalidByte(
464 index_at_start_of_input + 1,
465 input[1],
466 ));
467 }
468 accum |= (morsel as u64) << 52;
469
470 let morsel = decode_table[input[2] as usize];
471 if morsel == tables::INVALID_VALUE {
472 return Err(DecodeError::InvalidByte(
473 index_at_start_of_input + 2,
474 input[2],
475 ));
476 }
477 accum |= (morsel as u64) << 46;
478
479 let morsel = decode_table[input[3] as usize];
480 if morsel == tables::INVALID_VALUE {
481 return Err(DecodeError::InvalidByte(
482 index_at_start_of_input + 3,
483 input[3],
484 ));
485 }
486 accum |= (morsel as u64) << 40;
487
488 let morsel = decode_table[input[4] as usize];
489 if morsel == tables::INVALID_VALUE {
490 return Err(DecodeError::InvalidByte(
491 index_at_start_of_input + 4,
492 input[4],
493 ));
494 }
495 accum |= (morsel as u64) << 34;
496
497 let morsel = decode_table[input[5] as usize];
498 if morsel == tables::INVALID_VALUE {
499 return Err(DecodeError::InvalidByte(
500 index_at_start_of_input + 5,
501 input[5],
502 ));
503 }
504 accum |= (morsel as u64) << 28;
505
506 let morsel = decode_table[input[6] as usize];
507 if morsel == tables::INVALID_VALUE {
508 return Err(DecodeError::InvalidByte(
509 index_at_start_of_input + 6,
510 input[6],
511 ));
512 }
513 accum |= (morsel as u64) << 22;
514
515 let morsel = decode_table[input[7] as usize];
516 if morsel == tables::INVALID_VALUE {
517 return Err(DecodeError::InvalidByte(
518 index_at_start_of_input + 7,
519 input[7],
520 ));
521 }
522 accum |= (morsel as u64) << 16;
523
524 write_u64(output, accum);
525
526 Ok(())
527 }
528
529 /// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2
530 /// trailing garbage bytes.
531 #[inline]
decode_chunk_precise( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>532 fn decode_chunk_precise(
533 input: &[u8],
534 index_at_start_of_input: usize,
535 decode_table: &[u8; 256],
536 output: &mut [u8],
537 ) -> Result<(), DecodeError> {
538 let mut tmp_buf = [0_u8; 8];
539
540 decode_chunk(
541 input,
542 index_at_start_of_input,
543 decode_table,
544 &mut tmp_buf[..],
545 )?;
546
547 output[0..6].copy_from_slice(&tmp_buf[0..6]);
548
549 Ok(())
550 }
551
552 #[cfg(test)]
553 mod tests {
554 use super::*;
555 use crate::{
556 encode::encode_config_buf,
557 encode::encode_config_slice,
558 tests::{assert_encode_sanity, random_config},
559 };
560
561 use rand::{
562 distributions::{Distribution, Uniform},
563 FromEntropy, Rng,
564 };
565
566 #[test]
decode_chunk_precise_writes_only_6_bytes()567 fn decode_chunk_precise_writes_only_6_bytes() {
568 let input = b"Zm9vYmFy"; // "foobar"
569 let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
570 decode_chunk_precise(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
571 assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
572 }
573
574 #[test]
decode_chunk_writes_8_bytes()575 fn decode_chunk_writes_8_bytes() {
576 let input = b"Zm9vYmFy"; // "foobar"
577 let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
578 decode_chunk(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
579 assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
580 }
581
582 #[test]
decode_into_nonempty_vec_doesnt_clobber_existing_prefix()583 fn decode_into_nonempty_vec_doesnt_clobber_existing_prefix() {
584 let mut orig_data = Vec::new();
585 let mut encoded_data = String::new();
586 let mut decoded_with_prefix = Vec::new();
587 let mut decoded_without_prefix = Vec::new();
588 let mut prefix = Vec::new();
589
590 let prefix_len_range = Uniform::new(0, 1000);
591 let input_len_range = Uniform::new(0, 1000);
592
593 let mut rng = rand::rngs::SmallRng::from_entropy();
594
595 for _ in 0..10_000 {
596 orig_data.clear();
597 encoded_data.clear();
598 decoded_with_prefix.clear();
599 decoded_without_prefix.clear();
600 prefix.clear();
601
602 let input_len = input_len_range.sample(&mut rng);
603
604 for _ in 0..input_len {
605 orig_data.push(rng.gen());
606 }
607
608 let config = random_config(&mut rng);
609 encode_config_buf(&orig_data, config, &mut encoded_data);
610 assert_encode_sanity(&encoded_data, config, input_len);
611
612 let prefix_len = prefix_len_range.sample(&mut rng);
613
614 // fill the buf with a prefix
615 for _ in 0..prefix_len {
616 prefix.push(rng.gen());
617 }
618
619 decoded_with_prefix.resize(prefix_len, 0);
620 decoded_with_prefix.copy_from_slice(&prefix);
621
622 // decode into the non-empty buf
623 decode_config_buf(&encoded_data, config, &mut decoded_with_prefix).unwrap();
624 // also decode into the empty buf
625 decode_config_buf(&encoded_data, config, &mut decoded_without_prefix).unwrap();
626
627 assert_eq!(
628 prefix_len + decoded_without_prefix.len(),
629 decoded_with_prefix.len()
630 );
631 assert_eq!(orig_data, decoded_without_prefix);
632
633 // append plain decode onto prefix
634 prefix.append(&mut decoded_without_prefix);
635
636 assert_eq!(prefix, decoded_with_prefix);
637 }
638 }
639
640 #[test]
decode_into_slice_doesnt_clobber_existing_prefix_or_suffix()641 fn decode_into_slice_doesnt_clobber_existing_prefix_or_suffix() {
642 let mut orig_data = Vec::new();
643 let mut encoded_data = String::new();
644 let mut decode_buf = Vec::new();
645 let mut decode_buf_copy: Vec<u8> = Vec::new();
646
647 let input_len_range = Uniform::new(0, 1000);
648
649 let mut rng = rand::rngs::SmallRng::from_entropy();
650
651 for _ in 0..10_000 {
652 orig_data.clear();
653 encoded_data.clear();
654 decode_buf.clear();
655 decode_buf_copy.clear();
656
657 let input_len = input_len_range.sample(&mut rng);
658
659 for _ in 0..input_len {
660 orig_data.push(rng.gen());
661 }
662
663 let config = random_config(&mut rng);
664 encode_config_buf(&orig_data, config, &mut encoded_data);
665 assert_encode_sanity(&encoded_data, config, input_len);
666
667 // fill the buffer with random garbage, long enough to have some room before and after
668 for _ in 0..5000 {
669 decode_buf.push(rng.gen());
670 }
671
672 // keep a copy for later comparison
673 decode_buf_copy.extend(decode_buf.iter());
674
675 let offset = 1000;
676
677 // decode into the non-empty buf
678 let decode_bytes_written =
679 decode_config_slice(&encoded_data, config, &mut decode_buf[offset..]).unwrap();
680
681 assert_eq!(orig_data.len(), decode_bytes_written);
682 assert_eq!(
683 orig_data,
684 &decode_buf[offset..(offset + decode_bytes_written)]
685 );
686 assert_eq!(&decode_buf_copy[0..offset], &decode_buf[0..offset]);
687 assert_eq!(
688 &decode_buf_copy[offset + decode_bytes_written..],
689 &decode_buf[offset + decode_bytes_written..]
690 );
691 }
692 }
693
694 #[test]
decode_into_slice_fits_in_precisely_sized_slice()695 fn decode_into_slice_fits_in_precisely_sized_slice() {
696 let mut orig_data = Vec::new();
697 let mut encoded_data = String::new();
698 let mut decode_buf = Vec::new();
699
700 let input_len_range = Uniform::new(0, 1000);
701
702 let mut rng = rand::rngs::SmallRng::from_entropy();
703
704 for _ in 0..10_000 {
705 orig_data.clear();
706 encoded_data.clear();
707 decode_buf.clear();
708
709 let input_len = input_len_range.sample(&mut rng);
710
711 for _ in 0..input_len {
712 orig_data.push(rng.gen());
713 }
714
715 let config = random_config(&mut rng);
716 encode_config_buf(&orig_data, config, &mut encoded_data);
717 assert_encode_sanity(&encoded_data, config, input_len);
718
719 decode_buf.resize(input_len, 0);
720
721 // decode into the non-empty buf
722 let decode_bytes_written =
723 decode_config_slice(&encoded_data, config, &mut decode_buf[..]).unwrap();
724
725 assert_eq!(orig_data.len(), decode_bytes_written);
726 assert_eq!(orig_data, decode_buf);
727 }
728 }
729
730 #[test]
detect_invalid_last_symbol_two_bytes()731 fn detect_invalid_last_symbol_two_bytes() {
732 let decode =
733 |input, forgiving| decode_config(input, STANDARD.decode_allow_trailing_bits(forgiving));
734
735 // example from https://github.com/marshallpierce/rust-base64/issues/75
736 assert!(decode("iYU=", false).is_ok());
737 // trailing 01
738 assert_eq!(
739 Err(DecodeError::InvalidLastSymbol(2, b'V')),
740 decode("iYV=", false)
741 );
742 assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
743 // trailing 10
744 assert_eq!(
745 Err(DecodeError::InvalidLastSymbol(2, b'W')),
746 decode("iYW=", false)
747 );
748 assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
749 // trailing 11
750 assert_eq!(
751 Err(DecodeError::InvalidLastSymbol(2, b'X')),
752 decode("iYX=", false)
753 );
754 assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
755
756 // also works when there are 2 quads in the last block
757 assert_eq!(
758 Err(DecodeError::InvalidLastSymbol(6, b'X')),
759 decode("AAAAiYX=", false)
760 );
761 assert_eq!(Ok(vec![0, 0, 0, 137, 133]), decode("AAAAiYX=", true));
762 }
763
764 #[test]
detect_invalid_last_symbol_one_byte()765 fn detect_invalid_last_symbol_one_byte() {
766 // 0xFF -> "/w==", so all letters > w, 0-9, and '+', '/' should get InvalidLastSymbol
767
768 assert!(decode("/w==").is_ok());
769 // trailing 01
770 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'x')), decode("/x=="));
771 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'z')), decode("/z=="));
772 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'0')), decode("/0=="));
773 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'9')), decode("/9=="));
774 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'+')), decode("/+=="));
775 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'/')), decode("//=="));
776
777 // also works when there are 2 quads in the last block
778 assert_eq!(
779 Err(DecodeError::InvalidLastSymbol(5, b'x')),
780 decode("AAAA/x==")
781 );
782 }
783
784 #[test]
detect_invalid_last_symbol_every_possible_three_symbols()785 fn detect_invalid_last_symbol_every_possible_three_symbols() {
786 let mut base64_to_bytes = ::std::collections::HashMap::new();
787
788 let mut bytes = [0_u8; 2];
789 for b1 in 0_u16..256 {
790 bytes[0] = b1 as u8;
791 for b2 in 0_u16..256 {
792 bytes[1] = b2 as u8;
793 let mut b64 = vec![0_u8; 4];
794 assert_eq!(4, encode_config_slice(&bytes, STANDARD, &mut b64[..]));
795 let mut v = ::std::vec::Vec::with_capacity(2);
796 v.extend_from_slice(&bytes[..]);
797
798 assert!(base64_to_bytes.insert(b64, v).is_none());
799 }
800 }
801
802 // every possible combination of symbols must either decode to 2 bytes or get InvalidLastSymbol
803
804 let mut symbols = [0_u8; 4];
805 for &s1 in STANDARD.char_set.encode_table().iter() {
806 symbols[0] = s1;
807 for &s2 in STANDARD.char_set.encode_table().iter() {
808 symbols[1] = s2;
809 for &s3 in STANDARD.char_set.encode_table().iter() {
810 symbols[2] = s3;
811 symbols[3] = b'=';
812
813 match base64_to_bytes.get(&symbols[..]) {
814 Some(bytes) => {
815 assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
816 }
817 None => assert_eq!(
818 Err(DecodeError::InvalidLastSymbol(2, s3)),
819 decode_config(&symbols[..], STANDARD)
820 ),
821 }
822 }
823 }
824 }
825 }
826
827 #[test]
detect_invalid_last_symbol_every_possible_two_symbols()828 fn detect_invalid_last_symbol_every_possible_two_symbols() {
829 let mut base64_to_bytes = ::std::collections::HashMap::new();
830
831 for b in 0_u16..256 {
832 let mut b64 = vec![0_u8; 4];
833 assert_eq!(4, encode_config_slice(&[b as u8], STANDARD, &mut b64[..]));
834 let mut v = ::std::vec::Vec::with_capacity(1);
835 v.push(b as u8);
836
837 assert!(base64_to_bytes.insert(b64, v).is_none());
838 }
839
840 // every possible combination of symbols must either decode to 1 byte or get InvalidLastSymbol
841
842 let mut symbols = [0_u8; 4];
843 for &s1 in STANDARD.char_set.encode_table().iter() {
844 symbols[0] = s1;
845 for &s2 in STANDARD.char_set.encode_table().iter() {
846 symbols[1] = s2;
847 symbols[2] = b'=';
848 symbols[3] = b'=';
849
850 match base64_to_bytes.get(&symbols[..]) {
851 Some(bytes) => {
852 assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
853 }
854 None => assert_eq!(
855 Err(DecodeError::InvalidLastSymbol(1, s2)),
856 decode_config(&symbols[..], STANDARD)
857 ),
858 }
859 }
860 }
861 }
862 }
863