1 use std::io;
2 use std::fmt;
3 use std::ffi::CStr;
4 use std::fs::File;
5 use std::io::Read;
6 use std::path::{Path, PathBuf};
7 use std::env::current_dir;
8 use std::default::Default;
9 
10 use nix::mount::{MsFlags, mount};
11 
12 use {OSError, Error};
13 use util::path_to_cstring;
14 use explain::{Explainable, exists, user};
15 use mountinfo::{parse_mount_point};
16 
17 /// A remount definition
18 ///
19 /// Usually it is used to change mount flags for a mounted filesystem.
20 /// Especially to make a readonly filesystem writable or vice versa.
21 #[derive(Debug, Clone)]
22 pub struct Remount {
23     path: PathBuf,
24     flags: MountFlags,
25 }
26 
27 #[derive(Debug, Clone, Default)]
28 struct MountFlags {
29     pub bind: Option<bool>,
30     pub readonly: Option<bool>,
31     pub nodev: Option<bool>,
32     pub noexec: Option<bool>,
33     pub nosuid: Option<bool>,
34     pub noatime: Option<bool>,
35     pub nodiratime: Option<bool>,
36     pub relatime: Option<bool>,
37     pub strictatime: Option<bool>,
38     pub dirsync: Option<bool>,
39     pub synchronous: Option<bool>,
40     pub mandlock: Option<bool>,
41 }
42 
43 impl MountFlags {
apply_to_flags(&self, flags: MsFlags) -> MsFlags44     fn apply_to_flags(&self, flags: MsFlags) -> MsFlags {
45         let mut flags = flags;
46         flags = apply_flag(flags, MsFlags::MS_BIND, self.bind);
47         flags = apply_flag(flags, MsFlags::MS_RDONLY, self.readonly);
48         flags = apply_flag(flags, MsFlags::MS_NODEV, self.nodev);
49         flags = apply_flag(flags, MsFlags::MS_NOEXEC, self.noexec);
50         flags = apply_flag(flags, MsFlags::MS_NOSUID, self.nosuid);
51         flags = apply_flag(flags, MsFlags::MS_NOATIME, self.noatime);
52         flags = apply_flag(flags, MsFlags::MS_NODIRATIME, self.nodiratime);
53         flags = apply_flag(flags, MsFlags::MS_RELATIME, self.relatime);
54         flags = apply_flag(flags, MsFlags::MS_STRICTATIME, self.strictatime);
55         flags = apply_flag(flags, MsFlags::MS_DIRSYNC, self.dirsync);
56         flags = apply_flag(flags, MsFlags::MS_SYNCHRONOUS, self.synchronous);
57         flags = apply_flag(flags, MsFlags::MS_MANDLOCK, self.mandlock);
58         flags
59     }
60 }
61 
apply_flag(flags: MsFlags, flag: MsFlags, set: Option<bool>) -> MsFlags62 fn apply_flag(flags: MsFlags, flag: MsFlags, set: Option<bool>) -> MsFlags {
63     match set {
64         Some(true) => flags | flag,
65         Some(false) => flags & !flag,
66         None => flags,
67     }
68 }
69 
70 quick_error! {
71     #[derive(Debug)]
72     pub enum RemountError {
73         Io(msg: String, err: io::Error) {
74             cause(err)
75             display("{}: {}", msg, err)
76             description(err.description())
77             from(err: io::Error) -> (String::new(), err)
78         }
79         ParseMountInfo(err: String) {
80             display("{}", err)
81             from()
82         }
83         UnknownMountPoint(path: PathBuf) {
84             display("Cannot find mount point: {:?}", path)
85         }
86     }
87 }
88 
89 impl Remount {
90     /// Create a new Remount operation
91     ///
92     /// By default it doesn't modify any flags. So is basically useless, you
93     /// should set some flags to make it effective.
new<A: AsRef<Path>>(path: A) -> Remount94     pub fn new<A: AsRef<Path>>(path: A) -> Remount {
95         Remount {
96             path: path.as_ref().to_path_buf(),
97             flags: Default::default(),
98         }
99     }
100     /// Set bind flag
101     /// Note: remount readonly doesn't work without MS_BIND flag
102     /// inside unpriviledged user namespaces
bind(mut self, flag: bool) -> Remount103     pub fn bind(mut self, flag: bool) -> Remount {
104         self.flags.bind = Some(flag);
105         self
106     }
107     /// Set readonly flag
readonly(mut self, flag: bool) -> Remount108     pub fn readonly(mut self, flag: bool) -> Remount {
109         self.flags.readonly = Some(flag);
110         self
111     }
112     /// Set nodev flag
nodev(mut self, flag: bool) -> Remount113     pub fn nodev(mut self, flag: bool) -> Remount {
114         self.flags.nodev = Some(flag);
115         self
116     }
117     /// Set noexec flag
noexec(mut self, flag: bool) -> Remount118     pub fn noexec(mut self, flag: bool) -> Remount {
119         self.flags.noexec = Some(flag);
120         self
121     }
122     /// Set nosuid flag
nosuid(mut self, flag: bool) -> Remount123     pub fn nosuid(mut self, flag: bool) -> Remount {
124         self.flags.nosuid = Some(flag);
125         self
126     }
127     /// Set noatime flag
noatime(mut self, flag: bool) -> Remount128     pub fn noatime(mut self, flag: bool) -> Remount {
129         self.flags.noatime = Some(flag);
130         self
131     }
132     /// Set nodiratime flag
nodiratime(mut self, flag: bool) -> Remount133     pub fn nodiratime(mut self, flag: bool) -> Remount {
134         self.flags.nodiratime = Some(flag);
135         self
136     }
137     /// Set relatime flag
relatime(mut self, flag: bool) -> Remount138     pub fn relatime(mut self, flag: bool) -> Remount {
139         self.flags.relatime = Some(flag);
140         self
141     }
142     /// Set strictatime flag
strictatime(mut self, flag: bool) -> Remount143     pub fn strictatime(mut self, flag: bool) -> Remount {
144         self.flags.strictatime = Some(flag);
145         self
146     }
147     /// Set dirsync flag
dirsync(mut self, flag: bool) -> Remount148     pub fn dirsync(mut self, flag: bool) -> Remount {
149         self.flags.dirsync = Some(flag);
150         self
151     }
152     /// Set synchronous flag
synchronous(mut self, flag: bool) -> Remount153     pub fn synchronous(mut self, flag: bool) -> Remount {
154         self.flags.synchronous = Some(flag);
155         self
156     }
157     /// Set mandlock flag
mandlock(mut self, flag: bool) -> Remount158     pub fn mandlock(mut self, flag: bool) -> Remount {
159         self.flags.mandlock = Some(flag);
160         self
161     }
162 
163     /// Execute a remount
bare_remount(self) -> Result<(), OSError>164     pub fn bare_remount(self) -> Result<(), OSError> {
165         let mut flags = match get_mountpoint_flags(&self.path) {
166             Ok(flags) => flags,
167             Err(e) => {
168                 return Err(OSError::from_remount(e, Box::new(self)));
169             },
170         };
171         flags = self.flags.apply_to_flags(flags) | MsFlags::MS_REMOUNT;
172         mount(
173             None::<&CStr>,
174             &*path_to_cstring(&self.path),
175             None::<&CStr>,
176             flags,
177             None::<&CStr>,
178         ).map_err(|err| OSError::from_nix(err, Box::new(self)))
179     }
180 
181     /// Execute a remount and explain the error immediately
remount(self) -> Result<(), Error>182     pub fn remount(self) -> Result<(), Error> {
183         self.bare_remount().map_err(OSError::explain)
184     }
185 }
186 
187 impl fmt::Display for MountFlags {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result188     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
189         let mut prefix = "";
190         if let Some(true) = self.bind {
191             try!(write!(fmt, "{}bind", prefix));
192             prefix = ",";
193         }
194         if let Some(true) = self.readonly {
195             try!(write!(fmt, "{}ro", prefix));
196             prefix = ",";
197         }
198         if let Some(true) = self.nodev {
199             try!(write!(fmt, "{}nodev", prefix));
200             prefix = ",";
201         }
202         if let Some(true) = self.noexec {
203             try!(write!(fmt, "{}noexec", prefix));
204             prefix = ",";
205         }
206         if let Some(true) = self.nosuid {
207             try!(write!(fmt, "{}nosuid", prefix));
208             prefix = ",";
209         }
210         if let Some(true) = self.noatime {
211             try!(write!(fmt, "{}noatime", prefix));
212             prefix = ",";
213         }
214         if let Some(true) = self.nodiratime {
215             try!(write!(fmt, "{}nodiratime", prefix));
216             prefix = ",";
217         }
218         if let Some(true) = self.relatime {
219             try!(write!(fmt, "{}relatime", prefix));
220             prefix = ",";
221         }
222         if let Some(true) = self.strictatime {
223             try!(write!(fmt, "{}strictatime", prefix));
224             prefix = ",";
225         }
226         if let Some(true) = self.dirsync {
227             try!(write!(fmt, "{}dirsync", prefix));
228             prefix = ",";
229         }
230         if let Some(true) = self.synchronous {
231             try!(write!(fmt, "{}sync", prefix));
232             prefix = ",";
233         }
234         if let Some(true) = self.mandlock {
235             try!(write!(fmt, "{}mand", prefix));
236         }
237         Ok(())
238     }
239 }
240 
241 impl fmt::Display for Remount {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result242     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
243         if !self.flags.apply_to_flags(MsFlags::empty()).is_empty() {
244             try!(write!(fmt, "{} ", self.flags));
245         }
246         write!(fmt, "remount {:?}", &self.path)
247     }
248 }
249 
250 impl Explainable for Remount {
explain(&self) -> String251     fn explain(&self) -> String {
252         [
253             format!("path: {}", exists(&self.path)),
254             format!("{}", user()),
255         ].join(", ")
256     }
257 }
258 
get_mountpoint_flags(path: &Path) -> Result<MsFlags, RemountError>259 fn get_mountpoint_flags(path: &Path) -> Result<MsFlags, RemountError> {
260     let mount_path = if path.is_absolute() {
261         path.to_path_buf()
262     } else {
263         let mut mpath = try!(current_dir());
264         mpath.push(path);
265         mpath
266     };
267     let mut mountinfo_content = Vec::with_capacity(4 * 1024);
268     let mountinfo_path = Path::new("/proc/self/mountinfo");
269     let mut mountinfo_file = try!(File::open(mountinfo_path)
270         .map_err(|e| RemountError::Io(
271             format!("Cannot open file: {:?}", mountinfo_path), e)));
272     try!(mountinfo_file.read_to_end(&mut mountinfo_content)
273         .map_err(|e| RemountError::Io(
274             format!("Cannot read file: {:?}", mountinfo_path), e)));
275     match get_mountpoint_flags_from(&mountinfo_content, &mount_path) {
276         Ok(Some(flags)) => Ok(flags),
277         Ok(None) => Err(RemountError::UnknownMountPoint(mount_path)),
278         Err(e) => Err(e),
279     }
280 }
281 
get_mountpoint_flags_from(content: &[u8], path: &Path) -> Result<Option<MsFlags>, RemountError>282 fn get_mountpoint_flags_from(content: &[u8], path: &Path)
283     -> Result<Option<MsFlags>, RemountError>
284 {
285     // iterate from the end of the mountinfo file
286     for line in content.split(|c| *c == b'\n').rev() {
287         let entry = parse_mount_point(line)
288             .map_err(|e| RemountError::ParseMountInfo(e.0))?;
289         if let Some(mount_point) = entry {
290             if mount_point.mount_point == path {
291                 return Ok(Some(mount_point.get_mount_flags()));
292             }
293         }
294     }
295     Ok(None)
296 }
297 
298 #[cfg(test)]
299 mod test {
300     use std::path::Path;
301     use std::ffi::OsStr;
302     use std::os::unix::ffi::OsStrExt;
303 
304     use nix::mount::MsFlags;
305 
306     use Error;
307     use super::{Remount, RemountError, MountFlags};
308     use super::{get_mountpoint_flags, get_mountpoint_flags_from};
309 
310     #[test]
test_mount_flags()311     fn test_mount_flags() {
312         let flags = MountFlags {
313             bind: Some(true),
314             readonly: Some(true),
315             nodev: Some(true),
316             noexec: Some(true),
317             nosuid: Some(true),
318             noatime: Some(true),
319             nodiratime: Some(true),
320             relatime: Some(true),
321             strictatime: Some(true),
322             dirsync: Some(true),
323             synchronous: Some(true),
324             mandlock: Some(true),
325         };
326         let bits = (MsFlags::MS_BIND | MsFlags::MS_RDONLY | MsFlags::MS_NODEV | MsFlags::MS_NOEXEC | MsFlags::MS_NOSUID |
327             MsFlags::MS_NOATIME | MsFlags::MS_NODIRATIME | MsFlags::MS_RELATIME | MsFlags::MS_STRICTATIME |
328             MsFlags::MS_DIRSYNC | MsFlags::MS_SYNCHRONOUS | MsFlags::MS_MANDLOCK).bits();
329         assert_eq!(flags.apply_to_flags(MsFlags::empty()).bits(), bits);
330 
331         let flags = MountFlags {
332             bind: Some(false),
333             readonly: Some(false),
334             nodev: Some(false),
335             noexec: Some(false),
336             nosuid: Some(false),
337             noatime: Some(false),
338             nodiratime: Some(false),
339             relatime: Some(false),
340             strictatime: Some(false),
341             dirsync: Some(false),
342             synchronous: Some(false),
343             mandlock: Some(false),
344         };
345         assert_eq!(flags.apply_to_flags(MsFlags::from_bits_truncate(bits)).bits(), 0);
346 
347         let flags = MountFlags::default();
348         assert_eq!(flags.apply_to_flags(MsFlags::from_bits_truncate(0)).bits(), 0);
349         assert_eq!(flags.apply_to_flags(MsFlags::from_bits_truncate(bits)).bits(), bits);
350     }
351 
352     #[test]
test_remount()353     fn test_remount() {
354         let remount = Remount::new("/");
355         assert_eq!(format!("{}", remount), "remount \"/\"");
356 
357         let remount = Remount::new("/").readonly(true).nodev(true);
358         assert_eq!(format!("{}", remount), "ro,nodev remount \"/\"");
359     }
360 
361     #[test]
test_get_mountpoint_flags_from()362     fn test_get_mountpoint_flags_from() {
363         let content = b"19 24 0:4 / /proc rw,nosuid,nodev,noexec,relatime shared:12 - proc proc rw";
364         let flags = get_mountpoint_flags_from(&content[..], Path::new("/proc")).unwrap().unwrap();
365         assert_eq!(flags, MsFlags::MS_NODEV | MsFlags::MS_NOEXEC | MsFlags::MS_NOSUID | MsFlags::MS_RELATIME);
366     }
367 
368     #[test]
test_get_mountpoint_flags_from_dups()369     fn test_get_mountpoint_flags_from_dups() {
370         let content = b"11 18 0:4 / /tmp rw shared:28 - tmpfs tmpfs rw\n\
371                         12 18 0:6 / /tmp rw,nosuid,nodev shared:29 - tmpfs tmpfs rw\n";
372         let flags = get_mountpoint_flags_from(&content[..], Path::new("/tmp")).unwrap().unwrap();
373         assert_eq!(flags, MsFlags::MS_NOSUID | MsFlags::MS_NODEV);
374     }
375 
376     #[test]
test_get_mountpoint_flags()377     fn test_get_mountpoint_flags() {
378         assert!(get_mountpoint_flags(Path::new("/")).is_ok());
379     }
380 
381     #[test]
test_get_mountpoint_flags_unknown()382     fn test_get_mountpoint_flags_unknown() {
383         let mount_point = Path::new(OsStr::from_bytes(b"/\xff"));
384         let error = get_mountpoint_flags(mount_point).unwrap_err();
385         match error {
386             RemountError::UnknownMountPoint(p) => assert_eq!(p, mount_point),
387             _ => panic!(),
388         }
389     }
390 
391     #[test]
test_remount_unknown_mountpoint()392     fn test_remount_unknown_mountpoint() {
393         let remount = Remount::new("/non-existent");
394         let error = remount.remount().unwrap_err();
395         let Error(_, e, msg) = error;
396         match e.get_ref() {
397             Some(e) => {
398                 assert_eq!(
399                    e.to_string(),
400                    "Cannot find mount point: \"/non-existent\"");
401             },
402             _ => panic!(),
403         }
404         assert!(msg.starts_with(
405             "Cannot find mount point: \"/non-existent\", path: missing, "));
406     }
407 }
408