1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6 
7 use crate::constants::{ContentType, Epoch};
8 use crate::err::{nspr, Error, PR_SetError, Res};
9 use crate::prio;
10 use crate::ssl;
11 
12 use neqo_common::{hex, hex_with_len, qtrace};
13 use std::cmp::min;
14 use std::convert::{TryFrom, TryInto};
15 use std::fmt;
16 use std::mem;
17 use std::ops::Deref;
18 use std::os::raw::{c_uint, c_void};
19 use std::pin::Pin;
20 use std::ptr::{null, null_mut};
21 use std::vec::Vec;
22 
23 // Alias common types.
24 type PrFd = *mut prio::PRFileDesc;
25 type PrStatus = prio::PRStatus::Type;
26 const PR_SUCCESS: PrStatus = prio::PRStatus::PR_SUCCESS;
27 const PR_FAILURE: PrStatus = prio::PRStatus::PR_FAILURE;
28 
29 /// Convert a pinned, boxed object into a void pointer.
as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void30 pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
31     (Pin::into_inner(pin.as_mut()) as *mut T).cast()
32 }
33 
34 /// A slice of the output.
35 #[derive(Default)]
36 pub struct Record {
37     pub epoch: Epoch,
38     pub ct: ContentType,
39     pub data: Vec<u8>,
40 }
41 
42 impl Record {
43     #[must_use]
new(epoch: Epoch, ct: ContentType, data: &[u8]) -> Self44     pub fn new(epoch: Epoch, ct: ContentType, data: &[u8]) -> Self {
45         Self {
46             epoch,
47             ct,
48             data: data.to_vec(),
49         }
50     }
51 
52     // Shoves this record into the socket, returns true if blocked.
write(self, fd: *mut ssl::PRFileDesc) -> Res<()>53     pub(crate) fn write(self, fd: *mut ssl::PRFileDesc) -> Res<()> {
54         qtrace!("write {:?}", self);
55         unsafe {
56             ssl::SSL_RecordLayerData(
57                 fd,
58                 self.epoch,
59                 ssl::SSLContentType::Type::from(self.ct),
60                 self.data.as_ptr(),
61                 c_uint::try_from(self.data.len())?,
62             )
63         }
64     }
65 }
66 
67 impl fmt::Debug for Record {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result68     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
69         write!(
70             f,
71             "Record {:?}:{:?} {}",
72             self.epoch,
73             self.ct,
74             hex_with_len(&self.data[..])
75         )
76     }
77 }
78 
79 #[derive(Debug, Default)]
80 pub struct RecordList {
81     records: Vec<Record>,
82 }
83 
84 impl RecordList {
append(&mut self, epoch: Epoch, ct: ContentType, data: &[u8])85     fn append(&mut self, epoch: Epoch, ct: ContentType, data: &[u8]) {
86         self.records.push(Record::new(epoch, ct, data));
87     }
88 
89     #[allow(clippy::unused_self)]
ingest( _fd: *mut ssl::PRFileDesc, epoch: ssl::PRUint16, ct: ssl::SSLContentType::Type, data: *const ssl::PRUint8, len: c_uint, arg: *mut c_void, ) -> ssl::SECStatus90     unsafe extern "C" fn ingest(
91         _fd: *mut ssl::PRFileDesc,
92         epoch: ssl::PRUint16,
93         ct: ssl::SSLContentType::Type,
94         data: *const ssl::PRUint8,
95         len: c_uint,
96         arg: *mut c_void,
97     ) -> ssl::SECStatus {
98         let records = arg.cast::<Self>().as_mut().unwrap();
99 
100         let slice = std::slice::from_raw_parts(data, len as usize);
101         records.append(epoch, ContentType::try_from(ct).unwrap(), slice);
102         ssl::SECSuccess
103     }
104 
105     /// Create a new record list.
setup(fd: *mut ssl::PRFileDesc) -> Res<Pin<Box<Self>>>106     pub(crate) fn setup(fd: *mut ssl::PRFileDesc) -> Res<Pin<Box<Self>>> {
107         let mut records = Box::pin(Self::default());
108         unsafe {
109             ssl::SSL_RecordLayerWriteCallback(fd, Some(Self::ingest), as_c_void(&mut records))
110         }?;
111         Ok(records)
112     }
113 }
114 
115 impl Deref for RecordList {
116     type Target = Vec<Record>;
117     #[must_use]
deref(&self) -> &Vec<Record>118     fn deref(&self) -> &Vec<Record> {
119         &self.records
120     }
121 }
122 
123 pub struct RecordListIter(std::vec::IntoIter<Record>);
124 
125 impl Iterator for RecordListIter {
126     type Item = Record;
next(&mut self) -> Option<Self::Item>127     fn next(&mut self) -> Option<Self::Item> {
128         self.0.next()
129     }
130 }
131 
132 impl IntoIterator for RecordList {
133     type Item = Record;
134     type IntoIter = RecordListIter;
135     #[must_use]
into_iter(self) -> Self::IntoIter136     fn into_iter(self) -> Self::IntoIter {
137         RecordListIter(self.records.into_iter())
138     }
139 }
140 
141 pub struct AgentIoInputContext<'a> {
142     input: &'a mut AgentIoInput,
143 }
144 
145 impl<'a> Drop for AgentIoInputContext<'a> {
drop(&mut self)146     fn drop(&mut self) {
147         self.input.reset();
148     }
149 }
150 
151 #[derive(Debug)]
152 struct AgentIoInput {
153     // input is data that is read by TLS.
154     input: *const u8,
155     // input_available is how much data is left for reading.
156     available: usize,
157 }
158 
159 impl AgentIoInput {
wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c>160     fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> {
161         assert!(self.input.is_null());
162         self.input = input.as_ptr();
163         self.available = input.len();
164         qtrace!("AgentIoInput wrap {:p}", self.input);
165         AgentIoInputContext { input: self }
166     }
167 
168     // Take the data provided as input and provide it to the TLS stack.
read_input(&mut self, buf: *mut u8, count: usize) -> Res<usize>169     fn read_input(&mut self, buf: *mut u8, count: usize) -> Res<usize> {
170         let amount = min(self.available, count);
171         if amount == 0 {
172             unsafe {
173                 PR_SetError(nspr::PR_WOULD_BLOCK_ERROR, 0);
174             }
175             return Err(Error::NoDataAvailable);
176         }
177 
178         let src = unsafe { std::slice::from_raw_parts(self.input, amount) };
179         qtrace!([self], "read {}", hex(src));
180         let dst = unsafe { std::slice::from_raw_parts_mut(buf, amount) };
181         dst.copy_from_slice(src);
182         self.input = self.input.wrapping_add(amount);
183         self.available -= amount;
184         Ok(amount)
185     }
186 
reset(&mut self)187     fn reset(&mut self) {
188         qtrace!([self], "reset");
189         self.input = null();
190         self.available = 0;
191     }
192 }
193 
194 impl ::std::fmt::Display for AgentIoInput {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result195     fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
196         write!(f, "AgentIoInput {:p}", self.input)
197     }
198 }
199 
200 #[derive(Debug)]
201 pub struct AgentIo {
202     // input collects the input we might provide to TLS.
203     input: AgentIoInput,
204 
205     // output contains data that is written by TLS.
206     output: Vec<u8>,
207 }
208 
209 impl AgentIo {
new() -> Self210     pub fn new() -> Self {
211         Self {
212             input: AgentIoInput {
213                 input: null(),
214                 available: 0,
215             },
216             output: Vec::new(),
217         }
218     }
219 
borrow(fd: &mut PrFd) -> &mut Self220     unsafe fn borrow(fd: &mut PrFd) -> &mut Self {
221         #[allow(clippy::cast_ptr_alignment)]
222         (**fd).secret.cast::<Self>().as_mut().unwrap()
223     }
224 
wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c>225     pub fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> {
226         assert_eq!(self.output.len(), 0);
227         self.input.wrap(input)
228     }
229 
230     // Stage output from TLS into the output buffer.
save_output(&mut self, buf: *const u8, count: usize)231     fn save_output(&mut self, buf: *const u8, count: usize) {
232         let slice = unsafe { std::slice::from_raw_parts(buf, count) };
233         qtrace!([self], "save output {}", hex(slice));
234         self.output.extend_from_slice(slice);
235     }
236 
take_output(&mut self) -> Vec<u8>237     pub fn take_output(&mut self) -> Vec<u8> {
238         qtrace!([self], "take output");
239         mem::take(&mut self.output)
240     }
241 }
242 
243 impl ::std::fmt::Display for AgentIo {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result244     fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
245         write!(f, "AgentIo")
246     }
247 }
248 
agent_close(fd: PrFd) -> PrStatus249 unsafe extern "C" fn agent_close(fd: PrFd) -> PrStatus {
250     (*fd).secret = null_mut();
251     if let Some(dtor) = (*fd).dtor {
252         dtor(fd);
253     }
254     PR_SUCCESS
255 }
256 
agent_read(mut fd: PrFd, buf: *mut c_void, amount: prio::PRInt32) -> PrStatus257 unsafe extern "C" fn agent_read(mut fd: PrFd, buf: *mut c_void, amount: prio::PRInt32) -> PrStatus {
258     let io = AgentIo::borrow(&mut fd);
259     if let Ok(a) = usize::try_from(amount) {
260         match io.input.read_input(buf.cast(), a) {
261             Ok(_) => PR_SUCCESS,
262             Err(_) => PR_FAILURE,
263         }
264     } else {
265         PR_FAILURE
266     }
267 }
268 
agent_recv( mut fd: PrFd, buf: *mut c_void, amount: prio::PRInt32, flags: prio::PRIntn, _timeout: prio::PRIntervalTime, ) -> prio::PRInt32269 unsafe extern "C" fn agent_recv(
270     mut fd: PrFd,
271     buf: *mut c_void,
272     amount: prio::PRInt32,
273     flags: prio::PRIntn,
274     _timeout: prio::PRIntervalTime,
275 ) -> prio::PRInt32 {
276     let io = AgentIo::borrow(&mut fd);
277     if flags != 0 {
278         return PR_FAILURE;
279     }
280     if let Ok(a) = usize::try_from(amount) {
281         match io.input.read_input(buf.cast(), a) {
282             Ok(v) => prio::PRInt32::try_from(v).unwrap_or(PR_FAILURE),
283             Err(_) => PR_FAILURE,
284         }
285     } else {
286         PR_FAILURE
287     }
288 }
289 
agent_write( mut fd: PrFd, buf: *const c_void, amount: prio::PRInt32, ) -> PrStatus290 unsafe extern "C" fn agent_write(
291     mut fd: PrFd,
292     buf: *const c_void,
293     amount: prio::PRInt32,
294 ) -> PrStatus {
295     let io = AgentIo::borrow(&mut fd);
296     if let Ok(a) = usize::try_from(amount) {
297         io.save_output(buf.cast(), a);
298         amount
299     } else {
300         PR_FAILURE
301     }
302 }
303 
agent_send( mut fd: PrFd, buf: *const c_void, amount: prio::PRInt32, flags: prio::PRIntn, _timeout: prio::PRIntervalTime, ) -> prio::PRInt32304 unsafe extern "C" fn agent_send(
305     mut fd: PrFd,
306     buf: *const c_void,
307     amount: prio::PRInt32,
308     flags: prio::PRIntn,
309     _timeout: prio::PRIntervalTime,
310 ) -> prio::PRInt32 {
311     let io = AgentIo::borrow(&mut fd);
312 
313     if flags != 0 {
314         return PR_FAILURE;
315     }
316     if let Ok(a) = usize::try_from(amount) {
317         io.save_output(buf.cast(), a);
318         amount
319     } else {
320         PR_FAILURE
321     }
322 }
323 
agent_available(mut fd: PrFd) -> prio::PRInt32324 unsafe extern "C" fn agent_available(mut fd: PrFd) -> prio::PRInt32 {
325     let io = AgentIo::borrow(&mut fd);
326     io.input.available.try_into().unwrap_or(PR_FAILURE)
327 }
328 
agent_available64(mut fd: PrFd) -> prio::PRInt64329 unsafe extern "C" fn agent_available64(mut fd: PrFd) -> prio::PRInt64 {
330     let io = AgentIo::borrow(&mut fd);
331     io.input
332         .available
333         .try_into()
334         .unwrap_or_else(|_| PR_FAILURE.into())
335 }
336 
337 #[allow(clippy::cast_possible_truncation)]
agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrStatus338 unsafe extern "C" fn agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrStatus {
339     let a = addr.as_mut().unwrap();
340     // Cast is safe because prio::PR_AF_INET is 2
341     a.inet.family = prio::PR_AF_INET as prio::PRUint16;
342     a.inet.port = 0;
343     a.inet.ip = 0;
344     PR_SUCCESS
345 }
346 
agent_getsockopt(_fd: PrFd, opt: *mut prio::PRSocketOptionData) -> PrStatus347 unsafe extern "C" fn agent_getsockopt(_fd: PrFd, opt: *mut prio::PRSocketOptionData) -> PrStatus {
348     let o = opt.as_mut().unwrap();
349     if o.option == prio::PRSockOption::PR_SockOpt_Nonblocking {
350         o.value.non_blocking = 1;
351         return PR_SUCCESS;
352     }
353     PR_FAILURE
354 }
355 
356 pub const METHODS: &prio::PRIOMethods = &prio::PRIOMethods {
357     file_type: prio::PRDescType::PR_DESC_LAYERED,
358     close: Some(agent_close),
359     read: Some(agent_read),
360     write: Some(agent_write),
361     available: Some(agent_available),
362     available64: Some(agent_available64),
363     fsync: None,
364     seek: None,
365     seek64: None,
366     fileInfo: None,
367     fileInfo64: None,
368     writev: None,
369     connect: None,
370     accept: None,
371     bind: None,
372     listen: None,
373     shutdown: None,
374     recv: Some(agent_recv),
375     send: Some(agent_send),
376     recvfrom: None,
377     sendto: None,
378     poll: None,
379     acceptread: None,
380     transmitfile: None,
381     getsockname: Some(agent_getname),
382     getpeername: Some(agent_getname),
383     reserved_fn_6: None,
384     reserved_fn_5: None,
385     getsocketoption: Some(agent_getsockopt),
386     setsocketoption: None,
387     sendfile: None,
388     connectcontinue: None,
389     reserved_fn_3: None,
390     reserved_fn_2: None,
391     reserved_fn_1: None,
392     reserved_fn_0: None,
393 };
394