1# Copyright (c) 2017, 2020, Oracle and/or its affiliates.
2#
3# This program is free software; you can redistribute it and/or modify
4# it under the terms of the GNU General Public License, version 2.0, as
5# published by the Free Software Foundation.
6#
7# This program is also distributed with certain software (including
8# but not limited to OpenSSL) that is licensed under separate terms,
9# as designated in a particular file or component or in included license
10# documentation.  The authors of MySQL hereby grant you an
11# additional permission to link the program and your derivative works
12# with the separately licensed software that they have included with
13# MySQL.
14#
15# Without limiting anything contained in the foregoing, this file,
16# which is part of MySQL Connector/Python, is also subject to the
17# Universal FOSS Exception, version 1.0, a copy of which can be found at
18# http://oss.oracle.com/licenses/universal-foss-exception.
19#
20# This program is distributed in the hope that it will be useful, but
21# WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
23# See the GNU General Public License, version 2.0, for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program; if not, write to the Free Software Foundation, Inc.,
27# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA
28
29"""This module contains the implementation of a helper class for MySQL X
30Protobuf messages."""
31
32try:
33    ModuleNotFoundError
34except NameError:
35    ModuleNotFoundError = ImportError
36
37_SERVER_MESSAGES_TUPLES = (
38    ("Mysqlx.ServerMessages.Type.OK",
39     "Mysqlx.Ok"),
40    ("Mysqlx.ServerMessages.Type.ERROR",
41     "Mysqlx.Error"),
42    ("Mysqlx.ServerMessages.Type.CONN_CAPABILITIES",
43     "Mysqlx.Connection.Capabilities"),
44    ("Mysqlx.ServerMessages.Type.SESS_AUTHENTICATE_CONTINUE",
45     "Mysqlx.Session.AuthenticateContinue"),
46    ("Mysqlx.ServerMessages.Type.SESS_AUTHENTICATE_OK",
47     "Mysqlx.Session.AuthenticateOk"),
48    ("Mysqlx.ServerMessages.Type.NOTICE",
49     "Mysqlx.Notice.Frame"),
50    ("Mysqlx.ServerMessages.Type.RESULTSET_COLUMN_META_DATA",
51     "Mysqlx.Resultset.ColumnMetaData"),
52    ("Mysqlx.ServerMessages.Type.RESULTSET_ROW",
53     "Mysqlx.Resultset.Row"),
54    ("Mysqlx.ServerMessages.Type.RESULTSET_FETCH_DONE",
55     "Mysqlx.Resultset.FetchDone"),
56    ("Mysqlx.ServerMessages.Type.RESULTSET_FETCH_SUSPENDED",
57     "Mysqlx.Resultset.FetchSuspended"),
58    ("Mysqlx.ServerMessages.Type.RESULTSET_FETCH_DONE_MORE_RESULTSETS",
59     "Mysqlx.Resultset.FetchDoneMoreResultsets"),
60    ("Mysqlx.ServerMessages.Type.SQL_STMT_EXECUTE_OK",
61     "Mysqlx.Sql.StmtExecuteOk"),
62    ("Mysqlx.ServerMessages.Type.RESULTSET_FETCH_DONE_MORE_OUT_PARAMS",
63     "Mysqlx.Resultset.FetchDoneMoreOutParams"),
64    ("Mysqlx.ServerMessages.Type.COMPRESSION",
65     "Mysqlx.Connection.Compression"),
66)
67
68PROTOBUF_VERSION = None
69PROTOBUF_REPEATED_TYPES = [list]
70
71try:
72    import _mysqlxpb
73    SERVER_MESSAGES = dict([(int(_mysqlxpb.enum_value(key)), val)
74                            for key, val in _SERVER_MESSAGES_TUPLES])
75    HAVE_MYSQLXPB_CEXT = True
76except ImportError:
77    HAVE_MYSQLXPB_CEXT = False
78
79from ..helpers import BYTE_TYPES, NUMERIC_TYPES, encode_to_bytes
80
81try:
82    from google import protobuf
83    from google.protobuf import descriptor_database
84    from google.protobuf import descriptor_pb2
85    from google.protobuf import descriptor_pool
86    from google.protobuf import message_factory
87    from google.protobuf.internal.containers import (
88        RepeatedCompositeFieldContainer)
89    try:
90        from google.protobuf.pyext._message import (
91            RepeatedCompositeContainer)
92        PROTOBUF_REPEATED_TYPES.append(RepeatedCompositeContainer)
93    except ImportError:
94        pass
95
96    PROTOBUF_REPEATED_TYPES.append(RepeatedCompositeFieldContainer)
97    if hasattr(protobuf, "__version__"):
98        # Only Protobuf versions >=3.0.0 provide `__version__`
99        PROTOBUF_VERSION = protobuf.__version__
100
101    from . import mysqlx_connection_pb2
102    from . import mysqlx_crud_pb2
103    from . import mysqlx_cursor_pb2
104    from . import mysqlx_datatypes_pb2
105    from . import mysqlx_expect_pb2
106    from . import mysqlx_expr_pb2
107    from . import mysqlx_notice_pb2
108    from . import mysqlx_pb2
109    from . import mysqlx_prepare_pb2
110    from . import mysqlx_resultset_pb2
111    from . import mysqlx_session_pb2
112    from . import mysqlx_sql_pb2
113
114    # Dictionary with all messages descriptors
115    _MESSAGES = {}
116
117    # Mysqlx
118    for key, val in mysqlx_pb2.ClientMessages.Type.items():
119        _MESSAGES["Mysqlx.ClientMessages.Type.{0}".format(key)] = val
120    for key, val in mysqlx_pb2.ServerMessages.Type.items():
121        _MESSAGES["Mysqlx.ServerMessages.Type.{0}".format(key)] = val
122    for key, val in mysqlx_pb2.Error.Severity.items():
123        _MESSAGES["Mysqlx.Error.Severity.{0}".format(key)] = val
124
125    # Mysqlx.Crud
126    for key, val in mysqlx_crud_pb2.DataModel.items():
127        _MESSAGES["Mysqlx.Crud.DataModel.{0}".format(key)] = val
128    for key, val in mysqlx_crud_pb2.Find.RowLock.items():
129        _MESSAGES["Mysqlx.Crud.Find.RowLock.{0}".format(key)] = val
130    for key, val in mysqlx_crud_pb2.Order.Direction.items():
131        _MESSAGES["Mysqlx.Crud.Order.Direction.{0}".format(key)] = val
132    for key, val in mysqlx_crud_pb2.UpdateOperation.UpdateType.items():
133        _MESSAGES["Mysqlx.Crud.UpdateOperation.UpdateType.{0}".format(key)] = val
134
135    # Mysqlx.Datatypes
136    for key, val in mysqlx_datatypes_pb2.Scalar.Type.items():
137        _MESSAGES["Mysqlx.Datatypes.Scalar.Type.{0}".format(key)] = val
138    for key, val in mysqlx_datatypes_pb2.Any.Type.items():
139        _MESSAGES["Mysqlx.Datatypes.Any.Type.{0}".format(key)] = val
140
141    # Mysqlx.Expect
142    for key, val in mysqlx_expect_pb2.Open.Condition.ConditionOperation.items():
143        _MESSAGES["Mysqlx.Expect.Open.Condition.ConditionOperation.{0}"
144                  "".format(key)] = val
145    for key, val in mysqlx_expect_pb2.Open.Condition.Key.items():
146        _MESSAGES["Mysqlx.Expect.Open.Condition.Key.{0}"
147                  "".format(key)] = val
148    for key, val in mysqlx_expect_pb2.Open.CtxOperation.items():
149        _MESSAGES["Mysqlx.Expect.Open.CtxOperation.{0}".format(key)] = val
150
151    # Mysqlx.Expr
152    for key, val in mysqlx_expr_pb2.Expr.Type.items():
153        _MESSAGES["Mysqlx.Expr.Expr.Type.{0}".format(key)] = val
154    for key, val in mysqlx_expr_pb2.DocumentPathItem.Type.items():
155        _MESSAGES["Mysqlx.Expr.DocumentPathItem.Type.{0}".format(key)] = val
156
157    # Mysqlx.Notice
158    for key, val in mysqlx_notice_pb2.Frame.Scope.items():
159        _MESSAGES["Mysqlx.Notice.Frame.Scope.{0}".format(key)] = val
160    for key, val in mysqlx_notice_pb2.Warning.Level.items():
161        _MESSAGES["Mysqlx.Notice.Warning.Level.{0}".format(key)] = val
162    for key, val in mysqlx_notice_pb2.SessionStateChanged.Parameter.items():
163        _MESSAGES["Mysqlx.Notice.SessionStateChanged.Parameter.{0}"
164                  "".format(key)] = val
165
166    # Mysql.Prepare
167    for key, val in mysqlx_prepare_pb2.Prepare.OneOfMessage.Type.items():
168        _MESSAGES["Mysqlx.Prepare.Prepare.OneOfMessage.Type.{0}"
169                  "".format(key)] = val
170
171    # Mysql.Resultset
172    for key, val in mysqlx_resultset_pb2.ColumnMetaData.FieldType.items():
173        _MESSAGES["Mysqlx.Resultset.ColumnMetaData.FieldType.{0}".format(key)] = val
174
175    # Add messages to the descriptor pool
176    _DESCRIPTOR_DB = descriptor_database.DescriptorDatabase()
177    _DESCRIPTOR_POOL = descriptor_pool.DescriptorPool(_DESCRIPTOR_DB)
178
179    _DESCRIPTOR_DB.Add(descriptor_pb2.FileDescriptorProto.FromString(
180        mysqlx_connection_pb2.DESCRIPTOR.serialized_pb))
181    _DESCRIPTOR_DB.Add(descriptor_pb2.FileDescriptorProto.FromString(
182        mysqlx_crud_pb2.DESCRIPTOR.serialized_pb))
183    _DESCRIPTOR_DB.Add(descriptor_pb2.FileDescriptorProto.FromString(
184        mysqlx_cursor_pb2.DESCRIPTOR.serialized_pb))
185    _DESCRIPTOR_DB.Add(descriptor_pb2.FileDescriptorProto.FromString(
186        mysqlx_datatypes_pb2.DESCRIPTOR.serialized_pb))
187    _DESCRIPTOR_DB.Add(descriptor_pb2.FileDescriptorProto.FromString(
188        mysqlx_expect_pb2.DESCRIPTOR.serialized_pb))
189    _DESCRIPTOR_DB.Add(descriptor_pb2.FileDescriptorProto.FromString(
190        mysqlx_expr_pb2.DESCRIPTOR.serialized_pb))
191    _DESCRIPTOR_DB.Add(descriptor_pb2.FileDescriptorProto.FromString(
192        mysqlx_notice_pb2.DESCRIPTOR.serialized_pb))
193    _DESCRIPTOR_DB.Add(descriptor_pb2.FileDescriptorProto.FromString(
194        mysqlx_pb2.DESCRIPTOR.serialized_pb))
195    _DESCRIPTOR_DB.Add(descriptor_pb2.FileDescriptorProto.FromString(
196        mysqlx_prepare_pb2.DESCRIPTOR.serialized_pb))
197    _DESCRIPTOR_DB.Add(descriptor_pb2.FileDescriptorProto.FromString(
198        mysqlx_resultset_pb2.DESCRIPTOR.serialized_pb))
199    _DESCRIPTOR_DB.Add(descriptor_pb2.FileDescriptorProto.FromString(
200        mysqlx_session_pb2.DESCRIPTOR.serialized_pb))
201    _DESCRIPTOR_DB.Add(descriptor_pb2.FileDescriptorProto.FromString(
202        mysqlx_sql_pb2.DESCRIPTOR.serialized_pb))
203
204    SERVER_MESSAGES = dict(
205        [(_MESSAGES[key], val) for key, val in _SERVER_MESSAGES_TUPLES]
206    )
207    HAVE_PROTOBUF = True
208    HAVE_PROTOBUF_ERROR = None
209
210    class _mysqlxpb_pure(object):
211        """This class implements the methods in pure Python used by the
212        _mysqlxpb C++ extension."""
213
214        factory = message_factory.MessageFactory()
215
216        @staticmethod
217        def new_message(name):
218            cls = _mysqlxpb_pure.factory.GetPrototype(
219                _DESCRIPTOR_POOL.FindMessageTypeByName(name))
220            return cls()
221
222        @staticmethod
223        def enum_value(key):
224            return _MESSAGES[key]
225
226        @staticmethod
227        def serialize_message(msg):
228            return msg.SerializeToString()
229
230        @staticmethod
231        def serialize_partial_message(msg):
232            return msg.SerializePartialToString()
233
234        @staticmethod
235        def parse_message(msg_type_name, payload):
236            msg = _mysqlxpb_pure.new_message(msg_type_name)
237            msg.ParseFromString(payload)
238            return msg
239
240        @staticmethod
241        def parse_server_message(msg_type, payload):
242            msg_type_name = SERVER_MESSAGES.get(msg_type)
243            if not msg_type_name:
244                raise ValueError("Unknown msg_type: {0}".format(msg_type))
245            msg = _mysqlxpb_pure.new_message(msg_type_name)
246            msg.ParseFromString(payload)
247            return msg
248except (ImportError, ModuleNotFoundError, SyntaxError, TypeError) as err:
249    HAVE_PROTOBUF = False
250    HAVE_PROTOBUF_ERROR = err if PROTOBUF_VERSION is not None \
251        else "Protobuf >=3.0.0 is required"
252    if not HAVE_MYSQLXPB_CEXT:
253        raise ImportError("Protobuf is not available: {}"
254                          "".format(HAVE_PROTOBUF_ERROR))
255
256CRUD_PREPARE_MAPPING = {
257    "Mysqlx.ClientMessages.Type.CRUD_FIND": (
258        "Mysqlx.Prepare.Prepare.OneOfMessage.Type.FIND", "find"),
259    "Mysqlx.ClientMessages.Type.CRUD_INSERT": (
260        "Mysqlx.Prepare.Prepare.OneOfMessage.Type.INSERT", "insert"),
261    "Mysqlx.ClientMessages.Type.CRUD_UPDATE": (
262        "Mysqlx.Prepare.Prepare.OneOfMessage.Type.UPDATE", "update"),
263    "Mysqlx.ClientMessages.Type.CRUD_DELETE": (
264        "Mysqlx.Prepare.Prepare.OneOfMessage.Type.DELETE", "delete"),
265    "Mysqlx.ClientMessages.Type.SQL_STMT_EXECUTE": (
266        "Mysqlx.Prepare.Prepare.OneOfMessage.Type.STMT", "stmt_execute")
267}
268
269
270class Protobuf(object):
271    """Protobuf class acts as a container of the Protobuf message class.
272    It allows the switch between the C extension and pure Python implementation
273    message handlers, by patching the `mysqlxpb` class attribute.
274    """
275    mysqlxpb = _mysqlxpb if HAVE_MYSQLXPB_CEXT else _mysqlxpb_pure
276    use_pure = False if HAVE_MYSQLXPB_CEXT else True
277
278    @staticmethod
279    def set_use_pure(use_pure):
280        """Sets whether to use the C extension or pure Python implementation.
281
282        Args:
283            use_pure (bool): `True` to use pure Python implementation.
284        """
285        if use_pure and not HAVE_PROTOBUF:
286            raise ImportError("Protobuf is not available: {}"
287                              "".format(HAVE_PROTOBUF_ERROR))
288        elif not use_pure and not HAVE_MYSQLXPB_CEXT:
289            raise ImportError("MySQL X Protobuf C extension is not available")
290        Protobuf.mysqlxpb = _mysqlxpb_pure if use_pure else _mysqlxpb
291        Protobuf.use_pure = use_pure
292
293
294class Message(object):
295    """Helper class for interfacing with the MySQL X Protobuf extension.
296
297    Args:
298        msg_type_name (string): Protobuf type name.
299        **kwargs: Arbitrary keyword arguments with values for the message.
300    """
301    def __init__(self, msg_type_name=None, **kwargs):
302        self.__dict__["_msg"] = Protobuf.mysqlxpb.new_message(msg_type_name) \
303            if msg_type_name else None
304        for key, value in kwargs.items():
305            self.__setattr__(key, value)
306
307    def __setattr__(self, name, value):
308        if Protobuf.use_pure:
309            if isinstance(value, str):
310                setattr(self._msg, name, encode_to_bytes(value))
311            elif isinstance(value, (NUMERIC_TYPES, BYTE_TYPES)):
312                setattr(self._msg, name, value)
313            elif isinstance(value, list):
314                getattr(self._msg, name).extend(value)
315            elif isinstance(value, Message):
316                getattr(self._msg, name).MergeFrom(value.get_message())
317            else:
318                getattr(self._msg, name).MergeFrom(value)
319        else:
320            self._msg[name] = value.get_message() \
321                if isinstance(value, Message) else value
322
323    def __getattr__(self, name):
324        try:
325            return self._msg[name] if not Protobuf.use_pure \
326                else getattr(self._msg, name)
327        except KeyError:
328            raise AttributeError
329
330    def __setitem__(self, name, value):
331        self.__setattr__(name, value)
332
333    def __getitem__(self, name):
334        return self.__getattr__(name)
335
336    def get(self, name, default=None):
337        """Returns the value of an element of the message dictionary.
338
339        Args:
340            name (string): Key name.
341            default (object): The default value if the key does not exists.
342
343        Returns:
344            object: The value of the provided key name.
345        """
346        return self.__dict__["_msg"].get(name, default) \
347            if not Protobuf.use_pure \
348               else getattr(self.__dict__["_msg"], name, default)
349
350    def set_message(self, msg):
351        """Sets the message.
352
353        Args:
354            msg (dict): Dictionary representing a message.
355        """
356        self.__dict__["_msg"] = msg
357
358    def get_message(self):
359        """Returns the dictionary representing a message containing parsed
360        data.
361
362        Returns:
363            dict: The dictionary representing a message containing parsed data.
364        """
365        return self.__dict__["_msg"]
366
367    def serialize_to_string(self):
368        """Serializes a message to a string.
369
370        Returns:
371            str: A string representing a message containing parsed data.
372        """
373        return Protobuf.mysqlxpb.serialize_message(self._msg)
374
375    def serialize_partial_to_string(self):
376        """Serializes the protocol message to a binary string.
377
378        This method is similar to serialize_to_string but doesn't check if the
379        message is initialized.
380
381        Returns:
382            str: A string representation of the partial message.
383        """
384        return Protobuf.mysqlxpb.serialize_partial_message(self._msg)
385
386    @property
387    def type(self):
388        """string: Message type name."""
389        return self._msg["_mysqlxpb_type_name"] if not Protobuf.use_pure \
390            else self._msg.DESCRIPTOR.full_name
391
392    @staticmethod
393    def parse(msg_type_name, payload):
394        """Creates a new message, initialized with parsed data.
395
396        Args:
397            msg_type_name (string): Message type name.
398            payload (string): Serialized message data.
399
400        Returns:
401            dict: The dictionary representing a message containing parsed data.
402
403        .. versionadded:: 8.0.21
404        """
405        return Protobuf.mysqlxpb.parse_message(msg_type_name, payload)
406
407    @staticmethod
408    def byte_size(msg):
409        """Returns the size of the message in bytes.
410
411        Args:
412            msg (mysqlx.protobuf.Message): MySQL X Protobuf Message.
413
414        Returns:
415            int: Size of the message in bytes.
416
417        .. versionadded:: 8.0.21
418        """
419        return msg.ByteSize() if Protobuf.use_pure \
420            else len(encode_to_bytes(msg.serialize_to_string()))
421
422    @staticmethod
423    def parse_from_server(msg_type, payload):
424        """Creates a new server-side message, initialized with parsed data.
425
426        Args:
427            msg_type (int): Message type.
428            payload (string): Serialized message data.
429
430        Returns:
431            dict: The dictionary representing a message containing parsed data.
432        """
433        return Protobuf.mysqlxpb.parse_server_message(msg_type, payload)
434
435    @classmethod
436    def from_message(cls, msg_type_name, payload):
437        """Creates a new message, initialized with parsed data and returns a
438        :class:`mysqlx.protobuf.Message` object.
439
440        Args:
441            msg_type_name (string): Message type name.
442            payload (string): Serialized message data.
443
444        Returns:
445            mysqlx.protobuf.Message: The Message representing a message
446                                     containing parsed data.
447        """
448        msg = cls()
449        msg.set_message(Protobuf.mysqlxpb.parse_message(msg_type_name, payload))
450        return msg
451
452    @classmethod
453    def from_server_message(cls, msg_type, payload):
454        """Creates a new server-side message, initialized with parsed data and
455        returns a :class:`mysqlx.protobuf.Message` object.
456
457        Args:
458            msg_type (int): Message type.
459            payload (string): Serialized message data.
460
461        Returns:
462            mysqlx.protobuf.Message: The Message representing a message
463                                     containing parsed data.
464        """
465        msg = cls()
466        msg.set_message(
467            Protobuf.mysqlxpb.parse_server_message(msg_type, payload))
468        return msg
469
470
471def mysqlxpb_enum(name):
472    """Returns the value of a MySQL X Protobuf enumerator.
473
474    Args:
475        name (string): MySQL X Protobuf numerator name.
476
477    Returns:
478        int: Value of the enumerator.
479    """
480    return Protobuf.mysqlxpb.enum_value(name)
481