1# Copyright (c) 2016, 2020, Oracle and/or its affiliates.
2#
3# This program is free software; you can redistribute it and/or modify
4# it under the terms of the GNU General Public License, version 2.0, as
5# published by the Free Software Foundation.
6#
7# This program is also distributed with certain software (including
8# but not limited to OpenSSL) that is licensed under separate terms,
9# as designated in a particular file or component or in included license
10# documentation.  The authors of MySQL hereby grant you an
11# additional permission to link the program and your derivative works
12# with the separately licensed software that they have included with
13# MySQL.
14#
15# Without limiting anything contained in the foregoing, this file,
16# which is part of MySQL Connector/Python, is also subject to the
17# Universal FOSS Exception, version 1.0, a copy of which can be found at
18# http://oss.oracle.com/licenses/universal-foss-exception.
19#
20# This program is distributed in the hope that it will be useful, but
21# WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
23# See the GNU General Public License, version 2.0, for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program; if not, write to the Free Software Foundation, Inc.,
27# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA
28
29"""MySQL X DevAPI Python implementation"""
30
31import re
32import json
33import logging
34import ssl
35
36from urllib.parse import parse_qsl, unquote, urlparse
37
38try:
39    from json.decoder import JSONDecodeError
40except ImportError:
41    JSONDecodeError = ValueError
42
43from .connection import Client, Session
44from .constants import (Auth, LockContention, OPENSSL_CS_NAMES, SSLMode,
45                        TLS_VERSIONS, TLS_CIPHER_SUITES)
46from .crud import Schema, Collection, Table, View
47from .dbdoc import DbDoc
48# pylint: disable=W0622
49from .errors import (Error, InterfaceError, DatabaseError, NotSupportedError,
50                     DataError, IntegrityError, ProgrammingError,
51                     OperationalError, InternalError, PoolError, TimeoutError)
52from .result import (Column, Row, Result, BufferingResult, RowResult,
53                     SqlResult, DocResult, ColumnType)
54from .statement import (Statement, FilterableStatement, SqlStatement,
55                        FindStatement, AddStatement, RemoveStatement,
56                        ModifyStatement, SelectStatement, InsertStatement,
57                        DeleteStatement, UpdateStatement,
58                        CreateCollectionIndexStatement, Expr, ReadStatement,
59                        WriteStatement)
60from .expr import ExprParser as expr
61
62
63_SPLIT_RE = re.compile(r",(?![^\(\)]*\))")
64_PRIORITY_RE = re.compile(r"^\(address=(.+),priority=(\d+)\)$", re.VERBOSE)
65_ROUTER_RE = re.compile(r"^\(address=(.+)[,]*\)$", re.VERBOSE)
66_URI_SCHEME_RE = re.compile(r"^([a-zA-Z][a-zA-Z0-9+\-.]+)://(.*)")
67_SSL_OPTS = ["ssl-cert", "ssl-ca", "ssl-key", "ssl-crl", "tls-versions",
68             "tls-ciphersuites"]
69_SESS_OPTS = _SSL_OPTS + ["user", "password", "schema", "host", "port",
70                          "routers", "socket", "ssl-mode", "auth", "use-pure",
71                          "connect-timeout", "connection-attributes",
72                          "compression", "compression-algorithms", "dns-srv"]
73
74logging.getLogger(__name__).addHandler(logging.NullHandler())
75
76DUPLICATED_IN_LIST_ERROR = (
77    "The '{list}' list must not contain repeated values, the value "
78    "'{value}' is duplicated.")
79
80TLS_VERSION_ERROR = ("The given tls-version: '{}' is not recognized as a "
81                     "valid TLS protocol version (should be one of {}).")
82
83TLS_VER_NO_SUPPORTED = ("No supported TLS protocol version found in the "
84                        "'tls-versions' list '{}'. ")
85
86TLS_VERSIONS = ["TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"]
87
88TLS_V1_3_SUPPORTED = False
89if hasattr(ssl, "HAS_TLSv1_3") and ssl.HAS_TLSv1_3:
90    TLS_V1_3_SUPPORTED = True
91
92
93def _parse_address_list(path):
94    """Parses a list of host, port pairs
95
96    Args:
97        path: String containing a list of routers or just router
98
99    Returns:
100        Returns a dict with parsed values of host, port and priority if
101        specified.
102    """
103    path = path.replace(" ", "")
104    array = not("," not in path and path.count(":") > 1
105                and path.count("[") == 1) and path.startswith("[") \
106                and path.endswith("]")
107
108    routers = []
109    address_list = _SPLIT_RE.split(path[1:-1] if array else path)
110    priority_count = 0
111    for address in address_list:
112        router = {}
113
114        match = _PRIORITY_RE.match(address)
115        if match:
116            address = match.group(1)
117            router["priority"] = int(match.group(2))
118            priority_count += 1
119        else:
120            match = _ROUTER_RE.match(address)
121            if match:
122                address = match.group(1)
123                router["priority"] = 100
124
125        match = urlparse("//{0}".format(address))
126        if not match.hostname:
127            raise InterfaceError("Invalid address: {0}".format(address))
128
129        try:
130            router.update(host=match.hostname, port=match.port)
131        except ValueError as err:
132            raise ProgrammingError("Invalid URI: {0}".format(err), 4002)
133
134        routers.append(router)
135
136    if 0 < priority_count < len(address_list):
137        raise ProgrammingError("You must either assign no priority to any "
138                               "of the routers or give a priority for "
139                               "every router", 4000)
140
141    return {"routers": routers} if array else routers[0]
142
143
144def _parse_connection_uri(uri):
145    """Parses the connection string and returns a dictionary with the
146    connection settings.
147
148    Args:
149        uri: mysqlx URI scheme to connect to a MySQL server/farm.
150
151    Returns:
152        Returns a dict with parsed values of credentials and address of the
153        MySQL server/farm.
154
155    Raises:
156        :class:`mysqlx.InterfaceError`: If contains a duplicate option or
157                                        URI scheme is not valid.
158    """
159    settings = {"schema": ""}
160
161    match = _URI_SCHEME_RE.match(uri)
162    scheme, uri = match.groups() if match else ("mysqlx", uri)
163
164    if scheme not in ("mysqlx", "mysqlx+srv"):
165        raise InterfaceError("Scheme '{0}' is not valid".format(scheme))
166
167    if scheme == "mysqlx+srv":
168        settings["dns-srv"] = True
169
170    userinfo, tmp = uri.partition("@")[::2]
171    host, query_str = tmp.partition("?")[::2]
172
173    pos = host.rfind("/")
174    if host[pos:].find(")") == -1 and pos > 0:
175        host, settings["schema"] = host.rsplit("/", 1)
176    host = host.strip("()")
177
178    if not host or not userinfo or ":" not in userinfo:
179        raise InterfaceError("Malformed URI '{0}'".format(uri))
180    user, password = userinfo.split(":", 1)
181    settings["user"], settings["password"] = unquote(user), unquote(password)
182
183    if host.startswith(("/", "..", ".")):
184        settings["socket"] = unquote(host)
185    elif host.startswith("\\."):
186        raise InterfaceError("Windows Pipe is not supported")
187    else:
188        settings.update(_parse_address_list(host))
189
190    invalid_options = ("user", "password", "dns-srv")
191    for key, val in parse_qsl(query_str, True):
192        opt = key.replace("_", "-").lower()
193        if opt in invalid_options:
194            raise InterfaceError("Invalid option: '{0}'".format(key))
195        if opt in settings:
196            raise InterfaceError("Duplicate option: '{0}'".format(key))
197        if opt in _SSL_OPTS:
198            settings[opt] = unquote(val.strip("()"))
199        else:
200            val_str = val.lower()
201            if val_str in ("1", "true"):
202                settings[opt] = True
203            elif val_str in ("0", "false"):
204                settings[opt] = False
205            else:
206                settings[opt] = val_str
207    return settings
208
209
210def _validate_settings(settings):
211    """Validates the settings to be passed to a Session object
212    the port values are converted to int if specified or set to 33060
213    otherwise. The priority values for each router is converted to int
214    if specified.
215
216    Args:
217        settings: dict containing connection settings.
218
219    Raises:
220        :class:`mysqlx.InterfaceError`: On any configuration issue.
221    """
222    invalid_opts = set(settings.keys()).difference(_SESS_OPTS)
223    if invalid_opts:
224        raise InterfaceError("Invalid option(s): '{0}'"
225                             "".format("', '".join(invalid_opts)))
226
227    if "routers" in settings:
228        for router in settings["routers"]:
229            _validate_hosts(router, 33060)
230    elif "host" in settings:
231        _validate_hosts(settings)
232
233    if "ssl-mode" in settings:
234        try:
235            settings["ssl-mode"] = settings["ssl-mode"].lower()
236            SSLMode.index(settings["ssl-mode"])
237        except (AttributeError, ValueError):
238            raise InterfaceError("Invalid SSL Mode '{0}'"
239                                 "".format(settings["ssl-mode"]))
240        if settings["ssl-mode"] == SSLMode.DISABLED and \
241            any(key in settings for key in _SSL_OPTS):
242            raise InterfaceError("SSL options used with ssl-mode 'disabled'")
243
244    if "ssl-crl" in settings and not "ssl-ca" in settings:
245        raise InterfaceError("CA Certificate not provided")
246    if "ssl-key" in settings and not "ssl-cert" in settings:
247        raise InterfaceError("Client Certificate not provided")
248
249    if not "ssl-ca" in settings and settings.get("ssl-mode") \
250        in [SSLMode.VERIFY_IDENTITY, SSLMode.VERIFY_CA]:
251        raise InterfaceError("Cannot verify Server without CA")
252    if "ssl-ca" in settings and settings.get("ssl-mode") \
253        not in [SSLMode.VERIFY_IDENTITY, SSLMode.VERIFY_CA]:
254        raise InterfaceError("Must verify Server if CA is provided")
255
256    if "auth" in settings:
257        try:
258            settings["auth"] = settings["auth"].lower()
259            Auth.index(settings["auth"])
260        except (AttributeError, ValueError):
261            raise InterfaceError("Invalid Auth '{0}'".format(settings["auth"]))
262
263    if "compression" in settings:
264        compression = settings["compression"].lower().strip()
265        if compression not in ("preferred", "required", "disabled"):
266            raise InterfaceError(
267                "The connection property 'compression' acceptable values are: "
268                "'preferred', 'required', or 'disabled'. The value '{0}' is "
269                "not acceptable".format(settings["compression"]))
270        settings["compression"] = compression
271
272    if "compression-algorithms" in settings:
273        if isinstance(settings["compression-algorithms"], str):
274            compression_algorithms = \
275                settings["compression-algorithms"].strip().strip("[]")
276            if compression_algorithms:
277                settings["compression-algorithms"] = \
278                    compression_algorithms.split(",")
279            else:
280                settings["compression-algorithms"] = None
281        elif not isinstance(settings["compression-algorithms"], (list, tuple)):
282            raise InterfaceError("Invalid type of the connection property "
283                                 "'compression-algorithms'")
284        if settings.get("compression") == "disabled":
285            settings["compression-algorithms"] = None
286
287    if "connection-attributes" in settings:
288        _validate_connection_attributes(settings)
289
290    if "connect-timeout" in settings:
291        try:
292            if isinstance(settings["connect-timeout"], str):
293                settings["connect-timeout"] = int(settings["connect-timeout"])
294            if not isinstance(settings["connect-timeout"], int) \
295               or settings["connect-timeout"] < 0:
296                raise ValueError
297        except ValueError:
298            raise TypeError("The connection timeout value must be a positive "
299                            "integer (including 0)")
300
301    if "dns-srv" in settings:
302        if not isinstance(settings["dns-srv"], bool):
303            raise InterfaceError("The value of 'dns-srv' must be a boolean")
304        if settings.get("socket"):
305            raise InterfaceError("Using Unix domain sockets with DNS SRV "
306                                 "lookup is not allowed")
307        if settings.get("port"):
308            raise InterfaceError("Specifying a port number with DNS SRV "
309                                 "lookup is not allowed")
310        if settings.get("routers"):
311            raise InterfaceError("Specifying multiple hostnames with DNS "
312                                 "SRV look up is not allowed")
313    elif "host" in settings and not settings.get("port"):
314        settings["port"] = 33060
315
316    if "tls-versions" in settings:
317        _validate_tls_versions(settings)
318
319    if "tls-ciphersuites" in settings:
320        _validate_tls_ciphersuites(settings)
321
322
323def _validate_hosts(settings, default_port=None):
324    """Validate hosts.
325
326    Args:
327        settings (dict): Settings dictionary.
328        default_port (int): Default connection port.
329
330    Raises:
331        :class:`mysqlx.InterfaceError`: If priority or port are invalid.
332    """
333    if "priority" in settings and settings["priority"]:
334        try:
335            settings["priority"] = int(settings["priority"])
336            if settings["priority"] < 0 or settings["priority"] > 100:
337                raise ProgrammingError("Invalid priority value, "
338                                       "must be between 0 and 100", 4007)
339        except NameError:
340            raise ProgrammingError("Invalid priority", 4007)
341        except ValueError:
342            raise ProgrammingError(
343                "Invalid priority: {}".format(settings["priority"]), 4007)
344
345    if "port" in settings and settings["port"]:
346        try:
347            settings["port"] = int(settings["port"])
348        except NameError:
349            raise InterfaceError("Invalid port")
350    elif "host" in settings and default_port:
351        settings["port"] = default_port
352
353
354def _validate_connection_attributes(settings):
355    """Validate connection-attributes.
356
357    Args:
358        settings (dict): Settings dictionary.
359
360    Raises:
361        :class:`mysqlx.InterfaceError`: If attribute name or value exceeds size.
362    """
363    attributes = {}
364    if "connection-attributes" not in settings:
365        return
366
367    conn_attrs = settings["connection-attributes"]
368
369    if isinstance(conn_attrs, str):
370        if conn_attrs == "":
371            settings["connection-attributes"] = {}
372            return
373        if not (conn_attrs.startswith("[") and conn_attrs.endswith("]")) and \
374           not conn_attrs in ['False', "false", "True", "true"]:
375            raise InterfaceError("The value of 'connection-attributes' must "
376                                 "be a boolean or a list of key-value pairs, "
377                                 "found: '{}'".format(conn_attrs))
378        elif conn_attrs in ['False', "false", "True", "true"]:
379            if conn_attrs in ['False', "false"]:
380                settings["connection-attributes"] = False
381            else:
382                settings["connection-attributes"] = {}
383            return
384        else:
385            conn_attributes = conn_attrs[1:-1].split(",")
386            for attr in conn_attributes:
387                if attr == "":
388                    continue
389                attr_name_val = attr.split('=')
390                attr_name = attr_name_val[0]
391                attr_val = attr_name_val[1] if len(attr_name_val) > 1 else ""
392                if attr_name in attributes:
393                    raise InterfaceError("Duplicate key '{}' used in "
394                                         "connection-attributes"
395                                         "".format(attr_name))
396                else:
397                    attributes[attr_name] = attr_val
398    elif isinstance(conn_attrs, dict):
399        for attr_name in conn_attrs:
400            attr_value = conn_attrs[attr_name]
401            if not isinstance(attr_value, str):
402                attr_value = repr(attr_value)
403            attributes[attr_name] = attr_value
404    elif isinstance(conn_attrs, bool) or conn_attrs in [0, 1]:
405        if conn_attrs:
406            settings["connection-attributes"] = {}
407        else:
408            settings["connection-attributes"] = False
409        return
410    elif isinstance(conn_attrs, set):
411        for attr_name in conn_attrs:
412            attributes[attr_name] = ""
413    elif isinstance(conn_attrs, list):
414        for attr in conn_attrs:
415            if attr == "":
416                continue
417            attr_name_val = attr.split('=')
418            attr_name = attr_name_val[0]
419            attr_val = attr_name_val[1] if len(attr_name_val) > 1 else ""
420            if attr_name in attributes:
421                raise InterfaceError("Duplicate key '{}' used in "
422                                     "connection-attributes"
423                                     "".format(attr_name))
424            else:
425                attributes[attr_name] = attr_val
426    elif not isinstance(conn_attrs, bool):
427        raise InterfaceError("connection-attributes must be Boolean or a list "
428                             "of key-value pairs, found: '{}'"
429                             "".format(conn_attrs))
430
431    if attributes:
432        for attr_name in attributes:
433            attr_value = attributes[attr_name]
434
435            # Validate name type
436            if not isinstance(attr_name, str):
437                raise InterfaceError("Attribute name '{}' must be a string"
438                                     "type".format(attr_name))
439            # Validate attribute name limit 32 characters
440            if len(attr_name) > 32:
441                raise InterfaceError("Attribute name '{}' exceeds 32 "
442                                     "characters limit size".format(attr_name))
443            # Validate names in connection-attributes cannot start with "_"
444            if attr_name.startswith("_"):
445                raise InterfaceError("Key names in connection-attributes "
446                                     "cannot start with '_', found: '{}'"
447                                     "".format(attr_name))
448
449            # Validate value type
450            if not isinstance(attr_value, str):
451                raise InterfaceError("Attribute '{}' value: '{}' must "
452                                     "be a string type"
453                                     "".format(attr_name, attr_value))
454            # Validate attribute value limit 1024 characters
455            if len(attr_value) > 1024:
456                raise InterfaceError("Attribute '{}' value: '{}' "
457                                     "exceeds 1024 characters limit size"
458                                     "".format(attr_name, attr_value))
459
460    settings["connection-attributes"] = attributes
461
462
463def _validate_tls_versions(settings):
464    """Validate tls-versions.
465
466    Args:
467        settings (dict): Settings dictionary.
468
469    Raises:
470        :class:`mysqlx.InterfaceError`: If tls-versions name is not valid.
471    """
472    tls_versions = []
473    if "tls-versions" not in settings:
474        return
475
476    tls_versions_settings = settings["tls-versions"]
477
478    if isinstance(tls_versions_settings, str):
479        if not (tls_versions_settings.startswith("[") and
480                tls_versions_settings.endswith("]")):
481            raise InterfaceError("tls-versions must be a list, found: '{}'"
482                                 "".format(tls_versions_settings))
483        else:
484            tls_vers = tls_versions_settings[1:-1].split(",")
485            for tls_ver in tls_vers:
486                tls_version = tls_ver.strip()
487                if tls_version == "":
488                    continue
489                if tls_version not in TLS_VERSIONS:
490                    raise InterfaceError(TLS_VERSION_ERROR.format(tls_version,
491                                                                  TLS_VERSIONS))
492                else:
493                    if tls_version in tls_versions:
494                        raise InterfaceError(
495                            DUPLICATED_IN_LIST_ERROR.format(
496                                list="tls_versions", value=tls_version))
497                    tls_versions.append(tls_version)
498    elif isinstance(tls_versions_settings, list):
499        if not tls_versions_settings:
500            raise InterfaceError("At least one TLS protocol version must be "
501                                 "specified in 'tls-versions' list.")
502        for tls_ver in tls_versions_settings:
503            if tls_ver not in TLS_VERSIONS:
504                raise InterfaceError(TLS_VERSION_ERROR.format(tls_ver,
505                                                              TLS_VERSIONS))
506            elif tls_ver in tls_versions:
507                raise InterfaceError(
508                    DUPLICATED_IN_LIST_ERROR.format(list="tls_versions",
509                                                    value=tls_ver))
510            else:
511                tls_versions.append(tls_ver)
512
513    elif isinstance(tls_versions_settings, set):
514        for tls_ver in tls_versions_settings:
515            if tls_ver not in TLS_VERSIONS:
516                raise InterfaceError(TLS_VERSION_ERROR.format(tls_ver,
517                                                              TLS_VERSIONS))
518            else:
519                tls_versions.append(tls_ver)
520    else:
521        raise InterfaceError("tls-versions should be a list with one or more "
522                             "of versions in {}. found: '{}'"
523                             "".format(", ".join(TLS_VERSIONS), tls_versions))
524
525    if not tls_versions:
526        raise InterfaceError("At least one TLS protocol version must be "
527                             "specified in 'tls-versions' list.")
528
529    if tls_versions == ["TLSv1.3"] and not TLS_V1_3_SUPPORTED:
530            raise InterfaceError(
531                TLS_VER_NO_SUPPORTED.format(tls_versions, TLS_VERSIONS))
532
533    settings["tls-versions"] = tls_versions
534
535
536def _validate_tls_ciphersuites(settings):
537    """Validate tls-ciphersuites.
538
539    Args:
540        settings (dict): Settings dictionary.
541
542    Raises:
543        :class:`mysqlx.InterfaceError`: If tls-ciphersuites name is not valid.
544    """
545    tls_ciphersuites = []
546    if "tls-ciphersuites" not in settings:
547        return
548
549    tls_ciphersuites_settings = settings["tls-ciphersuites"]
550
551    if isinstance(tls_ciphersuites_settings, str):
552        if not (tls_ciphersuites_settings.startswith("[") and
553                tls_ciphersuites_settings.endswith("]")):
554            raise InterfaceError("tls-ciphersuites must be a list, found: '{}'"
555                                 "".format(tls_ciphersuites_settings))
556        else:
557            tls_css = tls_ciphersuites_settings[1:-1].split(",")
558            if not tls_css:
559                raise InterfaceError("No valid cipher suite found in the "
560                                     "'tls-ciphersuites' list.")
561            for tls_cs in tls_css:
562                tls_cs = tls_cs.strip().upper()
563                if tls_cs:
564                    tls_ciphersuites.append(tls_cs)
565    elif isinstance(tls_ciphersuites_settings, list):
566        tls_ciphersuites = [tls_cs for tls_cs in tls_ciphersuites_settings
567                            if tls_cs]
568
569    elif isinstance(tls_ciphersuites_settings, set):
570        for tls_cs in tls_ciphersuites:
571            if tls_cs:
572                tls_ciphersuites.append(tls_cs)
573    else:
574        raise InterfaceError("tls-ciphersuites should be a list with one or "
575                             "more ciphersuites. Found: '{}'"
576                             "".format(tls_ciphersuites_settings))
577
578    tls_versions = TLS_VERSIONS[:] if settings.get("tls-versions", None) \
579       is None else settings["tls-versions"][:]
580
581    # A newer TLS version can use a cipher introduced on
582    # an older version.
583    tls_versions.sort(reverse=True)
584    newer_tls_ver = tls_versions[0]
585
586    translated_names = []
587    iani_cipher_suites_names = {}
588    ossl_cipher_suites_names = []
589
590    # Old ciphers can work with new TLS versions.
591    # Find all the ciphers introduced on previous TLS versions
592    for tls_ver in TLS_VERSIONS[:TLS_VERSIONS.index(newer_tls_ver) + 1]:
593        iani_cipher_suites_names.update(TLS_CIPHER_SUITES[tls_ver])
594        ossl_cipher_suites_names.extend(OPENSSL_CS_NAMES[tls_ver])
595
596    for name in tls_ciphersuites:
597        if "-" in name and name in ossl_cipher_suites_names:
598            translated_names.append(name)
599        elif name in iani_cipher_suites_names:
600            translated_name = iani_cipher_suites_names[name]
601            if translated_name in translated_names:
602                raise AttributeError(
603                    DUPLICATED_IN_LIST_ERROR.format(
604                        list="tls_ciphersuites", value=translated_name))
605            else:
606                translated_names.append(translated_name)
607        else:
608            raise InterfaceError(
609                "The value '{}' in cipher suites is not a valid "
610                "cipher suite".format(name))
611
612    if not translated_names:
613        raise InterfaceError("No valid cipher suite found in the "
614                             "'tls-ciphersuites' list.")
615
616    settings["tls-ciphersuites"] = translated_names
617
618
619def _get_connection_settings(*args, **kwargs):
620    """Parses the connection string and returns a dictionary with the
621    connection settings.
622
623    Args:
624        *args: Variable length argument list with the connection data used
625               to connect to the database. It can be a dictionary or a
626               connection string.
627        **kwargs: Arbitrary keyword arguments with connection data used to
628                  connect to the database.
629
630    Returns:
631        mysqlx.Session: Session object.
632
633    Raises:
634        TypeError: If connection timeout is not a positive integer.
635        :class:`mysqlx.InterfaceError`: If settings not provided.
636    """
637    settings = {}
638    if args:
639        if isinstance(args[0], str):
640            settings = _parse_connection_uri(args[0])
641        elif isinstance(args[0], dict):
642            for key, val in args[0].items():
643                settings[key.replace("_", "-")] = val
644    elif kwargs:
645        for key, val in kwargs.items():
646            settings[key.replace("_", "-")] = val
647
648    if not settings:
649        raise InterfaceError("Settings not provided")
650
651    _validate_settings(settings)
652    return settings
653
654
655def get_session(*args, **kwargs):
656    """Creates a Session instance using the provided connection data.
657
658    Args:
659        *args: Variable length argument list with the connection data used
660               to connect to a MySQL server. It can be a dictionary or a
661               connection string.
662        **kwargs: Arbitrary keyword arguments with connection data used to
663                  connect to the database.
664
665    Returns:
666        mysqlx.Session: Session object.
667    """
668    settings = _get_connection_settings(*args, **kwargs)
669    return Session(settings)
670
671
672def get_client(connection_string, options_string):
673    """Creates a Client instance with the provided connection data and settings.
674
675    Args:
676        connection_string: A string or a dict type object to indicate the \
677            connection data used to connect to a MySQL server.
678
679            The string must have the following uri format::
680
681                cnx_str = 'mysqlx://{user}:{pwd}@{host}:{port}'
682                cnx_str = ('mysqlx://{user}:{pwd}@['
683                           '    (address={host}:{port}, priority=n),'
684                           '    (address={host}:{port}, priority=n), ...]'
685                           '       ?[option=value]')
686
687            And the dictionary::
688
689                cnx_dict = {
690                    'host': 'The host where the MySQL product is running',
691                    'port': '(int) the port number configured for X protocol',
692                    'user': 'The user name account',
693                    'password': 'The password for the given user account',
694                    'ssl-mode': 'The flags for ssl mode in mysqlx.SSLMode.FLAG',
695                    'ssl-ca': 'The path to the ca.cert'
696                    "connect-timeout": '(int) milliseconds to wait on timeout'
697                }
698
699        options_string: A string in the form of a document or a dictionary \
700            type with configuration for the client.
701
702            Current options include::
703
704                options = {
705                    'pooling': {
706                        'enabled': (bool), # [True | False], True by default
707                        'max_size': (int), # Maximum connections per pool
708                        "max_idle_time": (int), # milliseconds that a
709                            # connection will remain active while not in use.
710                            # By default 0, means infinite.
711                        "queue_timeout": (int), # milliseconds a request will
712                            # wait for a connection to become available.
713                            # By default 0, means infinite.
714                    }
715                }
716
717    Returns:
718        mysqlx.Client: Client object.
719
720    .. versionadded:: 8.0.13
721    """
722    if not isinstance(connection_string, (str, dict)):
723        raise InterfaceError("connection_data must be a string or dict")
724
725    settings_dict = _get_connection_settings(connection_string)
726
727    if not isinstance(options_string, (str, dict)):
728        raise InterfaceError("connection_options must be a string or dict")
729
730    if isinstance(options_string, str):
731        try:
732            options_dict = json.loads(options_string)
733        except JSONDecodeError:
734            raise InterfaceError("'pooling' options must be given in the form "
735                                 "of a document or dict")
736    else:
737        options_dict = {}
738        for key, value in options_string.items():
739            options_dict[key.replace("-", "_")] = value
740
741    if not isinstance(options_dict, dict):
742        raise InterfaceError("'pooling' options must be given in the form of a "
743                             "document or dict")
744    pooling_options_dict = {}
745    if "pooling" in options_dict:
746        pooling_options = options_dict.pop("pooling")
747        if not isinstance(pooling_options, (dict)):
748            raise InterfaceError("'pooling' options must be given in the form "
749                                 "document or dict")
750        # Fill default pooling settings
751        pooling_options_dict["enabled"] = pooling_options.pop("enabled", True)
752        pooling_options_dict["max_size"] = pooling_options.pop("max_size", 25)
753        pooling_options_dict["max_idle_time"] = \
754            pooling_options.pop("max_idle_time", 0)
755        pooling_options_dict["queue_timeout"] = \
756            pooling_options.pop("queue_timeout", 0)
757
758        # No other options besides pooling are supported
759        if len(pooling_options) > 0:
760            raise InterfaceError("Unrecognized pooling options: {}"
761                                 "".format(pooling_options))
762        # No other options besides pooling are supported
763        if len(options_dict) > 0:
764            raise InterfaceError("Unrecognized connection options: {}"
765                                 "".format(options_dict.keys()))
766
767    return Client(settings_dict, pooling_options_dict)
768
769
770__all__ = [
771    # mysqlx.connection
772    "Client", "Session", "get_client", "get_session", "expr",
773
774    # mysqlx.constants
775    "Auth", "LockContention", "SSLMode",
776
777    # mysqlx.crud
778    "Schema", "Collection", "Table", "View",
779
780    # mysqlx.errors
781    "Error", "InterfaceError", "DatabaseError", "NotSupportedError",
782    "DataError", "IntegrityError", "ProgrammingError", "OperationalError",
783    "InternalError", "PoolError", "TimeoutError",
784
785    # mysqlx.result
786    "Column", "Row", "Result", "BufferingResult", "RowResult",
787    "SqlResult", "DocResult", "ColumnType",
788
789    # mysqlx.statement
790    "DbDoc", "Statement", "FilterableStatement", "SqlStatement",
791    "FindStatement", "AddStatement", "RemoveStatement", "ModifyStatement",
792    "SelectStatement", "InsertStatement", "DeleteStatement", "UpdateStatement",
793    "ReadStatement", "WriteStatement", "CreateCollectionIndexStatement",
794    "Expr",
795]
796