1 use proc_macro2::{Span, TokenStream};
2 use quote::ToTokens;
3 use syn::{
4 visit_mut::{self, VisitMut},
5 *,
6 };
7
8 use crate::utils::*;
9
attribute(args: &TokenStream, input: Stmt, mutability: Mutability) -> TokenStream10 pub(crate) fn attribute(args: &TokenStream, input: Stmt, mutability: Mutability) -> TokenStream {
11 parse_as_empty(args)
12 .and_then(|()| parse(input, mutability))
13 .unwrap_or_else(|e| e.to_compile_error())
14 }
15
parse(mut stmt: Stmt, mutability: Mutability) -> Result<TokenStream>16 fn parse(mut stmt: Stmt, mutability: Mutability) -> Result<TokenStream> {
17 match &mut stmt {
18 Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => {
19 Context::new(mutability).replace_expr_match(expr)
20 }
21 Stmt::Local(local) => Context::new(mutability).replace_local(local)?,
22 Stmt::Item(Item::Fn(item)) => replace_item_fn(item, mutability)?,
23 Stmt::Item(Item::Impl(item)) => replace_item_impl(item, mutability),
24 Stmt::Item(Item::Use(item)) => replace_item_use(item, mutability)?,
25 _ => {}
26 }
27
28 Ok(stmt.into_token_stream())
29 }
30
31 struct Context {
32 register: Option<(Ident, usize)>,
33 replaced: bool,
34 mutability: Mutability,
35 }
36
37 impl Context {
new(mutability: Mutability) -> Self38 fn new(mutability: Mutability) -> Self {
39 Self { register: None, replaced: false, mutability }
40 }
41
update(&mut self, ident: &Ident, len: usize)42 fn update(&mut self, ident: &Ident, len: usize) {
43 if self.register.is_none() {
44 self.register = Some((ident.clone(), len));
45 }
46 }
47
compare_paths(&self, ident: &Ident, len: usize) -> bool48 fn compare_paths(&self, ident: &Ident, len: usize) -> bool {
49 match &self.register {
50 Some((i, l)) => *l == len && ident == i,
51 None => false,
52 }
53 }
54
replace_local(&mut self, local: &mut Local) -> Result<()>55 fn replace_local(&mut self, local: &mut Local) -> Result<()> {
56 if let Some(Expr::Match(expr)) = local.init.as_mut().map(|(_, expr)| &mut **expr) {
57 self.replace_expr_match(expr);
58 }
59
60 if self.replaced {
61 if is_replaceable(&local.pat, false) {
62 return Err(error!(
63 local.pat,
64 "Both initializer expression and pattern are replaceable, \
65 you need to split the initializer expression into separate let bindings \
66 to avoid ambiguity"
67 ));
68 }
69 } else {
70 self.replace_pat(&mut local.pat, false);
71 }
72
73 Ok(())
74 }
75
replace_expr_match(&mut self, expr: &mut ExprMatch)76 fn replace_expr_match(&mut self, expr: &mut ExprMatch) {
77 expr.arms.iter_mut().for_each(|Arm { pat, .. }| self.replace_pat(pat, true))
78 }
79
replace_pat(&mut self, pat: &mut Pat, allow_pat_path: bool)80 fn replace_pat(&mut self, pat: &mut Pat, allow_pat_path: bool) {
81 match pat {
82 Pat::Ident(PatIdent { subpat: Some((_, pat)), .. })
83 | Pat::Reference(PatReference { pat, .. })
84 | Pat::Box(PatBox { pat, .. })
85 | Pat::Type(PatType { pat, .. }) => self.replace_pat(pat, allow_pat_path),
86
87 Pat::Or(PatOr { cases, .. }) => {
88 cases.iter_mut().for_each(|pat| self.replace_pat(pat, allow_pat_path))
89 }
90
91 Pat::Struct(PatStruct { path, .. }) | Pat::TupleStruct(PatTupleStruct { path, .. }) => {
92 self.replace_path(path)
93 }
94 Pat::Path(PatPath { qself: None, path, .. }) if allow_pat_path => {
95 self.replace_path(path)
96 }
97 _ => {}
98 }
99 }
100
replace_path(&mut self, path: &mut Path)101 fn replace_path(&mut self, path: &mut Path) {
102 let len = match path.segments.len() {
103 // 1: struct
104 // 2: enum
105 len @ 1 | len @ 2 => len,
106 // other path
107 _ => return,
108 };
109
110 if self.register.is_none() || self.compare_paths(&path.segments[0].ident, len) {
111 self.update(&path.segments[0].ident, len);
112 self.replaced = true;
113 replace_ident(&mut path.segments[0].ident, self.mutability);
114 }
115 }
116 }
117
is_replaceable(pat: &Pat, allow_pat_path: bool) -> bool118 fn is_replaceable(pat: &Pat, allow_pat_path: bool) -> bool {
119 match pat {
120 Pat::Ident(PatIdent { subpat: Some((_, pat)), .. })
121 | Pat::Reference(PatReference { pat, .. })
122 | Pat::Box(PatBox { pat, .. })
123 | Pat::Type(PatType { pat, .. }) => is_replaceable(pat, allow_pat_path),
124
125 Pat::Or(PatOr { cases, .. }) => cases.iter().any(|pat| is_replaceable(pat, allow_pat_path)),
126
127 Pat::Struct(_) | Pat::TupleStruct(_) => true,
128 Pat::Path(PatPath { qself: None, .. }) => allow_pat_path,
129 _ => false,
130 }
131 }
132
replace_item_impl(item: &mut ItemImpl, mutability: Mutability)133 fn replace_item_impl(item: &mut ItemImpl, mutability: Mutability) {
134 let PathSegment { ident, arguments } = match &mut *item.self_ty {
135 Type::Path(TypePath { qself: None, path }) => path.segments.last_mut().unwrap(),
136 _ => return,
137 };
138
139 replace_ident(ident, mutability);
140
141 let mut lifetime_name = String::from(DEFAULT_LIFETIME_NAME);
142 determine_lifetime_name(&mut lifetime_name, &item.generics.params);
143 item.items
144 .iter_mut()
145 .filter_map(|i| if let ImplItem::Method(i) = i { Some(i) } else { None })
146 .for_each(|item| determine_lifetime_name(&mut lifetime_name, &item.sig.generics.params));
147 let lifetime = Lifetime::new(&lifetime_name, Span::call_site());
148
149 insert_lifetime(&mut item.generics, lifetime.clone());
150
151 match arguments {
152 PathArguments::None => {
153 *arguments = PathArguments::AngleBracketed(syn::parse_quote!(<#lifetime>));
154 }
155 PathArguments::AngleBracketed(args) => {
156 args.args.insert(0, syn::parse_quote!(#lifetime));
157 }
158 PathArguments::Parenthesized(_) => unreachable!(),
159 }
160 }
161
replace_item_fn(item: &mut ItemFn, mutability: Mutability) -> Result<()>162 fn replace_item_fn(item: &mut ItemFn, mutability: Mutability) -> Result<()> {
163 let mut visitor = FnVisitor { res: Ok(()), mutability };
164 visitor.visit_block_mut(&mut item.block);
165 visitor.res
166 }
167
replace_item_use(item: &mut ItemUse, mutability: Mutability) -> Result<()>168 fn replace_item_use(item: &mut ItemUse, mutability: Mutability) -> Result<()> {
169 let mut visitor = UseTreeVisitor { res: Ok(()), mutability };
170 visitor.visit_item_use_mut(item);
171 visitor.res
172 }
173
replace_ident(ident: &mut Ident, mutability: Mutability)174 fn replace_ident(ident: &mut Ident, mutability: Mutability) {
175 *ident = proj_ident(ident, mutability);
176 }
177
178 // =================================================================================================
179 // visitors
180
181 struct FnVisitor {
182 res: Result<()>,
183 mutability: Mutability,
184 }
185
186 impl FnVisitor {
187 /// Returns the attribute name.
name(&self) -> &str188 fn name(&self) -> &str {
189 if self.mutability == Mutable { "project" } else { "project_ref" }
190 }
191
visit_stmt(&mut self, node: &mut Stmt) -> Result<()>192 fn visit_stmt(&mut self, node: &mut Stmt) -> Result<()> {
193 let attr = match node {
194 Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => {
195 expr.attrs.find_remove(self.name())?
196 }
197 Stmt::Local(local) => local.attrs.find_remove(self.name())?,
198 _ => return Ok(()),
199 };
200 if let Some(attr) = attr {
201 parse_as_empty(&attr.tokens)?;
202 match node {
203 Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => {
204 Context::new(self.mutability).replace_expr_match(expr)
205 }
206 Stmt::Local(local) => Context::new(self.mutability).replace_local(local)?,
207 _ => unreachable!(),
208 }
209 }
210 Ok(())
211 }
212 }
213
214 impl VisitMut for FnVisitor {
visit_stmt_mut(&mut self, node: &mut Stmt)215 fn visit_stmt_mut(&mut self, node: &mut Stmt) {
216 if self.res.is_err() {
217 return;
218 }
219
220 visit_mut::visit_stmt_mut(self, node);
221
222 if let Err(e) = self.visit_stmt(node) {
223 self.res = Err(e)
224 }
225 }
226
visit_item_mut(&mut self, _: &mut Item)227 fn visit_item_mut(&mut self, _: &mut Item) {
228 // Do not recurse into nested items.
229 }
230 }
231
232 struct UseTreeVisitor {
233 res: Result<()>,
234 mutability: Mutability,
235 }
236
237 impl VisitMut for UseTreeVisitor {
visit_use_tree_mut(&mut self, node: &mut UseTree)238 fn visit_use_tree_mut(&mut self, node: &mut UseTree) {
239 if self.res.is_err() {
240 return;
241 }
242
243 match node {
244 // Desugar `use tree::<name>` into `tree::__<name>Projection`.
245 UseTree::Name(name) => replace_ident(&mut name.ident, self.mutability),
246 UseTree::Glob(glob) => {
247 self.res =
248 Err(error!(glob, "#[project] attribute may not be used on glob imports"));
249 }
250 UseTree::Rename(rename) => {
251 // TODO: Consider allowing the projected type to be renamed by `#[project] use Foo as Bar`.
252 self.res =
253 Err(error!(rename, "#[project] attribute may not be used on renamed imports"));
254 }
255 node @ UseTree::Path(_) | node @ UseTree::Group(_) => {
256 visit_mut::visit_use_tree_mut(self, node)
257 }
258 }
259 }
260 }
261