1 use clap::{Parser, ValueHint};
2 use clap_generate::Shell;
3 use http::header::{HeaderMap, HeaderName, HeaderValue};
4 use std::net::IpAddr;
5 use std::path::PathBuf;
6 
7 use crate::auth;
8 use crate::errors::ContextualError;
9 use crate::renderer;
10 
11 #[derive(Parser)]
12 #[clap(name = "miniserve", author, about, version)]
13 pub struct CliArgs {
14     /// Be verbose, includes emitting access logs
15     #[clap(short = 'v', long = "verbose")]
16     pub verbose: bool,
17 
18     /// Which path to serve
19     #[clap(name = "PATH", parse(from_os_str), value_hint = ValueHint::AnyPath)]
20     pub path: Option<PathBuf>,
21 
22     /// The name of a directory index file to serve, like "index.html"
23     ///
24     /// Normally, when miniserve serves a directory, it creates a listing for that directory.
25     /// However, if a directory contains this file, miniserve will serve that file instead.
26     #[clap(long, parse(from_os_str), name = "index_file", value_hint = ValueHint::FilePath)]
27     pub index: Option<PathBuf>,
28 
29     /// Activate SPA (Single Page Application) mode
30     ///
31     /// This will cause the file given by --index to be served for all non-existing file paths. In
32     /// effect, this will serve the index file whenever a 404 would otherwise occur in order to
33     /// allow the SPA router to handle the request instead.
34     #[clap(long, requires = "index_file")]
35     pub spa: bool,
36 
37     /// Port to use
38     #[clap(short = 'p', long = "port", default_value = "8080")]
39     pub port: u16,
40 
41     /// Interface to listen on
42     #[clap(
43         short = 'i',
44         long = "interfaces",
45         parse(try_from_str = parse_interface),
46         multiple_occurrences(true),
47         number_of_values = 1,
48     )]
49     pub interfaces: Vec<IpAddr>,
50 
51     /// Set authentication. Currently supported formats:
52     /// username:password, username:sha256:hash, username:sha512:hash
53     /// (e.g. joe:123, joe:sha256:a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3)
54     #[clap(
55         short = 'a',
56         long = "auth",
57         parse(try_from_str = parse_auth),
58         multiple_occurrences(true),
59         number_of_values = 1,
60     )]
61     pub auth: Vec<auth::RequiredAuth>,
62 
63     /// Generate a random 6-hexdigit route
64     #[clap(long = "random-route")]
65     pub random_route: bool,
66 
67     /// Do not follow symbolic links
68     #[clap(short = 'P', long = "no-symlinks")]
69     pub no_symlinks: bool,
70 
71     /// Show hidden files
72     #[clap(short = 'H', long = "hidden")]
73     pub hidden: bool,
74 
75     /// Default color scheme
76     #[clap(
77         short = 'c',
78         long = "color-scheme",
79         default_value = "squirrel",
80         possible_values = &*renderer::THEME_SLUGS,
81         case_insensitive = true,
82     )]
83     pub color_scheme: String,
84 
85     /// Default color scheme
86     #[clap(
87         short = 'd',
88         long = "color-scheme-dark",
89         default_value = "archlinux",
90         possible_values = &*renderer::THEME_SLUGS,
91         case_insensitive = true,
92     )]
93     pub color_scheme_dark: String,
94 
95     /// Enable QR code display
96     #[clap(short = 'q', long = "qrcode")]
97     pub qrcode: bool,
98 
99     /// Enable file uploading
100     #[clap(short = 'u', long = "upload-files")]
101     pub file_upload: bool,
102 
103     /// Enable overriding existing files during file upload
104     #[clap(short = 'o', long = "overwrite-files")]
105     pub overwrite_files: bool,
106 
107     /// Enable uncompressed tar archive generation
108     #[clap(short = 'r', long = "enable-tar")]
109     pub enable_tar: bool,
110 
111     /// Enable gz-compressed tar archive generation
112     #[clap(short = 'g', long = "enable-tar-gz")]
113     pub enable_tar_gz: bool,
114 
115     /// Enable zip archive generation
116     ///
117     /// WARNING: Zipping large directories can result in out-of-memory exception
118     /// because zip generation is done in memory and cannot be sent on the fly
119     #[clap(short = 'z', long = "enable-zip")]
120     pub enable_zip: bool,
121 
122     /// List directories first
123     #[clap(short = 'D', long = "dirs-first")]
124     pub dirs_first: bool,
125 
126     /// Shown instead of host in page title and heading
127     #[clap(short = 't', long = "title")]
128     pub title: Option<String>,
129 
130     /// Set custom header for responses
131     #[clap(
132         long = "header",
133         parse(try_from_str = parse_header),
134         multiple_occurrences(true),
135         number_of_values = 1
136     )]
137     pub header: Vec<HeaderMap>,
138 
139     /// Show symlink info
140     #[clap(short = 'l', long = "show-symlink-info")]
141     pub show_symlink_info: bool,
142 
143     /// Hide version footer
144     #[clap(short = 'F', long = "hide-version-footer")]
145     pub hide_version_footer: bool,
146 
147     /// If enabled, display a wget command to recursively download the current directory
148     #[clap(short = 'W', long = "show-wget-footer")]
149     pub show_wget_footer: bool,
150 
151     /// Generate completion file for a shell
152     #[clap(long = "print-completions", value_name = "shell", arg_enum)]
153     pub print_completions: Option<Shell>,
154 
155     /// TLS certificate to use
156     #[cfg(feature = "tls")]
157     #[clap(long = "tls-cert", requires = "tls-key", value_hint = ValueHint::FilePath)]
158     pub tls_cert: Option<PathBuf>,
159 
160     /// TLS private key to use
161     #[cfg(feature = "tls")]
162     #[clap(long = "tls-key", requires = "tls-cert", value_hint = ValueHint::FilePath)]
163     pub tls_key: Option<PathBuf>,
164 }
165 
166 /// Checks wether an interface is valid, i.e. it can be parsed into an IP address
parse_interface(src: &str) -> Result<IpAddr, std::net::AddrParseError>167 fn parse_interface(src: &str) -> Result<IpAddr, std::net::AddrParseError> {
168     src.parse::<IpAddr>()
169 }
170 
171 /// Parse authentication requirement
parse_auth(src: &str) -> Result<auth::RequiredAuth, ContextualError>172 fn parse_auth(src: &str) -> Result<auth::RequiredAuth, ContextualError> {
173     let mut split = src.splitn(3, ':');
174     let invalid_auth_format = Err(ContextualError::InvalidAuthFormat);
175 
176     let username = match split.next() {
177         Some(username) => username,
178         None => return invalid_auth_format,
179     };
180 
181     // second_part is either password in username:password or method in username:method:hash
182     let second_part = match split.next() {
183         // This allows empty passwords, as the spec does not forbid it
184         Some(password) => password,
185         None => return invalid_auth_format,
186     };
187 
188     let password = if let Some(hash_hex) = split.next() {
189         let hash_bin = hex::decode(hash_hex).map_err(|_| ContextualError::InvalidPasswordHash)?;
190 
191         match second_part {
192             "sha256" => auth::RequiredAuthPassword::Sha256(hash_bin),
193             "sha512" => auth::RequiredAuthPassword::Sha512(hash_bin),
194             _ => return Err(ContextualError::InvalidHashMethod(second_part.to_owned())),
195         }
196     } else {
197         // To make it Windows-compatible, the password needs to be shorter than 255 characters.
198         // After 255 characters, Windows will truncate the value.
199         // As for the username, the spec does not mention a limit in length
200         if second_part.len() > 255 {
201             return Err(ContextualError::PasswordTooLongError);
202         }
203 
204         auth::RequiredAuthPassword::Plain(second_part.to_owned())
205     };
206 
207     Ok(auth::RequiredAuth {
208         username: username.to_owned(),
209         password,
210     })
211 }
212 
213 /// Custom header parser (allow multiple headers input)
parse_header(src: &str) -> Result<HeaderMap, httparse::Error>214 pub fn parse_header(src: &str) -> Result<HeaderMap, httparse::Error> {
215     let mut headers = [httparse::EMPTY_HEADER; 1];
216     let header = format!("{}\n", src);
217     httparse::parse_headers(header.as_bytes(), &mut headers)?;
218 
219     let mut header_map = HeaderMap::new();
220     if let Some(h) = headers.first() {
221         if h.name != httparse::EMPTY_HEADER.name {
222             header_map.insert(
223                 HeaderName::from_bytes(h.name.as_bytes()).unwrap(),
224                 HeaderValue::from_bytes(h.value).unwrap(),
225             );
226         }
227     }
228 
229     Ok(header_map)
230 }
231 
232 #[rustfmt::skip]
233 #[cfg(test)]
234 mod tests {
235     use super::*;
236     use rstest::rstest;
237     use pretty_assertions::assert_eq;
238 
239     /// Helper function that creates a `RequiredAuth` structure
create_required_auth(username: &str, password: &str, encrypt: &str) -> auth::RequiredAuth240     fn create_required_auth(username: &str, password: &str, encrypt: &str) -> auth::RequiredAuth {
241         use auth::*;
242         use RequiredAuthPassword::*;
243 
244         let password = match encrypt {
245             "plain" => Plain(password.to_owned()),
246             "sha256" => Sha256(hex::decode(password.to_owned()).unwrap()),
247             "sha512" => Sha512(hex::decode(password.to_owned()).unwrap()),
248             _ => panic!("Unknown encryption type"),
249         };
250 
251         auth::RequiredAuth {
252             username: username.to_owned(),
253             password,
254         }
255     }
256 
257     #[rstest(
258         auth_string, username, password, encrypt,
259         case("username:password", "username", "password", "plain"),
260         case("username:sha256:abcd", "username", "abcd", "sha256"),
261         case("username:sha512:abcd", "username", "abcd", "sha512")
262     )]
parse_auth_valid(auth_string: &str, username: &str, password: &str, encrypt: &str)263     fn parse_auth_valid(auth_string: &str, username: &str, password: &str, encrypt: &str) {
264         assert_eq!(
265             parse_auth(auth_string).unwrap(),
266             create_required_auth(username, password, encrypt),
267         );
268     }
269 
270     #[rstest(
271         auth_string, err_msg,
272         case(
273             "foo",
274             "Invalid format for credentials string. Expected username:password, username:sha256:hash or username:sha512:hash"
275         ),
276         case(
277             "username:blahblah:abcd",
278             "blahblah is not a valid hashing method. Expected sha256 or sha512"
279         ),
280         case(
281             "username:sha256:invalid",
282             "Invalid format for password hash. Expected hex code"
283         ),
284         case(
285             "username:sha512:invalid",
286             "Invalid format for password hash. Expected hex code"
287         ),
288     )]
parse_auth_invalid(auth_string: &str, err_msg: &str)289     fn parse_auth_invalid(auth_string: &str, err_msg: &str) {
290         let err = parse_auth(auth_string).unwrap_err();
291         assert_eq!(format!("{}", err), err_msg.to_owned());
292     }
293 }
294