1 use std::ascii;
2 use std::borrow::Cow;
3 use std::collections::{HashMap, HashSet};
4 use std::iter;
5 
6 use itertools::{Either, Itertools};
7 use log::debug;
8 use multimap::MultiMap;
9 use prost_types::field_descriptor_proto::{Label, Type};
10 use prost_types::source_code_info::Location;
11 use prost_types::{
12     DescriptorProto, EnumDescriptorProto, EnumValueDescriptorProto, FieldDescriptorProto,
13     FieldOptions, FileDescriptorProto, OneofDescriptorProto, ServiceDescriptorProto,
14     SourceCodeInfo,
15 };
16 
17 use crate::ast::{Comments, Method, Service};
18 use crate::extern_paths::ExternPaths;
19 use crate::ident::{to_snake, to_upper_camel};
20 use crate::message_graph::MessageGraph;
21 use crate::{BytesType, Config, MapType};
22 
23 #[derive(PartialEq)]
24 enum Syntax {
25     Proto2,
26     Proto3,
27 }
28 
29 pub struct CodeGenerator<'a> {
30     config: &'a mut Config,
31     package: String,
32     source_info: SourceCodeInfo,
33     syntax: Syntax,
34     message_graph: &'a MessageGraph,
35     extern_paths: &'a ExternPaths,
36     depth: u8,
37     path: Vec<i32>,
38     buf: &'a mut String,
39 }
40 
41 impl<'a> CodeGenerator<'a> {
generate( config: &mut Config, message_graph: &MessageGraph, extern_paths: &ExternPaths, file: FileDescriptorProto, buf: &mut String, )42     pub fn generate(
43         config: &mut Config,
44         message_graph: &MessageGraph,
45         extern_paths: &ExternPaths,
46         file: FileDescriptorProto,
47         buf: &mut String,
48     ) {
49         let mut source_info = file
50             .source_code_info
51             .expect("no source code info in request");
52         source_info.location.retain(|location| {
53             let len = location.path.len();
54             len > 0 && len % 2 == 0
55         });
56         source_info
57             .location
58             .sort_by_key(|location| location.path.clone());
59 
60         let syntax = match file.syntax.as_ref().map(String::as_str) {
61             None | Some("proto2") => Syntax::Proto2,
62             Some("proto3") => Syntax::Proto3,
63             Some(s) => panic!("unknown syntax: {}", s),
64         };
65 
66         let mut code_gen = CodeGenerator {
67             config,
68             package: file.package.unwrap(),
69             source_info,
70             syntax,
71             message_graph,
72             extern_paths,
73             depth: 0,
74             path: Vec::new(),
75             buf,
76         };
77 
78         debug!(
79             "file: {:?}, package: {:?}",
80             file.name.as_ref().unwrap(),
81             code_gen.package
82         );
83 
84         code_gen.path.push(4);
85         for (idx, message) in file.message_type.into_iter().enumerate() {
86             code_gen.path.push(idx as i32);
87             code_gen.append_message(message);
88             code_gen.path.pop();
89         }
90         code_gen.path.pop();
91 
92         code_gen.path.push(5);
93         for (idx, desc) in file.enum_type.into_iter().enumerate() {
94             code_gen.path.push(idx as i32);
95             code_gen.append_enum(desc);
96             code_gen.path.pop();
97         }
98         code_gen.path.pop();
99 
100         if code_gen.config.service_generator.is_some() {
101             code_gen.path.push(6);
102             for (idx, service) in file.service.into_iter().enumerate() {
103                 code_gen.path.push(idx as i32);
104                 code_gen.push_service(service);
105                 code_gen.path.pop();
106             }
107 
108             if let Some(service_generator) = code_gen.config.service_generator.as_mut() {
109                 service_generator.finalize(code_gen.buf);
110             }
111 
112             code_gen.path.pop();
113         }
114     }
115 
append_message(&mut self, message: DescriptorProto)116     fn append_message(&mut self, message: DescriptorProto) {
117         debug!("  message: {:?}", message.name());
118 
119         let message_name = message.name().to_string();
120         let fq_message_name = format!(".{}.{}", self.package, message.name());
121 
122         // Skip external types.
123         if self.extern_paths.resolve_ident(&fq_message_name).is_some() {
124             return;
125         }
126 
127         // Split the nested message types into a vector of normal nested message types, and a map
128         // of the map field entry types. The path index of the nested message types is preserved so
129         // that comments can be retrieved.
130         type NestedTypes = Vec<(DescriptorProto, usize)>;
131         type MapTypes = HashMap<String, (FieldDescriptorProto, FieldDescriptorProto)>;
132         let (nested_types, map_types): (NestedTypes, MapTypes) = message
133             .nested_type
134             .into_iter()
135             .enumerate()
136             .partition_map(|(idx, nested_type)| {
137                 if nested_type
138                     .options
139                     .as_ref()
140                     .and_then(|options| options.map_entry)
141                     .unwrap_or(false)
142                 {
143                     let key = nested_type.field[0].clone();
144                     let value = nested_type.field[1].clone();
145                     assert_eq!("key", key.name());
146                     assert_eq!("value", value.name());
147 
148                     let name = format!("{}.{}", &fq_message_name, nested_type.name());
149                     Either::Right((name, (key, value)))
150                 } else {
151                     Either::Left((nested_type, idx))
152                 }
153             });
154 
155         // Split the fields into a vector of the normal fields, and oneof fields.
156         // Path indexes are preserved so that comments can be retrieved.
157         type Fields = Vec<(FieldDescriptorProto, usize)>;
158         type OneofFields = MultiMap<i32, (FieldDescriptorProto, usize)>;
159         let (fields, mut oneof_fields): (Fields, OneofFields) = message
160             .field
161             .into_iter()
162             .enumerate()
163             .partition_map(|(idx, field)| {
164                 if field.proto3_optional.unwrap_or(false) {
165                     Either::Left((field, idx))
166                 } else if let Some(oneof_index) = field.oneof_index {
167                     Either::Right((oneof_index, (field, idx)))
168                 } else {
169                     Either::Left((field, idx))
170                 }
171             });
172 
173         self.append_doc(&fq_message_name, None);
174         self.append_type_attributes(&fq_message_name);
175         self.push_indent();
176         self.buf
177             .push_str("#[derive(Clone, PartialEq, ::prost::Message)]\n");
178         self.push_indent();
179         self.buf.push_str("pub struct ");
180         self.buf.push_str(&to_upper_camel(&message_name));
181         self.buf.push_str(" {\n");
182 
183         self.depth += 1;
184         self.path.push(2);
185         for (field, idx) in fields {
186             self.path.push(idx as i32);
187             match field
188                 .type_name
189                 .as_ref()
190                 .and_then(|type_name| map_types.get(type_name))
191             {
192                 Some(&(ref key, ref value)) => {
193                     self.append_map_field(&fq_message_name, field, key, value)
194                 }
195                 None => self.append_field(&fq_message_name, field),
196             }
197             self.path.pop();
198         }
199         self.path.pop();
200 
201         self.path.push(8);
202         for (idx, oneof) in message.oneof_decl.iter().enumerate() {
203             let idx = idx as i32;
204 
205             let fields = match oneof_fields.get_vec(&idx) {
206                 Some(fields) => fields,
207                 None => continue,
208             };
209 
210             self.path.push(idx);
211             self.append_oneof_field(&message_name, &fq_message_name, oneof, fields);
212             self.path.pop();
213         }
214         self.path.pop();
215 
216         self.depth -= 1;
217         self.push_indent();
218         self.buf.push_str("}\n");
219 
220         if !message.enum_type.is_empty() || !nested_types.is_empty() || !oneof_fields.is_empty() {
221             self.push_mod(&message_name);
222             self.path.push(3);
223             for (nested_type, idx) in nested_types {
224                 self.path.push(idx as i32);
225                 self.append_message(nested_type);
226                 self.path.pop();
227             }
228             self.path.pop();
229 
230             self.path.push(4);
231             for (idx, nested_enum) in message.enum_type.into_iter().enumerate() {
232                 self.path.push(idx as i32);
233                 self.append_enum(nested_enum);
234                 self.path.pop();
235             }
236             self.path.pop();
237 
238             for (idx, oneof) in message.oneof_decl.into_iter().enumerate() {
239                 let idx = idx as i32;
240                 // optional fields create a synthetic oneof that we want to skip
241                 let fields = match oneof_fields.remove(&idx) {
242                     Some(fields) => fields,
243                     None => continue,
244                 };
245                 self.append_oneof(&fq_message_name, oneof, idx, fields);
246             }
247 
248             self.pop_mod();
249         }
250     }
251 
append_type_attributes(&mut self, fq_message_name: &str)252     fn append_type_attributes(&mut self, fq_message_name: &str) {
253         assert_eq!(b'.', fq_message_name.as_bytes()[0]);
254         // TODO: this clone is dirty, but expedious.
255         if let Some(attributes) = self.config.type_attributes.get(fq_message_name).cloned() {
256             self.push_indent();
257             self.buf.push_str(&attributes);
258             self.buf.push('\n');
259         }
260     }
261 
append_field_attributes(&mut self, fq_message_name: &str, field_name: &str)262     fn append_field_attributes(&mut self, fq_message_name: &str, field_name: &str) {
263         assert_eq!(b'.', fq_message_name.as_bytes()[0]);
264         // TODO: this clone is dirty, but expedious.
265         if let Some(attributes) = self
266             .config
267             .field_attributes
268             .get_field(fq_message_name, field_name)
269             .cloned()
270         {
271             self.push_indent();
272             self.buf.push_str(&attributes);
273             self.buf.push('\n');
274         }
275     }
276 
append_field(&mut self, fq_message_name: &str, field: FieldDescriptorProto)277     fn append_field(&mut self, fq_message_name: &str, field: FieldDescriptorProto) {
278         let type_ = field.r#type();
279         let repeated = field.label == Some(Label::Repeated as i32);
280         let deprecated = self.deprecated(&field);
281         let optional = self.optional(&field);
282         let ty = self.resolve_type(&field, fq_message_name);
283 
284         let boxed = !repeated
285             && (type_ == Type::Message || type_ == Type::Group)
286             && self
287                 .message_graph
288                 .is_nested(field.type_name(), fq_message_name);
289 
290         debug!(
291             "    field: {:?}, type: {:?}, boxed: {}",
292             field.name(),
293             ty,
294             boxed
295         );
296 
297         self.append_doc(fq_message_name, Some(field.name()));
298 
299         if deprecated {
300             self.push_indent();
301             self.buf.push_str("#[deprecated]\n");
302         }
303 
304         self.push_indent();
305         self.buf.push_str("#[prost(");
306         let type_tag = self.field_type_tag(&field);
307         self.buf.push_str(&type_tag);
308 
309         if type_ == Type::Bytes {
310             let bytes_type = self
311                 .config
312                 .bytes_type
313                 .get_field(fq_message_name, field.name())
314                 .copied()
315                 .unwrap_or_default();
316             self.buf
317                 .push_str(&format!("={:?}", bytes_type.annotation()));
318         }
319 
320         match field.label() {
321             Label::Optional => {
322                 if optional {
323                     self.buf.push_str(", optional");
324                 }
325             }
326             Label::Required => self.buf.push_str(", required"),
327             Label::Repeated => {
328                 self.buf.push_str(", repeated");
329                 if can_pack(&field)
330                     && !field
331                         .options
332                         .as_ref()
333                         .map_or(self.syntax == Syntax::Proto3, |options| options.packed())
334                 {
335                     self.buf.push_str(", packed=\"false\"");
336                 }
337             }
338         }
339 
340         if boxed {
341             self.buf.push_str(", boxed");
342         }
343         self.buf.push_str(", tag=\"");
344         self.buf.push_str(&field.number().to_string());
345 
346         if let Some(ref default) = field.default_value {
347             self.buf.push_str("\", default=\"");
348             if type_ == Type::Bytes {
349                 self.buf.push_str("b\\\"");
350                 for b in unescape_c_escape_string(default) {
351                     self.buf.extend(
352                         ascii::escape_default(b).flat_map(|c| (c as char).escape_default()),
353                     );
354                 }
355                 self.buf.push_str("\\\"");
356             } else if type_ == Type::Enum {
357                 let enum_value = to_upper_camel(default);
358                 let stripped_prefix = if self.config.strip_enum_prefix {
359                     // Field types are fully qualified, so we extract
360                     // the last segment and strip it from the left
361                     // side of the default value.
362                     let enum_type = field
363                         .type_name
364                         .as_ref()
365                         .and_then(|ty| ty.split('.').last())
366                         .unwrap();
367 
368                     strip_enum_prefix(&to_upper_camel(&enum_type), &enum_value)
369                 } else {
370                     &enum_value
371                 };
372                 self.buf.push_str(stripped_prefix);
373             } else {
374                 // TODO: this is only correct if the Protobuf escaping matches Rust escaping. To be
375                 // safer, we should unescape the Protobuf string and re-escape it with the Rust
376                 // escaping mechanisms.
377                 self.buf.push_str(default);
378             }
379         }
380 
381         self.buf.push_str("\")]\n");
382         self.append_field_attributes(fq_message_name, field.name());
383         self.push_indent();
384         self.buf.push_str("pub ");
385         self.buf.push_str(&to_snake(field.name()));
386         self.buf.push_str(": ");
387         if repeated {
388             self.buf.push_str("::prost::alloc::vec::Vec<");
389         } else if optional {
390             self.buf.push_str("::core::option::Option<");
391         }
392         if boxed {
393             self.buf.push_str("::prost::alloc::boxed::Box<");
394         }
395         self.buf.push_str(&ty);
396         if boxed {
397             self.buf.push('>');
398         }
399         if repeated || optional {
400             self.buf.push('>');
401         }
402         self.buf.push_str(",\n");
403     }
404 
append_map_field( &mut self, fq_message_name: &str, field: FieldDescriptorProto, key: &FieldDescriptorProto, value: &FieldDescriptorProto, )405     fn append_map_field(
406         &mut self,
407         fq_message_name: &str,
408         field: FieldDescriptorProto,
409         key: &FieldDescriptorProto,
410         value: &FieldDescriptorProto,
411     ) {
412         let key_ty = self.resolve_type(key, fq_message_name);
413         let value_ty = self.resolve_type(value, fq_message_name);
414 
415         debug!(
416             "    map field: {:?}, key type: {:?}, value type: {:?}",
417             field.name(),
418             key_ty,
419             value_ty
420         );
421 
422         self.append_doc(fq_message_name, Some(field.name()));
423         self.push_indent();
424 
425         let map_type = self
426             .config
427             .map_type
428             .get_field(fq_message_name, field.name())
429             .copied()
430             .unwrap_or_default();
431         let key_tag = self.field_type_tag(key);
432         let value_tag = self.map_value_type_tag(value);
433 
434         self.buf.push_str(&format!(
435             "#[prost({}=\"{}, {}\", tag=\"{}\")]\n",
436             map_type.annotation(),
437             key_tag,
438             value_tag,
439             field.number()
440         ));
441         self.append_field_attributes(fq_message_name, field.name());
442         self.push_indent();
443         self.buf.push_str(&format!(
444             "pub {}: {}<{}, {}>,\n",
445             to_snake(field.name()),
446             map_type.rust_type(),
447             key_ty,
448             value_ty
449         ));
450     }
451 
append_oneof_field( &mut self, message_name: &str, fq_message_name: &str, oneof: &OneofDescriptorProto, fields: &[(FieldDescriptorProto, usize)], )452     fn append_oneof_field(
453         &mut self,
454         message_name: &str,
455         fq_message_name: &str,
456         oneof: &OneofDescriptorProto,
457         fields: &[(FieldDescriptorProto, usize)],
458     ) {
459         let name = format!(
460             "{}::{}",
461             to_snake(message_name),
462             to_upper_camel(oneof.name())
463         );
464         self.append_doc(fq_message_name, None);
465         self.push_indent();
466         self.buf.push_str(&format!(
467             "#[prost(oneof=\"{}\", tags=\"{}\")]\n",
468             name,
469             fields
470                 .iter()
471                 .map(|&(ref field, _)| field.number())
472                 .join(", ")
473         ));
474         self.append_field_attributes(fq_message_name, oneof.name());
475         self.push_indent();
476         self.buf.push_str(&format!(
477             "pub {}: ::core::option::Option<{}>,\n",
478             to_snake(oneof.name()),
479             name
480         ));
481     }
482 
append_oneof( &mut self, fq_message_name: &str, oneof: OneofDescriptorProto, idx: i32, fields: Vec<(FieldDescriptorProto, usize)>, )483     fn append_oneof(
484         &mut self,
485         fq_message_name: &str,
486         oneof: OneofDescriptorProto,
487         idx: i32,
488         fields: Vec<(FieldDescriptorProto, usize)>,
489     ) {
490         self.path.push(8);
491         self.path.push(idx);
492         self.append_doc(fq_message_name, None);
493         self.path.pop();
494         self.path.pop();
495 
496         let oneof_name = format!("{}.{}", fq_message_name, oneof.name());
497         self.append_type_attributes(&oneof_name);
498         self.push_indent();
499         self.buf
500             .push_str("#[derive(Clone, PartialEq, ::prost::Oneof)]\n");
501         self.push_indent();
502         self.buf.push_str("pub enum ");
503         self.buf.push_str(&to_upper_camel(oneof.name()));
504         self.buf.push_str(" {\n");
505 
506         self.path.push(2);
507         self.depth += 1;
508         for (field, idx) in fields {
509             let type_ = field.r#type();
510 
511             self.path.push(idx as i32);
512             self.append_doc(fq_message_name, Some(field.name()));
513             self.path.pop();
514 
515             self.push_indent();
516             let ty_tag = self.field_type_tag(&field);
517             self.buf.push_str(&format!(
518                 "#[prost({}, tag=\"{}\")]\n",
519                 ty_tag,
520                 field.number()
521             ));
522             self.append_field_attributes(&oneof_name, field.name());
523 
524             self.push_indent();
525             let ty = self.resolve_type(&field, fq_message_name);
526 
527             let boxed = (type_ == Type::Message || type_ == Type::Group)
528                 && self
529                     .message_graph
530                     .is_nested(field.type_name(), fq_message_name);
531 
532             debug!(
533                 "    oneof: {:?}, type: {:?}, boxed: {}",
534                 field.name(),
535                 ty,
536                 boxed
537             );
538 
539             if boxed {
540                 self.buf.push_str(&format!(
541                     "{}(::prost::alloc::boxed::Box<{}>),\n",
542                     to_upper_camel(field.name()),
543                     ty
544                 ));
545             } else {
546                 self.buf
547                     .push_str(&format!("{}({}),\n", to_upper_camel(field.name()), ty));
548             }
549         }
550         self.depth -= 1;
551         self.path.pop();
552 
553         self.push_indent();
554         self.buf.push_str("}\n");
555     }
556 
location(&self) -> &Location557     fn location(&self) -> &Location {
558         let idx = self
559             .source_info
560             .location
561             .binary_search_by_key(&&self.path[..], |location| &location.path[..])
562             .unwrap();
563 
564         &self.source_info.location[idx]
565     }
566 
append_doc(&mut self, fq_name: &str, field_name: Option<&str>)567     fn append_doc(&mut self, fq_name: &str, field_name: Option<&str>) {
568         let append_doc = if let Some(field_name) = field_name {
569             self.config
570                 .disable_comments
571                 .get_field(fq_name, field_name)
572                 .is_none()
573         } else {
574             self.config.disable_comments.get(fq_name).is_none()
575         };
576         if append_doc {
577             Comments::from_location(self.location()).append_with_indent(self.depth, &mut self.buf)
578         }
579     }
580 
append_enum(&mut self, desc: EnumDescriptorProto)581     fn append_enum(&mut self, desc: EnumDescriptorProto) {
582         debug!("  enum: {:?}", desc.name());
583 
584         // Skip external types.
585         let enum_name = &desc.name();
586         let enum_values = &desc.value;
587         let fq_enum_name = format!(".{}.{}", self.package, enum_name);
588         if self.extern_paths.resolve_ident(&fq_enum_name).is_some() {
589             return;
590         }
591 
592         self.append_doc(&fq_enum_name, None);
593         self.append_type_attributes(&fq_enum_name);
594         self.push_indent();
595         self.buf.push_str(
596             "#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]\n",
597         );
598         self.push_indent();
599         self.buf.push_str("#[repr(i32)]\n");
600         self.push_indent();
601         self.buf.push_str("pub enum ");
602         self.buf.push_str(&to_upper_camel(desc.name()));
603         self.buf.push_str(" {\n");
604 
605         let mut numbers = HashSet::new();
606 
607         self.depth += 1;
608         self.path.push(2);
609         for (idx, value) in enum_values.iter().enumerate() {
610             // Skip duplicate enum values. Protobuf allows this when the
611             // 'allow_alias' option is set.
612             if !numbers.insert(value.number()) {
613                 continue;
614             }
615 
616             self.path.push(idx as i32);
617             let stripped_prefix = if self.config.strip_enum_prefix {
618                 Some(to_upper_camel(&enum_name))
619             } else {
620                 None
621             };
622             self.append_enum_value(&fq_enum_name, value, stripped_prefix);
623             self.path.pop();
624         }
625         self.path.pop();
626         self.depth -= 1;
627 
628         self.push_indent();
629         self.buf.push_str("}\n");
630     }
631 
append_enum_value( &mut self, fq_enum_name: &str, value: &EnumValueDescriptorProto, prefix_to_strip: Option<String>, )632     fn append_enum_value(
633         &mut self,
634         fq_enum_name: &str,
635         value: &EnumValueDescriptorProto,
636         prefix_to_strip: Option<String>,
637     ) {
638         self.append_doc(fq_enum_name, Some(value.name()));
639         self.append_field_attributes(fq_enum_name, &value.name());
640         self.push_indent();
641         let name = to_upper_camel(value.name());
642         let name_unprefixed = match prefix_to_strip {
643             Some(prefix) => strip_enum_prefix(&prefix, &name),
644             None => &name,
645         };
646         self.buf.push_str(name_unprefixed);
647         self.buf.push_str(" = ");
648         self.buf.push_str(&value.number().to_string());
649         self.buf.push_str(",\n");
650     }
651 
push_service(&mut self, service: ServiceDescriptorProto)652     fn push_service(&mut self, service: ServiceDescriptorProto) {
653         let name = service.name().to_owned();
654         debug!("  service: {:?}", name);
655 
656         let comments = Comments::from_location(self.location());
657 
658         self.path.push(2);
659         let methods = service
660             .method
661             .into_iter()
662             .enumerate()
663             .map(|(idx, mut method)| {
664                 debug!("  method: {:?}", method.name());
665                 self.path.push(idx as i32);
666                 let comments = Comments::from_location(self.location());
667                 self.path.pop();
668 
669                 let name = method.name.take().unwrap();
670                 let input_proto_type = method.input_type.take().unwrap();
671                 let output_proto_type = method.output_type.take().unwrap();
672                 let input_type = self.resolve_ident(&input_proto_type);
673                 let output_type = self.resolve_ident(&output_proto_type);
674                 let client_streaming = method.client_streaming();
675                 let server_streaming = method.server_streaming();
676 
677                 Method {
678                     name: to_snake(&name),
679                     proto_name: name,
680                     comments,
681                     input_type,
682                     output_type,
683                     input_proto_type,
684                     output_proto_type,
685                     options: method.options.unwrap_or_default(),
686                     client_streaming,
687                     server_streaming,
688                 }
689             })
690             .collect();
691         self.path.pop();
692 
693         let service = Service {
694             name: to_upper_camel(&name),
695             proto_name: name,
696             package: self.package.clone(),
697             comments,
698             methods,
699             options: service.options.unwrap_or_default(),
700         };
701 
702         if let Some(service_generator) = self.config.service_generator.as_mut() {
703             service_generator.generate(service, &mut self.buf)
704         }
705     }
706 
push_indent(&mut self)707     fn push_indent(&mut self) {
708         for _ in 0..self.depth {
709             self.buf.push_str("    ");
710         }
711     }
712 
push_mod(&mut self, module: &str)713     fn push_mod(&mut self, module: &str) {
714         self.push_indent();
715         self.buf.push_str("/// Nested message and enum types in `");
716         self.buf.push_str(module);
717         self.buf.push_str("`.\n");
718 
719         self.push_indent();
720         self.buf.push_str("pub mod ");
721         self.buf.push_str(&to_snake(module));
722         self.buf.push_str(" {\n");
723 
724         self.package.push('.');
725         self.package.push_str(module);
726 
727         self.depth += 1;
728     }
729 
pop_mod(&mut self)730     fn pop_mod(&mut self) {
731         self.depth -= 1;
732 
733         let idx = self.package.rfind('.').unwrap();
734         self.package.truncate(idx);
735 
736         self.push_indent();
737         self.buf.push_str("}\n");
738     }
739 
resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String740     fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String {
741         match field.r#type() {
742             Type::Float => String::from("f32"),
743             Type::Double => String::from("f64"),
744             Type::Uint32 | Type::Fixed32 => String::from("u32"),
745             Type::Uint64 | Type::Fixed64 => String::from("u64"),
746             Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"),
747             Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"),
748             Type::Bool => String::from("bool"),
749             Type::String => String::from("::prost::alloc::string::String"),
750             Type::Bytes => self
751                 .config
752                 .bytes_type
753                 .get_field(fq_message_name, field.name())
754                 .copied()
755                 .unwrap_or_default()
756                 .rust_type()
757                 .to_owned(),
758             Type::Group | Type::Message => self.resolve_ident(field.type_name()),
759         }
760     }
761 
resolve_ident(&self, pb_ident: &str) -> String762     fn resolve_ident(&self, pb_ident: &str) -> String {
763         // protoc should always give fully qualified identifiers.
764         assert_eq!(".", &pb_ident[..1]);
765 
766         if let Some(proto_ident) = self.extern_paths.resolve_ident(pb_ident) {
767             return proto_ident;
768         }
769 
770         let mut local_path = self.package.split('.').peekable();
771 
772         let mut ident_path = pb_ident[1..].split('.');
773         let ident_type = ident_path.next_back().unwrap();
774         let mut ident_path = ident_path.peekable();
775 
776         // Skip path elements in common.
777         while local_path.peek().is_some() && local_path.peek() == ident_path.peek() {
778             local_path.next();
779             ident_path.next();
780         }
781 
782         local_path
783             .map(|_| "super".to_string())
784             .chain(ident_path.map(to_snake))
785             .chain(iter::once(to_upper_camel(ident_type)))
786             .join("::")
787     }
788 
field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str>789     fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
790         match field.r#type() {
791             Type::Float => Cow::Borrowed("float"),
792             Type::Double => Cow::Borrowed("double"),
793             Type::Int32 => Cow::Borrowed("int32"),
794             Type::Int64 => Cow::Borrowed("int64"),
795             Type::Uint32 => Cow::Borrowed("uint32"),
796             Type::Uint64 => Cow::Borrowed("uint64"),
797             Type::Sint32 => Cow::Borrowed("sint32"),
798             Type::Sint64 => Cow::Borrowed("sint64"),
799             Type::Fixed32 => Cow::Borrowed("fixed32"),
800             Type::Fixed64 => Cow::Borrowed("fixed64"),
801             Type::Sfixed32 => Cow::Borrowed("sfixed32"),
802             Type::Sfixed64 => Cow::Borrowed("sfixed64"),
803             Type::Bool => Cow::Borrowed("bool"),
804             Type::String => Cow::Borrowed("string"),
805             Type::Bytes => Cow::Borrowed("bytes"),
806             Type::Group => Cow::Borrowed("group"),
807             Type::Message => Cow::Borrowed("message"),
808             Type::Enum => Cow::Owned(format!(
809                 "enumeration={:?}",
810                 self.resolve_ident(field.type_name())
811             )),
812         }
813     }
814 
map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str>815     fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
816         match field.r#type() {
817             Type::Enum => Cow::Owned(format!(
818                 "enumeration({})",
819                 self.resolve_ident(field.type_name())
820             )),
821             _ => self.field_type_tag(field),
822         }
823     }
824 
optional(&self, field: &FieldDescriptorProto) -> bool825     fn optional(&self, field: &FieldDescriptorProto) -> bool {
826         if field.proto3_optional.unwrap_or(false) {
827             return true;
828         }
829 
830         if field.label() != Label::Optional {
831             return false;
832         }
833 
834         match field.r#type() {
835             Type::Message => true,
836             _ => self.syntax == Syntax::Proto2,
837         }
838     }
839 
840     /// Returns `true` if the field options includes the `deprecated` option.
deprecated(&self, field: &FieldDescriptorProto) -> bool841     fn deprecated(&self, field: &FieldDescriptorProto) -> bool {
842         field
843             .options
844             .as_ref()
845             .map_or(false, FieldOptions::deprecated)
846     }
847 }
848 
849 /// Returns `true` if the repeated field type can be packed.
can_pack(field: &FieldDescriptorProto) -> bool850 fn can_pack(field: &FieldDescriptorProto) -> bool {
851     matches!(
852         field.r#type(),
853         Type::Float
854             | Type::Double
855             | Type::Int32
856             | Type::Int64
857             | Type::Uint32
858             | Type::Uint64
859             | Type::Sint32
860             | Type::Sint64
861             | Type::Fixed32
862             | Type::Fixed64
863             | Type::Sfixed32
864             | Type::Sfixed64
865             | Type::Bool
866             | Type::Enum
867     )
868 }
869 
870 /// Based on [`google::protobuf::UnescapeCEscapeString`][1]
871 /// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/stubs/strutil.cc#L312-L322
unescape_c_escape_string(s: &str) -> Vec<u8>872 fn unescape_c_escape_string(s: &str) -> Vec<u8> {
873     let src = s.as_bytes();
874     let len = src.len();
875     let mut dst = Vec::new();
876 
877     let mut p = 0;
878 
879     while p < len {
880         if src[p] != b'\\' {
881             dst.push(src[p]);
882             p += 1;
883         } else {
884             p += 1;
885             if p == len {
886                 panic!(
887                     "invalid c-escaped default binary value ({}): ends with '\'",
888                     s
889                 )
890             }
891             match src[p] {
892                 b'a' => {
893                     dst.push(0x07);
894                     p += 1;
895                 }
896                 b'b' => {
897                     dst.push(0x08);
898                     p += 1;
899                 }
900                 b'f' => {
901                     dst.push(0x0C);
902                     p += 1;
903                 }
904                 b'n' => {
905                     dst.push(0x0A);
906                     p += 1;
907                 }
908                 b'r' => {
909                     dst.push(0x0D);
910                     p += 1;
911                 }
912                 b't' => {
913                     dst.push(0x09);
914                     p += 1;
915                 }
916                 b'v' => {
917                     dst.push(0x0B);
918                     p += 1;
919                 }
920                 b'\\' => {
921                     dst.push(0x5C);
922                     p += 1;
923                 }
924                 b'?' => {
925                     dst.push(0x3F);
926                     p += 1;
927                 }
928                 b'\'' => {
929                     dst.push(0x27);
930                     p += 1;
931                 }
932                 b'"' => {
933                     dst.push(0x22);
934                     p += 1;
935                 }
936                 b'0'..=b'7' => {
937                     eprintln!("another octal: {}, offset: {}", s, &s[p..]);
938                     let mut octal = 0;
939                     for _ in 0..3 {
940                         if p < len && src[p] >= b'0' && src[p] <= b'7' {
941                             eprintln!("\toctal: {}", octal);
942                             octal = octal * 8 + (src[p] - b'0');
943                             p += 1;
944                         } else {
945                             break;
946                         }
947                     }
948                     dst.push(octal);
949                 }
950                 b'x' | b'X' => {
951                     if p + 2 > len {
952                         panic!(
953                             "invalid c-escaped default binary value ({}): incomplete hex value",
954                             s
955                         )
956                     }
957                     match u8::from_str_radix(&s[p + 1..p + 3], 16) {
958                         Ok(b) => dst.push(b),
959                         _ => panic!(
960                             "invalid c-escaped default binary value ({}): invalid hex value",
961                             &s[p..p + 2]
962                         ),
963                     }
964                     p += 3;
965                 }
966                 _ => panic!(
967                     "invalid c-escaped default binary value ({}): invalid escape",
968                     s
969                 ),
970             }
971         }
972     }
973     dst
974 }
975 
976 /// Strip an enum's type name from the prefix of an enum value.
977 ///
978 /// This function assumes that both have been formatted to Rust's
979 /// upper camel case naming conventions.
980 ///
981 /// It also tries to handle cases where the stripped name would be
982 /// invalid - for example, if it were to begin with a number.
strip_enum_prefix<'a>(prefix: &str, name: &'a str) -> &'a str983 fn strip_enum_prefix<'a>(prefix: &str, name: &'a str) -> &'a str {
984     let stripped = name.strip_prefix(prefix).unwrap_or(name);
985 
986     // If the next character after the stripped prefix is not
987     // uppercase, then it means that we didn't have a true prefix -
988     // for example, "Foo" should not be stripped from "Foobar".
989     if stripped
990         .chars()
991         .next()
992         .map(char::is_uppercase)
993         .unwrap_or(false)
994     {
995         stripped
996     } else {
997         name
998     }
999 }
1000 
1001 impl MapType {
1002     /// The `prost-derive` annotation type corresponding to the map type.
annotation(&self) -> &'static str1003     fn annotation(&self) -> &'static str {
1004         match self {
1005             MapType::HashMap => "map",
1006             MapType::BTreeMap => "btree_map",
1007         }
1008     }
1009 
1010     /// The fully-qualified Rust type corresponding to the map type.
rust_type(&self) -> &'static str1011     fn rust_type(&self) -> &'static str {
1012         match self {
1013             MapType::HashMap => "::std::collections::HashMap",
1014             MapType::BTreeMap => "::prost::alloc::collections::BTreeMap",
1015         }
1016     }
1017 }
1018 
1019 impl BytesType {
1020     /// The `prost-derive` annotation type corresponding to the bytes type.
annotation(&self) -> &'static str1021     fn annotation(&self) -> &'static str {
1022         match self {
1023             BytesType::Vec => "vec",
1024             BytesType::Bytes => "bytes",
1025         }
1026     }
1027 
1028     /// The fully-qualified Rust type corresponding to the bytes type.
rust_type(&self) -> &'static str1029     fn rust_type(&self) -> &'static str {
1030         match self {
1031             BytesType::Vec => "::prost::alloc::vec::Vec<u8>",
1032             BytesType::Bytes => "::prost::bytes::Bytes",
1033         }
1034     }
1035 }
1036 
1037 #[cfg(test)]
1038 mod tests {
1039     use super::*;
1040 
1041     #[test]
test_unescape_c_escape_string()1042     fn test_unescape_c_escape_string() {
1043         assert_eq!(
1044             &b"hello world"[..],
1045             &unescape_c_escape_string("hello world")[..]
1046         );
1047 
1048         assert_eq!(&b"\0"[..], &unescape_c_escape_string(r#"\0"#)[..]);
1049 
1050         assert_eq!(
1051             &[0o012, 0o156],
1052             &unescape_c_escape_string(r#"\012\156"#)[..]
1053         );
1054         assert_eq!(&[0x01, 0x02], &unescape_c_escape_string(r#"\x01\x02"#)[..]);
1055 
1056         assert_eq!(
1057             &b"\0\x01\x07\x08\x0C\n\r\t\x0B\\\'\"\xFE"[..],
1058             &unescape_c_escape_string(r#"\0\001\a\b\f\n\r\t\v\\\'\"\xfe"#)[..]
1059         );
1060     }
1061 
1062     #[test]
test_strip_enum_prefix()1063     fn test_strip_enum_prefix() {
1064         assert_eq!(strip_enum_prefix("Foo", "FooBar"), "Bar");
1065         assert_eq!(strip_enum_prefix("Foo", "Foobar"), "Foobar");
1066         assert_eq!(strip_enum_prefix("Foo", "Foo"), "Foo");
1067         assert_eq!(strip_enum_prefix("Foo", "Bar"), "Bar");
1068         assert_eq!(strip_enum_prefix("Foo", "Foo1"), "Foo1");
1069     }
1070 }
1071