1# Copyright 2016 Google LLC. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# Copyright 2002, Google LLC. 15 16 17from __future__ import absolute_import 18import array 19import six.moves.http_client 20import itertools 21import re 22import struct 23import six 24 25try: 26 # NOTE(user): Using non-google-style import to workaround a zipimport_tinypar 27 # issue for zip files embedded in par files. See http://b/13811096 28 import googlecloudsdk.third_party.appengine.proto.proto1 as proto1 29except ImportError: 30 # Protect in case of missing deps / strange env (GAE?) / etc. 31 class ProtocolBufferDecodeError(Exception): pass 32 class ProtocolBufferEncodeError(Exception): pass 33 class ProtocolBufferReturnError(Exception): pass 34else: 35 ProtocolBufferDecodeError = proto1.ProtocolBufferDecodeError 36 ProtocolBufferEncodeError = proto1.ProtocolBufferEncodeError 37 ProtocolBufferReturnError = proto1.ProtocolBufferReturnError 38 39__all__ = ['ProtocolMessage', 'Encoder', 'Decoder', 40 'ExtendableProtocolMessage', 41 'ProtocolBufferDecodeError', 42 'ProtocolBufferEncodeError', 43 'ProtocolBufferReturnError'] 44 45URL_RE = re.compile('^(https?)://([^/]+)(/.*)$') 46 47 48class ProtocolMessage: 49 """ 50 The parent class of all protocol buffers. 51 NOTE: the methods that unconditionally raise NotImplementedError are 52 reimplemented by the subclasses of this class. 53 Subclasses are automatically generated by tools/protocol_converter. 54 Encoding methods can raise ProtocolBufferEncodeError if a value for an 55 integer or long field is too large, or if any required field is not set. 56 Decoding methods can raise ProtocolBufferDecodeError if they couldn't 57 decode correctly, or the decoded message doesn't have all required fields. 58 """ 59 60 ##################################### 61 # methods you should use # 62 ##################################### 63 64 def __init__(self, contents=None): 65 """Construct a new protocol buffer, with optional starting contents 66 in binary protocol buffer format.""" 67 raise NotImplementedError 68 69 def Clear(self): 70 """Erases all fields of protocol buffer (& resets to defaults 71 if fields have defaults).""" 72 raise NotImplementedError 73 74 def IsInitialized(self, debug_strs=None): 75 """returns true iff all required fields have been set.""" 76 raise NotImplementedError 77 78 def Encode(self): 79 """Returns a string representing the protocol buffer object.""" 80 try: 81 return self._CEncode() 82 except (NotImplementedError, AttributeError): 83 e = Encoder() 84 self.Output(e) 85 return e.buffer().tostring() 86 87 def SerializeToString(self): 88 """Same as Encode(), but has same name as proto2's serialize function.""" 89 return self.Encode() 90 91 def SerializePartialToString(self): 92 """Returns a string representing the protocol buffer object. 93 Same as SerializeToString() but does not enforce required fields are set. 94 """ 95 try: 96 return self._CEncodePartial() 97 except (NotImplementedError, AttributeError): 98 e = Encoder() 99 self.OutputPartial(e) 100 return e.buffer().tostring() 101 102 def _CEncode(self): 103 """Call into C++ encode code. 104 105 Generated protocol buffer classes will override this method to 106 provide C++-based serialization. If a subclass does not 107 implement this method, Encode() will fall back to 108 using pure-Python encoding. 109 """ 110 raise NotImplementedError 111 112 def _CEncodePartial(self): 113 """Same as _CEncode, except does not encode missing required fields.""" 114 raise NotImplementedError 115 116 def ParseFromString(self, s): 117 """Reads data from the string 's'. 118 Raises a ProtocolBufferDecodeError if, after successfully reading 119 in the contents of 's', this protocol message is still not initialized.""" 120 self.Clear() 121 self.MergeFromString(s) 122 123 def ParsePartialFromString(self, s): 124 """Reads data from the string 's'. 125 Does not enforce required fields are set.""" 126 self.Clear() 127 self.MergePartialFromString(s) 128 129 def MergeFromString(self, s): 130 """Adds in data from the string 's'. 131 Raises a ProtocolBufferDecodeError if, after successfully merging 132 in the contents of 's', this protocol message is still not initialized.""" 133 self.MergePartialFromString(s) 134 dbg = [] 135 if not self.IsInitialized(dbg): 136 raise ProtocolBufferDecodeError('\n\t'.join(dbg)) 137 138 def MergePartialFromString(self, s): 139 """Merges in data from the string 's'. 140 Does not enforce required fields are set.""" 141 try: 142 self._CMergeFromString(s) 143 except (NotImplementedError, AttributeError): 144 # If we can't call into C++ to deserialize the string, use 145 # the (much slower) pure-Python implementation. 146 a = array.array('B') 147 a.fromstring(s) 148 d = Decoder(a, 0, len(a)) 149 self.TryMerge(d) 150 151 def _CMergeFromString(self, s): 152 """Call into C++ parsing code to merge from a string. 153 154 Does *not* check IsInitialized() before returning. 155 156 Generated protocol buffer classes will override this method to 157 provide C++-based deserialization. If a subclass does not 158 implement this method, MergeFromString() will fall back to 159 using pure-Python parsing. 160 """ 161 raise NotImplementedError 162 163 def __getstate__(self): 164 """Return the pickled representation of the data inside protocol buffer, 165 which is the same as its binary-encoded representation (as a string).""" 166 return self.Encode() 167 168 def __setstate__(self, contents_): 169 """Restore the pickled representation of the data inside protocol buffer. 170 Note that the mechanism underlying pickle.load() does not call __init__.""" 171 self.__init__(contents=contents_) 172 173 def sendCommand(self, server, url, response, follow_redirects=1, 174 secure=0, keyfile=None, certfile=None): 175 """posts the protocol buffer to the desired url on the server 176 and puts the return data into the protocol buffer 'response' 177 178 NOTE: The underlying socket raises the 'error' exception 179 for all I/O related errors (can't connect, etc.). 180 181 If 'response' is None, the server's PB response will be ignored. 182 183 The optional 'follow_redirects' argument indicates the number 184 of HTTP redirects that are followed before giving up and raising an 185 exception. The default is 1. 186 187 If 'secure' is true, HTTPS will be used instead of HTTP. Also, 188 'keyfile' and 'certfile' may be set for client authentication. 189 """ 190 data = self.Encode() 191 if secure: 192 if keyfile and certfile: 193 conn = six.moves.http_client.HTTPSConnection(server, key_file=keyfile, 194 cert_file=certfile) 195 else: 196 conn = six.moves.http_client.HTTPSConnection(server) 197 else: 198 conn = six.moves.http_client.HTTPConnection(server) 199 conn.putrequest("POST", url) 200 conn.putheader("Content-Length", "%d" %len(data)) 201 conn.endheaders() 202 conn.send(data) 203 resp = conn.getresponse() 204 if follow_redirects > 0 and resp.status == 302: 205 m = URL_RE.match(resp.getheader('Location')) 206 if m: 207 protocol, server, url = m.groups() 208 return self.sendCommand(server, url, response, 209 follow_redirects=follow_redirects - 1, 210 secure=(protocol == 'https'), 211 keyfile=keyfile, 212 certfile=certfile) 213 if resp.status != 200: 214 raise ProtocolBufferReturnError(resp.status) 215 if response is not None: 216 response.ParseFromString(resp.read()) 217 return response 218 219 def sendSecureCommand(self, server, keyfile, certfile, url, response, 220 follow_redirects=1): 221 """posts the protocol buffer via https to the desired url on the server, 222 using the specified key and certificate files, and puts the return 223 data int othe protocol buffer 'response'. 224 225 See caveats in sendCommand. 226 227 You need an SSL-aware build of the Python2 interpreter to use this command. 228 (Python1 is not supported). An SSL build of python2.2 is in 229 /home/build/buildtools/python-ssl-2.2 . An SSL build of python is 230 standard on all prod machines. 231 232 keyfile: Contains our private RSA key 233 certfile: Contains SSL certificate for remote host 234 Specify None for keyfile/certfile if you don't want to do client auth. 235 """ 236 return self.sendCommand(server, url, response, 237 follow_redirects=follow_redirects, 238 secure=1, keyfile=keyfile, certfile=certfile) 239 240 def __str__(self, prefix="", printElemNumber=0): 241 """Returns nicely formatted contents of this protocol buffer.""" 242 raise NotImplementedError 243 244 def ToASCII(self): 245 """Returns the protocol buffer as a human-readable string.""" 246 return self._CToASCII(ProtocolMessage._SYMBOLIC_FULL_ASCII) 247 248 def ToShortASCII(self): 249 """Returns the protocol buffer as an ASCII string. 250 The output is short, leaving out newlines and some other niceties. 251 Defers to the C++ ProtocolPrinter class in SYMBOLIC_SHORT mode. 252 """ 253 return self._CToASCII(ProtocolMessage._SYMBOLIC_SHORT_ASCII) 254 255 # Note that these must be consistent with the ProtocolPrinter::Level C++ 256 # enum. 257 _NUMERIC_ASCII = 0 258 _SYMBOLIC_SHORT_ASCII = 1 259 _SYMBOLIC_FULL_ASCII = 2 260 261 def _CToASCII(self, output_format): 262 """Calls into C++ ASCII-generating code. 263 264 Generated protocol buffer classes will override this method to provide 265 C++-based ASCII output. 266 """ 267 raise NotImplementedError 268 269 def ParseASCII(self, ascii_string): 270 """Parses a string generated by ToASCII() or by the C++ DebugString() 271 method, initializing this protocol buffer with its contents. This method 272 raises a ValueError if it encounters an unknown field. 273 """ 274 raise NotImplementedError 275 276 def ParseASCIIIgnoreUnknown(self, ascii_string): 277 """Parses a string generated by ToASCII() or by the C++ DebugString() 278 method, initializing this protocol buffer with its contents. Ignores 279 unknown fields. 280 """ 281 raise NotImplementedError 282 283 def Equals(self, other): 284 """Returns whether or not this protocol buffer is equivalent to another. 285 286 This assumes that self and other are of the same type. 287 """ 288 raise NotImplementedError 289 290 def __eq__(self, other): 291 """Implementation of operator ==.""" 292 # If self and other are of different types we return NotImplemented, which 293 # tells the Python interpreter to try some other methods of measuring 294 # equality before finally performing an identity comparison. This allows 295 # other classes to implement custom __eq__ or __ne__ methods. 296 # See http://docs.sympy.org/_sources/python-comparisons.txt 297 if other.__class__ is self.__class__: 298 return self.Equals(other) 299 return NotImplemented 300 301 def __ne__(self, other): 302 """Implementation of operator !=.""" 303 # We repeat code for __ne__ instead of returning "not (self == other)" 304 # so that we can return NotImplemented when comparing against an object of 305 # a different type. 306 # See http://bugs.python.org/msg76374 for an example of when __ne__ might 307 # return something other than the Boolean opposite of __eq__. 308 if other.__class__ is self.__class__: 309 return not self.Equals(other) 310 return NotImplemented 311 312 ##################################### 313 # methods power-users might want # 314 ##################################### 315 316 def Output(self, e): 317 """write self to the encoder 'e'.""" 318 dbg = [] 319 if not self.IsInitialized(dbg): 320 raise ProtocolBufferEncodeError('\n\t'.join(dbg)) 321 self.OutputUnchecked(e) 322 return 323 324 def OutputUnchecked(self, e): 325 """write self to the encoder 'e', don't check for initialization.""" 326 raise NotImplementedError 327 328 def OutputPartial(self, e): 329 """write self to the encoder 'e', don't check for initialization and 330 don't assume required fields exist.""" 331 raise NotImplementedError 332 333 def Parse(self, d): 334 """reads data from the Decoder 'd'.""" 335 self.Clear() 336 self.Merge(d) 337 return 338 339 def Merge(self, d): 340 """merges data from the Decoder 'd'.""" 341 self.TryMerge(d) 342 dbg = [] 343 if not self.IsInitialized(dbg): 344 raise ProtocolBufferDecodeError('\n\t'.join(dbg)) 345 return 346 347 def TryMerge(self, d): 348 """merges data from the Decoder 'd'.""" 349 raise NotImplementedError 350 351 def CopyFrom(self, pb): 352 """copy data from another protocol buffer""" 353 if (pb == self): return 354 self.Clear() 355 self.MergeFrom(pb) 356 357 def MergeFrom(self, pb): 358 """merge data from another protocol buffer""" 359 raise NotImplementedError 360 361 ##################################### 362 # helper methods for subclasses # 363 ##################################### 364 365 def lengthVarInt32(self, n): 366 return self.lengthVarInt64(n) 367 368 def lengthVarInt64(self, n): 369 if n < 0: 370 return 10 # ceil(64/7) 371 result = 0 372 while 1: 373 result += 1 374 n >>= 7 375 if n == 0: 376 break 377 return result 378 379 def lengthString(self, n): 380 return self.lengthVarInt32(n) + n 381 382 def DebugFormat(self, value): 383 return "%s" % value 384 def DebugFormatInt32(self, value): 385 if (value <= -2000000000 or value >= 2000000000): 386 return self.DebugFormatFixed32(value) 387 return "%d" % value 388 def DebugFormatInt64(self, value): 389 if (value <= -20000000000000 or value >= 20000000000000): 390 return self.DebugFormatFixed64(value) 391 return "%d" % value 392 def DebugFormatString(self, value): 393 # For now we only escape the bare minimum to insure interoperability 394 # and redability. In the future we may want to mimick the c++ behavior 395 # more closely, but this will make the code a lot more messy. 396 def escape(c): 397 o = ord(c) 398 if o == 10: return r"\n" # optional escape 399 if o == 39: return r"\'" # optional escape 400 401 if o == 34: return r'\"' # necessary escape 402 if o == 92: return r"\\" # necessary escape 403 404 if o >= 127 or o < 32: return "\\%03o" % o # necessary escapes 405 return c 406 return '"' + "".join(escape(c) for c in value) + '"' 407 def DebugFormatFloat(self, value): 408 return "%ff" % value 409 def DebugFormatFixed32(self, value): 410 if (value < 0): value += (1<<32) 411 return "0x%x" % value 412 def DebugFormatFixed64(self, value): 413 if (value < 0): value += (1<<64) 414 return "0x%x" % value 415 def DebugFormatBool(self, value): 416 if value: 417 return "true" 418 else: 419 return "false" 420 421# types of fields, must match Proto::Type and net/proto/protocoltype.proto 422TYPE_DOUBLE = 1 423TYPE_FLOAT = 2 424TYPE_INT64 = 3 425TYPE_UINT64 = 4 426TYPE_INT32 = 5 427TYPE_FIXED64 = 6 428TYPE_FIXED32 = 7 429TYPE_BOOL = 8 430TYPE_STRING = 9 431TYPE_GROUP = 10 432TYPE_FOREIGN = 11 433 434# debug string for extensions 435_TYPE_TO_DEBUG_STRING = { 436 TYPE_INT32: ProtocolMessage.DebugFormatInt32, 437 TYPE_INT64: ProtocolMessage.DebugFormatInt64, 438 TYPE_UINT64: ProtocolMessage.DebugFormatInt64, 439 TYPE_FLOAT: ProtocolMessage.DebugFormatFloat, 440 TYPE_STRING: ProtocolMessage.DebugFormatString, 441 TYPE_FIXED32: ProtocolMessage.DebugFormatFixed32, 442 TYPE_FIXED64: ProtocolMessage.DebugFormatFixed64, 443 TYPE_BOOL: ProtocolMessage.DebugFormatBool } 444 445# users of protocol buffers usually won't need to concern themselves 446# with either Encoders or Decoders. 447class Encoder: 448 449 # types of data 450 NUMERIC = 0 451 DOUBLE = 1 452 STRING = 2 453 STARTGROUP = 3 454 ENDGROUP = 4 455 FLOAT = 5 456 MAX_TYPE = 6 457 458 def __init__(self): 459 self.buf = array.array('B') 460 return 461 462 def buffer(self): 463 return self.buf 464 465 def put8(self, v): 466 if v < 0 or v >= (1<<8): raise ProtocolBufferEncodeError("u8 too big") 467 self.buf.append(v & 255) 468 return 469 470 def put16(self, v): 471 if v < 0 or v >= (1<<16): raise ProtocolBufferEncodeError("u16 too big") 472 self.buf.append((v >> 0) & 255) 473 self.buf.append((v >> 8) & 255) 474 return 475 476 def put32(self, v): 477 if v < 0 or v >= (1<<32): raise ProtocolBufferEncodeError("u32 too big") 478 self.buf.append((v >> 0) & 255) 479 self.buf.append((v >> 8) & 255) 480 self.buf.append((v >> 16) & 255) 481 self.buf.append((v >> 24) & 255) 482 return 483 484 def put64(self, v): 485 if v < 0 or v >= (1<<64): raise ProtocolBufferEncodeError("u64 too big") 486 self.buf.append((v >> 0) & 255) 487 self.buf.append((v >> 8) & 255) 488 self.buf.append((v >> 16) & 255) 489 self.buf.append((v >> 24) & 255) 490 self.buf.append((v >> 32) & 255) 491 self.buf.append((v >> 40) & 255) 492 self.buf.append((v >> 48) & 255) 493 self.buf.append((v >> 56) & 255) 494 return 495 496 def putVarInt32(self, v): 497 # Profiling has shown this code to be very performance critical 498 # so we duplicate code, go for early exits when possible, etc. 499 # VarInt32 gets more unrolling because VarInt32s are far and away 500 # the most common element in protobufs (field tags and string 501 # lengths), so they get more attention. They're also more 502 # likely to fit in one byte (string lengths again), so we 503 # check and bail out early if possible. 504 505 buf_append = self.buf.append # cache attribute lookup 506 if v & 127 == v: 507 buf_append(v) 508 return 509 if v >= 0x80000000 or v < -0x80000000: # python2.4 doesn't fold constants 510 raise ProtocolBufferEncodeError("int32 too big") 511 if v < 0: 512 v += 0x10000000000000000 513 while True: 514 bits = v & 127 515 v >>= 7 516 if v: 517 bits |= 128 518 buf_append(bits) 519 if not v: 520 break 521 return 522 523 def putVarInt64(self, v): 524 buf_append = self.buf.append 525 if v >= 0x8000000000000000 or v < -0x8000000000000000: 526 raise ProtocolBufferEncodeError("int64 too big") 527 if v < 0: 528 v += 0x10000000000000000 529 while True: 530 bits = v & 127 531 v >>= 7 532 if v: 533 bits |= 128 534 buf_append(bits) 535 if not v: 536 break 537 return 538 539 def putVarUint64(self, v): 540 buf_append = self.buf.append 541 if v < 0 or v >= 0x10000000000000000: 542 raise ProtocolBufferEncodeError("uint64 too big") 543 while True: 544 bits = v & 127 545 v >>= 7 546 if v: 547 bits |= 128 548 buf_append(bits) 549 if not v: 550 break 551 return 552 553 def putFloat(self, v): 554 a = array.array('B') 555 a.fromstring(struct.pack("<f", v)) 556 self.buf.extend(a) 557 return 558 559 def putDouble(self, v): 560 a = array.array('B') 561 a.fromstring(struct.pack("<d", v)) 562 self.buf.extend(a) 563 return 564 565 def putBoolean(self, v): 566 if v: 567 self.buf.append(1) 568 else: 569 self.buf.append(0) 570 return 571 572 def putPrefixedString(self, v): 573 # This change prevents corrupted encoding an YouTube, where 574 # our default encoding is utf-8 and unicode strings may occasionally be 575 # passed into ProtocolBuffers. 576 v = str(v) 577 self.putVarInt32(len(v)) 578 self.buf.fromstring(v) 579 return 580 581 def putRawString(self, v): 582 self.buf.fromstring(v) 583 584 _TYPE_TO_METHOD = { 585 TYPE_DOUBLE: putDouble, 586 TYPE_FLOAT: putFloat, 587 TYPE_FIXED64: put64, 588 TYPE_FIXED32: put32, 589 TYPE_INT32: putVarInt32, 590 TYPE_INT64: putVarInt64, 591 TYPE_UINT64: putVarUint64, 592 TYPE_BOOL: putBoolean, 593 TYPE_STRING: putPrefixedString } 594 595 _TYPE_TO_BYTE_SIZE = { 596 TYPE_DOUBLE: 8, 597 TYPE_FLOAT: 4, 598 TYPE_FIXED64: 8, 599 TYPE_FIXED32: 4, 600 TYPE_BOOL: 1 } 601 602class Decoder: 603 def __init__(self, buf, idx, limit): 604 self.buf = buf 605 self.idx = idx 606 self.limit = limit 607 return 608 609 def avail(self): 610 return self.limit - self.idx 611 612 def buffer(self): 613 return self.buf 614 615 def pos(self): 616 return self.idx 617 618 def skip(self, n): 619 if self.idx + n > self.limit: raise ProtocolBufferDecodeError("truncated") 620 self.idx += n 621 return 622 623 def skipData(self, tag): 624 t = tag & 7 # tag format type 625 if t == Encoder.NUMERIC: 626 self.getVarInt64() 627 elif t == Encoder.DOUBLE: 628 self.skip(8) 629 elif t == Encoder.STRING: 630 n = self.getVarInt32() 631 self.skip(n) 632 elif t == Encoder.STARTGROUP: 633 while 1: 634 t = self.getVarInt32() 635 if (t & 7) == Encoder.ENDGROUP: 636 break 637 else: 638 self.skipData(t) 639 if (t - Encoder.ENDGROUP) != (tag - Encoder.STARTGROUP): 640 raise ProtocolBufferDecodeError("corrupted") 641 elif t == Encoder.ENDGROUP: 642 raise ProtocolBufferDecodeError("corrupted") 643 elif t == Encoder.FLOAT: 644 self.skip(4) 645 else: 646 raise ProtocolBufferDecodeError("corrupted") 647 648 # these are all unsigned gets 649 def get8(self): 650 if self.idx >= self.limit: raise ProtocolBufferDecodeError("truncated") 651 c = self.buf[self.idx] 652 self.idx += 1 653 return c 654 655 def get16(self): 656 if self.idx + 2 > self.limit: raise ProtocolBufferDecodeError("truncated") 657 c = self.buf[self.idx] 658 d = self.buf[self.idx + 1] 659 self.idx += 2 660 return (d << 8) | c 661 662 def get32(self): 663 if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError("truncated") 664 c = self.buf[self.idx] 665 d = self.buf[self.idx + 1] 666 e = self.buf[self.idx + 2] 667 f = int(self.buf[self.idx + 3]) 668 self.idx += 4 669 return (f << 24) | (e << 16) | (d << 8) | c 670 671 def get64(self): 672 if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError("truncated") 673 c = self.buf[self.idx] 674 d = self.buf[self.idx + 1] 675 e = self.buf[self.idx + 2] 676 f = int(self.buf[self.idx + 3]) 677 g = int(self.buf[self.idx + 4]) 678 h = int(self.buf[self.idx + 5]) 679 i = int(self.buf[self.idx + 6]) 680 j = int(self.buf[self.idx + 7]) 681 self.idx += 8 682 return ((j << 56) | (i << 48) | (h << 40) | (g << 32) | (f << 24) 683 | (e << 16) | (d << 8) | c) 684 685 def getVarInt32(self): 686 # getVarInt32 gets different treatment than other integer getter 687 # functions due to the much larger number of varInt32s and also 688 # varInt32s that fit in one byte. See the comment at putVarInt32. 689 b = self.get8() 690 if not (b & 128): 691 return b 692 693 result = int(0) 694 shift = 0 695 696 while 1: 697 result |= (int(b & 127) << shift) 698 shift += 7 699 if not (b & 128): 700 if result >= 0x10000000000000000: # (1L << 64): 701 raise ProtocolBufferDecodeError("corrupted") 702 break 703 if shift >= 64: raise ProtocolBufferDecodeError("corrupted") 704 b = self.get8() 705 706 if result >= 0x8000000000000000: # (1L << 63) 707 result -= 0x10000000000000000 # (1L << 64) 708 if result >= 0x80000000 or result < -0x80000000: # (1L << 31) 709 raise ProtocolBufferDecodeError("corrupted") 710 return result 711 712 def getVarInt64(self): 713 result = self.getVarUint64() 714 if result >= (1 << 63): 715 result -= (1 << 64) 716 return result 717 718 def getVarUint64(self): 719 result = int(0) 720 shift = 0 721 while 1: 722 if shift >= 64: raise ProtocolBufferDecodeError("corrupted") 723 b = self.get8() 724 result |= (int(b & 127) << shift) 725 shift += 7 726 if not (b & 128): 727 if result >= (1 << 64): raise ProtocolBufferDecodeError("corrupted") 728 return result 729 return result # make pychecker happy 730 731 def getFloat(self): 732 if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError("truncated") 733 a = self.buf[self.idx:self.idx+4] 734 self.idx += 4 735 return struct.unpack("<f", a)[0] 736 737 def getDouble(self): 738 if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError("truncated") 739 a = self.buf[self.idx:self.idx+8] 740 self.idx += 8 741 return struct.unpack("<d", a)[0] 742 743 def getBoolean(self): 744 b = self.get8() 745 if b != 0 and b != 1: raise ProtocolBufferDecodeError("corrupted") 746 return b 747 748 def getPrefixedString(self): 749 length = self.getVarInt32() 750 if self.idx + length > self.limit: 751 raise ProtocolBufferDecodeError("truncated") 752 r = self.buf[self.idx : self.idx + length] 753 self.idx += length 754 return r.tostring() 755 756 def getRawString(self): 757 r = self.buf[self.idx:self.limit] 758 self.idx = self.limit 759 return r.tostring() 760 761 _TYPE_TO_METHOD = { 762 TYPE_DOUBLE: getDouble, 763 TYPE_FLOAT: getFloat, 764 TYPE_FIXED64: get64, 765 TYPE_FIXED32: get32, 766 TYPE_INT32: getVarInt32, 767 TYPE_INT64: getVarInt64, 768 TYPE_UINT64: getVarUint64, 769 TYPE_BOOL: getBoolean, 770 TYPE_STRING: getPrefixedString } 771 772##################################### 773# extensions # 774##################################### 775 776class ExtensionIdentifier(object): 777 __slots__ = ('full_name', 'number', 'field_type', 'wire_tag', 'is_repeated', 778 'default', 'containing_cls', 'composite_cls', 'message_name') 779 def __init__(self, full_name, number, field_type, wire_tag, is_repeated, 780 default): 781 self.full_name = full_name 782 self.number = number 783 self.field_type = field_type 784 self.wire_tag = wire_tag 785 self.is_repeated = is_repeated 786 self.default = default 787 788class ExtendableProtocolMessage(ProtocolMessage): 789 def HasExtension(self, extension): 790 """Checks if the message contains a certain non-repeated extension.""" 791 self._VerifyExtensionIdentifier(extension) 792 return extension in self._extension_fields 793 794 def ClearExtension(self, extension): 795 """Clears the value of extension, so that HasExtension() returns false or 796 ExtensionSize() returns 0.""" 797 self._VerifyExtensionIdentifier(extension) 798 if extension in self._extension_fields: 799 del self._extension_fields[extension] 800 801 def GetExtension(self, extension, index=None): 802 """Gets the extension value for a certain extension. 803 804 Args: 805 extension: The ExtensionIdentifier for the extension. 806 index: The index of element to get in a repeated field. Only needed if 807 the extension is repeated. 808 809 Returns: 810 The value of the extension if exists, otherwise the default value of the 811 extension will be returned. 812 """ 813 self._VerifyExtensionIdentifier(extension) 814 if extension in self._extension_fields: 815 result = self._extension_fields[extension] 816 else: 817 if extension.is_repeated: 818 result = [] 819 elif extension.composite_cls: 820 result = extension.composite_cls() 821 else: 822 result = extension.default 823 if extension.is_repeated: 824 result = result[index] 825 return result 826 827 def SetExtension(self, extension, *args): 828 """Sets the extension value for a certain scalar type extension. 829 830 Arg varies according to extension type: 831 - Singular: 832 message.SetExtension(extension, value) 833 - Repeated: 834 message.SetExtension(extension, index, value) 835 where 836 extension: The ExtensionIdentifier for the extension. 837 index: The index of element to set in a repeated field. Only needed if 838 the extension is repeated. 839 value: The value to set. 840 841 Raises: 842 TypeError if a message type extension is given. 843 """ 844 self._VerifyExtensionIdentifier(extension) 845 if extension.composite_cls: 846 raise TypeError( 847 'Cannot assign to extension "%s" because it is a composite type.' % 848 extension.full_name) 849 if extension.is_repeated: 850 try: 851 index, value = args 852 except ValueError: 853 raise TypeError( 854 "SetExtension(extension, index, value) for repeated extension " 855 "takes exactly 4 arguments: (%d given)" % (len(args) + 2)) 856 self._extension_fields[extension][index] = value 857 else: 858 try: 859 (value,) = args 860 except ValueError: 861 raise TypeError( 862 "SetExtension(extension, value) for singular extension " 863 "takes exactly 3 arguments: (%d given)" % (len(args) + 2)) 864 self._extension_fields[extension] = value 865 866 def MutableExtension(self, extension, index=None): 867 """Gets a mutable reference of a message type extension. 868 869 For repeated extension, index must be specified, and only one element will 870 be returned. For optional extension, if the extension does not exist, a new 871 message will be created and set in parent message. 872 873 Args: 874 extension: The ExtensionIdentifier for the extension. 875 index: The index of element to mutate in a repeated field. Only needed if 876 the extension is repeated. 877 878 Returns: 879 The mutable message reference. 880 881 Raises: 882 TypeError if non-message type extension is given. 883 """ 884 self._VerifyExtensionIdentifier(extension) 885 if extension.composite_cls is None: 886 raise TypeError( 887 'MutableExtension() cannot be applied to "%s", because it is not a ' 888 'composite type.' % extension.full_name) 889 if extension.is_repeated: 890 if index is None: 891 raise TypeError( 892 'MutableExtension(extension, index) for repeated extension ' 893 'takes exactly 2 arguments: (1 given)') 894 return self.GetExtension(extension, index) 895 if extension in self._extension_fields: 896 return self._extension_fields[extension] 897 else: 898 result = extension.composite_cls() 899 self._extension_fields[extension] = result 900 return result 901 902 def ExtensionList(self, extension): 903 """Returns a mutable list of extensions. 904 905 Raises: 906 TypeError if the extension is not repeated. 907 """ 908 self._VerifyExtensionIdentifier(extension) 909 if not extension.is_repeated: 910 raise TypeError( 911 'ExtensionList() cannot be applied to "%s", because it is not a ' 912 'repeated extension.' % extension.full_name) 913 if extension in self._extension_fields: 914 return self._extension_fields[extension] 915 result = [] 916 self._extension_fields[extension] = result 917 return result 918 919 def ExtensionSize(self, extension): 920 """Returns the size of a repeated extension. 921 922 Raises: 923 TypeError if the extension is not repeated. 924 """ 925 self._VerifyExtensionIdentifier(extension) 926 if not extension.is_repeated: 927 raise TypeError( 928 'ExtensionSize() cannot be applied to "%s", because it is not a ' 929 'repeated extension.' % extension.full_name) 930 if extension in self._extension_fields: 931 return len(self._extension_fields[extension]) 932 return 0 933 934 def AddExtension(self, extension, value=None): 935 """Appends a new element into a repeated extension. 936 937 Arg varies according to the extension field type: 938 - Scalar/String: 939 message.AddExtension(extension, value) 940 - Message: 941 mutable_message = AddExtension(extension) 942 943 Args: 944 extension: The ExtensionIdentifier for the extension. 945 value: The value of the extension if the extension is scalar/string type. 946 The value must NOT be set for message type extensions; set values on 947 the returned message object instead. 948 949 Returns: 950 A mutable new message if it's a message type extension, or None otherwise. 951 952 Raises: 953 TypeError if the extension is not repeated, or value is given for message 954 type extensions. 955 """ 956 self._VerifyExtensionIdentifier(extension) 957 if not extension.is_repeated: 958 raise TypeError( 959 'AddExtension() cannot be applied to "%s", because it is not a ' 960 'repeated extension.' % extension.full_name) 961 if extension in self._extension_fields: 962 field = self._extension_fields[extension] 963 else: 964 field = [] 965 self._extension_fields[extension] = field 966 # Composite field 967 if extension.composite_cls: 968 if value is not None: 969 raise TypeError( 970 'value must not be set in AddExtension() for "%s", because it is ' 971 'a message type extension. Set values on the returned message ' 972 'instead.' % extension.full_name) 973 msg = extension.composite_cls() 974 field.append(msg) 975 return msg 976 # Scalar and string field 977 field.append(value) 978 979 def _VerifyExtensionIdentifier(self, extension): 980 if extension.containing_cls != self.__class__: 981 raise TypeError("Containing type of %s is %s, but not %s." 982 % (extension.full_name, 983 extension.containing_cls.__name__, 984 self.__class__.__name__)) 985 986 def _MergeExtensionFields(self, x): 987 for ext, val in x._extension_fields.items(): 988 if ext.is_repeated: 989 for single_val in val: 990 if ext.composite_cls is None: 991 self.AddExtension(ext, single_val) 992 else: 993 self.AddExtension(ext).MergeFrom(single_val) 994 else: 995 if ext.composite_cls is None: 996 self.SetExtension(ext, val) 997 else: 998 self.MutableExtension(ext).MergeFrom(val) 999 1000 def _ListExtensions(self): 1001 return sorted( 1002 (ext for ext in self._extension_fields 1003 if (not ext.is_repeated) or self.ExtensionSize(ext) > 0), 1004 key=lambda item: item.number) 1005 1006 def _ExtensionEquals(self, x): 1007 extensions = self._ListExtensions() 1008 if extensions != x._ListExtensions(): 1009 return False 1010 for ext in extensions: 1011 if ext.is_repeated: 1012 if self.ExtensionSize(ext) != x.ExtensionSize(ext): return False 1013 for e1, e2 in zip(self.ExtensionList(ext), 1014 x.ExtensionList(ext)): 1015 if e1 != e2: return False 1016 else: 1017 if self.GetExtension(ext) != x.GetExtension(ext): return False 1018 return True 1019 1020 def _OutputExtensionFields(self, out, partial, extensions, start_index, 1021 end_field_number): 1022 """Serialize a range of extensions. 1023 1024 To generate canonical output when encoding, we interleave fields and 1025 extensions to preserve tag order. 1026 1027 Generated code will prepare a list of ExtensionIdentifier sorted in field 1028 number order and call this method to serialize a specific range of 1029 extensions. The range is specified by the two arguments, start_index and 1030 end_field_number. 1031 1032 The method will serialize all extensions[i] with i >= start_index and 1033 extensions[i].number < end_field_number. Since extensions argument is sorted 1034 by field_number, this is a contiguous range; the first index j not included 1035 in that range is returned. The return value can be used as the start_index 1036 in the next call to serialize the next range of extensions. 1037 1038 Args: 1039 extensions: A list of ExtensionIdentifier sorted in field number order. 1040 start_index: The start index in the extensions list. 1041 end_field_number: The end field number of the extension range. 1042 1043 Returns: 1044 The first index that is not in the range. Or the size of extensions if all 1045 the extensions are within the range. 1046 """ 1047 def OutputSingleField(ext, value): 1048 out.putVarInt32(ext.wire_tag) 1049 if ext.field_type == TYPE_GROUP: 1050 if partial: 1051 value.OutputPartial(out) 1052 else: 1053 value.OutputUnchecked(out) 1054 out.putVarInt32(ext.wire_tag + 1) # End the group 1055 elif ext.field_type == TYPE_FOREIGN: 1056 if partial: 1057 out.putVarInt32(value.ByteSizePartial()) 1058 value.OutputPartial(out) 1059 else: 1060 out.putVarInt32(value.ByteSize()) 1061 value.OutputUnchecked(out) 1062 else: 1063 Encoder._TYPE_TO_METHOD[ext.field_type](out, value) 1064 1065 for ext_index, ext in enumerate( 1066 itertools.islice(extensions, start_index, None), start=start_index): 1067 if ext.number >= end_field_number: 1068 # exceeding extension range end. 1069 return ext_index 1070 if ext.is_repeated: 1071 for field in self._extension_fields[ext]: 1072 OutputSingleField(ext, field) 1073 else: 1074 OutputSingleField(ext, self._extension_fields[ext]) 1075 return len(extensions) 1076 1077 def _ParseOneExtensionField(self, wire_tag, d): 1078 number = wire_tag >> 3 1079 if number in self._extensions_by_field_number: 1080 ext = self._extensions_by_field_number[number] 1081 if wire_tag != ext.wire_tag: 1082 # wire_tag doesn't match; discard as unknown field. 1083 return 1084 if ext.field_type == TYPE_FOREIGN: 1085 length = d.getVarInt32() 1086 tmp = Decoder(d.buffer(), d.pos(), d.pos() + length) 1087 if ext.is_repeated: 1088 self.AddExtension(ext).TryMerge(tmp) 1089 else: 1090 self.MutableExtension(ext).TryMerge(tmp) 1091 d.skip(length) 1092 elif ext.field_type == TYPE_GROUP: 1093 if ext.is_repeated: 1094 self.AddExtension(ext).TryMerge(d) 1095 else: 1096 self.MutableExtension(ext).TryMerge(d) 1097 else: 1098 value = Decoder._TYPE_TO_METHOD[ext.field_type](d) 1099 if ext.is_repeated: 1100 self.AddExtension(ext, value) 1101 else: 1102 self.SetExtension(ext, value) 1103 else: 1104 # discard unknown extensions. 1105 d.skipData(wire_tag) 1106 1107 def _ExtensionByteSize(self, partial): 1108 size = 0 1109 for extension, value in six.iteritems(self._extension_fields): 1110 ftype = extension.field_type 1111 tag_size = self.lengthVarInt64(extension.wire_tag) 1112 if ftype == TYPE_GROUP: 1113 tag_size *= 2 # end tag 1114 if extension.is_repeated: 1115 size += tag_size * len(value) 1116 for single_value in value: 1117 size += self._FieldByteSize(ftype, single_value, partial) 1118 else: 1119 size += tag_size + self._FieldByteSize(ftype, value, partial) 1120 return size 1121 1122 def _FieldByteSize(self, ftype, value, partial): 1123 size = 0 1124 if ftype == TYPE_STRING: 1125 size = self.lengthString(len(value)) 1126 elif ftype == TYPE_FOREIGN or ftype == TYPE_GROUP: 1127 if partial: 1128 size = self.lengthString(value.ByteSizePartial()) 1129 else: 1130 size = self.lengthString(value.ByteSize()) 1131 elif ftype == TYPE_INT64 or \ 1132 ftype == TYPE_UINT64 or \ 1133 ftype == TYPE_INT32: 1134 size = self.lengthVarInt64(value) 1135 else: 1136 if ftype in Encoder._TYPE_TO_BYTE_SIZE: 1137 size = Encoder._TYPE_TO_BYTE_SIZE[ftype] 1138 else: 1139 raise AssertionError( 1140 'Extension type %d is not recognized.' % ftype) 1141 return size 1142 1143 def _ExtensionDebugString(self, prefix, printElemNumber): 1144 res = '' 1145 extensions = self._ListExtensions() 1146 for extension in extensions: 1147 value = self._extension_fields[extension] 1148 if extension.is_repeated: 1149 cnt = 0 1150 for e in value: 1151 elm="" 1152 if printElemNumber: elm = "(%d)" % cnt 1153 if extension.composite_cls is not None: 1154 res += prefix + "[%s%s] {\n" % \ 1155 (extension.full_name, elm) 1156 res += e.__str__(prefix + " ", printElemNumber) 1157 res += prefix + "}\n" 1158 else: 1159 if extension.composite_cls is not None: 1160 res += prefix + "[%s] {\n" % extension.full_name 1161 res += value.__str__( 1162 prefix + " ", printElemNumber) 1163 res += prefix + "}\n" 1164 else: 1165 if extension.field_type in _TYPE_TO_DEBUG_STRING: 1166 text_value = _TYPE_TO_DEBUG_STRING[ 1167 extension.field_type](self, value) 1168 else: 1169 text_value = self.DebugFormat(value) 1170 res += prefix + "[%s]: %s\n" % (extension.full_name, text_value) 1171 return res 1172 1173 @staticmethod 1174 def _RegisterExtension(cls, extension, composite_cls=None): 1175 extension.containing_cls = cls 1176 extension.composite_cls = composite_cls 1177 if composite_cls is not None: 1178 extension.message_name = composite_cls._PROTO_DESCRIPTOR_NAME 1179 actual_handle = cls._extensions_by_field_number.setdefault( 1180 extension.number, extension) 1181 if actual_handle is not extension: 1182 raise AssertionError( 1183 'Extensions "%s" and "%s" both try to extend message type "%s" with ' 1184 'field number %d.' % 1185 (extension.full_name, actual_handle.full_name, 1186 cls.__name__, extension.number)) 1187