1 use crate::lifetime::CollectLifetimes;
2 use crate::parse::Item;
3 use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
4 use proc_macro2::TokenStream;
5 use quote::{format_ident, quote, quote_spanned, ToTokens};
6 use std::collections::BTreeSet as Set;
7 use syn::punctuated::Punctuated;
8 use syn::visit_mut::{self, VisitMut};
9 use syn::{
10 parse_quote, Attribute, Block, FnArg, GenericParam, Generics, Ident, ImplItem, Lifetime, Pat,
11 PatIdent, Receiver, ReturnType, Signature, Stmt, Token, TraitItem, Type, TypeParamBound,
12 TypePath, WhereClause,
13 };
14
15 macro_rules! parse_quote_spanned {
16 ($span:expr=> $($t:tt)*) => {
17 syn::parse2(quote_spanned!($span=> $($t)*)).unwrap()
18 };
19 }
20
21 impl ToTokens for Item {
to_tokens(&self, tokens: &mut TokenStream)22 fn to_tokens(&self, tokens: &mut TokenStream) {
23 match self {
24 Item::Trait(item) => item.to_tokens(tokens),
25 Item::Impl(item) => item.to_tokens(tokens),
26 }
27 }
28 }
29
30 #[derive(Clone, Copy)]
31 enum Context<'a> {
32 Trait {
33 generics: &'a Generics,
34 supertraits: &'a Supertraits,
35 },
36 Impl {
37 impl_generics: &'a Generics,
38 associated_type_impl_traits: &'a Set<Ident>,
39 },
40 }
41
42 impl Context<'_> {
lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a GenericParam>43 fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a GenericParam> {
44 let generics = match self {
45 Context::Trait { generics, .. } => generics,
46 Context::Impl { impl_generics, .. } => impl_generics,
47 };
48 generics.params.iter().filter(move |param| {
49 if let GenericParam::Lifetime(param) = param {
50 used.contains(¶m.lifetime)
51 } else {
52 false
53 }
54 })
55 }
56 }
57
58 type Supertraits = Punctuated<TypeParamBound, Token![+]>;
59
expand(input: &mut Item, is_local: bool)60 pub fn expand(input: &mut Item, is_local: bool) {
61 match input {
62 Item::Trait(input) => {
63 let context = Context::Trait {
64 generics: &input.generics,
65 supertraits: &input.supertraits,
66 };
67 for inner in &mut input.items {
68 if let TraitItem::Method(method) = inner {
69 let sig = &mut method.sig;
70 if sig.asyncness.is_some() {
71 let block = &mut method.default;
72 let mut has_self = has_self_in_sig(sig);
73 method.attrs.push(parse_quote!(#[must_use]));
74 if let Some(block) = block {
75 has_self |= has_self_in_block(block);
76 transform_block(context, sig, block);
77 method.attrs.push(lint_suppress_with_body());
78 } else {
79 method.attrs.push(lint_suppress_without_body());
80 }
81 let has_default = method.default.is_some();
82 transform_sig(context, sig, has_self, has_default, is_local);
83 }
84 }
85 }
86 }
87 Item::Impl(input) => {
88 let mut lifetimes = CollectLifetimes::new("'impl", input.impl_token.span);
89 lifetimes.visit_type_mut(&mut *input.self_ty);
90 lifetimes.visit_path_mut(&mut input.trait_.as_mut().unwrap().1);
91 let params = &input.generics.params;
92 let elided = lifetimes.elided;
93 input.generics.params = parse_quote!(#(#elided,)* #params);
94
95 let mut associated_type_impl_traits = Set::new();
96 for inner in &input.items {
97 if let ImplItem::Type(assoc) = inner {
98 if let Type::ImplTrait(_) = assoc.ty {
99 associated_type_impl_traits.insert(assoc.ident.clone());
100 }
101 }
102 }
103
104 let context = Context::Impl {
105 impl_generics: &input.generics,
106 associated_type_impl_traits: &associated_type_impl_traits,
107 };
108 for inner in &mut input.items {
109 if let ImplItem::Method(method) = inner {
110 let sig = &mut method.sig;
111 if sig.asyncness.is_some() {
112 let block = &mut method.block;
113 let has_self = has_self_in_sig(sig) || has_self_in_block(block);
114 transform_block(context, sig, block);
115 transform_sig(context, sig, has_self, false, is_local);
116 method.attrs.push(lint_suppress_with_body());
117 }
118 }
119 }
120 }
121 }
122 }
123
lint_suppress_with_body() -> Attribute124 fn lint_suppress_with_body() -> Attribute {
125 parse_quote! {
126 #[allow(
127 clippy::let_unit_value,
128 clippy::no_effect_underscore_binding,
129 clippy::shadow_same,
130 clippy::type_complexity,
131 clippy::type_repetition_in_bounds,
132 clippy::used_underscore_binding
133 )]
134 }
135 }
136
lint_suppress_without_body() -> Attribute137 fn lint_suppress_without_body() -> Attribute {
138 parse_quote! {
139 #[allow(
140 clippy::type_complexity,
141 clippy::type_repetition_in_bounds
142 )]
143 }
144 }
145
146 // Input:
147 // async fn f<T>(&self, x: &T) -> Ret;
148 //
149 // Output:
150 // fn f<'life0, 'life1, 'async_trait, T>(
151 // &'life0 self,
152 // x: &'life1 T,
153 // ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
154 // where
155 // 'life0: 'async_trait,
156 // 'life1: 'async_trait,
157 // T: 'async_trait,
158 // Self: Sync + 'async_trait;
transform_sig( context: Context, sig: &mut Signature, has_self: bool, has_default: bool, is_local: bool, )159 fn transform_sig(
160 context: Context,
161 sig: &mut Signature,
162 has_self: bool,
163 has_default: bool,
164 is_local: bool,
165 ) {
166 sig.fn_token.span = sig.asyncness.take().unwrap().span;
167
168 let ret = match &sig.output {
169 ReturnType::Default => quote!(()),
170 ReturnType::Type(_, ret) => quote!(#ret),
171 };
172
173 let default_span = sig
174 .ident
175 .span()
176 .join(sig.paren_token.span)
177 .unwrap_or_else(|| sig.ident.span());
178
179 let mut lifetimes = CollectLifetimes::new("'life", default_span);
180 for arg in sig.inputs.iter_mut() {
181 match arg {
182 FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
183 FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
184 }
185 }
186
187 for param in sig
188 .generics
189 .params
190 .iter()
191 .chain(context.lifetimes(&lifetimes.explicit))
192 {
193 match param {
194 GenericParam::Type(param) => {
195 let param = ¶m.ident;
196 let span = param.span();
197 where_clause_or_default(&mut sig.generics.where_clause)
198 .predicates
199 .push(parse_quote_spanned!(span=> #param: 'async_trait));
200 }
201 GenericParam::Lifetime(param) => {
202 let param = ¶m.lifetime;
203 let span = param.span();
204 where_clause_or_default(&mut sig.generics.where_clause)
205 .predicates
206 .push(parse_quote_spanned!(span=> #param: 'async_trait));
207 }
208 GenericParam::Const(_) => {}
209 }
210 }
211
212 if sig.generics.lt_token.is_none() {
213 sig.generics.lt_token = Some(Token![<](sig.ident.span()));
214 }
215 if sig.generics.gt_token.is_none() {
216 sig.generics.gt_token = Some(Token![>](sig.paren_token.span));
217 }
218
219 for elided in lifetimes.elided {
220 sig.generics.params.push(parse_quote!(#elided));
221 where_clause_or_default(&mut sig.generics.where_clause)
222 .predicates
223 .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
224 }
225
226 sig.generics
227 .params
228 .push(parse_quote_spanned!(default_span=> 'async_trait));
229
230 if has_self {
231 let bound_span = sig.ident.span();
232 let bound = match sig.inputs.iter().next() {
233 Some(FnArg::Receiver(Receiver {
234 reference: Some(_),
235 mutability: None,
236 ..
237 })) => Ident::new("Sync", bound_span),
238 Some(FnArg::Typed(arg))
239 if match (arg.pat.as_ref(), arg.ty.as_ref()) {
240 (Pat::Ident(pat), Type::Reference(ty)) => {
241 pat.ident == "self" && ty.mutability.is_none()
242 }
243 _ => false,
244 } =>
245 {
246 Ident::new("Sync", bound_span)
247 }
248 _ => Ident::new("Send", bound_span),
249 };
250
251 let assume_bound = match context {
252 Context::Trait { supertraits, .. } => !has_default || has_bound(supertraits, &bound),
253 Context::Impl { .. } => true,
254 };
255
256 let where_clause = where_clause_or_default(&mut sig.generics.where_clause);
257 where_clause.predicates.push(if assume_bound || is_local {
258 parse_quote_spanned!(bound_span=> Self: 'async_trait)
259 } else {
260 parse_quote_spanned!(bound_span=> Self: ::core::marker::#bound + 'async_trait)
261 });
262 }
263
264 for (i, arg) in sig.inputs.iter_mut().enumerate() {
265 match arg {
266 FnArg::Receiver(Receiver {
267 reference: Some(_), ..
268 }) => {}
269 FnArg::Receiver(arg) => arg.mutability = None,
270 FnArg::Typed(arg) => {
271 if let Pat::Ident(ident) = &mut *arg.pat {
272 ident.by_ref = None;
273 ident.mutability = None;
274 } else {
275 let positional = positional_arg(i, &arg.pat);
276 let m = mut_pat(&mut arg.pat);
277 arg.pat = parse_quote!(#m #positional);
278 }
279 }
280 }
281 }
282
283 let ret_span = sig.ident.span();
284 let bounds = if is_local {
285 quote_spanned!(ret_span=> 'async_trait)
286 } else {
287 quote_spanned!(ret_span=> ::core::marker::Send + 'async_trait)
288 };
289 sig.output = parse_quote_spanned! {ret_span=>
290 -> ::core::pin::Pin<Box<
291 dyn ::core::future::Future<Output = #ret> + #bounds
292 >>
293 };
294 }
295
296 // Input:
297 // async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret {
298 // self + x + a + b
299 // }
300 //
301 // Output:
302 // Box::pin(async move {
303 // let ___ret: Ret = {
304 // let __self = self;
305 // let x = x;
306 // let (a, b) = __arg1;
307 //
308 // __self + x + a + b
309 // };
310 //
311 // ___ret
312 // })
transform_block(context: Context, sig: &mut Signature, block: &mut Block)313 fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
314 if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() {
315 if block.stmts.len() == 1 && item.to_string() == ";" {
316 return;
317 }
318 }
319
320 let mut self_span = None;
321 let decls = sig
322 .inputs
323 .iter()
324 .enumerate()
325 .map(|(i, arg)| match arg {
326 FnArg::Receiver(Receiver {
327 self_token,
328 mutability,
329 ..
330 }) => {
331 let ident = Ident::new("__self", self_token.span);
332 self_span = Some(self_token.span);
333 quote!(let #mutability #ident = #self_token;)
334 }
335 FnArg::Typed(arg) => {
336 if let Pat::Ident(PatIdent {
337 ident, mutability, ..
338 }) = &*arg.pat
339 {
340 if ident == "self" {
341 self_span = Some(ident.span());
342 let prefixed = Ident::new("__self", ident.span());
343 quote!(let #mutability #prefixed = #ident;)
344 } else {
345 quote!(let #mutability #ident = #ident;)
346 }
347 } else {
348 let pat = &arg.pat;
349 let ident = positional_arg(i, pat);
350 quote!(let #pat = #ident;)
351 }
352 }
353 })
354 .collect::<Vec<_>>();
355
356 if let Some(span) = self_span {
357 let mut replace_self = ReplaceSelf(span);
358 replace_self.visit_block_mut(block);
359 }
360
361 let stmts = &block.stmts;
362 let let_ret = match &mut sig.output {
363 ReturnType::Default => quote_spanned! {block.brace_token.span=>
364 #(#decls)*
365 let _: () = { #(#stmts)* };
366 },
367 ReturnType::Type(_, ret) => {
368 if contains_associated_type_impl_trait(context, ret) {
369 if decls.is_empty() {
370 quote!(#(#stmts)*)
371 } else {
372 quote!(#(#decls)* { #(#stmts)* })
373 }
374 } else {
375 quote_spanned! {block.brace_token.span=>
376 if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
377 return __ret;
378 }
379 #(#decls)*
380 let __ret: #ret = { #(#stmts)* };
381 #[allow(unreachable_code)]
382 __ret
383 }
384 }
385 }
386 };
387 let box_pin = quote_spanned!(block.brace_token.span=>
388 Box::pin(async move { #let_ret })
389 );
390 block.stmts = parse_quote!(#box_pin);
391 }
392
positional_arg(i: usize, pat: &Pat) -> Ident393 fn positional_arg(i: usize, pat: &Pat) -> Ident {
394 use syn::spanned::Spanned;
395 format_ident!("__arg{}", i, span = pat.span())
396 }
397
has_bound(supertraits: &Supertraits, marker: &Ident) -> bool398 fn has_bound(supertraits: &Supertraits, marker: &Ident) -> bool {
399 for bound in supertraits {
400 if let TypeParamBound::Trait(bound) = bound {
401 if bound.path.is_ident(marker)
402 || bound.path.segments.len() == 3
403 && (bound.path.segments[0].ident == "std"
404 || bound.path.segments[0].ident == "core")
405 && bound.path.segments[1].ident == "marker"
406 && bound.path.segments[2].ident == *marker
407 {
408 return true;
409 }
410 }
411 }
412 false
413 }
414
contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool415 fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
416 struct AssociatedTypeImplTraits<'a> {
417 set: &'a Set<Ident>,
418 contains: bool,
419 }
420
421 impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
422 fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
423 if ty.qself.is_none()
424 && ty.path.segments.len() == 2
425 && ty.path.segments[0].ident == "Self"
426 && self.set.contains(&ty.path.segments[1].ident)
427 {
428 self.contains = true;
429 }
430 visit_mut::visit_type_path_mut(self, ty);
431 }
432 }
433
434 match context {
435 Context::Trait { .. } => false,
436 Context::Impl {
437 associated_type_impl_traits,
438 ..
439 } => {
440 let mut visit = AssociatedTypeImplTraits {
441 set: associated_type_impl_traits,
442 contains: false,
443 };
444 visit.visit_type_mut(ret);
445 visit.contains
446 }
447 }
448 }
449
where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause450 fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
451 clause.get_or_insert_with(|| WhereClause {
452 where_token: Default::default(),
453 predicates: Punctuated::new(),
454 })
455 }
456