1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31"""Code for decoding protocol buffer primitives. 32 33This code is very similar to encoder.py -- read the docs for that module first. 34 35A "decoder" is a function with the signature: 36 Decode(buffer, pos, end, message, field_dict) 37The arguments are: 38 buffer: The string containing the encoded message. 39 pos: The current position in the string. 40 end: The position in the string where the current message ends. May be 41 less than len(buffer) if we're reading a sub-message. 42 message: The message object into which we're parsing. 43 field_dict: message._fields (avoids a hashtable lookup). 44The decoder reads the field and stores it into field_dict, returning the new 45buffer position. A decoder for a repeated field may proactively decode all of 46the elements of that field, if they appear consecutively. 47 48Note that decoders may throw any of the following: 49 IndexError: Indicates a truncated message. 50 struct.error: Unpacking of a fixed-width field failed. 51 message.DecodeError: Other errors. 52 53Decoders are expected to raise an exception if they are called with pos > end. 54This allows callers to be lax about bounds checking: it's fineto read past 55"end" as long as you are sure that someone else will notice and throw an 56exception later on. 57 58Something up the call stack is expected to catch IndexError and struct.error 59and convert them to message.DecodeError. 60 61Decoders are constructed using decoder constructors with the signature: 62 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 63The arguments are: 64 field_number: The field number of the field we want to decode. 65 is_repeated: Is the field a repeated field? (bool) 66 is_packed: Is the field a packed field? (bool) 67 key: The key to use when looking up the field within field_dict. 68 (This is actually the FieldDescriptor but nothing in this 69 file should depend on that.) 70 new_default: A function which takes a message object as a parameter and 71 returns a new instance of the default value for this field. 72 (This is called for repeated fields and sub-messages, when an 73 instance does not already exist.) 74 75As with encoders, we define a decoder constructor for every type of field. 76Then, for every field of every message class we construct an actual decoder. 77That decoder goes into a dict indexed by tag, so when we decode a message 78we repeatedly read a tag, look up the corresponding decoder, and invoke it. 79""" 80 81__author__ = 'kenton@google.com (Kenton Varda)' 82 83import struct 84import sys 85import six 86 87_UCS2_MAXUNICODE = 65535 88if six.PY3: 89 long = int 90else: 91 import re # pylint: disable=g-import-not-at-top 92 _SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]')) 93 94from google.protobuf.internal import containers 95from google.protobuf.internal import encoder 96from google.protobuf.internal import wire_format 97from google.protobuf import message 98 99 100# This will overflow and thus become IEEE-754 "infinity". We would use 101# "float('inf')" but it doesn't work on Windows pre-Python-2.6. 102_POS_INF = 1e10000 103_NEG_INF = -_POS_INF 104_NAN = _POS_INF * 0 105 106 107# This is not for optimization, but rather to avoid conflicts with local 108# variables named "message". 109_DecodeError = message.DecodeError 110 111 112def _VarintDecoder(mask, result_type): 113 """Return an encoder for a basic varint value (does not include tag). 114 115 Decoded values will be bitwise-anded with the given mask before being 116 returned, e.g. to limit them to 32 bits. The returned decoder does not 117 take the usual "end" parameter -- the caller is expected to do bounds checking 118 after the fact (often the caller can defer such checking until later). The 119 decoder returns a (value, new_pos) pair. 120 """ 121 122 def DecodeVarint(buffer, pos): 123 result = 0 124 shift = 0 125 while 1: 126 b = six.indexbytes(buffer, pos) 127 result |= ((b & 0x7f) << shift) 128 pos += 1 129 if not (b & 0x80): 130 result &= mask 131 result = result_type(result) 132 return (result, pos) 133 shift += 7 134 if shift >= 64: 135 raise _DecodeError('Too many bytes when decoding varint.') 136 return DecodeVarint 137 138 139def _SignedVarintDecoder(bits, result_type): 140 """Like _VarintDecoder() but decodes signed values.""" 141 142 signbit = 1 << (bits - 1) 143 mask = (1 << bits) - 1 144 145 def DecodeVarint(buffer, pos): 146 result = 0 147 shift = 0 148 while 1: 149 b = six.indexbytes(buffer, pos) 150 result |= ((b & 0x7f) << shift) 151 pos += 1 152 if not (b & 0x80): 153 result &= mask 154 result = (result ^ signbit) - signbit 155 result = result_type(result) 156 return (result, pos) 157 shift += 7 158 if shift >= 64: 159 raise _DecodeError('Too many bytes when decoding varint.') 160 return DecodeVarint 161 162# We force 32-bit values to int and 64-bit values to long to make 163# alternate implementations where the distinction is more significant 164# (e.g. the C++ implementation) simpler. 165 166_DecodeVarint = _VarintDecoder((1 << 64) - 1, long) 167_DecodeSignedVarint = _SignedVarintDecoder(64, long) 168 169# Use these versions for values which must be limited to 32 bits. 170_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) 171_DecodeSignedVarint32 = _SignedVarintDecoder(32, int) 172 173 174def ReadTag(buffer, pos): 175 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. 176 177 We return the raw bytes of the tag rather than decoding them. The raw 178 bytes can then be used to look up the proper decoder. This effectively allows 179 us to trade some work that would be done in pure-python (decoding a varint) 180 for work that is done in C (searching for a byte string in a hash table). 181 In a low-level language it would be much cheaper to decode the varint and 182 use that, but not in Python. 183 184 Args: 185 buffer: memoryview object of the encoded bytes 186 pos: int of the current position to start from 187 188 Returns: 189 Tuple[bytes, int] of the tag data and new position. 190 """ 191 start = pos 192 while six.indexbytes(buffer, pos) & 0x80: 193 pos += 1 194 pos += 1 195 196 tag_bytes = buffer[start:pos].tobytes() 197 return tag_bytes, pos 198 199 200# -------------------------------------------------------------------- 201 202 203def _SimpleDecoder(wire_type, decode_value): 204 """Return a constructor for a decoder for fields of a particular type. 205 206 Args: 207 wire_type: The field's wire type. 208 decode_value: A function which decodes an individual value, e.g. 209 _DecodeVarint() 210 """ 211 212 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default): 213 if is_packed: 214 local_DecodeVarint = _DecodeVarint 215 def DecodePackedField(buffer, pos, end, message, field_dict): 216 value = field_dict.get(key) 217 if value is None: 218 value = field_dict.setdefault(key, new_default(message)) 219 (endpoint, pos) = local_DecodeVarint(buffer, pos) 220 endpoint += pos 221 if endpoint > end: 222 raise _DecodeError('Truncated message.') 223 while pos < endpoint: 224 (element, pos) = decode_value(buffer, pos) 225 value.append(element) 226 if pos > endpoint: 227 del value[-1] # Discard corrupt value. 228 raise _DecodeError('Packed element was truncated.') 229 return pos 230 return DecodePackedField 231 elif is_repeated: 232 tag_bytes = encoder.TagBytes(field_number, wire_type) 233 tag_len = len(tag_bytes) 234 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 235 value = field_dict.get(key) 236 if value is None: 237 value = field_dict.setdefault(key, new_default(message)) 238 while 1: 239 (element, new_pos) = decode_value(buffer, pos) 240 value.append(element) 241 # Predict that the next tag is another copy of the same repeated 242 # field. 243 pos = new_pos + tag_len 244 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 245 # Prediction failed. Return. 246 if new_pos > end: 247 raise _DecodeError('Truncated message.') 248 return new_pos 249 return DecodeRepeatedField 250 else: 251 def DecodeField(buffer, pos, end, message, field_dict): 252 (field_dict[key], pos) = decode_value(buffer, pos) 253 if pos > end: 254 del field_dict[key] # Discard corrupt value. 255 raise _DecodeError('Truncated message.') 256 return pos 257 return DecodeField 258 259 return SpecificDecoder 260 261 262def _ModifiedDecoder(wire_type, decode_value, modify_value): 263 """Like SimpleDecoder but additionally invokes modify_value on every value 264 before storing it. Usually modify_value is ZigZagDecode. 265 """ 266 267 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 268 # not enough to make a significant difference. 269 270 def InnerDecode(buffer, pos): 271 (result, new_pos) = decode_value(buffer, pos) 272 return (modify_value(result), new_pos) 273 return _SimpleDecoder(wire_type, InnerDecode) 274 275 276def _StructPackDecoder(wire_type, format): 277 """Return a constructor for a decoder for a fixed-width field. 278 279 Args: 280 wire_type: The field's wire type. 281 format: The format string to pass to struct.unpack(). 282 """ 283 284 value_size = struct.calcsize(format) 285 local_unpack = struct.unpack 286 287 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 288 # not enough to make a significant difference. 289 290 # Note that we expect someone up-stack to catch struct.error and convert 291 # it to _DecodeError -- this way we don't have to set up exception- 292 # handling blocks every time we parse one value. 293 294 def InnerDecode(buffer, pos): 295 new_pos = pos + value_size 296 result = local_unpack(format, buffer[pos:new_pos])[0] 297 return (result, new_pos) 298 return _SimpleDecoder(wire_type, InnerDecode) 299 300 301def _FloatDecoder(): 302 """Returns a decoder for a float field. 303 304 This code works around a bug in struct.unpack for non-finite 32-bit 305 floating-point values. 306 """ 307 308 local_unpack = struct.unpack 309 310 def InnerDecode(buffer, pos): 311 """Decode serialized float to a float and new position. 312 313 Args: 314 buffer: memoryview of the serialized bytes 315 pos: int, position in the memory view to start at. 316 317 Returns: 318 Tuple[float, int] of the deserialized float value and new position 319 in the serialized data. 320 """ 321 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign 322 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. 323 new_pos = pos + 4 324 float_bytes = buffer[pos:new_pos].tobytes() 325 326 # If this value has all its exponent bits set, then it's non-finite. 327 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. 328 # To avoid that, we parse it specially. 329 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): 330 # If at least one significand bit is set... 331 if float_bytes[0:3] != b'\x00\x00\x80': 332 return (_NAN, new_pos) 333 # If sign bit is set... 334 if float_bytes[3:4] == b'\xFF': 335 return (_NEG_INF, new_pos) 336 return (_POS_INF, new_pos) 337 338 # Note that we expect someone up-stack to catch struct.error and convert 339 # it to _DecodeError -- this way we don't have to set up exception- 340 # handling blocks every time we parse one value. 341 result = local_unpack('<f', float_bytes)[0] 342 return (result, new_pos) 343 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) 344 345 346def _DoubleDecoder(): 347 """Returns a decoder for a double field. 348 349 This code works around a bug in struct.unpack for not-a-number. 350 """ 351 352 local_unpack = struct.unpack 353 354 def InnerDecode(buffer, pos): 355 """Decode serialized double to a double and new position. 356 357 Args: 358 buffer: memoryview of the serialized bytes. 359 pos: int, position in the memory view to start at. 360 361 Returns: 362 Tuple[float, int] of the decoded double value and new position 363 in the serialized data. 364 """ 365 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign 366 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. 367 new_pos = pos + 8 368 double_bytes = buffer[pos:new_pos].tobytes() 369 370 # If this value has all its exponent bits set and at least one significand 371 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it 372 # as inf or -inf. To avoid that, we treat it specially. 373 if ((double_bytes[7:8] in b'\x7F\xFF') 374 and (double_bytes[6:7] >= b'\xF0') 375 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): 376 return (_NAN, new_pos) 377 378 # Note that we expect someone up-stack to catch struct.error and convert 379 # it to _DecodeError -- this way we don't have to set up exception- 380 # handling blocks every time we parse one value. 381 result = local_unpack('<d', double_bytes)[0] 382 return (result, new_pos) 383 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) 384 385 386def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): 387 enum_type = key.enum_type 388 if is_packed: 389 local_DecodeVarint = _DecodeVarint 390 def DecodePackedField(buffer, pos, end, message, field_dict): 391 """Decode serialized packed enum to its value and a new position. 392 393 Args: 394 buffer: memoryview of the serialized bytes. 395 pos: int, position in the memory view to start at. 396 end: int, end position of serialized data 397 message: Message object to store unknown fields in 398 field_dict: Map[Descriptor, Any] to store decoded values in. 399 400 Returns: 401 int, new position in serialized data. 402 """ 403 value = field_dict.get(key) 404 if value is None: 405 value = field_dict.setdefault(key, new_default(message)) 406 (endpoint, pos) = local_DecodeVarint(buffer, pos) 407 endpoint += pos 408 if endpoint > end: 409 raise _DecodeError('Truncated message.') 410 while pos < endpoint: 411 value_start_pos = pos 412 (element, pos) = _DecodeSignedVarint32(buffer, pos) 413 # pylint: disable=protected-access 414 if element in enum_type.values_by_number: 415 value.append(element) 416 else: 417 if not message._unknown_fields: 418 message._unknown_fields = [] 419 tag_bytes = encoder.TagBytes(field_number, 420 wire_format.WIRETYPE_VARINT) 421 422 message._unknown_fields.append( 423 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 424 if message._unknown_field_set is None: 425 message._unknown_field_set = containers.UnknownFieldSet() 426 message._unknown_field_set._add( 427 field_number, wire_format.WIRETYPE_VARINT, element) 428 # pylint: enable=protected-access 429 if pos > endpoint: 430 if element in enum_type.values_by_number: 431 del value[-1] # Discard corrupt value. 432 else: 433 del message._unknown_fields[-1] 434 # pylint: disable=protected-access 435 del message._unknown_field_set._values[-1] 436 # pylint: enable=protected-access 437 raise _DecodeError('Packed element was truncated.') 438 return pos 439 return DecodePackedField 440 elif is_repeated: 441 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) 442 tag_len = len(tag_bytes) 443 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 444 """Decode serialized repeated enum to its value and a new position. 445 446 Args: 447 buffer: memoryview of the serialized bytes. 448 pos: int, position in the memory view to start at. 449 end: int, end position of serialized data 450 message: Message object to store unknown fields in 451 field_dict: Map[Descriptor, Any] to store decoded values in. 452 453 Returns: 454 int, new position in serialized data. 455 """ 456 value = field_dict.get(key) 457 if value is None: 458 value = field_dict.setdefault(key, new_default(message)) 459 while 1: 460 (element, new_pos) = _DecodeSignedVarint32(buffer, pos) 461 # pylint: disable=protected-access 462 if element in enum_type.values_by_number: 463 value.append(element) 464 else: 465 if not message._unknown_fields: 466 message._unknown_fields = [] 467 message._unknown_fields.append( 468 (tag_bytes, buffer[pos:new_pos].tobytes())) 469 if message._unknown_field_set is None: 470 message._unknown_field_set = containers.UnknownFieldSet() 471 message._unknown_field_set._add( 472 field_number, wire_format.WIRETYPE_VARINT, element) 473 # pylint: enable=protected-access 474 # Predict that the next tag is another copy of the same repeated 475 # field. 476 pos = new_pos + tag_len 477 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 478 # Prediction failed. Return. 479 if new_pos > end: 480 raise _DecodeError('Truncated message.') 481 return new_pos 482 return DecodeRepeatedField 483 else: 484 def DecodeField(buffer, pos, end, message, field_dict): 485 """Decode serialized repeated enum to its value and a new position. 486 487 Args: 488 buffer: memoryview of the serialized bytes. 489 pos: int, position in the memory view to start at. 490 end: int, end position of serialized data 491 message: Message object to store unknown fields in 492 field_dict: Map[Descriptor, Any] to store decoded values in. 493 494 Returns: 495 int, new position in serialized data. 496 """ 497 value_start_pos = pos 498 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) 499 if pos > end: 500 raise _DecodeError('Truncated message.') 501 # pylint: disable=protected-access 502 if enum_value in enum_type.values_by_number: 503 field_dict[key] = enum_value 504 else: 505 if not message._unknown_fields: 506 message._unknown_fields = [] 507 tag_bytes = encoder.TagBytes(field_number, 508 wire_format.WIRETYPE_VARINT) 509 message._unknown_fields.append( 510 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 511 if message._unknown_field_set is None: 512 message._unknown_field_set = containers.UnknownFieldSet() 513 message._unknown_field_set._add( 514 field_number, wire_format.WIRETYPE_VARINT, enum_value) 515 # pylint: enable=protected-access 516 return pos 517 return DecodeField 518 519 520# -------------------------------------------------------------------- 521 522 523Int32Decoder = _SimpleDecoder( 524 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 525 526Int64Decoder = _SimpleDecoder( 527 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 528 529UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 530UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 531 532SInt32Decoder = _ModifiedDecoder( 533 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 534SInt64Decoder = _ModifiedDecoder( 535 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 536 537# Note that Python conveniently guarantees that when using the '<' prefix on 538# formats, they will also have the same size across all platforms (as opposed 539# to without the prefix, where their sizes depend on the C compiler's basic 540# type sizes). 541Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 542Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 543SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 544SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 545FloatDecoder = _FloatDecoder() 546DoubleDecoder = _DoubleDecoder() 547 548BoolDecoder = _ModifiedDecoder( 549 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 550 551 552def StringDecoder(field_number, is_repeated, is_packed, key, new_default, 553 is_strict_utf8=False): 554 """Returns a decoder for a string field.""" 555 556 local_DecodeVarint = _DecodeVarint 557 local_unicode = six.text_type 558 559 def _ConvertToUnicode(memview): 560 """Convert byte to unicode.""" 561 byte_str = memview.tobytes() 562 try: 563 value = local_unicode(byte_str, 'utf-8') 564 except UnicodeDecodeError as e: 565 # add more information to the error message and re-raise it. 566 e.reason = '%s in field: %s' % (e, key.full_name) 567 raise 568 569 if is_strict_utf8 and six.PY2 and sys.maxunicode > _UCS2_MAXUNICODE: 570 # Only do the check for python2 ucs4 when is_strict_utf8 enabled 571 if _SURROGATE_PATTERN.search(value): 572 reason = ('String field %s contains invalid UTF-8 data when parsing' 573 'a protocol buffer: surrogates not allowed. Use' 574 'the bytes type if you intend to send raw bytes.') % ( 575 key.full_name) 576 raise message.DecodeError(reason) 577 578 return value 579 580 assert not is_packed 581 if is_repeated: 582 tag_bytes = encoder.TagBytes(field_number, 583 wire_format.WIRETYPE_LENGTH_DELIMITED) 584 tag_len = len(tag_bytes) 585 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 586 value = field_dict.get(key) 587 if value is None: 588 value = field_dict.setdefault(key, new_default(message)) 589 while 1: 590 (size, pos) = local_DecodeVarint(buffer, pos) 591 new_pos = pos + size 592 if new_pos > end: 593 raise _DecodeError('Truncated string.') 594 value.append(_ConvertToUnicode(buffer[pos:new_pos])) 595 # Predict that the next tag is another copy of the same repeated field. 596 pos = new_pos + tag_len 597 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 598 # Prediction failed. Return. 599 return new_pos 600 return DecodeRepeatedField 601 else: 602 def DecodeField(buffer, pos, end, message, field_dict): 603 (size, pos) = local_DecodeVarint(buffer, pos) 604 new_pos = pos + size 605 if new_pos > end: 606 raise _DecodeError('Truncated string.') 607 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) 608 return new_pos 609 return DecodeField 610 611 612def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): 613 """Returns a decoder for a bytes field.""" 614 615 local_DecodeVarint = _DecodeVarint 616 617 assert not is_packed 618 if is_repeated: 619 tag_bytes = encoder.TagBytes(field_number, 620 wire_format.WIRETYPE_LENGTH_DELIMITED) 621 tag_len = len(tag_bytes) 622 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 623 value = field_dict.get(key) 624 if value is None: 625 value = field_dict.setdefault(key, new_default(message)) 626 while 1: 627 (size, pos) = local_DecodeVarint(buffer, pos) 628 new_pos = pos + size 629 if new_pos > end: 630 raise _DecodeError('Truncated string.') 631 value.append(buffer[pos:new_pos].tobytes()) 632 # Predict that the next tag is another copy of the same repeated field. 633 pos = new_pos + tag_len 634 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 635 # Prediction failed. Return. 636 return new_pos 637 return DecodeRepeatedField 638 else: 639 def DecodeField(buffer, pos, end, message, field_dict): 640 (size, pos) = local_DecodeVarint(buffer, pos) 641 new_pos = pos + size 642 if new_pos > end: 643 raise _DecodeError('Truncated string.') 644 field_dict[key] = buffer[pos:new_pos].tobytes() 645 return new_pos 646 return DecodeField 647 648 649def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 650 """Returns a decoder for a group field.""" 651 652 end_tag_bytes = encoder.TagBytes(field_number, 653 wire_format.WIRETYPE_END_GROUP) 654 end_tag_len = len(end_tag_bytes) 655 656 assert not is_packed 657 if is_repeated: 658 tag_bytes = encoder.TagBytes(field_number, 659 wire_format.WIRETYPE_START_GROUP) 660 tag_len = len(tag_bytes) 661 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 662 value = field_dict.get(key) 663 if value is None: 664 value = field_dict.setdefault(key, new_default(message)) 665 while 1: 666 value = field_dict.get(key) 667 if value is None: 668 value = field_dict.setdefault(key, new_default(message)) 669 # Read sub-message. 670 pos = value.add()._InternalParse(buffer, pos, end) 671 # Read end tag. 672 new_pos = pos+end_tag_len 673 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 674 raise _DecodeError('Missing group end tag.') 675 # Predict that the next tag is another copy of the same repeated field. 676 pos = new_pos + tag_len 677 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 678 # Prediction failed. Return. 679 return new_pos 680 return DecodeRepeatedField 681 else: 682 def DecodeField(buffer, pos, end, message, field_dict): 683 value = field_dict.get(key) 684 if value is None: 685 value = field_dict.setdefault(key, new_default(message)) 686 # Read sub-message. 687 pos = value._InternalParse(buffer, pos, end) 688 # Read end tag. 689 new_pos = pos+end_tag_len 690 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 691 raise _DecodeError('Missing group end tag.') 692 return new_pos 693 return DecodeField 694 695 696def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 697 """Returns a decoder for a message field.""" 698 699 local_DecodeVarint = _DecodeVarint 700 701 assert not is_packed 702 if is_repeated: 703 tag_bytes = encoder.TagBytes(field_number, 704 wire_format.WIRETYPE_LENGTH_DELIMITED) 705 tag_len = len(tag_bytes) 706 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 707 value = field_dict.get(key) 708 if value is None: 709 value = field_dict.setdefault(key, new_default(message)) 710 while 1: 711 # Read length. 712 (size, pos) = local_DecodeVarint(buffer, pos) 713 new_pos = pos + size 714 if new_pos > end: 715 raise _DecodeError('Truncated message.') 716 # Read sub-message. 717 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 718 # The only reason _InternalParse would return early is if it 719 # encountered an end-group tag. 720 raise _DecodeError('Unexpected end-group tag.') 721 # Predict that the next tag is another copy of the same repeated field. 722 pos = new_pos + tag_len 723 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 724 # Prediction failed. Return. 725 return new_pos 726 return DecodeRepeatedField 727 else: 728 def DecodeField(buffer, pos, end, message, field_dict): 729 value = field_dict.get(key) 730 if value is None: 731 value = field_dict.setdefault(key, new_default(message)) 732 # Read length. 733 (size, pos) = local_DecodeVarint(buffer, pos) 734 new_pos = pos + size 735 if new_pos > end: 736 raise _DecodeError('Truncated message.') 737 # Read sub-message. 738 if value._InternalParse(buffer, pos, new_pos) != new_pos: 739 # The only reason _InternalParse would return early is if it encountered 740 # an end-group tag. 741 raise _DecodeError('Unexpected end-group tag.') 742 return new_pos 743 return DecodeField 744 745 746# -------------------------------------------------------------------- 747 748MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 749 750def MessageSetItemDecoder(descriptor): 751 """Returns a decoder for a MessageSet item. 752 753 The parameter is the message Descriptor. 754 755 The message set message looks like this: 756 message MessageSet { 757 repeated group Item = 1 { 758 required int32 type_id = 2; 759 required string message = 3; 760 } 761 } 762 """ 763 764 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 765 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 766 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 767 768 local_ReadTag = ReadTag 769 local_DecodeVarint = _DecodeVarint 770 local_SkipField = SkipField 771 772 def DecodeItem(buffer, pos, end, message, field_dict): 773 """Decode serialized message set to its value and new position. 774 775 Args: 776 buffer: memoryview of the serialized bytes. 777 pos: int, position in the memory view to start at. 778 end: int, end position of serialized data 779 message: Message object to store unknown fields in 780 field_dict: Map[Descriptor, Any] to store decoded values in. 781 782 Returns: 783 int, new position in serialized data. 784 """ 785 message_set_item_start = pos 786 type_id = -1 787 message_start = -1 788 message_end = -1 789 790 # Technically, type_id and message can appear in any order, so we need 791 # a little loop here. 792 while 1: 793 (tag_bytes, pos) = local_ReadTag(buffer, pos) 794 if tag_bytes == type_id_tag_bytes: 795 (type_id, pos) = local_DecodeVarint(buffer, pos) 796 elif tag_bytes == message_tag_bytes: 797 (size, message_start) = local_DecodeVarint(buffer, pos) 798 pos = message_end = message_start + size 799 elif tag_bytes == item_end_tag_bytes: 800 break 801 else: 802 pos = SkipField(buffer, pos, end, tag_bytes) 803 if pos == -1: 804 raise _DecodeError('Missing group end tag.') 805 806 if pos > end: 807 raise _DecodeError('Truncated message.') 808 809 if type_id == -1: 810 raise _DecodeError('MessageSet item missing type_id.') 811 if message_start == -1: 812 raise _DecodeError('MessageSet item missing message.') 813 814 extension = message.Extensions._FindExtensionByNumber(type_id) 815 # pylint: disable=protected-access 816 if extension is not None: 817 value = field_dict.get(extension) 818 if value is None: 819 message_type = extension.message_type 820 if not hasattr(message_type, '_concrete_class'): 821 # pylint: disable=protected-access 822 message._FACTORY.GetPrototype(message_type) 823 value = field_dict.setdefault( 824 extension, message_type._concrete_class()) 825 if value._InternalParse(buffer, message_start,message_end) != message_end: 826 # The only reason _InternalParse would return early is if it encountered 827 # an end-group tag. 828 raise _DecodeError('Unexpected end-group tag.') 829 else: 830 if not message._unknown_fields: 831 message._unknown_fields = [] 832 message._unknown_fields.append( 833 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) 834 if message._unknown_field_set is None: 835 message._unknown_field_set = containers.UnknownFieldSet() 836 message._unknown_field_set._add( 837 type_id, 838 wire_format.WIRETYPE_LENGTH_DELIMITED, 839 buffer[message_start:message_end].tobytes()) 840 # pylint: enable=protected-access 841 842 return pos 843 844 return DecodeItem 845 846# -------------------------------------------------------------------- 847 848def MapDecoder(field_descriptor, new_default, is_message_map): 849 """Returns a decoder for a map field.""" 850 851 key = field_descriptor 852 tag_bytes = encoder.TagBytes(field_descriptor.number, 853 wire_format.WIRETYPE_LENGTH_DELIMITED) 854 tag_len = len(tag_bytes) 855 local_DecodeVarint = _DecodeVarint 856 # Can't read _concrete_class yet; might not be initialized. 857 message_type = field_descriptor.message_type 858 859 def DecodeMap(buffer, pos, end, message, field_dict): 860 submsg = message_type._concrete_class() 861 value = field_dict.get(key) 862 if value is None: 863 value = field_dict.setdefault(key, new_default(message)) 864 while 1: 865 # Read length. 866 (size, pos) = local_DecodeVarint(buffer, pos) 867 new_pos = pos + size 868 if new_pos > end: 869 raise _DecodeError('Truncated message.') 870 # Read sub-message. 871 submsg.Clear() 872 if submsg._InternalParse(buffer, pos, new_pos) != new_pos: 873 # The only reason _InternalParse would return early is if it 874 # encountered an end-group tag. 875 raise _DecodeError('Unexpected end-group tag.') 876 877 if is_message_map: 878 value[submsg.key].CopyFrom(submsg.value) 879 else: 880 value[submsg.key] = submsg.value 881 882 # Predict that the next tag is another copy of the same repeated field. 883 pos = new_pos + tag_len 884 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 885 # Prediction failed. Return. 886 return new_pos 887 888 return DecodeMap 889 890# -------------------------------------------------------------------- 891# Optimization is not as heavy here because calls to SkipField() are rare, 892# except for handling end-group tags. 893 894def _SkipVarint(buffer, pos, end): 895 """Skip a varint value. Returns the new position.""" 896 # Previously ord(buffer[pos]) raised IndexError when pos is out of range. 897 # With this code, ord(b'') raises TypeError. Both are handled in 898 # python_message.py to generate a 'Truncated message' error. 899 while ord(buffer[pos:pos+1].tobytes()) & 0x80: 900 pos += 1 901 pos += 1 902 if pos > end: 903 raise _DecodeError('Truncated message.') 904 return pos 905 906def _SkipFixed64(buffer, pos, end): 907 """Skip a fixed64 value. Returns the new position.""" 908 909 pos += 8 910 if pos > end: 911 raise _DecodeError('Truncated message.') 912 return pos 913 914 915def _DecodeFixed64(buffer, pos): 916 """Decode a fixed64.""" 917 new_pos = pos + 8 918 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) 919 920 921def _SkipLengthDelimited(buffer, pos, end): 922 """Skip a length-delimited value. Returns the new position.""" 923 924 (size, pos) = _DecodeVarint(buffer, pos) 925 pos += size 926 if pos > end: 927 raise _DecodeError('Truncated message.') 928 return pos 929 930 931def _SkipGroup(buffer, pos, end): 932 """Skip sub-group. Returns the new position.""" 933 934 while 1: 935 (tag_bytes, pos) = ReadTag(buffer, pos) 936 new_pos = SkipField(buffer, pos, end, tag_bytes) 937 if new_pos == -1: 938 return pos 939 pos = new_pos 940 941 942def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): 943 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" 944 945 unknown_field_set = containers.UnknownFieldSet() 946 while end_pos is None or pos < end_pos: 947 (tag_bytes, pos) = ReadTag(buffer, pos) 948 (tag, _) = _DecodeVarint(tag_bytes, 0) 949 field_number, wire_type = wire_format.UnpackTag(tag) 950 if wire_type == wire_format.WIRETYPE_END_GROUP: 951 break 952 (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) 953 # pylint: disable=protected-access 954 unknown_field_set._add(field_number, wire_type, data) 955 956 return (unknown_field_set, pos) 957 958 959def _DecodeUnknownField(buffer, pos, wire_type): 960 """Decode a unknown field. Returns the UnknownField and new position.""" 961 962 if wire_type == wire_format.WIRETYPE_VARINT: 963 (data, pos) = _DecodeVarint(buffer, pos) 964 elif wire_type == wire_format.WIRETYPE_FIXED64: 965 (data, pos) = _DecodeFixed64(buffer, pos) 966 elif wire_type == wire_format.WIRETYPE_FIXED32: 967 (data, pos) = _DecodeFixed32(buffer, pos) 968 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 969 (size, pos) = _DecodeVarint(buffer, pos) 970 data = buffer[pos:pos+size].tobytes() 971 pos += size 972 elif wire_type == wire_format.WIRETYPE_START_GROUP: 973 (data, pos) = _DecodeUnknownFieldSet(buffer, pos) 974 elif wire_type == wire_format.WIRETYPE_END_GROUP: 975 return (0, -1) 976 else: 977 raise _DecodeError('Wrong wire type in tag.') 978 979 return (data, pos) 980 981 982def _EndGroup(buffer, pos, end): 983 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" 984 985 return -1 986 987 988def _SkipFixed32(buffer, pos, end): 989 """Skip a fixed32 value. Returns the new position.""" 990 991 pos += 4 992 if pos > end: 993 raise _DecodeError('Truncated message.') 994 return pos 995 996 997def _DecodeFixed32(buffer, pos): 998 """Decode a fixed32.""" 999 1000 new_pos = pos + 4 1001 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) 1002 1003 1004def _RaiseInvalidWireType(buffer, pos, end): 1005 """Skip function for unknown wire types. Raises an exception.""" 1006 1007 raise _DecodeError('Tag had invalid wire type.') 1008 1009def _FieldSkipper(): 1010 """Constructs the SkipField function.""" 1011 1012 WIRETYPE_TO_SKIPPER = [ 1013 _SkipVarint, 1014 _SkipFixed64, 1015 _SkipLengthDelimited, 1016 _SkipGroup, 1017 _EndGroup, 1018 _SkipFixed32, 1019 _RaiseInvalidWireType, 1020 _RaiseInvalidWireType, 1021 ] 1022 1023 wiretype_mask = wire_format.TAG_TYPE_MASK 1024 1025 def SkipField(buffer, pos, end, tag_bytes): 1026 """Skips a field with the specified tag. 1027 1028 |pos| should point to the byte immediately after the tag. 1029 1030 Returns: 1031 The new position (after the tag value), or -1 if the tag is an end-group 1032 tag (in which case the calling loop should break). 1033 """ 1034 1035 # The wire type is always in the first byte since varints are little-endian. 1036 wire_type = ord(tag_bytes[0:1]) & wiretype_mask 1037 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 1038 1039 return SkipField 1040 1041SkipField = _FieldSkipper() 1042