1# connectors/pyodbc.py
2# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8from . import Connector
9from .. import util
10
11
12import sys
13import re
14
15
16class PyODBCConnector(Connector):
17    driver = 'pyodbc'
18
19    supports_sane_multi_rowcount = False
20
21    if util.py2k:
22        # PyODBC unicode is broken on UCS-4 builds
23        supports_unicode = sys.maxunicode == 65535
24        supports_unicode_statements = supports_unicode
25
26    supports_native_decimal = True
27    default_paramstyle = 'named'
28
29    # for non-DSN connections, this *may* be used to
30    # hold the desired driver name
31    pyodbc_driver_name = None
32
33    # will be set to True after initialize()
34    # if the freetds.so is detected
35    freetds = False
36
37    # will be set to the string version of
38    # the FreeTDS driver if freetds is detected
39    freetds_driver_version = None
40
41    # will be set to True after initialize()
42    # if the libessqlsrv.so is detected
43    easysoft = False
44
45    def __init__(self, supports_unicode_binds=None, **kw):
46        super(PyODBCConnector, self).__init__(**kw)
47        self._user_supports_unicode_binds = supports_unicode_binds
48
49    @classmethod
50    def dbapi(cls):
51        return __import__('pyodbc')
52
53    def create_connect_args(self, url):
54        opts = url.translate_connect_args(username='user')
55        opts.update(url.query)
56
57        keys = opts
58        query = url.query
59
60        connect_args = {}
61        for param in ('ansi', 'unicode_results', 'autocommit'):
62            if param in keys:
63                connect_args[param] = util.asbool(keys.pop(param))
64
65        if 'odbc_connect' in keys:
66            connectors = [util.unquote_plus(keys.pop('odbc_connect'))]
67        else:
68            dsn_connection = 'dsn' in keys or \
69                ('host' in keys and 'database' not in keys)
70            if dsn_connection:
71                connectors = ['dsn=%s' % (keys.pop('host', '') or
72                                          keys.pop('dsn', ''))]
73            else:
74                port = ''
75                if 'port' in keys and 'port' not in query:
76                    port = ',%d' % int(keys.pop('port'))
77
78                connectors = []
79                driver = keys.pop('driver', self.pyodbc_driver_name)
80                if driver is None:
81                    util.warn(
82                        "No driver name specified; "
83                        "this is expected by PyODBC when using "
84                        "DSN-less connections")
85                else:
86                    connectors.append("DRIVER={%s}" % driver)
87
88                connectors.extend(
89                    [
90                        'Server=%s%s' % (keys.pop('host', ''), port),
91                        'Database=%s' % keys.pop('database', '')
92                    ])
93
94            user = keys.pop("user", None)
95            if user:
96                connectors.append("UID=%s" % user)
97                connectors.append("PWD=%s" % keys.pop('password', ''))
98            else:
99                connectors.append("Trusted_Connection=Yes")
100
101            # if set to 'Yes', the ODBC layer will try to automagically
102            # convert textual data from your database encoding to your
103            # client encoding.  This should obviously be set to 'No' if
104            # you query a cp1253 encoded database from a latin1 client...
105            if 'odbc_autotranslate' in keys:
106                connectors.append("AutoTranslate=%s" %
107                                  keys.pop("odbc_autotranslate"))
108
109            connectors.extend(['%s=%s' % (k, v) for k, v in keys.items()])
110        return [[";".join(connectors)], connect_args]
111
112    def is_disconnect(self, e, connection, cursor):
113        if isinstance(e, self.dbapi.ProgrammingError):
114            return "The cursor's connection has been closed." in str(e) or \
115                'Attempt to use a closed connection.' in str(e)
116        elif isinstance(e, self.dbapi.Error):
117            return '[08S01]' in str(e)
118        else:
119            return False
120
121    def initialize(self, connection):
122        # determine FreeTDS first.   can't issue SQL easily
123        # without getting unicode_statements/binds set up.
124
125        pyodbc = self.dbapi
126
127        dbapi_con = connection.connection
128
129        _sql_driver_name = dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)
130        self.freetds = bool(re.match(r".*libtdsodbc.*\.so", _sql_driver_name
131                                     ))
132        self.easysoft = bool(re.match(r".*libessqlsrv.*\.so", _sql_driver_name
133                                      ))
134
135        if self.freetds:
136            self.freetds_driver_version = dbapi_con.getinfo(
137                pyodbc.SQL_DRIVER_VER)
138
139        self.supports_unicode_statements = (
140            not util.py2k or
141            (not self.freetds and not self.easysoft)
142        )
143
144        if self._user_supports_unicode_binds is not None:
145            self.supports_unicode_binds = self._user_supports_unicode_binds
146        elif util.py2k:
147            self.supports_unicode_binds = (
148                not self.freetds or self.freetds_driver_version >= '0.91'
149            ) and not self.easysoft
150        else:
151            self.supports_unicode_binds = True
152
153        # run other initialization which asks for user name, etc.
154        super(PyODBCConnector, self).initialize(connection)
155
156
157    def _dbapi_version(self):
158        if not self.dbapi:
159            return ()
160        return self._parse_dbapi_version(self.dbapi.version)
161
162    def _parse_dbapi_version(self, vers):
163        m = re.match(
164            r'(?:py.*-)?([\d\.]+)(?:-(\w+))?',
165            vers
166        )
167        if not m:
168            return ()
169        vers = tuple([int(x) for x in m.group(1).split(".")])
170        if m.group(2):
171            vers += (m.group(2),)
172        return vers
173
174    def _get_server_version_info(self, connection):
175        dbapi_con = connection.connection
176        version = []
177        r = re.compile('[.\-]')
178        for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
179            try:
180                version.append(int(n))
181            except ValueError:
182                version.append(n)
183        return tuple(version)
184