1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Contains _ExtensionDict class to represent extensions.
32"""
33
34from google.protobuf.internal import type_checkers
35from google.protobuf.descriptor import FieldDescriptor
36
37
38def _VerifyExtensionHandle(message, extension_handle):
39  """Verify that the given extension handle is valid."""
40
41  if not isinstance(extension_handle, FieldDescriptor):
42    raise KeyError('HasExtension() expects an extension handle, got: %s' %
43                   extension_handle)
44
45  if not extension_handle.is_extension:
46    raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
47
48  if not extension_handle.containing_type:
49    raise KeyError('"%s" is missing a containing_type.'
50                   % extension_handle.full_name)
51
52  if extension_handle.containing_type is not message.DESCRIPTOR:
53    raise KeyError('Extension "%s" extends message type "%s", but this '
54                   'message is of type "%s".' %
55                   (extension_handle.full_name,
56                    extension_handle.containing_type.full_name,
57                    message.DESCRIPTOR.full_name))
58
59
60# TODO(robinson): Unify error handling of "unknown extension" crap.
61# TODO(robinson): Support iteritems()-style iteration over all
62# extensions with the "has" bits turned on?
63class _ExtensionDict(object):
64
65  """Dict-like container for Extension fields on proto instances.
66
67  Note that in all cases we expect extension handles to be
68  FieldDescriptors.
69  """
70
71  def __init__(self, extended_message):
72    """
73    Args:
74      extended_message: Message instance for which we are the Extensions dict.
75    """
76    self._extended_message = extended_message
77
78  def __getitem__(self, extension_handle):
79    """Returns the current value of the given extension handle."""
80
81    _VerifyExtensionHandle(self._extended_message, extension_handle)
82
83    result = self._extended_message._fields.get(extension_handle)
84    if result is not None:
85      return result
86
87    if extension_handle.label == FieldDescriptor.LABEL_REPEATED:
88      result = extension_handle._default_constructor(self._extended_message)
89    elif extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
90      message_type = extension_handle.message_type
91      if not hasattr(message_type, '_concrete_class'):
92        # pylint: disable=protected-access
93        self._extended_message._FACTORY.GetPrototype(message_type)
94      assert getattr(extension_handle.message_type, '_concrete_class', None), (
95          'Uninitialized concrete class found for field %r (message type %r)'
96          % (extension_handle.full_name,
97             extension_handle.message_type.full_name))
98      result = extension_handle.message_type._concrete_class()
99      try:
100        result._SetListener(self._extended_message._listener_for_children)
101      except ReferenceError:
102        pass
103    else:
104      # Singular scalar -- just return the default without inserting into the
105      # dict.
106      return extension_handle.default_value
107
108    # Atomically check if another thread has preempted us and, if not, swap
109    # in the new object we just created.  If someone has preempted us, we
110    # take that object and discard ours.
111    # WARNING:  We are relying on setdefault() being atomic.  This is true
112    #   in CPython but we haven't investigated others.  This warning appears
113    #   in several other locations in this file.
114    result = self._extended_message._fields.setdefault(
115        extension_handle, result)
116
117    return result
118
119  def __eq__(self, other):
120    if not isinstance(other, self.__class__):
121      return False
122
123    my_fields = self._extended_message.ListFields()
124    other_fields = other._extended_message.ListFields()
125
126    # Get rid of non-extension fields.
127    my_fields = [field for field in my_fields if field.is_extension]
128    other_fields = [field for field in other_fields if field.is_extension]
129
130    return my_fields == other_fields
131
132  def __ne__(self, other):
133    return not self == other
134
135  def __len__(self):
136    fields = self._extended_message.ListFields()
137    # Get rid of non-extension fields.
138    extension_fields = [field for field in fields if field[0].is_extension]
139    return len(extension_fields)
140
141  def __hash__(self):
142    raise TypeError('unhashable object')
143
144  # Note that this is only meaningful for non-repeated, scalar extension
145  # fields.  Note also that we may have to call _Modified() when we do
146  # successfully set a field this way, to set any necessary "has" bits in the
147  # ancestors of the extended message.
148  def __setitem__(self, extension_handle, value):
149    """If extension_handle specifies a non-repeated, scalar extension
150    field, sets the value of that field.
151    """
152
153    _VerifyExtensionHandle(self._extended_message, extension_handle)
154
155    if (extension_handle.label == FieldDescriptor.LABEL_REPEATED or
156        extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE):
157      raise TypeError(
158          'Cannot assign to extension "%s" because it is a repeated or '
159          'composite type.' % extension_handle.full_name)
160
161    # It's slightly wasteful to lookup the type checker each time,
162    # but we expect this to be a vanishingly uncommon case anyway.
163    type_checker = type_checkers.GetTypeChecker(extension_handle)
164    # pylint: disable=protected-access
165    self._extended_message._fields[extension_handle] = (
166        type_checker.CheckValue(value))
167    self._extended_message._Modified()
168
169  def __delitem__(self, extension_handle):
170    self._extended_message.ClearExtension(extension_handle)
171
172  def _FindExtensionByName(self, name):
173    """Tries to find a known extension with the specified name.
174
175    Args:
176      name: Extension full name.
177
178    Returns:
179      Extension field descriptor.
180    """
181    return self._extended_message._extensions_by_name.get(name, None)
182
183  def _FindExtensionByNumber(self, number):
184    """Tries to find a known extension with the field number.
185
186    Args:
187      number: Extension field number.
188
189    Returns:
190      Extension field descriptor.
191    """
192    return self._extended_message._extensions_by_number.get(number, None)
193
194  def __iter__(self):
195    # Return a generator over the populated extension fields
196    return (f[0] for f in self._extended_message.ListFields()
197            if f[0].is_extension)
198
199  def __contains__(self, extension_handle):
200    _VerifyExtensionHandle(self._extended_message, extension_handle)
201
202    if extension_handle not in self._extended_message._fields:
203      return False
204
205    if extension_handle.label == FieldDescriptor.LABEL_REPEATED:
206      return bool(self._extended_message._fields.get(extension_handle))
207
208    if extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
209      value = self._extended_message._fields.get(extension_handle)
210      # pylint: disable=protected-access
211      return value is not None and value._is_present_in_parent
212
213    return True
214