1 use std::iter::{Fuse, Iterator};
2 
3 use crate::core::{BytePos, ByteRange};
4 use crate::scopes;
5 use crate::util::is_whitespace_byte;
6 
7 /// An iterator which iterates statements.
8 /// e.g. for "let a = 5; let b = 4;" it returns "let a = 5;" and then "let b = 4;"
9 /// This iterator only works for comment-masked source codes.
10 pub struct StmtIndicesIter<'a> {
11     src: &'a str,
12     pos: BytePos,
13     end: BytePos,
14 }
15 
16 impl<'a> Iterator for StmtIndicesIter<'a> {
17     type Item = ByteRange;
18 
19     #[inline]
next(&mut self) -> Option<Self::Item>20     fn next(&mut self) -> Option<Self::Item> {
21         let src_bytes = self.src.as_bytes();
22         let mut enddelim = b';';
23         let mut bracelevel = 0isize;
24         let mut parenlevel = 0isize;
25         let mut bracketlevel = 0isize;
26         let mut pos = self.pos;
27         for &b in &src_bytes[pos.0..self.end.0] {
28             match b {
29                 b' ' | b'\r' | b'\n' | b'\t' => {
30                     pos += BytePos(1);
31                 }
32                 _ => {
33                     break;
34                 }
35             }
36         }
37         let start = pos;
38         // test attribute   #[foo = bar]
39         if pos < self.end && src_bytes[pos.0] == b'#' {
40             enddelim = b']'
41         };
42         // iterate through the chunk, looking for stmt end
43         for &b in &src_bytes[pos.0..self.end.0] {
44             pos += BytePos(1);
45             match b {
46                 b'(' => {
47                     parenlevel += 1;
48                 }
49                 b')' => {
50                     parenlevel -= 1;
51                 }
52                 b'[' => {
53                     bracketlevel += 1;
54                 }
55                 b']' => {
56                     bracketlevel -= 1;
57                 }
58                 b'{' => {
59                     // if we are top level and stmt is not a 'use' or 'let' then
60                     // closebrace finishes the stmt
61                     if bracelevel == 0
62                         && parenlevel == 0
63                         && !(is_a_use_stmt(src_bytes, start, pos)
64                             || is_a_let_stmt(src_bytes, start, pos))
65                     {
66                         enddelim = b'}';
67                     }
68                     bracelevel += 1;
69                 }
70                 b'}' => {
71                     // have we reached the end of the scope?
72                     if bracelevel == 0 {
73                         self.pos = pos;
74                         return None;
75                     }
76                     bracelevel -= 1;
77                 }
78                 b'!' => {
79                     // macro if followed by at least one space or (
80                     // FIXME: test with boolean 'not' expression
81                     if parenlevel == 0 && bracelevel == 0 && pos < self.end && (pos - start).0 > 1 {
82                         match src_bytes[pos.0] {
83                             b' ' | b'\r' | b'\n' | b'\t' | b'(' => {
84                                 enddelim = b')';
85                             }
86                             _ => {}
87                         }
88                     }
89                 }
90                 _ => {}
91             }
92             if parenlevel < 0
93                 || bracelevel < 0
94                 || bracketlevel < 0
95                 || (enddelim == b && bracelevel == 0 && parenlevel == 0 && bracketlevel == 0)
96             {
97                 self.pos = pos;
98                 return Some(ByteRange::new(start, pos));
99             }
100         }
101         if start < self.end {
102             self.pos = pos;
103             return Some(ByteRange::new(start, self.end));
104         }
105         None
106     }
107 }
108 
is_a_use_stmt(src_bytes: &[u8], start: BytePos, pos: BytePos) -> bool109 fn is_a_use_stmt(src_bytes: &[u8], start: BytePos, pos: BytePos) -> bool {
110     let src = unsafe { ::std::str::from_utf8_unchecked(&src_bytes[start.0..pos.0]) };
111     scopes::use_stmt_start(&src).is_some()
112 }
113 
is_a_let_stmt(src_bytes: &[u8], start: BytePos, pos: BytePos) -> bool114 fn is_a_let_stmt(src_bytes: &[u8], start: BytePos, pos: BytePos) -> bool {
115     pos.0 > 3
116         && &src_bytes[start.0..start.0 + 3] == b"let"
117         && is_whitespace_byte(src_bytes[start.0 + 3])
118 }
119 
120 impl<'a> StmtIndicesIter<'a> {
from_parts(src: &str) -> Fuse<StmtIndicesIter<'_>>121     pub fn from_parts(src: &str) -> Fuse<StmtIndicesIter<'_>> {
122         StmtIndicesIter {
123             src,
124             pos: BytePos::ZERO,
125             end: BytePos(src.len()),
126         }
127         .fuse()
128     }
129 }
130 
131 #[cfg(test)]
132 mod test {
133     use std::iter::Fuse;
134 
135     use crate::codecleaner;
136     use crate::testutils::{rejustify, slice};
137 
138     use super::*;
139 
iter_stmts(src: &str) -> Fuse<StmtIndicesIter<'_>>140     fn iter_stmts(src: &str) -> Fuse<StmtIndicesIter<'_>> {
141         let idx: Vec<_> = codecleaner::code_chunks(&src).collect();
142         let code = scopes::mask_comments(src, &idx);
143         let code: &'static str = Box::leak(code.into_boxed_str());
144         StmtIndicesIter::from_parts(code)
145     }
146 
147     #[test]
iterates_single_use_stmts()148     fn iterates_single_use_stmts() {
149         let src = rejustify(
150             "
151             use std::Foo; // a comment
152             use std::Bar;
153         ",
154         );
155 
156         let mut it = iter_stmts(src.as_ref());
157         assert_eq!("use std::Foo;", slice(&src, it.next().unwrap()));
158         assert_eq!("use std::Bar;", slice(&src, it.next().unwrap()));
159     }
160 
161     #[test]
iterates_array_stmts()162     fn iterates_array_stmts() {
163         let src = rejustify(
164             "
165             let a: [i32; 2] = [1, 2];
166             let b = [[0], [1], [2]];
167             let c = ([1, 2, 3])[1];
168         ",
169         );
170 
171         let mut it = iter_stmts(src.as_ref());
172         assert_eq!("let a: [i32; 2] = [1, 2];", slice(&src, it.next().unwrap()));
173         assert_eq!("let b = [[0], [1], [2]];", slice(&src, it.next().unwrap()));
174         assert_eq!("let c = ([1, 2, 3])[1];", slice(&src, it.next().unwrap()));
175     }
176 
177     #[test]
iterates_use_stmt_over_two_lines()178     fn iterates_use_stmt_over_two_lines() {
179         let src = rejustify(
180             "
181         use std::{Foo,
182                   Bar}; // a comment
183         ",
184         );
185         let mut it = iter_stmts(src.as_ref());
186         assert_eq!(
187             "use std::{Foo,
188               Bar};",
189             slice(&src, it.next().unwrap())
190         );
191     }
192 
193     #[test]
iterates_use_stmt_without_the_prefix()194     fn iterates_use_stmt_without_the_prefix() {
195         let src = rejustify(
196             "
197         pub use {Foo,
198                  Bar}; // this is also legit apparently
199         ",
200         );
201         let mut it = iter_stmts(src.as_ref());
202         assert_eq!(
203             "pub use {Foo,
204              Bar};",
205             slice(&src, it.next().unwrap())
206         );
207     }
208 
209     #[test]
iterates_while_stmt()210     fn iterates_while_stmt() {
211         let src = rejustify(
212             "
213             while self.pos < 3 { }
214         ",
215         );
216         let mut it = iter_stmts(src.as_ref());
217         assert_eq!("while self.pos < 3 { }", slice(&src, it.next().unwrap()));
218     }
219 
220     #[test]
iterates_lambda_arg()221     fn iterates_lambda_arg() {
222         let src = rejustify(
223             "
224             myfn(|n|{});
225         ",
226         );
227         let mut it = iter_stmts(src.as_ref());
228         assert_eq!("myfn(|n|{});", slice(&src, it.next().unwrap()));
229     }
230 
231     #[test]
iterates_macro()232     fn iterates_macro() {
233         let src = "
234         mod foo;
235         macro_rules! otry(
236             ($e:expr) => (match $e { Some(e) => e, None => return })
237         )
238         mod bar;
239         ";
240         let mut it = iter_stmts(src.as_ref());
241         assert_eq!("mod foo;", slice(&src, it.next().unwrap()));
242         assert_eq!(
243             "macro_rules! otry(
244             ($e:expr) => (match $e { Some(e) => e, None => return })
245         )",
246             slice(&src, it.next().unwrap())
247         );
248         assert_eq!("mod bar;", slice(&src, it.next().unwrap()));
249     }
250 
251     #[test]
iterates_macro_invocation()252     fn iterates_macro_invocation() {
253         let src = "
254             mod foo;
255             local_data_key!(local_stdout: Box<Writer + Send>)  // no ';'
256             mod bar;
257         ";
258         let mut it = iter_stmts(src.as_ref());
259         assert_eq!("mod foo;", slice(&src, it.next().unwrap()));
260         assert_eq!(
261             "local_data_key!(local_stdout: Box<Writer + Send>)",
262             slice(&src, it.next().unwrap())
263         );
264         assert_eq!("mod bar;", slice(&src, it.next().unwrap()));
265     }
266 
267     #[test]
iterates_if_else_stmt()268     fn iterates_if_else_stmt() {
269         let src = "
270             if self.pos < 3 { } else { }
271         ";
272         let mut it = iter_stmts(src.as_ref());
273         assert_eq!("if self.pos < 3 { }", slice(&src, it.next().unwrap()));
274         assert_eq!("else { }", slice(&src, it.next().unwrap()));
275     }
276 
277     #[test]
iterates_inner_scope()278     fn iterates_inner_scope() {
279         let src = &"
280         while(self.pos < 3 {
281             let a = 35;
282             return a + 35;  // should iterate this
283         }
284         {
285             b = foo;       // but not this
286         }
287         "[29..];
288 
289         let mut it = iter_stmts(src.as_ref());
290 
291         assert_eq!("let a = 35;", slice(&src, it.next().unwrap()));
292         assert_eq!("return a + 35;", slice(&src, it.next().unwrap()));
293         assert_eq!(None, it.next());
294     }
295 
296     #[test]
iterates_module_attribute()297     fn iterates_module_attribute() {
298         let src = rejustify(
299             "
300             #![license = \"BSD\"]
301             #[test]
302         ",
303         );
304         let mut it = iter_stmts(src.as_ref());
305         assert_eq!("#![license = \"BSD\"]", slice(&src, it.next().unwrap()));
306         assert_eq!("#[test]", slice(&src, it.next().unwrap()));
307     }
308 
309     #[test]
iterates_half_open_subscope_if_is_the_last_thing()310     fn iterates_half_open_subscope_if_is_the_last_thing() {
311         let src = "
312             let something = 35;
313             while self.pos < 3 {
314             let a = 35;
315             return a + 35;  // should iterate this
316         ";
317 
318         let mut it = iter_stmts(src.as_ref());
319         assert_eq!("let something = 35;", slice(&src, it.next().unwrap()));
320         assert_eq!(
321             "while self.pos < 3 {
322             let a = 35;
323             return a + 35;  // should iterate this
324         ",
325             slice(&src, it.next().unwrap())
326         );
327     }
328 
329     #[test]
iterates_ndarray()330     fn iterates_ndarray() {
331         let src = "
332             let a = [[f64; 5]; 5];
333             pub struct Matrix44f(pub [[f64; 4]; 4]);
334         ";
335         let mut it = iter_stmts(src.as_ref());
336         assert_eq!("let a = [[f64; 5]; 5];", slice(&src, it.next().unwrap()));
337         assert_eq!(
338             "pub struct Matrix44f(pub [[f64; 4]; 4]);",
339             slice(&src, it.next().unwrap())
340         );
341     }
342 
343     #[test]
344     #[ignore]
iterates_for_struct()345     fn iterates_for_struct() {
346         let src = "
347             let a = 5;
348             for St { a, b } in iter() {
349                 let b = a;
350             }
351             while let St { a, b } = iter().next() {
352 
353             }
354             if let St(a) = hoge() {
355 
356             }
357         ";
358         let mut it = iter_stmts(src.as_ref());
359         assert_eq!("let a = 5;", slice(&src, it.next().unwrap()));
360         assert_eq!(
361             r"for St { a, b } in iter() {
362                 let b = a;
363             }",
364             slice(&src, it.next().unwrap())
365         );
366     }
367 }
368