1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Code for encoding protocol message primitives.
32
33Contains the logic for encoding every logical protocol field type
34into one of the 5 physical wire types.
35
36This code is designed to push the Python interpreter's performance to the
37limits.
38
39The basic idea is that at startup time, for every field (i.e. every
40FieldDescriptor) we construct two functions:  a "sizer" and an "encoder".  The
41sizer takes a value of this field's type and computes its byte size.  The
42encoder takes a writer function and a value.  It encodes the value into byte
43strings and invokes the writer function to write those strings.  Typically the
44writer function is the write() method of a BytesIO.
45
46We try to do as much work as possible when constructing the writer and the
47sizer rather than when calling them.  In particular:
48* We copy any needed global functions to local variables, so that we do not need
49  to do costly global table lookups at runtime.
50* Similarly, we try to do any attribute lookups at startup time if possible.
51* Every field's tag is encoded to bytes at startup, since it can't change at
52  runtime.
53* Whatever component of the field size we can compute at startup, we do.
54* We *avoid* sharing code if doing so would make the code slower and not sharing
55  does not burden us too much.  For example, encoders for repeated fields do
56  not just call the encoders for singular fields in a loop because this would
57  add an extra function call overhead for every loop iteration; instead, we
58  manually inline the single-value encoder into the loop.
59* If a Python function lacks a return statement, Python actually generates
60  instructions to pop the result of the last statement off the stack, push
61  None onto the stack, and then return that.  If we really don't care what
62  value is returned, then we can save two instructions by returning the
63  result of the last statement.  It looks funny but it helps.
64* We assume that type and bounds checking has happened at a higher level.
65"""
66
67__author__ = 'kenton@google.com (Kenton Varda)'
68
69import struct
70
71import six
72
73from google.protobuf.internal import wire_format
74
75
76# This will overflow and thus become IEEE-754 "infinity".  We would use
77# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
78_POS_INF = 1e10000
79_NEG_INF = -_POS_INF
80
81
82def _VarintSize(value):
83  """Compute the size of a varint value."""
84  if value <= 0x7f: return 1
85  if value <= 0x3fff: return 2
86  if value <= 0x1fffff: return 3
87  if value <= 0xfffffff: return 4
88  if value <= 0x7ffffffff: return 5
89  if value <= 0x3ffffffffff: return 6
90  if value <= 0x1ffffffffffff: return 7
91  if value <= 0xffffffffffffff: return 8
92  if value <= 0x7fffffffffffffff: return 9
93  return 10
94
95
96def _SignedVarintSize(value):
97  """Compute the size of a signed varint value."""
98  if value < 0: return 10
99  if value <= 0x7f: return 1
100  if value <= 0x3fff: return 2
101  if value <= 0x1fffff: return 3
102  if value <= 0xfffffff: return 4
103  if value <= 0x7ffffffff: return 5
104  if value <= 0x3ffffffffff: return 6
105  if value <= 0x1ffffffffffff: return 7
106  if value <= 0xffffffffffffff: return 8
107  if value <= 0x7fffffffffffffff: return 9
108  return 10
109
110
111def _TagSize(field_number):
112  """Returns the number of bytes required to serialize a tag with this field
113  number."""
114  # Just pass in type 0, since the type won't affect the tag+type size.
115  return _VarintSize(wire_format.PackTag(field_number, 0))
116
117
118# --------------------------------------------------------------------
119# In this section we define some generic sizers.  Each of these functions
120# takes parameters specific to a particular field type, e.g. int32 or fixed64.
121# It returns another function which in turn takes parameters specific to a
122# particular field, e.g. the field number and whether it is repeated or packed.
123# Look at the next section to see how these are used.
124
125
126def _SimpleSizer(compute_value_size):
127  """A sizer which uses the function compute_value_size to compute the size of
128  each value.  Typically compute_value_size is _VarintSize."""
129
130  def SpecificSizer(field_number, is_repeated, is_packed):
131    tag_size = _TagSize(field_number)
132    if is_packed:
133      local_VarintSize = _VarintSize
134      def PackedFieldSize(value):
135        result = 0
136        for element in value:
137          result += compute_value_size(element)
138        return result + local_VarintSize(result) + tag_size
139      return PackedFieldSize
140    elif is_repeated:
141      def RepeatedFieldSize(value):
142        result = tag_size * len(value)
143        for element in value:
144          result += compute_value_size(element)
145        return result
146      return RepeatedFieldSize
147    else:
148      def FieldSize(value):
149        return tag_size + compute_value_size(value)
150      return FieldSize
151
152  return SpecificSizer
153
154
155def _ModifiedSizer(compute_value_size, modify_value):
156  """Like SimpleSizer, but modify_value is invoked on each value before it is
157  passed to compute_value_size.  modify_value is typically ZigZagEncode."""
158
159  def SpecificSizer(field_number, is_repeated, is_packed):
160    tag_size = _TagSize(field_number)
161    if is_packed:
162      local_VarintSize = _VarintSize
163      def PackedFieldSize(value):
164        result = 0
165        for element in value:
166          result += compute_value_size(modify_value(element))
167        return result + local_VarintSize(result) + tag_size
168      return PackedFieldSize
169    elif is_repeated:
170      def RepeatedFieldSize(value):
171        result = tag_size * len(value)
172        for element in value:
173          result += compute_value_size(modify_value(element))
174        return result
175      return RepeatedFieldSize
176    else:
177      def FieldSize(value):
178        return tag_size + compute_value_size(modify_value(value))
179      return FieldSize
180
181  return SpecificSizer
182
183
184def _FixedSizer(value_size):
185  """Like _SimpleSizer except for a fixed-size field.  The input is the size
186  of one value."""
187
188  def SpecificSizer(field_number, is_repeated, is_packed):
189    tag_size = _TagSize(field_number)
190    if is_packed:
191      local_VarintSize = _VarintSize
192      def PackedFieldSize(value):
193        result = len(value) * value_size
194        return result + local_VarintSize(result) + tag_size
195      return PackedFieldSize
196    elif is_repeated:
197      element_size = value_size + tag_size
198      def RepeatedFieldSize(value):
199        return len(value) * element_size
200      return RepeatedFieldSize
201    else:
202      field_size = value_size + tag_size
203      def FieldSize(value):
204        return field_size
205      return FieldSize
206
207  return SpecificSizer
208
209
210# ====================================================================
211# Here we declare a sizer constructor for each field type.  Each "sizer
212# constructor" is a function that takes (field_number, is_repeated, is_packed)
213# as parameters and returns a sizer, which in turn takes a field value as
214# a parameter and returns its encoded size.
215
216
217Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize)
218
219UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize)
220
221SInt32Sizer = SInt64Sizer = _ModifiedSizer(
222    _SignedVarintSize, wire_format.ZigZagEncode)
223
224Fixed32Sizer = SFixed32Sizer = FloatSizer  = _FixedSizer(4)
225Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8)
226
227BoolSizer = _FixedSizer(1)
228
229
230def StringSizer(field_number, is_repeated, is_packed):
231  """Returns a sizer for a string field."""
232
233  tag_size = _TagSize(field_number)
234  local_VarintSize = _VarintSize
235  local_len = len
236  assert not is_packed
237  if is_repeated:
238    def RepeatedFieldSize(value):
239      result = tag_size * len(value)
240      for element in value:
241        l = local_len(element.encode('utf-8'))
242        result += local_VarintSize(l) + l
243      return result
244    return RepeatedFieldSize
245  else:
246    def FieldSize(value):
247      l = local_len(value.encode('utf-8'))
248      return tag_size + local_VarintSize(l) + l
249    return FieldSize
250
251
252def BytesSizer(field_number, is_repeated, is_packed):
253  """Returns a sizer for a bytes field."""
254
255  tag_size = _TagSize(field_number)
256  local_VarintSize = _VarintSize
257  local_len = len
258  assert not is_packed
259  if is_repeated:
260    def RepeatedFieldSize(value):
261      result = tag_size * len(value)
262      for element in value:
263        l = local_len(element)
264        result += local_VarintSize(l) + l
265      return result
266    return RepeatedFieldSize
267  else:
268    def FieldSize(value):
269      l = local_len(value)
270      return tag_size + local_VarintSize(l) + l
271    return FieldSize
272
273
274def GroupSizer(field_number, is_repeated, is_packed):
275  """Returns a sizer for a group field."""
276
277  tag_size = _TagSize(field_number) * 2
278  assert not is_packed
279  if is_repeated:
280    def RepeatedFieldSize(value):
281      result = tag_size * len(value)
282      for element in value:
283        result += element.ByteSize()
284      return result
285    return RepeatedFieldSize
286  else:
287    def FieldSize(value):
288      return tag_size + value.ByteSize()
289    return FieldSize
290
291
292def MessageSizer(field_number, is_repeated, is_packed):
293  """Returns a sizer for a message field."""
294
295  tag_size = _TagSize(field_number)
296  local_VarintSize = _VarintSize
297  assert not is_packed
298  if is_repeated:
299    def RepeatedFieldSize(value):
300      result = tag_size * len(value)
301      for element in value:
302        l = element.ByteSize()
303        result += local_VarintSize(l) + l
304      return result
305    return RepeatedFieldSize
306  else:
307    def FieldSize(value):
308      l = value.ByteSize()
309      return tag_size + local_VarintSize(l) + l
310    return FieldSize
311
312
313# --------------------------------------------------------------------
314# MessageSet is special: it needs custom logic to compute its size properly.
315
316
317def MessageSetItemSizer(field_number):
318  """Returns a sizer for extensions of MessageSet.
319
320  The message set message looks like this:
321    message MessageSet {
322      repeated group Item = 1 {
323        required int32 type_id = 2;
324        required string message = 3;
325      }
326    }
327  """
328  static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) +
329                 _TagSize(3))
330  local_VarintSize = _VarintSize
331
332  def FieldSize(value):
333    l = value.ByteSize()
334    return static_size + local_VarintSize(l) + l
335
336  return FieldSize
337
338
339# --------------------------------------------------------------------
340# Map is special: it needs custom logic to compute its size properly.
341
342
343def MapSizer(field_descriptor, is_message_map):
344  """Returns a sizer for a map field."""
345
346  # Can't look at field_descriptor.message_type._concrete_class because it may
347  # not have been initialized yet.
348  message_type = field_descriptor.message_type
349  message_sizer = MessageSizer(field_descriptor.number, False, False)
350
351  def FieldSize(map_value):
352    total = 0
353    for key in map_value:
354      value = map_value[key]
355      # It's wasteful to create the messages and throw them away one second
356      # later since we'll do the same for the actual encode.  But there's not an
357      # obvious way to avoid this within the current design without tons of code
358      # duplication. For message map, value.ByteSize() should be called to
359      # update the status.
360      entry_msg = message_type._concrete_class(key=key, value=value)
361      total += message_sizer(entry_msg)
362      if is_message_map:
363        value.ByteSize()
364    return total
365
366  return FieldSize
367
368# ====================================================================
369# Encoders!
370
371
372def _VarintEncoder():
373  """Return an encoder for a basic varint value (does not include tag)."""
374
375  local_int2byte = six.int2byte
376  def EncodeVarint(write, value, unused_deterministic=None):
377    bits = value & 0x7f
378    value >>= 7
379    while value:
380      write(local_int2byte(0x80|bits))
381      bits = value & 0x7f
382      value >>= 7
383    return write(local_int2byte(bits))
384
385  return EncodeVarint
386
387
388def _SignedVarintEncoder():
389  """Return an encoder for a basic signed varint value (does not include
390  tag)."""
391
392  local_int2byte = six.int2byte
393  def EncodeSignedVarint(write, value, unused_deterministic=None):
394    if value < 0:
395      value += (1 << 64)
396    bits = value & 0x7f
397    value >>= 7
398    while value:
399      write(local_int2byte(0x80|bits))
400      bits = value & 0x7f
401      value >>= 7
402    return write(local_int2byte(bits))
403
404  return EncodeSignedVarint
405
406
407_EncodeVarint = _VarintEncoder()
408_EncodeSignedVarint = _SignedVarintEncoder()
409
410
411def _VarintBytes(value):
412  """Encode the given integer as a varint and return the bytes.  This is only
413  called at startup time so it doesn't need to be fast."""
414
415  pieces = []
416  _EncodeVarint(pieces.append, value, True)
417  return b"".join(pieces)
418
419
420def TagBytes(field_number, wire_type):
421  """Encode the given tag and return the bytes.  Only called at startup."""
422
423  return six.binary_type(
424      _VarintBytes(wire_format.PackTag(field_number, wire_type)))
425
426# --------------------------------------------------------------------
427# As with sizers (see above), we have a number of common encoder
428# implementations.
429
430
431def _SimpleEncoder(wire_type, encode_value, compute_value_size):
432  """Return a constructor for an encoder for fields of a particular type.
433
434  Args:
435      wire_type:  The field's wire type, for encoding tags.
436      encode_value:  A function which encodes an individual value, e.g.
437        _EncodeVarint().
438      compute_value_size:  A function which computes the size of an individual
439        value, e.g. _VarintSize().
440  """
441
442  def SpecificEncoder(field_number, is_repeated, is_packed):
443    if is_packed:
444      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
445      local_EncodeVarint = _EncodeVarint
446      def EncodePackedField(write, value, deterministic):
447        write(tag_bytes)
448        size = 0
449        for element in value:
450          size += compute_value_size(element)
451        local_EncodeVarint(write, size, deterministic)
452        for element in value:
453          encode_value(write, element, deterministic)
454      return EncodePackedField
455    elif is_repeated:
456      tag_bytes = TagBytes(field_number, wire_type)
457      def EncodeRepeatedField(write, value, deterministic):
458        for element in value:
459          write(tag_bytes)
460          encode_value(write, element, deterministic)
461      return EncodeRepeatedField
462    else:
463      tag_bytes = TagBytes(field_number, wire_type)
464      def EncodeField(write, value, deterministic):
465        write(tag_bytes)
466        return encode_value(write, value, deterministic)
467      return EncodeField
468
469  return SpecificEncoder
470
471
472def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
473  """Like SimpleEncoder but additionally invokes modify_value on every value
474  before passing it to encode_value.  Usually modify_value is ZigZagEncode."""
475
476  def SpecificEncoder(field_number, is_repeated, is_packed):
477    if is_packed:
478      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
479      local_EncodeVarint = _EncodeVarint
480      def EncodePackedField(write, value, deterministic):
481        write(tag_bytes)
482        size = 0
483        for element in value:
484          size += compute_value_size(modify_value(element))
485        local_EncodeVarint(write, size, deterministic)
486        for element in value:
487          encode_value(write, modify_value(element), deterministic)
488      return EncodePackedField
489    elif is_repeated:
490      tag_bytes = TagBytes(field_number, wire_type)
491      def EncodeRepeatedField(write, value, deterministic):
492        for element in value:
493          write(tag_bytes)
494          encode_value(write, modify_value(element), deterministic)
495      return EncodeRepeatedField
496    else:
497      tag_bytes = TagBytes(field_number, wire_type)
498      def EncodeField(write, value, deterministic):
499        write(tag_bytes)
500        return encode_value(write, modify_value(value), deterministic)
501      return EncodeField
502
503  return SpecificEncoder
504
505
506def _StructPackEncoder(wire_type, format):
507  """Return a constructor for an encoder for a fixed-width field.
508
509  Args:
510      wire_type:  The field's wire type, for encoding tags.
511      format:  The format string to pass to struct.pack().
512  """
513
514  value_size = struct.calcsize(format)
515
516  def SpecificEncoder(field_number, is_repeated, is_packed):
517    local_struct_pack = struct.pack
518    if is_packed:
519      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
520      local_EncodeVarint = _EncodeVarint
521      def EncodePackedField(write, value, deterministic):
522        write(tag_bytes)
523        local_EncodeVarint(write, len(value) * value_size, deterministic)
524        for element in value:
525          write(local_struct_pack(format, element))
526      return EncodePackedField
527    elif is_repeated:
528      tag_bytes = TagBytes(field_number, wire_type)
529      def EncodeRepeatedField(write, value, unused_deterministic=None):
530        for element in value:
531          write(tag_bytes)
532          write(local_struct_pack(format, element))
533      return EncodeRepeatedField
534    else:
535      tag_bytes = TagBytes(field_number, wire_type)
536      def EncodeField(write, value, unused_deterministic=None):
537        write(tag_bytes)
538        return write(local_struct_pack(format, value))
539      return EncodeField
540
541  return SpecificEncoder
542
543
544def _FloatingPointEncoder(wire_type, format):
545  """Return a constructor for an encoder for float fields.
546
547  This is like StructPackEncoder, but catches errors that may be due to
548  passing non-finite floating-point values to struct.pack, and makes a
549  second attempt to encode those values.
550
551  Args:
552      wire_type:  The field's wire type, for encoding tags.
553      format:  The format string to pass to struct.pack().
554  """
555
556  value_size = struct.calcsize(format)
557  if value_size == 4:
558    def EncodeNonFiniteOrRaise(write, value):
559      # Remember that the serialized form uses little-endian byte order.
560      if value == _POS_INF:
561        write(b'\x00\x00\x80\x7F')
562      elif value == _NEG_INF:
563        write(b'\x00\x00\x80\xFF')
564      elif value != value:           # NaN
565        write(b'\x00\x00\xC0\x7F')
566      else:
567        raise
568  elif value_size == 8:
569    def EncodeNonFiniteOrRaise(write, value):
570      if value == _POS_INF:
571        write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
572      elif value == _NEG_INF:
573        write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
574      elif value != value:                         # NaN
575        write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
576      else:
577        raise
578  else:
579    raise ValueError('Can\'t encode floating-point values that are '
580                     '%d bytes long (only 4 or 8)' % value_size)
581
582  def SpecificEncoder(field_number, is_repeated, is_packed):
583    local_struct_pack = struct.pack
584    if is_packed:
585      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
586      local_EncodeVarint = _EncodeVarint
587      def EncodePackedField(write, value, deterministic):
588        write(tag_bytes)
589        local_EncodeVarint(write, len(value) * value_size, deterministic)
590        for element in value:
591          # This try/except block is going to be faster than any code that
592          # we could write to check whether element is finite.
593          try:
594            write(local_struct_pack(format, element))
595          except SystemError:
596            EncodeNonFiniteOrRaise(write, element)
597      return EncodePackedField
598    elif is_repeated:
599      tag_bytes = TagBytes(field_number, wire_type)
600      def EncodeRepeatedField(write, value, unused_deterministic=None):
601        for element in value:
602          write(tag_bytes)
603          try:
604            write(local_struct_pack(format, element))
605          except SystemError:
606            EncodeNonFiniteOrRaise(write, element)
607      return EncodeRepeatedField
608    else:
609      tag_bytes = TagBytes(field_number, wire_type)
610      def EncodeField(write, value, unused_deterministic=None):
611        write(tag_bytes)
612        try:
613          write(local_struct_pack(format, value))
614        except SystemError:
615          EncodeNonFiniteOrRaise(write, value)
616      return EncodeField
617
618  return SpecificEncoder
619
620
621# ====================================================================
622# Here we declare an encoder constructor for each field type.  These work
623# very similarly to sizer constructors, described earlier.
624
625
626Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
627    wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
628
629UInt32Encoder = UInt64Encoder = _SimpleEncoder(
630    wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
631
632SInt32Encoder = SInt64Encoder = _ModifiedEncoder(
633    wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize,
634    wire_format.ZigZagEncode)
635
636# Note that Python conveniently guarantees that when using the '<' prefix on
637# formats, they will also have the same size across all platforms (as opposed
638# to without the prefix, where their sizes depend on the C compiler's basic
639# type sizes).
640Fixed32Encoder  = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I')
641Fixed64Encoder  = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q')
642SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i')
643SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q')
644FloatEncoder    = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f')
645DoubleEncoder   = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d')
646
647
648def BoolEncoder(field_number, is_repeated, is_packed):
649  """Returns an encoder for a boolean field."""
650
651  false_byte = b'\x00'
652  true_byte = b'\x01'
653  if is_packed:
654    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
655    local_EncodeVarint = _EncodeVarint
656    def EncodePackedField(write, value, deterministic):
657      write(tag_bytes)
658      local_EncodeVarint(write, len(value), deterministic)
659      for element in value:
660        if element:
661          write(true_byte)
662        else:
663          write(false_byte)
664    return EncodePackedField
665  elif is_repeated:
666    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
667    def EncodeRepeatedField(write, value, unused_deterministic=None):
668      for element in value:
669        write(tag_bytes)
670        if element:
671          write(true_byte)
672        else:
673          write(false_byte)
674    return EncodeRepeatedField
675  else:
676    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
677    def EncodeField(write, value, unused_deterministic=None):
678      write(tag_bytes)
679      if value:
680        return write(true_byte)
681      return write(false_byte)
682    return EncodeField
683
684
685def StringEncoder(field_number, is_repeated, is_packed):
686  """Returns an encoder for a string field."""
687
688  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
689  local_EncodeVarint = _EncodeVarint
690  local_len = len
691  assert not is_packed
692  if is_repeated:
693    def EncodeRepeatedField(write, value, deterministic):
694      for element in value:
695        encoded = element.encode('utf-8')
696        write(tag)
697        local_EncodeVarint(write, local_len(encoded), deterministic)
698        write(encoded)
699    return EncodeRepeatedField
700  else:
701    def EncodeField(write, value, deterministic):
702      encoded = value.encode('utf-8')
703      write(tag)
704      local_EncodeVarint(write, local_len(encoded), deterministic)
705      return write(encoded)
706    return EncodeField
707
708
709def BytesEncoder(field_number, is_repeated, is_packed):
710  """Returns an encoder for a bytes field."""
711
712  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
713  local_EncodeVarint = _EncodeVarint
714  local_len = len
715  assert not is_packed
716  if is_repeated:
717    def EncodeRepeatedField(write, value, deterministic):
718      for element in value:
719        write(tag)
720        local_EncodeVarint(write, local_len(element), deterministic)
721        write(element)
722    return EncodeRepeatedField
723  else:
724    def EncodeField(write, value, deterministic):
725      write(tag)
726      local_EncodeVarint(write, local_len(value), deterministic)
727      return write(value)
728    return EncodeField
729
730
731def GroupEncoder(field_number, is_repeated, is_packed):
732  """Returns an encoder for a group field."""
733
734  start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP)
735  end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
736  assert not is_packed
737  if is_repeated:
738    def EncodeRepeatedField(write, value, deterministic):
739      for element in value:
740        write(start_tag)
741        element._InternalSerialize(write, deterministic)
742        write(end_tag)
743    return EncodeRepeatedField
744  else:
745    def EncodeField(write, value, deterministic):
746      write(start_tag)
747      value._InternalSerialize(write, deterministic)
748      return write(end_tag)
749    return EncodeField
750
751
752def MessageEncoder(field_number, is_repeated, is_packed):
753  """Returns an encoder for a message field."""
754
755  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
756  local_EncodeVarint = _EncodeVarint
757  assert not is_packed
758  if is_repeated:
759    def EncodeRepeatedField(write, value, deterministic):
760      for element in value:
761        write(tag)
762        local_EncodeVarint(write, element.ByteSize(), deterministic)
763        element._InternalSerialize(write, deterministic)
764    return EncodeRepeatedField
765  else:
766    def EncodeField(write, value, deterministic):
767      write(tag)
768      local_EncodeVarint(write, value.ByteSize(), deterministic)
769      return value._InternalSerialize(write, deterministic)
770    return EncodeField
771
772
773# --------------------------------------------------------------------
774# As before, MessageSet is special.
775
776
777def MessageSetItemEncoder(field_number):
778  """Encoder for extensions of MessageSet.
779
780  The message set message looks like this:
781    message MessageSet {
782      repeated group Item = 1 {
783        required int32 type_id = 2;
784        required string message = 3;
785      }
786    }
787  """
788  start_bytes = b"".join([
789      TagBytes(1, wire_format.WIRETYPE_START_GROUP),
790      TagBytes(2, wire_format.WIRETYPE_VARINT),
791      _VarintBytes(field_number),
792      TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)])
793  end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
794  local_EncodeVarint = _EncodeVarint
795
796  def EncodeField(write, value, deterministic):
797    write(start_bytes)
798    local_EncodeVarint(write, value.ByteSize(), deterministic)
799    value._InternalSerialize(write, deterministic)
800    return write(end_bytes)
801
802  return EncodeField
803
804
805# --------------------------------------------------------------------
806# As before, Map is special.
807
808
809def MapEncoder(field_descriptor):
810  """Encoder for extensions of MessageSet.
811
812  Maps always have a wire format like this:
813    message MapEntry {
814      key_type key = 1;
815      value_type value = 2;
816    }
817    repeated MapEntry map = N;
818  """
819  # Can't look at field_descriptor.message_type._concrete_class because it may
820  # not have been initialized yet.
821  message_type = field_descriptor.message_type
822  encode_message = MessageEncoder(field_descriptor.number, False, False)
823
824  def EncodeField(write, value, deterministic):
825    value_keys = sorted(value.keys()) if deterministic else value
826    for key in value_keys:
827      entry_msg = message_type._concrete_class(key=key, value=value[key])
828      encode_message(write, entry_msg, deterministic)
829
830  return EncodeField
831