1 use anyhow::{anyhow, Result};
2 use lazy_static::lazy_static;
3 use regex::Regex;
4 use std::fs;
5 use tree_sitter::{Language, Parser, Point};
6
7 lazy_static! {
8 static ref CAPTURE_NAME_REGEX: Regex = Regex::new("[\\w_\\-.]+").unwrap();
9 }
10
11 #[derive(Debug, Eq, PartialEq)]
12 pub struct CaptureInfo {
13 pub name: String,
14 pub start: Point,
15 pub end: Point,
16 }
17
18 #[derive(Debug, PartialEq, Eq)]
19 pub struct Assertion {
20 pub position: Point,
21 pub expected_capture_name: String,
22 }
23
24 /// Parse the given source code, finding all of the comments that contain
25 /// highlighting assertions. Return a vector of (position, expected highlight name)
26 /// pairs.
parse_position_comments( parser: &mut Parser, language: Language, source: &[u8], ) -> Result<Vec<Assertion>>27 pub fn parse_position_comments(
28 parser: &mut Parser,
29 language: Language,
30 source: &[u8],
31 ) -> Result<Vec<Assertion>> {
32 let mut result = Vec::new();
33 let mut assertion_ranges = Vec::new();
34
35 // Parse the code.
36 parser.set_included_ranges(&[]).unwrap();
37 parser.set_language(language).unwrap();
38 let tree = parser.parse(source, None).unwrap();
39
40 // Walk the tree, finding comment nodes that contain assertions.
41 let mut ascending = false;
42 let mut cursor = tree.root_node().walk();
43 loop {
44 if ascending {
45 let node = cursor.node();
46
47 // Find every comment node.
48 if node.kind().contains("comment") {
49 if let Ok(text) = node.utf8_text(source) {
50 let mut position = node.start_position();
51 if position.row > 0 {
52 // Find the arrow character ("^" or '<-") in the comment. A left arrow
53 // refers to the column where the comment node starts. An up arrow refers
54 // to its own column.
55 let mut has_left_caret = false;
56 let mut has_arrow = false;
57 let mut arrow_end = 0;
58 for (i, c) in text.char_indices() {
59 arrow_end = i + 1;
60 if c == '-' && has_left_caret {
61 has_arrow = true;
62 break;
63 }
64 if c == '^' {
65 has_arrow = true;
66 position.column += i;
67 break;
68 }
69 has_left_caret = c == '<';
70 }
71
72 // If the comment node contains an arrow and a highlight name, record the
73 // highlight name and the position.
74 if let (true, Some(mat)) =
75 (has_arrow, CAPTURE_NAME_REGEX.find(&text[arrow_end..]))
76 {
77 assertion_ranges.push((node.start_position(), node.end_position()));
78 result.push(Assertion {
79 position: position,
80 expected_capture_name: mat.as_str().to_string(),
81 });
82 }
83 }
84 }
85 }
86
87 // Continue walking the tree.
88 if cursor.goto_next_sibling() {
89 ascending = false;
90 } else if !cursor.goto_parent() {
91 break;
92 }
93 } else if !cursor.goto_first_child() {
94 ascending = true;
95 }
96 }
97
98 // Adjust the row number in each assertion's position to refer to the line of
99 // code *above* the assertion. There can be multiple lines of assertion comments,
100 // so the positions may have to be decremented by more than one row.
101 let mut i = 0;
102 for assertion in result.iter_mut() {
103 loop {
104 let on_assertion_line = assertion_ranges[i..]
105 .iter()
106 .any(|(start, _)| start.row == assertion.position.row);
107 if on_assertion_line {
108 assertion.position.row -= 1;
109 } else {
110 while i < assertion_ranges.len()
111 && assertion_ranges[i].0.row < assertion.position.row
112 {
113 i += 1;
114 }
115 break;
116 }
117 }
118 }
119
120 // The assertions can end up out of order due to the line adjustments.
121 result.sort_unstable_by_key(|a| a.position);
122
123 Ok(result)
124 }
125
assert_expected_captures( infos: Vec<CaptureInfo>, path: String, parser: &mut Parser, language: Language, ) -> Result<()>126 pub fn assert_expected_captures(
127 infos: Vec<CaptureInfo>,
128 path: String,
129 parser: &mut Parser,
130 language: Language,
131 ) -> Result<()> {
132 let contents = fs::read_to_string(path)?;
133 let pairs = parse_position_comments(parser, language, contents.as_bytes())?;
134 for info in &infos {
135 if let Some(found) = pairs.iter().find(|p| {
136 p.position.row == info.start.row && p.position >= info.start && p.position < info.end
137 }) {
138 if found.expected_capture_name != info.name && info.name != "name" {
139 Err(anyhow!(
140 "Assertion failed: at {}, found {}, expected {}",
141 info.start,
142 found.expected_capture_name,
143 info.name
144 ))?
145 }
146 }
147 }
148 Ok(())
149 }
150