1 // Inspired by Clang's clang-format-diff:
2 //
3 // https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/clang-format-diff.py
4 
5 #![deny(warnings)]
6 
7 #[macro_use]
8 extern crate log;
9 
10 use serde::{Deserialize, Serialize};
11 use serde_json as json;
12 use thiserror::Error;
13 
14 use std::collections::HashSet;
15 use std::env;
16 use std::ffi::OsStr;
17 use std::io::{self, BufRead};
18 use std::process;
19 
20 use regex::Regex;
21 
22 use structopt::clap::AppSettings;
23 use structopt::StructOpt;
24 
25 /// The default pattern of files to format.
26 ///
27 /// We only want to format rust files by default.
28 const DEFAULT_PATTERN: &str = r".*\.rs";
29 
30 #[derive(Error, Debug)]
31 enum FormatDiffError {
32     #[error("{0}")]
33     IncorrectOptions(#[from] getopts::Fail),
34     #[error("{0}")]
35     IncorrectFilter(#[from] regex::Error),
36     #[error("{0}")]
37     IoError(#[from] io::Error),
38 }
39 
40 #[derive(StructOpt, Debug)]
41 #[structopt(
42     name = "rustfmt-format-diff",
43     setting = AppSettings::DisableVersion,
44     setting = AppSettings::NextLineHelp
45 )]
46 pub struct Opts {
47     /// Skip the smallest prefix containing NUMBER slashes
48     #[structopt(
49         short = "p",
50         long = "skip-prefix",
51         value_name = "NUMBER",
52         default_value = "0"
53     )]
54     skip_prefix: u32,
55 
56     /// Custom pattern selecting file paths to reformat
57     #[structopt(
58         short = "f",
59         long = "filter",
60         value_name = "PATTERN",
61         default_value = DEFAULT_PATTERN
62     )]
63     filter: String,
64 }
65 
main()66 fn main() {
67     env_logger::Builder::from_env("RUSTFMT_LOG").init();
68     let opts = Opts::from_args();
69     if let Err(e) = run(opts) {
70         println!("{}", e);
71         Opts::clap().print_help().expect("cannot write to stdout");
72         process::exit(1);
73     }
74 }
75 
76 #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
77 struct Range {
78     file: String,
79     range: [u32; 2],
80 }
81 
run(opts: Opts) -> Result<(), FormatDiffError>82 fn run(opts: Opts) -> Result<(), FormatDiffError> {
83     let (files, ranges) = scan_diff(io::stdin(), opts.skip_prefix, &opts.filter)?;
84     run_rustfmt(&files, &ranges)
85 }
86 
run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError>87 fn run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError> {
88     if files.is_empty() || ranges.is_empty() {
89         debug!("No files to format found");
90         return Ok(());
91     }
92 
93     let ranges_as_json = json::to_string(ranges).unwrap();
94 
95     debug!("Files: {:?}", files);
96     debug!("Ranges: {:?}", ranges);
97 
98     let rustfmt_var = env::var_os("RUSTFMT");
99     let rustfmt = match &rustfmt_var {
100         Some(rustfmt) => rustfmt,
101         None => OsStr::new("rustfmt"),
102     };
103     let exit_status = process::Command::new(rustfmt)
104         .args(files)
105         .arg("--file-lines")
106         .arg(ranges_as_json)
107         .status()?;
108 
109     if !exit_status.success() {
110         return Err(FormatDiffError::IoError(io::Error::new(
111             io::ErrorKind::Other,
112             format!("rustfmt failed with {}", exit_status),
113         )));
114     }
115     Ok(())
116 }
117 
118 /// Scans a diff from `from`, and returns the set of files found, and the ranges
119 /// in those files.
scan_diff<R>( from: R, skip_prefix: u32, file_filter: &str, ) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError> where R: io::Read,120 fn scan_diff<R>(
121     from: R,
122     skip_prefix: u32,
123     file_filter: &str,
124 ) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError>
125 where
126     R: io::Read,
127 {
128     let diff_pattern = format!(r"^\+\+\+\s(?:.*?/){{{}}}(\S*)", skip_prefix);
129     let diff_pattern = Regex::new(&diff_pattern).unwrap();
130 
131     let lines_pattern = Regex::new(r"^@@.*\+(\d+)(,(\d+))?").unwrap();
132 
133     let file_filter = Regex::new(&format!("^{}$", file_filter))?;
134 
135     let mut current_file = None;
136 
137     let mut files = HashSet::new();
138     let mut ranges = vec![];
139     for line in io::BufReader::new(from).lines() {
140         let line = line.unwrap();
141 
142         if let Some(captures) = diff_pattern.captures(&line) {
143             current_file = Some(captures.get(1).unwrap().as_str().to_owned());
144         }
145 
146         let file = match current_file {
147             Some(ref f) => &**f,
148             None => continue,
149         };
150 
151         // FIXME(emilio): We could avoid this most of the time if needed, but
152         // it's not clear it's worth it.
153         if !file_filter.is_match(file) {
154             continue;
155         }
156 
157         let lines_captures = match lines_pattern.captures(&line) {
158             Some(captures) => captures,
159             None => continue,
160         };
161 
162         let start_line = lines_captures
163             .get(1)
164             .unwrap()
165             .as_str()
166             .parse::<u32>()
167             .unwrap();
168         let line_count = match lines_captures.get(3) {
169             Some(line_count) => line_count.as_str().parse::<u32>().unwrap(),
170             None => 1,
171         };
172 
173         if line_count == 0 {
174             continue;
175         }
176 
177         let end_line = start_line + line_count - 1;
178         files.insert(file.to_owned());
179         ranges.push(Range {
180             file: file.to_owned(),
181             range: [start_line, end_line],
182         });
183     }
184 
185     Ok((files, ranges))
186 }
187 
188 #[test]
scan_simple_git_diff()189 fn scan_simple_git_diff() {
190     const DIFF: &str = include_str!("test/bindgen.diff");
191     let (files, ranges) = scan_diff(DIFF.as_bytes(), 1, r".*\.rs").expect("scan_diff failed?");
192 
193     assert!(
194         files.contains("src/ir/traversal.rs"),
195         "Should've matched the filter"
196     );
197 
198     assert!(
199         !files.contains("tests/headers/anon_enum.hpp"),
200         "Shouldn't have matched the filter"
201     );
202 
203     assert_eq!(
204         &ranges,
205         &[
206             Range {
207                 file: "src/ir/item.rs".to_owned(),
208                 range: [148, 158],
209             },
210             Range {
211                 file: "src/ir/item.rs".to_owned(),
212                 range: [160, 170],
213             },
214             Range {
215                 file: "src/ir/traversal.rs".to_owned(),
216                 range: [9, 16],
217             },
218             Range {
219                 file: "src/ir/traversal.rs".to_owned(),
220                 range: [35, 43],
221             },
222         ]
223     );
224 }
225 
226 #[cfg(test)]
227 mod cmd_line_tests {
228     use super::*;
229 
230     #[test]
default_options()231     fn default_options() {
232         let empty: Vec<String> = vec![];
233         let o = Opts::from_iter(&empty);
234         assert_eq!(DEFAULT_PATTERN, o.filter);
235         assert_eq!(0, o.skip_prefix);
236     }
237 
238     #[test]
good_options()239     fn good_options() {
240         let o = Opts::from_iter(&["test", "-p", "10", "-f", r".*\.hs"]);
241         assert_eq!(r".*\.hs", o.filter);
242         assert_eq!(10, o.skip_prefix);
243     }
244 
245     #[test]
unexpected_option()246     fn unexpected_option() {
247         assert!(
248             Opts::clap()
249                 .get_matches_from_safe(&["test", "unexpected"])
250                 .is_err()
251         );
252     }
253 
254     #[test]
unexpected_flag()255     fn unexpected_flag() {
256         assert!(
257             Opts::clap()
258                 .get_matches_from_safe(&["test", "--flag"])
259                 .is_err()
260         );
261     }
262 
263     #[test]
overridden_option()264     fn overridden_option() {
265         assert!(
266             Opts::clap()
267                 .get_matches_from_safe(&["test", "-p", "10", "-p", "20"])
268                 .is_err()
269         );
270     }
271 
272     #[test]
negative_filter()273     fn negative_filter() {
274         assert!(
275             Opts::clap()
276                 .get_matches_from_safe(&["test", "-p", "-1"])
277                 .is_err()
278         );
279     }
280 }
281