1 use proc_macro2::{Span, TokenStream};
2 use quote::ToTokens;
3 use syn::{
4     visit_mut::{self, VisitMut},
5     *,
6 };
7 
8 use crate::utils::*;
9 
attribute(args: &TokenStream, input: Stmt, mutability: Mutability) -> TokenStream10 pub(crate) fn attribute(args: &TokenStream, input: Stmt, mutability: Mutability) -> TokenStream {
11     parse_as_empty(args)
12         .and_then(|()| parse(input, mutability))
13         .unwrap_or_else(|e| e.to_compile_error())
14 }
15 
parse(mut stmt: Stmt, mutability: Mutability) -> Result<TokenStream>16 fn parse(mut stmt: Stmt, mutability: Mutability) -> Result<TokenStream> {
17     match &mut stmt {
18         Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => {
19             Context::new(mutability).replace_expr_match(expr)
20         }
21         Stmt::Local(local) => Context::new(mutability).replace_local(local)?,
22         Stmt::Item(Item::Fn(item)) => replace_item_fn(item, mutability)?,
23         Stmt::Item(Item::Impl(item)) => replace_item_impl(item, mutability),
24         Stmt::Item(Item::Use(item)) => replace_item_use(item, mutability)?,
25         _ => {}
26     }
27 
28     Ok(stmt.into_token_stream())
29 }
30 
31 struct Context {
32     register: Option<(Ident, usize)>,
33     replaced: bool,
34     mutability: Mutability,
35 }
36 
37 impl Context {
new(mutability: Mutability) -> Self38     fn new(mutability: Mutability) -> Self {
39         Self { register: None, replaced: false, mutability }
40     }
41 
update(&mut self, ident: &Ident, len: usize)42     fn update(&mut self, ident: &Ident, len: usize) {
43         if self.register.is_none() {
44             self.register = Some((ident.clone(), len));
45         }
46     }
47 
compare_paths(&self, ident: &Ident, len: usize) -> bool48     fn compare_paths(&self, ident: &Ident, len: usize) -> bool {
49         match &self.register {
50             Some((i, l)) => *l == len && ident == i,
51             None => false,
52         }
53     }
54 
replace_local(&mut self, local: &mut Local) -> Result<()>55     fn replace_local(&mut self, local: &mut Local) -> Result<()> {
56         if let Some(Expr::Match(expr)) = local.init.as_mut().map(|(_, expr)| &mut **expr) {
57             self.replace_expr_match(expr);
58         }
59 
60         if self.replaced {
61             if is_replaceable(&local.pat, false) {
62                 return Err(error!(
63                     local.pat,
64                     "Both initializer expression and pattern are replaceable, \
65                      you need to split the initializer expression into separate let bindings \
66                      to avoid ambiguity"
67                 ));
68             }
69         } else {
70             self.replace_pat(&mut local.pat, false);
71         }
72 
73         Ok(())
74     }
75 
replace_expr_match(&mut self, expr: &mut ExprMatch)76     fn replace_expr_match(&mut self, expr: &mut ExprMatch) {
77         expr.arms.iter_mut().for_each(|Arm { pat, .. }| self.replace_pat(pat, true))
78     }
79 
replace_pat(&mut self, pat: &mut Pat, allow_pat_path: bool)80     fn replace_pat(&mut self, pat: &mut Pat, allow_pat_path: bool) {
81         match pat {
82             Pat::Ident(PatIdent { subpat: Some((_, pat)), .. })
83             | Pat::Reference(PatReference { pat, .. })
84             | Pat::Box(PatBox { pat, .. })
85             | Pat::Type(PatType { pat, .. }) => self.replace_pat(pat, allow_pat_path),
86 
87             Pat::Or(PatOr { cases, .. }) => {
88                 cases.iter_mut().for_each(|pat| self.replace_pat(pat, allow_pat_path))
89             }
90 
91             Pat::Struct(PatStruct { path, .. }) | Pat::TupleStruct(PatTupleStruct { path, .. }) => {
92                 self.replace_path(path)
93             }
94             Pat::Path(PatPath { qself: None, path, .. }) if allow_pat_path => {
95                 self.replace_path(path)
96             }
97             _ => {}
98         }
99     }
100 
replace_path(&mut self, path: &mut Path)101     fn replace_path(&mut self, path: &mut Path) {
102         let len = match path.segments.len() {
103             // 1: struct
104             // 2: enum
105             len @ 1 | len @ 2 => len,
106             // other path
107             _ => return,
108         };
109 
110         if self.register.is_none() || self.compare_paths(&path.segments[0].ident, len) {
111             self.update(&path.segments[0].ident, len);
112             self.replaced = true;
113             replace_ident(&mut path.segments[0].ident, self.mutability);
114         }
115     }
116 }
117 
is_replaceable(pat: &Pat, allow_pat_path: bool) -> bool118 fn is_replaceable(pat: &Pat, allow_pat_path: bool) -> bool {
119     match pat {
120         Pat::Ident(PatIdent { subpat: Some((_, pat)), .. })
121         | Pat::Reference(PatReference { pat, .. })
122         | Pat::Box(PatBox { pat, .. })
123         | Pat::Type(PatType { pat, .. }) => is_replaceable(pat, allow_pat_path),
124 
125         Pat::Or(PatOr { cases, .. }) => cases.iter().any(|pat| is_replaceable(pat, allow_pat_path)),
126 
127         Pat::Struct(_) | Pat::TupleStruct(_) => true,
128         Pat::Path(PatPath { qself: None, .. }) => allow_pat_path,
129         _ => false,
130     }
131 }
132 
replace_item_impl(item: &mut ItemImpl, mutability: Mutability)133 fn replace_item_impl(item: &mut ItemImpl, mutability: Mutability) {
134     let PathSegment { ident, arguments } = match &mut *item.self_ty {
135         Type::Path(TypePath { qself: None, path }) => path.segments.last_mut().unwrap(),
136         _ => return,
137     };
138 
139     replace_ident(ident, mutability);
140 
141     let mut lifetime_name = String::from(DEFAULT_LIFETIME_NAME);
142     determine_lifetime_name(&mut lifetime_name, &item.generics.params);
143     item.items
144         .iter_mut()
145         .filter_map(|i| if let ImplItem::Method(i) = i { Some(i) } else { None })
146         .for_each(|item| determine_lifetime_name(&mut lifetime_name, &item.sig.generics.params));
147     let lifetime = Lifetime::new(&lifetime_name, Span::call_site());
148 
149     insert_lifetime(&mut item.generics, lifetime.clone());
150 
151     match arguments {
152         PathArguments::None => {
153             *arguments = PathArguments::AngleBracketed(syn::parse_quote!(<#lifetime>));
154         }
155         PathArguments::AngleBracketed(args) => {
156             args.args.insert(0, syn::parse_quote!(#lifetime));
157         }
158         PathArguments::Parenthesized(_) => unreachable!(),
159     }
160 }
161 
replace_item_fn(item: &mut ItemFn, mutability: Mutability) -> Result<()>162 fn replace_item_fn(item: &mut ItemFn, mutability: Mutability) -> Result<()> {
163     let mut visitor = FnVisitor { res: Ok(()), mutability };
164     visitor.visit_block_mut(&mut item.block);
165     visitor.res
166 }
167 
replace_item_use(item: &mut ItemUse, mutability: Mutability) -> Result<()>168 fn replace_item_use(item: &mut ItemUse, mutability: Mutability) -> Result<()> {
169     let mut visitor = UseTreeVisitor { res: Ok(()), mutability };
170     visitor.visit_item_use_mut(item);
171     visitor.res
172 }
173 
replace_ident(ident: &mut Ident, mutability: Mutability)174 fn replace_ident(ident: &mut Ident, mutability: Mutability) {
175     *ident = proj_ident(ident, mutability);
176 }
177 
178 // =================================================================================================
179 // visitors
180 
181 struct FnVisitor {
182     res: Result<()>,
183     mutability: Mutability,
184 }
185 
186 impl FnVisitor {
187     /// Returns the attribute name.
name(&self) -> &str188     fn name(&self) -> &str {
189         if self.mutability == Mutable { "project" } else { "project_ref" }
190     }
191 
visit_stmt(&mut self, node: &mut Stmt) -> Result<()>192     fn visit_stmt(&mut self, node: &mut Stmt) -> Result<()> {
193         let attr = match node {
194             Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => {
195                 expr.attrs.find_remove(self.name())?
196             }
197             Stmt::Local(local) => local.attrs.find_remove(self.name())?,
198             _ => return Ok(()),
199         };
200         if let Some(attr) = attr {
201             parse_as_empty(&attr.tokens)?;
202             match node {
203                 Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => {
204                     Context::new(self.mutability).replace_expr_match(expr)
205                 }
206                 Stmt::Local(local) => Context::new(self.mutability).replace_local(local)?,
207                 _ => unreachable!(),
208             }
209         }
210         Ok(())
211     }
212 }
213 
214 impl VisitMut for FnVisitor {
visit_stmt_mut(&mut self, node: &mut Stmt)215     fn visit_stmt_mut(&mut self, node: &mut Stmt) {
216         if self.res.is_err() {
217             return;
218         }
219 
220         visit_mut::visit_stmt_mut(self, node);
221 
222         if let Err(e) = self.visit_stmt(node) {
223             self.res = Err(e)
224         }
225     }
226 
visit_item_mut(&mut self, _: &mut Item)227     fn visit_item_mut(&mut self, _: &mut Item) {
228         // Do not recurse into nested items.
229     }
230 }
231 
232 struct UseTreeVisitor {
233     res: Result<()>,
234     mutability: Mutability,
235 }
236 
237 impl VisitMut for UseTreeVisitor {
visit_use_tree_mut(&mut self, node: &mut UseTree)238     fn visit_use_tree_mut(&mut self, node: &mut UseTree) {
239         if self.res.is_err() {
240             return;
241         }
242 
243         match node {
244             // Desugar `use tree::<name>` into `tree::__<name>Projection`.
245             UseTree::Name(name) => replace_ident(&mut name.ident, self.mutability),
246             UseTree::Glob(glob) => {
247                 self.res =
248                     Err(error!(glob, "#[project] attribute may not be used on glob imports"));
249             }
250             UseTree::Rename(rename) => {
251                 // TODO: Consider allowing the projected type to be renamed by `#[project] use Foo as Bar`.
252                 self.res =
253                     Err(error!(rename, "#[project] attribute may not be used on renamed imports"));
254             }
255             node @ UseTree::Path(_) | node @ UseTree::Group(_) => {
256                 visit_mut::visit_use_tree_mut(self, node)
257             }
258         }
259     }
260 }
261