1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6 
7 #![allow(clippy::module_name_repetitions)]
8 
9 use super::{ExtendedConnectEvents, ExtendedConnectType, SessionCloseReason};
10 use crate::{
11     frames::{FrameReader, StreamReaderRecvStreamWrapper, WebTransportFrame},
12     recv_message::{RecvMessage, RecvMessageInfo},
13     send_message::SendMessage,
14     CloseType, Error, HFrame, Http3StreamInfo, Http3StreamType, HttpRecvStream,
15     HttpRecvStreamEvents, Priority, PriorityHandler, ReceiveOutput, RecvStream, RecvStreamEvents,
16     Res, SendStream, SendStreamEvents, Stream,
17 };
18 use neqo_common::{qtrace, Encoder, Header, MessageType, Role};
19 use neqo_qpack::{QPackDecoder, QPackEncoder};
20 use neqo_transport::{Connection, StreamId};
21 use std::any::Any;
22 use std::cell::RefCell;
23 use std::collections::BTreeSet;
24 use std::mem;
25 use std::rc::Rc;
26 
27 #[derive(Debug, PartialEq)]
28 enum SessionState {
29     Negotiating,
30     Active,
31     FinPending,
32     Done,
33 }
34 
35 impl SessionState {
closing_state(&self) -> bool36     pub fn closing_state(&self) -> bool {
37         matches!(self, Self::FinPending | Self::Done)
38     }
39 }
40 
41 #[derive(Debug)]
42 pub struct WebTransportSession {
43     control_stream_recv: Box<dyn RecvStream>,
44     control_stream_send: Box<dyn SendStream>,
45     stream_event_listener: Rc<RefCell<WebTransportSessionListener>>,
46     session_id: StreamId,
47     state: SessionState,
48     frame_reader: FrameReader,
49     events: Box<dyn ExtendedConnectEvents>,
50     send_streams: BTreeSet<StreamId>,
51     recv_streams: BTreeSet<StreamId>,
52     role: Role,
53 }
54 
55 impl ::std::fmt::Display for WebTransportSession {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result56     fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
57         write!(f, "WebTransportSession session={}", self.session_id,)
58     }
59 }
60 
61 impl WebTransportSession {
62     #[must_use]
new( session_id: StreamId, events: Box<dyn ExtendedConnectEvents>, role: Role, qpack_encoder: Rc<RefCell<QPackEncoder>>, qpack_decoder: Rc<RefCell<QPackDecoder>>, ) -> Self63     pub fn new(
64         session_id: StreamId,
65         events: Box<dyn ExtendedConnectEvents>,
66         role: Role,
67         qpack_encoder: Rc<RefCell<QPackEncoder>>,
68         qpack_decoder: Rc<RefCell<QPackDecoder>>,
69     ) -> Self {
70         let stream_event_listener = Rc::new(RefCell::new(WebTransportSessionListener::default()));
71         Self {
72             control_stream_recv: Box::new(RecvMessage::new(
73                 &RecvMessageInfo {
74                     message_type: MessageType::Response,
75                     stream_type: Http3StreamType::ExtendedConnect,
76                     stream_id: session_id,
77                     header_frame_type_read: false,
78                 },
79                 qpack_decoder,
80                 Box::new(stream_event_listener.clone()),
81                 None,
82                 PriorityHandler::new(false, Priority::default()),
83             )),
84             control_stream_send: Box::new(SendMessage::new(
85                 MessageType::Request,
86                 Http3StreamType::ExtendedConnect,
87                 session_id,
88                 qpack_encoder,
89                 Box::new(stream_event_listener.clone()),
90             )),
91             stream_event_listener,
92             session_id,
93             state: SessionState::Negotiating,
94             frame_reader: FrameReader::new(),
95             events,
96             send_streams: BTreeSet::new(),
97             recv_streams: BTreeSet::new(),
98             role,
99         }
100     }
101 
102     /// # Panics
103     /// This function is only called with `RecvStream` and `SendStream` that also implement
104     /// the http specific functions and `http_stream()` will never return `None`.
105     #[must_use]
new_with_http_streams( session_id: StreamId, events: Box<dyn ExtendedConnectEvents>, role: Role, mut control_stream_recv: Box<dyn RecvStream>, mut control_stream_send: Box<dyn SendStream>, ) -> Self106     pub fn new_with_http_streams(
107         session_id: StreamId,
108         events: Box<dyn ExtendedConnectEvents>,
109         role: Role,
110         mut control_stream_recv: Box<dyn RecvStream>,
111         mut control_stream_send: Box<dyn SendStream>,
112     ) -> Self {
113         let stream_event_listener = Rc::new(RefCell::new(WebTransportSessionListener::default()));
114         control_stream_recv
115             .http_stream()
116             .unwrap()
117             .set_new_listener(Box::new(stream_event_listener.clone()));
118         control_stream_send
119             .http_stream()
120             .unwrap()
121             .set_new_listener(Box::new(stream_event_listener.clone()));
122         Self {
123             control_stream_recv,
124             control_stream_send,
125             stream_event_listener,
126             session_id,
127             state: SessionState::Active,
128             frame_reader: FrameReader::new(),
129             events,
130             send_streams: BTreeSet::new(),
131             recv_streams: BTreeSet::new(),
132             role,
133         }
134     }
135 
136     /// # Errors
137     /// The function can only fail if supplied headers are not valid http headers.
138     /// # Panics
139     /// `control_stream_send` implements the  http specific functions and `http_stream()`
140     /// will never return `None`.
send_request(&mut self, headers: &[Header], conn: &mut Connection) -> Res<()>141     pub fn send_request(&mut self, headers: &[Header], conn: &mut Connection) -> Res<()> {
142         self.control_stream_send
143             .http_stream()
144             .unwrap()
145             .send_headers(headers, conn)
146     }
147 
receive(&mut self, conn: &mut Connection) -> Res<(ReceiveOutput, bool)>148     fn receive(&mut self, conn: &mut Connection) -> Res<(ReceiveOutput, bool)> {
149         qtrace!([self], "receive control data");
150         let (out, _) = self.control_stream_recv.receive(conn)?;
151         debug_assert!(out == ReceiveOutput::NoOutput);
152         self.maybe_check_headers();
153         self.read_control_stream(conn)?;
154         Ok((ReceiveOutput::NoOutput, self.state == SessionState::Done))
155     }
156 
header_unblocked(&mut self, conn: &mut Connection) -> Res<(ReceiveOutput, bool)>157     fn header_unblocked(&mut self, conn: &mut Connection) -> Res<(ReceiveOutput, bool)> {
158         let (out, _) = self
159             .control_stream_recv
160             .http_stream()
161             .unwrap()
162             .header_unblocked(conn)?;
163         debug_assert!(out == ReceiveOutput::NoOutput);
164         self.maybe_check_headers();
165         self.read_control_stream(conn)?;
166         Ok((ReceiveOutput::NoOutput, self.state == SessionState::Done))
167     }
168 
maybe_update_priority(&mut self, priority: Priority) -> bool169     fn maybe_update_priority(&mut self, priority: Priority) -> bool {
170         self.control_stream_recv
171             .http_stream()
172             .unwrap()
173             .maybe_update_priority(priority)
174     }
175 
priority_update_frame(&mut self) -> Option<HFrame>176     fn priority_update_frame(&mut self) -> Option<HFrame> {
177         self.control_stream_recv
178             .http_stream()
179             .unwrap()
180             .priority_update_frame()
181     }
182 
priority_update_sent(&mut self)183     fn priority_update_sent(&mut self) {
184         self.control_stream_recv
185             .http_stream()
186             .unwrap()
187             .priority_update_sent();
188     }
189 
send(&mut self, conn: &mut Connection) -> Res<()>190     fn send(&mut self, conn: &mut Connection) -> Res<()> {
191         self.control_stream_send.send(conn)?;
192         if self.control_stream_send.done() {
193             self.state = SessionState::Done;
194         }
195         Ok(())
196     }
197 
has_data_to_send(&self) -> bool198     fn has_data_to_send(&self) -> bool {
199         self.control_stream_send.has_data_to_send()
200     }
201 
done(&self) -> bool202     fn done(&self) -> bool {
203         self.state == SessionState::Done
204     }
205 
close(&mut self, close_type: CloseType)206     fn close(&mut self, close_type: CloseType) {
207         if self.state.closing_state() {
208             return;
209         }
210         qtrace!("ExtendedConnect close the session");
211         self.state = SessionState::Done;
212         if let CloseType::ResetApp(_) = close_type {
213             return;
214         }
215         self.events.session_end(
216             ExtendedConnectType::WebTransport,
217             self.session_id,
218             SessionCloseReason::from(close_type),
219         );
220     }
221 
222     /// # Panics
223     /// This cannot panic because headers are checked before this function called.
maybe_check_headers(&mut self)224     pub fn maybe_check_headers(&mut self) {
225         if SessionState::Negotiating != self.state {
226             return;
227         }
228 
229         if let Some((headers, interim, fin)) = self.stream_event_listener.borrow_mut().get_headers()
230         {
231             qtrace!(
232                 "ExtendedConnect response headers {:?}, fin={}",
233                 headers,
234                 fin
235             );
236 
237             if interim {
238                 if fin {
239                     self.events.session_end(
240                         ExtendedConnectType::WebTransport,
241                         self.session_id,
242                         SessionCloseReason::Clean {
243                             error: 0,
244                             message: "".to_string(),
245                         },
246                     );
247                     self.state = SessionState::Done;
248                 }
249             } else {
250                 let status = headers
251                     .iter()
252                     .find_map(|h| {
253                         if h.name() == ":status" {
254                             h.value().parse::<u16>().ok()
255                         } else {
256                             None
257                         }
258                     })
259                     .unwrap();
260 
261                 self.state = if (200..300).contains(&status) {
262                     if fin {
263                         self.events.session_end(
264                             ExtendedConnectType::WebTransport,
265                             self.session_id,
266                             SessionCloseReason::Clean {
267                                 error: 0,
268                                 message: "".to_string(),
269                             },
270                         );
271                         SessionState::Done
272                     } else {
273                         self.events.session_start(
274                             ExtendedConnectType::WebTransport,
275                             self.session_id,
276                             status,
277                         );
278                         SessionState::Active
279                     }
280                 } else {
281                     self.events.session_end(
282                         ExtendedConnectType::WebTransport,
283                         self.session_id,
284                         SessionCloseReason::Status(status),
285                     );
286                     SessionState::Done
287                 };
288             }
289         }
290     }
291 
add_stream(&mut self, stream_id: StreamId)292     pub fn add_stream(&mut self, stream_id: StreamId) {
293         if let SessionState::Active = self.state {
294             if stream_id.is_bidi() {
295                 self.send_streams.insert(stream_id);
296                 self.recv_streams.insert(stream_id);
297             } else if stream_id.is_self_initiated(self.role) {
298                 self.send_streams.insert(stream_id);
299             } else {
300                 self.recv_streams.insert(stream_id);
301             }
302 
303             if !stream_id.is_self_initiated(self.role) {
304                 self.events
305                     .extended_connect_new_stream(Http3StreamInfo::new(
306                         stream_id,
307                         ExtendedConnectType::WebTransport.get_stream_type(self.session_id),
308                     ));
309             }
310         }
311     }
312 
remove_recv_stream(&mut self, stream_id: StreamId)313     pub fn remove_recv_stream(&mut self, stream_id: StreamId) {
314         self.recv_streams.remove(&stream_id);
315     }
316 
remove_send_stream(&mut self, stream_id: StreamId)317     pub fn remove_send_stream(&mut self, stream_id: StreamId) {
318         self.send_streams.remove(&stream_id);
319     }
320 
321     #[must_use]
is_active(&self) -> bool322     pub fn is_active(&self) -> bool {
323         matches!(self.state, SessionState::Active)
324     }
325 
take_sub_streams(&mut self) -> Option<(BTreeSet<StreamId>, BTreeSet<StreamId>)>326     pub fn take_sub_streams(&mut self) -> Option<(BTreeSet<StreamId>, BTreeSet<StreamId>)> {
327         Some((
328             mem::take(&mut self.recv_streams),
329             mem::take(&mut self.send_streams),
330         ))
331     }
332 
333     /// # Errors
334     /// It may return an error if the frame is not correctly decoded.
read_control_stream(&mut self, conn: &mut Connection) -> Res<()>335     pub fn read_control_stream(&mut self, conn: &mut Connection) -> Res<()> {
336         let (f, fin) = self
337             .frame_reader
338             .receive::<WebTransportFrame>(&mut StreamReaderRecvStreamWrapper::new(
339                 conn,
340                 &mut self.control_stream_recv,
341             ))
342             .map_err(|_| Error::HttpGeneralProtocolStream)?;
343         qtrace!([self], "Received frame: {:?} fin={}", f, fin);
344         if let Some(WebTransportFrame::CloseSession { error, message }) = f {
345             self.events.session_end(
346                 ExtendedConnectType::WebTransport,
347                 self.session_id,
348                 SessionCloseReason::Clean { error, message },
349             );
350             self.state = if fin {
351                 SessionState::Done
352             } else {
353                 SessionState::FinPending
354             };
355         } else if fin {
356             self.events.session_end(
357                 ExtendedConnectType::WebTransport,
358                 self.session_id,
359                 SessionCloseReason::Clean {
360                     error: 0,
361                     message: "".to_string(),
362                 },
363             );
364             self.state = SessionState::Done;
365         }
366         Ok(())
367     }
368 
369     /// # Errors
370     /// Return an error if the stream was closed on the transport layer, but that information is not yet
371     /// consumed on the http/3 layer.
close_session(&mut self, conn: &mut Connection, error: u32, message: &str) -> Res<()>372     pub fn close_session(&mut self, conn: &mut Connection, error: u32, message: &str) -> Res<()> {
373         self.state = SessionState::Done;
374         let close_frame = WebTransportFrame::CloseSession {
375             error,
376             message: message.to_string(),
377         };
378         let mut encoder = Encoder::default();
379         close_frame.encode(&mut encoder);
380         self.control_stream_send.send_data_atomic(conn, &encoder)?;
381         self.control_stream_send.close(conn)?;
382         self.state = if self.control_stream_send.done() {
383             SessionState::Done
384         } else {
385             SessionState::FinPending
386         };
387         Ok(())
388     }
389 
send_data(&mut self, conn: &mut Connection, buf: &[u8]) -> Res<usize>390     fn send_data(&mut self, conn: &mut Connection, buf: &[u8]) -> Res<usize> {
391         self.control_stream_send.send_data(conn, buf)
392     }
393 }
394 
395 impl Stream for Rc<RefCell<WebTransportSession>> {
stream_type(&self) -> Http3StreamType396     fn stream_type(&self) -> Http3StreamType {
397         Http3StreamType::ExtendedConnect
398     }
399 }
400 
401 impl RecvStream for Rc<RefCell<WebTransportSession>> {
receive(&mut self, conn: &mut Connection) -> Res<(ReceiveOutput, bool)>402     fn receive(&mut self, conn: &mut Connection) -> Res<(ReceiveOutput, bool)> {
403         self.borrow_mut().receive(conn)
404     }
405 
reset(&mut self, close_type: CloseType) -> Res<()>406     fn reset(&mut self, close_type: CloseType) -> Res<()> {
407         self.borrow_mut().close(close_type);
408         Ok(())
409     }
410 
http_stream(&mut self) -> Option<&mut dyn HttpRecvStream>411     fn http_stream(&mut self) -> Option<&mut dyn HttpRecvStream> {
412         Some(self)
413     }
414 
webtransport(&self) -> Option<Rc<RefCell<WebTransportSession>>>415     fn webtransport(&self) -> Option<Rc<RefCell<WebTransportSession>>> {
416         Some(self.clone())
417     }
418 }
419 
420 impl HttpRecvStream for Rc<RefCell<WebTransportSession>> {
header_unblocked(&mut self, conn: &mut Connection) -> Res<(ReceiveOutput, bool)>421     fn header_unblocked(&mut self, conn: &mut Connection) -> Res<(ReceiveOutput, bool)> {
422         self.borrow_mut().header_unblocked(conn)
423     }
424 
maybe_update_priority(&mut self, priority: Priority) -> bool425     fn maybe_update_priority(&mut self, priority: Priority) -> bool {
426         self.borrow_mut().maybe_update_priority(priority)
427     }
428 
priority_update_frame(&mut self) -> Option<HFrame>429     fn priority_update_frame(&mut self) -> Option<HFrame> {
430         self.borrow_mut().priority_update_frame()
431     }
432 
priority_update_sent(&mut self)433     fn priority_update_sent(&mut self) {
434         self.borrow_mut().priority_update_sent();
435     }
436 
any(&self) -> &dyn Any437     fn any(&self) -> &dyn Any {
438         self
439     }
440 }
441 
442 impl SendStream for Rc<RefCell<WebTransportSession>> {
send(&mut self, conn: &mut Connection) -> Res<()>443     fn send(&mut self, conn: &mut Connection) -> Res<()> {
444         self.borrow_mut().send(conn)
445     }
446 
send_data(&mut self, conn: &mut Connection, buf: &[u8]) -> Res<usize>447     fn send_data(&mut self, conn: &mut Connection, buf: &[u8]) -> Res<usize> {
448         self.borrow_mut().send_data(conn, buf)
449     }
450 
has_data_to_send(&self) -> bool451     fn has_data_to_send(&self) -> bool {
452         self.borrow_mut().has_data_to_send()
453     }
454 
stream_writable(&self)455     fn stream_writable(&self) {}
456 
done(&self) -> bool457     fn done(&self) -> bool {
458         self.borrow_mut().done()
459     }
460 
close(&mut self, conn: &mut Connection) -> Res<()>461     fn close(&mut self, conn: &mut Connection) -> Res<()> {
462         self.borrow_mut().close_session(conn, 0, "")
463     }
464 
close_with_message(&mut self, conn: &mut Connection, error: u32, message: &str) -> Res<()>465     fn close_with_message(&mut self, conn: &mut Connection, error: u32, message: &str) -> Res<()> {
466         self.borrow_mut().close_session(conn, error, message)
467     }
468 
handle_stop_sending(&mut self, close_type: CloseType)469     fn handle_stop_sending(&mut self, close_type: CloseType) {
470         self.borrow_mut().close(close_type);
471     }
472 }
473 
474 #[derive(Debug, Default)]
475 struct WebTransportSessionListener {
476     headers: Option<(Vec<Header>, bool, bool)>,
477 }
478 
479 impl WebTransportSessionListener {
set_headers(&mut self, headers: Vec<Header>, interim: bool, fin: bool)480     fn set_headers(&mut self, headers: Vec<Header>, interim: bool, fin: bool) {
481         self.headers = Some((headers, interim, fin));
482     }
483 
get_headers(&mut self) -> Option<(Vec<Header>, bool, bool)>484     pub fn get_headers(&mut self) -> Option<(Vec<Header>, bool, bool)> {
485         mem::take(&mut self.headers)
486     }
487 }
488 
489 impl RecvStreamEvents for Rc<RefCell<WebTransportSessionListener>> {}
490 
491 impl HttpRecvStreamEvents for Rc<RefCell<WebTransportSessionListener>> {
header_ready( &self, _stream_info: Http3StreamInfo, headers: Vec<Header>, interim: bool, fin: bool, )492     fn header_ready(
493         &self,
494         _stream_info: Http3StreamInfo,
495         headers: Vec<Header>,
496         interim: bool,
497         fin: bool,
498     ) {
499         if !interim || fin {
500             self.borrow_mut().set_headers(headers, interim, fin);
501         }
502     }
503 }
504 
505 impl SendStreamEvents for Rc<RefCell<WebTransportSessionListener>> {}
506