1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Code for decoding protocol buffer primitives.
32
33This code is very similar to encoder.py -- read the docs for that module first.
34
35A "decoder" is a function with the signature:
36  Decode(buffer, pos, end, message, field_dict)
37The arguments are:
38  buffer:     The string containing the encoded message.
39  pos:        The current position in the string.
40  end:        The position in the string where the current message ends.  May be
41              less than len(buffer) if we're reading a sub-message.
42  message:    The message object into which we're parsing.
43  field_dict: message._fields (avoids a hashtable lookup).
44The decoder reads the field and stores it into field_dict, returning the new
45buffer position.  A decoder for a repeated field may proactively decode all of
46the elements of that field, if they appear consecutively.
47
48Note that decoders may throw any of the following:
49  IndexError:  Indicates a truncated message.
50  struct.error:  Unpacking of a fixed-width field failed.
51  message.DecodeError:  Other errors.
52
53Decoders are expected to raise an exception if they are called with pos > end.
54This allows callers to be lax about bounds checking:  it's fineto read past
55"end" as long as you are sure that someone else will notice and throw an
56exception later on.
57
58Something up the call stack is expected to catch IndexError and struct.error
59and convert them to message.DecodeError.
60
61Decoders are constructed using decoder constructors with the signature:
62  MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
63The arguments are:
64  field_number:  The field number of the field we want to decode.
65  is_repeated:   Is the field a repeated field? (bool)
66  is_packed:     Is the field a packed field? (bool)
67  key:           The key to use when looking up the field within field_dict.
68                 (This is actually the FieldDescriptor but nothing in this
69                 file should depend on that.)
70  new_default:   A function which takes a message object as a parameter and
71                 returns a new instance of the default value for this field.
72                 (This is called for repeated fields and sub-messages, when an
73                 instance does not already exist.)
74
75As with encoders, we define a decoder constructor for every type of field.
76Then, for every field of every message class we construct an actual decoder.
77That decoder goes into a dict indexed by tag, so when we decode a message
78we repeatedly read a tag, look up the corresponding decoder, and invoke it.
79"""
80
81__author__ = 'kenton@google.com (Kenton Varda)'
82
83import struct
84import sys
85import six
86
87_UCS2_MAXUNICODE = 65535
88if six.PY3:
89  long = int
90else:
91  import re    # pylint: disable=g-import-not-at-top
92  _SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]'))
93
94from google.protobuf.internal import containers
95from google.protobuf.internal import encoder
96from google.protobuf.internal import wire_format
97from google.protobuf import message
98
99
100# This will overflow and thus become IEEE-754 "infinity".  We would use
101# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
102_POS_INF = 1e10000
103_NEG_INF = -_POS_INF
104_NAN = _POS_INF * 0
105
106
107# This is not for optimization, but rather to avoid conflicts with local
108# variables named "message".
109_DecodeError = message.DecodeError
110
111
112def _VarintDecoder(mask, result_type):
113  """Return an encoder for a basic varint value (does not include tag).
114
115  Decoded values will be bitwise-anded with the given mask before being
116  returned, e.g. to limit them to 32 bits.  The returned decoder does not
117  take the usual "end" parameter -- the caller is expected to do bounds checking
118  after the fact (often the caller can defer such checking until later).  The
119  decoder returns a (value, new_pos) pair.
120  """
121
122  def DecodeVarint(buffer, pos):
123    result = 0
124    shift = 0
125    while 1:
126      b = six.indexbytes(buffer, pos)
127      result |= ((b & 0x7f) << shift)
128      pos += 1
129      if not (b & 0x80):
130        result &= mask
131        result = result_type(result)
132        return (result, pos)
133      shift += 7
134      if shift >= 64:
135        raise _DecodeError('Too many bytes when decoding varint.')
136  return DecodeVarint
137
138
139def _SignedVarintDecoder(bits, result_type):
140  """Like _VarintDecoder() but decodes signed values."""
141
142  signbit = 1 << (bits - 1)
143  mask = (1 << bits) - 1
144
145  def DecodeVarint(buffer, pos):
146    result = 0
147    shift = 0
148    while 1:
149      b = six.indexbytes(buffer, pos)
150      result |= ((b & 0x7f) << shift)
151      pos += 1
152      if not (b & 0x80):
153        result &= mask
154        result = (result ^ signbit) - signbit
155        result = result_type(result)
156        return (result, pos)
157      shift += 7
158      if shift >= 64:
159        raise _DecodeError('Too many bytes when decoding varint.')
160  return DecodeVarint
161
162# We force 32-bit values to int and 64-bit values to long to make
163# alternate implementations where the distinction is more significant
164# (e.g. the C++ implementation) simpler.
165
166_DecodeVarint = _VarintDecoder((1 << 64) - 1, long)
167_DecodeSignedVarint = _SignedVarintDecoder(64, long)
168
169# Use these versions for values which must be limited to 32 bits.
170_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
171_DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
172
173
174def ReadTag(buffer, pos):
175  """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
176
177  We return the raw bytes of the tag rather than decoding them.  The raw
178  bytes can then be used to look up the proper decoder.  This effectively allows
179  us to trade some work that would be done in pure-python (decoding a varint)
180  for work that is done in C (searching for a byte string in a hash table).
181  In a low-level language it would be much cheaper to decode the varint and
182  use that, but not in Python.
183
184  Args:
185    buffer: memoryview object of the encoded bytes
186    pos: int of the current position to start from
187
188  Returns:
189    Tuple[bytes, int] of the tag data and new position.
190  """
191  start = pos
192  while six.indexbytes(buffer, pos) & 0x80:
193    pos += 1
194  pos += 1
195
196  tag_bytes = buffer[start:pos].tobytes()
197  return tag_bytes, pos
198
199
200# --------------------------------------------------------------------
201
202
203def _SimpleDecoder(wire_type, decode_value):
204  """Return a constructor for a decoder for fields of a particular type.
205
206  Args:
207      wire_type:  The field's wire type.
208      decode_value:  A function which decodes an individual value, e.g.
209        _DecodeVarint()
210  """
211
212  def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
213                      clear_if_default=False):
214    if is_packed:
215      local_DecodeVarint = _DecodeVarint
216      def DecodePackedField(buffer, pos, end, message, field_dict):
217        value = field_dict.get(key)
218        if value is None:
219          value = field_dict.setdefault(key, new_default(message))
220        (endpoint, pos) = local_DecodeVarint(buffer, pos)
221        endpoint += pos
222        if endpoint > end:
223          raise _DecodeError('Truncated message.')
224        while pos < endpoint:
225          (element, pos) = decode_value(buffer, pos)
226          value.append(element)
227        if pos > endpoint:
228          del value[-1]   # Discard corrupt value.
229          raise _DecodeError('Packed element was truncated.')
230        return pos
231      return DecodePackedField
232    elif is_repeated:
233      tag_bytes = encoder.TagBytes(field_number, wire_type)
234      tag_len = len(tag_bytes)
235      def DecodeRepeatedField(buffer, pos, end, message, field_dict):
236        value = field_dict.get(key)
237        if value is None:
238          value = field_dict.setdefault(key, new_default(message))
239        while 1:
240          (element, new_pos) = decode_value(buffer, pos)
241          value.append(element)
242          # Predict that the next tag is another copy of the same repeated
243          # field.
244          pos = new_pos + tag_len
245          if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
246            # Prediction failed.  Return.
247            if new_pos > end:
248              raise _DecodeError('Truncated message.')
249            return new_pos
250      return DecodeRepeatedField
251    else:
252      def DecodeField(buffer, pos, end, message, field_dict):
253        (new_value, pos) = decode_value(buffer, pos)
254        if pos > end:
255          raise _DecodeError('Truncated message.')
256        if clear_if_default and not new_value:
257          field_dict.pop(key, None)
258        else:
259          field_dict[key] = new_value
260        return pos
261      return DecodeField
262
263  return SpecificDecoder
264
265
266def _ModifiedDecoder(wire_type, decode_value, modify_value):
267  """Like SimpleDecoder but additionally invokes modify_value on every value
268  before storing it.  Usually modify_value is ZigZagDecode.
269  """
270
271  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
272  # not enough to make a significant difference.
273
274  def InnerDecode(buffer, pos):
275    (result, new_pos) = decode_value(buffer, pos)
276    return (modify_value(result), new_pos)
277  return _SimpleDecoder(wire_type, InnerDecode)
278
279
280def _StructPackDecoder(wire_type, format):
281  """Return a constructor for a decoder for a fixed-width field.
282
283  Args:
284      wire_type:  The field's wire type.
285      format:  The format string to pass to struct.unpack().
286  """
287
288  value_size = struct.calcsize(format)
289  local_unpack = struct.unpack
290
291  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
292  # not enough to make a significant difference.
293
294  # Note that we expect someone up-stack to catch struct.error and convert
295  # it to _DecodeError -- this way we don't have to set up exception-
296  # handling blocks every time we parse one value.
297
298  def InnerDecode(buffer, pos):
299    new_pos = pos + value_size
300    result = local_unpack(format, buffer[pos:new_pos])[0]
301    return (result, new_pos)
302  return _SimpleDecoder(wire_type, InnerDecode)
303
304
305def _FloatDecoder():
306  """Returns a decoder for a float field.
307
308  This code works around a bug in struct.unpack for non-finite 32-bit
309  floating-point values.
310  """
311
312  local_unpack = struct.unpack
313
314  def InnerDecode(buffer, pos):
315    """Decode serialized float to a float and new position.
316
317    Args:
318      buffer: memoryview of the serialized bytes
319      pos: int, position in the memory view to start at.
320
321    Returns:
322      Tuple[float, int] of the deserialized float value and new position
323      in the serialized data.
324    """
325    # We expect a 32-bit value in little-endian byte order.  Bit 1 is the sign
326    # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
327    new_pos = pos + 4
328    float_bytes = buffer[pos:new_pos].tobytes()
329
330    # If this value has all its exponent bits set, then it's non-finite.
331    # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
332    # To avoid that, we parse it specially.
333    if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
334      # If at least one significand bit is set...
335      if float_bytes[0:3] != b'\x00\x00\x80':
336        return (_NAN, new_pos)
337      # If sign bit is set...
338      if float_bytes[3:4] == b'\xFF':
339        return (_NEG_INF, new_pos)
340      return (_POS_INF, new_pos)
341
342    # Note that we expect someone up-stack to catch struct.error and convert
343    # it to _DecodeError -- this way we don't have to set up exception-
344    # handling blocks every time we parse one value.
345    result = local_unpack('<f', float_bytes)[0]
346    return (result, new_pos)
347  return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
348
349
350def _DoubleDecoder():
351  """Returns a decoder for a double field.
352
353  This code works around a bug in struct.unpack for not-a-number.
354  """
355
356  local_unpack = struct.unpack
357
358  def InnerDecode(buffer, pos):
359    """Decode serialized double to a double and new position.
360
361    Args:
362      buffer: memoryview of the serialized bytes.
363      pos: int, position in the memory view to start at.
364
365    Returns:
366      Tuple[float, int] of the decoded double value and new position
367      in the serialized data.
368    """
369    # We expect a 64-bit value in little-endian byte order.  Bit 1 is the sign
370    # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
371    new_pos = pos + 8
372    double_bytes = buffer[pos:new_pos].tobytes()
373
374    # If this value has all its exponent bits set and at least one significand
375    # bit set, it's not a number.  In Python 2.4, struct.unpack will treat it
376    # as inf or -inf.  To avoid that, we treat it specially.
377    if ((double_bytes[7:8] in b'\x7F\xFF')
378        and (double_bytes[6:7] >= b'\xF0')
379        and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
380      return (_NAN, new_pos)
381
382    # Note that we expect someone up-stack to catch struct.error and convert
383    # it to _DecodeError -- this way we don't have to set up exception-
384    # handling blocks every time we parse one value.
385    result = local_unpack('<d', double_bytes)[0]
386    return (result, new_pos)
387  return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
388
389
390def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
391                clear_if_default=False):
392  """Returns a decoder for enum field."""
393  enum_type = key.enum_type
394  if is_packed:
395    local_DecodeVarint = _DecodeVarint
396    def DecodePackedField(buffer, pos, end, message, field_dict):
397      """Decode serialized packed enum to its value and a new position.
398
399      Args:
400        buffer: memoryview of the serialized bytes.
401        pos: int, position in the memory view to start at.
402        end: int, end position of serialized data
403        message: Message object to store unknown fields in
404        field_dict: Map[Descriptor, Any] to store decoded values in.
405
406      Returns:
407        int, new position in serialized data.
408      """
409      value = field_dict.get(key)
410      if value is None:
411        value = field_dict.setdefault(key, new_default(message))
412      (endpoint, pos) = local_DecodeVarint(buffer, pos)
413      endpoint += pos
414      if endpoint > end:
415        raise _DecodeError('Truncated message.')
416      while pos < endpoint:
417        value_start_pos = pos
418        (element, pos) = _DecodeSignedVarint32(buffer, pos)
419        # pylint: disable=protected-access
420        if element in enum_type.values_by_number:
421          value.append(element)
422        else:
423          if not message._unknown_fields:
424            message._unknown_fields = []
425          tag_bytes = encoder.TagBytes(field_number,
426                                       wire_format.WIRETYPE_VARINT)
427
428          message._unknown_fields.append(
429              (tag_bytes, buffer[value_start_pos:pos].tobytes()))
430          if message._unknown_field_set is None:
431            message._unknown_field_set = containers.UnknownFieldSet()
432          message._unknown_field_set._add(
433              field_number, wire_format.WIRETYPE_VARINT, element)
434          # pylint: enable=protected-access
435      if pos > endpoint:
436        if element in enum_type.values_by_number:
437          del value[-1]   # Discard corrupt value.
438        else:
439          del message._unknown_fields[-1]
440          # pylint: disable=protected-access
441          del message._unknown_field_set._values[-1]
442          # pylint: enable=protected-access
443        raise _DecodeError('Packed element was truncated.')
444      return pos
445    return DecodePackedField
446  elif is_repeated:
447    tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
448    tag_len = len(tag_bytes)
449    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
450      """Decode serialized repeated enum to its value and a new position.
451
452      Args:
453        buffer: memoryview of the serialized bytes.
454        pos: int, position in the memory view to start at.
455        end: int, end position of serialized data
456        message: Message object to store unknown fields in
457        field_dict: Map[Descriptor, Any] to store decoded values in.
458
459      Returns:
460        int, new position in serialized data.
461      """
462      value = field_dict.get(key)
463      if value is None:
464        value = field_dict.setdefault(key, new_default(message))
465      while 1:
466        (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
467        # pylint: disable=protected-access
468        if element in enum_type.values_by_number:
469          value.append(element)
470        else:
471          if not message._unknown_fields:
472            message._unknown_fields = []
473          message._unknown_fields.append(
474              (tag_bytes, buffer[pos:new_pos].tobytes()))
475          if message._unknown_field_set is None:
476            message._unknown_field_set = containers.UnknownFieldSet()
477          message._unknown_field_set._add(
478              field_number, wire_format.WIRETYPE_VARINT, element)
479        # pylint: enable=protected-access
480        # Predict that the next tag is another copy of the same repeated
481        # field.
482        pos = new_pos + tag_len
483        if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
484          # Prediction failed.  Return.
485          if new_pos > end:
486            raise _DecodeError('Truncated message.')
487          return new_pos
488    return DecodeRepeatedField
489  else:
490    def DecodeField(buffer, pos, end, message, field_dict):
491      """Decode serialized repeated enum to its value and a new position.
492
493      Args:
494        buffer: memoryview of the serialized bytes.
495        pos: int, position in the memory view to start at.
496        end: int, end position of serialized data
497        message: Message object to store unknown fields in
498        field_dict: Map[Descriptor, Any] to store decoded values in.
499
500      Returns:
501        int, new position in serialized data.
502      """
503      value_start_pos = pos
504      (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
505      if pos > end:
506        raise _DecodeError('Truncated message.')
507      if clear_if_default and not enum_value:
508        field_dict.pop(key, None)
509        return pos
510      # pylint: disable=protected-access
511      if enum_value in enum_type.values_by_number:
512        field_dict[key] = enum_value
513      else:
514        if not message._unknown_fields:
515          message._unknown_fields = []
516        tag_bytes = encoder.TagBytes(field_number,
517                                     wire_format.WIRETYPE_VARINT)
518        message._unknown_fields.append(
519            (tag_bytes, buffer[value_start_pos:pos].tobytes()))
520        if message._unknown_field_set is None:
521          message._unknown_field_set = containers.UnknownFieldSet()
522        message._unknown_field_set._add(
523            field_number, wire_format.WIRETYPE_VARINT, enum_value)
524        # pylint: enable=protected-access
525      return pos
526    return DecodeField
527
528
529# --------------------------------------------------------------------
530
531
532Int32Decoder = _SimpleDecoder(
533    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
534
535Int64Decoder = _SimpleDecoder(
536    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
537
538UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
539UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
540
541SInt32Decoder = _ModifiedDecoder(
542    wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
543SInt64Decoder = _ModifiedDecoder(
544    wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
545
546# Note that Python conveniently guarantees that when using the '<' prefix on
547# formats, they will also have the same size across all platforms (as opposed
548# to without the prefix, where their sizes depend on the C compiler's basic
549# type sizes).
550Fixed32Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
551Fixed64Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
552SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
553SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
554FloatDecoder = _FloatDecoder()
555DoubleDecoder = _DoubleDecoder()
556
557BoolDecoder = _ModifiedDecoder(
558    wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
559
560
561def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
562                  is_strict_utf8=False, clear_if_default=False):
563  """Returns a decoder for a string field."""
564
565  local_DecodeVarint = _DecodeVarint
566  local_unicode = six.text_type
567
568  def _ConvertToUnicode(memview):
569    """Convert byte to unicode."""
570    byte_str = memview.tobytes()
571    try:
572      value = local_unicode(byte_str, 'utf-8')
573    except UnicodeDecodeError as e:
574      # add more information to the error message and re-raise it.
575      e.reason = '%s in field: %s' % (e, key.full_name)
576      raise
577
578    if is_strict_utf8 and six.PY2 and sys.maxunicode > _UCS2_MAXUNICODE:
579      # Only do the check for python2 ucs4 when is_strict_utf8 enabled
580      if _SURROGATE_PATTERN.search(value):
581        reason = ('String field %s contains invalid UTF-8 data when parsing'
582                  'a protocol buffer: surrogates not allowed. Use'
583                  'the bytes type if you intend to send raw bytes.') % (
584                      key.full_name)
585        raise message.DecodeError(reason)
586
587    return value
588
589  assert not is_packed
590  if is_repeated:
591    tag_bytes = encoder.TagBytes(field_number,
592                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
593    tag_len = len(tag_bytes)
594    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
595      value = field_dict.get(key)
596      if value is None:
597        value = field_dict.setdefault(key, new_default(message))
598      while 1:
599        (size, pos) = local_DecodeVarint(buffer, pos)
600        new_pos = pos + size
601        if new_pos > end:
602          raise _DecodeError('Truncated string.')
603        value.append(_ConvertToUnicode(buffer[pos:new_pos]))
604        # Predict that the next tag is another copy of the same repeated field.
605        pos = new_pos + tag_len
606        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
607          # Prediction failed.  Return.
608          return new_pos
609    return DecodeRepeatedField
610  else:
611    def DecodeField(buffer, pos, end, message, field_dict):
612      (size, pos) = local_DecodeVarint(buffer, pos)
613      new_pos = pos + size
614      if new_pos > end:
615        raise _DecodeError('Truncated string.')
616      if clear_if_default and not size:
617        field_dict.pop(key, None)
618      else:
619        field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
620      return new_pos
621    return DecodeField
622
623
624def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
625                 clear_if_default=False):
626  """Returns a decoder for a bytes field."""
627
628  local_DecodeVarint = _DecodeVarint
629
630  assert not is_packed
631  if is_repeated:
632    tag_bytes = encoder.TagBytes(field_number,
633                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
634    tag_len = len(tag_bytes)
635    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
636      value = field_dict.get(key)
637      if value is None:
638        value = field_dict.setdefault(key, new_default(message))
639      while 1:
640        (size, pos) = local_DecodeVarint(buffer, pos)
641        new_pos = pos + size
642        if new_pos > end:
643          raise _DecodeError('Truncated string.')
644        value.append(buffer[pos:new_pos].tobytes())
645        # Predict that the next tag is another copy of the same repeated field.
646        pos = new_pos + tag_len
647        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
648          # Prediction failed.  Return.
649          return new_pos
650    return DecodeRepeatedField
651  else:
652    def DecodeField(buffer, pos, end, message, field_dict):
653      (size, pos) = local_DecodeVarint(buffer, pos)
654      new_pos = pos + size
655      if new_pos > end:
656        raise _DecodeError('Truncated string.')
657      if clear_if_default and not size:
658        field_dict.pop(key, None)
659      else:
660        field_dict[key] = buffer[pos:new_pos].tobytes()
661      return new_pos
662    return DecodeField
663
664
665def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
666  """Returns a decoder for a group field."""
667
668  end_tag_bytes = encoder.TagBytes(field_number,
669                                   wire_format.WIRETYPE_END_GROUP)
670  end_tag_len = len(end_tag_bytes)
671
672  assert not is_packed
673  if is_repeated:
674    tag_bytes = encoder.TagBytes(field_number,
675                                 wire_format.WIRETYPE_START_GROUP)
676    tag_len = len(tag_bytes)
677    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
678      value = field_dict.get(key)
679      if value is None:
680        value = field_dict.setdefault(key, new_default(message))
681      while 1:
682        value = field_dict.get(key)
683        if value is None:
684          value = field_dict.setdefault(key, new_default(message))
685        # Read sub-message.
686        pos = value.add()._InternalParse(buffer, pos, end)
687        # Read end tag.
688        new_pos = pos+end_tag_len
689        if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
690          raise _DecodeError('Missing group end tag.')
691        # Predict that the next tag is another copy of the same repeated field.
692        pos = new_pos + tag_len
693        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
694          # Prediction failed.  Return.
695          return new_pos
696    return DecodeRepeatedField
697  else:
698    def DecodeField(buffer, pos, end, message, field_dict):
699      value = field_dict.get(key)
700      if value is None:
701        value = field_dict.setdefault(key, new_default(message))
702      # Read sub-message.
703      pos = value._InternalParse(buffer, pos, end)
704      # Read end tag.
705      new_pos = pos+end_tag_len
706      if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
707        raise _DecodeError('Missing group end tag.')
708      return new_pos
709    return DecodeField
710
711
712def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
713  """Returns a decoder for a message field."""
714
715  local_DecodeVarint = _DecodeVarint
716
717  assert not is_packed
718  if is_repeated:
719    tag_bytes = encoder.TagBytes(field_number,
720                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
721    tag_len = len(tag_bytes)
722    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
723      value = field_dict.get(key)
724      if value is None:
725        value = field_dict.setdefault(key, new_default(message))
726      while 1:
727        # Read length.
728        (size, pos) = local_DecodeVarint(buffer, pos)
729        new_pos = pos + size
730        if new_pos > end:
731          raise _DecodeError('Truncated message.')
732        # Read sub-message.
733        if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
734          # The only reason _InternalParse would return early is if it
735          # encountered an end-group tag.
736          raise _DecodeError('Unexpected end-group tag.')
737        # Predict that the next tag is another copy of the same repeated field.
738        pos = new_pos + tag_len
739        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
740          # Prediction failed.  Return.
741          return new_pos
742    return DecodeRepeatedField
743  else:
744    def DecodeField(buffer, pos, end, message, field_dict):
745      value = field_dict.get(key)
746      if value is None:
747        value = field_dict.setdefault(key, new_default(message))
748      # Read length.
749      (size, pos) = local_DecodeVarint(buffer, pos)
750      new_pos = pos + size
751      if new_pos > end:
752        raise _DecodeError('Truncated message.')
753      # Read sub-message.
754      if value._InternalParse(buffer, pos, new_pos) != new_pos:
755        # The only reason _InternalParse would return early is if it encountered
756        # an end-group tag.
757        raise _DecodeError('Unexpected end-group tag.')
758      return new_pos
759    return DecodeField
760
761
762# --------------------------------------------------------------------
763
764MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
765
766def MessageSetItemDecoder(descriptor):
767  """Returns a decoder for a MessageSet item.
768
769  The parameter is the message Descriptor.
770
771  The message set message looks like this:
772    message MessageSet {
773      repeated group Item = 1 {
774        required int32 type_id = 2;
775        required string message = 3;
776      }
777    }
778  """
779
780  type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
781  message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
782  item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
783
784  local_ReadTag = ReadTag
785  local_DecodeVarint = _DecodeVarint
786  local_SkipField = SkipField
787
788  def DecodeItem(buffer, pos, end, message, field_dict):
789    """Decode serialized message set to its value and new position.
790
791    Args:
792      buffer: memoryview of the serialized bytes.
793      pos: int, position in the memory view to start at.
794      end: int, end position of serialized data
795      message: Message object to store unknown fields in
796      field_dict: Map[Descriptor, Any] to store decoded values in.
797
798    Returns:
799      int, new position in serialized data.
800    """
801    message_set_item_start = pos
802    type_id = -1
803    message_start = -1
804    message_end = -1
805
806    # Technically, type_id and message can appear in any order, so we need
807    # a little loop here.
808    while 1:
809      (tag_bytes, pos) = local_ReadTag(buffer, pos)
810      if tag_bytes == type_id_tag_bytes:
811        (type_id, pos) = local_DecodeVarint(buffer, pos)
812      elif tag_bytes == message_tag_bytes:
813        (size, message_start) = local_DecodeVarint(buffer, pos)
814        pos = message_end = message_start + size
815      elif tag_bytes == item_end_tag_bytes:
816        break
817      else:
818        pos = SkipField(buffer, pos, end, tag_bytes)
819        if pos == -1:
820          raise _DecodeError('Missing group end tag.')
821
822    if pos > end:
823      raise _DecodeError('Truncated message.')
824
825    if type_id == -1:
826      raise _DecodeError('MessageSet item missing type_id.')
827    if message_start == -1:
828      raise _DecodeError('MessageSet item missing message.')
829
830    extension = message.Extensions._FindExtensionByNumber(type_id)
831    # pylint: disable=protected-access
832    if extension is not None:
833      value = field_dict.get(extension)
834      if value is None:
835        message_type = extension.message_type
836        if not hasattr(message_type, '_concrete_class'):
837          # pylint: disable=protected-access
838          message._FACTORY.GetPrototype(message_type)
839        value = field_dict.setdefault(
840            extension, message_type._concrete_class())
841      if value._InternalParse(buffer, message_start,message_end) != message_end:
842        # The only reason _InternalParse would return early is if it encountered
843        # an end-group tag.
844        raise _DecodeError('Unexpected end-group tag.')
845    else:
846      if not message._unknown_fields:
847        message._unknown_fields = []
848      message._unknown_fields.append(
849          (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
850      if message._unknown_field_set is None:
851        message._unknown_field_set = containers.UnknownFieldSet()
852      message._unknown_field_set._add(
853          type_id,
854          wire_format.WIRETYPE_LENGTH_DELIMITED,
855          buffer[message_start:message_end].tobytes())
856      # pylint: enable=protected-access
857
858    return pos
859
860  return DecodeItem
861
862# --------------------------------------------------------------------
863
864def MapDecoder(field_descriptor, new_default, is_message_map):
865  """Returns a decoder for a map field."""
866
867  key = field_descriptor
868  tag_bytes = encoder.TagBytes(field_descriptor.number,
869                               wire_format.WIRETYPE_LENGTH_DELIMITED)
870  tag_len = len(tag_bytes)
871  local_DecodeVarint = _DecodeVarint
872  # Can't read _concrete_class yet; might not be initialized.
873  message_type = field_descriptor.message_type
874
875  def DecodeMap(buffer, pos, end, message, field_dict):
876    submsg = message_type._concrete_class()
877    value = field_dict.get(key)
878    if value is None:
879      value = field_dict.setdefault(key, new_default(message))
880    while 1:
881      # Read length.
882      (size, pos) = local_DecodeVarint(buffer, pos)
883      new_pos = pos + size
884      if new_pos > end:
885        raise _DecodeError('Truncated message.')
886      # Read sub-message.
887      submsg.Clear()
888      if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
889        # The only reason _InternalParse would return early is if it
890        # encountered an end-group tag.
891        raise _DecodeError('Unexpected end-group tag.')
892
893      if is_message_map:
894        value[submsg.key].CopyFrom(submsg.value)
895      else:
896        value[submsg.key] = submsg.value
897
898      # Predict that the next tag is another copy of the same repeated field.
899      pos = new_pos + tag_len
900      if buffer[new_pos:pos] != tag_bytes or new_pos == end:
901        # Prediction failed.  Return.
902        return new_pos
903
904  return DecodeMap
905
906# --------------------------------------------------------------------
907# Optimization is not as heavy here because calls to SkipField() are rare,
908# except for handling end-group tags.
909
910def _SkipVarint(buffer, pos, end):
911  """Skip a varint value.  Returns the new position."""
912  # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
913  # With this code, ord(b'') raises TypeError.  Both are handled in
914  # python_message.py to generate a 'Truncated message' error.
915  while ord(buffer[pos:pos+1].tobytes()) & 0x80:
916    pos += 1
917  pos += 1
918  if pos > end:
919    raise _DecodeError('Truncated message.')
920  return pos
921
922def _SkipFixed64(buffer, pos, end):
923  """Skip a fixed64 value.  Returns the new position."""
924
925  pos += 8
926  if pos > end:
927    raise _DecodeError('Truncated message.')
928  return pos
929
930
931def _DecodeFixed64(buffer, pos):
932  """Decode a fixed64."""
933  new_pos = pos + 8
934  return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
935
936
937def _SkipLengthDelimited(buffer, pos, end):
938  """Skip a length-delimited value.  Returns the new position."""
939
940  (size, pos) = _DecodeVarint(buffer, pos)
941  pos += size
942  if pos > end:
943    raise _DecodeError('Truncated message.')
944  return pos
945
946
947def _SkipGroup(buffer, pos, end):
948  """Skip sub-group.  Returns the new position."""
949
950  while 1:
951    (tag_bytes, pos) = ReadTag(buffer, pos)
952    new_pos = SkipField(buffer, pos, end, tag_bytes)
953    if new_pos == -1:
954      return pos
955    pos = new_pos
956
957
958def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
959  """Decode UnknownFieldSet.  Returns the UnknownFieldSet and new position."""
960
961  unknown_field_set = containers.UnknownFieldSet()
962  while end_pos is None or pos < end_pos:
963    (tag_bytes, pos) = ReadTag(buffer, pos)
964    (tag, _) = _DecodeVarint(tag_bytes, 0)
965    field_number, wire_type = wire_format.UnpackTag(tag)
966    if wire_type == wire_format.WIRETYPE_END_GROUP:
967      break
968    (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
969    # pylint: disable=protected-access
970    unknown_field_set._add(field_number, wire_type, data)
971
972  return (unknown_field_set, pos)
973
974
975def _DecodeUnknownField(buffer, pos, wire_type):
976  """Decode a unknown field.  Returns the UnknownField and new position."""
977
978  if wire_type == wire_format.WIRETYPE_VARINT:
979    (data, pos) = _DecodeVarint(buffer, pos)
980  elif wire_type == wire_format.WIRETYPE_FIXED64:
981    (data, pos) = _DecodeFixed64(buffer, pos)
982  elif wire_type == wire_format.WIRETYPE_FIXED32:
983    (data, pos) = _DecodeFixed32(buffer, pos)
984  elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
985    (size, pos) = _DecodeVarint(buffer, pos)
986    data = buffer[pos:pos+size].tobytes()
987    pos += size
988  elif wire_type == wire_format.WIRETYPE_START_GROUP:
989    (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
990  elif wire_type == wire_format.WIRETYPE_END_GROUP:
991    return (0, -1)
992  else:
993    raise _DecodeError('Wrong wire type in tag.')
994
995  return (data, pos)
996
997
998def _EndGroup(buffer, pos, end):
999  """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
1000
1001  return -1
1002
1003
1004def _SkipFixed32(buffer, pos, end):
1005  """Skip a fixed32 value.  Returns the new position."""
1006
1007  pos += 4
1008  if pos > end:
1009    raise _DecodeError('Truncated message.')
1010  return pos
1011
1012
1013def _DecodeFixed32(buffer, pos):
1014  """Decode a fixed32."""
1015
1016  new_pos = pos + 4
1017  return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
1018
1019
1020def _RaiseInvalidWireType(buffer, pos, end):
1021  """Skip function for unknown wire types.  Raises an exception."""
1022
1023  raise _DecodeError('Tag had invalid wire type.')
1024
1025def _FieldSkipper():
1026  """Constructs the SkipField function."""
1027
1028  WIRETYPE_TO_SKIPPER = [
1029      _SkipVarint,
1030      _SkipFixed64,
1031      _SkipLengthDelimited,
1032      _SkipGroup,
1033      _EndGroup,
1034      _SkipFixed32,
1035      _RaiseInvalidWireType,
1036      _RaiseInvalidWireType,
1037      ]
1038
1039  wiretype_mask = wire_format.TAG_TYPE_MASK
1040
1041  def SkipField(buffer, pos, end, tag_bytes):
1042    """Skips a field with the specified tag.
1043
1044    |pos| should point to the byte immediately after the tag.
1045
1046    Returns:
1047        The new position (after the tag value), or -1 if the tag is an end-group
1048        tag (in which case the calling loop should break).
1049    """
1050
1051    # The wire type is always in the first byte since varints are little-endian.
1052    wire_type = ord(tag_bytes[0:1]) & wiretype_mask
1053    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
1054
1055  return SkipField
1056
1057SkipField = _FieldSkipper()
1058