1 //! Route match guards.
2 //!
3 //! Guards are one of the ways how actix-web router chooses a
4 //! handler service. In essence it is just a function that accepts a
5 //! reference to a `RequestHead` instance and returns a boolean.
6 //! It is possible to add guards to *scopes*, *resources*
7 //! and *routes*. Actix provide several guards by default, like various
8 //! http methods, header, etc. To become a guard, type must implement `Guard`
9 //! trait. Simple functions could be guards as well.
10 //!
11 //! Guards can not modify the request object. But it is possible
12 //! to store extra attributes on a request by using the `Extensions` container.
13 //! Extensions containers are available via the `RequestHead::extensions()` method.
14 //!
15 //! ```
16 //! use actix_web::{web, http, dev, guard, App, HttpResponse};
17 //!
18 //! fn main() {
19 //!     App::new().service(web::resource("/index.html").route(
20 //!         web::route()
21 //!              .guard(guard::Post())
22 //!              .guard(guard::fn_guard(|head| head.method == http::Method::GET))
23 //!              .to(|| HttpResponse::MethodNotAllowed()))
24 //!     );
25 //! }
26 //! ```
27 #![allow(non_snake_case)]
28 use std::convert::TryFrom;
29 use std::ops::Deref;
30 use std::rc::Rc;
31 
32 use actix_http::http::{self, header, uri::Uri};
33 use actix_http::RequestHead;
34 
35 /// Trait defines resource guards. Guards are used for route selection.
36 ///
37 /// Guards can not modify the request object. But it is possible
38 /// to store extra attributes on a request by using the `Extensions` container.
39 /// Extensions containers are available via the `RequestHead::extensions()` method.
40 pub trait Guard {
41     /// Check if request matches predicate
check(&self, request: &RequestHead) -> bool42     fn check(&self, request: &RequestHead) -> bool;
43 }
44 
45 impl Guard for Rc<dyn Guard> {
check(&self, request: &RequestHead) -> bool46     fn check(&self, request: &RequestHead) -> bool {
47         self.deref().check(request)
48     }
49 }
50 
51 /// Create guard object for supplied function.
52 ///
53 /// ```
54 /// use actix_web::{guard, web, App, HttpResponse};
55 ///
56 /// fn main() {
57 ///     App::new().service(web::resource("/index.html").route(
58 ///         web::route()
59 ///             .guard(
60 ///                 guard::fn_guard(
61 ///                     |req| req.headers()
62 ///                              .contains_key("content-type")))
63 ///             .to(|| HttpResponse::MethodNotAllowed()))
64 ///     );
65 /// }
66 /// ```
fn_guard<F>(f: F) -> impl Guard where F: Fn(&RequestHead) -> bool,67 pub fn fn_guard<F>(f: F) -> impl Guard
68 where
69     F: Fn(&RequestHead) -> bool,
70 {
71     FnGuard(f)
72 }
73 
74 struct FnGuard<F: Fn(&RequestHead) -> bool>(F);
75 
76 impl<F> Guard for FnGuard<F>
77 where
78     F: Fn(&RequestHead) -> bool,
79 {
check(&self, head: &RequestHead) -> bool80     fn check(&self, head: &RequestHead) -> bool {
81         (self.0)(head)
82     }
83 }
84 
85 impl<F> Guard for F
86 where
87     F: Fn(&RequestHead) -> bool,
88 {
check(&self, head: &RequestHead) -> bool89     fn check(&self, head: &RequestHead) -> bool {
90         (self)(head)
91     }
92 }
93 
94 /// Return guard that matches if any of supplied guards.
95 ///
96 /// ```
97 /// use actix_web::{web, guard, App, HttpResponse};
98 ///
99 /// fn main() {
100 ///     App::new().service(web::resource("/index.html").route(
101 ///         web::route()
102 ///              .guard(guard::Any(guard::Get()).or(guard::Post()))
103 ///              .to(|| HttpResponse::MethodNotAllowed()))
104 ///     );
105 /// }
106 /// ```
Any<F: Guard + 'static>(guard: F) -> AnyGuard107 pub fn Any<F: Guard + 'static>(guard: F) -> AnyGuard {
108     AnyGuard(vec![Box::new(guard)])
109 }
110 
111 /// Matches any of supplied guards.
112 pub struct AnyGuard(Vec<Box<dyn Guard>>);
113 
114 impl AnyGuard {
115     /// Add guard to a list of guards to check
or<F: Guard + 'static>(mut self, guard: F) -> Self116     pub fn or<F: Guard + 'static>(mut self, guard: F) -> Self {
117         self.0.push(Box::new(guard));
118         self
119     }
120 }
121 
122 impl Guard for AnyGuard {
check(&self, req: &RequestHead) -> bool123     fn check(&self, req: &RequestHead) -> bool {
124         for p in &self.0 {
125             if p.check(req) {
126                 return true;
127             }
128         }
129         false
130     }
131 }
132 
133 /// Return guard that matches if all of the supplied guards.
134 ///
135 /// ```
136 /// use actix_web::{guard, web, App, HttpResponse};
137 ///
138 /// fn main() {
139 ///     App::new().service(web::resource("/index.html").route(
140 ///         web::route()
141 ///             .guard(
142 ///                 guard::All(guard::Get()).and(guard::Header("content-type", "text/plain")))
143 ///             .to(|| HttpResponse::MethodNotAllowed()))
144 ///     );
145 /// }
146 /// ```
All<F: Guard + 'static>(guard: F) -> AllGuard147 pub fn All<F: Guard + 'static>(guard: F) -> AllGuard {
148     AllGuard(vec![Box::new(guard)])
149 }
150 
151 /// Matches if all of supplied guards.
152 pub struct AllGuard(Vec<Box<dyn Guard>>);
153 
154 impl AllGuard {
155     /// Add new guard to the list of guards to check
and<F: Guard + 'static>(mut self, guard: F) -> Self156     pub fn and<F: Guard + 'static>(mut self, guard: F) -> Self {
157         self.0.push(Box::new(guard));
158         self
159     }
160 }
161 
162 impl Guard for AllGuard {
check(&self, request: &RequestHead) -> bool163     fn check(&self, request: &RequestHead) -> bool {
164         for p in &self.0 {
165             if !p.check(request) {
166                 return false;
167             }
168         }
169         true
170     }
171 }
172 
173 /// Return guard that matches if supplied guard does not match.
Not<F: Guard + 'static>(guard: F) -> NotGuard174 pub fn Not<F: Guard + 'static>(guard: F) -> NotGuard {
175     NotGuard(Box::new(guard))
176 }
177 
178 #[doc(hidden)]
179 pub struct NotGuard(Box<dyn Guard>);
180 
181 impl Guard for NotGuard {
check(&self, request: &RequestHead) -> bool182     fn check(&self, request: &RequestHead) -> bool {
183         !self.0.check(request)
184     }
185 }
186 
187 /// HTTP method guard.
188 #[doc(hidden)]
189 pub struct MethodGuard(http::Method);
190 
191 impl Guard for MethodGuard {
check(&self, request: &RequestHead) -> bool192     fn check(&self, request: &RequestHead) -> bool {
193         request.method == self.0
194     }
195 }
196 
197 /// Guard to match *GET* HTTP method.
Get() -> MethodGuard198 pub fn Get() -> MethodGuard {
199     MethodGuard(http::Method::GET)
200 }
201 
202 /// Predicate to match *POST* HTTP method.
Post() -> MethodGuard203 pub fn Post() -> MethodGuard {
204     MethodGuard(http::Method::POST)
205 }
206 
207 /// Predicate to match *PUT* HTTP method.
Put() -> MethodGuard208 pub fn Put() -> MethodGuard {
209     MethodGuard(http::Method::PUT)
210 }
211 
212 /// Predicate to match *DELETE* HTTP method.
Delete() -> MethodGuard213 pub fn Delete() -> MethodGuard {
214     MethodGuard(http::Method::DELETE)
215 }
216 
217 /// Predicate to match *HEAD* HTTP method.
Head() -> MethodGuard218 pub fn Head() -> MethodGuard {
219     MethodGuard(http::Method::HEAD)
220 }
221 
222 /// Predicate to match *OPTIONS* HTTP method.
Options() -> MethodGuard223 pub fn Options() -> MethodGuard {
224     MethodGuard(http::Method::OPTIONS)
225 }
226 
227 /// Predicate to match *CONNECT* HTTP method.
Connect() -> MethodGuard228 pub fn Connect() -> MethodGuard {
229     MethodGuard(http::Method::CONNECT)
230 }
231 
232 /// Predicate to match *PATCH* HTTP method.
Patch() -> MethodGuard233 pub fn Patch() -> MethodGuard {
234     MethodGuard(http::Method::PATCH)
235 }
236 
237 /// Predicate to match *TRACE* HTTP method.
Trace() -> MethodGuard238 pub fn Trace() -> MethodGuard {
239     MethodGuard(http::Method::TRACE)
240 }
241 
242 /// Predicate to match specified HTTP method.
Method(method: http::Method) -> MethodGuard243 pub fn Method(method: http::Method) -> MethodGuard {
244     MethodGuard(method)
245 }
246 
247 /// Return predicate that matches if request contains specified header and
248 /// value.
Header(name: &'static str, value: &'static str) -> HeaderGuard249 pub fn Header(name: &'static str, value: &'static str) -> HeaderGuard {
250     HeaderGuard(
251         header::HeaderName::try_from(name).unwrap(),
252         header::HeaderValue::from_static(value),
253     )
254 }
255 
256 #[doc(hidden)]
257 pub struct HeaderGuard(header::HeaderName, header::HeaderValue);
258 
259 impl Guard for HeaderGuard {
check(&self, req: &RequestHead) -> bool260     fn check(&self, req: &RequestHead) -> bool {
261         if let Some(val) = req.headers.get(&self.0) {
262             return val == self.1;
263         }
264         false
265     }
266 }
267 
268 /// Return predicate that matches if request contains specified Host name.
269 ///
270 /// ```
271 /// use actix_web::{web, guard::Host, App, HttpResponse};
272 ///
273 /// fn main() {
274 ///     App::new().service(
275 ///         web::resource("/index.html")
276 ///             .guard(Host("www.rust-lang.org"))
277 ///             .to(|| HttpResponse::MethodNotAllowed())
278 ///     );
279 /// }
280 /// ```
Host<H: AsRef<str>>(host: H) -> HostGuard281 pub fn Host<H: AsRef<str>>(host: H) -> HostGuard {
282     HostGuard(host.as_ref().to_string(), None)
283 }
284 
get_host_uri(req: &RequestHead) -> Option<Uri>285 fn get_host_uri(req: &RequestHead) -> Option<Uri> {
286     use core::str::FromStr;
287     req.headers
288         .get(header::HOST)
289         .and_then(|host_value| host_value.to_str().ok())
290         .or_else(|| req.uri.host())
291         .map(|host: &str| Uri::from_str(host).ok())
292         .and_then(|host_success| host_success)
293 }
294 
295 #[doc(hidden)]
296 pub struct HostGuard(String, Option<String>);
297 
298 impl HostGuard {
299     /// Set request scheme to match
scheme<H: AsRef<str>>(mut self, scheme: H) -> HostGuard300     pub fn scheme<H: AsRef<str>>(mut self, scheme: H) -> HostGuard {
301         self.1 = Some(scheme.as_ref().to_string());
302         self
303     }
304 }
305 
306 impl Guard for HostGuard {
check(&self, req: &RequestHead) -> bool307     fn check(&self, req: &RequestHead) -> bool {
308         let req_host_uri = if let Some(uri) = get_host_uri(req) {
309             uri
310         } else {
311             return false;
312         };
313 
314         if let Some(uri_host) = req_host_uri.host() {
315             if self.0 != uri_host {
316                 return false;
317             }
318         } else {
319             return false;
320         }
321 
322         if let Some(ref scheme) = self.1 {
323             if let Some(ref req_host_uri_scheme) = req_host_uri.scheme_str() {
324                 return scheme == req_host_uri_scheme;
325             }
326         }
327 
328         true
329     }
330 }
331 
332 #[cfg(test)]
333 mod tests {
334     use actix_http::http::{header, Method};
335 
336     use super::*;
337     use crate::test::TestRequest;
338 
339     #[test]
test_header()340     fn test_header() {
341         let req = TestRequest::default()
342             .insert_header((header::TRANSFER_ENCODING, "chunked"))
343             .to_http_request();
344 
345         let pred = Header("transfer-encoding", "chunked");
346         assert!(pred.check(req.head()));
347 
348         let pred = Header("transfer-encoding", "other");
349         assert!(!pred.check(req.head()));
350 
351         let pred = Header("content-type", "other");
352         assert!(!pred.check(req.head()));
353     }
354 
355     #[test]
test_host()356     fn test_host() {
357         let req = TestRequest::default()
358             .insert_header((
359                 header::HOST,
360                 header::HeaderValue::from_static("www.rust-lang.org"),
361             ))
362             .to_http_request();
363 
364         let pred = Host("www.rust-lang.org");
365         assert!(pred.check(req.head()));
366 
367         let pred = Host("www.rust-lang.org").scheme("https");
368         assert!(pred.check(req.head()));
369 
370         let pred = Host("blog.rust-lang.org");
371         assert!(!pred.check(req.head()));
372 
373         let pred = Host("blog.rust-lang.org").scheme("https");
374         assert!(!pred.check(req.head()));
375 
376         let pred = Host("crates.io");
377         assert!(!pred.check(req.head()));
378 
379         let pred = Host("localhost");
380         assert!(!pred.check(req.head()));
381     }
382 
383     #[test]
test_host_scheme()384     fn test_host_scheme() {
385         let req = TestRequest::default()
386             .insert_header((
387                 header::HOST,
388                 header::HeaderValue::from_static("https://www.rust-lang.org"),
389             ))
390             .to_http_request();
391 
392         let pred = Host("www.rust-lang.org").scheme("https");
393         assert!(pred.check(req.head()));
394 
395         let pred = Host("www.rust-lang.org");
396         assert!(pred.check(req.head()));
397 
398         let pred = Host("www.rust-lang.org").scheme("http");
399         assert!(!pred.check(req.head()));
400 
401         let pred = Host("blog.rust-lang.org");
402         assert!(!pred.check(req.head()));
403 
404         let pred = Host("blog.rust-lang.org").scheme("https");
405         assert!(!pred.check(req.head()));
406 
407         let pred = Host("crates.io").scheme("https");
408         assert!(!pred.check(req.head()));
409 
410         let pred = Host("localhost");
411         assert!(!pred.check(req.head()));
412     }
413 
414     #[test]
test_host_without_header()415     fn test_host_without_header() {
416         let req = TestRequest::default()
417             .uri("www.rust-lang.org")
418             .to_http_request();
419 
420         let pred = Host("www.rust-lang.org");
421         assert!(pred.check(req.head()));
422 
423         let pred = Host("www.rust-lang.org").scheme("https");
424         assert!(pred.check(req.head()));
425 
426         let pred = Host("blog.rust-lang.org");
427         assert!(!pred.check(req.head()));
428 
429         let pred = Host("blog.rust-lang.org").scheme("https");
430         assert!(!pred.check(req.head()));
431 
432         let pred = Host("crates.io");
433         assert!(!pred.check(req.head()));
434 
435         let pred = Host("localhost");
436         assert!(!pred.check(req.head()));
437     }
438 
439     #[test]
test_methods()440     fn test_methods() {
441         let req = TestRequest::default().to_http_request();
442         let req2 = TestRequest::default()
443             .method(Method::POST)
444             .to_http_request();
445 
446         assert!(Get().check(req.head()));
447         assert!(!Get().check(req2.head()));
448         assert!(Post().check(req2.head()));
449         assert!(!Post().check(req.head()));
450 
451         let r = TestRequest::default().method(Method::PUT).to_http_request();
452         assert!(Put().check(r.head()));
453         assert!(!Put().check(req.head()));
454 
455         let r = TestRequest::default()
456             .method(Method::DELETE)
457             .to_http_request();
458         assert!(Delete().check(r.head()));
459         assert!(!Delete().check(req.head()));
460 
461         let r = TestRequest::default()
462             .method(Method::HEAD)
463             .to_http_request();
464         assert!(Head().check(r.head()));
465         assert!(!Head().check(req.head()));
466 
467         let r = TestRequest::default()
468             .method(Method::OPTIONS)
469             .to_http_request();
470         assert!(Options().check(r.head()));
471         assert!(!Options().check(req.head()));
472 
473         let r = TestRequest::default()
474             .method(Method::CONNECT)
475             .to_http_request();
476         assert!(Connect().check(r.head()));
477         assert!(!Connect().check(req.head()));
478 
479         let r = TestRequest::default()
480             .method(Method::PATCH)
481             .to_http_request();
482         assert!(Patch().check(r.head()));
483         assert!(!Patch().check(req.head()));
484 
485         let r = TestRequest::default()
486             .method(Method::TRACE)
487             .to_http_request();
488         assert!(Trace().check(r.head()));
489         assert!(!Trace().check(req.head()));
490     }
491 
492     #[test]
test_preds()493     fn test_preds() {
494         let r = TestRequest::default()
495             .method(Method::TRACE)
496             .to_http_request();
497 
498         assert!(Not(Get()).check(r.head()));
499         assert!(!Not(Trace()).check(r.head()));
500 
501         assert!(All(Trace()).and(Trace()).check(r.head()));
502         assert!(!All(Get()).and(Trace()).check(r.head()));
503 
504         assert!(Any(Get()).or(Trace()).check(r.head()));
505         assert!(!Any(Get()).or(Get()).check(r.head()));
506     }
507 }
508