1# postgresql/pg8000.py
2# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors <see AUTHORS
3# file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""
9.. dialect:: postgresql+pg8000
10    :name: pg8000
11    :dbapi: pg8000
12    :connectstring: \
13postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]
14    :url: https://pythonhosted.org/pg8000/
15
16
17.. _pg8000_unicode:
18
19Unicode
20-------
21
22pg8000 will encode / decode string values between it and the server using the
23PostgreSQL ``client_encoding`` parameter; by default this is the value in
24the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
25Typically, this can be changed to ``utf-8``, as a more useful default::
26
27    #client_encoding = sql_ascii # actually, defaults to database
28                                 # encoding
29    client_encoding = utf8
30
31The ``client_encoding`` can be overriden for a session by executing the SQL:
32
33SET CLIENT_ENCODING TO 'utf8';
34
35SQLAlchemy will execute this SQL on all new connections based on the value
36passed to :func:`.create_engine` using the ``client_encoding`` parameter::
37
38    engine = create_engine(
39        "postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8')
40
41
42.. _pg8000_isolation_level:
43
44pg8000 Transaction Isolation Level
45-------------------------------------
46
47The pg8000 dialect offers the same isolation level settings as that
48of the :ref:`psycopg2 <psycopg2_isolation_level>` dialect:
49
50* ``READ COMMITTED``
51* ``READ UNCOMMITTED``
52* ``REPEATABLE READ``
53* ``SERIALIZABLE``
54* ``AUTOCOMMIT``
55
56.. versionadded:: 0.9.5 support for AUTOCOMMIT isolation level when using
57   pg8000.
58
59.. seealso::
60
61    :ref:`postgresql_isolation_level`
62
63    :ref:`psycopg2_isolation_level`
64
65
66"""
67from ... import util, exc
68import decimal
69from ... import processors
70from ... import types as sqltypes
71from .base import (
72    PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext,
73    _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES)
74import re
75from sqlalchemy.dialects.postgresql.json import JSON
76
77
78class _PGNumeric(sqltypes.Numeric):
79    def result_processor(self, dialect, coltype):
80        if self.asdecimal:
81            if coltype in _FLOAT_TYPES:
82                return processors.to_decimal_processor_factory(
83                    decimal.Decimal, self._effective_decimal_return_scale)
84            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
85                # pg8000 returns Decimal natively for 1700
86                return None
87            else:
88                raise exc.InvalidRequestError(
89                    "Unknown PG numeric type: %d" % coltype)
90        else:
91            if coltype in _FLOAT_TYPES:
92                # pg8000 returns float natively for 701
93                return None
94            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
95                return processors.to_float
96            else:
97                raise exc.InvalidRequestError(
98                    "Unknown PG numeric type: %d" % coltype)
99
100
101class _PGNumericNoBind(_PGNumeric):
102    def bind_processor(self, dialect):
103        return None
104
105
106class _PGJSON(JSON):
107
108    def result_processor(self, dialect, coltype):
109        if dialect._dbapi_version > (1, 10, 1):
110            return None  # Has native JSON
111        else:
112            return super(_PGJSON, self).result_processor(dialect, coltype)
113
114
115class PGExecutionContext_pg8000(PGExecutionContext):
116    pass
117
118
119class PGCompiler_pg8000(PGCompiler):
120    def visit_mod_binary(self, binary, operator, **kw):
121        return self.process(binary.left, **kw) + " %% " + \
122            self.process(binary.right, **kw)
123
124    def post_process_text(self, text):
125        if '%%' in text:
126            util.warn("The SQLAlchemy postgresql dialect "
127                      "now automatically escapes '%' in text() "
128                      "expressions to '%%'.")
129        return text.replace('%', '%%')
130
131
132class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
133    def _escape_identifier(self, value):
134        value = value.replace(self.escape_quote, self.escape_to_quote)
135        return value.replace('%', '%%')
136
137
138class PGDialect_pg8000(PGDialect):
139    driver = 'pg8000'
140
141    supports_unicode_statements = True
142
143    supports_unicode_binds = True
144
145    default_paramstyle = 'format'
146    supports_sane_multi_rowcount = True
147    execution_ctx_cls = PGExecutionContext_pg8000
148    statement_compiler = PGCompiler_pg8000
149    preparer = PGIdentifierPreparer_pg8000
150    description_encoding = 'use_encoding'
151
152    colspecs = util.update_copy(
153        PGDialect.colspecs,
154        {
155            sqltypes.Numeric: _PGNumericNoBind,
156            sqltypes.Float: _PGNumeric,
157            JSON: _PGJSON,
158        }
159    )
160
161    def __init__(self, client_encoding=None, **kwargs):
162        PGDialect.__init__(self, **kwargs)
163        self.client_encoding = client_encoding
164
165    def initialize(self, connection):
166        self.supports_sane_multi_rowcount = self._dbapi_version >= (1, 9, 14)
167        super(PGDialect_pg8000, self).initialize(connection)
168
169    @util.memoized_property
170    def _dbapi_version(self):
171        if self.dbapi and hasattr(self.dbapi, '__version__'):
172            return tuple(
173                [
174                    int(x) for x in re.findall(
175                        r'(\d+)(?:[-\.]?|$)', self.dbapi.__version__)])
176        else:
177            return (99, 99, 99)
178
179    @classmethod
180    def dbapi(cls):
181        return __import__('pg8000')
182
183    def create_connect_args(self, url):
184        opts = url.translate_connect_args(username='user')
185        if 'port' in opts:
186            opts['port'] = int(opts['port'])
187        opts.update(url.query)
188        return ([], opts)
189
190    def is_disconnect(self, e, connection, cursor):
191        return "connection is closed" in str(e)
192
193    def set_isolation_level(self, connection, level):
194        level = level.replace('_', ' ')
195
196        # adjust for ConnectionFairy possibly being present
197        if hasattr(connection, 'connection'):
198            connection = connection.connection
199
200        if level == 'AUTOCOMMIT':
201            connection.autocommit = True
202        elif level in self._isolation_lookup:
203            connection.autocommit = False
204            cursor = connection.cursor()
205            cursor.execute(
206                "SET SESSION CHARACTERISTICS AS TRANSACTION "
207                "ISOLATION LEVEL %s" % level)
208            cursor.execute("COMMIT")
209            cursor.close()
210        else:
211            raise exc.ArgumentError(
212                "Invalid value '%s' for isolation_level. "
213                "Valid isolation levels for %s are %s or AUTOCOMMIT" %
214                (level, self.name, ", ".join(self._isolation_lookup))
215            )
216
217    def set_client_encoding(self, connection, client_encoding):
218        # adjust for ConnectionFairy possibly being present
219        if hasattr(connection, 'connection'):
220            connection = connection.connection
221
222        cursor = connection.cursor()
223        cursor.execute("SET CLIENT_ENCODING TO '" + client_encoding + "'")
224        cursor.execute("COMMIT")
225        cursor.close()
226
227    def do_begin_twophase(self, connection, xid):
228        connection.connection.tpc_begin((0, xid, ''))
229
230    def do_prepare_twophase(self, connection, xid):
231        connection.connection.tpc_prepare()
232
233    def do_rollback_twophase(
234            self, connection, xid, is_prepared=True, recover=False):
235        connection.connection.tpc_rollback((0, xid, ''))
236
237    def do_commit_twophase(
238            self, connection, xid, is_prepared=True, recover=False):
239        connection.connection.tpc_commit((0, xid, ''))
240
241    def do_recover_twophase(self, connection):
242        return [row[1] for row in connection.connection.tpc_recover()]
243
244    def on_connect(self):
245        fns = []
246        if self.client_encoding is not None:
247            def on_connect(conn):
248                self.set_client_encoding(conn, self.client_encoding)
249            fns.append(on_connect)
250
251        if self.isolation_level is not None:
252            def on_connect(conn):
253                self.set_isolation_level(conn, self.isolation_level)
254            fns.append(on_connect)
255
256        if len(fns) > 0:
257            def on_connect(conn):
258                for fn in fns:
259                    fn(conn)
260            return on_connect
261        else:
262            return None
263
264dialect = PGDialect_pg8000
265