1import codecs 2import socket 3import struct 4from collections import defaultdict, deque 5from hashlib import md5 6from io import TextIOBase 7from itertools import count 8from struct import Struct 9 10import scramp 11 12from pg8000.converters import ( 13 PG_PY_ENCODINGS, 14 PG_TYPES, 15 PY_TYPES, 16 make_params, 17 string_in, 18) 19from pg8000.exceptions import DatabaseError, InterfaceError 20 21 22def pack_funcs(fmt): 23 struc = Struct(f"!{fmt}") 24 return struc.pack, struc.unpack_from 25 26 27i_pack, i_unpack = pack_funcs("i") 28h_pack, h_unpack = pack_funcs("h") 29ii_pack, ii_unpack = pack_funcs("ii") 30ihihih_pack, ihihih_unpack = pack_funcs("ihihih") 31ci_pack, ci_unpack = pack_funcs("ci") 32bh_pack, bh_unpack = pack_funcs("bh") 33cccc_pack, cccc_unpack = pack_funcs("cccc") 34 35 36# Copyright (c) 2007-2009, Mathieu Fenniak 37# Copyright (c) The Contributors 38# All rights reserved. 39# 40# Redistribution and use in source and binary forms, with or without 41# modification, are permitted provided that the following conditions are 42# met: 43# 44# * Redistributions of source code must retain the above copyright notice, 45# this list of conditions and the following disclaimer. 46# * Redistributions in binary form must reproduce the above copyright notice, 47# this list of conditions and the following disclaimer in the documentation 48# and/or other materials provided with the distribution. 49# * The name of the author may not be used to endorse or promote products 50# derived from this software without specific prior written permission. 51# 52# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 53# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 54# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 55# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 56# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 57# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 58# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 59# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 60# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 61# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 62# POSSIBILITY OF SUCH DAMAGE. 63 64__author__ = "Mathieu Fenniak" 65 66 67NULL_BYTE = b"\x00" 68 69 70# Message codes 71NOTICE_RESPONSE = b"N" 72AUTHENTICATION_REQUEST = b"R" 73PARAMETER_STATUS = b"S" 74BACKEND_KEY_DATA = b"K" 75READY_FOR_QUERY = b"Z" 76ROW_DESCRIPTION = b"T" 77ERROR_RESPONSE = b"E" 78DATA_ROW = b"D" 79COMMAND_COMPLETE = b"C" 80PARSE_COMPLETE = b"1" 81BIND_COMPLETE = b"2" 82CLOSE_COMPLETE = b"3" 83PORTAL_SUSPENDED = b"s" 84NO_DATA = b"n" 85PARAMETER_DESCRIPTION = b"t" 86NOTIFICATION_RESPONSE = b"A" 87COPY_DONE = b"c" 88COPY_DATA = b"d" 89COPY_IN_RESPONSE = b"G" 90COPY_OUT_RESPONSE = b"H" 91EMPTY_QUERY_RESPONSE = b"I" 92 93BIND = b"B" 94PARSE = b"P" 95QUERY = b"Q" 96EXECUTE = b"E" 97FLUSH = b"H" 98SYNC = b"S" 99PASSWORD = b"p" 100DESCRIBE = b"D" 101TERMINATE = b"X" 102CLOSE = b"C" 103 104 105def _create_message(code, data=b""): 106 return code + i_pack(len(data) + 4) + data 107 108 109FLUSH_MSG = _create_message(FLUSH) 110SYNC_MSG = _create_message(SYNC) 111TERMINATE_MSG = _create_message(TERMINATE) 112COPY_DONE_MSG = _create_message(COPY_DONE) 113EXECUTE_MSG = _create_message(EXECUTE, NULL_BYTE + i_pack(0)) 114 115# DESCRIBE constants 116STATEMENT = b"S" 117PORTAL = b"P" 118 119# ErrorResponse codes 120RESPONSE_SEVERITY = "S" # always present 121RESPONSE_SEVERITY = "V" # always present 122RESPONSE_CODE = "C" # always present 123RESPONSE_MSG = "M" # always present 124RESPONSE_DETAIL = "D" 125RESPONSE_HINT = "H" 126RESPONSE_POSITION = "P" 127RESPONSE__POSITION = "p" 128RESPONSE__QUERY = "q" 129RESPONSE_WHERE = "W" 130RESPONSE_FILE = "F" 131RESPONSE_LINE = "L" 132RESPONSE_ROUTINE = "R" 133 134IDLE = b"I" 135IDLE_IN_TRANSACTION = b"T" 136IDLE_IN_FAILED_TRANSACTION = b"E" 137 138 139class CoreConnection: 140 def __enter__(self): 141 return self 142 143 def __exit__(self, exc_type, exc_value, traceback): 144 self.close() 145 146 def __init__( 147 self, 148 user, 149 host="localhost", 150 database=None, 151 port=5432, 152 password=None, 153 source_address=None, 154 unix_sock=None, 155 ssl_context=None, 156 timeout=None, 157 tcp_keepalive=True, 158 application_name=None, 159 replication=None, 160 ): 161 self._client_encoding = "utf8" 162 self._commands_with_count = ( 163 b"INSERT", 164 b"DELETE", 165 b"UPDATE", 166 b"MOVE", 167 b"FETCH", 168 b"COPY", 169 b"SELECT", 170 ) 171 self.notifications = deque(maxlen=100) 172 self.notices = deque(maxlen=100) 173 self.parameter_statuses = deque(maxlen=100) 174 175 if user is None: 176 raise InterfaceError("The 'user' connection parameter cannot be None") 177 178 init_params = { 179 "user": user, 180 "database": database, 181 "application_name": application_name, 182 "replication": replication, 183 } 184 185 for k, v in tuple(init_params.items()): 186 if isinstance(v, str): 187 init_params[k] = v.encode("utf8") 188 elif v is None: 189 del init_params[k] 190 elif not isinstance(v, (bytes, bytearray)): 191 raise InterfaceError(f"The parameter {k} can't be of type {type(v)}.") 192 193 self.user = init_params["user"] 194 195 if isinstance(password, str): 196 self.password = password.encode("utf8") 197 else: 198 self.password = password 199 200 self.autocommit = False 201 self._xid = None 202 self._statement_nums = set() 203 204 self._caches = {} 205 206 if unix_sock is None and host is not None: 207 try: 208 self._usock = socket.create_connection( 209 (host, port), timeout, source_address 210 ) 211 except socket.error as e: 212 raise InterfaceError( 213 f"Can't create a connection to host {host} and port {port} " 214 f"(timeout is {timeout} and source_address is {source_address})." 215 ) from e 216 217 elif unix_sock is not None: 218 try: 219 if not hasattr(socket, "AF_UNIX"): 220 raise InterfaceError( 221 "attempt to connect to unix socket on unsupported platform" 222 ) 223 self._usock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 224 self._usock.settimeout(timeout) 225 self._usock.connect(unix_sock) 226 except socket.error as e: 227 if self._usock is not None: 228 self._usock.close() 229 raise InterfaceError("communication error") from e 230 231 else: 232 raise InterfaceError("one of host or unix_sock must be provided") 233 234 if tcp_keepalive: 235 self._usock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 236 237 self.channel_binding = None 238 if ssl_context is not None: 239 try: 240 import ssl 241 242 if ssl_context is True: 243 ssl_context = ssl.create_default_context() 244 245 request_ssl = getattr(ssl_context, "request_ssl", True) 246 247 if request_ssl: 248 # Int32(8) - Message length, including self. 249 # Int32(80877103) - The SSL request code. 250 self._usock.sendall(ii_pack(8, 80877103)) 251 resp = self._usock.recv(1) 252 if resp != b"S": 253 raise InterfaceError("Server refuses SSL") 254 255 self._usock = ssl_context.wrap_socket(self._usock, server_hostname=host) 256 257 if request_ssl: 258 self.channel_binding = scramp.make_channel_binding( 259 "tls-server-end-point", self._usock 260 ) 261 262 except ImportError: 263 raise InterfaceError( 264 "SSL required but ssl module not available in this python " 265 "installation." 266 ) 267 268 self._sock = self._usock.makefile(mode="rwb") 269 270 def sock_flush(): 271 try: 272 self._sock.flush() 273 except OSError as e: 274 raise InterfaceError("network error on flush") from e 275 276 self._flush = sock_flush 277 278 def sock_read(b): 279 try: 280 return self._sock.read(b) 281 except OSError as e: 282 raise InterfaceError("network error on read") from e 283 284 self._read = sock_read 285 286 def sock_write(d): 287 try: 288 self._sock.write(d) 289 except OSError as e: 290 raise InterfaceError("network error on write") from e 291 292 self._write = sock_write 293 self._backend_key_data = None 294 295 self.pg_types = defaultdict(lambda: string_in, PG_TYPES) 296 self.py_types = dict(PY_TYPES) 297 298 self.message_types = { 299 NOTICE_RESPONSE: self.handle_NOTICE_RESPONSE, 300 AUTHENTICATION_REQUEST: self.handle_AUTHENTICATION_REQUEST, 301 PARAMETER_STATUS: self.handle_PARAMETER_STATUS, 302 BACKEND_KEY_DATA: self.handle_BACKEND_KEY_DATA, 303 READY_FOR_QUERY: self.handle_READY_FOR_QUERY, 304 ROW_DESCRIPTION: self.handle_ROW_DESCRIPTION, 305 ERROR_RESPONSE: self.handle_ERROR_RESPONSE, 306 EMPTY_QUERY_RESPONSE: self.handle_EMPTY_QUERY_RESPONSE, 307 DATA_ROW: self.handle_DATA_ROW, 308 COMMAND_COMPLETE: self.handle_COMMAND_COMPLETE, 309 PARSE_COMPLETE: self.handle_PARSE_COMPLETE, 310 BIND_COMPLETE: self.handle_BIND_COMPLETE, 311 CLOSE_COMPLETE: self.handle_CLOSE_COMPLETE, 312 PORTAL_SUSPENDED: self.handle_PORTAL_SUSPENDED, 313 NO_DATA: self.handle_NO_DATA, 314 PARAMETER_DESCRIPTION: self.handle_PARAMETER_DESCRIPTION, 315 NOTIFICATION_RESPONSE: self.handle_NOTIFICATION_RESPONSE, 316 COPY_DONE: self.handle_COPY_DONE, 317 COPY_DATA: self.handle_COPY_DATA, 318 COPY_IN_RESPONSE: self.handle_COPY_IN_RESPONSE, 319 COPY_OUT_RESPONSE: self.handle_COPY_OUT_RESPONSE, 320 } 321 322 # Int32 - Message length, including self. 323 # Int32(196608) - Protocol version number. Version 3.0. 324 # Any number of key/value pairs, terminated by a zero byte: 325 # String - A parameter name (user, database, or options) 326 # String - Parameter value 327 protocol = 196608 328 val = bytearray(i_pack(protocol)) 329 330 for k, v in init_params.items(): 331 val.extend(k.encode("ascii") + NULL_BYTE + v + NULL_BYTE) 332 val.append(0) 333 self._write(i_pack(len(val) + 4)) 334 self._write(val) 335 self._flush() 336 337 code = self.error = None 338 while code not in (READY_FOR_QUERY, ERROR_RESPONSE): 339 code, data_len = ci_unpack(self._read(5)) 340 self.message_types[code](self._read(data_len - 4), None) 341 if self.error is not None: 342 raise self.error 343 344 self.in_transaction = False 345 346 def register_out_adapter(self, typ, out_func): 347 self.py_types[typ] = out_func 348 349 def register_in_adapter(self, oid, in_func): 350 self.pg_types[oid] = in_func 351 352 def handle_ERROR_RESPONSE(self, data, context): 353 msg = dict( 354 ( 355 s[:1].decode("ascii"), 356 s[1:].decode(self._client_encoding, errors="replace"), 357 ) 358 for s in data.split(NULL_BYTE) 359 if s != b"" 360 ) 361 362 self.error = DatabaseError(msg) 363 364 def handle_EMPTY_QUERY_RESPONSE(self, data, context): 365 self.error = DatabaseError("query was empty") 366 367 def handle_CLOSE_COMPLETE(self, data, context): 368 pass 369 370 def handle_PARSE_COMPLETE(self, data, context): 371 # Byte1('1') - Identifier. 372 # Int32(4) - Message length, including self. 373 pass 374 375 def handle_BIND_COMPLETE(self, data, context): 376 pass 377 378 def handle_PORTAL_SUSPENDED(self, data, context): 379 pass 380 381 def handle_PARAMETER_DESCRIPTION(self, data, context): 382 """https://www.postgresql.org/docs/current/protocol-message-formats.html""" 383 384 # count = h_unpack(data)[0] 385 # context.parameter_oids = unpack_from("!" + "i" * count, data, 2) 386 387 def handle_COPY_DONE(self, data, context): 388 pass 389 390 def handle_COPY_OUT_RESPONSE(self, data, context): 391 """https://www.postgresql.org/docs/current/protocol-message-formats.html""" 392 393 is_binary, num_cols = bh_unpack(data) 394 # column_formats = unpack_from('!' + 'h' * num_cols, data, 3) 395 396 if context.stream is None: 397 raise InterfaceError( 398 "An output stream is required for the COPY OUT response." 399 ) 400 401 elif isinstance(context.stream, TextIOBase): 402 if is_binary: 403 raise InterfaceError( 404 "The COPY OUT stream is binary, but the stream parameter is text." 405 ) 406 else: 407 decode = codecs.getdecoder(self._client_encoding) 408 409 def w(data): 410 context.stream.write(decode(data)[0]) 411 412 context.stream_write = w 413 414 else: 415 context.stream_write = context.stream.write 416 417 def handle_COPY_DATA(self, data, context): 418 context.stream_write(data) 419 420 def handle_COPY_IN_RESPONSE(self, data, context): 421 """https://www.postgresql.org/docs/current/protocol-message-formats.html""" 422 is_binary, num_cols = bh_unpack(data) 423 # column_formats = unpack_from('!' + 'h' * num_cols, data, 3) 424 425 if context.stream is None: 426 raise InterfaceError( 427 "An input stream is required for the COPY IN response." 428 ) 429 430 elif isinstance(context.stream, TextIOBase): 431 if is_binary: 432 raise InterfaceError( 433 "The COPY IN stream is binary, but the stream parameter is text." 434 ) 435 436 else: 437 438 def ri(bffr): 439 bffr.clear() 440 bffr.extend(context.stream.read(4096).encode(self._client_encoding)) 441 return len(bffr) 442 443 readinto = ri 444 else: 445 readinto = context.stream.readinto 446 447 bffr = bytearray(8192) 448 while True: 449 bytes_read = readinto(bffr) 450 if bytes_read == 0: 451 break 452 self._write(COPY_DATA) 453 self._write(i_pack(bytes_read + 4)) 454 self._write(bffr[:bytes_read]) 455 self._flush() 456 457 # Send CopyDone 458 self._write(COPY_DONE_MSG) 459 self._write(SYNC_MSG) 460 self._flush() 461 462 def handle_NOTIFICATION_RESPONSE(self, data, context): 463 """https://www.postgresql.org/docs/current/protocol-message-formats.html""" 464 backend_pid = i_unpack(data)[0] 465 idx = 4 466 null_idx = data.find(NULL_BYTE, idx) 467 channel = data[idx:null_idx].decode("ascii") 468 payload = data[null_idx + 1 : -1].decode("ascii") 469 470 self.notifications.append((backend_pid, channel, payload)) 471 472 def close(self): 473 """Closes the database connection. 474 475 This function is part of the `DBAPI 2.0 specification 476 <http://www.python.org/dev/peps/pep-0249/>`_. 477 """ 478 try: 479 self._write(TERMINATE_MSG) 480 self._flush() 481 self._sock.close() 482 except AttributeError: 483 raise InterfaceError("connection is closed") 484 except ValueError: 485 raise InterfaceError("connection is closed") 486 except socket.error: 487 pass 488 finally: 489 self._usock.close() 490 self._sock = None 491 492 def handle_AUTHENTICATION_REQUEST(self, data, context): 493 """https://www.postgresql.org/docs/current/protocol-message-formats.html""" 494 495 auth_code = i_unpack(data)[0] 496 if auth_code == 0: 497 pass 498 elif auth_code == 3: 499 if self.password is None: 500 raise InterfaceError( 501 "server requesting password authentication, but no password was " 502 "provided" 503 ) 504 self._send_message(PASSWORD, self.password + NULL_BYTE) 505 self._flush() 506 507 elif auth_code == 5: 508 salt = b"".join(cccc_unpack(data, 4)) 509 if self.password is None: 510 raise InterfaceError( 511 "server requesting MD5 password authentication, but no password " 512 "was provided" 513 ) 514 pwd = b"md5" + md5( 515 md5(self.password + self.user).hexdigest().encode("ascii") + salt 516 ).hexdigest().encode("ascii") 517 self._send_message(PASSWORD, pwd + NULL_BYTE) 518 self._flush() 519 520 elif auth_code == 10: 521 # AuthenticationSASL 522 mechanisms = [m.decode("ascii") for m in data[4:-2].split(NULL_BYTE)] 523 524 self.auth = scramp.ScramClient( 525 mechanisms, 526 self.user.decode("utf8"), 527 self.password.decode("utf8"), 528 channel_binding=self.channel_binding, 529 ) 530 531 init = self.auth.get_client_first().encode("utf8") 532 mech = self.auth.mechanism_name.encode("ascii") + NULL_BYTE 533 534 # SASLInitialResponse 535 self._send_message(PASSWORD, mech + i_pack(len(init)) + init) 536 self._flush() 537 538 elif auth_code == 11: 539 # AuthenticationSASLContinue 540 self.auth.set_server_first(data[4:].decode("utf8")) 541 542 # SASLResponse 543 msg = self.auth.get_client_final().encode("utf8") 544 self._send_message(PASSWORD, msg) 545 self._flush() 546 547 elif auth_code == 12: 548 # AuthenticationSASLFinal 549 self.auth.set_server_final(data[4:].decode("utf8")) 550 551 elif auth_code in (2, 4, 6, 7, 8, 9): 552 raise InterfaceError( 553 f"Authentication method {auth_code} not supported by pg8000." 554 ) 555 else: 556 raise InterfaceError( 557 f"Authentication method {auth_code} not recognized by pg8000." 558 ) 559 560 def handle_READY_FOR_QUERY(self, data, context): 561 self.in_transaction = data != IDLE 562 563 def handle_BACKEND_KEY_DATA(self, data, context): 564 self._backend_key_data = data 565 566 def handle_ROW_DESCRIPTION(self, data, context): 567 count = h_unpack(data)[0] 568 idx = 2 569 columns = [] 570 input_funcs = [] 571 for i in range(count): 572 name = data[idx : data.find(NULL_BYTE, idx)] 573 idx += len(name) + 1 574 field = dict( 575 zip( 576 ( 577 "table_oid", 578 "column_attrnum", 579 "type_oid", 580 "type_size", 581 "type_modifier", 582 "format", 583 ), 584 ihihih_unpack(data, idx), 585 ) 586 ) 587 field["name"] = name.decode(self._client_encoding) 588 idx += 18 589 columns.append(field) 590 input_funcs.append(self.pg_types[field["type_oid"]]) 591 592 context.columns = columns 593 context.input_funcs = input_funcs 594 if context.rows is None: 595 context.rows = [] 596 597 def send_PARSE(self, statement_name_bin, statement, oids=()): 598 599 val = bytearray(statement_name_bin) 600 val.extend(statement.encode(self._client_encoding) + NULL_BYTE) 601 val.extend(h_pack(len(oids))) 602 for oid in oids: 603 val.extend(i_pack(0 if oid == -1 else oid)) 604 605 self._send_message(PARSE, val) 606 self._write(FLUSH_MSG) 607 608 def send_DESCRIBE_STATEMENT(self, statement_name_bin): 609 self._send_message(DESCRIBE, STATEMENT + statement_name_bin) 610 self._write(FLUSH_MSG) 611 612 def send_QUERY(self, sql): 613 self._send_message(QUERY, sql.encode(self._client_encoding) + NULL_BYTE) 614 615 def execute_simple(self, statement): 616 context = Context() 617 618 self.send_QUERY(statement) 619 self._flush() 620 self.handle_messages(context) 621 622 return context 623 624 def execute_unnamed(self, statement, vals=(), oids=(), stream=None): 625 context = Context(stream=stream) 626 627 self.send_PARSE(NULL_BYTE, statement, oids) 628 self._write(SYNC_MSG) 629 self._flush() 630 self.handle_messages(context) 631 self.send_DESCRIBE_STATEMENT(NULL_BYTE) 632 633 self._write(SYNC_MSG) 634 635 try: 636 self._flush() 637 except AttributeError as e: 638 if self._sock is None: 639 raise InterfaceError("connection is closed") 640 else: 641 raise e 642 params = make_params(self.py_types, vals) 643 self.send_BIND(NULL_BYTE, params) 644 self.handle_messages(context) 645 self.send_EXECUTE() 646 647 self._write(SYNC_MSG) 648 self._flush() 649 self.handle_messages(context) 650 651 return context 652 653 def prepare_statement(self, statement, oids=None): 654 655 for i in count(): 656 statement_name = f"pg8000_statement_{i}" 657 statement_name_bin = statement_name.encode("ascii") + NULL_BYTE 658 if statement_name_bin not in self._statement_nums: 659 self._statement_nums.add(statement_name_bin) 660 break 661 662 self.send_PARSE(statement_name_bin, statement, oids) 663 self.send_DESCRIBE_STATEMENT(statement_name_bin) 664 self._write(SYNC_MSG) 665 666 try: 667 self._flush() 668 except AttributeError as e: 669 if self._sock is None: 670 raise InterfaceError("connection is closed") 671 else: 672 raise e 673 674 context = Context() 675 self.handle_messages(context) 676 677 return statement_name_bin, context.columns, context.input_funcs 678 679 def execute_named(self, statement_name_bin, params, columns, input_funcs): 680 context = Context(columns=columns, input_funcs=input_funcs) 681 682 self.send_BIND(statement_name_bin, params) 683 self.send_EXECUTE() 684 self._write(SYNC_MSG) 685 self._flush() 686 self.handle_messages(context) 687 return context 688 689 def _send_message(self, code, data): 690 try: 691 self._write(code) 692 self._write(i_pack(len(data) + 4)) 693 self._write(data) 694 except ValueError as e: 695 if str(e) == "write to closed file": 696 raise InterfaceError("connection is closed") 697 else: 698 raise e 699 except AttributeError: 700 raise InterfaceError("connection is closed") 701 702 def send_BIND(self, statement_name_bin, params): 703 """https://www.postgresql.org/docs/current/protocol-message-formats.html""" 704 705 retval = bytearray( 706 NULL_BYTE + statement_name_bin + h_pack(0) + h_pack(len(params)) 707 ) 708 709 for value in params: 710 if value is None: 711 retval.extend(i_pack(-1)) 712 else: 713 val = value.encode(self._client_encoding) 714 retval.extend(i_pack(len(val))) 715 retval.extend(val) 716 retval.extend(h_pack(0)) 717 718 self._send_message(BIND, retval) 719 self._write(FLUSH_MSG) 720 721 def send_EXECUTE(self): 722 """https://www.postgresql.org/docs/current/protocol-message-formats.html""" 723 self._write(EXECUTE_MSG) 724 self._write(FLUSH_MSG) 725 726 def handle_NO_DATA(self, msg, context): 727 pass 728 729 def handle_COMMAND_COMPLETE(self, data, context): 730 values = data[:-1].split(b" ") 731 try: 732 row_count = int(values[-1]) 733 if context.row_count == -1: 734 context.row_count = row_count 735 else: 736 context.row_count += row_count 737 except ValueError: 738 pass 739 740 def handle_DATA_ROW(self, data, context): 741 idx = 2 742 row = [] 743 for func in context.input_funcs: 744 vlen = i_unpack(data, idx)[0] 745 idx += 4 746 if vlen == -1: 747 v = None 748 else: 749 v = func(str(data[idx : idx + vlen], encoding=self._client_encoding)) 750 idx += vlen 751 row.append(v) 752 context.rows.append(row) 753 754 def handle_messages(self, context): 755 code = self.error = None 756 757 while code != READY_FOR_QUERY: 758 759 try: 760 code, data_len = ci_unpack(self._read(5)) 761 except struct.error as e: 762 raise InterfaceError("network error on read") from e 763 764 self.message_types[code](self._read(data_len - 4), context) 765 766 if self.error is not None: 767 raise self.error 768 769 def close_prepared_statement(self, statement_name_bin): 770 """https://www.postgresql.org/docs/current/protocol-message-formats.html""" 771 self._send_message(CLOSE, STATEMENT + statement_name_bin) 772 self._write(FLUSH_MSG) 773 self._write(SYNC_MSG) 774 self._flush() 775 context = Context() 776 self.handle_messages(context) 777 self._statement_nums.remove(statement_name_bin) 778 779 def handle_NOTICE_RESPONSE(self, data, context): 780 """https://www.postgresql.org/docs/current/protocol-message-formats.html""" 781 self.notices.append(dict((s[0:1], s[1:]) for s in data.split(NULL_BYTE))) 782 783 def handle_PARAMETER_STATUS(self, data, context): 784 pos = data.find(NULL_BYTE) 785 key, value = data[:pos], data[pos + 1 : -1] 786 self.parameter_statuses.append((key, value)) 787 if key == b"client_encoding": 788 encoding = value.decode("ascii").lower() 789 self._client_encoding = PG_PY_ENCODINGS.get(encoding, encoding) 790 791 elif key == b"integer_datetimes": 792 if value == b"on": 793 pass 794 795 else: 796 pass 797 798 elif key == b"server_version": 799 pass 800 801 802class Context: 803 def __init__(self, stream=None, columns=None, input_funcs=None): 804 self.rows = None if columns is None else [] 805 self.row_count = -1 806 self.columns = columns 807 self.stream = stream 808 self.input_funcs = [] if input_funcs is None else input_funcs 809