1 //! HTTP Upgrades
2 //!
3 //! This module deals with managing [HTTP Upgrades][mdn] in hyper. Since
4 //! several concepts in HTTP allow for first talking HTTP, and then converting
5 //! to a different protocol, this module conflates them into a single API.
6 //! Those include:
7 //!
8 //! - HTTP/1.1 Upgrades
9 //! - HTTP `CONNECT`
10 //!
11 //! You are responsible for any other pre-requisites to establish an upgrade,
12 //! such as sending the appropriate headers, methods, and status codes. You can
13 //! then use [`on`][] to grab a `Future` which will resolve to the upgraded
14 //! connection object, or an error if the upgrade fails.
15 //!
16 //! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism
17 //!
18 //! # Client
19 //!
20 //! Sending an HTTP ugprade from the [`client`](super::client) involves setting
21 //! either the appropriate method, if wanting to `CONNECT`, or headers such as
22 //! `Upgrade` and `Connection`, on the `http::Request`. Once receiving the
23 //! `http::Response` back, you must check for the specific information that the
24 //! upgrade is agreed upon by the server (such as a `101` status code), and then
25 //! get the `Future` from the `Response`.
26 //!
27 //! # Server
28 //!
29 //! Receiving upgrade requests in a server requires you to check the relevant
30 //! headers in a `Request`, and if an upgrade should be done, you then send the
31 //! corresponding headers in a response. To then wait for hyper to finish the
32 //! upgrade, you call `on()` with the `Request`, and then can spawn a task
33 //! awaiting it.
34 //!
35 //! # Example
36 //!
37 //! See [this example][example] showing how upgrades work with both
38 //! Clients and Servers.
39 //!
40 //! [example]: https://github.com/hyperium/hyper/blob/master/examples/upgrades.rs
41 
42 use std::any::TypeId;
43 use std::error::Error as StdError;
44 use std::fmt;
45 use std::io;
46 use std::marker::Unpin;
47 
48 use bytes::Bytes;
49 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
50 use tokio::sync::oneshot;
51 #[cfg(any(feature = "http1", feature = "http2"))]
52 use tracing::trace;
53 
54 use crate::common::io::Rewind;
55 use crate::common::{task, Future, Pin, Poll};
56 
57 /// An upgraded HTTP connection.
58 ///
59 /// This type holds a trait object internally of the original IO that
60 /// was used to speak HTTP before the upgrade. It can be used directly
61 /// as a `Read` or `Write` for convenience.
62 ///
63 /// Alternatively, if the exact type is known, this can be deconstructed
64 /// into its parts.
65 pub struct Upgraded {
66     io: Rewind<Box<dyn Io + Send>>,
67 }
68 
69 /// A future for a possible HTTP upgrade.
70 ///
71 /// If no upgrade was available, or it doesn't succeed, yields an `Error`.
72 pub struct OnUpgrade {
73     rx: Option<oneshot::Receiver<crate::Result<Upgraded>>>,
74 }
75 
76 /// The deconstructed parts of an [`Upgraded`](Upgraded) type.
77 ///
78 /// Includes the original IO type, and a read buffer of bytes that the
79 /// HTTP state machine may have already read before completing an upgrade.
80 #[derive(Debug)]
81 pub struct Parts<T> {
82     /// The original IO object used before the upgrade.
83     pub io: T,
84     /// A buffer of bytes that have been read but not processed as HTTP.
85     ///
86     /// For instance, if the `Connection` is used for an HTTP upgrade request,
87     /// it is possible the server sent back the first bytes of the new protocol
88     /// along with the response upgrade.
89     ///
90     /// You will want to check for any existing bytes if you plan to continue
91     /// communicating on the IO object.
92     pub read_buf: Bytes,
93     _inner: (),
94 }
95 
96 /// Gets a pending HTTP upgrade from this message.
97 ///
98 /// This can be called on the following types:
99 ///
100 /// - `http::Request<B>`
101 /// - `http::Response<B>`
102 /// - `&mut http::Request<B>`
103 /// - `&mut http::Response<B>`
on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade104 pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
105     msg.on_upgrade()
106 }
107 
108 #[cfg(any(feature = "http1", feature = "http2"))]
109 pub(super) struct Pending {
110     tx: oneshot::Sender<crate::Result<Upgraded>>,
111 }
112 
113 #[cfg(any(feature = "http1", feature = "http2"))]
pending() -> (Pending, OnUpgrade)114 pub(super) fn pending() -> (Pending, OnUpgrade) {
115     let (tx, rx) = oneshot::channel();
116     (Pending { tx }, OnUpgrade { rx: Some(rx) })
117 }
118 
119 // ===== impl Upgraded =====
120 
121 impl Upgraded {
122     #[cfg(any(feature = "http1", feature = "http2", test))]
new<T>(io: T, read_buf: Bytes) -> Self where T: AsyncRead + AsyncWrite + Unpin + Send + 'static,123     pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
124     where
125         T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
126     {
127         Upgraded {
128             io: Rewind::new_buffered(Box::new(io), read_buf),
129         }
130     }
131 
132     /// Tries to downcast the internal trait object to the type passed.
133     ///
134     /// On success, returns the downcasted parts. On error, returns the
135     /// `Upgraded` back.
downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self>136     pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
137         let (io, buf) = self.io.into_inner();
138         match io.__hyper_downcast() {
139             Ok(t) => Ok(Parts {
140                 io: *t,
141                 read_buf: buf,
142                 _inner: (),
143             }),
144             Err(io) => Err(Upgraded {
145                 io: Rewind::new_buffered(io, buf),
146             }),
147         }
148     }
149 }
150 
151 impl AsyncRead for Upgraded {
poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>152     fn poll_read(
153         mut self: Pin<&mut Self>,
154         cx: &mut task::Context<'_>,
155         buf: &mut ReadBuf<'_>,
156     ) -> Poll<io::Result<()>> {
157         Pin::new(&mut self.io).poll_read(cx, buf)
158     }
159 }
160 
161 impl AsyncWrite for Upgraded {
poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>162     fn poll_write(
163         mut self: Pin<&mut Self>,
164         cx: &mut task::Context<'_>,
165         buf: &[u8],
166     ) -> Poll<io::Result<usize>> {
167         Pin::new(&mut self.io).poll_write(cx, buf)
168     }
169 
poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>170     fn poll_write_vectored(
171         mut self: Pin<&mut Self>,
172         cx: &mut task::Context<'_>,
173         bufs: &[io::IoSlice<'_>],
174     ) -> Poll<io::Result<usize>> {
175         Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
176     }
177 
poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>>178     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
179         Pin::new(&mut self.io).poll_flush(cx)
180     }
181 
poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>>182     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
183         Pin::new(&mut self.io).poll_shutdown(cx)
184     }
185 
is_write_vectored(&self) -> bool186     fn is_write_vectored(&self) -> bool {
187         self.io.is_write_vectored()
188     }
189 }
190 
191 impl fmt::Debug for Upgraded {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result192     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193         f.debug_struct("Upgraded").finish()
194     }
195 }
196 
197 // ===== impl OnUpgrade =====
198 
199 impl OnUpgrade {
none() -> Self200     pub(super) fn none() -> Self {
201         OnUpgrade { rx: None }
202     }
203 
204     #[cfg(feature = "http1")]
is_none(&self) -> bool205     pub(super) fn is_none(&self) -> bool {
206         self.rx.is_none()
207     }
208 }
209 
210 impl Future for OnUpgrade {
211     type Output = Result<Upgraded, crate::Error>;
212 
poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>213     fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
214         match self.rx {
215             Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res {
216                 Ok(Ok(upgraded)) => Ok(upgraded),
217                 Ok(Err(err)) => Err(err),
218                 Err(_oneshot_canceled) => Err(crate::Error::new_canceled().with(UpgradeExpected)),
219             }),
220             None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())),
221         }
222     }
223 }
224 
225 impl fmt::Debug for OnUpgrade {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result226     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227         f.debug_struct("OnUpgrade").finish()
228     }
229 }
230 
231 // ===== impl Pending =====
232 
233 #[cfg(any(feature = "http1", feature = "http2"))]
234 impl Pending {
fulfill(self, upgraded: Upgraded)235     pub(super) fn fulfill(self, upgraded: Upgraded) {
236         trace!("pending upgrade fulfill");
237         let _ = self.tx.send(Ok(upgraded));
238     }
239 
240     #[cfg(feature = "http1")]
241     /// Don't fulfill the pending Upgrade, but instead signal that
242     /// upgrades are handled manually.
manual(self)243     pub(super) fn manual(self) {
244         trace!("pending upgrade handled manually");
245         let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade()));
246     }
247 }
248 
249 // ===== impl UpgradeExpected =====
250 
251 /// Error cause returned when an upgrade was expected but canceled
252 /// for whatever reason.
253 ///
254 /// This likely means the actual `Conn` future wasn't polled and upgraded.
255 #[derive(Debug)]
256 struct UpgradeExpected;
257 
258 impl fmt::Display for UpgradeExpected {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result259     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260         f.write_str("upgrade expected but not completed")
261     }
262 }
263 
264 impl StdError for UpgradeExpected {}
265 
266 // ===== impl Io =====
267 
268 pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static {
__hyper_type_id(&self) -> TypeId269     fn __hyper_type_id(&self) -> TypeId {
270         TypeId::of::<Self>()
271     }
272 }
273 
274 impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {}
275 
276 impl dyn Io + Send {
__hyper_is<T: Io>(&self) -> bool277     fn __hyper_is<T: Io>(&self) -> bool {
278         let t = TypeId::of::<T>();
279         self.__hyper_type_id() == t
280     }
281 
__hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>>282     fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
283         if self.__hyper_is::<T>() {
284             // Taken from `std::error::Error::downcast()`.
285             unsafe {
286                 let raw: *mut dyn Io = Box::into_raw(self);
287                 Ok(Box::from_raw(raw as *mut T))
288             }
289         } else {
290             Err(self)
291         }
292     }
293 }
294 
295 mod sealed {
296     use super::OnUpgrade;
297 
298     pub trait CanUpgrade {
on_upgrade(self) -> OnUpgrade299         fn on_upgrade(self) -> OnUpgrade;
300     }
301 
302     impl<B> CanUpgrade for http::Request<B> {
on_upgrade(mut self) -> OnUpgrade303         fn on_upgrade(mut self) -> OnUpgrade {
304             self.extensions_mut()
305                 .remove::<OnUpgrade>()
306                 .unwrap_or_else(OnUpgrade::none)
307         }
308     }
309 
310     impl<B> CanUpgrade for &'_ mut http::Request<B> {
on_upgrade(self) -> OnUpgrade311         fn on_upgrade(self) -> OnUpgrade {
312             self.extensions_mut()
313                 .remove::<OnUpgrade>()
314                 .unwrap_or_else(OnUpgrade::none)
315         }
316     }
317 
318     impl<B> CanUpgrade for http::Response<B> {
on_upgrade(mut self) -> OnUpgrade319         fn on_upgrade(mut self) -> OnUpgrade {
320             self.extensions_mut()
321                 .remove::<OnUpgrade>()
322                 .unwrap_or_else(OnUpgrade::none)
323         }
324     }
325 
326     impl<B> CanUpgrade for &'_ mut http::Response<B> {
on_upgrade(self) -> OnUpgrade327         fn on_upgrade(self) -> OnUpgrade {
328             self.extensions_mut()
329                 .remove::<OnUpgrade>()
330                 .unwrap_or_else(OnUpgrade::none)
331         }
332     }
333 }
334 
335 #[cfg(test)]
336 mod tests {
337     use super::*;
338 
339     #[test]
upgraded_downcast()340     fn upgraded_downcast() {
341         let upgraded = Upgraded::new(Mock, Bytes::new());
342 
343         let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err();
344 
345         upgraded.downcast::<Mock>().unwrap();
346     }
347 
348     // TODO: replace with tokio_test::io when it can test write_buf
349     struct Mock;
350 
351     impl AsyncRead for Mock {
poll_read( self: Pin<&mut Self>, _cx: &mut task::Context<'_>, _buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>352         fn poll_read(
353             self: Pin<&mut Self>,
354             _cx: &mut task::Context<'_>,
355             _buf: &mut ReadBuf<'_>,
356         ) -> Poll<io::Result<()>> {
357             unreachable!("Mock::poll_read")
358         }
359     }
360 
361     impl AsyncWrite for Mock {
poll_write( self: Pin<&mut Self>, _: &mut task::Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>362         fn poll_write(
363             self: Pin<&mut Self>,
364             _: &mut task::Context<'_>,
365             buf: &[u8],
366         ) -> Poll<io::Result<usize>> {
367             // panic!("poll_write shouldn't be called");
368             Poll::Ready(Ok(buf.len()))
369         }
370 
poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>>371         fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
372             unreachable!("Mock::poll_flush")
373         }
374 
poll_shutdown( self: Pin<&mut Self>, _cx: &mut task::Context<'_>, ) -> Poll<io::Result<()>>375         fn poll_shutdown(
376             self: Pin<&mut Self>,
377             _cx: &mut task::Context<'_>,
378         ) -> Poll<io::Result<()>> {
379             unreachable!("Mock::poll_shutdown")
380         }
381     }
382 }
383