1 use std::io::{self, Read, Write, Cursor};
2 use std::net::{SocketAddr, Shutdown};
3 use std::time::Duration;
4 use std::cell::Cell;
5 
6 use net::{NetworkStream, NetworkConnector, SslClient};
7 
8 #[derive(Clone, Debug)]
9 pub struct MockStream {
10     pub read: Cursor<Vec<u8>>,
11     next_reads: Vec<Vec<u8>>,
12     pub write: Vec<u8>,
13     pub is_closed: bool,
14     pub error_on_write: bool,
15     pub error_on_read: bool,
16     pub read_timeout: Cell<Option<Duration>>,
17     pub write_timeout: Cell<Option<Duration>>,
18     pub id: u64,
19 }
20 
21 impl PartialEq for MockStream {
eq(&self, other: &MockStream) -> bool22     fn eq(&self, other: &MockStream) -> bool {
23         self.read.get_ref() == other.read.get_ref() && self.write == other.write
24     }
25 }
26 
27 impl MockStream {
new() -> MockStream28     pub fn new() -> MockStream {
29         MockStream::with_input(b"")
30     }
31 
with_input(input: &[u8]) -> MockStream32     pub fn with_input(input: &[u8]) -> MockStream {
33         MockStream::with_responses(vec![input])
34     }
35 
with_responses(mut responses: Vec<&[u8]>) -> MockStream36     pub fn with_responses(mut responses: Vec<&[u8]>) -> MockStream {
37         MockStream {
38             read: Cursor::new(responses.remove(0).to_vec()),
39             next_reads: responses.into_iter().map(|arr| arr.to_vec()).collect(),
40             write: vec![],
41             is_closed: false,
42             error_on_write: false,
43             error_on_read: false,
44             read_timeout: Cell::new(None),
45             write_timeout: Cell::new(None),
46             id: 0,
47         }
48     }
49 }
50 
51 impl Read for MockStream {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>52     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
53         if self.error_on_read {
54             Err(io::Error::new(io::ErrorKind::Other, "mock error"))
55         } else {
56             match self.read.read(buf) {
57                 Ok(n) => {
58                     if self.read.position() as usize == self.read.get_ref().len() {
59                         if self.next_reads.len() > 0 {
60                             self.read = Cursor::new(self.next_reads.remove(0));
61                         }
62                     }
63                     Ok(n)
64                 },
65                 r => r
66             }
67         }
68     }
69 }
70 
71 impl Write for MockStream {
write(&mut self, msg: &[u8]) -> io::Result<usize>72     fn write(&mut self, msg: &[u8]) -> io::Result<usize> {
73         if self.error_on_write {
74             Err(io::Error::new(io::ErrorKind::Other, "mock error"))
75         } else {
76             Write::write(&mut self.write, msg)
77         }
78     }
79 
flush(&mut self) -> io::Result<()>80     fn flush(&mut self) -> io::Result<()> {
81         Ok(())
82     }
83 }
84 
85 impl NetworkStream for MockStream {
peer_addr(&mut self) -> io::Result<SocketAddr>86     fn peer_addr(&mut self) -> io::Result<SocketAddr> {
87         Ok("127.0.0.1:1337".parse().unwrap())
88     }
89 
set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()>90     fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
91         self.read_timeout.set(dur);
92         Ok(())
93     }
94 
set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()>95     fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
96         self.write_timeout.set(dur);
97         Ok(())
98     }
99 
close(&mut self, _how: Shutdown) -> io::Result<()>100     fn close(&mut self, _how: Shutdown) -> io::Result<()> {
101         self.is_closed = true;
102         Ok(())
103     }
104 }
105 
106 pub struct MockConnector;
107 
108 impl NetworkConnector for MockConnector {
109     type Stream = MockStream;
110 
connect(&self, _host: &str, _port: u16, _scheme: &str) -> ::Result<MockStream>111     fn connect(&self, _host: &str, _port: u16, _scheme: &str) -> ::Result<MockStream> {
112         Ok(MockStream::new())
113     }
114 }
115 
116 /// new connectors must be created if you wish to intercept requests.
117 macro_rules! mock_connector (
118     ($name:ident {
119         $($url:expr => $res:expr)*
120     }) => (
121 
122         struct $name;
123 
124         impl $crate::net::NetworkConnector for $name {
125             type Stream = ::mock::MockStream;
126             fn connect(&self, host: &str, port: u16, scheme: &str)
127                     -> $crate::Result<::mock::MockStream> {
128                 use std::collections::HashMap;
129                 debug!("MockStream::connect({:?}, {:?}, {:?})", host, port, scheme);
130                 let mut map = HashMap::new();
131                 $(map.insert($url, $res);)*
132 
133 
134                 let key = format!("{}://{}", scheme, host);
135                 // ignore port for now
136                 match map.get(&*key) {
137                     Some(&res) => Ok($crate::mock::MockStream::with_input(res.as_bytes())),
138                     None => panic!("{:?} doesn't know url {}", stringify!($name), key)
139                 }
140             }
141         }
142 
143     );
144 
145     ($name:ident { $($response:expr),+ }) => (
146         struct $name;
147 
148         impl $crate::net::NetworkConnector for $name {
149             type Stream = $crate::mock::MockStream;
150             fn connect(&self, _: &str, _: u16, _: &str)
151                     -> $crate::Result<$crate::mock::MockStream> {
152                 Ok($crate::mock::MockStream::with_responses(vec![
153                     $($response),+
154                 ]))
155             }
156         }
157     );
158 );
159 
160 #[derive(Debug, Default)]
161 pub struct MockSsl;
162 
163 impl<T: NetworkStream + Send + Clone> SslClient<T> for MockSsl {
164     type Stream = T;
wrap_client(&self, stream: T, _host: &str) -> ::Result<T>165     fn wrap_client(&self, stream: T, _host: &str) -> ::Result<T> {
166         Ok(stream)
167     }
168 }
169