1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""A database of Python protocol buffer generated symbols.
32
33SymbolDatabase is the MessageFactory for messages generated at compile time,
34and makes it easy to create new instances of a registered type, given only the
35type's protocol buffer symbol name.
36
37Example usage:
38
39  db = symbol_database.SymbolDatabase()
40
41  # Register symbols of interest, from one or multiple files.
42  db.RegisterFileDescriptor(my_proto_pb2.DESCRIPTOR)
43  db.RegisterMessage(my_proto_pb2.MyMessage)
44  db.RegisterEnumDescriptor(my_proto_pb2.MyEnum.DESCRIPTOR)
45
46  # The database can be used as a MessageFactory, to generate types based on
47  # their name:
48  types = db.GetMessages(['my_proto.proto'])
49  my_message_instance = types['MyMessage']()
50
51  # The database's underlying descriptor pool can be queried, so it's not
52  # necessary to know a type's filename to be able to generate it:
53  filename = db.pool.FindFileContainingSymbol('MyMessage')
54  my_message_instance = db.GetMessages([filename])['MyMessage']()
55
56  # This functionality is also provided directly via a convenience method:
57  my_message_instance = db.GetSymbol('MyMessage')()
58"""
59
60
61from google.protobuf import descriptor_pool
62from google.protobuf import message_factory
63
64
65class SymbolDatabase(message_factory.MessageFactory):
66  """A database of Python generated symbols."""
67
68  def RegisterMessage(self, message):
69    """Registers the given message type in the local database.
70
71    Calls to GetSymbol() and GetMessages() will return messages registered here.
72
73    Args:
74      message: a message.Message, to be registered.
75
76    Returns:
77      The provided message.
78    """
79
80    desc = message.DESCRIPTOR
81    self._classes[desc] = message
82    self.RegisterMessageDescriptor(desc)
83    return message
84
85  def RegisterMessageDescriptor(self, message_descriptor):
86    """Registers the given message descriptor in the local database.
87
88    Args:
89      message_descriptor: a descriptor.MessageDescriptor.
90    """
91    self.pool.AddDescriptor(message_descriptor)
92
93  def RegisterEnumDescriptor(self, enum_descriptor):
94    """Registers the given enum descriptor in the local database.
95
96    Args:
97      enum_descriptor: a descriptor.EnumDescriptor.
98
99    Returns:
100      The provided descriptor.
101    """
102    self.pool.AddEnumDescriptor(enum_descriptor)
103    return enum_descriptor
104
105  def RegisterServiceDescriptor(self, service_descriptor):
106    """Registers the given service descriptor in the local database.
107
108    Args:
109      service_descriptor: a descriptor.ServiceDescriptor.
110
111    Returns:
112      The provided descriptor.
113    """
114    self.pool.AddServiceDescriptor(service_descriptor)
115
116  def RegisterFileDescriptor(self, file_descriptor):
117    """Registers the given file descriptor in the local database.
118
119    Args:
120      file_descriptor: a descriptor.FileDescriptor.
121
122    Returns:
123      The provided descriptor.
124    """
125    self.pool.AddFileDescriptor(file_descriptor)
126
127  def GetSymbol(self, symbol):
128    """Tries to find a symbol in the local database.
129
130    Currently, this method only returns message.Message instances, however, if
131    may be extended in future to support other symbol types.
132
133    Args:
134      symbol: A str, a protocol buffer symbol.
135
136    Returns:
137      A Python class corresponding to the symbol.
138
139    Raises:
140      KeyError: if the symbol could not be found.
141    """
142
143    return self._classes[self.pool.FindMessageTypeByName(symbol)]
144
145  def GetMessages(self, files):
146    # TODO(amauryfa): Fix the differences with MessageFactory.
147    """Gets all registered messages from a specified file.
148
149    Only messages already created and registered will be returned; (this is the
150    case for imported _pb2 modules)
151    But unlike MessageFactory, this version also returns already defined nested
152    messages, but does not register any message extensions.
153
154    Args:
155      files: The file names to extract messages from.
156
157    Returns:
158      A dictionary mapping proto names to the message classes.
159
160    Raises:
161      KeyError: if a file could not be found.
162    """
163
164    def _GetAllMessages(desc):
165      """Walk a message Descriptor and recursively yields all message names."""
166      yield desc
167      for msg_desc in desc.nested_types:
168        for nested_desc in _GetAllMessages(msg_desc):
169          yield nested_desc
170
171    result = {}
172    for file_name in files:
173      file_desc = self.pool.FindFileByName(file_name)
174      for msg_desc in file_desc.message_types_by_name.values():
175        for desc in _GetAllMessages(msg_desc):
176          try:
177            result[desc.full_name] = self._classes[desc]
178          except KeyError:
179            # This descriptor has no registered class, skip it.
180            pass
181    return result
182
183
184_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default())
185
186
187def Default():
188  """Returns the default SymbolDatabase."""
189  return _DEFAULT
190