1 //! WebSocket protocol support.
2 //!
3 //! To setup a `WebSocket`, first do web socket handshake then on success
4 //! convert `Payload` into a `WsStream` stream and then use `WsWriter` to
5 //! communicate with the peer.
6 use std::io;
7
8 use derive_more::{Display, From};
9 use http::{header, Method, StatusCode};
10
11 use crate::error::ResponseError;
12 use crate::message::RequestHead;
13 use crate::response::{Response, ResponseBuilder};
14
15 mod codec;
16 mod dispatcher;
17 mod frame;
18 mod mask;
19 mod proto;
20
21 pub use self::codec::{Codec, Frame, Item, Message};
22 pub use self::dispatcher::Dispatcher;
23 pub use self::frame::Parser;
24 pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
25
26 /// Websocket protocol errors
27 #[derive(Debug, Display, From)]
28 pub enum ProtocolError {
29 /// Received an unmasked frame from client
30 #[display(fmt = "Received an unmasked frame from client")]
31 UnmaskedFrame,
32 /// Received a masked frame from server
33 #[display(fmt = "Received a masked frame from server")]
34 MaskedFrame,
35 /// Encountered invalid opcode
36 #[display(fmt = "Invalid opcode: {}", _0)]
37 InvalidOpcode(u8),
38 /// Invalid control frame length
39 #[display(fmt = "Invalid control frame length: {}", _0)]
40 InvalidLength(usize),
41 /// Bad web socket op code
42 #[display(fmt = "Bad web socket op code")]
43 BadOpCode,
44 /// A payload reached size limit.
45 #[display(fmt = "A payload reached size limit.")]
46 Overflow,
47 /// Continuation is not started
48 #[display(fmt = "Continuation is not started.")]
49 ContinuationNotStarted,
50 /// Received new continuation but it is already started
51 #[display(fmt = "Received new continuation but it is already started")]
52 ContinuationStarted,
53 /// Unknown continuation fragment
54 #[display(fmt = "Unknown continuation fragment.")]
55 ContinuationFragment(OpCode),
56 /// Io error
57 #[display(fmt = "io error: {}", _0)]
58 Io(io::Error),
59 }
60
61 impl ResponseError for ProtocolError {}
62
63 /// Websocket handshake errors
64 #[derive(PartialEq, Debug, Display)]
65 pub enum HandshakeError {
66 /// Only get method is allowed
67 #[display(fmt = "Method not allowed")]
68 GetMethodRequired,
69 /// Upgrade header if not set to websocket
70 #[display(fmt = "Websocket upgrade is expected")]
71 NoWebsocketUpgrade,
72 /// Connection header is not set to upgrade
73 #[display(fmt = "Connection upgrade is expected")]
74 NoConnectionUpgrade,
75 /// Websocket version header is not set
76 #[display(fmt = "Websocket version header is required")]
77 NoVersionHeader,
78 /// Unsupported websocket version
79 #[display(fmt = "Unsupported version")]
80 UnsupportedVersion,
81 /// Websocket key is not set or wrong
82 #[display(fmt = "Unknown websocket key")]
83 BadWebsocketKey,
84 }
85
86 impl ResponseError for HandshakeError {
error_response(&self) -> Response87 fn error_response(&self) -> Response {
88 match *self {
89 HandshakeError::GetMethodRequired => Response::MethodNotAllowed()
90 .header(header::ALLOW, "GET")
91 .finish(),
92 HandshakeError::NoWebsocketUpgrade => Response::BadRequest()
93 .reason("No WebSocket UPGRADE header found")
94 .finish(),
95 HandshakeError::NoConnectionUpgrade => Response::BadRequest()
96 .reason("No CONNECTION upgrade")
97 .finish(),
98 HandshakeError::NoVersionHeader => Response::BadRequest()
99 .reason("Websocket version header is required")
100 .finish(),
101 HandshakeError::UnsupportedVersion => Response::BadRequest()
102 .reason("Unsupported version")
103 .finish(),
104 HandshakeError::BadWebsocketKey => {
105 Response::BadRequest().reason("Handshake error").finish()
106 }
107 }
108 }
109 }
110
111 /// Verify `WebSocket` handshake request and create handshake reponse.
112 // /// `protocols` is a sequence of known protocols. On successful handshake,
113 // /// the returned response headers contain the first protocol in this list
114 // /// which the server also knows.
handshake(req: &RequestHead) -> Result<ResponseBuilder, HandshakeError>115 pub fn handshake(req: &RequestHead) -> Result<ResponseBuilder, HandshakeError> {
116 verify_handshake(req)?;
117 Ok(handshake_response(req))
118 }
119
120 /// Verify `WebSocket` handshake request.
121 // /// `protocols` is a sequence of known protocols. On successful handshake,
122 // /// the returned response headers contain the first protocol in this list
123 // /// which the server also knows.
verify_handshake(req: &RequestHead) -> Result<(), HandshakeError>124 pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> {
125 // WebSocket accepts only GET
126 if req.method != Method::GET {
127 return Err(HandshakeError::GetMethodRequired);
128 }
129
130 // Check for "UPGRADE" to websocket header
131 let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) {
132 if let Ok(s) = hdr.to_str() {
133 s.to_ascii_lowercase().contains("websocket")
134 } else {
135 false
136 }
137 } else {
138 false
139 };
140 if !has_hdr {
141 return Err(HandshakeError::NoWebsocketUpgrade);
142 }
143
144 // Upgrade connection
145 if !req.upgrade() {
146 return Err(HandshakeError::NoConnectionUpgrade);
147 }
148
149 // check supported version
150 if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) {
151 return Err(HandshakeError::NoVersionHeader);
152 }
153 let supported_ver = {
154 if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) {
155 hdr == "13" || hdr == "8" || hdr == "7"
156 } else {
157 false
158 }
159 };
160 if !supported_ver {
161 return Err(HandshakeError::UnsupportedVersion);
162 }
163
164 // check client handshake for validity
165 if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) {
166 return Err(HandshakeError::BadWebsocketKey);
167 }
168 Ok(())
169 }
170
171 /// Create websocket's handshake response
172 ///
173 /// This function returns handshake `Response`, ready to send to peer.
handshake_response(req: &RequestHead) -> ResponseBuilder174 pub fn handshake_response(req: &RequestHead) -> ResponseBuilder {
175 let key = {
176 let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap();
177 proto::hash_key(key.as_ref())
178 };
179
180 Response::build(StatusCode::SWITCHING_PROTOCOLS)
181 .upgrade("websocket")
182 .header(header::TRANSFER_ENCODING, "chunked")
183 .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())
184 .take()
185 }
186
187 #[cfg(test)]
188 mod tests {
189 use super::*;
190 use crate::test::TestRequest;
191 use http::{header, Method};
192
193 #[test]
test_handshake()194 fn test_handshake() {
195 let req = TestRequest::default().method(Method::POST).finish();
196 assert_eq!(
197 HandshakeError::GetMethodRequired,
198 verify_handshake(req.head()).err().unwrap()
199 );
200
201 let req = TestRequest::default().finish();
202 assert_eq!(
203 HandshakeError::NoWebsocketUpgrade,
204 verify_handshake(req.head()).err().unwrap()
205 );
206
207 let req = TestRequest::default()
208 .header(header::UPGRADE, header::HeaderValue::from_static("test"))
209 .finish();
210 assert_eq!(
211 HandshakeError::NoWebsocketUpgrade,
212 verify_handshake(req.head()).err().unwrap()
213 );
214
215 let req = TestRequest::default()
216 .header(
217 header::UPGRADE,
218 header::HeaderValue::from_static("websocket"),
219 )
220 .finish();
221 assert_eq!(
222 HandshakeError::NoConnectionUpgrade,
223 verify_handshake(req.head()).err().unwrap()
224 );
225
226 let req = TestRequest::default()
227 .header(
228 header::UPGRADE,
229 header::HeaderValue::from_static("websocket"),
230 )
231 .header(
232 header::CONNECTION,
233 header::HeaderValue::from_static("upgrade"),
234 )
235 .finish();
236 assert_eq!(
237 HandshakeError::NoVersionHeader,
238 verify_handshake(req.head()).err().unwrap()
239 );
240
241 let req = TestRequest::default()
242 .header(
243 header::UPGRADE,
244 header::HeaderValue::from_static("websocket"),
245 )
246 .header(
247 header::CONNECTION,
248 header::HeaderValue::from_static("upgrade"),
249 )
250 .header(
251 header::SEC_WEBSOCKET_VERSION,
252 header::HeaderValue::from_static("5"),
253 )
254 .finish();
255 assert_eq!(
256 HandshakeError::UnsupportedVersion,
257 verify_handshake(req.head()).err().unwrap()
258 );
259
260 let req = TestRequest::default()
261 .header(
262 header::UPGRADE,
263 header::HeaderValue::from_static("websocket"),
264 )
265 .header(
266 header::CONNECTION,
267 header::HeaderValue::from_static("upgrade"),
268 )
269 .header(
270 header::SEC_WEBSOCKET_VERSION,
271 header::HeaderValue::from_static("13"),
272 )
273 .finish();
274 assert_eq!(
275 HandshakeError::BadWebsocketKey,
276 verify_handshake(req.head()).err().unwrap()
277 );
278
279 let req = TestRequest::default()
280 .header(
281 header::UPGRADE,
282 header::HeaderValue::from_static("websocket"),
283 )
284 .header(
285 header::CONNECTION,
286 header::HeaderValue::from_static("upgrade"),
287 )
288 .header(
289 header::SEC_WEBSOCKET_VERSION,
290 header::HeaderValue::from_static("13"),
291 )
292 .header(
293 header::SEC_WEBSOCKET_KEY,
294 header::HeaderValue::from_static("13"),
295 )
296 .finish();
297 assert_eq!(
298 StatusCode::SWITCHING_PROTOCOLS,
299 handshake_response(req.head()).finish().status()
300 );
301 }
302
303 #[test]
test_wserror_http_response()304 fn test_wserror_http_response() {
305 let resp: Response = HandshakeError::GetMethodRequired.error_response();
306 assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
307 let resp: Response = HandshakeError::NoWebsocketUpgrade.error_response();
308 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
309 let resp: Response = HandshakeError::NoConnectionUpgrade.error_response();
310 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
311 let resp: Response = HandshakeError::NoVersionHeader.error_response();
312 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
313 let resp: Response = HandshakeError::UnsupportedVersion.error_response();
314 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
315 let resp: Response = HandshakeError::BadWebsocketKey.error_response();
316 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
317 }
318 }
319