1 use anyhow::{anyhow, Result};
2 use lazy_static::lazy_static;
3 use regex::Regex;
4 use std::fs;
5 use tree_sitter::{Language, Parser, Point};
6 
7 lazy_static! {
8     static ref CAPTURE_NAME_REGEX: Regex = Regex::new("[\\w_\\-.]+").unwrap();
9 }
10 
11 #[derive(Debug, Eq, PartialEq)]
12 pub struct CaptureInfo {
13     pub name: String,
14     pub start: Point,
15     pub end: Point,
16 }
17 
18 #[derive(Debug, PartialEq, Eq)]
19 pub struct Assertion {
20     pub position: Point,
21     pub expected_capture_name: String,
22 }
23 
24 /// Parse the given source code, finding all of the comments that contain
25 /// highlighting assertions. Return a vector of (position, expected highlight name)
26 /// pairs.
parse_position_comments( parser: &mut Parser, language: Language, source: &[u8], ) -> Result<Vec<Assertion>>27 pub fn parse_position_comments(
28     parser: &mut Parser,
29     language: Language,
30     source: &[u8],
31 ) -> Result<Vec<Assertion>> {
32     let mut result = Vec::new();
33     let mut assertion_ranges = Vec::new();
34 
35     // Parse the code.
36     parser.set_included_ranges(&[]).unwrap();
37     parser.set_language(language).unwrap();
38     let tree = parser.parse(source, None).unwrap();
39 
40     // Walk the tree, finding comment nodes that contain assertions.
41     let mut ascending = false;
42     let mut cursor = tree.root_node().walk();
43     loop {
44         if ascending {
45             let node = cursor.node();
46 
47             // Find every comment node.
48             if node.kind().contains("comment") {
49                 if let Ok(text) = node.utf8_text(source) {
50                     let mut position = node.start_position();
51                     if position.row > 0 {
52                         // Find the arrow character ("^" or '<-") in the comment. A left arrow
53                         // refers to the column where the comment node starts. An up arrow refers
54                         // to its own column.
55                         let mut has_left_caret = false;
56                         let mut has_arrow = false;
57                         let mut arrow_end = 0;
58                         for (i, c) in text.char_indices() {
59                             arrow_end = i + 1;
60                             if c == '-' && has_left_caret {
61                                 has_arrow = true;
62                                 break;
63                             }
64                             if c == '^' {
65                                 has_arrow = true;
66                                 position.column += i;
67                                 break;
68                             }
69                             has_left_caret = c == '<';
70                         }
71 
72                         // If the comment node contains an arrow and a highlight name, record the
73                         // highlight name and the position.
74                         if let (true, Some(mat)) =
75                             (has_arrow, CAPTURE_NAME_REGEX.find(&text[arrow_end..]))
76                         {
77                             assertion_ranges.push((node.start_position(), node.end_position()));
78                             result.push(Assertion {
79                                 position: position,
80                                 expected_capture_name: mat.as_str().to_string(),
81                             });
82                         }
83                     }
84                 }
85             }
86 
87             // Continue walking the tree.
88             if cursor.goto_next_sibling() {
89                 ascending = false;
90             } else if !cursor.goto_parent() {
91                 break;
92             }
93         } else if !cursor.goto_first_child() {
94             ascending = true;
95         }
96     }
97 
98     // Adjust the row number in each assertion's position to refer to the line of
99     // code *above* the assertion. There can be multiple lines of assertion comments,
100     // so the positions may have to be decremented by more than one row.
101     let mut i = 0;
102     for assertion in result.iter_mut() {
103         loop {
104             let on_assertion_line = assertion_ranges[i..]
105                 .iter()
106                 .any(|(start, _)| start.row == assertion.position.row);
107             if on_assertion_line {
108                 assertion.position.row -= 1;
109             } else {
110                 while i < assertion_ranges.len()
111                     && assertion_ranges[i].0.row < assertion.position.row
112                 {
113                     i += 1;
114                 }
115                 break;
116             }
117         }
118     }
119 
120     // The assertions can end up out of order due to the line adjustments.
121     result.sort_unstable_by_key(|a| a.position);
122 
123     Ok(result)
124 }
125 
assert_expected_captures( infos: Vec<CaptureInfo>, path: String, parser: &mut Parser, language: Language, ) -> Result<()>126 pub fn assert_expected_captures(
127     infos: Vec<CaptureInfo>,
128     path: String,
129     parser: &mut Parser,
130     language: Language,
131 ) -> Result<()> {
132     let contents = fs::read_to_string(path)?;
133     let pairs = parse_position_comments(parser, language, contents.as_bytes())?;
134     for info in &infos {
135         if let Some(found) = pairs.iter().find(|p| {
136             p.position.row == info.start.row && p.position >= info.start && p.position < info.end
137         }) {
138             if found.expected_capture_name != info.name && info.name != "name" {
139                 Err(anyhow!(
140                     "Assertion failed: at {}, found {}, expected {}",
141                     info.start,
142                     found.expected_capture_name,
143                     info.name
144                 ))?
145             }
146         }
147     }
148     Ok(())
149 }
150