1 //! WebSocket protocol support.
2 //!
3 //! To setup a `WebSocket`, first do web socket handshake then on success
4 //! convert `Payload` into a `WsStream` stream and then use `WsWriter` to
5 //! communicate with the peer.
6 use std::io;
7 
8 use derive_more::{Display, From};
9 use http::{header, Method, StatusCode};
10 
11 use crate::error::ResponseError;
12 use crate::message::RequestHead;
13 use crate::response::{Response, ResponseBuilder};
14 
15 mod codec;
16 mod dispatcher;
17 mod frame;
18 mod mask;
19 mod proto;
20 
21 pub use self::codec::{Codec, Frame, Item, Message};
22 pub use self::dispatcher::Dispatcher;
23 pub use self::frame::Parser;
24 pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
25 
26 /// Websocket protocol errors
27 #[derive(Debug, Display, From)]
28 pub enum ProtocolError {
29     /// Received an unmasked frame from client
30     #[display(fmt = "Received an unmasked frame from client")]
31     UnmaskedFrame,
32     /// Received a masked frame from server
33     #[display(fmt = "Received a masked frame from server")]
34     MaskedFrame,
35     /// Encountered invalid opcode
36     #[display(fmt = "Invalid opcode: {}", _0)]
37     InvalidOpcode(u8),
38     /// Invalid control frame length
39     #[display(fmt = "Invalid control frame length: {}", _0)]
40     InvalidLength(usize),
41     /// Bad web socket op code
42     #[display(fmt = "Bad web socket op code")]
43     BadOpCode,
44     /// A payload reached size limit.
45     #[display(fmt = "A payload reached size limit.")]
46     Overflow,
47     /// Continuation is not started
48     #[display(fmt = "Continuation is not started.")]
49     ContinuationNotStarted,
50     /// Received new continuation but it is already started
51     #[display(fmt = "Received new continuation but it is already started")]
52     ContinuationStarted,
53     /// Unknown continuation fragment
54     #[display(fmt = "Unknown continuation fragment.")]
55     ContinuationFragment(OpCode),
56     /// Io error
57     #[display(fmt = "io error: {}", _0)]
58     Io(io::Error),
59 }
60 
61 impl ResponseError for ProtocolError {}
62 
63 /// Websocket handshake errors
64 #[derive(PartialEq, Debug, Display)]
65 pub enum HandshakeError {
66     /// Only get method is allowed
67     #[display(fmt = "Method not allowed")]
68     GetMethodRequired,
69     /// Upgrade header if not set to websocket
70     #[display(fmt = "Websocket upgrade is expected")]
71     NoWebsocketUpgrade,
72     /// Connection header is not set to upgrade
73     #[display(fmt = "Connection upgrade is expected")]
74     NoConnectionUpgrade,
75     /// Websocket version header is not set
76     #[display(fmt = "Websocket version header is required")]
77     NoVersionHeader,
78     /// Unsupported websocket version
79     #[display(fmt = "Unsupported version")]
80     UnsupportedVersion,
81     /// Websocket key is not set or wrong
82     #[display(fmt = "Unknown websocket key")]
83     BadWebsocketKey,
84 }
85 
86 impl ResponseError for HandshakeError {
error_response(&self) -> Response87     fn error_response(&self) -> Response {
88         match *self {
89             HandshakeError::GetMethodRequired => Response::MethodNotAllowed()
90                 .header(header::ALLOW, "GET")
91                 .finish(),
92             HandshakeError::NoWebsocketUpgrade => Response::BadRequest()
93                 .reason("No WebSocket UPGRADE header found")
94                 .finish(),
95             HandshakeError::NoConnectionUpgrade => Response::BadRequest()
96                 .reason("No CONNECTION upgrade")
97                 .finish(),
98             HandshakeError::NoVersionHeader => Response::BadRequest()
99                 .reason("Websocket version header is required")
100                 .finish(),
101             HandshakeError::UnsupportedVersion => Response::BadRequest()
102                 .reason("Unsupported version")
103                 .finish(),
104             HandshakeError::BadWebsocketKey => {
105                 Response::BadRequest().reason("Handshake error").finish()
106             }
107         }
108     }
109 }
110 
111 /// Verify `WebSocket` handshake request and create handshake reponse.
112 // /// `protocols` is a sequence of known protocols. On successful handshake,
113 // /// the returned response headers contain the first protocol in this list
114 // /// which the server also knows.
handshake(req: &RequestHead) -> Result<ResponseBuilder, HandshakeError>115 pub fn handshake(req: &RequestHead) -> Result<ResponseBuilder, HandshakeError> {
116     verify_handshake(req)?;
117     Ok(handshake_response(req))
118 }
119 
120 /// Verify `WebSocket` handshake request.
121 // /// `protocols` is a sequence of known protocols. On successful handshake,
122 // /// the returned response headers contain the first protocol in this list
123 // /// which the server also knows.
verify_handshake(req: &RequestHead) -> Result<(), HandshakeError>124 pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> {
125     // WebSocket accepts only GET
126     if req.method != Method::GET {
127         return Err(HandshakeError::GetMethodRequired);
128     }
129 
130     // Check for "UPGRADE" to websocket header
131     let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) {
132         if let Ok(s) = hdr.to_str() {
133             s.to_ascii_lowercase().contains("websocket")
134         } else {
135             false
136         }
137     } else {
138         false
139     };
140     if !has_hdr {
141         return Err(HandshakeError::NoWebsocketUpgrade);
142     }
143 
144     // Upgrade connection
145     if !req.upgrade() {
146         return Err(HandshakeError::NoConnectionUpgrade);
147     }
148 
149     // check supported version
150     if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) {
151         return Err(HandshakeError::NoVersionHeader);
152     }
153     let supported_ver = {
154         if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) {
155             hdr == "13" || hdr == "8" || hdr == "7"
156         } else {
157             false
158         }
159     };
160     if !supported_ver {
161         return Err(HandshakeError::UnsupportedVersion);
162     }
163 
164     // check client handshake for validity
165     if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
166         return Err(HandshakeError::BadWebsocketKey);
167     }
168     Ok(())
169 }
170 
171 /// Create websocket's handshake response
172 ///
173 /// This function returns handshake `Response`, ready to send to peer.
handshake_response(req: &RequestHead) -> ResponseBuilder174 pub fn handshake_response(req: &RequestHead) -> ResponseBuilder {
175     let key = {
176         let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
177         proto::hash_key(key.as_ref())
178     };
179 
180     Response::build(StatusCode::SWITCHING_PROTOCOLS)
181         .upgrade("websocket")
182         .header(header::TRANSFER_ENCODING, "chunked")
183         .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())
184         .take()
185 }
186 
187 #[cfg(test)]
188 mod tests {
189     use super::*;
190     use crate::test::TestRequest;
191     use http::{header, Method};
192 
193     #[test]
test_handshake()194     fn test_handshake() {
195         let req = TestRequest::default().method(Method::POST).finish();
196         assert_eq!(
197             HandshakeError::GetMethodRequired,
198             verify_handshake(req.head()).err().unwrap()
199         );
200 
201         let req = TestRequest::default().finish();
202         assert_eq!(
203             HandshakeError::NoWebsocketUpgrade,
204             verify_handshake(req.head()).err().unwrap()
205         );
206 
207         let req = TestRequest::default()
208             .header(header::UPGRADE, header::HeaderValue::from_static("test"))
209             .finish();
210         assert_eq!(
211             HandshakeError::NoWebsocketUpgrade,
212             verify_handshake(req.head()).err().unwrap()
213         );
214 
215         let req = TestRequest::default()
216             .header(
217                 header::UPGRADE,
218                 header::HeaderValue::from_static("websocket"),
219             )
220             .finish();
221         assert_eq!(
222             HandshakeError::NoConnectionUpgrade,
223             verify_handshake(req.head()).err().unwrap()
224         );
225 
226         let req = TestRequest::default()
227             .header(
228                 header::UPGRADE,
229                 header::HeaderValue::from_static("websocket"),
230             )
231             .header(
232                 header::CONNECTION,
233                 header::HeaderValue::from_static("upgrade"),
234             )
235             .finish();
236         assert_eq!(
237             HandshakeError::NoVersionHeader,
238             verify_handshake(req.head()).err().unwrap()
239         );
240 
241         let req = TestRequest::default()
242             .header(
243                 header::UPGRADE,
244                 header::HeaderValue::from_static("websocket"),
245             )
246             .header(
247                 header::CONNECTION,
248                 header::HeaderValue::from_static("upgrade"),
249             )
250             .header(
251                 header::SEC_WEBSOCKET_VERSION,
252                 header::HeaderValue::from_static("5"),
253             )
254             .finish();
255         assert_eq!(
256             HandshakeError::UnsupportedVersion,
257             verify_handshake(req.head()).err().unwrap()
258         );
259 
260         let req = TestRequest::default()
261             .header(
262                 header::UPGRADE,
263                 header::HeaderValue::from_static("websocket"),
264             )
265             .header(
266                 header::CONNECTION,
267                 header::HeaderValue::from_static("upgrade"),
268             )
269             .header(
270                 header::SEC_WEBSOCKET_VERSION,
271                 header::HeaderValue::from_static("13"),
272             )
273             .finish();
274         assert_eq!(
275             HandshakeError::BadWebsocketKey,
276             verify_handshake(req.head()).err().unwrap()
277         );
278 
279         let req = TestRequest::default()
280             .header(
281                 header::UPGRADE,
282                 header::HeaderValue::from_static("websocket"),
283             )
284             .header(
285                 header::CONNECTION,
286                 header::HeaderValue::from_static("upgrade"),
287             )
288             .header(
289                 header::SEC_WEBSOCKET_VERSION,
290                 header::HeaderValue::from_static("13"),
291             )
292             .header(
293                 header::SEC_WEBSOCKET_KEY,
294                 header::HeaderValue::from_static("13"),
295             )
296             .finish();
297         assert_eq!(
298             StatusCode::SWITCHING_PROTOCOLS,
299             handshake_response(req.head()).finish().status()
300         );
301     }
302 
303     #[test]
test_wserror_http_response()304     fn test_wserror_http_response() {
305         let resp: Response = HandshakeError::GetMethodRequired.error_response();
306         assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
307         let resp: Response = HandshakeError::NoWebsocketUpgrade.error_response();
308         assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
309         let resp: Response = HandshakeError::NoConnectionUpgrade.error_response();
310         assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
311         let resp: Response = HandshakeError::NoVersionHeader.error_response();
312         assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
313         let resp: Response = HandshakeError::UnsupportedVersion.error_response();
314         assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
315         let resp: Response = HandshakeError::BadWebsocketKey.error_response();
316         assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
317     }
318 }
319