1 #[cfg(feature = "tls")]
2 use crate::tls::TlsConfigBuilder;
3 use std::convert::Infallible;
4 use std::error::Error as StdError;
5 use std::future::Future;
6 use std::net::SocketAddr;
7 #[cfg(feature = "tls")]
8 use std::path::Path;
9 
10 use futures::{future, FutureExt, TryFuture, TryStream, TryStreamExt};
11 use hyper::server::conn::AddrIncoming;
12 use hyper::service::{make_service_fn, service_fn};
13 use hyper::Server as HyperServer;
14 use tokio::io::{AsyncRead, AsyncWrite};
15 
16 use crate::filter::Filter;
17 use crate::reject::IsReject;
18 use crate::reply::Reply;
19 use crate::transport::Transport;
20 
21 /// Create a `Server` with the provided `Filter`.
serve<F>(filter: F) -> Server<F> where F: Filter + Clone + Send + Sync + 'static, F::Extract: Reply, F::Error: IsReject,22 pub fn serve<F>(filter: F) -> Server<F>
23 where
24     F: Filter + Clone + Send + Sync + 'static,
25     F::Extract: Reply,
26     F::Error: IsReject,
27 {
28     Server {
29         pipeline: false,
30         filter,
31     }
32 }
33 
34 /// A Warp Server ready to filter requests.
35 #[derive(Debug)]
36 pub struct Server<F> {
37     pipeline: bool,
38     filter: F,
39 }
40 
41 /// A Warp Server ready to filter requests over TLS.
42 ///
43 /// *This type requires the `"tls"` feature.*
44 #[cfg(feature = "tls")]
45 pub struct TlsServer<F> {
46     server: Server<F>,
47     tls: TlsConfigBuilder,
48 }
49 
50 // Getting all various generic bounds to make this a re-usable method is
51 // very complicated, so instead this is just a macro.
52 macro_rules! into_service {
53     ($into:expr) => {{
54         let inner = crate::service($into);
55         make_service_fn(move |transport| {
56             let inner = inner.clone();
57             let remote_addr = Transport::remote_addr(transport);
58             future::ok::<_, Infallible>(service_fn(move |req| {
59                 inner.call_with_addr(req, remote_addr)
60             }))
61         })
62     }};
63 }
64 
65 macro_rules! addr_incoming {
66     ($addr:expr) => {{
67         let mut incoming = AddrIncoming::bind($addr)?;
68         incoming.set_nodelay(true);
69         let addr = incoming.local_addr();
70         (addr, incoming)
71     }};
72 }
73 
74 macro_rules! bind_inner {
75     ($this:ident, $addr:expr) => {{
76         let service = into_service!($this.filter);
77         let (addr, incoming) = addr_incoming!($addr);
78         let srv = HyperServer::builder(incoming)
79             .http1_pipeline_flush($this.pipeline)
80             .serve(service);
81         Ok::<_, hyper::Error>((addr, srv))
82     }};
83 
84     (tls: $this:ident, $addr:expr) => {{
85         let service = into_service!($this.server.filter);
86         let (addr, incoming) = addr_incoming!($addr);
87         let tls = $this.tls.build()?;
88         let srv = HyperServer::builder(crate::tls::TlsAcceptor::new(tls, incoming))
89             .http1_pipeline_flush($this.server.pipeline)
90             .serve(service);
91         Ok::<_, Box<dyn std::error::Error + Send + Sync>>((addr, srv))
92     }};
93 }
94 
95 macro_rules! bind {
96     ($this:ident, $addr:expr) => {{
97         let addr = $addr.into();
98         (|addr| bind_inner!($this, addr))(&addr).unwrap_or_else(|e| {
99             panic!("error binding to {}: {}", addr, e);
100         })
101     }};
102 
103     (tls: $this:ident, $addr:expr) => {{
104         let addr = $addr.into();
105         (|addr| bind_inner!(tls: $this, addr))(&addr).unwrap_or_else(|e| {
106             panic!("error binding to {}: {}", addr, e);
107         })
108     }};
109 }
110 
111 macro_rules! try_bind {
112     ($this:ident, $addr:expr) => {{
113         (|addr| bind_inner!($this, addr))($addr)
114     }};
115 
116     (tls: $this:ident, $addr:expr) => {{
117         (|addr| bind_inner!(tls: $this, addr))($addr)
118     }};
119 }
120 
121 // ===== impl Server =====
122 
123 impl<F> Server<F>
124 where
125     F: Filter + Clone + Send + Sync + 'static,
126     <F::Future as TryFuture>::Ok: Reply,
127     <F::Future as TryFuture>::Error: IsReject,
128 {
129     /// Run this `Server` forever on the current thread.
run(self, addr: impl Into<SocketAddr> + 'static)130     pub async fn run(self, addr: impl Into<SocketAddr> + 'static) {
131         let (addr, fut) = self.bind_ephemeral(addr);
132 
133         log::info!("listening on http://{}", addr);
134 
135         fut.await;
136     }
137 
138     /// Run this `Server` forever on the current thread with a specific stream
139     /// of incoming connections.
140     ///
141     /// This can be used for Unix Domain Sockets, or TLS, etc.
run_incoming<I>(self, incoming: I) where I: TryStream + Send, I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin, I::Error: Into<Box<dyn StdError + Send + Sync>>,142     pub async fn run_incoming<I>(self, incoming: I)
143     where
144         I: TryStream + Send,
145         I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin,
146         I::Error: Into<Box<dyn StdError + Send + Sync>>,
147     {
148         self.run_incoming2(incoming.map_ok(crate::transport::LiftIo).into_stream())
149             .await;
150     }
151 
run_incoming2<I>(self, incoming: I) where I: TryStream + Send, I::Ok: Transport + Send + 'static + Unpin, I::Error: Into<Box<dyn StdError + Send + Sync>>,152     async fn run_incoming2<I>(self, incoming: I)
153     where
154         I: TryStream + Send,
155         I::Ok: Transport + Send + 'static + Unpin,
156         I::Error: Into<Box<dyn StdError + Send + Sync>>,
157     {
158         let fut = self.serve_incoming2(incoming);
159 
160         log::info!("listening with custom incoming");
161 
162         fut.await;
163     }
164 
165     /// Bind to a socket address, returning a `Future` that can be
166     /// executed on any runtime.
167     ///
168     /// # Panics
169     ///
170     /// Panics if we are unable to bind to the provided address.
bind(self, addr: impl Into<SocketAddr> + 'static) -> impl Future<Output = ()> + 'static171     pub fn bind(self, addr: impl Into<SocketAddr> + 'static) -> impl Future<Output = ()> + 'static {
172         let (_, fut) = self.bind_ephemeral(addr);
173         fut
174     }
175 
176     /// Bind to a socket address, returning a `Future` that can be
177     /// executed on any runtime.
178     ///
179     /// In case we are unable to bind to the specified address, resolves to an
180     /// error and logs the reason.
try_bind(self, addr: impl Into<SocketAddr> + 'static)181     pub async fn try_bind(self, addr: impl Into<SocketAddr> + 'static) {
182         let addr = addr.into();
183         let srv = match try_bind!(self, &addr) {
184             Ok((_, srv)) => srv,
185             Err(err) => {
186                 log::error!("error binding to {}: {}", addr, err);
187                 return;
188             }
189         };
190 
191         srv.map(|result| {
192             if let Err(err) = result {
193                 log::error!("server error: {}", err)
194             }
195         })
196         .await;
197     }
198 
199     /// Bind to a possibly ephemeral socket address.
200     ///
201     /// Returns the bound address and a `Future` that can be executed on
202     /// any runtime.
203     ///
204     /// # Panics
205     ///
206     /// Panics if we are unable to bind to the provided address.
bind_ephemeral( self, addr: impl Into<SocketAddr> + 'static, ) -> (SocketAddr, impl Future<Output = ()> + 'static)207     pub fn bind_ephemeral(
208         self,
209         addr: impl Into<SocketAddr> + 'static,
210     ) -> (SocketAddr, impl Future<Output = ()> + 'static) {
211         let (addr, srv) = bind!(self, addr);
212         let srv = srv.map(|result| {
213             if let Err(err) = result {
214                 log::error!("server error: {}", err)
215             }
216         });
217 
218         (addr, srv)
219     }
220 
221     /// Tried to bind a possibly ephemeral socket address.
222     ///
223     /// Returns a `Result` which fails in case we are unable to bind with the
224     /// underlying error.
225     ///
226     /// Returns the bound address and a `Future` that can be executed on
227     /// any runtime.
try_bind_ephemeral( self, addr: impl Into<SocketAddr> + 'static, ) -> Result<(SocketAddr, impl Future<Output = ()> + 'static), crate::Error>228     pub fn try_bind_ephemeral(
229         self,
230         addr: impl Into<SocketAddr> + 'static,
231     ) -> Result<(SocketAddr, impl Future<Output = ()> + 'static), crate::Error> {
232         let addr = addr.into();
233         let (addr, srv) = try_bind!(self, &addr).map_err(crate::Error::new)?;
234         let srv = srv.map(|result| {
235             if let Err(err) = result {
236                 log::error!("server error: {}", err)
237             }
238         });
239 
240         Ok((addr, srv))
241     }
242 
243     /// Create a server with graceful shutdown signal.
244     ///
245     /// When the signal completes, the server will start the graceful shutdown
246     /// process.
247     ///
248     /// Returns the bound address and a `Future` that can be executed on
249     /// any runtime.
250     ///
251     /// # Example
252     ///
253     /// ```no_run
254     /// use warp::Filter;
255     /// use futures::future::TryFutureExt;
256     /// use tokio::sync::oneshot;
257     ///
258     /// # fn main() {
259     /// let routes = warp::any()
260     ///     .map(|| "Hello, World!");
261     ///
262     /// let (tx, rx) = oneshot::channel();
263     ///
264     /// let (addr, server) = warp::serve(routes)
265     ///     .bind_with_graceful_shutdown(([127, 0, 0, 1], 3030), async {
266     ///          rx.await.ok();
267     ///     });
268     ///
269     /// // Spawn the server into a runtime
270     /// tokio::task::spawn(server);
271     ///
272     /// // Later, start the shutdown...
273     /// let _ = tx.send(());
274     /// # }
275     /// ```
bind_with_graceful_shutdown( self, addr: impl Into<SocketAddr> + 'static, signal: impl Future<Output = ()> + Send + 'static, ) -> (SocketAddr, impl Future<Output = ()> + 'static)276     pub fn bind_with_graceful_shutdown(
277         self,
278         addr: impl Into<SocketAddr> + 'static,
279         signal: impl Future<Output = ()> + Send + 'static,
280     ) -> (SocketAddr, impl Future<Output = ()> + 'static) {
281         let (addr, srv) = bind!(self, addr);
282         let fut = srv.with_graceful_shutdown(signal).map(|result| {
283             if let Err(err) = result {
284                 log::error!("server error: {}", err)
285             }
286         });
287         (addr, fut)
288     }
289 
290     /// Setup this `Server` with a specific stream of incoming connections.
291     ///
292     /// This can be used for Unix Domain Sockets, or TLS, etc.
293     ///
294     /// Returns a `Future` that can be executed on any runtime.
serve_incoming<I>(self, incoming: I) -> impl Future<Output = ()> + 'static where I: TryStream + Send + 'static, I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin, I::Error: Into<Box<dyn StdError + Send + Sync>>,295     pub fn serve_incoming<I>(self, incoming: I) -> impl Future<Output = ()> + 'static
296     where
297         I: TryStream + Send + 'static,
298         I::Ok: AsyncRead + AsyncWrite + Send + 'static + Unpin,
299         I::Error: Into<Box<dyn StdError + Send + Sync>>,
300     {
301         let incoming = incoming.map_ok(crate::transport::LiftIo);
302         self.serve_incoming2(incoming)
303     }
304 
serve_incoming2<I>(self, incoming: I) where I: TryStream + Send, I::Ok: Transport + Send + 'static + Unpin, I::Error: Into<Box<dyn StdError + Send + Sync>>,305     async fn serve_incoming2<I>(self, incoming: I)
306     where
307         I: TryStream + Send,
308         I::Ok: Transport + Send + 'static + Unpin,
309         I::Error: Into<Box<dyn StdError + Send + Sync>>,
310     {
311         let service = into_service!(self.filter);
312 
313         let srv = HyperServer::builder(hyper::server::accept::from_stream(incoming.into_stream()))
314             .http1_pipeline_flush(self.pipeline)
315             .serve(service)
316             .await;
317 
318         if let Err(err) = srv {
319             log::error!("server error: {}", err);
320         }
321     }
322 
323     // Generally shouldn't be used, as it can slow down non-pipelined responses.
324     //
325     // It's only real use is to make silly pipeline benchmarks look better.
326     #[doc(hidden)]
unstable_pipeline(mut self) -> Self327     pub fn unstable_pipeline(mut self) -> Self {
328         self.pipeline = true;
329         self
330     }
331 
332     /// Configure a server to use TLS.
333     ///
334     /// *This function requires the `"tls"` feature.*
335     #[cfg(feature = "tls")]
tls(self) -> TlsServer<F>336     pub fn tls(self) -> TlsServer<F> {
337         TlsServer {
338             server: self,
339             tls: TlsConfigBuilder::new(),
340         }
341     }
342 }
343 
344 // // ===== impl TlsServer =====
345 
346 #[cfg(feature = "tls")]
347 impl<F> TlsServer<F>
348 where
349     F: Filter + Clone + Send + Sync + 'static,
350     <F::Future as TryFuture>::Ok: Reply,
351     <F::Future as TryFuture>::Error: IsReject,
352 {
353     // TLS config methods
354 
355     /// Specify the file path to read the private key.
key_path(self, path: impl AsRef<Path>) -> Self356     pub fn key_path(self, path: impl AsRef<Path>) -> Self {
357         self.with_tls(|tls| tls.key_path(path))
358     }
359 
360     /// Specify the file path to read the certificate.
cert_path(self, path: impl AsRef<Path>) -> Self361     pub fn cert_path(self, path: impl AsRef<Path>) -> Self {
362         self.with_tls(|tls| tls.cert_path(path))
363     }
364 
365     /// Specify the in-memory contents of the private key.
key(self, key: impl AsRef<[u8]>) -> Self366     pub fn key(self, key: impl AsRef<[u8]>) -> Self {
367         self.with_tls(|tls| tls.key(key.as_ref()))
368     }
369 
370     /// Specify the in-memory contents of the certificate.
cert(self, cert: impl AsRef<[u8]>) -> Self371     pub fn cert(self, cert: impl AsRef<[u8]>) -> Self {
372         self.with_tls(|tls| tls.cert(cert.as_ref()))
373     }
374 
with_tls<Func>(self, func: Func) -> Self where Func: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,375     fn with_tls<Func>(self, func: Func) -> Self
376     where
377         Func: FnOnce(TlsConfigBuilder) -> TlsConfigBuilder,
378     {
379         let TlsServer { server, tls } = self;
380         let tls = func(tls);
381         TlsServer { server, tls }
382     }
383 
384     // Server run methods
385 
386     /// Run this `TlsServer` forever on the current thread.
387     ///
388     /// *This function requires the `"tls"` feature.*
run(self, addr: impl Into<SocketAddr> + 'static)389     pub async fn run(self, addr: impl Into<SocketAddr> + 'static) {
390         let (addr, fut) = self.bind_ephemeral(addr);
391 
392         log::info!("listening on https://{}", addr);
393 
394         fut.await;
395     }
396 
397     /// Bind to a socket address, returning a `Future` that can be
398     /// executed on a runtime.
399     ///
400     /// *This function requires the `"tls"` feature.*
bind(self, addr: impl Into<SocketAddr> + 'static)401     pub async fn bind(self, addr: impl Into<SocketAddr> + 'static) {
402         let (_, fut) = self.bind_ephemeral(addr);
403         fut.await;
404     }
405 
406     /// Bind to a possibly ephemeral socket address.
407     ///
408     /// Returns the bound address and a `Future` that can be executed on
409     /// any runtime.
410     ///
411     /// *This function requires the `"tls"` feature.*
bind_ephemeral( self, addr: impl Into<SocketAddr> + 'static, ) -> (SocketAddr, impl Future<Output = ()> + 'static)412     pub fn bind_ephemeral(
413         self,
414         addr: impl Into<SocketAddr> + 'static,
415     ) -> (SocketAddr, impl Future<Output = ()> + 'static) {
416         let (addr, srv) = bind!(tls: self, addr);
417         let srv = srv.map(|result| {
418             if let Err(err) = result {
419                 log::error!("server error: {}", err)
420             }
421         });
422 
423         (addr, srv)
424     }
425 
426     /// Create a server with graceful shutdown signal.
427     ///
428     /// When the signal completes, the server will start the graceful shutdown
429     /// process.
430     ///
431     /// *This function requires the `"tls"` feature.*
bind_with_graceful_shutdown( self, addr: impl Into<SocketAddr> + 'static, signal: impl Future<Output = ()> + Send + 'static, ) -> (SocketAddr, impl Future<Output = ()> + 'static)432     pub fn bind_with_graceful_shutdown(
433         self,
434         addr: impl Into<SocketAddr> + 'static,
435         signal: impl Future<Output = ()> + Send + 'static,
436     ) -> (SocketAddr, impl Future<Output = ()> + 'static) {
437         let (addr, srv) = bind!(tls: self, addr);
438 
439         let fut = srv.with_graceful_shutdown(signal).map(|result| {
440             if let Err(err) = result {
441                 log::error!("server error: {}", err)
442             }
443         });
444         (addr, fut)
445     }
446 }
447 
448 #[cfg(feature = "tls")]
449 impl<F> ::std::fmt::Debug for TlsServer<F>
450 where
451     F: ::std::fmt::Debug,
452 {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result453     fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
454         f.debug_struct("TlsServer")
455             .field("server", &self.server)
456             .finish()
457     }
458 }
459