1 //! A procedural macro attribute for instrumenting functions with [`tracing`].
2 //!
3 //! [`tracing`] is a framework for instrumenting Rust programs to collect
4 //! structured, event-based diagnostic information. This crate provides the
5 //! [`#[instrument]`][instrument] procedural macro attribute.
6 //!
7 //! Note that this macro is also re-exported by the main `tracing` crate.
8 //!
9 //! *Compiler support: [requires `rustc` 1.40+][msrv]*
10 //!
11 //! [msrv]: #supported-rust-versions
12 //!
13 //! ## Usage
14 //!
15 //! First, add this to your `Cargo.toml`:
16 //!
17 //! ```toml
18 //! [dependencies]
19 //! tracing-attributes = "0.1.11"
20 //! ```
21 //!
22 //! The [`#[instrument]`][instrument] attribute can now be added to a function
23 //! to automatically create and enter `tracing` [span] when that function is
24 //! called. For example:
25 //!
26 //! ```
27 //! use tracing_attributes::instrument;
28 //!
29 //! #[instrument]
30 //! pub fn my_function(my_arg: usize) {
31 //!     // ...
32 //! }
33 //!
34 //! # fn main() {}
35 //! ```
36 //!
37 //! [`tracing`]: https://crates.io/crates/tracing
38 //! [span]: https://docs.rs/tracing/latest/tracing/span/index.html
39 //! [instrument]: attr.instrument.html
40 //!
41 //! ## Supported Rust Versions
42 //!
43 //! Tracing is built against the latest stable release. The minimum supported
44 //! version is 1.40. The current Tracing version is not guaranteed to build on
45 //! Rust versions earlier than the minimum supported version.
46 //!
47 //! Tracing follows the same compiler support policies as the rest of the Tokio
48 //! project. The current stable Rust compiler and the three most recent minor
49 //! versions before it will always be supported. For example, if the current
50 //! stable compiler version is 1.45, the minimum supported version will not be
51 //! increased past 1.42, three minor versions prior. Increasing the minimum
52 //! supported compiler version is not considered a semver breaking change as
53 //! long as doing so complies with this policy.
54 //!
55 #![doc(html_root_url = "https://docs.rs/tracing-attributes/0.1.11")]
56 #![doc(
57     html_logo_url = "https://raw.githubusercontent.com/tokio-rs/tracing/master/assets/logo.svg",
58     issue_tracker_base_url = "https://github.com/tokio-rs/tracing/issues/"
59 )]
60 #![warn(
61     missing_debug_implementations,
62     missing_docs,
63     rust_2018_idioms,
64     unreachable_pub,
65     bad_style,
66     const_err,
67     dead_code,
68     improper_ctypes,
69     non_shorthand_field_patterns,
70     no_mangle_generic_items,
71     overflowing_literals,
72     path_statements,
73     patterns_in_fns_without_body,
74     private_in_public,
75     unconditional_recursion,
76     unused,
77     unused_allocation,
78     unused_comparisons,
79     unused_parens,
80     while_true
81 )]
82 // TODO: once `tracing` bumps its MSRV to 1.42, remove this allow.
83 #![allow(unused)]
84 extern crate proc_macro;
85 
86 use std::collections::{HashMap, HashSet};
87 use std::iter;
88 
89 use proc_macro2::TokenStream;
90 use quote::{quote, quote_spanned, ToTokens, TokenStreamExt as _};
91 use syn::ext::IdentExt as _;
92 use syn::parse::{Parse, ParseStream};
93 use syn::{
94     punctuated::Punctuated, spanned::Spanned, AttributeArgs, Block, Expr, ExprCall, FieldPat,
95     FnArg, Ident, Item, ItemFn, Lit, LitInt, LitStr, Meta, MetaList, MetaNameValue, NestedMeta,
96     Pat, PatIdent, PatReference, PatStruct, PatTuple, PatTupleStruct, PatType, Path, Signature,
97     Stmt, Token,
98 };
99 /// Instruments a function to create and enter a `tracing` [span] every time
100 /// the function is called.
101 ///
102 /// The generated span's name will be the name of the function. Any arguments
103 /// to that function will be recorded as fields using [`fmt::Debug`]. To skip
104 /// recording a function's or method's argument, pass the argument's name
105 /// to the `skip` argument on the `#[instrument]` macro. For example,
106 /// `skip` can be used when an argument to an instrumented function does
107 /// not implement [`fmt::Debug`], or to exclude an argument with a verbose
108 /// or costly Debug implementation. Note that:
109 /// - multiple argument names can be passed to `skip`.
110 /// - arguments passed to `skip` do _not_ need to implement `fmt::Debug`.
111 ///
112 /// You can also pass additional fields (key-value pairs with arbitrary data)
113 /// to the generated span. This is achieved using the `fields` argument on the
114 /// `#[instrument]` macro. You can use a string, integer or boolean literal as
115 /// a value for each field. The name of the field must be a single valid Rust
116 /// identifier, nested (dotted) field names are not supported.
117 ///
118 /// Note that overlap between the names of fields and (non-skipped) arguments
119 /// will result in a compile error.
120 ///
121 /// # Examples
122 /// Instrumenting a function:
123 /// ```
124 /// # use tracing_attributes::instrument;
125 /// #[instrument]
126 /// pub fn my_function(my_arg: usize) {
127 ///     // This event will be recorded inside a span named `my_function` with the
128 ///     // field `my_arg`.
129 ///     tracing::info!("inside my_function!");
130 ///     // ...
131 /// }
132 /// ```
133 /// Setting the level for the generated span:
134 /// ```
135 /// # use tracing_attributes::instrument;
136 /// #[instrument(level = "debug")]
137 /// pub fn my_function() {
138 ///     // ...
139 /// }
140 /// ```
141 /// Overriding the generated span's name:
142 /// ```
143 /// # use tracing_attributes::instrument;
144 /// #[instrument(name = "my_name")]
145 /// pub fn my_function() {
146 ///     // ...
147 /// }
148 /// ```
149 /// Overriding the generated span's target:
150 /// ```
151 /// # use tracing_attributes::instrument;
152 /// #[instrument(target = "my_target")]
153 /// pub fn my_function() {
154 ///     // ...
155 /// }
156 /// ```
157 ///
158 /// To skip recording an argument, pass the argument's name to the `skip`:
159 ///
160 /// ```
161 /// # use tracing_attributes::instrument;
162 /// struct NonDebug;
163 ///
164 /// #[instrument(skip(non_debug))]
165 /// fn my_function(arg: usize, non_debug: NonDebug) {
166 ///     // ...
167 /// }
168 /// ```
169 ///
170 /// To add an additional context to the span, you can pass key-value pairs to `fields`:
171 ///
172 /// ```
173 /// # use tracing_attributes::instrument;
174 /// #[instrument(fields(foo="bar", id=1, show=true))]
175 /// fn my_function(arg: usize) {
176 ///     // ...
177 /// }
178 /// ```
179 ///
180 /// If the function returns a `Result<T, E>` and `E` implements `std::fmt::Display`, you can add
181 /// `err` to emit error events when the function returns `Err`:
182 ///
183 /// ```
184 /// # use tracing_attributes::instrument;
185 /// #[instrument(err)]
186 /// fn my_function(arg: usize) -> Result<(), std::io::Error> {
187 ///     Ok(())
188 /// }
189 /// ```
190 ///
191 /// If `tracing_futures` is specified as a dependency in `Cargo.toml`,
192 /// `async fn`s may also be instrumented:
193 ///
194 /// ```
195 /// # use tracing_attributes::instrument;
196 /// #[instrument]
197 /// pub async fn my_function() -> Result<(), ()> {
198 ///     // ...
199 ///     # Ok(())
200 /// }
201 /// ```
202 ///
203 /// It also works with [async-trait](https://crates.io/crates/async-trait)
204 /// (a crate that allows defining async functions in traits,
205 /// something not currently possible in Rust),
206 /// and hopefully most libraries that exhibit similar behaviors:
207 ///
208 /// ```
209 /// # use tracing::instrument;
210 /// use async_trait::async_trait;
211 ///
212 /// #[async_trait]
213 /// pub trait Foo {
214 ///     async fn foo(&self, arg: usize);
215 /// }
216 ///
217 /// #[derive(Debug)]
218 /// struct FooImpl(usize);
219 ///
220 /// #[async_trait]
221 /// impl Foo for FooImpl {
222 ///     #[instrument(fields(value = self.0, tmp = std::any::type_name::<Self>()))]
223 ///     async fn foo(&self, arg: usize) {}
224 /// }
225 /// ```
226 ///
227 /// An interesting note on this subject is that references to the `Self`
228 /// type inside the `fields` argument are only allowed when the instrumented
229 /// function is a method aka. the function receives `self` as an argument.
230 /// For example, this *will not work* because it doesn't receive `self`:
231 /// ```compile_fail
232 /// # use tracing::instrument;
233 /// use async_trait::async_trait;
234 ///
235 /// #[async_trait]
236 /// pub trait Bar {
237 ///     async fn bar();
238 /// }
239 ///
240 /// #[derive(Debug)]
241 /// struct BarImpl(usize);
242 ///
243 /// #[async_trait]
244 /// impl Bar for BarImpl {
245 ///     #[instrument(fields(tmp = std::any::type_name::<Self>()))]
246 ///     async fn bar() {}
247 /// }
248 /// ```
249 /// Instead, you should manually rewrite any `Self` types as the type for
250 /// which you implement the trait: `#[instrument(fields(tmp = std::any::type_name::<Bar>()))]`.
251 
252 ///
253 /// [span]: https://docs.rs/tracing/latest/tracing/span/index.html
254 /// [`tracing`]: https://github.com/tokio-rs/tracing
255 /// [`fmt::Debug`]: https://doc.rust-lang.org/std/fmt/trait.Debug.html
256 #[proc_macro_attribute]
instrument( args: proc_macro::TokenStream, item: proc_macro::TokenStream, ) -> proc_macro::TokenStream257 pub fn instrument(
258     args: proc_macro::TokenStream,
259     item: proc_macro::TokenStream,
260 ) -> proc_macro::TokenStream {
261     let input: ItemFn = syn::parse_macro_input!(item as ItemFn);
262     let args = syn::parse_macro_input!(args as InstrumentArgs);
263 
264     let instrumented_function_name = input.sig.ident.to_string();
265 
266     // check for async_trait-like patterns in the block and wrap the
267     // internal function with Instrument instead of wrapping the
268     // async_trait generated wrapper
269     if let Some(internal_fun) = get_async_trait_info(&input.block, input.sig.asyncness.is_some()) {
270         // let's rewrite some statements!
271         let mut stmts: Vec<Stmt> = input.block.stmts.to_vec();
272         for stmt in &mut stmts {
273             if let Stmt::Item(Item::Fn(fun)) = stmt {
274                 // instrument the function if we considered it as the one we truly want to trace
275                 if fun.sig.ident == internal_fun.name {
276                     *stmt = syn::parse2(gen_body(
277                         fun,
278                         args,
279                         instrumented_function_name,
280                         Some(internal_fun),
281                     ))
282                     .unwrap();
283                     break;
284                 }
285             }
286         }
287 
288         let sig = &input.sig;
289         let attrs = &input.attrs;
290         quote!(
291             #(#attrs) *
292             #sig {
293                 #(#stmts) *
294             }
295         )
296         .into()
297     } else {
298         gen_body(&input, args, instrumented_function_name, None).into()
299     }
300 }
301 
gen_body( input: &ItemFn, mut args: InstrumentArgs, instrumented_function_name: String, async_trait_fun: Option<AsyncTraitInfo>, ) -> proc_macro2::TokenStream302 fn gen_body(
303     input: &ItemFn,
304     mut args: InstrumentArgs,
305     instrumented_function_name: String,
306     async_trait_fun: Option<AsyncTraitInfo>,
307 ) -> proc_macro2::TokenStream {
308     // these are needed ahead of time, as ItemFn contains the function body _and_
309     // isn't representable inside a quote!/quote_spanned! macro
310     // (Syn's ToTokens isn't implemented for ItemFn)
311     let ItemFn {
312         attrs,
313         vis,
314         block,
315         sig,
316         ..
317     } = input;
318 
319     let Signature {
320         output: return_type,
321         inputs: params,
322         unsafety,
323         asyncness,
324         constness,
325         abi,
326         ident,
327         generics:
328             syn::Generics {
329                 params: gen_params,
330                 where_clause,
331                 ..
332             },
333         ..
334     } = sig;
335 
336     let err = args.err;
337     let warnings = args.warnings();
338 
339     // generate the span's name
340     let span_name = args
341         // did the user override the span's name?
342         .name
343         .as_ref()
344         .map(|name| quote!(#name))
345         .unwrap_or_else(|| quote!(#instrumented_function_name));
346 
347     // generate this inside a closure, so we can return early on errors.
348     let span = (|| {
349         // Pull out the arguments-to-be-skipped first, so we can filter results
350         // below.
351         let param_names: Vec<(Ident, Ident)> = params
352             .clone()
353             .into_iter()
354             .flat_map(|param| match param {
355                 FnArg::Typed(PatType { pat, .. }) => param_names(*pat),
356                 FnArg::Receiver(_) => Box::new(iter::once(Ident::new("self", param.span()))),
357             })
358             // Little dance with new (user-exposed) names and old (internal)
359             // names of identifiers. That way, you can do the following
360             // even though async_trait rewrite "self" as "_self":
361             // ```
362             // #[async_trait]
363             // impl Foo for FooImpl {
364             //     #[instrument(skip(self))]
365             //     async fn foo(&self, v: usize) {}
366             // }
367             // ```
368             .map(|x| {
369                 // if we are inside a function generated by async-trait, we
370                 // should take care to rewrite "_self" as "self" for
371                 // 'user convenience'
372                 if async_trait_fun.is_some() && x == "_self" {
373                     (Ident::new("self", x.span()), x)
374                 } else {
375                     (x.clone(), x)
376                 }
377             })
378             .collect();
379 
380         for skip in &args.skips {
381             if !param_names.iter().map(|(user, _)| user).any(|y| y == skip) {
382                 return quote_spanned! {skip.span()=>
383                     compile_error!("attempting to skip non-existent parameter")
384                 };
385             }
386         }
387 
388         let level = args.level();
389         let target = args.target();
390 
391         // filter out skipped fields
392         let mut quoted_fields: Vec<_> = param_names
393             .into_iter()
394             .filter(|(param, _)| {
395                 if args.skips.contains(param) {
396                     return false;
397                 }
398 
399                 // If any parameters have the same name as a custom field, skip
400                 // and allow them to be formatted by the custom field.
401                 if let Some(ref fields) = args.fields {
402                     fields.0.iter().all(|Field { ref name, .. }| {
403                         let first = name.first();
404                         first != name.last() || !first.iter().any(|name| name == &param)
405                     })
406                 } else {
407                     true
408                 }
409             })
410             .map(|(user_name, real_name)| quote!(#user_name = tracing::field::debug(&#real_name)))
411             .collect();
412 
413         // when async-trait is in use, replace instances of "self" with "_self" inside the fields values
414         if let (Some(ref async_trait_fun), Some(Fields(ref mut fields))) =
415             (async_trait_fun, &mut args.fields)
416         {
417             let mut replacer = SelfReplacer {
418                 ty: async_trait_fun.self_type.clone(),
419             };
420             for e in fields.iter_mut().filter_map(|f| f.value.as_mut()) {
421                 syn::visit_mut::visit_expr_mut(&mut replacer, e);
422             }
423         }
424 
425         let custom_fields = &args.fields;
426 
427         quote!(tracing::span!(
428             target: #target,
429             #level,
430             #span_name,
431             #(#quoted_fields,)*
432             #custom_fields
433 
434         ))
435     })();
436 
437     // Generate the instrumented function body.
438     // If the function is an `async fn`, this will wrap it in an async block,
439     // which is `instrument`ed using `tracing-futures`. Otherwise, this will
440     // enter the span and then perform the rest of the body.
441     // If `err` is in args, instrument any resulting `Err`s.
442     let body = if asyncness.is_some() {
443         if err {
444             quote_spanned! {block.span()=>
445                 let __tracing_attr_span = #span;
446                 tracing_futures::Instrument::instrument(async move {
447                     match async move { #block }.await {
448                         Ok(x) => Ok(x),
449                         Err(e) => {
450                             tracing::error!(error = %e);
451                             Err(e)
452                         }
453                     }
454                 }, __tracing_attr_span).await
455             }
456         } else {
457             quote_spanned!(block.span()=>
458                 let __tracing_attr_span = #span;
459                     tracing_futures::Instrument::instrument(
460                         async move { #block },
461                         __tracing_attr_span
462                     )
463                     .await
464             )
465         }
466     } else if err {
467         quote_spanned!(block.span()=>
468             let __tracing_attr_span = #span;
469             let __tracing_attr_guard = __tracing_attr_span.enter();
470             match { #block } {
471                 Ok(x) => Ok(x),
472                 Err(e) => {
473                     tracing::error!(error = %e);
474                     Err(e)
475                 }
476             }
477         )
478     } else {
479         quote_spanned!(block.span()=>
480             let __tracing_attr_span = #span;
481             let __tracing_attr_guard = __tracing_attr_span.enter();
482             #block
483         )
484     };
485 
486     quote!(
487         #(#attrs) *
488         #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #return_type
489         #where_clause
490         {
491             #warnings
492             #body
493         }
494     )
495 }
496 
497 #[derive(Default, Debug)]
498 struct InstrumentArgs {
499     level: Option<Level>,
500     name: Option<LitStr>,
501     target: Option<LitStr>,
502     skips: HashSet<Ident>,
503     fields: Option<Fields>,
504     err: bool,
505     /// Errors describing any unrecognized parse inputs that we skipped.
506     parse_warnings: Vec<syn::Error>,
507 }
508 
509 impl InstrumentArgs {
level(&self) -> impl ToTokens510     fn level(&self) -> impl ToTokens {
511         fn is_level(lit: &LitInt, expected: u64) -> bool {
512             match lit.base10_parse::<u64>() {
513                 Ok(value) => value == expected,
514                 Err(_) => false,
515             }
516         }
517 
518         match &self.level {
519             Some(Level::Str(ref lit)) if lit.value().eq_ignore_ascii_case("trace") => {
520                 quote!(tracing::Level::TRACE)
521             }
522             Some(Level::Str(ref lit)) if lit.value().eq_ignore_ascii_case("debug") => {
523                 quote!(tracing::Level::DEBUG)
524             }
525             Some(Level::Str(ref lit)) if lit.value().eq_ignore_ascii_case("info") => {
526                 quote!(tracing::Level::INFO)
527             }
528             Some(Level::Str(ref lit)) if lit.value().eq_ignore_ascii_case("warn") => {
529                 quote!(tracing::Level::WARN)
530             }
531             Some(Level::Str(ref lit)) if lit.value().eq_ignore_ascii_case("error") => {
532                 quote!(tracing::Level::ERROR)
533             }
534             Some(Level::Int(ref lit)) if is_level(lit, 1) => quote!(tracing::Level::TRACE),
535             Some(Level::Int(ref lit)) if is_level(lit, 2) => quote!(tracing::Level::DEBUG),
536             Some(Level::Int(ref lit)) if is_level(lit, 3) => quote!(tracing::Level::INFO),
537             Some(Level::Int(ref lit)) if is_level(lit, 4) => quote!(tracing::Level::WARN),
538             Some(Level::Int(ref lit)) if is_level(lit, 5) => quote!(tracing::Level::ERROR),
539             Some(Level::Path(ref pat)) => quote!(#pat),
540             Some(lit) => quote! {
541                 compile_error!(
542                     "unknown verbosity level, expected one of \"trace\", \
543                      \"debug\", \"info\", \"warn\", or \"error\", or a number 1-5"
544                 )
545             },
546             None => quote!(tracing::Level::INFO),
547         }
548     }
549 
target(&self) -> impl ToTokens550     fn target(&self) -> impl ToTokens {
551         if let Some(ref target) = self.target {
552             quote!(#target)
553         } else {
554             quote!(module_path!())
555         }
556     }
557 
558     /// Generate "deprecation" warnings for any unrecognized attribute inputs
559     /// that we skipped.
560     ///
561     /// For backwards compatibility, we need to emit compiler warnings rather
562     /// than errors for unrecognized inputs. Generating a fake deprecation is
563     /// the only way to do this on stable Rust right now.
warnings(&self) -> impl ToTokens564     fn warnings(&self) -> impl ToTokens {
565         let warnings = self.parse_warnings.iter().map(|err| {
566             let msg = format!("found unrecognized input, {}", err);
567             let msg = LitStr::new(&msg, err.span());
568             // TODO(eliza): This is a bit of a hack, but it's just about the
569             // only way to emit warnings from a proc macro on stable Rust.
570             // Eventually, when the `proc_macro::Diagnostic` API stabilizes, we
571             // should definitely use that instead.
572             quote_spanned! {err.span()=>
573                 #[warn(deprecated)]
574                 {
575                     #[deprecated(since = "not actually deprecated", note = #msg)]
576                     const TRACING_INSTRUMENT_WARNING: () = ();
577                     let _ = TRACING_INSTRUMENT_WARNING;
578                 }
579             }
580         });
581         quote! {
582             { #(#warnings)* }
583         }
584     }
585 }
586 
587 impl Parse for InstrumentArgs {
parse(input: ParseStream<'_>) -> syn::Result<Self>588     fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
589         let mut args = Self::default();
590         while !input.is_empty() {
591             let lookahead = input.lookahead1();
592             if lookahead.peek(kw::name) {
593                 if args.name.is_some() {
594                     return Err(input.error("expected only a single `name` argument"));
595                 }
596                 let name = input.parse::<StrArg<kw::name>>()?.value;
597                 args.name = Some(name);
598             } else if lookahead.peek(LitStr) {
599                 // XXX: apparently we support names as either named args with an
600                 // sign, _or_ as unnamed string literals. That's weird, but
601                 // changing it is apparently breaking.
602                 if args.name.is_some() {
603                     return Err(input.error("expected only a single `name` argument"));
604                 }
605                 args.name = Some(input.parse()?);
606             } else if lookahead.peek(kw::target) {
607                 if args.target.is_some() {
608                     return Err(input.error("expected only a single `target` argument"));
609                 }
610                 let target = input.parse::<StrArg<kw::target>>()?.value;
611                 args.target = Some(target);
612             } else if lookahead.peek(kw::level) {
613                 if args.level.is_some() {
614                     return Err(input.error("expected only a single `level` argument"));
615                 }
616                 args.level = Some(input.parse()?);
617             } else if lookahead.peek(kw::skip) {
618                 if !args.skips.is_empty() {
619                     return Err(input.error("expected only a single `skip` argument"));
620                 }
621                 let Skips(skips) = input.parse()?;
622                 args.skips = skips;
623             } else if lookahead.peek(kw::fields) {
624                 if args.fields.is_some() {
625                     return Err(input.error("expected only a single `fields` argument"));
626                 }
627                 args.fields = Some(input.parse()?);
628             } else if lookahead.peek(kw::err) {
629                 let _ = input.parse::<kw::err>()?;
630                 args.err = true;
631             } else if lookahead.peek(Token![,]) {
632                 let _ = input.parse::<Token![,]>()?;
633             } else {
634                 // We found a token that we didn't expect!
635                 // We want to emit warnings for these, rather than errors, so
636                 // we'll add it to the list of unrecognized inputs we've seen so
637                 // far and keep going.
638                 args.parse_warnings.push(lookahead.error());
639                 // Parse the unrecognized token tree to advance the parse
640                 // stream, and throw it away so we can keep parsing.
641                 let _ = input.parse::<proc_macro2::TokenTree>();
642             }
643         }
644         Ok(args)
645     }
646 }
647 
648 struct StrArg<T> {
649     value: LitStr,
650     _p: std::marker::PhantomData<T>,
651 }
652 
653 impl<T: Parse> Parse for StrArg<T> {
parse(input: ParseStream<'_>) -> syn::Result<Self>654     fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
655         let _ = input.parse::<T>()?;
656         let _ = input.parse::<Token![=]>()?;
657         let value = input.parse()?;
658         Ok(Self {
659             value,
660             _p: std::marker::PhantomData,
661         })
662     }
663 }
664 
665 struct Skips(HashSet<Ident>);
666 
667 impl Parse for Skips {
parse(input: ParseStream<'_>) -> syn::Result<Self>668     fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
669         let _ = input.parse::<kw::skip>();
670         let content;
671         let _ = syn::parenthesized!(content in input);
672         let names: Punctuated<Ident, Token![,]> = content.parse_terminated(Ident::parse_any)?;
673         let mut skips = HashSet::new();
674         for name in names {
675             if skips.contains(&name) {
676                 return Err(syn::Error::new(
677                     name.span(),
678                     "tried to skip the same field twice",
679                 ));
680             } else {
681                 skips.insert(name);
682             }
683         }
684         Ok(Self(skips))
685     }
686 }
687 
688 #[derive(Debug)]
689 struct Fields(Punctuated<Field, Token![,]>);
690 
691 #[derive(Debug)]
692 struct Field {
693     name: Punctuated<Ident, Token![.]>,
694     value: Option<Expr>,
695     kind: FieldKind,
696 }
697 
698 #[derive(Debug, Eq, PartialEq)]
699 enum FieldKind {
700     Debug,
701     Display,
702     Value,
703 }
704 
705 impl Parse for Fields {
parse(input: ParseStream<'_>) -> syn::Result<Self>706     fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
707         let _ = input.parse::<kw::fields>();
708         let content;
709         let _ = syn::parenthesized!(content in input);
710         let fields: Punctuated<_, Token![,]> = content.parse_terminated(Field::parse)?;
711         Ok(Self(fields))
712     }
713 }
714 
715 impl ToTokens for Fields {
to_tokens(&self, tokens: &mut TokenStream)716     fn to_tokens(&self, tokens: &mut TokenStream) {
717         self.0.to_tokens(tokens)
718     }
719 }
720 
721 impl Parse for Field {
parse(input: ParseStream<'_>) -> syn::Result<Self>722     fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
723         let mut kind = FieldKind::Value;
724         if input.peek(Token![%]) {
725             input.parse::<Token![%]>()?;
726             kind = FieldKind::Display;
727         } else if input.peek(Token![?]) {
728             input.parse::<Token![?]>()?;
729             kind = FieldKind::Debug;
730         };
731         let name = Punctuated::parse_separated_nonempty_with(input, Ident::parse_any)?;
732         let value = if input.peek(Token![=]) {
733             input.parse::<Token![=]>()?;
734             if input.peek(Token![%]) {
735                 input.parse::<Token![%]>()?;
736                 kind = FieldKind::Display;
737             } else if input.peek(Token![?]) {
738                 input.parse::<Token![?]>()?;
739                 kind = FieldKind::Debug;
740             };
741             Some(input.parse()?)
742         } else {
743             None
744         };
745         Ok(Self { name, kind, value })
746     }
747 }
748 
749 impl ToTokens for Field {
to_tokens(&self, tokens: &mut TokenStream)750     fn to_tokens(&self, tokens: &mut TokenStream) {
751         if let Some(ref value) = self.value {
752             let name = &self.name;
753             let kind = &self.kind;
754             tokens.extend(quote! {
755                 #name = #kind#value
756             })
757         } else if self.kind == FieldKind::Value {
758             // XXX(eliza): I don't like that fields without values produce
759             // empty fields rather than local variable shorthand...but,
760             // we've released a version where field names without values in
761             // `instrument` produce empty field values, so changing it now
762             // is a breaking change. agh.
763             let name = &self.name;
764             tokens.extend(quote!(#name = tracing::field::Empty))
765         } else {
766             self.kind.to_tokens(tokens);
767             self.name.to_tokens(tokens);
768         }
769     }
770 }
771 
772 impl ToTokens for FieldKind {
to_tokens(&self, tokens: &mut TokenStream)773     fn to_tokens(&self, tokens: &mut TokenStream) {
774         match self {
775             FieldKind::Debug => tokens.extend(quote! { ? }),
776             FieldKind::Display => tokens.extend(quote! { % }),
777             _ => {}
778         }
779     }
780 }
781 
782 #[derive(Debug)]
783 enum Level {
784     Str(LitStr),
785     Int(LitInt),
786     Path(Path),
787 }
788 
789 impl Parse for Level {
parse(input: ParseStream<'_>) -> syn::Result<Self>790     fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
791         let _ = input.parse::<kw::level>()?;
792         let _ = input.parse::<Token![=]>()?;
793         let lookahead = input.lookahead1();
794         if lookahead.peek(LitStr) {
795             Ok(Self::Str(input.parse()?))
796         } else if lookahead.peek(LitInt) {
797             Ok(Self::Int(input.parse()?))
798         } else if lookahead.peek(Ident) {
799             Ok(Self::Path(input.parse()?))
800         } else {
801             Err(lookahead.error())
802         }
803     }
804 }
805 
param_names(pat: Pat) -> Box<dyn Iterator<Item = Ident>>806 fn param_names(pat: Pat) -> Box<dyn Iterator<Item = Ident>> {
807     match pat {
808         Pat::Ident(PatIdent { ident, .. }) => Box::new(iter::once(ident)),
809         Pat::Reference(PatReference { pat, .. }) => param_names(*pat),
810         Pat::Struct(PatStruct { fields, .. }) => Box::new(
811             fields
812                 .into_iter()
813                 .flat_map(|FieldPat { pat, .. }| param_names(*pat)),
814         ),
815         Pat::Tuple(PatTuple { elems, .. }) => Box::new(elems.into_iter().flat_map(param_names)),
816         Pat::TupleStruct(PatTupleStruct {
817             pat: PatTuple { elems, .. },
818             ..
819         }) => Box::new(elems.into_iter().flat_map(param_names)),
820 
821         // The above *should* cover all cases of irrefutable patterns,
822         // but we purposefully don't do any funny business here
823         // (such as panicking) because that would obscure rustc's
824         // much more informative error message.
825         _ => Box::new(iter::empty()),
826     }
827 }
828 
829 mod kw {
830     syn::custom_keyword!(fields);
831     syn::custom_keyword!(skip);
832     syn::custom_keyword!(level);
833     syn::custom_keyword!(target);
834     syn::custom_keyword!(name);
835     syn::custom_keyword!(err);
836 }
837 
838 // Get the AST of the inner function we need to hook, if it was generated
839 // by async-trait.
840 // When we are given a function annotated by async-trait, that function
841 // is only a placeholder that returns a pinned future containing the
842 // user logic, and it is that pinned future that needs to be instrumented.
843 // Were we to instrument its parent, we would only collect information
844 // regarding the allocation of that future, and not its own span of execution.
845 // So we inspect the block of the function to find if it matches the pattern
846 // `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))` and we return
847 // the name `foo` if that is the case. 'gen_body' will then be able
848 // to use that information to instrument the proper function.
849 // (this follows the approach suggested in
850 // https://github.com/dtolnay/async-trait/issues/45#issuecomment-571245673)
get_async_trait_function(block: &Block, block_is_async: bool) -> Option<&ItemFn>851 fn get_async_trait_function(block: &Block, block_is_async: bool) -> Option<&ItemFn> {
852     // are we in an async context? If yes, this isn't a async_trait-like pattern
853     if block_is_async {
854         return None;
855     }
856 
857     // list of async functions declared inside the block
858     let mut inside_funs = Vec::new();
859     // last expression declared in the block (it determines the return
860     // value of the block, so that if we are working on a function
861     // whose `trait` or `impl` declaration is annotated by async_trait,
862     // this is quite likely the point where the future is pinned)
863     let mut last_expr = None;
864 
865     // obtain the list of direct internal functions and the last
866     // expression of the block
867     for stmt in &block.stmts {
868         if let Stmt::Item(Item::Fn(fun)) = &stmt {
869             // is the function declared as async? If so, this is a good
870             // candidate, let's keep it in hand
871             if fun.sig.asyncness.is_some() {
872                 inside_funs.push(fun);
873             }
874         } else if let Stmt::Expr(e) = &stmt {
875             last_expr = Some(e);
876         }
877     }
878 
879     // let's play with (too much) pattern matching
880     // is the last expression a function call?
881     if let Some(Expr::Call(ExprCall {
882         func: outside_func,
883         args: outside_args,
884         ..
885     })) = last_expr
886     {
887         if let Expr::Path(path) = outside_func.as_ref() {
888             // is it a call to `Box::pin()`?
889             if "Box::pin" == path_to_string(&path.path) {
890                 // does it takes at least an argument? (if it doesn't,
891                 // it's not gonna compile anyway, but that's no reason
892                 // to (try to) perform an out of bounds access)
893                 if outside_args.is_empty() {
894                     return None;
895                 }
896                 // is the argument to Box::pin a function call itself?
897                 if let Expr::Call(ExprCall { func, args, .. }) = &outside_args[0] {
898                     if let Expr::Path(inside_path) = func.as_ref() {
899                         // "stringify" the path of the function called
900                         let func_name = path_to_string(&inside_path.path);
901                         // is this function directly defined insided the current block?
902                         for fun in inside_funs {
903                             if fun.sig.ident == func_name {
904                                 // we must hook this function now
905                                 return Some(fun);
906                             }
907                         }
908                     }
909                 }
910             }
911         }
912     }
913     None
914 }
915 
916 struct AsyncTraitInfo {
917     name: String,
918     self_type: Option<syn::TypePath>,
919 }
920 
921 // Return the informations necessary to process a function annotated with async-trait.
get_async_trait_info(block: &Block, block_is_async: bool) -> Option<AsyncTraitInfo>922 fn get_async_trait_info(block: &Block, block_is_async: bool) -> Option<AsyncTraitInfo> {
923     let fun = get_async_trait_function(block, block_is_async)?;
924 
925     // if "_self" is present as an argument, we store its type to be able to rewrite "Self" (the
926     // parameter type) with the type of "_self"
927     let self_type = fun
928         .sig
929         .inputs
930         .iter()
931         .map(|arg| {
932             if let FnArg::Typed(ty) = arg {
933                 if let Pat::Ident(PatIdent { ident, .. }) = &*ty.pat {
934                     if ident == "_self" {
935                         let mut ty = &*ty.ty;
936                         // extract the inner type if the argument is "&self" or "&mut self"
937                         if let syn::Type::Reference(syn::TypeReference { elem, .. }) = ty {
938                             ty = &*elem;
939                         }
940                         if let syn::Type::Path(tp) = ty {
941                             return Some(tp.clone());
942                         }
943                     }
944                 }
945             }
946 
947             None
948         })
949         .next();
950     let self_type = match self_type {
951         Some(x) => x,
952         None => None,
953     };
954 
955     Some(AsyncTraitInfo {
956         name: fun.sig.ident.to_string(),
957         self_type,
958     })
959 }
960 
961 // Return a path as a String
path_to_string(path: &Path) -> String962 fn path_to_string(path: &Path) -> String {
963     use std::fmt::Write;
964     // some heuristic to prevent too many allocations
965     let mut res = String::with_capacity(path.segments.len() * 5);
966     for i in 0..path.segments.len() {
967         write!(&mut res, "{}", path.segments[i].ident)
968             .expect("writing to a String should never fail");
969         if i < path.segments.len() - 1 {
970             res.push_str("::");
971         }
972     }
973     res
974 }
975 
976 // A visitor struct replacing the "self" and "Self" tokens in user-supplied fields expressions when
977 // the function is generated by async-trait.
978 struct SelfReplacer {
979     ty: Option<syn::TypePath>,
980 }
981 
982 impl syn::visit_mut::VisitMut for SelfReplacer {
visit_ident_mut(&mut self, id: &mut Ident)983     fn visit_ident_mut(&mut self, id: &mut Ident) {
984         if id == "self" {
985             *id = Ident::new("_self", id.span())
986         }
987     }
988 
visit_type_mut(&mut self, ty: &mut syn::Type)989     fn visit_type_mut(&mut self, ty: &mut syn::Type) {
990         if let syn::Type::Path(syn::TypePath { ref mut path, .. }) = ty {
991             if path_to_string(path) == "Self" {
992                 if let Some(ref true_type) = self.ty {
993                     *path = true_type.path.clone();
994                 }
995             }
996         }
997     }
998 }
999