1 #![forbid(unsafe_code)]
2 
3 use std::fmt::{self, Display};
4 
5 #[derive(Debug, Copy, Clone, Eq, PartialEq)]
6 pub enum Error {
7     Overflow,
8     InvalidInput,
9 }
10 
11 impl std::error::Error for Error {}
12 
13 impl Display for Error {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result14     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15         match self {
16             Error::Overflow => write!(f, "Overflow"),
17             Error::InvalidInput => write!(f, "Invalid input"),
18         }
19     }
20 }
21 
22 pub trait Decoder {
decode<IN: AsRef<[u8]>>(bin: &mut [u8], encoded: IN) -> Result<&[u8], Error>23     fn decode<IN: AsRef<[u8]>>(bin: &mut [u8], encoded: IN) -> Result<&[u8], Error>;
24 
decode_to_vec<IN: AsRef<[u8]>>(encoded: IN) -> Result<Vec<u8>, Error>25     fn decode_to_vec<IN: AsRef<[u8]>>(encoded: IN) -> Result<Vec<u8>, Error> {
26         let mut bin = vec![0u8; encoded.as_ref().len()];
27         let bin_len = Self::decode(&mut bin, encoded)?.len();
28         bin.truncate(bin_len);
29         Ok(bin)
30     }
31 }
32 
33 struct Base64Impl;
34 
35 impl Base64Impl {
36     #[inline]
_eq(x: u8, y: u8) -> u837     fn _eq(x: u8, y: u8) -> u8 {
38         !(((0u16.wrapping_sub((x as u16) ^ (y as u16))) >> 8) as u8)
39     }
40 
41     #[inline]
_gt(x: u8, y: u8) -> u842     fn _gt(x: u8, y: u8) -> u8 {
43         (((y as u16).wrapping_sub(x as u16)) >> 8) as u8
44     }
45 
46     #[inline]
_ge(x: u8, y: u8) -> u847     fn _ge(x: u8, y: u8) -> u8 {
48         !Self::_gt(y, x)
49     }
50 
51     #[inline]
_lt(x: u8, y: u8) -> u852     fn _lt(x: u8, y: u8) -> u8 {
53         Self::_gt(y, x)
54     }
55 
56     #[inline]
_le(x: u8, y: u8) -> u857     fn _le(x: u8, y: u8) -> u8 {
58         Self::_ge(y, x)
59     }
60 
61     #[inline]
b64_char_to_byte(c: u8) -> u862     fn b64_char_to_byte(c: u8) -> u8 {
63         let x = (Self::_ge(c, b'A') & Self::_le(c, b'Z') & (c.wrapping_sub(b'A')))
64             | (Self::_ge(c, b'a') & Self::_le(c, b'z') & (c.wrapping_sub(b'a'.wrapping_sub(26))))
65             | (Self::_ge(c, b'0') & Self::_le(c, b'9') & (c.wrapping_sub(b'0'.wrapping_sub(52))))
66             | (Self::_eq(c, b'+') & 62)
67             | (Self::_eq(c, b'/') & 63);
68         x | (Self::_eq(x, 0) & (Self::_eq(c, b'A') ^ 0xff))
69     }
70 
skip_padding(b64: &[u8], mut padding_len: usize) -> Result<&[u8], Error>71     fn skip_padding(b64: &[u8], mut padding_len: usize) -> Result<&[u8], Error> {
72         let b64_len = b64.len();
73         let mut b64_pos = 0usize;
74         while padding_len > 0 {
75             if b64_pos > b64_len {
76                 return Err(Error::InvalidInput);
77             }
78             let c = b64[b64_pos];
79             if c == b'=' {
80                 padding_len -= 1
81             } else {
82                 return Err(Error::InvalidInput);
83             }
84             b64_pos += 1
85         }
86         Ok(&b64[b64_pos..])
87     }
88 
decode<'t>(bin: &'t mut [u8], b64: &[u8]) -> Result<&'t [u8], Error>89     pub fn decode<'t>(bin: &'t mut [u8], b64: &[u8]) -> Result<&'t [u8], Error> {
90         let bin_maxlen = bin.len();
91         let mut acc = 0u16;
92         let mut acc_len = 0usize;
93         let mut bin_pos = 0usize;
94         let mut premature_end = None;
95         for (b64_pos, &c) in b64.iter().enumerate() {
96             let d = Self::b64_char_to_byte(c);
97             if d == 0xff {
98                 premature_end = Some(b64_pos);
99                 break;
100             }
101             acc = (acc << 6) + d as u16;
102             acc_len += 6;
103             if acc_len >= 8 {
104                 acc_len -= 8;
105                 if bin_pos >= bin_maxlen {
106                     return Err(Error::Overflow);
107                 }
108                 bin[bin_pos] = (acc >> acc_len) as u8;
109                 bin_pos += 1;
110             }
111         }
112         if acc_len > 4 || (acc & ((1u16 << acc_len).wrapping_sub(1))) != 0 {
113             return Err(Error::InvalidInput);
114         }
115         if let Some(premature_end) = premature_end {
116             let remaining = {
117                 let padding_len = acc_len / 2;
118                 Self::skip_padding(&b64[premature_end..], padding_len)?
119             };
120             if !remaining.is_empty() {
121                 return Err(Error::InvalidInput);
122             }
123         }
124         Ok(&bin[..bin_pos])
125     }
126 }
127 
128 pub struct Base64;
129 
130 impl Decoder for Base64 {
131     #[inline]
decode<IN: AsRef<[u8]>>(bin: &mut [u8], b64: IN) -> Result<&[u8], Error>132     fn decode<IN: AsRef<[u8]>>(bin: &mut [u8], b64: IN) -> Result<&[u8], Error> {
133         Base64Impl::decode(bin, b64.as_ref())
134     }
135 }
136