1 //! Async WebSockets.
2 //!
3 //! This crate is based on [tungstenite](https://crates.io/crates/tungstenite)
4 //! Rust WebSocket library and provides async bindings and wrappers for it, so you
5 //! can use it with non-blocking/asynchronous `TcpStream`s from and couple it
6 //! together with other crates from the async stack. In addition, optional
7 //! integration with various other crates can be enabled via feature flags
8 //!
9 //!  * `async-tls`: Enables the `async_tls` module, which provides integration
10 //!    with the [async-tls](https://crates.io/crates/async-tls) TLS stack and can
11 //!    be used independent of any async runtime.
12 //!  * `async-std-runtime`: Enables the `async_std` module, which provides
13 //!    integration with the [async-std](https://async.rs) runtime.
14 //!  * `async-native-tls`: Enables the additional functions in the `async_std`
15 //!    module to implement TLS via
16 //!    [async-native-tls](https://crates.io/crates/async-native-tls).
17 //!  * `tokio-runtime`: Enables the `tokio` module, which provides integration
18 //!    with the [tokio](https://tokio.rs) runtime.
19 //!  * `tokio-native-tls`: Enables the additional functions in the `tokio` module to
20 //!    implement TLS via [tokio-native-tls](https://crates.io/crates/tokio-native-tls).
21 //!  * `tokio-rustls-native-certs`: Enables the additional functions in the `tokio`
22 //!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
23 //!    and uses native system certificates found with
24 //!    [rustls-native-certs](https://github.com/rustls/rustls-native-certs).
25 //!  * `tokio-rustls-webpki-roots`: Enables the additional functions in the `tokio`
26 //!    module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
27 //!    and uses the certificates [webpki-roots](https://github.com/rustls/webpki-roots)
28 //!    provides.
29 //!  * `tokio-openssl`: Enables the additional functions in the `tokio` module to
30 //!    implement TLS via [tokio-openssl](https://crates.io/crates/tokio-openssl).
31 //!  * `gio-runtime`: Enables the `gio` module, which provides integration with
32 //!    the [gio](https://www.gtk-rs.org) runtime.
33 //!
34 //! Each WebSocket stream implements the required `Stream` and `Sink` traits,
35 //! making the socket a stream of WebSocket messages coming in and going out.
36 
37 #![deny(
38     missing_docs,
39     unused_must_use,
40     unused_mut,
41     unused_imports,
42     unused_import_braces
43 )]
44 
45 pub use tungstenite;
46 
47 mod compat;
48 mod handshake;
49 
50 #[cfg(any(
51     feature = "async-tls",
52     feature = "async-native-tls",
53     feature = "tokio-native-tls",
54     feature = "tokio-rustls-native-certs",
55     feature = "tokio-rustls-webpki-roots",
56     feature = "tokio-openssl",
57 ))]
58 pub mod stream;
59 
60 use std::io::{Read, Write};
61 
62 use compat::{cvt, AllowStd, ContextWaker};
63 use futures_io::{AsyncRead, AsyncWrite};
64 use futures_util::{
65     sink::{Sink, SinkExt},
66     stream::Stream,
67 };
68 use log::*;
69 use std::pin::Pin;
70 use std::task::{Context, Poll};
71 
72 use tungstenite::{
73     client::IntoClientRequest,
74     error::Error as WsError,
75     handshake::{
76         client::{ClientHandshake, Response},
77         server::{Callback, NoCallback},
78         HandshakeError,
79     },
80     protocol::{Message, Role, WebSocket, WebSocketConfig},
81     server,
82 };
83 
84 #[cfg(feature = "async-std-runtime")]
85 pub mod async_std;
86 #[cfg(feature = "async-tls")]
87 pub mod async_tls;
88 #[cfg(feature = "gio-runtime")]
89 pub mod gio;
90 #[cfg(feature = "tokio-runtime")]
91 pub mod tokio;
92 
93 use tungstenite::protocol::CloseFrame;
94 
95 /// Creates a WebSocket handshake from a request and a stream.
96 /// For convenience, the user may call this with a url string, a URL,
97 /// or a `Request`. Calling with `Request` allows the user to add
98 /// a WebSocket protocol or other custom headers.
99 ///
100 /// Internally, this custom creates a handshake representation and returns
101 /// a future representing the resolution of the WebSocket handshake. The
102 /// returned future will resolve to either `WebSocketStream<S>` or `Error`
103 /// depending on whether the handshake is successful.
104 ///
105 /// This is typically used for clients who have already established, for
106 /// example, a TCP connection to the remote server.
client_async<'a, R, S>( request: R, stream: S, ) -> Result<(WebSocketStream<S>, Response), WsError> where R: IntoClientRequest + Unpin, S: AsyncRead + AsyncWrite + Unpin,107 pub async fn client_async<'a, R, S>(
108     request: R,
109     stream: S,
110 ) -> Result<(WebSocketStream<S>, Response), WsError>
111 where
112     R: IntoClientRequest + Unpin,
113     S: AsyncRead + AsyncWrite + Unpin,
114 {
115     client_async_with_config(request, stream, None).await
116 }
117 
118 /// The same as `client_async()` but the one can specify a websocket configuration.
119 /// Please refer to `client_async()` for more details.
client_async_with_config<'a, R, S>( request: R, stream: S, config: Option<WebSocketConfig>, ) -> Result<(WebSocketStream<S>, Response), WsError> where R: IntoClientRequest + Unpin, S: AsyncRead + AsyncWrite + Unpin,120 pub async fn client_async_with_config<'a, R, S>(
121     request: R,
122     stream: S,
123     config: Option<WebSocketConfig>,
124 ) -> Result<(WebSocketStream<S>, Response), WsError>
125 where
126     R: IntoClientRequest + Unpin,
127     S: AsyncRead + AsyncWrite + Unpin,
128 {
129     let f = handshake::client_handshake(stream, move |allow_std| {
130         let request = request.into_client_request()?;
131         let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
132         cli_handshake.handshake()
133     });
134     f.await.map_err(|e| match e {
135         HandshakeError::Failure(e) => e,
136         e => WsError::Io(std::io::Error::new(
137             std::io::ErrorKind::Other,
138             e.to_string(),
139         )),
140     })
141 }
142 
143 /// Accepts a new WebSocket connection with the provided stream.
144 ///
145 /// This function will internally call `server::accept` to create a
146 /// handshake representation and returns a future representing the
147 /// resolution of the WebSocket handshake. The returned future will resolve
148 /// to either `WebSocketStream<S>` or `Error` depending if it's successful
149 /// or not.
150 ///
151 /// This is typically used after a socket has been accepted from a
152 /// `TcpListener`. That socket is then passed to this function to perform
153 /// the server half of the accepting a client's websocket connection.
accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError> where S: AsyncRead + AsyncWrite + Unpin,154 pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
155 where
156     S: AsyncRead + AsyncWrite + Unpin,
157 {
158     accept_hdr_async(stream, NoCallback).await
159 }
160 
161 /// The same as `accept_async()` but the one can specify a websocket configuration.
162 /// Please refer to `accept_async()` for more details.
accept_async_with_config<S>( stream: S, config: Option<WebSocketConfig>, ) -> Result<WebSocketStream<S>, WsError> where S: AsyncRead + AsyncWrite + Unpin,163 pub async fn accept_async_with_config<S>(
164     stream: S,
165     config: Option<WebSocketConfig>,
166 ) -> Result<WebSocketStream<S>, WsError>
167 where
168     S: AsyncRead + AsyncWrite + Unpin,
169 {
170     accept_hdr_async_with_config(stream, NoCallback, config).await
171 }
172 
173 /// Accepts a new WebSocket connection with the provided stream.
174 ///
175 /// This function does the same as `accept_async()` but accepts an extra callback
176 /// for header processing. The callback receives headers of the incoming
177 /// requests and is able to add extra headers to the reply.
accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError> where S: AsyncRead + AsyncWrite + Unpin, C: Callback + Unpin,178 pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
179 where
180     S: AsyncRead + AsyncWrite + Unpin,
181     C: Callback + Unpin,
182 {
183     accept_hdr_async_with_config(stream, callback, None).await
184 }
185 
186 /// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
187 /// Please refer to `accept_hdr_async()` for more details.
accept_hdr_async_with_config<S, C>( stream: S, callback: C, config: Option<WebSocketConfig>, ) -> Result<WebSocketStream<S>, WsError> where S: AsyncRead + AsyncWrite + Unpin, C: Callback + Unpin,188 pub async fn accept_hdr_async_with_config<S, C>(
189     stream: S,
190     callback: C,
191     config: Option<WebSocketConfig>,
192 ) -> Result<WebSocketStream<S>, WsError>
193 where
194     S: AsyncRead + AsyncWrite + Unpin,
195     C: Callback + Unpin,
196 {
197     let f = handshake::server_handshake(stream, move |allow_std| {
198         server::accept_hdr_with_config(allow_std, callback, config)
199     });
200     f.await.map_err(|e| match e {
201         HandshakeError::Failure(e) => e,
202         e => WsError::Io(std::io::Error::new(
203             std::io::ErrorKind::Other,
204             e.to_string(),
205         )),
206     })
207 }
208 
209 /// A wrapper around an underlying raw stream which implements the WebSocket
210 /// protocol.
211 ///
212 /// A `WebSocketStream<S>` represents a handshake that has been completed
213 /// successfully and both the server and the client are ready for receiving
214 /// and sending data. Message from a `WebSocketStream<S>` are accessible
215 /// through the respective `Stream` and `Sink`. Check more information about
216 /// them in `futures-rs` crate documentation or have a look on the examples
217 /// and unit tests for this crate.
218 #[derive(Debug)]
219 pub struct WebSocketStream<S> {
220     inner: WebSocket<AllowStd<S>>,
221 }
222 
223 impl<S> WebSocketStream<S> {
224     /// Convert a raw socket into a WebSocketStream without performing a
225     /// handshake.
from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self where S: AsyncRead + AsyncWrite + Unpin,226     pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
227     where
228         S: AsyncRead + AsyncWrite + Unpin,
229     {
230         handshake::without_handshake(stream, move |allow_std| {
231             WebSocket::from_raw_socket(allow_std, role, config)
232         })
233         .await
234     }
235 
236     /// Convert a raw socket into a WebSocketStream without performing a
237     /// handshake.
from_partially_read( stream: S, part: Vec<u8>, role: Role, config: Option<WebSocketConfig>, ) -> Self where S: AsyncRead + AsyncWrite + Unpin,238     pub async fn from_partially_read(
239         stream: S,
240         part: Vec<u8>,
241         role: Role,
242         config: Option<WebSocketConfig>,
243     ) -> Self
244     where
245         S: AsyncRead + AsyncWrite + Unpin,
246     {
247         handshake::without_handshake(stream, move |allow_std| {
248             WebSocket::from_partially_read(allow_std, part, role, config)
249         })
250         .await
251     }
252 
new(ws: WebSocket<AllowStd<S>>) -> Self253     pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
254         WebSocketStream { inner: ws }
255     }
256 
with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R where S: Unpin, F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R, AllowStd<S>: Read + Write,257     fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
258     where
259         S: Unpin,
260         F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
261         AllowStd<S>: Read + Write,
262     {
263         #[cfg(feature = "verbose-logging")]
264         trace!("{}:{} WebSocketStream.with_context", file!(), line!());
265         if let Some((kind, ctx)) = ctx {
266             self.inner.get_mut().set_waker(kind, &ctx.waker());
267         }
268         f(&mut self.inner)
269     }
270 
271     /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S where S: AsyncRead + AsyncWrite + Unpin,272     pub fn get_ref(&self) -> &S
273     where
274         S: AsyncRead + AsyncWrite + Unpin,
275     {
276         &self.inner.get_ref().get_ref()
277     }
278 
279     /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S where S: AsyncRead + AsyncWrite + Unpin,280     pub fn get_mut(&mut self) -> &mut S
281     where
282         S: AsyncRead + AsyncWrite + Unpin,
283     {
284         self.inner.get_mut().get_mut()
285     }
286 
287     /// Returns a reference to the configuration of the tungstenite stream.
get_config(&self) -> &WebSocketConfig288     pub fn get_config(&self) -> &WebSocketConfig {
289         self.inner.get_config()
290     }
291 
292     /// Close the underlying web socket
close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<(), WsError> where S: AsyncRead + AsyncWrite + Unpin,293     pub async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<(), WsError>
294     where
295         S: AsyncRead + AsyncWrite + Unpin,
296     {
297         let msg = msg.map(|msg| msg.into_owned());
298         self.send(Message::Close(msg)).await
299     }
300 }
301 
302 impl<T> Stream for WebSocketStream<T>
303 where
304     T: AsyncRead + AsyncWrite + Unpin,
305 {
306     type Item = Result<Message, WsError>;
307 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>308     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
309         #[cfg(feature = "verbose-logging")]
310         trace!("{}:{} Stream.poll_next", file!(), line!());
311         match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
312             #[cfg(feature = "verbose-logging")]
313             trace!(
314                 "{}:{} Stream.with_context poll_next -> read_message()",
315                 file!(),
316                 line!()
317             );
318             cvt(s.read_message())
319         })) {
320             Ok(v) => Poll::Ready(Some(Ok(v))),
321             Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => Poll::Ready(None),
322             Err(e) => Poll::Ready(Some(Err(e))),
323         }
324     }
325 }
326 
327 impl<T> Sink<Message> for WebSocketStream<T>
328 where
329     T: AsyncRead + AsyncWrite + Unpin,
330 {
331     type Error = WsError;
332 
poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>333     fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
334         (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.write_pending()))
335     }
336 
start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error>337     fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
338         match (*self).with_context(None, |s| s.write_message(item)) {
339             Ok(()) => Ok(()),
340             Err(::tungstenite::Error::Io(ref err))
341                 if err.kind() == std::io::ErrorKind::WouldBlock =>
342             {
343                 // the message was accepted and queued
344                 // isn't an error.
345                 Ok(())
346             }
347             Err(e) => {
348                 debug!("websocket start_send error: {}", e);
349                 Err(e)
350             }
351         }
352     }
353 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>354     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
355         (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.write_pending()))
356     }
357 
poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>358     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
359         match (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None)) {
360             Ok(()) => Poll::Ready(Ok(())),
361             Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())),
362             Err(err) => {
363                 debug!("websocket close error: {}", err);
364                 Poll::Ready(Err(err))
365             }
366         }
367     }
368 }
369 
370 #[cfg(any(
371     feature = "async-tls",
372     feature = "async-std-runtime",
373     feature = "tokio-runtime",
374     feature = "gio-runtime"
375 ))]
376 /// Get a domain from an URL.
377 #[inline]
domain( request: &tungstenite::handshake::client::Request, ) -> Result<String, tungstenite::Error>378 pub(crate) fn domain(
379     request: &tungstenite::handshake::client::Request,
380 ) -> Result<String, tungstenite::Error> {
381     match request.uri().host() {
382         Some(d) => Ok(d.to_string()),
383         None => Err(tungstenite::Error::Url(
384             tungstenite::error::UrlError::NoHostName,
385         )),
386     }
387 }
388 
389 #[cfg(any(
390     feature = "async-std-runtime",
391     feature = "tokio-runtime",
392     feature = "gio-runtime"
393 ))]
394 /// Get the port from an URL.
395 #[inline]
port( request: &tungstenite::handshake::client::Request, ) -> Result<u16, tungstenite::Error>396 pub(crate) fn port(
397     request: &tungstenite::handshake::client::Request,
398 ) -> Result<u16, tungstenite::Error> {
399     request
400         .uri()
401         .port_u16()
402         .or_else(|| match request.uri().scheme_str() {
403             Some("wss") => Some(443),
404             Some("ws") => Some(80),
405             _ => None,
406         })
407         .ok_or(tungstenite::Error::Url(
408             tungstenite::error::UrlError::UnsupportedUrlScheme,
409         ))
410 }
411