1 //! Websockets Filters
2
3 use std::borrow::Cow;
4 use std::fmt;
5 use std::future::Future;
6 use std::pin::Pin;
7 use std::task::{Context, Poll};
8
9 use super::{body, header};
10 use crate::filter::{Filter, One};
11 use crate::reject::Rejection;
12 use crate::reply::{Reply, Response};
13 use futures::{future, ready, FutureExt, Sink, Stream, TryFutureExt};
14 use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade};
15 use http;
16 use tokio_tungstenite::{
17 tungstenite::protocol::{self, WebSocketConfig},
18 WebSocketStream,
19 };
20
21 /// Creates a Websocket Filter.
22 ///
23 /// The yielded `Ws` is used to finish the websocket upgrade.
24 ///
25 /// # Note
26 ///
27 /// This filter combines multiple filters internally, so you don't need them:
28 ///
29 /// - Method must be `GET`
30 /// - Header `connection` must be `upgrade`
31 /// - Header `upgrade` must be `websocket`
32 /// - Header `sec-websocket-version` must be `13`
33 /// - Header `sec-websocket-key` must be set.
34 ///
35 /// If the filters are met, yields a `Ws`. Calling `Ws::on_upgrade` will
36 /// return a reply with:
37 ///
38 /// - Status of `101 Switching Protocols`
39 /// - Header `connection: upgrade`
40 /// - Header `upgrade: websocket`
41 /// - Header `sec-websocket-accept` with the hash value of the received key.
ws() -> impl Filter<Extract = One<Ws>, Error = Rejection> + Copy42 pub fn ws() -> impl Filter<Extract = One<Ws>, Error = Rejection> + Copy {
43 let connection_has_upgrade = header::header2()
44 .and_then(|conn: ::headers::Connection| {
45 if conn.contains("upgrade") {
46 future::ok(())
47 } else {
48 future::err(crate::reject::known(MissingConnectionUpgrade))
49 }
50 })
51 .untuple_one();
52
53 crate::get()
54 .and(connection_has_upgrade)
55 .and(header::exact_ignore_case("upgrade", "websocket"))
56 .and(header::exact("sec-websocket-version", "13"))
57 //.and(header::exact2(Upgrade::websocket()))
58 //.and(header::exact2(SecWebsocketVersion::V13))
59 .and(header::header2::<SecWebsocketKey>())
60 .and(body::body())
61 .map(move |key: SecWebsocketKey, body: ::hyper::Body| Ws {
62 body,
63 config: None,
64 key,
65 })
66 }
67
68 /// Extracted by the [`ws`](ws) filter, and used to finish an upgrade.
69 pub struct Ws {
70 body: ::hyper::Body,
71 config: Option<WebSocketConfig>,
72 key: SecWebsocketKey,
73 }
74
75 impl Ws {
76 /// Finish the upgrade, passing a function to handle the `WebSocket`.
77 ///
78 /// The passed function must return a `Future`.
on_upgrade<F, U>(self, func: F) -> impl Reply where F: FnOnce(WebSocket) -> U + Send + 'static, U: Future<Output = ()> + Send + 'static,79 pub fn on_upgrade<F, U>(self, func: F) -> impl Reply
80 where
81 F: FnOnce(WebSocket) -> U + Send + 'static,
82 U: Future<Output = ()> + Send + 'static,
83 {
84 WsReply {
85 ws: self,
86 on_upgrade: func,
87 }
88 }
89
90 // config
91
92 /// Set the size of the internal message send queue.
max_send_queue(mut self, max: usize) -> Self93 pub fn max_send_queue(mut self, max: usize) -> Self {
94 self.config
95 .get_or_insert_with(WebSocketConfig::default)
96 .max_send_queue = Some(max);
97 self
98 }
99
100 /// Set the maximum message size (defaults to 64 megabytes)
max_message_size(mut self, max: usize) -> Self101 pub fn max_message_size(mut self, max: usize) -> Self {
102 self.config
103 .get_or_insert_with(WebSocketConfig::default)
104 .max_message_size = Some(max);
105 self
106 }
107 }
108
109 impl fmt::Debug for Ws {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result110 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
111 f.debug_struct("Ws").finish()
112 }
113 }
114
115 #[allow(missing_debug_implementations)]
116 struct WsReply<F> {
117 ws: Ws,
118 on_upgrade: F,
119 }
120
121 impl<F, U> Reply for WsReply<F>
122 where
123 F: FnOnce(WebSocket) -> U + Send + 'static,
124 U: Future<Output = ()> + Send + 'static,
125 {
into_response(self) -> Response126 fn into_response(self) -> Response {
127 let on_upgrade = self.on_upgrade;
128 let config = self.ws.config;
129 let fut = self
130 .ws
131 .body
132 .on_upgrade()
133 .and_then(move |upgraded| {
134 log::trace!("websocket upgrade complete");
135 WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok)
136 })
137 .and_then(move |socket| on_upgrade(socket).map(Ok))
138 .map(|result| {
139 if let Err(err) = result {
140 log::debug!("ws upgrade error: {}", err);
141 }
142 });
143 ::tokio::task::spawn(fut);
144
145 let mut res = http::Response::default();
146
147 *res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS;
148
149 res.headers_mut().typed_insert(Connection::upgrade());
150 res.headers_mut().typed_insert(Upgrade::websocket());
151 res.headers_mut()
152 .typed_insert(SecWebsocketAccept::from(self.ws.key));
153
154 res
155 }
156 }
157
158 /// A websocket `Stream` and `Sink`, provided to `ws` filters.
159 pub struct WebSocket {
160 inner: WebSocketStream<hyper::upgrade::Upgraded>,
161 }
162
163 impl WebSocket {
from_raw_socket( upgraded: hyper::upgrade::Upgraded, role: protocol::Role, config: Option<protocol::WebSocketConfig>, ) -> Self164 pub(crate) async fn from_raw_socket(
165 upgraded: hyper::upgrade::Upgraded,
166 role: protocol::Role,
167 config: Option<protocol::WebSocketConfig>,
168 ) -> Self {
169 WebSocketStream::from_raw_socket(upgraded, role, config)
170 .map(|inner| WebSocket { inner })
171 .await
172 }
173
174 /// Gracefully close this websocket.
close(mut self) -> Result<(), crate::Error>175 pub async fn close(mut self) -> Result<(), crate::Error> {
176 future::poll_fn(|cx| Pin::new(&mut self).poll_close(cx)).await
177 }
178 }
179
180 impl Stream for WebSocket {
181 type Item = Result<Message, crate::Error>;
182
poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>>183 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
184 match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
185 Some(Ok(item)) => Poll::Ready(Some(Ok(Message { inner: item }))),
186 Some(Err(e)) => {
187 log::debug!("websocket poll error: {}", e);
188 Poll::Ready(Some(Err(crate::Error::new(e))))
189 }
190 None => {
191 log::trace!("websocket closed");
192 Poll::Ready(None)
193 }
194 }
195 }
196 }
197
198 impl Sink<Message> for WebSocket {
199 type Error = crate::Error;
200
poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>201 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
202 match ready!(Pin::new(&mut self.inner).poll_ready(cx)) {
203 Ok(()) => Poll::Ready(Ok(())),
204 Err(e) => Poll::Ready(Err(crate::Error::new(e))),
205 }
206 }
207
start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error>208 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
209 match Pin::new(&mut self.inner).start_send(item.inner) {
210 Ok(()) => Ok(()),
211 Err(e) => {
212 log::debug!("websocket start_send error: {}", e);
213 Err(crate::Error::new(e))
214 }
215 }
216 }
217
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>>218 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
219 match ready!(Pin::new(&mut self.inner).poll_flush(cx)) {
220 Ok(()) => Poll::Ready(Ok(())),
221 Err(e) => Poll::Ready(Err(crate::Error::new(e))),
222 }
223 }
224
poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>>225 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
226 match ready!(Pin::new(&mut self.inner).poll_close(cx)) {
227 Ok(()) => Poll::Ready(Ok(())),
228 Err(err) => {
229 log::debug!("websocket close error: {}", err);
230 Poll::Ready(Err(crate::Error::new(err)))
231 }
232 }
233 }
234 }
235
236 impl fmt::Debug for WebSocket {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result237 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
238 f.debug_struct("WebSocket").finish()
239 }
240 }
241
242 /// A WebSocket message.
243 ///
244 /// Only repesents Text and Binary messages.
245 ///
246 /// This will likely become a `non-exhaustive` enum in the future, once that
247 /// language feature has stabilized.
248 #[derive(Eq, PartialEq, Clone)]
249 pub struct Message {
250 inner: protocol::Message,
251 }
252
253 impl Message {
254 /// Construct a new Text `Message`.
text<S: Into<String>>(s: S) -> Message255 pub fn text<S: Into<String>>(s: S) -> Message {
256 Message {
257 inner: protocol::Message::text(s),
258 }
259 }
260
261 /// Construct a new Binary `Message`.
binary<V: Into<Vec<u8>>>(v: V) -> Message262 pub fn binary<V: Into<Vec<u8>>>(v: V) -> Message {
263 Message {
264 inner: protocol::Message::binary(v),
265 }
266 }
267
268 /// Construct a new Ping `Message`.
ping<V: Into<Vec<u8>>>(v: V) -> Message269 pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
270 Message {
271 inner: protocol::Message::Ping(v.into()),
272 }
273 }
274
275 /// Construct the default Close `Message`.
close() -> Message276 pub fn close() -> Message {
277 Message {
278 inner: protocol::Message::Close(None),
279 }
280 }
281
282 /// Construct a Close `Message` with a code and reason.
close_with(code: impl Into<u16>, reason: impl Into<Cow<'static, str>>) -> Message283 pub fn close_with(code: impl Into<u16>, reason: impl Into<Cow<'static, str>>) -> Message {
284 Message {
285 inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
286 code: protocol::frame::coding::CloseCode::from(code.into()),
287 reason: reason.into(),
288 })),
289 }
290 }
291
292 /// Returns true if this message is a Text message.
is_text(&self) -> bool293 pub fn is_text(&self) -> bool {
294 self.inner.is_text()
295 }
296
297 /// Returns true if this message is a Binary message.
is_binary(&self) -> bool298 pub fn is_binary(&self) -> bool {
299 self.inner.is_binary()
300 }
301
302 /// Returns true if this message a is a Close message.
is_close(&self) -> bool303 pub fn is_close(&self) -> bool {
304 self.inner.is_close()
305 }
306
307 /// Returns true if this message is a Ping message.
is_ping(&self) -> bool308 pub fn is_ping(&self) -> bool {
309 self.inner.is_ping()
310 }
311
312 /// Returns true if this message is a Pong message.
is_pong(&self) -> bool313 pub fn is_pong(&self) -> bool {
314 self.inner.is_pong()
315 }
316
317 /// Try to get a reference to the string text, if this is a Text message.
to_str(&self) -> Result<&str, ()>318 pub fn to_str(&self) -> Result<&str, ()> {
319 match self.inner {
320 protocol::Message::Text(ref s) => Ok(s),
321 _ => Err(()),
322 }
323 }
324
325 /// Return the bytes of this message, if the message can contain data.
as_bytes(&self) -> &[u8]326 pub fn as_bytes(&self) -> &[u8] {
327 match self.inner {
328 protocol::Message::Text(ref s) => s.as_bytes(),
329 protocol::Message::Binary(ref v) => v,
330 protocol::Message::Ping(ref v) => v,
331 protocol::Message::Pong(ref v) => v,
332 protocol::Message::Close(_) => &[],
333 }
334 }
335
336 /// Destructure this message into binary data.
into_bytes(self) -> Vec<u8>337 pub fn into_bytes(self) -> Vec<u8> {
338 self.inner.into_data()
339 }
340 }
341
342 impl fmt::Debug for Message {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result343 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
344 fmt::Debug::fmt(&self.inner, f)
345 }
346 }
347
348 impl Into<Vec<u8>> for Message {
into(self) -> Vec<u8>349 fn into(self) -> Vec<u8> {
350 self.into_bytes()
351 }
352 }
353
354 // ===== Rejections =====
355
356 #[derive(Debug)]
357 pub(crate) struct MissingConnectionUpgrade;
358
359 impl ::std::fmt::Display for MissingConnectionUpgrade {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result360 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
361 write!(f, "Connection header did not include 'upgrade'")
362 }
363 }
364
365 impl ::std::error::Error for MissingConnectionUpgrade {}
366