1# mysql/mariadbconnector.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""
9
10.. dialect:: mysql+mariadbconnector
11    :name: MariaDB Connector/Python
12    :dbapi: mariadb
13    :connectstring: mariadb+mariadbconnector://<user>:<password>@<host>[:<port>]/<dbname>
14    :url: https://pypi.org/project/mariadb/
15
16Driver Status
17-------------
18
19MariaDB Connector/Python enables Python programs to access MariaDB and MySQL
20databases using an API which is compliant with the Python DB API 2.0 (PEP-249).
21It is written in C and uses MariaDB Connector/C client library for client server
22communication.
23
24Note that the default driver for a ``mariadb://`` connection URI continues to
25be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver.
26
27.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python
28
29"""  # noqa
30import re
31
32from .base import MySQLCompiler
33from .base import MySQLDialect
34from .base import MySQLExecutionContext
35from ... import sql
36from ... import util
37
38mariadb_cpy_minimum_version = (1, 0, 1)
39
40
41class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
42    def create_server_side_cursor(self):
43        return self._dbapi_connection.cursor(buffered=False)
44
45    def create_default_cursor(self):
46        return self._dbapi_connection.cursor(buffered=True)
47
48
49class MySQLCompiler_mariadbconnector(MySQLCompiler):
50    pass
51
52
53class MySQLDialect_mariadbconnector(MySQLDialect):
54    driver = "mariadbconnector"
55    supports_statement_cache = True
56
57    # set this to True at the module level to prevent the driver from running
58    # against a backend that server detects as MySQL. currently this appears to
59    # be unnecessary as MariaDB client libraries have always worked against
60    # MySQL databases.   However, if this changes at some point, this can be
61    # adjusted, but PLEASE ADD A TEST in test/dialect/mysql/test_dialect.py if
62    # this change is made at some point to ensure the correct exception
63    # is raised at the correct point when running the driver against
64    # a MySQL backend.
65    # is_mariadb = True
66
67    supports_unicode_statements = True
68    encoding = "utf8mb4"
69    convert_unicode = True
70    supports_sane_rowcount = True
71    supports_sane_multi_rowcount = True
72    supports_native_decimal = True
73    default_paramstyle = "qmark"
74    execution_ctx_cls = MySQLExecutionContext_mariadbconnector
75    statement_compiler = MySQLCompiler_mariadbconnector
76
77    supports_server_side_cursors = True
78
79    @util.memoized_property
80    def _dbapi_version(self):
81        if self.dbapi and hasattr(self.dbapi, "__version__"):
82            return tuple(
83                [
84                    int(x)
85                    for x in re.findall(
86                        r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
87                    )
88                ]
89            )
90        else:
91            return (99, 99, 99)
92
93    def __init__(self, **kwargs):
94        super(MySQLDialect_mariadbconnector, self).__init__(**kwargs)
95        self.paramstyle = "qmark"
96        if self.dbapi is not None:
97            if self._dbapi_version < mariadb_cpy_minimum_version:
98                raise NotImplementedError(
99                    "The minimum required version for MariaDB "
100                    "Connector/Python is %s"
101                    % ".".join(str(x) for x in mariadb_cpy_minimum_version)
102                )
103
104    @classmethod
105    def dbapi(cls):
106        return __import__("mariadb")
107
108    def is_disconnect(self, e, connection, cursor):
109        if super(MySQLDialect_mariadbconnector, self).is_disconnect(
110            e, connection, cursor
111        ):
112            return True
113        elif isinstance(e, self.dbapi.Error):
114            str_e = str(e).lower()
115            return "not connected" in str_e or "isn't valid" in str_e
116        else:
117            return False
118
119    def create_connect_args(self, url):
120        opts = url.translate_connect_args()
121
122        int_params = [
123            "connect_timeout",
124            "read_timeout",
125            "write_timeout",
126            "client_flag",
127            "port",
128            "pool_size",
129        ]
130        bool_params = [
131            "local_infile",
132            "ssl_verify_cert",
133            "ssl",
134            "pool_reset_connection",
135        ]
136
137        for key in int_params:
138            util.coerce_kw_type(opts, key, int)
139        for key in bool_params:
140            util.coerce_kw_type(opts, key, bool)
141
142        # FOUND_ROWS must be set in CLIENT_FLAGS to enable
143        # supports_sane_rowcount.
144        client_flag = opts.get("client_flag", 0)
145        if self.dbapi is not None:
146            try:
147                CLIENT_FLAGS = __import__(
148                    self.dbapi.__name__ + ".constants.CLIENT"
149                ).constants.CLIENT
150                client_flag |= CLIENT_FLAGS.FOUND_ROWS
151            except (AttributeError, ImportError):
152                self.supports_sane_rowcount = False
153            opts["client_flag"] = client_flag
154        return [[], opts]
155
156    def _extract_error_code(self, exception):
157        try:
158            rc = exception.errno
159        except:
160            rc = -1
161        return rc
162
163    def _detect_charset(self, connection):
164        return "utf8mb4"
165
166    _isolation_lookup = set(
167        [
168            "SERIALIZABLE",
169            "READ UNCOMMITTED",
170            "READ COMMITTED",
171            "REPEATABLE READ",
172            "AUTOCOMMIT",
173        ]
174    )
175
176    def _set_isolation_level(self, connection, level):
177        if level == "AUTOCOMMIT":
178            connection.autocommit = True
179        else:
180            connection.autocommit = False
181            super(MySQLDialect_mariadbconnector, self)._set_isolation_level(
182                connection, level
183            )
184
185    def do_begin_twophase(self, connection, xid):
186        connection.execute(
187            sql.text("XA BEGIN :xid").bindparams(
188                sql.bindparam("xid", xid, literal_execute=True)
189            )
190        )
191
192    def do_prepare_twophase(self, connection, xid):
193        connection.execute(
194            sql.text("XA END :xid").bindparams(
195                sql.bindparam("xid", xid, literal_execute=True)
196            )
197        )
198        connection.execute(
199            sql.text("XA PREPARE :xid").bindparams(
200                sql.bindparam("xid", xid, literal_execute=True)
201            )
202        )
203
204    def do_rollback_twophase(
205        self, connection, xid, is_prepared=True, recover=False
206    ):
207        if not is_prepared:
208            connection.execute(
209                sql.text("XA END :xid").bindparams(
210                    sql.bindparam("xid", xid, literal_execute=True)
211                )
212            )
213        connection.execute(
214            sql.text("XA ROLLBACK :xid").bindparams(
215                sql.bindparam("xid", xid, literal_execute=True)
216            )
217        )
218
219    def do_commit_twophase(
220        self, connection, xid, is_prepared=True, recover=False
221    ):
222        if not is_prepared:
223            self.do_prepare_twophase(connection, xid)
224        connection.execute(
225            sql.text("XA COMMIT :xid").bindparams(
226                sql.bindparam("xid", xid, literal_execute=True)
227            )
228        )
229
230
231dialect = MySQLDialect_mariadbconnector
232