1 //! Async WebSockets.
2 //!
3 //! This crate is based on [tungstenite](https://crates.io/crates/tungstenite)
4 //! Rust WebSocket library and provides async bindings and wrappers for it, so you
5 //! can use it with non-blocking/asynchronous `TcpStream`s from and couple it
6 //! together with other crates from the async stack. In addition, optional
7 //! integration with various other crates can be enabled via feature flags
8 //!
9 //! * `async-tls`: Enables the `async_tls` module, which provides integration
10 //! with the [async-tls](https://crates.io/crates/async-tls) TLS stack and can
11 //! be used independent of any async runtime.
12 //! * `async-std-runtime`: Enables the `async_std` module, which provides
13 //! integration with the [async-std](https://async.rs) runtime.
14 //! * `async-native-tls`: Enables the additional functions in the `async_std`
15 //! module to implement TLS via
16 //! [async-native-tls](https://crates.io/crates/async-native-tls).
17 //! * `tokio-runtime`: Enables the `tokio` module, which provides integration
18 //! with the [tokio](https://tokio.rs) runtime.
19 //! * `tokio-native-tls`: Enables the additional functions in the `tokio` module to
20 //! implement TLS via [tokio-native-tls](https://crates.io/crates/tokio-native-tls).
21 //! * `tokio-rustls-native-certs`: Enables the additional functions in the `tokio`
22 //! module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
23 //! and uses native system certificates found with
24 //! [rustls-native-certs](https://github.com/rustls/rustls-native-certs).
25 //! * `tokio-rustls-webpki-roots`: Enables the additional functions in the `tokio`
26 //! module to implement TLS via [tokio-rustls](https://crates.io/crates/tokio-rustls)
27 //! and uses the certificates [webpki-roots](https://github.com/rustls/webpki-roots)
28 //! provides.
29 //! * `tokio-openssl`: Enables the additional functions in the `tokio` module to
30 //! implement TLS via [tokio-openssl](https://crates.io/crates/tokio-openssl).
31 //! * `gio-runtime`: Enables the `gio` module, which provides integration with
32 //! the [gio](https://www.gtk-rs.org) runtime.
33 //!
34 //! Each WebSocket stream implements the required `Stream` and `Sink` traits,
35 //! making the socket a stream of WebSocket messages coming in and going out.
36
37 #![deny(
38 missing_docs,
39 unused_must_use,
40 unused_mut,
41 unused_imports,
42 unused_import_braces
43 )]
44
45 pub use tungstenite;
46
47 mod compat;
48 mod handshake;
49
50 #[cfg(any(
51 feature = "async-tls",
52 feature = "async-native-tls",
53 feature = "tokio-native-tls",
54 feature = "tokio-rustls-native-certs",
55 feature = "tokio-rustls-webpki-roots",
56 feature = "tokio-openssl",
57 ))]
58 pub mod stream;
59
60 use std::io::{Read, Write};
61
62 use compat::{cvt, AllowStd, ContextWaker};
63 use futures_io::{AsyncRead, AsyncWrite};
64 use futures_util::{
65 sink::{Sink, SinkExt},
66 stream::Stream,
67 };
68 use log::*;
69 use std::pin::Pin;
70 use std::task::{Context, Poll};
71
72 use tungstenite::{
73 client::IntoClientRequest,
74 error::Error as WsError,
75 handshake::{
76 client::{ClientHandshake, Response},
77 server::{Callback, NoCallback},
78 HandshakeError,
79 },
80 protocol::{Message, Role, WebSocket, WebSocketConfig},
81 server,
82 };
83
84 #[cfg(feature = "async-std-runtime")]
85 pub mod async_std;
86 #[cfg(feature = "async-tls")]
87 pub mod async_tls;
88 #[cfg(feature = "gio-runtime")]
89 pub mod gio;
90 #[cfg(feature = "tokio-runtime")]
91 pub mod tokio;
92
93 use tungstenite::protocol::CloseFrame;
94
95 /// Creates a WebSocket handshake from a request and a stream.
96 /// For convenience, the user may call this with a url string, a URL,
97 /// or a `Request`. Calling with `Request` allows the user to add
98 /// a WebSocket protocol or other custom headers.
99 ///
100 /// Internally, this custom creates a handshake representation and returns
101 /// a future representing the resolution of the WebSocket handshake. The
102 /// returned future will resolve to either `WebSocketStream<S>` or `Error`
103 /// depending on whether the handshake is successful.
104 ///
105 /// This is typically used for clients who have already established, for
106 /// example, a TCP connection to the remote server.
client_async<'a, R, S>( request: R, stream: S, ) -> Result<(WebSocketStream<S>, Response), WsError> where R: IntoClientRequest + Unpin, S: AsyncRead + AsyncWrite + Unpin,107 pub async fn client_async<'a, R, S>(
108 request: R,
109 stream: S,
110 ) -> Result<(WebSocketStream<S>, Response), WsError>
111 where
112 R: IntoClientRequest + Unpin,
113 S: AsyncRead + AsyncWrite + Unpin,
114 {
115 client_async_with_config(request, stream, None).await
116 }
117
118 /// The same as `client_async()` but the one can specify a websocket configuration.
119 /// Please refer to `client_async()` for more details.
client_async_with_config<'a, R, S>( request: R, stream: S, config: Option<WebSocketConfig>, ) -> Result<(WebSocketStream<S>, Response), WsError> where R: IntoClientRequest + Unpin, S: AsyncRead + AsyncWrite + Unpin,120 pub async fn client_async_with_config<'a, R, S>(
121 request: R,
122 stream: S,
123 config: Option<WebSocketConfig>,
124 ) -> Result<(WebSocketStream<S>, Response), WsError>
125 where
126 R: IntoClientRequest + Unpin,
127 S: AsyncRead + AsyncWrite + Unpin,
128 {
129 let f = handshake::client_handshake(stream, move |allow_std| {
130 let request = request.into_client_request()?;
131 let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
132 cli_handshake.handshake()
133 });
134 f.await.map_err(|e| match e {
135 HandshakeError::Failure(e) => e,
136 e => WsError::Io(std::io::Error::new(
137 std::io::ErrorKind::Other,
138 e.to_string(),
139 )),
140 })
141 }
142
143 /// Accepts a new WebSocket connection with the provided stream.
144 ///
145 /// This function will internally call `server::accept` to create a
146 /// handshake representation and returns a future representing the
147 /// resolution of the WebSocket handshake. The returned future will resolve
148 /// to either `WebSocketStream<S>` or `Error` depending if it's successful
149 /// or not.
150 ///
151 /// This is typically used after a socket has been accepted from a
152 /// `TcpListener`. That socket is then passed to this function to perform
153 /// the server half of the accepting a client's websocket connection.
accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError> where S: AsyncRead + AsyncWrite + Unpin,154 pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
155 where
156 S: AsyncRead + AsyncWrite + Unpin,
157 {
158 accept_hdr_async(stream, NoCallback).await
159 }
160
161 /// The same as `accept_async()` but the one can specify a websocket configuration.
162 /// Please refer to `accept_async()` for more details.
accept_async_with_config<S>( stream: S, config: Option<WebSocketConfig>, ) -> Result<WebSocketStream<S>, WsError> where S: AsyncRead + AsyncWrite + Unpin,163 pub async fn accept_async_with_config<S>(
164 stream: S,
165 config: Option<WebSocketConfig>,
166 ) -> Result<WebSocketStream<S>, WsError>
167 where
168 S: AsyncRead + AsyncWrite + Unpin,
169 {
170 accept_hdr_async_with_config(stream, NoCallback, config).await
171 }
172
173 /// Accepts a new WebSocket connection with the provided stream.
174 ///
175 /// This function does the same as `accept_async()` but accepts an extra callback
176 /// for header processing. The callback receives headers of the incoming
177 /// requests and is able to add extra headers to the reply.
accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError> where S: AsyncRead + AsyncWrite + Unpin, C: Callback + Unpin,178 pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
179 where
180 S: AsyncRead + AsyncWrite + Unpin,
181 C: Callback + Unpin,
182 {
183 accept_hdr_async_with_config(stream, callback, None).await
184 }
185
186 /// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
187 /// Please refer to `accept_hdr_async()` for more details.
accept_hdr_async_with_config<S, C>( stream: S, callback: C, config: Option<WebSocketConfig>, ) -> Result<WebSocketStream<S>, WsError> where S: AsyncRead + AsyncWrite + Unpin, C: Callback + Unpin,188 pub async fn accept_hdr_async_with_config<S, C>(
189 stream: S,
190 callback: C,
191 config: Option<WebSocketConfig>,
192 ) -> Result<WebSocketStream<S>, WsError>
193 where
194 S: AsyncRead + AsyncWrite + Unpin,
195 C: Callback + Unpin,
196 {
197 let f = handshake::server_handshake(stream, move |allow_std| {
198 server::accept_hdr_with_config(allow_std, callback, config)
199 });
200 f.await.map_err(|e| match e {
201 HandshakeError::Failure(e) => e,
202 e => WsError::Io(std::io::Error::new(
203 std::io::ErrorKind::Other,
204 e.to_string(),
205 )),
206 })
207 }
208
209 /// A wrapper around an underlying raw stream which implements the WebSocket
210 /// protocol.
211 ///
212 /// A `WebSocketStream<S>` represents a handshake that has been completed
213 /// successfully and both the server and the client are ready for receiving
214 /// and sending data. Message from a `WebSocketStream<S>` are accessible
215 /// through the respective `Stream` and `Sink`. Check more information about
216 /// them in `futures-rs` crate documentation or have a look on the examples
217 /// and unit tests for this crate.
218 #[derive(Debug)]
219 pub struct WebSocketStream<S> {
220 inner: WebSocket<AllowStd<S>>,
221 }
222
223 impl<S> WebSocketStream<S> {
224 /// Convert a raw socket into a WebSocketStream without performing a
225 /// handshake.
from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self where S: AsyncRead + AsyncWrite + Unpin,226 pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
227 where
228 S: AsyncRead + AsyncWrite + Unpin,
229 {
230 handshake::without_handshake(stream, move |allow_std| {
231 WebSocket::from_raw_socket(allow_std, role, config)
232 })
233 .await
234 }
235
236 /// Convert a raw socket into a WebSocketStream without performing a
237 /// handshake.
from_partially_read( stream: S, part: Vec<u8>, role: Role, config: Option<WebSocketConfig>, ) -> Self where S: AsyncRead + AsyncWrite + Unpin,238 pub async fn from_partially_read(
239 stream: S,
240 part: Vec<u8>,
241 role: Role,
242 config: Option<WebSocketConfig>,
243 ) -> Self
244 where
245 S: AsyncRead + AsyncWrite + Unpin,
246 {
247 handshake::without_handshake(stream, move |allow_std| {
248 WebSocket::from_partially_read(allow_std, part, role, config)
249 })
250 .await
251 }
252
new(ws: WebSocket<AllowStd<S>>) -> Self253 pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
254 WebSocketStream { inner: ws }
255 }
256
with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R where S: Unpin, F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R, AllowStd<S>: Read + Write,257 fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
258 where
259 S: Unpin,
260 F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
261 AllowStd<S>: Read + Write,
262 {
263 #[cfg(feature = "verbose-logging")]
264 trace!("{}:{} WebSocketStream.with_context", file!(), line!());
265 if let Some((kind, ctx)) = ctx {
266 self.inner.get_mut().set_waker(kind, &ctx.waker());
267 }
268 f(&mut self.inner)
269 }
270
271 /// Returns a shared reference to the inner stream.
get_ref(&self) -> &S where S: AsyncRead + AsyncWrite + Unpin,272 pub fn get_ref(&self) -> &S
273 where
274 S: AsyncRead + AsyncWrite + Unpin,
275 {
276 &self.inner.get_ref().get_ref()
277 }
278
279 /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut S where S: AsyncRead + AsyncWrite + Unpin,280 pub fn get_mut(&mut self) -> &mut S
281 where
282 S: AsyncRead + AsyncWrite + Unpin,
283 {
284 self.inner.get_mut().get_mut()
285 }
286
287 /// Returns a reference to the configuration of the tungstenite stream.
get_config(&self) -> &WebSocketConfig288 pub fn get_config(&self) -> &WebSocketConfig {
289 self.inner.get_config()
290 }
291
292 /// Close the underlying web socket
close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<(), WsError> where S: AsyncRead + AsyncWrite + Unpin,293 pub async fn close(&mut self, msg: Option<CloseFrame<'_>>) -> Result<(), WsError>
294 where
295 S: AsyncRead + AsyncWrite + Unpin,
296 {
297 let msg = msg.map(|msg| msg.into_owned());
298 self.send(Message::Close(msg)).await
299 }
300 }
301
302 impl<T> Stream for WebSocketStream<T>
303 where
304 T: AsyncRead + AsyncWrite + Unpin,
305 {
306 type Item = Result<Message, WsError>;
307
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>308 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
309 #[cfg(feature = "verbose-logging")]
310 trace!("{}:{} Stream.poll_next", file!(), line!());
311 match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
312 #[cfg(feature = "verbose-logging")]
313 trace!(
314 "{}:{} Stream.with_context poll_next -> read_message()",
315 file!(),
316 line!()
317 );
318 cvt(s.read_message())
319 })) {
320 Ok(v) => Poll::Ready(Some(Ok(v))),
321 Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => Poll::Ready(None),
322 Err(e) => Poll::Ready(Some(Err(e))),
323 }
324 }
325 }
326
327 impl<T> Sink<Message> for WebSocketStream<T>
328 where
329 T: AsyncRead + AsyncWrite + Unpin,
330 {
331 type Error = WsError;
332
poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>333 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
334 (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.write_pending()))
335 }
336
start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error>337 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
338 match (*self).with_context(None, |s| s.write_message(item)) {
339 Ok(()) => Ok(()),
340 Err(::tungstenite::Error::Io(ref err))
341 if err.kind() == std::io::ErrorKind::WouldBlock =>
342 {
343 // the message was accepted and queued
344 // isn't an error.
345 Ok(())
346 }
347 Err(e) => {
348 debug!("websocket start_send error: {}", e);
349 Err(e)
350 }
351 }
352 }
353
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>354 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
355 (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.write_pending()))
356 }
357
poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>358 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
359 match (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None)) {
360 Ok(()) => Poll::Ready(Ok(())),
361 Err(::tungstenite::Error::ConnectionClosed) => Poll::Ready(Ok(())),
362 Err(err) => {
363 debug!("websocket close error: {}", err);
364 Poll::Ready(Err(err))
365 }
366 }
367 }
368 }
369
370 #[cfg(any(
371 feature = "async-tls",
372 feature = "async-std-runtime",
373 feature = "tokio-runtime",
374 feature = "gio-runtime"
375 ))]
376 /// Get a domain from an URL.
377 #[inline]
domain( request: &tungstenite::handshake::client::Request, ) -> Result<String, tungstenite::Error>378 pub(crate) fn domain(
379 request: &tungstenite::handshake::client::Request,
380 ) -> Result<String, tungstenite::Error> {
381 match request.uri().host() {
382 Some(d) => Ok(d.to_string()),
383 None => Err(tungstenite::Error::Url(
384 tungstenite::error::UrlError::NoHostName,
385 )),
386 }
387 }
388
389 #[cfg(any(
390 feature = "async-std-runtime",
391 feature = "tokio-runtime",
392 feature = "gio-runtime"
393 ))]
394 /// Get the port from an URL.
395 #[inline]
port( request: &tungstenite::handshake::client::Request, ) -> Result<u16, tungstenite::Error>396 pub(crate) fn port(
397 request: &tungstenite::handshake::client::Request,
398 ) -> Result<u16, tungstenite::Error> {
399 request
400 .uri()
401 .port_u16()
402 .or_else(|| match request.uri().scheme_str() {
403 Some("wss") => Some(443),
404 Some("ws") => Some(80),
405 _ => None,
406 })
407 .ok_or(tungstenite::Error::Url(
408 tungstenite::error::UrlError::UnsupportedUrlScheme,
409 ))
410 }
411