1# Copyright (c) 2016, 2020, Oracle and/or its affiliates.
2#
3# This program is free software; you can redistribute it and/or modify
4# it under the terms of the GNU General Public License, version 2.0, as
5# published by the Free Software Foundation.
6#
7# This program is also distributed with certain software (including
8# but not limited to OpenSSL) that is licensed under separate terms,
9# as designated in a particular file or component or in included license
10# documentation.  The authors of MySQL hereby grant you an
11# additional permission to link the program and your derivative works
12# with the separately licensed software that they have included with
13# MySQL.
14#
15# Without limiting anything contained in the foregoing, this file,
16# which is part of MySQL Connector/Python, is also subject to the
17# Universal FOSS Exception, version 1.0, a copy of which can be found at
18# http://oss.oracle.com/licenses/universal-foss-exception.
19#
20# This program is distributed in the hope that it will be useful, but
21# WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
23# See the GNU General Public License, version 2.0, for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program; if not, write to the Free Software Foundation, Inc.,
27# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA
28
29"""Implementation of Statements."""
30
31import copy
32import json
33import warnings
34
35from .errors import ProgrammingError, NotSupportedError
36from .expr import ExprParser
37from .constants import LockContention
38from .dbdoc import DbDoc
39from .helpers import deprecated
40from .result import Result
41from .protobuf import mysqlxpb_enum
42
43ERR_INVALID_INDEX_NAME = 'The given index name "{}" is not valid'
44
45
46class Expr(object):
47    """Expression wrapper."""
48    def __init__(self, expr):
49        self.expr = expr
50
51
52def flexible_params(*values):
53    """Parse flexible parameters."""
54    if len(values) == 1 and isinstance(values[0], (list, tuple,)):
55        return values[0]
56    return values
57
58
59def is_quoted_identifier(identifier, sql_mode=""):
60    """Check if the given identifier is quoted.
61
62    Args:
63        identifier (string): Identifier to check.
64        sql_mode (Optional[string]): SQL mode.
65
66    Returns:
67        `True` if the identifier has backtick quotes, and False otherwise.
68    """
69    if "ANSI_QUOTES" in sql_mode:
70        return ((identifier[0] == "`" and identifier[-1] == "`") or
71                (identifier[0] == '"' and identifier[-1] == '"'))
72    return identifier[0] == "`" and identifier[-1] == "`"
73
74
75def quote_identifier(identifier, sql_mode=""):
76    """Quote the given identifier with backticks, converting backticks (`) in
77    the identifier name with the correct escape sequence (``).
78
79    Args:
80        identifier (string): Identifier to quote.
81        sql_mode (Optional[string]): SQL mode.
82
83    Returns:
84        A string with the identifier quoted with backticks.
85    """
86    if len(identifier) == 0:
87        return "``"
88    if "ANSI_QUOTES" in sql_mode:
89        return '"{0}"'.format(identifier.replace('"', '""'))
90    return "`{0}`".format(identifier.replace("`", "``"))
91
92
93def quote_multipart_identifier(identifiers, sql_mode=""):
94    """Quote the given multi-part identifier with backticks.
95
96    Args:
97        identifiers (iterable): List of identifiers to quote.
98        sql_mode (Optional[string]): SQL mode.
99
100    Returns:
101        A string with the multi-part identifier quoted with backticks.
102    """
103    return ".".join([quote_identifier(identifier, sql_mode)
104                     for identifier in identifiers])
105
106
107def parse_table_name(default_schema, table_name, sql_mode=""):
108    """Parse table name.
109
110    Args:
111        default_schema (str): The default schema.
112        table_name (str): The table name.
113        sql_mode(Optional[str]): The SQL mode.
114
115    Returns:
116        str: The parsed table name.
117    """
118    quote = '"' if "ANSI_QUOTES" in sql_mode else "`"
119    delimiter = ".{0}".format(quote) if quote in table_name else "."
120    temp = table_name.split(delimiter, 1)
121    return (default_schema if len(temp) == 1 else temp[0].strip(quote),
122            temp[-1].strip(quote),)
123
124
125class Statement(object):
126    """Provides base functionality for statement objects.
127
128    Args:
129        target (object): The target database object, it can be
130                         :class:`mysqlx.Collection` or :class:`mysqlx.Table`.
131        doc_based (bool): `True` if it is document based.
132    """
133    def __init__(self, target, doc_based=True):
134        self._target = target
135        self._doc_based = doc_based
136        self._connection = target.get_connection() if target else None
137        self._stmt_id = None
138        self._exec_counter = 0
139        self._changed = True
140        self._prepared = False
141        self._deallocate_prepare_execute = False
142
143    @property
144    def target(self):
145        """object: The database object target."""
146        return self._target
147
148    @property
149    def schema(self):
150        """:class:`mysqlx.Schema`: The Schema object."""
151        return self._target.schema
152
153    @property
154    def stmt_id(self):
155        """Returns this statement ID.
156
157        Returns:
158            int: The statement ID.
159        """
160        return self._stmt_id
161
162    @stmt_id.setter
163    def stmt_id(self, value):
164        self._stmt_id = value
165
166    @property
167    def exec_counter(self):
168        """int: The number of times this statement was executed."""
169        return self._exec_counter
170
171    @property
172    def changed(self):
173        """bool: `True` if this statement has changes."""
174        return self._changed
175
176    @changed.setter
177    def changed(self, value):
178        self._changed = value
179
180    @property
181    def prepared(self):
182        """bool: `True` if this statement has been prepared."""
183        return self._prepared
184
185    @prepared.setter
186    def prepared(self, value):
187        self._prepared = value
188
189    @property
190    def repeated(self):
191        """bool: `True` if this statement was executed more than once.
192        """
193        return self._exec_counter > 1
194
195    @property
196    def deallocate_prepare_execute(self):
197        """bool: `True` to deallocate + prepare + execute statement.
198        """
199        return self._deallocate_prepare_execute
200
201    @deallocate_prepare_execute.setter
202    def deallocate_prepare_execute(self, value):
203        self._deallocate_prepare_execute = value
204
205    def is_doc_based(self):
206        """Check if it is document based.
207
208        Returns:
209            bool: `True` if it is document based.
210        """
211        return self._doc_based
212
213    def increment_exec_counter(self):
214        """Increments the number of times this statement has been executed."""
215        self._exec_counter += 1
216
217    def reset_exec_counter(self):
218        """Resets the number of times this statement has been executed."""
219        self._exec_counter = 0
220
221    def execute(self):
222        """Execute the statement.
223
224        Raises:
225           NotImplementedError: This method must be implemented.
226        """
227        raise NotImplementedError
228
229
230class FilterableStatement(Statement):
231    """A statement to be used with filterable statements.
232
233    Args:
234        target (object): The target database object, it can be
235                         :class:`mysqlx.Collection` or :class:`mysqlx.Table`.
236        doc_based (Optional[bool]): `True` if it is document based
237                                    (default: `True`).
238        condition (Optional[str]): Sets the search condition to filter
239                                   documents or records.
240    """
241    def __init__(self, target, doc_based=True, condition=None):
242        super(FilterableStatement, self).__init__(target=target,
243                                                  doc_based=doc_based)
244        self._binding_map = {}
245        self._bindings = {}
246        self._having = None
247        self._grouping_str = ""
248        self._grouping = None
249        self._limit_offset = 0
250        self._limit_row_count = None
251        self._projection_str = ""
252        self._projection_expr = None
253        self._sort_str = ""
254        self._sort_expr = None
255        self._where_str = ""
256        self._where_expr = None
257        self.has_bindings = False
258        self.has_limit = False
259        self.has_group_by = False
260        self.has_having = False
261        self.has_projection = False
262        self.has_sort = False
263        self.has_where = False
264        if condition:
265            self._set_where(condition)
266
267    def _bind_single(self, obj):
268        """Bind single object.
269
270        Args:
271            obj (:class:`mysqlx.DbDoc` or str): DbDoc or JSON string object.
272
273        Raises:
274            :class:`mysqlx.ProgrammingError`: If invalid JSON string to bind.
275            ValueError: If JSON loaded is not a dictionary.
276        """
277        if isinstance(obj, dict):
278            self.bind(DbDoc(obj).as_str())
279        elif isinstance(obj, DbDoc):
280            self.bind(obj.as_str())
281        elif isinstance(obj, str):
282            try:
283                res = json.loads(obj)
284                if not isinstance(res, dict):
285                    raise ValueError
286            except ValueError:
287                raise ProgrammingError("Invalid JSON string to bind")
288            for key in res.keys():
289                self.bind(key, res[key])
290        else:
291            raise ProgrammingError("Invalid JSON string or object to bind")
292
293    def _sort(self, *clauses):
294        """Sets the sorting criteria.
295
296        Args:
297            *clauses: The expression strings defining the sort criteria.
298
299        Returns:
300            mysqlx.FilterableStatement: FilterableStatement object.
301        """
302        self.has_sort = True
303        self._sort_str = ",".join(flexible_params(*clauses))
304        self._sort_expr = ExprParser(self._sort_str,
305                                     not self._doc_based).parse_order_spec()
306        self._changed = True
307        return self
308
309    def _set_where(self, condition):
310        """Sets the search condition to filter.
311
312        Args:
313            condition (str): Sets the search condition to filter documents or
314                             records.
315
316        Returns:
317            mysqlx.FilterableStatement: FilterableStatement object.
318        """
319        self.has_where = True
320        self._where_str = condition
321        try:
322            expr = ExprParser(condition, not self._doc_based)
323            self._where_expr = expr.expr()
324        except ValueError:
325            raise ProgrammingError("Invalid condition")
326        self._binding_map = expr.placeholder_name_to_position
327        self._changed = True
328        return self
329
330    def _set_group_by(self, *fields):
331        """Set group by.
332
333        Args:
334            *fields: List of fields.
335        """
336        fields = flexible_params(*fields)
337        self.has_group_by = True
338        self._grouping_str = ",".join(fields)
339        self._grouping = ExprParser(self._grouping_str,
340                                    not self._doc_based).parse_expr_list()
341        self._changed = True
342
343    def _set_having(self, condition):
344        """Set having.
345
346        Args:
347            condition (str): The condition.
348        """
349        self.has_having = True
350        self._having = ExprParser(condition, not self._doc_based).expr()
351        self._changed = True
352
353    def _set_projection(self, *fields):
354        """Set the projection.
355
356        Args:
357            *fields: List of fields.
358
359        Returns:
360            :class:`mysqlx.FilterableStatement`: Returns self.
361        """
362        fields = flexible_params(*fields)
363        self.has_projection = True
364        self._projection_str = ",".join(fields)
365        self._projection_expr = ExprParser(
366            self._projection_str,
367            not self._doc_based).parse_table_select_projection()
368        self._changed = True
369        return self
370
371    def get_binding_map(self):
372        """Returns the binding map dictionary.
373
374        Returns:
375            dict: The binding map dictionary.
376        """
377        return self._binding_map
378
379    def get_bindings(self):
380        """Returns the bindings list.
381
382        Returns:
383            `list`: The bindings list.
384        """
385        return self._bindings
386
387    def get_grouping(self):
388        """Returns the grouping expression list.
389
390        Returns:
391            `list`: The grouping expression list.
392        """
393        return self._grouping
394
395    def get_having(self):
396        """Returns the having expression.
397
398        Returns:
399            object: The having expression.
400        """
401        return self._having
402
403    def get_limit_row_count(self):
404        """Returns the limit row count.
405
406        Returns:
407            int: The limit row count.
408        """
409        return self._limit_row_count
410
411    def get_limit_offset(self):
412        """Returns the limit offset.
413
414        Returns:
415            int: The limit offset.
416        """
417        return self._limit_offset
418
419    def get_where_expr(self):
420        """Returns the where expression.
421
422        Returns:
423            object: The where expression.
424        """
425        return self._where_expr
426
427    def get_projection_expr(self):
428        """Returns the projection expression.
429
430        Returns:
431            object: The projection expression.
432        """
433        return self._projection_expr
434
435    def get_sort_expr(self):
436        """Returns the sort expression.
437
438        Returns:
439            object: The sort expression.
440        """
441        return self._sort_expr
442
443    @deprecated("8.0.12")
444    def where(self, condition):
445        """Sets the search condition to filter.
446
447        Args:
448            condition (str): Sets the search condition to filter documents or
449                             records.
450
451        Returns:
452            mysqlx.FilterableStatement: FilterableStatement object.
453
454        .. deprecated:: 8.0.12
455        """
456        return self._set_where(condition)
457
458    @deprecated("8.0.12")
459    def sort(self, *clauses):
460        """Sets the sorting criteria.
461
462        Args:
463            *clauses: The expression strings defining the sort criteria.
464
465        Returns:
466            mysqlx.FilterableStatement: FilterableStatement object.
467
468        .. deprecated:: 8.0.12
469        """
470        return self._sort(*clauses)
471
472    def limit(self, row_count, offset=None):
473        """Sets the maximum number of items to be returned.
474
475        Args:
476            row_count (int): The maximum number of items.
477
478        Returns:
479            mysqlx.FilterableStatement: FilterableStatement object.
480
481        Raises:
482            ValueError: If ``row_count`` is not a positive integer.
483
484        .. versionchanged:: 8.0.12
485           The usage of ``offset`` was deprecated.
486        """
487        if not isinstance(row_count, int) or row_count < 0:
488            raise ValueError("The 'row_count' value must be a positive integer")
489        if not self.has_limit:
490            self._changed = bool(self._exec_counter == 0)
491            self._deallocate_prepare_execute = bool(not self._exec_counter == 0)
492
493        self._limit_row_count = row_count
494        self.has_limit = True
495        if offset:
496            self.offset(offset)
497            warnings.warn("'limit(row_count, offset)' is deprecated, please "
498                          "use 'offset(offset)' to set the number of items to "
499                          "skip", category=DeprecationWarning)
500        return self
501
502    def offset(self, offset):
503        """Sets the number of items to skip.
504
505        Args:
506            offset (int): The number of items to skip.
507
508        Returns:
509            mysqlx.FilterableStatement: FilterableStatement object.
510
511        Raises:
512            ValueError: If ``offset`` is not a positive integer.
513
514        .. versionadded:: 8.0.12
515        """
516        if not isinstance(offset, int) or offset < 0:
517            raise ValueError("The 'offset' value must be a positive integer")
518        self._limit_offset = offset
519        return self
520
521    def bind(self, *args):
522        """Binds value(s) to a specific placeholder(s).
523
524        Args:
525            *args: The name of the placeholder and the value to bind.
526                   A :class:`mysqlx.DbDoc` object or a JSON string
527                   representation can be used.
528
529        Returns:
530            mysqlx.FilterableStatement: FilterableStatement object.
531
532        Raises:
533            ProgrammingError: If the number of arguments is invalid.
534        """
535        self.has_bindings = True
536        count = len(args)
537        if count == 1:
538            self._bind_single(args[0])
539        elif count == 2:
540            self._bindings[args[0]] = args[1]
541        else:
542            raise ProgrammingError("Invalid number of arguments to bind")
543        return self
544
545    def execute(self):
546        """Execute the statement.
547
548        Raises:
549           NotImplementedError: This method must be implemented.
550        """
551        raise NotImplementedError
552
553
554class SqlStatement(Statement):
555    """A statement for SQL execution.
556
557    Args:
558        connection (mysqlx.connection.Connection): Connection object.
559        sql (string): The sql statement to be executed.
560    """
561    def __init__(self, connection, sql):
562        super(SqlStatement, self).__init__(target=None, doc_based=False)
563        self._connection = connection
564        self._sql = sql
565        self._binding_map = None
566        self._bindings = []
567        self.has_bindings = False
568        self.has_limit = False
569
570    @property
571    def sql(self):
572        """string: The SQL text statement."""
573        return self._sql
574
575    def get_binding_map(self):
576        """Returns the binding map dictionary.
577
578        Returns:
579            dict: The binding map dictionary.
580        """
581        return self._binding_map
582
583    def get_bindings(self):
584        """Returns the bindings list.
585
586        Returns:
587            `list`: The bindings list.
588        """
589        return self._bindings
590
591    def bind(self, *args):
592        """Binds value(s) to a specific placeholder(s).
593
594        Args:
595            *args: The value(s) to bind.
596
597        Returns:
598            mysqlx.SqlStatement: SqlStatement object.
599        """
600        if len(args) == 0:
601            raise ProgrammingError("Invalid number of arguments to bind")
602        self.has_bindings = True
603        bindings = flexible_params(*args)
604        if isinstance(bindings, (list, tuple)):
605            self._bindings = bindings
606        else:
607            self._bindings.append(bindings)
608        return self
609
610    def execute(self):
611        """Execute the statement.
612
613        Returns:
614            mysqlx.SqlResult: SqlResult object.
615        """
616        return self._connection.send_sql(self)
617
618
619class WriteStatement(Statement):
620    """Provide common write operation attributes.
621    """
622    def __init__(self, target, doc_based):
623        super(WriteStatement, self).__init__(target, doc_based)
624        self._values = []
625
626    def get_values(self):
627        """Returns the list of values.
628
629        Returns:
630            `list`: The list of values.
631        """
632        return self._values
633
634    def execute(self):
635        """Execute the statement.
636
637        Raises:
638           NotImplementedError: This method must be implemented.
639        """
640        raise NotImplementedError
641
642
643class AddStatement(WriteStatement):
644    """A statement for document addition on a collection.
645
646    Args:
647        collection (mysqlx.Collection): The Collection object.
648    """
649    def __init__(self, collection):
650        super(AddStatement, self).__init__(collection, True)
651        self._upsert = False
652        self.ids = []
653
654    def is_upsert(self):
655        """Returns `True` if it's an upsert.
656
657        Returns:
658            bool: `True` if it's an upsert.
659        """
660        return self._upsert
661
662    def upsert(self, value=True):
663        """Sets the upset flag to the boolean of the value provided.
664        Setting of this flag allows updating of the matched rows/documents
665        with the provided value.
666
667        Args:
668            value (optional[bool]): Set or unset the upsert flag.
669        """
670        self._upsert = value
671        return self
672
673    def add(self, *values):
674        """Adds a list of documents into a collection.
675
676        Args:
677            *values: The documents to be added into the collection.
678
679        Returns:
680            mysqlx.AddStatement: AddStatement object.
681        """
682        for val in flexible_params(*values):
683            if isinstance(val, DbDoc):
684                self._values.append(val)
685            else:
686                self._values.append(DbDoc(val))
687        return self
688
689    def execute(self):
690        """Execute the statement.
691
692        Returns:
693            mysqlx.Result: Result object.
694        """
695        if len(self._values) == 0:
696            return Result()
697
698        return self._connection.send_insert(self)
699
700
701class UpdateSpec(object):
702    """Update specification class implementation.
703
704    Args:
705        update_type (int): The update type.
706        source (str): The source.
707        value (Optional[str]): The value.
708    """
709    def __init__(self, update_type, source, value=None):
710        if update_type == mysqlxpb_enum(
711                "Mysqlx.Crud.UpdateOperation.UpdateType.SET"):
712            self._table_set(source, value)
713        else:
714            self.update_type = update_type
715            self.source = source
716            if len(source) > 0 and source[0] == '$':
717                self.source = source[1:]
718            self.source = ExprParser(self.source,
719                                     False).document_field().identifier
720            self.value = value
721
722    def _table_set(self, source, value):
723        """Table set.
724
725        Args:
726            source (str): The source.
727            value (str): The value.
728        """
729        self.update_type = mysqlxpb_enum(
730            "Mysqlx.Crud.UpdateOperation.UpdateType.SET")
731        self.source = ExprParser(source, True).parse_table_update_field()
732        self.value = value
733
734
735class ModifyStatement(FilterableStatement):
736    """A statement for document update operations on a Collection.
737
738    Args:
739        collection (mysqlx.Collection): The Collection object.
740        condition (str): Sets the search condition to identify the documents
741                         to be modified.
742
743    .. versionchanged:: 8.0.12
744       The ``condition`` parameter is now mandatory.
745    """
746    def __init__(self, collection, condition):
747        super(ModifyStatement, self).__init__(target=collection,
748                                              condition=condition)
749        self._update_ops = {}
750
751    def sort(self, *clauses):
752        """Sets the sorting criteria.
753
754        Args:
755            *clauses: The expression strings defining the sort criteria.
756
757        Returns:
758            mysqlx.ModifyStatement: ModifyStatement object.
759        """
760        return self._sort(*clauses)
761
762    def get_update_ops(self):
763        """Returns the list of update operations.
764
765        Returns:
766            `list`: The list of update operations.
767        """
768        return self._update_ops
769
770    def set(self, doc_path, value):
771        """Sets or updates attributes on documents in a collection.
772
773        Args:
774            doc_path (string): The document path of the item to be set.
775            value (string): The value to be set on the specified attribute.
776
777        Returns:
778            mysqlx.ModifyStatement: ModifyStatement object.
779        """
780        self._update_ops[doc_path] = UpdateSpec(mysqlxpb_enum(
781            "Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_SET"),
782                                                doc_path, value)
783        self._changed = True
784        return self
785
786    @deprecated("8.0.12")
787    def change(self, doc_path, value):
788        """Add an update to the statement setting the field, if it exists at
789        the document path, to the given value.
790
791        Args:
792            doc_path (string): The document path of the item to be set.
793            value (object): The value to be set on the specified attribute.
794
795        Returns:
796            mysqlx.ModifyStatement: ModifyStatement object.
797
798        .. deprecated:: 8.0.12
799        """
800        self._update_ops[doc_path] = UpdateSpec(mysqlxpb_enum(
801            "Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_REPLACE"),
802                                                doc_path, value)
803        self._changed = True
804        return self
805
806    def unset(self, *doc_paths):
807        """Removes attributes from documents in a collection.
808
809        Args:
810            doc_paths (list): The list of document paths of the attributes to be
811                              removed.
812
813        Returns:
814            mysqlx.ModifyStatement: ModifyStatement object.
815        """
816        for item in flexible_params(*doc_paths):
817            self._update_ops[item] = UpdateSpec(mysqlxpb_enum(
818                "Mysqlx.Crud.UpdateOperation.UpdateType.ITEM_REMOVE"), item)
819        self._changed = True
820        return self
821
822    def array_insert(self, field, value):
823        """Insert a value into the specified array in documents of a
824        collection.
825
826        Args:
827            field (string): A document path that identifies the array attribute
828                            and position where the value will be inserted.
829            value (object): The value to be inserted.
830
831        Returns:
832            mysqlx.ModifyStatement: ModifyStatement object.
833        """
834        self._update_ops[field] = UpdateSpec(mysqlxpb_enum(
835            "Mysqlx.Crud.UpdateOperation.UpdateType.ARRAY_INSERT"),
836                                             field, value)
837        self._changed = True
838        return self
839
840    def array_append(self, doc_path, value):
841        """Inserts a value into a specific position in an array attribute in
842        documents of a collection.
843
844        Args:
845            doc_path (string): A document path that identifies the array
846                               attribute and position where the value will be
847                               inserted.
848            value (object): The value to be inserted.
849
850        Returns:
851            mysqlx.ModifyStatement: ModifyStatement object.
852        """
853        self._update_ops[doc_path] = UpdateSpec(mysqlxpb_enum(
854            "Mysqlx.Crud.UpdateOperation.UpdateType.ARRAY_APPEND"),
855                                                doc_path, value)
856        self._changed = True
857        return self
858
859    def patch(self, doc):
860        """Takes a :class:`mysqlx.DbDoc`, string JSON format or a dict with the
861        changes and applies it on all matching documents.
862
863        Args:
864            doc (object): A generic document (DbDoc), string in JSON format or
865                          dict, with the changes to apply to the matching
866                          documents.
867
868        Returns:
869            mysqlx.ModifyStatement: ModifyStatement object.
870        """
871        if doc is None:
872            doc = ''
873        if not isinstance(doc, (ExprParser, dict, DbDoc, str)):
874            raise ProgrammingError(
875                "Invalid data for update operation on document collection "
876                "table")
877        self._update_ops["patch"] = UpdateSpec(
878            mysqlxpb_enum("Mysqlx.Crud.UpdateOperation.UpdateType.MERGE_PATCH"),
879            '', doc.expr() if isinstance(doc, ExprParser) else doc)
880        self._changed = True
881        return self
882
883    def execute(self):
884        """Execute the statement.
885
886        Returns:
887            mysqlx.Result: Result object.
888
889        Raises:
890            ProgrammingError: If condition was not set.
891        """
892        if not self.has_where:
893            raise ProgrammingError("No condition was found for modify")
894        return self._connection.send_update(self)
895
896
897class ReadStatement(FilterableStatement):
898    """Provide base functionality for Read operations
899
900    Args:
901        target (object): The target database object, it can be
902                         :class:`mysqlx.Collection` or :class:`mysqlx.Table`.
903        doc_based (Optional[bool]): `True` if it is document based
904                                    (default: `True`).
905        condition (Optional[str]): Sets the search condition to filter
906                                   documents or records.
907    """
908    def __init__(self, target, doc_based=True, condition=None):
909        super(ReadStatement, self).__init__(target, doc_based, condition)
910        self._lock_exclusive = False
911        self._lock_shared = False
912        self._lock_contention = LockContention.DEFAULT
913
914    @property
915    def lock_contention(self):
916        """:class:`mysqlx.LockContention`: The lock contention value."""
917        return self._lock_contention
918
919    def _set_lock_contention(self, lock_contention):
920        """Set the lock contention.
921
922        Args:
923            lock_contention (:class:`mysqlx.LockContention`): Lock contention.
924
925        Raises:
926            ProgrammingError: If is an invalid lock contention value.
927        """
928        try:
929            # Check if is a valid lock contention value
930            _ = LockContention.index(lock_contention)
931        except ValueError:
932            raise ProgrammingError("Invalid lock contention mode. Use 'NOWAIT' "
933                                   "or 'SKIP_LOCKED'")
934        self._lock_contention = lock_contention
935
936    def is_lock_exclusive(self):
937        """Returns `True` if is `EXCLUSIVE LOCK`.
938
939        Returns:
940            bool: `True` if is `EXCLUSIVE LOCK`.
941        """
942        return self._lock_exclusive
943
944    def is_lock_shared(self):
945        """Returns `True` if is `SHARED LOCK`.
946
947        Returns:
948            bool: `True` if is `SHARED LOCK`.
949        """
950        return self._lock_shared
951
952    def lock_shared(self, lock_contention=LockContention.DEFAULT):
953        """Execute a read operation with `SHARED LOCK`. Only one lock can be
954           active at a time.
955
956        Args:
957            lock_contention (:class:`mysqlx.LockContention`): Lock contention.
958        """
959        self._lock_exclusive = False
960        self._lock_shared = True
961        self._set_lock_contention(lock_contention)
962        return self
963
964    def lock_exclusive(self, lock_contention=LockContention.DEFAULT):
965        """Execute a read operation with `EXCLUSIVE LOCK`. Only one lock can be
966           active at a time.
967
968        Args:
969            lock_contention (:class:`mysqlx.LockContention`): Lock contention.
970        """
971        self._lock_exclusive = True
972        self._lock_shared = False
973        self._set_lock_contention(lock_contention)
974        return self
975
976    def group_by(self, *fields):
977        """Sets a grouping criteria for the resultset.
978
979        Args:
980            *fields: The string expressions identifying the grouping criteria.
981
982        Returns:
983            mysqlx.ReadStatement: ReadStatement object.
984        """
985        self._set_group_by(*fields)
986        return self
987
988    def having(self, condition):
989        """Sets a condition for records to be considered in agregate function
990        operations.
991
992        Args:
993            condition (string): A condition on the agregate functions used on
994                                the grouping criteria.
995
996        Returns:
997            mysqlx.ReadStatement: ReadStatement object.
998        """
999        self._set_having(condition)
1000        return self
1001
1002    def execute(self):
1003        """Execute the statement.
1004
1005        Returns:
1006            mysqlx.Result: Result object.
1007        """
1008        return self._connection.send_find(self)
1009
1010
1011class FindStatement(ReadStatement):
1012    """A statement document selection on a Collection.
1013
1014    Args:
1015        collection (mysqlx.Collection): The Collection object.
1016        condition (Optional[str]): An optional expression to identify the
1017                                   documents to be retrieved. If not specified
1018                                   all the documents will be included on the
1019                                   result unless a limit is set.
1020    """
1021    def __init__(self, collection, condition=None):
1022        super(FindStatement, self).__init__(collection, True, condition)
1023
1024    def fields(self, *fields):
1025        """Sets a document field filter.
1026
1027        Args:
1028            *fields: The string expressions identifying the fields to be
1029                     extracted.
1030
1031        Returns:
1032            mysqlx.FindStatement: FindStatement object.
1033        """
1034        return self._set_projection(*fields)
1035
1036    def sort(self, *clauses):
1037        """Sets the sorting criteria.
1038
1039        Args:
1040            *clauses: The expression strings defining the sort criteria.
1041
1042        Returns:
1043            mysqlx.FindStatement: FindStatement object.
1044        """
1045        return self._sort(*clauses)
1046
1047
1048class SelectStatement(ReadStatement):
1049    """A statement for record retrieval operations on a Table.
1050
1051    Args:
1052        table (mysqlx.Table): The Table object.
1053        *fields: The fields to be retrieved.
1054    """
1055    def __init__(self, table, *fields):
1056        super(SelectStatement, self).__init__(table, False)
1057        self._set_projection(*fields)
1058
1059    def where(self, condition):
1060        """Sets the search condition to filter.
1061
1062        Args:
1063            condition (str): Sets the search condition to filter records.
1064
1065        Returns:
1066            mysqlx.SelectStatement: SelectStatement object.
1067        """
1068        return self._set_where(condition)
1069
1070    def order_by(self, *clauses):
1071        """Sets the order by criteria.
1072
1073        Args:
1074            *clauses: The expression strings defining the order by criteria.
1075
1076        Returns:
1077            mysqlx.SelectStatement: SelectStatement object.
1078        """
1079        return self._sort(*clauses)
1080
1081    def get_sql(self):
1082        """Returns the generated SQL.
1083
1084        Returns:
1085            str: The generated SQL.
1086        """
1087        where = " WHERE {0}".format(self._where_str) if self.has_where else ""
1088        group_by = " GROUP BY {0}".format(self._grouping_str) if \
1089            self.has_group_by else ""
1090        having = " HAVING {0}".format(self._having) if self.has_having else ""
1091        order_by = " ORDER BY {0}".format(self._sort_str) if self.has_sort \
1092            else ""
1093        limit = " LIMIT {0} OFFSET {1}".format(self._limit_row_count,
1094                                               self._limit_offset) \
1095                                               if self.has_limit else ""
1096        stmt = ("SELECT {select} FROM {schema}.{table}{where}{group}{having}"
1097                "{order}{limit}".format(select=self._projection_str or "*",
1098                                        schema=self.schema.name,
1099                                        table=self.target.name, limit=limit,
1100                                        where=where, group=group_by,
1101                                        having=having, order=order_by))
1102        return stmt
1103
1104
1105class InsertStatement(WriteStatement):
1106    """A statement for insert operations on Table.
1107
1108    Args:
1109        table (mysqlx.Table): The Table object.
1110        *fields: The fields to be inserted.
1111    """
1112    def __init__(self, table, *fields):
1113        super(InsertStatement, self).__init__(table, False)
1114        self._fields = flexible_params(*fields)
1115
1116    def values(self, *values):
1117        """Set the values to be inserted.
1118
1119        Args:
1120            *values: The values of the columns to be inserted.
1121
1122        Returns:
1123            mysqlx.InsertStatement: InsertStatement object.
1124        """
1125        self._values.append(list(flexible_params(*values)))
1126        return self
1127
1128    def execute(self):
1129        """Execute the statement.
1130
1131        Returns:
1132            mysqlx.Result: Result object.
1133        """
1134        return self._connection.send_insert(self)
1135
1136
1137class UpdateStatement(FilterableStatement):
1138    """A statement for record update operations on a Table.
1139
1140    Args:
1141        table (mysqlx.Table): The Table object.
1142
1143    .. versionchanged:: 8.0.12
1144       The ``fields`` parameters were removed.
1145    """
1146    def __init__(self, table):
1147        super(UpdateStatement, self).__init__(target=table, doc_based=False)
1148        self._update_ops = {}
1149
1150    def where(self, condition):
1151        """Sets the search condition to filter.
1152
1153        Args:
1154            condition (str): Sets the search condition to filter records.
1155
1156        Returns:
1157            mysqlx.UpdateStatement: UpdateStatement object.
1158        """
1159        return self._set_where(condition)
1160
1161    def order_by(self, *clauses):
1162        """Sets the order by criteria.
1163
1164        Args:
1165            *clauses: The expression strings defining the order by criteria.
1166
1167        Returns:
1168            mysqlx.UpdateStatement: UpdateStatement object.
1169        """
1170        return self._sort(*clauses)
1171
1172    def get_update_ops(self):
1173        """Returns the list of update operations.
1174
1175        Returns:
1176            `list`: The list of update operations.
1177        """
1178        return self._update_ops
1179
1180    def set(self, field, value):
1181        """Updates the column value on records in a table.
1182
1183        Args:
1184            field (string): The column name to be updated.
1185            value (object): The value to be set on the specified column.
1186
1187        Returns:
1188            mysqlx.UpdateStatement: UpdateStatement object.
1189        """
1190        self._update_ops[field] = UpdateSpec(mysqlxpb_enum(
1191            "Mysqlx.Crud.UpdateOperation.UpdateType.SET"), field, value)
1192        self._changed = True
1193        return self
1194
1195    def execute(self):
1196        """Execute the statement.
1197
1198        Returns:
1199            mysqlx.Result: Result object
1200
1201        Raises:
1202            ProgrammingError: If condition was not set.
1203        """
1204        if not self.has_where:
1205            raise ProgrammingError("No condition was found for update")
1206        return self._connection.send_update(self)
1207
1208
1209class RemoveStatement(FilterableStatement):
1210    """A statement for document removal from a collection.
1211
1212    Args:
1213        collection (mysqlx.Collection): The Collection object.
1214        condition (str): Sets the search condition to identify the documents
1215                         to be removed.
1216
1217    .. versionchanged:: 8.0.12
1218       The ``condition`` parameter was added.
1219    """
1220    def __init__(self, collection, condition):
1221        super(RemoveStatement, self).__init__(target=collection,
1222                                              condition=condition)
1223
1224    def sort(self, *clauses):
1225        """Sets the sorting criteria.
1226
1227        Args:
1228            *clauses: The expression strings defining the sort criteria.
1229
1230        Returns:
1231            mysqlx.FindStatement: FindStatement object.
1232        """
1233        return self._sort(*clauses)
1234
1235    def execute(self):
1236        """Execute the statement.
1237
1238        Returns:
1239            mysqlx.Result: Result object.
1240
1241        Raises:
1242            ProgrammingError: If condition was not set.
1243        """
1244        if not self.has_where:
1245            raise ProgrammingError("No condition was found for remove")
1246        return self._connection.send_delete(self)
1247
1248
1249class DeleteStatement(FilterableStatement):
1250    """A statement that drops a table.
1251
1252    Args:
1253        table (mysqlx.Table): The Table object.
1254
1255    .. versionchanged:: 8.0.12
1256       The ``condition`` parameter was removed.
1257    """
1258    def __init__(self, table):
1259        super(DeleteStatement, self).__init__(target=table, doc_based=False)
1260
1261    def where(self, condition):
1262        """Sets the search condition to filter.
1263
1264        Args:
1265            condition (str): Sets the search condition to filter records.
1266
1267        Returns:
1268            mysqlx.DeleteStatement: DeleteStatement object.
1269        """
1270        return self._set_where(condition)
1271
1272    def order_by(self, *clauses):
1273        """Sets the order by criteria.
1274
1275        Args:
1276            *clauses: The expression strings defining the order by criteria.
1277
1278        Returns:
1279            mysqlx.DeleteStatement: DeleteStatement object.
1280        """
1281        return self._sort(*clauses)
1282
1283    def execute(self):
1284        """Execute the statement.
1285
1286        Returns:
1287            mysqlx.Result: Result object.
1288
1289        Raises:
1290            ProgrammingError: If condition was not set.
1291        """
1292        if not self.has_where:
1293            raise ProgrammingError("No condition was found for delete")
1294        return self._connection.send_delete(self)
1295
1296
1297class CreateCollectionIndexStatement(Statement):
1298    """A statement that creates an index on a collection.
1299
1300    Args:
1301        collection (mysqlx.Collection): Collection.
1302        index_name (string): Index name.
1303        index_desc (dict): A dictionary containing the fields members that
1304                           constraints the index to be created. It must have
1305                           the form as shown in the following::
1306
1307                               {"fields": [{"field": member_path,
1308                                            "type": member_type,
1309                                            "required": member_required,
1310                                            "collation": collation,
1311                                            "options": options,
1312                                            "srid": srid},
1313                                            # {... more members,
1314                                            #      repeated as many times
1315                                            #      as needed}
1316                                            ],
1317                                "type": type}
1318    """
1319    def __init__(self, collection, index_name, index_desc):
1320        super(CreateCollectionIndexStatement, self).__init__(target=collection)
1321        self._index_desc = copy.deepcopy(index_desc)
1322        self._index_name = index_name
1323        self._fields_desc = self._index_desc.pop("fields", [])
1324
1325    def execute(self):
1326        """Execute the statement.
1327
1328        Returns:
1329            mysqlx.Result: Result object.
1330        """
1331        # Validate index name is a valid identifier
1332        if self._index_name is None:
1333            raise ProgrammingError(
1334                ERR_INVALID_INDEX_NAME.format(self._index_name))
1335        try:
1336            parsed_ident = ExprParser(self._index_name).expr().get_message()
1337
1338            # The message is type dict when the Protobuf cext is used
1339            if isinstance(parsed_ident, dict):
1340                if parsed_ident["type"] != mysqlxpb_enum(
1341                        "Mysqlx.Expr.Expr.Type.IDENT"):
1342                    raise ProgrammingError(
1343                        ERR_INVALID_INDEX_NAME.format(self._index_name))
1344            else:
1345                if parsed_ident.type != mysqlxpb_enum(
1346                        "Mysqlx.Expr.Expr.Type.IDENT"):
1347                    raise ProgrammingError(
1348                        ERR_INVALID_INDEX_NAME.format(self._index_name))
1349
1350        except (ValueError, AttributeError):
1351            raise ProgrammingError(
1352                ERR_INVALID_INDEX_NAME.format(self._index_name))
1353
1354        # Validate members that constraint the index
1355        if not self._fields_desc:
1356            raise ProgrammingError("Required member 'fields' not found in "
1357                                   "the given index description: {}"
1358                                   "".format(self._index_desc))
1359
1360        if not isinstance(self._fields_desc, list):
1361            raise ProgrammingError("Required member 'fields' must contain a "
1362                                   "list.")
1363
1364        args = {}
1365        args["name"] = self._index_name
1366        args["collection"] = self._target.name
1367        args["schema"] = self._target.schema.name
1368        if "type" in self._index_desc:
1369            args["type"] = self._index_desc.pop("type")
1370        else:
1371            args["type"] = "INDEX"
1372        args["unique"] = self._index_desc.pop("unique", False)
1373        # Currently unique indexes are not supported:
1374        if args["unique"]:
1375            raise NotSupportedError("Unique indexes are not supported.")
1376        args["constraint"] = []
1377
1378        if self._index_desc:
1379            raise ProgrammingError("Unidentified fields: {}"
1380                                   "".format(self._index_desc))
1381
1382        try:
1383            for field_desc in self._fields_desc:
1384                constraint = {}
1385                constraint["member"] = field_desc.pop("field")
1386                constraint["type"] = field_desc.pop("type")
1387                constraint["required"] = field_desc.pop("required", False)
1388                constraint["array"] = field_desc.pop("array", False)
1389                if not isinstance(constraint["required"], bool):
1390                    raise TypeError("Field member 'required' must be Boolean")
1391                if not isinstance(constraint["array"], bool):
1392                    raise TypeError("Field member 'array' must be Boolean")
1393                if args["type"].upper() == "SPATIAL" and \
1394                   not constraint["required"]:
1395                    raise ProgrammingError(
1396                        "Field member 'required' must be set to 'True' when "
1397                        "index type is set to 'SPATIAL'")
1398                if args["type"].upper() == "INDEX" and \
1399                   constraint["type"] == "GEOJSON":
1400                    raise ProgrammingError(
1401                        "Index 'type' must be set to 'SPATIAL' when field "
1402                        "type is set to 'GEOJSON'")
1403                if "collation" in field_desc:
1404                    if not constraint["type"].upper().startswith("TEXT"):
1405                        raise ProgrammingError(
1406                            "The 'collation' member can only be used when "
1407                            "field type is set to '{}'"
1408                            "".format(constraint["type"].upper()))
1409                    constraint["collation"] = field_desc.pop("collation")
1410                # "options" and "srid" fields in IndexField can be
1411                # present only if "type" is set to "GEOJSON"
1412                if "options" in field_desc:
1413                    if constraint["type"].upper() != "GEOJSON":
1414                        raise ProgrammingError(
1415                            "The 'options' member can only be used when "
1416                            "index type is set to 'GEOJSON'")
1417                    constraint["options"] = field_desc.pop("options")
1418                if "srid" in field_desc:
1419                    if constraint["type"].upper() != "GEOJSON":
1420                        raise ProgrammingError(
1421                            "The 'srid' member can only be used when index "
1422                            "type is set to 'GEOJSON'")
1423                    constraint["srid"] = field_desc.pop("srid")
1424                args["constraint"].append(constraint)
1425        except KeyError as err:
1426            raise ProgrammingError("Required inner member {} not found in "
1427                                   "constraint: {}".format(err, field_desc))
1428
1429        for field_desc in self._fields_desc:
1430            if field_desc:
1431                raise ProgrammingError("Unidentified inner fields:{}"
1432                                       "".format(field_desc))
1433
1434        return self._connection.execute_nonquery(
1435            "mysqlx", "create_collection_index", True, args)
1436