1 use proc_macro2::{Span, TokenStream};
2 use quote::ToTokens;
3 use syn::{
4     visit_mut::{self, VisitMut},
5     *,
6 };
7 
8 use crate::utils::{
9     determine_lifetime_name, insert_lifetime, parse_as_empty, ProjKind, SliceExt, VecExt,
10 };
11 
attribute(args: &TokenStream, input: Stmt, kind: ProjKind) -> TokenStream12 pub(crate) fn attribute(args: &TokenStream, input: Stmt, kind: ProjKind) -> TokenStream {
13     parse_as_empty(args).and_then(|()| parse(input, kind)).unwrap_or_else(|e| e.to_compile_error())
14 }
15 
replace_expr(expr: &mut Expr, kind: ProjKind)16 fn replace_expr(expr: &mut Expr, kind: ProjKind) {
17     match expr {
18         Expr::Match(expr) => {
19             Context::new(kind).replace_expr_match(expr);
20         }
21         Expr::If(expr_if) => {
22             let mut expr_if = expr_if;
23             while let Expr::Let(ref mut expr) = &mut *expr_if.cond {
24                 Context::new(kind).replace_expr_let(expr);
25                 if let Some((_, ref mut expr)) = expr_if.else_branch {
26                     if let Expr::If(new_expr_if) = &mut **expr {
27                         expr_if = new_expr_if;
28                         continue;
29                     }
30                 }
31                 break;
32             }
33         }
34         _ => {}
35     }
36 }
37 
parse(mut stmt: Stmt, kind: ProjKind) -> Result<TokenStream>38 fn parse(mut stmt: Stmt, kind: ProjKind) -> Result<TokenStream> {
39     match &mut stmt {
40         Stmt::Expr(expr) | Stmt::Semi(expr, _) => replace_expr(expr, kind),
41         Stmt::Local(local) => Context::new(kind).replace_local(local)?,
42         Stmt::Item(Item::Fn(item)) => replace_item_fn(item, kind)?,
43         Stmt::Item(Item::Impl(item)) => replace_item_impl(item, kind)?,
44         Stmt::Item(Item::Use(item)) => replace_item_use(item, kind)?,
45         _ => {}
46     }
47 
48     Ok(stmt.into_token_stream())
49 }
50 
51 struct Context {
52     register: Option<(Ident, usize)>,
53     replaced: bool,
54     kind: ProjKind,
55 }
56 
57 impl Context {
new(kind: ProjKind) -> Self58     fn new(kind: ProjKind) -> Self {
59         Self { register: None, replaced: false, kind }
60     }
61 
update(&mut self, ident: &Ident, len: usize)62     fn update(&mut self, ident: &Ident, len: usize) {
63         self.register.get_or_insert_with(|| (ident.clone(), len));
64     }
65 
compare_paths(&self, ident: &Ident, len: usize) -> bool66     fn compare_paths(&self, ident: &Ident, len: usize) -> bool {
67         match &self.register {
68             Some((i, l)) => *l == len && i == ident,
69             None => false,
70         }
71     }
72 
replace_local(&mut self, local: &mut Local) -> Result<()>73     fn replace_local(&mut self, local: &mut Local) -> Result<()> {
74         if let Some(attr) = local.attrs.find(self.kind.method_name()) {
75             return Err(error!(attr, "duplicate #[{}] attribute", self.kind.method_name()));
76         }
77 
78         if let Some(Expr::Match(expr)) = local.init.as_mut().map(|(_, expr)| &mut **expr) {
79             self.replace_expr_match(expr);
80         }
81 
82         if self.replaced {
83             if is_replaceable(&local.pat, false) {
84                 return Err(error!(
85                     local.pat,
86                     "Both initializer expression and pattern are replaceable, \
87                      you need to split the initializer expression into separate let bindings \
88                      to avoid ambiguity"
89                 ));
90             }
91         } else {
92             self.replace_pat(&mut local.pat, false);
93         }
94 
95         Ok(())
96     }
97 
replace_expr_let(&mut self, expr: &mut ExprLet)98     fn replace_expr_let(&mut self, expr: &mut ExprLet) {
99         self.replace_pat(&mut expr.pat, true)
100     }
101 
replace_expr_match(&mut self, expr: &mut ExprMatch)102     fn replace_expr_match(&mut self, expr: &mut ExprMatch) {
103         expr.arms.iter_mut().for_each(|arm| self.replace_pat(&mut arm.pat, true))
104     }
105 
replace_pat(&mut self, pat: &mut Pat, allow_pat_path: bool)106     fn replace_pat(&mut self, pat: &mut Pat, allow_pat_path: bool) {
107         match pat {
108             Pat::Ident(PatIdent { subpat: Some((_, pat)), .. })
109             | Pat::Reference(PatReference { pat, .. })
110             | Pat::Box(PatBox { pat, .. })
111             | Pat::Type(PatType { pat, .. }) => self.replace_pat(pat, allow_pat_path),
112 
113             Pat::Or(PatOr { cases, .. }) => {
114                 cases.iter_mut().for_each(|pat| self.replace_pat(pat, allow_pat_path))
115             }
116 
117             Pat::Struct(PatStruct { path, .. }) | Pat::TupleStruct(PatTupleStruct { path, .. }) => {
118                 self.replace_path(path)
119             }
120             Pat::Path(PatPath { qself: None, path, .. }) if allow_pat_path => {
121                 self.replace_path(path)
122             }
123             _ => {}
124         }
125     }
126 
replace_path(&mut self, path: &mut Path)127     fn replace_path(&mut self, path: &mut Path) {
128         let len = match path.segments.len() {
129             // 1: struct
130             // 2: enum
131             len @ 1 | len @ 2 => len,
132             // other path
133             _ => return,
134         };
135 
136         if self.register.is_none() || self.compare_paths(&path.segments[0].ident, len) {
137             self.update(&path.segments[0].ident, len);
138             self.replaced = true;
139             replace_ident(&mut path.segments[0].ident, self.kind);
140         }
141     }
142 }
143 
is_replaceable(pat: &Pat, allow_pat_path: bool) -> bool144 fn is_replaceable(pat: &Pat, allow_pat_path: bool) -> bool {
145     match pat {
146         Pat::Ident(PatIdent { subpat: Some((_, pat)), .. })
147         | Pat::Reference(PatReference { pat, .. })
148         | Pat::Box(PatBox { pat, .. })
149         | Pat::Type(PatType { pat, .. }) => is_replaceable(pat, allow_pat_path),
150 
151         Pat::Or(PatOr { cases, .. }) => cases.iter().any(|pat| is_replaceable(pat, allow_pat_path)),
152 
153         Pat::Struct(_) | Pat::TupleStruct(_) => true,
154         Pat::Path(PatPath { qself: None, .. }) => allow_pat_path,
155         _ => false,
156     }
157 }
158 
replace_ident(ident: &mut Ident, kind: ProjKind)159 fn replace_ident(ident: &mut Ident, kind: ProjKind) {
160     *ident = kind.proj_ident(ident);
161 }
162 
replace_item_impl(item: &mut ItemImpl, kind: ProjKind) -> Result<()>163 fn replace_item_impl(item: &mut ItemImpl, kind: ProjKind) -> Result<()> {
164     if let Some(attr) = item.attrs.find(kind.method_name()) {
165         return Err(error!(attr, "duplicate #[{}] attribute", kind.method_name()));
166     }
167 
168     let PathSegment { ident, arguments } = match &mut *item.self_ty {
169         Type::Path(TypePath { qself: None, path }) => path.segments.last_mut().unwrap(),
170         _ => return Ok(()),
171     };
172 
173     replace_ident(ident, kind);
174 
175     let mut lifetime_name = String::from("'pin");
176     determine_lifetime_name(&mut lifetime_name, &mut item.generics);
177     item.items
178         .iter_mut()
179         .filter_map(|i| if let ImplItem::Method(i) = i { Some(i) } else { None })
180         .for_each(|item| determine_lifetime_name(&mut lifetime_name, &mut item.sig.generics));
181     let lifetime = Lifetime::new(&lifetime_name, Span::call_site());
182 
183     insert_lifetime(&mut item.generics, lifetime.clone());
184 
185     match arguments {
186         PathArguments::None => {
187             *arguments = PathArguments::AngleBracketed(parse_quote!(<#lifetime>));
188         }
189         PathArguments::AngleBracketed(args) => {
190             args.args.insert(0, parse_quote!(#lifetime));
191         }
192         PathArguments::Parenthesized(_) => unreachable!(),
193     }
194     Ok(())
195 }
196 
replace_item_fn(item: &mut ItemFn, kind: ProjKind) -> Result<()>197 fn replace_item_fn(item: &mut ItemFn, kind: ProjKind) -> Result<()> {
198     struct FnVisitor(Result<()>);
199 
200     impl FnVisitor {
201         fn visit_stmt(&mut self, node: &mut Stmt) -> Result<()> {
202             match node {
203                 Stmt::Expr(expr) | Stmt::Semi(expr, _) => self.visit_expr(expr),
204                 Stmt::Local(local) => {
205                     visit_mut::visit_local_mut(self, local);
206 
207                     let mut prev = None;
208                     for &kind in &ProjKind::ALL {
209                         if let Some(attr) = local.attrs.find_remove(kind.method_name())? {
210                             if let Some(prev) = prev.replace(kind) {
211                                 return Err(error!(
212                                     attr,
213                                     "attributes `{}` and `{}` are mutually exclusive",
214                                     prev.method_name(),
215                                     kind.method_name(),
216                                 ));
217                             }
218                             Context::new(kind).replace_local(local)?;
219                         }
220                     }
221 
222                     Ok(())
223                 }
224                 // Do not recurse into nested items.
225                 Stmt::Item(_) => Ok(()),
226             }
227         }
228 
229         fn visit_expr(&mut self, node: &mut Expr) -> Result<()> {
230             visit_mut::visit_expr_mut(self, node);
231             match node {
232                 Expr::Match(expr) => {
233                     let mut prev = None;
234                     for &kind in &ProjKind::ALL {
235                         if let Some(attr) = expr.attrs.find_remove(kind.method_name())? {
236                             if let Some(prev) = prev.replace(kind) {
237                                 return Err(error!(
238                                     attr,
239                                     "attributes `{}` and `{}` are mutually exclusive",
240                                     prev.method_name(),
241                                     kind.method_name(),
242                                 ));
243                             }
244                         }
245                     }
246                     if let Some(kind) = prev {
247                         replace_expr(node, kind);
248                     }
249                 }
250                 Expr::If(expr_if) => {
251                     if let Expr::Let(_) = &*expr_if.cond {
252                         let mut prev = None;
253                         for &kind in &ProjKind::ALL {
254                             if let Some(attr) = expr_if.attrs.find_remove(kind.method_name())? {
255                                 if let Some(prev) = prev.replace(kind) {
256                                     return Err(error!(
257                                         attr,
258                                         "attributes `{}` and `{}` are mutually exclusive",
259                                         prev.method_name(),
260                                         kind.method_name(),
261                                     ));
262                                 }
263                             }
264                         }
265                         if let Some(kind) = prev {
266                             replace_expr(node, kind);
267                         }
268                     }
269                 }
270                 _ => {}
271             }
272             Ok(())
273         }
274     }
275 
276     impl VisitMut for FnVisitor {
277         fn visit_stmt_mut(&mut self, node: &mut Stmt) {
278             if self.0.is_err() {
279                 return;
280             }
281             if let Err(e) = self.visit_stmt(node) {
282                 self.0 = Err(e)
283             }
284         }
285 
286         fn visit_expr_mut(&mut self, node: &mut Expr) {
287             if self.0.is_err() {
288                 return;
289             }
290             if let Err(e) = self.visit_expr(node) {
291                 self.0 = Err(e)
292             }
293         }
294 
295         fn visit_item_mut(&mut self, _: &mut Item) {
296             // Do not recurse into nested items.
297         }
298     }
299 
300     if let Some(attr) = item.attrs.find(kind.method_name()) {
301         return Err(error!(attr, "duplicate #[{}] attribute", kind.method_name()));
302     }
303 
304     let mut visitor = FnVisitor(Ok(()));
305     visitor.visit_block_mut(&mut item.block);
306     visitor.0
307 }
308 
replace_item_use(item: &mut ItemUse, kind: ProjKind) -> Result<()>309 fn replace_item_use(item: &mut ItemUse, kind: ProjKind) -> Result<()> {
310     struct UseTreeVisitor {
311         res: Result<()>,
312         kind: ProjKind,
313     }
314 
315     impl VisitMut for UseTreeVisitor {
316         fn visit_use_tree_mut(&mut self, node: &mut UseTree) {
317             if self.res.is_err() {
318                 return;
319             }
320 
321             match node {
322                 // Desugar `use tree::<name>` into `tree::__<name>Projection`.
323                 UseTree::Name(name) => replace_ident(&mut name.ident, self.kind),
324                 UseTree::Glob(glob) => {
325                     self.res = Err(error!(
326                         glob,
327                         "#[{}] attribute may not be used on glob imports",
328                         self.kind.method_name()
329                     ));
330                 }
331                 UseTree::Rename(rename) => {
332                     self.res = Err(error!(
333                         rename,
334                         "#[{}] attribute may not be used on renamed imports",
335                         self.kind.method_name()
336                     ));
337                 }
338                 UseTree::Path(_) | UseTree::Group(_) => visit_mut::visit_use_tree_mut(self, node),
339             }
340         }
341     }
342 
343     if let Some(attr) = item.attrs.find(kind.method_name()) {
344         return Err(error!(attr, "duplicate #[{}] attribute", kind.method_name()));
345     }
346 
347     let mut visitor = UseTreeVisitor { res: Ok(()), kind };
348     visitor.visit_item_use_mut(item);
349     visitor.res
350 }
351