1 use crate::lifetime::CollectLifetimes;
2 use crate::parse::Item;
3 use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
4 use proc_macro2::TokenStream;
5 use quote::{format_ident, quote, quote_spanned, ToTokens};
6 use std::collections::BTreeSet as Set;
7 use syn::punctuated::Punctuated;
8 use syn::visit_mut::{self, VisitMut};
9 use syn::{
10     parse_quote, Attribute, Block, FnArg, GenericParam, Generics, Ident, ImplItem, Lifetime, Pat,
11     PatIdent, Receiver, ReturnType, Signature, Stmt, Token, TraitItem, Type, TypeParamBound,
12     TypePath, WhereClause,
13 };
14 
15 macro_rules! parse_quote_spanned {
16     ($span:expr=> $($t:tt)*) => {
17         syn::parse2(quote_spanned!($span=> $($t)*)).unwrap()
18     };
19 }
20 
21 impl ToTokens for Item {
to_tokens(&self, tokens: &mut TokenStream)22     fn to_tokens(&self, tokens: &mut TokenStream) {
23         match self {
24             Item::Trait(item) => item.to_tokens(tokens),
25             Item::Impl(item) => item.to_tokens(tokens),
26         }
27     }
28 }
29 
30 #[derive(Clone, Copy)]
31 enum Context<'a> {
32     Trait {
33         generics: &'a Generics,
34         supertraits: &'a Supertraits,
35     },
36     Impl {
37         impl_generics: &'a Generics,
38         associated_type_impl_traits: &'a Set<Ident>,
39     },
40 }
41 
42 impl Context<'_> {
lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a GenericParam>43     fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a GenericParam> {
44         let generics = match self {
45             Context::Trait { generics, .. } => generics,
46             Context::Impl { impl_generics, .. } => impl_generics,
47         };
48         generics.params.iter().filter(move |param| {
49             if let GenericParam::Lifetime(param) = param {
50                 used.contains(&param.lifetime)
51             } else {
52                 false
53             }
54         })
55     }
56 }
57 
58 type Supertraits = Punctuated<TypeParamBound, Token![+]>;
59 
expand(input: &mut Item, is_local: bool)60 pub fn expand(input: &mut Item, is_local: bool) {
61     match input {
62         Item::Trait(input) => {
63             let context = Context::Trait {
64                 generics: &input.generics,
65                 supertraits: &input.supertraits,
66             };
67             for inner in &mut input.items {
68                 if let TraitItem::Method(method) = inner {
69                     let sig = &mut method.sig;
70                     if sig.asyncness.is_some() {
71                         let block = &mut method.default;
72                         let mut has_self = has_self_in_sig(sig);
73                         method.attrs.push(parse_quote!(#[must_use]));
74                         if let Some(block) = block {
75                             has_self |= has_self_in_block(block);
76                             transform_block(context, sig, block);
77                             method.attrs.push(lint_suppress_with_body());
78                         } else {
79                             method.attrs.push(lint_suppress_without_body());
80                         }
81                         let has_default = method.default.is_some();
82                         transform_sig(context, sig, has_self, has_default, is_local);
83                     }
84                 }
85             }
86         }
87         Item::Impl(input) => {
88             let mut lifetimes = CollectLifetimes::new("'impl", input.impl_token.span);
89             lifetimes.visit_type_mut(&mut *input.self_ty);
90             lifetimes.visit_path_mut(&mut input.trait_.as_mut().unwrap().1);
91             let params = &input.generics.params;
92             let elided = lifetimes.elided;
93             input.generics.params = parse_quote!(#(#elided,)* #params);
94 
95             let mut associated_type_impl_traits = Set::new();
96             for inner in &input.items {
97                 if let ImplItem::Type(assoc) = inner {
98                     if let Type::ImplTrait(_) = assoc.ty {
99                         associated_type_impl_traits.insert(assoc.ident.clone());
100                     }
101                 }
102             }
103 
104             let context = Context::Impl {
105                 impl_generics: &input.generics,
106                 associated_type_impl_traits: &associated_type_impl_traits,
107             };
108             for inner in &mut input.items {
109                 if let ImplItem::Method(method) = inner {
110                     let sig = &mut method.sig;
111                     if sig.asyncness.is_some() {
112                         let block = &mut method.block;
113                         let has_self = has_self_in_sig(sig) || has_self_in_block(block);
114                         transform_block(context, sig, block);
115                         transform_sig(context, sig, has_self, false, is_local);
116                         method.attrs.push(lint_suppress_with_body());
117                     }
118                 }
119             }
120         }
121     }
122 }
123 
lint_suppress_with_body() -> Attribute124 fn lint_suppress_with_body() -> Attribute {
125     parse_quote! {
126         #[allow(
127             clippy::let_unit_value,
128             clippy::no_effect_underscore_binding,
129             clippy::shadow_same,
130             clippy::type_complexity,
131             clippy::type_repetition_in_bounds,
132             clippy::used_underscore_binding
133         )]
134     }
135 }
136 
lint_suppress_without_body() -> Attribute137 fn lint_suppress_without_body() -> Attribute {
138     parse_quote! {
139         #[allow(
140             clippy::type_complexity,
141             clippy::type_repetition_in_bounds
142         )]
143     }
144 }
145 
146 // Input:
147 //     async fn f<T>(&self, x: &T) -> Ret;
148 //
149 // Output:
150 //     fn f<'life0, 'life1, 'async_trait, T>(
151 //         &'life0 self,
152 //         x: &'life1 T,
153 //     ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
154 //     where
155 //         'life0: 'async_trait,
156 //         'life1: 'async_trait,
157 //         T: 'async_trait,
158 //         Self: Sync + 'async_trait;
transform_sig( context: Context, sig: &mut Signature, has_self: bool, has_default: bool, is_local: bool, )159 fn transform_sig(
160     context: Context,
161     sig: &mut Signature,
162     has_self: bool,
163     has_default: bool,
164     is_local: bool,
165 ) {
166     sig.fn_token.span = sig.asyncness.take().unwrap().span;
167 
168     let ret = match &sig.output {
169         ReturnType::Default => quote!(()),
170         ReturnType::Type(_, ret) => quote!(#ret),
171     };
172 
173     let default_span = sig
174         .ident
175         .span()
176         .join(sig.paren_token.span)
177         .unwrap_or_else(|| sig.ident.span());
178 
179     let mut lifetimes = CollectLifetimes::new("'life", default_span);
180     for arg in sig.inputs.iter_mut() {
181         match arg {
182             FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
183             FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
184         }
185     }
186 
187     for param in sig
188         .generics
189         .params
190         .iter()
191         .chain(context.lifetimes(&lifetimes.explicit))
192     {
193         match param {
194             GenericParam::Type(param) => {
195                 let param = &param.ident;
196                 let span = param.span();
197                 where_clause_or_default(&mut sig.generics.where_clause)
198                     .predicates
199                     .push(parse_quote_spanned!(span=> #param: 'async_trait));
200             }
201             GenericParam::Lifetime(param) => {
202                 let param = &param.lifetime;
203                 let span = param.span();
204                 where_clause_or_default(&mut sig.generics.where_clause)
205                     .predicates
206                     .push(parse_quote_spanned!(span=> #param: 'async_trait));
207             }
208             GenericParam::Const(_) => {}
209         }
210     }
211 
212     if sig.generics.lt_token.is_none() {
213         sig.generics.lt_token = Some(Token![<](sig.ident.span()));
214     }
215     if sig.generics.gt_token.is_none() {
216         sig.generics.gt_token = Some(Token![>](sig.paren_token.span));
217     }
218 
219     for elided in lifetimes.elided {
220         sig.generics.params.push(parse_quote!(#elided));
221         where_clause_or_default(&mut sig.generics.where_clause)
222             .predicates
223             .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
224     }
225 
226     sig.generics
227         .params
228         .push(parse_quote_spanned!(default_span=> 'async_trait));
229 
230     if has_self {
231         let bound_span = sig.ident.span();
232         let bound = match sig.inputs.iter().next() {
233             Some(FnArg::Receiver(Receiver {
234                 reference: Some(_),
235                 mutability: None,
236                 ..
237             })) => Ident::new("Sync", bound_span),
238             Some(FnArg::Typed(arg))
239                 if match (arg.pat.as_ref(), arg.ty.as_ref()) {
240                     (Pat::Ident(pat), Type::Reference(ty)) => {
241                         pat.ident == "self" && ty.mutability.is_none()
242                     }
243                     _ => false,
244                 } =>
245             {
246                 Ident::new("Sync", bound_span)
247             }
248             _ => Ident::new("Send", bound_span),
249         };
250 
251         let assume_bound = match context {
252             Context::Trait { supertraits, .. } => !has_default || has_bound(supertraits, &bound),
253             Context::Impl { .. } => true,
254         };
255 
256         let where_clause = where_clause_or_default(&mut sig.generics.where_clause);
257         where_clause.predicates.push(if assume_bound || is_local {
258             parse_quote_spanned!(bound_span=> Self: 'async_trait)
259         } else {
260             parse_quote_spanned!(bound_span=> Self: ::core::marker::#bound + 'async_trait)
261         });
262     }
263 
264     for (i, arg) in sig.inputs.iter_mut().enumerate() {
265         match arg {
266             FnArg::Receiver(Receiver {
267                 reference: Some(_), ..
268             }) => {}
269             FnArg::Receiver(arg) => arg.mutability = None,
270             FnArg::Typed(arg) => {
271                 if let Pat::Ident(ident) = &mut *arg.pat {
272                     ident.by_ref = None;
273                     ident.mutability = None;
274                 } else {
275                     let positional = positional_arg(i, &arg.pat);
276                     let m = mut_pat(&mut arg.pat);
277                     arg.pat = parse_quote!(#m #positional);
278                 }
279             }
280         }
281     }
282 
283     let ret_span = sig.ident.span();
284     let bounds = if is_local {
285         quote_spanned!(ret_span=> 'async_trait)
286     } else {
287         quote_spanned!(ret_span=> ::core::marker::Send + 'async_trait)
288     };
289     sig.output = parse_quote_spanned! {ret_span=>
290         -> ::core::pin::Pin<Box<
291             dyn ::core::future::Future<Output = #ret> + #bounds
292         >>
293     };
294 }
295 
296 // Input:
297 //     async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret {
298 //         self + x + a + b
299 //     }
300 //
301 // Output:
302 //     Box::pin(async move {
303 //         let ___ret: Ret = {
304 //             let __self = self;
305 //             let x = x;
306 //             let (a, b) = __arg1;
307 //
308 //             __self + x + a + b
309 //         };
310 //
311 //         ___ret
312 //     })
transform_block(context: Context, sig: &mut Signature, block: &mut Block)313 fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
314     if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() {
315         if block.stmts.len() == 1 && item.to_string() == ";" {
316             return;
317         }
318     }
319 
320     let mut self_span = None;
321     let decls = sig
322         .inputs
323         .iter()
324         .enumerate()
325         .map(|(i, arg)| match arg {
326             FnArg::Receiver(Receiver {
327                 self_token,
328                 mutability,
329                 ..
330             }) => {
331                 let ident = Ident::new("__self", self_token.span);
332                 self_span = Some(self_token.span);
333                 quote!(let #mutability #ident = #self_token;)
334             }
335             FnArg::Typed(arg) => {
336                 if let Pat::Ident(PatIdent {
337                     ident, mutability, ..
338                 }) = &*arg.pat
339                 {
340                     if ident == "self" {
341                         self_span = Some(ident.span());
342                         let prefixed = Ident::new("__self", ident.span());
343                         quote!(let #mutability #prefixed = #ident;)
344                     } else {
345                         quote!(let #mutability #ident = #ident;)
346                     }
347                 } else {
348                     let pat = &arg.pat;
349                     let ident = positional_arg(i, pat);
350                     quote!(let #pat = #ident;)
351                 }
352             }
353         })
354         .collect::<Vec<_>>();
355 
356     if let Some(span) = self_span {
357         let mut replace_self = ReplaceSelf(span);
358         replace_self.visit_block_mut(block);
359     }
360 
361     let stmts = &block.stmts;
362     let let_ret = match &mut sig.output {
363         ReturnType::Default => quote_spanned! {block.brace_token.span=>
364             #(#decls)*
365             let _: () = { #(#stmts)* };
366         },
367         ReturnType::Type(_, ret) => {
368             if contains_associated_type_impl_trait(context, ret) {
369                 if decls.is_empty() {
370                     quote!(#(#stmts)*)
371                 } else {
372                     quote!(#(#decls)* { #(#stmts)* })
373                 }
374             } else {
375                 quote_spanned! {block.brace_token.span=>
376                     if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
377                         return __ret;
378                     }
379                     #(#decls)*
380                     let __ret: #ret = { #(#stmts)* };
381                     #[allow(unreachable_code)]
382                     __ret
383                 }
384             }
385         }
386     };
387     let box_pin = quote_spanned!(block.brace_token.span=>
388         Box::pin(async move { #let_ret })
389     );
390     block.stmts = parse_quote!(#box_pin);
391 }
392 
positional_arg(i: usize, pat: &Pat) -> Ident393 fn positional_arg(i: usize, pat: &Pat) -> Ident {
394     use syn::spanned::Spanned;
395     format_ident!("__arg{}", i, span = pat.span())
396 }
397 
has_bound(supertraits: &Supertraits, marker: &Ident) -> bool398 fn has_bound(supertraits: &Supertraits, marker: &Ident) -> bool {
399     for bound in supertraits {
400         if let TypeParamBound::Trait(bound) = bound {
401             if bound.path.is_ident(marker)
402                 || bound.path.segments.len() == 3
403                     && (bound.path.segments[0].ident == "std"
404                         || bound.path.segments[0].ident == "core")
405                     && bound.path.segments[1].ident == "marker"
406                     && bound.path.segments[2].ident == *marker
407             {
408                 return true;
409             }
410         }
411     }
412     false
413 }
414 
contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool415 fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
416     struct AssociatedTypeImplTraits<'a> {
417         set: &'a Set<Ident>,
418         contains: bool,
419     }
420 
421     impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
422         fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
423             if ty.qself.is_none()
424                 && ty.path.segments.len() == 2
425                 && ty.path.segments[0].ident == "Self"
426                 && self.set.contains(&ty.path.segments[1].ident)
427             {
428                 self.contains = true;
429             }
430             visit_mut::visit_type_path_mut(self, ty);
431         }
432     }
433 
434     match context {
435         Context::Trait { .. } => false,
436         Context::Impl {
437             associated_type_impl_traits,
438             ..
439         } => {
440             let mut visit = AssociatedTypeImplTraits {
441                 set: associated_type_impl_traits,
442                 contains: false,
443             };
444             visit.visit_type_mut(ret);
445             visit.contains
446         }
447     }
448 }
449 
where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause450 fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
451     clause.get_or_insert_with(|| WhereClause {
452         where_token: Default::default(),
453         predicates: Punctuated::new(),
454     })
455 }
456