1# postgresql/pg8000.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors <see AUTHORS
3# file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7r"""
8.. dialect:: postgresql+pg8000
9    :name: pg8000
10    :dbapi: pg8000
11    :connectstring: postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]
12    :url: https://pypi.org/project/pg8000/
13
14.. versionchanged:: 1.4  The pg8000 dialect has been updated for version
15   1.16.6 and higher, and is again part of SQLAlchemy's continuous integration
16   with full feature support.
17
18.. _pg8000_unicode:
19
20Unicode
21-------
22
23pg8000 will encode / decode string values between it and the server using the
24PostgreSQL ``client_encoding`` parameter; by default this is the value in
25the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
26Typically, this can be changed to ``utf-8``, as a more useful default::
27
28    #client_encoding = sql_ascii # actually, defaults to database
29                                 # encoding
30    client_encoding = utf8
31
32The ``client_encoding`` can be overridden for a session by executing the SQL:
33
34SET CLIENT_ENCODING TO 'utf8';
35
36SQLAlchemy will execute this SQL on all new connections based on the value
37passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter::
38
39    engine = create_engine(
40        "postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8')
41
42.. _pg8000_ssl:
43
44SSL Connections
45---------------
46
47pg8000 accepts a Python ``SSLContext`` object which may be specified using the
48:paramref:`_sa.create_engine.connect_args` dictionary::
49
50    import ssl
51    ssl_context = ssl.create_default_context()
52    engine = sa.create_engine(
53        "postgresql+pg8000://scott:tiger@192.168.0.199/test",
54        connect_args={"ssl_context": ssl_context},
55    )
56
57If the server uses an automatically-generated certificate that is self-signed
58or does not match the host name (as seen from the client), it may also be
59necessary to disable hostname checking::
60
61    import ssl
62    ssl_context = ssl.create_default_context()
63    ssl_context.check_hostname = False
64    ssl_context.verify_mode = ssl.CERT_NONE
65    engine = sa.create_engine(
66        "postgresql+pg8000://scott:tiger@192.168.0.199/test",
67        connect_args={"ssl_context": ssl_context},
68    )
69
70.. _pg8000_isolation_level:
71
72pg8000 Transaction Isolation Level
73-------------------------------------
74
75The pg8000 dialect offers the same isolation level settings as that
76of the :ref:`psycopg2 <psycopg2_isolation_level>` dialect:
77
78* ``READ COMMITTED``
79* ``READ UNCOMMITTED``
80* ``REPEATABLE READ``
81* ``SERIALIZABLE``
82* ``AUTOCOMMIT``
83
84.. seealso::
85
86    :ref:`postgresql_isolation_level`
87
88    :ref:`psycopg2_isolation_level`
89
90
91"""  # noqa
92import decimal
93import re
94from uuid import UUID as _python_UUID
95
96from .array import ARRAY as PGARRAY
97from .base import _ColonCast
98from .base import _DECIMAL_TYPES
99from .base import _FLOAT_TYPES
100from .base import _INT_TYPES
101from .base import ENUM
102from .base import INTERVAL
103from .base import PGCompiler
104from .base import PGDialect
105from .base import PGExecutionContext
106from .base import PGIdentifierPreparer
107from .base import UUID
108from .json import JSON
109from .json import JSONB
110from .json import JSONPathType
111from ... import exc
112from ... import processors
113from ... import types as sqltypes
114from ... import util
115from ...sql.elements import quoted_name
116
117
118class _PGNumeric(sqltypes.Numeric):
119    def result_processor(self, dialect, coltype):
120        if self.asdecimal:
121            if coltype in _FLOAT_TYPES:
122                return processors.to_decimal_processor_factory(
123                    decimal.Decimal, self._effective_decimal_return_scale
124                )
125            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
126                # pg8000 returns Decimal natively for 1700
127                return None
128            else:
129                raise exc.InvalidRequestError(
130                    "Unknown PG numeric type: %d" % coltype
131                )
132        else:
133            if coltype in _FLOAT_TYPES:
134                # pg8000 returns float natively for 701
135                return None
136            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
137                return processors.to_float
138            else:
139                raise exc.InvalidRequestError(
140                    "Unknown PG numeric type: %d" % coltype
141                )
142
143
144class _PGNumericNoBind(_PGNumeric):
145    def bind_processor(self, dialect):
146        return None
147
148
149class _PGJSON(JSON):
150    def result_processor(self, dialect, coltype):
151        return None
152
153    def get_dbapi_type(self, dbapi):
154        return dbapi.JSON
155
156
157class _PGJSONB(JSONB):
158    def result_processor(self, dialect, coltype):
159        return None
160
161    def get_dbapi_type(self, dbapi):
162        return dbapi.JSONB
163
164
165class _PGJSONIndexType(sqltypes.JSON.JSONIndexType):
166    def get_dbapi_type(self, dbapi):
167        raise NotImplementedError("should not be here")
168
169
170class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
171    def get_dbapi_type(self, dbapi):
172        return dbapi.INTEGER
173
174
175class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
176    def get_dbapi_type(self, dbapi):
177        return dbapi.STRING
178
179
180class _PGJSONPathType(JSONPathType):
181    def get_dbapi_type(self, dbapi):
182        return 1009
183
184
185class _PGUUID(UUID):
186    def bind_processor(self, dialect):
187        if not self.as_uuid:
188
189            def process(value):
190                if value is not None:
191                    value = _python_UUID(value)
192                return value
193
194            return process
195
196    def result_processor(self, dialect, coltype):
197        if not self.as_uuid:
198
199            def process(value):
200                if value is not None:
201                    value = str(value)
202                return value
203
204            return process
205
206
207class _PGEnum(ENUM):
208    def get_dbapi_type(self, dbapi):
209        return dbapi.UNKNOWN
210
211
212class _PGInterval(INTERVAL):
213    def get_dbapi_type(self, dbapi):
214        return dbapi.INTERVAL
215
216    @classmethod
217    def adapt_emulated_to_native(cls, interval, **kw):
218        return _PGInterval(precision=interval.second_precision)
219
220
221class _PGTimeStamp(sqltypes.DateTime):
222    def get_dbapi_type(self, dbapi):
223        if self.timezone:
224            # TIMESTAMPTZOID
225            return 1184
226        else:
227            # TIMESTAMPOID
228            return 1114
229
230
231class _PGTime(sqltypes.Time):
232    def get_dbapi_type(self, dbapi):
233        return dbapi.TIME
234
235
236class _PGInteger(sqltypes.Integer):
237    def get_dbapi_type(self, dbapi):
238        return dbapi.INTEGER
239
240
241class _PGSmallInteger(sqltypes.SmallInteger):
242    def get_dbapi_type(self, dbapi):
243        return dbapi.INTEGER
244
245
246class _PGNullType(sqltypes.NullType):
247    def get_dbapi_type(self, dbapi):
248        return dbapi.NULLTYPE
249
250
251class _PGBigInteger(sqltypes.BigInteger):
252    def get_dbapi_type(self, dbapi):
253        return dbapi.BIGINTEGER
254
255
256class _PGBoolean(sqltypes.Boolean):
257    def get_dbapi_type(self, dbapi):
258        return dbapi.BOOLEAN
259
260
261class _PGARRAY(PGARRAY):
262    def bind_expression(self, bindvalue):
263        return _ColonCast(bindvalue, self)
264
265
266_server_side_id = util.counter()
267
268
269class PGExecutionContext_pg8000(PGExecutionContext):
270    def create_server_side_cursor(self):
271        ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
272        return ServerSideCursor(self._dbapi_connection.cursor(), ident)
273
274    def pre_exec(self):
275        if not self.compiled:
276            return
277
278
279class ServerSideCursor:
280    server_side = True
281
282    def __init__(self, cursor, ident):
283        self.ident = ident
284        self.cursor = cursor
285
286    @property
287    def connection(self):
288        return self.cursor.connection
289
290    @property
291    def rowcount(self):
292        return self.cursor.rowcount
293
294    @property
295    def description(self):
296        return self.cursor.description
297
298    def execute(self, operation, args=(), stream=None):
299        op = "DECLARE " + self.ident + " NO SCROLL CURSOR FOR " + operation
300        self.cursor.execute(op, args, stream=stream)
301        return self
302
303    def executemany(self, operation, param_sets):
304        self.cursor.executemany(operation, param_sets)
305        return self
306
307    def fetchone(self):
308        self.cursor.execute("FETCH FORWARD 1 FROM " + self.ident)
309        return self.cursor.fetchone()
310
311    def fetchmany(self, num=None):
312        if num is None:
313            return self.fetchall()
314        else:
315            self.cursor.execute(
316                "FETCH FORWARD " + str(int(num)) + " FROM " + self.ident
317            )
318            return self.cursor.fetchall()
319
320    def fetchall(self):
321        self.cursor.execute("FETCH FORWARD ALL FROM " + self.ident)
322        return self.cursor.fetchall()
323
324    def close(self):
325        self.cursor.execute("CLOSE " + self.ident)
326        self.cursor.close()
327
328    def setinputsizes(self, *sizes):
329        self.cursor.setinputsizes(*sizes)
330
331    def setoutputsize(self, size, column=None):
332        pass
333
334
335class PGCompiler_pg8000(PGCompiler):
336    def visit_mod_binary(self, binary, operator, **kw):
337        return (
338            self.process(binary.left, **kw)
339            + " %% "
340            + self.process(binary.right, **kw)
341        )
342
343
344class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
345    def __init__(self, *args, **kwargs):
346        PGIdentifierPreparer.__init__(self, *args, **kwargs)
347        self._double_percents = False
348
349
350class PGDialect_pg8000(PGDialect):
351    driver = "pg8000"
352    supports_statement_cache = True
353
354    supports_unicode_statements = True
355
356    supports_unicode_binds = True
357
358    default_paramstyle = "format"
359    supports_sane_multi_rowcount = True
360    execution_ctx_cls = PGExecutionContext_pg8000
361    statement_compiler = PGCompiler_pg8000
362    preparer = PGIdentifierPreparer_pg8000
363    supports_server_side_cursors = True
364
365    use_setinputsizes = True
366
367    # reversed as of pg8000 1.16.6.  1.16.5 and lower
368    # are no longer compatible
369    description_encoding = None
370    # description_encoding = "use_encoding"
371
372    colspecs = util.update_copy(
373        PGDialect.colspecs,
374        {
375            sqltypes.Numeric: _PGNumericNoBind,
376            sqltypes.Float: _PGNumeric,
377            sqltypes.JSON: _PGJSON,
378            sqltypes.Boolean: _PGBoolean,
379            sqltypes.NullType: _PGNullType,
380            JSONB: _PGJSONB,
381            sqltypes.JSON.JSONPathType: _PGJSONPathType,
382            sqltypes.JSON.JSONIndexType: _PGJSONIndexType,
383            sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
384            sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType,
385            UUID: _PGUUID,
386            sqltypes.Interval: _PGInterval,
387            INTERVAL: _PGInterval,
388            sqltypes.DateTime: _PGTimeStamp,
389            sqltypes.Time: _PGTime,
390            sqltypes.Integer: _PGInteger,
391            sqltypes.SmallInteger: _PGSmallInteger,
392            sqltypes.BigInteger: _PGBigInteger,
393            sqltypes.Enum: _PGEnum,
394            sqltypes.ARRAY: _PGARRAY,
395        },
396    )
397
398    def __init__(self, client_encoding=None, **kwargs):
399        PGDialect.__init__(self, **kwargs)
400        self.client_encoding = client_encoding
401
402        if self._dbapi_version < (1, 16, 6):
403            raise NotImplementedError("pg8000 1.16.6 or greater is required")
404
405    @util.memoized_property
406    def _dbapi_version(self):
407        if self.dbapi and hasattr(self.dbapi, "__version__"):
408            return tuple(
409                [
410                    int(x)
411                    for x in re.findall(
412                        r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
413                    )
414                ]
415            )
416        else:
417            return (99, 99, 99)
418
419    @classmethod
420    def dbapi(cls):
421        return __import__("pg8000")
422
423    def create_connect_args(self, url):
424        opts = url.translate_connect_args(username="user")
425        if "port" in opts:
426            opts["port"] = int(opts["port"])
427        opts.update(url.query)
428        return ([], opts)
429
430    def is_disconnect(self, e, connection, cursor):
431        if isinstance(e, self.dbapi.InterfaceError) and "network error" in str(
432            e
433        ):
434            # new as of pg8000 1.19.0 for broken connections
435            return True
436
437        # connection was closed normally
438        return "connection is closed" in str(e)
439
440    def set_isolation_level(self, connection, level):
441        level = level.replace("_", " ")
442
443        # adjust for ConnectionFairy possibly being present
444        if hasattr(connection, "dbapi_connection"):
445            connection = connection.dbapi_connection
446
447        if level == "AUTOCOMMIT":
448            connection.autocommit = True
449        elif level in self._isolation_lookup:
450            connection.autocommit = False
451            cursor = connection.cursor()
452            cursor.execute(
453                "SET SESSION CHARACTERISTICS AS TRANSACTION "
454                "ISOLATION LEVEL %s" % level
455            )
456            cursor.execute("COMMIT")
457            cursor.close()
458        else:
459            raise exc.ArgumentError(
460                "Invalid value '%s' for isolation_level. "
461                "Valid isolation levels for %s are %s or AUTOCOMMIT"
462                % (level, self.name, ", ".join(self._isolation_lookup))
463            )
464
465    def set_readonly(self, connection, value):
466        cursor = connection.cursor()
467        try:
468            cursor.execute(
469                "SET SESSION CHARACTERISTICS AS TRANSACTION %s"
470                % ("READ ONLY" if value else "READ WRITE")
471            )
472            cursor.execute("COMMIT")
473        finally:
474            cursor.close()
475
476    def get_readonly(self, connection):
477        cursor = connection.cursor()
478        try:
479            cursor.execute("show transaction_read_only")
480            val = cursor.fetchone()[0]
481        finally:
482            cursor.close()
483
484        return val == "on"
485
486    def set_deferrable(self, connection, value):
487        cursor = connection.cursor()
488        try:
489            cursor.execute(
490                "SET SESSION CHARACTERISTICS AS TRANSACTION %s"
491                % ("DEFERRABLE" if value else "NOT DEFERRABLE")
492            )
493            cursor.execute("COMMIT")
494        finally:
495            cursor.close()
496
497    def get_deferrable(self, connection):
498        cursor = connection.cursor()
499        try:
500            cursor.execute("show transaction_deferrable")
501            val = cursor.fetchone()[0]
502        finally:
503            cursor.close()
504
505        return val == "on"
506
507    def set_client_encoding(self, connection, client_encoding):
508        # adjust for ConnectionFairy possibly being present
509        if hasattr(connection, "dbapi_connection"):
510            connection = connection.dbapi_connection
511
512        cursor = connection.cursor()
513        cursor.execute("SET CLIENT_ENCODING TO '" + client_encoding + "'")
514        cursor.execute("COMMIT")
515        cursor.close()
516
517    def do_set_input_sizes(self, cursor, list_of_tuples, context):
518        if self.positional:
519            cursor.setinputsizes(
520                *[dbtype for key, dbtype, sqltype in list_of_tuples]
521            )
522        else:
523            cursor.setinputsizes(
524                **{
525                    key: dbtype
526                    for key, dbtype, sqltype in list_of_tuples
527                    if dbtype
528                }
529            )
530
531    def do_begin_twophase(self, connection, xid):
532        connection.connection.tpc_begin((0, xid, ""))
533
534    def do_prepare_twophase(self, connection, xid):
535        connection.connection.tpc_prepare()
536
537    def do_rollback_twophase(
538        self, connection, xid, is_prepared=True, recover=False
539    ):
540        connection.connection.tpc_rollback((0, xid, ""))
541
542    def do_commit_twophase(
543        self, connection, xid, is_prepared=True, recover=False
544    ):
545        connection.connection.tpc_commit((0, xid, ""))
546
547    def do_recover_twophase(self, connection):
548        return [row[1] for row in connection.connection.tpc_recover()]
549
550    def on_connect(self):
551        fns = []
552
553        def on_connect(conn):
554            conn.py_types[quoted_name] = conn.py_types[util.text_type]
555
556        fns.append(on_connect)
557
558        if self.client_encoding is not None:
559
560            def on_connect(conn):
561                self.set_client_encoding(conn, self.client_encoding)
562
563            fns.append(on_connect)
564
565        if self.isolation_level is not None:
566
567            def on_connect(conn):
568                self.set_isolation_level(conn, self.isolation_level)
569
570            fns.append(on_connect)
571
572        if self._json_deserializer:
573
574            def on_connect(conn):
575                # json
576                conn.register_in_adapter(114, self._json_deserializer)
577
578                # jsonb
579                conn.register_in_adapter(3802, self._json_deserializer)
580
581            fns.append(on_connect)
582
583        if len(fns) > 0:
584
585            def on_connect(conn):
586                for fn in fns:
587                    fn(conn)
588
589            return on_connect
590        else:
591            return None
592
593
594dialect = PGDialect_pg8000
595