1 use std::collections::hash_map::DefaultHasher;
2 use std::hash::{Hash, Hasher};
3
4 use proc_macro::{TokenStream, Span};
5 use crate::proc_macro2::TokenStream as TokenStream2;
6 use devise::{syn, Spanned, SpanWrapped, Result, FromMeta, ext::TypeExt};
7 use indexmap::IndexSet;
8
9 use crate::proc_macro_ext::{Diagnostics, StringLit};
10 use crate::syn_ext::{syn_to_diag, IdentExt};
11 use self::syn::{Attribute, parse::Parser};
12
13 use crate::http_codegen::{Method, MediaType, RoutePath, DataSegment, Optional};
14 use crate::attribute::segments::{Source, Kind, Segment};
15 use crate::{ROUTE_FN_PREFIX, ROUTE_STRUCT_PREFIX, URI_MACRO_PREFIX, ROCKET_PARAM_PREFIX};
16
17 /// The raw, parsed `#[route]` attribute.
18 #[derive(Debug, FromMeta)]
19 struct RouteAttribute {
20 #[meta(naked)]
21 method: SpanWrapped<Method>,
22 path: RoutePath,
23 data: Option<SpanWrapped<DataSegment>>,
24 format: Option<MediaType>,
25 rank: Option<isize>,
26 }
27
28 /// The raw, parsed `#[method]` (e.g, `get`, `put`, `post`, etc.) attribute.
29 #[derive(Debug, FromMeta)]
30 struct MethodRouteAttribute {
31 #[meta(naked)]
32 path: RoutePath,
33 data: Option<SpanWrapped<DataSegment>>,
34 format: Option<MediaType>,
35 rank: Option<isize>,
36 }
37
38 /// This structure represents the parsed `route` attribute and associated items.
39 #[derive(Debug)]
40 struct Route {
41 /// The status associated with the code in the `#[route(code)]` attribute.
42 attribute: RouteAttribute,
43 /// The function that was decorated with the `route` attribute.
44 function: syn::ItemFn,
45 /// The non-static parameters declared in the route segments.
46 segments: IndexSet<Segment>,
47 /// The parsed inputs to the user's function. The first ident is the ident
48 /// as the user wrote it, while the second ident is the identifier that
49 /// should be used during code generation, the `rocket_ident`.
50 inputs: Vec<(syn::Ident, syn::Ident, syn::Type)>,
51 }
52
parse_route(attr: RouteAttribute, function: syn::ItemFn) -> Result<Route>53 fn parse_route(attr: RouteAttribute, function: syn::ItemFn) -> Result<Route> {
54 // Gather diagnostics as we proceed.
55 let mut diags = Diagnostics::new();
56
57 // Emit a warning if a `data` param was supplied for non-payload methods.
58 if let Some(ref data) = attr.data {
59 if !attr.method.0.supports_payload() {
60 let msg = format!("'{}' does not typically support payloads", attr.method.0);
61 data.full_span.warning("`data` used with non-payload-supporting method")
62 .span_note(attr.method.span, msg)
63 .emit()
64 }
65 }
66
67 // Collect all of the dynamic segments in an `IndexSet`, checking for dups.
68 let mut segments: IndexSet<Segment> = IndexSet::new();
69 fn dup_check<I>(set: &mut IndexSet<Segment>, iter: I, diags: &mut Diagnostics)
70 where I: Iterator<Item = Segment>
71 {
72 for segment in iter.filter(|s| s.kind != Kind::Static) {
73 let span = segment.span;
74 if let Some(previous) = set.replace(segment) {
75 diags.push(span.error(format!("duplicate parameter: `{}`", previous.name))
76 .span_note(previous.span, "previous parameter with the same name here"))
77 }
78 }
79 }
80
81 dup_check(&mut segments, attr.path.path.iter().cloned(), &mut diags);
82 attr.path.query.as_ref().map(|q| dup_check(&mut segments, q.iter().cloned(), &mut diags));
83 dup_check(&mut segments, attr.data.clone().map(|s| s.value.0).into_iter(), &mut diags);
84
85 // Check the validity of function arguments.
86 let mut inputs = vec![];
87 let mut fn_segments: IndexSet<Segment> = IndexSet::new();
88 for input in &function.sig.inputs {
89 let help = "all handler arguments must be of the form: `ident: Type`";
90 let span = input.span();
91 let (ident, ty) = match input {
92 syn::FnArg::Typed(arg) => match *arg.pat {
93 syn::Pat::Ident(ref pat) => (&pat.ident, &arg.ty),
94 syn::Pat::Wild(_) => {
95 diags.push(span.error("handler arguments cannot be ignored").help(help));
96 continue;
97 }
98 _ => {
99 diags.push(span.error("invalid use of pattern").help(help));
100 continue;
101 }
102 }
103 // Other cases shouldn't happen since we parsed an `ItemFn`.
104 _ => {
105 diags.push(span.error("invalid handler argument").help(help));
106 continue;
107 }
108 };
109
110 let rocket_ident = ident.prepend(ROCKET_PARAM_PREFIX);
111 inputs.push((ident.clone(), rocket_ident, ty.with_stripped_lifetimes()));
112 fn_segments.insert(ident.into());
113 }
114
115 // Check that all of the declared parameters are function inputs.
116 let span = match function.sig.inputs.is_empty() {
117 false => function.sig.inputs.span(),
118 true => function.span()
119 };
120
121 for missing in segments.difference(&fn_segments) {
122 diags.push(missing.span.error("unused dynamic parameter")
123 .span_note(span, format!("expected argument named `{}` here", missing.name)))
124 }
125
126 diags.head_err_or(Route { attribute: attr, function, inputs, segments })
127 }
128
param_expr(seg: &Segment, ident: &syn::Ident, ty: &syn::Type) -> TokenStream2129 fn param_expr(seg: &Segment, ident: &syn::Ident, ty: &syn::Type) -> TokenStream2 {
130 define_vars_and_mods!(req, data, error, log, request, _None, _Some, _Ok, _Err, Outcome);
131 let i = seg.index.expect("dynamic parameters must be indexed");
132 let span = ident.span().unstable().join(ty.span()).unwrap().into();
133 let name = ident.to_string();
134
135 // All dynamic parameter should be found if this function is being called;
136 // that's the point of statically checking the URI parameters.
137 let internal_error = quote!({
138 #log::error("Internal invariant error: expected dynamic parameter not found.");
139 #log::error("Please report this error to the Rocket issue tracker.");
140 #Outcome::Forward(#data)
141 });
142
143 // Returned when a dynamic parameter fails to parse.
144 let parse_error = quote!({
145 #log::warn_(&format!("Failed to parse '{}': {:?}", #name, #error));
146 #Outcome::Forward(#data)
147 });
148
149 let expr = match seg.kind {
150 Kind::Single => quote_spanned! { span =>
151 match #req.raw_segment_str(#i) {
152 #_Some(__s) => match <#ty as #request::FromParam>::from_param(__s) {
153 #_Ok(__v) => __v,
154 #_Err(#error) => return #parse_error,
155 },
156 #_None => return #internal_error
157 }
158 },
159 Kind::Multi => quote_spanned! { span =>
160 match #req.raw_segments(#i) {
161 #_Some(__s) => match <#ty as #request::FromSegments>::from_segments(__s) {
162 #_Ok(__v) => __v,
163 #_Err(#error) => return #parse_error,
164 },
165 #_None => return #internal_error
166 }
167 },
168 Kind::Static => return quote!()
169 };
170
171 quote! {
172 #[allow(non_snake_case, unreachable_patterns, unreachable_code)]
173 let #ident: #ty = #expr;
174 }
175 }
176
data_expr(ident: &syn::Ident, ty: &syn::Type) -> TokenStream2177 fn data_expr(ident: &syn::Ident, ty: &syn::Type) -> TokenStream2 {
178 define_vars_and_mods!(req, data, FromData, Outcome, Transform);
179 let span = ident.span().unstable().join(ty.span()).unwrap().into();
180 quote_spanned! { span =>
181 let __transform = <#ty as #FromData>::transform(#req, #data);
182
183 #[allow(unreachable_patterns, unreachable_code)]
184 let __outcome = match __transform {
185 #Transform::Owned(#Outcome::Success(__v)) => {
186 #Transform::Owned(#Outcome::Success(__v))
187 },
188 #Transform::Borrowed(#Outcome::Success(ref __v)) => {
189 #Transform::Borrowed(#Outcome::Success(::std::borrow::Borrow::borrow(__v)))
190 },
191 #Transform::Borrowed(__o) => #Transform::Borrowed(__o.map(|_| {
192 unreachable!("Borrowed(Success(..)) case handled in previous block")
193 })),
194 #Transform::Owned(__o) => #Transform::Owned(__o),
195 };
196
197 #[allow(non_snake_case, unreachable_patterns, unreachable_code)]
198 let #ident: #ty = match <#ty as #FromData>::from_data(#req, __outcome) {
199 #Outcome::Success(__d) => __d,
200 #Outcome::Forward(__d) => return #Outcome::Forward(__d),
201 #Outcome::Failure((__c, _)) => return #Outcome::Failure(__c),
202 };
203 }
204 }
205
query_exprs(route: &Route) -> Option<TokenStream2>206 fn query_exprs(route: &Route) -> Option<TokenStream2> {
207 define_vars_and_mods!(_None, _Some, _Ok, _Err, _Option);
208 define_vars_and_mods!(data, trail, log, request, req, Outcome, SmallVec, Query);
209 let query_segments = route.attribute.path.query.as_ref()?;
210 let (mut decls, mut matchers, mut builders) = (vec![], vec![], vec![]);
211 for segment in query_segments {
212 let name = &segment.name;
213 let (ident, ty, span) = if segment.kind != Kind::Static {
214 let (ident, ty) = route.inputs.iter()
215 .find(|(ident, _, _)| ident == &segment.name)
216 .map(|(_, rocket_ident, ty)| (rocket_ident, ty))
217 .unwrap();
218
219 let span = ident.span().unstable().join(ty.span()).unwrap();
220 (Some(ident), Some(ty), span.into())
221 } else {
222 (None, None, segment.span.into())
223 };
224
225 let decl = match segment.kind {
226 Kind::Single => quote_spanned! { span =>
227 #[allow(non_snake_case)]
228 let mut #ident: #_Option<#ty> = #_None;
229 },
230 Kind::Multi => quote_spanned! { span =>
231 #[allow(non_snake_case)]
232 let mut #trail = #SmallVec::<[#request::FormItem; 8]>::new();
233 },
234 Kind::Static => quote!()
235 };
236
237 let matcher = match segment.kind {
238 Kind::Single => quote_spanned! { span =>
239 (_, #name, __v) => {
240 #[allow(unreachable_patterns, unreachable_code)]
241 let __v = match <#ty as #request::FromFormValue>::from_form_value(__v) {
242 #_Ok(__v) => __v,
243 #_Err(__e) => {
244 #log::warn_(&format!("Failed to parse '{}': {:?}", #name, __e));
245 return #Outcome::Forward(#data);
246 }
247 };
248
249 #ident = #_Some(__v);
250 }
251 },
252 Kind::Static => quote! {
253 (#name, _, _) => continue,
254 },
255 Kind::Multi => quote! {
256 _ => #trail.push(__i),
257 }
258 };
259
260 let builder = match segment.kind {
261 Kind::Single => quote_spanned! { span =>
262 #[allow(non_snake_case)]
263 let #ident = match #ident.or_else(<#ty as #request::FromFormValue>::default) {
264 #_Some(__v) => __v,
265 #_None => {
266 #log::warn_(&format!("Missing required query parameter '{}'.", #name));
267 return #Outcome::Forward(#data);
268 }
269 };
270 },
271 Kind::Multi => quote_spanned! { span =>
272 #[allow(non_snake_case)]
273 let #ident = match <#ty as #request::FromQuery>::from_query(#Query(&#trail)) {
274 #_Ok(__v) => __v,
275 #_Err(__e) => {
276 #log::warn_(&format!("Failed to parse '{}': {:?}", #name, __e));
277 return #Outcome::Forward(#data);
278 }
279 };
280 },
281 Kind::Static => quote!()
282 };
283
284 decls.push(decl);
285 matchers.push(matcher);
286 builders.push(builder);
287 }
288
289 matchers.push(quote!(_ => continue));
290 Some(quote! {
291 #(#decls)*
292
293 if let #_Some(__items) = #req.raw_query_items() {
294 for __i in __items {
295 match (__i.raw.as_str(), __i.key.as_str(), __i.value) {
296 #(
297 #[allow(unreachable_patterns, unreachable_code)]
298 #matchers
299 )*
300 }
301 }
302 }
303
304 #(
305 #[allow(unreachable_patterns, unreachable_code)]
306 #builders
307 )*
308 })
309 }
310
request_guard_expr(ident: &syn::Ident, ty: &syn::Type) -> TokenStream2311 fn request_guard_expr(ident: &syn::Ident, ty: &syn::Type) -> TokenStream2 {
312 define_vars_and_mods!(req, data, request, Outcome);
313 let span = ident.span().unstable().join(ty.span()).unwrap().into();
314 quote_spanned! { span =>
315 #[allow(non_snake_case, unreachable_patterns, unreachable_code)]
316 let #ident: #ty = match <#ty as #request::FromRequest>::from_request(#req) {
317 #Outcome::Success(__v) => __v,
318 #Outcome::Forward(_) => return #Outcome::Forward(#data),
319 #Outcome::Failure((__c, _)) => return #Outcome::Failure(__c),
320 };
321 }
322 }
323
generate_internal_uri_macro(route: &Route) -> TokenStream2324 fn generate_internal_uri_macro(route: &Route) -> TokenStream2 {
325 let dynamic_args = route.segments.iter()
326 .filter(|seg| seg.source == Source::Path || seg.source == Source::Query)
327 .filter(|seg| seg.kind != Kind::Static)
328 .map(|seg| &seg.name)
329 .map(|name| route.inputs.iter().find(|(ident, ..)| ident == name).unwrap())
330 .map(|(ident, _, ty)| quote!(#ident: #ty));
331
332 let mut hasher = DefaultHasher::new();
333 let route_span = route.function.span();
334 route_span.source_file().path().hash(&mut hasher);
335 let line_column = route_span.start();
336 line_column.line.hash(&mut hasher);
337 line_column.column.hash(&mut hasher);
338
339 let mut generated_macro_name = route.function.sig.ident.prepend(URI_MACRO_PREFIX);
340 generated_macro_name.set_span(Span::call_site().into());
341 let inner_generated_macro_name = generated_macro_name.append(&hasher.finish().to_string());
342 let route_uri = route.attribute.path.origin.0.to_string();
343
344 quote! {
345 #[doc(hidden)]
346 #[macro_export]
347 macro_rules! #inner_generated_macro_name {
348 ($($token:tt)*) => {{
349 extern crate std;
350 extern crate rocket;
351 rocket::rocket_internal_uri!(#route_uri, (#(#dynamic_args),*), $($token)*)
352 }};
353 }
354
355 #[doc(hidden)]
356 pub use #inner_generated_macro_name as #generated_macro_name;
357 }
358 }
359
generate_respond_expr(route: &Route) -> TokenStream2360 fn generate_respond_expr(route: &Route) -> TokenStream2 {
361 let ret_span = match route.function.sig.output {
362 syn::ReturnType::Default => route.function.sig.ident.span(),
363 syn::ReturnType::Type(_, ref ty) => ty.span().into()
364 };
365
366 define_vars_and_mods!(req);
367 define_vars_and_mods!(ret_span => handler);
368 let user_handler_fn_name = &route.function.sig.ident;
369 let parameter_names = route.inputs.iter()
370 .map(|(_, rocket_ident, _)| rocket_ident);
371
372 quote_spanned! { ret_span =>
373 let ___responder = #user_handler_fn_name(#(#parameter_names),*);
374 #handler::Outcome::from(#req, ___responder)
375 }
376 }
377
codegen_route(route: Route) -> Result<TokenStream>378 fn codegen_route(route: Route) -> Result<TokenStream> {
379 // Generate the declarations for path, data, and request guard parameters.
380 let mut data_stmt = None;
381 let mut req_guard_definitions = vec![];
382 let mut parameter_definitions = vec![];
383 for (ident, rocket_ident, ty) in &route.inputs {
384 let fn_segment: Segment = ident.into();
385 match route.segments.get(&fn_segment) {
386 Some(seg) if seg.source == Source::Path => {
387 parameter_definitions.push(param_expr(seg, rocket_ident, &ty));
388 }
389 Some(seg) if seg.source == Source::Data => {
390 // the data statement needs to come last, so record it specially
391 data_stmt = Some(data_expr(rocket_ident, &ty));
392 }
393 Some(_) => continue, // handle query parameters later
394 None => {
395 req_guard_definitions.push(request_guard_expr(rocket_ident, &ty));
396 }
397 };
398 }
399
400 // Generate the declarations for query parameters.
401 if let Some(exprs) = query_exprs(&route) {
402 parameter_definitions.push(exprs);
403 }
404
405 // Gather everything we need.
406 define_vars_and_mods!(req, data, handler, Request, Data, StaticRouteInfo);
407 let (vis, user_handler_fn) = (&route.function.vis, &route.function);
408 let user_handler_fn_name = &user_handler_fn.sig.ident;
409 let generated_fn_name = user_handler_fn_name.prepend(ROUTE_FN_PREFIX);
410 let generated_struct_name = user_handler_fn_name.prepend(ROUTE_STRUCT_PREFIX);
411 let generated_internal_uri_macro = generate_internal_uri_macro(&route);
412 let generated_respond_expr = generate_respond_expr(&route);
413
414 let method = route.attribute.method;
415 let path = route.attribute.path.origin.0.to_string();
416 let rank = Optional(route.attribute.rank);
417 let format = Optional(route.attribute.format);
418
419 Ok(quote! {
420 #user_handler_fn
421
422 /// Rocket code generated wrapping route function.
423 #[doc(hidden)]
424 #vis fn #generated_fn_name<'_b>(
425 #req: &'_b #Request,
426 #data: #Data
427 ) -> #handler::Outcome<'_b> {
428 #(#req_guard_definitions)*
429 #(#parameter_definitions)*
430 #data_stmt
431
432 #generated_respond_expr
433 }
434
435 /// Rocket code generated wrapping URI macro.
436 #generated_internal_uri_macro
437
438 /// Rocket code generated static route info.
439 #[doc(hidden)]
440 #[allow(non_upper_case_globals)]
441 #vis static #generated_struct_name: #StaticRouteInfo =
442 #StaticRouteInfo {
443 name: stringify!(#user_handler_fn_name),
444 method: #method,
445 path: #path,
446 handler: #generated_fn_name,
447 format: #format,
448 rank: #rank,
449 };
450 }.into())
451 }
452
complete_route(args: TokenStream2, input: TokenStream) -> Result<TokenStream>453 fn complete_route(args: TokenStream2, input: TokenStream) -> Result<TokenStream> {
454 let function: syn::ItemFn = syn::parse(input).map_err(syn_to_diag)
455 .map_err(|diag| diag.help("`#[route]` can only be used on functions"))?;
456
457 let full_attr = quote!(#[route(#args)]);
458 let attrs = Attribute::parse_outer.parse2(full_attr).map_err(syn_to_diag)?;
459 let attribute = match RouteAttribute::from_attrs("route", &attrs) {
460 Some(result) => result?,
461 None => return Err(Span::call_site().error("internal error: bad attribute"))
462 };
463
464 codegen_route(parse_route(attribute, function)?)
465 }
466
incomplete_route( method: crate::http::Method, args: TokenStream2, input: TokenStream ) -> Result<TokenStream>467 fn incomplete_route(
468 method: crate::http::Method,
469 args: TokenStream2,
470 input: TokenStream
471 ) -> Result<TokenStream> {
472 let method_str = method.to_string().to_lowercase();
473 // FIXME(proc_macro): there should be a way to get this `Span`.
474 let method_span = StringLit::new(format!("#[{}]", method), Span::call_site())
475 .subspan(2..2 + method_str.len());
476
477 let method_ident = syn::Ident::new(&method_str, method_span.into());
478
479 let function: syn::ItemFn = syn::parse(input).map_err(syn_to_diag)
480 .map_err(|d| d.help(format!("#[{}] can only be used on functions", method_str)))?;
481
482 let full_attr = quote!(#[#method_ident(#args)]);
483 let attrs = Attribute::parse_outer.parse2(full_attr).map_err(syn_to_diag)?;
484 let method_attribute = match MethodRouteAttribute::from_attrs(&method_str, &attrs) {
485 Some(result) => result?,
486 None => return Err(Span::call_site().error("internal error: bad attribute"))
487 };
488
489 let attribute = RouteAttribute {
490 method: SpanWrapped {
491 full_span: method_span, span: method_span, value: Method(method)
492 },
493 path: method_attribute.path,
494 data: method_attribute.data,
495 format: method_attribute.format,
496 rank: method_attribute.rank,
497 };
498
499 codegen_route(parse_route(attribute, function)?)
500 }
501
route_attribute<M: Into<Option<crate::http::Method>>>( method: M, args: TokenStream, input: TokenStream ) -> TokenStream502 pub fn route_attribute<M: Into<Option<crate::http::Method>>>(
503 method: M,
504 args: TokenStream,
505 input: TokenStream
506 ) -> TokenStream {
507 let result = match method.into() {
508 Some(method) => incomplete_route(method, args.into(), input),
509 None => complete_route(args.into(), input)
510 };
511
512 result.unwrap_or_else(|diag| { diag.emit(); TokenStream::new() })
513 }
514