1 use crate::{decode_config_slice, Config, DecodeError};
2 use std::io::Read;
3 use std::{cmp, fmt, io};
4 
5 // This should be large, but it has to fit on the stack.
6 pub(crate) const BUF_SIZE: usize = 1024;
7 
8 // 4 bytes of base64 data encode 3 bytes of raw data (modulo padding).
9 const BASE64_CHUNK_SIZE: usize = 4;
10 const DECODED_CHUNK_SIZE: usize = 3;
11 
12 /// A `Read` implementation that decodes base64 data read from an underlying reader.
13 ///
14 /// # Examples
15 ///
16 /// ```
17 /// use std::io::Read;
18 /// use std::io::Cursor;
19 ///
20 /// // use a cursor as the simplest possible `Read` -- in real code this is probably a file, etc.
21 /// let mut wrapped_reader = Cursor::new(b"YXNkZg==");
22 /// let mut decoder = base64::read::DecoderReader::new(
23 ///     &mut wrapped_reader, base64::STANDARD);
24 ///
25 /// // handle errors as you normally would
26 /// let mut result = Vec::new();
27 /// decoder.read_to_end(&mut result).unwrap();
28 ///
29 /// assert_eq!(b"asdf", &result[..]);
30 ///
31 /// ```
32 pub struct DecoderReader<'a, R: 'a + io::Read> {
33     config: Config,
34     /// Where b64 data is read from
35     r: &'a mut R,
36 
37     // Holds b64 data read from the delegate reader.
38     b64_buffer: [u8; BUF_SIZE],
39     // The start of the pending buffered data in b64_buffer.
40     b64_offset: usize,
41     // The amount of buffered b64 data.
42     b64_len: usize,
43     // Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a
44     // decoded chunk in to, we have to be able to hang on to a few decoded bytes.
45     // Technically we only need to hold 2 bytes but then we'd need a separate temporary buffer to
46     // decode 3 bytes into and then juggle copying one byte into the provided read buf and the rest
47     // into here, which seems like a lot of complexity for 1 extra byte of storage.
48     decoded_buffer: [u8; 3],
49     // index of start of decoded data
50     decoded_offset: usize,
51     // length of decoded data
52     decoded_len: usize,
53     // used to provide accurate offsets in errors
54     total_b64_decoded: usize,
55 }
56 
57 impl<'a, R: io::Read> fmt::Debug for DecoderReader<'a, R> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result58     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
59         f.debug_struct("DecoderReader")
60             .field("config", &self.config)
61             .field("b64_offset", &self.b64_offset)
62             .field("b64_len", &self.b64_len)
63             .field("decoded_buffer", &self.decoded_buffer)
64             .field("decoded_offset", &self.decoded_offset)
65             .field("decoded_len", &self.decoded_len)
66             .field("total_b64_decoded", &self.total_b64_decoded)
67             .finish()
68     }
69 }
70 
71 impl<'a, R: io::Read> DecoderReader<'a, R> {
72     /// Create a new decoder that will read from the provided reader `r`.
new(r: &'a mut R, config: Config) -> Self73     pub fn new(r: &'a mut R, config: Config) -> Self {
74         DecoderReader {
75             config,
76             r,
77             b64_buffer: [0; BUF_SIZE],
78             b64_offset: 0,
79             b64_len: 0,
80             decoded_buffer: [0; DECODED_CHUNK_SIZE],
81             decoded_offset: 0,
82             decoded_len: 0,
83             total_b64_decoded: 0,
84         }
85     }
86 
87     /// Write as much as possible of the decoded buffer into the target buffer.
88     /// Must only be called when there is something to write and space to write into.
89     /// Returns a Result with the number of (decoded) bytes copied.
flush_decoded_buf(&mut self, buf: &mut [u8]) -> io::Result<usize>90     fn flush_decoded_buf(&mut self, buf: &mut [u8]) -> io::Result<usize> {
91         debug_assert!(self.decoded_len > 0);
92         debug_assert!(buf.len() > 0);
93 
94         let copy_len = cmp::min(self.decoded_len, buf.len());
95         debug_assert!(copy_len > 0);
96         debug_assert!(copy_len <= self.decoded_len);
97 
98         buf[..copy_len].copy_from_slice(
99             &self.decoded_buffer[self.decoded_offset..self.decoded_offset + copy_len],
100         );
101 
102         self.decoded_offset += copy_len;
103         self.decoded_len -= copy_len;
104 
105         debug_assert!(self.decoded_len < DECODED_CHUNK_SIZE);
106 
107         Ok(copy_len)
108     }
109 
110     /// Read into the remaining space in the buffer after the current contents.
111     /// Must only be called when there is space to read into in the buffer.
112     /// Returns the number of bytes read.
read_from_delegate(&mut self) -> io::Result<usize>113     fn read_from_delegate(&mut self) -> io::Result<usize> {
114         debug_assert!(self.b64_offset + self.b64_len < BUF_SIZE);
115 
116         let read = self
117             .r
118             .read(&mut self.b64_buffer[self.b64_offset + self.b64_len..])?;
119         self.b64_len += read;
120 
121         debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
122 
123         return Ok(read);
124     }
125 
126     /// Decode the requested number of bytes from the b64 buffer into the provided buffer. It's the
127     /// caller's responsibility to choose the number of b64 bytes to decode correctly.
128     ///
129     /// Returns a Result with the number of decoded bytes written to `buf`.
decode_to_buf(&mut self, num_bytes: usize, buf: &mut [u8]) -> io::Result<usize>130     fn decode_to_buf(&mut self, num_bytes: usize, buf: &mut [u8]) -> io::Result<usize> {
131         debug_assert!(self.b64_len >= num_bytes);
132         debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
133         debug_assert!(buf.len() > 0);
134 
135         let decoded = decode_config_slice(
136             &self.b64_buffer[self.b64_offset..self.b64_offset + num_bytes],
137             self.config,
138             &mut buf[..],
139         )
140         .map_err(|e| match e {
141             DecodeError::InvalidByte(offset, byte) => {
142                 DecodeError::InvalidByte(self.total_b64_decoded + offset, byte)
143             }
144             DecodeError::InvalidLength => DecodeError::InvalidLength,
145             DecodeError::InvalidLastSymbol(offset, byte) => {
146                 DecodeError::InvalidLastSymbol(self.total_b64_decoded + offset, byte)
147             }
148         })
149         .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
150 
151         self.total_b64_decoded += num_bytes;
152         self.b64_offset += num_bytes;
153         self.b64_len -= num_bytes;
154 
155         debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
156 
157         Ok(decoded)
158     }
159 }
160 
161 impl<'a, R: Read> Read for DecoderReader<'a, R> {
162     /// Decode input from the wrapped reader.
163     ///
164     /// Under non-error circumstances, this returns `Ok` with the value being the number of bytes
165     /// written in `buf`.
166     ///
167     /// Where possible, this function buffers base64 to minimize the number of read() calls to the
168     /// delegate reader.
169     ///
170     /// # Errors
171     ///
172     /// Any errors emitted by the delegate reader are returned. Decoding errors due to invalid
173     /// base64 are also possible, and will have `io::ErrorKind::InvalidData`.
read(&mut self, buf: &mut [u8]) -> io::Result<usize>174     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
175         if buf.len() == 0 {
176             return Ok(0);
177         }
178 
179         // offset == BUF_SIZE when we copied it all last time
180         debug_assert!(self.b64_offset <= BUF_SIZE);
181         debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
182         debug_assert!(if self.b64_offset == BUF_SIZE {
183             self.b64_len == 0
184         } else {
185             self.b64_len <= BUF_SIZE
186         });
187 
188         debug_assert!(if self.decoded_len == 0 {
189             // can be = when we were able to copy the complete chunk
190             self.decoded_offset <= DECODED_CHUNK_SIZE
191         } else {
192             self.decoded_offset < DECODED_CHUNK_SIZE
193         });
194 
195         // We shouldn't ever decode into here when we can't immediately write at least one byte into
196         // the provided buf, so the effective length should only be 3 momentarily between when we
197         // decode and when we copy into the target buffer.
198         debug_assert!(self.decoded_len < DECODED_CHUNK_SIZE);
199         debug_assert!(self.decoded_len + self.decoded_offset <= DECODED_CHUNK_SIZE);
200 
201         if self.decoded_len > 0 {
202             // we have a few leftover decoded bytes; flush that rather than pull in more b64
203             self.flush_decoded_buf(buf)
204         } else {
205             let mut at_eof = false;
206             while self.b64_len < BASE64_CHUNK_SIZE {
207                 // Work around lack of copy_within, which is only present in 1.37
208                 // Copy any bytes we have to the start of the buffer.
209                 // We know we have < 1 chunk, so we can use a tiny tmp buffer.
210                 let mut memmove_buf = [0_u8; BASE64_CHUNK_SIZE];
211                 memmove_buf[..self.b64_len].copy_from_slice(
212                     &self.b64_buffer[self.b64_offset..self.b64_offset + self.b64_len],
213                 );
214                 self.b64_buffer[0..self.b64_len].copy_from_slice(&memmove_buf[..self.b64_len]);
215                 self.b64_offset = 0;
216 
217                 // then fill in more data
218                 let read = self.read_from_delegate()?;
219                 if read == 0 {
220                     // we never pass in an empty buf, so 0 => we've hit EOF
221                     at_eof = true;
222                     break;
223                 }
224             }
225 
226             if self.b64_len == 0 {
227                 debug_assert!(at_eof);
228                 // we must be at EOF, and we have no data left to decode
229                 return Ok(0);
230             };
231 
232             debug_assert!(if at_eof {
233                 // if we are at eof, we may not have a complete chunk
234                 self.b64_len > 0
235             } else {
236                 // otherwise, we must have at least one chunk
237                 self.b64_len >= BASE64_CHUNK_SIZE
238             });
239 
240             debug_assert_eq!(0, self.decoded_len);
241 
242             if buf.len() < DECODED_CHUNK_SIZE {
243                 // caller requested an annoyingly short read
244                 // have to write to a tmp buf first to avoid double mutable borrow
245                 let mut decoded_chunk = [0_u8; DECODED_CHUNK_SIZE];
246                 // if we are at eof, could have less than BASE64_CHUNK_SIZE, in which case we have
247                 // to assume that these last few tokens are, in fact, valid (i.e. must be 2-4 b64
248                 // tokens, not 1, since 1 token can't decode to 1 byte).
249                 let to_decode = cmp::min(self.b64_len, BASE64_CHUNK_SIZE);
250 
251                 let decoded = self.decode_to_buf(to_decode, &mut decoded_chunk[..])?;
252                 self.decoded_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]);
253 
254                 self.decoded_offset = 0;
255                 self.decoded_len = decoded;
256 
257                 // can be less than 3 on last block due to padding
258                 debug_assert!(decoded <= 3);
259 
260                 self.flush_decoded_buf(buf)
261             } else {
262                 let b64_bytes_that_can_decode_into_buf = (buf.len() / DECODED_CHUNK_SIZE)
263                     .checked_mul(BASE64_CHUNK_SIZE)
264                     .expect("too many chunks");
265                 debug_assert!(b64_bytes_that_can_decode_into_buf >= BASE64_CHUNK_SIZE);
266 
267                 let b64_bytes_available_to_decode = if at_eof {
268                     self.b64_len
269                 } else {
270                     // only use complete chunks
271                     self.b64_len - self.b64_len % 4
272                 };
273 
274                 let actual_decode_len = cmp::min(
275                     b64_bytes_that_can_decode_into_buf,
276                     b64_bytes_available_to_decode,
277                 );
278                 self.decode_to_buf(actual_decode_len, buf)
279             }
280         }
281     }
282 }
283