1from datetime import ( 2 date as Date, 3 datetime as Datetime, 4 time as Time, 5 timedelta as Timedelta, 6 timezone as Timezone, 7) 8from decimal import Decimal 9from enum import Enum 10from ipaddress import ( 11 IPv4Address, 12 IPv4Network, 13 IPv6Address, 14 IPv6Network, 15 ip_address, 16 ip_network, 17) 18from json import dumps, loads 19from uuid import UUID 20 21from pg8000.exceptions import InterfaceError 22 23 24ANY_ARRAY = 2277 25BIGINT = 20 26BIGINT_ARRAY = 1016 27BOOLEAN = 16 28BOOLEAN_ARRAY = 1000 29BYTES = 17 30BYTES_ARRAY = 1001 31CHAR = 1042 32CHAR_ARRAY = 1014 33CIDR = 650 34CIDR_ARRAY = 651 35CSTRING = 2275 36CSTRING_ARRAY = 1263 37DATE = 1082 38DATE_ARRAY = 1182 39FLOAT = 701 40FLOAT_ARRAY = 1022 41INET = 869 42INET_ARRAY = 1041 43INT2VECTOR = 22 44INTEGER = 23 45INTEGER_ARRAY = 1007 46INTERVAL = 1186 47INTERVAL_ARRAY = 1187 48OID = 26 49JSON = 114 50JSON_ARRAY = 199 51JSONB = 3802 52JSONB_ARRAY = 3807 53MACADDR = 829 54MONEY = 790 55MONEY_ARRAY = 791 56NAME = 19 57NAME_ARRAY = 1003 58NUMERIC = 1700 59NUMERIC_ARRAY = 1231 60NULLTYPE = -1 61OID = 26 62POINT = 600 63REAL = 700 64REAL_ARRAY = 1021 65SMALLINT = 21 66SMALLINT_ARRAY = 1005 67SMALLINT_VECTOR = 22 68STRING = 1043 69TEXT = 25 70TEXT_ARRAY = 1009 71TIME = 1083 72TIME_ARRAY = 1183 73TIMESTAMP = 1114 74TIMESTAMP_ARRAY = 1115 75TIMESTAMPTZ = 1184 76TIMESTAMPTZ_ARRAY = 1185 77UNKNOWN = 705 78UUID_TYPE = 2950 79UUID_ARRAY = 2951 80VARCHAR = 1043 81VARCHAR_ARRAY = 1015 82XID = 28 83 84 85MIN_INT2, MAX_INT2 = -(2 ** 15), 2 ** 15 86MIN_INT4, MAX_INT4 = -(2 ** 31), 2 ** 31 87MIN_INT8, MAX_INT8 = -(2 ** 63), 2 ** 63 88 89 90def bool_in(data): 91 return data == "t" 92 93 94def bool_out(v): 95 return "true" if v else "false" 96 97 98def bytes_in(data): 99 return bytes.fromhex(data[2:]) 100 101 102def bytes_out(v): 103 return "\\x" + v.hex() 104 105 106def cidr_out(v): 107 return str(v) 108 109 110def cidr_in(data): 111 return ip_network(data, False) if "/" in data else ip_address(data) 112 113 114def date_in(data): 115 return Datetime.strptime(data, "%Y-%m-%d").date() 116 117 118def date_out(v): 119 return v.isoformat() 120 121 122def datetime_out(v): 123 if v.tzinfo is None: 124 return v.isoformat() 125 else: 126 return v.astimezone(Timezone.utc).isoformat() 127 128 129def enum_out(v): 130 return str(v.value) 131 132 133def float_out(v): 134 return str(v) 135 136 137def inet_in(data): 138 return ip_network(data, False) if "/" in data else ip_address(data) 139 140 141def inet_out(v): 142 return str(v) 143 144 145def int_in(data): 146 return int(data) 147 148 149def int_out(v): 150 return str(v) 151 152 153def interval_in(data): 154 t = {} 155 156 curr_val = None 157 for k in data.split(): 158 if ":" in k: 159 t["hours"], t["minutes"], t["seconds"] = map(float, k.split(":")) 160 else: 161 try: 162 curr_val = float(k) 163 except ValueError: 164 t[PGInterval.UNIT_MAP[k]] = curr_val 165 166 for n in ["weeks", "months", "years", "decades", "centuries", "millennia"]: 167 if n in t: 168 raise InterfaceError( 169 f"Can't fit the interval {t} into a datetime.timedelta." 170 ) 171 172 return Timedelta(**t) 173 174 175def interval_out(v): 176 return f"{v.days} days {v.seconds} seconds {v.microseconds} microseconds" 177 178 179def json_in(data): 180 return loads(data) 181 182 183def json_out(v): 184 return dumps(v) 185 186 187def null_out(v): 188 return None 189 190 191def numeric_in(data): 192 return Decimal(data) 193 194 195def numeric_out(d): 196 return str(d) 197 198 199def pg_interval_in(data): 200 return PGInterval.from_str(data) 201 202 203def pg_interval_out(v): 204 return str(v) 205 206 207def string_in(data): 208 return data 209 210 211def string_out(v): 212 return v 213 214 215def time_in(data): 216 pattern = "%H:%M:%S.%f" if "." in data else "%H:%M:%S" 217 return Datetime.strptime(data, pattern).time() 218 219 220def time_out(v): 221 return v.isoformat() 222 223 224def timestamp_in(data): 225 if data in ("infinity", "-infinity"): 226 return data 227 228 pattern = "%Y-%m-%d %H:%M:%S.%f" if "." in data else "%Y-%m-%d %H:%M:%S" 229 return Datetime.strptime(data, pattern) 230 231 232def timestamptz_in(data): 233 patt = "%Y-%m-%d %H:%M:%S.%f%z" if "." in data else "%Y-%m-%d %H:%M:%S%z" 234 return Datetime.strptime(data + "00", patt) 235 236 237def unknown_out(v): 238 return str(v) 239 240 241def vector_in(data): 242 return eval("[" + data.replace(" ", ",") + "]") 243 244 245def uuid_out(v): 246 return str(v) 247 248 249def uuid_in(data): 250 return UUID(data) 251 252 253class PGInterval: 254 UNIT_MAP = { 255 "year": "years", 256 "years": "years", 257 "millennia": "millennia", 258 "millenium": "millennia", 259 "centuries": "centuries", 260 "century": "centuries", 261 "decades": "decades", 262 "decade": "decades", 263 "years": "years", 264 "year": "years", 265 "months": "months", 266 "month": "months", 267 "mon": "months", 268 "mons": "months", 269 "weeks": "weeks", 270 "week": "weeks", 271 "days": "days", 272 "day": "days", 273 "hours": "hours", 274 "hour": "hours", 275 "minutes": "minutes", 276 "minute": "minutes", 277 "seconds": "seconds", 278 "second": "seconds", 279 "microseconds": "microseconds", 280 "microsecond": "microseconds", 281 } 282 283 @staticmethod 284 def from_str(interval_str): 285 t = {} 286 287 curr_val = None 288 for k in interval_str.split(): 289 if ":" in k: 290 hours_str, minutes_str, seconds_str = k.split(":") 291 hours = int(hours_str) 292 if hours != 0: 293 t["hours"] = hours 294 minutes = int(minutes_str) 295 if minutes != 0: 296 t["minutes"] = minutes 297 try: 298 seconds = int(seconds_str) 299 except ValueError: 300 seconds = float(seconds_str) 301 302 if seconds != 0: 303 t["seconds"] = seconds 304 305 else: 306 try: 307 curr_val = int(k) 308 except ValueError: 309 t[PGInterval.UNIT_MAP[k]] = curr_val 310 311 return PGInterval(**t) 312 313 def __init__( 314 self, 315 millennia=None, 316 centuries=None, 317 decades=None, 318 years=None, 319 months=None, 320 weeks=None, 321 days=None, 322 hours=None, 323 minutes=None, 324 seconds=None, 325 microseconds=None, 326 ): 327 self.millennia = millennia 328 self.centuries = centuries 329 self.decades = decades 330 self.years = years 331 self.months = months 332 self.weeks = weeks 333 self.days = days 334 self.hours = hours 335 self.minutes = minutes 336 self.seconds = seconds 337 self.microseconds = microseconds 338 339 def __repr__(self): 340 return f"<PGInterval {self}>" 341 342 def __str__(self): 343 pairs = ( 344 ("millennia", self.millennia), 345 ("centuries", self.centuries), 346 ("decades", self.decades), 347 ("years", self.years), 348 ("months", self.months), 349 ("weeks", self.weeks), 350 ("days", self.days), 351 ("hours", self.hours), 352 ("minutes", self.minutes), 353 ("seconds", self.seconds), 354 ("microseconds", self.microseconds), 355 ) 356 return " ".join(f"{v} {n}" for n, v in pairs if v is not None) 357 358 def normalize(self): 359 months = 0 360 if self.months is not None: 361 months += self.months 362 if self.years is not None: 363 months += self.years * 12 364 365 days = 0 366 if self.days is not None: 367 days += self.days 368 if self.weeks is not None: 369 days += self.weeks * 7 370 371 seconds = 0 372 if self.hours is not None: 373 seconds += self.hours * 60 * 60 374 if self.minutes is not None: 375 seconds += self.minutes * 60 376 if self.seconds is not None: 377 seconds += self.seconds 378 if self.microseconds is not None: 379 seconds += self.microseconds / 1000000 380 381 return PGInterval(months=months, days=days, seconds=seconds) 382 383 def __eq__(self, other): 384 if isinstance(other, PGInterval): 385 s = self.normalize() 386 o = other.normalize() 387 return s.months == o.months and s.days == o.days and s.seconds == o.seconds 388 else: 389 return False 390 391 392class ArrayState(Enum): 393 InString = 1 394 InEscape = 2 395 InValue = 3 396 Out = 4 397 398 399def _parse_array(data, adapter): 400 state = ArrayState.Out 401 stack = [[]] 402 val = [] 403 for c in data: 404 if state == ArrayState.InValue: 405 if c in ("}", ","): 406 value = "".join(val) 407 stack[-1].append(None if value == "NULL" else adapter(value)) 408 state = ArrayState.Out 409 else: 410 val.append(c) 411 412 if state == ArrayState.Out: 413 if c == "{": 414 a = [] 415 stack[-1].append(a) 416 stack.append(a) 417 elif c == "}": 418 stack.pop() 419 elif c == ",": 420 pass 421 elif c == '"': 422 val = [] 423 state = ArrayState.InString 424 else: 425 val = [c] 426 state = ArrayState.InValue 427 428 elif state == ArrayState.InString: 429 if c == '"': 430 stack[-1].append(adapter("".join(val))) 431 state = ArrayState.Out 432 elif c == "\\": 433 state = ArrayState.InEscape 434 else: 435 val.append(c) 436 elif state == ArrayState.InEscape: 437 val.append(c) 438 state = ArrayState.InString 439 440 return stack[0][0] 441 442 443def _array_in(adapter): 444 def f(data): 445 return _parse_array(data, adapter) 446 447 return f 448 449 450bool_array_in = _array_in(bool_in) 451bytes_array_in = _array_in(bytes_in) 452cidr_array_in = _array_in(cidr_in) 453date_array_in = _array_in(date_in) 454inet_array_in = _array_in(inet_in) 455int_array_in = _array_in(int) 456interval_array_in = _array_in(interval_in) 457json_array_in = _array_in(json_in) 458float_array_in = _array_in(float) 459numeric_array_in = _array_in(numeric_in) 460string_array_in = _array_in(string_in) 461time_array_in = _array_in(time_in) 462timestamp_array_in = _array_in(timestamp_in) 463timestamptz_array_in = _array_in(timestamptz_in) 464uuid_array_in = _array_in(uuid_in) 465 466 467def array_string_escape(v): 468 cs = [] 469 for c in v: 470 if c == "\\": 471 cs.append("\\") 472 elif c == '"': 473 cs.append("\\") 474 cs.append(c) 475 val = "".join(cs) 476 if ( 477 len(val) == 0 478 or val == "NULL" 479 or any([c in val for c in ("{", "}", ",", " ", "\\")]) 480 ): 481 val = f'"{val}"' 482 return val 483 484 485def array_out(ar): 486 result = [] 487 for v in ar: 488 489 if isinstance(v, (list, tuple)): 490 val = array_out(v) 491 492 elif v is None: 493 val = "NULL" 494 495 elif isinstance(v, dict): 496 val = array_string_escape(json_out(v)) 497 498 elif isinstance(v, (bytes, bytearray)): 499 val = f'"\\{bytes_out(v)}"' 500 501 elif isinstance(v, str): 502 val = array_string_escape(v) 503 504 else: 505 val = make_param(PY_TYPES, v) 506 507 result.append(val) 508 509 return "{" + ",".join(result) + "}" 510 511 512PY_PG = { 513 Date: DATE, 514 Decimal: NUMERIC, 515 IPv4Address: INET, 516 IPv6Address: INET, 517 IPv4Network: INET, 518 IPv6Network: INET, 519 PGInterval: INTERVAL, 520 Time: TIME, 521 Timedelta: INTERVAL, 522 UUID: UUID_TYPE, 523 bool: BOOLEAN, 524 bytearray: BYTES, 525 dict: JSONB, 526 float: FLOAT, 527 type(None): NULLTYPE, 528 bytes: BYTES, 529 str: TEXT, 530} 531 532 533PY_TYPES = { 534 Date: date_out, # date 535 Datetime: datetime_out, 536 Decimal: numeric_out, # numeric 537 Enum: enum_out, # enum 538 IPv4Address: inet_out, # inet 539 IPv6Address: inet_out, # inet 540 IPv4Network: inet_out, # inet 541 IPv6Network: inet_out, # inet 542 PGInterval: interval_out, # interval 543 Time: time_out, # time 544 Timedelta: interval_out, # interval 545 UUID: uuid_out, # uuid 546 bool: bool_out, # bool 547 bytearray: bytes_out, # bytea 548 dict: json_out, # jsonb 549 float: float_out, # float8 550 type(None): null_out, # null 551 bytes: bytes_out, # bytea 552 str: string_out, # unknown 553 int: int_out, 554 list: array_out, 555 tuple: array_out, 556} 557 558 559PG_TYPES = { 560 BIGINT: int, # int8 561 BIGINT_ARRAY: int_array_in, # int8[] 562 BOOLEAN: bool_in, # bool 563 BOOLEAN_ARRAY: bool_array_in, # bool[] 564 BYTES: bytes_in, # bytea 565 BYTES_ARRAY: bytes_array_in, # bytea[] 566 CHAR: string_in, # char 567 CHAR_ARRAY: string_array_in, # char[] 568 CIDR_ARRAY: cidr_array_in, # cidr[] 569 CSTRING: string_in, # cstring 570 CSTRING_ARRAY: string_array_in, # cstring[] 571 DATE: date_in, # date 572 DATE_ARRAY: date_array_in, # date[] 573 FLOAT: float, # float8 574 FLOAT_ARRAY: float_array_in, # float8[] 575 INET: inet_in, # inet 576 INET_ARRAY: inet_array_in, # inet[] 577 INTEGER: int, # int4 578 INTEGER_ARRAY: int_array_in, # int4[] 579 JSON: json_in, # json 580 JSON_ARRAY: json_array_in, # json[] 581 JSONB: json_in, # jsonb 582 JSONB_ARRAY: json_array_in, # jsonb[] 583 MACADDR: string_in, # MACADDR type 584 MONEY: string_in, # money 585 MONEY_ARRAY: string_array_in, # money[] 586 NAME: string_in, # name 587 NAME_ARRAY: string_array_in, # name[] 588 NUMERIC: numeric_in, # numeric 589 NUMERIC_ARRAY: numeric_array_in, # numeric[] 590 OID: int, # oid 591 INTERVAL: interval_in, # interval 592 INTERVAL_ARRAY: interval_array_in, # interval[] 593 REAL: float, # float4 594 REAL_ARRAY: float_array_in, # float4[] 595 SMALLINT: int, # int2 596 SMALLINT_ARRAY: int_array_in, # int2[] 597 SMALLINT_VECTOR: vector_in, # int2vector 598 TEXT: string_in, # text 599 TEXT_ARRAY: string_array_in, # text[] 600 TIME: time_in, # time 601 TIME_ARRAY: time_array_in, # time[] 602 INTERVAL: interval_in, # interval 603 TIMESTAMP: timestamp_in, # timestamp 604 TIMESTAMP_ARRAY: timestamp_array_in, # timestamp 605 TIMESTAMPTZ: timestamptz_in, # timestamptz 606 TIMESTAMPTZ_ARRAY: timestamptz_array_in, # timestamptz 607 UNKNOWN: string_in, # unknown 608 UUID_ARRAY: uuid_array_in, # uuid[] 609 UUID_TYPE: uuid_in, # uuid 610 VARCHAR: string_in, # varchar 611 VARCHAR_ARRAY: string_array_in, # varchar[] 612 XID: int, # xid 613} 614 615 616# PostgreSQL encodings: 617# https://www.postgresql.org/docs/current/multibyte.html 618# 619# Python encodings: 620# https://docs.python.org/3/library/codecs.html 621# 622# Commented out encodings don't require a name change between PostgreSQL and 623# Python. If the py side is None, then the encoding isn't supported. 624PG_PY_ENCODINGS = { 625 # Not supported: 626 "mule_internal": None, 627 "euc_tw": None, 628 # Name fine as-is: 629 # "euc_jp", 630 # "euc_jis_2004", 631 # "euc_kr", 632 # "gb18030", 633 # "gbk", 634 # "johab", 635 # "sjis", 636 # "shift_jis_2004", 637 # "uhc", 638 # "utf8", 639 # Different name: 640 "euc_cn": "gb2312", 641 "iso_8859_5": "is8859_5", 642 "iso_8859_6": "is8859_6", 643 "iso_8859_7": "is8859_7", 644 "iso_8859_8": "is8859_8", 645 "koi8": "koi8_r", 646 "latin1": "iso8859-1", 647 "latin2": "iso8859_2", 648 "latin3": "iso8859_3", 649 "latin4": "iso8859_4", 650 "latin5": "iso8859_9", 651 "latin6": "iso8859_10", 652 "latin7": "iso8859_13", 653 "latin8": "iso8859_14", 654 "latin9": "iso8859_15", 655 "sql_ascii": "ascii", 656 "win866": "cp886", 657 "win874": "cp874", 658 "win1250": "cp1250", 659 "win1251": "cp1251", 660 "win1252": "cp1252", 661 "win1253": "cp1253", 662 "win1254": "cp1254", 663 "win1255": "cp1255", 664 "win1256": "cp1256", 665 "win1257": "cp1257", 666 "win1258": "cp1258", 667 "unicode": "utf-8", # Needed for Amazon Redshift 668} 669 670 671def make_param(py_types, value): 672 try: 673 func = py_types[type(value)] 674 except KeyError: 675 func = str 676 for k, v in py_types.items(): 677 try: 678 if isinstance(value, k): 679 func = v 680 break 681 except TypeError: 682 pass 683 684 return func(value) 685 686 687def make_params(py_types, values): 688 return tuple([make_param(py_types, v) for v in values]) 689