1 #![doc(html_root_url = "https://docs.rs/prost-derive/0.8.0")]
2 // The `quote!` macro requires deep recursion.
3 #![recursion_limit = "4096"]
4 
5 extern crate alloc;
6 extern crate proc_macro;
7 
8 use anyhow::{bail, Error};
9 use itertools::Itertools;
10 use proc_macro::TokenStream;
11 use proc_macro2::Span;
12 use quote::quote;
13 use syn::{
14     punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
15     FieldsUnnamed, Ident, Variant,
16 };
17 
18 mod field;
19 use crate::field::Field;
20 
try_message(input: TokenStream) -> Result<TokenStream, Error>21 fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
22     let input: DeriveInput = syn::parse(input)?;
23 
24     let ident = input.ident;
25 
26     let variant_data = match input.data {
27         Data::Struct(variant_data) => variant_data,
28         Data::Enum(..) => bail!("Message can not be derived for an enum"),
29         Data::Union(..) => bail!("Message can not be derived for a union"),
30     };
31 
32     let generics = &input.generics;
33     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
34 
35     let fields = match variant_data {
36         DataStruct {
37             fields: Fields::Named(FieldsNamed { named: fields, .. }),
38             ..
39         }
40         | DataStruct {
41             fields:
42                 Fields::Unnamed(FieldsUnnamed {
43                     unnamed: fields, ..
44                 }),
45             ..
46         } => fields.into_iter().collect(),
47         DataStruct {
48             fields: Fields::Unit,
49             ..
50         } => Vec::new(),
51     };
52 
53     let mut next_tag: u32 = 1;
54     let mut fields = fields
55         .into_iter()
56         .enumerate()
57         .flat_map(|(idx, field)| {
58             let field_ident = field
59                 .ident
60                 .unwrap_or_else(|| Ident::new(&idx.to_string(), Span::call_site()));
61             match Field::new(field.attrs, Some(next_tag)) {
62                 Ok(Some(field)) => {
63                     next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
64                     Some(Ok((field_ident, field)))
65                 }
66                 Ok(None) => None,
67                 Err(err) => Some(Err(
68                     err.context(format!("invalid message field {}.{}", ident, field_ident))
69                 )),
70             }
71         })
72         .collect::<Result<Vec<_>, _>>()?;
73 
74     // We want Debug to be in declaration order
75     let unsorted_fields = fields.clone();
76 
77     // Sort the fields by tag number so that fields will be encoded in tag order.
78     // TODO: This encodes oneof fields in the position of their lowest tag,
79     // regardless of the currently occupied variant, is that consequential?
80     // See: https://developers.google.com/protocol-buffers/docs/encoding#order
81     fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
82     let fields = fields;
83 
84     let mut tags = fields
85         .iter()
86         .flat_map(|&(_, ref field)| field.tags())
87         .collect::<Vec<_>>();
88     let num_tags = tags.len();
89     tags.sort_unstable();
90     tags.dedup();
91     if tags.len() != num_tags {
92         bail!("message {} has fields with duplicate tags", ident);
93     }
94 
95     let encoded_len = fields
96         .iter()
97         .map(|&(ref field_ident, ref field)| field.encoded_len(quote!(self.#field_ident)));
98 
99     let encode = fields
100         .iter()
101         .map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));
102 
103     let merge = fields.iter().map(|&(ref field_ident, ref field)| {
104         let merge = field.merge(quote!(value));
105         let tags = field.tags().into_iter().map(|tag| quote!(#tag));
106         let tags = Itertools::intersperse(tags, quote!(|));
107 
108         quote! {
109             #(#tags)* => {
110                 let mut value = &mut self.#field_ident;
111                 #merge.map_err(|mut error| {
112                     error.push(STRUCT_NAME, stringify!(#field_ident));
113                     error
114                 })
115             },
116         }
117     });
118 
119     let struct_name = if fields.is_empty() {
120         quote!()
121     } else {
122         quote!(
123             const STRUCT_NAME: &'static str = stringify!(#ident);
124         )
125     };
126 
127     // TODO
128     let is_struct = true;
129 
130     let clear = fields
131         .iter()
132         .map(|&(ref field_ident, ref field)| field.clear(quote!(self.#field_ident)));
133 
134     let default = fields.iter().map(|&(ref field_ident, ref field)| {
135         let value = field.default();
136         quote!(#field_ident: #value,)
137     });
138 
139     let methods = fields
140         .iter()
141         .flat_map(|&(ref field_ident, ref field)| field.methods(field_ident))
142         .collect::<Vec<_>>();
143     let methods = if methods.is_empty() {
144         quote!()
145     } else {
146         quote! {
147             #[allow(dead_code)]
148             impl #impl_generics #ident #ty_generics #where_clause {
149                 #(#methods)*
150             }
151         }
152     };
153 
154     let debugs = unsorted_fields.iter().map(|&(ref field_ident, ref field)| {
155         let wrapper = field.debug(quote!(self.#field_ident));
156         let call = if is_struct {
157             quote!(builder.field(stringify!(#field_ident), &wrapper))
158         } else {
159             quote!(builder.field(&wrapper))
160         };
161         quote! {
162              let builder = {
163                  let wrapper = #wrapper;
164                  #call
165              };
166         }
167     });
168     let debug_builder = if is_struct {
169         quote!(f.debug_struct(stringify!(#ident)))
170     } else {
171         quote!(f.debug_tuple(stringify!(#ident)))
172     };
173 
174     let expanded = quote! {
175         impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
176             #[allow(unused_variables)]
177             fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
178                 #(#encode)*
179             }
180 
181             #[allow(unused_variables)]
182             fn merge_field<B>(
183                 &mut self,
184                 tag: u32,
185                 wire_type: ::prost::encoding::WireType,
186                 buf: &mut B,
187                 ctx: ::prost::encoding::DecodeContext,
188             ) -> ::core::result::Result<(), ::prost::DecodeError>
189             where B: ::prost::bytes::Buf {
190                 #struct_name
191                 match tag {
192                     #(#merge)*
193                     _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
194                 }
195             }
196 
197             #[inline]
198             fn encoded_len(&self) -> usize {
199                 0 #(+ #encoded_len)*
200             }
201 
202             fn clear(&mut self) {
203                 #(#clear;)*
204             }
205         }
206 
207         impl #impl_generics Default for #ident #ty_generics #where_clause {
208             fn default() -> Self {
209                 #ident {
210                     #(#default)*
211                 }
212             }
213         }
214 
215         impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
216             fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
217                 let mut builder = #debug_builder;
218                 #(#debugs;)*
219                 builder.finish()
220             }
221         }
222 
223         #methods
224     };
225 
226     Ok(expanded.into())
227 }
228 
229 #[proc_macro_derive(Message, attributes(prost))]
message(input: TokenStream) -> TokenStream230 pub fn message(input: TokenStream) -> TokenStream {
231     try_message(input).unwrap()
232 }
233 
try_enumeration(input: TokenStream) -> Result<TokenStream, Error>234 fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
235     let input: DeriveInput = syn::parse(input)?;
236     let ident = input.ident;
237 
238     let generics = &input.generics;
239     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
240 
241     let punctuated_variants = match input.data {
242         Data::Enum(DataEnum { variants, .. }) => variants,
243         Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
244         Data::Union(..) => bail!("Enumeration can not be derived for a union"),
245     };
246 
247     // Map the variants into 'fields'.
248     let mut variants: Vec<(Ident, Expr)> = Vec::new();
249     for Variant {
250         ident,
251         fields,
252         discriminant,
253         ..
254     } in punctuated_variants
255     {
256         match fields {
257             Fields::Unit => (),
258             Fields::Named(_) | Fields::Unnamed(_) => {
259                 bail!("Enumeration variants may not have fields")
260             }
261         }
262 
263         match discriminant {
264             Some((_, expr)) => variants.push((ident, expr)),
265             None => bail!("Enumeration variants must have a disriminant"),
266         }
267     }
268 
269     if variants.is_empty() {
270         panic!("Enumeration must have at least one variant");
271     }
272 
273     let default = variants[0].0.clone();
274 
275     let is_valid = variants
276         .iter()
277         .map(|&(_, ref value)| quote!(#value => true));
278     let from = variants.iter().map(
279         |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)),
280     );
281 
282     let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
283     let from_i32_doc = format!(
284         "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
285         ident
286     );
287 
288     let expanded = quote! {
289         impl #impl_generics #ident #ty_generics #where_clause {
290             #[doc=#is_valid_doc]
291             pub fn is_valid(value: i32) -> bool {
292                 match value {
293                     #(#is_valid,)*
294                     _ => false,
295                 }
296             }
297 
298             #[doc=#from_i32_doc]
299             pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
300                 match value {
301                     #(#from,)*
302                     _ => ::core::option::Option::None,
303                 }
304             }
305         }
306 
307         impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
308             fn default() -> #ident {
309                 #ident::#default
310             }
311         }
312 
313         impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
314             fn from(value: #ident) -> i32 {
315                 value as i32
316             }
317         }
318     };
319 
320     Ok(expanded.into())
321 }
322 
323 #[proc_macro_derive(Enumeration, attributes(prost))]
enumeration(input: TokenStream) -> TokenStream324 pub fn enumeration(input: TokenStream) -> TokenStream {
325     try_enumeration(input).unwrap()
326 }
327 
try_oneof(input: TokenStream) -> Result<TokenStream, Error>328 fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
329     let input: DeriveInput = syn::parse(input)?;
330 
331     let ident = input.ident;
332 
333     let variants = match input.data {
334         Data::Enum(DataEnum { variants, .. }) => variants,
335         Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
336         Data::Union(..) => bail!("Oneof can not be derived for a union"),
337     };
338 
339     let generics = &input.generics;
340     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
341 
342     // Map the variants into 'fields'.
343     let mut fields: Vec<(Ident, Field)> = Vec::new();
344     for Variant {
345         attrs,
346         ident: variant_ident,
347         fields: variant_fields,
348         ..
349     } in variants
350     {
351         let variant_fields = match variant_fields {
352             Fields::Unit => Punctuated::new(),
353             Fields::Named(FieldsNamed { named: fields, .. })
354             | Fields::Unnamed(FieldsUnnamed {
355                 unnamed: fields, ..
356             }) => fields,
357         };
358         if variant_fields.len() != 1 {
359             bail!("Oneof enum variants must have a single field");
360         }
361         match Field::new_oneof(attrs)? {
362             Some(field) => fields.push((variant_ident, field)),
363             None => bail!("invalid oneof variant: oneof variants may not be ignored"),
364         }
365     }
366 
367     let mut tags = fields
368         .iter()
369         .flat_map(|&(ref variant_ident, ref field)| -> Result<u32, Error> {
370             if field.tags().len() > 1 {
371                 bail!(
372                     "invalid oneof variant {}::{}: oneof variants may only have a single tag",
373                     ident,
374                     variant_ident
375                 );
376             }
377             Ok(field.tags()[0])
378         })
379         .collect::<Vec<_>>();
380     tags.sort_unstable();
381     tags.dedup();
382     if tags.len() != fields.len() {
383         panic!("invalid oneof {}: variants have duplicate tags", ident);
384     }
385 
386     let encode = fields.iter().map(|&(ref variant_ident, ref field)| {
387         let encode = field.encode(quote!(*value));
388         quote!(#ident::#variant_ident(ref value) => { #encode })
389     });
390 
391     let merge = fields.iter().map(|&(ref variant_ident, ref field)| {
392         let tag = field.tags()[0];
393         let merge = field.merge(quote!(value));
394         quote! {
395             #tag => {
396                 match field {
397                     ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
398                         #merge
399                     },
400                     _ => {
401                         let mut owned_value = ::core::default::Default::default();
402                         let value = &mut owned_value;
403                         #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value)))
404                     },
405                 }
406             }
407         }
408     });
409 
410     let encoded_len = fields.iter().map(|&(ref variant_ident, ref field)| {
411         let encoded_len = field.encoded_len(quote!(*value));
412         quote!(#ident::#variant_ident(ref value) => #encoded_len)
413     });
414 
415     let debug = fields.iter().map(|&(ref variant_ident, ref field)| {
416         let wrapper = field.debug(quote!(*value));
417         quote!(#ident::#variant_ident(ref value) => {
418             let wrapper = #wrapper;
419             f.debug_tuple(stringify!(#variant_ident))
420                 .field(&wrapper)
421                 .finish()
422         })
423     });
424 
425     let expanded = quote! {
426         impl #impl_generics #ident #ty_generics #where_clause {
427             pub fn encode<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
428                 match *self {
429                     #(#encode,)*
430                 }
431             }
432 
433             pub fn merge<B>(
434                 field: &mut ::core::option::Option<#ident #ty_generics>,
435                 tag: u32,
436                 wire_type: ::prost::encoding::WireType,
437                 buf: &mut B,
438                 ctx: ::prost::encoding::DecodeContext,
439             ) -> ::core::result::Result<(), ::prost::DecodeError>
440             where B: ::prost::bytes::Buf {
441                 match tag {
442                     #(#merge,)*
443                     _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
444                 }
445             }
446 
447             #[inline]
448             pub fn encoded_len(&self) -> usize {
449                 match *self {
450                     #(#encoded_len,)*
451                 }
452             }
453         }
454 
455         impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
456             fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
457                 match *self {
458                     #(#debug,)*
459                 }
460             }
461         }
462     };
463 
464     Ok(expanded.into())
465 }
466 
467 #[proc_macro_derive(Oneof, attributes(prost))]
oneof(input: TokenStream) -> TokenStream468 pub fn oneof(input: TokenStream) -> TokenStream {
469     try_oneof(input).unwrap()
470 }
471