1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: kenton@google.com (Kenton Varda)
32 //  Based on original Protocol Buffers design by
33 //  Sanjay Ghemawat, Jeff Dean, and others.
34 
35 #include <google/protobuf/descriptor_database.h>
36 
37 #include <set>
38 
39 #include <google/protobuf/descriptor.pb.h>
40 #include <google/protobuf/stubs/map_util.h>
41 #include <google/protobuf/stubs/stl_util.h>
42 
43 
44 namespace google {
45 namespace protobuf {
46 
47 namespace {
RecordMessageNames(const DescriptorProto & desc_proto,const std::string & prefix,std::set<std::string> * output)48 void RecordMessageNames(const DescriptorProto& desc_proto,
49                         const std::string& prefix,
50                         std::set<std::string>* output) {
51   GOOGLE_CHECK(desc_proto.has_name());
52   std::string full_name = prefix.empty()
53                               ? desc_proto.name()
54                               : StrCat(prefix, ".", desc_proto.name());
55   output->insert(full_name);
56 
57   for (const auto& d : desc_proto.nested_type()) {
58     RecordMessageNames(d, full_name, output);
59   }
60 }
61 
RecordMessageNames(const FileDescriptorProto & file_proto,std::set<std::string> * output)62 void RecordMessageNames(const FileDescriptorProto& file_proto,
63                         std::set<std::string>* output) {
64   for (const auto& d : file_proto.message_type()) {
65     RecordMessageNames(d, file_proto.package(), output);
66   }
67 }
68 
69 template <typename Fn>
ForAllFileProtos(DescriptorDatabase * db,Fn callback,std::vector<std::string> * output)70 bool ForAllFileProtos(DescriptorDatabase* db, Fn callback,
71                       std::vector<std::string>* output) {
72   std::vector<std::string> file_names;
73   if (!db->FindAllFileNames(&file_names)) {
74     return false;
75   }
76   std::set<std::string> set;
77   FileDescriptorProto file_proto;
78   for (const auto& f : file_names) {
79     file_proto.Clear();
80     if (!db->FindFileByName(f, &file_proto)) {
81       GOOGLE_LOG(ERROR) << "File not found in database (unexpected): " << f;
82       return false;
83     }
84     callback(file_proto, &set);
85   }
86   output->insert(output->end(), set.begin(), set.end());
87   return true;
88 }
89 }  // namespace
90 
~DescriptorDatabase()91 DescriptorDatabase::~DescriptorDatabase() {}
92 
FindAllPackageNames(std::vector<std::string> * output)93 bool DescriptorDatabase::FindAllPackageNames(std::vector<std::string>* output) {
94   return ForAllFileProtos(
95       this,
96       [](const FileDescriptorProto& file_proto, std::set<std::string>* set) {
97         set->insert(file_proto.package());
98       },
99       output);
100 }
101 
FindAllMessageNames(std::vector<std::string> * output)102 bool DescriptorDatabase::FindAllMessageNames(std::vector<std::string>* output) {
103   return ForAllFileProtos(
104       this,
105       [](const FileDescriptorProto& file_proto, std::set<std::string>* set) {
106         RecordMessageNames(file_proto, set);
107       },
108       output);
109 }
110 
111 // ===================================================================
112 
SimpleDescriptorDatabase()113 SimpleDescriptorDatabase::SimpleDescriptorDatabase() {}
~SimpleDescriptorDatabase()114 SimpleDescriptorDatabase::~SimpleDescriptorDatabase() {}
115 
116 template <typename Value>
AddFile(const FileDescriptorProto & file,Value value)117 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddFile(
118     const FileDescriptorProto& file, Value value) {
119   if (!InsertIfNotPresent(&by_name_, file.name(), value)) {
120     GOOGLE_LOG(ERROR) << "File already exists in database: " << file.name();
121     return false;
122   }
123 
124   // We must be careful here -- calling file.package() if file.has_package() is
125   // false could access an uninitialized static-storage variable if we are being
126   // run at startup time.
127   std::string path = file.has_package() ? file.package() : std::string();
128   if (!path.empty()) path += '.';
129 
130   for (int i = 0; i < file.message_type_size(); i++) {
131     if (!AddSymbol(path + file.message_type(i).name(), value)) return false;
132     if (!AddNestedExtensions(file.name(), file.message_type(i), value))
133       return false;
134   }
135   for (int i = 0; i < file.enum_type_size(); i++) {
136     if (!AddSymbol(path + file.enum_type(i).name(), value)) return false;
137   }
138   for (int i = 0; i < file.extension_size(); i++) {
139     if (!AddSymbol(path + file.extension(i).name(), value)) return false;
140     if (!AddExtension(file.name(), file.extension(i), value)) return false;
141   }
142   for (int i = 0; i < file.service_size(); i++) {
143     if (!AddSymbol(path + file.service(i).name(), value)) return false;
144   }
145 
146   return true;
147 }
148 
149 namespace {
150 
151 // Returns true if and only if all characters in the name are alphanumerics,
152 // underscores, or periods.
ValidateSymbolName(StringPiece name)153 bool ValidateSymbolName(StringPiece name) {
154   for (char c : name) {
155     // I don't trust ctype.h due to locales.  :(
156     if (c != '.' && c != '_' && (c < '0' || c > '9') && (c < 'A' || c > 'Z') &&
157         (c < 'a' || c > 'z')) {
158       return false;
159     }
160   }
161   return true;
162 }
163 
164 // Find the last key in the container which sorts less than or equal to the
165 // symbol name.  Since upper_bound() returns the *first* key that sorts
166 // *greater* than the input, we want the element immediately before that.
167 template <typename Container, typename Key>
FindLastLessOrEqual(const Container * container,const Key & key)168 typename Container::const_iterator FindLastLessOrEqual(
169     const Container* container, const Key& key) {
170   auto iter = container->upper_bound(key);
171   if (iter != container->begin()) --iter;
172   return iter;
173 }
174 
175 // As above, but using std::upper_bound instead.
176 template <typename Container, typename Key, typename Cmp>
FindLastLessOrEqual(const Container * container,const Key & key,const Cmp & cmp)177 typename Container::const_iterator FindLastLessOrEqual(
178     const Container* container, const Key& key, const Cmp& cmp) {
179   auto iter = std::upper_bound(container->begin(), container->end(), key, cmp);
180   if (iter != container->begin()) --iter;
181   return iter;
182 }
183 
184 // True if either the arguments are equal or super_symbol identifies a
185 // parent symbol of sub_symbol (e.g. "foo.bar" is a parent of
186 // "foo.bar.baz", but not a parent of "foo.barbaz").
IsSubSymbol(StringPiece sub_symbol,StringPiece super_symbol)187 bool IsSubSymbol(StringPiece sub_symbol, StringPiece super_symbol) {
188   return sub_symbol == super_symbol ||
189          (HasPrefixString(super_symbol, sub_symbol) &&
190           super_symbol[sub_symbol.size()] == '.');
191 }
192 
193 }  // namespace
194 
195 template <typename Value>
AddSymbol(const std::string & name,Value value)196 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddSymbol(
197     const std::string& name, Value value) {
198   // We need to make sure not to violate our map invariant.
199 
200   // If the symbol name is invalid it could break our lookup algorithm (which
201   // relies on the fact that '.' sorts before all other characters that are
202   // valid in symbol names).
203   if (!ValidateSymbolName(name)) {
204     GOOGLE_LOG(ERROR) << "Invalid symbol name: " << name;
205     return false;
206   }
207 
208   // Try to look up the symbol to make sure a super-symbol doesn't already
209   // exist.
210   auto iter = FindLastLessOrEqual(&by_symbol_, name);
211 
212   if (iter == by_symbol_.end()) {
213     // Apparently the map is currently empty.  Just insert and be done with it.
214     by_symbol_.insert(
215         typename std::map<std::string, Value>::value_type(name, value));
216     return true;
217   }
218 
219   if (IsSubSymbol(iter->first, name)) {
220     GOOGLE_LOG(ERROR) << "Symbol name \"" << name
221                << "\" conflicts with the existing "
222                   "symbol \""
223                << iter->first << "\".";
224     return false;
225   }
226 
227   // OK, that worked.  Now we have to make sure that no symbol in the map is
228   // a sub-symbol of the one we are inserting.  The only symbol which could
229   // be so is the first symbol that is greater than the new symbol.  Since
230   // |iter| points at the last symbol that is less than or equal, we just have
231   // to increment it.
232   ++iter;
233 
234   if (iter != by_symbol_.end() && IsSubSymbol(name, iter->first)) {
235     GOOGLE_LOG(ERROR) << "Symbol name \"" << name
236                << "\" conflicts with the existing "
237                   "symbol \""
238                << iter->first << "\".";
239     return false;
240   }
241 
242   // OK, no conflicts.
243 
244   // Insert the new symbol using the iterator as a hint, the new entry will
245   // appear immediately before the one the iterator is pointing at.
246   by_symbol_.insert(
247       iter, typename std::map<std::string, Value>::value_type(name, value));
248 
249   return true;
250 }
251 
252 template <typename Value>
AddNestedExtensions(const std::string & filename,const DescriptorProto & message_type,Value value)253 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddNestedExtensions(
254     const std::string& filename, const DescriptorProto& message_type,
255     Value value) {
256   for (int i = 0; i < message_type.nested_type_size(); i++) {
257     if (!AddNestedExtensions(filename, message_type.nested_type(i), value))
258       return false;
259   }
260   for (int i = 0; i < message_type.extension_size(); i++) {
261     if (!AddExtension(filename, message_type.extension(i), value)) return false;
262   }
263   return true;
264 }
265 
266 template <typename Value>
AddExtension(const std::string & filename,const FieldDescriptorProto & field,Value value)267 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddExtension(
268     const std::string& filename, const FieldDescriptorProto& field,
269     Value value) {
270   if (!field.extendee().empty() && field.extendee()[0] == '.') {
271     // The extension is fully-qualified.  We can use it as a lookup key in
272     // the by_symbol_ table.
273     if (!InsertIfNotPresent(
274             &by_extension_,
275             std::make_pair(field.extendee().substr(1), field.number()),
276             value)) {
277       GOOGLE_LOG(ERROR) << "Extension conflicts with extension already in database: "
278                     "extend "
279                  << field.extendee() << " { " << field.name() << " = "
280                  << field.number() << " } from:" << filename;
281       return false;
282     }
283   } else {
284     // Not fully-qualified.  We can't really do anything here, unfortunately.
285     // We don't consider this an error, though, because the descriptor is
286     // valid.
287   }
288   return true;
289 }
290 
291 template <typename Value>
FindFile(const std::string & filename)292 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindFile(
293     const std::string& filename) {
294   return FindWithDefault(by_name_, filename, Value());
295 }
296 
297 template <typename Value>
FindSymbol(const std::string & name)298 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindSymbol(
299     const std::string& name) {
300   auto iter = FindLastLessOrEqual(&by_symbol_, name);
301 
302   return (iter != by_symbol_.end() && IsSubSymbol(iter->first, name))
303              ? iter->second
304              : Value();
305 }
306 
307 template <typename Value>
FindExtension(const std::string & containing_type,int field_number)308 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindExtension(
309     const std::string& containing_type, int field_number) {
310   return FindWithDefault(
311       by_extension_, std::make_pair(containing_type, field_number), Value());
312 }
313 
314 template <typename Value>
FindAllExtensionNumbers(const std::string & containing_type,std::vector<int> * output)315 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::FindAllExtensionNumbers(
316     const std::string& containing_type, std::vector<int>* output) {
317   typename std::map<std::pair<std::string, int>, Value>::const_iterator it =
318       by_extension_.lower_bound(std::make_pair(containing_type, 0));
319   bool success = false;
320 
321   for (; it != by_extension_.end() && it->first.first == containing_type;
322        ++it) {
323     output->push_back(it->first.second);
324     success = true;
325   }
326 
327   return success;
328 }
329 
330 template <typename Value>
FindAllFileNames(std::vector<std::string> * output)331 void SimpleDescriptorDatabase::DescriptorIndex<Value>::FindAllFileNames(
332     std::vector<std::string>* output) {
333   output->resize(by_name_.size());
334   int i = 0;
335   for (const auto& kv : by_name_) {
336     (*output)[i] = kv.first;
337     i++;
338   }
339 }
340 
341 // -------------------------------------------------------------------
342 
Add(const FileDescriptorProto & file)343 bool SimpleDescriptorDatabase::Add(const FileDescriptorProto& file) {
344   FileDescriptorProto* new_file = new FileDescriptorProto;
345   new_file->CopyFrom(file);
346   return AddAndOwn(new_file);
347 }
348 
AddAndOwn(const FileDescriptorProto * file)349 bool SimpleDescriptorDatabase::AddAndOwn(const FileDescriptorProto* file) {
350   files_to_delete_.emplace_back(file);
351   return index_.AddFile(*file, file);
352 }
353 
FindFileByName(const std::string & filename,FileDescriptorProto * output)354 bool SimpleDescriptorDatabase::FindFileByName(const std::string& filename,
355                                               FileDescriptorProto* output) {
356   return MaybeCopy(index_.FindFile(filename), output);
357 }
358 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)359 bool SimpleDescriptorDatabase::FindFileContainingSymbol(
360     const std::string& symbol_name, FileDescriptorProto* output) {
361   return MaybeCopy(index_.FindSymbol(symbol_name), output);
362 }
363 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)364 bool SimpleDescriptorDatabase::FindFileContainingExtension(
365     const std::string& containing_type, int field_number,
366     FileDescriptorProto* output) {
367   return MaybeCopy(index_.FindExtension(containing_type, field_number), output);
368 }
369 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)370 bool SimpleDescriptorDatabase::FindAllExtensionNumbers(
371     const std::string& extendee_type, std::vector<int>* output) {
372   return index_.FindAllExtensionNumbers(extendee_type, output);
373 }
374 
375 
FindAllFileNames(std::vector<std::string> * output)376 bool SimpleDescriptorDatabase::FindAllFileNames(
377     std::vector<std::string>* output) {
378   index_.FindAllFileNames(output);
379   return true;
380 }
381 
MaybeCopy(const FileDescriptorProto * file,FileDescriptorProto * output)382 bool SimpleDescriptorDatabase::MaybeCopy(const FileDescriptorProto* file,
383                                          FileDescriptorProto* output) {
384   if (file == NULL) return false;
385   output->CopyFrom(*file);
386   return true;
387 }
388 
389 // -------------------------------------------------------------------
390 
391 class EncodedDescriptorDatabase::DescriptorIndex {
392  public:
393   using Value = std::pair<const void*, int>;
394   // Helpers to recursively add particular descriptors and all their contents
395   // to the index.
396   template <typename FileProto>
397   bool AddFile(const FileProto& file, Value value);
398 
399   Value FindFile(StringPiece filename);
400   Value FindSymbol(StringPiece name);
401   Value FindSymbolOnlyFlat(StringPiece name) const;
402   Value FindExtension(StringPiece containing_type, int field_number);
403   bool FindAllExtensionNumbers(StringPiece containing_type,
404                                std::vector<int>* output);
405   void FindAllFileNames(std::vector<std::string>* output) const;
406 
407  private:
408   friend class EncodedDescriptorDatabase;
409 
410   bool AddSymbol(StringPiece symbol);
411 
412   template <typename DescProto>
413   bool AddNestedExtensions(StringPiece filename,
414                            const DescProto& message_type);
415   template <typename FieldProto>
416   bool AddExtension(StringPiece filename, const FieldProto& field);
417 
418   // All the maps below have two representations:
419   //  - a std::set<> where we insert initially.
420   //  - a std::vector<> where we flatten the structure on demand.
421   // The initial tree helps avoid O(N) behavior of inserting into a sorted
422   // vector, while the vector reduces the heap requirements of the data
423   // structure.
424 
425   void EnsureFlat();
426 
427   using String = std::string;
428 
EncodeString(StringPiece str) const429   String EncodeString(StringPiece str) const { return String(str); }
DecodeString(const String & str,int) const430   StringPiece DecodeString(const String& str, int) const { return str; }
431 
432   struct EncodedEntry {
433     // Do not use `Value` here to avoid the padding of that object.
434     const void* data;
435     int size;
436     // Keep the package here instead of each SymbolEntry to save space.
437     String encoded_package;
438 
valuegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::EncodedEntry439     Value value() const { return {data, size}; }
440   };
441   std::vector<EncodedEntry> all_values_;
442 
443   struct FileEntry {
444     int data_offset;
445     String encoded_name;
446 
namegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileEntry447     StringPiece name(const DescriptorIndex& index) const {
448       return index.DecodeString(encoded_name, data_offset);
449     }
450   };
451   struct FileCompare {
452     const DescriptorIndex& index;
453 
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare454     bool operator()(const FileEntry& a, const FileEntry& b) const {
455       return a.name(index) < b.name(index);
456     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare457     bool operator()(const FileEntry& a, StringPiece b) const {
458       return a.name(index) < b;
459     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare460     bool operator()(StringPiece a, const FileEntry& b) const {
461       return a < b.name(index);
462     }
463   };
464   std::set<FileEntry, FileCompare> by_name_{FileCompare{*this}};
465   std::vector<FileEntry> by_name_flat_;
466 
467   struct SymbolEntry {
468     int data_offset;
469     String encoded_symbol;
470 
packagegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry471     StringPiece package(const DescriptorIndex& index) const {
472       return index.DecodeString(index.all_values_[data_offset].encoded_package,
473                                 data_offset);
474     }
symbolgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry475     StringPiece symbol(const DescriptorIndex& index) const {
476       return index.DecodeString(encoded_symbol, data_offset);
477     }
478 
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry479     std::string AsString(const DescriptorIndex& index) const {
480       auto p = package(index);
481       return StrCat(p, p.empty() ? "" : ".", symbol(index));
482     }
483   };
484 
485   struct SymbolCompare {
486     const DescriptorIndex& index;
487 
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare488     std::string AsString(const SymbolEntry& entry) const {
489       return entry.AsString(index);
490     }
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare491     static StringPiece AsString(StringPiece str) { return str; }
492 
GetPartsgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare493     std::pair<StringPiece, StringPiece> GetParts(
494         const SymbolEntry& entry) const {
495       auto package = entry.package(index);
496       if (package.empty()) return {entry.symbol(index), StringPiece{}};
497       return {package, entry.symbol(index)};
498     }
GetPartsgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare499     std::pair<StringPiece, StringPiece> GetParts(
500         StringPiece str) const {
501       return {str, {}};
502     }
503 
504     template <typename T, typename U>
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare505     bool operator()(const T& lhs, const U& rhs) const {
506       auto lhs_parts = GetParts(lhs);
507       auto rhs_parts = GetParts(rhs);
508 
509       // Fast path to avoid making the whole string for common cases.
510       if (int res =
511               lhs_parts.first.substr(0, rhs_parts.first.size())
512                   .compare(rhs_parts.first.substr(0, lhs_parts.first.size()))) {
513         // If the packages already differ, exit early.
514         return res < 0;
515       } else if (lhs_parts.first.size() == rhs_parts.first.size()) {
516         return lhs_parts.second < rhs_parts.second;
517       }
518       return AsString(lhs) < AsString(rhs);
519     }
520   };
521   std::set<SymbolEntry, SymbolCompare> by_symbol_{SymbolCompare{*this}};
522   std::vector<SymbolEntry> by_symbol_flat_;
523 
524   struct ExtensionEntry {
525     int data_offset;
526     String encoded_extendee;
extendeegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionEntry527     StringPiece extendee(const DescriptorIndex& index) const {
528       return index.DecodeString(encoded_extendee, data_offset).substr(1);
529     }
530     int extension_number;
531   };
532   struct ExtensionCompare {
533     const DescriptorIndex& index;
534 
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare535     bool operator()(const ExtensionEntry& a, const ExtensionEntry& b) const {
536       return std::make_tuple(a.extendee(index), a.extension_number) <
537              std::make_tuple(b.extendee(index), b.extension_number);
538     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare539     bool operator()(const ExtensionEntry& a,
540                     std::tuple<StringPiece, int> b) const {
541       return std::make_tuple(a.extendee(index), a.extension_number) < b;
542     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare543     bool operator()(std::tuple<StringPiece, int> a,
544                     const ExtensionEntry& b) const {
545       return a < std::make_tuple(b.extendee(index), b.extension_number);
546     }
547   };
548   std::set<ExtensionEntry, ExtensionCompare> by_extension_{
549       ExtensionCompare{*this}};
550   std::vector<ExtensionEntry> by_extension_flat_;
551 };
552 
Add(const void * encoded_file_descriptor,int size)553 bool EncodedDescriptorDatabase::Add(const void* encoded_file_descriptor,
554                                     int size) {
555   FileDescriptorProto file;
556   if (file.ParseFromArray(encoded_file_descriptor, size)) {
557     return index_->AddFile(file, std::make_pair(encoded_file_descriptor, size));
558   } else {
559     GOOGLE_LOG(ERROR) << "Invalid file descriptor data passed to "
560                   "EncodedDescriptorDatabase::Add().";
561     return false;
562   }
563 }
564 
AddCopy(const void * encoded_file_descriptor,int size)565 bool EncodedDescriptorDatabase::AddCopy(const void* encoded_file_descriptor,
566                                         int size) {
567   void* copy = operator new(size);
568   memcpy(copy, encoded_file_descriptor, size);
569   files_to_delete_.push_back(copy);
570   return Add(copy, size);
571 }
572 
FindFileByName(const std::string & filename,FileDescriptorProto * output)573 bool EncodedDescriptorDatabase::FindFileByName(const std::string& filename,
574                                                FileDescriptorProto* output) {
575   return MaybeParse(index_->FindFile(filename), output);
576 }
577 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)578 bool EncodedDescriptorDatabase::FindFileContainingSymbol(
579     const std::string& symbol_name, FileDescriptorProto* output) {
580   return MaybeParse(index_->FindSymbol(symbol_name), output);
581 }
582 
FindNameOfFileContainingSymbol(const std::string & symbol_name,std::string * output)583 bool EncodedDescriptorDatabase::FindNameOfFileContainingSymbol(
584     const std::string& symbol_name, std::string* output) {
585   auto encoded_file = index_->FindSymbol(symbol_name);
586   if (encoded_file.first == NULL) return false;
587 
588   // Optimization:  The name should be the first field in the encoded message.
589   //   Try to just read it directly.
590   io::CodedInputStream input(static_cast<const uint8*>(encoded_file.first),
591                              encoded_file.second);
592 
593   const uint32 kNameTag = internal::WireFormatLite::MakeTag(
594       FileDescriptorProto::kNameFieldNumber,
595       internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED);
596 
597   if (input.ReadTagNoLastTag() == kNameTag) {
598     // Success!
599     return internal::WireFormatLite::ReadString(&input, output);
600   } else {
601     // Slow path.  Parse whole message.
602     FileDescriptorProto file_proto;
603     if (!file_proto.ParseFromArray(encoded_file.first, encoded_file.second)) {
604       return false;
605     }
606     *output = file_proto.name();
607     return true;
608   }
609 }
610 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)611 bool EncodedDescriptorDatabase::FindFileContainingExtension(
612     const std::string& containing_type, int field_number,
613     FileDescriptorProto* output) {
614   return MaybeParse(index_->FindExtension(containing_type, field_number),
615                     output);
616 }
617 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)618 bool EncodedDescriptorDatabase::FindAllExtensionNumbers(
619     const std::string& extendee_type, std::vector<int>* output) {
620   return index_->FindAllExtensionNumbers(extendee_type, output);
621 }
622 
623 template <typename FileProto>
AddFile(const FileProto & file,Value value)624 bool EncodedDescriptorDatabase::DescriptorIndex::AddFile(const FileProto& file,
625                                                          Value value) {
626   // We push `value` into the array first. This is important because the AddXXX
627   // functions below will expect it to be there.
628   all_values_.push_back({value.first, value.second, {}});
629 
630   if (!ValidateSymbolName(file.package())) {
631     GOOGLE_LOG(ERROR) << "Invalid package name: " << file.package();
632     return false;
633   }
634   all_values_.back().encoded_package = EncodeString(file.package());
635 
636   if (!InsertIfNotPresent(
637           &by_name_, FileEntry{static_cast<int>(all_values_.size() - 1),
638                                EncodeString(file.name())}) ||
639       std::binary_search(by_name_flat_.begin(), by_name_flat_.end(),
640                          file.name(), by_name_.key_comp())) {
641     GOOGLE_LOG(ERROR) << "File already exists in database: " << file.name();
642     return false;
643   }
644 
645   for (const auto& message_type : file.message_type()) {
646     if (!AddSymbol(message_type.name())) return false;
647     if (!AddNestedExtensions(file.name(), message_type)) return false;
648   }
649   for (const auto& enum_type : file.enum_type()) {
650     if (!AddSymbol(enum_type.name())) return false;
651   }
652   for (const auto& extension : file.extension()) {
653     if (!AddSymbol(extension.name())) return false;
654     if (!AddExtension(file.name(), extension)) return false;
655   }
656   for (const auto& service : file.service()) {
657     if (!AddSymbol(service.name())) return false;
658   }
659 
660   return true;
661 }
662 
663 template <typename Iter, typename Iter2, typename Index>
CheckForMutualSubsymbols(StringPiece symbol_name,Iter * iter,Iter2 end,const Index & index)664 static bool CheckForMutualSubsymbols(StringPiece symbol_name, Iter* iter,
665                                      Iter2 end, const Index& index) {
666   if (*iter != end) {
667     if (IsSubSymbol((*iter)->AsString(index), symbol_name)) {
668       GOOGLE_LOG(ERROR) << "Symbol name \"" << symbol_name
669                  << "\" conflicts with the existing symbol \""
670                  << (*iter)->AsString(index) << "\".";
671       return false;
672     }
673 
674     // OK, that worked.  Now we have to make sure that no symbol in the map is
675     // a sub-symbol of the one we are inserting.  The only symbol which could
676     // be so is the first symbol that is greater than the new symbol.  Since
677     // |iter| points at the last symbol that is less than or equal, we just have
678     // to increment it.
679     ++*iter;
680 
681     if (*iter != end && IsSubSymbol(symbol_name, (*iter)->AsString(index))) {
682       GOOGLE_LOG(ERROR) << "Symbol name \"" << symbol_name
683                  << "\" conflicts with the existing symbol \""
684                  << (*iter)->AsString(index) << "\".";
685       return false;
686     }
687   }
688   return true;
689 }
690 
AddSymbol(StringPiece symbol)691 bool EncodedDescriptorDatabase::DescriptorIndex::AddSymbol(
692     StringPiece symbol) {
693   SymbolEntry entry = {static_cast<int>(all_values_.size() - 1),
694                        EncodeString(symbol)};
695   std::string entry_as_string = entry.AsString(*this);
696 
697   // We need to make sure not to violate our map invariant.
698 
699   // If the symbol name is invalid it could break our lookup algorithm (which
700   // relies on the fact that '.' sorts before all other characters that are
701   // valid in symbol names).
702   if (!ValidateSymbolName(symbol)) {
703     GOOGLE_LOG(ERROR) << "Invalid symbol name: " << entry_as_string;
704     return false;
705   }
706 
707   auto iter = FindLastLessOrEqual(&by_symbol_, entry);
708   if (!CheckForMutualSubsymbols(entry_as_string, &iter, by_symbol_.end(),
709                                 *this)) {
710     return false;
711   }
712 
713   // Same, but on by_symbol_flat_
714   auto flat_iter =
715       FindLastLessOrEqual(&by_symbol_flat_, entry, by_symbol_.key_comp());
716   if (!CheckForMutualSubsymbols(entry_as_string, &flat_iter,
717                                 by_symbol_flat_.end(), *this)) {
718     return false;
719   }
720 
721   // OK, no conflicts.
722 
723   // Insert the new symbol using the iterator as a hint, the new entry will
724   // appear immediately before the one the iterator is pointing at.
725   by_symbol_.insert(iter, entry);
726 
727   return true;
728 }
729 
730 template <typename DescProto>
AddNestedExtensions(StringPiece filename,const DescProto & message_type)731 bool EncodedDescriptorDatabase::DescriptorIndex::AddNestedExtensions(
732     StringPiece filename, const DescProto& message_type) {
733   for (const auto& nested_type : message_type.nested_type()) {
734     if (!AddNestedExtensions(filename, nested_type)) return false;
735   }
736   for (const auto& extension : message_type.extension()) {
737     if (!AddExtension(filename, extension)) return false;
738   }
739   return true;
740 }
741 
742 template <typename FieldProto>
AddExtension(StringPiece filename,const FieldProto & field)743 bool EncodedDescriptorDatabase::DescriptorIndex::AddExtension(
744     StringPiece filename, const FieldProto& field) {
745   if (!field.extendee().empty() && field.extendee()[0] == '.') {
746     // The extension is fully-qualified.  We can use it as a lookup key in
747     // the by_symbol_ table.
748     if (!InsertIfNotPresent(
749             &by_extension_,
750             ExtensionEntry{static_cast<int>(all_values_.size() - 1),
751                            EncodeString(field.extendee()), field.number()}) ||
752         std::binary_search(
753             by_extension_flat_.begin(), by_extension_flat_.end(),
754             std::make_pair(field.extendee().substr(1), field.number()),
755             by_extension_.key_comp())) {
756       GOOGLE_LOG(ERROR) << "Extension conflicts with extension already in database: "
757                     "extend "
758                  << field.extendee() << " { " << field.name() << " = "
759                  << field.number() << " } from:" << filename;
760       return false;
761     }
762   } else {
763     // Not fully-qualified.  We can't really do anything here, unfortunately.
764     // We don't consider this an error, though, because the descriptor is
765     // valid.
766   }
767   return true;
768 }
769 
770 std::pair<const void*, int>
FindSymbol(StringPiece name)771 EncodedDescriptorDatabase::DescriptorIndex::FindSymbol(StringPiece name) {
772   EnsureFlat();
773   return FindSymbolOnlyFlat(name);
774 }
775 
776 std::pair<const void*, int>
FindSymbolOnlyFlat(StringPiece name) const777 EncodedDescriptorDatabase::DescriptorIndex::FindSymbolOnlyFlat(
778     StringPiece name) const {
779   auto iter =
780       FindLastLessOrEqual(&by_symbol_flat_, name, by_symbol_.key_comp());
781 
782   return iter != by_symbol_flat_.end() &&
783                  IsSubSymbol(iter->AsString(*this), name)
784              ? all_values_[iter->data_offset].value()
785              : Value();
786 }
787 
788 std::pair<const void*, int>
FindExtension(StringPiece containing_type,int field_number)789 EncodedDescriptorDatabase::DescriptorIndex::FindExtension(
790     StringPiece containing_type, int field_number) {
791   EnsureFlat();
792 
793   auto it = std::lower_bound(
794       by_extension_flat_.begin(), by_extension_flat_.end(),
795       std::make_tuple(containing_type, field_number), by_extension_.key_comp());
796   return it == by_extension_flat_.end() ||
797                  it->extendee(*this) != containing_type ||
798                  it->extension_number != field_number
799              ? std::make_pair(nullptr, 0)
800              : all_values_[it->data_offset].value();
801 }
802 
803 template <typename T, typename Less>
MergeIntoFlat(std::set<T,Less> * s,std::vector<T> * flat)804 static void MergeIntoFlat(std::set<T, Less>* s, std::vector<T>* flat) {
805   if (s->empty()) return;
806   std::vector<T> new_flat(s->size() + flat->size());
807   std::merge(s->begin(), s->end(), flat->begin(), flat->end(), &new_flat[0],
808              s->key_comp());
809   *flat = std::move(new_flat);
810   s->clear();
811 }
812 
EnsureFlat()813 void EncodedDescriptorDatabase::DescriptorIndex::EnsureFlat() {
814   all_values_.shrink_to_fit();
815   // Merge each of the sets into their flat counterpart.
816   MergeIntoFlat(&by_name_, &by_name_flat_);
817   MergeIntoFlat(&by_symbol_, &by_symbol_flat_);
818   MergeIntoFlat(&by_extension_, &by_extension_flat_);
819 }
820 
FindAllExtensionNumbers(StringPiece containing_type,std::vector<int> * output)821 bool EncodedDescriptorDatabase::DescriptorIndex::FindAllExtensionNumbers(
822     StringPiece containing_type, std::vector<int>* output) {
823   EnsureFlat();
824 
825   bool success = false;
826   auto it = std::lower_bound(
827       by_extension_flat_.begin(), by_extension_flat_.end(),
828       std::make_tuple(containing_type, 0), by_extension_.key_comp());
829   for (;
830        it != by_extension_flat_.end() && it->extendee(*this) == containing_type;
831        ++it) {
832     output->push_back(it->extension_number);
833     success = true;
834   }
835 
836   return success;
837 }
838 
FindAllFileNames(std::vector<std::string> * output) const839 void EncodedDescriptorDatabase::DescriptorIndex::FindAllFileNames(
840     std::vector<std::string>* output) const {
841   output->resize(by_name_.size() + by_name_flat_.size());
842   int i = 0;
843   for (const auto& entry : by_name_) {
844     (*output)[i] = std::string(entry.name(*this));
845     i++;
846   }
847   for (const auto& entry : by_name_flat_) {
848     (*output)[i] = std::string(entry.name(*this));
849     i++;
850   }
851 }
852 
853 std::pair<const void*, int>
FindFile(StringPiece filename)854 EncodedDescriptorDatabase::DescriptorIndex::FindFile(
855     StringPiece filename) {
856   EnsureFlat();
857 
858   auto it = std::lower_bound(by_name_flat_.begin(), by_name_flat_.end(),
859                              filename, by_name_.key_comp());
860   return it == by_name_flat_.end() || it->name(*this) != filename
861              ? std::make_pair(nullptr, 0)
862              : all_values_[it->data_offset].value();
863 }
864 
865 
FindAllFileNames(std::vector<std::string> * output)866 bool EncodedDescriptorDatabase::FindAllFileNames(
867     std::vector<std::string>* output) {
868   index_->FindAllFileNames(output);
869   return true;
870 }
871 
MaybeParse(std::pair<const void *,int> encoded_file,FileDescriptorProto * output)872 bool EncodedDescriptorDatabase::MaybeParse(
873     std::pair<const void*, int> encoded_file, FileDescriptorProto* output) {
874   if (encoded_file.first == NULL) return false;
875   return output->ParseFromArray(encoded_file.first, encoded_file.second);
876 }
877 
EncodedDescriptorDatabase()878 EncodedDescriptorDatabase::EncodedDescriptorDatabase()
879     : index_(new DescriptorIndex()) {}
880 
~EncodedDescriptorDatabase()881 EncodedDescriptorDatabase::~EncodedDescriptorDatabase() {
882   for (void* p : files_to_delete_) {
883     operator delete(p);
884   }
885 }
886 
887 // ===================================================================
888 
DescriptorPoolDatabase(const DescriptorPool & pool)889 DescriptorPoolDatabase::DescriptorPoolDatabase(const DescriptorPool& pool)
890     : pool_(pool) {}
~DescriptorPoolDatabase()891 DescriptorPoolDatabase::~DescriptorPoolDatabase() {}
892 
FindFileByName(const std::string & filename,FileDescriptorProto * output)893 bool DescriptorPoolDatabase::FindFileByName(const std::string& filename,
894                                             FileDescriptorProto* output) {
895   const FileDescriptor* file = pool_.FindFileByName(filename);
896   if (file == NULL) return false;
897   output->Clear();
898   file->CopyTo(output);
899   return true;
900 }
901 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)902 bool DescriptorPoolDatabase::FindFileContainingSymbol(
903     const std::string& symbol_name, FileDescriptorProto* output) {
904   const FileDescriptor* file = pool_.FindFileContainingSymbol(symbol_name);
905   if (file == NULL) return false;
906   output->Clear();
907   file->CopyTo(output);
908   return true;
909 }
910 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)911 bool DescriptorPoolDatabase::FindFileContainingExtension(
912     const std::string& containing_type, int field_number,
913     FileDescriptorProto* output) {
914   const Descriptor* extendee = pool_.FindMessageTypeByName(containing_type);
915   if (extendee == NULL) return false;
916 
917   const FieldDescriptor* extension =
918       pool_.FindExtensionByNumber(extendee, field_number);
919   if (extension == NULL) return false;
920 
921   output->Clear();
922   extension->file()->CopyTo(output);
923   return true;
924 }
925 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)926 bool DescriptorPoolDatabase::FindAllExtensionNumbers(
927     const std::string& extendee_type, std::vector<int>* output) {
928   const Descriptor* extendee = pool_.FindMessageTypeByName(extendee_type);
929   if (extendee == NULL) return false;
930 
931   std::vector<const FieldDescriptor*> extensions;
932   pool_.FindAllExtensions(extendee, &extensions);
933 
934   for (const FieldDescriptor* extension : extensions) {
935     output->push_back(extension->number());
936   }
937 
938   return true;
939 }
940 
941 // ===================================================================
942 
MergedDescriptorDatabase(DescriptorDatabase * source1,DescriptorDatabase * source2)943 MergedDescriptorDatabase::MergedDescriptorDatabase(
944     DescriptorDatabase* source1, DescriptorDatabase* source2) {
945   sources_.push_back(source1);
946   sources_.push_back(source2);
947 }
MergedDescriptorDatabase(const std::vector<DescriptorDatabase * > & sources)948 MergedDescriptorDatabase::MergedDescriptorDatabase(
949     const std::vector<DescriptorDatabase*>& sources)
950     : sources_(sources) {}
~MergedDescriptorDatabase()951 MergedDescriptorDatabase::~MergedDescriptorDatabase() {}
952 
FindFileByName(const std::string & filename,FileDescriptorProto * output)953 bool MergedDescriptorDatabase::FindFileByName(const std::string& filename,
954                                               FileDescriptorProto* output) {
955   for (DescriptorDatabase* source : sources_) {
956     if (source->FindFileByName(filename, output)) {
957       return true;
958     }
959   }
960   return false;
961 }
962 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)963 bool MergedDescriptorDatabase::FindFileContainingSymbol(
964     const std::string& symbol_name, FileDescriptorProto* output) {
965   for (size_t i = 0; i < sources_.size(); i++) {
966     if (sources_[i]->FindFileContainingSymbol(symbol_name, output)) {
967       // The symbol was found in source i.  However, if one of the previous
968       // sources defines a file with the same name (which presumably doesn't
969       // contain the symbol, since it wasn't found in that source), then we
970       // must hide it from the caller.
971       FileDescriptorProto temp;
972       for (size_t j = 0; j < i; j++) {
973         if (sources_[j]->FindFileByName(output->name(), &temp)) {
974           // Found conflicting file in a previous source.
975           return false;
976         }
977       }
978       return true;
979     }
980   }
981   return false;
982 }
983 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)984 bool MergedDescriptorDatabase::FindFileContainingExtension(
985     const std::string& containing_type, int field_number,
986     FileDescriptorProto* output) {
987   for (size_t i = 0; i < sources_.size(); i++) {
988     if (sources_[i]->FindFileContainingExtension(containing_type, field_number,
989                                                  output)) {
990       // The symbol was found in source i.  However, if one of the previous
991       // sources defines a file with the same name (which presumably doesn't
992       // contain the symbol, since it wasn't found in that source), then we
993       // must hide it from the caller.
994       FileDescriptorProto temp;
995       for (size_t j = 0; j < i; j++) {
996         if (sources_[j]->FindFileByName(output->name(), &temp)) {
997           // Found conflicting file in a previous source.
998           return false;
999         }
1000       }
1001       return true;
1002     }
1003   }
1004   return false;
1005 }
1006 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)1007 bool MergedDescriptorDatabase::FindAllExtensionNumbers(
1008     const std::string& extendee_type, std::vector<int>* output) {
1009   std::set<int> merged_results;
1010   std::vector<int> results;
1011   bool success = false;
1012 
1013   for (DescriptorDatabase* source : sources_) {
1014     if (source->FindAllExtensionNumbers(extendee_type, &results)) {
1015       std::copy(results.begin(), results.end(),
1016                 std::insert_iterator<std::set<int> >(merged_results,
1017                                                      merged_results.begin()));
1018       success = true;
1019     }
1020     results.clear();
1021   }
1022 
1023   std::copy(merged_results.begin(), merged_results.end(),
1024             std::insert_iterator<std::vector<int> >(*output, output->end()));
1025 
1026   return success;
1027 }
1028 
1029 
1030 }  // namespace protobuf
1031 }  // namespace google
1032