1from datetime import (
2    date as Date,
3    datetime as Datetime,
4    time as Time,
5    timedelta as Timedelta,
6    timezone as Timezone,
7)
8from decimal import Decimal
9from enum import Enum
10from ipaddress import (
11    IPv4Address,
12    IPv4Network,
13    IPv6Address,
14    IPv6Network,
15    ip_address,
16    ip_network,
17)
18from json import dumps, loads
19from uuid import UUID
20
21from pg8000.exceptions import InterfaceError
22
23
24ANY_ARRAY = 2277
25BIGINT = 20
26BIGINT_ARRAY = 1016
27BOOLEAN = 16
28BOOLEAN_ARRAY = 1000
29BYTES = 17
30BYTES_ARRAY = 1001
31CHAR = 1042
32CHAR_ARRAY = 1014
33CIDR = 650
34CIDR_ARRAY = 651
35CSTRING = 2275
36CSTRING_ARRAY = 1263
37DATE = 1082
38DATE_ARRAY = 1182
39FLOAT = 701
40FLOAT_ARRAY = 1022
41INET = 869
42INET_ARRAY = 1041
43INT2VECTOR = 22
44INTEGER = 23
45INTEGER_ARRAY = 1007
46INTERVAL = 1186
47INTERVAL_ARRAY = 1187
48OID = 26
49JSON = 114
50JSON_ARRAY = 199
51JSONB = 3802
52JSONB_ARRAY = 3807
53MACADDR = 829
54MONEY = 790
55MONEY_ARRAY = 791
56NAME = 19
57NAME_ARRAY = 1003
58NUMERIC = 1700
59NUMERIC_ARRAY = 1231
60NULLTYPE = -1
61OID = 26
62POINT = 600
63REAL = 700
64REAL_ARRAY = 1021
65SMALLINT = 21
66SMALLINT_ARRAY = 1005
67SMALLINT_VECTOR = 22
68STRING = 1043
69TEXT = 25
70TEXT_ARRAY = 1009
71TIME = 1083
72TIME_ARRAY = 1183
73TIMESTAMP = 1114
74TIMESTAMP_ARRAY = 1115
75TIMESTAMPTZ = 1184
76TIMESTAMPTZ_ARRAY = 1185
77UNKNOWN = 705
78UUID_TYPE = 2950
79UUID_ARRAY = 2951
80VARCHAR = 1043
81VARCHAR_ARRAY = 1015
82XID = 28
83
84
85MIN_INT2, MAX_INT2 = -(2 ** 15), 2 ** 15
86MIN_INT4, MAX_INT4 = -(2 ** 31), 2 ** 31
87MIN_INT8, MAX_INT8 = -(2 ** 63), 2 ** 63
88
89
90def bool_in(data):
91    return data == "t"
92
93
94def bool_out(v):
95    return "true" if v else "false"
96
97
98def bytes_in(data):
99    return bytes.fromhex(data[2:])
100
101
102def bytes_out(v):
103    return "\\x" + v.hex()
104
105
106def cidr_out(v):
107    return str(v)
108
109
110def cidr_in(data):
111    return ip_network(data, False) if "/" in data else ip_address(data)
112
113
114def date_in(data):
115    return Datetime.strptime(data, "%Y-%m-%d").date()
116
117
118def date_out(v):
119    return v.isoformat()
120
121
122def datetime_out(v):
123    if v.tzinfo is None:
124        return v.isoformat()
125    else:
126        return v.astimezone(Timezone.utc).isoformat()
127
128
129def enum_out(v):
130    return str(v.value)
131
132
133def float_out(v):
134    return str(v)
135
136
137def inet_in(data):
138    return ip_network(data, False) if "/" in data else ip_address(data)
139
140
141def inet_out(v):
142    return str(v)
143
144
145def int_in(data):
146    return int(data)
147
148
149def int_out(v):
150    return str(v)
151
152
153def interval_in(data):
154    t = {}
155
156    curr_val = None
157    for k in data.split():
158        if ":" in k:
159            t["hours"], t["minutes"], t["seconds"] = map(float, k.split(":"))
160        else:
161            try:
162                curr_val = float(k)
163            except ValueError:
164                t[PGInterval.UNIT_MAP[k]] = curr_val
165
166    for n in ["weeks", "months", "years", "decades", "centuries", "millennia"]:
167        if n in t:
168            raise InterfaceError(
169                f"Can't fit the interval {t} into a datetime.timedelta."
170            )
171
172    return Timedelta(**t)
173
174
175def interval_out(v):
176    return f"{v.days} days {v.seconds} seconds {v.microseconds} microseconds"
177
178
179def json_in(data):
180    return loads(data)
181
182
183def json_out(v):
184    return dumps(v)
185
186
187def null_out(v):
188    return None
189
190
191def numeric_in(data):
192    return Decimal(data)
193
194
195def numeric_out(d):
196    return str(d)
197
198
199def pg_interval_in(data):
200    return PGInterval.from_str(data)
201
202
203def pg_interval_out(v):
204    return str(v)
205
206
207def string_in(data):
208    return data
209
210
211def string_out(v):
212    return v
213
214
215def time_in(data):
216    pattern = "%H:%M:%S.%f" if "." in data else "%H:%M:%S"
217    return Datetime.strptime(data, pattern).time()
218
219
220def time_out(v):
221    return v.isoformat()
222
223
224def timestamp_in(data):
225    if data in ("infinity", "-infinity"):
226        return data
227
228    pattern = "%Y-%m-%d %H:%M:%S.%f" if "." in data else "%Y-%m-%d %H:%M:%S"
229    return Datetime.strptime(data, pattern)
230
231
232def timestamptz_in(data):
233    patt = "%Y-%m-%d %H:%M:%S.%f%z" if "." in data else "%Y-%m-%d %H:%M:%S%z"
234    return Datetime.strptime(data + "00", patt)
235
236
237def unknown_out(v):
238    return str(v)
239
240
241def vector_in(data):
242    return eval("[" + data.replace(" ", ",") + "]")
243
244
245def uuid_out(v):
246    return str(v)
247
248
249def uuid_in(data):
250    return UUID(data)
251
252
253class PGInterval:
254    UNIT_MAP = {
255        "year": "years",
256        "years": "years",
257        "millennia": "millennia",
258        "millenium": "millennia",
259        "centuries": "centuries",
260        "century": "centuries",
261        "decades": "decades",
262        "decade": "decades",
263        "years": "years",
264        "year": "years",
265        "months": "months",
266        "month": "months",
267        "mon": "months",
268        "mons": "months",
269        "weeks": "weeks",
270        "week": "weeks",
271        "days": "days",
272        "day": "days",
273        "hours": "hours",
274        "hour": "hours",
275        "minutes": "minutes",
276        "minute": "minutes",
277        "seconds": "seconds",
278        "second": "seconds",
279        "microseconds": "microseconds",
280        "microsecond": "microseconds",
281    }
282
283    @staticmethod
284    def from_str(interval_str):
285        t = {}
286
287        curr_val = None
288        for k in interval_str.split():
289            if ":" in k:
290                hours_str, minutes_str, seconds_str = k.split(":")
291                hours = int(hours_str)
292                if hours != 0:
293                    t["hours"] = hours
294                minutes = int(minutes_str)
295                if minutes != 0:
296                    t["minutes"] = minutes
297                try:
298                    seconds = int(seconds_str)
299                except ValueError:
300                    seconds = float(seconds_str)
301
302                if seconds != 0:
303                    t["seconds"] = seconds
304
305            else:
306                try:
307                    curr_val = int(k)
308                except ValueError:
309                    t[PGInterval.UNIT_MAP[k]] = curr_val
310
311        return PGInterval(**t)
312
313    def __init__(
314        self,
315        millennia=None,
316        centuries=None,
317        decades=None,
318        years=None,
319        months=None,
320        weeks=None,
321        days=None,
322        hours=None,
323        minutes=None,
324        seconds=None,
325        microseconds=None,
326    ):
327        self.millennia = millennia
328        self.centuries = centuries
329        self.decades = decades
330        self.years = years
331        self.months = months
332        self.weeks = weeks
333        self.days = days
334        self.hours = hours
335        self.minutes = minutes
336        self.seconds = seconds
337        self.microseconds = microseconds
338
339    def __repr__(self):
340        return f"<PGInterval {self}>"
341
342    def __str__(self):
343        pairs = (
344            ("millennia", self.millennia),
345            ("centuries", self.centuries),
346            ("decades", self.decades),
347            ("years", self.years),
348            ("months", self.months),
349            ("weeks", self.weeks),
350            ("days", self.days),
351            ("hours", self.hours),
352            ("minutes", self.minutes),
353            ("seconds", self.seconds),
354            ("microseconds", self.microseconds),
355        )
356        return " ".join(f"{v} {n}" for n, v in pairs if v is not None)
357
358    def normalize(self):
359        months = 0
360        if self.months is not None:
361            months += self.months
362        if self.years is not None:
363            months += self.years * 12
364
365        days = 0
366        if self.days is not None:
367            days += self.days
368        if self.weeks is not None:
369            days += self.weeks * 7
370
371        seconds = 0
372        if self.hours is not None:
373            seconds += self.hours * 60 * 60
374        if self.minutes is not None:
375            seconds += self.minutes * 60
376        if self.seconds is not None:
377            seconds += self.seconds
378        if self.microseconds is not None:
379            seconds += self.microseconds / 1000000
380
381        return PGInterval(months=months, days=days, seconds=seconds)
382
383    def __eq__(self, other):
384        if isinstance(other, PGInterval):
385            s = self.normalize()
386            o = other.normalize()
387            return s.months == o.months and s.days == o.days and s.seconds == o.seconds
388        else:
389            return False
390
391
392class ArrayState(Enum):
393    InString = 1
394    InEscape = 2
395    InValue = 3
396    Out = 4
397
398
399def _parse_array(data, adapter):
400    state = ArrayState.Out
401    stack = [[]]
402    val = []
403    for c in data:
404        if state == ArrayState.InValue:
405            if c in ("}", ","):
406                value = "".join(val)
407                stack[-1].append(None if value == "NULL" else adapter(value))
408                state = ArrayState.Out
409            else:
410                val.append(c)
411
412        if state == ArrayState.Out:
413            if c == "{":
414                a = []
415                stack[-1].append(a)
416                stack.append(a)
417            elif c == "}":
418                stack.pop()
419            elif c == ",":
420                pass
421            elif c == '"':
422                val = []
423                state = ArrayState.InString
424            else:
425                val = [c]
426                state = ArrayState.InValue
427
428        elif state == ArrayState.InString:
429            if c == '"':
430                stack[-1].append(adapter("".join(val)))
431                state = ArrayState.Out
432            elif c == "\\":
433                state = ArrayState.InEscape
434            else:
435                val.append(c)
436        elif state == ArrayState.InEscape:
437            val.append(c)
438            state = ArrayState.InString
439
440    return stack[0][0]
441
442
443def _array_in(adapter):
444    def f(data):
445        return _parse_array(data, adapter)
446
447    return f
448
449
450bool_array_in = _array_in(bool_in)
451bytes_array_in = _array_in(bytes_in)
452cidr_array_in = _array_in(cidr_in)
453date_array_in = _array_in(date_in)
454inet_array_in = _array_in(inet_in)
455int_array_in = _array_in(int)
456interval_array_in = _array_in(interval_in)
457json_array_in = _array_in(json_in)
458float_array_in = _array_in(float)
459numeric_array_in = _array_in(numeric_in)
460string_array_in = _array_in(string_in)
461time_array_in = _array_in(time_in)
462timestamp_array_in = _array_in(timestamp_in)
463timestamptz_array_in = _array_in(timestamptz_in)
464uuid_array_in = _array_in(uuid_in)
465
466
467def array_string_escape(v):
468    cs = []
469    for c in v:
470        if c == "\\":
471            cs.append("\\")
472        elif c == '"':
473            cs.append("\\")
474        cs.append(c)
475    val = "".join(cs)
476    if (
477        len(val) == 0
478        or val == "NULL"
479        or any([c in val for c in ("{", "}", ",", " ", "\\")])
480    ):
481        val = f'"{val}"'
482    return val
483
484
485def array_out(ar):
486    result = []
487    for v in ar:
488
489        if isinstance(v, (list, tuple)):
490            val = array_out(v)
491
492        elif v is None:
493            val = "NULL"
494
495        elif isinstance(v, dict):
496            val = array_string_escape(json_out(v))
497
498        elif isinstance(v, (bytes, bytearray)):
499            val = f'"\\{bytes_out(v)}"'
500
501        elif isinstance(v, str):
502            val = array_string_escape(v)
503
504        else:
505            val = make_param(PY_TYPES, v)
506
507        result.append(val)
508
509    return "{" + ",".join(result) + "}"
510
511
512PY_PG = {
513    Date: DATE,
514    Decimal: NUMERIC,
515    IPv4Address: INET,
516    IPv6Address: INET,
517    IPv4Network: INET,
518    IPv6Network: INET,
519    PGInterval: INTERVAL,
520    Time: TIME,
521    Timedelta: INTERVAL,
522    UUID: UUID_TYPE,
523    bool: BOOLEAN,
524    bytearray: BYTES,
525    dict: JSONB,
526    float: FLOAT,
527    type(None): NULLTYPE,
528    bytes: BYTES,
529    str: TEXT,
530}
531
532
533PY_TYPES = {
534    Date: date_out,  # date
535    Datetime: datetime_out,
536    Decimal: numeric_out,  # numeric
537    Enum: enum_out,  # enum
538    IPv4Address: inet_out,  # inet
539    IPv6Address: inet_out,  # inet
540    IPv4Network: inet_out,  # inet
541    IPv6Network: inet_out,  # inet
542    PGInterval: interval_out,  # interval
543    Time: time_out,  # time
544    Timedelta: interval_out,  # interval
545    UUID: uuid_out,  # uuid
546    bool: bool_out,  # bool
547    bytearray: bytes_out,  # bytea
548    dict: json_out,  # jsonb
549    float: float_out,  # float8
550    type(None): null_out,  # null
551    bytes: bytes_out,  # bytea
552    str: string_out,  # unknown
553    int: int_out,
554    list: array_out,
555    tuple: array_out,
556}
557
558
559PG_TYPES = {
560    BIGINT: int,  # int8
561    BIGINT_ARRAY: int_array_in,  # int8[]
562    BOOLEAN: bool_in,  # bool
563    BOOLEAN_ARRAY: bool_array_in,  # bool[]
564    BYTES: bytes_in,  # bytea
565    BYTES_ARRAY: bytes_array_in,  # bytea[]
566    CHAR: string_in,  # char
567    CHAR_ARRAY: string_array_in,  # char[]
568    CIDR_ARRAY: cidr_array_in,  # cidr[]
569    CSTRING: string_in,  # cstring
570    CSTRING_ARRAY: string_array_in,  # cstring[]
571    DATE: date_in,  # date
572    DATE_ARRAY: date_array_in,  # date[]
573    FLOAT: float,  # float8
574    FLOAT_ARRAY: float_array_in,  # float8[]
575    INET: inet_in,  # inet
576    INET_ARRAY: inet_array_in,  # inet[]
577    INTEGER: int,  # int4
578    INTEGER_ARRAY: int_array_in,  # int4[]
579    JSON: json_in,  # json
580    JSON_ARRAY: json_array_in,  # json[]
581    JSONB: json_in,  # jsonb
582    JSONB_ARRAY: json_array_in,  # jsonb[]
583    MACADDR: string_in,  # MACADDR type
584    MONEY: string_in,  # money
585    MONEY_ARRAY: string_array_in,  # money[]
586    NAME: string_in,  # name
587    NAME_ARRAY: string_array_in,  # name[]
588    NUMERIC: numeric_in,  # numeric
589    NUMERIC_ARRAY: numeric_array_in,  # numeric[]
590    OID: int,  # oid
591    INTERVAL: interval_in,  # interval
592    INTERVAL_ARRAY: interval_array_in,  # interval[]
593    REAL: float,  # float4
594    REAL_ARRAY: float_array_in,  # float4[]
595    SMALLINT: int,  # int2
596    SMALLINT_ARRAY: int_array_in,  # int2[]
597    SMALLINT_VECTOR: vector_in,  # int2vector
598    TEXT: string_in,  # text
599    TEXT_ARRAY: string_array_in,  # text[]
600    TIME: time_in,  # time
601    TIME_ARRAY: time_array_in,  # time[]
602    INTERVAL: interval_in,  # interval
603    TIMESTAMP: timestamp_in,  # timestamp
604    TIMESTAMP_ARRAY: timestamp_array_in,  # timestamp
605    TIMESTAMPTZ: timestamptz_in,  # timestamptz
606    TIMESTAMPTZ_ARRAY: timestamptz_array_in,  # timestamptz
607    UNKNOWN: string_in,  # unknown
608    UUID_ARRAY: uuid_array_in,  # uuid[]
609    UUID_TYPE: uuid_in,  # uuid
610    VARCHAR: string_in,  # varchar
611    VARCHAR_ARRAY: string_array_in,  # varchar[]
612    XID: int,  # xid
613}
614
615
616# PostgreSQL encodings:
617# https://www.postgresql.org/docs/current/multibyte.html
618#
619# Python encodings:
620# https://docs.python.org/3/library/codecs.html
621#
622# Commented out encodings don't require a name change between PostgreSQL and
623# Python.  If the py side is None, then the encoding isn't supported.
624PG_PY_ENCODINGS = {
625    # Not supported:
626    "mule_internal": None,
627    "euc_tw": None,
628    # Name fine as-is:
629    # "euc_jp",
630    # "euc_jis_2004",
631    # "euc_kr",
632    # "gb18030",
633    # "gbk",
634    # "johab",
635    # "sjis",
636    # "shift_jis_2004",
637    # "uhc",
638    # "utf8",
639    # Different name:
640    "euc_cn": "gb2312",
641    "iso_8859_5": "is8859_5",
642    "iso_8859_6": "is8859_6",
643    "iso_8859_7": "is8859_7",
644    "iso_8859_8": "is8859_8",
645    "koi8": "koi8_r",
646    "latin1": "iso8859-1",
647    "latin2": "iso8859_2",
648    "latin3": "iso8859_3",
649    "latin4": "iso8859_4",
650    "latin5": "iso8859_9",
651    "latin6": "iso8859_10",
652    "latin7": "iso8859_13",
653    "latin8": "iso8859_14",
654    "latin9": "iso8859_15",
655    "sql_ascii": "ascii",
656    "win866": "cp886",
657    "win874": "cp874",
658    "win1250": "cp1250",
659    "win1251": "cp1251",
660    "win1252": "cp1252",
661    "win1253": "cp1253",
662    "win1254": "cp1254",
663    "win1255": "cp1255",
664    "win1256": "cp1256",
665    "win1257": "cp1257",
666    "win1258": "cp1258",
667    "unicode": "utf-8",  # Needed for Amazon Redshift
668}
669
670
671def make_param(py_types, value):
672    try:
673        func = py_types[type(value)]
674    except KeyError:
675        func = str
676        for k, v in py_types.items():
677            try:
678                if isinstance(value, k):
679                    func = v
680                    break
681            except TypeError:
682                pass
683
684    return func(value)
685
686
687def make_params(py_types, values):
688    return tuple([make_param(py_types, v) for v in values])
689