1# postgresql/pygresql.py
2# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7"""
8.. dialect:: postgresql+pygresql
9    :name: pygresql
10    :dbapi: pgdb
11    :connectstring: postgresql+pygresql://user:password@host:port/dbname[?key=value&key=value...]
12    :url: http://www.pygresql.org/
13"""  # noqa
14
15import decimal
16import re
17
18from .base import _DECIMAL_TYPES
19from .base import _FLOAT_TYPES
20from .base import _INT_TYPES
21from .base import PGCompiler
22from .base import PGDialect
23from .base import PGIdentifierPreparer
24from .base import UUID
25from .hstore import HSTORE
26from .json import JSON
27from .json import JSONB
28from ... import exc
29from ... import processors
30from ... import util
31from ...sql.elements import Null
32from ...types import JSON as Json
33from ...types import Numeric
34
35
36class _PGNumeric(Numeric):
37    def bind_processor(self, dialect):
38        return None
39
40    def result_processor(self, dialect, coltype):
41        if not isinstance(coltype, int):
42            coltype = coltype.oid
43        if self.asdecimal:
44            if coltype in _FLOAT_TYPES:
45                return processors.to_decimal_processor_factory(
46                    decimal.Decimal, self._effective_decimal_return_scale
47                )
48            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
49                # PyGreSQL returns Decimal natively for 1700 (numeric)
50                return None
51            else:
52                raise exc.InvalidRequestError(
53                    "Unknown PG numeric type: %d" % coltype
54                )
55        else:
56            if coltype in _FLOAT_TYPES:
57                # PyGreSQL returns float natively for 701 (float8)
58                return None
59            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
60                return processors.to_float
61            else:
62                raise exc.InvalidRequestError(
63                    "Unknown PG numeric type: %d" % coltype
64                )
65
66
67class _PGHStore(HSTORE):
68    def bind_processor(self, dialect):
69        if not dialect.has_native_hstore:
70            return super(_PGHStore, self).bind_processor(dialect)
71        hstore = dialect.dbapi.Hstore
72
73        def process(value):
74            if isinstance(value, dict):
75                return hstore(value)
76            return value
77
78        return process
79
80    def result_processor(self, dialect, coltype):
81        if not dialect.has_native_hstore:
82            return super(_PGHStore, self).result_processor(dialect, coltype)
83
84
85class _PGJSON(JSON):
86    def bind_processor(self, dialect):
87        if not dialect.has_native_json:
88            return super(_PGJSON, self).bind_processor(dialect)
89        json = dialect.dbapi.Json
90
91        def process(value):
92            if value is self.NULL:
93                value = None
94            elif isinstance(value, Null) or (
95                value is None and self.none_as_null
96            ):
97                return None
98            if value is None or isinstance(value, (dict, list)):
99                return json(value)
100            return value
101
102        return process
103
104    def result_processor(self, dialect, coltype):
105        if not dialect.has_native_json:
106            return super(_PGJSON, self).result_processor(dialect, coltype)
107
108
109class _PGJSONB(JSONB):
110    def bind_processor(self, dialect):
111        if not dialect.has_native_json:
112            return super(_PGJSONB, self).bind_processor(dialect)
113        json = dialect.dbapi.Json
114
115        def process(value):
116            if value is self.NULL:
117                value = None
118            elif isinstance(value, Null) or (
119                value is None and self.none_as_null
120            ):
121                return None
122            if value is None or isinstance(value, (dict, list)):
123                return json(value)
124            return value
125
126        return process
127
128    def result_processor(self, dialect, coltype):
129        if not dialect.has_native_json:
130            return super(_PGJSONB, self).result_processor(dialect, coltype)
131
132
133class _PGUUID(UUID):
134    def bind_processor(self, dialect):
135        if not dialect.has_native_uuid:
136            return super(_PGUUID, self).bind_processor(dialect)
137        uuid = dialect.dbapi.Uuid
138
139        def process(value):
140            if value is None:
141                return None
142            if isinstance(value, (str, bytes)):
143                if len(value) == 16:
144                    return uuid(bytes=value)
145                return uuid(value)
146            if isinstance(value, int):
147                return uuid(int=value)
148            return value
149
150        return process
151
152    def result_processor(self, dialect, coltype):
153        if not dialect.has_native_uuid:
154            return super(_PGUUID, self).result_processor(dialect, coltype)
155        if not self.as_uuid:
156
157            def process(value):
158                if value is not None:
159                    return str(value)
160
161            return process
162
163
164class _PGCompiler(PGCompiler):
165    def visit_mod_binary(self, binary, operator, **kw):
166        return (
167            self.process(binary.left, **kw)
168            + " %% "
169            + self.process(binary.right, **kw)
170        )
171
172    def post_process_text(self, text):
173        return text.replace("%", "%%")
174
175
176class _PGIdentifierPreparer(PGIdentifierPreparer):
177    def _escape_identifier(self, value):
178        value = value.replace(self.escape_quote, self.escape_to_quote)
179        return value.replace("%", "%%")
180
181
182class PGDialect_pygresql(PGDialect):
183
184    driver = "pygresql"
185
186    statement_compiler = _PGCompiler
187    preparer = _PGIdentifierPreparer
188
189    @classmethod
190    def dbapi(cls):
191        import pgdb
192
193        return pgdb
194
195    colspecs = util.update_copy(
196        PGDialect.colspecs,
197        {
198            Numeric: _PGNumeric,
199            HSTORE: _PGHStore,
200            Json: _PGJSON,
201            JSON: _PGJSON,
202            JSONB: _PGJSONB,
203            UUID: _PGUUID,
204        },
205    )
206
207    def __init__(self, **kwargs):
208        super(PGDialect_pygresql, self).__init__(**kwargs)
209        try:
210            version = self.dbapi.version
211            m = re.match(r"(\d+)\.(\d+)", version)
212            version = (int(m.group(1)), int(m.group(2)))
213        except (AttributeError, ValueError, TypeError):
214            version = (0, 0)
215        self.dbapi_version = version
216        if version < (5, 0):
217            has_native_hstore = has_native_json = has_native_uuid = False
218            if version != (0, 0):
219                util.warn(
220                    "PyGreSQL is only fully supported by SQLAlchemy"
221                    " since version 5.0."
222                )
223        else:
224            self.supports_unicode_statements = True
225            self.supports_unicode_binds = True
226            has_native_hstore = has_native_json = has_native_uuid = True
227        self.has_native_hstore = has_native_hstore
228        self.has_native_json = has_native_json
229        self.has_native_uuid = has_native_uuid
230
231    def create_connect_args(self, url):
232        opts = url.translate_connect_args(username="user")
233        if "port" in opts:
234            opts["host"] = "%s:%s" % (
235                opts.get("host", "").rsplit(":", 1)[0],
236                opts.pop("port"),
237            )
238        opts.update(url.query)
239        return [], opts
240
241    def is_disconnect(self, e, connection, cursor):
242        if isinstance(e, self.dbapi.Error):
243            if not connection:
244                return False
245            try:
246                connection = connection.connection
247            except AttributeError:
248                pass
249            else:
250                if not connection:
251                    return False
252            try:
253                return connection.closed
254            except AttributeError:  # PyGreSQL < 5.0
255                return connection._cnx is None
256        return False
257
258
259dialect = PGDialect_pygresql
260