1 use std::collections::hash_map::DefaultHasher;
2 use std::hash::{Hash, Hasher};
3 
4 use proc_macro::{TokenStream, Span};
5 use crate::proc_macro2::TokenStream as TokenStream2;
6 use devise::{syn, Spanned, SpanWrapped, Result, FromMeta, ext::TypeExt};
7 use indexmap::IndexSet;
8 
9 use crate::proc_macro_ext::{Diagnostics, StringLit};
10 use crate::syn_ext::{syn_to_diag, IdentExt};
11 use self::syn::{Attribute, parse::Parser};
12 
13 use crate::http_codegen::{Method, MediaType, RoutePath, DataSegment, Optional};
14 use crate::attribute::segments::{Source, Kind, Segment};
15 use crate::{ROUTE_FN_PREFIX, ROUTE_STRUCT_PREFIX, URI_MACRO_PREFIX, ROCKET_PARAM_PREFIX};
16 
17 /// The raw, parsed `#[route]` attribute.
18 #[derive(Debug, FromMeta)]
19 struct RouteAttribute {
20     #[meta(naked)]
21     method: SpanWrapped<Method>,
22     path: RoutePath,
23     data: Option<SpanWrapped<DataSegment>>,
24     format: Option<MediaType>,
25     rank: Option<isize>,
26 }
27 
28 /// The raw, parsed `#[method]` (e.g, `get`, `put`, `post`, etc.) attribute.
29 #[derive(Debug, FromMeta)]
30 struct MethodRouteAttribute {
31     #[meta(naked)]
32     path: RoutePath,
33     data: Option<SpanWrapped<DataSegment>>,
34     format: Option<MediaType>,
35     rank: Option<isize>,
36 }
37 
38 /// This structure represents the parsed `route` attribute and associated items.
39 #[derive(Debug)]
40 struct Route {
41     /// The status associated with the code in the `#[route(code)]` attribute.
42     attribute: RouteAttribute,
43     /// The function that was decorated with the `route` attribute.
44     function: syn::ItemFn,
45     /// The non-static parameters declared in the route segments.
46     segments: IndexSet<Segment>,
47     /// The parsed inputs to the user's function. The first ident is the ident
48     /// as the user wrote it, while the second ident is the identifier that
49     /// should be used during code generation, the `rocket_ident`.
50     inputs: Vec<(syn::Ident, syn::Ident, syn::Type)>,
51 }
52 
parse_route(attr: RouteAttribute, function: syn::ItemFn) -> Result<Route>53 fn parse_route(attr: RouteAttribute, function: syn::ItemFn) -> Result<Route> {
54     // Gather diagnostics as we proceed.
55     let mut diags = Diagnostics::new();
56 
57     // Emit a warning if a `data` param was supplied for non-payload methods.
58     if let Some(ref data) = attr.data {
59         if !attr.method.0.supports_payload() {
60             let msg = format!("'{}' does not typically support payloads", attr.method.0);
61             data.full_span.warning("`data` used with non-payload-supporting method")
62                 .span_note(attr.method.span, msg)
63                 .emit()
64         }
65     }
66 
67     // Collect all of the dynamic segments in an `IndexSet`, checking for dups.
68     let mut segments: IndexSet<Segment> = IndexSet::new();
69     fn dup_check<I>(set: &mut IndexSet<Segment>, iter: I, diags: &mut Diagnostics)
70         where I: Iterator<Item = Segment>
71     {
72         for segment in iter.filter(|s| s.kind != Kind::Static) {
73             let span = segment.span;
74             if let Some(previous) = set.replace(segment) {
75                 diags.push(span.error(format!("duplicate parameter: `{}`", previous.name))
76                     .span_note(previous.span, "previous parameter with the same name here"))
77             }
78         }
79     }
80 
81     dup_check(&mut segments, attr.path.path.iter().cloned(), &mut diags);
82     attr.path.query.as_ref().map(|q| dup_check(&mut segments, q.iter().cloned(), &mut diags));
83     dup_check(&mut segments, attr.data.clone().map(|s| s.value.0).into_iter(), &mut diags);
84 
85     // Check the validity of function arguments.
86     let mut inputs = vec![];
87     let mut fn_segments: IndexSet<Segment> = IndexSet::new();
88     for input in &function.sig.inputs {
89         let help = "all handler arguments must be of the form: `ident: Type`";
90         let span = input.span();
91         let (ident, ty) = match input {
92             syn::FnArg::Typed(arg) => match *arg.pat {
93                 syn::Pat::Ident(ref pat) => (&pat.ident, &arg.ty),
94                 syn::Pat::Wild(_) => {
95                     diags.push(span.error("handler arguments cannot be ignored").help(help));
96                     continue;
97                 }
98                 _ => {
99                     diags.push(span.error("invalid use of pattern").help(help));
100                     continue;
101                 }
102             }
103             // Other cases shouldn't happen since we parsed an `ItemFn`.
104             _ => {
105                 diags.push(span.error("invalid handler argument").help(help));
106                 continue;
107             }
108         };
109 
110         let rocket_ident = ident.prepend(ROCKET_PARAM_PREFIX);
111         inputs.push((ident.clone(), rocket_ident, ty.with_stripped_lifetimes()));
112         fn_segments.insert(ident.into());
113     }
114 
115     // Check that all of the declared parameters are function inputs.
116     let span = match function.sig.inputs.is_empty() {
117         false => function.sig.inputs.span(),
118         true => function.span()
119     };
120 
121     for missing in segments.difference(&fn_segments) {
122         diags.push(missing.span.error("unused dynamic parameter")
123             .span_note(span, format!("expected argument named `{}` here", missing.name)))
124     }
125 
126     diags.head_err_or(Route { attribute: attr, function, inputs, segments })
127 }
128 
param_expr(seg: &Segment, ident: &syn::Ident, ty: &syn::Type) -> TokenStream2129 fn param_expr(seg: &Segment, ident: &syn::Ident, ty: &syn::Type) -> TokenStream2 {
130     define_vars_and_mods!(req, data, error, log, request, _None, _Some, _Ok, _Err, Outcome);
131     let i = seg.index.expect("dynamic parameters must be indexed");
132     let span = ident.span().unstable().join(ty.span()).unwrap().into();
133     let name = ident.to_string();
134 
135     // All dynamic parameter should be found if this function is being called;
136     // that's the point of statically checking the URI parameters.
137     let internal_error = quote!({
138         #log::error("Internal invariant error: expected dynamic parameter not found.");
139         #log::error("Please report this error to the Rocket issue tracker.");
140         #Outcome::Forward(#data)
141     });
142 
143     // Returned when a dynamic parameter fails to parse.
144     let parse_error = quote!({
145         #log::warn_(&format!("Failed to parse '{}': {:?}", #name, #error));
146         #Outcome::Forward(#data)
147     });
148 
149     let expr = match seg.kind {
150         Kind::Single => quote_spanned! { span =>
151             match #req.raw_segment_str(#i) {
152                 #_Some(__s) => match <#ty as #request::FromParam>::from_param(__s) {
153                     #_Ok(__v) => __v,
154                     #_Err(#error) => return #parse_error,
155                 },
156                 #_None => return #internal_error
157             }
158         },
159         Kind::Multi => quote_spanned! { span =>
160             match #req.raw_segments(#i) {
161                 #_Some(__s) => match <#ty as #request::FromSegments>::from_segments(__s) {
162                     #_Ok(__v) => __v,
163                     #_Err(#error) => return #parse_error,
164                 },
165                 #_None => return #internal_error
166             }
167         },
168         Kind::Static => return quote!()
169     };
170 
171     quote! {
172         #[allow(non_snake_case, unreachable_patterns, unreachable_code)]
173         let #ident: #ty = #expr;
174     }
175 }
176 
data_expr(ident: &syn::Ident, ty: &syn::Type) -> TokenStream2177 fn data_expr(ident: &syn::Ident, ty: &syn::Type) -> TokenStream2 {
178     define_vars_and_mods!(req, data, FromData, Outcome, Transform);
179     let span = ident.span().unstable().join(ty.span()).unwrap().into();
180     quote_spanned! { span =>
181         let __transform = <#ty as #FromData>::transform(#req, #data);
182 
183         #[allow(unreachable_patterns, unreachable_code)]
184         let __outcome = match __transform {
185             #Transform::Owned(#Outcome::Success(__v)) => {
186                 #Transform::Owned(#Outcome::Success(__v))
187             },
188             #Transform::Borrowed(#Outcome::Success(ref __v)) => {
189                 #Transform::Borrowed(#Outcome::Success(::std::borrow::Borrow::borrow(__v)))
190             },
191             #Transform::Borrowed(__o) => #Transform::Borrowed(__o.map(|_| {
192                 unreachable!("Borrowed(Success(..)) case handled in previous block")
193             })),
194             #Transform::Owned(__o) => #Transform::Owned(__o),
195         };
196 
197         #[allow(non_snake_case, unreachable_patterns, unreachable_code)]
198         let #ident: #ty = match <#ty as #FromData>::from_data(#req, __outcome) {
199             #Outcome::Success(__d) => __d,
200             #Outcome::Forward(__d) => return #Outcome::Forward(__d),
201             #Outcome::Failure((__c, _)) => return #Outcome::Failure(__c),
202         };
203     }
204 }
205 
query_exprs(route: &Route) -> Option<TokenStream2>206 fn query_exprs(route: &Route) -> Option<TokenStream2> {
207     define_vars_and_mods!(_None, _Some, _Ok, _Err, _Option);
208     define_vars_and_mods!(data, trail, log, request, req, Outcome, SmallVec, Query);
209     let query_segments = route.attribute.path.query.as_ref()?;
210     let (mut decls, mut matchers, mut builders) = (vec![], vec![], vec![]);
211     for segment in query_segments {
212         let name = &segment.name;
213         let (ident, ty, span) = if segment.kind != Kind::Static {
214             let (ident, ty) = route.inputs.iter()
215                 .find(|(ident, _, _)| ident == &segment.name)
216                 .map(|(_, rocket_ident, ty)| (rocket_ident, ty))
217                 .unwrap();
218 
219             let span = ident.span().unstable().join(ty.span()).unwrap();
220             (Some(ident), Some(ty), span.into())
221         } else {
222             (None, None, segment.span.into())
223         };
224 
225         let decl = match segment.kind {
226             Kind::Single => quote_spanned! { span =>
227                 #[allow(non_snake_case)]
228                 let mut #ident: #_Option<#ty> = #_None;
229             },
230             Kind::Multi => quote_spanned! { span =>
231                 #[allow(non_snake_case)]
232                 let mut #trail = #SmallVec::<[#request::FormItem; 8]>::new();
233             },
234             Kind::Static => quote!()
235         };
236 
237         let matcher = match segment.kind {
238             Kind::Single => quote_spanned! { span =>
239                 (_, #name, __v) => {
240                     #[allow(unreachable_patterns, unreachable_code)]
241                     let __v = match <#ty as #request::FromFormValue>::from_form_value(__v) {
242                         #_Ok(__v) => __v,
243                         #_Err(__e) => {
244                             #log::warn_(&format!("Failed to parse '{}': {:?}", #name, __e));
245                             return #Outcome::Forward(#data);
246                         }
247                     };
248 
249                     #ident = #_Some(__v);
250                 }
251             },
252             Kind::Static => quote! {
253                 (#name, _, _) => continue,
254             },
255             Kind::Multi => quote! {
256                 _ => #trail.push(__i),
257             }
258         };
259 
260         let builder = match segment.kind {
261             Kind::Single => quote_spanned! { span =>
262                 #[allow(non_snake_case)]
263                 let #ident = match #ident.or_else(<#ty as #request::FromFormValue>::default) {
264                     #_Some(__v) => __v,
265                     #_None => {
266                         #log::warn_(&format!("Missing required query parameter '{}'.", #name));
267                         return #Outcome::Forward(#data);
268                     }
269                 };
270             },
271             Kind::Multi => quote_spanned! { span =>
272                 #[allow(non_snake_case)]
273                 let #ident = match <#ty as #request::FromQuery>::from_query(#Query(&#trail)) {
274                     #_Ok(__v) => __v,
275                     #_Err(__e) => {
276                         #log::warn_(&format!("Failed to parse '{}': {:?}", #name, __e));
277                         return #Outcome::Forward(#data);
278                     }
279                 };
280             },
281             Kind::Static => quote!()
282         };
283 
284         decls.push(decl);
285         matchers.push(matcher);
286         builders.push(builder);
287     }
288 
289     matchers.push(quote!(_ => continue));
290     Some(quote! {
291         #(#decls)*
292 
293         if let #_Some(__items) = #req.raw_query_items() {
294             for __i in __items {
295                 match (__i.raw.as_str(), __i.key.as_str(), __i.value) {
296                     #(
297                         #[allow(unreachable_patterns, unreachable_code)]
298                         #matchers
299                     )*
300                 }
301             }
302         }
303 
304         #(
305             #[allow(unreachable_patterns, unreachable_code)]
306             #builders
307         )*
308     })
309 }
310 
request_guard_expr(ident: &syn::Ident, ty: &syn::Type) -> TokenStream2311 fn request_guard_expr(ident: &syn::Ident, ty: &syn::Type) -> TokenStream2 {
312     define_vars_and_mods!(req, data, request, Outcome);
313     let span = ident.span().unstable().join(ty.span()).unwrap().into();
314     quote_spanned! { span =>
315         #[allow(non_snake_case, unreachable_patterns, unreachable_code)]
316         let #ident: #ty = match <#ty as #request::FromRequest>::from_request(#req) {
317             #Outcome::Success(__v) => __v,
318             #Outcome::Forward(_) => return #Outcome::Forward(#data),
319             #Outcome::Failure((__c, _)) => return #Outcome::Failure(__c),
320         };
321     }
322 }
323 
generate_internal_uri_macro(route: &Route) -> TokenStream2324 fn generate_internal_uri_macro(route: &Route) -> TokenStream2 {
325     let dynamic_args = route.segments.iter()
326         .filter(|seg| seg.source == Source::Path || seg.source == Source::Query)
327         .filter(|seg| seg.kind != Kind::Static)
328         .map(|seg| &seg.name)
329         .map(|name| route.inputs.iter().find(|(ident, ..)| ident == name).unwrap())
330         .map(|(ident, _, ty)| quote!(#ident: #ty));
331 
332     let mut hasher = DefaultHasher::new();
333     let route_span = route.function.span();
334     route_span.source_file().path().hash(&mut hasher);
335     let line_column = route_span.start();
336     line_column.line.hash(&mut hasher);
337     line_column.column.hash(&mut hasher);
338 
339     let mut generated_macro_name = route.function.sig.ident.prepend(URI_MACRO_PREFIX);
340     generated_macro_name.set_span(Span::call_site().into());
341     let inner_generated_macro_name = generated_macro_name.append(&hasher.finish().to_string());
342     let route_uri = route.attribute.path.origin.0.to_string();
343 
344     quote! {
345         #[doc(hidden)]
346         #[macro_export]
347         macro_rules! #inner_generated_macro_name {
348             ($($token:tt)*) => {{
349                 extern crate std;
350                 extern crate rocket;
351                 rocket::rocket_internal_uri!(#route_uri, (#(#dynamic_args),*), $($token)*)
352             }};
353         }
354 
355         #[doc(hidden)]
356         pub use #inner_generated_macro_name as #generated_macro_name;
357     }
358 }
359 
generate_respond_expr(route: &Route) -> TokenStream2360 fn generate_respond_expr(route: &Route) -> TokenStream2 {
361     let ret_span = match route.function.sig.output {
362         syn::ReturnType::Default => route.function.sig.ident.span(),
363         syn::ReturnType::Type(_, ref ty) => ty.span().into()
364     };
365 
366     define_vars_and_mods!(req);
367     define_vars_and_mods!(ret_span => handler);
368     let user_handler_fn_name = &route.function.sig.ident;
369     let parameter_names = route.inputs.iter()
370         .map(|(_, rocket_ident, _)| rocket_ident);
371 
372     quote_spanned! { ret_span =>
373         let ___responder = #user_handler_fn_name(#(#parameter_names),*);
374         #handler::Outcome::from(#req, ___responder)
375     }
376 }
377 
codegen_route(route: Route) -> Result<TokenStream>378 fn codegen_route(route: Route) -> Result<TokenStream> {
379     // Generate the declarations for path, data, and request guard parameters.
380     let mut data_stmt = None;
381     let mut req_guard_definitions = vec![];
382     let mut parameter_definitions = vec![];
383     for (ident, rocket_ident, ty) in &route.inputs {
384         let fn_segment: Segment = ident.into();
385         match route.segments.get(&fn_segment) {
386             Some(seg) if seg.source == Source::Path => {
387                 parameter_definitions.push(param_expr(seg, rocket_ident, &ty));
388             }
389             Some(seg) if seg.source == Source::Data => {
390                 // the data statement needs to come last, so record it specially
391                 data_stmt = Some(data_expr(rocket_ident, &ty));
392             }
393             Some(_) => continue, // handle query parameters later
394             None => {
395                 req_guard_definitions.push(request_guard_expr(rocket_ident, &ty));
396             }
397         };
398     }
399 
400     // Generate the declarations for query parameters.
401     if let Some(exprs) = query_exprs(&route) {
402         parameter_definitions.push(exprs);
403     }
404 
405     // Gather everything we need.
406     define_vars_and_mods!(req, data, handler, Request, Data, StaticRouteInfo);
407     let (vis, user_handler_fn) = (&route.function.vis, &route.function);
408     let user_handler_fn_name = &user_handler_fn.sig.ident;
409     let generated_fn_name = user_handler_fn_name.prepend(ROUTE_FN_PREFIX);
410     let generated_struct_name = user_handler_fn_name.prepend(ROUTE_STRUCT_PREFIX);
411     let generated_internal_uri_macro = generate_internal_uri_macro(&route);
412     let generated_respond_expr = generate_respond_expr(&route);
413 
414     let method = route.attribute.method;
415     let path = route.attribute.path.origin.0.to_string();
416     let rank = Optional(route.attribute.rank);
417     let format = Optional(route.attribute.format);
418 
419     Ok(quote! {
420         #user_handler_fn
421 
422         /// Rocket code generated wrapping route function.
423         #[doc(hidden)]
424         #vis fn #generated_fn_name<'_b>(
425             #req: &'_b #Request,
426             #data: #Data
427         ) -> #handler::Outcome<'_b> {
428             #(#req_guard_definitions)*
429             #(#parameter_definitions)*
430             #data_stmt
431 
432             #generated_respond_expr
433         }
434 
435         /// Rocket code generated wrapping URI macro.
436         #generated_internal_uri_macro
437 
438         /// Rocket code generated static route info.
439         #[doc(hidden)]
440         #[allow(non_upper_case_globals)]
441         #vis static #generated_struct_name: #StaticRouteInfo =
442             #StaticRouteInfo {
443                 name: stringify!(#user_handler_fn_name),
444                 method: #method,
445                 path: #path,
446                 handler: #generated_fn_name,
447                 format: #format,
448                 rank: #rank,
449             };
450     }.into())
451 }
452 
complete_route(args: TokenStream2, input: TokenStream) -> Result<TokenStream>453 fn complete_route(args: TokenStream2, input: TokenStream) -> Result<TokenStream> {
454     let function: syn::ItemFn = syn::parse(input).map_err(syn_to_diag)
455         .map_err(|diag| diag.help("`#[route]` can only be used on functions"))?;
456 
457     let full_attr = quote!(#[route(#args)]);
458     let attrs = Attribute::parse_outer.parse2(full_attr).map_err(syn_to_diag)?;
459     let attribute = match RouteAttribute::from_attrs("route", &attrs) {
460         Some(result) => result?,
461         None => return Err(Span::call_site().error("internal error: bad attribute"))
462     };
463 
464     codegen_route(parse_route(attribute, function)?)
465 }
466 
incomplete_route( method: crate::http::Method, args: TokenStream2, input: TokenStream ) -> Result<TokenStream>467 fn incomplete_route(
468     method: crate::http::Method,
469     args: TokenStream2,
470     input: TokenStream
471 ) -> Result<TokenStream> {
472     let method_str = method.to_string().to_lowercase();
473     // FIXME(proc_macro): there should be a way to get this `Span`.
474     let method_span = StringLit::new(format!("#[{}]", method), Span::call_site())
475         .subspan(2..2 + method_str.len());
476 
477     let method_ident = syn::Ident::new(&method_str, method_span.into());
478 
479     let function: syn::ItemFn = syn::parse(input).map_err(syn_to_diag)
480         .map_err(|d| d.help(format!("#[{}] can only be used on functions", method_str)))?;
481 
482     let full_attr = quote!(#[#method_ident(#args)]);
483     let attrs = Attribute::parse_outer.parse2(full_attr).map_err(syn_to_diag)?;
484     let method_attribute = match MethodRouteAttribute::from_attrs(&method_str, &attrs) {
485         Some(result) => result?,
486         None => return Err(Span::call_site().error("internal error: bad attribute"))
487     };
488 
489     let attribute = RouteAttribute {
490         method: SpanWrapped {
491             full_span: method_span, span: method_span, value: Method(method)
492         },
493         path: method_attribute.path,
494         data: method_attribute.data,
495         format: method_attribute.format,
496         rank: method_attribute.rank,
497     };
498 
499     codegen_route(parse_route(attribute, function)?)
500 }
501 
route_attribute<M: Into<Option<crate::http::Method>>>( method: M, args: TokenStream, input: TokenStream ) -> TokenStream502 pub fn route_attribute<M: Into<Option<crate::http::Method>>>(
503     method: M,
504     args: TokenStream,
505     input: TokenStream
506 ) -> TokenStream {
507     let result = match method.into() {
508         Some(method) => incomplete_route(method, args.into(), input),
509         None => complete_route(args.into(), input)
510     };
511 
512     result.unwrap_or_else(|diag| { diag.emit(); TokenStream::new() })
513 }
514