1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: kenton@google.com (Kenton Varda)
32 //  Based on original Protocol Buffers design by
33 //  Sanjay Ghemawat, Jeff Dean, and others.
34 
35 #include <google/protobuf/wire_format.h>
36 
37 #include <stack>
38 #include <string>
39 #include <vector>
40 
41 #include <google/protobuf/stubs/logging.h>
42 #include <google/protobuf/stubs/common.h>
43 #include <google/protobuf/stubs/stringprintf.h>
44 #include <google/protobuf/descriptor.pb.h>
45 #include <google/protobuf/io/coded_stream.h>
46 #include <google/protobuf/io/zero_copy_stream.h>
47 #include <google/protobuf/io/zero_copy_stream_impl.h>
48 #include <google/protobuf/descriptor.h>
49 #include <google/protobuf/dynamic_message.h>
50 #include <google/protobuf/map_field.h>
51 #include <google/protobuf/map_field_inl.h>
52 #include <google/protobuf/message_lite.h>
53 #include <google/protobuf/unknown_field_set.h>
54 
55 
56 #include <google/protobuf/port_def.inc>
57 
58 const size_t kMapEntryTagByteSize = 2;
59 
60 namespace google {
61 namespace protobuf {
62 namespace internal {
63 
64 // Forward declare static functions
65 static size_t MapKeyDataOnlyByteSize(const FieldDescriptor* field,
66                                      const MapKey& value);
67 static size_t MapValueRefDataOnlyByteSize(const FieldDescriptor* field,
68                                           const MapValueRef& value);
69 
70 // ===================================================================
71 
SkipField(io::CodedInputStream * input,uint32 tag)72 bool UnknownFieldSetFieldSkipper::SkipField(io::CodedInputStream* input,
73                                             uint32 tag) {
74   return WireFormat::SkipField(input, tag, unknown_fields_);
75 }
76 
SkipMessage(io::CodedInputStream * input)77 bool UnknownFieldSetFieldSkipper::SkipMessage(io::CodedInputStream* input) {
78   return WireFormat::SkipMessage(input, unknown_fields_);
79 }
80 
SkipUnknownEnum(int field_number,int value)81 void UnknownFieldSetFieldSkipper::SkipUnknownEnum(int field_number, int value) {
82   unknown_fields_->AddVarint(field_number, value);
83 }
84 
SkipField(io::CodedInputStream * input,uint32 tag,UnknownFieldSet * unknown_fields)85 bool WireFormat::SkipField(io::CodedInputStream* input, uint32 tag,
86                            UnknownFieldSet* unknown_fields) {
87   int number = WireFormatLite::GetTagFieldNumber(tag);
88   // Field number 0 is illegal.
89   if (number == 0) return false;
90 
91   switch (WireFormatLite::GetTagWireType(tag)) {
92     case WireFormatLite::WIRETYPE_VARINT: {
93       uint64 value;
94       if (!input->ReadVarint64(&value)) return false;
95       if (unknown_fields != NULL) unknown_fields->AddVarint(number, value);
96       return true;
97     }
98     case WireFormatLite::WIRETYPE_FIXED64: {
99       uint64 value;
100       if (!input->ReadLittleEndian64(&value)) return false;
101       if (unknown_fields != NULL) unknown_fields->AddFixed64(number, value);
102       return true;
103     }
104     case WireFormatLite::WIRETYPE_LENGTH_DELIMITED: {
105       uint32 length;
106       if (!input->ReadVarint32(&length)) return false;
107       if (unknown_fields == NULL) {
108         if (!input->Skip(length)) return false;
109       } else {
110         if (!input->ReadString(unknown_fields->AddLengthDelimited(number),
111                                length)) {
112           return false;
113         }
114       }
115       return true;
116     }
117     case WireFormatLite::WIRETYPE_START_GROUP: {
118       if (!input->IncrementRecursionDepth()) return false;
119       if (!SkipMessage(input, (unknown_fields == NULL)
120                                   ? NULL
121                                   : unknown_fields->AddGroup(number))) {
122         return false;
123       }
124       input->DecrementRecursionDepth();
125       // Check that the ending tag matched the starting tag.
126       if (!input->LastTagWas(
127               WireFormatLite::MakeTag(WireFormatLite::GetTagFieldNumber(tag),
128                                       WireFormatLite::WIRETYPE_END_GROUP))) {
129         return false;
130       }
131       return true;
132     }
133     case WireFormatLite::WIRETYPE_END_GROUP: {
134       return false;
135     }
136     case WireFormatLite::WIRETYPE_FIXED32: {
137       uint32 value;
138       if (!input->ReadLittleEndian32(&value)) return false;
139       if (unknown_fields != NULL) unknown_fields->AddFixed32(number, value);
140       return true;
141     }
142     default: {
143       return false;
144     }
145   }
146 }
147 
SkipMessage(io::CodedInputStream * input,UnknownFieldSet * unknown_fields)148 bool WireFormat::SkipMessage(io::CodedInputStream* input,
149                              UnknownFieldSet* unknown_fields) {
150   while (true) {
151     uint32 tag = input->ReadTag();
152     if (tag == 0) {
153       // End of input.  This is a valid place to end, so return true.
154       return true;
155     }
156 
157     WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
158 
159     if (wire_type == WireFormatLite::WIRETYPE_END_GROUP) {
160       // Must be the end of the message.
161       return true;
162     }
163 
164     if (!SkipField(input, tag, unknown_fields)) return false;
165   }
166 }
167 
ReadPackedEnumPreserveUnknowns(io::CodedInputStream * input,uint32 field_number,bool (* is_valid)(int),UnknownFieldSet * unknown_fields,RepeatedField<int> * values)168 bool WireFormat::ReadPackedEnumPreserveUnknowns(io::CodedInputStream* input,
169                                                 uint32 field_number,
170                                                 bool (*is_valid)(int),
171                                                 UnknownFieldSet* unknown_fields,
172                                                 RepeatedField<int>* values) {
173   uint32 length;
174   if (!input->ReadVarint32(&length)) return false;
175   io::CodedInputStream::Limit limit = input->PushLimit(length);
176   while (input->BytesUntilLimit() > 0) {
177     int value;
178     if (!WireFormatLite::ReadPrimitive<int, WireFormatLite::TYPE_ENUM>(
179             input, &value)) {
180       return false;
181     }
182     if (is_valid == NULL || is_valid(value)) {
183       values->Add(value);
184     } else {
185       unknown_fields->AddVarint(field_number, value);
186     }
187   }
188   input->PopLimit(limit);
189   return true;
190 }
191 
InternalSerializeUnknownFieldsToArray(const UnknownFieldSet & unknown_fields,uint8 * target,io::EpsCopyOutputStream * stream)192 uint8* WireFormat::InternalSerializeUnknownFieldsToArray(
193     const UnknownFieldSet& unknown_fields, uint8* target,
194     io::EpsCopyOutputStream* stream) {
195   for (int i = 0; i < unknown_fields.field_count(); i++) {
196     const UnknownField& field = unknown_fields.field(i);
197 
198     target = stream->EnsureSpace(target);
199     switch (field.type()) {
200       case UnknownField::TYPE_VARINT:
201         target = WireFormatLite::WriteUInt64ToArray(field.number(),
202                                                     field.varint(), target);
203         break;
204       case UnknownField::TYPE_FIXED32:
205         target = WireFormatLite::WriteFixed32ToArray(field.number(),
206                                                      field.fixed32(), target);
207         break;
208       case UnknownField::TYPE_FIXED64:
209         target = WireFormatLite::WriteFixed64ToArray(field.number(),
210                                                      field.fixed64(), target);
211         break;
212       case UnknownField::TYPE_LENGTH_DELIMITED:
213         target = stream->WriteString(field.number(), field.length_delimited(),
214                                      target);
215         break;
216       case UnknownField::TYPE_GROUP:
217         target = WireFormatLite::WriteTagToArray(
218             field.number(), WireFormatLite::WIRETYPE_START_GROUP, target);
219         target = InternalSerializeUnknownFieldsToArray(field.group(), target,
220                                                        stream);
221         target = stream->EnsureSpace(target);
222         target = WireFormatLite::WriteTagToArray(
223             field.number(), WireFormatLite::WIRETYPE_END_GROUP, target);
224         break;
225     }
226   }
227   return target;
228 }
229 
InternalSerializeUnknownMessageSetItemsToArray(const UnknownFieldSet & unknown_fields,uint8 * target,io::EpsCopyOutputStream * stream)230 uint8* WireFormat::InternalSerializeUnknownMessageSetItemsToArray(
231     const UnknownFieldSet& unknown_fields, uint8* target,
232     io::EpsCopyOutputStream* stream) {
233   for (int i = 0; i < unknown_fields.field_count(); i++) {
234     const UnknownField& field = unknown_fields.field(i);
235 
236     // The only unknown fields that are allowed to exist in a MessageSet are
237     // messages, which are length-delimited.
238     if (field.type() == UnknownField::TYPE_LENGTH_DELIMITED) {
239       target = stream->EnsureSpace(target);
240       // Start group.
241       target = io::CodedOutputStream::WriteTagToArray(
242           WireFormatLite::kMessageSetItemStartTag, target);
243 
244       // Write type ID.
245       target = io::CodedOutputStream::WriteTagToArray(
246           WireFormatLite::kMessageSetTypeIdTag, target);
247       target =
248           io::CodedOutputStream::WriteVarint32ToArray(field.number(), target);
249 
250       // Write message.
251       target = io::CodedOutputStream::WriteTagToArray(
252           WireFormatLite::kMessageSetMessageTag, target);
253 
254       target = field.InternalSerializeLengthDelimitedNoTag(target, stream);
255 
256       target = stream->EnsureSpace(target);
257       // End group.
258       target = io::CodedOutputStream::WriteTagToArray(
259           WireFormatLite::kMessageSetItemEndTag, target);
260     }
261   }
262 
263   return target;
264 }
265 
ComputeUnknownFieldsSize(const UnknownFieldSet & unknown_fields)266 size_t WireFormat::ComputeUnknownFieldsSize(
267     const UnknownFieldSet& unknown_fields) {
268   size_t size = 0;
269   for (int i = 0; i < unknown_fields.field_count(); i++) {
270     const UnknownField& field = unknown_fields.field(i);
271 
272     switch (field.type()) {
273       case UnknownField::TYPE_VARINT:
274         size += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
275             field.number(), WireFormatLite::WIRETYPE_VARINT));
276         size += io::CodedOutputStream::VarintSize64(field.varint());
277         break;
278       case UnknownField::TYPE_FIXED32:
279         size += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
280             field.number(), WireFormatLite::WIRETYPE_FIXED32));
281         size += sizeof(int32);
282         break;
283       case UnknownField::TYPE_FIXED64:
284         size += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
285             field.number(), WireFormatLite::WIRETYPE_FIXED64));
286         size += sizeof(int64);
287         break;
288       case UnknownField::TYPE_LENGTH_DELIMITED:
289         size += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
290             field.number(), WireFormatLite::WIRETYPE_LENGTH_DELIMITED));
291         size += io::CodedOutputStream::VarintSize32(
292             field.length_delimited().size());
293         size += field.length_delimited().size();
294         break;
295       case UnknownField::TYPE_GROUP:
296         size += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
297             field.number(), WireFormatLite::WIRETYPE_START_GROUP));
298         size += ComputeUnknownFieldsSize(field.group());
299         size += io::CodedOutputStream::VarintSize32(WireFormatLite::MakeTag(
300             field.number(), WireFormatLite::WIRETYPE_END_GROUP));
301         break;
302     }
303   }
304 
305   return size;
306 }
307 
ComputeUnknownMessageSetItemsSize(const UnknownFieldSet & unknown_fields)308 size_t WireFormat::ComputeUnknownMessageSetItemsSize(
309     const UnknownFieldSet& unknown_fields) {
310   size_t size = 0;
311   for (int i = 0; i < unknown_fields.field_count(); i++) {
312     const UnknownField& field = unknown_fields.field(i);
313 
314     // The only unknown fields that are allowed to exist in a MessageSet are
315     // messages, which are length-delimited.
316     if (field.type() == UnknownField::TYPE_LENGTH_DELIMITED) {
317       size += WireFormatLite::kMessageSetItemTagsSize;
318       size += io::CodedOutputStream::VarintSize32(field.number());
319 
320       int field_size = field.GetLengthDelimitedSize();
321       size += io::CodedOutputStream::VarintSize32(field_size);
322       size += field_size;
323     }
324   }
325 
326   return size;
327 }
328 
329 // ===================================================================
330 
ParseAndMergePartial(io::CodedInputStream * input,Message * message)331 bool WireFormat::ParseAndMergePartial(io::CodedInputStream* input,
332                                       Message* message) {
333   const Descriptor* descriptor = message->GetDescriptor();
334   const Reflection* message_reflection = message->GetReflection();
335 
336   while (true) {
337     uint32 tag = input->ReadTag();
338     if (tag == 0) {
339       // End of input.  This is a valid place to end, so return true.
340       return true;
341     }
342 
343     if (WireFormatLite::GetTagWireType(tag) ==
344         WireFormatLite::WIRETYPE_END_GROUP) {
345       // Must be the end of the message.
346       return true;
347     }
348 
349     const FieldDescriptor* field = NULL;
350 
351     if (descriptor != NULL) {
352       int field_number = WireFormatLite::GetTagFieldNumber(tag);
353       field = descriptor->FindFieldByNumber(field_number);
354 
355       // If that failed, check if the field is an extension.
356       if (field == NULL && descriptor->IsExtensionNumber(field_number)) {
357         if (input->GetExtensionPool() == NULL) {
358           field = message_reflection->FindKnownExtensionByNumber(field_number);
359         } else {
360           field = input->GetExtensionPool()->FindExtensionByNumber(
361               descriptor, field_number);
362         }
363       }
364 
365       // If that failed, but we're a MessageSet, and this is the tag for a
366       // MessageSet item, then parse that.
367       if (field == NULL && descriptor->options().message_set_wire_format() &&
368           tag == WireFormatLite::kMessageSetItemStartTag) {
369         if (!ParseAndMergeMessageSetItem(input, message)) {
370           return false;
371         }
372         continue;  // Skip ParseAndMergeField(); already taken care of.
373       }
374     }
375 
376     if (!ParseAndMergeField(tag, field, message, input)) {
377       return false;
378     }
379   }
380 }
381 
SkipMessageSetField(io::CodedInputStream * input,uint32 field_number,UnknownFieldSet * unknown_fields)382 bool WireFormat::SkipMessageSetField(io::CodedInputStream* input,
383                                      uint32 field_number,
384                                      UnknownFieldSet* unknown_fields) {
385   uint32 length;
386   if (!input->ReadVarint32(&length)) return false;
387   return input->ReadString(unknown_fields->AddLengthDelimited(field_number),
388                            length);
389 }
390 
ParseAndMergeMessageSetField(uint32 field_number,const FieldDescriptor * field,Message * message,io::CodedInputStream * input)391 bool WireFormat::ParseAndMergeMessageSetField(uint32 field_number,
392                                               const FieldDescriptor* field,
393                                               Message* message,
394                                               io::CodedInputStream* input) {
395   const Reflection* message_reflection = message->GetReflection();
396   if (field == NULL) {
397     // We store unknown MessageSet extensions as groups.
398     return SkipMessageSetField(
399         input, field_number, message_reflection->MutableUnknownFields(message));
400   } else if (field->is_repeated() ||
401              field->type() != FieldDescriptor::TYPE_MESSAGE) {
402     // This shouldn't happen as we only allow optional message extensions to
403     // MessageSet.
404     GOOGLE_LOG(ERROR) << "Extensions of MessageSets must be optional messages.";
405     return false;
406   } else {
407     Message* sub_message = message_reflection->MutableMessage(
408         message, field, input->GetExtensionFactory());
409     return WireFormatLite::ReadMessage(input, sub_message);
410   }
411 }
412 
StrictUtf8Check(const FieldDescriptor * field)413 static bool StrictUtf8Check(const FieldDescriptor* field) {
414   return field->file()->syntax() == FileDescriptor::SYNTAX_PROTO3;
415 }
416 
ParseAndMergeField(uint32 tag,const FieldDescriptor * field,Message * message,io::CodedInputStream * input)417 bool WireFormat::ParseAndMergeField(
418     uint32 tag,
419     const FieldDescriptor* field,  // May be NULL for unknown
420     Message* message, io::CodedInputStream* input) {
421   const Reflection* message_reflection = message->GetReflection();
422 
423   enum { UNKNOWN, NORMAL_FORMAT, PACKED_FORMAT } value_format;
424 
425   if (field == NULL) {
426     value_format = UNKNOWN;
427   } else if (WireFormatLite::GetTagWireType(tag) ==
428              WireTypeForFieldType(field->type())) {
429     value_format = NORMAL_FORMAT;
430   } else if (field->is_packable() &&
431              WireFormatLite::GetTagWireType(tag) ==
432                  WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
433     value_format = PACKED_FORMAT;
434   } else {
435     // We don't recognize this field. Either the field number is unknown
436     // or the wire type doesn't match. Put it in our unknown field set.
437     value_format = UNKNOWN;
438   }
439 
440   if (value_format == UNKNOWN) {
441     return SkipField(input, tag,
442                      message_reflection->MutableUnknownFields(message));
443   } else if (value_format == PACKED_FORMAT) {
444     uint32 length;
445     if (!input->ReadVarint32(&length)) return false;
446     io::CodedInputStream::Limit limit = input->PushLimit(length);
447 
448     switch (field->type()) {
449 #define HANDLE_PACKED_TYPE(TYPE, CPPTYPE, CPPTYPE_METHOD)                      \
450   case FieldDescriptor::TYPE_##TYPE: {                                         \
451     while (input->BytesUntilLimit() > 0) {                                     \
452       CPPTYPE value;                                                           \
453       if (!WireFormatLite::ReadPrimitive<CPPTYPE,                              \
454                                          WireFormatLite::TYPE_##TYPE>(input,   \
455                                                                       &value)) \
456         return false;                                                          \
457       message_reflection->Add##CPPTYPE_METHOD(message, field, value);          \
458     }                                                                          \
459     break;                                                                     \
460   }
461 
462       HANDLE_PACKED_TYPE(INT32, int32, Int32)
463       HANDLE_PACKED_TYPE(INT64, int64, Int64)
464       HANDLE_PACKED_TYPE(SINT32, int32, Int32)
465       HANDLE_PACKED_TYPE(SINT64, int64, Int64)
466       HANDLE_PACKED_TYPE(UINT32, uint32, UInt32)
467       HANDLE_PACKED_TYPE(UINT64, uint64, UInt64)
468 
469       HANDLE_PACKED_TYPE(FIXED32, uint32, UInt32)
470       HANDLE_PACKED_TYPE(FIXED64, uint64, UInt64)
471       HANDLE_PACKED_TYPE(SFIXED32, int32, Int32)
472       HANDLE_PACKED_TYPE(SFIXED64, int64, Int64)
473 
474       HANDLE_PACKED_TYPE(FLOAT, float, Float)
475       HANDLE_PACKED_TYPE(DOUBLE, double, Double)
476 
477       HANDLE_PACKED_TYPE(BOOL, bool, Bool)
478 #undef HANDLE_PACKED_TYPE
479 
480       case FieldDescriptor::TYPE_ENUM: {
481         while (input->BytesUntilLimit() > 0) {
482           int value;
483           if (!WireFormatLite::ReadPrimitive<int, WireFormatLite::TYPE_ENUM>(
484                   input, &value))
485             return false;
486           if (message->GetDescriptor()->file()->syntax() ==
487               FileDescriptor::SYNTAX_PROTO3) {
488             message_reflection->AddEnumValue(message, field, value);
489           } else {
490             const EnumValueDescriptor* enum_value =
491                 field->enum_type()->FindValueByNumber(value);
492             if (enum_value != NULL) {
493               message_reflection->AddEnum(message, field, enum_value);
494             } else {
495               // The enum value is not one of the known values.  Add it to the
496               // UnknownFieldSet.
497               int64 sign_extended_value = static_cast<int64>(value);
498               message_reflection->MutableUnknownFields(message)->AddVarint(
499                   WireFormatLite::GetTagFieldNumber(tag), sign_extended_value);
500             }
501           }
502         }
503 
504         break;
505       }
506 
507       case FieldDescriptor::TYPE_STRING:
508       case FieldDescriptor::TYPE_GROUP:
509       case FieldDescriptor::TYPE_MESSAGE:
510       case FieldDescriptor::TYPE_BYTES:
511         // Can't have packed fields of these types: these should be caught by
512         // the protocol compiler.
513         return false;
514         break;
515     }
516 
517     input->PopLimit(limit);
518   } else {
519     // Non-packed value (value_format == NORMAL_FORMAT)
520     switch (field->type()) {
521 #define HANDLE_TYPE(TYPE, CPPTYPE, CPPTYPE_METHOD)                            \
522   case FieldDescriptor::TYPE_##TYPE: {                                        \
523     CPPTYPE value;                                                            \
524     if (!WireFormatLite::ReadPrimitive<CPPTYPE, WireFormatLite::TYPE_##TYPE>( \
525             input, &value))                                                   \
526       return false;                                                           \
527     if (field->is_repeated()) {                                               \
528       message_reflection->Add##CPPTYPE_METHOD(message, field, value);         \
529     } else {                                                                  \
530       message_reflection->Set##CPPTYPE_METHOD(message, field, value);         \
531     }                                                                         \
532     break;                                                                    \
533   }
534 
535       HANDLE_TYPE(INT32, int32, Int32)
536       HANDLE_TYPE(INT64, int64, Int64)
537       HANDLE_TYPE(SINT32, int32, Int32)
538       HANDLE_TYPE(SINT64, int64, Int64)
539       HANDLE_TYPE(UINT32, uint32, UInt32)
540       HANDLE_TYPE(UINT64, uint64, UInt64)
541 
542       HANDLE_TYPE(FIXED32, uint32, UInt32)
543       HANDLE_TYPE(FIXED64, uint64, UInt64)
544       HANDLE_TYPE(SFIXED32, int32, Int32)
545       HANDLE_TYPE(SFIXED64, int64, Int64)
546 
547       HANDLE_TYPE(FLOAT, float, Float)
548       HANDLE_TYPE(DOUBLE, double, Double)
549 
550       HANDLE_TYPE(BOOL, bool, Bool)
551 #undef HANDLE_TYPE
552 
553       case FieldDescriptor::TYPE_ENUM: {
554         int value;
555         if (!WireFormatLite::ReadPrimitive<int, WireFormatLite::TYPE_ENUM>(
556                 input, &value))
557           return false;
558         if (field->is_repeated()) {
559           message_reflection->AddEnumValue(message, field, value);
560         } else {
561           message_reflection->SetEnumValue(message, field, value);
562         }
563         break;
564       }
565 
566       // Handle strings separately so that we can optimize the ctype=CORD case.
567       case FieldDescriptor::TYPE_STRING: {
568         bool strict_utf8_check = StrictUtf8Check(field);
569         std::string value;
570         if (!WireFormatLite::ReadString(input, &value)) return false;
571         if (strict_utf8_check) {
572           if (!WireFormatLite::VerifyUtf8String(value.data(), value.length(),
573                                                 WireFormatLite::PARSE,
574                                                 field->full_name().c_str())) {
575             return false;
576           }
577         } else {
578           VerifyUTF8StringNamedField(value.data(), value.length(), PARSE,
579                                      field->full_name().c_str());
580         }
581         if (field->is_repeated()) {
582           message_reflection->AddString(message, field, value);
583         } else {
584           message_reflection->SetString(message, field, value);
585         }
586         break;
587       }
588 
589       case FieldDescriptor::TYPE_BYTES: {
590         std::string value;
591         if (!WireFormatLite::ReadBytes(input, &value)) return false;
592         if (field->is_repeated()) {
593           message_reflection->AddString(message, field, value);
594         } else {
595           message_reflection->SetString(message, field, value);
596         }
597         break;
598       }
599 
600       case FieldDescriptor::TYPE_GROUP: {
601         Message* sub_message;
602         if (field->is_repeated()) {
603           sub_message = message_reflection->AddMessage(
604               message, field, input->GetExtensionFactory());
605         } else {
606           sub_message = message_reflection->MutableMessage(
607               message, field, input->GetExtensionFactory());
608         }
609 
610         if (!WireFormatLite::ReadGroup(WireFormatLite::GetTagFieldNumber(tag),
611                                        input, sub_message))
612           return false;
613         break;
614       }
615 
616       case FieldDescriptor::TYPE_MESSAGE: {
617         Message* sub_message;
618         if (field->is_repeated()) {
619           sub_message = message_reflection->AddMessage(
620               message, field, input->GetExtensionFactory());
621         } else {
622           sub_message = message_reflection->MutableMessage(
623               message, field, input->GetExtensionFactory());
624         }
625 
626         if (!WireFormatLite::ReadMessage(input, sub_message)) return false;
627         break;
628       }
629     }
630   }
631 
632   return true;
633 }
634 
ParseAndMergeMessageSetItem(io::CodedInputStream * input,Message * message)635 bool WireFormat::ParseAndMergeMessageSetItem(io::CodedInputStream* input,
636                                              Message* message) {
637   struct MSReflective {
638     bool ParseField(int type_id, io::CodedInputStream* input) {
639       const FieldDescriptor* field =
640           message_reflection->FindKnownExtensionByNumber(type_id);
641       return ParseAndMergeMessageSetField(type_id, field, message, input);
642     }
643 
644     bool SkipField(uint32 tag, io::CodedInputStream* input) {
645       return WireFormat::SkipField(input, tag, NULL);
646     }
647 
648     const Reflection* message_reflection;
649     Message* message;
650   };
651 
652   return ParseMessageSetItemImpl(
653       input, MSReflective{message->GetReflection(), message});
654 }
655 
656 // ===================================================================
657 
_InternalSerialize(const Message & message,uint8 * target,io::EpsCopyOutputStream * stream)658 uint8* WireFormat::_InternalSerialize(const Message& message, uint8* target,
659                                       io::EpsCopyOutputStream* stream) {
660   const Descriptor* descriptor = message.GetDescriptor();
661   const Reflection* message_reflection = message.GetReflection();
662 
663   std::vector<const FieldDescriptor*> fields;
664 
665   // Fields of map entry should always be serialized.
666   if (descriptor->options().map_entry()) {
667     for (int i = 0; i < descriptor->field_count(); i++) {
668       fields.push_back(descriptor->field(i));
669     }
670   } else {
671     message_reflection->ListFields(message, &fields);
672   }
673 
674   for (auto field : fields) {
675     target = InternalSerializeField(field, message, target, stream);
676   }
677 
678   if (descriptor->options().message_set_wire_format()) {
679     return InternalSerializeUnknownMessageSetItemsToArray(
680         message_reflection->GetUnknownFields(message), target, stream);
681   } else {
682     return InternalSerializeUnknownFieldsToArray(
683         message_reflection->GetUnknownFields(message), target, stream);
684   }
685 }
686 
SerializeMapKeyWithCachedSizes(const FieldDescriptor * field,const MapKey & value,uint8 * target,io::EpsCopyOutputStream * stream)687 static uint8* SerializeMapKeyWithCachedSizes(const FieldDescriptor* field,
688                                              const MapKey& value, uint8* target,
689                                              io::EpsCopyOutputStream* stream) {
690   target = stream->EnsureSpace(target);
691   switch (field->type()) {
692     case FieldDescriptor::TYPE_DOUBLE:
693     case FieldDescriptor::TYPE_FLOAT:
694     case FieldDescriptor::TYPE_GROUP:
695     case FieldDescriptor::TYPE_MESSAGE:
696     case FieldDescriptor::TYPE_BYTES:
697     case FieldDescriptor::TYPE_ENUM:
698       GOOGLE_LOG(FATAL) << "Unsupported";
699       break;
700 #define CASE_TYPE(FieldType, CamelFieldType, CamelCppType)   \
701   case FieldDescriptor::TYPE_##FieldType:                    \
702     target = WireFormatLite::Write##CamelFieldType##ToArray( \
703         1, value.Get##CamelCppType##Value(), target);        \
704     break;
705       CASE_TYPE(INT64, Int64, Int64)
706       CASE_TYPE(UINT64, UInt64, UInt64)
707       CASE_TYPE(INT32, Int32, Int32)
708       CASE_TYPE(FIXED64, Fixed64, UInt64)
709       CASE_TYPE(FIXED32, Fixed32, UInt32)
710       CASE_TYPE(BOOL, Bool, Bool)
711       CASE_TYPE(UINT32, UInt32, UInt32)
712       CASE_TYPE(SFIXED32, SFixed32, Int32)
713       CASE_TYPE(SFIXED64, SFixed64, Int64)
714       CASE_TYPE(SINT32, SInt32, Int32)
715       CASE_TYPE(SINT64, SInt64, Int64)
716 #undef CASE_TYPE
717     case FieldDescriptor::TYPE_STRING:
718       target = stream->WriteString(1, value.GetStringValue(), target);
719       break;
720   }
721   return target;
722 }
723 
SerializeMapValueRefWithCachedSizes(const FieldDescriptor * field,const MapValueRef & value,uint8 * target,io::EpsCopyOutputStream * stream)724 static uint8* SerializeMapValueRefWithCachedSizes(
725     const FieldDescriptor* field, const MapValueRef& value, uint8* target,
726     io::EpsCopyOutputStream* stream) {
727   target = stream->EnsureSpace(target);
728   switch (field->type()) {
729 #define CASE_TYPE(FieldType, CamelFieldType, CamelCppType)   \
730   case FieldDescriptor::TYPE_##FieldType:                    \
731     target = WireFormatLite::Write##CamelFieldType##ToArray( \
732         2, value.Get##CamelCppType##Value(), target);        \
733     break;
734     CASE_TYPE(INT64, Int64, Int64)
735     CASE_TYPE(UINT64, UInt64, UInt64)
736     CASE_TYPE(INT32, Int32, Int32)
737     CASE_TYPE(FIXED64, Fixed64, UInt64)
738     CASE_TYPE(FIXED32, Fixed32, UInt32)
739     CASE_TYPE(BOOL, Bool, Bool)
740     CASE_TYPE(UINT32, UInt32, UInt32)
741     CASE_TYPE(SFIXED32, SFixed32, Int32)
742     CASE_TYPE(SFIXED64, SFixed64, Int64)
743     CASE_TYPE(SINT32, SInt32, Int32)
744     CASE_TYPE(SINT64, SInt64, Int64)
745     CASE_TYPE(ENUM, Enum, Enum)
746     CASE_TYPE(DOUBLE, Double, Double)
747     CASE_TYPE(FLOAT, Float, Float)
748 #undef CASE_TYPE
749     case FieldDescriptor::TYPE_STRING:
750     case FieldDescriptor::TYPE_BYTES:
751       target = stream->WriteString(2, value.GetStringValue(), target);
752       break;
753     case FieldDescriptor::TYPE_MESSAGE:
754       target = WireFormatLite::InternalWriteMessage(2, value.GetMessageValue(),
755                                                     target, stream);
756       break;
757     case FieldDescriptor::TYPE_GROUP:
758       target = WireFormatLite::InternalWriteGroup(2, value.GetMessageValue(),
759                                                   target, stream);
760       break;
761   }
762   return target;
763 }
764 
765 class MapKeySorter {
766  public:
SortKey(const Message & message,const Reflection * reflection,const FieldDescriptor * field)767   static std::vector<MapKey> SortKey(const Message& message,
768                                      const Reflection* reflection,
769                                      const FieldDescriptor* field) {
770     std::vector<MapKey> sorted_key_list;
771     for (MapIterator it =
772              reflection->MapBegin(const_cast<Message*>(&message), field);
773          it != reflection->MapEnd(const_cast<Message*>(&message), field);
774          ++it) {
775       sorted_key_list.push_back(it.GetKey());
776     }
777     MapKeyComparator comparator;
778     std::sort(sorted_key_list.begin(), sorted_key_list.end(), comparator);
779     return sorted_key_list;
780   }
781 
782  private:
783   class MapKeyComparator {
784    public:
operator ()(const MapKey & a,const MapKey & b) const785     bool operator()(const MapKey& a, const MapKey& b) const {
786       GOOGLE_DCHECK(a.type() == b.type());
787       switch (a.type()) {
788 #define CASE_TYPE(CppType, CamelCppType)                                \
789   case FieldDescriptor::CPPTYPE_##CppType: {                            \
790     return a.Get##CamelCppType##Value() < b.Get##CamelCppType##Value(); \
791   }
792         CASE_TYPE(STRING, String)
793         CASE_TYPE(INT64, Int64)
794         CASE_TYPE(INT32, Int32)
795         CASE_TYPE(UINT64, UInt64)
796         CASE_TYPE(UINT32, UInt32)
797         CASE_TYPE(BOOL, Bool)
798 #undef CASE_TYPE
799 
800         default:
801           GOOGLE_LOG(DFATAL) << "Invalid key for map field.";
802           return true;
803       }
804     }
805   };
806 };
807 
InternalSerializeMapEntry(const FieldDescriptor * field,const MapKey & key,const MapValueRef & value,uint8 * target,io::EpsCopyOutputStream * stream)808 static uint8* InternalSerializeMapEntry(const FieldDescriptor* field,
809                                         const MapKey& key,
810                                         const MapValueRef& value, uint8* target,
811                                         io::EpsCopyOutputStream* stream) {
812   const FieldDescriptor* key_field = field->message_type()->field(0);
813   const FieldDescriptor* value_field = field->message_type()->field(1);
814 
815   size_t size = kMapEntryTagByteSize;
816   size += MapKeyDataOnlyByteSize(key_field, key);
817   size += MapValueRefDataOnlyByteSize(value_field, value);
818   target = stream->EnsureSpace(target);
819   target = WireFormatLite::WriteTagToArray(
820       field->number(), WireFormatLite::WIRETYPE_LENGTH_DELIMITED, target);
821   target = io::CodedOutputStream::WriteVarint32ToArray(size, target);
822   target = SerializeMapKeyWithCachedSizes(key_field, key, target, stream);
823   target =
824       SerializeMapValueRefWithCachedSizes(value_field, value, target, stream);
825   return target;
826 }
827 
InternalSerializeField(const FieldDescriptor * field,const Message & message,uint8 * target,io::EpsCopyOutputStream * stream)828 uint8* WireFormat::InternalSerializeField(const FieldDescriptor* field,
829                                           const Message& message, uint8* target,
830                                           io::EpsCopyOutputStream* stream) {
831   const Reflection* message_reflection = message.GetReflection();
832 
833   if (field->is_extension() &&
834       field->containing_type()->options().message_set_wire_format() &&
835       field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
836       !field->is_repeated()) {
837     return InternalSerializeMessageSetItem(field, message, target, stream);
838   }
839 
840   // For map fields, we can use either repeated field reflection or map
841   // reflection.  Our choice has some subtle effects.  If we use repeated field
842   // reflection here, then the repeated field representation becomes
843   // authoritative for this field: any existing references that came from map
844   // reflection remain valid for reading, but mutations to them are lost and
845   // will be overwritten next time we call map reflection!
846   //
847   // So far this mainly affects Python, which keeps long-term references to map
848   // values around, and always uses map reflection.  See: b/35918691
849   //
850   // Here we choose to use map reflection API as long as the internal
851   // map is valid. In this way, the serialization doesn't change map field's
852   // internal state and existing references that came from map reflection remain
853   // valid for both reading and writing.
854   if (field->is_map()) {
855     const MapFieldBase* map_field =
856         message_reflection->GetMapData(message, field);
857     if (map_field->IsMapValid()) {
858       if (stream->IsSerializationDeterministic()) {
859         std::vector<MapKey> sorted_key_list =
860             MapKeySorter::SortKey(message, message_reflection, field);
861         for (std::vector<MapKey>::iterator it = sorted_key_list.begin();
862              it != sorted_key_list.end(); ++it) {
863           MapValueRef map_value;
864           message_reflection->InsertOrLookupMapValue(
865               const_cast<Message*>(&message), field, *it, &map_value);
866           target =
867               InternalSerializeMapEntry(field, *it, map_value, target, stream);
868         }
869       } else {
870         for (MapIterator it = message_reflection->MapBegin(
871                  const_cast<Message*>(&message), field);
872              it !=
873              message_reflection->MapEnd(const_cast<Message*>(&message), field);
874              ++it) {
875           target = InternalSerializeMapEntry(field, it.GetKey(),
876                                              it.GetValueRef(), target, stream);
877         }
878       }
879 
880       return target;
881     }
882   }
883   int count = 0;
884 
885   if (field->is_repeated()) {
886     count = message_reflection->FieldSize(message, field);
887   } else if (field->containing_type()->options().map_entry()) {
888     // Map entry fields always need to be serialized.
889     count = 1;
890   } else if (message_reflection->HasField(message, field)) {
891     count = 1;
892   }
893 
894   // map_entries is for maps that'll be deterministically serialized.
895   std::vector<const Message*> map_entries;
896   if (count > 1 && field->is_map() && stream->IsSerializationDeterministic()) {
897     map_entries =
898         DynamicMapSorter::Sort(message, count, message_reflection, field);
899   }
900 
901   if (field->is_packed()) {
902     if (count == 0) return target;
903     target = stream->EnsureSpace(target);
904     switch (field->type()) {
905 #define HANDLE_PRIMITIVE_TYPE(TYPE, CPPTYPE, TYPE_METHOD, CPPTYPE_METHOD)      \
906   case FieldDescriptor::TYPE_##TYPE: {                                         \
907     auto r =                                                                   \
908         message_reflection->GetRepeatedFieldInternal<CPPTYPE>(message, field); \
909     target = stream->Write##TYPE_METHOD##Packed(                               \
910         field->number(), r, FieldDataOnlyByteSize(field, message), target);    \
911     break;                                                                     \
912   }
913 
914       HANDLE_PRIMITIVE_TYPE(INT32, int32, Int32, Int32)
915       HANDLE_PRIMITIVE_TYPE(INT64, int64, Int64, Int64)
916       HANDLE_PRIMITIVE_TYPE(SINT32, int32, SInt32, Int32)
917       HANDLE_PRIMITIVE_TYPE(SINT64, int64, SInt64, Int64)
918       HANDLE_PRIMITIVE_TYPE(UINT32, uint32, UInt32, UInt32)
919       HANDLE_PRIMITIVE_TYPE(UINT64, uint64, UInt64, UInt64)
920       HANDLE_PRIMITIVE_TYPE(ENUM, int, Enum, Enum)
921 
922 #undef HANDLE_PRIMITIVE_TYPE
923 #define HANDLE_PRIMITIVE_TYPE(TYPE, CPPTYPE, TYPE_METHOD, CPPTYPE_METHOD)      \
924   case FieldDescriptor::TYPE_##TYPE: {                                         \
925     auto r =                                                                   \
926         message_reflection->GetRepeatedFieldInternal<CPPTYPE>(message, field); \
927     target = stream->WriteFixedPacked(field->number(), r, target);             \
928     break;                                                                     \
929   }
930 
931       HANDLE_PRIMITIVE_TYPE(FIXED32, uint32, Fixed32, UInt32)
932       HANDLE_PRIMITIVE_TYPE(FIXED64, uint64, Fixed64, UInt64)
933       HANDLE_PRIMITIVE_TYPE(SFIXED32, int32, SFixed32, Int32)
934       HANDLE_PRIMITIVE_TYPE(SFIXED64, int64, SFixed64, Int64)
935 
936       HANDLE_PRIMITIVE_TYPE(FLOAT, float, Float, Float)
937       HANDLE_PRIMITIVE_TYPE(DOUBLE, double, Double, Double)
938 
939       HANDLE_PRIMITIVE_TYPE(BOOL, bool, Bool, Bool)
940 #undef HANDLE_PRIMITIVE_TYPE
941       default:
942         GOOGLE_LOG(FATAL) << "Invalid descriptor";
943     }
944     return target;
945   }
946 
947   for (int j = 0; j < count; j++) {
948     target = stream->EnsureSpace(target);
949     switch (field->type()) {
950 #define HANDLE_PRIMITIVE_TYPE(TYPE, CPPTYPE, TYPE_METHOD, CPPTYPE_METHOD)     \
951   case FieldDescriptor::TYPE_##TYPE: {                                        \
952     const CPPTYPE value =                                                     \
953         field->is_repeated()                                                  \
954             ? message_reflection->GetRepeated##CPPTYPE_METHOD(message, field, \
955                                                               j)              \
956             : message_reflection->Get##CPPTYPE_METHOD(message, field);        \
957     target = WireFormatLite::Write##TYPE_METHOD##ToArray(field->number(),     \
958                                                          value, target);      \
959     break;                                                                    \
960   }
961 
962       HANDLE_PRIMITIVE_TYPE(INT32, int32, Int32, Int32)
963       HANDLE_PRIMITIVE_TYPE(INT64, int64, Int64, Int64)
964       HANDLE_PRIMITIVE_TYPE(SINT32, int32, SInt32, Int32)
965       HANDLE_PRIMITIVE_TYPE(SINT64, int64, SInt64, Int64)
966       HANDLE_PRIMITIVE_TYPE(UINT32, uint32, UInt32, UInt32)
967       HANDLE_PRIMITIVE_TYPE(UINT64, uint64, UInt64, UInt64)
968 
969       HANDLE_PRIMITIVE_TYPE(FIXED32, uint32, Fixed32, UInt32)
970       HANDLE_PRIMITIVE_TYPE(FIXED64, uint64, Fixed64, UInt64)
971       HANDLE_PRIMITIVE_TYPE(SFIXED32, int32, SFixed32, Int32)
972       HANDLE_PRIMITIVE_TYPE(SFIXED64, int64, SFixed64, Int64)
973 
974       HANDLE_PRIMITIVE_TYPE(FLOAT, float, Float, Float)
975       HANDLE_PRIMITIVE_TYPE(DOUBLE, double, Double, Double)
976 
977       HANDLE_PRIMITIVE_TYPE(BOOL, bool, Bool, Bool)
978 #undef HANDLE_PRIMITIVE_TYPE
979 
980 #define HANDLE_TYPE(TYPE, TYPE_METHOD, CPPTYPE_METHOD)                         \
981   case FieldDescriptor::TYPE_##TYPE:                                           \
982     target = WireFormatLite::InternalWrite##TYPE_METHOD(                       \
983         field->number(),                                                       \
984         field->is_repeated()                                                   \
985             ? (map_entries.empty()                                             \
986                    ? message_reflection->GetRepeated##CPPTYPE_METHOD(message,  \
987                                                                      field, j) \
988                    : *map_entries[j])                                          \
989             : message_reflection->Get##CPPTYPE_METHOD(message, field),         \
990         target, stream);                                                       \
991     break;
992 
993       HANDLE_TYPE(GROUP, Group, Message)
994       HANDLE_TYPE(MESSAGE, Message, Message)
995 #undef HANDLE_TYPE
996 
997       case FieldDescriptor::TYPE_ENUM: {
998         const EnumValueDescriptor* value =
999             field->is_repeated()
1000                 ? message_reflection->GetRepeatedEnum(message, field, j)
1001                 : message_reflection->GetEnum(message, field);
1002         target = WireFormatLite::WriteEnumToArray(field->number(),
1003                                                   value->number(), target);
1004         break;
1005       }
1006 
1007       // Handle strings separately so that we can get string references
1008       // instead of copying.
1009       case FieldDescriptor::TYPE_STRING: {
1010         bool strict_utf8_check = StrictUtf8Check(field);
1011         std::string scratch;
1012         const std::string& value =
1013             field->is_repeated()
1014                 ? message_reflection->GetRepeatedStringReference(message, field,
1015                                                                  j, &scratch)
1016                 : message_reflection->GetStringReference(message, field,
1017                                                          &scratch);
1018         if (strict_utf8_check) {
1019           WireFormatLite::VerifyUtf8String(value.data(), value.length(),
1020                                            WireFormatLite::SERIALIZE,
1021                                            field->full_name().c_str());
1022         } else {
1023           VerifyUTF8StringNamedField(value.data(), value.length(), SERIALIZE,
1024                                      field->full_name().c_str());
1025         }
1026         target = stream->WriteString(field->number(), value, target);
1027         break;
1028       }
1029 
1030       case FieldDescriptor::TYPE_BYTES: {
1031         std::string scratch;
1032         const std::string& value =
1033             field->is_repeated()
1034                 ? message_reflection->GetRepeatedStringReference(message, field,
1035                                                                  j, &scratch)
1036                 : message_reflection->GetStringReference(message, field,
1037                                                          &scratch);
1038         target = stream->WriteString(field->number(), value, target);
1039         break;
1040       }
1041     }
1042   }
1043   return target;
1044 }
1045 
InternalSerializeMessageSetItem(const FieldDescriptor * field,const Message & message,uint8 * target,io::EpsCopyOutputStream * stream)1046 uint8* WireFormat::InternalSerializeMessageSetItem(
1047     const FieldDescriptor* field, const Message& message, uint8* target,
1048     io::EpsCopyOutputStream* stream) {
1049   const Reflection* message_reflection = message.GetReflection();
1050 
1051   target = stream->EnsureSpace(target);
1052   // Start group.
1053   target = io::CodedOutputStream::WriteTagToArray(
1054       WireFormatLite::kMessageSetItemStartTag, target);
1055   // Write type ID.
1056   target = WireFormatLite::WriteUInt32ToArray(
1057       WireFormatLite::kMessageSetTypeIdNumber, field->number(), target);
1058   // Write message.
1059   target = WireFormatLite::InternalWriteMessage(
1060       WireFormatLite::kMessageSetMessageNumber,
1061       message_reflection->GetMessage(message, field), target, stream);
1062   // End group.
1063   target = stream->EnsureSpace(target);
1064   target = io::CodedOutputStream::WriteTagToArray(
1065       WireFormatLite::kMessageSetItemEndTag, target);
1066   return target;
1067 }
1068 
1069 // ===================================================================
1070 
ByteSize(const Message & message)1071 size_t WireFormat::ByteSize(const Message& message) {
1072   const Descriptor* descriptor = message.GetDescriptor();
1073   const Reflection* message_reflection = message.GetReflection();
1074 
1075   size_t our_size = 0;
1076 
1077   std::vector<const FieldDescriptor*> fields;
1078 
1079   // Fields of map entry should always be serialized.
1080   if (descriptor->options().map_entry()) {
1081     for (int i = 0; i < descriptor->field_count(); i++) {
1082       fields.push_back(descriptor->field(i));
1083     }
1084   } else {
1085     message_reflection->ListFields(message, &fields);
1086   }
1087 
1088   for (int i = 0; i < fields.size(); i++) {
1089     our_size += FieldByteSize(fields[i], message);
1090   }
1091 
1092   if (descriptor->options().message_set_wire_format()) {
1093     our_size += ComputeUnknownMessageSetItemsSize(
1094         message_reflection->GetUnknownFields(message));
1095   } else {
1096     our_size +=
1097         ComputeUnknownFieldsSize(message_reflection->GetUnknownFields(message));
1098   }
1099 
1100   return our_size;
1101 }
1102 
FieldByteSize(const FieldDescriptor * field,const Message & message)1103 size_t WireFormat::FieldByteSize(const FieldDescriptor* field,
1104                                  const Message& message) {
1105   const Reflection* message_reflection = message.GetReflection();
1106 
1107   if (field->is_extension() &&
1108       field->containing_type()->options().message_set_wire_format() &&
1109       field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE &&
1110       !field->is_repeated()) {
1111     return MessageSetItemByteSize(field, message);
1112   }
1113 
1114   size_t count = 0;
1115   if (field->is_repeated()) {
1116     if (field->is_map()) {
1117       const MapFieldBase* map_field =
1118           message_reflection->GetMapData(message, field);
1119       if (map_field->IsMapValid()) {
1120         count = FromIntSize(map_field->size());
1121       } else {
1122         count = FromIntSize(message_reflection->FieldSize(message, field));
1123       }
1124     } else {
1125       count = FromIntSize(message_reflection->FieldSize(message, field));
1126     }
1127   } else if (field->containing_type()->options().map_entry()) {
1128     // Map entry fields always need to be serialized.
1129     count = 1;
1130   } else if (message_reflection->HasField(message, field)) {
1131     count = 1;
1132   }
1133 
1134   const size_t data_size = FieldDataOnlyByteSize(field, message);
1135   size_t our_size = data_size;
1136   if (field->is_packed()) {
1137     if (data_size > 0) {
1138       // Packed fields get serialized like a string, not their native type.
1139       // Technically this doesn't really matter; the size only changes if it's
1140       // a GROUP
1141       our_size += TagSize(field->number(), FieldDescriptor::TYPE_STRING);
1142       our_size += io::CodedOutputStream::VarintSize32(data_size);
1143     }
1144   } else {
1145     our_size += count * TagSize(field->number(), field->type());
1146   }
1147   return our_size;
1148 }
1149 
MapKeyDataOnlyByteSize(const FieldDescriptor * field,const MapKey & value)1150 static size_t MapKeyDataOnlyByteSize(const FieldDescriptor* field,
1151                                      const MapKey& value) {
1152   GOOGLE_DCHECK_EQ(FieldDescriptor::TypeToCppType(field->type()), value.type());
1153   switch (field->type()) {
1154     case FieldDescriptor::TYPE_DOUBLE:
1155     case FieldDescriptor::TYPE_FLOAT:
1156     case FieldDescriptor::TYPE_GROUP:
1157     case FieldDescriptor::TYPE_MESSAGE:
1158     case FieldDescriptor::TYPE_BYTES:
1159     case FieldDescriptor::TYPE_ENUM:
1160       GOOGLE_LOG(FATAL) << "Unsupported";
1161       return 0;
1162 #define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \
1163   case FieldDescriptor::TYPE_##FieldType:                  \
1164     return WireFormatLite::CamelFieldType##Size(           \
1165         value.Get##CamelCppType##Value());
1166 
1167 #define FIXED_CASE_TYPE(FieldType, CamelFieldType) \
1168   case FieldDescriptor::TYPE_##FieldType:          \
1169     return WireFormatLite::k##CamelFieldType##Size;
1170 
1171       CASE_TYPE(INT32, Int32, Int32);
1172       CASE_TYPE(INT64, Int64, Int64);
1173       CASE_TYPE(UINT32, UInt32, UInt32);
1174       CASE_TYPE(UINT64, UInt64, UInt64);
1175       CASE_TYPE(SINT32, SInt32, Int32);
1176       CASE_TYPE(SINT64, SInt64, Int64);
1177       CASE_TYPE(STRING, String, String);
1178       FIXED_CASE_TYPE(FIXED32, Fixed32);
1179       FIXED_CASE_TYPE(FIXED64, Fixed64);
1180       FIXED_CASE_TYPE(SFIXED32, SFixed32);
1181       FIXED_CASE_TYPE(SFIXED64, SFixed64);
1182       FIXED_CASE_TYPE(BOOL, Bool);
1183 
1184 #undef CASE_TYPE
1185 #undef FIXED_CASE_TYPE
1186   }
1187   GOOGLE_LOG(FATAL) << "Cannot get here";
1188   return 0;
1189 }
1190 
MapValueRefDataOnlyByteSize(const FieldDescriptor * field,const MapValueRef & value)1191 static size_t MapValueRefDataOnlyByteSize(const FieldDescriptor* field,
1192                                           const MapValueRef& value) {
1193   switch (field->type()) {
1194     case FieldDescriptor::TYPE_GROUP:
1195       GOOGLE_LOG(FATAL) << "Unsupported";
1196       return 0;
1197 #define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \
1198   case FieldDescriptor::TYPE_##FieldType:                  \
1199     return WireFormatLite::CamelFieldType##Size(           \
1200         value.Get##CamelCppType##Value());
1201 
1202 #define FIXED_CASE_TYPE(FieldType, CamelFieldType) \
1203   case FieldDescriptor::TYPE_##FieldType:          \
1204     return WireFormatLite::k##CamelFieldType##Size;
1205 
1206       CASE_TYPE(INT32, Int32, Int32);
1207       CASE_TYPE(INT64, Int64, Int64);
1208       CASE_TYPE(UINT32, UInt32, UInt32);
1209       CASE_TYPE(UINT64, UInt64, UInt64);
1210       CASE_TYPE(SINT32, SInt32, Int32);
1211       CASE_TYPE(SINT64, SInt64, Int64);
1212       CASE_TYPE(STRING, String, String);
1213       CASE_TYPE(BYTES, Bytes, String);
1214       CASE_TYPE(ENUM, Enum, Enum);
1215       CASE_TYPE(MESSAGE, Message, Message);
1216       FIXED_CASE_TYPE(FIXED32, Fixed32);
1217       FIXED_CASE_TYPE(FIXED64, Fixed64);
1218       FIXED_CASE_TYPE(SFIXED32, SFixed32);
1219       FIXED_CASE_TYPE(SFIXED64, SFixed64);
1220       FIXED_CASE_TYPE(DOUBLE, Double);
1221       FIXED_CASE_TYPE(FLOAT, Float);
1222       FIXED_CASE_TYPE(BOOL, Bool);
1223 
1224 #undef CASE_TYPE
1225 #undef FIXED_CASE_TYPE
1226   }
1227   GOOGLE_LOG(FATAL) << "Cannot get here";
1228   return 0;
1229 }
1230 
FieldDataOnlyByteSize(const FieldDescriptor * field,const Message & message)1231 size_t WireFormat::FieldDataOnlyByteSize(const FieldDescriptor* field,
1232                                          const Message& message) {
1233   const Reflection* message_reflection = message.GetReflection();
1234 
1235   size_t data_size = 0;
1236 
1237   if (field->is_map()) {
1238     const MapFieldBase* map_field =
1239         message_reflection->GetMapData(message, field);
1240     if (map_field->IsMapValid()) {
1241       MapIterator iter(const_cast<Message*>(&message), field);
1242       MapIterator end(const_cast<Message*>(&message), field);
1243       const FieldDescriptor* key_field = field->message_type()->field(0);
1244       const FieldDescriptor* value_field = field->message_type()->field(1);
1245       for (map_field->MapBegin(&iter), map_field->MapEnd(&end); iter != end;
1246            ++iter) {
1247         size_t size = kMapEntryTagByteSize;
1248         size += MapKeyDataOnlyByteSize(key_field, iter.GetKey());
1249         size += MapValueRefDataOnlyByteSize(value_field, iter.GetValueRef());
1250         data_size += WireFormatLite::LengthDelimitedSize(size);
1251       }
1252       return data_size;
1253     }
1254   }
1255 
1256   size_t count = 0;
1257   if (field->is_repeated()) {
1258     count =
1259         internal::FromIntSize(message_reflection->FieldSize(message, field));
1260   } else if (field->containing_type()->options().map_entry()) {
1261     // Map entry fields always need to be serialized.
1262     count = 1;
1263   } else if (message_reflection->HasField(message, field)) {
1264     count = 1;
1265   }
1266 
1267   switch (field->type()) {
1268 #define HANDLE_TYPE(TYPE, TYPE_METHOD, CPPTYPE_METHOD)                      \
1269   case FieldDescriptor::TYPE_##TYPE:                                        \
1270     if (field->is_repeated()) {                                             \
1271       for (int j = 0; j < count; j++) {                                     \
1272         data_size += WireFormatLite::TYPE_METHOD##Size(                     \
1273             message_reflection->GetRepeated##CPPTYPE_METHOD(message, field, \
1274                                                             j));            \
1275       }                                                                     \
1276     } else {                                                                \
1277       data_size += WireFormatLite::TYPE_METHOD##Size(                       \
1278           message_reflection->Get##CPPTYPE_METHOD(message, field));         \
1279     }                                                                       \
1280     break;
1281 
1282 #define HANDLE_FIXED_TYPE(TYPE, TYPE_METHOD)                   \
1283   case FieldDescriptor::TYPE_##TYPE:                           \
1284     data_size += count * WireFormatLite::k##TYPE_METHOD##Size; \
1285     break;
1286 
1287     HANDLE_TYPE(INT32, Int32, Int32)
1288     HANDLE_TYPE(INT64, Int64, Int64)
1289     HANDLE_TYPE(SINT32, SInt32, Int32)
1290     HANDLE_TYPE(SINT64, SInt64, Int64)
1291     HANDLE_TYPE(UINT32, UInt32, UInt32)
1292     HANDLE_TYPE(UINT64, UInt64, UInt64)
1293 
1294     HANDLE_FIXED_TYPE(FIXED32, Fixed32)
1295     HANDLE_FIXED_TYPE(FIXED64, Fixed64)
1296     HANDLE_FIXED_TYPE(SFIXED32, SFixed32)
1297     HANDLE_FIXED_TYPE(SFIXED64, SFixed64)
1298 
1299     HANDLE_FIXED_TYPE(FLOAT, Float)
1300     HANDLE_FIXED_TYPE(DOUBLE, Double)
1301 
1302     HANDLE_FIXED_TYPE(BOOL, Bool)
1303 
1304     HANDLE_TYPE(GROUP, Group, Message)
1305     HANDLE_TYPE(MESSAGE, Message, Message)
1306 #undef HANDLE_TYPE
1307 #undef HANDLE_FIXED_TYPE
1308 
1309     case FieldDescriptor::TYPE_ENUM: {
1310       if (field->is_repeated()) {
1311         for (int j = 0; j < count; j++) {
1312           data_size += WireFormatLite::EnumSize(
1313               message_reflection->GetRepeatedEnum(message, field, j)->number());
1314         }
1315       } else {
1316         data_size += WireFormatLite::EnumSize(
1317             message_reflection->GetEnum(message, field)->number());
1318       }
1319       break;
1320     }
1321 
1322     // Handle strings separately so that we can get string references
1323     // instead of copying.
1324     case FieldDescriptor::TYPE_STRING:
1325     case FieldDescriptor::TYPE_BYTES: {
1326       for (int j = 0; j < count; j++) {
1327         std::string scratch;
1328         const std::string& value =
1329             field->is_repeated()
1330                 ? message_reflection->GetRepeatedStringReference(message, field,
1331                                                                  j, &scratch)
1332                 : message_reflection->GetStringReference(message, field,
1333                                                          &scratch);
1334         data_size += WireFormatLite::StringSize(value);
1335       }
1336       break;
1337     }
1338   }
1339   return data_size;
1340 }
1341 
MessageSetItemByteSize(const FieldDescriptor * field,const Message & message)1342 size_t WireFormat::MessageSetItemByteSize(const FieldDescriptor* field,
1343                                           const Message& message) {
1344   const Reflection* message_reflection = message.GetReflection();
1345 
1346   size_t our_size = WireFormatLite::kMessageSetItemTagsSize;
1347 
1348   // type_id
1349   our_size += io::CodedOutputStream::VarintSize32(field->number());
1350 
1351   // message
1352   const Message& sub_message = message_reflection->GetMessage(message, field);
1353   size_t message_size = sub_message.ByteSizeLong();
1354 
1355   our_size += io::CodedOutputStream::VarintSize32(message_size);
1356   our_size += message_size;
1357 
1358   return our_size;
1359 }
1360 
1361 // Compute the size of the UnknownFieldSet on the wire.
ComputeUnknownFieldsSize(const InternalMetadataWithArena & metadata,size_t total_size,CachedSize * cached_size)1362 size_t ComputeUnknownFieldsSize(const InternalMetadataWithArena& metadata,
1363                                 size_t total_size, CachedSize* cached_size) {
1364   total_size += WireFormat::ComputeUnknownFieldsSize(metadata.unknown_fields());
1365   cached_size->Set(ToCachedSize(total_size));
1366   return total_size;
1367 }
1368 
1369 }  // namespace internal
1370 }  // namespace protobuf
1371 }  // namespace google
1372