1 #![doc(html_root_url = "https://docs.rs/prost-derive/0.6.1")]
2 // The `quote!` macro requires deep recursion.
3 #![recursion_limit = "4096"]
4 
5 extern crate proc_macro;
6 
7 use anyhow::bail;
8 use quote::quote;
9 
10 use anyhow::Error;
11 use itertools::Itertools;
12 use proc_macro::TokenStream;
13 use proc_macro2::Span;
14 use syn::{
15     punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
16     FieldsUnnamed, Ident, Variant,
17 };
18 
19 mod field;
20 use crate::field::Field;
21 
try_message(input: TokenStream) -> Result<TokenStream, Error>22 fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
23     let input: DeriveInput = syn::parse(input)?;
24 
25     let ident = input.ident;
26 
27     let variant_data = match input.data {
28         Data::Struct(variant_data) => variant_data,
29         Data::Enum(..) => bail!("Message can not be derived for an enum"),
30         Data::Union(..) => bail!("Message can not be derived for a union"),
31     };
32 
33     if !input.generics.params.is_empty() || input.generics.where_clause.is_some() {
34         bail!("Message may not be derived for generic type");
35     }
36 
37     let fields = match variant_data {
38         DataStruct {
39             fields: Fields::Named(FieldsNamed { named: fields, .. }),
40             ..
41         }
42         | DataStruct {
43             fields:
44                 Fields::Unnamed(FieldsUnnamed {
45                     unnamed: fields, ..
46                 }),
47             ..
48         } => fields.into_iter().collect(),
49         DataStruct {
50             fields: Fields::Unit,
51             ..
52         } => Vec::new(),
53     };
54 
55     let mut next_tag: u32 = 1;
56     let mut fields = fields
57         .into_iter()
58         .enumerate()
59         .flat_map(|(idx, field)| {
60             let field_ident = field
61                 .ident
62                 .unwrap_or_else(|| Ident::new(&idx.to_string(), Span::call_site()));
63             match Field::new(field.attrs, Some(next_tag)) {
64                 Ok(Some(field)) => {
65                     next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
66                     Some(Ok((field_ident, field)))
67                 }
68                 Ok(None) => None,
69                 Err(err) => Some(Err(
70                     err.context(format!("invalid message field {}.{}", ident, field_ident))
71                 )),
72             }
73         })
74         .collect::<Result<Vec<_>, _>>()?;
75 
76     // We want Debug to be in declaration order
77     let unsorted_fields = fields.clone();
78 
79     // Sort the fields by tag number so that fields will be encoded in tag order.
80     // TODO: This encodes oneof fields in the position of their lowest tag,
81     // regardless of the currently occupied variant, is that consequential?
82     // See: https://developers.google.com/protocol-buffers/docs/encoding#order
83     fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
84     let fields = fields;
85 
86     let mut tags = fields
87         .iter()
88         .flat_map(|&(_, ref field)| field.tags())
89         .collect::<Vec<_>>();
90     let num_tags = tags.len();
91     tags.sort();
92     tags.dedup();
93     if tags.len() != num_tags {
94         bail!("message {} has fields with duplicate tags", ident);
95     }
96 
97     let encoded_len = fields
98         .iter()
99         .map(|&(ref field_ident, ref field)| field.encoded_len(quote!(self.#field_ident)));
100 
101     let encode = fields
102         .iter()
103         .map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));
104 
105     let merge = fields.iter().map(|&(ref field_ident, ref field)| {
106         let merge = field.merge(quote!(value));
107         let tags = field
108             .tags()
109             .into_iter()
110             .map(|tag| quote!(#tag))
111             .intersperse(quote!(|));
112         quote! {
113             #(#tags)* => {
114                 let mut value = &mut self.#field_ident;
115                 #merge.map_err(|mut error| {
116                     error.push(STRUCT_NAME, stringify!(#field_ident));
117                     error
118                 })
119             },
120         }
121     });
122 
123     let struct_name = if fields.is_empty() {
124         quote!()
125     } else {
126         quote!(
127             const STRUCT_NAME: &'static str = stringify!(#ident);
128         )
129     };
130 
131     // TODO
132     let is_struct = true;
133 
134     let clear = fields
135         .iter()
136         .map(|&(ref field_ident, ref field)| field.clear(quote!(self.#field_ident)));
137 
138     let default = fields.iter().map(|&(ref field_ident, ref field)| {
139         let value = field.default();
140         quote!(#field_ident: #value,)
141     });
142 
143     let methods = fields
144         .iter()
145         .flat_map(|&(ref field_ident, ref field)| field.methods(field_ident))
146         .collect::<Vec<_>>();
147     let methods = if methods.is_empty() {
148         quote!()
149     } else {
150         quote! {
151             #[allow(dead_code)]
152             impl #ident {
153                 #(#methods)*
154             }
155         }
156     };
157 
158     let debugs = unsorted_fields.iter().map(|&(ref field_ident, ref field)| {
159         let wrapper = field.debug(quote!(self.#field_ident));
160         let call = if is_struct {
161             quote!(builder.field(stringify!(#field_ident), &wrapper))
162         } else {
163             quote!(builder.field(&wrapper))
164         };
165         quote! {
166              let builder = {
167                  let wrapper = #wrapper;
168                  #call
169              };
170         }
171     });
172     let debug_builder = if is_struct {
173         quote!(f.debug_struct(stringify!(#ident)))
174     } else {
175         quote!(f.debug_tuple(stringify!(#ident)))
176     };
177 
178     let expanded = quote! {
179         impl ::prost::Message for #ident {
180             #[allow(unused_variables)]
181             fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
182                 #(#encode)*
183             }
184 
185             #[allow(unused_variables)]
186             fn merge_field<B>(
187                 &mut self,
188                 tag: u32,
189                 wire_type: ::prost::encoding::WireType,
190                 buf: &mut B,
191                 ctx: ::prost::encoding::DecodeContext,
192             ) -> ::std::result::Result<(), ::prost::DecodeError>
193             where B: ::prost::bytes::Buf {
194                 #struct_name
195                 match tag {
196                     #(#merge)*
197                     _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
198                 }
199             }
200 
201             #[inline]
202             fn encoded_len(&self) -> usize {
203                 0 #(+ #encoded_len)*
204             }
205 
206             fn clear(&mut self) {
207                 #(#clear;)*
208             }
209         }
210 
211         impl Default for #ident {
212             fn default() -> #ident {
213                 #ident {
214                     #(#default)*
215                 }
216             }
217         }
218 
219         impl ::std::fmt::Debug for #ident {
220             fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
221                 let mut builder = #debug_builder;
222                 #(#debugs;)*
223                 builder.finish()
224             }
225         }
226 
227         #methods
228     };
229 
230     Ok(expanded.into())
231 }
232 
233 #[proc_macro_derive(Message, attributes(prost))]
message(input: TokenStream) -> TokenStream234 pub fn message(input: TokenStream) -> TokenStream {
235     try_message(input).unwrap()
236 }
237 
try_enumeration(input: TokenStream) -> Result<TokenStream, Error>238 fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
239     let input: DeriveInput = syn::parse(input)?;
240     let ident = input.ident;
241 
242     if !input.generics.params.is_empty() || input.generics.where_clause.is_some() {
243         bail!("Message may not be derived for generic type");
244     }
245 
246     let punctuated_variants = match input.data {
247         Data::Enum(DataEnum { variants, .. }) => variants,
248         Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
249         Data::Union(..) => bail!("Enumeration can not be derived for a union"),
250     };
251 
252     // Map the variants into 'fields'.
253     let mut variants: Vec<(Ident, Expr)> = Vec::new();
254     for Variant {
255         ident,
256         fields,
257         discriminant,
258         ..
259     } in punctuated_variants
260     {
261         match fields {
262             Fields::Unit => (),
263             Fields::Named(_) | Fields::Unnamed(_) => {
264                 bail!("Enumeration variants may not have fields")
265             }
266         }
267 
268         match discriminant {
269             Some((_, expr)) => variants.push((ident, expr)),
270             None => bail!("Enumeration variants must have a disriminant"),
271         }
272     }
273 
274     if variants.is_empty() {
275         panic!("Enumeration must have at least one variant");
276     }
277 
278     let default = variants[0].0.clone();
279 
280     let is_valid = variants
281         .iter()
282         .map(|&(_, ref value)| quote!(#value => true));
283     let from = variants.iter().map(
284         |&(ref variant, ref value)| quote!(#value => ::std::option::Option::Some(#ident::#variant)),
285     );
286 
287     let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
288     let from_i32_doc = format!(
289         "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
290         ident
291     );
292 
293     let expanded = quote! {
294         impl #ident {
295             #[doc=#is_valid_doc]
296             pub fn is_valid(value: i32) -> bool {
297                 match value {
298                     #(#is_valid,)*
299                     _ => false,
300                 }
301             }
302 
303             #[doc=#from_i32_doc]
304             pub fn from_i32(value: i32) -> ::std::option::Option<#ident> {
305                 match value {
306                     #(#from,)*
307                     _ => ::std::option::Option::None,
308                 }
309             }
310         }
311 
312         impl ::std::default::Default for #ident {
313             fn default() -> #ident {
314                 #ident::#default
315             }
316         }
317 
318         impl ::std::convert::From<#ident> for i32 {
319             fn from(value: #ident) -> i32 {
320                 value as i32
321             }
322         }
323     };
324 
325     Ok(expanded.into())
326 }
327 
328 #[proc_macro_derive(Enumeration, attributes(prost))]
enumeration(input: TokenStream) -> TokenStream329 pub fn enumeration(input: TokenStream) -> TokenStream {
330     try_enumeration(input).unwrap()
331 }
332 
try_oneof(input: TokenStream) -> Result<TokenStream, Error>333 fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
334     let input: DeriveInput = syn::parse(input)?;
335 
336     let ident = input.ident;
337 
338     let variants = match input.data {
339         Data::Enum(DataEnum { variants, .. }) => variants,
340         Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
341         Data::Union(..) => bail!("Oneof can not be derived for a union"),
342     };
343 
344     if !input.generics.params.is_empty() || input.generics.where_clause.is_some() {
345         bail!("Message may not be derived for generic type");
346     }
347 
348     // Map the variants into 'fields'.
349     let mut fields: Vec<(Ident, Field)> = Vec::new();
350     for Variant {
351         attrs,
352         ident: variant_ident,
353         fields: variant_fields,
354         ..
355     } in variants
356     {
357         let variant_fields = match variant_fields {
358             Fields::Unit => Punctuated::new(),
359             Fields::Named(FieldsNamed { named: fields, .. })
360             | Fields::Unnamed(FieldsUnnamed {
361                 unnamed: fields, ..
362             }) => fields,
363         };
364         if variant_fields.len() != 1 {
365             bail!("Oneof enum variants must have a single field");
366         }
367         match Field::new_oneof(attrs)? {
368             Some(field) => fields.push((variant_ident, field)),
369             None => bail!("invalid oneof variant: oneof variants may not be ignored"),
370         }
371     }
372 
373     let mut tags = fields
374         .iter()
375         .flat_map(|&(ref variant_ident, ref field)| -> Result<u32, Error> {
376             if field.tags().len() > 1 {
377                 bail!(
378                     "invalid oneof variant {}::{}: oneof variants may only have a single tag",
379                     ident,
380                     variant_ident
381                 );
382             }
383             Ok(field.tags()[0])
384         })
385         .collect::<Vec<_>>();
386     tags.sort();
387     tags.dedup();
388     if tags.len() != fields.len() {
389         panic!("invalid oneof {}: variants have duplicate tags", ident);
390     }
391 
392     let encode = fields.iter().map(|&(ref variant_ident, ref field)| {
393         let encode = field.encode(quote!(*value));
394         quote!(#ident::#variant_ident(ref value) => { #encode })
395     });
396 
397     let merge = fields.iter().map(|&(ref variant_ident, ref field)| {
398         let tag = field.tags()[0];
399         let merge = field.merge(quote!(value));
400         quote! {
401             #tag => {
402                 match field {
403                     ::std::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
404                         #merge
405                     },
406                     _ => {
407                         let mut owned_value = ::std::default::Default::default();
408                         let value = &mut owned_value;
409                         #merge.map(|_| *field = ::std::option::Option::Some(#ident::#variant_ident(owned_value)))
410                     },
411                 }
412             }
413         }
414     });
415 
416     let encoded_len = fields.iter().map(|&(ref variant_ident, ref field)| {
417         let encoded_len = field.encoded_len(quote!(*value));
418         quote!(#ident::#variant_ident(ref value) => #encoded_len)
419     });
420 
421     let debug = fields.iter().map(|&(ref variant_ident, ref field)| {
422         let wrapper = field.debug(quote!(*value));
423         quote!(#ident::#variant_ident(ref value) => {
424             let wrapper = #wrapper;
425             f.debug_tuple(stringify!(#variant_ident))
426                 .field(&wrapper)
427                 .finish()
428         })
429     });
430 
431     let expanded = quote! {
432         impl #ident {
433             pub fn encode<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
434                 match *self {
435                     #(#encode,)*
436                 }
437             }
438 
439             pub fn merge<B>(
440                 field: &mut ::std::option::Option<#ident>,
441                 tag: u32,
442                 wire_type: ::prost::encoding::WireType,
443                 buf: &mut B,
444                 ctx: ::prost::encoding::DecodeContext,
445             ) -> ::std::result::Result<(), ::prost::DecodeError>
446             where B: ::prost::bytes::Buf {
447                 match tag {
448                     #(#merge,)*
449                     _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
450                 }
451             }
452 
453             #[inline]
454             pub fn encoded_len(&self) -> usize {
455                 match *self {
456                     #(#encoded_len,)*
457                 }
458             }
459         }
460 
461         impl ::std::fmt::Debug for #ident {
462             fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
463                 match *self {
464                     #(#debug,)*
465                 }
466             }
467         }
468     };
469 
470     Ok(expanded.into())
471 }
472 
473 #[proc_macro_derive(Oneof, attributes(prost))]
oneof(input: TokenStream) -> TokenStream474 pub fn oneof(input: TokenStream) -> TokenStream {
475     try_oneof(input).unwrap()
476 }
477