1 use proc_macro2::{Span, TokenStream};
2 use quote::ToTokens;
3 use syn::{
4     parse_quote,
5     visit_mut::{self, VisitMut},
6     Expr, ExprLet, ExprMatch, Ident, ImplItem, Item, ItemFn, ItemImpl, ItemUse, Lifetime, Local,
7     Pat, PatBox, PatIdent, PatOr, PatPath, PatReference, PatStruct, PatTupleStruct, PatType, Path,
8     PathArguments, PathSegment, Result, Stmt, Type, TypePath, UseTree,
9 };
10 
11 use crate::utils::{
12     determine_lifetime_name, insert_lifetime, parse_as_empty, ProjKind, SliceExt, VecExt,
13 };
14 
attribute(args: &TokenStream, input: Stmt, kind: ProjKind) -> TokenStream15 pub(crate) fn attribute(args: &TokenStream, input: Stmt, kind: ProjKind) -> TokenStream {
16     parse_as_empty(args).and_then(|()| parse(input, kind)).unwrap_or_else(|e| e.to_compile_error())
17 }
18 
replace_expr(expr: &mut Expr, kind: ProjKind)19 fn replace_expr(expr: &mut Expr, kind: ProjKind) {
20     match expr {
21         Expr::Match(expr) => {
22             Context::new(kind).replace_expr_match(expr);
23         }
24         Expr::If(expr_if) => {
25             let mut expr_if = expr_if;
26             while let Expr::Let(ref mut expr) = &mut *expr_if.cond {
27                 Context::new(kind).replace_expr_let(expr);
28                 if let Some((_, ref mut expr)) = expr_if.else_branch {
29                     if let Expr::If(new_expr_if) = &mut **expr {
30                         expr_if = new_expr_if;
31                         continue;
32                     }
33                 }
34                 break;
35             }
36         }
37         _ => {}
38     }
39 }
40 
parse(mut stmt: Stmt, kind: ProjKind) -> Result<TokenStream>41 fn parse(mut stmt: Stmt, kind: ProjKind) -> Result<TokenStream> {
42     match &mut stmt {
43         Stmt::Expr(expr) | Stmt::Semi(expr, _) => replace_expr(expr, kind),
44         Stmt::Local(local) => Context::new(kind).replace_local(local)?,
45         Stmt::Item(Item::Fn(item)) => replace_item_fn(item, kind)?,
46         Stmt::Item(Item::Impl(item)) => replace_item_impl(item, kind)?,
47         Stmt::Item(Item::Use(item)) => replace_item_use(item, kind)?,
48         _ => {}
49     }
50 
51     Ok(stmt.into_token_stream())
52 }
53 
54 struct Context {
55     register: Option<(Ident, usize)>,
56     replaced: bool,
57     kind: ProjKind,
58 }
59 
60 impl Context {
new(kind: ProjKind) -> Self61     fn new(kind: ProjKind) -> Self {
62         Self { register: None, replaced: false, kind }
63     }
64 
update(&mut self, ident: &Ident, len: usize)65     fn update(&mut self, ident: &Ident, len: usize) {
66         self.register.get_or_insert_with(|| (ident.clone(), len));
67     }
68 
compare_paths(&self, ident: &Ident, len: usize) -> bool69     fn compare_paths(&self, ident: &Ident, len: usize) -> bool {
70         match &self.register {
71             Some((i, l)) => *l == len && i == ident,
72             None => false,
73         }
74     }
75 
replace_local(&mut self, local: &mut Local) -> Result<()>76     fn replace_local(&mut self, local: &mut Local) -> Result<()> {
77         if let Some(attr) = local.attrs.find(self.kind.method_name()) {
78             return Err(error!(attr, "duplicate #[{}] attribute", self.kind.method_name()));
79         }
80 
81         if let Some(Expr::Match(expr)) = local.init.as_mut().map(|(_, expr)| &mut **expr) {
82             self.replace_expr_match(expr);
83         }
84 
85         if self.replaced {
86             if is_replaceable(&local.pat, false) {
87                 return Err(error!(
88                     local.pat,
89                     "Both initializer expression and pattern are replaceable, \
90                      you need to split the initializer expression into separate let bindings \
91                      to avoid ambiguity"
92                 ));
93             }
94         } else {
95             self.replace_pat(&mut local.pat, false);
96         }
97 
98         Ok(())
99     }
100 
replace_expr_let(&mut self, expr: &mut ExprLet)101     fn replace_expr_let(&mut self, expr: &mut ExprLet) {
102         self.replace_pat(&mut expr.pat, true)
103     }
104 
replace_expr_match(&mut self, expr: &mut ExprMatch)105     fn replace_expr_match(&mut self, expr: &mut ExprMatch) {
106         expr.arms.iter_mut().for_each(|arm| self.replace_pat(&mut arm.pat, true))
107     }
108 
replace_pat(&mut self, pat: &mut Pat, allow_pat_path: bool)109     fn replace_pat(&mut self, pat: &mut Pat, allow_pat_path: bool) {
110         match pat {
111             Pat::Ident(PatIdent { subpat: Some((_, pat)), .. })
112             | Pat::Reference(PatReference { pat, .. })
113             | Pat::Box(PatBox { pat, .. })
114             | Pat::Type(PatType { pat, .. }) => self.replace_pat(pat, allow_pat_path),
115 
116             Pat::Or(PatOr { cases, .. }) => {
117                 cases.iter_mut().for_each(|pat| self.replace_pat(pat, allow_pat_path))
118             }
119 
120             Pat::Struct(PatStruct { path, .. }) | Pat::TupleStruct(PatTupleStruct { path, .. }) => {
121                 self.replace_path(path)
122             }
123             Pat::Path(PatPath { qself: None, path, .. }) if allow_pat_path => {
124                 self.replace_path(path)
125             }
126             _ => {}
127         }
128     }
129 
replace_path(&mut self, path: &mut Path)130     fn replace_path(&mut self, path: &mut Path) {
131         let len = match path.segments.len() {
132             // 1: struct
133             // 2: enum
134             len @ 1 | len @ 2 => len,
135             // other path
136             _ => return,
137         };
138 
139         if self.register.is_none() || self.compare_paths(&path.segments[0].ident, len) {
140             self.update(&path.segments[0].ident, len);
141             self.replaced = true;
142             replace_ident(&mut path.segments[0].ident, self.kind);
143         }
144     }
145 }
146 
is_replaceable(pat: &Pat, allow_pat_path: bool) -> bool147 fn is_replaceable(pat: &Pat, allow_pat_path: bool) -> bool {
148     match pat {
149         Pat::Ident(PatIdent { subpat: Some((_, pat)), .. })
150         | Pat::Reference(PatReference { pat, .. })
151         | Pat::Box(PatBox { pat, .. })
152         | Pat::Type(PatType { pat, .. }) => is_replaceable(pat, allow_pat_path),
153 
154         Pat::Or(PatOr { cases, .. }) => cases.iter().any(|pat| is_replaceable(pat, allow_pat_path)),
155 
156         Pat::Struct(_) | Pat::TupleStruct(_) => true,
157         Pat::Path(PatPath { qself: None, .. }) => allow_pat_path,
158         _ => false,
159     }
160 }
161 
replace_ident(ident: &mut Ident, kind: ProjKind)162 fn replace_ident(ident: &mut Ident, kind: ProjKind) {
163     *ident = kind.proj_ident(ident);
164 }
165 
replace_item_impl(item: &mut ItemImpl, kind: ProjKind) -> Result<()>166 fn replace_item_impl(item: &mut ItemImpl, kind: ProjKind) -> Result<()> {
167     if let Some(attr) = item.attrs.find(kind.method_name()) {
168         return Err(error!(attr, "duplicate #[{}] attribute", kind.method_name()));
169     }
170 
171     let PathSegment { ident, arguments } = match &mut *item.self_ty {
172         Type::Path(TypePath { qself: None, path }) => path.segments.last_mut().unwrap(),
173         _ => return Ok(()),
174     };
175 
176     replace_ident(ident, kind);
177 
178     let mut lifetime_name = String::from("'pin");
179     determine_lifetime_name(&mut lifetime_name, &mut item.generics);
180     item.items
181         .iter_mut()
182         .filter_map(|i| if let ImplItem::Method(i) = i { Some(i) } else { None })
183         .for_each(|item| determine_lifetime_name(&mut lifetime_name, &mut item.sig.generics));
184     let lifetime = Lifetime::new(&lifetime_name, Span::call_site());
185 
186     insert_lifetime(&mut item.generics, lifetime.clone());
187 
188     match arguments {
189         PathArguments::None => {
190             *arguments = PathArguments::AngleBracketed(parse_quote!(<#lifetime>));
191         }
192         PathArguments::AngleBracketed(args) => {
193             args.args.insert(0, parse_quote!(#lifetime));
194         }
195         PathArguments::Parenthesized(_) => unreachable!(),
196     }
197     Ok(())
198 }
199 
replace_item_fn(item: &mut ItemFn, kind: ProjKind) -> Result<()>200 fn replace_item_fn(item: &mut ItemFn, kind: ProjKind) -> Result<()> {
201     struct FnVisitor(Result<()>);
202 
203     impl FnVisitor {
204         fn visit_stmt(&mut self, node: &mut Stmt) -> Result<()> {
205             match node {
206                 Stmt::Expr(expr) | Stmt::Semi(expr, _) => self.visit_expr(expr),
207                 Stmt::Local(local) => {
208                     visit_mut::visit_local_mut(self, local);
209 
210                     let mut prev = None;
211                     for &kind in &ProjKind::ALL {
212                         if let Some(attr) = local.attrs.find_remove(kind.method_name())? {
213                             if let Some(prev) = prev.replace(kind) {
214                                 return Err(error!(
215                                     attr,
216                                     "attributes `{}` and `{}` are mutually exclusive",
217                                     prev.method_name(),
218                                     kind.method_name(),
219                                 ));
220                             }
221                             Context::new(kind).replace_local(local)?;
222                         }
223                     }
224 
225                     Ok(())
226                 }
227                 // Do not recurse into nested items.
228                 Stmt::Item(_) => Ok(()),
229             }
230         }
231 
232         fn visit_expr(&mut self, node: &mut Expr) -> Result<()> {
233             visit_mut::visit_expr_mut(self, node);
234             match node {
235                 Expr::Match(expr) => {
236                     let mut prev = None;
237                     for &kind in &ProjKind::ALL {
238                         if let Some(attr) = expr.attrs.find_remove(kind.method_name())? {
239                             if let Some(prev) = prev.replace(kind) {
240                                 return Err(error!(
241                                     attr,
242                                     "attributes `{}` and `{}` are mutually exclusive",
243                                     prev.method_name(),
244                                     kind.method_name(),
245                                 ));
246                             }
247                         }
248                     }
249                     if let Some(kind) = prev {
250                         replace_expr(node, kind);
251                     }
252                 }
253                 Expr::If(expr_if) => {
254                     if let Expr::Let(_) = &*expr_if.cond {
255                         let mut prev = None;
256                         for &kind in &ProjKind::ALL {
257                             if let Some(attr) = expr_if.attrs.find_remove(kind.method_name())? {
258                                 if let Some(prev) = prev.replace(kind) {
259                                     return Err(error!(
260                                         attr,
261                                         "attributes `{}` and `{}` are mutually exclusive",
262                                         prev.method_name(),
263                                         kind.method_name(),
264                                     ));
265                                 }
266                             }
267                         }
268                         if let Some(kind) = prev {
269                             replace_expr(node, kind);
270                         }
271                     }
272                 }
273                 _ => {}
274             }
275             Ok(())
276         }
277     }
278 
279     impl VisitMut for FnVisitor {
280         fn visit_stmt_mut(&mut self, node: &mut Stmt) {
281             if self.0.is_err() {
282                 return;
283             }
284             if let Err(e) = self.visit_stmt(node) {
285                 self.0 = Err(e)
286             }
287         }
288 
289         fn visit_expr_mut(&mut self, node: &mut Expr) {
290             if self.0.is_err() {
291                 return;
292             }
293             if let Err(e) = self.visit_expr(node) {
294                 self.0 = Err(e)
295             }
296         }
297 
298         fn visit_item_mut(&mut self, _: &mut Item) {
299             // Do not recurse into nested items.
300         }
301     }
302 
303     if let Some(attr) = item.attrs.find(kind.method_name()) {
304         return Err(error!(attr, "duplicate #[{}] attribute", kind.method_name()));
305     }
306 
307     let mut visitor = FnVisitor(Ok(()));
308     visitor.visit_block_mut(&mut item.block);
309     visitor.0
310 }
311 
replace_item_use(item: &mut ItemUse, kind: ProjKind) -> Result<()>312 fn replace_item_use(item: &mut ItemUse, kind: ProjKind) -> Result<()> {
313     struct UseTreeVisitor {
314         res: Result<()>,
315         kind: ProjKind,
316     }
317 
318     impl VisitMut for UseTreeVisitor {
319         fn visit_use_tree_mut(&mut self, node: &mut UseTree) {
320             if self.res.is_err() {
321                 return;
322             }
323 
324             match node {
325                 // Desugar `use tree::<name>` into `tree::__<name>Projection`.
326                 UseTree::Name(name) => replace_ident(&mut name.ident, self.kind),
327                 UseTree::Glob(glob) => {
328                     self.res = Err(error!(
329                         glob,
330                         "#[{}] attribute may not be used on glob imports",
331                         self.kind.method_name()
332                     ));
333                 }
334                 UseTree::Rename(rename) => {
335                     self.res = Err(error!(
336                         rename,
337                         "#[{}] attribute may not be used on renamed imports",
338                         self.kind.method_name()
339                     ));
340                 }
341                 UseTree::Path(_) | UseTree::Group(_) => visit_mut::visit_use_tree_mut(self, node),
342             }
343         }
344     }
345 
346     if let Some(attr) = item.attrs.find(kind.method_name()) {
347         return Err(error!(attr, "duplicate #[{}] attribute", kind.method_name()));
348     }
349 
350     let mut visitor = UseTreeVisitor { res: Ok(()), kind };
351     visitor.visit_item_use_mut(item);
352     visitor.res
353 }
354