1 //
2 // Web Headers and caching
3 //
4 use std::io::Cursor;
5 
6 use rocket::{
7     fairing::{Fairing, Info, Kind},
8     http::{ContentType, Header, HeaderMap, Method, RawStr, Status},
9     request::FromParam,
10     response::{self, Responder},
11     Data, Request, Response, Rocket,
12 };
13 
14 use crate::CONFIG;
15 
16 pub struct AppHeaders();
17 
18 impl Fairing for AppHeaders {
info(&self) -> Info19     fn info(&self) -> Info {
20         Info {
21             name: "Application Headers",
22             kind: Kind::Response,
23         }
24     }
25 
on_response(&self, _req: &Request, res: &mut Response)26     fn on_response(&self, _req: &Request, res: &mut Response) {
27         res.set_raw_header("Feature-Policy", "accelerometer 'none'; ambient-light-sensor 'none'; autoplay 'none'; camera 'none'; encrypted-media 'none'; fullscreen 'none'; geolocation 'none'; gyroscope 'none'; magnetometer 'none'; microphone 'none'; midi 'none'; payment 'none'; picture-in-picture 'none'; sync-xhr 'self' https://haveibeenpwned.com https://2fa.directory; usb 'none'; vr 'none'");
28         res.set_raw_header("Referrer-Policy", "same-origin");
29         res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
30         res.set_raw_header("X-Content-Type-Options", "nosniff");
31         res.set_raw_header("X-XSS-Protection", "1; mode=block");
32         let csp = format!(
33             // Chrome Web Store: https://chrome.google.com/webstore/detail/bitwarden-free-password-m/nngceckbapebfimnlniiiahkandclblb
34             // Edge Add-ons: https://microsoftedge.microsoft.com/addons/detail/bitwarden-free-password/jbkfoedolllekgbhcbcoahefnbanhhlh?hl=en-US
35             // Firefox Browser Add-ons: https://addons.mozilla.org/en-US/firefox/addon/bitwarden-password-manager/
36             "frame-ancestors 'self' chrome-extension://nngceckbapebfimnlniiiahkandclblb chrome-extension://jbkfoedolllekgbhcbcoahefnbanhhlh moz-extension://* {};",
37             CONFIG.allowed_iframe_ancestors()
38         );
39         res.set_raw_header("Content-Security-Policy", csp);
40 
41         // Disable cache unless otherwise specified
42         if !res.headers().contains("cache-control") {
43             res.set_raw_header("Cache-Control", "no-cache, no-store, max-age=0");
44         }
45     }
46 }
47 
48 pub struct Cors();
49 
50 impl Cors {
get_header(headers: &HeaderMap, name: &str) -> String51     fn get_header(headers: &HeaderMap, name: &str) -> String {
52         match headers.get_one(name) {
53             Some(h) => h.to_string(),
54             _ => "".to_string(),
55         }
56     }
57 
58     // Check a request's `Origin` header against the list of allowed origins.
59     // If a match exists, return it. Otherwise, return None.
get_allowed_origin(headers: &HeaderMap) -> Option<String>60     fn get_allowed_origin(headers: &HeaderMap) -> Option<String> {
61         let origin = Cors::get_header(headers, "Origin");
62         let domain_origin = CONFIG.domain_origin();
63         let safari_extension_origin = "file://";
64         if origin == domain_origin || origin == safari_extension_origin {
65             Some(origin)
66         } else {
67             None
68         }
69     }
70 }
71 
72 impl Fairing for Cors {
info(&self) -> Info73     fn info(&self) -> Info {
74         Info {
75             name: "Cors",
76             kind: Kind::Response,
77         }
78     }
79 
on_response(&self, request: &Request, response: &mut Response)80     fn on_response(&self, request: &Request, response: &mut Response) {
81         let req_headers = request.headers();
82 
83         if let Some(origin) = Cors::get_allowed_origin(req_headers) {
84             response.set_header(Header::new("Access-Control-Allow-Origin", origin));
85         }
86 
87         // Preflight request
88         if request.method() == Method::Options {
89             let req_allow_headers = Cors::get_header(req_headers, "Access-Control-Request-Headers");
90             let req_allow_method = Cors::get_header(req_headers, "Access-Control-Request-Method");
91 
92             response.set_header(Header::new("Access-Control-Allow-Methods", req_allow_method));
93             response.set_header(Header::new("Access-Control-Allow-Headers", req_allow_headers));
94             response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
95             response.set_status(Status::Ok);
96             response.set_header(ContentType::Plain);
97             response.set_sized_body(Cursor::new(""));
98         }
99     }
100 }
101 
102 pub struct Cached<R>(R, String);
103 
104 impl<R> Cached<R> {
long(r: R) -> Cached<R>105     pub fn long(r: R) -> Cached<R> {
106         // 7 days
107         Self::ttl(r, 604800)
108     }
109 
short(r: R) -> Cached<R>110     pub fn short(r: R) -> Cached<R> {
111         // 10 minutes
112         Self(r, String::from("public, max-age=600"))
113     }
114 
ttl(r: R, ttl: u64) -> Cached<R>115     pub fn ttl(r: R, ttl: u64) -> Cached<R> {
116         Self(r, format!("public, immutable, max-age={}", ttl))
117     }
118 }
119 
120 impl<'r, R: Responder<'r>> Responder<'r> for Cached<R> {
respond_to(self, req: &Request) -> response::Result<'r>121     fn respond_to(self, req: &Request) -> response::Result<'r> {
122         match self.0.respond_to(req) {
123             Ok(mut res) => {
124                 res.set_raw_header("Cache-Control", self.1);
125                 Ok(res)
126             }
127             e @ Err(_) => e,
128         }
129     }
130 }
131 
132 pub struct SafeString(String);
133 
134 impl std::fmt::Display for SafeString {
fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result135     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
136         self.0.fmt(f)
137     }
138 }
139 
140 impl AsRef<Path> for SafeString {
141     #[inline]
as_ref(&self) -> &Path142     fn as_ref(&self) -> &Path {
143         Path::new(&self.0)
144     }
145 }
146 
147 impl<'r> FromParam<'r> for SafeString {
148     type Error = ();
149 
150     #[inline(always)]
from_param(param: &'r RawStr) -> Result<Self, Self::Error>151     fn from_param(param: &'r RawStr) -> Result<Self, Self::Error> {
152         let s = param.percent_decode().map(|cow| cow.into_owned()).map_err(|_| ())?;
153 
154         if s.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
155             Ok(SafeString(s))
156         } else {
157             Err(())
158         }
159     }
160 }
161 
162 // Log all the routes from the main paths list, and the attachments endpoint
163 // Effectively ignores, any static file route, and the alive endpoint
164 const LOGGED_ROUTES: [&str; 6] =
165     ["/api", "/admin", "/identity", "/icons", "/notifications/hub/negotiate", "/attachments"];
166 
167 // Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts
168 pub struct BetterLogging(pub bool);
169 impl Fairing for BetterLogging {
info(&self) -> Info170     fn info(&self) -> Info {
171         Info {
172             name: "Better Logging",
173             kind: Kind::Launch | Kind::Request | Kind::Response,
174         }
175     }
176 
on_launch(&self, rocket: &Rocket)177     fn on_launch(&self, rocket: &Rocket) {
178         if self.0 {
179             info!(target: "routes", "Routes loaded:");
180             let mut routes: Vec<_> = rocket.routes().collect();
181             routes.sort_by_key(|r| r.uri.path());
182             for route in routes {
183                 if route.rank < 0 {
184                     info!(target: "routes", "{:<6} {}", route.method, route.uri);
185                 } else {
186                     info!(target: "routes", "{:<6} {} [{}]", route.method, route.uri, route.rank);
187                 }
188             }
189         }
190 
191         let config = rocket.config();
192         let scheme = if config.tls_enabled() {
193             "https"
194         } else {
195             "http"
196         };
197         let addr = format!("{}://{}:{}", &scheme, &config.address, &config.port);
198         info!(target: "start", "Rocket has launched from {}", addr);
199     }
200 
on_request(&self, request: &mut Request<'_>, _data: &Data)201     fn on_request(&self, request: &mut Request<'_>, _data: &Data) {
202         let method = request.method();
203         if !self.0 && method == Method::Options {
204             return;
205         }
206         let uri = request.uri();
207         let uri_path = uri.path();
208         let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
209         if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
210             match uri.query() {
211                 Some(q) => info!(target: "request", "{} {}?{}", method, uri_path, &q[..q.len().min(30)]),
212                 None => info!(target: "request", "{} {}", method, uri_path),
213             };
214         }
215     }
216 
on_response(&self, request: &Request, response: &mut Response)217     fn on_response(&self, request: &Request, response: &mut Response) {
218         if !self.0 && request.method() == Method::Options {
219             return;
220         }
221         let uri_path = request.uri().path();
222         let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
223         if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
224             let status = response.status();
225             if let Some(route) = request.route() {
226                 info!(target: "response", "{} => {} {}", route, status.code, status.reason)
227             } else {
228                 info!(target: "response", "{} {}", status.code, status.reason)
229             }
230         }
231     }
232 }
233 
234 //
235 // File handling
236 //
237 use std::{
238     fs::{self, File},
239     io::{Read, Result as IOResult},
240     path::Path,
241 };
242 
file_exists(path: &str) -> bool243 pub fn file_exists(path: &str) -> bool {
244     Path::new(path).exists()
245 }
246 
read_file(path: &str) -> IOResult<Vec<u8>>247 pub fn read_file(path: &str) -> IOResult<Vec<u8>> {
248     let mut contents: Vec<u8> = Vec::new();
249 
250     let mut file = File::open(Path::new(path))?;
251     file.read_to_end(&mut contents)?;
252 
253     Ok(contents)
254 }
255 
write_file(path: &str, content: &[u8]) -> Result<(), crate::error::Error>256 pub fn write_file(path: &str, content: &[u8]) -> Result<(), crate::error::Error> {
257     use std::io::Write;
258     let mut f = File::create(path)?;
259     f.write_all(content)?;
260     f.flush()?;
261     Ok(())
262 }
263 
read_file_string(path: &str) -> IOResult<String>264 pub fn read_file_string(path: &str) -> IOResult<String> {
265     let mut contents = String::new();
266 
267     let mut file = File::open(Path::new(path))?;
268     file.read_to_string(&mut contents)?;
269 
270     Ok(contents)
271 }
272 
delete_file(path: &str) -> IOResult<()>273 pub fn delete_file(path: &str) -> IOResult<()> {
274     let res = fs::remove_file(path);
275 
276     if let Some(parent) = Path::new(path).parent() {
277         // If the directory isn't empty, this returns an error, which we ignore
278         // We only want to delete the folder if it's empty
279         fs::remove_dir(parent).ok();
280     }
281 
282     res
283 }
284 
get_display_size(size: i32) -> String285 pub fn get_display_size(size: i32) -> String {
286     const UNITS: [&str; 6] = ["bytes", "KB", "MB", "GB", "TB", "PB"];
287 
288     let mut size: f64 = size.into();
289     let mut unit_counter = 0;
290 
291     loop {
292         if size > 1024. {
293             size /= 1024.;
294             unit_counter += 1;
295         } else {
296             break;
297         }
298     }
299 
300     format!("{:.2} {}", size, UNITS[unit_counter])
301 }
302 
get_uuid() -> String303 pub fn get_uuid() -> String {
304     uuid::Uuid::new_v4().to_string()
305 }
306 
307 //
308 // String util methods
309 //
310 
311 use std::str::FromStr;
312 
upcase_first(s: &str) -> String313 pub fn upcase_first(s: &str) -> String {
314     let mut c = s.chars();
315     match c.next() {
316         None => String::new(),
317         Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
318     }
319 }
320 
try_parse_string<S, T>(string: Option<S>) -> Option<T> where S: AsRef<str>, T: FromStr,321 pub fn try_parse_string<S, T>(string: Option<S>) -> Option<T>
322 where
323     S: AsRef<str>,
324     T: FromStr,
325 {
326     if let Some(Ok(value)) = string.map(|s| s.as_ref().parse::<T>()) {
327         Some(value)
328     } else {
329         None
330     }
331 }
332 
333 //
334 // Env methods
335 //
336 
337 use std::env;
338 
get_env_str_value(key: &str) -> Option<String>339 pub fn get_env_str_value(key: &str) -> Option<String> {
340     let key_file = format!("{}_FILE", key);
341     let value_from_env = env::var(key);
342     let value_file = env::var(&key_file);
343 
344     match (value_from_env, value_file) {
345         (Ok(_), Ok(_)) => panic!("You should not define both {} and {}!", key, key_file),
346         (Ok(v_env), Err(_)) => Some(v_env),
347         (Err(_), Ok(v_file)) => match fs::read_to_string(v_file) {
348             Ok(content) => Some(content.trim().to_string()),
349             Err(e) => panic!("Failed to load {}: {:?}", key, e),
350         },
351         _ => None,
352     }
353 }
354 
get_env<V>(key: &str) -> Option<V> where V: FromStr,355 pub fn get_env<V>(key: &str) -> Option<V>
356 where
357     V: FromStr,
358 {
359     try_parse_string(get_env_str_value(key))
360 }
361 
get_env_bool(key: &str) -> Option<bool>362 pub fn get_env_bool(key: &str) -> Option<bool> {
363     const TRUE_VALUES: &[&str] = &["true", "t", "yes", "y", "1"];
364     const FALSE_VALUES: &[&str] = &["false", "f", "no", "n", "0"];
365 
366     match get_env_str_value(key) {
367         Some(val) if TRUE_VALUES.contains(&val.to_lowercase().as_ref()) => Some(true),
368         Some(val) if FALSE_VALUES.contains(&val.to_lowercase().as_ref()) => Some(false),
369         _ => None,
370     }
371 }
372 
373 //
374 // Date util methods
375 //
376 
377 use chrono::{DateTime, Local, NaiveDateTime, TimeZone};
378 
379 /// Formats a UTC-offset `NaiveDateTime` in the format used by Bitwarden API
380 /// responses with "date" fields (`CreationDate`, `RevisionDate`, etc.).
format_date(dt: &NaiveDateTime) -> String381 pub fn format_date(dt: &NaiveDateTime) -> String {
382     dt.format("%Y-%m-%dT%H:%M:%S%.6fZ").to_string()
383 }
384 
385 /// Formats a `DateTime<Local>` using the specified format string.
386 ///
387 /// For a `DateTime<Local>`, the `%Z` specifier normally formats as the
388 /// time zone's UTC offset (e.g., `+00:00`). In this function, if the
389 /// `TZ` environment variable is set, then `%Z` instead formats as the
390 /// abbreviation for that time zone (e.g., `UTC`).
format_datetime_local(dt: &DateTime<Local>, fmt: &str) -> String391 pub fn format_datetime_local(dt: &DateTime<Local>, fmt: &str) -> String {
392     // Try parsing the `TZ` environment variable to enable formatting `%Z` as
393     // a time zone abbreviation.
394     if let Ok(tz) = env::var("TZ") {
395         if let Ok(tz) = tz.parse::<chrono_tz::Tz>() {
396             return dt.with_timezone(&tz).format(fmt).to_string();
397         }
398     }
399 
400     // Otherwise, fall back to formatting `%Z` as a UTC offset.
401     dt.format(fmt).to_string()
402 }
403 
404 /// Formats a UTC-offset `NaiveDateTime` as a datetime in the local time zone.
405 ///
406 /// This function basically converts the `NaiveDateTime` to a `DateTime<Local>`,
407 /// and then calls [format_datetime_local](crate::util::format_datetime_local).
format_naive_datetime_local(dt: &NaiveDateTime, fmt: &str) -> String408 pub fn format_naive_datetime_local(dt: &NaiveDateTime, fmt: &str) -> String {
409     format_datetime_local(&Local.from_utc_datetime(dt), fmt)
410 }
411 
412 //
413 // Deployment environment methods
414 //
415 
416 /// Returns true if the program is running in Docker or Podman.
is_running_in_docker() -> bool417 pub fn is_running_in_docker() -> bool {
418     Path::new("/.dockerenv").exists() || Path::new("/run/.containerenv").exists()
419 }
420 
421 /// Simple check to determine on which docker base image vaultwarden is running.
422 /// We build images based upon Debian or Alpine, so these we check here.
docker_base_image() -> String423 pub fn docker_base_image() -> String {
424     if Path::new("/etc/debian_version").exists() {
425         "Debian".to_string()
426     } else if Path::new("/etc/alpine-release").exists() {
427         "Alpine".to_string()
428     } else {
429         "Unknown".to_string()
430     }
431 }
432 
433 //
434 // Deserialization methods
435 //
436 
437 use std::fmt;
438 
439 use serde::de::{self, DeserializeOwned, Deserializer, MapAccess, SeqAccess, Visitor};
440 use serde_json::{self, Value};
441 
442 pub type JsonMap = serde_json::Map<String, Value>;
443 
444 #[derive(Serialize, Deserialize)]
445 pub struct UpCase<T: DeserializeOwned> {
446     #[serde(deserialize_with = "upcase_deserialize")]
447     #[serde(flatten)]
448     pub data: T,
449 }
450 
451 // https://github.com/serde-rs/serde/issues/586
upcase_deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error> where T: DeserializeOwned, D: Deserializer<'de>,452 pub fn upcase_deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
453 where
454     T: DeserializeOwned,
455     D: Deserializer<'de>,
456 {
457     let d = deserializer.deserialize_any(UpCaseVisitor)?;
458     T::deserialize(d).map_err(de::Error::custom)
459 }
460 
461 struct UpCaseVisitor;
462 
463 impl<'de> Visitor<'de> for UpCaseVisitor {
464     type Value = Value;
465 
expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result466     fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
467         formatter.write_str("an object or an array")
468     }
469 
visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error> where A: MapAccess<'de>,470     fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
471     where
472         A: MapAccess<'de>,
473     {
474         let mut result_map = JsonMap::new();
475 
476         while let Some((key, value)) = map.next_entry()? {
477             result_map.insert(upcase_first(key), upcase_value(value));
478         }
479 
480         Ok(Value::Object(result_map))
481     }
482 
visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error> where A: SeqAccess<'de>,483     fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
484     where
485         A: SeqAccess<'de>,
486     {
487         let mut result_seq = Vec::<Value>::new();
488 
489         while let Some(value) = seq.next_element()? {
490             result_seq.push(upcase_value(value));
491         }
492 
493         Ok(Value::Array(result_seq))
494     }
495 }
496 
upcase_value(value: Value) -> Value497 fn upcase_value(value: Value) -> Value {
498     if let Value::Object(map) = value {
499         let mut new_value = json!({});
500 
501         for (key, val) in map.into_iter() {
502             let processed_key = _process_key(&key);
503             new_value[processed_key] = upcase_value(val);
504         }
505         new_value
506     } else if let Value::Array(array) = value {
507         // Initialize array with null values
508         let mut new_value = json!(vec![Value::Null; array.len()]);
509 
510         for (index, val) in array.into_iter().enumerate() {
511             new_value[index] = upcase_value(val);
512         }
513         new_value
514     } else {
515         value
516     }
517 }
518 
519 // Inner function to handle some speciale case for the 'ssn' key.
520 // This key is part of the Identity Cipher (Social Security Number)
_process_key(key: &str) -> String521 fn _process_key(key: &str) -> String {
522     match key.to_lowercase().as_ref() {
523         "ssn" => "SSN".into(),
524         _ => self::upcase_first(key),
525     }
526 }
527 
528 //
529 // Retry methods
530 //
531 
retry<F, T, E>(func: F, max_tries: u32) -> Result<T, E> where F: Fn() -> Result<T, E>,532 pub fn retry<F, T, E>(func: F, max_tries: u32) -> Result<T, E>
533 where
534     F: Fn() -> Result<T, E>,
535 {
536     let mut tries = 0;
537 
538     loop {
539         match func() {
540             ok @ Ok(_) => return ok,
541             err @ Err(_) => {
542                 tries += 1;
543 
544                 if tries >= max_tries {
545                     return err;
546                 }
547 
548                 sleep(Duration::from_millis(500));
549             }
550         }
551     }
552 }
553 
554 use std::{thread::sleep, time::Duration};
555 
retry_db<F, T, E>(func: F, max_tries: u32) -> Result<T, E> where F: Fn() -> Result<T, E>, E: std::error::Error,556 pub fn retry_db<F, T, E>(func: F, max_tries: u32) -> Result<T, E>
557 where
558     F: Fn() -> Result<T, E>,
559     E: std::error::Error,
560 {
561     let mut tries = 0;
562 
563     loop {
564         match func() {
565             ok @ Ok(_) => return ok,
566             Err(e) => {
567                 tries += 1;
568 
569                 if tries >= max_tries && max_tries > 0 {
570                     return Err(e);
571                 }
572 
573                 warn!("Can't connect to database, retrying: {:?}", e);
574 
575                 sleep(Duration::from_millis(1_000));
576             }
577         }
578     }
579 }
580 
581 use reqwest::{
582     blocking::{Client, ClientBuilder},
583     header,
584 };
585 
get_reqwest_client() -> Client586 pub fn get_reqwest_client() -> Client {
587     get_reqwest_client_builder().build().expect("Failed to build client")
588 }
589 
get_reqwest_client_builder() -> ClientBuilder590 pub fn get_reqwest_client_builder() -> ClientBuilder {
591     let mut headers = header::HeaderMap::new();
592     headers.insert(header::USER_AGENT, header::HeaderValue::from_static("Vaultwarden"));
593     Client::builder().default_headers(headers).timeout(Duration::from_secs(10))
594 }
595