1 use std::cmp;
2 use std::io::BufRead;
3 use std::io::BufReader;
4 use std::io::Read;
5 use std::mem;
6 use std::u64;
7 
8 #[cfg(feature = "bytes")]
9 use bytes::buf::UninitSlice;
10 #[cfg(feature = "bytes")]
11 use bytes::BufMut;
12 #[cfg(feature = "bytes")]
13 use bytes::Bytes;
14 #[cfg(feature = "bytes")]
15 use bytes::BytesMut;
16 
17 use crate::coded_input_stream::READ_RAW_BYTES_MAX_ALLOC;
18 use crate::error::WireError;
19 use crate::ProtobufError;
20 use crate::ProtobufResult;
21 
22 // If an input stream is constructed with a `Read`, we create a
23 // `BufReader` with an internal buffer of this size.
24 const INPUT_STREAM_BUFFER_SIZE: usize = 4096;
25 
26 const USE_UNSAFE_FOR_SPEED: bool = true;
27 
28 const NO_LIMIT: u64 = u64::MAX;
29 
30 /// Hold all possible combinations of input source
31 enum InputSource<'a> {
32     BufRead(&'a mut dyn BufRead),
33     Read(BufReader<&'a mut dyn Read>),
34     Slice(&'a [u8]),
35     #[cfg(feature = "bytes")]
36     Bytes(&'a Bytes),
37 }
38 
39 /// Dangerous implementation of `BufRead`.
40 ///
41 /// Unsafe wrapper around BufRead which assumes that `BufRead` buf is
42 /// not moved when `BufRead` is moved.
43 ///
44 /// This assumption is generally incorrect, however, in practice
45 /// `BufReadIter` is created either from `BufRead` reference (which
46 /// cannot  be moved, because it is locked by `CodedInputStream`) or from
47 /// `BufReader` which does not move its buffer (we know that from
48 /// inspecting rust standard library).
49 ///
50 /// It is important for `CodedInputStream` performance that small reads
51 /// (e. g. 4 bytes reads) do not involve virtual calls or switches.
52 /// This is achievable with `BufReadIter`.
53 pub struct BufReadIter<'a> {
54     input_source: InputSource<'a>,
55     buf: &'a [u8],
56     pos_within_buf: usize,
57     limit_within_buf: usize,
58     pos_of_buf_start: u64,
59     limit: u64,
60 }
61 
62 impl<'a> Drop for BufReadIter<'a> {
drop(&mut self)63     fn drop(&mut self) {
64         match self.input_source {
65             InputSource::BufRead(ref mut buf_read) => buf_read.consume(self.pos_within_buf),
66             InputSource::Read(_) => {
67                 // Nothing to flush, because we own BufReader
68             }
69             _ => {}
70         }
71     }
72 }
73 
74 impl<'ignore> BufReadIter<'ignore> {
from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a>75     pub fn from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a> {
76         BufReadIter {
77             input_source: InputSource::Read(BufReader::with_capacity(
78                 INPUT_STREAM_BUFFER_SIZE,
79                 read,
80             )),
81             buf: &[],
82             pos_within_buf: 0,
83             limit_within_buf: 0,
84             pos_of_buf_start: 0,
85             limit: NO_LIMIT,
86         }
87     }
88 
from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a>89     pub fn from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> {
90         BufReadIter {
91             input_source: InputSource::BufRead(buf_read),
92             buf: &[],
93             pos_within_buf: 0,
94             limit_within_buf: 0,
95             pos_of_buf_start: 0,
96             limit: NO_LIMIT,
97         }
98     }
99 
from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a>100     pub fn from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a> {
101         BufReadIter {
102             input_source: InputSource::Slice(bytes),
103             buf: bytes,
104             pos_within_buf: 0,
105             limit_within_buf: bytes.len(),
106             pos_of_buf_start: 0,
107             limit: NO_LIMIT,
108         }
109     }
110 
111     #[cfg(feature = "bytes")]
from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a>112     pub fn from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a> {
113         BufReadIter {
114             input_source: InputSource::Bytes(bytes),
115             buf: &bytes,
116             pos_within_buf: 0,
117             limit_within_buf: bytes.len(),
118             pos_of_buf_start: 0,
119             limit: NO_LIMIT,
120         }
121     }
122 
123     #[inline]
assertions(&self)124     fn assertions(&self) {
125         debug_assert!(self.pos_within_buf <= self.limit_within_buf);
126         debug_assert!(self.limit_within_buf <= self.buf.len());
127         debug_assert!(self.pos_of_buf_start + self.pos_within_buf as u64 <= self.limit);
128     }
129 
130     #[inline(always)]
pos(&self) -> u64131     pub fn pos(&self) -> u64 {
132         self.pos_of_buf_start + self.pos_within_buf as u64
133     }
134 
135     /// Recompute `limit_within_buf` after update of `limit`
136     #[inline]
update_limit_within_buf(&mut self)137     fn update_limit_within_buf(&mut self) {
138         if self.pos_of_buf_start + (self.buf.len() as u64) <= self.limit {
139             self.limit_within_buf = self.buf.len();
140         } else {
141             self.limit_within_buf = (self.limit - self.pos_of_buf_start) as usize;
142         }
143 
144         self.assertions();
145     }
146 
push_limit(&mut self, limit: u64) -> ProtobufResult<u64>147     pub fn push_limit(&mut self, limit: u64) -> ProtobufResult<u64> {
148         let new_limit = match self.pos().checked_add(limit) {
149             Some(new_limit) => new_limit,
150             None => return Err(ProtobufError::WireError(WireError::Other)),
151         };
152 
153         if new_limit > self.limit {
154             return Err(ProtobufError::WireError(WireError::Other));
155         }
156 
157         let prev_limit = mem::replace(&mut self.limit, new_limit);
158 
159         self.update_limit_within_buf();
160 
161         Ok(prev_limit)
162     }
163 
164     #[inline]
pop_limit(&mut self, limit: u64)165     pub fn pop_limit(&mut self, limit: u64) {
166         assert!(limit >= self.limit);
167 
168         self.limit = limit;
169 
170         self.update_limit_within_buf();
171     }
172 
173     #[inline]
remaining_in_buf(&self) -> &[u8]174     pub fn remaining_in_buf(&self) -> &[u8] {
175         if USE_UNSAFE_FOR_SPEED {
176             unsafe {
177                 &self
178                     .buf
179                     .get_unchecked(self.pos_within_buf..self.limit_within_buf)
180             }
181         } else {
182             &self.buf[self.pos_within_buf..self.limit_within_buf]
183         }
184     }
185 
186     #[inline(always)]
remaining_in_buf_len(&self) -> usize187     pub fn remaining_in_buf_len(&self) -> usize {
188         self.limit_within_buf - self.pos_within_buf
189     }
190 
191     #[inline(always)]
bytes_until_limit(&self) -> u64192     pub fn bytes_until_limit(&self) -> u64 {
193         if self.limit == NO_LIMIT {
194             NO_LIMIT
195         } else {
196             self.limit - (self.pos_of_buf_start + self.pos_within_buf as u64)
197         }
198     }
199 
200     #[inline(always)]
eof(&mut self) -> ProtobufResult<bool>201     pub fn eof(&mut self) -> ProtobufResult<bool> {
202         if self.pos_within_buf == self.limit_within_buf {
203             Ok(self.fill_buf()?.is_empty())
204         } else {
205             Ok(false)
206         }
207     }
208 
209     #[inline(always)]
read_byte(&mut self) -> ProtobufResult<u8>210     pub fn read_byte(&mut self) -> ProtobufResult<u8> {
211         if self.pos_within_buf == self.limit_within_buf {
212             self.do_fill_buf()?;
213             if self.remaining_in_buf_len() == 0 {
214                 return Err(ProtobufError::WireError(WireError::UnexpectedEof));
215             }
216         }
217 
218         let r = if USE_UNSAFE_FOR_SPEED {
219             unsafe { *self.buf.get_unchecked(self.pos_within_buf) }
220         } else {
221             self.buf[self.pos_within_buf]
222         };
223         self.pos_within_buf += 1;
224         Ok(r)
225     }
226 
227     /// Read at most `max` bytes, append to `Vec`.
228     ///
229     /// Returns 0 when EOF or limit reached.
read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> ProtobufResult<usize>230     fn read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> ProtobufResult<usize> {
231         let len = {
232             let rem = self.fill_buf()?;
233 
234             let len = cmp::min(rem.len(), max);
235             vec.extend_from_slice(&rem[..len]);
236             len
237         };
238         self.pos_within_buf += len;
239         Ok(len)
240     }
241 
242     /// Read exact number of bytes into `Vec`.
243     ///
244     /// `Vec` is cleared in the beginning.
read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()>245     pub fn read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()> {
246         // TODO: also do some limits when reading from unlimited source
247         if count as u64 > self.bytes_until_limit() {
248             return Err(ProtobufError::WireError(WireError::TruncatedMessage));
249         }
250 
251         target.clear();
252 
253         if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() {
254             // avoid calling `reserve` on buf with very large buffer: could be a malformed message
255 
256             target.reserve(READ_RAW_BYTES_MAX_ALLOC);
257 
258             while target.len() < count {
259                 let need_to_read = count - target.len();
260                 if need_to_read <= target.len() {
261                     target.reserve_exact(need_to_read);
262                 } else {
263                     target.reserve(1);
264                 }
265 
266                 let max = cmp::min(target.capacity() - target.len(), need_to_read);
267                 let read = self.read_to_vec(target, max)?;
268                 if read == 0 {
269                     return Err(ProtobufError::WireError(WireError::TruncatedMessage));
270                 }
271             }
272         } else {
273             target.reserve_exact(count);
274 
275             unsafe {
276                 self.read_exact(&mut target.get_unchecked_mut(..count))?;
277                 target.set_len(count);
278             }
279         }
280 
281         debug_assert_eq!(count, target.len());
282 
283         Ok(())
284     }
285 
286     #[cfg(feature = "bytes")]
read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes>287     pub fn read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes> {
288         if let InputSource::Bytes(bytes) = self.input_source {
289             let end = match self.pos_within_buf.checked_add(len) {
290                 Some(end) => end,
291                 None => return Err(ProtobufError::WireError(WireError::UnexpectedEof)),
292             };
293 
294             if end > self.limit_within_buf {
295                 return Err(ProtobufError::WireError(WireError::UnexpectedEof));
296             }
297 
298             let r = bytes.slice(self.pos_within_buf..end);
299             self.pos_within_buf += len;
300             Ok(r)
301         } else {
302             if len >= READ_RAW_BYTES_MAX_ALLOC {
303                 // We cannot trust `len` because protobuf message could be malformed.
304                 // Reading should not result in OOM when allocating a buffer.
305                 let mut v = Vec::new();
306                 self.read_exact_to_vec(len, &mut v)?;
307                 Ok(Bytes::from(v))
308             } else {
309                 let mut r = BytesMut::with_capacity(len);
310                 unsafe {
311                     let buf = Self::uninit_slice_as_mut_slice(&mut r.chunk_mut()[..len]);
312                     self.read_exact(buf)?;
313                     r.advance_mut(len);
314                 }
315                 Ok(r.freeze())
316             }
317         }
318     }
319 
320     #[cfg(feature = "bytes")]
uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [u8]321     unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [u8] {
322         use std::slice;
323         slice::from_raw_parts_mut(slice.as_mut_ptr(), slice.len())
324     }
325 
326     /// Returns 0 when EOF or limit reached.
read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize>327     pub fn read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize> {
328         self.fill_buf()?;
329 
330         let rem = &self.buf[self.pos_within_buf..self.limit_within_buf];
331 
332         let len = cmp::min(rem.len(), buf.len());
333         buf[..len].copy_from_slice(&rem[..len]);
334         self.pos_within_buf += len;
335         Ok(len)
336     }
337 
read_exact(&mut self, buf: &mut [u8]) -> ProtobufResult<()>338     pub fn read_exact(&mut self, buf: &mut [u8]) -> ProtobufResult<()> {
339         if self.remaining_in_buf_len() >= buf.len() {
340             let buf_len = buf.len();
341             buf.copy_from_slice(&self.buf[self.pos_within_buf..self.pos_within_buf + buf_len]);
342             self.pos_within_buf += buf_len;
343             return Ok(());
344         }
345 
346         if self.bytes_until_limit() < buf.len() as u64 {
347             return Err(ProtobufError::WireError(WireError::UnexpectedEof));
348         }
349 
350         let consume = self.pos_within_buf;
351         self.pos_of_buf_start += self.pos_within_buf as u64;
352         self.pos_within_buf = 0;
353         self.buf = &[];
354         self.limit_within_buf = 0;
355 
356         match self.input_source {
357             InputSource::Read(ref mut buf_read) => {
358                 buf_read.consume(consume);
359                 buf_read.read_exact(buf)?;
360             }
361             InputSource::BufRead(ref mut buf_read) => {
362                 buf_read.consume(consume);
363                 buf_read.read_exact(buf)?;
364             }
365             _ => {
366                 return Err(ProtobufError::WireError(WireError::UnexpectedEof));
367             }
368         }
369 
370         self.pos_of_buf_start += buf.len() as u64;
371 
372         self.assertions();
373 
374         Ok(())
375     }
376 
do_fill_buf(&mut self) -> ProtobufResult<()>377     fn do_fill_buf(&mut self) -> ProtobufResult<()> {
378         debug_assert!(self.pos_within_buf == self.limit_within_buf);
379 
380         // Limit is reached, do not fill buf, because otherwise
381         // synchronous read from `CodedInputStream` may block.
382         if self.limit == self.pos() {
383             return Ok(());
384         }
385 
386         let consume = self.buf.len();
387         self.pos_of_buf_start += self.buf.len() as u64;
388         self.buf = &[];
389         self.pos_within_buf = 0;
390         self.limit_within_buf = 0;
391 
392         match self.input_source {
393             InputSource::Read(ref mut buf_read) => {
394                 buf_read.consume(consume);
395                 self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) };
396             }
397             InputSource::BufRead(ref mut buf_read) => {
398                 buf_read.consume(consume);
399                 self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) };
400             }
401             _ => {
402                 return Ok(());
403             }
404         }
405 
406         self.update_limit_within_buf();
407 
408         Ok(())
409     }
410 
411     #[inline(always)]
fill_buf(&mut self) -> ProtobufResult<&[u8]>412     pub fn fill_buf(&mut self) -> ProtobufResult<&[u8]> {
413         if self.pos_within_buf == self.limit_within_buf {
414             self.do_fill_buf()?;
415         }
416 
417         Ok(if USE_UNSAFE_FOR_SPEED {
418             unsafe {
419                 self.buf
420                     .get_unchecked(self.pos_within_buf..self.limit_within_buf)
421             }
422         } else {
423             &self.buf[self.pos_within_buf..self.limit_within_buf]
424         })
425     }
426 
427     #[inline(always)]
consume(&mut self, amt: usize)428     pub fn consume(&mut self, amt: usize) {
429         assert!(amt <= self.limit_within_buf - self.pos_within_buf);
430         self.pos_within_buf += amt;
431     }
432 }
433 
434 #[cfg(all(test, feature = "bytes"))]
435 mod test_bytes {
436     use super::*;
437     use std::io::Write;
438 
make_long_string(len: usize) -> Vec<u8>439     fn make_long_string(len: usize) -> Vec<u8> {
440         let mut s = Vec::new();
441         while s.len() < len {
442             let len = s.len();
443             write!(&mut s, "{}", len).expect("unexpected");
444         }
445         s.truncate(len);
446         s
447     }
448 
449     #[test]
read_exact_bytes_from_slice()450     fn read_exact_bytes_from_slice() {
451         let bytes = make_long_string(100);
452         let mut bri = BufReadIter::from_byte_slice(&bytes[..]);
453         assert_eq!(&bytes[..90], &bri.read_exact_bytes(90).unwrap()[..]);
454         assert_eq!(bytes[90], bri.read_byte().expect("read_byte"));
455     }
456 
457     #[test]
read_exact_bytes_from_bytes()458     fn read_exact_bytes_from_bytes() {
459         let bytes = Bytes::from(make_long_string(100));
460         let mut bri = BufReadIter::from_bytes(&bytes);
461         let read = bri.read_exact_bytes(90).unwrap();
462         assert_eq!(&bytes[..90], &read[..]);
463         assert_eq!(&bytes[..90].as_ptr(), &read.as_ptr());
464         assert_eq!(bytes[90], bri.read_byte().expect("read_byte"));
465     }
466 }
467 
468 #[cfg(test)]
469 mod test {
470     use super::*;
471     use std::io;
472     use std::io::BufRead;
473     use std::io::Read;
474 
475     #[test]
eof_at_limit()476     fn eof_at_limit() {
477         struct Read5ThenPanic {
478             pos: usize,
479         }
480 
481         impl Read for Read5ThenPanic {
482             fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
483                 unreachable!();
484             }
485         }
486 
487         impl BufRead for Read5ThenPanic {
488             fn fill_buf(&mut self) -> io::Result<&[u8]> {
489                 assert_eq!(0, self.pos);
490                 static ZERO_TO_FIVE: &'static [u8] = &[0, 1, 2, 3, 4];
491                 Ok(ZERO_TO_FIVE)
492             }
493 
494             fn consume(&mut self, amt: usize) {
495                 if amt == 0 {
496                     // drop of BufReadIter
497                     return;
498                 }
499 
500                 assert_eq!(0, self.pos);
501                 assert_eq!(5, amt);
502                 self.pos += amt;
503             }
504         }
505 
506         let mut read = Read5ThenPanic { pos: 0 };
507         let mut buf_read_iter = BufReadIter::from_buf_read(&mut read);
508         assert_eq!(0, buf_read_iter.pos());
509         let _prev_limit = buf_read_iter.push_limit(5);
510         buf_read_iter.read_byte().expect("read_byte");
511         buf_read_iter
512             .read_exact(&mut [1, 2, 3, 4])
513             .expect("read_exact");
514         assert!(buf_read_iter.eof().expect("eof"));
515     }
516 }
517