1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(user): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
38
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
41
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
45
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
50
51__author__ = 'robinson@google.com (Will Robinson)'
52
53from io import BytesIO
54import struct
55import sys
56import weakref
57
58import six
59from six.moves import range
60
61# We use "as" to avoid name collisions with variables.
62from google.protobuf.internal import api_implementation
63from google.protobuf.internal import containers
64from google.protobuf.internal import decoder
65from google.protobuf.internal import encoder
66from google.protobuf.internal import enum_type_wrapper
67from google.protobuf.internal import extension_dict
68from google.protobuf.internal import message_listener as message_listener_mod
69from google.protobuf.internal import type_checkers
70from google.protobuf.internal import well_known_types
71from google.protobuf.internal import wire_format
72from google.protobuf import descriptor as descriptor_mod
73from google.protobuf import message as message_mod
74from google.protobuf import text_format
75
76_FieldDescriptor = descriptor_mod.FieldDescriptor
77_AnyFullTypeName = 'google.protobuf.Any'
78_ExtensionDict = extension_dict._ExtensionDict
79
80class GeneratedProtocolMessageType(type):
81
82  """Metaclass for protocol message classes created at runtime from Descriptors.
83
84  We add implementations for all methods described in the Message class.  We
85  also create properties to allow getting/setting all fields in the protocol
86  message.  Finally, we create slots to prevent users from accidentally
87  "setting" nonexistent fields in the protocol message, which then wouldn't get
88  serialized / deserialized properly.
89
90  The protocol compiler currently uses this metaclass to create protocol
91  message classes at runtime.  Clients can also manually create their own
92  classes at runtime, as in this example:
93
94  mydescriptor = Descriptor(.....)
95  factory = symbol_database.Default()
96  factory.pool.AddDescriptor(mydescriptor)
97  MyProtoClass = factory.GetPrototype(mydescriptor)
98  myproto_instance = MyProtoClass()
99  myproto.foo_field = 23
100  ...
101  """
102
103  # Must be consistent with the protocol-compiler code in
104  # proto2/compiler/internal/generator.*.
105  _DESCRIPTOR_KEY = 'DESCRIPTOR'
106
107  def __new__(cls, name, bases, dictionary):
108    """Custom allocation for runtime-generated class types.
109
110    We override __new__ because this is apparently the only place
111    where we can meaningfully set __slots__ on the class we're creating(?).
112    (The interplay between metaclasses and slots is not very well-documented).
113
114    Args:
115      name: Name of the class (ignored, but required by the
116        metaclass protocol).
117      bases: Base classes of the class we're constructing.
118        (Should be message.Message).  We ignore this field, but
119        it's required by the metaclass protocol
120      dictionary: The class dictionary of the class we're
121        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
122        a Descriptor object describing this protocol message
123        type.
124
125    Returns:
126      Newly-allocated class.
127
128    Raises:
129      RuntimeError: Generated code only work with python cpp extension.
130    """
131    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
132
133    if isinstance(descriptor, str):
134      raise RuntimeError('The generated code only work with python cpp '
135                         'extension, but it is using pure python runtime.')
136
137    # If a concrete class already exists for this descriptor, don't try to
138    # create another.  Doing so will break any messages that already exist with
139    # the existing class.
140    #
141    # The C++ implementation appears to have its own internal `PyMessageFactory`
142    # to achieve similar results.
143    #
144    # This most commonly happens in `text_format.py` when using descriptors from
145    # a custom pool; it calls symbol_database.Global().getPrototype() on a
146    # descriptor which already has an existing concrete class.
147    new_class = getattr(descriptor, '_concrete_class', None)
148    if new_class:
149      return new_class
150
151    if descriptor.full_name in well_known_types.WKTBASES:
152      bases += (well_known_types.WKTBASES[descriptor.full_name],)
153    _AddClassAttributesForNestedExtensions(descriptor, dictionary)
154    _AddSlots(descriptor, dictionary)
155
156    superclass = super(GeneratedProtocolMessageType, cls)
157    new_class = superclass.__new__(cls, name, bases, dictionary)
158    return new_class
159
160  def __init__(cls, name, bases, dictionary):
161    """Here we perform the majority of our work on the class.
162    We add enum getters, an __init__ method, implementations
163    of all Message methods, and properties for all fields
164    in the protocol type.
165
166    Args:
167      name: Name of the class (ignored, but required by the
168        metaclass protocol).
169      bases: Base classes of the class we're constructing.
170        (Should be message.Message).  We ignore this field, but
171        it's required by the metaclass protocol
172      dictionary: The class dictionary of the class we're
173        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
174        a Descriptor object describing this protocol message
175        type.
176    """
177    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
178
179    # If this is an _existing_ class looked up via `_concrete_class` in the
180    # __new__ method above, then we don't need to re-initialize anything.
181    existing_class = getattr(descriptor, '_concrete_class', None)
182    if existing_class:
183      assert existing_class is cls, (
184          'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
185          % (descriptor.full_name))
186      return
187
188    cls._decoders_by_tag = {}
189    if (descriptor.has_options and
190        descriptor.GetOptions().message_set_wire_format):
191      cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
192          decoder.MessageSetItemDecoder(descriptor), None)
193
194    # Attach stuff to each FieldDescriptor for quick lookup later on.
195    for field in descriptor.fields:
196      _AttachFieldHelpers(cls, field)
197
198    descriptor._concrete_class = cls  # pylint: disable=protected-access
199    _AddEnumValues(descriptor, cls)
200    _AddInitMethod(descriptor, cls)
201    _AddPropertiesForFields(descriptor, cls)
202    _AddPropertiesForExtensions(descriptor, cls)
203    _AddStaticMethods(cls)
204    _AddMessageMethods(descriptor, cls)
205    _AddPrivateHelperMethods(descriptor, cls)
206
207    superclass = super(GeneratedProtocolMessageType, cls)
208    superclass.__init__(name, bases, dictionary)
209
210
211# Stateless helpers for GeneratedProtocolMessageType below.
212# Outside clients should not access these directly.
213#
214# I opted not to make any of these methods on the metaclass, to make it more
215# clear that I'm not really using any state there and to keep clients from
216# thinking that they have direct access to these construction helpers.
217
218
219def _PropertyName(proto_field_name):
220  """Returns the name of the public property attribute which
221  clients can use to get and (in some cases) set the value
222  of a protocol message field.
223
224  Args:
225    proto_field_name: The protocol message field name, exactly
226      as it appears (or would appear) in a .proto file.
227  """
228  # TODO(user): Escape Python keywords (e.g., yield), and test this support.
229  # nnorwitz makes my day by writing:
230  # """
231  # FYI.  See the keyword module in the stdlib. This could be as simple as:
232  #
233  # if keyword.iskeyword(proto_field_name):
234  #   return proto_field_name + "_"
235  # return proto_field_name
236  # """
237  # Kenton says:  The above is a BAD IDEA.  People rely on being able to use
238  #   getattr() and setattr() to reflectively manipulate field values.  If we
239  #   rename the properties, then every such user has to also make sure to apply
240  #   the same transformation.  Note that currently if you name a field "yield",
241  #   you can still access it just fine using getattr/setattr -- it's not even
242  #   that cumbersome to do so.
243  # TODO(user):  Remove this method entirely if/when everyone agrees with my
244  #   position.
245  return proto_field_name
246
247
248def _AddSlots(message_descriptor, dictionary):
249  """Adds a __slots__ entry to dictionary, containing the names of all valid
250  attributes for this message type.
251
252  Args:
253    message_descriptor: A Descriptor instance describing this message type.
254    dictionary: Class dictionary to which we'll add a '__slots__' entry.
255  """
256  dictionary['__slots__'] = ['_cached_byte_size',
257                             '_cached_byte_size_dirty',
258                             '_fields',
259                             '_unknown_fields',
260                             '_unknown_field_set',
261                             '_is_present_in_parent',
262                             '_listener',
263                             '_listener_for_children',
264                             '__weakref__',
265                             '_oneofs']
266
267
268def _IsMessageSetExtension(field):
269  return (field.is_extension and
270          field.containing_type.has_options and
271          field.containing_type.GetOptions().message_set_wire_format and
272          field.type == _FieldDescriptor.TYPE_MESSAGE and
273          field.label == _FieldDescriptor.LABEL_OPTIONAL)
274
275
276def _IsMapField(field):
277  return (field.type == _FieldDescriptor.TYPE_MESSAGE and
278          field.message_type.has_options and
279          field.message_type.GetOptions().map_entry)
280
281
282def _IsMessageMapField(field):
283  value_type = field.message_type.fields_by_name['value']
284  return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
285
286
287def _IsStrictUtf8Check(field):
288  if field.containing_type.syntax != 'proto3':
289    return False
290  enforce_utf8 = True
291  return enforce_utf8
292
293
294def _AttachFieldHelpers(cls, field_descriptor):
295  is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
296  is_packable = (is_repeated and
297                 wire_format.IsTypePackable(field_descriptor.type))
298  if not is_packable:
299    is_packed = False
300  elif field_descriptor.containing_type.syntax == 'proto2':
301    is_packed = (field_descriptor.has_options and
302                field_descriptor.GetOptions().packed)
303  else:
304    has_packed_false = (field_descriptor.has_options and
305                        field_descriptor.GetOptions().HasField('packed') and
306                        field_descriptor.GetOptions().packed == False)
307    is_packed = not has_packed_false
308  is_map_entry = _IsMapField(field_descriptor)
309
310  if is_map_entry:
311    field_encoder = encoder.MapEncoder(field_descriptor)
312    sizer = encoder.MapSizer(field_descriptor,
313                             _IsMessageMapField(field_descriptor))
314  elif _IsMessageSetExtension(field_descriptor):
315    field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
316    sizer = encoder.MessageSetItemSizer(field_descriptor.number)
317  else:
318    field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
319        field_descriptor.number, is_repeated, is_packed)
320    sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
321        field_descriptor.number, is_repeated, is_packed)
322
323  field_descriptor._encoder = field_encoder
324  field_descriptor._sizer = sizer
325  field_descriptor._default_constructor = _DefaultValueConstructorForField(
326      field_descriptor)
327
328  def AddDecoder(wiretype, is_packed):
329    tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
330    decode_type = field_descriptor.type
331    if (decode_type == _FieldDescriptor.TYPE_ENUM and
332        type_checkers.SupportsOpenEnums(field_descriptor)):
333      decode_type = _FieldDescriptor.TYPE_INT32
334
335    oneof_descriptor = None
336    if field_descriptor.containing_oneof is not None:
337      oneof_descriptor = field_descriptor
338
339    if is_map_entry:
340      is_message_map = _IsMessageMapField(field_descriptor)
341
342      field_decoder = decoder.MapDecoder(
343          field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
344          is_message_map)
345    elif decode_type == _FieldDescriptor.TYPE_STRING:
346      is_strict_utf8_check = _IsStrictUtf8Check(field_descriptor)
347      field_decoder = decoder.StringDecoder(
348          field_descriptor.number, is_repeated, is_packed,
349          field_descriptor, field_descriptor._default_constructor,
350          is_strict_utf8_check)
351    else:
352      field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
353          field_descriptor.number, is_repeated, is_packed,
354          field_descriptor, field_descriptor._default_constructor)
355
356    cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
357
358  AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
359             False)
360
361  if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
362    # To support wire compatibility of adding packed = true, add a decoder for
363    # packed values regardless of the field's options.
364    AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
365
366
367def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
368  extensions = descriptor.extensions_by_name
369  for extension_name, extension_field in extensions.items():
370    assert extension_name not in dictionary
371    dictionary[extension_name] = extension_field
372
373
374def _AddEnumValues(descriptor, cls):
375  """Sets class-level attributes for all enum fields defined in this message.
376
377  Also exporting a class-level object that can name enum values.
378
379  Args:
380    descriptor: Descriptor object for this message type.
381    cls: Class we're constructing for this message type.
382  """
383  for enum_type in descriptor.enum_types:
384    setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
385    for enum_value in enum_type.values:
386      setattr(cls, enum_value.name, enum_value.number)
387
388
389def _GetInitializeDefaultForMap(field):
390  if field.label != _FieldDescriptor.LABEL_REPEATED:
391    raise ValueError('map_entry set on non-repeated field %s' % (
392        field.name))
393  fields_by_name = field.message_type.fields_by_name
394  key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
395
396  value_field = fields_by_name['value']
397  if _IsMessageMapField(field):
398    def MakeMessageMapDefault(message):
399      return containers.MessageMap(
400          message._listener_for_children, value_field.message_type, key_checker,
401          field.message_type)
402    return MakeMessageMapDefault
403  else:
404    value_checker = type_checkers.GetTypeChecker(value_field)
405    def MakePrimitiveMapDefault(message):
406      return containers.ScalarMap(
407          message._listener_for_children, key_checker, value_checker,
408          field.message_type)
409    return MakePrimitiveMapDefault
410
411def _DefaultValueConstructorForField(field):
412  """Returns a function which returns a default value for a field.
413
414  Args:
415    field: FieldDescriptor object for this field.
416
417  The returned function has one argument:
418    message: Message instance containing this field, or a weakref proxy
419      of same.
420
421  That function in turn returns a default value for this field.  The default
422    value may refer back to |message| via a weak reference.
423  """
424
425  if _IsMapField(field):
426    return _GetInitializeDefaultForMap(field)
427
428  if field.label == _FieldDescriptor.LABEL_REPEATED:
429    if field.has_default_value and field.default_value != []:
430      raise ValueError('Repeated field default value not empty list: %s' % (
431          field.default_value))
432    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
433      # We can't look at _concrete_class yet since it might not have
434      # been set.  (Depends on order in which we initialize the classes).
435      message_type = field.message_type
436      def MakeRepeatedMessageDefault(message):
437        return containers.RepeatedCompositeFieldContainer(
438            message._listener_for_children, field.message_type)
439      return MakeRepeatedMessageDefault
440    else:
441      type_checker = type_checkers.GetTypeChecker(field)
442      def MakeRepeatedScalarDefault(message):
443        return containers.RepeatedScalarFieldContainer(
444            message._listener_for_children, type_checker)
445      return MakeRepeatedScalarDefault
446
447  if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
448    # _concrete_class may not yet be initialized.
449    message_type = field.message_type
450    def MakeSubMessageDefault(message):
451      assert getattr(message_type, '_concrete_class', None), (
452          'Uninitialized concrete class found for field %r (message type %r)'
453          % (field.full_name, message_type.full_name))
454      result = message_type._concrete_class()
455      result._SetListener(
456          _OneofListener(message, field)
457          if field.containing_oneof is not None
458          else message._listener_for_children)
459      return result
460    return MakeSubMessageDefault
461
462  def MakeScalarDefault(message):
463    # TODO(protobuf-team): This may be broken since there may not be
464    # default_value.  Combine with has_default_value somehow.
465    return field.default_value
466  return MakeScalarDefault
467
468
469def _ReraiseTypeErrorWithFieldName(message_name, field_name):
470  """Re-raise the currently-handled TypeError with the field name added."""
471  exc = sys.exc_info()[1]
472  if len(exc.args) == 1 and type(exc) is TypeError:
473    # simple TypeError; add field name to exception message
474    exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
475
476  # re-raise possibly-amended exception with original traceback:
477  six.reraise(type(exc), exc, sys.exc_info()[2])
478
479
480def _AddInitMethod(message_descriptor, cls):
481  """Adds an __init__ method to cls."""
482
483  def _GetIntegerEnumValue(enum_type, value):
484    """Convert a string or integer enum value to an integer.
485
486    If the value is a string, it is converted to the enum value in
487    enum_type with the same name.  If the value is not a string, it's
488    returned as-is.  (No conversion or bounds-checking is done.)
489    """
490    if isinstance(value, six.string_types):
491      try:
492        return enum_type.values_by_name[value].number
493      except KeyError:
494        raise ValueError('Enum type %s: unknown label "%s"' % (
495            enum_type.full_name, value))
496    return value
497
498  def init(self, **kwargs):
499    self._cached_byte_size = 0
500    self._cached_byte_size_dirty = len(kwargs) > 0
501    self._fields = {}
502    # Contains a mapping from oneof field descriptors to the descriptor
503    # of the currently set field in that oneof field.
504    self._oneofs = {}
505
506    # _unknown_fields is () when empty for efficiency, and will be turned into
507    # a list if fields are added.
508    self._unknown_fields = ()
509    # _unknown_field_set is None when empty for efficiency, and will be
510    # turned into UnknownFieldSet struct if fields are added.
511    self._unknown_field_set = None      # pylint: disable=protected-access
512    self._is_present_in_parent = False
513    self._listener = message_listener_mod.NullMessageListener()
514    self._listener_for_children = _Listener(self)
515    for field_name, field_value in kwargs.items():
516      field = _GetFieldByName(message_descriptor, field_name)
517      if field is None:
518        raise TypeError('%s() got an unexpected keyword argument "%s"' %
519                        (message_descriptor.name, field_name))
520      if field_value is None:
521        # field=None is the same as no field at all.
522        continue
523      if field.label == _FieldDescriptor.LABEL_REPEATED:
524        copy = field._default_constructor(self)
525        if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:  # Composite
526          if _IsMapField(field):
527            if _IsMessageMapField(field):
528              for key in field_value:
529                copy[key].MergeFrom(field_value[key])
530            else:
531              copy.update(field_value)
532          else:
533            for val in field_value:
534              if isinstance(val, dict):
535                copy.add(**val)
536              else:
537                copy.add().MergeFrom(val)
538        else:  # Scalar
539          if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
540            field_value = [_GetIntegerEnumValue(field.enum_type, val)
541                           for val in field_value]
542          copy.extend(field_value)
543        self._fields[field] = copy
544      elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
545        copy = field._default_constructor(self)
546        new_val = field_value
547        if isinstance(field_value, dict):
548          new_val = field.message_type._concrete_class(**field_value)
549        try:
550          copy.MergeFrom(new_val)
551        except TypeError:
552          _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
553        self._fields[field] = copy
554      else:
555        if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
556          field_value = _GetIntegerEnumValue(field.enum_type, field_value)
557        try:
558          setattr(self, field_name, field_value)
559        except TypeError:
560          _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
561
562  init.__module__ = None
563  init.__doc__ = None
564  cls.__init__ = init
565
566
567def _GetFieldByName(message_descriptor, field_name):
568  """Returns a field descriptor by field name.
569
570  Args:
571    message_descriptor: A Descriptor describing all fields in message.
572    field_name: The name of the field to retrieve.
573  Returns:
574    The field descriptor associated with the field name.
575  """
576  try:
577    return message_descriptor.fields_by_name[field_name]
578  except KeyError:
579    raise ValueError('Protocol message %s has no "%s" field.' %
580                     (message_descriptor.name, field_name))
581
582
583def _AddPropertiesForFields(descriptor, cls):
584  """Adds properties for all fields in this protocol message type."""
585  for field in descriptor.fields:
586    _AddPropertiesForField(field, cls)
587
588  if descriptor.is_extendable:
589    # _ExtensionDict is just an adaptor with no state so we allocate a new one
590    # every time it is accessed.
591    cls.Extensions = property(lambda self: _ExtensionDict(self))
592
593
594def _AddPropertiesForField(field, cls):
595  """Adds a public property for a protocol message field.
596  Clients can use this property to get and (in the case
597  of non-repeated scalar fields) directly set the value
598  of a protocol message field.
599
600  Args:
601    field: A FieldDescriptor for this field.
602    cls: The class we're constructing.
603  """
604  # Catch it if we add other types that we should
605  # handle specially here.
606  assert _FieldDescriptor.MAX_CPPTYPE == 10
607
608  constant_name = field.name.upper() + '_FIELD_NUMBER'
609  setattr(cls, constant_name, field.number)
610
611  if field.label == _FieldDescriptor.LABEL_REPEATED:
612    _AddPropertiesForRepeatedField(field, cls)
613  elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
614    _AddPropertiesForNonRepeatedCompositeField(field, cls)
615  else:
616    _AddPropertiesForNonRepeatedScalarField(field, cls)
617
618
619class _FieldProperty(property):
620  __slots__ = ('DESCRIPTOR',)
621
622  def __init__(self, descriptor, getter, setter, doc):
623    property.__init__(self, getter, setter, doc=doc)
624    self.DESCRIPTOR = descriptor
625
626
627def _AddPropertiesForRepeatedField(field, cls):
628  """Adds a public property for a "repeated" protocol message field.  Clients
629  can use this property to get the value of the field, which will be either a
630  RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see
631  below).
632
633  Note that when clients add values to these containers, we perform
634  type-checking in the case of repeated scalar fields, and we also set any
635  necessary "has" bits as a side-effect.
636
637  Args:
638    field: A FieldDescriptor for this field.
639    cls: The class we're constructing.
640  """
641  proto_field_name = field.name
642  property_name = _PropertyName(proto_field_name)
643
644  def getter(self):
645    field_value = self._fields.get(field)
646    if field_value is None:
647      # Construct a new object to represent this field.
648      field_value = field._default_constructor(self)
649
650      # Atomically check if another thread has preempted us and, if not, swap
651      # in the new object we just created.  If someone has preempted us, we
652      # take that object and discard ours.
653      # WARNING:  We are relying on setdefault() being atomic.  This is true
654      #   in CPython but we haven't investigated others.  This warning appears
655      #   in several other locations in this file.
656      field_value = self._fields.setdefault(field, field_value)
657    return field_value
658  getter.__module__ = None
659  getter.__doc__ = 'Getter for %s.' % proto_field_name
660
661  # We define a setter just so we can throw an exception with a more
662  # helpful error message.
663  def setter(self, new_value):
664    raise AttributeError('Assignment not allowed to repeated field '
665                         '"%s" in protocol message object.' % proto_field_name)
666
667  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
668  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
669
670
671def _AddPropertiesForNonRepeatedScalarField(field, cls):
672  """Adds a public property for a nonrepeated, scalar protocol message field.
673  Clients can use this property to get and directly set the value of the field.
674  Note that when the client sets the value of a field by using this property,
675  all necessary "has" bits are set as a side-effect, and we also perform
676  type-checking.
677
678  Args:
679    field: A FieldDescriptor for this field.
680    cls: The class we're constructing.
681  """
682  proto_field_name = field.name
683  property_name = _PropertyName(proto_field_name)
684  type_checker = type_checkers.GetTypeChecker(field)
685  default_value = field.default_value
686  is_proto3 = field.containing_type.syntax == 'proto3'
687
688  def getter(self):
689    # TODO(protobuf-team): This may be broken since there may not be
690    # default_value.  Combine with has_default_value somehow.
691    return self._fields.get(field, default_value)
692  getter.__module__ = None
693  getter.__doc__ = 'Getter for %s.' % proto_field_name
694
695  clear_when_set_to_default = is_proto3 and not field.containing_oneof
696
697  def field_setter(self, new_value):
698    # pylint: disable=protected-access
699    # Testing the value for truthiness captures all of the proto3 defaults
700    # (0, 0.0, enum 0, and False).
701    try:
702      new_value = type_checker.CheckValue(new_value)
703    except TypeError as e:
704      raise TypeError(
705          'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
706    if clear_when_set_to_default and not new_value:
707      self._fields.pop(field, None)
708    else:
709      self._fields[field] = new_value
710    # Check _cached_byte_size_dirty inline to improve performance, since scalar
711    # setters are called frequently.
712    if not self._cached_byte_size_dirty:
713      self._Modified()
714
715  if field.containing_oneof:
716    def setter(self, new_value):
717      field_setter(self, new_value)
718      self._UpdateOneofState(field)
719  else:
720    setter = field_setter
721
722  setter.__module__ = None
723  setter.__doc__ = 'Setter for %s.' % proto_field_name
724
725  # Add a property to encapsulate the getter/setter.
726  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
727  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
728
729
730def _AddPropertiesForNonRepeatedCompositeField(field, cls):
731  """Adds a public property for a nonrepeated, composite protocol message field.
732  A composite field is a "group" or "message" field.
733
734  Clients can use this property to get the value of the field, but cannot
735  assign to the property directly.
736
737  Args:
738    field: A FieldDescriptor for this field.
739    cls: The class we're constructing.
740  """
741  # TODO(user): Remove duplication with similar method
742  # for non-repeated scalars.
743  proto_field_name = field.name
744  property_name = _PropertyName(proto_field_name)
745
746  def getter(self):
747    field_value = self._fields.get(field)
748    if field_value is None:
749      # Construct a new object to represent this field.
750      field_value = field._default_constructor(self)
751
752      # Atomically check if another thread has preempted us and, if not, swap
753      # in the new object we just created.  If someone has preempted us, we
754      # take that object and discard ours.
755      # WARNING:  We are relying on setdefault() being atomic.  This is true
756      #   in CPython but we haven't investigated others.  This warning appears
757      #   in several other locations in this file.
758      field_value = self._fields.setdefault(field, field_value)
759    return field_value
760  getter.__module__ = None
761  getter.__doc__ = 'Getter for %s.' % proto_field_name
762
763  # We define a setter just so we can throw an exception with a more
764  # helpful error message.
765  def setter(self, new_value):
766    raise AttributeError('Assignment not allowed to composite field '
767                         '"%s" in protocol message object.' % proto_field_name)
768
769  # Add a property to encapsulate the getter.
770  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
771  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
772
773
774def _AddPropertiesForExtensions(descriptor, cls):
775  """Adds properties for all fields in this protocol message type."""
776  extensions = descriptor.extensions_by_name
777  for extension_name, extension_field in extensions.items():
778    constant_name = extension_name.upper() + '_FIELD_NUMBER'
779    setattr(cls, constant_name, extension_field.number)
780
781  # TODO(user): Migrate all users of these attributes to functions like
782  #   pool.FindExtensionByNumber(descriptor).
783  if descriptor.file is not None:
784    # TODO(user): Use cls.MESSAGE_FACTORY.pool when available.
785    pool = descriptor.file.pool
786    cls._extensions_by_number = pool._extensions_by_number[descriptor]
787    cls._extensions_by_name = pool._extensions_by_name[descriptor]
788
789def _AddStaticMethods(cls):
790  # TODO(user): This probably needs to be thread-safe(?)
791  def RegisterExtension(extension_handle):
792    extension_handle.containing_type = cls.DESCRIPTOR
793    # TODO(user): Use cls.MESSAGE_FACTORY.pool when available.
794    # pylint: disable=protected-access
795    cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(extension_handle)
796    _AttachFieldHelpers(cls, extension_handle)
797  cls.RegisterExtension = staticmethod(RegisterExtension)
798
799  def FromString(s):
800    message = cls()
801    message.MergeFromString(s)
802    return message
803  cls.FromString = staticmethod(FromString)
804
805
806def _IsPresent(item):
807  """Given a (FieldDescriptor, value) tuple from _fields, return true if the
808  value should be included in the list returned by ListFields()."""
809
810  if item[0].label == _FieldDescriptor.LABEL_REPEATED:
811    return bool(item[1])
812  elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
813    return item[1]._is_present_in_parent
814  else:
815    return True
816
817
818def _AddListFieldsMethod(message_descriptor, cls):
819  """Helper for _AddMessageMethods()."""
820
821  def ListFields(self):
822    all_fields = [item for item in self._fields.items() if _IsPresent(item)]
823    all_fields.sort(key = lambda item: item[0].number)
824    return all_fields
825
826  cls.ListFields = ListFields
827
828_PROTO3_ERROR_TEMPLATE = \
829  ('Protocol message %s has no non-repeated submessage field "%s" '
830   'nor marked as optional')
831_PROTO2_ERROR_TEMPLATE = 'Protocol message %s has no non-repeated field "%s"'
832
833def _AddHasFieldMethod(message_descriptor, cls):
834  """Helper for _AddMessageMethods()."""
835
836  is_proto3 = (message_descriptor.syntax == "proto3")
837  error_msg = _PROTO3_ERROR_TEMPLATE if is_proto3 else _PROTO2_ERROR_TEMPLATE
838
839  hassable_fields = {}
840  for field in message_descriptor.fields:
841    if field.label == _FieldDescriptor.LABEL_REPEATED:
842      continue
843    # For proto3, only submessages and fields inside a oneof have presence.
844    if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
845        not field.containing_oneof):
846      continue
847    hassable_fields[field.name] = field
848
849  # Has methods are supported for oneof descriptors.
850  for oneof in message_descriptor.oneofs:
851    hassable_fields[oneof.name] = oneof
852
853  def HasField(self, field_name):
854    try:
855      field = hassable_fields[field_name]
856    except KeyError:
857      raise ValueError(error_msg % (message_descriptor.full_name, field_name))
858
859    if isinstance(field, descriptor_mod.OneofDescriptor):
860      try:
861        return HasField(self, self._oneofs[field].name)
862      except KeyError:
863        return False
864    else:
865      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
866        value = self._fields.get(field)
867        return value is not None and value._is_present_in_parent
868      else:
869        return field in self._fields
870
871  cls.HasField = HasField
872
873
874def _AddClearFieldMethod(message_descriptor, cls):
875  """Helper for _AddMessageMethods()."""
876  def ClearField(self, field_name):
877    try:
878      field = message_descriptor.fields_by_name[field_name]
879    except KeyError:
880      try:
881        field = message_descriptor.oneofs_by_name[field_name]
882        if field in self._oneofs:
883          field = self._oneofs[field]
884        else:
885          return
886      except KeyError:
887        raise ValueError('Protocol message %s has no "%s" field.' %
888                         (message_descriptor.name, field_name))
889
890    if field in self._fields:
891      # To match the C++ implementation, we need to invalidate iterators
892      # for map fields when ClearField() happens.
893      if hasattr(self._fields[field], 'InvalidateIterators'):
894        self._fields[field].InvalidateIterators()
895
896      # Note:  If the field is a sub-message, its listener will still point
897      #   at us.  That's fine, because the worst than can happen is that it
898      #   will call _Modified() and invalidate our byte size.  Big deal.
899      del self._fields[field]
900
901      if self._oneofs.get(field.containing_oneof, None) is field:
902        del self._oneofs[field.containing_oneof]
903
904    # Always call _Modified() -- even if nothing was changed, this is
905    # a mutating method, and thus calling it should cause the field to become
906    # present in the parent message.
907    self._Modified()
908
909  cls.ClearField = ClearField
910
911
912def _AddClearExtensionMethod(cls):
913  """Helper for _AddMessageMethods()."""
914  def ClearExtension(self, extension_handle):
915    extension_dict._VerifyExtensionHandle(self, extension_handle)
916
917    # Similar to ClearField(), above.
918    if extension_handle in self._fields:
919      del self._fields[extension_handle]
920    self._Modified()
921  cls.ClearExtension = ClearExtension
922
923
924def _AddHasExtensionMethod(cls):
925  """Helper for _AddMessageMethods()."""
926  def HasExtension(self, extension_handle):
927    extension_dict._VerifyExtensionHandle(self, extension_handle)
928    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
929      raise KeyError('"%s" is repeated.' % extension_handle.full_name)
930
931    if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
932      value = self._fields.get(extension_handle)
933      return value is not None and value._is_present_in_parent
934    else:
935      return extension_handle in self._fields
936  cls.HasExtension = HasExtension
937
938def _InternalUnpackAny(msg):
939  """Unpacks Any message and returns the unpacked message.
940
941  This internal method is different from public Any Unpack method which takes
942  the target message as argument. _InternalUnpackAny method does not have
943  target message type and need to find the message type in descriptor pool.
944
945  Args:
946    msg: An Any message to be unpacked.
947
948  Returns:
949    The unpacked message.
950  """
951  # TODO(user): Don't use the factory of generated messages.
952  # To make Any work with custom factories, use the message factory of the
953  # parent message.
954  # pylint: disable=g-import-not-at-top
955  from google.protobuf import symbol_database
956  factory = symbol_database.Default()
957
958  type_url = msg.type_url
959
960  if not type_url:
961    return None
962
963  # TODO(user): For now we just strip the hostname.  Better logic will be
964  # required.
965  type_name = type_url.split('/')[-1]
966  descriptor = factory.pool.FindMessageTypeByName(type_name)
967
968  if descriptor is None:
969    return None
970
971  message_class = factory.GetPrototype(descriptor)
972  message = message_class()
973
974  message.ParseFromString(msg.value)
975  return message
976
977
978def _AddEqualsMethod(message_descriptor, cls):
979  """Helper for _AddMessageMethods()."""
980  def __eq__(self, other):
981    if (not isinstance(other, message_mod.Message) or
982        other.DESCRIPTOR != self.DESCRIPTOR):
983      return False
984
985    if self is other:
986      return True
987
988    if self.DESCRIPTOR.full_name == _AnyFullTypeName:
989      any_a = _InternalUnpackAny(self)
990      any_b = _InternalUnpackAny(other)
991      if any_a and any_b:
992        return any_a == any_b
993
994    if not self.ListFields() == other.ListFields():
995      return False
996
997    # TODO(user): Fix UnknownFieldSet to consider MessageSet extensions,
998    # then use it for the comparison.
999    unknown_fields = list(self._unknown_fields)
1000    unknown_fields.sort()
1001    other_unknown_fields = list(other._unknown_fields)
1002    other_unknown_fields.sort()
1003    return unknown_fields == other_unknown_fields
1004
1005  cls.__eq__ = __eq__
1006
1007
1008def _AddStrMethod(message_descriptor, cls):
1009  """Helper for _AddMessageMethods()."""
1010  def __str__(self):
1011    return text_format.MessageToString(self)
1012  cls.__str__ = __str__
1013
1014
1015def _AddReprMethod(message_descriptor, cls):
1016  """Helper for _AddMessageMethods()."""
1017  def __repr__(self):
1018    return text_format.MessageToString(self)
1019  cls.__repr__ = __repr__
1020
1021
1022def _AddUnicodeMethod(unused_message_descriptor, cls):
1023  """Helper for _AddMessageMethods()."""
1024
1025  def __unicode__(self):
1026    return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1027  cls.__unicode__ = __unicode__
1028
1029
1030def _BytesForNonRepeatedElement(value, field_number, field_type):
1031  """Returns the number of bytes needed to serialize a non-repeated element.
1032  The returned byte count includes space for tag information and any
1033  other additional space associated with serializing value.
1034
1035  Args:
1036    value: Value we're serializing.
1037    field_number: Field number of this value.  (Since the field number
1038      is stored as part of a varint-encoded tag, this has an impact
1039      on the total bytes required to serialize the value).
1040    field_type: The type of the field.  One of the TYPE_* constants
1041      within FieldDescriptor.
1042  """
1043  try:
1044    fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1045    return fn(field_number, value)
1046  except KeyError:
1047    raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1048
1049
1050def _AddByteSizeMethod(message_descriptor, cls):
1051  """Helper for _AddMessageMethods()."""
1052
1053  def ByteSize(self):
1054    if not self._cached_byte_size_dirty:
1055      return self._cached_byte_size
1056
1057    size = 0
1058    descriptor = self.DESCRIPTOR
1059    if descriptor.GetOptions().map_entry:
1060      # Fields of map entry should always be serialized.
1061      size = descriptor.fields_by_name['key']._sizer(self.key)
1062      size += descriptor.fields_by_name['value']._sizer(self.value)
1063    else:
1064      for field_descriptor, field_value in self.ListFields():
1065        size += field_descriptor._sizer(field_value)
1066      for tag_bytes, value_bytes in self._unknown_fields:
1067        size += len(tag_bytes) + len(value_bytes)
1068
1069    self._cached_byte_size = size
1070    self._cached_byte_size_dirty = False
1071    self._listener_for_children.dirty = False
1072    return size
1073
1074  cls.ByteSize = ByteSize
1075
1076
1077def _AddSerializeToStringMethod(message_descriptor, cls):
1078  """Helper for _AddMessageMethods()."""
1079
1080  def SerializeToString(self, **kwargs):
1081    # Check if the message has all of its required fields set.
1082    if not self.IsInitialized():
1083      raise message_mod.EncodeError(
1084          'Message %s is missing required fields: %s' % (
1085          self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1086    return self.SerializePartialToString(**kwargs)
1087  cls.SerializeToString = SerializeToString
1088
1089
1090def _AddSerializePartialToStringMethod(message_descriptor, cls):
1091  """Helper for _AddMessageMethods()."""
1092
1093  def SerializePartialToString(self, **kwargs):
1094    out = BytesIO()
1095    self._InternalSerialize(out.write, **kwargs)
1096    return out.getvalue()
1097  cls.SerializePartialToString = SerializePartialToString
1098
1099  def InternalSerialize(self, write_bytes, deterministic=None):
1100    if deterministic is None:
1101      deterministic = (
1102          api_implementation.IsPythonDefaultSerializationDeterministic())
1103    else:
1104      deterministic = bool(deterministic)
1105
1106    descriptor = self.DESCRIPTOR
1107    if descriptor.GetOptions().map_entry:
1108      # Fields of map entry should always be serialized.
1109      descriptor.fields_by_name['key']._encoder(
1110          write_bytes, self.key, deterministic)
1111      descriptor.fields_by_name['value']._encoder(
1112          write_bytes, self.value, deterministic)
1113    else:
1114      for field_descriptor, field_value in self.ListFields():
1115        field_descriptor._encoder(write_bytes, field_value, deterministic)
1116      for tag_bytes, value_bytes in self._unknown_fields:
1117        write_bytes(tag_bytes)
1118        write_bytes(value_bytes)
1119  cls._InternalSerialize = InternalSerialize
1120
1121
1122def _AddMergeFromStringMethod(message_descriptor, cls):
1123  """Helper for _AddMessageMethods()."""
1124  def MergeFromString(self, serialized):
1125    if isinstance(serialized, memoryview) and six.PY2:
1126      raise TypeError(
1127          'memoryview not supported in Python 2 with the pure Python proto '
1128          'implementation: this is to maintain compatibility with the C++ '
1129          'implementation')
1130
1131    serialized = memoryview(serialized)
1132    length = len(serialized)
1133    try:
1134      if self._InternalParse(serialized, 0, length) != length:
1135        # The only reason _InternalParse would return early is if it
1136        # encountered an end-group tag.
1137        raise message_mod.DecodeError('Unexpected end-group tag.')
1138    except (IndexError, TypeError):
1139      # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1140      raise message_mod.DecodeError('Truncated message.')
1141    except struct.error as e:
1142      raise message_mod.DecodeError(e)
1143    return length   # Return this for legacy reasons.
1144  cls.MergeFromString = MergeFromString
1145
1146  local_ReadTag = decoder.ReadTag
1147  local_SkipField = decoder.SkipField
1148  decoders_by_tag = cls._decoders_by_tag
1149
1150  def InternalParse(self, buffer, pos, end):
1151    """Create a message from serialized bytes.
1152
1153    Args:
1154      self: Message, instance of the proto message object.
1155      buffer: memoryview of the serialized data.
1156      pos: int, position to start in the serialized data.
1157      end: int, end position of the serialized data.
1158
1159    Returns:
1160      Message object.
1161    """
1162    # Guard against internal misuse, since this function is called internally
1163    # quite extensively, and its easy to accidentally pass bytes.
1164    assert isinstance(buffer, memoryview)
1165    self._Modified()
1166    field_dict = self._fields
1167    # pylint: disable=protected-access
1168    unknown_field_set = self._unknown_field_set
1169    while pos != end:
1170      (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1171      field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
1172      if field_decoder is None:
1173        if not self._unknown_fields:   # pylint: disable=protected-access
1174          self._unknown_fields = []    # pylint: disable=protected-access
1175        if unknown_field_set is None:
1176          # pylint: disable=protected-access
1177          self._unknown_field_set = containers.UnknownFieldSet()
1178          # pylint: disable=protected-access
1179          unknown_field_set = self._unknown_field_set
1180        # pylint: disable=protected-access
1181        (tag, _) = decoder._DecodeVarint(tag_bytes, 0)
1182        field_number, wire_type = wire_format.UnpackTag(tag)
1183        if field_number == 0:
1184          raise message_mod.DecodeError('Field number 0 is illegal.')
1185        # TODO(user): remove old_pos.
1186        old_pos = new_pos
1187        (data, new_pos) = decoder._DecodeUnknownField(
1188            buffer, new_pos, wire_type)  # pylint: disable=protected-access
1189        if new_pos == -1:
1190          return pos
1191        # pylint: disable=protected-access
1192        unknown_field_set._add(field_number, wire_type, data)
1193        # TODO(user): remove _unknown_fields.
1194        new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
1195        if new_pos == -1:
1196          return pos
1197        self._unknown_fields.append(
1198            (tag_bytes, buffer[old_pos:new_pos].tobytes()))
1199        pos = new_pos
1200      else:
1201        pos = field_decoder(buffer, new_pos, end, self, field_dict)
1202        if field_desc:
1203          self._UpdateOneofState(field_desc)
1204    return pos
1205  cls._InternalParse = InternalParse
1206
1207
1208def _AddIsInitializedMethod(message_descriptor, cls):
1209  """Adds the IsInitialized and FindInitializationError methods to the
1210  protocol message class."""
1211
1212  required_fields = [field for field in message_descriptor.fields
1213                           if field.label == _FieldDescriptor.LABEL_REQUIRED]
1214
1215  def IsInitialized(self, errors=None):
1216    """Checks if all required fields of a message are set.
1217
1218    Args:
1219      errors:  A list which, if provided, will be populated with the field
1220               paths of all missing required fields.
1221
1222    Returns:
1223      True iff the specified message has all required fields set.
1224    """
1225
1226    # Performance is critical so we avoid HasField() and ListFields().
1227
1228    for field in required_fields:
1229      if (field not in self._fields or
1230          (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1231           not self._fields[field]._is_present_in_parent)):
1232        if errors is not None:
1233          errors.extend(self.FindInitializationErrors())
1234        return False
1235
1236    for field, value in list(self._fields.items()):  # dict can change size!
1237      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1238        if field.label == _FieldDescriptor.LABEL_REPEATED:
1239          if (field.message_type.has_options and
1240              field.message_type.GetOptions().map_entry):
1241            continue
1242          for element in value:
1243            if not element.IsInitialized():
1244              if errors is not None:
1245                errors.extend(self.FindInitializationErrors())
1246              return False
1247        elif value._is_present_in_parent and not value.IsInitialized():
1248          if errors is not None:
1249            errors.extend(self.FindInitializationErrors())
1250          return False
1251
1252    return True
1253
1254  cls.IsInitialized = IsInitialized
1255
1256  def FindInitializationErrors(self):
1257    """Finds required fields which are not initialized.
1258
1259    Returns:
1260      A list of strings.  Each string is a path to an uninitialized field from
1261      the top-level message, e.g. "foo.bar[5].baz".
1262    """
1263
1264    errors = []  # simplify things
1265
1266    for field in required_fields:
1267      if not self.HasField(field.name):
1268        errors.append(field.name)
1269
1270    for field, value in self.ListFields():
1271      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1272        if field.is_extension:
1273          name = '(%s)' % field.full_name
1274        else:
1275          name = field.name
1276
1277        if _IsMapField(field):
1278          if _IsMessageMapField(field):
1279            for key in value:
1280              element = value[key]
1281              prefix = '%s[%s].' % (name, key)
1282              sub_errors = element.FindInitializationErrors()
1283              errors += [prefix + error for error in sub_errors]
1284          else:
1285            # ScalarMaps can't have any initialization errors.
1286            pass
1287        elif field.label == _FieldDescriptor.LABEL_REPEATED:
1288          for i in range(len(value)):
1289            element = value[i]
1290            prefix = '%s[%d].' % (name, i)
1291            sub_errors = element.FindInitializationErrors()
1292            errors += [prefix + error for error in sub_errors]
1293        else:
1294          prefix = name + '.'
1295          sub_errors = value.FindInitializationErrors()
1296          errors += [prefix + error for error in sub_errors]
1297
1298    return errors
1299
1300  cls.FindInitializationErrors = FindInitializationErrors
1301
1302
1303def _AddMergeFromMethod(cls):
1304  LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1305  CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1306
1307  def MergeFrom(self, msg):
1308    if not isinstance(msg, cls):
1309      raise TypeError(
1310          'Parameter to MergeFrom() must be instance of same class: '
1311          'expected %s got %s.' % (cls.__name__, msg.__class__.__name__))
1312
1313    assert msg is not self
1314    self._Modified()
1315
1316    fields = self._fields
1317
1318    for field, value in msg._fields.items():
1319      if field.label == LABEL_REPEATED:
1320        field_value = fields.get(field)
1321        if field_value is None:
1322          # Construct a new object to represent this field.
1323          field_value = field._default_constructor(self)
1324          fields[field] = field_value
1325        field_value.MergeFrom(value)
1326      elif field.cpp_type == CPPTYPE_MESSAGE:
1327        if value._is_present_in_parent:
1328          field_value = fields.get(field)
1329          if field_value is None:
1330            # Construct a new object to represent this field.
1331            field_value = field._default_constructor(self)
1332            fields[field] = field_value
1333          field_value.MergeFrom(value)
1334      else:
1335        self._fields[field] = value
1336        if field.containing_oneof:
1337          self._UpdateOneofState(field)
1338
1339    if msg._unknown_fields:
1340      if not self._unknown_fields:
1341        self._unknown_fields = []
1342      self._unknown_fields.extend(msg._unknown_fields)
1343      # pylint: disable=protected-access
1344      if self._unknown_field_set is None:
1345        self._unknown_field_set = containers.UnknownFieldSet()
1346      self._unknown_field_set._extend(msg._unknown_field_set)
1347
1348  cls.MergeFrom = MergeFrom
1349
1350
1351def _AddWhichOneofMethod(message_descriptor, cls):
1352  def WhichOneof(self, oneof_name):
1353    """Returns the name of the currently set field inside a oneof, or None."""
1354    try:
1355      field = message_descriptor.oneofs_by_name[oneof_name]
1356    except KeyError:
1357      raise ValueError(
1358          'Protocol message has no oneof "%s" field.' % oneof_name)
1359
1360    nested_field = self._oneofs.get(field, None)
1361    if nested_field is not None and self.HasField(nested_field.name):
1362      return nested_field.name
1363    else:
1364      return None
1365
1366  cls.WhichOneof = WhichOneof
1367
1368
1369def _Clear(self):
1370  # Clear fields.
1371  self._fields = {}
1372  self._unknown_fields = ()
1373  # pylint: disable=protected-access
1374  if self._unknown_field_set is not None:
1375    self._unknown_field_set._clear()
1376    self._unknown_field_set = None
1377
1378  self._oneofs = {}
1379  self._Modified()
1380
1381
1382def _UnknownFields(self):
1383  if self._unknown_field_set is None:  # pylint: disable=protected-access
1384    # pylint: disable=protected-access
1385    self._unknown_field_set = containers.UnknownFieldSet()
1386  return self._unknown_field_set    # pylint: disable=protected-access
1387
1388
1389def _DiscardUnknownFields(self):
1390  self._unknown_fields = []
1391  self._unknown_field_set = None      # pylint: disable=protected-access
1392  for field, value in self.ListFields():
1393    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1394      if _IsMapField(field):
1395        if _IsMessageMapField(field):
1396          for key in value:
1397            value[key].DiscardUnknownFields()
1398      elif field.label == _FieldDescriptor.LABEL_REPEATED:
1399        for sub_message in value:
1400          sub_message.DiscardUnknownFields()
1401      else:
1402        value.DiscardUnknownFields()
1403
1404
1405def _SetListener(self, listener):
1406  if listener is None:
1407    self._listener = message_listener_mod.NullMessageListener()
1408  else:
1409    self._listener = listener
1410
1411
1412def _AddMessageMethods(message_descriptor, cls):
1413  """Adds implementations of all Message methods to cls."""
1414  _AddListFieldsMethod(message_descriptor, cls)
1415  _AddHasFieldMethod(message_descriptor, cls)
1416  _AddClearFieldMethod(message_descriptor, cls)
1417  if message_descriptor.is_extendable:
1418    _AddClearExtensionMethod(cls)
1419    _AddHasExtensionMethod(cls)
1420  _AddEqualsMethod(message_descriptor, cls)
1421  _AddStrMethod(message_descriptor, cls)
1422  _AddReprMethod(message_descriptor, cls)
1423  _AddUnicodeMethod(message_descriptor, cls)
1424  _AddByteSizeMethod(message_descriptor, cls)
1425  _AddSerializeToStringMethod(message_descriptor, cls)
1426  _AddSerializePartialToStringMethod(message_descriptor, cls)
1427  _AddMergeFromStringMethod(message_descriptor, cls)
1428  _AddIsInitializedMethod(message_descriptor, cls)
1429  _AddMergeFromMethod(cls)
1430  _AddWhichOneofMethod(message_descriptor, cls)
1431  # Adds methods which do not depend on cls.
1432  cls.Clear = _Clear
1433  cls.UnknownFields = _UnknownFields
1434  cls.DiscardUnknownFields = _DiscardUnknownFields
1435  cls._SetListener = _SetListener
1436
1437
1438def _AddPrivateHelperMethods(message_descriptor, cls):
1439  """Adds implementation of private helper methods to cls."""
1440
1441  def Modified(self):
1442    """Sets the _cached_byte_size_dirty bit to true,
1443    and propagates this to our listener iff this was a state change.
1444    """
1445
1446    # Note:  Some callers check _cached_byte_size_dirty before calling
1447    #   _Modified() as an extra optimization.  So, if this method is ever
1448    #   changed such that it does stuff even when _cached_byte_size_dirty is
1449    #   already true, the callers need to be updated.
1450    if not self._cached_byte_size_dirty:
1451      self._cached_byte_size_dirty = True
1452      self._listener_for_children.dirty = True
1453      self._is_present_in_parent = True
1454      self._listener.Modified()
1455
1456  def _UpdateOneofState(self, field):
1457    """Sets field as the active field in its containing oneof.
1458
1459    Will also delete currently active field in the oneof, if it is different
1460    from the argument. Does not mark the message as modified.
1461    """
1462    other_field = self._oneofs.setdefault(field.containing_oneof, field)
1463    if other_field is not field:
1464      del self._fields[other_field]
1465      self._oneofs[field.containing_oneof] = field
1466
1467  cls._Modified = Modified
1468  cls.SetInParent = Modified
1469  cls._UpdateOneofState = _UpdateOneofState
1470
1471
1472class _Listener(object):
1473
1474  """MessageListener implementation that a parent message registers with its
1475  child message.
1476
1477  In order to support semantics like:
1478
1479    foo.bar.baz.qux = 23
1480    assert foo.HasField('bar')
1481
1482  ...child objects must have back references to their parents.
1483  This helper class is at the heart of this support.
1484  """
1485
1486  def __init__(self, parent_message):
1487    """Args:
1488      parent_message: The message whose _Modified() method we should call when
1489        we receive Modified() messages.
1490    """
1491    # This listener establishes a back reference from a child (contained) object
1492    # to its parent (containing) object.  We make this a weak reference to avoid
1493    # creating cyclic garbage when the client finishes with the 'parent' object
1494    # in the tree.
1495    if isinstance(parent_message, weakref.ProxyType):
1496      self._parent_message_weakref = parent_message
1497    else:
1498      self._parent_message_weakref = weakref.proxy(parent_message)
1499
1500    # As an optimization, we also indicate directly on the listener whether
1501    # or not the parent message is dirty.  This way we can avoid traversing
1502    # up the tree in the common case.
1503    self.dirty = False
1504
1505  def Modified(self):
1506    if self.dirty:
1507      return
1508    try:
1509      # Propagate the signal to our parents iff this is the first field set.
1510      self._parent_message_weakref._Modified()
1511    except ReferenceError:
1512      # We can get here if a client has kept a reference to a child object,
1513      # and is now setting a field on it, but the child's parent has been
1514      # garbage-collected.  This is not an error.
1515      pass
1516
1517
1518class _OneofListener(_Listener):
1519  """Special listener implementation for setting composite oneof fields."""
1520
1521  def __init__(self, parent_message, field):
1522    """Args:
1523      parent_message: The message whose _Modified() method we should call when
1524        we receive Modified() messages.
1525      field: The descriptor of the field being set in the parent message.
1526    """
1527    super(_OneofListener, self).__init__(parent_message)
1528    self._field = field
1529
1530  def Modified(self):
1531    """Also updates the state of the containing oneof in the parent message."""
1532    try:
1533      self._parent_message_weakref._UpdateOneofState(self._field)
1534      super(_OneofListener, self).Modified()
1535    except ReferenceError:
1536      pass
1537