1 use byteorder::{BigEndian, ByteOrder};
2 use {tables, Config, STANDARD};
3
4 use std::{error, fmt, str};
5
6 // decode logic operates on chunks of 8 input bytes without padding
7 const INPUT_CHUNK_LEN: usize = 8;
8 const DECODED_CHUNK_LEN: usize = 6;
9 // we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last
10 // 2 bytes of any output u64 should not be counted as written to (but must be available in a
11 // slice).
12 const DECODED_CHUNK_SUFFIX: usize = 2;
13
14 // how many u64's of input to handle at a time
15 const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
16 const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
17 // includes the trailing 2 bytes for the final u64 write
18 const DECODED_BLOCK_LEN: usize =
19 CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
20
21 /// Errors that can occur while decoding.
22 #[derive(Clone, Debug, PartialEq, Eq)]
23 pub enum DecodeError {
24 /// An invalid byte was found in the input. The offset and offending byte are provided.
25 InvalidByte(usize, u8),
26 /// The length of the input is invalid.
27 InvalidLength,
28 /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
29 /// This is indicative of corrupted or truncated Base64.
30 /// Unlike InvalidByte, which reports symbols that aren't in the alphabet, this error is for
31 /// symbols that are in the alphabet but represent nonsensical encodings.
32 InvalidLastSymbol(usize, u8),
33 }
34
35 impl fmt::Display for DecodeError {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result36 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37 match *self {
38 DecodeError::InvalidByte(index, byte) => {
39 write!(f, "Invalid byte {}, offset {}.", byte, index)
40 }
41 DecodeError::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."),
42 DecodeError::InvalidLastSymbol(index, byte) => {
43 write!(f, "Invalid last symbol {}, offset {}.", byte, index)
44 }
45 }
46 }
47 }
48
49 impl error::Error for DecodeError {
description(&self) -> &str50 fn description(&self) -> &str {
51 match *self {
52 DecodeError::InvalidByte(_, _) => "invalid byte",
53 DecodeError::InvalidLength => "invalid length",
54 DecodeError::InvalidLastSymbol(_, _) => "invalid last symbol",
55 }
56 }
57
cause(&self) -> Option<&error::Error>58 fn cause(&self) -> Option<&error::Error> {
59 None
60 }
61 }
62
63 ///Decode from string reference as octets.
64 ///Returns a Result containing a Vec<u8>.
65 ///Convenience `decode_config(input, base64::STANDARD);`.
66 ///
67 ///# Example
68 ///
69 ///```rust
70 ///extern crate base64;
71 ///
72 ///fn main() {
73 /// let bytes = base64::decode("aGVsbG8gd29ybGQ=").unwrap();
74 /// println!("{:?}", bytes);
75 ///}
76 ///```
decode<T: ?Sized + AsRef<[u8]>>(input: &T) -> Result<Vec<u8>, DecodeError>77 pub fn decode<T: ?Sized + AsRef<[u8]>>(input: &T) -> Result<Vec<u8>, DecodeError> {
78 decode_config(input, STANDARD)
79 }
80
81 ///Decode from string reference as octets.
82 ///Returns a Result containing a Vec<u8>.
83 ///
84 ///# Example
85 ///
86 ///```rust
87 ///extern crate base64;
88 ///
89 ///fn main() {
90 /// let bytes = base64::decode_config("aGVsbG8gd29ybGR+Cg==", base64::STANDARD).unwrap();
91 /// println!("{:?}", bytes);
92 ///
93 /// let bytes_url = base64::decode_config("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE).unwrap();
94 /// println!("{:?}", bytes_url);
95 ///}
96 ///```
decode_config<T: ?Sized + AsRef<[u8]>>( input: &T, config: Config, ) -> Result<Vec<u8>, DecodeError>97 pub fn decode_config<T: ?Sized + AsRef<[u8]>>(
98 input: &T,
99 config: Config,
100 ) -> Result<Vec<u8>, DecodeError> {
101 let mut buffer = Vec::<u8>::with_capacity(input.as_ref().len() * 4 / 3);
102
103 decode_config_buf(input, config, &mut buffer).map(|_| buffer)
104 }
105
106 ///Decode from string reference as octets.
107 ///Writes into the supplied buffer to avoid allocation.
108 ///Returns a Result containing an empty tuple, aka ().
109 ///
110 ///# Example
111 ///
112 ///```rust
113 ///extern crate base64;
114 ///
115 ///fn main() {
116 /// let mut buffer = Vec::<u8>::new();
117 /// base64::decode_config_buf("aGVsbG8gd29ybGR+Cg==", base64::STANDARD, &mut buffer).unwrap();
118 /// println!("{:?}", buffer);
119 ///
120 /// buffer.clear();
121 ///
122 /// base64::decode_config_buf("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE, &mut buffer)
123 /// .unwrap();
124 /// println!("{:?}", buffer);
125 ///}
126 ///```
decode_config_buf<T: ?Sized + AsRef<[u8]>>( input: &T, config: Config, buffer: &mut Vec<u8>, ) -> Result<(), DecodeError>127 pub fn decode_config_buf<T: ?Sized + AsRef<[u8]>>(
128 input: &T,
129 config: Config,
130 buffer: &mut Vec<u8>,
131 ) -> Result<(), DecodeError> {
132 let input_bytes = input.as_ref();
133
134 let starting_output_len = buffer.len();
135
136 let num_chunks = num_chunks(input_bytes);
137 let decoded_len_estimate = num_chunks
138 .checked_mul(DECODED_CHUNK_LEN)
139 .and_then(|p| p.checked_add(starting_output_len))
140 .expect("Overflow when calculating output buffer length");
141 buffer.resize(decoded_len_estimate, 0);
142
143 let bytes_written;
144 {
145 let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..];
146 bytes_written = decode_helper(input_bytes, num_chunks, config, buffer_slice)?;
147 }
148
149 buffer.truncate(starting_output_len + bytes_written);
150
151 Ok(())
152 }
153
154 /// Decode the input into the provided output slice.
155 ///
156 /// This will not write any bytes past exactly what is decoded (no stray garbage bytes at the end).
157 ///
158 /// If you don't know ahead of time what the decoded length should be, size your buffer with a
159 /// conservative estimate for the decoded length of an input: 3 bytes of output for every 4 bytes of
160 /// input, rounded up, or in other words `(input_len + 3) / 4 * 3`.
161 ///
162 /// If the slice is not large enough, this will panic.
decode_config_slice<T: ?Sized + AsRef<[u8]>>( input: &T, config: Config, output: &mut [u8], ) -> Result<usize, DecodeError>163 pub fn decode_config_slice<T: ?Sized + AsRef<[u8]>>(
164 input: &T,
165 config: Config,
166 output: &mut [u8],
167 ) -> Result<usize, DecodeError> {
168 let input_bytes = input.as_ref();
169
170 decode_helper(
171 input_bytes,
172 num_chunks(input_bytes),
173 config,
174 output,
175 )
176 }
177
178 /// Return the number of input chunks (including a possibly partial final chunk) in the input
num_chunks(input: &[u8]) -> usize179 fn num_chunks(input: &[u8]) -> usize {
180 input
181 .len()
182 .checked_add(INPUT_CHUNK_LEN - 1)
183 .expect("Overflow when calculating number of chunks in input")
184 / INPUT_CHUNK_LEN
185 }
186
187 /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
188 /// Returns the number of bytes written, or an error.
189 // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
190 // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment,
191 // but this is fragile and the best setting changes with only minor code modifications.
192 #[inline]
decode_helper( input: &[u8], num_chunks: usize, config: Config, output: &mut [u8], ) -> Result<usize, DecodeError>193 fn decode_helper(
194 input: &[u8],
195 num_chunks: usize,
196 config: Config,
197 output: &mut [u8],
198 ) -> Result<usize, DecodeError> {
199 let char_set = config.char_set;
200 let decode_table = char_set.decode_table();
201
202 let remainder_len = input.len() % INPUT_CHUNK_LEN;
203
204 // Because the fast decode loop writes in groups of 8 bytes (unrolled to
205 // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of
206 // which only 6 are valid data), we need to be sure that we stop using the fast decode loop
207 // soon enough that there will always be 2 more bytes of valid data written after that loop.
208 let trailing_bytes_to_skip = match remainder_len {
209 // if input is a multiple of the chunk size, ignore the last chunk as it may have padding,
210 // and the fast decode logic cannot handle padding
211 0 => INPUT_CHUNK_LEN,
212 // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte
213 1 | 5 => return Err(DecodeError::InvalidLength),
214 // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes
215 // written by the fast decode loop. So, we have to ignore both these 2 bytes and the
216 // previous chunk.
217 2 => INPUT_CHUNK_LEN + 2,
218 // If this is 3 unpadded chars, then it would actually decode to 2 bytes. However, if this
219 // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail
220 // with an error, not panic from going past the bounds of the output slice, so we let it
221 // use stage 3 + 4.
222 3 => INPUT_CHUNK_LEN + 3,
223 // This can also decode to one output byte because it may be 2 input chars + 2 padding
224 // chars, which would decode to 1 byte.
225 4 => INPUT_CHUNK_LEN + 4,
226 // Everything else is a legal decode len (given that we don't require padding), and will
227 // decode to at least 2 bytes of output.
228 _ => remainder_len,
229 };
230
231 // rounded up to include partial chunks
232 let mut remaining_chunks = num_chunks;
233
234 let mut input_index = 0;
235 let mut output_index = 0;
236
237 {
238 let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
239
240 // Fast loop, stage 1
241 // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks
242 if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
243 while input_index <= max_start_index {
244 let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
245 let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
246
247 decode_chunk(
248 &input_slice[0..],
249 input_index,
250 decode_table,
251 &mut output_slice[0..],
252 )?;
253 decode_chunk(
254 &input_slice[8..],
255 input_index + 8,
256 decode_table,
257 &mut output_slice[6..],
258 )?;
259 decode_chunk(
260 &input_slice[16..],
261 input_index + 16,
262 decode_table,
263 &mut output_slice[12..],
264 )?;
265 decode_chunk(
266 &input_slice[24..],
267 input_index + 24,
268 decode_table,
269 &mut output_slice[18..],
270 )?;
271
272 input_index += INPUT_BLOCK_LEN;
273 output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
274 remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
275 }
276 }
277
278 // Fast loop, stage 2 (aka still pretty fast loop)
279 // 8 bytes at a time for whatever we didn't do in stage 1.
280 if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
281 while input_index < max_start_index {
282 decode_chunk(
283 &input[input_index..(input_index + INPUT_CHUNK_LEN)],
284 input_index,
285 decode_table,
286 &mut output
287 [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
288 )?;
289
290 output_index += DECODED_CHUNK_LEN;
291 input_index += INPUT_CHUNK_LEN;
292 remaining_chunks -= 1;
293 }
294 }
295 }
296
297 // Stage 3
298 // If input length was such that a chunk had to be deferred until after the fast loop
299 // because decoding it would have produced 2 trailing bytes that wouldn't then be
300 // overwritten, we decode that chunk here. This way is slower but doesn't write the 2
301 // trailing bytes.
302 // However, we still need to avoid the last chunk (partial or complete) because it could
303 // have padding, so we always do 1 fewer to avoid the last chunk.
304 for _ in 1..remaining_chunks {
305 decode_chunk_precise(
306 &input[input_index..],
307 input_index,
308 decode_table,
309 &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
310 )?;
311
312 input_index += INPUT_CHUNK_LEN;
313 output_index += DECODED_CHUNK_LEN;
314 }
315
316 // always have one more (possibly partial) block of 8 input
317 debug_assert!(input.len() - input_index > 1 || input.is_empty());
318 debug_assert!(input.len() - input_index <= 8);
319
320 // Stage 4
321 // Finally, decode any leftovers that aren't a complete input block of 8 bytes.
322 // Use a u64 as a stack-resident 8 byte buffer.
323 let mut leftover_bits: u64 = 0;
324 let mut morsels_in_leftover = 0;
325 let mut padding_bytes = 0;
326 let mut first_padding_index: usize = 0;
327 let mut last_symbol = 0_u8;
328 let start_of_leftovers = input_index;
329 for (i, b) in input[start_of_leftovers..].iter().enumerate() {
330 // '=' padding
331 if *b == 0x3D {
332 // There can be bad padding in a few ways:
333 // 1 - Padding with non-padding characters after it
334 // 2 - Padding after zero or one non-padding characters before it
335 // in the current quad.
336 // 3 - More than two characters of padding. If 3 or 4 padding chars
337 // are in the same quad, that implies it will be caught by #2.
338 // If it spreads from one quad to another, it will be caught by
339 // #2 in the second quad.
340
341 if i % 4 < 2 {
342 // Check for case #2.
343 let bad_padding_index = start_of_leftovers + if padding_bytes > 0 {
344 // If we've already seen padding, report the first padding index.
345 // This is to be consistent with the faster logic above: it will report an
346 // error on the first padding character (since it doesn't expect to see
347 // anything but actual encoded data).
348 first_padding_index
349 } else {
350 // haven't seen padding before, just use where we are now
351 i
352 };
353 return Err(DecodeError::InvalidByte(bad_padding_index, *b));
354 }
355
356 if padding_bytes == 0 {
357 first_padding_index = i;
358 }
359
360 padding_bytes += 1;
361 continue;
362 }
363
364 // Check for case #1.
365 // To make '=' handling consistent with the main loop, don't allow
366 // non-suffix '=' in trailing chunk either. Report error as first
367 // erroneous padding.
368 if padding_bytes > 0 {
369 return Err(DecodeError::InvalidByte(
370 start_of_leftovers + first_padding_index,
371 0x3D,
372 ));
373 }
374 last_symbol = *b;
375
376 // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
377 // To minimize shifts, pack the leftovers from left to right.
378 let shift = 64 - (morsels_in_leftover + 1) * 6;
379 // tables are all 256 elements, lookup with a u8 index always succeeds
380 let morsel = decode_table[*b as usize];
381 if morsel == tables::INVALID_VALUE {
382 return Err(DecodeError::InvalidByte(start_of_leftovers + i, *b));
383 }
384
385 leftover_bits |= (morsel as u64) << shift;
386 morsels_in_leftover += 1;
387 }
388
389 let leftover_bits_ready_to_append = match morsels_in_leftover {
390 0 => 0,
391 2 => 8,
392 3 => 16,
393 4 => 24,
394 6 => 32,
395 7 => 40,
396 8 => 48,
397 _ => unreachable!(
398 "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths"
399 ),
400 };
401
402 // if there are bits set outside the bits we care about, last symbol encodes trailing bits that
403 // will not be included in the output
404 let mask = !0 >> leftover_bits_ready_to_append;
405 if !config.decode_allow_trailing_bits && (leftover_bits & mask) != 0 {
406 // last morsel is at `morsels_in_leftover` - 1
407 return Err(DecodeError::InvalidLastSymbol(
408 start_of_leftovers + morsels_in_leftover - 1,
409 last_symbol,
410 ));
411 }
412
413 let mut leftover_bits_appended_to_buf = 0;
414 while leftover_bits_appended_to_buf < leftover_bits_ready_to_append {
415 // `as` simply truncates the higher bits, which is what we want here
416 let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8;
417 output[output_index] = selected_bits;
418 output_index += 1;
419
420 leftover_bits_appended_to_buf += 8;
421 }
422
423 Ok(output_index)
424 }
425
426 /// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the
427 /// first 6 of those contain meaningful data.
428 ///
429 /// `input` is the bytes to decode, of which the first 8 bytes will be processed.
430 /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
431 /// accurately)
432 /// `decode_table` is the lookup table for the particular base64 alphabet.
433 /// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded
434 /// data.
435 // yes, really inline (worth 30-50% speedup)
436 #[inline(always)]
decode_chunk( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>437 fn decode_chunk(
438 input: &[u8],
439 index_at_start_of_input: usize,
440 decode_table: &[u8; 256],
441 output: &mut [u8],
442 ) -> Result<(), DecodeError> {
443 let mut accum: u64;
444
445 let morsel = decode_table[input[0] as usize];
446 if morsel == tables::INVALID_VALUE {
447 return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
448 }
449 accum = (morsel as u64) << 58;
450
451 let morsel = decode_table[input[1] as usize];
452 if morsel == tables::INVALID_VALUE {
453 return Err(DecodeError::InvalidByte(
454 index_at_start_of_input + 1,
455 input[1],
456 ));
457 }
458 accum |= (morsel as u64) << 52;
459
460 let morsel = decode_table[input[2] as usize];
461 if morsel == tables::INVALID_VALUE {
462 return Err(DecodeError::InvalidByte(
463 index_at_start_of_input + 2,
464 input[2],
465 ));
466 }
467 accum |= (morsel as u64) << 46;
468
469 let morsel = decode_table[input[3] as usize];
470 if morsel == tables::INVALID_VALUE {
471 return Err(DecodeError::InvalidByte(
472 index_at_start_of_input + 3,
473 input[3],
474 ));
475 }
476 accum |= (morsel as u64) << 40;
477
478 let morsel = decode_table[input[4] as usize];
479 if morsel == tables::INVALID_VALUE {
480 return Err(DecodeError::InvalidByte(
481 index_at_start_of_input + 4,
482 input[4],
483 ));
484 }
485 accum |= (morsel as u64) << 34;
486
487 let morsel = decode_table[input[5] as usize];
488 if morsel == tables::INVALID_VALUE {
489 return Err(DecodeError::InvalidByte(
490 index_at_start_of_input + 5,
491 input[5],
492 ));
493 }
494 accum |= (morsel as u64) << 28;
495
496 let morsel = decode_table[input[6] as usize];
497 if morsel == tables::INVALID_VALUE {
498 return Err(DecodeError::InvalidByte(
499 index_at_start_of_input + 6,
500 input[6],
501 ));
502 }
503 accum |= (morsel as u64) << 22;
504
505 let morsel = decode_table[input[7] as usize];
506 if morsel == tables::INVALID_VALUE {
507 return Err(DecodeError::InvalidByte(
508 index_at_start_of_input + 7,
509 input[7],
510 ));
511 }
512 accum |= (morsel as u64) << 16;
513
514 BigEndian::write_u64(output, accum);
515
516 Ok(())
517 }
518
519 /// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2
520 /// trailing garbage bytes.
521 #[inline]
decode_chunk_precise( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>522 fn decode_chunk_precise(
523 input: &[u8],
524 index_at_start_of_input: usize,
525 decode_table: &[u8; 256],
526 output: &mut [u8],
527 ) -> Result<(), DecodeError> {
528 let mut tmp_buf = [0_u8; 8];
529
530 decode_chunk(
531 input,
532 index_at_start_of_input,
533 decode_table,
534 &mut tmp_buf[..],
535 )?;
536
537 output[0..6].copy_from_slice(&tmp_buf[0..6]);
538
539 Ok(())
540 }
541
542 #[cfg(test)]
543 mod tests {
544 extern crate rand;
545
546 use super::*;
547 use encode::encode_config_buf;
548 use tests::{assert_encode_sanity, random_config};
549
550 use self::rand::distributions::{Distribution, Uniform};
551 use self::rand::{FromEntropy, Rng};
552
553 #[test]
decode_chunk_precise_writes_only_6_bytes()554 fn decode_chunk_precise_writes_only_6_bytes() {
555 let input = b"Zm9vYmFy"; // "foobar"
556 let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
557 decode_chunk_precise(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
558 assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
559 }
560
561 #[test]
decode_chunk_writes_8_bytes()562 fn decode_chunk_writes_8_bytes() {
563 let input = b"Zm9vYmFy"; // "foobar"
564 let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
565 decode_chunk(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
566 assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
567 }
568
569 #[test]
decode_into_nonempty_vec_doesnt_clobber_existing_prefix()570 fn decode_into_nonempty_vec_doesnt_clobber_existing_prefix() {
571 let mut orig_data = Vec::new();
572 let mut encoded_data = String::new();
573 let mut decoded_with_prefix = Vec::new();
574 let mut decoded_without_prefix = Vec::new();
575 let mut prefix = Vec::new();
576
577 let prefix_len_range = Uniform::new(0, 1000);
578 let input_len_range = Uniform::new(0, 1000);
579
580 let mut rng = rand::rngs::SmallRng::from_entropy();
581
582 for _ in 0..10_000 {
583 orig_data.clear();
584 encoded_data.clear();
585 decoded_with_prefix.clear();
586 decoded_without_prefix.clear();
587 prefix.clear();
588
589 let input_len = input_len_range.sample(&mut rng);
590
591 for _ in 0..input_len {
592 orig_data.push(rng.gen());
593 }
594
595 let config = random_config(&mut rng);
596 encode_config_buf(&orig_data, config, &mut encoded_data);
597 assert_encode_sanity(&encoded_data, config, input_len);
598
599 let prefix_len = prefix_len_range.sample(&mut rng);
600
601 // fill the buf with a prefix
602 for _ in 0..prefix_len {
603 prefix.push(rng.gen());
604 }
605
606 decoded_with_prefix.resize(prefix_len, 0);
607 decoded_with_prefix.copy_from_slice(&prefix);
608
609 // decode into the non-empty buf
610 decode_config_buf(&encoded_data, config, &mut decoded_with_prefix).unwrap();
611 // also decode into the empty buf
612 decode_config_buf(&encoded_data, config, &mut decoded_without_prefix).unwrap();
613
614 assert_eq!(
615 prefix_len + decoded_without_prefix.len(),
616 decoded_with_prefix.len()
617 );
618 assert_eq!(orig_data, decoded_without_prefix);
619
620 // append plain decode onto prefix
621 prefix.append(&mut decoded_without_prefix);
622
623 assert_eq!(prefix, decoded_with_prefix);
624 }
625 }
626
627 #[test]
decode_into_slice_doesnt_clobber_existing_prefix_or_suffix()628 fn decode_into_slice_doesnt_clobber_existing_prefix_or_suffix() {
629 let mut orig_data = Vec::new();
630 let mut encoded_data = String::new();
631 let mut decode_buf = Vec::new();
632 let mut decode_buf_copy: Vec<u8> = Vec::new();
633
634 let input_len_range = Uniform::new(0, 1000);
635
636 let mut rng = rand::rngs::SmallRng::from_entropy();
637
638 for _ in 0..10_000 {
639 orig_data.clear();
640 encoded_data.clear();
641 decode_buf.clear();
642 decode_buf_copy.clear();
643
644 let input_len = input_len_range.sample(&mut rng);
645
646 for _ in 0..input_len {
647 orig_data.push(rng.gen());
648 }
649
650 let config = random_config(&mut rng);
651 encode_config_buf(&orig_data, config, &mut encoded_data);
652 assert_encode_sanity(&encoded_data, config, input_len);
653
654 // fill the buffer with random garbage, long enough to have some room before and after
655 for _ in 0..5000 {
656 decode_buf.push(rng.gen());
657 }
658
659 // keep a copy for later comparison
660 decode_buf_copy.extend(decode_buf.iter());
661
662 let offset = 1000;
663
664 // decode into the non-empty buf
665 let decode_bytes_written =
666 decode_config_slice(&encoded_data, config, &mut decode_buf[offset..]).unwrap();
667
668 assert_eq!(orig_data.len(), decode_bytes_written);
669 assert_eq!(
670 orig_data,
671 &decode_buf[offset..(offset + decode_bytes_written)]
672 );
673 assert_eq!(&decode_buf_copy[0..offset], &decode_buf[0..offset]);
674 assert_eq!(
675 &decode_buf_copy[offset + decode_bytes_written..],
676 &decode_buf[offset + decode_bytes_written..]
677 );
678 }
679 }
680
681 #[test]
decode_into_slice_fits_in_precisely_sized_slice()682 fn decode_into_slice_fits_in_precisely_sized_slice() {
683 let mut orig_data = Vec::new();
684 let mut encoded_data = String::new();
685 let mut decode_buf = Vec::new();
686
687 let input_len_range = Uniform::new(0, 1000);
688
689 let mut rng = rand::rngs::SmallRng::from_entropy();
690
691 for _ in 0..10_000 {
692 orig_data.clear();
693 encoded_data.clear();
694 decode_buf.clear();
695
696 let input_len = input_len_range.sample(&mut rng);
697
698 for _ in 0..input_len {
699 orig_data.push(rng.gen());
700 }
701
702 let config = random_config(&mut rng);
703 encode_config_buf(&orig_data, config, &mut encoded_data);
704 assert_encode_sanity(&encoded_data, config, input_len);
705
706 decode_buf.resize(input_len, 0);
707
708 // decode into the non-empty buf
709 let decode_bytes_written =
710 decode_config_slice(&encoded_data, config, &mut decode_buf[..]).unwrap();
711
712 assert_eq!(orig_data.len(), decode_bytes_written);
713 assert_eq!(orig_data, decode_buf);
714 }
715 }
716
717 #[test]
detect_invalid_last_symbol_two_bytes()718 fn detect_invalid_last_symbol_two_bytes() {
719 let decode = |input, forgiving| {
720 decode_config(input, STANDARD.decode_allow_trailing_bits(forgiving))
721 };
722
723 // example from https://github.com/alicemaz/rust-base64/issues/75
724 assert!(decode("iYU=", false).is_ok());
725 // trailing 01
726 assert_eq!(Err(DecodeError::InvalidLastSymbol(2, b'V')), decode("iYV=", false));
727 assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
728 // trailing 10
729 assert_eq!(Err(DecodeError::InvalidLastSymbol(2, b'W')), decode("iYW=", false));
730 assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
731 // trailing 11
732 assert_eq!(Err(DecodeError::InvalidLastSymbol(2, b'X')), decode("iYX=", false));
733 assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
734
735 // also works when there are 2 quads in the last block
736 assert_eq!(Err(DecodeError::InvalidLastSymbol(6, b'X')), decode("AAAAiYX=", false));
737 assert_eq!(Ok(vec![0, 0, 0, 137, 133]), decode("AAAAiYX=", true));
738 }
739
740 #[test]
detect_invalid_last_symbol_one_byte()741 fn detect_invalid_last_symbol_one_byte() {
742 // 0xFF -> "/w==", so all letters > w, 0-9, and '+', '/' should get InvalidLastSymbol
743
744 assert!(decode("/w==").is_ok());
745 // trailing 01
746 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'x')), decode("/x=="));
747 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'z')), decode("/z=="));
748 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'0')), decode("/0=="));
749 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'9')), decode("/9=="));
750 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'+')), decode("/+=="));
751 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'/')), decode("//=="));
752
753 // also works when there are 2 quads in the last block
754 assert_eq!(
755 Err(DecodeError::InvalidLastSymbol(5, b'x')),
756 decode("AAAA/x==")
757 );
758 }
759
760 #[test]
detect_invalid_last_symbol_every_possible_three_symbols()761 fn detect_invalid_last_symbol_every_possible_three_symbols() {
762 let mut base64_to_bytes = ::std::collections::HashMap::new();
763
764 let mut bytes = [0_u8; 2];
765 for b1 in 0_u16..256 {
766 bytes[0] = b1 as u8;
767 for b2 in 0_u16..256 {
768 bytes[1] = b2 as u8;
769 let mut b64 = vec![0_u8; 4];
770 assert_eq!(4, ::encode_config_slice(&bytes, STANDARD, &mut b64[..]));
771 let mut v = ::std::vec::Vec::with_capacity(2);
772 v.extend_from_slice(&bytes[..]);
773
774 assert!(base64_to_bytes.insert(b64, v).is_none());
775 }
776 }
777
778 // every possible combination of symbols must either decode to 2 bytes or get InvalidLastSymbol
779
780 let mut symbols = [0_u8; 4];
781 for &s1 in STANDARD.char_set.encode_table().iter() {
782 symbols[0] = s1;
783 for &s2 in STANDARD.char_set.encode_table().iter() {
784 symbols[1] = s2;
785 for &s3 in STANDARD.char_set.encode_table().iter() {
786 symbols[2] = s3;
787 symbols[3] = b'=';
788
789 match base64_to_bytes.get(&symbols[..]) {
790 Some(bytes) => {
791 assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
792 }
793 None => assert_eq!(
794 Err(DecodeError::InvalidLastSymbol(2, s3)),
795 decode_config(&symbols[..], STANDARD)
796 ),
797 }
798 }
799 }
800 }
801 }
802
803 #[test]
detect_invalid_last_symbol_every_possible_two_symbols()804 fn detect_invalid_last_symbol_every_possible_two_symbols() {
805 let mut base64_to_bytes = ::std::collections::HashMap::new();
806
807 for b in 0_u16..256 {
808 let mut b64 = vec![0_u8; 4];
809 assert_eq!(4, ::encode_config_slice(&[b as u8], STANDARD, &mut b64[..]));
810 let mut v = ::std::vec::Vec::with_capacity(1);
811 v.push(b as u8);
812
813 assert!(base64_to_bytes.insert(b64, v).is_none());
814 }
815
816 // every possible combination of symbols must either decode to 1 byte or get InvalidLastSymbol
817
818 let mut symbols = [0_u8; 4];
819 for &s1 in STANDARD.char_set.encode_table().iter() {
820 symbols[0] = s1;
821 for &s2 in STANDARD.char_set.encode_table().iter() {
822 symbols[1] = s2;
823 symbols[2] = b'=';
824 symbols[3] = b'=';
825
826 match base64_to_bytes.get(&symbols[..]) {
827 Some(bytes) => {
828 assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
829 }
830 None => assert_eq!(
831 Err(DecodeError::InvalidLastSymbol(1, s2)),
832 decode_config(&symbols[..], STANDARD)
833 ),
834 }
835 }
836 }
837 }
838 }
839