1 //! Route match guards.
2 //!
3 //! Guards are one of the ways how actix-web router chooses a
4 //! handler service. In essence it is just a function that accepts a
5 //! reference to a `RequestHead` instance and returns a boolean.
6 //! It is possible to add guards to *scopes*, *resources*
7 //! and *routes*. Actix provide several guards by default, like various
8 //! http methods, header, etc. To become a guard, type must implement `Guard`
9 //! trait. Simple functions could be guards as well.
10 //!
11 //! Guards can not modify the request object. But it is possible
12 //! to store extra attributes on a request by using the `Extensions` container.
13 //! Extensions containers are available via the `RequestHead::extensions()` method.
14 //!
15 //! ```
16 //! use actix_web::{web, http, dev, guard, App, HttpResponse};
17 //!
18 //! fn main() {
19 //! App::new().service(web::resource("/index.html").route(
20 //! web::route()
21 //! .guard(guard::Post())
22 //! .guard(guard::fn_guard(|head| head.method == http::Method::GET))
23 //! .to(|| HttpResponse::MethodNotAllowed()))
24 //! );
25 //! }
26 //! ```
27 #![allow(non_snake_case)]
28 use std::convert::TryFrom;
29 use std::ops::Deref;
30 use std::rc::Rc;
31
32 use actix_http::http::{self, header, uri::Uri};
33 use actix_http::RequestHead;
34
35 /// Trait defines resource guards. Guards are used for route selection.
36 ///
37 /// Guards can not modify the request object. But it is possible
38 /// to store extra attributes on a request by using the `Extensions` container.
39 /// Extensions containers are available via the `RequestHead::extensions()` method.
40 pub trait Guard {
41 /// Check if request matches predicate
check(&self, request: &RequestHead) -> bool42 fn check(&self, request: &RequestHead) -> bool;
43 }
44
45 impl Guard for Rc<dyn Guard> {
check(&self, request: &RequestHead) -> bool46 fn check(&self, request: &RequestHead) -> bool {
47 self.deref().check(request)
48 }
49 }
50
51 /// Create guard object for supplied function.
52 ///
53 /// ```
54 /// use actix_web::{guard, web, App, HttpResponse};
55 ///
56 /// fn main() {
57 /// App::new().service(web::resource("/index.html").route(
58 /// web::route()
59 /// .guard(
60 /// guard::fn_guard(
61 /// |req| req.headers()
62 /// .contains_key("content-type")))
63 /// .to(|| HttpResponse::MethodNotAllowed()))
64 /// );
65 /// }
66 /// ```
fn_guard<F>(f: F) -> impl Guard where F: Fn(&RequestHead) -> bool,67 pub fn fn_guard<F>(f: F) -> impl Guard
68 where
69 F: Fn(&RequestHead) -> bool,
70 {
71 FnGuard(f)
72 }
73
74 struct FnGuard<F: Fn(&RequestHead) -> bool>(F);
75
76 impl<F> Guard for FnGuard<F>
77 where
78 F: Fn(&RequestHead) -> bool,
79 {
check(&self, head: &RequestHead) -> bool80 fn check(&self, head: &RequestHead) -> bool {
81 (self.0)(head)
82 }
83 }
84
85 impl<F> Guard for F
86 where
87 F: Fn(&RequestHead) -> bool,
88 {
check(&self, head: &RequestHead) -> bool89 fn check(&self, head: &RequestHead) -> bool {
90 (self)(head)
91 }
92 }
93
94 /// Return guard that matches if any of supplied guards.
95 ///
96 /// ```
97 /// use actix_web::{web, guard, App, HttpResponse};
98 ///
99 /// fn main() {
100 /// App::new().service(web::resource("/index.html").route(
101 /// web::route()
102 /// .guard(guard::Any(guard::Get()).or(guard::Post()))
103 /// .to(|| HttpResponse::MethodNotAllowed()))
104 /// );
105 /// }
106 /// ```
Any<F: Guard + 'static>(guard: F) -> AnyGuard107 pub fn Any<F: Guard + 'static>(guard: F) -> AnyGuard {
108 AnyGuard(vec![Box::new(guard)])
109 }
110
111 /// Matches any of supplied guards.
112 pub struct AnyGuard(Vec<Box<dyn Guard>>);
113
114 impl AnyGuard {
115 /// Add guard to a list of guards to check
or<F: Guard + 'static>(mut self, guard: F) -> Self116 pub fn or<F: Guard + 'static>(mut self, guard: F) -> Self {
117 self.0.push(Box::new(guard));
118 self
119 }
120 }
121
122 impl Guard for AnyGuard {
check(&self, req: &RequestHead) -> bool123 fn check(&self, req: &RequestHead) -> bool {
124 for p in &self.0 {
125 if p.check(req) {
126 return true;
127 }
128 }
129 false
130 }
131 }
132
133 /// Return guard that matches if all of the supplied guards.
134 ///
135 /// ```
136 /// use actix_web::{guard, web, App, HttpResponse};
137 ///
138 /// fn main() {
139 /// App::new().service(web::resource("/index.html").route(
140 /// web::route()
141 /// .guard(
142 /// guard::All(guard::Get()).and(guard::Header("content-type", "text/plain")))
143 /// .to(|| HttpResponse::MethodNotAllowed()))
144 /// );
145 /// }
146 /// ```
All<F: Guard + 'static>(guard: F) -> AllGuard147 pub fn All<F: Guard + 'static>(guard: F) -> AllGuard {
148 AllGuard(vec![Box::new(guard)])
149 }
150
151 /// Matches if all of supplied guards.
152 pub struct AllGuard(Vec<Box<dyn Guard>>);
153
154 impl AllGuard {
155 /// Add new guard to the list of guards to check
and<F: Guard + 'static>(mut self, guard: F) -> Self156 pub fn and<F: Guard + 'static>(mut self, guard: F) -> Self {
157 self.0.push(Box::new(guard));
158 self
159 }
160 }
161
162 impl Guard for AllGuard {
check(&self, request: &RequestHead) -> bool163 fn check(&self, request: &RequestHead) -> bool {
164 for p in &self.0 {
165 if !p.check(request) {
166 return false;
167 }
168 }
169 true
170 }
171 }
172
173 /// Return guard that matches if supplied guard does not match.
Not<F: Guard + 'static>(guard: F) -> NotGuard174 pub fn Not<F: Guard + 'static>(guard: F) -> NotGuard {
175 NotGuard(Box::new(guard))
176 }
177
178 #[doc(hidden)]
179 pub struct NotGuard(Box<dyn Guard>);
180
181 impl Guard for NotGuard {
check(&self, request: &RequestHead) -> bool182 fn check(&self, request: &RequestHead) -> bool {
183 !self.0.check(request)
184 }
185 }
186
187 /// HTTP method guard.
188 #[doc(hidden)]
189 pub struct MethodGuard(http::Method);
190
191 impl Guard for MethodGuard {
check(&self, request: &RequestHead) -> bool192 fn check(&self, request: &RequestHead) -> bool {
193 request.method == self.0
194 }
195 }
196
197 /// Guard to match *GET* HTTP method.
Get() -> MethodGuard198 pub fn Get() -> MethodGuard {
199 MethodGuard(http::Method::GET)
200 }
201
202 /// Predicate to match *POST* HTTP method.
Post() -> MethodGuard203 pub fn Post() -> MethodGuard {
204 MethodGuard(http::Method::POST)
205 }
206
207 /// Predicate to match *PUT* HTTP method.
Put() -> MethodGuard208 pub fn Put() -> MethodGuard {
209 MethodGuard(http::Method::PUT)
210 }
211
212 /// Predicate to match *DELETE* HTTP method.
Delete() -> MethodGuard213 pub fn Delete() -> MethodGuard {
214 MethodGuard(http::Method::DELETE)
215 }
216
217 /// Predicate to match *HEAD* HTTP method.
Head() -> MethodGuard218 pub fn Head() -> MethodGuard {
219 MethodGuard(http::Method::HEAD)
220 }
221
222 /// Predicate to match *OPTIONS* HTTP method.
Options() -> MethodGuard223 pub fn Options() -> MethodGuard {
224 MethodGuard(http::Method::OPTIONS)
225 }
226
227 /// Predicate to match *CONNECT* HTTP method.
Connect() -> MethodGuard228 pub fn Connect() -> MethodGuard {
229 MethodGuard(http::Method::CONNECT)
230 }
231
232 /// Predicate to match *PATCH* HTTP method.
Patch() -> MethodGuard233 pub fn Patch() -> MethodGuard {
234 MethodGuard(http::Method::PATCH)
235 }
236
237 /// Predicate to match *TRACE* HTTP method.
Trace() -> MethodGuard238 pub fn Trace() -> MethodGuard {
239 MethodGuard(http::Method::TRACE)
240 }
241
242 /// Predicate to match specified HTTP method.
Method(method: http::Method) -> MethodGuard243 pub fn Method(method: http::Method) -> MethodGuard {
244 MethodGuard(method)
245 }
246
247 /// Return predicate that matches if request contains specified header and
248 /// value.
Header(name: &'static str, value: &'static str) -> HeaderGuard249 pub fn Header(name: &'static str, value: &'static str) -> HeaderGuard {
250 HeaderGuard(
251 header::HeaderName::try_from(name).unwrap(),
252 header::HeaderValue::from_static(value),
253 )
254 }
255
256 #[doc(hidden)]
257 pub struct HeaderGuard(header::HeaderName, header::HeaderValue);
258
259 impl Guard for HeaderGuard {
check(&self, req: &RequestHead) -> bool260 fn check(&self, req: &RequestHead) -> bool {
261 if let Some(val) = req.headers.get(&self.0) {
262 return val == self.1;
263 }
264 false
265 }
266 }
267
268 /// Return predicate that matches if request contains specified Host name.
269 ///
270 /// ```
271 /// use actix_web::{web, guard::Host, App, HttpResponse};
272 ///
273 /// fn main() {
274 /// App::new().service(
275 /// web::resource("/index.html")
276 /// .guard(Host("www.rust-lang.org"))
277 /// .to(|| HttpResponse::MethodNotAllowed())
278 /// );
279 /// }
280 /// ```
Host<H: AsRef<str>>(host: H) -> HostGuard281 pub fn Host<H: AsRef<str>>(host: H) -> HostGuard {
282 HostGuard(host.as_ref().to_string(), None)
283 }
284
get_host_uri(req: &RequestHead) -> Option<Uri>285 fn get_host_uri(req: &RequestHead) -> Option<Uri> {
286 use core::str::FromStr;
287 req.headers
288 .get(header::HOST)
289 .and_then(|host_value| host_value.to_str().ok())
290 .or_else(|| req.uri.host())
291 .map(|host: &str| Uri::from_str(host).ok())
292 .and_then(|host_success| host_success)
293 }
294
295 #[doc(hidden)]
296 pub struct HostGuard(String, Option<String>);
297
298 impl HostGuard {
299 /// Set request scheme to match
scheme<H: AsRef<str>>(mut self, scheme: H) -> HostGuard300 pub fn scheme<H: AsRef<str>>(mut self, scheme: H) -> HostGuard {
301 self.1 = Some(scheme.as_ref().to_string());
302 self
303 }
304 }
305
306 impl Guard for HostGuard {
check(&self, req: &RequestHead) -> bool307 fn check(&self, req: &RequestHead) -> bool {
308 let req_host_uri = if let Some(uri) = get_host_uri(req) {
309 uri
310 } else {
311 return false;
312 };
313
314 if let Some(uri_host) = req_host_uri.host() {
315 if self.0 != uri_host {
316 return false;
317 }
318 } else {
319 return false;
320 }
321
322 if let Some(ref scheme) = self.1 {
323 if let Some(ref req_host_uri_scheme) = req_host_uri.scheme_str() {
324 return scheme == req_host_uri_scheme;
325 }
326 }
327
328 true
329 }
330 }
331
332 #[cfg(test)]
333 mod tests {
334 use actix_http::http::{header, Method};
335
336 use super::*;
337 use crate::test::TestRequest;
338
339 #[test]
test_header()340 fn test_header() {
341 let req = TestRequest::default()
342 .insert_header((header::TRANSFER_ENCODING, "chunked"))
343 .to_http_request();
344
345 let pred = Header("transfer-encoding", "chunked");
346 assert!(pred.check(req.head()));
347
348 let pred = Header("transfer-encoding", "other");
349 assert!(!pred.check(req.head()));
350
351 let pred = Header("content-type", "other");
352 assert!(!pred.check(req.head()));
353 }
354
355 #[test]
test_host()356 fn test_host() {
357 let req = TestRequest::default()
358 .insert_header((
359 header::HOST,
360 header::HeaderValue::from_static("www.rust-lang.org"),
361 ))
362 .to_http_request();
363
364 let pred = Host("www.rust-lang.org");
365 assert!(pred.check(req.head()));
366
367 let pred = Host("www.rust-lang.org").scheme("https");
368 assert!(pred.check(req.head()));
369
370 let pred = Host("blog.rust-lang.org");
371 assert!(!pred.check(req.head()));
372
373 let pred = Host("blog.rust-lang.org").scheme("https");
374 assert!(!pred.check(req.head()));
375
376 let pred = Host("crates.io");
377 assert!(!pred.check(req.head()));
378
379 let pred = Host("localhost");
380 assert!(!pred.check(req.head()));
381 }
382
383 #[test]
test_host_scheme()384 fn test_host_scheme() {
385 let req = TestRequest::default()
386 .insert_header((
387 header::HOST,
388 header::HeaderValue::from_static("https://www.rust-lang.org"),
389 ))
390 .to_http_request();
391
392 let pred = Host("www.rust-lang.org").scheme("https");
393 assert!(pred.check(req.head()));
394
395 let pred = Host("www.rust-lang.org");
396 assert!(pred.check(req.head()));
397
398 let pred = Host("www.rust-lang.org").scheme("http");
399 assert!(!pred.check(req.head()));
400
401 let pred = Host("blog.rust-lang.org");
402 assert!(!pred.check(req.head()));
403
404 let pred = Host("blog.rust-lang.org").scheme("https");
405 assert!(!pred.check(req.head()));
406
407 let pred = Host("crates.io").scheme("https");
408 assert!(!pred.check(req.head()));
409
410 let pred = Host("localhost");
411 assert!(!pred.check(req.head()));
412 }
413
414 #[test]
test_host_without_header()415 fn test_host_without_header() {
416 let req = TestRequest::default()
417 .uri("www.rust-lang.org")
418 .to_http_request();
419
420 let pred = Host("www.rust-lang.org");
421 assert!(pred.check(req.head()));
422
423 let pred = Host("www.rust-lang.org").scheme("https");
424 assert!(pred.check(req.head()));
425
426 let pred = Host("blog.rust-lang.org");
427 assert!(!pred.check(req.head()));
428
429 let pred = Host("blog.rust-lang.org").scheme("https");
430 assert!(!pred.check(req.head()));
431
432 let pred = Host("crates.io");
433 assert!(!pred.check(req.head()));
434
435 let pred = Host("localhost");
436 assert!(!pred.check(req.head()));
437 }
438
439 #[test]
test_methods()440 fn test_methods() {
441 let req = TestRequest::default().to_http_request();
442 let req2 = TestRequest::default()
443 .method(Method::POST)
444 .to_http_request();
445
446 assert!(Get().check(req.head()));
447 assert!(!Get().check(req2.head()));
448 assert!(Post().check(req2.head()));
449 assert!(!Post().check(req.head()));
450
451 let r = TestRequest::default().method(Method::PUT).to_http_request();
452 assert!(Put().check(r.head()));
453 assert!(!Put().check(req.head()));
454
455 let r = TestRequest::default()
456 .method(Method::DELETE)
457 .to_http_request();
458 assert!(Delete().check(r.head()));
459 assert!(!Delete().check(req.head()));
460
461 let r = TestRequest::default()
462 .method(Method::HEAD)
463 .to_http_request();
464 assert!(Head().check(r.head()));
465 assert!(!Head().check(req.head()));
466
467 let r = TestRequest::default()
468 .method(Method::OPTIONS)
469 .to_http_request();
470 assert!(Options().check(r.head()));
471 assert!(!Options().check(req.head()));
472
473 let r = TestRequest::default()
474 .method(Method::CONNECT)
475 .to_http_request();
476 assert!(Connect().check(r.head()));
477 assert!(!Connect().check(req.head()));
478
479 let r = TestRequest::default()
480 .method(Method::PATCH)
481 .to_http_request();
482 assert!(Patch().check(r.head()));
483 assert!(!Patch().check(req.head()));
484
485 let r = TestRequest::default()
486 .method(Method::TRACE)
487 .to_http_request();
488 assert!(Trace().check(r.head()));
489 assert!(!Trace().check(req.head()));
490 }
491
492 #[test]
test_preds()493 fn test_preds() {
494 let r = TestRequest::default()
495 .method(Method::TRACE)
496 .to_http_request();
497
498 assert!(Not(Get()).check(r.head()));
499 assert!(!Not(Trace()).check(r.head()));
500
501 assert!(All(Trace()).and(Trace()).check(r.head()));
502 assert!(!All(Get()).and(Trace()).check(r.head()));
503
504 assert!(Any(Get()).or(Trace()).check(r.head()));
505 assert!(!Any(Get()).or(Get()).check(r.head()));
506 }
507 }
508