1# Copyright (c) 2016, 2021, Oracle and/or its affiliates.
2#
3# This program is free software; you can redistribute it and/or modify
4# it under the terms of the GNU General Public License, version 2.0, as
5# published by the Free Software Foundation.
6#
7# This program is also distributed with certain software (including
8# but not limited to OpenSSL) that is licensed under separate terms,
9# as designated in a particular file or component or in included license
10# documentation.  The authors of MySQL hereby grant you an
11# additional permission to link the program and your derivative works
12# with the separately licensed software that they have included with
13# MySQL.
14#
15# Without limiting anything contained in the foregoing, this file,
16# which is part of MySQL Connector/Python, is also subject to the
17# Universal FOSS Exception, version 1.0, a copy of which can be found at
18# http://oss.oracle.com/licenses/universal-foss-exception.
19#
20# This program is distributed in the hope that it will be useful, but
21# WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
23# See the GNU General Public License, version 2.0, for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program; if not, write to the Free Software Foundation, Inc.,
27# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA
28
29"""Implementation of the CRUD database objects."""
30
31import json
32import warnings
33
34from .dbdoc import DbDoc
35from .errorcode import (ER_NO_SUCH_TABLE, ER_TABLE_EXISTS_ERROR,
36                        ER_X_CMD_NUM_ARGUMENTS, ER_X_INVALID_ADMIN_COMMAND)
37from .errors import NotSupportedError, OperationalError, ProgrammingError
38from .helpers import deprecated, escape, quote_identifier
39from .statement import (FindStatement, AddStatement, RemoveStatement,
40                        ModifyStatement, SelectStatement, InsertStatement,
41                        DeleteStatement, UpdateStatement,
42                        CreateCollectionIndexStatement)
43
44
45_COUNT_VIEWS_QUERY = ("SELECT COUNT(*) FROM information_schema.views "
46                      "WHERE table_schema = '{0}' AND table_name = '{1}'")
47_COUNT_TABLES_QUERY = ("SELECT COUNT(*) FROM information_schema.tables "
48                       "WHERE table_schema = '{0}' AND table_name = '{1}'")
49_COUNT_SCHEMAS_QUERY = ("SELECT COUNT(*) FROM information_schema.schemata "
50                        "WHERE schema_name = '{0}'")
51_COUNT_QUERY = "SELECT COUNT(*) FROM {0}.{1}"
52_DROP_TABLE_QUERY = "DROP TABLE IF EXISTS {0}.{1}"
53
54
55class DatabaseObject(object):
56    """Provides base functionality for database objects.
57
58    Args:
59        schema (mysqlx.Schema): The Schema object.
60        name (str): The database object name.
61    """
62    def __init__(self, schema, name):
63        self._schema = schema
64        self._name = name.decode() if isinstance(name, bytes) else name
65        self._session = self._schema.get_session()
66        self._connection = self._session.get_connection()
67
68    @property
69    def session(self):
70        """:class:`mysqlx.Session`: The Session object.
71        """
72        return self._session
73
74    @property
75    def schema(self):
76        """:class:`mysqlx.Schema`: The Schema object.
77        """
78        return self._schema
79
80    @property
81    def name(self):
82        """str: The name of this database object.
83        """
84        return self._name
85
86    def get_connection(self):
87        """Returns the underlying connection.
88
89        Returns:
90            mysqlx.connection.Connection: The connection object.
91        """
92        return self._connection
93
94    def get_session(self):
95        """Returns the session of this database object.
96
97        Returns:
98            mysqlx.Session: The Session object.
99        """
100        return self._session
101
102    def get_schema(self):
103        """Returns the Schema object of this database object.
104
105        Returns:
106            mysqlx.Schema: The Schema object.
107        """
108        return self._schema
109
110    def get_name(self):
111        """Returns the name of this database object.
112
113        Returns:
114            str: The name of this database object.
115        """
116        return self._name
117
118    def exists_in_database(self):
119        """Verifies if this object exists in the database.
120
121        Returns:
122            bool: `True` if object exists in database.
123
124        Raises:
125           NotImplementedError: This method must be implemented.
126        """
127        raise NotImplementedError
128
129    @deprecated("8.0.12", "Use 'exists_in_database()' method instead")
130    def am_i_real(self):
131        """Verifies if this object exists in the database.
132
133        Returns:
134            bool: `True` if object exists in database.
135
136        Raises:
137           NotImplementedError: This method must be implemented.
138
139        .. deprecated:: 8.0.12
140           Use ``exists_in_database()`` method instead.
141        """
142        return self.exists_in_database()
143
144    @deprecated("8.0.12", "Use 'get_name()' method instead")
145    def who_am_i(self):
146        """Returns the name of this database object.
147
148        Returns:
149            str: The name of this database object.
150
151        .. deprecated:: 8.0.12
152           Use ``get_name()`` method instead.
153        """
154        return self.get_name()
155
156
157class Schema(DatabaseObject):
158    """A client-side representation of a database schema. Provides access to
159    the schema contents.
160
161    Args:
162        session (mysqlx.XSession): Session object.
163        name (str): The Schema name.
164    """
165    def __init__(self, session, name):
166        self._session = session
167        super(Schema, self).__init__(self, name)
168
169    def exists_in_database(self):
170        """Verifies if this object exists in the database.
171
172        Returns:
173            bool: `True` if object exists in database.
174        """
175        sql = _COUNT_SCHEMAS_QUERY.format(escape(self._name))
176        return self._connection.execute_sql_scalar(sql) == 1
177
178    def get_collections(self):
179        """Returns a list of collections for this schema.
180
181        Returns:
182            `list`: List of Collection objects.
183        """
184        rows = self._connection.get_row_result("list_objects",
185                                               {"schema": self._name})
186        rows.fetch_all()
187        collections = []
188        for row in rows:
189            if row["type"] != "COLLECTION":
190                continue
191            try:
192                collection = Collection(self, row["TABLE_NAME"])
193            except ValueError:
194                collection = Collection(self, row["name"])
195            collections.append(collection)
196        return collections
197
198    def get_collection_as_table(self, name, check_existence=False):
199        """Returns a a table object for the given collection
200
201        Returns:
202            mysqlx.Table: Table object.
203
204        """
205        return self.get_table(name, check_existence)
206
207    def get_tables(self):
208        """Returns a list of tables for this schema.
209
210        Returns:
211            `list`: List of Table objects.
212        """
213        rows = self._connection.get_row_result("list_objects",
214                                               {"schema": self._name})
215        rows.fetch_all()
216        tables = []
217        object_types = ("TABLE", "VIEW",)
218        for row in rows:
219            if row["type"] in object_types:
220                try:
221                    table = Table(self, row["TABLE_NAME"])
222                except ValueError:
223                    table = Table(self, row["name"])
224                tables.append(table)
225        return tables
226
227    def get_table(self, name, check_existence=False):
228        """Returns the table of the given name for this schema.
229
230        Returns:
231            mysqlx.Table: Table object.
232        """
233        table = Table(self, name)
234        if check_existence:
235            if not table.exists_in_database():
236                raise ProgrammingError("Table does not exist")
237        return table
238
239    def get_view(self, name, check_existence=False):
240        """Returns the view of the given name for this schema.
241
242        Returns:
243            mysqlx.View: View object.
244        """
245        view = View(self, name)
246        if check_existence:
247            if not view.exists_in_database():
248                raise ProgrammingError("View does not exist")
249        return view
250
251    def get_collection(self, name, check_existence=False):
252        """Returns the collection of the given name for this schema.
253
254        Returns:
255            mysqlx.Collection: Collection object.
256        """
257        collection = Collection(self, name)
258        if check_existence:
259            if not collection.exists_in_database():
260                raise ProgrammingError("Collection does not exist")
261        return collection
262
263    def drop_collection(self, name):
264        """Drops a collection.
265
266        Args:
267            name (str): The name of the collection to be dropped.
268        """
269        self._connection.execute_nonquery(
270            "sql", _DROP_TABLE_QUERY.format(quote_identifier(self._name),
271                                            quote_identifier(name)), False)
272
273    def create_collection(self, name, reuse_existing=False, validation=None,
274                          **kwargs):
275        """Creates in the current schema a new collection with the specified
276        name and retrieves an object representing the new collection created.
277
278        Args:
279            name (str): The name of the collection.
280            reuse_existing (bool): `True` to reuse an existing collection.
281            validation (Optional[dict]): A dict, containing the keys `level`
282                                         with the validation level and `schema`
283                                         with a dict or a string representation
284                                         of a JSON schema specification.
285
286        Returns:
287            mysqlx.Collection: Collection object.
288
289        Raises:
290            :class:`mysqlx.ProgrammingError`: If ``reuse_existing`` is False
291                                              and collection exists or the
292                                              collection name is invalid.
293            :class:`mysqlx.NotSupportedError`: If schema validation is not
294                                               supported by the server.
295
296        .. versionchanged:: 8.0.21
297        """
298        if not name:
299            raise ProgrammingError("Collection name is invalid")
300
301        if "reuse" in kwargs:
302            warnings.warn("'reuse' is deprecated since 8.0.21. "
303                          "Please use 'reuse_existing' instead",
304                          DeprecationWarning)
305            reuse_existing = kwargs["reuse"]
306
307        collection = Collection(self, name)
308        fields = {"schema": self._name, "name": name}
309
310        if validation is not None:
311            if not isinstance(validation, dict) or not validation:
312                raise ProgrammingError("Invalid value for 'validation'")
313
314            valid_options = ("level", "schema")
315            for option in validation:
316                if option not in valid_options:
317                    raise ProgrammingError("Invalid option in 'validation': {}"
318                                           "".format(option))
319
320            options = []
321
322            if "level" in validation:
323                level = validation["level"]
324                if not isinstance(level, str):
325                    raise ProgrammingError("Invalid value for 'level'")
326                options.append(("level", level))
327
328            if "schema" in validation:
329                schema = validation["schema"]
330                if not isinstance(schema, (str, dict)):
331                    raise ProgrammingError("Invalid value for 'schema'")
332                options.append(
333                    ("schema", json.dumps(schema)
334                               if isinstance(schema, dict) else schema))
335
336            fields["options"] = ("validation", options)
337
338        try:
339            self._connection.execute_nonquery(
340                "mysqlx", "create_collection", True, fields)
341        except OperationalError as err:
342            if err.errno == ER_X_CMD_NUM_ARGUMENTS:
343                raise NotSupportedError(
344                    "Your MySQL server does not support the requested "
345                    "operation. Please update to MySQL 8.0.19 or a later "
346                    "version")
347            if err.errno == ER_TABLE_EXISTS_ERROR:
348                if not reuse_existing:
349                    raise ProgrammingError(
350                        "Collection '{}' already exists".format(name))
351            else:
352                raise ProgrammingError(err.msg, err.errno)
353
354        return collection
355
356    def modify_collection(self, name, validation=None):
357        """Modifies a collection using a JSON schema validation.
358
359        Args:
360            name (str): The name of the collection.
361            validation (Optional[dict]): A dict, containing the keys `level`
362                                         with the validation level and `schema`
363                                         with a dict or a string representation
364                                         of a JSON schema specification.
365
366        Raises:
367            :class:`mysqlx.ProgrammingError`: If the collection name or
368                                              validation is invalid.
369            :class:`mysqlx.NotSupportedError`: If schema validation is not
370                                               supported by the server.
371
372        .. versionadded:: 8.0.21
373        """
374        if not name:
375            raise ProgrammingError("Collection name is invalid")
376
377        if not isinstance(validation, dict) or not validation:
378            raise ProgrammingError("Invalid value for 'validation'")
379
380        valid_options = ("level", "schema")
381        for option in validation:
382            if option not in valid_options:
383                raise ProgrammingError("Invalid option in 'validation': {}"
384                                       "".format(option))
385        options = []
386
387        if "level" in validation:
388            level = validation["level"]
389            if not isinstance(level, str):
390                raise ProgrammingError("Invalid value for 'level'")
391            options.append(("level", level))
392
393        if "schema" in validation:
394            schema = validation["schema"]
395            if not isinstance(schema, (str, dict)):
396                raise ProgrammingError("Invalid value for 'schema'")
397            options.append(
398                ("schema", json.dumps(schema)
399                           if isinstance(schema, dict) else schema))
400
401        fields = {
402            "schema": self._name,
403            "name": name,
404            "options": ("validation", options)
405        }
406
407        try:
408            self._connection.execute_nonquery(
409                "mysqlx", "modify_collection_options", True, fields)
410        except OperationalError as err:
411            if err.errno == ER_X_INVALID_ADMIN_COMMAND:
412                raise NotSupportedError(
413                    "Your MySQL server does not support the requested "
414                    "operation. Please update to MySQL 8.0.19 or a later "
415                    "version")
416            raise ProgrammingError(err.msg, err.errno)
417
418
419class Collection(DatabaseObject):
420    """Represents a collection of documents on a schema.
421
422    Args:
423        schema (mysqlx.Schema): The Schema object.
424        name (str): The collection name.
425    """
426
427    def exists_in_database(self):
428        """Verifies if this object exists in the database.
429
430        Returns:
431            bool: `True` if object exists in database.
432        """
433        sql = _COUNT_TABLES_QUERY.format(escape(self._schema.name),
434                                         escape(self._name))
435        return self._connection.execute_sql_scalar(sql) == 1
436
437    def find(self, condition=None):
438        """Retrieves documents from a collection.
439
440        Args:
441            condition (Optional[str]): The string with the filter expression of
442                                       the documents to be retrieved.
443        """
444        stmt = FindStatement(self, condition)
445        stmt.stmt_id = self._connection.get_next_statement_id()
446        return stmt
447
448    def add(self, *values):
449        """Adds a list of documents to a collection.
450
451        Args:
452            *values: The document list to be added into the collection.
453
454        Returns:
455            mysqlx.AddStatement: AddStatement object.
456        """
457        return AddStatement(self).add(*values)
458
459    def remove(self, condition):
460        """Removes documents based on the ``condition``.
461
462        Args:
463            condition (str): The string with the filter expression of the
464                             documents to be removed.
465
466        Returns:
467            mysqlx.RemoveStatement: RemoveStatement object.
468
469        .. versionchanged:: 8.0.12
470           The ``condition`` parameter is now mandatory.
471        """
472        stmt = RemoveStatement(self, condition)
473        stmt.stmt_id = self._connection.get_next_statement_id()
474        return stmt
475
476    def modify(self, condition):
477        """Modifies documents based on the ``condition``.
478
479        Args:
480            condition (str): The string with the filter expression of the
481                             documents to be modified.
482
483        Returns:
484            mysqlx.ModifyStatement: ModifyStatement object.
485
486        .. versionchanged:: 8.0.12
487           The ``condition`` parameter is now mandatory.
488        """
489        stmt = ModifyStatement(self, condition)
490        stmt.stmt_id = self._connection.get_next_statement_id()
491        return stmt
492
493    def count(self):
494        """Counts the documents in the collection.
495
496        Returns:
497            int: The total of documents in the collection.
498        """
499        sql = _COUNT_QUERY.format(quote_identifier(self._schema.name),
500                                  quote_identifier(self._name))
501        try:
502            res = self._connection.execute_sql_scalar(sql)
503        except OperationalError as err:
504            if err.errno == ER_NO_SUCH_TABLE:
505                raise OperationalError(
506                    "Collection '{}' does not exist in schema '{}'"
507                    "".format(self._name, self._schema.name))
508            raise
509        return res
510
511    def create_index(self, index_name, fields_desc):
512        """Creates a collection index.
513
514        Args:
515            index_name (str): Index name.
516            fields_desc (dict): A dictionary containing the fields members that
517                                constraints the index to be created. It must
518                                have the form as shown in the following::
519
520                                   {"fields": [{"field": member_path,
521                                                "type": member_type,
522                                                "required": member_required,
523                                                "array": array,
524                                                "collation": collation,
525                                                "options": options,
526                                                "srid": srid},
527                                                # {... more members,
528                                                #      repeated as many times
529                                                #      as needed}
530                                                ],
531                                    "type": type}
532        """
533        return CreateCollectionIndexStatement(self, index_name, fields_desc)
534
535    def drop_index(self, index_name):
536        """Drops a collection index.
537
538        Args:
539            index_name (str): Index name.
540        """
541        self._connection.execute_nonquery("mysqlx", "drop_collection_index",
542                                          False, {"schema": self._schema.name,
543                                                  "collection": self._name,
544                                                  "name": index_name})
545
546    def replace_one(self, doc_id, doc):
547        """Replaces the Document matching the document ID with a new document
548        provided.
549
550        Args:
551            doc_id (str): Document ID
552            doc (:class:`mysqlx.DbDoc` or `dict`): New Document
553        """
554        if "_id" in doc and doc["_id"] != doc_id:
555            raise ProgrammingError(
556                "Replacement document has an _id that is different than the "
557                "matched document"
558            )
559        return self.modify("_id = :id").set("$", doc) \
560                   .bind("id", doc_id).execute()
561
562    def add_or_replace_one(self, doc_id, doc):
563        """Upserts the Document matching the document ID with a new document
564        provided.
565
566        Args:
567            doc_id (str): Document ID
568            doc (:class:`mysqlx.DbDoc` or dict): New Document
569        """
570        if "_id" in doc and doc["_id"] != doc_id:
571            raise ProgrammingError(
572                "Replacement document has an _id that is different than the "
573                "matched document"
574            )
575        if not isinstance(doc, DbDoc):
576            doc = DbDoc(doc)
577        return self.add(doc.copy(doc_id)).upsert(True).execute()
578
579    def get_one(self, doc_id):
580        """Returns a Document matching the Document ID.
581
582        Args:
583            doc_id (str): Document ID
584
585        Returns:
586            mysqlx.DbDoc: The Document matching the Document ID.
587        """
588        result = self.find("_id = :id").bind("id", doc_id).execute()
589        doc = result.fetch_one()
590        self._connection.fetch_active_result()
591        return doc
592
593    def remove_one(self, doc_id):
594        """Removes a Document matching the Document ID.
595
596        Args:
597            doc_id (str): Document ID
598
599        Returns:
600            mysqlx.Result: Result object.
601        """
602        return self.remove("_id = :id").bind("id", doc_id).execute()
603
604
605class Table(DatabaseObject):
606    """Represents a database table on a schema.
607
608    Provides access to the table through standard INSERT/SELECT/UPDATE/DELETE
609    statements.
610
611    Args:
612        schema (mysqlx.Schema): The Schema object.
613        name (str): The table name.
614    """
615
616    def exists_in_database(self):
617        """Verifies if this object exists in the database.
618
619        Returns:
620            bool: `True` if object exists in database.
621        """
622        sql = _COUNT_TABLES_QUERY.format(escape(self._schema.name),
623                                         escape(self._name))
624        return self._connection.execute_sql_scalar(sql) == 1
625
626    def select(self, *fields):
627        """Creates a new :class:`mysqlx.SelectStatement` object.
628
629        Args:
630            *fields: The fields to be retrieved.
631
632        Returns:
633            mysqlx.SelectStatement: SelectStatement object
634        """
635        stmt = SelectStatement(self, *fields)
636        stmt.stmt_id = self._connection.get_next_statement_id()
637        return stmt
638
639    def insert(self, *fields):
640        """Creates a new :class:`mysqlx.InsertStatement` object.
641
642        Args:
643            *fields: The fields to be inserted.
644
645        Returns:
646            mysqlx.InsertStatement: InsertStatement object
647        """
648        stmt = InsertStatement(self, *fields)
649        stmt.stmt_id = self._connection.get_next_statement_id()
650        return stmt
651
652    def update(self):
653        """Creates a new :class:`mysqlx.UpdateStatement` object.
654
655        Returns:
656            mysqlx.UpdateStatement: UpdateStatement object
657        """
658        stmt = UpdateStatement(self)
659        stmt.stmt_id = self._connection.get_next_statement_id()
660        return stmt
661
662    def delete(self):
663        """Creates a new :class:`mysqlx.DeleteStatement` object.
664
665        Returns:
666            mysqlx.DeleteStatement: DeleteStatement object
667
668        .. versionchanged:: 8.0.12
669           The ``condition`` parameter was removed.
670        """
671        stmt = DeleteStatement(self)
672        stmt.stmt_id = self._connection.get_next_statement_id()
673        return stmt
674
675    def count(self):
676        """Counts the rows in the table.
677
678        Returns:
679            int: The total of rows in the table.
680        """
681        sql = _COUNT_QUERY.format(quote_identifier(self._schema.name),
682                                  quote_identifier(self._name))
683        try:
684            res = self._connection.execute_sql_scalar(sql)
685        except OperationalError as err:
686            if err.errno == ER_NO_SUCH_TABLE:
687                raise OperationalError(
688                    "Table '{}' does not exist in schema '{}'"
689                    "".format(self._name, self._schema.name))
690            raise
691        return res
692
693    def is_view(self):
694        """Determine if the underlying object is a view or not.
695
696        Returns:
697            bool: `True` if the underlying object is a view.
698        """
699        sql = _COUNT_VIEWS_QUERY.format(escape(self._schema.name),
700                                        escape(self._name))
701        return self._connection.execute_sql_scalar(sql) == 1
702
703
704class View(Table):
705    """Represents a database view on a schema.
706
707    Provides a mechanism for creating, alter and drop views.
708
709    Args:
710        schema (mysqlx.Schema): The Schema object.
711        name (str): The table name.
712    """
713
714    def exists_in_database(self):
715        """Verifies if this object exists in the database.
716
717        Returns:
718            bool: `True` if object exists in database.
719        """
720        sql = _COUNT_VIEWS_QUERY.format(escape(self._schema.name),
721                                        escape(self._name))
722        return self._connection.execute_sql_scalar(sql) == 1
723