1# mysql/mysqldb.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""
9
10.. dialect:: mysql+mysqldb
11    :name: mysqlclient (maintained fork of MySQL-Python)
12    :dbapi: mysqldb
13    :connectstring: mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
14    :url: https://pypi.org/project/mysqlclient/
15
16Driver Status
17-------------
18
19The mysqlclient DBAPI is a maintained fork of the
20`MySQL-Python <http://sourceforge.net/projects/mysql-python>`_ DBAPI
21that is no longer maintained.  `mysqlclient`_ supports Python 2 and Python 3
22and is very stable.
23
24.. _mysqlclient: https://github.com/PyMySQL/mysqlclient-python
25
26.. _mysqldb_unicode:
27
28Unicode
29-------
30
31Please see :ref:`mysql_unicode` for current recommendations on unicode
32handling.
33
34
35Using MySQLdb with Google Cloud SQL
36-----------------------------------
37
38Google Cloud SQL now recommends use of the MySQLdb dialect.  Connect
39using a URL like the following::
40
41    mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>
42
43Server Side Cursors
44-------------------
45
46The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
47
48"""
49
50import re
51
52from .base import MySQLCompiler
53from .base import MySQLDialect
54from .base import MySQLExecutionContext
55from .base import MySQLIdentifierPreparer
56from .base import TEXT
57from ... import sql
58from ... import util
59
60
61class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
62    @property
63    def rowcount(self):
64        if hasattr(self, "_rowcount"):
65            return self._rowcount
66        else:
67            return self.cursor.rowcount
68
69
70class MySQLCompiler_mysqldb(MySQLCompiler):
71    pass
72
73
74class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer):
75    pass
76
77
78class MySQLDialect_mysqldb(MySQLDialect):
79    driver = "mysqldb"
80    supports_unicode_statements = True
81    supports_sane_rowcount = True
82    supports_sane_multi_rowcount = True
83
84    supports_native_decimal = True
85
86    default_paramstyle = "format"
87    execution_ctx_cls = MySQLExecutionContext_mysqldb
88    statement_compiler = MySQLCompiler_mysqldb
89    preparer = MySQLIdentifierPreparer_mysqldb
90
91    def __init__(self, server_side_cursors=False, **kwargs):
92        super(MySQLDialect_mysqldb, self).__init__(**kwargs)
93        self.server_side_cursors = server_side_cursors
94        self._mysql_dbapi_version = (
95            self._parse_dbapi_version(self.dbapi.__version__)
96            if self.dbapi is not None and hasattr(self.dbapi, "__version__")
97            else (0, 0, 0)
98        )
99
100    def _parse_dbapi_version(self, version):
101        m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
102        if m:
103            return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
104        else:
105            return (0, 0, 0)
106
107    @util.langhelpers.memoized_property
108    def supports_server_side_cursors(self):
109        try:
110            cursors = __import__("MySQLdb.cursors").cursors
111            self._sscursor = cursors.SSCursor
112            return True
113        except (ImportError, AttributeError):
114            return False
115
116    @classmethod
117    def dbapi(cls):
118        return __import__("MySQLdb")
119
120    def on_connect(self):
121        super_ = super(MySQLDialect_mysqldb, self).on_connect()
122
123        def on_connect(conn):
124            if super_ is not None:
125                super_(conn)
126
127            charset_name = conn.character_set_name()
128
129            if charset_name is not None:
130                cursor = conn.cursor()
131                cursor.execute("SET NAMES %s" % charset_name)
132                cursor.close()
133
134        return on_connect
135
136    def do_ping(self, dbapi_connection):
137        try:
138            dbapi_connection.ping(False)
139        except self.dbapi.Error as err:
140            if self.is_disconnect(err, dbapi_connection, None):
141                return False
142            else:
143                raise
144        else:
145            return True
146
147    def do_executemany(self, cursor, statement, parameters, context=None):
148        rowcount = cursor.executemany(statement, parameters)
149        if context is not None:
150            context._rowcount = rowcount
151
152    def _check_unicode_returns(self, connection):
153        # work around issue fixed in
154        # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
155        # specific issue w/ the utf8mb4_bin collation and unicode returns
156
157        has_utf8mb4_bin = self.server_version_info > (
158            5,
159        ) and connection.scalar(
160            "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
161            % (
162                self.identifier_preparer.quote("Charset"),
163                self.identifier_preparer.quote("Collation"),
164            )
165        )
166        if has_utf8mb4_bin:
167            additional_tests = [
168                sql.collate(
169                    sql.cast(
170                        sql.literal_column("'test collated returns'"),
171                        TEXT(charset="utf8mb4"),
172                    ),
173                    "utf8mb4_bin",
174                )
175            ]
176        else:
177            additional_tests = []
178        return super(MySQLDialect_mysqldb, self)._check_unicode_returns(
179            connection, additional_tests
180        )
181
182    def create_connect_args(self, url, _translate_args=None):
183        if _translate_args is None:
184            _translate_args = dict(
185                database="db", username="user", password="passwd"
186            )
187
188        opts = url.translate_connect_args(**_translate_args)
189        opts.update(url.query)
190
191        util.coerce_kw_type(opts, "compress", bool)
192        util.coerce_kw_type(opts, "connect_timeout", int)
193        util.coerce_kw_type(opts, "read_timeout", int)
194        util.coerce_kw_type(opts, "write_timeout", int)
195        util.coerce_kw_type(opts, "client_flag", int)
196        util.coerce_kw_type(opts, "local_infile", int)
197        # Note: using either of the below will cause all strings to be
198        # returned as Unicode, both in raw SQL operations and with column
199        # types like String and MSString.
200        util.coerce_kw_type(opts, "use_unicode", bool)
201        util.coerce_kw_type(opts, "charset", str)
202
203        # Rich values 'cursorclass' and 'conv' are not supported via
204        # query string.
205
206        ssl = {}
207        keys = ["ssl_ca", "ssl_key", "ssl_cert", "ssl_capath", "ssl_cipher"]
208        for key in keys:
209            if key in opts:
210                ssl[key[4:]] = opts[key]
211                util.coerce_kw_type(ssl, key[4:], str)
212                del opts[key]
213        if ssl:
214            opts["ssl"] = ssl
215
216        # FOUND_ROWS must be set in CLIENT_FLAGS to enable
217        # supports_sane_rowcount.
218        client_flag = opts.get("client_flag", 0)
219        if self.dbapi is not None:
220            try:
221                CLIENT_FLAGS = __import__(
222                    self.dbapi.__name__ + ".constants.CLIENT"
223                ).constants.CLIENT
224                client_flag |= CLIENT_FLAGS.FOUND_ROWS
225            except (AttributeError, ImportError):
226                self.supports_sane_rowcount = False
227            opts["client_flag"] = client_flag
228        return [[], opts]
229
230    def _extract_error_code(self, exception):
231        return exception.args[0]
232
233    def _detect_charset(self, connection):
234        """Sniff out the character set in use for connection results."""
235
236        try:
237            # note: the SQL here would be
238            # "SHOW VARIABLES LIKE 'character_set%%'"
239            cset_name = connection.connection.character_set_name
240        except AttributeError:
241            util.warn(
242                "No 'character_set_name' can be detected with "
243                "this MySQL-Python version; "
244                "please upgrade to a recent version of MySQL-Python.  "
245                "Assuming latin1."
246            )
247            return "latin1"
248        else:
249            return cset_name()
250
251    _isolation_lookup = set(
252        [
253            "SERIALIZABLE",
254            "READ UNCOMMITTED",
255            "READ COMMITTED",
256            "REPEATABLE READ",
257            "AUTOCOMMIT",
258        ]
259    )
260
261    def _set_isolation_level(self, connection, level):
262        if level == "AUTOCOMMIT":
263            connection.autocommit(True)
264        else:
265            connection.autocommit(False)
266            super(MySQLDialect_mysqldb, self)._set_isolation_level(
267                connection, level
268            )
269
270
271dialect = MySQLDialect_mysqldb
272