1 //! Provides an implementation of a WebSocket server
2 use crate::server::upgrade::sync::{Buffer, IntoWs, Upgrade};
3 pub use crate::server::upgrade::{HyperIntoWsError, Request};
4 use crate::server::{InvalidConnection, NoTlsAcceptor, OptionalTlsAcceptor, WsServer};
5 #[cfg(feature = "sync-ssl")]
6 use native_tls::{TlsAcceptor, TlsStream};
7 use std::convert::Into;
8 use std::io;
9 use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
10 
11 #[cfg(feature = "async")]
12 use crate::server::r#async;
13 #[cfg(feature = "async")]
14 use tokio_reactor::Handle;
15 #[cfg(feature = "async")]
16 use tokio_tcp::TcpListener as AsyncTcpListener;
17 
18 /// Either the stream was established and it sent a websocket handshake
19 /// which represents the `Ok` variant, or there was an error (this is the
20 /// `Err` variant).
21 pub type AcceptResult<S> = Result<Upgrade<S>, InvalidConnection<S, Buffer>>;
22 
23 /// Represents a WebSocket server which can work with either normal
24 /// (non-secure) connections, or secure WebSocket connections.
25 ///
26 /// This is a convenient way to implement WebSocket servers, however
27 /// it is possible to use any sendable Reader and Writer to obtain
28 /// a WebSocketClient, so if needed, an alternative server implementation can be used.
29 pub type Server<S> = WsServer<S, TcpListener>;
30 
31 /// Synchronous methods for creating a server and accepting incoming connections.
32 impl<S> WsServer<S, TcpListener>
33 where
34 	S: OptionalTlsAcceptor,
35 {
36 	/// Get the socket address of this server
local_addr(&self) -> io::Result<SocketAddr>37 	pub fn local_addr(&self) -> io::Result<SocketAddr> {
38 		self.listener.local_addr()
39 	}
40 
41 	/// Changes whether the Server is in nonblocking mode.
42 	/// NOTE: It is strongly encouraged to use the `websocket::async` module instead
43 	/// of this. It provides high level APIs for creating asynchronous servers.
44 	///
45 	/// If it is in nonblocking mode, accept() will return an error instead of
46 	/// blocking when there are no incoming connections.
47 	///
48 	///```no_run
49 	/// # extern crate websocket;
50 	/// # use websocket::sync::Server;
51 	/// # fn main() {
52 	/// // Suppose we have to work in a single thread, but want to
53 	/// // accomplish two unrelated things:
54 	/// // (1) Once in a while we want to check if anybody tried to connect to
55 	/// // our websocket server, and if so, handle the TcpStream.
56 	/// // (2) In between we need to something else, possibly unrelated to networking.
57 	///
58 	/// let mut server = Server::bind("127.0.0.1:0").unwrap();
59 	///
60 	/// // Set the server to non-blocking.
61 	/// server.set_nonblocking(true);
62 	///
63 	/// for i in 1..3 {
64 	/// 	let result = match server.accept() {
65 	/// 		Ok(wsupgrade) => {
66 	/// 			// Do something with the established TcpStream.
67 	/// 		}
68 	/// 		_ => {
69 	/// 			// Nobody tried to connect, move on.
70 	/// 		}
71 	/// 	};
72 	/// 	// Perform another task. Because we have a non-blocking server,
73 	/// 	// this will execute independent of whether someone tried to
74 	/// 	// establish a connection.
75 	/// 	let two = 1+1;
76 	/// }
77 	/// # }
78 	///```
set_nonblocking(&self, nonblocking: bool) -> io::Result<()>79 	pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
80 		self.listener.set_nonblocking(nonblocking)
81 	}
82 
83 	/// Turns an existing synchronous server into an asynchronous one.
84 	/// This will only work if the stream used for this server `S` already implements
85 	/// `AsyncRead + AsyncWrite`. Useful if you would like some blocking things to happen
86 	/// at the start of your server.
87 	#[cfg(feature = "async")]
into_async(self, handle: &Handle) -> io::Result<r#async::Server<S>>88 	pub fn into_async(self, handle: &Handle) -> io::Result<r#async::Server<S>> {
89 		Ok(WsServer {
90 			listener: AsyncTcpListener::from_std(self.listener, handle)?,
91 			ssl_acceptor: self.ssl_acceptor,
92 		})
93 	}
94 }
95 
96 /// Synchronous methods for creating an SSL server and accepting incoming connections.
97 #[cfg(feature = "sync-ssl")]
98 impl WsServer<TlsAcceptor, TcpListener> {
99 	/// Bind this Server to this socket, utilising the given SslContext
100 	///
101 	/// # Secure Servers
102 	/// ```no_run
103 	/// extern crate websocket;
104 	/// extern crate native_tls;
105 	/// # fn main() {
106 	/// use std::thread;
107 	/// use std::io::Read;
108 	/// use std::fs::File;
109 	/// use websocket::Message;
110 	/// use websocket::sync::Server;
111 	/// use native_tls::{Identity, TlsAcceptor};
112 	///
113 	/// // In this example we retrieve our keypair and certificate chain from a PKCS #12 archive,
114 	/// // but but they can also be retrieved from, for example, individual PEM- or DER-formatted
115 	/// // files. See the documentation for the `PKey` and `X509` types for more details.
116 	/// let mut file = File::open("identity.pfx").unwrap();
117 	/// let mut pkcs12 = vec![];
118 	/// file.read_to_end(&mut pkcs12).unwrap();
119 	/// let pkcs12 = Identity::from_pkcs12(&pkcs12, "hacktheplanet").unwrap();
120 	///
121 	/// let acceptor = TlsAcceptor::builder(pkcs12).build().unwrap();
122 	///
123 	/// let server = Server::bind_secure("127.0.0.1:1234", acceptor).unwrap();
124 	///
125 	/// for connection in server.filter_map(Result::ok) {
126 	///     // Spawn a new thread for each connection.
127 	///     thread::spawn(move || {
128 	/// 		    let mut client = connection.accept().unwrap();
129 	///
130 	/// 		    let message = Message::text("Hello, client!");
131 	/// 		    let _ = client.send_message(&message);
132 	///
133 	/// 		    // ...
134 	///     });
135 	/// }
136 	/// # }
137 	/// ```
bind_secure<A>(addr: A, acceptor: TlsAcceptor) -> io::Result<Self> where A: ToSocketAddrs,138 	pub fn bind_secure<A>(addr: A, acceptor: TlsAcceptor) -> io::Result<Self>
139 	where
140 		A: ToSocketAddrs,
141 	{
142 		Ok(Server {
143 			listener: TcpListener::bind(&addr)?,
144 			ssl_acceptor: acceptor,
145 		})
146 	}
147 
148 	/// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest
accept(&mut self) -> AcceptResult<TlsStream<TcpStream>>149 	pub fn accept(&mut self) -> AcceptResult<TlsStream<TcpStream>> {
150 		let stream = match self.listener.accept() {
151 			Ok(s) => s.0,
152 			Err(e) => {
153 				return Err(InvalidConnection {
154 					stream: None,
155 					parsed: None,
156 					buffer: None,
157 					error: HyperIntoWsError::Io(e),
158 				});
159 			}
160 		};
161 
162 		let stream = match self.ssl_acceptor.accept(stream) {
163 			Ok(s) => s,
164 			Err(err) => {
165 				return Err(InvalidConnection {
166 					stream: None,
167 					parsed: None,
168 					buffer: None,
169 					error: io::Error::new(io::ErrorKind::Other, err).into(),
170 				});
171 			}
172 		};
173 
174 		match stream.into_ws() {
175 			Ok(u) => Ok(u),
176 			Err((s, r, b, e)) => Err(InvalidConnection {
177 				stream: Some(s),
178 				parsed: r,
179 				buffer: b,
180 				error: e,
181 			}),
182 		}
183 	}
184 }
185 
186 #[cfg(feature = "sync-ssl")]
187 impl Iterator for WsServer<TlsAcceptor, TcpListener> {
188 	type Item = AcceptResult<TlsStream<TcpStream>>;
189 
next(&mut self) -> Option<<Self as Iterator>::Item>190 	fn next(&mut self) -> Option<<Self as Iterator>::Item> {
191 		Some(self.accept())
192 	}
193 }
194 
195 impl WsServer<NoTlsAcceptor, TcpListener> {
196 	/// Bind this Server to this socket
197 	///
198 	/// # Non-secure Servers
199 	///
200 	/// ```no_run
201 	/// extern crate websocket;
202 	/// # fn main() {
203 	/// use std::thread;
204 	/// use websocket::Message;
205 	/// use websocket::sync::Server;
206 	///
207 	/// let server = Server::bind("127.0.0.1:1234").unwrap();
208 	///
209 	/// for connection in server.filter_map(Result::ok) {
210 	///     // Spawn a new thread for each connection.
211 	///     thread::spawn(move || {
212 	///		      let mut client = connection.accept().unwrap();
213 	///
214 	///		      let message = Message::text("Hello, client!");
215 	///		      let _ = client.send_message(&message);
216 	///
217 	///		      // ...
218 	///    });
219 	/// }
220 	/// # }
221 	/// ```
bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self>222 	pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
223 		Ok(Server {
224 			listener: TcpListener::bind(&addr)?,
225 			ssl_acceptor: NoTlsAcceptor,
226 		})
227 	}
228 
229 	/// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest
accept(&mut self) -> AcceptResult<TcpStream>230 	pub fn accept(&mut self) -> AcceptResult<TcpStream> {
231 		let stream = match self.listener.accept() {
232 			Ok(s) => s.0,
233 			Err(e) => {
234 				return Err(InvalidConnection {
235 					stream: None,
236 					parsed: None,
237 					buffer: None,
238 					error: e.into(),
239 				});
240 			}
241 		};
242 
243 		match stream.into_ws() {
244 			Ok(u) => Ok(u),
245 			Err((s, r, b, e)) => Err(InvalidConnection {
246 				stream: Some(s),
247 				parsed: r,
248 				buffer: b,
249 				error: e,
250 			}),
251 		}
252 	}
253 
254 	/// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> io::Result<Self>255 	pub fn try_clone(&self) -> io::Result<Self> {
256 		let inner = self.listener.try_clone()?;
257 		Ok(Server {
258 			listener: inner,
259 			ssl_acceptor: self.ssl_acceptor.clone(),
260 		})
261 	}
262 }
263 
264 impl Iterator for WsServer<NoTlsAcceptor, TcpListener> {
265 	type Item = AcceptResult<TcpStream>;
266 
next(&mut self) -> Option<<Self as Iterator>::Item>267 	fn next(&mut self) -> Option<<Self as Iterator>::Item> {
268 		Some(self.accept())
269 	}
270 }
271 
272 mod tests {
273 	#[test]
274 	// test the set_nonblocking() method for Server<NoSslAcceptor>.
275 	// Some of this is copied from
276 	// https://doc.rust-lang.org/src/std/net/tcp.rs.html#1413
set_nonblocking()277 	fn set_nonblocking() {
278 		use super::*;
279 
280 		// Test unsecure server
281 
282 		let mut server = Server::bind("127.0.0.1:0").unwrap();
283 
284 		// Note that if set_nonblocking() doesn't work, but the following
285 		// fails to panic for some reason, then the .accept() method below
286 		// will block indefinitely.
287 		server.set_nonblocking(true).unwrap();
288 
289 		let result = server.accept();
290 		match result {
291 			// nobody tried to establish a connection, so we expect an error
292 			Ok(_) => panic!("expected error"),
293 			Err(e) => match e.error {
294 				HyperIntoWsError::Io(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
295 				_ => panic!("unexpected error {}"),
296 			},
297 		}
298 	}
299 }
300