1# Copyright 2016 Google LLC. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# Copyright 2002, Google LLC.
15
16
17from __future__ import absolute_import
18import array
19import six.moves.http_client
20import itertools
21import re
22import struct
23import six
24
25try:
26  # NOTE(user): Using non-google-style import to workaround a zipimport_tinypar
27  # issue for zip files embedded in par files. See http://b/13811096
28  import googlecloudsdk.third_party.appengine.proto.proto1 as proto1
29except ImportError:
30  # Protect in case of missing deps / strange env (GAE?) / etc.
31  class ProtocolBufferDecodeError(Exception): pass
32  class ProtocolBufferEncodeError(Exception): pass
33  class ProtocolBufferReturnError(Exception): pass
34else:
35  ProtocolBufferDecodeError = proto1.ProtocolBufferDecodeError
36  ProtocolBufferEncodeError = proto1.ProtocolBufferEncodeError
37  ProtocolBufferReturnError = proto1.ProtocolBufferReturnError
38
39__all__ = ['ProtocolMessage', 'Encoder', 'Decoder',
40           'ExtendableProtocolMessage',
41           'ProtocolBufferDecodeError',
42           'ProtocolBufferEncodeError',
43           'ProtocolBufferReturnError']
44
45URL_RE = re.compile('^(https?)://([^/]+)(/.*)$')
46
47
48class ProtocolMessage:
49  """
50  The parent class of all protocol buffers.
51  NOTE: the methods that unconditionally raise NotImplementedError are
52  reimplemented by the subclasses of this class.
53  Subclasses are automatically generated by tools/protocol_converter.
54  Encoding methods can raise ProtocolBufferEncodeError if a value for an
55  integer or long field is too large, or if any required field is not set.
56  Decoding methods can raise ProtocolBufferDecodeError if they couldn't
57  decode correctly, or the decoded message doesn't have all required fields.
58  """
59
60  #####################################
61  # methods you should use            #
62  #####################################
63
64  def __init__(self, contents=None):
65    """Construct a new protocol buffer, with optional starting contents
66    in binary protocol buffer format."""
67    raise NotImplementedError
68
69  def Clear(self):
70    """Erases all fields of protocol buffer (& resets to defaults
71    if fields have defaults)."""
72    raise NotImplementedError
73
74  def IsInitialized(self, debug_strs=None):
75    """returns true iff all required fields have been set."""
76    raise NotImplementedError
77
78  def Encode(self):
79    """Returns a string representing the protocol buffer object."""
80    try:
81      return self._CEncode()
82    except (NotImplementedError, AttributeError):
83      e = Encoder()
84      self.Output(e)
85      return e.buffer().tostring()
86
87  def SerializeToString(self):
88    """Same as Encode(), but has same name as proto2's serialize function."""
89    return self.Encode()
90
91  def SerializePartialToString(self):
92    """Returns a string representing the protocol buffer object.
93    Same as SerializeToString() but does not enforce required fields are set.
94    """
95    try:
96      return self._CEncodePartial()
97    except (NotImplementedError, AttributeError):
98      e = Encoder()
99      self.OutputPartial(e)
100      return e.buffer().tostring()
101
102  def _CEncode(self):
103    """Call into C++ encode code.
104
105    Generated protocol buffer classes will override this method to
106    provide C++-based serialization. If a subclass does not
107    implement this method, Encode() will fall back to
108    using pure-Python encoding.
109    """
110    raise NotImplementedError
111
112  def _CEncodePartial(self):
113    """Same as _CEncode, except does not encode missing required fields."""
114    raise NotImplementedError
115
116  def ParseFromString(self, s):
117    """Reads data from the string 's'.
118    Raises a ProtocolBufferDecodeError if, after successfully reading
119    in the contents of 's', this protocol message is still not initialized."""
120    self.Clear()
121    self.MergeFromString(s)
122
123  def ParsePartialFromString(self, s):
124    """Reads data from the string 's'.
125    Does not enforce required fields are set."""
126    self.Clear()
127    self.MergePartialFromString(s)
128
129  def MergeFromString(self, s):
130    """Adds in data from the string 's'.
131    Raises a ProtocolBufferDecodeError if, after successfully merging
132    in the contents of 's', this protocol message is still not initialized."""
133    self.MergePartialFromString(s)
134    dbg = []
135    if not self.IsInitialized(dbg):
136      raise ProtocolBufferDecodeError('\n\t'.join(dbg))
137
138  def MergePartialFromString(self, s):
139    """Merges in data from the string 's'.
140    Does not enforce required fields are set."""
141    try:
142      self._CMergeFromString(s)
143    except (NotImplementedError, AttributeError):
144      # If we can't call into C++ to deserialize the string, use
145      # the (much slower) pure-Python implementation.
146      a = array.array('B')
147      a.fromstring(s)
148      d = Decoder(a, 0, len(a))
149      self.TryMerge(d)
150
151  def _CMergeFromString(self, s):
152    """Call into C++ parsing code to merge from a string.
153
154    Does *not* check IsInitialized() before returning.
155
156    Generated protocol buffer classes will override this method to
157    provide C++-based deserialization.  If a subclass does not
158    implement this method, MergeFromString() will fall back to
159    using pure-Python parsing.
160    """
161    raise NotImplementedError
162
163  def __getstate__(self):
164    """Return the pickled representation of the data inside protocol buffer,
165    which is the same as its binary-encoded representation (as a string)."""
166    return self.Encode()
167
168  def __setstate__(self, contents_):
169    """Restore the pickled representation of the data inside protocol buffer.
170    Note that the mechanism underlying pickle.load() does not call __init__."""
171    self.__init__(contents=contents_)
172
173  def sendCommand(self, server, url, response, follow_redirects=1,
174                  secure=0, keyfile=None, certfile=None):
175    """posts the protocol buffer to the desired url on the server
176    and puts the return data into the protocol buffer 'response'
177
178    NOTE: The underlying socket raises the 'error' exception
179    for all I/O related errors (can't connect, etc.).
180
181    If 'response' is None, the server's PB response will be ignored.
182
183    The optional 'follow_redirects' argument indicates the number
184    of HTTP redirects that are followed before giving up and raising an
185    exception.  The default is 1.
186
187    If 'secure' is true, HTTPS will be used instead of HTTP.  Also,
188    'keyfile' and 'certfile' may be set for client authentication.
189    """
190    data = self.Encode()
191    if secure:
192      if keyfile and certfile:
193        conn = six.moves.http_client.HTTPSConnection(server, key_file=keyfile,
194                                       cert_file=certfile)
195      else:
196        conn = six.moves.http_client.HTTPSConnection(server)
197    else:
198      conn = six.moves.http_client.HTTPConnection(server)
199    conn.putrequest("POST", url)
200    conn.putheader("Content-Length", "%d" %len(data))
201    conn.endheaders()
202    conn.send(data)
203    resp = conn.getresponse()
204    if follow_redirects > 0 and resp.status == 302:
205      m = URL_RE.match(resp.getheader('Location'))
206      if m:
207        protocol, server, url = m.groups()
208        return self.sendCommand(server, url, response,
209                                follow_redirects=follow_redirects - 1,
210                                secure=(protocol == 'https'),
211                                keyfile=keyfile,
212                                certfile=certfile)
213    if resp.status != 200:
214      raise ProtocolBufferReturnError(resp.status)
215    if response is not None:
216      response.ParseFromString(resp.read())
217    return response
218
219  def sendSecureCommand(self, server, keyfile, certfile, url, response,
220                        follow_redirects=1):
221    """posts the protocol buffer via https to the desired url on the server,
222    using the specified key and certificate files, and puts the return
223    data int othe protocol buffer 'response'.
224
225    See caveats in sendCommand.
226
227    You need an SSL-aware build of the Python2 interpreter to use this command.
228    (Python1 is not supported).  An SSL build of python2.2 is in
229    /home/build/buildtools/python-ssl-2.2 . An SSL build of python is
230    standard on all prod machines.
231
232    keyfile: Contains our private RSA key
233    certfile: Contains SSL certificate for remote host
234    Specify None for keyfile/certfile if you don't want to do client auth.
235    """
236    return self.sendCommand(server, url, response,
237                            follow_redirects=follow_redirects,
238                            secure=1, keyfile=keyfile, certfile=certfile)
239
240  def __str__(self, prefix="", printElemNumber=0):
241    """Returns nicely formatted contents of this protocol buffer."""
242    raise NotImplementedError
243
244  def ToASCII(self):
245    """Returns the protocol buffer as a human-readable string."""
246    return self._CToASCII(ProtocolMessage._SYMBOLIC_FULL_ASCII)
247
248  def ToShortASCII(self):
249    """Returns the protocol buffer as an ASCII string.
250    The output is short, leaving out newlines and some other niceties.
251    Defers to the C++ ProtocolPrinter class in SYMBOLIC_SHORT mode.
252    """
253    return self._CToASCII(ProtocolMessage._SYMBOLIC_SHORT_ASCII)
254
255  # Note that these must be consistent with the ProtocolPrinter::Level C++
256  # enum.
257  _NUMERIC_ASCII = 0
258  _SYMBOLIC_SHORT_ASCII = 1
259  _SYMBOLIC_FULL_ASCII = 2
260
261  def _CToASCII(self, output_format):
262    """Calls into C++ ASCII-generating code.
263
264    Generated protocol buffer classes will override this method to provide
265    C++-based ASCII output.
266    """
267    raise NotImplementedError
268
269  def ParseASCII(self, ascii_string):
270    """Parses a string generated by ToASCII() or by the C++ DebugString()
271    method, initializing this protocol buffer with its contents. This method
272    raises a ValueError if it encounters an unknown field.
273    """
274    raise NotImplementedError
275
276  def ParseASCIIIgnoreUnknown(self, ascii_string):
277    """Parses a string generated by ToASCII() or by the C++ DebugString()
278    method, initializing this protocol buffer with its contents.  Ignores
279    unknown fields.
280    """
281    raise NotImplementedError
282
283  def Equals(self, other):
284    """Returns whether or not this protocol buffer is equivalent to another.
285
286    This assumes that self and other are of the same type.
287    """
288    raise NotImplementedError
289
290  def __eq__(self, other):
291    """Implementation of operator ==."""
292    # If self and other are of different types we return NotImplemented, which
293    # tells the Python interpreter to try some other methods of measuring
294    # equality before finally performing an identity comparison.  This allows
295    # other classes to implement custom __eq__ or __ne__ methods.
296    # See http://docs.sympy.org/_sources/python-comparisons.txt
297    if other.__class__ is self.__class__:
298      return self.Equals(other)
299    return NotImplemented
300
301  def __ne__(self, other):
302    """Implementation of operator !=."""
303    # We repeat code for __ne__ instead of returning "not (self == other)"
304    # so that we can return NotImplemented when comparing against an object of
305    # a different type.
306    # See http://bugs.python.org/msg76374 for an example of when __ne__ might
307    # return something other than the Boolean opposite of __eq__.
308    if other.__class__ is self.__class__:
309      return not self.Equals(other)
310    return NotImplemented
311
312  #####################################
313  # methods power-users might want    #
314  #####################################
315
316  def Output(self, e):
317    """write self to the encoder 'e'."""
318    dbg = []
319    if not self.IsInitialized(dbg):
320      raise ProtocolBufferEncodeError('\n\t'.join(dbg))
321    self.OutputUnchecked(e)
322    return
323
324  def OutputUnchecked(self, e):
325    """write self to the encoder 'e', don't check for initialization."""
326    raise NotImplementedError
327
328  def OutputPartial(self, e):
329    """write self to the encoder 'e', don't check for initialization and
330    don't assume required fields exist."""
331    raise NotImplementedError
332
333  def Parse(self, d):
334    """reads data from the Decoder 'd'."""
335    self.Clear()
336    self.Merge(d)
337    return
338
339  def Merge(self, d):
340    """merges data from the Decoder 'd'."""
341    self.TryMerge(d)
342    dbg = []
343    if not self.IsInitialized(dbg):
344      raise ProtocolBufferDecodeError('\n\t'.join(dbg))
345    return
346
347  def TryMerge(self, d):
348    """merges data from the Decoder 'd'."""
349    raise NotImplementedError
350
351  def CopyFrom(self, pb):
352    """copy data from another protocol buffer"""
353    if (pb == self): return
354    self.Clear()
355    self.MergeFrom(pb)
356
357  def MergeFrom(self, pb):
358    """merge data from another protocol buffer"""
359    raise NotImplementedError
360
361  #####################################
362  # helper methods for subclasses     #
363  #####################################
364
365  def lengthVarInt32(self, n):
366    return self.lengthVarInt64(n)
367
368  def lengthVarInt64(self, n):
369    if n < 0:
370      return 10 # ceil(64/7)
371    result = 0
372    while 1:
373      result += 1
374      n >>= 7
375      if n == 0:
376        break
377    return result
378
379  def lengthString(self, n):
380    return self.lengthVarInt32(n) + n
381
382  def DebugFormat(self, value):
383    return "%s" % value
384  def DebugFormatInt32(self, value):
385    if (value <= -2000000000 or value >= 2000000000):
386      return self.DebugFormatFixed32(value)
387    return "%d" % value
388  def DebugFormatInt64(self, value):
389    if (value <= -20000000000000 or value >= 20000000000000):
390      return self.DebugFormatFixed64(value)
391    return "%d" % value
392  def DebugFormatString(self, value):
393    # For now we only escape the bare minimum to insure interoperability
394    # and redability. In the future we may want to mimick the c++ behavior
395    # more closely, but this will make the code a lot more messy.
396    def escape(c):
397      o = ord(c)
398      if o == 10: return r"\n"   # optional escape
399      if o == 39: return r"\'"   # optional escape
400
401      if o == 34: return r'\"'   # necessary escape
402      if o == 92: return r"\\"   # necessary escape
403
404      if o >= 127 or o < 32: return "\\%03o" % o # necessary escapes
405      return c
406    return '"' + "".join(escape(c) for c in value) + '"'
407  def DebugFormatFloat(self, value):
408    return "%ff" % value
409  def DebugFormatFixed32(self, value):
410    if (value < 0): value += (1<<32)
411    return "0x%x" % value
412  def DebugFormatFixed64(self, value):
413    if (value < 0): value += (1<<64)
414    return "0x%x" % value
415  def DebugFormatBool(self, value):
416    if value:
417      return "true"
418    else:
419      return "false"
420
421# types of fields, must match Proto::Type and net/proto/protocoltype.proto
422TYPE_DOUBLE  = 1
423TYPE_FLOAT   = 2
424TYPE_INT64   = 3
425TYPE_UINT64  = 4
426TYPE_INT32   = 5
427TYPE_FIXED64 = 6
428TYPE_FIXED32 = 7
429TYPE_BOOL    = 8
430TYPE_STRING  = 9
431TYPE_GROUP   = 10
432TYPE_FOREIGN = 11
433
434# debug string for extensions
435_TYPE_TO_DEBUG_STRING = {
436    TYPE_INT32:   ProtocolMessage.DebugFormatInt32,
437    TYPE_INT64:   ProtocolMessage.DebugFormatInt64,
438    TYPE_UINT64:  ProtocolMessage.DebugFormatInt64,
439    TYPE_FLOAT:   ProtocolMessage.DebugFormatFloat,
440    TYPE_STRING:  ProtocolMessage.DebugFormatString,
441    TYPE_FIXED32: ProtocolMessage.DebugFormatFixed32,
442    TYPE_FIXED64: ProtocolMessage.DebugFormatFixed64,
443    TYPE_BOOL:    ProtocolMessage.DebugFormatBool }
444
445# users of protocol buffers usually won't need to concern themselves
446# with either Encoders or Decoders.
447class Encoder:
448
449  # types of data
450  NUMERIC     = 0
451  DOUBLE      = 1
452  STRING      = 2
453  STARTGROUP  = 3
454  ENDGROUP    = 4
455  FLOAT       = 5
456  MAX_TYPE    = 6
457
458  def __init__(self):
459    self.buf = array.array('B')
460    return
461
462  def buffer(self):
463    return self.buf
464
465  def put8(self, v):
466    if v < 0 or v >= (1<<8): raise ProtocolBufferEncodeError("u8 too big")
467    self.buf.append(v & 255)
468    return
469
470  def put16(self, v):
471    if v < 0 or v >= (1<<16): raise ProtocolBufferEncodeError("u16 too big")
472    self.buf.append((v >> 0) & 255)
473    self.buf.append((v >> 8) & 255)
474    return
475
476  def put32(self, v):
477    if v < 0 or v >= (1<<32): raise ProtocolBufferEncodeError("u32 too big")
478    self.buf.append((v >> 0) & 255)
479    self.buf.append((v >> 8) & 255)
480    self.buf.append((v >> 16) & 255)
481    self.buf.append((v >> 24) & 255)
482    return
483
484  def put64(self, v):
485    if v < 0 or v >= (1<<64): raise ProtocolBufferEncodeError("u64 too big")
486    self.buf.append((v >> 0) & 255)
487    self.buf.append((v >> 8) & 255)
488    self.buf.append((v >> 16) & 255)
489    self.buf.append((v >> 24) & 255)
490    self.buf.append((v >> 32) & 255)
491    self.buf.append((v >> 40) & 255)
492    self.buf.append((v >> 48) & 255)
493    self.buf.append((v >> 56) & 255)
494    return
495
496  def putVarInt32(self, v):
497    # Profiling has shown this code to be very performance critical
498    # so we duplicate code, go for early exits when possible, etc.
499    # VarInt32 gets more unrolling because VarInt32s are far and away
500    # the most common element in protobufs (field tags and string
501    # lengths), so they get more attention.  They're also more
502    # likely to fit in one byte (string lengths again), so we
503    # check and bail out early if possible.
504
505    buf_append = self.buf.append  # cache attribute lookup
506    if v & 127 == v:
507      buf_append(v)
508      return
509    if v >= 0x80000000 or v < -0x80000000:  # python2.4 doesn't fold constants
510      raise ProtocolBufferEncodeError("int32 too big")
511    if v < 0:
512      v += 0x10000000000000000
513    while True:
514      bits = v & 127
515      v >>= 7
516      if v:
517        bits |= 128
518      buf_append(bits)
519      if not v:
520        break
521    return
522
523  def putVarInt64(self, v):
524    buf_append = self.buf.append
525    if v >= 0x8000000000000000 or v < -0x8000000000000000:
526      raise ProtocolBufferEncodeError("int64 too big")
527    if v < 0:
528      v += 0x10000000000000000
529    while True:
530      bits = v & 127
531      v >>= 7
532      if v:
533        bits |= 128
534      buf_append(bits)
535      if not v:
536        break
537    return
538
539  def putVarUint64(self, v):
540    buf_append = self.buf.append
541    if v < 0 or v >= 0x10000000000000000:
542      raise ProtocolBufferEncodeError("uint64 too big")
543    while True:
544      bits = v & 127
545      v >>= 7
546      if v:
547        bits |= 128
548      buf_append(bits)
549      if not v:
550        break
551    return
552
553  def putFloat(self, v):
554    a = array.array('B')
555    a.fromstring(struct.pack("<f", v))
556    self.buf.extend(a)
557    return
558
559  def putDouble(self, v):
560    a = array.array('B')
561    a.fromstring(struct.pack("<d", v))
562    self.buf.extend(a)
563    return
564
565  def putBoolean(self, v):
566    if v:
567      self.buf.append(1)
568    else:
569      self.buf.append(0)
570    return
571
572  def putPrefixedString(self, v):
573    # This change prevents corrupted encoding an YouTube, where
574    # our default encoding is utf-8 and unicode strings may occasionally be
575    # passed into ProtocolBuffers.
576    v = str(v)
577    self.putVarInt32(len(v))
578    self.buf.fromstring(v)
579    return
580
581  def putRawString(self, v):
582    self.buf.fromstring(v)
583
584  _TYPE_TO_METHOD = {
585      TYPE_DOUBLE:   putDouble,
586      TYPE_FLOAT:    putFloat,
587      TYPE_FIXED64:  put64,
588      TYPE_FIXED32:  put32,
589      TYPE_INT32:    putVarInt32,
590      TYPE_INT64:    putVarInt64,
591      TYPE_UINT64:   putVarUint64,
592      TYPE_BOOL:     putBoolean,
593      TYPE_STRING:   putPrefixedString }
594
595  _TYPE_TO_BYTE_SIZE = {
596      TYPE_DOUBLE:  8,
597      TYPE_FLOAT:   4,
598      TYPE_FIXED64: 8,
599      TYPE_FIXED32: 4,
600      TYPE_BOOL:    1 }
601
602class Decoder:
603  def __init__(self, buf, idx, limit):
604    self.buf = buf
605    self.idx = idx
606    self.limit = limit
607    return
608
609  def avail(self):
610    return self.limit - self.idx
611
612  def buffer(self):
613    return self.buf
614
615  def pos(self):
616    return self.idx
617
618  def skip(self, n):
619    if self.idx + n > self.limit: raise ProtocolBufferDecodeError("truncated")
620    self.idx += n
621    return
622
623  def skipData(self, tag):
624    t = tag & 7               # tag format type
625    if t == Encoder.NUMERIC:
626      self.getVarInt64()
627    elif t == Encoder.DOUBLE:
628      self.skip(8)
629    elif t == Encoder.STRING:
630      n = self.getVarInt32()
631      self.skip(n)
632    elif t == Encoder.STARTGROUP:
633      while 1:
634        t = self.getVarInt32()
635        if (t & 7) == Encoder.ENDGROUP:
636          break
637        else:
638          self.skipData(t)
639      if (t - Encoder.ENDGROUP) != (tag - Encoder.STARTGROUP):
640        raise ProtocolBufferDecodeError("corrupted")
641    elif t == Encoder.ENDGROUP:
642      raise ProtocolBufferDecodeError("corrupted")
643    elif t == Encoder.FLOAT:
644      self.skip(4)
645    else:
646      raise ProtocolBufferDecodeError("corrupted")
647
648  # these are all unsigned gets
649  def get8(self):
650    if self.idx >= self.limit: raise ProtocolBufferDecodeError("truncated")
651    c = self.buf[self.idx]
652    self.idx += 1
653    return c
654
655  def get16(self):
656    if self.idx + 2 > self.limit: raise ProtocolBufferDecodeError("truncated")
657    c = self.buf[self.idx]
658    d = self.buf[self.idx + 1]
659    self.idx += 2
660    return (d << 8) | c
661
662  def get32(self):
663    if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError("truncated")
664    c = self.buf[self.idx]
665    d = self.buf[self.idx + 1]
666    e = self.buf[self.idx + 2]
667    f = int(self.buf[self.idx + 3])
668    self.idx += 4
669    return (f << 24) | (e << 16) | (d << 8) | c
670
671  def get64(self):
672    if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError("truncated")
673    c = self.buf[self.idx]
674    d = self.buf[self.idx + 1]
675    e = self.buf[self.idx + 2]
676    f = int(self.buf[self.idx + 3])
677    g = int(self.buf[self.idx + 4])
678    h = int(self.buf[self.idx + 5])
679    i = int(self.buf[self.idx + 6])
680    j = int(self.buf[self.idx + 7])
681    self.idx += 8
682    return ((j << 56) | (i << 48) | (h << 40) | (g << 32) | (f << 24)
683            | (e << 16) | (d << 8) | c)
684
685  def getVarInt32(self):
686    # getVarInt32 gets different treatment than other integer getter
687    # functions due to the much larger number of varInt32s and also
688    # varInt32s that fit in one byte.  See the comment at putVarInt32.
689    b = self.get8()
690    if not (b & 128):
691      return b
692
693    result = int(0)
694    shift = 0
695
696    while 1:
697      result |= (int(b & 127) << shift)
698      shift += 7
699      if not (b & 128):
700        if result >= 0x10000000000000000:  # (1L << 64):
701          raise ProtocolBufferDecodeError("corrupted")
702        break
703      if shift >= 64: raise ProtocolBufferDecodeError("corrupted")
704      b = self.get8()
705
706    if result >= 0x8000000000000000:  # (1L << 63)
707      result -= 0x10000000000000000  # (1L << 64)
708    if result >= 0x80000000 or result < -0x80000000:  # (1L << 31)
709      raise ProtocolBufferDecodeError("corrupted")
710    return result
711
712  def getVarInt64(self):
713    result = self.getVarUint64()
714    if result >= (1 << 63):
715      result -= (1 << 64)
716    return result
717
718  def getVarUint64(self):
719    result = int(0)
720    shift = 0
721    while 1:
722      if shift >= 64: raise ProtocolBufferDecodeError("corrupted")
723      b = self.get8()
724      result |= (int(b & 127) << shift)
725      shift += 7
726      if not (b & 128):
727        if result >= (1 << 64): raise ProtocolBufferDecodeError("corrupted")
728        return result
729    return result             # make pychecker happy
730
731  def getFloat(self):
732    if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError("truncated")
733    a = self.buf[self.idx:self.idx+4]
734    self.idx += 4
735    return struct.unpack("<f", a)[0]
736
737  def getDouble(self):
738    if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError("truncated")
739    a = self.buf[self.idx:self.idx+8]
740    self.idx += 8
741    return struct.unpack("<d", a)[0]
742
743  def getBoolean(self):
744    b = self.get8()
745    if b != 0 and b != 1: raise ProtocolBufferDecodeError("corrupted")
746    return b
747
748  def getPrefixedString(self):
749    length = self.getVarInt32()
750    if self.idx + length > self.limit:
751      raise ProtocolBufferDecodeError("truncated")
752    r = self.buf[self.idx : self.idx + length]
753    self.idx += length
754    return r.tostring()
755
756  def getRawString(self):
757    r = self.buf[self.idx:self.limit]
758    self.idx = self.limit
759    return r.tostring()
760
761  _TYPE_TO_METHOD = {
762      TYPE_DOUBLE:   getDouble,
763      TYPE_FLOAT:    getFloat,
764      TYPE_FIXED64:  get64,
765      TYPE_FIXED32:  get32,
766      TYPE_INT32:    getVarInt32,
767      TYPE_INT64:    getVarInt64,
768      TYPE_UINT64:   getVarUint64,
769      TYPE_BOOL:     getBoolean,
770      TYPE_STRING:   getPrefixedString }
771
772#####################################
773# extensions                        #
774#####################################
775
776class ExtensionIdentifier(object):
777  __slots__ = ('full_name', 'number', 'field_type', 'wire_tag', 'is_repeated',
778               'default', 'containing_cls', 'composite_cls', 'message_name')
779  def __init__(self, full_name, number, field_type, wire_tag, is_repeated,
780               default):
781    self.full_name = full_name
782    self.number = number
783    self.field_type = field_type
784    self.wire_tag = wire_tag
785    self.is_repeated = is_repeated
786    self.default = default
787
788class ExtendableProtocolMessage(ProtocolMessage):
789  def HasExtension(self, extension):
790    """Checks if the message contains a certain non-repeated extension."""
791    self._VerifyExtensionIdentifier(extension)
792    return extension in self._extension_fields
793
794  def ClearExtension(self, extension):
795    """Clears the value of extension, so that HasExtension() returns false or
796    ExtensionSize() returns 0."""
797    self._VerifyExtensionIdentifier(extension)
798    if extension in self._extension_fields:
799      del self._extension_fields[extension]
800
801  def GetExtension(self, extension, index=None):
802    """Gets the extension value for a certain extension.
803
804    Args:
805      extension: The ExtensionIdentifier for the extension.
806      index: The index of element to get in a repeated field. Only needed if
807          the extension is repeated.
808
809    Returns:
810      The value of the extension if exists, otherwise the default value of the
811      extension will be returned.
812    """
813    self._VerifyExtensionIdentifier(extension)
814    if extension in self._extension_fields:
815      result = self._extension_fields[extension]
816    else:
817      if extension.is_repeated:
818        result = []
819      elif extension.composite_cls:
820        result = extension.composite_cls()
821      else:
822        result = extension.default
823    if extension.is_repeated:
824      result = result[index]
825    return result
826
827  def SetExtension(self, extension, *args):
828    """Sets the extension value for a certain scalar type extension.
829
830    Arg varies according to extension type:
831    - Singular:
832      message.SetExtension(extension, value)
833    - Repeated:
834      message.SetExtension(extension, index, value)
835    where
836      extension: The ExtensionIdentifier for the extension.
837      index: The index of element to set in a repeated field. Only needed if
838          the extension is repeated.
839      value: The value to set.
840
841    Raises:
842      TypeError if a message type extension is given.
843    """
844    self._VerifyExtensionIdentifier(extension)
845    if extension.composite_cls:
846      raise TypeError(
847          'Cannot assign to extension "%s" because it is a composite type.' %
848          extension.full_name)
849    if extension.is_repeated:
850      try:
851        index, value = args
852      except ValueError:
853        raise TypeError(
854            "SetExtension(extension, index, value) for repeated extension "
855            "takes exactly 4 arguments: (%d given)" % (len(args) + 2))
856      self._extension_fields[extension][index] = value
857    else:
858      try:
859        (value,) = args
860      except ValueError:
861        raise TypeError(
862            "SetExtension(extension, value) for singular extension "
863            "takes exactly 3 arguments: (%d given)" % (len(args) + 2))
864      self._extension_fields[extension] = value
865
866  def MutableExtension(self, extension, index=None):
867    """Gets a mutable reference of a message type extension.
868
869    For repeated extension, index must be specified, and only one element will
870    be returned. For optional extension, if the extension does not exist, a new
871    message will be created and set in parent message.
872
873    Args:
874      extension: The ExtensionIdentifier for the extension.
875      index: The index of element to mutate in a repeated field. Only needed if
876          the extension is repeated.
877
878    Returns:
879      The mutable message reference.
880
881    Raises:
882      TypeError if non-message type extension is given.
883    """
884    self._VerifyExtensionIdentifier(extension)
885    if extension.composite_cls is None:
886      raise TypeError(
887          'MutableExtension() cannot be applied to "%s", because it is not a '
888          'composite type.' % extension.full_name)
889    if extension.is_repeated:
890      if index is None:
891        raise TypeError(
892            'MutableExtension(extension, index) for repeated extension '
893            'takes exactly 2 arguments: (1 given)')
894      return self.GetExtension(extension, index)
895    if extension in self._extension_fields:
896      return self._extension_fields[extension]
897    else:
898      result = extension.composite_cls()
899      self._extension_fields[extension] = result
900      return result
901
902  def ExtensionList(self, extension):
903    """Returns a mutable list of extensions.
904
905    Raises:
906      TypeError if the extension is not repeated.
907    """
908    self._VerifyExtensionIdentifier(extension)
909    if not extension.is_repeated:
910      raise TypeError(
911          'ExtensionList() cannot be applied to "%s", because it is not a '
912          'repeated extension.' % extension.full_name)
913    if extension in self._extension_fields:
914      return self._extension_fields[extension]
915    result = []
916    self._extension_fields[extension] = result
917    return result
918
919  def ExtensionSize(self, extension):
920    """Returns the size of a repeated extension.
921
922    Raises:
923      TypeError if the extension is not repeated.
924    """
925    self._VerifyExtensionIdentifier(extension)
926    if not extension.is_repeated:
927      raise TypeError(
928          'ExtensionSize() cannot be applied to "%s", because it is not a '
929          'repeated extension.' % extension.full_name)
930    if extension in self._extension_fields:
931      return len(self._extension_fields[extension])
932    return 0
933
934  def AddExtension(self, extension, value=None):
935    """Appends a new element into a repeated extension.
936
937    Arg varies according to the extension field type:
938    - Scalar/String:
939      message.AddExtension(extension, value)
940    - Message:
941      mutable_message = AddExtension(extension)
942
943    Args:
944      extension: The ExtensionIdentifier for the extension.
945      value: The value of the extension if the extension is scalar/string type.
946          The value must NOT be set for message type extensions; set values on
947          the returned message object instead.
948
949    Returns:
950      A mutable new message if it's a message type extension, or None otherwise.
951
952    Raises:
953      TypeError if the extension is not repeated, or value is given for message
954      type extensions.
955    """
956    self._VerifyExtensionIdentifier(extension)
957    if not extension.is_repeated:
958      raise TypeError(
959          'AddExtension() cannot be applied to "%s", because it is not a '
960          'repeated extension.' % extension.full_name)
961    if extension in self._extension_fields:
962      field = self._extension_fields[extension]
963    else:
964      field = []
965      self._extension_fields[extension] = field
966    # Composite field
967    if extension.composite_cls:
968      if value is not None:
969        raise TypeError(
970            'value must not be set in AddExtension() for "%s", because it is '
971            'a message type extension. Set values on the returned message '
972            'instead.' % extension.full_name)
973      msg = extension.composite_cls()
974      field.append(msg)
975      return msg
976    # Scalar and string field
977    field.append(value)
978
979  def _VerifyExtensionIdentifier(self, extension):
980    if extension.containing_cls != self.__class__:
981      raise TypeError("Containing type of %s is %s, but not %s."
982                      % (extension.full_name,
983                         extension.containing_cls.__name__,
984                         self.__class__.__name__))
985
986  def _MergeExtensionFields(self, x):
987    for ext, val in x._extension_fields.items():
988      if ext.is_repeated:
989        for single_val in val:
990          if ext.composite_cls is None:
991            self.AddExtension(ext, single_val)
992          else:
993            self.AddExtension(ext).MergeFrom(single_val)
994      else:
995        if ext.composite_cls is None:
996          self.SetExtension(ext, val)
997        else:
998          self.MutableExtension(ext).MergeFrom(val)
999
1000  def _ListExtensions(self):
1001    return sorted(
1002        (ext for ext in self._extension_fields
1003         if (not ext.is_repeated) or self.ExtensionSize(ext) > 0),
1004        key=lambda item: item.number)
1005
1006  def _ExtensionEquals(self, x):
1007    extensions = self._ListExtensions()
1008    if extensions != x._ListExtensions():
1009      return False
1010    for ext in extensions:
1011      if ext.is_repeated:
1012        if self.ExtensionSize(ext) != x.ExtensionSize(ext): return False
1013        for e1, e2 in zip(self.ExtensionList(ext),
1014                                     x.ExtensionList(ext)):
1015          if e1 != e2: return False
1016      else:
1017        if self.GetExtension(ext) != x.GetExtension(ext): return False
1018    return True
1019
1020  def _OutputExtensionFields(self, out, partial, extensions, start_index,
1021                             end_field_number):
1022    """Serialize a range of extensions.
1023
1024    To generate canonical output when encoding, we interleave fields and
1025    extensions to preserve tag order.
1026
1027    Generated code will prepare a list of ExtensionIdentifier sorted in field
1028    number order and call this method to serialize a specific range of
1029    extensions. The range is specified by the two arguments, start_index and
1030    end_field_number.
1031
1032    The method will serialize all extensions[i] with i >= start_index and
1033    extensions[i].number < end_field_number. Since extensions argument is sorted
1034    by field_number, this is a contiguous range; the first index j not included
1035    in that range is returned. The return value can be used as the start_index
1036    in the next call to serialize the next range of extensions.
1037
1038    Args:
1039      extensions: A list of ExtensionIdentifier sorted in field number order.
1040      start_index: The start index in the extensions list.
1041      end_field_number: The end field number of the extension range.
1042
1043    Returns:
1044      The first index that is not in the range. Or the size of extensions if all
1045      the extensions are within the range.
1046    """
1047    def OutputSingleField(ext, value):
1048      out.putVarInt32(ext.wire_tag)
1049      if ext.field_type == TYPE_GROUP:
1050        if partial:
1051          value.OutputPartial(out)
1052        else:
1053          value.OutputUnchecked(out)
1054        out.putVarInt32(ext.wire_tag + 1)  # End the group
1055      elif ext.field_type == TYPE_FOREIGN:
1056        if partial:
1057          out.putVarInt32(value.ByteSizePartial())
1058          value.OutputPartial(out)
1059        else:
1060          out.putVarInt32(value.ByteSize())
1061          value.OutputUnchecked(out)
1062      else:
1063        Encoder._TYPE_TO_METHOD[ext.field_type](out, value)
1064
1065    for ext_index, ext in enumerate(
1066        itertools.islice(extensions, start_index, None), start=start_index):
1067      if ext.number >= end_field_number:
1068        # exceeding extension range end.
1069        return ext_index
1070      if ext.is_repeated:
1071        for field in self._extension_fields[ext]:
1072          OutputSingleField(ext, field)
1073      else:
1074        OutputSingleField(ext, self._extension_fields[ext])
1075    return len(extensions)
1076
1077  def _ParseOneExtensionField(self, wire_tag, d):
1078    number = wire_tag >> 3
1079    if number in self._extensions_by_field_number:
1080      ext = self._extensions_by_field_number[number]
1081      if wire_tag != ext.wire_tag:
1082        # wire_tag doesn't match; discard as unknown field.
1083        return
1084      if ext.field_type == TYPE_FOREIGN:
1085        length = d.getVarInt32()
1086        tmp = Decoder(d.buffer(), d.pos(), d.pos() + length)
1087        if ext.is_repeated:
1088          self.AddExtension(ext).TryMerge(tmp)
1089        else:
1090          self.MutableExtension(ext).TryMerge(tmp)
1091        d.skip(length)
1092      elif ext.field_type == TYPE_GROUP:
1093        if ext.is_repeated:
1094          self.AddExtension(ext).TryMerge(d)
1095        else:
1096          self.MutableExtension(ext).TryMerge(d)
1097      else:
1098        value = Decoder._TYPE_TO_METHOD[ext.field_type](d)
1099        if ext.is_repeated:
1100          self.AddExtension(ext, value)
1101        else:
1102          self.SetExtension(ext, value)
1103    else:
1104      # discard unknown extensions.
1105      d.skipData(wire_tag)
1106
1107  def _ExtensionByteSize(self, partial):
1108    size = 0
1109    for extension, value in six.iteritems(self._extension_fields):
1110      ftype = extension.field_type
1111      tag_size = self.lengthVarInt64(extension.wire_tag)
1112      if ftype == TYPE_GROUP:
1113        tag_size *= 2  # end tag
1114      if extension.is_repeated:
1115        size += tag_size * len(value)
1116        for single_value in value:
1117          size += self._FieldByteSize(ftype, single_value, partial)
1118      else:
1119        size += tag_size + self._FieldByteSize(ftype, value, partial)
1120    return size
1121
1122  def _FieldByteSize(self, ftype, value, partial):
1123    size = 0
1124    if ftype == TYPE_STRING:
1125      size = self.lengthString(len(value))
1126    elif ftype == TYPE_FOREIGN or ftype == TYPE_GROUP:
1127      if partial:
1128        size = self.lengthString(value.ByteSizePartial())
1129      else:
1130        size = self.lengthString(value.ByteSize())
1131    elif ftype == TYPE_INT64 or  \
1132         ftype == TYPE_UINT64 or \
1133         ftype == TYPE_INT32:
1134      size = self.lengthVarInt64(value)
1135    else:
1136      if ftype in Encoder._TYPE_TO_BYTE_SIZE:
1137        size = Encoder._TYPE_TO_BYTE_SIZE[ftype]
1138      else:
1139        raise AssertionError(
1140            'Extension type %d is not recognized.' % ftype)
1141    return size
1142
1143  def _ExtensionDebugString(self, prefix, printElemNumber):
1144    res = ''
1145    extensions = self._ListExtensions()
1146    for extension in extensions:
1147      value = self._extension_fields[extension]
1148      if extension.is_repeated:
1149        cnt = 0
1150        for e in value:
1151          elm=""
1152          if printElemNumber: elm = "(%d)" % cnt
1153          if extension.composite_cls is not None:
1154            res += prefix + "[%s%s] {\n" % \
1155                (extension.full_name, elm)
1156            res += e.__str__(prefix + "  ", printElemNumber)
1157            res += prefix + "}\n"
1158      else:
1159        if extension.composite_cls is not None:
1160          res += prefix + "[%s] {\n" % extension.full_name
1161          res += value.__str__(
1162              prefix + "  ", printElemNumber)
1163          res += prefix + "}\n"
1164        else:
1165          if extension.field_type in _TYPE_TO_DEBUG_STRING:
1166            text_value = _TYPE_TO_DEBUG_STRING[
1167                extension.field_type](self, value)
1168          else:
1169            text_value = self.DebugFormat(value)
1170          res += prefix + "[%s]: %s\n" % (extension.full_name, text_value)
1171    return res
1172
1173  @staticmethod
1174  def _RegisterExtension(cls, extension, composite_cls=None):
1175    extension.containing_cls = cls
1176    extension.composite_cls = composite_cls
1177    if composite_cls is not None:
1178      extension.message_name = composite_cls._PROTO_DESCRIPTOR_NAME
1179    actual_handle = cls._extensions_by_field_number.setdefault(
1180        extension.number, extension)
1181    if actual_handle is not extension:
1182      raise AssertionError(
1183          'Extensions "%s" and "%s" both try to extend message type "%s" with '
1184          'field number %d.' %
1185          (extension.full_name, actual_handle.full_name,
1186           cls.__name__, extension.number))
1187