1 //! Client Connection Pooling
2 use std::borrow::ToOwned;
3 use std::collections::HashMap;
4 use std::fmt;
5 use std::io::{self, Read, Write};
6 use std::net::{SocketAddr, Shutdown};
7 use std::sync::{Arc, Mutex};
8 use std::sync::atomic::{AtomicBool, Ordering};
9 
10 use std::time::{Duration, Instant};
11 
12 use net::{NetworkConnector, NetworkStream, DefaultConnector};
13 use client::scheme::Scheme;
14 
15 use self::stale::{StaleCheck, Stale};
16 
17 /// The `NetworkConnector` that behaves as a connection pool used by hyper's `Client`.
18 pub struct Pool<C: NetworkConnector> {
19     connector: C,
20     inner: Arc<Mutex<PoolImpl<<C as NetworkConnector>::Stream>>>,
21     stale_check: Option<StaleCallback<C::Stream>>,
22 }
23 
24 /// Config options for the `Pool`.
25 #[derive(Debug)]
26 pub struct Config {
27     /// The maximum idle connections *per host*.
28     pub max_idle: usize,
29 }
30 
31 impl Default for Config {
32     #[inline]
default() -> Config33     fn default() -> Config {
34         Config {
35             max_idle: 5,
36         }
37     }
38 }
39 
40 // Because `Config` has all its properties public, it would be a breaking
41 // change to add new ones. Sigh.
42 #[derive(Debug)]
43 struct Config2 {
44     idle_timeout: Option<Duration>,
45     max_idle: usize,
46 }
47 
48 
49 #[derive(Debug)]
50 struct PoolImpl<S> {
51     conns: HashMap<Key, Vec<PooledStreamInner<S>>>,
52     config: Config2,
53 }
54 
55 type Key = (String, u16, Scheme);
56 
key<T: Into<Scheme>>(host: &str, port: u16, scheme: T) -> Key57 fn key<T: Into<Scheme>>(host: &str, port: u16, scheme: T) -> Key {
58     (host.to_owned(), port, scheme.into())
59 }
60 
61 impl Pool<DefaultConnector> {
62     /// Creates a `Pool` with a `DefaultConnector`.
63     #[inline]
new(config: Config) -> Pool<DefaultConnector>64     pub fn new(config: Config) -> Pool<DefaultConnector> {
65         Pool::with_connector(config, DefaultConnector::default())
66     }
67 }
68 
69 impl<C: NetworkConnector> Pool<C> {
70     /// Creates a `Pool` with a specified `NetworkConnector`.
71     #[inline]
with_connector(config: Config, connector: C) -> Pool<C>72     pub fn with_connector(config: Config, connector: C) -> Pool<C> {
73         Pool {
74             connector: connector,
75             inner: Arc::new(Mutex::new(PoolImpl {
76                 conns: HashMap::new(),
77                 config: Config2 {
78                     idle_timeout: None,
79                     max_idle: config.max_idle,
80                 }
81             })),
82             stale_check: None,
83         }
84     }
85 
86     /// Set a duration for how long an idle connection is still valid.
set_idle_timeout(&mut self, timeout: Option<Duration>)87     pub fn set_idle_timeout(&mut self, timeout: Option<Duration>) {
88         self.inner.lock().unwrap().config.idle_timeout = timeout;
89     }
90 
set_stale_check<F>(&mut self, callback: F) where F: Fn(StaleCheck<C::Stream>) -> Stale + Send + Sync + 'static91     pub fn set_stale_check<F>(&mut self, callback: F)
92     where F: Fn(StaleCheck<C::Stream>) -> Stale + Send + Sync + 'static {
93         self.stale_check = Some(Box::new(callback));
94     }
95 
96     /// Clear all idle connections from the Pool, closing them.
97     #[inline]
clear_idle(&mut self)98     pub fn clear_idle(&mut self) {
99         self.inner.lock().unwrap().conns.clear();
100     }
101 
102     // private
103 
checkout(&self, key: &Key) -> Option<PooledStreamInner<C::Stream>>104     fn checkout(&self, key: &Key) -> Option<PooledStreamInner<C::Stream>> {
105         while let Some(mut inner) = self.lookup(key) {
106             if let Some(ref stale_check) = self.stale_check {
107                 let dur = inner.idle.expect("idle is never missing inside pool").elapsed();
108                 let arg = stale::check(&mut inner.stream, dur);
109                 if stale_check(arg).is_stale() {
110                     trace!("ejecting stale connection");
111                     continue;
112                 }
113             }
114             return Some(inner);
115         }
116         None
117     }
118 
119 
lookup(&self, key: &Key) -> Option<PooledStreamInner<C::Stream>>120     fn lookup(&self, key: &Key) -> Option<PooledStreamInner<C::Stream>> {
121         let mut locked = self.inner.lock().unwrap();
122         let mut should_remove = false;
123         let deadline = locked.config.idle_timeout.map(|dur| Instant::now() - dur);
124         let inner = locked.conns.get_mut(key).and_then(|vec| {
125             while let Some(inner) = vec.pop() {
126                 should_remove = vec.is_empty();
127                 if let Some(deadline) = deadline {
128                     if inner.idle.expect("idle is never missing inside pool") < deadline {
129                         trace!("ejecting expired connection");
130                         continue;
131                     }
132                 }
133                 return Some(inner);
134             }
135             None
136         });
137         if should_remove {
138             locked.conns.remove(key);
139         }
140         inner
141     }
142 }
143 
144 impl<S> PoolImpl<S> {
reuse(&mut self, key: Key, conn: PooledStreamInner<S>)145     fn reuse(&mut self, key: Key, conn: PooledStreamInner<S>) {
146         trace!("reuse {:?}", key);
147         let conns = self.conns.entry(key).or_insert(vec![]);
148         if conns.len() < self.config.max_idle {
149             conns.push(conn);
150         }
151     }
152 }
153 
154 impl<C: NetworkConnector<Stream=S>, S: NetworkStream + Send> NetworkConnector for Pool<C> {
155     type Stream = PooledStream<S>;
connect(&self, host: &str, port: u16, scheme: &str) -> ::Result<PooledStream<S>>156     fn connect(&self, host: &str, port: u16, scheme: &str) -> ::Result<PooledStream<S>> {
157         let key = key(host, port, scheme);
158         let inner = match self.checkout(&key) {
159             Some(inner) => {
160                 trace!("Pool had connection, using");
161                 inner
162             },
163             None => PooledStreamInner {
164                 key: key.clone(),
165                 idle: None,
166                 stream: try!(self.connector.connect(host, port, scheme)),
167                 previous_response_expected_no_content: false,
168             }
169 
170         };
171         Ok(PooledStream {
172             has_read: false,
173             inner: Some(inner),
174             is_closed: AtomicBool::new(false),
175             pool: self.inner.clone(),
176         })
177     }
178 }
179 
180 type StaleCallback<S> = Box<Fn(StaleCheck<S>) -> Stale + Send + Sync + 'static>;
181 
182 // private on purpose
183 //
184 // Yes, I know! Shame on me! This hurts docs! And it means it only
185 // works with closures! I know!
186 //
187 // The thing is, this is experiemental. I'm not certain about the naming.
188 // Or other things. So I don't really want it in the docs, yet.
189 //
190 // As for only working with closures, that's fine. A closure is probably
191 // enough, and if it isn't, well you can grab the stream and duration and
192 // pass those to a function, and then figure out whether to call stale()
193 // or fresh() based on the return value.
194 //
195 // Point is, it's not that bad. And it's not ready to publicize.
196 mod stale {
197     use std::time::Duration;
198 
199     pub struct StaleCheck<'a, S: 'a> {
200         stream: &'a mut S,
201         duration: Duration,
202     }
203 
204     #[inline]
check<'a, S: 'a>(stream: &'a mut S, dur: Duration) -> StaleCheck<'a, S>205     pub fn check<'a, S: 'a>(stream: &'a mut S, dur: Duration) -> StaleCheck<'a, S> {
206         StaleCheck {
207             stream: stream,
208             duration: dur,
209         }
210     }
211 
212     impl<'a, S: 'a> StaleCheck<'a, S> {
stream(&mut self) -> &mut S213         pub fn stream(&mut self) -> &mut S {
214             self.stream
215         }
216 
idle_duration(&self) -> Duration217         pub fn idle_duration(&self) -> Duration {
218             self.duration
219         }
220 
stale(self) -> Stale221         pub fn stale(self) -> Stale {
222             Stale(true)
223         }
224 
fresh(self) -> Stale225         pub fn fresh(self) -> Stale {
226             Stale(false)
227         }
228     }
229 
230     pub struct Stale(bool);
231 
232 
233     impl Stale {
234         #[inline]
is_stale(self) -> bool235         pub fn is_stale(self) -> bool {
236             self.0
237         }
238     }
239 }
240 
241 
242 /// A Stream that will try to be returned to the Pool when dropped.
243 pub struct PooledStream<S> {
244     has_read: bool,
245     inner: Option<PooledStreamInner<S>>,
246     // mutated in &self methods
247     is_closed: AtomicBool,
248     pool: Arc<Mutex<PoolImpl<S>>>,
249 }
250 
251 // manual impl to add the 'static bound for 1.7 compat
252 impl<S> fmt::Debug for PooledStream<S> where S: fmt::Debug + 'static {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result253     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
254         fmt.debug_struct("PooledStream")
255            .field("inner", &self.inner)
256            .field("has_read", &self.has_read)
257            .field("is_closed", &self.is_closed.load(Ordering::Relaxed))
258            .field("pool", &self.pool)
259            .finish()
260     }
261 }
262 
263 impl<S: NetworkStream> PooledStream<S> {
264     /// Take the wrapped stream out of the pool completely.
into_inner(mut self) -> S265     pub fn into_inner(mut self) -> S {
266         self.inner.take().expect("PooledStream lost its inner stream").stream
267     }
268 
269     /// Gets a borrowed reference to the underlying stream.
get_ref(&self) -> &S270     pub fn get_ref(&self) -> &S {
271         &self.inner.as_ref().expect("PooledStream lost its inner stream").stream
272     }
273 
274     #[cfg(test)]
get_mut(&mut self) -> &mut S275     fn get_mut(&mut self) -> &mut S {
276         &mut self.inner.as_mut().expect("PooledStream lost its inner stream").stream
277     }
278 }
279 
280 #[derive(Debug)]
281 struct PooledStreamInner<S> {
282     key: Key,
283     idle: Option<Instant>,
284     stream: S,
285     previous_response_expected_no_content: bool,
286 }
287 
288 impl<S: NetworkStream> Read for PooledStream<S> {
289     #[inline]
read(&mut self, buf: &mut [u8]) -> io::Result<usize>290     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
291         let inner = self.inner.as_mut().unwrap();
292         let n = try!(inner.stream.read(buf));
293         if n == 0 {
294             // if the wrapped stream returns EOF (Ok(0)), that means the
295             // server has closed the stream. we must be sure this stream
296             // is dropped and not put back into the pool.
297             self.is_closed.store(true, Ordering::Relaxed);
298 
299             // if the stream has never read bytes before, then the pooled
300             // stream may have been disconnected by the server while
301             // we checked it back out
302             if !self.has_read && inner.idle.is_some() {
303                 // idle being some means this is a reused stream
304                 Err(io::Error::new(
305                     io::ErrorKind::ConnectionAborted,
306                     "Pooled stream disconnected"
307                 ))
308             } else {
309                 Ok(0)
310             }
311         } else {
312             self.has_read = true;
313             Ok(n)
314         }
315     }
316 }
317 
318 impl<S: NetworkStream> Write for PooledStream<S> {
319     #[inline]
write(&mut self, buf: &[u8]) -> io::Result<usize>320     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
321         self.inner.as_mut().unwrap().stream.write(buf)
322     }
323 
324     #[inline]
flush(&mut self) -> io::Result<()>325     fn flush(&mut self) -> io::Result<()> {
326         self.inner.as_mut().unwrap().stream.flush()
327     }
328 }
329 
330 impl<S: NetworkStream> NetworkStream for PooledStream<S> {
331     #[inline]
peer_addr(&mut self) -> io::Result<SocketAddr>332     fn peer_addr(&mut self) -> io::Result<SocketAddr> {
333         self.inner.as_mut().unwrap().stream.peer_addr()
334             .map_err(|e| {
335                 self.is_closed.store(true, Ordering::Relaxed);
336                 e
337             })
338     }
339 
340     #[inline]
set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()>341     fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
342         self.inner.as_ref().unwrap().stream.set_read_timeout(dur)
343             .map_err(|e| {
344                 self.is_closed.store(true, Ordering::Relaxed);
345                 e
346             })
347     }
348 
349     #[inline]
set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()>350     fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
351         self.inner.as_ref().unwrap().stream.set_write_timeout(dur)
352             .map_err(|e| {
353                 self.is_closed.store(true, Ordering::Relaxed);
354                 e
355             })
356     }
357 
358     #[inline]
close(&mut self, how: Shutdown) -> io::Result<()>359     fn close(&mut self, how: Shutdown) -> io::Result<()> {
360         self.is_closed.store(true, Ordering::Relaxed);
361         self.inner.as_mut().unwrap().stream.close(how)
362     }
363 
364     #[inline]
set_previous_response_expected_no_content(&mut self, expected: bool)365     fn set_previous_response_expected_no_content(&mut self, expected: bool) {
366         trace!("set_previous_response_expected_no_content {}", expected);
367         self.inner.as_mut().unwrap().previous_response_expected_no_content = expected;
368     }
369 
370     #[inline]
previous_response_expected_no_content(&self) -> bool371     fn previous_response_expected_no_content(&self) -> bool {
372         let answer = self.inner.as_ref().unwrap().previous_response_expected_no_content;
373         trace!("previous_response_expected_no_content {}", answer);
374         answer
375     }
376 }
377 
378 impl<S> Drop for PooledStream<S> {
drop(&mut self)379     fn drop(&mut self) {
380         let is_closed = self.is_closed.load(Ordering::Relaxed);
381         trace!("PooledStream.drop, is_closed={}", is_closed);
382         if !is_closed {
383             self.inner.take().map(|mut inner| {
384                 let now = Instant::now();
385                 inner.idle = Some(now);
386                 if let Ok(mut pool) = self.pool.lock() {
387                     pool.reuse(inner.key.clone(), inner);
388                 }
389                 // else poisoned, give up
390             });
391         }
392     }
393 }
394 
395 #[cfg(test)]
396 mod tests {
397     use std::net::Shutdown;
398     use std::io::Read;
399     use std::time::Duration;
400     use mock::{MockConnector};
401     use net::{NetworkConnector, NetworkStream};
402 
403     use super::{Pool, key};
404 
405     macro_rules! mocked {
406         () => ({
407             Pool::with_connector(Default::default(), MockConnector)
408         })
409     }
410 
411     #[test]
test_connect_and_drop()412     fn test_connect_and_drop() {
413         let mut pool = mocked!();
414         pool.set_idle_timeout(Some(Duration::from_millis(100)));
415         let key = key("127.0.0.1", 3000, "http");
416         let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
417         assert_eq!(stream.get_ref().id, 0);
418         stream.get_mut().id = 9;
419         drop(stream);
420         {
421             let locked = pool.inner.lock().unwrap();
422             assert_eq!(locked.conns.len(), 1);
423             assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
424         }
425         let stream = pool.connect("127.0.0.1", 3000, "http").unwrap(); //reused
426         assert_eq!(stream.get_ref().id, 9);
427         drop(stream);
428         {
429             let locked = pool.inner.lock().unwrap();
430             assert_eq!(locked.conns.len(), 1);
431             assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
432         }
433     }
434 
435     #[test]
test_double_connect_reuse()436     fn test_double_connect_reuse() {
437         let mut pool = mocked!();
438         pool.set_idle_timeout(Some(Duration::from_millis(100)));
439         let key = key("127.0.0.1", 3000, "http");
440         let stream1 = pool.connect("127.0.0.1", 3000, "http").unwrap();
441         let stream2 = pool.connect("127.0.0.1", 3000, "http").unwrap();
442         drop(stream1);
443         drop(stream2);
444         let stream1 = pool.connect("127.0.0.1", 3000, "http").unwrap();
445         {
446             let locked = pool.inner.lock().unwrap();
447             assert_eq!(locked.conns.len(), 1);
448             assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
449         }
450         let _ = stream1;
451     }
452 
453     #[test]
test_closed()454     fn test_closed() {
455         let pool = mocked!();
456         let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
457         stream.close(Shutdown::Both).unwrap();
458         drop(stream);
459         let locked = pool.inner.lock().unwrap();
460         assert_eq!(locked.conns.len(), 0);
461     }
462 
463     #[test]
test_eof_closes()464     fn test_eof_closes() {
465         let pool = mocked!();
466 
467         let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
468         assert_eq!(stream.read(&mut [0]).unwrap(), 0);
469         drop(stream);
470         let locked = pool.inner.lock().unwrap();
471         assert_eq!(locked.conns.len(), 0);
472     }
473 
474     #[test]
test_read_conn_aborted()475     fn test_read_conn_aborted() {
476         let pool = mocked!();
477 
478         pool.connect("127.0.0.1", 3000, "http").unwrap();
479         let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
480         let err = stream.read(&mut [0]).unwrap_err();
481         assert_eq!(err.kind(), ::std::io::ErrorKind::ConnectionAborted);
482         drop(stream);
483         let locked = pool.inner.lock().unwrap();
484         assert_eq!(locked.conns.len(), 0);
485     }
486 
487     #[test]
test_idle_timeout()488     fn test_idle_timeout() {
489         let mut pool = mocked!();
490         pool.set_idle_timeout(Some(Duration::from_millis(10)));
491         let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
492         assert_eq!(stream.get_ref().id, 0);
493         stream.get_mut().id = 1337;
494         drop(stream);
495         ::std::thread::sleep(Duration::from_millis(100));
496         let stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
497         assert_eq!(stream.get_ref().id, 0);
498     }
499 }
500