1 use crate::Asn1DerError;
2 use std::{
3     io::{self, Read, Write},
4     mem::size_of,
5 };
6 
7 /// The byte size of an `usize`
8 const USIZE_LEN: usize = size_of::<usize>();
9 
10 /// An extension for `io::Read`
11 pub trait ReadExt {
12     /// Reads the next byte
read_one(&mut self) -> io::Result<u8>13     fn read_one(&mut self) -> io::Result<u8>;
14 }
15 impl<T: Read> ReadExt for T {
read_one(&mut self) -> io::Result<u8>16     fn read_one(&mut self) -> io::Result<u8> {
17         let mut buf = [0];
18         self.read_exact(&mut buf)?;
19         Ok(buf[0])
20     }
21 }
22 /// An extension for `io::Write`
23 pub trait WriteExt {
24     /// Writes on `byte`
write_one(&mut self, byte: u8) -> io::Result<usize>25     fn write_one(&mut self, byte: u8) -> io::Result<usize>;
26     /// Writes all bytes in `data`
write_exact(&mut self, data: &[u8]) -> io::Result<usize>27     fn write_exact(&mut self, data: &[u8]) -> io::Result<usize>;
28 }
29 impl<T: Write> WriteExt for T {
write_one(&mut self, byte: u8) -> io::Result<usize>30     fn write_one(&mut self, byte: u8) -> io::Result<usize> {
31         self.write_exact(&[byte])
32     }
write_exact(&mut self, data: &[u8]) -> io::Result<usize>33     fn write_exact(&mut self, data: &[u8]) -> io::Result<usize> {
34         self.write_all(data)?;
35         Ok(data.len())
36     }
37 }
38 
39 const PEEKED_BUFFER_SIZE: usize = 10;
40 
41 #[derive(Debug)]
42 pub struct PeekedContent {
43     len: usize,
44     buffer: [u8; PEEKED_BUFFER_SIZE],
45 }
46 
47 impl PeekedContent {
new() -> Self48     fn new() -> Self {
49         Self {
50             len: 0,
51             buffer: [0; PEEKED_BUFFER_SIZE],
52         }
53     }
54 
take(&mut self) -> Self55     pub fn take(&mut self) -> Self {
56         let mut val = Self::new();
57         std::mem::swap(&mut val, self);
58         val
59     }
60 
len(&self) -> usize61     pub fn len(&self) -> usize {
62         self.len
63     }
64 
buffer(&self) -> [u8; PEEKED_BUFFER_SIZE]65     pub fn buffer(&self) -> [u8; PEEKED_BUFFER_SIZE] {
66         self.buffer
67     }
68 }
69 
70 /// A peekable reader
71 pub struct PeekableReader<R: Read> {
72     reader: R,
73     peeked: PeekedContent,
74     pos: usize,
75 }
76 impl<R: Read> PeekableReader<R> {
77     /// Creates a new `PeekableReader` with `reader` as source
new(reader: R) -> Self78     pub fn new(reader: R) -> Self {
79         Self {
80             reader,
81             peeked: PeekedContent::new(),
82             pos: 0,
83         }
84     }
85 
86     /// Peeks one byte without removing it from the `read`-queue
87     ///
88     /// Multiple successive calls to `peek_one` will always return the same next byte
peek_one(&mut self) -> io::Result<u8>89     pub fn peek_one(&mut self) -> io::Result<u8> {
90         // Check if we already have peeked data
91         if self.peeked.len == 0 {
92             self.peeked.buffer[0] = self.reader.read_one()?;
93             self.peeked.len = 1;
94         }
95         Ok(self.peeked.buffer[0])
96     }
97 
98     /// Peeks several bytes at once without removing them from the `read`-queue
99     /// Buffer size is defined by `PeekedBuffer`.
100     ///
101     /// Successive calls to `peek_buffer` always return the same bytes.
peek_buffer(&mut self) -> io::Result<&PeekedContent>102     pub fn peek_buffer(&mut self) -> io::Result<&PeekedContent> {
103         // Check if we already have peeked data
104         if self.peeked.len < PEEKED_BUFFER_SIZE {
105             let n = self.reader.read(&mut self.peeked.buffer[self.peeked.len..])?;
106             self.peeked.len += n;
107         }
108 
109         Ok(&self.peeked)
110     }
111 
112     /// The current position (amount of bytes read)
pos(&self) -> usize113     pub fn pos(&self) -> usize {
114         self.pos
115     }
116 }
117 impl<R: Read> Read for PeekableReader<R> {
read(&mut self, mut buf: &mut [u8]) -> io::Result<usize>118     fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
119         let mut read = 0;
120 
121         let peeked = self.peeked.take();
122         let new_start_index = if buf.len() <= peeked.len {
123             buf.copy_from_slice(&peeked.buffer[..buf.len()]);
124 
125             // keep remaining peeked bytes
126             let remaining_bytes = peeked.len - buf.len();
127             if remaining_bytes > 0 {
128                 self.peeked.buffer[..remaining_bytes].copy_from_slice(&peeked.buffer[buf.len()..peeked.len]);
129                 self.peeked.len = remaining_bytes;
130             }
131 
132             buf.len()
133         } else {
134             buf[..peeked.len].copy_from_slice(&peeked.buffer[..peeked.len]);
135             peeked.len
136         };
137         read += new_start_index;
138         buf = &mut buf[new_start_index..];
139 
140         // Read remaining bytes
141         read += self.reader.read(buf)?;
142 
143         self.pos += read;
144 
145         Ok(read)
146     }
147 }
148 
149 /// An implementation of the ASN.1-DER length
150 pub struct Length;
151 impl Length {
152     /// Deserializes a length from `reader`
deserialized(mut reader: impl Read) -> Result<usize, Asn1DerError>153     pub fn deserialized(mut reader: impl Read) -> Result<usize, Asn1DerError> {
154         // Deserialize length
155         Ok(match reader.read_one()? {
156             n @ 128..=255 => {
157                 // Deserialize the amount of length bytes
158                 let len = n as usize & 127;
159                 if len > USIZE_LEN {
160                     return Err(Asn1DerError::UnsupportedValue);
161                 }
162 
163                 // Deserialize value
164                 let mut num = [0; USIZE_LEN];
165                 reader.read_exact(&mut num[USIZE_LEN - len..])?;
166                 usize::from_be_bytes(num)
167             }
168             n => n as usize,
169         })
170     }
171 
172     /// Serializes `len` to `writer`
serialize(len: usize, mut writer: impl Write) -> Result<usize, Asn1DerError>173     pub fn serialize(len: usize, mut writer: impl Write) -> Result<usize, Asn1DerError> {
174         // Determine the serialized length
175         let written = match len {
176             0..=127 => writer.write_one(len as u8)?,
177             _ => {
178                 let to_write = USIZE_LEN - (len.leading_zeros() / 8) as usize;
179                 // Write number of bytes used to encode length
180                 let mut written = writer.write_one(to_write as u8 | 0x80)?;
181 
182                 // Write length
183                 let mut buf = [0; USIZE_LEN];
184                 buf.copy_from_slice(&len.to_be_bytes());
185                 written += writer.write_exact(&buf[USIZE_LEN - to_write..])?;
186 
187                 written
188             }
189         };
190 
191         Ok(written)
192     }
193 
194     /// Returns how many bytes are going to be needed to encode `len`.
encoded_len(len: usize) -> usize195     pub fn encoded_len(len: usize) -> usize {
196         match len {
197             0..=127 => 1,
198             _ => 1 + USIZE_LEN - (len.leading_zeros() / 8) as usize,
199         }
200     }
201 }
202 
203 #[cfg(test)]
204 mod tests {
205     use super::*;
206 
207     #[test]
asn1_short_form_length()208     fn asn1_short_form_length() {
209         let mut writer: Vec<u8> = Vec::new();
210         let written = Length::serialize(10, &mut writer).expect("serialization failed");
211         assert_eq!(written, 1);
212         assert_eq!(writer.len(), 1);
213         assert_eq!(writer[0], 10);
214     }
215 
216     #[test]
asn1_long_form_length_1_byte()217     fn asn1_long_form_length_1_byte() {
218         let mut writer: Vec<u8> = Vec::new();
219         let written = Length::serialize(129, &mut writer).expect("serialization failed");
220         assert_eq!(written, 2);
221         assert_eq!(writer.len(), 2);
222         assert_eq!(writer[0], 0x81);
223         assert_eq!(writer[1], 0x81);
224     }
225 
226     #[test]
asn1_long_form_length_2_bytes()227     fn asn1_long_form_length_2_bytes() {
228         let mut writer: Vec<u8> = Vec::new();
229         let written = Length::serialize(290, &mut writer).expect("serialization failed");
230         assert_eq!(written, 3);
231         assert_eq!(writer.len(), 3);
232         assert_eq!(writer[0], 0x82);
233         assert_eq!(writer[1], 0x01);
234         assert_eq!(writer[2], 0x22);
235     }
236 }
237