1 use std::mem::replace; 2 3 #[cfg(feature = "ssl")] 4 use openssl::ssl::SslStream; 5 #[cfg(feature = "nativetls")] 6 use native_tls::TlsStream as SslStream; 7 use url; 8 9 use frame::Frame; 10 use handler::Handler; 11 use handshake::{Handshake, Request, Response}; 12 use message::Message; 13 use protocol::{CloseCode, OpCode}; 14 use result::{Error, Kind, Result}; 15 #[cfg(any(feature = "ssl", feature = "nativetls"))] 16 use util::TcpStream; 17 use util::{Timeout, Token}; 18 19 use super::context::{Compressor, Decompressor}; 20 21 /// Deflate Extension Handler Settings 22 #[derive(Debug, Clone, Copy)] 23 pub struct DeflateSettings { 24 /// The max size of the sliding window. If the other endpoint selects a smaller size, that size 25 /// will be used instead. This must be an integer between 9 and 15 inclusive. 26 /// Default: 15 27 pub max_window_bits: u8, 28 /// Indicates whether to ask the other endpoint to reset the sliding window for each message. 29 /// Default: false 30 pub request_no_context_takeover: bool, 31 /// Indicates whether this endpoint will agree to reset the sliding window for each message it 32 /// compresses. If this endpoint won't agree to reset the sliding window, then the handshake 33 /// will fail if this endpoint is a client and the server requests no context takeover. 34 /// Default: true 35 pub accept_no_context_takeover: bool, 36 /// The number of WebSocket frames to store when defragmenting an incoming fragmented 37 /// compressed message. 38 /// This setting may be different from the `fragments_capacity` setting of the WebSocket in order to 39 /// allow for differences between compressed and uncompressed messages. 40 /// Default: 10 41 pub fragments_capacity: usize, 42 /// Indicates whether the extension handler will reallocate if the `fragments_capacity` is 43 /// exceeded. If this is not true, a capacity error will be triggered instead. 44 /// Default: true 45 pub fragments_grow: bool, 46 } 47 48 impl Default for DeflateSettings { default() -> DeflateSettings49 fn default() -> DeflateSettings { 50 DeflateSettings { 51 max_window_bits: 15, 52 request_no_context_takeover: false, 53 accept_no_context_takeover: true, 54 fragments_capacity: 10, 55 fragments_grow: true, 56 } 57 } 58 } 59 60 /// Utility for applying the permessage-deflate extension to a handler with particular deflate 61 /// settings. 62 #[derive(Debug, Clone, Copy)] 63 pub struct DeflateBuilder { 64 settings: DeflateSettings, 65 } 66 67 impl DeflateBuilder { 68 /// Create a new DeflateBuilder with the default settings. new() -> DeflateBuilder69 pub fn new() -> DeflateBuilder { 70 DeflateBuilder { 71 settings: DeflateSettings::default(), 72 } 73 } 74 75 /// Configure the DeflateBuilder with the given deflate settings. with_settings(&mut self, settings: DeflateSettings) -> &mut DeflateBuilder76 pub fn with_settings(&mut self, settings: DeflateSettings) -> &mut DeflateBuilder { 77 self.settings = settings; 78 self 79 } 80 81 /// Wrap another handler in with a deflate handler as configured. build<H: Handler>(&self, handler: H) -> DeflateHandler<H>82 pub fn build<H: Handler>(&self, handler: H) -> DeflateHandler<H> { 83 DeflateHandler { 84 com: Compressor::new(self.settings.max_window_bits as i8), 85 dec: Decompressor::new(self.settings.max_window_bits as i8), 86 fragments: Vec::with_capacity(self.settings.fragments_capacity), 87 compress_reset: false, 88 decompress_reset: false, 89 pass: false, 90 settings: self.settings, 91 inner: handler, 92 } 93 } 94 } 95 96 /// A WebSocket handler that implements the permessage-deflate extension. 97 /// 98 /// This handler wraps a child handler and proxies all handler methods to it. The handler will 99 /// decompress incoming WebSocket message frames in their reserved bits match the 100 /// permessage-deflate specification and pass them to the child handler. Message frames sent from 101 /// the child handler will be compressed and sent to the other endpoint using deflate compression. 102 pub struct DeflateHandler<H: Handler> { 103 com: Compressor, 104 dec: Decompressor, 105 fragments: Vec<Frame>, 106 compress_reset: bool, 107 decompress_reset: bool, 108 pass: bool, 109 settings: DeflateSettings, 110 inner: H, 111 } 112 113 impl<H: Handler> DeflateHandler<H> { 114 /// Wrap a child handler to provide the permessage-deflate extension. new(handler: H) -> DeflateHandler<H>115 pub fn new(handler: H) -> DeflateHandler<H> { 116 trace!("Using permessage-deflate handler."); 117 let settings = DeflateSettings::default(); 118 DeflateHandler { 119 com: Compressor::new(settings.max_window_bits as i8), 120 dec: Decompressor::new(settings.max_window_bits as i8), 121 fragments: Vec::with_capacity(settings.fragments_capacity), 122 compress_reset: false, 123 decompress_reset: false, 124 pass: false, 125 settings: settings, 126 inner: handler, 127 } 128 } 129 130 #[doc(hidden)] 131 #[inline] decline(&mut self, mut res: Response) -> Result<Response>132 fn decline(&mut self, mut res: Response) -> Result<Response> { 133 trace!("Declined permessage-deflate offer"); 134 self.pass = true; 135 res.remove_extension("permessage-deflate"); 136 Ok(res) 137 } 138 } 139 140 impl<H: Handler> Handler for DeflateHandler<H> { build_request(&mut self, url: &url::Url) -> Result<Request>141 fn build_request(&mut self, url: &url::Url) -> Result<Request> { 142 let mut req = self.inner.build_request(url)?; 143 let mut req_ext = String::with_capacity(100); 144 req_ext.push_str("permessage-deflate"); 145 if self.settings.max_window_bits < 15 { 146 req_ext.push_str(&format!( 147 "; client_max_window_bits={}; server_max_window_bits={}", 148 self.settings.max_window_bits, self.settings.max_window_bits 149 )) 150 } else { 151 req_ext.push_str("; client_max_window_bits") 152 } 153 if self.settings.request_no_context_takeover { 154 req_ext.push_str("; server_no_context_takeover") 155 } 156 req.add_extension(&req_ext); 157 Ok(req) 158 } 159 on_request(&mut self, req: &Request) -> Result<Response>160 fn on_request(&mut self, req: &Request) -> Result<Response> { 161 let mut res = self.inner.on_request(req)?; 162 163 'ext: for req_ext in req.extensions()? 164 .iter() 165 .filter(|&&ext| ext.contains("permessage-deflate")) 166 { 167 let mut res_ext = String::with_capacity(req_ext.len()); 168 let mut s_takeover = false; 169 let mut c_takeover = false; 170 let mut s_max = false; 171 let mut c_max = false; 172 173 for param in req_ext.split(';') { 174 match param.trim() { 175 "permessage-deflate" => res_ext.push_str("permessage-deflate"), 176 "server_no_context_takeover" => { 177 if s_takeover { 178 return self.decline(res); 179 } else { 180 s_takeover = true; 181 if self.settings.accept_no_context_takeover { 182 self.compress_reset = true; 183 res_ext.push_str("; server_no_context_takeover"); 184 } else { 185 continue 'ext; 186 } 187 } 188 } 189 "client_no_context_takeover" => { 190 if c_takeover { 191 return self.decline(res); 192 } else { 193 c_takeover = true; 194 self.decompress_reset = true; 195 res_ext.push_str("; client_no_context_takeover"); 196 } 197 } 198 param if param.starts_with("server_max_window_bits") => { 199 if s_max { 200 return self.decline(res); 201 } else { 202 s_max = true; 203 let mut param_iter = param.split('='); 204 param_iter.next(); // we already know the name 205 if let Some(window_bits_str) = param_iter.next() { 206 if let Ok(window_bits) = window_bits_str.trim().parse() { 207 if window_bits >= 9 && window_bits <= 15 { 208 if window_bits < self.settings.max_window_bits as i8 { 209 self.com = Compressor::new(window_bits); 210 res_ext.push_str("; "); 211 res_ext.push_str(param) 212 } 213 } else { 214 return self.decline(res); 215 } 216 } else { 217 return self.decline(res); 218 } 219 } 220 } 221 } 222 param if param.starts_with("client_max_window_bits") => { 223 if c_max { 224 return self.decline(res); 225 } else { 226 c_max = true; 227 let mut param_iter = param.split('='); 228 param_iter.next(); // we already know the name 229 if let Some(window_bits_str) = param_iter.next() { 230 if let Ok(window_bits) = window_bits_str.trim().parse() { 231 if window_bits >= 9 && window_bits <= 15 { 232 if window_bits < self.settings.max_window_bits as i8 { 233 self.dec = Decompressor::new(window_bits); 234 res_ext.push_str("; "); 235 res_ext.push_str(param); 236 continue; 237 } 238 } else { 239 return self.decline(res); 240 } 241 } else { 242 return self.decline(res); 243 } 244 } 245 res_ext.push_str("; "); 246 res_ext.push_str(&format!( 247 "client_max_window_bits={}", 248 self.settings.max_window_bits 249 )) 250 } 251 } 252 _ => { 253 // decline all extension offers because we got a bad parameter 254 return self.decline(res); 255 } 256 } 257 } 258 259 if !res_ext.contains("client_no_context_takeover") 260 && self.settings.request_no_context_takeover 261 { 262 self.decompress_reset = true; 263 res_ext.push_str("; client_no_context_takeover"); 264 } 265 266 if !res_ext.contains("server_max_window_bits") { 267 res_ext.push_str("; "); 268 res_ext.push_str(&format!( 269 "server_max_window_bits={}", 270 self.settings.max_window_bits 271 )) 272 } 273 274 if !res_ext.contains("client_max_window_bits") && self.settings.max_window_bits < 15 { 275 continue; 276 } 277 278 res.add_extension(&res_ext); 279 return Ok(res); 280 } 281 self.decline(res) 282 } 283 on_response(&mut self, res: &Response) -> Result<()>284 fn on_response(&mut self, res: &Response) -> Result<()> { 285 if let Some(res_ext) = res.extensions()? 286 .iter() 287 .find(|&&ext| ext.contains("permessage-deflate")) 288 { 289 let mut name = false; 290 let mut s_takeover = false; 291 let mut c_takeover = false; 292 let mut s_max = false; 293 let mut c_max = false; 294 295 for param in res_ext.split(';') { 296 match param.trim() { 297 "permessage-deflate" => { 298 if name { 299 return Err(Error::new( 300 Kind::Protocol, 301 format!("Duplicate extension name permessage-deflate"), 302 )); 303 } else { 304 name = true; 305 } 306 } 307 "server_no_context_takeover" => { 308 if s_takeover { 309 return Err(Error::new( 310 Kind::Protocol, 311 format!("Duplicate extension parameter server_no_context_takeover"), 312 )); 313 } else { 314 s_takeover = true; 315 self.decompress_reset = true; 316 } 317 } 318 "client_no_context_takeover" => { 319 if c_takeover { 320 return Err(Error::new( 321 Kind::Protocol, 322 format!("Duplicate extension parameter client_no_context_takeover"), 323 )); 324 } else { 325 c_takeover = true; 326 if self.settings.accept_no_context_takeover { 327 self.compress_reset = true; 328 } else { 329 return Err(Error::new( 330 Kind::Protocol, 331 format!("The client requires context takeover."), 332 )); 333 } 334 } 335 } 336 param if param.starts_with("server_max_window_bits") => { 337 if s_max { 338 return Err(Error::new( 339 Kind::Protocol, 340 format!("Duplicate extension parameter server_max_window_bits"), 341 )); 342 } else { 343 s_max = true; 344 let mut param_iter = param.split('='); 345 param_iter.next(); // we already know the name 346 if let Some(window_bits_str) = param_iter.next() { 347 if let Ok(window_bits) = window_bits_str.trim().parse() { 348 if window_bits >= 9 && window_bits <= 15 { 349 if window_bits as u8 != self.settings.max_window_bits { 350 self.dec = Decompressor::new(window_bits); 351 } 352 } else { 353 return Err(Error::new( 354 Kind::Protocol, 355 format!( 356 "Invalid server_max_window_bits parameter: {}", 357 window_bits 358 ), 359 )); 360 } 361 } else { 362 return Err(Error::new( 363 Kind::Protocol, 364 format!( 365 "Invalid server_max_window_bits parameter: {}", 366 window_bits_str 367 ), 368 )); 369 } 370 } 371 } 372 } 373 param if param.starts_with("client_max_window_bits") => { 374 if c_max { 375 return Err(Error::new( 376 Kind::Protocol, 377 format!("Duplicate extension parameter client_max_window_bits"), 378 )); 379 } else { 380 c_max = true; 381 let mut param_iter = param.split('='); 382 param_iter.next(); // we already know the name 383 if let Some(window_bits_str) = param_iter.next() { 384 if let Ok(window_bits) = window_bits_str.trim().parse() { 385 if window_bits >= 9 && window_bits <= 15 { 386 if window_bits as u8 != self.settings.max_window_bits { 387 self.com = Compressor::new(window_bits); 388 } 389 } else { 390 return Err(Error::new( 391 Kind::Protocol, 392 format!( 393 "Invalid client_max_window_bits parameter: {}", 394 window_bits 395 ), 396 )); 397 } 398 } else { 399 return Err(Error::new( 400 Kind::Protocol, 401 format!( 402 "Invalid client_max_window_bits parameter: {}", 403 window_bits_str 404 ), 405 )); 406 } 407 } 408 } 409 } 410 param => { 411 // fail the connection because we got a bad parameter 412 return Err(Error::new( 413 Kind::Protocol, 414 format!("Bad extension parameter: {}", param), 415 )); 416 } 417 } 418 } 419 } else { 420 self.pass = true 421 } 422 423 Ok(()) 424 } 425 on_frame(&mut self, mut frame: Frame) -> Result<Option<Frame>>426 fn on_frame(&mut self, mut frame: Frame) -> Result<Option<Frame>> { 427 if !self.pass && !frame.is_control() { 428 if !self.fragments.is_empty() || frame.has_rsv1() { 429 frame.set_rsv1(false); 430 431 if !frame.is_final() { 432 self.fragments.push(frame); 433 return Ok(None); 434 } else { 435 if frame.opcode() == OpCode::Continue { 436 if self.fragments.is_empty() { 437 return Err(Error::new( 438 Kind::Protocol, 439 "Unable to reconstruct fragmented message. No first frame.", 440 )); 441 } else { 442 if !self.settings.fragments_grow 443 && self.settings.fragments_capacity == self.fragments.len() 444 { 445 return Err(Error::new(Kind::Capacity, "Exceeded max fragments.")); 446 } else { 447 self.fragments.push(frame); 448 } 449 450 // it's safe to unwrap because of the above check for empty 451 let opcode = self.fragments.first().unwrap().opcode(); 452 let size = self.fragments 453 .iter() 454 .fold(0, |len, frame| len + frame.payload().len()); 455 let mut compressed = Vec::with_capacity(size); 456 let mut decompressed = Vec::with_capacity(size * 2); 457 for frag in replace( 458 &mut self.fragments, 459 Vec::with_capacity(self.settings.fragments_capacity), 460 ) { 461 compressed.extend(frag.into_data()) 462 } 463 464 compressed.extend(&[0, 0, 255, 255]); 465 self.dec.decompress(&compressed, &mut decompressed)?; 466 frame = Frame::message(decompressed, opcode, true); 467 } 468 } else { 469 let mut decompressed = Vec::with_capacity(frame.payload().len() * 2); 470 frame.payload_mut().extend(&[0, 0, 255, 255]); 471 472 self.dec.decompress(frame.payload(), &mut decompressed)?; 473 474 *frame.payload_mut() = decompressed; 475 } 476 477 if self.decompress_reset { 478 self.dec.reset()? 479 } 480 } 481 } 482 } 483 self.inner.on_frame(frame) 484 } 485 on_send_frame(&mut self, frame: Frame) -> Result<Option<Frame>>486 fn on_send_frame(&mut self, frame: Frame) -> Result<Option<Frame>> { 487 if let Some(mut frame) = self.inner.on_send_frame(frame)? { 488 if !self.pass && !frame.is_control() { 489 debug_assert!( 490 frame.is_final(), 491 "Received non-final frame from upstream handler!" 492 ); 493 debug_assert!( 494 frame.opcode() != OpCode::Continue, 495 "Received continue frame from upstream handler!" 496 ); 497 498 frame.set_rsv1(true); 499 let mut compressed = Vec::with_capacity(frame.payload().len()); 500 self.com.compress(frame.payload(), &mut compressed)?; 501 let len = compressed.len(); 502 compressed.truncate(len - 4); 503 *frame.payload_mut() = compressed; 504 505 if self.compress_reset { 506 self.com.reset()? 507 } 508 } 509 Ok(Some(frame)) 510 } else { 511 Ok(None) 512 } 513 } 514 515 #[inline] on_shutdown(&mut self)516 fn on_shutdown(&mut self) { 517 self.inner.on_shutdown() 518 } 519 520 #[inline] on_open(&mut self, shake: Handshake) -> Result<()>521 fn on_open(&mut self, shake: Handshake) -> Result<()> { 522 self.inner.on_open(shake) 523 } 524 525 #[inline] on_message(&mut self, msg: Message) -> Result<()>526 fn on_message(&mut self, msg: Message) -> Result<()> { 527 self.inner.on_message(msg) 528 } 529 530 #[inline] on_close(&mut self, code: CloseCode, reason: &str)531 fn on_close(&mut self, code: CloseCode, reason: &str) { 532 self.inner.on_close(code, reason) 533 } 534 535 #[inline] on_error(&mut self, err: Error)536 fn on_error(&mut self, err: Error) { 537 self.inner.on_error(err) 538 } 539 540 #[inline] on_timeout(&mut self, event: Token) -> Result<()>541 fn on_timeout(&mut self, event: Token) -> Result<()> { 542 self.inner.on_timeout(event) 543 } 544 545 #[inline] on_new_timeout(&mut self, tok: Token, timeout: Timeout) -> Result<()>546 fn on_new_timeout(&mut self, tok: Token, timeout: Timeout) -> Result<()> { 547 self.inner.on_new_timeout(tok, timeout) 548 } 549 550 #[inline] 551 #[cfg(any(feature = "ssl", feature = "nativetls"))] upgrade_ssl_client( &mut self, stream: TcpStream, url: &url::Url, ) -> Result<SslStream<TcpStream>>552 fn upgrade_ssl_client( 553 &mut self, 554 stream: TcpStream, 555 url: &url::Url, 556 ) -> Result<SslStream<TcpStream>> { 557 self.inner.upgrade_ssl_client(stream, url) 558 } 559 560 #[inline] 561 #[cfg(any(feature = "ssl", feature = "nativetls"))] upgrade_ssl_server(&mut self, stream: TcpStream) -> Result<SslStream<TcpStream>>562 fn upgrade_ssl_server(&mut self, stream: TcpStream) -> Result<SslStream<TcpStream>> { 563 self.inner.upgrade_ssl_server(stream) 564 } 565 } 566