1 use std::mem::replace;
2 
3 #[cfg(feature = "ssl")]
4 use openssl::ssl::SslStream;
5 #[cfg(feature = "nativetls")]
6 use native_tls::TlsStream as SslStream;
7 use url;
8 
9 use frame::Frame;
10 use handler::Handler;
11 use handshake::{Handshake, Request, Response};
12 use message::Message;
13 use protocol::{CloseCode, OpCode};
14 use result::{Error, Kind, Result};
15 #[cfg(any(feature = "ssl", feature = "nativetls"))]
16 use util::TcpStream;
17 use util::{Timeout, Token};
18 
19 use super::context::{Compressor, Decompressor};
20 
21 /// Deflate Extension Handler Settings
22 #[derive(Debug, Clone, Copy)]
23 pub struct DeflateSettings {
24     /// The max size of the sliding window. If the other endpoint selects a smaller size, that size
25     /// will be used instead. This must be an integer between 9 and 15 inclusive.
26     /// Default: 15
27     pub max_window_bits: u8,
28     /// Indicates whether to ask the other endpoint to reset the sliding window for each message.
29     /// Default: false
30     pub request_no_context_takeover: bool,
31     /// Indicates whether this endpoint will agree to reset the sliding window for each message it
32     /// compresses. If this endpoint won't agree to reset the sliding window, then the handshake
33     /// will fail if this endpoint is a client and the server requests no context takeover.
34     /// Default: true
35     pub accept_no_context_takeover: bool,
36     /// The number of WebSocket frames to store when defragmenting an incoming fragmented
37     /// compressed message.
38     /// This setting may be different from the `fragments_capacity` setting of the WebSocket in order to
39     /// allow for differences between compressed and uncompressed messages.
40     /// Default: 10
41     pub fragments_capacity: usize,
42     /// Indicates whether the extension handler will reallocate if the `fragments_capacity` is
43     /// exceeded. If this is not true, a capacity error will be triggered instead.
44     /// Default: true
45     pub fragments_grow: bool,
46 }
47 
48 impl Default for DeflateSettings {
default() -> DeflateSettings49     fn default() -> DeflateSettings {
50         DeflateSettings {
51             max_window_bits: 15,
52             request_no_context_takeover: false,
53             accept_no_context_takeover: true,
54             fragments_capacity: 10,
55             fragments_grow: true,
56         }
57     }
58 }
59 
60 /// Utility for applying the permessage-deflate extension to a handler with particular deflate
61 /// settings.
62 #[derive(Debug, Clone, Copy)]
63 pub struct DeflateBuilder {
64     settings: DeflateSettings,
65 }
66 
67 impl DeflateBuilder {
68     /// Create a new DeflateBuilder with the default settings.
new() -> DeflateBuilder69     pub fn new() -> DeflateBuilder {
70         DeflateBuilder {
71             settings: DeflateSettings::default(),
72         }
73     }
74 
75     /// Configure the DeflateBuilder with the given deflate settings.
with_settings(&mut self, settings: DeflateSettings) -> &mut DeflateBuilder76     pub fn with_settings(&mut self, settings: DeflateSettings) -> &mut DeflateBuilder {
77         self.settings = settings;
78         self
79     }
80 
81     /// Wrap another handler in with a deflate handler as configured.
build<H: Handler>(&self, handler: H) -> DeflateHandler<H>82     pub fn build<H: Handler>(&self, handler: H) -> DeflateHandler<H> {
83         DeflateHandler {
84             com: Compressor::new(self.settings.max_window_bits as i8),
85             dec: Decompressor::new(self.settings.max_window_bits as i8),
86             fragments: Vec::with_capacity(self.settings.fragments_capacity),
87             compress_reset: false,
88             decompress_reset: false,
89             pass: false,
90             settings: self.settings,
91             inner: handler,
92         }
93     }
94 }
95 
96 /// A WebSocket handler that implements the permessage-deflate extension.
97 ///
98 /// This handler wraps a child handler and proxies all handler methods to it. The handler will
99 /// decompress incoming WebSocket message frames in their reserved bits match the
100 /// permessage-deflate specification and pass them to the child handler. Message frames sent from
101 /// the child handler will be compressed and sent to the other endpoint using deflate compression.
102 pub struct DeflateHandler<H: Handler> {
103     com: Compressor,
104     dec: Decompressor,
105     fragments: Vec<Frame>,
106     compress_reset: bool,
107     decompress_reset: bool,
108     pass: bool,
109     settings: DeflateSettings,
110     inner: H,
111 }
112 
113 impl<H: Handler> DeflateHandler<H> {
114     /// Wrap a child handler to provide the permessage-deflate extension.
new(handler: H) -> DeflateHandler<H>115     pub fn new(handler: H) -> DeflateHandler<H> {
116         trace!("Using permessage-deflate handler.");
117         let settings = DeflateSettings::default();
118         DeflateHandler {
119             com: Compressor::new(settings.max_window_bits as i8),
120             dec: Decompressor::new(settings.max_window_bits as i8),
121             fragments: Vec::with_capacity(settings.fragments_capacity),
122             compress_reset: false,
123             decompress_reset: false,
124             pass: false,
125             settings: settings,
126             inner: handler,
127         }
128     }
129 
130     #[doc(hidden)]
131     #[inline]
decline(&mut self, mut res: Response) -> Result<Response>132     fn decline(&mut self, mut res: Response) -> Result<Response> {
133         trace!("Declined permessage-deflate offer");
134         self.pass = true;
135         res.remove_extension("permessage-deflate");
136         Ok(res)
137     }
138 }
139 
140 impl<H: Handler> Handler for DeflateHandler<H> {
build_request(&mut self, url: &url::Url) -> Result<Request>141     fn build_request(&mut self, url: &url::Url) -> Result<Request> {
142         let mut req = self.inner.build_request(url)?;
143         let mut req_ext = String::with_capacity(100);
144         req_ext.push_str("permessage-deflate");
145         if self.settings.max_window_bits < 15 {
146             req_ext.push_str(&format!(
147                 "; client_max_window_bits={}; server_max_window_bits={}",
148                 self.settings.max_window_bits, self.settings.max_window_bits
149             ))
150         } else {
151             req_ext.push_str("; client_max_window_bits")
152         }
153         if self.settings.request_no_context_takeover {
154             req_ext.push_str("; server_no_context_takeover")
155         }
156         req.add_extension(&req_ext);
157         Ok(req)
158     }
159 
on_request(&mut self, req: &Request) -> Result<Response>160     fn on_request(&mut self, req: &Request) -> Result<Response> {
161         let mut res = self.inner.on_request(req)?;
162 
163         'ext: for req_ext in req.extensions()?
164             .iter()
165             .filter(|&&ext| ext.contains("permessage-deflate"))
166         {
167             let mut res_ext = String::with_capacity(req_ext.len());
168             let mut s_takeover = false;
169             let mut c_takeover = false;
170             let mut s_max = false;
171             let mut c_max = false;
172 
173             for param in req_ext.split(';') {
174                 match param.trim() {
175                     "permessage-deflate" => res_ext.push_str("permessage-deflate"),
176                     "server_no_context_takeover" => {
177                         if s_takeover {
178                             return self.decline(res);
179                         } else {
180                             s_takeover = true;
181                             if self.settings.accept_no_context_takeover {
182                                 self.compress_reset = true;
183                                 res_ext.push_str("; server_no_context_takeover");
184                             } else {
185                                 continue 'ext;
186                             }
187                         }
188                     }
189                     "client_no_context_takeover" => {
190                         if c_takeover {
191                             return self.decline(res);
192                         } else {
193                             c_takeover = true;
194                             self.decompress_reset = true;
195                             res_ext.push_str("; client_no_context_takeover");
196                         }
197                     }
198                     param if param.starts_with("server_max_window_bits") => {
199                         if s_max {
200                             return self.decline(res);
201                         } else {
202                             s_max = true;
203                             let mut param_iter = param.split('=');
204                             param_iter.next(); // we already know the name
205                             if let Some(window_bits_str) = param_iter.next() {
206                                 if let Ok(window_bits) = window_bits_str.trim().parse() {
207                                     if window_bits >= 9 && window_bits <= 15 {
208                                         if window_bits < self.settings.max_window_bits as i8 {
209                                             self.com = Compressor::new(window_bits);
210                                             res_ext.push_str("; ");
211                                             res_ext.push_str(param)
212                                         }
213                                     } else {
214                                         return self.decline(res);
215                                     }
216                                 } else {
217                                     return self.decline(res);
218                                 }
219                             }
220                         }
221                     }
222                     param if param.starts_with("client_max_window_bits") => {
223                         if c_max {
224                             return self.decline(res);
225                         } else {
226                             c_max = true;
227                             let mut param_iter = param.split('=');
228                             param_iter.next(); // we already know the name
229                             if let Some(window_bits_str) = param_iter.next() {
230                                 if let Ok(window_bits) = window_bits_str.trim().parse() {
231                                     if window_bits >= 9 && window_bits <= 15 {
232                                         if window_bits < self.settings.max_window_bits as i8 {
233                                             self.dec = Decompressor::new(window_bits);
234                                             res_ext.push_str("; ");
235                                             res_ext.push_str(param);
236                                             continue;
237                                         }
238                                     } else {
239                                         return self.decline(res);
240                                     }
241                                 } else {
242                                     return self.decline(res);
243                                 }
244                             }
245                             res_ext.push_str("; ");
246                             res_ext.push_str(&format!(
247                                 "client_max_window_bits={}",
248                                 self.settings.max_window_bits
249                             ))
250                         }
251                     }
252                     _ => {
253                         // decline all extension offers because we got a bad parameter
254                         return self.decline(res);
255                     }
256                 }
257             }
258 
259             if !res_ext.contains("client_no_context_takeover")
260                 && self.settings.request_no_context_takeover
261             {
262                 self.decompress_reset = true;
263                 res_ext.push_str("; client_no_context_takeover");
264             }
265 
266             if !res_ext.contains("server_max_window_bits") {
267                 res_ext.push_str("; ");
268                 res_ext.push_str(&format!(
269                     "server_max_window_bits={}",
270                     self.settings.max_window_bits
271                 ))
272             }
273 
274             if !res_ext.contains("client_max_window_bits") && self.settings.max_window_bits < 15 {
275                 continue;
276             }
277 
278             res.add_extension(&res_ext);
279             return Ok(res);
280         }
281         self.decline(res)
282     }
283 
on_response(&mut self, res: &Response) -> Result<()>284     fn on_response(&mut self, res: &Response) -> Result<()> {
285         if let Some(res_ext) = res.extensions()?
286             .iter()
287             .find(|&&ext| ext.contains("permessage-deflate"))
288         {
289             let mut name = false;
290             let mut s_takeover = false;
291             let mut c_takeover = false;
292             let mut s_max = false;
293             let mut c_max = false;
294 
295             for param in res_ext.split(';') {
296                 match param.trim() {
297                     "permessage-deflate" => {
298                         if name {
299                             return Err(Error::new(
300                                 Kind::Protocol,
301                                 format!("Duplicate extension name permessage-deflate"),
302                             ));
303                         } else {
304                             name = true;
305                         }
306                     }
307                     "server_no_context_takeover" => {
308                         if s_takeover {
309                             return Err(Error::new(
310                                 Kind::Protocol,
311                                 format!("Duplicate extension parameter server_no_context_takeover"),
312                             ));
313                         } else {
314                             s_takeover = true;
315                             self.decompress_reset = true;
316                         }
317                     }
318                     "client_no_context_takeover" => {
319                         if c_takeover {
320                             return Err(Error::new(
321                                 Kind::Protocol,
322                                 format!("Duplicate extension parameter client_no_context_takeover"),
323                             ));
324                         } else {
325                             c_takeover = true;
326                             if self.settings.accept_no_context_takeover {
327                                 self.compress_reset = true;
328                             } else {
329                                 return Err(Error::new(
330                                     Kind::Protocol,
331                                     format!("The client requires context takeover."),
332                                 ));
333                             }
334                         }
335                     }
336                     param if param.starts_with("server_max_window_bits") => {
337                         if s_max {
338                             return Err(Error::new(
339                                 Kind::Protocol,
340                                 format!("Duplicate extension parameter server_max_window_bits"),
341                             ));
342                         } else {
343                             s_max = true;
344                             let mut param_iter = param.split('=');
345                             param_iter.next(); // we already know the name
346                             if let Some(window_bits_str) = param_iter.next() {
347                                 if let Ok(window_bits) = window_bits_str.trim().parse() {
348                                     if window_bits >= 9 && window_bits <= 15 {
349                                         if window_bits as u8 != self.settings.max_window_bits {
350                                             self.dec = Decompressor::new(window_bits);
351                                         }
352                                     } else {
353                                         return Err(Error::new(
354                                             Kind::Protocol,
355                                             format!(
356                                                 "Invalid server_max_window_bits parameter: {}",
357                                                 window_bits
358                                             ),
359                                         ));
360                                     }
361                                 } else {
362                                     return Err(Error::new(
363                                         Kind::Protocol,
364                                         format!(
365                                             "Invalid server_max_window_bits parameter: {}",
366                                             window_bits_str
367                                         ),
368                                     ));
369                                 }
370                             }
371                         }
372                     }
373                     param if param.starts_with("client_max_window_bits") => {
374                         if c_max {
375                             return Err(Error::new(
376                                 Kind::Protocol,
377                                 format!("Duplicate extension parameter client_max_window_bits"),
378                             ));
379                         } else {
380                             c_max = true;
381                             let mut param_iter = param.split('=');
382                             param_iter.next(); // we already know the name
383                             if let Some(window_bits_str) = param_iter.next() {
384                                 if let Ok(window_bits) = window_bits_str.trim().parse() {
385                                     if window_bits >= 9 && window_bits <= 15 {
386                                         if window_bits as u8 != self.settings.max_window_bits {
387                                             self.com = Compressor::new(window_bits);
388                                         }
389                                     } else {
390                                         return Err(Error::new(
391                                             Kind::Protocol,
392                                             format!(
393                                                 "Invalid client_max_window_bits parameter: {}",
394                                                 window_bits
395                                             ),
396                                         ));
397                                     }
398                                 } else {
399                                     return Err(Error::new(
400                                         Kind::Protocol,
401                                         format!(
402                                             "Invalid client_max_window_bits parameter: {}",
403                                             window_bits_str
404                                         ),
405                                     ));
406                                 }
407                             }
408                         }
409                     }
410                     param => {
411                         // fail the connection because we got a bad parameter
412                         return Err(Error::new(
413                             Kind::Protocol,
414                             format!("Bad extension parameter: {}", param),
415                         ));
416                     }
417                 }
418             }
419         } else {
420             self.pass = true
421         }
422 
423         Ok(())
424     }
425 
on_frame(&mut self, mut frame: Frame) -> Result<Option<Frame>>426     fn on_frame(&mut self, mut frame: Frame) -> Result<Option<Frame>> {
427         if !self.pass && !frame.is_control() {
428             if !self.fragments.is_empty() || frame.has_rsv1() {
429                 frame.set_rsv1(false);
430 
431                 if !frame.is_final() {
432                     self.fragments.push(frame);
433                     return Ok(None);
434                 } else {
435                     if frame.opcode() == OpCode::Continue {
436                         if self.fragments.is_empty() {
437                             return Err(Error::new(
438                                 Kind::Protocol,
439                                 "Unable to reconstruct fragmented message. No first frame.",
440                             ));
441                         } else {
442                             if !self.settings.fragments_grow
443                                 && self.settings.fragments_capacity == self.fragments.len()
444                             {
445                                 return Err(Error::new(Kind::Capacity, "Exceeded max fragments."));
446                             } else {
447                                 self.fragments.push(frame);
448                             }
449 
450                             // it's safe to unwrap because of the above check for empty
451                             let opcode = self.fragments.first().unwrap().opcode();
452                             let size = self.fragments
453                                 .iter()
454                                 .fold(0, |len, frame| len + frame.payload().len());
455                             let mut compressed = Vec::with_capacity(size);
456                             let mut decompressed = Vec::with_capacity(size * 2);
457                             for frag in replace(
458                                 &mut self.fragments,
459                                 Vec::with_capacity(self.settings.fragments_capacity),
460                             ) {
461                                 compressed.extend(frag.into_data())
462                             }
463 
464                             compressed.extend(&[0, 0, 255, 255]);
465                             self.dec.decompress(&compressed, &mut decompressed)?;
466                             frame = Frame::message(decompressed, opcode, true);
467                         }
468                     } else {
469                         let mut decompressed = Vec::with_capacity(frame.payload().len() * 2);
470                         frame.payload_mut().extend(&[0, 0, 255, 255]);
471 
472                         self.dec.decompress(frame.payload(), &mut decompressed)?;
473 
474                         *frame.payload_mut() = decompressed;
475                     }
476 
477                     if self.decompress_reset {
478                         self.dec.reset()?
479                     }
480                 }
481             }
482         }
483         self.inner.on_frame(frame)
484     }
485 
on_send_frame(&mut self, frame: Frame) -> Result<Option<Frame>>486     fn on_send_frame(&mut self, frame: Frame) -> Result<Option<Frame>> {
487         if let Some(mut frame) = self.inner.on_send_frame(frame)? {
488             if !self.pass && !frame.is_control() {
489                 debug_assert!(
490                     frame.is_final(),
491                     "Received non-final frame from upstream handler!"
492                 );
493                 debug_assert!(
494                     frame.opcode() != OpCode::Continue,
495                     "Received continue frame from upstream handler!"
496                 );
497 
498                 frame.set_rsv1(true);
499                 let mut compressed = Vec::with_capacity(frame.payload().len());
500                 self.com.compress(frame.payload(), &mut compressed)?;
501                 let len = compressed.len();
502                 compressed.truncate(len - 4);
503                 *frame.payload_mut() = compressed;
504 
505                 if self.compress_reset {
506                     self.com.reset()?
507                 }
508             }
509             Ok(Some(frame))
510         } else {
511             Ok(None)
512         }
513     }
514 
515     #[inline]
on_shutdown(&mut self)516     fn on_shutdown(&mut self) {
517         self.inner.on_shutdown()
518     }
519 
520     #[inline]
on_open(&mut self, shake: Handshake) -> Result<()>521     fn on_open(&mut self, shake: Handshake) -> Result<()> {
522         self.inner.on_open(shake)
523     }
524 
525     #[inline]
on_message(&mut self, msg: Message) -> Result<()>526     fn on_message(&mut self, msg: Message) -> Result<()> {
527         self.inner.on_message(msg)
528     }
529 
530     #[inline]
on_close(&mut self, code: CloseCode, reason: &str)531     fn on_close(&mut self, code: CloseCode, reason: &str) {
532         self.inner.on_close(code, reason)
533     }
534 
535     #[inline]
on_error(&mut self, err: Error)536     fn on_error(&mut self, err: Error) {
537         self.inner.on_error(err)
538     }
539 
540     #[inline]
on_timeout(&mut self, event: Token) -> Result<()>541     fn on_timeout(&mut self, event: Token) -> Result<()> {
542         self.inner.on_timeout(event)
543     }
544 
545     #[inline]
on_new_timeout(&mut self, tok: Token, timeout: Timeout) -> Result<()>546     fn on_new_timeout(&mut self, tok: Token, timeout: Timeout) -> Result<()> {
547         self.inner.on_new_timeout(tok, timeout)
548     }
549 
550     #[inline]
551     #[cfg(any(feature = "ssl", feature = "nativetls"))]
upgrade_ssl_client( &mut self, stream: TcpStream, url: &url::Url, ) -> Result<SslStream<TcpStream>>552     fn upgrade_ssl_client(
553         &mut self,
554         stream: TcpStream,
555         url: &url::Url,
556     ) -> Result<SslStream<TcpStream>> {
557         self.inner.upgrade_ssl_client(stream, url)
558     }
559 
560     #[inline]
561     #[cfg(any(feature = "ssl", feature = "nativetls"))]
upgrade_ssl_server(&mut self, stream: TcpStream) -> Result<SslStream<TcpStream>>562     fn upgrade_ssl_server(&mut self, stream: TcpStream) -> Result<SslStream<TcpStream>> {
563         self.inner.upgrade_ssl_server(stream)
564     }
565 }
566