1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6
7 use crate::constants::{ContentType, Epoch};
8 use crate::err::{nspr, Error, PR_SetError, Res};
9 use crate::prio;
10 use crate::ssl;
11
12 use neqo_common::{hex, hex_with_len, qtrace};
13 use std::cmp::min;
14 use std::convert::{TryFrom, TryInto};
15 use std::fmt;
16 use std::mem;
17 use std::ops::Deref;
18 use std::os::raw::{c_uint, c_void};
19 use std::pin::Pin;
20 use std::ptr::{null, null_mut};
21 use std::vec::Vec;
22
23 // Alias common types.
24 type PrFd = *mut prio::PRFileDesc;
25 type PrStatus = prio::PRStatus::Type;
26 const PR_SUCCESS: PrStatus = prio::PRStatus::PR_SUCCESS;
27 const PR_FAILURE: PrStatus = prio::PRStatus::PR_FAILURE;
28
29 /// Convert a pinned, boxed object into a void pointer.
as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void30 pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
31 (Pin::into_inner(pin.as_mut()) as *mut T).cast()
32 }
33
34 /// A slice of the output.
35 #[derive(Default)]
36 pub struct Record {
37 pub epoch: Epoch,
38 pub ct: ContentType,
39 pub data: Vec<u8>,
40 }
41
42 impl Record {
43 #[must_use]
new(epoch: Epoch, ct: ContentType, data: &[u8]) -> Self44 pub fn new(epoch: Epoch, ct: ContentType, data: &[u8]) -> Self {
45 Self {
46 epoch,
47 ct,
48 data: data.to_vec(),
49 }
50 }
51
52 // Shoves this record into the socket, returns true if blocked.
write(self, fd: *mut ssl::PRFileDesc) -> Res<()>53 pub(crate) fn write(self, fd: *mut ssl::PRFileDesc) -> Res<()> {
54 qtrace!("write {:?}", self);
55 unsafe {
56 ssl::SSL_RecordLayerData(
57 fd,
58 self.epoch,
59 ssl::SSLContentType::Type::from(self.ct),
60 self.data.as_ptr(),
61 c_uint::try_from(self.data.len())?,
62 )
63 }
64 }
65 }
66
67 impl fmt::Debug for Record {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result68 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
69 write!(
70 f,
71 "Record {:?}:{:?} {}",
72 self.epoch,
73 self.ct,
74 hex_with_len(&self.data[..])
75 )
76 }
77 }
78
79 #[derive(Debug, Default)]
80 pub struct RecordList {
81 records: Vec<Record>,
82 }
83
84 impl RecordList {
append(&mut self, epoch: Epoch, ct: ContentType, data: &[u8])85 fn append(&mut self, epoch: Epoch, ct: ContentType, data: &[u8]) {
86 self.records.push(Record::new(epoch, ct, data));
87 }
88
89 #[allow(clippy::unused_self)]
ingest( _fd: *mut ssl::PRFileDesc, epoch: ssl::PRUint16, ct: ssl::SSLContentType::Type, data: *const ssl::PRUint8, len: c_uint, arg: *mut c_void, ) -> ssl::SECStatus90 unsafe extern "C" fn ingest(
91 _fd: *mut ssl::PRFileDesc,
92 epoch: ssl::PRUint16,
93 ct: ssl::SSLContentType::Type,
94 data: *const ssl::PRUint8,
95 len: c_uint,
96 arg: *mut c_void,
97 ) -> ssl::SECStatus {
98 let records = arg.cast::<Self>().as_mut().unwrap();
99
100 let slice = std::slice::from_raw_parts(data, len as usize);
101 records.append(epoch, ContentType::try_from(ct).unwrap(), slice);
102 ssl::SECSuccess
103 }
104
105 /// Create a new record list.
setup(fd: *mut ssl::PRFileDesc) -> Res<Pin<Box<Self>>>106 pub(crate) fn setup(fd: *mut ssl::PRFileDesc) -> Res<Pin<Box<Self>>> {
107 let mut records = Box::pin(Self::default());
108 unsafe {
109 ssl::SSL_RecordLayerWriteCallback(fd, Some(Self::ingest), as_c_void(&mut records))
110 }?;
111 Ok(records)
112 }
113 }
114
115 impl Deref for RecordList {
116 type Target = Vec<Record>;
117 #[must_use]
deref(&self) -> &Vec<Record>118 fn deref(&self) -> &Vec<Record> {
119 &self.records
120 }
121 }
122
123 pub struct RecordListIter(std::vec::IntoIter<Record>);
124
125 impl Iterator for RecordListIter {
126 type Item = Record;
next(&mut self) -> Option<Self::Item>127 fn next(&mut self) -> Option<Self::Item> {
128 self.0.next()
129 }
130 }
131
132 impl IntoIterator for RecordList {
133 type Item = Record;
134 type IntoIter = RecordListIter;
135 #[must_use]
into_iter(self) -> Self::IntoIter136 fn into_iter(self) -> Self::IntoIter {
137 RecordListIter(self.records.into_iter())
138 }
139 }
140
141 pub struct AgentIoInputContext<'a> {
142 input: &'a mut AgentIoInput,
143 }
144
145 impl<'a> Drop for AgentIoInputContext<'a> {
drop(&mut self)146 fn drop(&mut self) {
147 self.input.reset();
148 }
149 }
150
151 #[derive(Debug)]
152 struct AgentIoInput {
153 // input is data that is read by TLS.
154 input: *const u8,
155 // input_available is how much data is left for reading.
156 available: usize,
157 }
158
159 impl AgentIoInput {
wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c>160 fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> {
161 assert!(self.input.is_null());
162 self.input = input.as_ptr();
163 self.available = input.len();
164 qtrace!("AgentIoInput wrap {:p}", self.input);
165 AgentIoInputContext { input: self }
166 }
167
168 // Take the data provided as input and provide it to the TLS stack.
read_input(&mut self, buf: *mut u8, count: usize) -> Res<usize>169 fn read_input(&mut self, buf: *mut u8, count: usize) -> Res<usize> {
170 let amount = min(self.available, count);
171 if amount == 0 {
172 unsafe {
173 PR_SetError(nspr::PR_WOULD_BLOCK_ERROR, 0);
174 }
175 return Err(Error::NoDataAvailable);
176 }
177
178 let src = unsafe { std::slice::from_raw_parts(self.input, amount) };
179 qtrace!([self], "read {}", hex(src));
180 let dst = unsafe { std::slice::from_raw_parts_mut(buf, amount) };
181 dst.copy_from_slice(src);
182 self.input = self.input.wrapping_add(amount);
183 self.available -= amount;
184 Ok(amount)
185 }
186
reset(&mut self)187 fn reset(&mut self) {
188 qtrace!([self], "reset");
189 self.input = null();
190 self.available = 0;
191 }
192 }
193
194 impl ::std::fmt::Display for AgentIoInput {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result195 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
196 write!(f, "AgentIoInput {:p}", self.input)
197 }
198 }
199
200 #[derive(Debug)]
201 pub struct AgentIo {
202 // input collects the input we might provide to TLS.
203 input: AgentIoInput,
204
205 // output contains data that is written by TLS.
206 output: Vec<u8>,
207 }
208
209 impl AgentIo {
new() -> Self210 pub fn new() -> Self {
211 Self {
212 input: AgentIoInput {
213 input: null(),
214 available: 0,
215 },
216 output: Vec::new(),
217 }
218 }
219
borrow(fd: &mut PrFd) -> &mut Self220 unsafe fn borrow(fd: &mut PrFd) -> &mut Self {
221 #[allow(clippy::cast_ptr_alignment)]
222 (**fd).secret.cast::<Self>().as_mut().unwrap()
223 }
224
wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c>225 pub fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> {
226 assert_eq!(self.output.len(), 0);
227 self.input.wrap(input)
228 }
229
230 // Stage output from TLS into the output buffer.
save_output(&mut self, buf: *const u8, count: usize)231 fn save_output(&mut self, buf: *const u8, count: usize) {
232 let slice = unsafe { std::slice::from_raw_parts(buf, count) };
233 qtrace!([self], "save output {}", hex(slice));
234 self.output.extend_from_slice(slice);
235 }
236
take_output(&mut self) -> Vec<u8>237 pub fn take_output(&mut self) -> Vec<u8> {
238 qtrace!([self], "take output");
239 mem::take(&mut self.output)
240 }
241 }
242
243 impl ::std::fmt::Display for AgentIo {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result244 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
245 write!(f, "AgentIo")
246 }
247 }
248
agent_close(fd: PrFd) -> PrStatus249 unsafe extern "C" fn agent_close(fd: PrFd) -> PrStatus {
250 (*fd).secret = null_mut();
251 if let Some(dtor) = (*fd).dtor {
252 dtor(fd);
253 }
254 PR_SUCCESS
255 }
256
agent_read(mut fd: PrFd, buf: *mut c_void, amount: prio::PRInt32) -> PrStatus257 unsafe extern "C" fn agent_read(mut fd: PrFd, buf: *mut c_void, amount: prio::PRInt32) -> PrStatus {
258 let io = AgentIo::borrow(&mut fd);
259 if let Ok(a) = usize::try_from(amount) {
260 match io.input.read_input(buf.cast(), a) {
261 Ok(_) => PR_SUCCESS,
262 Err(_) => PR_FAILURE,
263 }
264 } else {
265 PR_FAILURE
266 }
267 }
268
agent_recv( mut fd: PrFd, buf: *mut c_void, amount: prio::PRInt32, flags: prio::PRIntn, _timeout: prio::PRIntervalTime, ) -> prio::PRInt32269 unsafe extern "C" fn agent_recv(
270 mut fd: PrFd,
271 buf: *mut c_void,
272 amount: prio::PRInt32,
273 flags: prio::PRIntn,
274 _timeout: prio::PRIntervalTime,
275 ) -> prio::PRInt32 {
276 let io = AgentIo::borrow(&mut fd);
277 if flags != 0 {
278 return PR_FAILURE;
279 }
280 if let Ok(a) = usize::try_from(amount) {
281 match io.input.read_input(buf.cast(), a) {
282 Ok(v) => prio::PRInt32::try_from(v).unwrap_or(PR_FAILURE),
283 Err(_) => PR_FAILURE,
284 }
285 } else {
286 PR_FAILURE
287 }
288 }
289
agent_write( mut fd: PrFd, buf: *const c_void, amount: prio::PRInt32, ) -> PrStatus290 unsafe extern "C" fn agent_write(
291 mut fd: PrFd,
292 buf: *const c_void,
293 amount: prio::PRInt32,
294 ) -> PrStatus {
295 let io = AgentIo::borrow(&mut fd);
296 if let Ok(a) = usize::try_from(amount) {
297 io.save_output(buf.cast(), a);
298 amount
299 } else {
300 PR_FAILURE
301 }
302 }
303
agent_send( mut fd: PrFd, buf: *const c_void, amount: prio::PRInt32, flags: prio::PRIntn, _timeout: prio::PRIntervalTime, ) -> prio::PRInt32304 unsafe extern "C" fn agent_send(
305 mut fd: PrFd,
306 buf: *const c_void,
307 amount: prio::PRInt32,
308 flags: prio::PRIntn,
309 _timeout: prio::PRIntervalTime,
310 ) -> prio::PRInt32 {
311 let io = AgentIo::borrow(&mut fd);
312
313 if flags != 0 {
314 return PR_FAILURE;
315 }
316 if let Ok(a) = usize::try_from(amount) {
317 io.save_output(buf.cast(), a);
318 amount
319 } else {
320 PR_FAILURE
321 }
322 }
323
agent_available(mut fd: PrFd) -> prio::PRInt32324 unsafe extern "C" fn agent_available(mut fd: PrFd) -> prio::PRInt32 {
325 let io = AgentIo::borrow(&mut fd);
326 io.input.available.try_into().unwrap_or(PR_FAILURE)
327 }
328
agent_available64(mut fd: PrFd) -> prio::PRInt64329 unsafe extern "C" fn agent_available64(mut fd: PrFd) -> prio::PRInt64 {
330 let io = AgentIo::borrow(&mut fd);
331 io.input
332 .available
333 .try_into()
334 .unwrap_or_else(|_| PR_FAILURE.into())
335 }
336
337 #[allow(clippy::cast_possible_truncation)]
agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrStatus338 unsafe extern "C" fn agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrStatus {
339 let a = addr.as_mut().unwrap();
340 // Cast is safe because prio::PR_AF_INET is 2
341 a.inet.family = prio::PR_AF_INET as prio::PRUint16;
342 a.inet.port = 0;
343 a.inet.ip = 0;
344 PR_SUCCESS
345 }
346
agent_getsockopt(_fd: PrFd, opt: *mut prio::PRSocketOptionData) -> PrStatus347 unsafe extern "C" fn agent_getsockopt(_fd: PrFd, opt: *mut prio::PRSocketOptionData) -> PrStatus {
348 let o = opt.as_mut().unwrap();
349 if o.option == prio::PRSockOption::PR_SockOpt_Nonblocking {
350 o.value.non_blocking = 1;
351 return PR_SUCCESS;
352 }
353 PR_FAILURE
354 }
355
356 pub const METHODS: &prio::PRIOMethods = &prio::PRIOMethods {
357 file_type: prio::PRDescType::PR_DESC_LAYERED,
358 close: Some(agent_close),
359 read: Some(agent_read),
360 write: Some(agent_write),
361 available: Some(agent_available),
362 available64: Some(agent_available64),
363 fsync: None,
364 seek: None,
365 seek64: None,
366 fileInfo: None,
367 fileInfo64: None,
368 writev: None,
369 connect: None,
370 accept: None,
371 bind: None,
372 listen: None,
373 shutdown: None,
374 recv: Some(agent_recv),
375 send: Some(agent_send),
376 recvfrom: None,
377 sendto: None,
378 poll: None,
379 acceptread: None,
380 transmitfile: None,
381 getsockname: Some(agent_getname),
382 getpeername: Some(agent_getname),
383 reserved_fn_6: None,
384 reserved_fn_5: None,
385 getsocketoption: Some(agent_getsockopt),
386 setsocketoption: None,
387 sendfile: None,
388 connectcontinue: None,
389 reserved_fn_3: None,
390 reserved_fn_2: None,
391 reserved_fn_1: None,
392 reserved_fn_0: None,
393 };
394