1import codecs
2import socket
3import struct
4from collections import defaultdict, deque
5from hashlib import md5
6from io import TextIOBase
7from itertools import count
8from struct import Struct
9
10import scramp
11
12from pg8000.converters import (
13    PG_PY_ENCODINGS,
14    PG_TYPES,
15    PY_TYPES,
16    make_params,
17    string_in,
18)
19from pg8000.exceptions import DatabaseError, InterfaceError
20
21
22def pack_funcs(fmt):
23    struc = Struct(f"!{fmt}")
24    return struc.pack, struc.unpack_from
25
26
27i_pack, i_unpack = pack_funcs("i")
28h_pack, h_unpack = pack_funcs("h")
29ii_pack, ii_unpack = pack_funcs("ii")
30ihihih_pack, ihihih_unpack = pack_funcs("ihihih")
31ci_pack, ci_unpack = pack_funcs("ci")
32bh_pack, bh_unpack = pack_funcs("bh")
33cccc_pack, cccc_unpack = pack_funcs("cccc")
34
35
36# Copyright (c) 2007-2009, Mathieu Fenniak
37# Copyright (c) The Contributors
38# All rights reserved.
39#
40# Redistribution and use in source and binary forms, with or without
41# modification, are permitted provided that the following conditions are
42# met:
43#
44# * Redistributions of source code must retain the above copyright notice,
45# this list of conditions and the following disclaimer.
46# * Redistributions in binary form must reproduce the above copyright notice,
47# this list of conditions and the following disclaimer in the documentation
48# and/or other materials provided with the distribution.
49# * The name of the author may not be used to endorse or promote products
50# derived from this software without specific prior written permission.
51#
52# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
53# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
54# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
55# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
56# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
57# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
58# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
59# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
60# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
61# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
62# POSSIBILITY OF SUCH DAMAGE.
63
64__author__ = "Mathieu Fenniak"
65
66
67NULL_BYTE = b"\x00"
68
69
70# Message codes
71NOTICE_RESPONSE = b"N"
72AUTHENTICATION_REQUEST = b"R"
73PARAMETER_STATUS = b"S"
74BACKEND_KEY_DATA = b"K"
75READY_FOR_QUERY = b"Z"
76ROW_DESCRIPTION = b"T"
77ERROR_RESPONSE = b"E"
78DATA_ROW = b"D"
79COMMAND_COMPLETE = b"C"
80PARSE_COMPLETE = b"1"
81BIND_COMPLETE = b"2"
82CLOSE_COMPLETE = b"3"
83PORTAL_SUSPENDED = b"s"
84NO_DATA = b"n"
85PARAMETER_DESCRIPTION = b"t"
86NOTIFICATION_RESPONSE = b"A"
87COPY_DONE = b"c"
88COPY_DATA = b"d"
89COPY_IN_RESPONSE = b"G"
90COPY_OUT_RESPONSE = b"H"
91EMPTY_QUERY_RESPONSE = b"I"
92
93BIND = b"B"
94PARSE = b"P"
95QUERY = b"Q"
96EXECUTE = b"E"
97FLUSH = b"H"
98SYNC = b"S"
99PASSWORD = b"p"
100DESCRIBE = b"D"
101TERMINATE = b"X"
102CLOSE = b"C"
103
104
105def _create_message(code, data=b""):
106    return code + i_pack(len(data) + 4) + data
107
108
109FLUSH_MSG = _create_message(FLUSH)
110SYNC_MSG = _create_message(SYNC)
111TERMINATE_MSG = _create_message(TERMINATE)
112COPY_DONE_MSG = _create_message(COPY_DONE)
113EXECUTE_MSG = _create_message(EXECUTE, NULL_BYTE + i_pack(0))
114
115# DESCRIBE constants
116STATEMENT = b"S"
117PORTAL = b"P"
118
119# ErrorResponse codes
120RESPONSE_SEVERITY = "S"  # always present
121RESPONSE_SEVERITY = "V"  # always present
122RESPONSE_CODE = "C"  # always present
123RESPONSE_MSG = "M"  # always present
124RESPONSE_DETAIL = "D"
125RESPONSE_HINT = "H"
126RESPONSE_POSITION = "P"
127RESPONSE__POSITION = "p"
128RESPONSE__QUERY = "q"
129RESPONSE_WHERE = "W"
130RESPONSE_FILE = "F"
131RESPONSE_LINE = "L"
132RESPONSE_ROUTINE = "R"
133
134IDLE = b"I"
135IDLE_IN_TRANSACTION = b"T"
136IDLE_IN_FAILED_TRANSACTION = b"E"
137
138
139class CoreConnection:
140    def __enter__(self):
141        return self
142
143    def __exit__(self, exc_type, exc_value, traceback):
144        self.close()
145
146    def __init__(
147        self,
148        user,
149        host="localhost",
150        database=None,
151        port=5432,
152        password=None,
153        source_address=None,
154        unix_sock=None,
155        ssl_context=None,
156        timeout=None,
157        tcp_keepalive=True,
158        application_name=None,
159        replication=None,
160    ):
161        self._client_encoding = "utf8"
162        self._commands_with_count = (
163            b"INSERT",
164            b"DELETE",
165            b"UPDATE",
166            b"MOVE",
167            b"FETCH",
168            b"COPY",
169            b"SELECT",
170        )
171        self.notifications = deque(maxlen=100)
172        self.notices = deque(maxlen=100)
173        self.parameter_statuses = deque(maxlen=100)
174
175        if user is None:
176            raise InterfaceError("The 'user' connection parameter cannot be None")
177
178        init_params = {
179            "user": user,
180            "database": database,
181            "application_name": application_name,
182            "replication": replication,
183        }
184
185        for k, v in tuple(init_params.items()):
186            if isinstance(v, str):
187                init_params[k] = v.encode("utf8")
188            elif v is None:
189                del init_params[k]
190            elif not isinstance(v, (bytes, bytearray)):
191                raise InterfaceError(f"The parameter {k} can't be of type {type(v)}.")
192
193        self.user = init_params["user"]
194
195        if isinstance(password, str):
196            self.password = password.encode("utf8")
197        else:
198            self.password = password
199
200        self.autocommit = False
201        self._xid = None
202        self._statement_nums = set()
203
204        self._caches = {}
205
206        if unix_sock is None and host is not None:
207            try:
208                self._usock = socket.create_connection(
209                    (host, port), timeout, source_address
210                )
211            except socket.error as e:
212                raise InterfaceError(
213                    f"Can't create a connection to host {host} and port {port} "
214                    f"(timeout is {timeout} and source_address is {source_address})."
215                ) from e
216
217        elif unix_sock is not None:
218            try:
219                if not hasattr(socket, "AF_UNIX"):
220                    raise InterfaceError(
221                        "attempt to connect to unix socket on unsupported platform"
222                    )
223                self._usock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
224                self._usock.settimeout(timeout)
225                self._usock.connect(unix_sock)
226            except socket.error as e:
227                if self._usock is not None:
228                    self._usock.close()
229                raise InterfaceError("communication error") from e
230
231        else:
232            raise InterfaceError("one of host or unix_sock must be provided")
233
234        if tcp_keepalive:
235            self._usock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
236
237        self.channel_binding = None
238        if ssl_context is not None:
239            try:
240                import ssl
241
242                if ssl_context is True:
243                    ssl_context = ssl.create_default_context()
244
245                request_ssl = getattr(ssl_context, "request_ssl", True)
246
247                if request_ssl:
248                    # Int32(8) - Message length, including self.
249                    # Int32(80877103) - The SSL request code.
250                    self._usock.sendall(ii_pack(8, 80877103))
251                    resp = self._usock.recv(1)
252                    if resp != b"S":
253                        raise InterfaceError("Server refuses SSL")
254
255                self._usock = ssl_context.wrap_socket(self._usock, server_hostname=host)
256
257                if request_ssl:
258                    self.channel_binding = scramp.make_channel_binding(
259                        "tls-server-end-point", self._usock
260                    )
261
262            except ImportError:
263                raise InterfaceError(
264                    "SSL required but ssl module not available in this python "
265                    "installation."
266                )
267
268        self._sock = self._usock.makefile(mode="rwb")
269
270        def sock_flush():
271            try:
272                self._sock.flush()
273            except OSError as e:
274                raise InterfaceError("network error on flush") from e
275
276        self._flush = sock_flush
277
278        def sock_read(b):
279            try:
280                return self._sock.read(b)
281            except OSError as e:
282                raise InterfaceError("network error on read") from e
283
284        self._read = sock_read
285
286        def sock_write(d):
287            try:
288                self._sock.write(d)
289            except OSError as e:
290                raise InterfaceError("network error on write") from e
291
292        self._write = sock_write
293        self._backend_key_data = None
294
295        self.pg_types = defaultdict(lambda: string_in, PG_TYPES)
296        self.py_types = dict(PY_TYPES)
297
298        self.message_types = {
299            NOTICE_RESPONSE: self.handle_NOTICE_RESPONSE,
300            AUTHENTICATION_REQUEST: self.handle_AUTHENTICATION_REQUEST,
301            PARAMETER_STATUS: self.handle_PARAMETER_STATUS,
302            BACKEND_KEY_DATA: self.handle_BACKEND_KEY_DATA,
303            READY_FOR_QUERY: self.handle_READY_FOR_QUERY,
304            ROW_DESCRIPTION: self.handle_ROW_DESCRIPTION,
305            ERROR_RESPONSE: self.handle_ERROR_RESPONSE,
306            EMPTY_QUERY_RESPONSE: self.handle_EMPTY_QUERY_RESPONSE,
307            DATA_ROW: self.handle_DATA_ROW,
308            COMMAND_COMPLETE: self.handle_COMMAND_COMPLETE,
309            PARSE_COMPLETE: self.handle_PARSE_COMPLETE,
310            BIND_COMPLETE: self.handle_BIND_COMPLETE,
311            CLOSE_COMPLETE: self.handle_CLOSE_COMPLETE,
312            PORTAL_SUSPENDED: self.handle_PORTAL_SUSPENDED,
313            NO_DATA: self.handle_NO_DATA,
314            PARAMETER_DESCRIPTION: self.handle_PARAMETER_DESCRIPTION,
315            NOTIFICATION_RESPONSE: self.handle_NOTIFICATION_RESPONSE,
316            COPY_DONE: self.handle_COPY_DONE,
317            COPY_DATA: self.handle_COPY_DATA,
318            COPY_IN_RESPONSE: self.handle_COPY_IN_RESPONSE,
319            COPY_OUT_RESPONSE: self.handle_COPY_OUT_RESPONSE,
320        }
321
322        # Int32 - Message length, including self.
323        # Int32(196608) - Protocol version number.  Version 3.0.
324        # Any number of key/value pairs, terminated by a zero byte:
325        #   String - A parameter name (user, database, or options)
326        #   String - Parameter value
327        protocol = 196608
328        val = bytearray(i_pack(protocol))
329
330        for k, v in init_params.items():
331            val.extend(k.encode("ascii") + NULL_BYTE + v + NULL_BYTE)
332        val.append(0)
333        self._write(i_pack(len(val) + 4))
334        self._write(val)
335        self._flush()
336
337        code = self.error = None
338        while code not in (READY_FOR_QUERY, ERROR_RESPONSE):
339            code, data_len = ci_unpack(self._read(5))
340            self.message_types[code](self._read(data_len - 4), None)
341        if self.error is not None:
342            raise self.error
343
344        self.in_transaction = False
345
346    def register_out_adapter(self, typ, out_func):
347        self.py_types[typ] = out_func
348
349    def register_in_adapter(self, oid, in_func):
350        self.pg_types[oid] = in_func
351
352    def handle_ERROR_RESPONSE(self, data, context):
353        msg = dict(
354            (
355                s[:1].decode("ascii"),
356                s[1:].decode(self._client_encoding, errors="replace"),
357            )
358            for s in data.split(NULL_BYTE)
359            if s != b""
360        )
361
362        self.error = DatabaseError(msg)
363
364    def handle_EMPTY_QUERY_RESPONSE(self, data, context):
365        self.error = DatabaseError("query was empty")
366
367    def handle_CLOSE_COMPLETE(self, data, context):
368        pass
369
370    def handle_PARSE_COMPLETE(self, data, context):
371        # Byte1('1') - Identifier.
372        # Int32(4) - Message length, including self.
373        pass
374
375    def handle_BIND_COMPLETE(self, data, context):
376        pass
377
378    def handle_PORTAL_SUSPENDED(self, data, context):
379        pass
380
381    def handle_PARAMETER_DESCRIPTION(self, data, context):
382        """https://www.postgresql.org/docs/current/protocol-message-formats.html"""
383
384        # count = h_unpack(data)[0]
385        # context.parameter_oids = unpack_from("!" + "i" * count, data, 2)
386
387    def handle_COPY_DONE(self, data, context):
388        pass
389
390    def handle_COPY_OUT_RESPONSE(self, data, context):
391        """https://www.postgresql.org/docs/current/protocol-message-formats.html"""
392
393        is_binary, num_cols = bh_unpack(data)
394        # column_formats = unpack_from('!' + 'h' * num_cols, data, 3)
395
396        if context.stream is None:
397            raise InterfaceError(
398                "An output stream is required for the COPY OUT response."
399            )
400
401        elif isinstance(context.stream, TextIOBase):
402            if is_binary:
403                raise InterfaceError(
404                    "The COPY OUT stream is binary, but the stream parameter is text."
405                )
406            else:
407                decode = codecs.getdecoder(self._client_encoding)
408
409                def w(data):
410                    context.stream.write(decode(data)[0])
411
412                context.stream_write = w
413
414        else:
415            context.stream_write = context.stream.write
416
417    def handle_COPY_DATA(self, data, context):
418        context.stream_write(data)
419
420    def handle_COPY_IN_RESPONSE(self, data, context):
421        """https://www.postgresql.org/docs/current/protocol-message-formats.html"""
422        is_binary, num_cols = bh_unpack(data)
423        # column_formats = unpack_from('!' + 'h' * num_cols, data, 3)
424
425        if context.stream is None:
426            raise InterfaceError(
427                "An input stream is required for the COPY IN response."
428            )
429
430        elif isinstance(context.stream, TextIOBase):
431            if is_binary:
432                raise InterfaceError(
433                    "The COPY IN stream is binary, but the stream parameter is text."
434                )
435
436            else:
437
438                def ri(bffr):
439                    bffr.clear()
440                    bffr.extend(context.stream.read(4096).encode(self._client_encoding))
441                    return len(bffr)
442
443                readinto = ri
444        else:
445            readinto = context.stream.readinto
446
447        bffr = bytearray(8192)
448        while True:
449            bytes_read = readinto(bffr)
450            if bytes_read == 0:
451                break
452            self._write(COPY_DATA)
453            self._write(i_pack(bytes_read + 4))
454            self._write(bffr[:bytes_read])
455            self._flush()
456
457        # Send CopyDone
458        self._write(COPY_DONE_MSG)
459        self._write(SYNC_MSG)
460        self._flush()
461
462    def handle_NOTIFICATION_RESPONSE(self, data, context):
463        """https://www.postgresql.org/docs/current/protocol-message-formats.html"""
464        backend_pid = i_unpack(data)[0]
465        idx = 4
466        null_idx = data.find(NULL_BYTE, idx)
467        channel = data[idx:null_idx].decode("ascii")
468        payload = data[null_idx + 1 : -1].decode("ascii")
469
470        self.notifications.append((backend_pid, channel, payload))
471
472    def close(self):
473        """Closes the database connection.
474
475        This function is part of the `DBAPI 2.0 specification
476        <http://www.python.org/dev/peps/pep-0249/>`_.
477        """
478        try:
479            self._write(TERMINATE_MSG)
480            self._flush()
481            self._sock.close()
482        except AttributeError:
483            raise InterfaceError("connection is closed")
484        except ValueError:
485            raise InterfaceError("connection is closed")
486        except socket.error:
487            pass
488        finally:
489            self._usock.close()
490            self._sock = None
491
492    def handle_AUTHENTICATION_REQUEST(self, data, context):
493        """https://www.postgresql.org/docs/current/protocol-message-formats.html"""
494
495        auth_code = i_unpack(data)[0]
496        if auth_code == 0:
497            pass
498        elif auth_code == 3:
499            if self.password is None:
500                raise InterfaceError(
501                    "server requesting password authentication, but no password was "
502                    "provided"
503                )
504            self._send_message(PASSWORD, self.password + NULL_BYTE)
505            self._flush()
506
507        elif auth_code == 5:
508            salt = b"".join(cccc_unpack(data, 4))
509            if self.password is None:
510                raise InterfaceError(
511                    "server requesting MD5 password authentication, but no password "
512                    "was provided"
513                )
514            pwd = b"md5" + md5(
515                md5(self.password + self.user).hexdigest().encode("ascii") + salt
516            ).hexdigest().encode("ascii")
517            self._send_message(PASSWORD, pwd + NULL_BYTE)
518            self._flush()
519
520        elif auth_code == 10:
521            # AuthenticationSASL
522            mechanisms = [m.decode("ascii") for m in data[4:-2].split(NULL_BYTE)]
523
524            self.auth = scramp.ScramClient(
525                mechanisms,
526                self.user.decode("utf8"),
527                self.password.decode("utf8"),
528                channel_binding=self.channel_binding,
529            )
530
531            init = self.auth.get_client_first().encode("utf8")
532            mech = self.auth.mechanism_name.encode("ascii") + NULL_BYTE
533
534            # SASLInitialResponse
535            self._send_message(PASSWORD, mech + i_pack(len(init)) + init)
536            self._flush()
537
538        elif auth_code == 11:
539            # AuthenticationSASLContinue
540            self.auth.set_server_first(data[4:].decode("utf8"))
541
542            # SASLResponse
543            msg = self.auth.get_client_final().encode("utf8")
544            self._send_message(PASSWORD, msg)
545            self._flush()
546
547        elif auth_code == 12:
548            # AuthenticationSASLFinal
549            self.auth.set_server_final(data[4:].decode("utf8"))
550
551        elif auth_code in (2, 4, 6, 7, 8, 9):
552            raise InterfaceError(
553                f"Authentication method {auth_code} not supported by pg8000."
554            )
555        else:
556            raise InterfaceError(
557                f"Authentication method {auth_code} not recognized by pg8000."
558            )
559
560    def handle_READY_FOR_QUERY(self, data, context):
561        self.in_transaction = data != IDLE
562
563    def handle_BACKEND_KEY_DATA(self, data, context):
564        self._backend_key_data = data
565
566    def handle_ROW_DESCRIPTION(self, data, context):
567        count = h_unpack(data)[0]
568        idx = 2
569        columns = []
570        input_funcs = []
571        for i in range(count):
572            name = data[idx : data.find(NULL_BYTE, idx)]
573            idx += len(name) + 1
574            field = dict(
575                zip(
576                    (
577                        "table_oid",
578                        "column_attrnum",
579                        "type_oid",
580                        "type_size",
581                        "type_modifier",
582                        "format",
583                    ),
584                    ihihih_unpack(data, idx),
585                )
586            )
587            field["name"] = name.decode(self._client_encoding)
588            idx += 18
589            columns.append(field)
590            input_funcs.append(self.pg_types[field["type_oid"]])
591
592        context.columns = columns
593        context.input_funcs = input_funcs
594        if context.rows is None:
595            context.rows = []
596
597    def send_PARSE(self, statement_name_bin, statement, oids=()):
598
599        val = bytearray(statement_name_bin)
600        val.extend(statement.encode(self._client_encoding) + NULL_BYTE)
601        val.extend(h_pack(len(oids)))
602        for oid in oids:
603            val.extend(i_pack(0 if oid == -1 else oid))
604
605        self._send_message(PARSE, val)
606        self._write(FLUSH_MSG)
607
608    def send_DESCRIBE_STATEMENT(self, statement_name_bin):
609        self._send_message(DESCRIBE, STATEMENT + statement_name_bin)
610        self._write(FLUSH_MSG)
611
612    def send_QUERY(self, sql):
613        self._send_message(QUERY, sql.encode(self._client_encoding) + NULL_BYTE)
614
615    def execute_simple(self, statement):
616        context = Context()
617
618        self.send_QUERY(statement)
619        self._flush()
620        self.handle_messages(context)
621
622        return context
623
624    def execute_unnamed(self, statement, vals=(), oids=(), stream=None):
625        context = Context(stream=stream)
626
627        self.send_PARSE(NULL_BYTE, statement, oids)
628        self._write(SYNC_MSG)
629        self._flush()
630        self.handle_messages(context)
631        self.send_DESCRIBE_STATEMENT(NULL_BYTE)
632
633        self._write(SYNC_MSG)
634
635        try:
636            self._flush()
637        except AttributeError as e:
638            if self._sock is None:
639                raise InterfaceError("connection is closed")
640            else:
641                raise e
642        params = make_params(self.py_types, vals)
643        self.send_BIND(NULL_BYTE, params)
644        self.handle_messages(context)
645        self.send_EXECUTE()
646
647        self._write(SYNC_MSG)
648        self._flush()
649        self.handle_messages(context)
650
651        return context
652
653    def prepare_statement(self, statement, oids=None):
654
655        for i in count():
656            statement_name = f"pg8000_statement_{i}"
657            statement_name_bin = statement_name.encode("ascii") + NULL_BYTE
658            if statement_name_bin not in self._statement_nums:
659                self._statement_nums.add(statement_name_bin)
660                break
661
662        self.send_PARSE(statement_name_bin, statement, oids)
663        self.send_DESCRIBE_STATEMENT(statement_name_bin)
664        self._write(SYNC_MSG)
665
666        try:
667            self._flush()
668        except AttributeError as e:
669            if self._sock is None:
670                raise InterfaceError("connection is closed")
671            else:
672                raise e
673
674        context = Context()
675        self.handle_messages(context)
676
677        return statement_name_bin, context.columns, context.input_funcs
678
679    def execute_named(self, statement_name_bin, params, columns, input_funcs):
680        context = Context(columns=columns, input_funcs=input_funcs)
681
682        self.send_BIND(statement_name_bin, params)
683        self.send_EXECUTE()
684        self._write(SYNC_MSG)
685        self._flush()
686        self.handle_messages(context)
687        return context
688
689    def _send_message(self, code, data):
690        try:
691            self._write(code)
692            self._write(i_pack(len(data) + 4))
693            self._write(data)
694        except ValueError as e:
695            if str(e) == "write to closed file":
696                raise InterfaceError("connection is closed")
697            else:
698                raise e
699        except AttributeError:
700            raise InterfaceError("connection is closed")
701
702    def send_BIND(self, statement_name_bin, params):
703        """https://www.postgresql.org/docs/current/protocol-message-formats.html"""
704
705        retval = bytearray(
706            NULL_BYTE + statement_name_bin + h_pack(0) + h_pack(len(params))
707        )
708
709        for value in params:
710            if value is None:
711                retval.extend(i_pack(-1))
712            else:
713                val = value.encode(self._client_encoding)
714                retval.extend(i_pack(len(val)))
715                retval.extend(val)
716        retval.extend(h_pack(0))
717
718        self._send_message(BIND, retval)
719        self._write(FLUSH_MSG)
720
721    def send_EXECUTE(self):
722        """https://www.postgresql.org/docs/current/protocol-message-formats.html"""
723        self._write(EXECUTE_MSG)
724        self._write(FLUSH_MSG)
725
726    def handle_NO_DATA(self, msg, context):
727        pass
728
729    def handle_COMMAND_COMPLETE(self, data, context):
730        values = data[:-1].split(b" ")
731        try:
732            row_count = int(values[-1])
733            if context.row_count == -1:
734                context.row_count = row_count
735            else:
736                context.row_count += row_count
737        except ValueError:
738            pass
739
740    def handle_DATA_ROW(self, data, context):
741        idx = 2
742        row = []
743        for func in context.input_funcs:
744            vlen = i_unpack(data, idx)[0]
745            idx += 4
746            if vlen == -1:
747                v = None
748            else:
749                v = func(str(data[idx : idx + vlen], encoding=self._client_encoding))
750                idx += vlen
751            row.append(v)
752        context.rows.append(row)
753
754    def handle_messages(self, context):
755        code = self.error = None
756
757        while code != READY_FOR_QUERY:
758
759            try:
760                code, data_len = ci_unpack(self._read(5))
761            except struct.error as e:
762                raise InterfaceError("network error on read") from e
763
764            self.message_types[code](self._read(data_len - 4), context)
765
766        if self.error is not None:
767            raise self.error
768
769    def close_prepared_statement(self, statement_name_bin):
770        """https://www.postgresql.org/docs/current/protocol-message-formats.html"""
771        self._send_message(CLOSE, STATEMENT + statement_name_bin)
772        self._write(FLUSH_MSG)
773        self._write(SYNC_MSG)
774        self._flush()
775        context = Context()
776        self.handle_messages(context)
777        self._statement_nums.remove(statement_name_bin)
778
779    def handle_NOTICE_RESPONSE(self, data, context):
780        """https://www.postgresql.org/docs/current/protocol-message-formats.html"""
781        self.notices.append(dict((s[0:1], s[1:]) for s in data.split(NULL_BYTE)))
782
783    def handle_PARAMETER_STATUS(self, data, context):
784        pos = data.find(NULL_BYTE)
785        key, value = data[:pos], data[pos + 1 : -1]
786        self.parameter_statuses.append((key, value))
787        if key == b"client_encoding":
788            encoding = value.decode("ascii").lower()
789            self._client_encoding = PG_PY_ENCODINGS.get(encoding, encoding)
790
791        elif key == b"integer_datetimes":
792            if value == b"on":
793                pass
794
795            else:
796                pass
797
798        elif key == b"server_version":
799            pass
800
801
802class Context:
803    def __init__(self, stream=None, columns=None, input_funcs=None):
804        self.rows = None if columns is None else []
805        self.row_count = -1
806        self.columns = columns
807        self.stream = stream
808        self.input_funcs = [] if input_funcs is None else input_funcs
809