1 // Copyright (c) 2009-2021, Google LLC
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are met:
6 //     * Redistributions of source code must retain the above copyright
7 //       notice, this list of conditions and the following disclaimer.
8 //     * Redistributions in binary form must reproduce the above copyright
9 //       notice, this list of conditions and the following disclaimer in the
10 //       documentation and/or other materials provided with the distribution.
11 //     * Neither the name of Google LLC nor the
12 //       names of its contributors may be used to endorse or promote products
13 //       derived from this software without specific prior written permission.
14 //
15 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 // DISCLAIMED. IN NO EVENT SHALL Google LLC BE LIABLE FOR ANY
19 // DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 
26 #include <memory>
27 
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/strings/ascii.h"
30 #include "absl/strings/substitute.h"
31 #include "google/protobuf/compiler/code_generator.h"
32 #include "google/protobuf/compiler/plugin.h"
33 #include "google/protobuf/descriptor.h"
34 #include "google/protobuf/descriptor.pb.h"
35 #include "google/protobuf/wire_format.h"
36 #include "upbc/common.h"
37 #include "upbc/message_layout.h"
38 
39 namespace upbc {
40 namespace {
41 
42 namespace protoc = ::google::protobuf::compiler;
43 namespace protobuf = ::google::protobuf;
44 
HeaderFilename(std::string proto_filename)45 std::string HeaderFilename(std::string proto_filename) {
46   return StripExtension(proto_filename) + ".upb.h";
47 }
48 
SourceFilename(std::string proto_filename)49 std::string SourceFilename(std::string proto_filename) {
50   return StripExtension(proto_filename) + ".upb.c";
51 }
52 
AddEnums(const protobuf::Descriptor * message,std::vector<const protobuf::EnumDescriptor * > * enums)53 void AddEnums(const protobuf::Descriptor* message,
54               std::vector<const protobuf::EnumDescriptor*>* enums) {
55   for (int i = 0; i < message->enum_type_count(); i++) {
56     enums->push_back(message->enum_type(i));
57   }
58   for (int i = 0; i < message->nested_type_count(); i++) {
59     AddEnums(message->nested_type(i), enums);
60   }
61 }
62 
63 template <class T>
SortDefs(std::vector<T> * defs)64 void SortDefs(std::vector<T>* defs) {
65   std::sort(defs->begin(), defs->end(),
66             [](T a, T b) { return a->full_name() < b->full_name(); });
67 }
68 
SortedEnums(const protobuf::FileDescriptor * file)69 std::vector<const protobuf::EnumDescriptor*> SortedEnums(
70     const protobuf::FileDescriptor* file) {
71   std::vector<const protobuf::EnumDescriptor*> enums;
72   for (int i = 0; i < file->enum_type_count(); i++) {
73     enums.push_back(file->enum_type(i));
74   }
75   for (int i = 0; i < file->message_type_count(); i++) {
76     AddEnums(file->message_type(i), &enums);
77   }
78   SortDefs(&enums);
79   return enums;
80 }
81 
FieldNumberOrder(const protobuf::Descriptor * message)82 std::vector<const protobuf::FieldDescriptor*> FieldNumberOrder(
83     const protobuf::Descriptor* message) {
84   std::vector<const protobuf::FieldDescriptor*> fields;
85   for (int i = 0; i < message->field_count(); i++) {
86     fields.push_back(message->field(i));
87   }
88   std::sort(fields.begin(), fields.end(),
89             [](const protobuf::FieldDescriptor* a,
90                const protobuf::FieldDescriptor* b) {
91               return a->number() < b->number();
92             });
93   return fields;
94 }
95 
SortedSubmessages(const protobuf::Descriptor * message)96 std::vector<const protobuf::FieldDescriptor*> SortedSubmessages(
97     const protobuf::Descriptor* message) {
98   std::vector<const protobuf::FieldDescriptor*> ret;
99   for (int i = 0; i < message->field_count(); i++) {
100     if (message->field(i)->cpp_type() ==
101         protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
102       ret.push_back(message->field(i));
103     }
104   }
105   std::sort(ret.begin(), ret.end(),
106             [](const protobuf::FieldDescriptor* a,
107                const protobuf::FieldDescriptor* b) {
108               return a->message_type()->full_name() <
109                      b->message_type()->full_name();
110             });
111   return ret;
112 }
113 
EnumValueSymbol(const protobuf::EnumValueDescriptor * value)114 std::string EnumValueSymbol(const protobuf::EnumValueDescriptor* value) {
115   return ToCIdent(value->full_name());
116 }
117 
GetSizeInit(const MessageLayout::Size & size)118 std::string GetSizeInit(const MessageLayout::Size& size) {
119   return absl::Substitute("UPB_SIZE($0, $1)", size.size32, size.size64);
120 }
121 
CTypeInternal(const protobuf::FieldDescriptor * field,bool is_const)122 std::string CTypeInternal(const protobuf::FieldDescriptor* field,
123                           bool is_const) {
124   std::string maybe_const = is_const ? "const " : "";
125   switch (field->cpp_type()) {
126     case protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
127       std::string maybe_struct =
128           field->file() != field->message_type()->file() ? "struct " : "";
129       return maybe_const + maybe_struct + MessageName(field->message_type()) +
130              "*";
131     }
132     case protobuf::FieldDescriptor::CPPTYPE_BOOL:
133       return "bool";
134     case protobuf::FieldDescriptor::CPPTYPE_FLOAT:
135       return "float";
136     case protobuf::FieldDescriptor::CPPTYPE_INT32:
137     case protobuf::FieldDescriptor::CPPTYPE_ENUM:
138       return "int32_t";
139     case protobuf::FieldDescriptor::CPPTYPE_UINT32:
140       return "uint32_t";
141     case protobuf::FieldDescriptor::CPPTYPE_DOUBLE:
142       return "double";
143     case protobuf::FieldDescriptor::CPPTYPE_INT64:
144       return "int64_t";
145     case protobuf::FieldDescriptor::CPPTYPE_UINT64:
146       return "uint64_t";
147     case protobuf::FieldDescriptor::CPPTYPE_STRING:
148       return "upb_strview";
149     default:
150       fprintf(stderr, "Unexpected type");
151       abort();
152   }
153 }
154 
SizeLg2(const protobuf::FieldDescriptor * field)155 std::string SizeLg2(const protobuf::FieldDescriptor* field) {
156   switch (field->cpp_type()) {
157     case protobuf::FieldDescriptor::CPPTYPE_MESSAGE:
158       return "UPB_SIZE(2, 3)";
159     case protobuf::FieldDescriptor::CPPTYPE_ENUM:
160       return std::to_string(2);
161     case protobuf::FieldDescriptor::CPPTYPE_BOOL:
162       return std::to_string(1);
163     case protobuf::FieldDescriptor::CPPTYPE_FLOAT:
164       return std::to_string(2);
165     case protobuf::FieldDescriptor::CPPTYPE_INT32:
166       return std::to_string(2);
167     case protobuf::FieldDescriptor::CPPTYPE_UINT32:
168       return std::to_string(2);
169     case protobuf::FieldDescriptor::CPPTYPE_DOUBLE:
170       return std::to_string(3);
171     case protobuf::FieldDescriptor::CPPTYPE_INT64:
172       return std::to_string(3);
173     case protobuf::FieldDescriptor::CPPTYPE_UINT64:
174       return std::to_string(3);
175     case protobuf::FieldDescriptor::CPPTYPE_STRING:
176       return "UPB_SIZE(3, 4)";
177     default:
178       fprintf(stderr, "Unexpected type");
179       abort();
180   }
181 }
182 
FieldDefault(const protobuf::FieldDescriptor * field)183 std::string FieldDefault(const protobuf::FieldDescriptor* field) {
184   switch (field->cpp_type()) {
185     case protobuf::FieldDescriptor::CPPTYPE_MESSAGE:
186       return "NULL";
187     case protobuf::FieldDescriptor::CPPTYPE_STRING:
188       return absl::Substitute("upb_strview_make(\"$0\", strlen(\"$0\"))",
189                               absl::CEscape(field->default_value_string()));
190     case protobuf::FieldDescriptor::CPPTYPE_INT32:
191       return absl::StrCat(field->default_value_int32());
192     case protobuf::FieldDescriptor::CPPTYPE_INT64:
193       return absl::StrCat(field->default_value_int64());
194     case protobuf::FieldDescriptor::CPPTYPE_UINT32:
195       return absl::StrCat(field->default_value_uint32());
196     case protobuf::FieldDescriptor::CPPTYPE_UINT64:
197       return absl::StrCat(field->default_value_uint64());
198     case protobuf::FieldDescriptor::CPPTYPE_FLOAT:
199       return absl::StrCat(field->default_value_float());
200     case protobuf::FieldDescriptor::CPPTYPE_DOUBLE:
201       return absl::StrCat(field->default_value_double());
202     case protobuf::FieldDescriptor::CPPTYPE_BOOL:
203       return field->default_value_bool() ? "true" : "false";
204     case protobuf::FieldDescriptor::CPPTYPE_ENUM:
205       // Use a number instead of a symbolic name so that we don't require
206       // this enum's header to be included.
207       return absl::StrCat(field->default_value_enum()->number());
208   }
209   ABSL_ASSERT(false);
210   return "XXX";
211 }
212 
CType(const protobuf::FieldDescriptor * field)213 std::string CType(const protobuf::FieldDescriptor* field) {
214   return CTypeInternal(field, false);
215 }
216 
CTypeConst(const protobuf::FieldDescriptor * field)217 std::string CTypeConst(const protobuf::FieldDescriptor* field) {
218   return CTypeInternal(field, true);
219 }
220 
DumpEnumValues(const protobuf::EnumDescriptor * desc,Output & output)221 void DumpEnumValues(const protobuf::EnumDescriptor* desc, Output& output) {
222   std::vector<const protobuf::EnumValueDescriptor*> values;
223   for (int i = 0; i < desc->value_count(); i++) {
224     values.push_back(desc->value(i));
225   }
226   std::sort(values.begin(), values.end(),
227             [](const protobuf::EnumValueDescriptor* a,
228                const protobuf::EnumValueDescriptor* b) {
229               return a->number() < b->number();
230             });
231 
232   for (size_t i = 0; i < values.size(); i++) {
233     auto value = values[i];
234     output("  $0 = $1", EnumValueSymbol(value), value->number());
235     if (i != values.size() - 1) {
236       output(",");
237     }
238     output("\n");
239   }
240 }
241 
GenerateMessageInHeader(const protobuf::Descriptor * message,Output & output)242 void GenerateMessageInHeader(const protobuf::Descriptor* message, Output& output) {
243   MessageLayout layout(message);
244 
245   output("/* $0 */\n\n", message->full_name());
246   std::string msg_name = ToCIdent(message->full_name());
247 
248   if (!message->options().map_entry()) {
249     output(
250         "UPB_INLINE $0 *$0_new(upb_arena *arena) {\n"
251         "  return ($0 *)_upb_msg_new(&$1, arena);\n"
252         "}\n"
253         "UPB_INLINE $0 *$0_parse(const char *buf, size_t size,\n"
254         "                        upb_arena *arena) {\n"
255         "  $0 *ret = $0_new(arena);\n"
256         "  if (!ret) return NULL;\n"
257         "  if (!upb_decode(buf, size, ret, &$1, arena)) return NULL;\n"
258         "  return ret;\n"
259         "}\n"
260         "UPB_INLINE $0 *$0_parse_ex(const char *buf, size_t size,\n"
261         "                           const upb_extreg *extreg, int options,\n"
262         "                           upb_arena *arena) {\n"
263         "  $0 *ret = $0_new(arena);\n"
264         "  if (!ret) return NULL;\n"
265         "  if (!_upb_decode(buf, size, ret, &$1, extreg, options, arena)) {\n"
266         "    return NULL;\n"
267         "  }\n"
268         "  return ret;\n"
269         "}\n"
270         "UPB_INLINE char *$0_serialize(const $0 *msg, upb_arena *arena, size_t "
271         "*len) {\n"
272         "  return upb_encode(msg, &$1, arena, len);\n"
273         "}\n"
274         "\n",
275         MessageName(message), MessageInit(message));
276   }
277 
278   for (int i = 0; i < message->real_oneof_decl_count(); i++) {
279     const protobuf::OneofDescriptor* oneof = message->oneof_decl(i);
280     std::string fullname = ToCIdent(oneof->full_name());
281     output("typedef enum {\n");
282     for (int j = 0; j < oneof->field_count(); j++) {
283       const protobuf::FieldDescriptor* field = oneof->field(j);
284       output("  $0_$1 = $2,\n", fullname, field->name(), field->number());
285     }
286     output(
287         "  $0_NOT_SET = 0\n"
288         "} $0_oneofcases;\n",
289         fullname);
290     output(
291         "UPB_INLINE $0_oneofcases $1_$2_case(const $1* msg) { "
292         "return ($0_oneofcases)*UPB_PTR_AT(msg, $3, int32_t); }\n"
293         "\n",
294         fullname, msg_name, oneof->name(),
295         GetSizeInit(layout.GetOneofCaseOffset(oneof)));
296   }
297 
298   // Generate const methods.
299 
300   for (auto field : FieldNumberOrder(message)) {
301     // Generate hazzer (if any).
302     if (layout.HasHasbit(field)) {
303       output(
304           "UPB_INLINE bool $0_has_$1(const $0 *msg) { "
305           "return _upb_hasbit(msg, $2); }\n",
306           msg_name, field->name(), layout.GetHasbitIndex(field));
307     } else if (field->real_containing_oneof()) {
308       output(
309           "UPB_INLINE bool $0_has_$1(const $0 *msg) { "
310           "return _upb_getoneofcase(msg, $2) == $3; }\n",
311           msg_name, field->name(),
312           GetSizeInit(
313               layout.GetOneofCaseOffset(field->real_containing_oneof())),
314           field->number());
315     } else if (field->message_type()) {
316       output(
317           "UPB_INLINE bool $0_has_$1(const $0 *msg) { "
318           "return _upb_has_submsg_nohasbit(msg, $2); }\n",
319           msg_name, field->name(), GetSizeInit(layout.GetFieldOffset(field)));
320     }
321 
322     // Generate getter.
323     if (field->is_map()) {
324       const protobuf::Descriptor* entry = field->message_type();
325       const protobuf::FieldDescriptor* key = entry->FindFieldByNumber(1);
326       const protobuf::FieldDescriptor* val = entry->FindFieldByNumber(2);
327       output(
328           "UPB_INLINE size_t $0_$1_size(const $0 *msg) {"
329           "return _upb_msg_map_size(msg, $2); }\n",
330           msg_name, field->name(), GetSizeInit(layout.GetFieldOffset(field)));
331       output(
332           "UPB_INLINE bool $0_$1_get(const $0 *msg, $2 key, $3 *val) { "
333           "return _upb_msg_map_get(msg, $4, &key, $5, val, $6); }\n",
334           msg_name, field->name(), CType(key), CType(val),
335           GetSizeInit(layout.GetFieldOffset(field)),
336           key->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING
337               ? "0"
338               : "sizeof(key)",
339           val->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING
340               ? "0"
341               : "sizeof(*val)");
342       output(
343           "UPB_INLINE $0 $1_$2_next(const $1 *msg, size_t* iter) { "
344           "return ($0)_upb_msg_map_next(msg, $3, iter); }\n",
345           CTypeConst(field), msg_name, field->name(),
346           GetSizeInit(layout.GetFieldOffset(field)));
347     } else if (message->options().map_entry()) {
348       output(
349           "UPB_INLINE $0 $1_$2(const $1 *msg) {\n"
350           "  $3 ret;\n"
351           "  _upb_msg_map_$2(msg, &ret, $4);\n"
352           "  return ret;\n"
353           "}\n",
354           CTypeConst(field), msg_name, field->name(), CType(field),
355           field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING
356               ? "0"
357               : "sizeof(ret)");
358     } else if (field->is_repeated()) {
359       output(
360           "UPB_INLINE $0 const* $1_$2(const $1 *msg, size_t *len) { "
361           "return ($0 const*)_upb_array_accessor(msg, $3, len); }\n",
362           CTypeConst(field), msg_name, field->name(),
363           GetSizeInit(layout.GetFieldOffset(field)));
364     } else if (field->real_containing_oneof()) {
365       output(
366           "UPB_INLINE $0 $1_$2(const $1 *msg) { "
367           "return UPB_READ_ONEOF(msg, $0, $3, $4, $5, $6); }\n",
368           CTypeConst(field), msg_name, field->name(),
369           GetSizeInit(layout.GetFieldOffset(field)),
370           GetSizeInit(layout.GetOneofCaseOffset(field->real_containing_oneof())),
371           field->number(), FieldDefault(field));
372     } else {
373       output(
374           "UPB_INLINE $0 $1_$2(const $1 *msg) { "
375           "return *UPB_PTR_AT(msg, $3, $0); }\n",
376           CTypeConst(field), msg_name, field->name(),
377           GetSizeInit(layout.GetFieldOffset(field)));
378     }
379   }
380 
381   output("\n");
382 
383   // Generate mutable methods.
384 
385   for (auto field : FieldNumberOrder(message)) {
386     if (field->is_map()) {
387       // TODO(haberman): add map-based mutators.
388       const protobuf::Descriptor* entry = field->message_type();
389       const protobuf::FieldDescriptor* key = entry->FindFieldByNumber(1);
390       const protobuf::FieldDescriptor* val = entry->FindFieldByNumber(2);
391       output(
392           "UPB_INLINE void $0_$1_clear($0 *msg) { _upb_msg_map_clear(msg, $2); }\n",
393           msg_name, field->name(),
394           GetSizeInit(layout.GetFieldOffset(field)));
395       output(
396           "UPB_INLINE bool $0_$1_set($0 *msg, $2 key, $3 val, upb_arena *a) { "
397           "return _upb_msg_map_set(msg, $4, &key, $5, &val, $6, a); }\n",
398           msg_name, field->name(), CType(key), CType(val),
399           GetSizeInit(layout.GetFieldOffset(field)),
400           key->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING
401               ? "0"
402               : "sizeof(key)",
403           val->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING
404               ? "0"
405               : "sizeof(val)");
406       output(
407           "UPB_INLINE bool $0_$1_delete($0 *msg, $2 key) { "
408           "return _upb_msg_map_delete(msg, $3, &key, $4); }\n",
409           msg_name, field->name(), CType(key),
410           GetSizeInit(layout.GetFieldOffset(field)),
411           key->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING
412               ? "0"
413               : "sizeof(key)");
414       output(
415           "UPB_INLINE $0 $1_$2_nextmutable($1 *msg, size_t* iter) { "
416           "return ($0)_upb_msg_map_next(msg, $3, iter); }\n",
417           CType(field), msg_name, field->name(),
418           GetSizeInit(layout.GetFieldOffset(field)));
419     } else if (field->is_repeated()) {
420       output(
421           "UPB_INLINE $0* $1_mutable_$2($1 *msg, size_t *len) {\n"
422           "  return ($0*)_upb_array_mutable_accessor(msg, $3, len);\n"
423           "}\n",
424           CType(field), msg_name, field->name(),
425           GetSizeInit(layout.GetFieldOffset(field)));
426       output(
427           "UPB_INLINE $0* $1_resize_$2($1 *msg, size_t len, "
428           "upb_arena *arena) {\n"
429           "  return ($0*)_upb_array_resize_accessor2(msg, $3, len, $4, arena);\n"
430           "}\n",
431           CType(field), msg_name, field->name(),
432           GetSizeInit(layout.GetFieldOffset(field)),
433           SizeLg2(field));
434       if (field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
435         output(
436             "UPB_INLINE struct $0* $1_add_$2($1 *msg, upb_arena *arena) {\n"
437             "  struct $0* sub = (struct $0*)_upb_msg_new(&$3, arena);\n"
438             "  bool ok = _upb_array_append_accessor2(\n"
439             "      msg, $4, $5, &sub, arena);\n"
440             "  if (!ok) return NULL;\n"
441             "  return sub;\n"
442             "}\n",
443             MessageName(field->message_type()), msg_name, field->name(),
444             MessageInit(field->message_type()),
445             GetSizeInit(layout.GetFieldOffset(field)),
446             SizeLg2(field));
447       } else {
448         output(
449             "UPB_INLINE bool $1_add_$2($1 *msg, $0 val, upb_arena *arena) {\n"
450             "  return _upb_array_append_accessor2(msg, $3, $4, &val,\n"
451             "      arena);\n"
452             "}\n",
453             CType(field), msg_name, field->name(),
454             GetSizeInit(layout.GetFieldOffset(field)),
455             SizeLg2(field));
456       }
457     } else {
458       // Non-repeated field.
459       if (message->options().map_entry() && field->name() == "key") {
460         // Key cannot be mutated.
461         continue;
462       }
463 
464       // The common function signature for all setters.  Varying implementations
465       // follow.
466       output("UPB_INLINE void $0_set_$1($0 *msg, $2 value) {\n", msg_name,
467              field->name(), CType(field));
468 
469       if (message->options().map_entry()) {
470         output(
471             "  _upb_msg_map_set_value(msg, &value, $0);\n"
472             "}\n",
473             field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_STRING
474                 ? "0"
475                 : "sizeof(" + CType(field) + ")");
476       } else if (field->real_containing_oneof()) {
477         output(
478             "  UPB_WRITE_ONEOF(msg, $0, $1, value, $2, $3);\n"
479             "}\n",
480             CType(field), GetSizeInit(layout.GetFieldOffset(field)),
481             GetSizeInit(
482                 layout.GetOneofCaseOffset(field->real_containing_oneof())),
483             field->number());
484       } else {
485         if (MessageLayout::HasHasbit(field)) {
486           output("  _upb_sethas(msg, $0);\n", layout.GetHasbitIndex(field));
487         }
488         output(
489             "  *UPB_PTR_AT(msg, $1, $0) = value;\n"
490             "}\n",
491             CType(field), GetSizeInit(layout.GetFieldOffset(field)));
492       }
493 
494       if (field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_MESSAGE &&
495           !message->options().map_entry()) {
496         output(
497             "UPB_INLINE struct $0* $1_mutable_$2($1 *msg, upb_arena *arena) {\n"
498             "  struct $0* sub = (struct $0*)$1_$2(msg);\n"
499             "  if (sub == NULL) {\n"
500             "    sub = (struct $0*)_upb_msg_new(&$3, arena);\n"
501             "    if (!sub) return NULL;\n"
502             "    $1_set_$2(msg, sub);\n"
503             "  }\n"
504             "  return sub;\n"
505             "}\n",
506             MessageName(field->message_type()), msg_name, field->name(),
507             MessageInit(field->message_type()));
508       }
509     }
510   }
511 
512   output("\n");
513 }
514 
WriteHeader(const protobuf::FileDescriptor * file,Output & output)515 void WriteHeader(const protobuf::FileDescriptor* file, Output& output) {
516   EmitFileWarning(file, output);
517   output(
518       "#ifndef $0_UPB_H_\n"
519       "#define $0_UPB_H_\n\n"
520       "#include \"upb/msg_internal.h\"\n"
521       "#include \"upb/decode.h\"\n"
522       "#include \"upb/decode_fast.h\"\n"
523       "#include \"upb/encode.h\"\n\n",
524       ToPreproc(file->name()));
525 
526   for (int i = 0; i < file->public_dependency_count(); i++) {
527     const auto& name = file->public_dependency(i)->name();
528     if (i == 0) {
529       output("/* Public Imports. */\n");
530     }
531     output("#include \"$0\"\n", HeaderFilename(name));
532     if (i == file->public_dependency_count() - 1) {
533       output("\n");
534     }
535   }
536 
537   output(
538       "#include \"upb/port_def.inc\"\n"
539       "\n"
540       "#ifdef __cplusplus\n"
541       "extern \"C\" {\n"
542       "#endif\n"
543       "\n");
544 
545   const std::vector<const protobuf::Descriptor*> this_file_messages =
546       SortedMessages(file);
547 
548   // Forward-declare types defined in this file.
549   for (auto message : this_file_messages) {
550     output("struct $0;\n", ToCIdent(message->full_name()));
551   }
552   for (auto message : this_file_messages) {
553     output("typedef struct $0 $0;\n", ToCIdent(message->full_name()));
554   }
555   for (auto message : this_file_messages) {
556     output("extern const upb_msglayout $0;\n", MessageInit(message));
557   }
558 
559   // Forward-declare types not in this file, but used as submessages.
560   // Order by full name for consistent ordering.
561   std::map<std::string, const protobuf::Descriptor*> forward_messages;
562 
563   for (auto* message : this_file_messages) {
564     for (int i = 0; i < message->field_count(); i++) {
565       const protobuf::FieldDescriptor* field = message->field(i);
566       if (field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_MESSAGE &&
567           field->file() != field->message_type()->file()) {
568         forward_messages[field->message_type()->full_name()] =
569             field->message_type();
570       }
571     }
572   }
573   for (const auto& pair : forward_messages) {
574     output("struct $0;\n", MessageName(pair.second));
575   }
576   for (const auto& pair : forward_messages) {
577     output("extern const upb_msglayout $0;\n", MessageInit(pair.second));
578   }
579 
580   if (!this_file_messages.empty()) {
581     output("\n");
582   }
583 
584   std::vector<const protobuf::EnumDescriptor*> this_file_enums =
585       SortedEnums(file);
586 
587   for (auto enumdesc : this_file_enums) {
588     output("typedef enum {\n");
589     DumpEnumValues(enumdesc, output);
590     output("} $0;\n\n", ToCIdent(enumdesc->full_name()));
591   }
592 
593   output("\n");
594 
595   for (auto message : this_file_messages) {
596     GenerateMessageInHeader(message, output);
597   }
598 
599   output(
600       "#ifdef __cplusplus\n"
601       "}  /* extern \"C\" */\n"
602       "#endif\n"
603       "\n"
604       "#include \"upb/port_undef.inc\"\n"
605       "\n"
606       "#endif  /* $0_UPB_H_ */\n",
607       ToPreproc(file->name()));
608 }
609 
TableDescriptorType(const protobuf::FieldDescriptor * field)610 int TableDescriptorType(const protobuf::FieldDescriptor* field) {
611   if (field->file()->syntax() == protobuf::FileDescriptor::SYNTAX_PROTO2 &&
612       field->type() == protobuf::FieldDescriptor::TYPE_STRING) {
613     // From the perspective of the binary encoder/decoder, proto2 string fields
614     // are identical to bytes fields. Only in proto3 do we check UTF-8 for
615     // string fields at parse time.
616     //
617     // If we ever use these tables for JSON encoding/decoding (for example by
618     // embedding field names on the side) we will have to revisit this, because
619     // string vs. bytes behavior is not affected by proto2 vs proto3.
620     return protobuf::FieldDescriptor::TYPE_BYTES;
621   } else {
622     return field->type();
623   }
624 }
625 
626 struct SubmsgArray {
627  public:
SubmsgArrayupbc::__anond357e7fd0111::SubmsgArray628   SubmsgArray(const protobuf::Descriptor* message) : message_(message) {
629     MessageLayout layout(message);
630     std::vector<const protobuf::FieldDescriptor*> sorted_submsgs =
631         SortedSubmessages(message);
632     int i = 0;
633     for (auto submsg : sorted_submsgs) {
634       if (indexes_.find(submsg->message_type()) != indexes_.end()) {
635         continue;
636       }
637       submsgs_.push_back(submsg->message_type());
638       indexes_[submsg->message_type()] = i++;
639     }
640   }
641 
submsgsupbc::__anond357e7fd0111::SubmsgArray642   const std::vector<const protobuf::Descriptor*>& submsgs() const {
643     return submsgs_;
644   }
645 
GetIndexupbc::__anond357e7fd0111::SubmsgArray646   int GetIndex(const protobuf::FieldDescriptor* field) {
647     (void)message_;
648     assert(field->containing_type() == message_);
649     auto it = indexes_.find(field->message_type());
650     assert(it != indexes_.end());
651     return it->second;
652   }
653 
654  private:
655   const protobuf::Descriptor* message_;
656   std::vector<const protobuf::Descriptor*> submsgs_;
657   absl::flat_hash_map<const protobuf::Descriptor*, int> indexes_;
658 };
659 
660 typedef std::pair<std::string, uint64_t> TableEntry;
661 
GetEncodedTag(const protobuf::FieldDescriptor * field)662 uint64_t GetEncodedTag(const protobuf::FieldDescriptor* field) {
663   protobuf::internal::WireFormatLite::WireType wire_type =
664       protobuf::internal::WireFormat::WireTypeForField(field);
665   uint32_t unencoded_tag =
666       protobuf::internal::WireFormatLite::MakeTag(field->number(), wire_type);
667   uint8_t tag_bytes[10] = {0};
668   protobuf::io::CodedOutputStream::WriteVarint32ToArray(unencoded_tag,
669                                                         tag_bytes);
670   uint64_t encoded_tag = 0;
671   memcpy(&encoded_tag, tag_bytes, sizeof(encoded_tag));
672   // TODO: byte-swap for big endian.
673   return encoded_tag;
674 }
675 
GetTableSlot(const protobuf::FieldDescriptor * field)676 int GetTableSlot(const protobuf::FieldDescriptor* field) {
677   uint64_t tag = GetEncodedTag(field);
678   if (tag > 0x7fff) {
679     // Tag must fit within a two-byte varint.
680     return -1;
681   }
682   return (tag & 0xf8) >> 3;
683 }
684 
TryFillTableEntry(const protobuf::Descriptor * message,const MessageLayout & layout,const protobuf::FieldDescriptor * field,TableEntry & ent)685 bool TryFillTableEntry(const protobuf::Descriptor* message,
686                        const MessageLayout& layout,
687                        const protobuf::FieldDescriptor* field,
688                        TableEntry& ent) {
689   std::string type = "";
690   std::string cardinality = "";
691   switch (field->type()) {
692     case protobuf::FieldDescriptor::TYPE_BOOL:
693       type = "b1";
694       break;
695     case protobuf::FieldDescriptor::TYPE_INT32:
696     case protobuf::FieldDescriptor::TYPE_ENUM:
697     case protobuf::FieldDescriptor::TYPE_UINT32:
698       type = "v4";
699       break;
700     case protobuf::FieldDescriptor::TYPE_INT64:
701     case protobuf::FieldDescriptor::TYPE_UINT64:
702       type = "v8";
703       break;
704     case protobuf::FieldDescriptor::TYPE_FIXED32:
705     case protobuf::FieldDescriptor::TYPE_SFIXED32:
706     case protobuf::FieldDescriptor::TYPE_FLOAT:
707       type = "f4";
708       break;
709     case protobuf::FieldDescriptor::TYPE_FIXED64:
710     case protobuf::FieldDescriptor::TYPE_SFIXED64:
711     case protobuf::FieldDescriptor::TYPE_DOUBLE:
712       type = "f8";
713       break;
714     case protobuf::FieldDescriptor::TYPE_SINT32:
715       type = "z4";
716       break;
717     case protobuf::FieldDescriptor::TYPE_SINT64:
718       type = "z8";
719       break;
720     case protobuf::FieldDescriptor::TYPE_STRING:
721       if (field->file()->syntax() == protobuf::FileDescriptor::SYNTAX_PROTO3) {
722         // Only proto3 validates UTF-8.
723         type = "s";
724         break;
725       }
726       ABSL_FALLTHROUGH_INTENDED;
727     case protobuf::FieldDescriptor::TYPE_BYTES:
728       type = "b";
729       break;
730     case protobuf::FieldDescriptor::TYPE_MESSAGE:
731       if (field->is_map()) {
732         return false;  // Not supported yet (ever?).
733       }
734       type = "m";
735       break;
736     default:
737       return false;  // Not supported yet.
738   }
739 
740   switch (field->label()) {
741     case protobuf::FieldDescriptor::LABEL_REPEATED:
742       if (field->is_packed()) {
743         cardinality = "p";
744       } else {
745         cardinality = "r";
746       }
747       break;
748     case protobuf::FieldDescriptor::LABEL_OPTIONAL:
749     case protobuf::FieldDescriptor::LABEL_REQUIRED:
750       if (field->real_containing_oneof()) {
751         cardinality = "o";
752       } else {
753         cardinality = "s";
754       }
755       break;
756   }
757 
758   uint64_t expected_tag = GetEncodedTag(field);
759   MessageLayout::Size offset = layout.GetFieldOffset(field);
760 
761   // Data is:
762   //
763   //                  48                32                16                 0
764   // |--------|--------|--------|--------|--------|--------|--------|--------|
765   // |   offset (16)   |case offset (16) |presence| submsg |  exp. tag (16)  |
766   // |--------|--------|--------|--------|--------|--------|--------|--------|
767   //
768   // - |presence| is either hasbit index or field number for oneofs.
769 
770   uint64_t data = offset.size64 << 48 | expected_tag;
771 
772   if (field->is_repeated()) {
773     // No hasbit/oneof-related fields.
774   } if (field->real_containing_oneof()) {
775     MessageLayout::Size case_offset =
776         layout.GetOneofCaseOffset(field->real_containing_oneof());
777     if (case_offset.size64 > 0xffff) return false;
778     assert(field->number() < 256);
779     data |= field->number() << 24;
780     data |= case_offset.size64 << 32;
781   } else {
782     uint64_t hasbit_index = 63;  // No hasbit (set a high, unused bit).
783     if (layout.HasHasbit(field)) {
784       hasbit_index = layout.GetHasbitIndex(field);
785       if (hasbit_index > 31) return false;
786     }
787     data |= hasbit_index << 24;
788   }
789 
790   if (field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
791     SubmsgArray submsg_array(message);
792     uint64_t idx = submsg_array.GetIndex(field);
793     if (idx > 255) return false;
794     data |= idx << 16;
795 
796     std::string size_ceil = "max";
797     size_t size = SIZE_MAX;
798     if (field->message_type()->file() == field->file()) {
799       // We can only be guaranteed the size of the sub-message if it is in the
800       // same file as us.  We could relax this to increase the speed of
801       // cross-file sub-message parsing if we are comfortable requiring that
802       // users compile all messages at the same time.
803       MessageLayout sub_layout(field->message_type());
804       size = sub_layout.message_size().size64 + 8;
805     }
806     std::vector<size_t> breaks = {64, 128, 192, 256};
807     for (auto brk : breaks) {
808       if (size <= brk) {
809         size_ceil = std::to_string(brk);
810         break;
811       }
812     }
813     ent.first = absl::Substitute("upb_p$0$1_$2bt_max$3b", cardinality, type,
814                                  expected_tag > 0xff ? "2" : "1", size_ceil);
815 
816   } else {
817     ent.first = absl::Substitute("upb_p$0$1_$2bt", cardinality, type,
818                                  expected_tag > 0xff ? "2" : "1");
819   }
820   ent.second = data;
821   return true;
822 }
823 
FastDecodeTable(const protobuf::Descriptor * message,const MessageLayout & layout)824 std::vector<TableEntry> FastDecodeTable(const protobuf::Descriptor* message,
825                                         const MessageLayout& layout) {
826   std::vector<TableEntry> table;
827   for (const auto field : FieldHotnessOrder(message)) {
828     TableEntry ent;
829     int slot = GetTableSlot(field);
830     // std::cerr << "table slot: " << field->number() << ": " << slot << "\n";
831     if (slot < 0) {
832       // Tag can't fit in the table.
833       continue;
834     }
835     if (!TryFillTableEntry(message, layout, field, ent)) {
836       // Unsupported field type or offset, hasbit index, etc. doesn't fit.
837       continue;
838     }
839     while ((size_t)slot >= table.size()) {
840       size_t size = std::max(static_cast<size_t>(1), table.size() * 2);
841       table.resize(size, TableEntry{"fastdecode_generic", 0});
842     }
843     if (table[slot].first != "fastdecode_generic") {
844       // A hotter field already filled this slot.
845       continue;
846     }
847     table[slot] = ent;
848   }
849   return table;
850 }
851 
WriteField(const protobuf::FieldDescriptor * field,absl::string_view offset,absl::string_view presence,int submsg_index,Output & output)852 void WriteField(const protobuf::FieldDescriptor* field,
853                 absl::string_view offset, absl::string_view presence,
854                 int submsg_index, Output& output) {
855   std::string mode;
856   if (field->is_map()) {
857     mode = "_UPB_MODE_MAP";
858   } else if (field->is_repeated()) {
859     mode = "_UPB_MODE_ARRAY";
860   } else {
861     mode = "_UPB_MODE_SCALAR";
862   }
863 
864   if (field->is_packed()) {
865     absl::StrAppend(&mode, " | _UPB_MODE_IS_PACKED");
866   }
867 
868   output("{$0, $1, $2, $3, $4, $5}", field->number(), offset, presence,
869          submsg_index, TableDescriptorType(field), mode);
870 }
871 
872 // Writes a single field into a .upb.c source file.
WriteMessageField(const protobuf::FieldDescriptor * field,const MessageLayout & layout,int submsg_index,Output & output)873 void WriteMessageField(const protobuf::FieldDescriptor* field,
874                        const MessageLayout& layout, int submsg_index,
875                        Output& output) {
876   std::string presence = "0";
877 
878   if (MessageLayout::HasHasbit(field)) {
879     int index = layout.GetHasbitIndex(field);
880     assert(index != 0);
881     presence = absl::StrCat(index);
882   } else if (field->real_containing_oneof()) {
883     MessageLayout::Size case_offset =
884         layout.GetOneofCaseOffset(field->real_containing_oneof());
885 
886     // We encode as negative to distinguish from hasbits.
887     case_offset.size32 = ~case_offset.size32;
888     case_offset.size64 = ~case_offset.size64;
889     assert(case_offset.size32 < 0);
890     assert(case_offset.size64 < 0);
891     presence = GetSizeInit(case_offset);
892   }
893 
894   output("  ");
895   WriteField(field, GetSizeInit(layout.GetFieldOffset(field)), presence,
896              submsg_index, output);
897   output(",\n");
898 }
899 
900 // Writes a single message into a .upb.c source file.
WriteMessage(const protobuf::Descriptor * message,Output & output,bool fasttable_enabled)901 void WriteMessage(const protobuf::Descriptor* message, Output& output,
902                   bool fasttable_enabled) {
903   std::string msg_name = ToCIdent(message->full_name());
904   std::string fields_array_ref = "NULL";
905   std::string submsgs_array_ref = "NULL";
906   uint8_t dense_below = 0;
907   const int dense_below_max = std::numeric_limits<decltype(dense_below)>::max();
908   MessageLayout layout(message);
909   SubmsgArray submsg_array(message);
910 
911   if (!submsg_array.submsgs().empty()) {
912     // TODO(haberman): could save a little bit of space by only generating a
913     // "submsgs" array for every strongly-connected component.
914     std::string submsgs_array_name = msg_name + "_submsgs";
915     submsgs_array_ref = "&" + submsgs_array_name + "[0]";
916     output("static const upb_msglayout *const $0[$1] = {\n",
917            submsgs_array_name, submsg_array.submsgs().size());
918 
919     for (auto submsg : submsg_array.submsgs()) {
920       output("  &$0,\n", MessageInit(submsg));
921     }
922 
923     output("};\n\n");
924   }
925 
926   std::vector<const protobuf::FieldDescriptor*> field_number_order =
927       FieldNumberOrder(message);
928   if (!field_number_order.empty()) {
929     std::string fields_array_name = msg_name + "__fields";
930     fields_array_ref = "&" + fields_array_name + "[0]";
931     output("static const upb_msglayout_field $0[$1] = {\n",
932            fields_array_name, field_number_order.size());
933     for (int i = 0; i < static_cast<int>(field_number_order.size()); i++) {
934       auto field = field_number_order[i];
935       int submsg_index = 0;
936 
937       if (i < dense_below_max && field->number() == i + 1 &&
938           (i == 0 || field_number_order[i - 1]->number() == i)) {
939         dense_below = i + 1;
940       }
941 
942       if (field->cpp_type() == protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
943         submsg_index = submsg_array.GetIndex(field);
944       }
945 
946       WriteMessageField(field, layout, submsg_index, output);
947     }
948     output("};\n\n");
949   }
950 
951   std::vector<TableEntry> table;
952   uint8_t table_mask = -1;
953 
954   if (fasttable_enabled) {
955     table = FastDecodeTable(message, layout);
956   }
957 
958   if (table.size() > 1) {
959     assert((table.size() & (table.size() - 1)) == 0);
960     table_mask = (table.size() - 1) << 3;
961   }
962 
963   output("const upb_msglayout $0 = {\n", MessageInit(message));
964   output("  $0,\n", submsgs_array_ref);
965   output("  $0,\n", fields_array_ref);
966   output("  $0, $1, $2, $3, $4,\n", GetSizeInit(layout.message_size()),
967          field_number_order.size(),
968          "false",  // TODO: extendable
969          dense_below,
970          table_mask
971   );
972   if (!table.empty()) {
973     output("  UPB_FASTTABLE_INIT({\n");
974     for (const auto& ent : table) {
975       output("    {0x$1, &$0},\n", ent.first,
976              absl::StrCat(absl::Hex(ent.second, absl::kZeroPad16)));
977     }
978     output("  }),\n");
979   }
980   output("};\n\n");
981 }
982 
WriteMessages(const protobuf::FileDescriptor * file,Output & output,bool fasttable_enabled)983 void WriteMessages(const protobuf::FileDescriptor* file, Output& output,
984                    bool fasttable_enabled) {
985   for (auto* message : SortedMessages(file)) {
986     WriteMessage(message, output, fasttable_enabled);
987   }
988 }
989 
990 // Writes a .upb.c source file.
WriteSource(const protobuf::FileDescriptor * file,Output & output,bool fasttable_enabled)991 void WriteSource(const protobuf::FileDescriptor* file, Output& output,
992                  bool fasttable_enabled) {
993   EmitFileWarning(file, output);
994 
995   output(
996       "#include <stddef.h>\n"
997       "#include \"upb/msg_internal.h\"\n"
998       "#include \"$0\"\n",
999       HeaderFilename(file->name()));
1000 
1001   for (int i = 0; i < file->dependency_count(); i++) {
1002     output("#include \"$0\"\n", HeaderFilename(file->dependency(i)->name()));
1003   }
1004 
1005   output(
1006       "\n"
1007       "#include \"upb/port_def.inc\"\n"
1008       "\n");
1009 
1010   WriteMessages(file, output, fasttable_enabled);
1011 
1012   output("#include \"upb/port_undef.inc\"\n");
1013   output("\n");
1014 }
1015 
1016 class Generator : public protoc::CodeGenerator {
~Generator()1017   ~Generator() override {}
1018   bool Generate(const protobuf::FileDescriptor* file,
1019                 const std::string& parameter, protoc::GeneratorContext* context,
1020                 std::string* error) const override;
GetSupportedFeatures() const1021   uint64_t GetSupportedFeatures() const override {
1022     return FEATURE_PROTO3_OPTIONAL;
1023   }
1024 };
1025 
Generate(const protobuf::FileDescriptor * file,const std::string & parameter,protoc::GeneratorContext * context,std::string * error) const1026 bool Generator::Generate(const protobuf::FileDescriptor* file,
1027                          const std::string& parameter,
1028                          protoc::GeneratorContext* context,
1029                          std::string* error) const {
1030   bool fasttable_enabled = false;
1031   std::vector<std::pair<std::string, std::string>> params;
1032   google::protobuf::compiler::ParseGeneratorParameter(parameter, &params);
1033 
1034   for (const auto& pair : params) {
1035     if (pair.first == "fasttable") {
1036       fasttable_enabled = true;
1037     } else {
1038       *error = "Unknown parameter: " + pair.first;
1039       return false;
1040     }
1041   }
1042 
1043   Output h_output(context->Open(HeaderFilename(file->name())));
1044   WriteHeader(file, h_output);
1045 
1046   Output c_output(context->Open(SourceFilename(file->name())));
1047   WriteSource(file, c_output, fasttable_enabled);
1048 
1049   return true;
1050 }
1051 
1052 }  // namespace
1053 }  // namespace upbc
1054 
main(int argc,char ** argv)1055 int main(int argc, char** argv) {
1056   std::unique_ptr<google::protobuf::compiler::CodeGenerator> generator(
1057       new upbc::Generator());
1058   return google::protobuf::compiler::PluginMain(argc, argv, generator.get());
1059 }
1060