1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31"""Code for decoding protocol buffer primitives. 32 33This code is very similar to encoder.py -- read the docs for that module first. 34 35A "decoder" is a function with the signature: 36 Decode(buffer, pos, end, message, field_dict) 37The arguments are: 38 buffer: The string containing the encoded message. 39 pos: The current position in the string. 40 end: The position in the string where the current message ends. May be 41 less than len(buffer) if we're reading a sub-message. 42 message: The message object into which we're parsing. 43 field_dict: message._fields (avoids a hashtable lookup). 44The decoder reads the field and stores it into field_dict, returning the new 45buffer position. A decoder for a repeated field may proactively decode all of 46the elements of that field, if they appear consecutively. 47 48Note that decoders may throw any of the following: 49 IndexError: Indicates a truncated message. 50 struct.error: Unpacking of a fixed-width field failed. 51 message.DecodeError: Other errors. 52 53Decoders are expected to raise an exception if they are called with pos > end. 54This allows callers to be lax about bounds checking: it's fineto read past 55"end" as long as you are sure that someone else will notice and throw an 56exception later on. 57 58Something up the call stack is expected to catch IndexError and struct.error 59and convert them to message.DecodeError. 60 61Decoders are constructed using decoder constructors with the signature: 62 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 63The arguments are: 64 field_number: The field number of the field we want to decode. 65 is_repeated: Is the field a repeated field? (bool) 66 is_packed: Is the field a packed field? (bool) 67 key: The key to use when looking up the field within field_dict. 68 (This is actually the FieldDescriptor but nothing in this 69 file should depend on that.) 70 new_default: A function which takes a message object as a parameter and 71 returns a new instance of the default value for this field. 72 (This is called for repeated fields and sub-messages, when an 73 instance does not already exist.) 74 75As with encoders, we define a decoder constructor for every type of field. 76Then, for every field of every message class we construct an actual decoder. 77That decoder goes into a dict indexed by tag, so when we decode a message 78we repeatedly read a tag, look up the corresponding decoder, and invoke it. 79""" 80 81__author__ = 'kenton@google.com (Kenton Varda)' 82 83import struct 84import sys 85import six 86 87_UCS2_MAXUNICODE = 65535 88if six.PY3: 89 long = int 90else: 91 import re # pylint: disable=g-import-not-at-top 92 _SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]')) 93 94from google.protobuf.internal import containers 95from google.protobuf.internal import encoder 96from google.protobuf.internal import wire_format 97from google.protobuf import message 98 99 100# This will overflow and thus become IEEE-754 "infinity". We would use 101# "float('inf')" but it doesn't work on Windows pre-Python-2.6. 102_POS_INF = 1e10000 103_NEG_INF = -_POS_INF 104_NAN = _POS_INF * 0 105 106 107# This is not for optimization, but rather to avoid conflicts with local 108# variables named "message". 109_DecodeError = message.DecodeError 110 111 112def _VarintDecoder(mask, result_type): 113 """Return an encoder for a basic varint value (does not include tag). 114 115 Decoded values will be bitwise-anded with the given mask before being 116 returned, e.g. to limit them to 32 bits. The returned decoder does not 117 take the usual "end" parameter -- the caller is expected to do bounds checking 118 after the fact (often the caller can defer such checking until later). The 119 decoder returns a (value, new_pos) pair. 120 """ 121 122 def DecodeVarint(buffer, pos): 123 result = 0 124 shift = 0 125 while 1: 126 b = six.indexbytes(buffer, pos) 127 result |= ((b & 0x7f) << shift) 128 pos += 1 129 if not (b & 0x80): 130 result &= mask 131 result = result_type(result) 132 return (result, pos) 133 shift += 7 134 if shift >= 64: 135 raise _DecodeError('Too many bytes when decoding varint.') 136 return DecodeVarint 137 138 139def _SignedVarintDecoder(bits, result_type): 140 """Like _VarintDecoder() but decodes signed values.""" 141 142 signbit = 1 << (bits - 1) 143 mask = (1 << bits) - 1 144 145 def DecodeVarint(buffer, pos): 146 result = 0 147 shift = 0 148 while 1: 149 b = six.indexbytes(buffer, pos) 150 result |= ((b & 0x7f) << shift) 151 pos += 1 152 if not (b & 0x80): 153 result &= mask 154 result = (result ^ signbit) - signbit 155 result = result_type(result) 156 return (result, pos) 157 shift += 7 158 if shift >= 64: 159 raise _DecodeError('Too many bytes when decoding varint.') 160 return DecodeVarint 161 162# We force 32-bit values to int and 64-bit values to long to make 163# alternate implementations where the distinction is more significant 164# (e.g. the C++ implementation) simpler. 165 166_DecodeVarint = _VarintDecoder((1 << 64) - 1, long) 167_DecodeSignedVarint = _SignedVarintDecoder(64, long) 168 169# Use these versions for values which must be limited to 32 bits. 170_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) 171_DecodeSignedVarint32 = _SignedVarintDecoder(32, int) 172 173 174def ReadTag(buffer, pos): 175 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. 176 177 We return the raw bytes of the tag rather than decoding them. The raw 178 bytes can then be used to look up the proper decoder. This effectively allows 179 us to trade some work that would be done in pure-python (decoding a varint) 180 for work that is done in C (searching for a byte string in a hash table). 181 In a low-level language it would be much cheaper to decode the varint and 182 use that, but not in Python. 183 184 Args: 185 buffer: memoryview object of the encoded bytes 186 pos: int of the current position to start from 187 188 Returns: 189 Tuple[bytes, int] of the tag data and new position. 190 """ 191 start = pos 192 while six.indexbytes(buffer, pos) & 0x80: 193 pos += 1 194 pos += 1 195 196 tag_bytes = buffer[start:pos].tobytes() 197 return tag_bytes, pos 198 199 200# -------------------------------------------------------------------- 201 202 203def _SimpleDecoder(wire_type, decode_value): 204 """Return a constructor for a decoder for fields of a particular type. 205 206 Args: 207 wire_type: The field's wire type. 208 decode_value: A function which decodes an individual value, e.g. 209 _DecodeVarint() 210 """ 211 212 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default, 213 clear_if_default=False): 214 if is_packed: 215 local_DecodeVarint = _DecodeVarint 216 def DecodePackedField(buffer, pos, end, message, field_dict): 217 value = field_dict.get(key) 218 if value is None: 219 value = field_dict.setdefault(key, new_default(message)) 220 (endpoint, pos) = local_DecodeVarint(buffer, pos) 221 endpoint += pos 222 if endpoint > end: 223 raise _DecodeError('Truncated message.') 224 while pos < endpoint: 225 (element, pos) = decode_value(buffer, pos) 226 value.append(element) 227 if pos > endpoint: 228 del value[-1] # Discard corrupt value. 229 raise _DecodeError('Packed element was truncated.') 230 return pos 231 return DecodePackedField 232 elif is_repeated: 233 tag_bytes = encoder.TagBytes(field_number, wire_type) 234 tag_len = len(tag_bytes) 235 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 236 value = field_dict.get(key) 237 if value is None: 238 value = field_dict.setdefault(key, new_default(message)) 239 while 1: 240 (element, new_pos) = decode_value(buffer, pos) 241 value.append(element) 242 # Predict that the next tag is another copy of the same repeated 243 # field. 244 pos = new_pos + tag_len 245 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 246 # Prediction failed. Return. 247 if new_pos > end: 248 raise _DecodeError('Truncated message.') 249 return new_pos 250 return DecodeRepeatedField 251 else: 252 def DecodeField(buffer, pos, end, message, field_dict): 253 (new_value, pos) = decode_value(buffer, pos) 254 if pos > end: 255 raise _DecodeError('Truncated message.') 256 if clear_if_default and not new_value: 257 field_dict.pop(key, None) 258 else: 259 field_dict[key] = new_value 260 return pos 261 return DecodeField 262 263 return SpecificDecoder 264 265 266def _ModifiedDecoder(wire_type, decode_value, modify_value): 267 """Like SimpleDecoder but additionally invokes modify_value on every value 268 before storing it. Usually modify_value is ZigZagDecode. 269 """ 270 271 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 272 # not enough to make a significant difference. 273 274 def InnerDecode(buffer, pos): 275 (result, new_pos) = decode_value(buffer, pos) 276 return (modify_value(result), new_pos) 277 return _SimpleDecoder(wire_type, InnerDecode) 278 279 280def _StructPackDecoder(wire_type, format): 281 """Return a constructor for a decoder for a fixed-width field. 282 283 Args: 284 wire_type: The field's wire type. 285 format: The format string to pass to struct.unpack(). 286 """ 287 288 value_size = struct.calcsize(format) 289 local_unpack = struct.unpack 290 291 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 292 # not enough to make a significant difference. 293 294 # Note that we expect someone up-stack to catch struct.error and convert 295 # it to _DecodeError -- this way we don't have to set up exception- 296 # handling blocks every time we parse one value. 297 298 def InnerDecode(buffer, pos): 299 new_pos = pos + value_size 300 result = local_unpack(format, buffer[pos:new_pos])[0] 301 return (result, new_pos) 302 return _SimpleDecoder(wire_type, InnerDecode) 303 304 305def _FloatDecoder(): 306 """Returns a decoder for a float field. 307 308 This code works around a bug in struct.unpack for non-finite 32-bit 309 floating-point values. 310 """ 311 312 local_unpack = struct.unpack 313 314 def InnerDecode(buffer, pos): 315 """Decode serialized float to a float and new position. 316 317 Args: 318 buffer: memoryview of the serialized bytes 319 pos: int, position in the memory view to start at. 320 321 Returns: 322 Tuple[float, int] of the deserialized float value and new position 323 in the serialized data. 324 """ 325 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign 326 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. 327 new_pos = pos + 4 328 float_bytes = buffer[pos:new_pos].tobytes() 329 330 # If this value has all its exponent bits set, then it's non-finite. 331 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. 332 # To avoid that, we parse it specially. 333 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): 334 # If at least one significand bit is set... 335 if float_bytes[0:3] != b'\x00\x00\x80': 336 return (_NAN, new_pos) 337 # If sign bit is set... 338 if float_bytes[3:4] == b'\xFF': 339 return (_NEG_INF, new_pos) 340 return (_POS_INF, new_pos) 341 342 # Note that we expect someone up-stack to catch struct.error and convert 343 # it to _DecodeError -- this way we don't have to set up exception- 344 # handling blocks every time we parse one value. 345 result = local_unpack('<f', float_bytes)[0] 346 return (result, new_pos) 347 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) 348 349 350def _DoubleDecoder(): 351 """Returns a decoder for a double field. 352 353 This code works around a bug in struct.unpack for not-a-number. 354 """ 355 356 local_unpack = struct.unpack 357 358 def InnerDecode(buffer, pos): 359 """Decode serialized double to a double and new position. 360 361 Args: 362 buffer: memoryview of the serialized bytes. 363 pos: int, position in the memory view to start at. 364 365 Returns: 366 Tuple[float, int] of the decoded double value and new position 367 in the serialized data. 368 """ 369 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign 370 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. 371 new_pos = pos + 8 372 double_bytes = buffer[pos:new_pos].tobytes() 373 374 # If this value has all its exponent bits set and at least one significand 375 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it 376 # as inf or -inf. To avoid that, we treat it specially. 377 if ((double_bytes[7:8] in b'\x7F\xFF') 378 and (double_bytes[6:7] >= b'\xF0') 379 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): 380 return (_NAN, new_pos) 381 382 # Note that we expect someone up-stack to catch struct.error and convert 383 # it to _DecodeError -- this way we don't have to set up exception- 384 # handling blocks every time we parse one value. 385 result = local_unpack('<d', double_bytes)[0] 386 return (result, new_pos) 387 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) 388 389 390def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, 391 clear_if_default=False): 392 """Returns a decoder for enum field.""" 393 enum_type = key.enum_type 394 if is_packed: 395 local_DecodeVarint = _DecodeVarint 396 def DecodePackedField(buffer, pos, end, message, field_dict): 397 """Decode serialized packed enum to its value and a new position. 398 399 Args: 400 buffer: memoryview of the serialized bytes. 401 pos: int, position in the memory view to start at. 402 end: int, end position of serialized data 403 message: Message object to store unknown fields in 404 field_dict: Map[Descriptor, Any] to store decoded values in. 405 406 Returns: 407 int, new position in serialized data. 408 """ 409 value = field_dict.get(key) 410 if value is None: 411 value = field_dict.setdefault(key, new_default(message)) 412 (endpoint, pos) = local_DecodeVarint(buffer, pos) 413 endpoint += pos 414 if endpoint > end: 415 raise _DecodeError('Truncated message.') 416 while pos < endpoint: 417 value_start_pos = pos 418 (element, pos) = _DecodeSignedVarint32(buffer, pos) 419 # pylint: disable=protected-access 420 if element in enum_type.values_by_number: 421 value.append(element) 422 else: 423 if not message._unknown_fields: 424 message._unknown_fields = [] 425 tag_bytes = encoder.TagBytes(field_number, 426 wire_format.WIRETYPE_VARINT) 427 428 message._unknown_fields.append( 429 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 430 if message._unknown_field_set is None: 431 message._unknown_field_set = containers.UnknownFieldSet() 432 message._unknown_field_set._add( 433 field_number, wire_format.WIRETYPE_VARINT, element) 434 # pylint: enable=protected-access 435 if pos > endpoint: 436 if element in enum_type.values_by_number: 437 del value[-1] # Discard corrupt value. 438 else: 439 del message._unknown_fields[-1] 440 # pylint: disable=protected-access 441 del message._unknown_field_set._values[-1] 442 # pylint: enable=protected-access 443 raise _DecodeError('Packed element was truncated.') 444 return pos 445 return DecodePackedField 446 elif is_repeated: 447 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) 448 tag_len = len(tag_bytes) 449 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 450 """Decode serialized repeated enum to its value and a new position. 451 452 Args: 453 buffer: memoryview of the serialized bytes. 454 pos: int, position in the memory view to start at. 455 end: int, end position of serialized data 456 message: Message object to store unknown fields in 457 field_dict: Map[Descriptor, Any] to store decoded values in. 458 459 Returns: 460 int, new position in serialized data. 461 """ 462 value = field_dict.get(key) 463 if value is None: 464 value = field_dict.setdefault(key, new_default(message)) 465 while 1: 466 (element, new_pos) = _DecodeSignedVarint32(buffer, pos) 467 # pylint: disable=protected-access 468 if element in enum_type.values_by_number: 469 value.append(element) 470 else: 471 if not message._unknown_fields: 472 message._unknown_fields = [] 473 message._unknown_fields.append( 474 (tag_bytes, buffer[pos:new_pos].tobytes())) 475 if message._unknown_field_set is None: 476 message._unknown_field_set = containers.UnknownFieldSet() 477 message._unknown_field_set._add( 478 field_number, wire_format.WIRETYPE_VARINT, element) 479 # pylint: enable=protected-access 480 # Predict that the next tag is another copy of the same repeated 481 # field. 482 pos = new_pos + tag_len 483 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 484 # Prediction failed. Return. 485 if new_pos > end: 486 raise _DecodeError('Truncated message.') 487 return new_pos 488 return DecodeRepeatedField 489 else: 490 def DecodeField(buffer, pos, end, message, field_dict): 491 """Decode serialized repeated enum to its value and a new position. 492 493 Args: 494 buffer: memoryview of the serialized bytes. 495 pos: int, position in the memory view to start at. 496 end: int, end position of serialized data 497 message: Message object to store unknown fields in 498 field_dict: Map[Descriptor, Any] to store decoded values in. 499 500 Returns: 501 int, new position in serialized data. 502 """ 503 value_start_pos = pos 504 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) 505 if pos > end: 506 raise _DecodeError('Truncated message.') 507 if clear_if_default and not enum_value: 508 field_dict.pop(key, None) 509 return pos 510 # pylint: disable=protected-access 511 if enum_value in enum_type.values_by_number: 512 field_dict[key] = enum_value 513 else: 514 if not message._unknown_fields: 515 message._unknown_fields = [] 516 tag_bytes = encoder.TagBytes(field_number, 517 wire_format.WIRETYPE_VARINT) 518 message._unknown_fields.append( 519 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 520 if message._unknown_field_set is None: 521 message._unknown_field_set = containers.UnknownFieldSet() 522 message._unknown_field_set._add( 523 field_number, wire_format.WIRETYPE_VARINT, enum_value) 524 # pylint: enable=protected-access 525 return pos 526 return DecodeField 527 528 529# -------------------------------------------------------------------- 530 531 532Int32Decoder = _SimpleDecoder( 533 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 534 535Int64Decoder = _SimpleDecoder( 536 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 537 538UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 539UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 540 541SInt32Decoder = _ModifiedDecoder( 542 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 543SInt64Decoder = _ModifiedDecoder( 544 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 545 546# Note that Python conveniently guarantees that when using the '<' prefix on 547# formats, they will also have the same size across all platforms (as opposed 548# to without the prefix, where their sizes depend on the C compiler's basic 549# type sizes). 550Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 551Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 552SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 553SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 554FloatDecoder = _FloatDecoder() 555DoubleDecoder = _DoubleDecoder() 556 557BoolDecoder = _ModifiedDecoder( 558 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 559 560 561def StringDecoder(field_number, is_repeated, is_packed, key, new_default, 562 is_strict_utf8=False, clear_if_default=False): 563 """Returns a decoder for a string field.""" 564 565 local_DecodeVarint = _DecodeVarint 566 local_unicode = six.text_type 567 568 def _ConvertToUnicode(memview): 569 """Convert byte to unicode.""" 570 byte_str = memview.tobytes() 571 try: 572 value = local_unicode(byte_str, 'utf-8') 573 except UnicodeDecodeError as e: 574 # add more information to the error message and re-raise it. 575 e.reason = '%s in field: %s' % (e, key.full_name) 576 raise 577 578 if is_strict_utf8 and six.PY2 and sys.maxunicode > _UCS2_MAXUNICODE: 579 # Only do the check for python2 ucs4 when is_strict_utf8 enabled 580 if _SURROGATE_PATTERN.search(value): 581 reason = ('String field %s contains invalid UTF-8 data when parsing' 582 'a protocol buffer: surrogates not allowed. Use' 583 'the bytes type if you intend to send raw bytes.') % ( 584 key.full_name) 585 raise message.DecodeError(reason) 586 587 return value 588 589 assert not is_packed 590 if is_repeated: 591 tag_bytes = encoder.TagBytes(field_number, 592 wire_format.WIRETYPE_LENGTH_DELIMITED) 593 tag_len = len(tag_bytes) 594 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 595 value = field_dict.get(key) 596 if value is None: 597 value = field_dict.setdefault(key, new_default(message)) 598 while 1: 599 (size, pos) = local_DecodeVarint(buffer, pos) 600 new_pos = pos + size 601 if new_pos > end: 602 raise _DecodeError('Truncated string.') 603 value.append(_ConvertToUnicode(buffer[pos:new_pos])) 604 # Predict that the next tag is another copy of the same repeated field. 605 pos = new_pos + tag_len 606 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 607 # Prediction failed. Return. 608 return new_pos 609 return DecodeRepeatedField 610 else: 611 def DecodeField(buffer, pos, end, message, field_dict): 612 (size, pos) = local_DecodeVarint(buffer, pos) 613 new_pos = pos + size 614 if new_pos > end: 615 raise _DecodeError('Truncated string.') 616 if clear_if_default and not size: 617 field_dict.pop(key, None) 618 else: 619 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) 620 return new_pos 621 return DecodeField 622 623 624def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, 625 clear_if_default=False): 626 """Returns a decoder for a bytes field.""" 627 628 local_DecodeVarint = _DecodeVarint 629 630 assert not is_packed 631 if is_repeated: 632 tag_bytes = encoder.TagBytes(field_number, 633 wire_format.WIRETYPE_LENGTH_DELIMITED) 634 tag_len = len(tag_bytes) 635 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 636 value = field_dict.get(key) 637 if value is None: 638 value = field_dict.setdefault(key, new_default(message)) 639 while 1: 640 (size, pos) = local_DecodeVarint(buffer, pos) 641 new_pos = pos + size 642 if new_pos > end: 643 raise _DecodeError('Truncated string.') 644 value.append(buffer[pos:new_pos].tobytes()) 645 # Predict that the next tag is another copy of the same repeated field. 646 pos = new_pos + tag_len 647 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 648 # Prediction failed. Return. 649 return new_pos 650 return DecodeRepeatedField 651 else: 652 def DecodeField(buffer, pos, end, message, field_dict): 653 (size, pos) = local_DecodeVarint(buffer, pos) 654 new_pos = pos + size 655 if new_pos > end: 656 raise _DecodeError('Truncated string.') 657 if clear_if_default and not size: 658 field_dict.pop(key, None) 659 else: 660 field_dict[key] = buffer[pos:new_pos].tobytes() 661 return new_pos 662 return DecodeField 663 664 665def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 666 """Returns a decoder for a group field.""" 667 668 end_tag_bytes = encoder.TagBytes(field_number, 669 wire_format.WIRETYPE_END_GROUP) 670 end_tag_len = len(end_tag_bytes) 671 672 assert not is_packed 673 if is_repeated: 674 tag_bytes = encoder.TagBytes(field_number, 675 wire_format.WIRETYPE_START_GROUP) 676 tag_len = len(tag_bytes) 677 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 678 value = field_dict.get(key) 679 if value is None: 680 value = field_dict.setdefault(key, new_default(message)) 681 while 1: 682 value = field_dict.get(key) 683 if value is None: 684 value = field_dict.setdefault(key, new_default(message)) 685 # Read sub-message. 686 pos = value.add()._InternalParse(buffer, pos, end) 687 # Read end tag. 688 new_pos = pos+end_tag_len 689 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 690 raise _DecodeError('Missing group end tag.') 691 # Predict that the next tag is another copy of the same repeated field. 692 pos = new_pos + tag_len 693 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 694 # Prediction failed. Return. 695 return new_pos 696 return DecodeRepeatedField 697 else: 698 def DecodeField(buffer, pos, end, message, field_dict): 699 value = field_dict.get(key) 700 if value is None: 701 value = field_dict.setdefault(key, new_default(message)) 702 # Read sub-message. 703 pos = value._InternalParse(buffer, pos, end) 704 # Read end tag. 705 new_pos = pos+end_tag_len 706 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 707 raise _DecodeError('Missing group end tag.') 708 return new_pos 709 return DecodeField 710 711 712def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 713 """Returns a decoder for a message field.""" 714 715 local_DecodeVarint = _DecodeVarint 716 717 assert not is_packed 718 if is_repeated: 719 tag_bytes = encoder.TagBytes(field_number, 720 wire_format.WIRETYPE_LENGTH_DELIMITED) 721 tag_len = len(tag_bytes) 722 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 723 value = field_dict.get(key) 724 if value is None: 725 value = field_dict.setdefault(key, new_default(message)) 726 while 1: 727 # Read length. 728 (size, pos) = local_DecodeVarint(buffer, pos) 729 new_pos = pos + size 730 if new_pos > end: 731 raise _DecodeError('Truncated message.') 732 # Read sub-message. 733 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 734 # The only reason _InternalParse would return early is if it 735 # encountered an end-group tag. 736 raise _DecodeError('Unexpected end-group tag.') 737 # Predict that the next tag is another copy of the same repeated field. 738 pos = new_pos + tag_len 739 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 740 # Prediction failed. Return. 741 return new_pos 742 return DecodeRepeatedField 743 else: 744 def DecodeField(buffer, pos, end, message, field_dict): 745 value = field_dict.get(key) 746 if value is None: 747 value = field_dict.setdefault(key, new_default(message)) 748 # Read length. 749 (size, pos) = local_DecodeVarint(buffer, pos) 750 new_pos = pos + size 751 if new_pos > end: 752 raise _DecodeError('Truncated message.') 753 # Read sub-message. 754 if value._InternalParse(buffer, pos, new_pos) != new_pos: 755 # The only reason _InternalParse would return early is if it encountered 756 # an end-group tag. 757 raise _DecodeError('Unexpected end-group tag.') 758 return new_pos 759 return DecodeField 760 761 762# -------------------------------------------------------------------- 763 764MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 765 766def MessageSetItemDecoder(descriptor): 767 """Returns a decoder for a MessageSet item. 768 769 The parameter is the message Descriptor. 770 771 The message set message looks like this: 772 message MessageSet { 773 repeated group Item = 1 { 774 required int32 type_id = 2; 775 required string message = 3; 776 } 777 } 778 """ 779 780 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 781 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 782 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 783 784 local_ReadTag = ReadTag 785 local_DecodeVarint = _DecodeVarint 786 local_SkipField = SkipField 787 788 def DecodeItem(buffer, pos, end, message, field_dict): 789 """Decode serialized message set to its value and new position. 790 791 Args: 792 buffer: memoryview of the serialized bytes. 793 pos: int, position in the memory view to start at. 794 end: int, end position of serialized data 795 message: Message object to store unknown fields in 796 field_dict: Map[Descriptor, Any] to store decoded values in. 797 798 Returns: 799 int, new position in serialized data. 800 """ 801 message_set_item_start = pos 802 type_id = -1 803 message_start = -1 804 message_end = -1 805 806 # Technically, type_id and message can appear in any order, so we need 807 # a little loop here. 808 while 1: 809 (tag_bytes, pos) = local_ReadTag(buffer, pos) 810 if tag_bytes == type_id_tag_bytes: 811 (type_id, pos) = local_DecodeVarint(buffer, pos) 812 elif tag_bytes == message_tag_bytes: 813 (size, message_start) = local_DecodeVarint(buffer, pos) 814 pos = message_end = message_start + size 815 elif tag_bytes == item_end_tag_bytes: 816 break 817 else: 818 pos = SkipField(buffer, pos, end, tag_bytes) 819 if pos == -1: 820 raise _DecodeError('Missing group end tag.') 821 822 if pos > end: 823 raise _DecodeError('Truncated message.') 824 825 if type_id == -1: 826 raise _DecodeError('MessageSet item missing type_id.') 827 if message_start == -1: 828 raise _DecodeError('MessageSet item missing message.') 829 830 extension = message.Extensions._FindExtensionByNumber(type_id) 831 # pylint: disable=protected-access 832 if extension is not None: 833 value = field_dict.get(extension) 834 if value is None: 835 message_type = extension.message_type 836 if not hasattr(message_type, '_concrete_class'): 837 # pylint: disable=protected-access 838 message._FACTORY.GetPrototype(message_type) 839 value = field_dict.setdefault( 840 extension, message_type._concrete_class()) 841 if value._InternalParse(buffer, message_start,message_end) != message_end: 842 # The only reason _InternalParse would return early is if it encountered 843 # an end-group tag. 844 raise _DecodeError('Unexpected end-group tag.') 845 else: 846 if not message._unknown_fields: 847 message._unknown_fields = [] 848 message._unknown_fields.append( 849 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) 850 if message._unknown_field_set is None: 851 message._unknown_field_set = containers.UnknownFieldSet() 852 message._unknown_field_set._add( 853 type_id, 854 wire_format.WIRETYPE_LENGTH_DELIMITED, 855 buffer[message_start:message_end].tobytes()) 856 # pylint: enable=protected-access 857 858 return pos 859 860 return DecodeItem 861 862# -------------------------------------------------------------------- 863 864def MapDecoder(field_descriptor, new_default, is_message_map): 865 """Returns a decoder for a map field.""" 866 867 key = field_descriptor 868 tag_bytes = encoder.TagBytes(field_descriptor.number, 869 wire_format.WIRETYPE_LENGTH_DELIMITED) 870 tag_len = len(tag_bytes) 871 local_DecodeVarint = _DecodeVarint 872 # Can't read _concrete_class yet; might not be initialized. 873 message_type = field_descriptor.message_type 874 875 def DecodeMap(buffer, pos, end, message, field_dict): 876 submsg = message_type._concrete_class() 877 value = field_dict.get(key) 878 if value is None: 879 value = field_dict.setdefault(key, new_default(message)) 880 while 1: 881 # Read length. 882 (size, pos) = local_DecodeVarint(buffer, pos) 883 new_pos = pos + size 884 if new_pos > end: 885 raise _DecodeError('Truncated message.') 886 # Read sub-message. 887 submsg.Clear() 888 if submsg._InternalParse(buffer, pos, new_pos) != new_pos: 889 # The only reason _InternalParse would return early is if it 890 # encountered an end-group tag. 891 raise _DecodeError('Unexpected end-group tag.') 892 893 if is_message_map: 894 value[submsg.key].CopyFrom(submsg.value) 895 else: 896 value[submsg.key] = submsg.value 897 898 # Predict that the next tag is another copy of the same repeated field. 899 pos = new_pos + tag_len 900 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 901 # Prediction failed. Return. 902 return new_pos 903 904 return DecodeMap 905 906# -------------------------------------------------------------------- 907# Optimization is not as heavy here because calls to SkipField() are rare, 908# except for handling end-group tags. 909 910def _SkipVarint(buffer, pos, end): 911 """Skip a varint value. Returns the new position.""" 912 # Previously ord(buffer[pos]) raised IndexError when pos is out of range. 913 # With this code, ord(b'') raises TypeError. Both are handled in 914 # python_message.py to generate a 'Truncated message' error. 915 while ord(buffer[pos:pos+1].tobytes()) & 0x80: 916 pos += 1 917 pos += 1 918 if pos > end: 919 raise _DecodeError('Truncated message.') 920 return pos 921 922def _SkipFixed64(buffer, pos, end): 923 """Skip a fixed64 value. Returns the new position.""" 924 925 pos += 8 926 if pos > end: 927 raise _DecodeError('Truncated message.') 928 return pos 929 930 931def _DecodeFixed64(buffer, pos): 932 """Decode a fixed64.""" 933 new_pos = pos + 8 934 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) 935 936 937def _SkipLengthDelimited(buffer, pos, end): 938 """Skip a length-delimited value. Returns the new position.""" 939 940 (size, pos) = _DecodeVarint(buffer, pos) 941 pos += size 942 if pos > end: 943 raise _DecodeError('Truncated message.') 944 return pos 945 946 947def _SkipGroup(buffer, pos, end): 948 """Skip sub-group. Returns the new position.""" 949 950 while 1: 951 (tag_bytes, pos) = ReadTag(buffer, pos) 952 new_pos = SkipField(buffer, pos, end, tag_bytes) 953 if new_pos == -1: 954 return pos 955 pos = new_pos 956 957 958def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): 959 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" 960 961 unknown_field_set = containers.UnknownFieldSet() 962 while end_pos is None or pos < end_pos: 963 (tag_bytes, pos) = ReadTag(buffer, pos) 964 (tag, _) = _DecodeVarint(tag_bytes, 0) 965 field_number, wire_type = wire_format.UnpackTag(tag) 966 if wire_type == wire_format.WIRETYPE_END_GROUP: 967 break 968 (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) 969 # pylint: disable=protected-access 970 unknown_field_set._add(field_number, wire_type, data) 971 972 return (unknown_field_set, pos) 973 974 975def _DecodeUnknownField(buffer, pos, wire_type): 976 """Decode a unknown field. Returns the UnknownField and new position.""" 977 978 if wire_type == wire_format.WIRETYPE_VARINT: 979 (data, pos) = _DecodeVarint(buffer, pos) 980 elif wire_type == wire_format.WIRETYPE_FIXED64: 981 (data, pos) = _DecodeFixed64(buffer, pos) 982 elif wire_type == wire_format.WIRETYPE_FIXED32: 983 (data, pos) = _DecodeFixed32(buffer, pos) 984 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 985 (size, pos) = _DecodeVarint(buffer, pos) 986 data = buffer[pos:pos+size].tobytes() 987 pos += size 988 elif wire_type == wire_format.WIRETYPE_START_GROUP: 989 (data, pos) = _DecodeUnknownFieldSet(buffer, pos) 990 elif wire_type == wire_format.WIRETYPE_END_GROUP: 991 return (0, -1) 992 else: 993 raise _DecodeError('Wrong wire type in tag.') 994 995 return (data, pos) 996 997 998def _EndGroup(buffer, pos, end): 999 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" 1000 1001 return -1 1002 1003 1004def _SkipFixed32(buffer, pos, end): 1005 """Skip a fixed32 value. Returns the new position.""" 1006 1007 pos += 4 1008 if pos > end: 1009 raise _DecodeError('Truncated message.') 1010 return pos 1011 1012 1013def _DecodeFixed32(buffer, pos): 1014 """Decode a fixed32.""" 1015 1016 new_pos = pos + 4 1017 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) 1018 1019 1020def _RaiseInvalidWireType(buffer, pos, end): 1021 """Skip function for unknown wire types. Raises an exception.""" 1022 1023 raise _DecodeError('Tag had invalid wire type.') 1024 1025def _FieldSkipper(): 1026 """Constructs the SkipField function.""" 1027 1028 WIRETYPE_TO_SKIPPER = [ 1029 _SkipVarint, 1030 _SkipFixed64, 1031 _SkipLengthDelimited, 1032 _SkipGroup, 1033 _EndGroup, 1034 _SkipFixed32, 1035 _RaiseInvalidWireType, 1036 _RaiseInvalidWireType, 1037 ] 1038 1039 wiretype_mask = wire_format.TAG_TYPE_MASK 1040 1041 def SkipField(buffer, pos, end, tag_bytes): 1042 """Skips a field with the specified tag. 1043 1044 |pos| should point to the byte immediately after the tag. 1045 1046 Returns: 1047 The new position (after the tag value), or -1 if the tag is an end-group 1048 tag (in which case the calling loop should break). 1049 """ 1050 1051 # The wire type is always in the first byte since varints are little-endian. 1052 wire_type = ord(tag_bytes[0:1]) & wiretype_mask 1053 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 1054 1055 return SkipField 1056 1057SkipField = _FieldSkipper() 1058