1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31# This code is meant to work on Python 2.4 and above only. 32# 33# TODO(user): Helpers for verbose, common checks like seeing if a 34# descriptor's cpp_type is CPPTYPE_MESSAGE. 35 36"""Contains a metaclass and helper functions used to create 37protocol message classes from Descriptor objects at runtime. 38 39Recall that a metaclass is the "type" of a class. 40(A class is to a metaclass what an instance is to a class.) 41 42In this case, we use the GeneratedProtocolMessageType metaclass 43to inject all the useful functionality into the classes 44output by the protocol compiler at compile-time. 45 46The upshot of all this is that the real implementation 47details for ALL pure-Python protocol buffers are *here in 48this file*. 49""" 50 51__author__ = 'robinson@google.com (Will Robinson)' 52 53from io import BytesIO 54import struct 55import sys 56import weakref 57 58import six 59from six.moves import range 60 61# We use "as" to avoid name collisions with variables. 62from google.protobuf.internal import api_implementation 63from google.protobuf.internal import containers 64from google.protobuf.internal import decoder 65from google.protobuf.internal import encoder 66from google.protobuf.internal import enum_type_wrapper 67from google.protobuf.internal import extension_dict 68from google.protobuf.internal import message_listener as message_listener_mod 69from google.protobuf.internal import type_checkers 70from google.protobuf.internal import well_known_types 71from google.protobuf.internal import wire_format 72from google.protobuf import descriptor as descriptor_mod 73from google.protobuf import message as message_mod 74from google.protobuf import text_format 75 76_FieldDescriptor = descriptor_mod.FieldDescriptor 77_AnyFullTypeName = 'google.protobuf.Any' 78_ExtensionDict = extension_dict._ExtensionDict 79 80class GeneratedProtocolMessageType(type): 81 82 """Metaclass for protocol message classes created at runtime from Descriptors. 83 84 We add implementations for all methods described in the Message class. We 85 also create properties to allow getting/setting all fields in the protocol 86 message. Finally, we create slots to prevent users from accidentally 87 "setting" nonexistent fields in the protocol message, which then wouldn't get 88 serialized / deserialized properly. 89 90 The protocol compiler currently uses this metaclass to create protocol 91 message classes at runtime. Clients can also manually create their own 92 classes at runtime, as in this example: 93 94 mydescriptor = Descriptor(.....) 95 factory = symbol_database.Default() 96 factory.pool.AddDescriptor(mydescriptor) 97 MyProtoClass = factory.GetPrototype(mydescriptor) 98 myproto_instance = MyProtoClass() 99 myproto.foo_field = 23 100 ... 101 """ 102 103 # Must be consistent with the protocol-compiler code in 104 # proto2/compiler/internal/generator.*. 105 _DESCRIPTOR_KEY = 'DESCRIPTOR' 106 107 def __new__(cls, name, bases, dictionary): 108 """Custom allocation for runtime-generated class types. 109 110 We override __new__ because this is apparently the only place 111 where we can meaningfully set __slots__ on the class we're creating(?). 112 (The interplay between metaclasses and slots is not very well-documented). 113 114 Args: 115 name: Name of the class (ignored, but required by the 116 metaclass protocol). 117 bases: Base classes of the class we're constructing. 118 (Should be message.Message). We ignore this field, but 119 it's required by the metaclass protocol 120 dictionary: The class dictionary of the class we're 121 constructing. dictionary[_DESCRIPTOR_KEY] must contain 122 a Descriptor object describing this protocol message 123 type. 124 125 Returns: 126 Newly-allocated class. 127 128 Raises: 129 RuntimeError: Generated code only work with python cpp extension. 130 """ 131 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 132 133 if isinstance(descriptor, str): 134 raise RuntimeError('The generated code only work with python cpp ' 135 'extension, but it is using pure python runtime.') 136 137 # If a concrete class already exists for this descriptor, don't try to 138 # create another. Doing so will break any messages that already exist with 139 # the existing class. 140 # 141 # The C++ implementation appears to have its own internal `PyMessageFactory` 142 # to achieve similar results. 143 # 144 # This most commonly happens in `text_format.py` when using descriptors from 145 # a custom pool; it calls symbol_database.Global().getPrototype() on a 146 # descriptor which already has an existing concrete class. 147 new_class = getattr(descriptor, '_concrete_class', None) 148 if new_class: 149 return new_class 150 151 if descriptor.full_name in well_known_types.WKTBASES: 152 bases += (well_known_types.WKTBASES[descriptor.full_name],) 153 _AddClassAttributesForNestedExtensions(descriptor, dictionary) 154 _AddSlots(descriptor, dictionary) 155 156 superclass = super(GeneratedProtocolMessageType, cls) 157 new_class = superclass.__new__(cls, name, bases, dictionary) 158 return new_class 159 160 def __init__(cls, name, bases, dictionary): 161 """Here we perform the majority of our work on the class. 162 We add enum getters, an __init__ method, implementations 163 of all Message methods, and properties for all fields 164 in the protocol type. 165 166 Args: 167 name: Name of the class (ignored, but required by the 168 metaclass protocol). 169 bases: Base classes of the class we're constructing. 170 (Should be message.Message). We ignore this field, but 171 it's required by the metaclass protocol 172 dictionary: The class dictionary of the class we're 173 constructing. dictionary[_DESCRIPTOR_KEY] must contain 174 a Descriptor object describing this protocol message 175 type. 176 """ 177 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 178 179 # If this is an _existing_ class looked up via `_concrete_class` in the 180 # __new__ method above, then we don't need to re-initialize anything. 181 existing_class = getattr(descriptor, '_concrete_class', None) 182 if existing_class: 183 assert existing_class is cls, ( 184 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r' 185 % (descriptor.full_name)) 186 return 187 188 cls._decoders_by_tag = {} 189 if (descriptor.has_options and 190 descriptor.GetOptions().message_set_wire_format): 191 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( 192 decoder.MessageSetItemDecoder(descriptor), None) 193 194 # Attach stuff to each FieldDescriptor for quick lookup later on. 195 for field in descriptor.fields: 196 _AttachFieldHelpers(cls, field) 197 198 descriptor._concrete_class = cls # pylint: disable=protected-access 199 _AddEnumValues(descriptor, cls) 200 _AddInitMethod(descriptor, cls) 201 _AddPropertiesForFields(descriptor, cls) 202 _AddPropertiesForExtensions(descriptor, cls) 203 _AddStaticMethods(cls) 204 _AddMessageMethods(descriptor, cls) 205 _AddPrivateHelperMethods(descriptor, cls) 206 207 superclass = super(GeneratedProtocolMessageType, cls) 208 superclass.__init__(name, bases, dictionary) 209 210 211# Stateless helpers for GeneratedProtocolMessageType below. 212# Outside clients should not access these directly. 213# 214# I opted not to make any of these methods on the metaclass, to make it more 215# clear that I'm not really using any state there and to keep clients from 216# thinking that they have direct access to these construction helpers. 217 218 219def _PropertyName(proto_field_name): 220 """Returns the name of the public property attribute which 221 clients can use to get and (in some cases) set the value 222 of a protocol message field. 223 224 Args: 225 proto_field_name: The protocol message field name, exactly 226 as it appears (or would appear) in a .proto file. 227 """ 228 # TODO(user): Escape Python keywords (e.g., yield), and test this support. 229 # nnorwitz makes my day by writing: 230 # """ 231 # FYI. See the keyword module in the stdlib. This could be as simple as: 232 # 233 # if keyword.iskeyword(proto_field_name): 234 # return proto_field_name + "_" 235 # return proto_field_name 236 # """ 237 # Kenton says: The above is a BAD IDEA. People rely on being able to use 238 # getattr() and setattr() to reflectively manipulate field values. If we 239 # rename the properties, then every such user has to also make sure to apply 240 # the same transformation. Note that currently if you name a field "yield", 241 # you can still access it just fine using getattr/setattr -- it's not even 242 # that cumbersome to do so. 243 # TODO(user): Remove this method entirely if/when everyone agrees with my 244 # position. 245 return proto_field_name 246 247 248def _AddSlots(message_descriptor, dictionary): 249 """Adds a __slots__ entry to dictionary, containing the names of all valid 250 attributes for this message type. 251 252 Args: 253 message_descriptor: A Descriptor instance describing this message type. 254 dictionary: Class dictionary to which we'll add a '__slots__' entry. 255 """ 256 dictionary['__slots__'] = ['_cached_byte_size', 257 '_cached_byte_size_dirty', 258 '_fields', 259 '_unknown_fields', 260 '_unknown_field_set', 261 '_is_present_in_parent', 262 '_listener', 263 '_listener_for_children', 264 '__weakref__', 265 '_oneofs'] 266 267 268def _IsMessageSetExtension(field): 269 return (field.is_extension and 270 field.containing_type.has_options and 271 field.containing_type.GetOptions().message_set_wire_format and 272 field.type == _FieldDescriptor.TYPE_MESSAGE and 273 field.label == _FieldDescriptor.LABEL_OPTIONAL) 274 275 276def _IsMapField(field): 277 return (field.type == _FieldDescriptor.TYPE_MESSAGE and 278 field.message_type.has_options and 279 field.message_type.GetOptions().map_entry) 280 281 282def _IsMessageMapField(field): 283 value_type = field.message_type.fields_by_name['value'] 284 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE 285 286 287def _IsStrictUtf8Check(field): 288 if field.containing_type.syntax != 'proto3': 289 return False 290 enforce_utf8 = True 291 return enforce_utf8 292 293 294def _AttachFieldHelpers(cls, field_descriptor): 295 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 296 is_packable = (is_repeated and 297 wire_format.IsTypePackable(field_descriptor.type)) 298 if not is_packable: 299 is_packed = False 300 elif field_descriptor.containing_type.syntax == 'proto2': 301 is_packed = (field_descriptor.has_options and 302 field_descriptor.GetOptions().packed) 303 else: 304 has_packed_false = (field_descriptor.has_options and 305 field_descriptor.GetOptions().HasField('packed') and 306 field_descriptor.GetOptions().packed == False) 307 is_packed = not has_packed_false 308 is_map_entry = _IsMapField(field_descriptor) 309 310 if is_map_entry: 311 field_encoder = encoder.MapEncoder(field_descriptor) 312 sizer = encoder.MapSizer(field_descriptor, 313 _IsMessageMapField(field_descriptor)) 314 elif _IsMessageSetExtension(field_descriptor): 315 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 316 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 317 else: 318 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 319 field_descriptor.number, is_repeated, is_packed) 320 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 321 field_descriptor.number, is_repeated, is_packed) 322 323 field_descriptor._encoder = field_encoder 324 field_descriptor._sizer = sizer 325 field_descriptor._default_constructor = _DefaultValueConstructorForField( 326 field_descriptor) 327 328 def AddDecoder(wiretype, is_packed): 329 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 330 decode_type = field_descriptor.type 331 if (decode_type == _FieldDescriptor.TYPE_ENUM and 332 type_checkers.SupportsOpenEnums(field_descriptor)): 333 decode_type = _FieldDescriptor.TYPE_INT32 334 335 oneof_descriptor = None 336 if field_descriptor.containing_oneof is not None: 337 oneof_descriptor = field_descriptor 338 339 if is_map_entry: 340 is_message_map = _IsMessageMapField(field_descriptor) 341 342 field_decoder = decoder.MapDecoder( 343 field_descriptor, _GetInitializeDefaultForMap(field_descriptor), 344 is_message_map) 345 elif decode_type == _FieldDescriptor.TYPE_STRING: 346 is_strict_utf8_check = _IsStrictUtf8Check(field_descriptor) 347 field_decoder = decoder.StringDecoder( 348 field_descriptor.number, is_repeated, is_packed, 349 field_descriptor, field_descriptor._default_constructor, 350 is_strict_utf8_check) 351 else: 352 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 353 field_descriptor.number, is_repeated, is_packed, 354 field_descriptor, field_descriptor._default_constructor) 355 356 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) 357 358 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], 359 False) 360 361 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 362 # To support wire compatibility of adding packed = true, add a decoder for 363 # packed values regardless of the field's options. 364 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 365 366 367def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 368 extensions = descriptor.extensions_by_name 369 for extension_name, extension_field in extensions.items(): 370 assert extension_name not in dictionary 371 dictionary[extension_name] = extension_field 372 373 374def _AddEnumValues(descriptor, cls): 375 """Sets class-level attributes for all enum fields defined in this message. 376 377 Also exporting a class-level object that can name enum values. 378 379 Args: 380 descriptor: Descriptor object for this message type. 381 cls: Class we're constructing for this message type. 382 """ 383 for enum_type in descriptor.enum_types: 384 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) 385 for enum_value in enum_type.values: 386 setattr(cls, enum_value.name, enum_value.number) 387 388 389def _GetInitializeDefaultForMap(field): 390 if field.label != _FieldDescriptor.LABEL_REPEATED: 391 raise ValueError('map_entry set on non-repeated field %s' % ( 392 field.name)) 393 fields_by_name = field.message_type.fields_by_name 394 key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) 395 396 value_field = fields_by_name['value'] 397 if _IsMessageMapField(field): 398 def MakeMessageMapDefault(message): 399 return containers.MessageMap( 400 message._listener_for_children, value_field.message_type, key_checker, 401 field.message_type) 402 return MakeMessageMapDefault 403 else: 404 value_checker = type_checkers.GetTypeChecker(value_field) 405 def MakePrimitiveMapDefault(message): 406 return containers.ScalarMap( 407 message._listener_for_children, key_checker, value_checker, 408 field.message_type) 409 return MakePrimitiveMapDefault 410 411def _DefaultValueConstructorForField(field): 412 """Returns a function which returns a default value for a field. 413 414 Args: 415 field: FieldDescriptor object for this field. 416 417 The returned function has one argument: 418 message: Message instance containing this field, or a weakref proxy 419 of same. 420 421 That function in turn returns a default value for this field. The default 422 value may refer back to |message| via a weak reference. 423 """ 424 425 if _IsMapField(field): 426 return _GetInitializeDefaultForMap(field) 427 428 if field.label == _FieldDescriptor.LABEL_REPEATED: 429 if field.has_default_value and field.default_value != []: 430 raise ValueError('Repeated field default value not empty list: %s' % ( 431 field.default_value)) 432 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 433 # We can't look at _concrete_class yet since it might not have 434 # been set. (Depends on order in which we initialize the classes). 435 message_type = field.message_type 436 def MakeRepeatedMessageDefault(message): 437 return containers.RepeatedCompositeFieldContainer( 438 message._listener_for_children, field.message_type) 439 return MakeRepeatedMessageDefault 440 else: 441 type_checker = type_checkers.GetTypeChecker(field) 442 def MakeRepeatedScalarDefault(message): 443 return containers.RepeatedScalarFieldContainer( 444 message._listener_for_children, type_checker) 445 return MakeRepeatedScalarDefault 446 447 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 448 # _concrete_class may not yet be initialized. 449 message_type = field.message_type 450 def MakeSubMessageDefault(message): 451 assert getattr(message_type, '_concrete_class', None), ( 452 'Uninitialized concrete class found for field %r (message type %r)' 453 % (field.full_name, message_type.full_name)) 454 result = message_type._concrete_class() 455 result._SetListener( 456 _OneofListener(message, field) 457 if field.containing_oneof is not None 458 else message._listener_for_children) 459 return result 460 return MakeSubMessageDefault 461 462 def MakeScalarDefault(message): 463 # TODO(protobuf-team): This may be broken since there may not be 464 # default_value. Combine with has_default_value somehow. 465 return field.default_value 466 return MakeScalarDefault 467 468 469def _ReraiseTypeErrorWithFieldName(message_name, field_name): 470 """Re-raise the currently-handled TypeError with the field name added.""" 471 exc = sys.exc_info()[1] 472 if len(exc.args) == 1 and type(exc) is TypeError: 473 # simple TypeError; add field name to exception message 474 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) 475 476 # re-raise possibly-amended exception with original traceback: 477 six.reraise(type(exc), exc, sys.exc_info()[2]) 478 479 480def _AddInitMethod(message_descriptor, cls): 481 """Adds an __init__ method to cls.""" 482 483 def _GetIntegerEnumValue(enum_type, value): 484 """Convert a string or integer enum value to an integer. 485 486 If the value is a string, it is converted to the enum value in 487 enum_type with the same name. If the value is not a string, it's 488 returned as-is. (No conversion or bounds-checking is done.) 489 """ 490 if isinstance(value, six.string_types): 491 try: 492 return enum_type.values_by_name[value].number 493 except KeyError: 494 raise ValueError('Enum type %s: unknown label "%s"' % ( 495 enum_type.full_name, value)) 496 return value 497 498 def init(self, **kwargs): 499 self._cached_byte_size = 0 500 self._cached_byte_size_dirty = len(kwargs) > 0 501 self._fields = {} 502 # Contains a mapping from oneof field descriptors to the descriptor 503 # of the currently set field in that oneof field. 504 self._oneofs = {} 505 506 # _unknown_fields is () when empty for efficiency, and will be turned into 507 # a list if fields are added. 508 self._unknown_fields = () 509 # _unknown_field_set is None when empty for efficiency, and will be 510 # turned into UnknownFieldSet struct if fields are added. 511 self._unknown_field_set = None # pylint: disable=protected-access 512 self._is_present_in_parent = False 513 self._listener = message_listener_mod.NullMessageListener() 514 self._listener_for_children = _Listener(self) 515 for field_name, field_value in kwargs.items(): 516 field = _GetFieldByName(message_descriptor, field_name) 517 if field is None: 518 raise TypeError('%s() got an unexpected keyword argument "%s"' % 519 (message_descriptor.name, field_name)) 520 if field_value is None: 521 # field=None is the same as no field at all. 522 continue 523 if field.label == _FieldDescriptor.LABEL_REPEATED: 524 copy = field._default_constructor(self) 525 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 526 if _IsMapField(field): 527 if _IsMessageMapField(field): 528 for key in field_value: 529 copy[key].MergeFrom(field_value[key]) 530 else: 531 copy.update(field_value) 532 else: 533 for val in field_value: 534 if isinstance(val, dict): 535 copy.add(**val) 536 else: 537 copy.add().MergeFrom(val) 538 else: # Scalar 539 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 540 field_value = [_GetIntegerEnumValue(field.enum_type, val) 541 for val in field_value] 542 copy.extend(field_value) 543 self._fields[field] = copy 544 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 545 copy = field._default_constructor(self) 546 new_val = field_value 547 if isinstance(field_value, dict): 548 new_val = field.message_type._concrete_class(**field_value) 549 try: 550 copy.MergeFrom(new_val) 551 except TypeError: 552 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 553 self._fields[field] = copy 554 else: 555 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 556 field_value = _GetIntegerEnumValue(field.enum_type, field_value) 557 try: 558 setattr(self, field_name, field_value) 559 except TypeError: 560 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 561 562 init.__module__ = None 563 init.__doc__ = None 564 cls.__init__ = init 565 566 567def _GetFieldByName(message_descriptor, field_name): 568 """Returns a field descriptor by field name. 569 570 Args: 571 message_descriptor: A Descriptor describing all fields in message. 572 field_name: The name of the field to retrieve. 573 Returns: 574 The field descriptor associated with the field name. 575 """ 576 try: 577 return message_descriptor.fields_by_name[field_name] 578 except KeyError: 579 raise ValueError('Protocol message %s has no "%s" field.' % 580 (message_descriptor.name, field_name)) 581 582 583def _AddPropertiesForFields(descriptor, cls): 584 """Adds properties for all fields in this protocol message type.""" 585 for field in descriptor.fields: 586 _AddPropertiesForField(field, cls) 587 588 if descriptor.is_extendable: 589 # _ExtensionDict is just an adaptor with no state so we allocate a new one 590 # every time it is accessed. 591 cls.Extensions = property(lambda self: _ExtensionDict(self)) 592 593 594def _AddPropertiesForField(field, cls): 595 """Adds a public property for a protocol message field. 596 Clients can use this property to get and (in the case 597 of non-repeated scalar fields) directly set the value 598 of a protocol message field. 599 600 Args: 601 field: A FieldDescriptor for this field. 602 cls: The class we're constructing. 603 """ 604 # Catch it if we add other types that we should 605 # handle specially here. 606 assert _FieldDescriptor.MAX_CPPTYPE == 10 607 608 constant_name = field.name.upper() + '_FIELD_NUMBER' 609 setattr(cls, constant_name, field.number) 610 611 if field.label == _FieldDescriptor.LABEL_REPEATED: 612 _AddPropertiesForRepeatedField(field, cls) 613 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 614 _AddPropertiesForNonRepeatedCompositeField(field, cls) 615 else: 616 _AddPropertiesForNonRepeatedScalarField(field, cls) 617 618 619class _FieldProperty(property): 620 __slots__ = ('DESCRIPTOR',) 621 622 def __init__(self, descriptor, getter, setter, doc): 623 property.__init__(self, getter, setter, doc=doc) 624 self.DESCRIPTOR = descriptor 625 626 627def _AddPropertiesForRepeatedField(field, cls): 628 """Adds a public property for a "repeated" protocol message field. Clients 629 can use this property to get the value of the field, which will be either a 630 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see 631 below). 632 633 Note that when clients add values to these containers, we perform 634 type-checking in the case of repeated scalar fields, and we also set any 635 necessary "has" bits as a side-effect. 636 637 Args: 638 field: A FieldDescriptor for this field. 639 cls: The class we're constructing. 640 """ 641 proto_field_name = field.name 642 property_name = _PropertyName(proto_field_name) 643 644 def getter(self): 645 field_value = self._fields.get(field) 646 if field_value is None: 647 # Construct a new object to represent this field. 648 field_value = field._default_constructor(self) 649 650 # Atomically check if another thread has preempted us and, if not, swap 651 # in the new object we just created. If someone has preempted us, we 652 # take that object and discard ours. 653 # WARNING: We are relying on setdefault() being atomic. This is true 654 # in CPython but we haven't investigated others. This warning appears 655 # in several other locations in this file. 656 field_value = self._fields.setdefault(field, field_value) 657 return field_value 658 getter.__module__ = None 659 getter.__doc__ = 'Getter for %s.' % proto_field_name 660 661 # We define a setter just so we can throw an exception with a more 662 # helpful error message. 663 def setter(self, new_value): 664 raise AttributeError('Assignment not allowed to repeated field ' 665 '"%s" in protocol message object.' % proto_field_name) 666 667 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 668 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 669 670 671def _AddPropertiesForNonRepeatedScalarField(field, cls): 672 """Adds a public property for a nonrepeated, scalar protocol message field. 673 Clients can use this property to get and directly set the value of the field. 674 Note that when the client sets the value of a field by using this property, 675 all necessary "has" bits are set as a side-effect, and we also perform 676 type-checking. 677 678 Args: 679 field: A FieldDescriptor for this field. 680 cls: The class we're constructing. 681 """ 682 proto_field_name = field.name 683 property_name = _PropertyName(proto_field_name) 684 type_checker = type_checkers.GetTypeChecker(field) 685 default_value = field.default_value 686 is_proto3 = field.containing_type.syntax == 'proto3' 687 688 def getter(self): 689 # TODO(protobuf-team): This may be broken since there may not be 690 # default_value. Combine with has_default_value somehow. 691 return self._fields.get(field, default_value) 692 getter.__module__ = None 693 getter.__doc__ = 'Getter for %s.' % proto_field_name 694 695 clear_when_set_to_default = is_proto3 and not field.containing_oneof 696 697 def field_setter(self, new_value): 698 # pylint: disable=protected-access 699 # Testing the value for truthiness captures all of the proto3 defaults 700 # (0, 0.0, enum 0, and False). 701 try: 702 new_value = type_checker.CheckValue(new_value) 703 except TypeError as e: 704 raise TypeError( 705 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e)) 706 if clear_when_set_to_default and not new_value: 707 self._fields.pop(field, None) 708 else: 709 self._fields[field] = new_value 710 # Check _cached_byte_size_dirty inline to improve performance, since scalar 711 # setters are called frequently. 712 if not self._cached_byte_size_dirty: 713 self._Modified() 714 715 if field.containing_oneof: 716 def setter(self, new_value): 717 field_setter(self, new_value) 718 self._UpdateOneofState(field) 719 else: 720 setter = field_setter 721 722 setter.__module__ = None 723 setter.__doc__ = 'Setter for %s.' % proto_field_name 724 725 # Add a property to encapsulate the getter/setter. 726 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 727 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 728 729 730def _AddPropertiesForNonRepeatedCompositeField(field, cls): 731 """Adds a public property for a nonrepeated, composite protocol message field. 732 A composite field is a "group" or "message" field. 733 734 Clients can use this property to get the value of the field, but cannot 735 assign to the property directly. 736 737 Args: 738 field: A FieldDescriptor for this field. 739 cls: The class we're constructing. 740 """ 741 # TODO(user): Remove duplication with similar method 742 # for non-repeated scalars. 743 proto_field_name = field.name 744 property_name = _PropertyName(proto_field_name) 745 746 def getter(self): 747 field_value = self._fields.get(field) 748 if field_value is None: 749 # Construct a new object to represent this field. 750 field_value = field._default_constructor(self) 751 752 # Atomically check if another thread has preempted us and, if not, swap 753 # in the new object we just created. If someone has preempted us, we 754 # take that object and discard ours. 755 # WARNING: We are relying on setdefault() being atomic. This is true 756 # in CPython but we haven't investigated others. This warning appears 757 # in several other locations in this file. 758 field_value = self._fields.setdefault(field, field_value) 759 return field_value 760 getter.__module__ = None 761 getter.__doc__ = 'Getter for %s.' % proto_field_name 762 763 # We define a setter just so we can throw an exception with a more 764 # helpful error message. 765 def setter(self, new_value): 766 raise AttributeError('Assignment not allowed to composite field ' 767 '"%s" in protocol message object.' % proto_field_name) 768 769 # Add a property to encapsulate the getter. 770 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 771 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 772 773 774def _AddPropertiesForExtensions(descriptor, cls): 775 """Adds properties for all fields in this protocol message type.""" 776 extensions = descriptor.extensions_by_name 777 for extension_name, extension_field in extensions.items(): 778 constant_name = extension_name.upper() + '_FIELD_NUMBER' 779 setattr(cls, constant_name, extension_field.number) 780 781 # TODO(user): Migrate all users of these attributes to functions like 782 # pool.FindExtensionByNumber(descriptor). 783 if descriptor.file is not None: 784 # TODO(user): Use cls.MESSAGE_FACTORY.pool when available. 785 pool = descriptor.file.pool 786 cls._extensions_by_number = pool._extensions_by_number[descriptor] 787 cls._extensions_by_name = pool._extensions_by_name[descriptor] 788 789def _AddStaticMethods(cls): 790 # TODO(user): This probably needs to be thread-safe(?) 791 def RegisterExtension(extension_handle): 792 extension_handle.containing_type = cls.DESCRIPTOR 793 # TODO(user): Use cls.MESSAGE_FACTORY.pool when available. 794 # pylint: disable=protected-access 795 cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(extension_handle) 796 _AttachFieldHelpers(cls, extension_handle) 797 cls.RegisterExtension = staticmethod(RegisterExtension) 798 799 def FromString(s): 800 message = cls() 801 message.MergeFromString(s) 802 return message 803 cls.FromString = staticmethod(FromString) 804 805 806def _IsPresent(item): 807 """Given a (FieldDescriptor, value) tuple from _fields, return true if the 808 value should be included in the list returned by ListFields().""" 809 810 if item[0].label == _FieldDescriptor.LABEL_REPEATED: 811 return bool(item[1]) 812 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 813 return item[1]._is_present_in_parent 814 else: 815 return True 816 817 818def _AddListFieldsMethod(message_descriptor, cls): 819 """Helper for _AddMessageMethods().""" 820 821 def ListFields(self): 822 all_fields = [item for item in self._fields.items() if _IsPresent(item)] 823 all_fields.sort(key = lambda item: item[0].number) 824 return all_fields 825 826 cls.ListFields = ListFields 827 828_PROTO3_ERROR_TEMPLATE = \ 829 ('Protocol message %s has no non-repeated submessage field "%s" ' 830 'nor marked as optional') 831_PROTO2_ERROR_TEMPLATE = 'Protocol message %s has no non-repeated field "%s"' 832 833def _AddHasFieldMethod(message_descriptor, cls): 834 """Helper for _AddMessageMethods().""" 835 836 is_proto3 = (message_descriptor.syntax == "proto3") 837 error_msg = _PROTO3_ERROR_TEMPLATE if is_proto3 else _PROTO2_ERROR_TEMPLATE 838 839 hassable_fields = {} 840 for field in message_descriptor.fields: 841 if field.label == _FieldDescriptor.LABEL_REPEATED: 842 continue 843 # For proto3, only submessages and fields inside a oneof have presence. 844 if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and 845 not field.containing_oneof): 846 continue 847 hassable_fields[field.name] = field 848 849 # Has methods are supported for oneof descriptors. 850 for oneof in message_descriptor.oneofs: 851 hassable_fields[oneof.name] = oneof 852 853 def HasField(self, field_name): 854 try: 855 field = hassable_fields[field_name] 856 except KeyError: 857 raise ValueError(error_msg % (message_descriptor.full_name, field_name)) 858 859 if isinstance(field, descriptor_mod.OneofDescriptor): 860 try: 861 return HasField(self, self._oneofs[field].name) 862 except KeyError: 863 return False 864 else: 865 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 866 value = self._fields.get(field) 867 return value is not None and value._is_present_in_parent 868 else: 869 return field in self._fields 870 871 cls.HasField = HasField 872 873 874def _AddClearFieldMethod(message_descriptor, cls): 875 """Helper for _AddMessageMethods().""" 876 def ClearField(self, field_name): 877 try: 878 field = message_descriptor.fields_by_name[field_name] 879 except KeyError: 880 try: 881 field = message_descriptor.oneofs_by_name[field_name] 882 if field in self._oneofs: 883 field = self._oneofs[field] 884 else: 885 return 886 except KeyError: 887 raise ValueError('Protocol message %s has no "%s" field.' % 888 (message_descriptor.name, field_name)) 889 890 if field in self._fields: 891 # To match the C++ implementation, we need to invalidate iterators 892 # for map fields when ClearField() happens. 893 if hasattr(self._fields[field], 'InvalidateIterators'): 894 self._fields[field].InvalidateIterators() 895 896 # Note: If the field is a sub-message, its listener will still point 897 # at us. That's fine, because the worst than can happen is that it 898 # will call _Modified() and invalidate our byte size. Big deal. 899 del self._fields[field] 900 901 if self._oneofs.get(field.containing_oneof, None) is field: 902 del self._oneofs[field.containing_oneof] 903 904 # Always call _Modified() -- even if nothing was changed, this is 905 # a mutating method, and thus calling it should cause the field to become 906 # present in the parent message. 907 self._Modified() 908 909 cls.ClearField = ClearField 910 911 912def _AddClearExtensionMethod(cls): 913 """Helper for _AddMessageMethods().""" 914 def ClearExtension(self, extension_handle): 915 extension_dict._VerifyExtensionHandle(self, extension_handle) 916 917 # Similar to ClearField(), above. 918 if extension_handle in self._fields: 919 del self._fields[extension_handle] 920 self._Modified() 921 cls.ClearExtension = ClearExtension 922 923 924def _AddHasExtensionMethod(cls): 925 """Helper for _AddMessageMethods().""" 926 def HasExtension(self, extension_handle): 927 extension_dict._VerifyExtensionHandle(self, extension_handle) 928 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 929 raise KeyError('"%s" is repeated.' % extension_handle.full_name) 930 931 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 932 value = self._fields.get(extension_handle) 933 return value is not None and value._is_present_in_parent 934 else: 935 return extension_handle in self._fields 936 cls.HasExtension = HasExtension 937 938def _InternalUnpackAny(msg): 939 """Unpacks Any message and returns the unpacked message. 940 941 This internal method is different from public Any Unpack method which takes 942 the target message as argument. _InternalUnpackAny method does not have 943 target message type and need to find the message type in descriptor pool. 944 945 Args: 946 msg: An Any message to be unpacked. 947 948 Returns: 949 The unpacked message. 950 """ 951 # TODO(user): Don't use the factory of generated messages. 952 # To make Any work with custom factories, use the message factory of the 953 # parent message. 954 # pylint: disable=g-import-not-at-top 955 from google.protobuf import symbol_database 956 factory = symbol_database.Default() 957 958 type_url = msg.type_url 959 960 if not type_url: 961 return None 962 963 # TODO(user): For now we just strip the hostname. Better logic will be 964 # required. 965 type_name = type_url.split('/')[-1] 966 descriptor = factory.pool.FindMessageTypeByName(type_name) 967 968 if descriptor is None: 969 return None 970 971 message_class = factory.GetPrototype(descriptor) 972 message = message_class() 973 974 message.ParseFromString(msg.value) 975 return message 976 977 978def _AddEqualsMethod(message_descriptor, cls): 979 """Helper for _AddMessageMethods().""" 980 def __eq__(self, other): 981 if (not isinstance(other, message_mod.Message) or 982 other.DESCRIPTOR != self.DESCRIPTOR): 983 return False 984 985 if self is other: 986 return True 987 988 if self.DESCRIPTOR.full_name == _AnyFullTypeName: 989 any_a = _InternalUnpackAny(self) 990 any_b = _InternalUnpackAny(other) 991 if any_a and any_b: 992 return any_a == any_b 993 994 if not self.ListFields() == other.ListFields(): 995 return False 996 997 # TODO(user): Fix UnknownFieldSet to consider MessageSet extensions, 998 # then use it for the comparison. 999 unknown_fields = list(self._unknown_fields) 1000 unknown_fields.sort() 1001 other_unknown_fields = list(other._unknown_fields) 1002 other_unknown_fields.sort() 1003 return unknown_fields == other_unknown_fields 1004 1005 cls.__eq__ = __eq__ 1006 1007 1008def _AddStrMethod(message_descriptor, cls): 1009 """Helper for _AddMessageMethods().""" 1010 def __str__(self): 1011 return text_format.MessageToString(self) 1012 cls.__str__ = __str__ 1013 1014 1015def _AddReprMethod(message_descriptor, cls): 1016 """Helper for _AddMessageMethods().""" 1017 def __repr__(self): 1018 return text_format.MessageToString(self) 1019 cls.__repr__ = __repr__ 1020 1021 1022def _AddUnicodeMethod(unused_message_descriptor, cls): 1023 """Helper for _AddMessageMethods().""" 1024 1025 def __unicode__(self): 1026 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') 1027 cls.__unicode__ = __unicode__ 1028 1029 1030def _BytesForNonRepeatedElement(value, field_number, field_type): 1031 """Returns the number of bytes needed to serialize a non-repeated element. 1032 The returned byte count includes space for tag information and any 1033 other additional space associated with serializing value. 1034 1035 Args: 1036 value: Value we're serializing. 1037 field_number: Field number of this value. (Since the field number 1038 is stored as part of a varint-encoded tag, this has an impact 1039 on the total bytes required to serialize the value). 1040 field_type: The type of the field. One of the TYPE_* constants 1041 within FieldDescriptor. 1042 """ 1043 try: 1044 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] 1045 return fn(field_number, value) 1046 except KeyError: 1047 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) 1048 1049 1050def _AddByteSizeMethod(message_descriptor, cls): 1051 """Helper for _AddMessageMethods().""" 1052 1053 def ByteSize(self): 1054 if not self._cached_byte_size_dirty: 1055 return self._cached_byte_size 1056 1057 size = 0 1058 descriptor = self.DESCRIPTOR 1059 if descriptor.GetOptions().map_entry: 1060 # Fields of map entry should always be serialized. 1061 size = descriptor.fields_by_name['key']._sizer(self.key) 1062 size += descriptor.fields_by_name['value']._sizer(self.value) 1063 else: 1064 for field_descriptor, field_value in self.ListFields(): 1065 size += field_descriptor._sizer(field_value) 1066 for tag_bytes, value_bytes in self._unknown_fields: 1067 size += len(tag_bytes) + len(value_bytes) 1068 1069 self._cached_byte_size = size 1070 self._cached_byte_size_dirty = False 1071 self._listener_for_children.dirty = False 1072 return size 1073 1074 cls.ByteSize = ByteSize 1075 1076 1077def _AddSerializeToStringMethod(message_descriptor, cls): 1078 """Helper for _AddMessageMethods().""" 1079 1080 def SerializeToString(self, **kwargs): 1081 # Check if the message has all of its required fields set. 1082 if not self.IsInitialized(): 1083 raise message_mod.EncodeError( 1084 'Message %s is missing required fields: %s' % ( 1085 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) 1086 return self.SerializePartialToString(**kwargs) 1087 cls.SerializeToString = SerializeToString 1088 1089 1090def _AddSerializePartialToStringMethod(message_descriptor, cls): 1091 """Helper for _AddMessageMethods().""" 1092 1093 def SerializePartialToString(self, **kwargs): 1094 out = BytesIO() 1095 self._InternalSerialize(out.write, **kwargs) 1096 return out.getvalue() 1097 cls.SerializePartialToString = SerializePartialToString 1098 1099 def InternalSerialize(self, write_bytes, deterministic=None): 1100 if deterministic is None: 1101 deterministic = ( 1102 api_implementation.IsPythonDefaultSerializationDeterministic()) 1103 else: 1104 deterministic = bool(deterministic) 1105 1106 descriptor = self.DESCRIPTOR 1107 if descriptor.GetOptions().map_entry: 1108 # Fields of map entry should always be serialized. 1109 descriptor.fields_by_name['key']._encoder( 1110 write_bytes, self.key, deterministic) 1111 descriptor.fields_by_name['value']._encoder( 1112 write_bytes, self.value, deterministic) 1113 else: 1114 for field_descriptor, field_value in self.ListFields(): 1115 field_descriptor._encoder(write_bytes, field_value, deterministic) 1116 for tag_bytes, value_bytes in self._unknown_fields: 1117 write_bytes(tag_bytes) 1118 write_bytes(value_bytes) 1119 cls._InternalSerialize = InternalSerialize 1120 1121 1122def _AddMergeFromStringMethod(message_descriptor, cls): 1123 """Helper for _AddMessageMethods().""" 1124 def MergeFromString(self, serialized): 1125 if isinstance(serialized, memoryview) and six.PY2: 1126 raise TypeError( 1127 'memoryview not supported in Python 2 with the pure Python proto ' 1128 'implementation: this is to maintain compatibility with the C++ ' 1129 'implementation') 1130 1131 serialized = memoryview(serialized) 1132 length = len(serialized) 1133 try: 1134 if self._InternalParse(serialized, 0, length) != length: 1135 # The only reason _InternalParse would return early is if it 1136 # encountered an end-group tag. 1137 raise message_mod.DecodeError('Unexpected end-group tag.') 1138 except (IndexError, TypeError): 1139 # Now ord(buf[p:p+1]) == ord('') gets TypeError. 1140 raise message_mod.DecodeError('Truncated message.') 1141 except struct.error as e: 1142 raise message_mod.DecodeError(e) 1143 return length # Return this for legacy reasons. 1144 cls.MergeFromString = MergeFromString 1145 1146 local_ReadTag = decoder.ReadTag 1147 local_SkipField = decoder.SkipField 1148 decoders_by_tag = cls._decoders_by_tag 1149 1150 def InternalParse(self, buffer, pos, end): 1151 """Create a message from serialized bytes. 1152 1153 Args: 1154 self: Message, instance of the proto message object. 1155 buffer: memoryview of the serialized data. 1156 pos: int, position to start in the serialized data. 1157 end: int, end position of the serialized data. 1158 1159 Returns: 1160 Message object. 1161 """ 1162 # Guard against internal misuse, since this function is called internally 1163 # quite extensively, and its easy to accidentally pass bytes. 1164 assert isinstance(buffer, memoryview) 1165 self._Modified() 1166 field_dict = self._fields 1167 # pylint: disable=protected-access 1168 unknown_field_set = self._unknown_field_set 1169 while pos != end: 1170 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 1171 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) 1172 if field_decoder is None: 1173 if not self._unknown_fields: # pylint: disable=protected-access 1174 self._unknown_fields = [] # pylint: disable=protected-access 1175 if unknown_field_set is None: 1176 # pylint: disable=protected-access 1177 self._unknown_field_set = containers.UnknownFieldSet() 1178 # pylint: disable=protected-access 1179 unknown_field_set = self._unknown_field_set 1180 # pylint: disable=protected-access 1181 (tag, _) = decoder._DecodeVarint(tag_bytes, 0) 1182 field_number, wire_type = wire_format.UnpackTag(tag) 1183 if field_number == 0: 1184 raise message_mod.DecodeError('Field number 0 is illegal.') 1185 # TODO(user): remove old_pos. 1186 old_pos = new_pos 1187 (data, new_pos) = decoder._DecodeUnknownField( 1188 buffer, new_pos, wire_type) # pylint: disable=protected-access 1189 if new_pos == -1: 1190 return pos 1191 # pylint: disable=protected-access 1192 unknown_field_set._add(field_number, wire_type, data) 1193 # TODO(user): remove _unknown_fields. 1194 new_pos = local_SkipField(buffer, old_pos, end, tag_bytes) 1195 if new_pos == -1: 1196 return pos 1197 self._unknown_fields.append( 1198 (tag_bytes, buffer[old_pos:new_pos].tobytes())) 1199 pos = new_pos 1200 else: 1201 pos = field_decoder(buffer, new_pos, end, self, field_dict) 1202 if field_desc: 1203 self._UpdateOneofState(field_desc) 1204 return pos 1205 cls._InternalParse = InternalParse 1206 1207 1208def _AddIsInitializedMethod(message_descriptor, cls): 1209 """Adds the IsInitialized and FindInitializationError methods to the 1210 protocol message class.""" 1211 1212 required_fields = [field for field in message_descriptor.fields 1213 if field.label == _FieldDescriptor.LABEL_REQUIRED] 1214 1215 def IsInitialized(self, errors=None): 1216 """Checks if all required fields of a message are set. 1217 1218 Args: 1219 errors: A list which, if provided, will be populated with the field 1220 paths of all missing required fields. 1221 1222 Returns: 1223 True iff the specified message has all required fields set. 1224 """ 1225 1226 # Performance is critical so we avoid HasField() and ListFields(). 1227 1228 for field in required_fields: 1229 if (field not in self._fields or 1230 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 1231 not self._fields[field]._is_present_in_parent)): 1232 if errors is not None: 1233 errors.extend(self.FindInitializationErrors()) 1234 return False 1235 1236 for field, value in list(self._fields.items()): # dict can change size! 1237 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1238 if field.label == _FieldDescriptor.LABEL_REPEATED: 1239 if (field.message_type.has_options and 1240 field.message_type.GetOptions().map_entry): 1241 continue 1242 for element in value: 1243 if not element.IsInitialized(): 1244 if errors is not None: 1245 errors.extend(self.FindInitializationErrors()) 1246 return False 1247 elif value._is_present_in_parent and not value.IsInitialized(): 1248 if errors is not None: 1249 errors.extend(self.FindInitializationErrors()) 1250 return False 1251 1252 return True 1253 1254 cls.IsInitialized = IsInitialized 1255 1256 def FindInitializationErrors(self): 1257 """Finds required fields which are not initialized. 1258 1259 Returns: 1260 A list of strings. Each string is a path to an uninitialized field from 1261 the top-level message, e.g. "foo.bar[5].baz". 1262 """ 1263 1264 errors = [] # simplify things 1265 1266 for field in required_fields: 1267 if not self.HasField(field.name): 1268 errors.append(field.name) 1269 1270 for field, value in self.ListFields(): 1271 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1272 if field.is_extension: 1273 name = '(%s)' % field.full_name 1274 else: 1275 name = field.name 1276 1277 if _IsMapField(field): 1278 if _IsMessageMapField(field): 1279 for key in value: 1280 element = value[key] 1281 prefix = '%s[%s].' % (name, key) 1282 sub_errors = element.FindInitializationErrors() 1283 errors += [prefix + error for error in sub_errors] 1284 else: 1285 # ScalarMaps can't have any initialization errors. 1286 pass 1287 elif field.label == _FieldDescriptor.LABEL_REPEATED: 1288 for i in range(len(value)): 1289 element = value[i] 1290 prefix = '%s[%d].' % (name, i) 1291 sub_errors = element.FindInitializationErrors() 1292 errors += [prefix + error for error in sub_errors] 1293 else: 1294 prefix = name + '.' 1295 sub_errors = value.FindInitializationErrors() 1296 errors += [prefix + error for error in sub_errors] 1297 1298 return errors 1299 1300 cls.FindInitializationErrors = FindInitializationErrors 1301 1302 1303def _AddMergeFromMethod(cls): 1304 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 1305 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 1306 1307 def MergeFrom(self, msg): 1308 if not isinstance(msg, cls): 1309 raise TypeError( 1310 'Parameter to MergeFrom() must be instance of same class: ' 1311 'expected %s got %s.' % (cls.__name__, msg.__class__.__name__)) 1312 1313 assert msg is not self 1314 self._Modified() 1315 1316 fields = self._fields 1317 1318 for field, value in msg._fields.items(): 1319 if field.label == LABEL_REPEATED: 1320 field_value = fields.get(field) 1321 if field_value is None: 1322 # Construct a new object to represent this field. 1323 field_value = field._default_constructor(self) 1324 fields[field] = field_value 1325 field_value.MergeFrom(value) 1326 elif field.cpp_type == CPPTYPE_MESSAGE: 1327 if value._is_present_in_parent: 1328 field_value = fields.get(field) 1329 if field_value is None: 1330 # Construct a new object to represent this field. 1331 field_value = field._default_constructor(self) 1332 fields[field] = field_value 1333 field_value.MergeFrom(value) 1334 else: 1335 self._fields[field] = value 1336 if field.containing_oneof: 1337 self._UpdateOneofState(field) 1338 1339 if msg._unknown_fields: 1340 if not self._unknown_fields: 1341 self._unknown_fields = [] 1342 self._unknown_fields.extend(msg._unknown_fields) 1343 # pylint: disable=protected-access 1344 if self._unknown_field_set is None: 1345 self._unknown_field_set = containers.UnknownFieldSet() 1346 self._unknown_field_set._extend(msg._unknown_field_set) 1347 1348 cls.MergeFrom = MergeFrom 1349 1350 1351def _AddWhichOneofMethod(message_descriptor, cls): 1352 def WhichOneof(self, oneof_name): 1353 """Returns the name of the currently set field inside a oneof, or None.""" 1354 try: 1355 field = message_descriptor.oneofs_by_name[oneof_name] 1356 except KeyError: 1357 raise ValueError( 1358 'Protocol message has no oneof "%s" field.' % oneof_name) 1359 1360 nested_field = self._oneofs.get(field, None) 1361 if nested_field is not None and self.HasField(nested_field.name): 1362 return nested_field.name 1363 else: 1364 return None 1365 1366 cls.WhichOneof = WhichOneof 1367 1368 1369def _Clear(self): 1370 # Clear fields. 1371 self._fields = {} 1372 self._unknown_fields = () 1373 # pylint: disable=protected-access 1374 if self._unknown_field_set is not None: 1375 self._unknown_field_set._clear() 1376 self._unknown_field_set = None 1377 1378 self._oneofs = {} 1379 self._Modified() 1380 1381 1382def _UnknownFields(self): 1383 if self._unknown_field_set is None: # pylint: disable=protected-access 1384 # pylint: disable=protected-access 1385 self._unknown_field_set = containers.UnknownFieldSet() 1386 return self._unknown_field_set # pylint: disable=protected-access 1387 1388 1389def _DiscardUnknownFields(self): 1390 self._unknown_fields = [] 1391 self._unknown_field_set = None # pylint: disable=protected-access 1392 for field, value in self.ListFields(): 1393 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1394 if _IsMapField(field): 1395 if _IsMessageMapField(field): 1396 for key in value: 1397 value[key].DiscardUnknownFields() 1398 elif field.label == _FieldDescriptor.LABEL_REPEATED: 1399 for sub_message in value: 1400 sub_message.DiscardUnknownFields() 1401 else: 1402 value.DiscardUnknownFields() 1403 1404 1405def _SetListener(self, listener): 1406 if listener is None: 1407 self._listener = message_listener_mod.NullMessageListener() 1408 else: 1409 self._listener = listener 1410 1411 1412def _AddMessageMethods(message_descriptor, cls): 1413 """Adds implementations of all Message methods to cls.""" 1414 _AddListFieldsMethod(message_descriptor, cls) 1415 _AddHasFieldMethod(message_descriptor, cls) 1416 _AddClearFieldMethod(message_descriptor, cls) 1417 if message_descriptor.is_extendable: 1418 _AddClearExtensionMethod(cls) 1419 _AddHasExtensionMethod(cls) 1420 _AddEqualsMethod(message_descriptor, cls) 1421 _AddStrMethod(message_descriptor, cls) 1422 _AddReprMethod(message_descriptor, cls) 1423 _AddUnicodeMethod(message_descriptor, cls) 1424 _AddByteSizeMethod(message_descriptor, cls) 1425 _AddSerializeToStringMethod(message_descriptor, cls) 1426 _AddSerializePartialToStringMethod(message_descriptor, cls) 1427 _AddMergeFromStringMethod(message_descriptor, cls) 1428 _AddIsInitializedMethod(message_descriptor, cls) 1429 _AddMergeFromMethod(cls) 1430 _AddWhichOneofMethod(message_descriptor, cls) 1431 # Adds methods which do not depend on cls. 1432 cls.Clear = _Clear 1433 cls.UnknownFields = _UnknownFields 1434 cls.DiscardUnknownFields = _DiscardUnknownFields 1435 cls._SetListener = _SetListener 1436 1437 1438def _AddPrivateHelperMethods(message_descriptor, cls): 1439 """Adds implementation of private helper methods to cls.""" 1440 1441 def Modified(self): 1442 """Sets the _cached_byte_size_dirty bit to true, 1443 and propagates this to our listener iff this was a state change. 1444 """ 1445 1446 # Note: Some callers check _cached_byte_size_dirty before calling 1447 # _Modified() as an extra optimization. So, if this method is ever 1448 # changed such that it does stuff even when _cached_byte_size_dirty is 1449 # already true, the callers need to be updated. 1450 if not self._cached_byte_size_dirty: 1451 self._cached_byte_size_dirty = True 1452 self._listener_for_children.dirty = True 1453 self._is_present_in_parent = True 1454 self._listener.Modified() 1455 1456 def _UpdateOneofState(self, field): 1457 """Sets field as the active field in its containing oneof. 1458 1459 Will also delete currently active field in the oneof, if it is different 1460 from the argument. Does not mark the message as modified. 1461 """ 1462 other_field = self._oneofs.setdefault(field.containing_oneof, field) 1463 if other_field is not field: 1464 del self._fields[other_field] 1465 self._oneofs[field.containing_oneof] = field 1466 1467 cls._Modified = Modified 1468 cls.SetInParent = Modified 1469 cls._UpdateOneofState = _UpdateOneofState 1470 1471 1472class _Listener(object): 1473 1474 """MessageListener implementation that a parent message registers with its 1475 child message. 1476 1477 In order to support semantics like: 1478 1479 foo.bar.baz.qux = 23 1480 assert foo.HasField('bar') 1481 1482 ...child objects must have back references to their parents. 1483 This helper class is at the heart of this support. 1484 """ 1485 1486 def __init__(self, parent_message): 1487 """Args: 1488 parent_message: The message whose _Modified() method we should call when 1489 we receive Modified() messages. 1490 """ 1491 # This listener establishes a back reference from a child (contained) object 1492 # to its parent (containing) object. We make this a weak reference to avoid 1493 # creating cyclic garbage when the client finishes with the 'parent' object 1494 # in the tree. 1495 if isinstance(parent_message, weakref.ProxyType): 1496 self._parent_message_weakref = parent_message 1497 else: 1498 self._parent_message_weakref = weakref.proxy(parent_message) 1499 1500 # As an optimization, we also indicate directly on the listener whether 1501 # or not the parent message is dirty. This way we can avoid traversing 1502 # up the tree in the common case. 1503 self.dirty = False 1504 1505 def Modified(self): 1506 if self.dirty: 1507 return 1508 try: 1509 # Propagate the signal to our parents iff this is the first field set. 1510 self._parent_message_weakref._Modified() 1511 except ReferenceError: 1512 # We can get here if a client has kept a reference to a child object, 1513 # and is now setting a field on it, but the child's parent has been 1514 # garbage-collected. This is not an error. 1515 pass 1516 1517 1518class _OneofListener(_Listener): 1519 """Special listener implementation for setting composite oneof fields.""" 1520 1521 def __init__(self, parent_message, field): 1522 """Args: 1523 parent_message: The message whose _Modified() method we should call when 1524 we receive Modified() messages. 1525 field: The descriptor of the field being set in the parent message. 1526 """ 1527 super(_OneofListener, self).__init__(parent_message) 1528 self._field = field 1529 1530 def Modified(self): 1531 """Also updates the state of the containing oneof in the parent message.""" 1532 try: 1533 self._parent_message_weakref._UpdateOneofState(self._field) 1534 super(_OneofListener, self).Modified() 1535 except ReferenceError: 1536 pass 1537