1 use proc_macro2::{Span, TokenStream};
2 use quote::ToTokens;
3 use syn::{
4 parse_quote,
5 visit_mut::{self, VisitMut},
6 Expr, ExprLet, ExprMatch, Ident, ImplItem, Item, ItemFn, ItemImpl, ItemUse, Lifetime, Local,
7 Pat, PatBox, PatIdent, PatOr, PatPath, PatReference, PatStruct, PatTupleStruct, PatType, Path,
8 PathArguments, PathSegment, Result, Stmt, Type, TypePath, UseTree,
9 };
10
11 use crate::utils::{
12 determine_lifetime_name, insert_lifetime, parse_as_empty, ProjKind, SliceExt, VecExt,
13 };
14
attribute(args: &TokenStream, input: Stmt, kind: ProjKind) -> TokenStream15 pub(crate) fn attribute(args: &TokenStream, input: Stmt, kind: ProjKind) -> TokenStream {
16 parse_as_empty(args).and_then(|()| parse(input, kind)).unwrap_or_else(|e| e.to_compile_error())
17 }
18
replace_expr(expr: &mut Expr, kind: ProjKind)19 fn replace_expr(expr: &mut Expr, kind: ProjKind) {
20 match expr {
21 Expr::Match(expr) => {
22 Context::new(kind).replace_expr_match(expr);
23 }
24 Expr::If(expr_if) => {
25 let mut expr_if = expr_if;
26 while let Expr::Let(ref mut expr) = &mut *expr_if.cond {
27 Context::new(kind).replace_expr_let(expr);
28 if let Some((_, ref mut expr)) = expr_if.else_branch {
29 if let Expr::If(new_expr_if) = &mut **expr {
30 expr_if = new_expr_if;
31 continue;
32 }
33 }
34 break;
35 }
36 }
37 _ => {}
38 }
39 }
40
parse(mut stmt: Stmt, kind: ProjKind) -> Result<TokenStream>41 fn parse(mut stmt: Stmt, kind: ProjKind) -> Result<TokenStream> {
42 match &mut stmt {
43 Stmt::Expr(expr) | Stmt::Semi(expr, _) => replace_expr(expr, kind),
44 Stmt::Local(local) => Context::new(kind).replace_local(local)?,
45 Stmt::Item(Item::Fn(item)) => replace_item_fn(item, kind)?,
46 Stmt::Item(Item::Impl(item)) => replace_item_impl(item, kind)?,
47 Stmt::Item(Item::Use(item)) => replace_item_use(item, kind)?,
48 _ => {}
49 }
50
51 Ok(stmt.into_token_stream())
52 }
53
54 struct Context {
55 register: Option<(Ident, usize)>,
56 replaced: bool,
57 kind: ProjKind,
58 }
59
60 impl Context {
new(kind: ProjKind) -> Self61 fn new(kind: ProjKind) -> Self {
62 Self { register: None, replaced: false, kind }
63 }
64
update(&mut self, ident: &Ident, len: usize)65 fn update(&mut self, ident: &Ident, len: usize) {
66 self.register.get_or_insert_with(|| (ident.clone(), len));
67 }
68
compare_paths(&self, ident: &Ident, len: usize) -> bool69 fn compare_paths(&self, ident: &Ident, len: usize) -> bool {
70 match &self.register {
71 Some((i, l)) => *l == len && i == ident,
72 None => false,
73 }
74 }
75
replace_local(&mut self, local: &mut Local) -> Result<()>76 fn replace_local(&mut self, local: &mut Local) -> Result<()> {
77 if let Some(attr) = local.attrs.find(self.kind.method_name()) {
78 return Err(error!(attr, "duplicate #[{}] attribute", self.kind.method_name()));
79 }
80
81 if let Some(Expr::Match(expr)) = local.init.as_mut().map(|(_, expr)| &mut **expr) {
82 self.replace_expr_match(expr);
83 }
84
85 if self.replaced {
86 if is_replaceable(&local.pat, false) {
87 return Err(error!(
88 local.pat,
89 "Both initializer expression and pattern are replaceable, \
90 you need to split the initializer expression into separate let bindings \
91 to avoid ambiguity"
92 ));
93 }
94 } else {
95 self.replace_pat(&mut local.pat, false);
96 }
97
98 Ok(())
99 }
100
replace_expr_let(&mut self, expr: &mut ExprLet)101 fn replace_expr_let(&mut self, expr: &mut ExprLet) {
102 self.replace_pat(&mut expr.pat, true)
103 }
104
replace_expr_match(&mut self, expr: &mut ExprMatch)105 fn replace_expr_match(&mut self, expr: &mut ExprMatch) {
106 expr.arms.iter_mut().for_each(|arm| self.replace_pat(&mut arm.pat, true))
107 }
108
replace_pat(&mut self, pat: &mut Pat, allow_pat_path: bool)109 fn replace_pat(&mut self, pat: &mut Pat, allow_pat_path: bool) {
110 match pat {
111 Pat::Ident(PatIdent { subpat: Some((_, pat)), .. })
112 | Pat::Reference(PatReference { pat, .. })
113 | Pat::Box(PatBox { pat, .. })
114 | Pat::Type(PatType { pat, .. }) => self.replace_pat(pat, allow_pat_path),
115
116 Pat::Or(PatOr { cases, .. }) => {
117 cases.iter_mut().for_each(|pat| self.replace_pat(pat, allow_pat_path))
118 }
119
120 Pat::Struct(PatStruct { path, .. }) | Pat::TupleStruct(PatTupleStruct { path, .. }) => {
121 self.replace_path(path)
122 }
123 Pat::Path(PatPath { qself: None, path, .. }) if allow_pat_path => {
124 self.replace_path(path)
125 }
126 _ => {}
127 }
128 }
129
replace_path(&mut self, path: &mut Path)130 fn replace_path(&mut self, path: &mut Path) {
131 let len = match path.segments.len() {
132 // 1: struct
133 // 2: enum
134 len @ 1 | len @ 2 => len,
135 // other path
136 _ => return,
137 };
138
139 if self.register.is_none() || self.compare_paths(&path.segments[0].ident, len) {
140 self.update(&path.segments[0].ident, len);
141 self.replaced = true;
142 replace_ident(&mut path.segments[0].ident, self.kind);
143 }
144 }
145 }
146
is_replaceable(pat: &Pat, allow_pat_path: bool) -> bool147 fn is_replaceable(pat: &Pat, allow_pat_path: bool) -> bool {
148 match pat {
149 Pat::Ident(PatIdent { subpat: Some((_, pat)), .. })
150 | Pat::Reference(PatReference { pat, .. })
151 | Pat::Box(PatBox { pat, .. })
152 | Pat::Type(PatType { pat, .. }) => is_replaceable(pat, allow_pat_path),
153
154 Pat::Or(PatOr { cases, .. }) => cases.iter().any(|pat| is_replaceable(pat, allow_pat_path)),
155
156 Pat::Struct(_) | Pat::TupleStruct(_) => true,
157 Pat::Path(PatPath { qself: None, .. }) => allow_pat_path,
158 _ => false,
159 }
160 }
161
replace_ident(ident: &mut Ident, kind: ProjKind)162 fn replace_ident(ident: &mut Ident, kind: ProjKind) {
163 *ident = kind.proj_ident(ident);
164 }
165
replace_item_impl(item: &mut ItemImpl, kind: ProjKind) -> Result<()>166 fn replace_item_impl(item: &mut ItemImpl, kind: ProjKind) -> Result<()> {
167 if let Some(attr) = item.attrs.find(kind.method_name()) {
168 return Err(error!(attr, "duplicate #[{}] attribute", kind.method_name()));
169 }
170
171 let PathSegment { ident, arguments } = match &mut *item.self_ty {
172 Type::Path(TypePath { qself: None, path }) => path.segments.last_mut().unwrap(),
173 _ => return Ok(()),
174 };
175
176 replace_ident(ident, kind);
177
178 let mut lifetime_name = String::from("'pin");
179 determine_lifetime_name(&mut lifetime_name, &mut item.generics);
180 item.items
181 .iter_mut()
182 .filter_map(|i| if let ImplItem::Method(i) = i { Some(i) } else { None })
183 .for_each(|item| determine_lifetime_name(&mut lifetime_name, &mut item.sig.generics));
184 let lifetime = Lifetime::new(&lifetime_name, Span::call_site());
185
186 insert_lifetime(&mut item.generics, lifetime.clone());
187
188 match arguments {
189 PathArguments::None => {
190 *arguments = PathArguments::AngleBracketed(parse_quote!(<#lifetime>));
191 }
192 PathArguments::AngleBracketed(args) => {
193 args.args.insert(0, parse_quote!(#lifetime));
194 }
195 PathArguments::Parenthesized(_) => unreachable!(),
196 }
197 Ok(())
198 }
199
replace_item_fn(item: &mut ItemFn, kind: ProjKind) -> Result<()>200 fn replace_item_fn(item: &mut ItemFn, kind: ProjKind) -> Result<()> {
201 struct FnVisitor(Result<()>);
202
203 impl FnVisitor {
204 fn visit_stmt(&mut self, node: &mut Stmt) -> Result<()> {
205 match node {
206 Stmt::Expr(expr) | Stmt::Semi(expr, _) => self.visit_expr(expr),
207 Stmt::Local(local) => {
208 visit_mut::visit_local_mut(self, local);
209
210 let mut prev = None;
211 for &kind in &ProjKind::ALL {
212 if let Some(attr) = local.attrs.find_remove(kind.method_name())? {
213 if let Some(prev) = prev.replace(kind) {
214 return Err(error!(
215 attr,
216 "attributes `{}` and `{}` are mutually exclusive",
217 prev.method_name(),
218 kind.method_name(),
219 ));
220 }
221 Context::new(kind).replace_local(local)?;
222 }
223 }
224
225 Ok(())
226 }
227 // Do not recurse into nested items.
228 Stmt::Item(_) => Ok(()),
229 }
230 }
231
232 fn visit_expr(&mut self, node: &mut Expr) -> Result<()> {
233 visit_mut::visit_expr_mut(self, node);
234 match node {
235 Expr::Match(expr) => {
236 let mut prev = None;
237 for &kind in &ProjKind::ALL {
238 if let Some(attr) = expr.attrs.find_remove(kind.method_name())? {
239 if let Some(prev) = prev.replace(kind) {
240 return Err(error!(
241 attr,
242 "attributes `{}` and `{}` are mutually exclusive",
243 prev.method_name(),
244 kind.method_name(),
245 ));
246 }
247 }
248 }
249 if let Some(kind) = prev {
250 replace_expr(node, kind);
251 }
252 }
253 Expr::If(expr_if) => {
254 if let Expr::Let(_) = &*expr_if.cond {
255 let mut prev = None;
256 for &kind in &ProjKind::ALL {
257 if let Some(attr) = expr_if.attrs.find_remove(kind.method_name())? {
258 if let Some(prev) = prev.replace(kind) {
259 return Err(error!(
260 attr,
261 "attributes `{}` and `{}` are mutually exclusive",
262 prev.method_name(),
263 kind.method_name(),
264 ));
265 }
266 }
267 }
268 if let Some(kind) = prev {
269 replace_expr(node, kind);
270 }
271 }
272 }
273 _ => {}
274 }
275 Ok(())
276 }
277 }
278
279 impl VisitMut for FnVisitor {
280 fn visit_stmt_mut(&mut self, node: &mut Stmt) {
281 if self.0.is_err() {
282 return;
283 }
284 if let Err(e) = self.visit_stmt(node) {
285 self.0 = Err(e)
286 }
287 }
288
289 fn visit_expr_mut(&mut self, node: &mut Expr) {
290 if self.0.is_err() {
291 return;
292 }
293 if let Err(e) = self.visit_expr(node) {
294 self.0 = Err(e)
295 }
296 }
297
298 fn visit_item_mut(&mut self, _: &mut Item) {
299 // Do not recurse into nested items.
300 }
301 }
302
303 if let Some(attr) = item.attrs.find(kind.method_name()) {
304 return Err(error!(attr, "duplicate #[{}] attribute", kind.method_name()));
305 }
306
307 let mut visitor = FnVisitor(Ok(()));
308 visitor.visit_block_mut(&mut item.block);
309 visitor.0
310 }
311
replace_item_use(item: &mut ItemUse, kind: ProjKind) -> Result<()>312 fn replace_item_use(item: &mut ItemUse, kind: ProjKind) -> Result<()> {
313 struct UseTreeVisitor {
314 res: Result<()>,
315 kind: ProjKind,
316 }
317
318 impl VisitMut for UseTreeVisitor {
319 fn visit_use_tree_mut(&mut self, node: &mut UseTree) {
320 if self.res.is_err() {
321 return;
322 }
323
324 match node {
325 // Desugar `use tree::<name>` into `tree::__<name>Projection`.
326 UseTree::Name(name) => replace_ident(&mut name.ident, self.kind),
327 UseTree::Glob(glob) => {
328 self.res = Err(error!(
329 glob,
330 "#[{}] attribute may not be used on glob imports",
331 self.kind.method_name()
332 ));
333 }
334 UseTree::Rename(rename) => {
335 self.res = Err(error!(
336 rename,
337 "#[{}] attribute may not be used on renamed imports",
338 self.kind.method_name()
339 ));
340 }
341 UseTree::Path(_) | UseTree::Group(_) => visit_mut::visit_use_tree_mut(self, node),
342 }
343 }
344 }
345
346 if let Some(attr) = item.attrs.find(kind.method_name()) {
347 return Err(error!(attr, "duplicate #[{}] attribute", kind.method_name()));
348 }
349
350 let mut visitor = UseTreeVisitor { res: Ok(()), kind };
351 visitor.visit_item_use_mut(item);
352 visitor.res
353 }
354