1 #![doc(html_root_url = "https://docs.rs/prost-derive/0.6.1")]
2 // The `quote!` macro requires deep recursion.
3 #![recursion_limit = "4096"]
4
5 extern crate proc_macro;
6
7 use anyhow::bail;
8 use quote::quote;
9
10 use anyhow::Error;
11 use itertools::Itertools;
12 use proc_macro::TokenStream;
13 use proc_macro2::Span;
14 use syn::{
15 punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
16 FieldsUnnamed, Ident, Variant,
17 };
18
19 mod field;
20 use crate::field::Field;
21
try_message(input: TokenStream) -> Result<TokenStream, Error>22 fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
23 let input: DeriveInput = syn::parse(input)?;
24
25 let ident = input.ident;
26
27 let variant_data = match input.data {
28 Data::Struct(variant_data) => variant_data,
29 Data::Enum(..) => bail!("Message can not be derived for an enum"),
30 Data::Union(..) => bail!("Message can not be derived for a union"),
31 };
32
33 if !input.generics.params.is_empty() || input.generics.where_clause.is_some() {
34 bail!("Message may not be derived for generic type");
35 }
36
37 let fields = match variant_data {
38 DataStruct {
39 fields: Fields::Named(FieldsNamed { named: fields, .. }),
40 ..
41 }
42 | DataStruct {
43 fields:
44 Fields::Unnamed(FieldsUnnamed {
45 unnamed: fields, ..
46 }),
47 ..
48 } => fields.into_iter().collect(),
49 DataStruct {
50 fields: Fields::Unit,
51 ..
52 } => Vec::new(),
53 };
54
55 let mut next_tag: u32 = 1;
56 let mut fields = fields
57 .into_iter()
58 .enumerate()
59 .flat_map(|(idx, field)| {
60 let field_ident = field
61 .ident
62 .unwrap_or_else(|| Ident::new(&idx.to_string(), Span::call_site()));
63 match Field::new(field.attrs, Some(next_tag)) {
64 Ok(Some(field)) => {
65 next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
66 Some(Ok((field_ident, field)))
67 }
68 Ok(None) => None,
69 Err(err) => Some(Err(
70 err.context(format!("invalid message field {}.{}", ident, field_ident))
71 )),
72 }
73 })
74 .collect::<Result<Vec<_>, _>>()?;
75
76 // We want Debug to be in declaration order
77 let unsorted_fields = fields.clone();
78
79 // Sort the fields by tag number so that fields will be encoded in tag order.
80 // TODO: This encodes oneof fields in the position of their lowest tag,
81 // regardless of the currently occupied variant, is that consequential?
82 // See: https://developers.google.com/protocol-buffers/docs/encoding#order
83 fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
84 let fields = fields;
85
86 let mut tags = fields
87 .iter()
88 .flat_map(|&(_, ref field)| field.tags())
89 .collect::<Vec<_>>();
90 let num_tags = tags.len();
91 tags.sort();
92 tags.dedup();
93 if tags.len() != num_tags {
94 bail!("message {} has fields with duplicate tags", ident);
95 }
96
97 let encoded_len = fields
98 .iter()
99 .map(|&(ref field_ident, ref field)| field.encoded_len(quote!(self.#field_ident)));
100
101 let encode = fields
102 .iter()
103 .map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));
104
105 let merge = fields.iter().map(|&(ref field_ident, ref field)| {
106 let merge = field.merge(quote!(value));
107 let tags = field
108 .tags()
109 .into_iter()
110 .map(|tag| quote!(#tag))
111 .intersperse(quote!(|));
112 quote! {
113 #(#tags)* => {
114 let mut value = &mut self.#field_ident;
115 #merge.map_err(|mut error| {
116 error.push(STRUCT_NAME, stringify!(#field_ident));
117 error
118 })
119 },
120 }
121 });
122
123 let struct_name = if fields.is_empty() {
124 quote!()
125 } else {
126 quote!(
127 const STRUCT_NAME: &'static str = stringify!(#ident);
128 )
129 };
130
131 // TODO
132 let is_struct = true;
133
134 let clear = fields
135 .iter()
136 .map(|&(ref field_ident, ref field)| field.clear(quote!(self.#field_ident)));
137
138 let default = fields.iter().map(|&(ref field_ident, ref field)| {
139 let value = field.default();
140 quote!(#field_ident: #value,)
141 });
142
143 let methods = fields
144 .iter()
145 .flat_map(|&(ref field_ident, ref field)| field.methods(field_ident))
146 .collect::<Vec<_>>();
147 let methods = if methods.is_empty() {
148 quote!()
149 } else {
150 quote! {
151 #[allow(dead_code)]
152 impl #ident {
153 #(#methods)*
154 }
155 }
156 };
157
158 let debugs = unsorted_fields.iter().map(|&(ref field_ident, ref field)| {
159 let wrapper = field.debug(quote!(self.#field_ident));
160 let call = if is_struct {
161 quote!(builder.field(stringify!(#field_ident), &wrapper))
162 } else {
163 quote!(builder.field(&wrapper))
164 };
165 quote! {
166 let builder = {
167 let wrapper = #wrapper;
168 #call
169 };
170 }
171 });
172 let debug_builder = if is_struct {
173 quote!(f.debug_struct(stringify!(#ident)))
174 } else {
175 quote!(f.debug_tuple(stringify!(#ident)))
176 };
177
178 let expanded = quote! {
179 impl ::prost::Message for #ident {
180 #[allow(unused_variables)]
181 fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
182 #(#encode)*
183 }
184
185 #[allow(unused_variables)]
186 fn merge_field<B>(
187 &mut self,
188 tag: u32,
189 wire_type: ::prost::encoding::WireType,
190 buf: &mut B,
191 ctx: ::prost::encoding::DecodeContext,
192 ) -> ::std::result::Result<(), ::prost::DecodeError>
193 where B: ::prost::bytes::Buf {
194 #struct_name
195 match tag {
196 #(#merge)*
197 _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
198 }
199 }
200
201 #[inline]
202 fn encoded_len(&self) -> usize {
203 0 #(+ #encoded_len)*
204 }
205
206 fn clear(&mut self) {
207 #(#clear;)*
208 }
209 }
210
211 impl Default for #ident {
212 fn default() -> #ident {
213 #ident {
214 #(#default)*
215 }
216 }
217 }
218
219 impl ::std::fmt::Debug for #ident {
220 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
221 let mut builder = #debug_builder;
222 #(#debugs;)*
223 builder.finish()
224 }
225 }
226
227 #methods
228 };
229
230 Ok(expanded.into())
231 }
232
233 #[proc_macro_derive(Message, attributes(prost))]
message(input: TokenStream) -> TokenStream234 pub fn message(input: TokenStream) -> TokenStream {
235 try_message(input).unwrap()
236 }
237
try_enumeration(input: TokenStream) -> Result<TokenStream, Error>238 fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
239 let input: DeriveInput = syn::parse(input)?;
240 let ident = input.ident;
241
242 if !input.generics.params.is_empty() || input.generics.where_clause.is_some() {
243 bail!("Message may not be derived for generic type");
244 }
245
246 let punctuated_variants = match input.data {
247 Data::Enum(DataEnum { variants, .. }) => variants,
248 Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
249 Data::Union(..) => bail!("Enumeration can not be derived for a union"),
250 };
251
252 // Map the variants into 'fields'.
253 let mut variants: Vec<(Ident, Expr)> = Vec::new();
254 for Variant {
255 ident,
256 fields,
257 discriminant,
258 ..
259 } in punctuated_variants
260 {
261 match fields {
262 Fields::Unit => (),
263 Fields::Named(_) | Fields::Unnamed(_) => {
264 bail!("Enumeration variants may not have fields")
265 }
266 }
267
268 match discriminant {
269 Some((_, expr)) => variants.push((ident, expr)),
270 None => bail!("Enumeration variants must have a disriminant"),
271 }
272 }
273
274 if variants.is_empty() {
275 panic!("Enumeration must have at least one variant");
276 }
277
278 let default = variants[0].0.clone();
279
280 let is_valid = variants
281 .iter()
282 .map(|&(_, ref value)| quote!(#value => true));
283 let from = variants.iter().map(
284 |&(ref variant, ref value)| quote!(#value => ::std::option::Option::Some(#ident::#variant)),
285 );
286
287 let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
288 let from_i32_doc = format!(
289 "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
290 ident
291 );
292
293 let expanded = quote! {
294 impl #ident {
295 #[doc=#is_valid_doc]
296 pub fn is_valid(value: i32) -> bool {
297 match value {
298 #(#is_valid,)*
299 _ => false,
300 }
301 }
302
303 #[doc=#from_i32_doc]
304 pub fn from_i32(value: i32) -> ::std::option::Option<#ident> {
305 match value {
306 #(#from,)*
307 _ => ::std::option::Option::None,
308 }
309 }
310 }
311
312 impl ::std::default::Default for #ident {
313 fn default() -> #ident {
314 #ident::#default
315 }
316 }
317
318 impl ::std::convert::From<#ident> for i32 {
319 fn from(value: #ident) -> i32 {
320 value as i32
321 }
322 }
323 };
324
325 Ok(expanded.into())
326 }
327
328 #[proc_macro_derive(Enumeration, attributes(prost))]
enumeration(input: TokenStream) -> TokenStream329 pub fn enumeration(input: TokenStream) -> TokenStream {
330 try_enumeration(input).unwrap()
331 }
332
try_oneof(input: TokenStream) -> Result<TokenStream, Error>333 fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
334 let input: DeriveInput = syn::parse(input)?;
335
336 let ident = input.ident;
337
338 let variants = match input.data {
339 Data::Enum(DataEnum { variants, .. }) => variants,
340 Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
341 Data::Union(..) => bail!("Oneof can not be derived for a union"),
342 };
343
344 if !input.generics.params.is_empty() || input.generics.where_clause.is_some() {
345 bail!("Message may not be derived for generic type");
346 }
347
348 // Map the variants into 'fields'.
349 let mut fields: Vec<(Ident, Field)> = Vec::new();
350 for Variant {
351 attrs,
352 ident: variant_ident,
353 fields: variant_fields,
354 ..
355 } in variants
356 {
357 let variant_fields = match variant_fields {
358 Fields::Unit => Punctuated::new(),
359 Fields::Named(FieldsNamed { named: fields, .. })
360 | Fields::Unnamed(FieldsUnnamed {
361 unnamed: fields, ..
362 }) => fields,
363 };
364 if variant_fields.len() != 1 {
365 bail!("Oneof enum variants must have a single field");
366 }
367 match Field::new_oneof(attrs)? {
368 Some(field) => fields.push((variant_ident, field)),
369 None => bail!("invalid oneof variant: oneof variants may not be ignored"),
370 }
371 }
372
373 let mut tags = fields
374 .iter()
375 .flat_map(|&(ref variant_ident, ref field)| -> Result<u32, Error> {
376 if field.tags().len() > 1 {
377 bail!(
378 "invalid oneof variant {}::{}: oneof variants may only have a single tag",
379 ident,
380 variant_ident
381 );
382 }
383 Ok(field.tags()[0])
384 })
385 .collect::<Vec<_>>();
386 tags.sort();
387 tags.dedup();
388 if tags.len() != fields.len() {
389 panic!("invalid oneof {}: variants have duplicate tags", ident);
390 }
391
392 let encode = fields.iter().map(|&(ref variant_ident, ref field)| {
393 let encode = field.encode(quote!(*value));
394 quote!(#ident::#variant_ident(ref value) => { #encode })
395 });
396
397 let merge = fields.iter().map(|&(ref variant_ident, ref field)| {
398 let tag = field.tags()[0];
399 let merge = field.merge(quote!(value));
400 quote! {
401 #tag => {
402 match field {
403 ::std::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
404 #merge
405 },
406 _ => {
407 let mut owned_value = ::std::default::Default::default();
408 let value = &mut owned_value;
409 #merge.map(|_| *field = ::std::option::Option::Some(#ident::#variant_ident(owned_value)))
410 },
411 }
412 }
413 }
414 });
415
416 let encoded_len = fields.iter().map(|&(ref variant_ident, ref field)| {
417 let encoded_len = field.encoded_len(quote!(*value));
418 quote!(#ident::#variant_ident(ref value) => #encoded_len)
419 });
420
421 let debug = fields.iter().map(|&(ref variant_ident, ref field)| {
422 let wrapper = field.debug(quote!(*value));
423 quote!(#ident::#variant_ident(ref value) => {
424 let wrapper = #wrapper;
425 f.debug_tuple(stringify!(#variant_ident))
426 .field(&wrapper)
427 .finish()
428 })
429 });
430
431 let expanded = quote! {
432 impl #ident {
433 pub fn encode<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
434 match *self {
435 #(#encode,)*
436 }
437 }
438
439 pub fn merge<B>(
440 field: &mut ::std::option::Option<#ident>,
441 tag: u32,
442 wire_type: ::prost::encoding::WireType,
443 buf: &mut B,
444 ctx: ::prost::encoding::DecodeContext,
445 ) -> ::std::result::Result<(), ::prost::DecodeError>
446 where B: ::prost::bytes::Buf {
447 match tag {
448 #(#merge,)*
449 _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
450 }
451 }
452
453 #[inline]
454 pub fn encoded_len(&self) -> usize {
455 match *self {
456 #(#encoded_len,)*
457 }
458 }
459 }
460
461 impl ::std::fmt::Debug for #ident {
462 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
463 match *self {
464 #(#debug,)*
465 }
466 }
467 }
468 };
469
470 Ok(expanded.into())
471 }
472
473 #[proc_macro_derive(Oneof, attributes(prost))]
oneof(input: TokenStream) -> TokenStream474 pub fn oneof(input: TokenStream) -> TokenStream {
475 try_oneof(input).unwrap()
476 }
477