1 //! Test utilities to test your filters.
2 //!
3 //! [`Filter`](../trait.Filter.html)s can be easily tested without starting up an HTTP
4 //! server, by making use of the [`RequestBuilder`](./struct.RequestBuilder.html) in this
5 //! module.
6 //!
7 //! # Testing Filters
8 //!
9 //! It's easy to test filters, especially if smaller filters are used to build
10 //! up your full set. Consider these example filters:
11 //!
12 //! ```
13 //! use warp::Filter;
14 //!
15 //! fn sum() -> impl Filter<Extract = (u32,), Error = warp::Rejection> + Copy {
16 //!     warp::path::param()
17 //!         .and(warp::path::param())
18 //!         .map(|x: u32, y: u32| {
19 //!             x + y
20 //!         })
21 //! }
22 //!
23 //! fn math() -> impl Filter<Extract = (String,), Error = warp::Rejection> + Copy {
24 //!     warp::post()
25 //!         .and(sum())
26 //!         .map(|z: u32| {
27 //!             format!("Sum = {}", z)
28 //!         })
29 //! }
30 //! ```
31 //!
32 //! We can test some requests against the `sum` filter like this:
33 //!
34 //! ```
35 //! # use warp::Filter;
36 //! #[test]
37 //! fn test_sum() {
38 //! #    let sum = || warp::any().map(|| 3);
39 //!     let filter = sum();
40 //!
41 //!     // Execute `sum` and get the `Extract` back.
42 //!     let value = warp::test::request()
43 //!         .path("/1/2")
44 //!         .filter(&filter)
45 //!         .unwrap();
46 //!     assert_eq!(value, 3);
47 //!
48 //!     // Or simply test if a request matches (doesn't reject).
49 //!     assert!(
50 //!         !warp::test::request()
51 //!             .path("/1/-5")
52 //!             .matches(&filter)
53 //!     );
54 //! }
55 //! ```
56 //!
57 //! If the filter returns something that implements `Reply`, and thus can be
58 //! turned into a response sent back to the client, we can test what exact
59 //! response is returned. The `math` filter uses the `sum` filter, but returns
60 //! a `String` that can be turned into a response.
61 //!
62 //! ```
63 //! # use warp::Filter;
64 //! #[test]
65 //! fn test_math() {
66 //! #    let math = || warp::any().map(warp::reply);
67 //!     let filter = math();
68 //!
69 //!     let res = warp::test::request()
70 //!         .path("/1/2")
71 //!         .reply(&filter);
72 //!     assert_eq!(res.status(), 405, "GET is not allowed");
73 //!
74 //!     let res = warp::test::request()
75 //!         .method("POST")
76 //!         .path("/1/2")
77 //!         .reply(&filter);
78 //!     assert_eq!(res.status(), 200);
79 //!     assert_eq!(res.body(), "Sum is 3");
80 //! }
81 //! ```
82 use std::convert::TryFrom;
83 use std::error::Error as StdError;
84 use std::fmt;
85 use std::future::Future;
86 use std::net::SocketAddr;
87 #[cfg(feature = "websocket")]
88 use std::pin::Pin;
89 #[cfg(feature = "websocket")]
90 use std::task::{self, Poll};
91 
92 use bytes::Bytes;
93 #[cfg(feature = "websocket")]
94 use futures::StreamExt;
95 use futures::{future, FutureExt, TryFutureExt};
96 use http::{
97     header::{HeaderName, HeaderValue},
98     Response,
99 };
100 use serde::Serialize;
101 use serde_json;
102 #[cfg(feature = "websocket")]
103 use tokio::sync::{mpsc, oneshot};
104 
105 use crate::filter::Filter;
106 use crate::reject::IsReject;
107 use crate::reply::Reply;
108 use crate::route::{self, Route};
109 use crate::Request;
110 
111 use self::inner::OneOrTuple;
112 
113 /// Starts a new test `RequestBuilder`.
request() -> RequestBuilder114 pub fn request() -> RequestBuilder {
115     RequestBuilder {
116         remote_addr: None,
117         req: Request::default(),
118     }
119 }
120 
121 /// Starts a new test `WsBuilder`.
122 #[cfg(feature = "websocket")]
ws() -> WsBuilder123 pub fn ws() -> WsBuilder {
124     WsBuilder { req: request() }
125 }
126 
127 /// A request builder for testing filters.
128 ///
129 /// See [module documentation](crate::test) for an overview.
130 #[must_use = "RequestBuilder does nothing on its own"]
131 #[derive(Debug)]
132 pub struct RequestBuilder {
133     remote_addr: Option<SocketAddr>,
134     req: Request,
135 }
136 
137 /// A Websocket builder for testing filters.
138 ///
139 /// See [module documentation](crate::test) for an overview.
140 #[cfg(feature = "websocket")]
141 #[must_use = "WsBuilder does nothing on its own"]
142 #[derive(Debug)]
143 pub struct WsBuilder {
144     req: RequestBuilder,
145 }
146 
147 /// A test client for Websocket filters.
148 #[cfg(feature = "websocket")]
149 pub struct WsClient {
150     tx: mpsc::UnboundedSender<crate::ws::Message>,
151     rx: mpsc::UnboundedReceiver<Result<crate::ws::Message, crate::error::Error>>,
152 }
153 
154 /// An error from Websocket filter tests.
155 #[derive(Debug)]
156 pub struct WsError {
157     cause: Box<dyn StdError + Send + Sync>,
158 }
159 
160 impl RequestBuilder {
161     /// Sets the method of this builder.
162     ///
163     /// The default if not set is `GET`.
164     ///
165     /// # Example
166     ///
167     /// ```
168     /// let req = warp::test::request()
169     ///     .method("POST");
170     /// ```
171     ///
172     /// # Panic
173     ///
174     /// This panics if the passed string is not able to be parsed as a valid
175     /// `Method`.
method(mut self, method: &str) -> Self176     pub fn method(mut self, method: &str) -> Self {
177         *self.req.method_mut() = method.parse().expect("valid method");
178         self
179     }
180 
181     /// Sets the request path of this builder.
182     ///
183     /// The default is not set is `/`.
184     ///
185     /// # Example
186     ///
187     /// ```
188     /// let req = warp::test::request()
189     ///     .path("/todos/33");
190     /// ```
191     ///
192     /// # Panic
193     ///
194     /// This panics if the passed string is not able to be parsed as a valid
195     /// `Uri`.
path(mut self, p: &str) -> Self196     pub fn path(mut self, p: &str) -> Self {
197         let uri = p.parse().expect("test request path invalid");
198         *self.req.uri_mut() = uri;
199         self
200     }
201 
202     /// Set a header for this request.
203     ///
204     /// # Example
205     ///
206     /// ```
207     /// let req = warp::test::request()
208     ///     .header("accept", "application/json");
209     /// ```
210     ///
211     /// # Panic
212     ///
213     /// This panics if the passed strings are not able to be parsed as a valid
214     /// `HeaderName` and `HeaderValue`.
header<K, V>(mut self, key: K, value: V) -> Self where HeaderName: TryFrom<K>, HeaderValue: TryFrom<V>,215     pub fn header<K, V>(mut self, key: K, value: V) -> Self
216     where
217         HeaderName: TryFrom<K>,
218         HeaderValue: TryFrom<V>,
219     {
220         let name: HeaderName = TryFrom::try_from(key)
221             .map_err(|_| ())
222             .expect("invalid header name");
223         let value = TryFrom::try_from(value)
224             .map_err(|_| ())
225             .expect("invalid header value");
226         self.req.headers_mut().insert(name, value);
227         self
228     }
229 
230     /// Add a type to the request's `http::Extensions`.
extension<T>(mut self, ext: T) -> Self where T: Send + Sync + 'static,231     pub fn extension<T>(mut self, ext: T) -> Self
232     where
233         T: Send + Sync + 'static,
234     {
235         self.req.extensions_mut().insert(ext);
236         self
237     }
238 
239     /// Set the bytes of this request body.
240     ///
241     /// Default is an empty body.
242     ///
243     /// # Example
244     ///
245     /// ```
246     /// let req = warp::test::request()
247     ///     .body("foo=bar&baz=quux");
248     /// ```
body(mut self, body: impl AsRef<[u8]>) -> Self249     pub fn body(mut self, body: impl AsRef<[u8]>) -> Self {
250         let body = body.as_ref().to_vec();
251         let len = body.len();
252         *self.req.body_mut() = body.into();
253         self.header("content-length", len.to_string())
254     }
255 
256     /// Set the bytes of this request body by serializing a value into JSON.
257     ///
258     /// # Example
259     ///
260     /// ```
261     /// let req = warp::test::request()
262     ///     .json(&true);
263     /// ```
json(mut self, val: &impl Serialize) -> Self264     pub fn json(mut self, val: &impl Serialize) -> Self {
265         let vec = serde_json::to_vec(val).expect("json() must serialize to JSON");
266         let len = vec.len();
267         *self.req.body_mut() = vec.into();
268         self.header("content-length", len.to_string())
269             .header("content-type", "application/json")
270     }
271 
272     /// Tries to apply the `Filter` on this request.
273     ///
274     /// # Example
275     ///
276     /// ```no_run
277     /// async {
278     ///     let param = warp::path::param::<u32>();
279     ///
280     ///     let ex = warp::test::request()
281     ///         .path("/41")
282     ///         .filter(&param)
283     ///         .await
284     ///         .unwrap();
285     ///
286     ///     assert_eq!(ex, 41);
287     ///
288     ///     assert!(
289     ///         warp::test::request()
290     ///             .path("/foo")
291     ///             .filter(&param)
292     ///             .await
293     ///             .is_err()
294     ///     );
295     ///};
296     /// ```
filter<F>(self, f: &F) -> Result<<F::Extract as OneOrTuple>::Output, F::Error> where F: Filter, F::Future: Send + 'static, F::Extract: OneOrTuple + Send + 'static, F::Error: Send + 'static,297     pub async fn filter<F>(self, f: &F) -> Result<<F::Extract as OneOrTuple>::Output, F::Error>
298     where
299         F: Filter,
300         F::Future: Send + 'static,
301         F::Extract: OneOrTuple + Send + 'static,
302         F::Error: Send + 'static,
303     {
304         self.apply_filter(f).await.map(|ex| ex.one_or_tuple())
305     }
306 
307     /// Returns whether the `Filter` matches this request, or rejects it.
308     ///
309     /// # Example
310     ///
311     /// ```no_run
312     /// async {
313     ///     let get = warp::get();
314     ///     let post = warp::post();
315     ///
316     ///     assert!(
317     ///         warp::test::request()
318     ///             .method("GET")
319     ///             .matches(&get)
320     ///             .await
321     ///     );
322     ///
323     ///     assert!(
324     ///         !warp::test::request()
325     ///             .method("GET")
326     ///             .matches(&post)
327     ///             .await
328     ///     );
329     ///};
330     /// ```
matches<F>(self, f: &F) -> bool where F: Filter, F::Future: Send + 'static, F::Extract: Send + 'static, F::Error: Send + 'static,331     pub async fn matches<F>(self, f: &F) -> bool
332     where
333         F: Filter,
334         F::Future: Send + 'static,
335         F::Extract: Send + 'static,
336         F::Error: Send + 'static,
337     {
338         self.apply_filter(f).await.is_ok()
339     }
340 
341     /// Returns `Response` provided by applying the `Filter`.
342     ///
343     /// This requires that the supplied `Filter` return a [`Reply`](Reply).
reply<F>(self, f: &F) -> Response<Bytes> where F: Filter + 'static, F::Extract: Reply + Send, F::Error: IsReject + Send,344     pub async fn reply<F>(self, f: &F) -> Response<Bytes>
345     where
346         F: Filter + 'static,
347         F::Extract: Reply + Send,
348         F::Error: IsReject + Send,
349     {
350         // TODO: de-duplicate this and apply_filter()
351         assert!(!route::is_set(), "nested test filter calls");
352 
353         let route = Route::new(self.req, self.remote_addr);
354         let mut fut = Box::pin(
355             route::set(&route, move || f.filter(crate::filter::Internal)).then(|result| {
356                 let res = match result {
357                     Ok(rep) => rep.into_response(),
358                     Err(rej) => {
359                         log::debug!("rejected: {:?}", rej);
360                         rej.into_response()
361                     }
362                 };
363                 let (parts, body) = res.into_parts();
364                 hyper::body::to_bytes(body)
365                     .map_ok(|chunk| Response::from_parts(parts, chunk.into()))
366             }),
367         );
368 
369         let fut = future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx)));
370 
371         fut.await.expect("reply shouldn't fail")
372     }
373 
apply_filter<F>(self, f: &F) -> impl Future<Output = Result<F::Extract, F::Error>> where F: Filter, F::Future: Send + 'static, F::Extract: Send + 'static, F::Error: Send + 'static,374     fn apply_filter<F>(self, f: &F) -> impl Future<Output = Result<F::Extract, F::Error>>
375     where
376         F: Filter,
377         F::Future: Send + 'static,
378         F::Extract: Send + 'static,
379         F::Error: Send + 'static,
380     {
381         assert!(!route::is_set(), "nested test filter calls");
382 
383         let route = Route::new(self.req, self.remote_addr);
384         let mut fut = Box::pin(route::set(&route, move || {
385             f.filter(crate::filter::Internal)
386         }));
387         future::poll_fn(move |cx| route::set(&route, || fut.as_mut().poll(cx)))
388     }
389 }
390 
391 #[cfg(feature = "websocket")]
392 impl WsBuilder {
393     /// Sets the request path of this builder.
394     ///
395     /// The default is not set is `/`.
396     ///
397     /// # Example
398     ///
399     /// ```
400     /// let req = warp::test::ws()
401     ///     .path("/chat");
402     /// ```
403     ///
404     /// # Panic
405     ///
406     /// This panics if the passed string is not able to be parsed as a valid
407     /// `Uri`.
path(self, p: &str) -> Self408     pub fn path(self, p: &str) -> Self {
409         WsBuilder {
410             req: self.req.path(p),
411         }
412     }
413 
414     /// Set a header for this request.
415     ///
416     /// # Example
417     ///
418     /// ```
419     /// let req = warp::test::ws()
420     ///     .header("foo", "bar");
421     /// ```
422     ///
423     /// # Panic
424     ///
425     /// This panics if the passed strings are not able to be parsed as a valid
426     /// `HeaderName` and `HeaderValue`.
header<K, V>(self, key: K, value: V) -> Self where HeaderName: TryFrom<K>, HeaderValue: TryFrom<V>,427     pub fn header<K, V>(self, key: K, value: V) -> Self
428     where
429         HeaderName: TryFrom<K>,
430         HeaderValue: TryFrom<V>,
431     {
432         WsBuilder {
433             req: self.req.header(key, value),
434         }
435     }
436 
437     /// Execute this Websocket request against te provided filter.
438     ///
439     /// If the handshake succeeds, returns a `WsClient`.
440     ///
441     /// # Example
442     ///
443     /// ```no_run
444     /// use futures::future;
445     /// use warp::Filter;
446     /// #[tokio::main]
447     /// # async fn main() {
448     ///
449     /// // Some route that accepts websockets (but drops them immediately).
450     /// let route = warp::ws()
451     ///     .map(|ws: warp::ws::Ws| {
452     ///         ws.on_upgrade(|_| future::ready(()))
453     ///     });
454     ///
455     /// let client = warp::test::ws()
456     ///     .handshake(route)
457     ///     .await
458     ///     .expect("handshake");
459     /// # }
460     /// ```
handshake<F>(self, f: F) -> Result<WsClient, WsError> where F: Filter + Clone + Send + Sync + 'static, F::Extract: Reply + Send, F::Error: IsReject + Send,461     pub async fn handshake<F>(self, f: F) -> Result<WsClient, WsError>
462     where
463         F: Filter + Clone + Send + Sync + 'static,
464         F::Extract: Reply + Send,
465         F::Error: IsReject + Send,
466     {
467         let (upgraded_tx, upgraded_rx) = oneshot::channel();
468         let (wr_tx, wr_rx) = mpsc::unbounded_channel();
469         let (rd_tx, rd_rx) = mpsc::unbounded_channel();
470 
471         tokio::spawn(async move {
472             use tokio_tungstenite::tungstenite::protocol;
473 
474             let (addr, srv) = crate::serve(f).bind_ephemeral(([127, 0, 0, 1], 0));
475 
476             let mut req = self
477                 .req
478                 .header("connection", "upgrade")
479                 .header("upgrade", "websocket")
480                 .header("sec-websocket-version", "13")
481                 .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
482                 .req;
483 
484             let uri = format!("http://{}{}", addr, req.uri().path())
485                 .parse()
486                 .expect("addr + path is valid URI");
487 
488             *req.uri_mut() = uri;
489 
490             // let mut rt = current_thread::Runtime::new().unwrap();
491             tokio::spawn(srv);
492 
493             let upgrade = ::hyper::Client::builder()
494                 .build(AddrConnect(addr))
495                 .request(req)
496                 .and_then(|res| res.into_body().on_upgrade());
497 
498             let upgraded = match upgrade.await {
499                 Ok(up) => {
500                     let _ = upgraded_tx.send(Ok(()));
501                     up
502                 }
503                 Err(err) => {
504                     let _ = upgraded_tx.send(Err(err));
505                     return;
506                 }
507             };
508             let ws = crate::ws::WebSocket::from_raw_socket(
509                 upgraded,
510                 protocol::Role::Client,
511                 Default::default(),
512             )
513             .await;
514 
515             let (tx, rx) = ws.split();
516             let write = wr_rx.map(Ok).forward(tx).map(|_| ());
517 
518             let read = rx
519                 .take_while(|result| match result {
520                     Err(_) => future::ready(false),
521                     Ok(m) => future::ready(!m.is_close()),
522                 })
523                 .for_each(move |item| {
524                     rd_tx.send(item).expect("ws receive error");
525                     future::ready(())
526                 });
527 
528             future::join(write, read).await;
529         });
530 
531         match upgraded_rx.await {
532             Ok(Ok(())) => Ok(WsClient {
533                 tx: wr_tx,
534                 rx: rd_rx,
535             }),
536             Ok(Err(err)) => Err(WsError::new(err)),
537             Err(_canceled) => panic!("websocket handshake thread panicked"),
538         }
539     }
540 }
541 
542 #[cfg(feature = "websocket")]
543 impl WsClient {
544     /// Send a "text" websocket message to the server.
send_text(&mut self, text: impl Into<String>)545     pub async fn send_text(&mut self, text: impl Into<String>) {
546         self.send(crate::ws::Message::text(text)).await;
547     }
548 
549     /// Send a websocket message to the server.
send(&mut self, msg: crate::ws::Message)550     pub async fn send(&mut self, msg: crate::ws::Message) {
551         self.tx.send(msg).unwrap();
552     }
553 
554     /// Receive a websocket message from the server.
recv(&mut self) -> Result<crate::filters::ws::Message, WsError>555     pub async fn recv(&mut self) -> Result<crate::filters::ws::Message, WsError> {
556         self.rx
557             .next()
558             .await
559             .map(|unbounded_result| unbounded_result.map_err(WsError::new))
560             .unwrap_or_else(|| {
561                 // websocket is closed
562                 Err(WsError::new("closed"))
563             })
564     }
565 
566     /// Assert the server has closed the connection.
recv_closed(&mut self) -> Result<(), WsError>567     pub async fn recv_closed(&mut self) -> Result<(), WsError> {
568         self.rx
569             .next()
570             .await
571             .map(|result| match result {
572                 Ok(msg) => Err(WsError::new(format!("received message: {:?}", msg))),
573                 Err(err) => Err(WsError::new(err)),
574             })
575             .unwrap_or_else(|| {
576                 // closed successfully
577                 Ok(())
578             })
579     }
580 }
581 
582 #[cfg(feature = "websocket")]
583 impl fmt::Debug for WsClient {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result584     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
585         f.debug_struct("WsClient").finish()
586     }
587 }
588 
589 // ===== impl WsError =====
590 
591 #[cfg(feature = "websocket")]
592 impl WsError {
new<E: Into<Box<dyn StdError + Send + Sync>>>(cause: E) -> Self593     fn new<E: Into<Box<dyn StdError + Send + Sync>>>(cause: E) -> Self {
594         WsError {
595             cause: cause.into(),
596         }
597     }
598 }
599 
600 impl fmt::Display for WsError {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result601     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
602         write!(f, "websocket error: {}", self.cause)
603     }
604 }
605 
606 impl StdError for WsError {
description(&self) -> &str607     fn description(&self) -> &str {
608         "websocket error"
609     }
610 }
611 
612 // ===== impl AddrConnect =====
613 
614 #[cfg(feature = "websocket")]
615 #[derive(Clone)]
616 struct AddrConnect(SocketAddr);
617 
618 #[cfg(feature = "websocket")]
619 impl tower_service::Service<::http::Uri> for AddrConnect {
620     type Response = ::tokio::net::TcpStream;
621     type Error = ::std::io::Error;
622     type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
623 
poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>624     fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
625         Poll::Ready(Ok(()))
626     }
627 
call(&mut self, _: ::http::Uri) -> Self::Future628     fn call(&mut self, _: ::http::Uri) -> Self::Future {
629         Box::pin(tokio::net::TcpStream::connect(self.0))
630     }
631 }
632 
633 mod inner {
634     pub trait OneOrTuple {
635         type Output;
636 
one_or_tuple(self) -> Self::Output637         fn one_or_tuple(self) -> Self::Output;
638     }
639 
640     impl OneOrTuple for () {
641         type Output = ();
one_or_tuple(self) -> Self::Output642         fn one_or_tuple(self) -> Self::Output {}
643     }
644 
645     macro_rules! one_or_tuple {
646         ($type1:ident) => {
647             impl<$type1> OneOrTuple for ($type1,) {
648                 type Output = $type1;
649                 fn one_or_tuple(self) -> Self::Output {
650                     self.0
651                 }
652             }
653         };
654         ($type1:ident, $( $type:ident ),*) => {
655             one_or_tuple!($( $type ),*);
656 
657             impl<$type1, $($type),*> OneOrTuple for ($type1, $($type),*) {
658                 type Output = Self;
659                 fn one_or_tuple(self) -> Self::Output {
660                     self
661                 }
662             }
663         }
664     }
665 
666     one_or_tuple! {
667         T1,
668         T2,
669         T3,
670         T4,
671         T5,
672         T6,
673         T7,
674         T8,
675         T9,
676         T10,
677         T11,
678         T12,
679         T13,
680         T14,
681         T15,
682         T16
683     }
684 }
685