1 //! Websockets Filters
2 
3 use std::borrow::Cow;
4 use std::fmt;
5 use std::future::Future;
6 use std::pin::Pin;
7 use std::task::{Context, Poll};
8 
9 use super::{body, header};
10 use crate::filter::{Filter, One};
11 use crate::reject::Rejection;
12 use crate::reply::{Reply, Response};
13 use futures::{future, ready, FutureExt, Sink, Stream, TryFutureExt};
14 use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade};
15 use http;
16 use tokio_tungstenite::{
17     tungstenite::protocol::{self, WebSocketConfig},
18     WebSocketStream,
19 };
20 
21 /// Creates a Websocket Filter.
22 ///
23 /// The yielded `Ws` is used to finish the websocket upgrade.
24 ///
25 /// # Note
26 ///
27 /// This filter combines multiple filters internally, so you don't need them:
28 ///
29 /// - Method must be `GET`
30 /// - Header `connection` must be `upgrade`
31 /// - Header `upgrade` must be `websocket`
32 /// - Header `sec-websocket-version` must be `13`
33 /// - Header `sec-websocket-key` must be set.
34 ///
35 /// If the filters are met, yields a `Ws`. Calling `Ws::on_upgrade` will
36 /// return a reply with:
37 ///
38 /// - Status of `101 Switching Protocols`
39 /// - Header `connection: upgrade`
40 /// - Header `upgrade: websocket`
41 /// - Header `sec-websocket-accept` with the hash value of the received key.
ws() -> impl Filter<Extract = One<Ws>, Error = Rejection> + Copy42 pub fn ws() -> impl Filter<Extract = One<Ws>, Error = Rejection> + Copy {
43     let connection_has_upgrade = header::header2()
44         .and_then(|conn: ::headers::Connection| {
45             if conn.contains("upgrade") {
46                 future::ok(())
47             } else {
48                 future::err(crate::reject::known(MissingConnectionUpgrade))
49             }
50         })
51         .untuple_one();
52 
53     crate::get()
54         .and(connection_has_upgrade)
55         .and(header::exact_ignore_case("upgrade", "websocket"))
56         .and(header::exact("sec-websocket-version", "13"))
57         //.and(header::exact2(Upgrade::websocket()))
58         //.and(header::exact2(SecWebsocketVersion::V13))
59         .and(header::header2::<SecWebsocketKey>())
60         .and(body::body())
61         .map(move |key: SecWebsocketKey, body: ::hyper::Body| Ws {
62             body,
63             config: None,
64             key,
65         })
66 }
67 
68 /// Extracted by the [`ws`](ws) filter, and used to finish an upgrade.
69 pub struct Ws {
70     body: ::hyper::Body,
71     config: Option<WebSocketConfig>,
72     key: SecWebsocketKey,
73 }
74 
75 impl Ws {
76     /// Finish the upgrade, passing a function to handle the `WebSocket`.
77     ///
78     /// The passed function must return a `Future`.
on_upgrade<F, U>(self, func: F) -> impl Reply where F: FnOnce(WebSocket) -> U + Send + 'static, U: Future<Output = ()> + Send + 'static,79     pub fn on_upgrade<F, U>(self, func: F) -> impl Reply
80     where
81         F: FnOnce(WebSocket) -> U + Send + 'static,
82         U: Future<Output = ()> + Send + 'static,
83     {
84         WsReply {
85             ws: self,
86             on_upgrade: func,
87         }
88     }
89 
90     // config
91 
92     /// Set the size of the internal message send queue.
max_send_queue(mut self, max: usize) -> Self93     pub fn max_send_queue(mut self, max: usize) -> Self {
94         self.config
95             .get_or_insert_with(WebSocketConfig::default)
96             .max_send_queue = Some(max);
97         self
98     }
99 
100     /// Set the maximum message size (defaults to 64 megabytes)
max_message_size(mut self, max: usize) -> Self101     pub fn max_message_size(mut self, max: usize) -> Self {
102         self.config
103             .get_or_insert_with(WebSocketConfig::default)
104             .max_message_size = Some(max);
105         self
106     }
107 }
108 
109 impl fmt::Debug for Ws {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result110     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
111         f.debug_struct("Ws").finish()
112     }
113 }
114 
115 #[allow(missing_debug_implementations)]
116 struct WsReply<F> {
117     ws: Ws,
118     on_upgrade: F,
119 }
120 
121 impl<F, U> Reply for WsReply<F>
122 where
123     F: FnOnce(WebSocket) -> U + Send + 'static,
124     U: Future<Output = ()> + Send + 'static,
125 {
into_response(self) -> Response126     fn into_response(self) -> Response {
127         let on_upgrade = self.on_upgrade;
128         let config = self.ws.config;
129         let fut = self
130             .ws
131             .body
132             .on_upgrade()
133             .and_then(move |upgraded| {
134                 log::trace!("websocket upgrade complete");
135                 WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok)
136             })
137             .and_then(move |socket| on_upgrade(socket).map(Ok))
138             .map(|result| {
139                 if let Err(err) = result {
140                     log::debug!("ws upgrade error: {}", err);
141                 }
142             });
143         ::tokio::task::spawn(fut);
144 
145         let mut res = http::Response::default();
146 
147         *res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS;
148 
149         res.headers_mut().typed_insert(Connection::upgrade());
150         res.headers_mut().typed_insert(Upgrade::websocket());
151         res.headers_mut()
152             .typed_insert(SecWebsocketAccept::from(self.ws.key));
153 
154         res
155     }
156 }
157 
158 /// A websocket `Stream` and `Sink`, provided to `ws` filters.
159 pub struct WebSocket {
160     inner: WebSocketStream<hyper::upgrade::Upgraded>,
161 }
162 
163 impl WebSocket {
from_raw_socket( upgraded: hyper::upgrade::Upgraded, role: protocol::Role, config: Option<protocol::WebSocketConfig>, ) -> Self164     pub(crate) async fn from_raw_socket(
165         upgraded: hyper::upgrade::Upgraded,
166         role: protocol::Role,
167         config: Option<protocol::WebSocketConfig>,
168     ) -> Self {
169         WebSocketStream::from_raw_socket(upgraded, role, config)
170             .map(|inner| WebSocket { inner })
171             .await
172     }
173 
174     /// Gracefully close this websocket.
close(mut self) -> Result<(), crate::Error>175     pub async fn close(mut self) -> Result<(), crate::Error> {
176         future::poll_fn(|cx| Pin::new(&mut self).poll_close(cx)).await
177     }
178 }
179 
180 impl Stream for WebSocket {
181     type Item = Result<Message, crate::Error>;
182 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>>183     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
184         match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
185             Some(Ok(item)) => Poll::Ready(Some(Ok(Message { inner: item }))),
186             Some(Err(e)) => {
187                 log::debug!("websocket poll error: {}", e);
188                 Poll::Ready(Some(Err(crate::Error::new(e))))
189             }
190             None => {
191                 log::trace!("websocket closed");
192                 Poll::Ready(None)
193             }
194         }
195     }
196 }
197 
198 impl Sink<Message> for WebSocket {
199     type Error = crate::Error;
200 
poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>201     fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
202         match ready!(Pin::new(&mut self.inner).poll_ready(cx)) {
203             Ok(()) => Poll::Ready(Ok(())),
204             Err(e) => Poll::Ready(Err(crate::Error::new(e))),
205         }
206     }
207 
start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error>208     fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
209         match Pin::new(&mut self.inner).start_send(item.inner) {
210             Ok(()) => Ok(()),
211             Err(e) => {
212                 log::debug!("websocket start_send error: {}", e);
213                 Err(crate::Error::new(e))
214             }
215         }
216     }
217 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>>218     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
219         match ready!(Pin::new(&mut self.inner).poll_flush(cx)) {
220             Ok(()) => Poll::Ready(Ok(())),
221             Err(e) => Poll::Ready(Err(crate::Error::new(e))),
222         }
223     }
224 
poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>>225     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
226         match ready!(Pin::new(&mut self.inner).poll_close(cx)) {
227             Ok(()) => Poll::Ready(Ok(())),
228             Err(err) => {
229                 log::debug!("websocket close error: {}", err);
230                 Poll::Ready(Err(crate::Error::new(err)))
231             }
232         }
233     }
234 }
235 
236 impl fmt::Debug for WebSocket {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result237     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
238         f.debug_struct("WebSocket").finish()
239     }
240 }
241 
242 /// A WebSocket message.
243 ///
244 /// Only repesents Text and Binary messages.
245 ///
246 /// This will likely become a `non-exhaustive` enum in the future, once that
247 /// language feature has stabilized.
248 #[derive(Eq, PartialEq, Clone)]
249 pub struct Message {
250     inner: protocol::Message,
251 }
252 
253 impl Message {
254     /// Construct a new Text `Message`.
text<S: Into<String>>(s: S) -> Message255     pub fn text<S: Into<String>>(s: S) -> Message {
256         Message {
257             inner: protocol::Message::text(s),
258         }
259     }
260 
261     /// Construct a new Binary `Message`.
binary<V: Into<Vec<u8>>>(v: V) -> Message262     pub fn binary<V: Into<Vec<u8>>>(v: V) -> Message {
263         Message {
264             inner: protocol::Message::binary(v),
265         }
266     }
267 
268     /// Construct a new Ping `Message`.
ping<V: Into<Vec<u8>>>(v: V) -> Message269     pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
270         Message {
271             inner: protocol::Message::Ping(v.into()),
272         }
273     }
274 
275     /// Construct the default Close `Message`.
close() -> Message276     pub fn close() -> Message {
277         Message {
278             inner: protocol::Message::Close(None),
279         }
280     }
281 
282     /// Construct a Close `Message` with a code and reason.
close_with(code: impl Into<u16>, reason: impl Into<Cow<'static, str>>) -> Message283     pub fn close_with(code: impl Into<u16>, reason: impl Into<Cow<'static, str>>) -> Message {
284         Message {
285             inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
286                 code: protocol::frame::coding::CloseCode::from(code.into()),
287                 reason: reason.into(),
288             })),
289         }
290     }
291 
292     /// Returns true if this message is a Text message.
is_text(&self) -> bool293     pub fn is_text(&self) -> bool {
294         self.inner.is_text()
295     }
296 
297     /// Returns true if this message is a Binary message.
is_binary(&self) -> bool298     pub fn is_binary(&self) -> bool {
299         self.inner.is_binary()
300     }
301 
302     /// Returns true if this message a is a Close message.
is_close(&self) -> bool303     pub fn is_close(&self) -> bool {
304         self.inner.is_close()
305     }
306 
307     /// Returns true if this message is a Ping message.
is_ping(&self) -> bool308     pub fn is_ping(&self) -> bool {
309         self.inner.is_ping()
310     }
311 
312     /// Returns true if this message is a Pong message.
is_pong(&self) -> bool313     pub fn is_pong(&self) -> bool {
314         self.inner.is_pong()
315     }
316 
317     /// Try to get a reference to the string text, if this is a Text message.
to_str(&self) -> Result<&str, ()>318     pub fn to_str(&self) -> Result<&str, ()> {
319         match self.inner {
320             protocol::Message::Text(ref s) => Ok(s),
321             _ => Err(()),
322         }
323     }
324 
325     /// Return the bytes of this message, if the message can contain data.
as_bytes(&self) -> &[u8]326     pub fn as_bytes(&self) -> &[u8] {
327         match self.inner {
328             protocol::Message::Text(ref s) => s.as_bytes(),
329             protocol::Message::Binary(ref v) => v,
330             protocol::Message::Ping(ref v) => v,
331             protocol::Message::Pong(ref v) => v,
332             protocol::Message::Close(_) => &[],
333         }
334     }
335 
336     /// Destructure this message into binary data.
into_bytes(self) -> Vec<u8>337     pub fn into_bytes(self) -> Vec<u8> {
338         self.inner.into_data()
339     }
340 }
341 
342 impl fmt::Debug for Message {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result343     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
344         fmt::Debug::fmt(&self.inner, f)
345     }
346 }
347 
348 impl Into<Vec<u8>> for Message {
into(self) -> Vec<u8>349     fn into(self) -> Vec<u8> {
350         self.into_bytes()
351     }
352 }
353 
354 // ===== Rejections =====
355 
356 #[derive(Debug)]
357 pub(crate) struct MissingConnectionUpgrade;
358 
359 impl ::std::fmt::Display for MissingConnectionUpgrade {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result360     fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
361         write!(f, "Connection header did not include 'upgrade'")
362     }
363 }
364 
365 impl ::std::error::Error for MissingConnectionUpgrade {}
366