1from datetime import (
2    timedelta as Timedelta, datetime as Datetime, tzinfo, date, time)
3from warnings import warn
4import socket
5from struct import pack
6from hashlib import md5
7from decimal import Decimal
8from collections import deque, defaultdict
9from itertools import count, islice
10from six.moves import map
11from six import (
12    b, PY2, integer_types, next, text_type, u, binary_type, itervalues,
13    iteritems)
14from uuid import UUID
15from copy import deepcopy
16from calendar import timegm
17from distutils.version import LooseVersion
18from struct import Struct
19from time import localtime
20import pg8000
21from json import loads, dumps
22from os import getpid
23
24
25# Copyright (c) 2007-2009, Mathieu Fenniak
26# Copyright (c) The Contributors
27# All rights reserved.
28#
29# Redistribution and use in source and binary forms, with or without
30# modification, are permitted provided that the following conditions are
31# met:
32#
33# * Redistributions of source code must retain the above copyright notice,
34# this list of conditions and the following disclaimer.
35# * Redistributions in binary form must reproduce the above copyright notice,
36# this list of conditions and the following disclaimer in the documentation
37# and/or other materials provided with the distribution.
38# * The name of the author may not be used to endorse or promote products
39# derived from this software without specific prior written permission.
40#
41# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
42# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
43# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
44# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
45# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
46# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
47# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
48# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
49# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
50# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
51# POSSIBILITY OF SUCH DAMAGE.
52
53__author__ = "Mathieu Fenniak"
54
55
56ZERO = Timedelta(0)
57
58
59class UTC(tzinfo):
60
61    def utcoffset(self, dt):
62        return ZERO
63
64    def tzname(self, dt):
65        return "UTC"
66
67    def dst(self, dt):
68        return ZERO
69
70
71utc = UTC()
72
73
74class Interval(object):
75    """An Interval represents a measurement of time.  In PostgreSQL, an
76    interval is defined in the measure of months, days, and microseconds; as
77    such, the pg8000 interval type represents the same information.
78
79    Note that values of the :attr:`microseconds`, :attr:`days` and
80    :attr:`months` properties are independently measured and cannot be
81    converted to each other.  A month may be 28, 29, 30, or 31 days, and a day
82    may occasionally be lengthened slightly by a leap second.
83
84    .. attribute:: microseconds
85
86        Measure of microseconds in the interval.
87
88        The microseconds value is constrained to fit into a signed 64-bit
89        integer.  Any attempt to set a value too large or too small will result
90        in an OverflowError being raised.
91
92    .. attribute:: days
93
94        Measure of days in the interval.
95
96        The days value is constrained to fit into a signed 32-bit integer.
97        Any attempt to set a value too large or too small will result in an
98        OverflowError being raised.
99
100    .. attribute:: months
101
102        Measure of months in the interval.
103
104        The months value is constrained to fit into a signed 32-bit integer.
105        Any attempt to set a value too large or too small will result in an
106        OverflowError being raised.
107    """
108
109    def __init__(self, microseconds=0, days=0, months=0):
110        self.microseconds = microseconds
111        self.days = days
112        self.months = months
113
114    def _setMicroseconds(self, value):
115        if not isinstance(value, integer_types):
116            raise TypeError("microseconds must be an integer type")
117        elif not (min_int8 < value < max_int8):
118            raise OverflowError(
119                "microseconds must be representable as a 64-bit integer")
120        else:
121            self._microseconds = value
122
123    def _setDays(self, value):
124        if not isinstance(value, integer_types):
125            raise TypeError("days must be an integer type")
126        elif not (min_int4 < value < max_int4):
127            raise OverflowError(
128                "days must be representable as a 32-bit integer")
129        else:
130            self._days = value
131
132    def _setMonths(self, value):
133        if not isinstance(value, integer_types):
134            raise TypeError("months must be an integer type")
135        elif not (min_int4 < value < max_int4):
136            raise OverflowError(
137                "months must be representable as a 32-bit integer")
138        else:
139            self._months = value
140
141    microseconds = property(lambda self: self._microseconds, _setMicroseconds)
142    days = property(lambda self: self._days, _setDays)
143    months = property(lambda self: self._months, _setMonths)
144
145    def __repr__(self):
146        return "<Interval %s months %s days %s microseconds>" % (
147            self.months, self.days, self.microseconds)
148
149    def __eq__(self, other):
150        return other is not None and isinstance(other, Interval) and \
151            self.months == other.months and self.days == other.days and \
152            self.microseconds == other.microseconds
153
154    def __neq__(self, other):
155        return not self.__eq__(other)
156
157
158class PGType(object):
159    def __init__(self, value):
160        self.value = value
161
162    def encode(self, encoding):
163        return str(self.value).encode(encoding)
164
165
166class PGEnum(PGType):
167    def __init__(self, value):
168        if isinstance(value, str):
169            self.value = value
170        else:
171            self.value = value.value
172
173
174class PGJson(PGType):
175    def encode(self, encoding):
176        return dumps(self.value).encode(encoding)
177
178
179class PGJsonb(PGType):
180    def encode(self, encoding):
181        return dumps(self.value).encode(encoding)
182
183
184class PGTsvector(PGType):
185    pass
186
187
188class PGVarchar(str):
189    pass
190
191
192class PGText(str):
193    pass
194
195
196def pack_funcs(fmt):
197    struc = Struct('!' + fmt)
198    return struc.pack, struc.unpack_from
199
200
201i_pack, i_unpack = pack_funcs('i')
202h_pack, h_unpack = pack_funcs('h')
203q_pack, q_unpack = pack_funcs('q')
204d_pack, d_unpack = pack_funcs('d')
205f_pack, f_unpack = pack_funcs('f')
206iii_pack, iii_unpack = pack_funcs('iii')
207ii_pack, ii_unpack = pack_funcs('ii')
208qii_pack, qii_unpack = pack_funcs('qii')
209dii_pack, dii_unpack = pack_funcs('dii')
210ihihih_pack, ihihih_unpack = pack_funcs('ihihih')
211ci_pack, ci_unpack = pack_funcs('ci')
212bh_pack, bh_unpack = pack_funcs('bh')
213cccc_pack, cccc_unpack = pack_funcs('cccc')
214
215
216min_int2, max_int2 = -2 ** 15, 2 ** 15
217min_int4, max_int4 = -2 ** 31, 2 ** 31
218min_int8, max_int8 = -2 ** 63, 2 ** 63
219
220
221class Warning(Exception):
222    """Generic exception raised for important database warnings like data
223    truncations.  This exception is not currently used by pg8000.
224
225    This exception is part of the `DBAPI 2.0 specification
226    <http://www.python.org/dev/peps/pep-0249/>`_.
227    """
228    pass
229
230
231class Error(Exception):
232    """Generic exception that is the base exception of all other error
233    exceptions.
234
235    This exception is part of the `DBAPI 2.0 specification
236    <http://www.python.org/dev/peps/pep-0249/>`_.
237    """
238    pass
239
240
241class InterfaceError(Error):
242    """Generic exception raised for errors that are related to the database
243    interface rather than the database itself.  For example, if the interface
244    attempts to use an SSL connection but the server refuses, an InterfaceError
245    will be raised.
246
247    This exception is part of the `DBAPI 2.0 specification
248    <http://www.python.org/dev/peps/pep-0249/>`_.
249    """
250    pass
251
252
253class DatabaseError(Error):
254    """Generic exception raised for errors that are related to the database.
255    This exception is currently never raised by pg8000.
256
257    This exception is part of the `DBAPI 2.0 specification
258    <http://www.python.org/dev/peps/pep-0249/>`_.
259    """
260    pass
261
262
263class DataError(DatabaseError):
264    """Generic exception raised for errors that are due to problems with the
265    processed data.  This exception is not currently raised by pg8000.
266
267    This exception is part of the `DBAPI 2.0 specification
268    <http://www.python.org/dev/peps/pep-0249/>`_.
269    """
270    pass
271
272
273class OperationalError(DatabaseError):
274    """
275    Generic exception raised for errors that are related to the database's
276    operation and not necessarily under the control of the programmer. This
277    exception is currently never raised by pg8000.
278
279    This exception is part of the `DBAPI 2.0 specification
280    <http://www.python.org/dev/peps/pep-0249/>`_.
281    """
282    pass
283
284
285class IntegrityError(DatabaseError):
286    """
287    Generic exception raised when the relational integrity of the database is
288    affected.  This exception is not currently raised by pg8000.
289
290    This exception is part of the `DBAPI 2.0 specification
291    <http://www.python.org/dev/peps/pep-0249/>`_.
292    """
293    pass
294
295
296class InternalError(DatabaseError):
297    """Generic exception raised when the database encounters an internal error.
298    This is currently only raised when unexpected state occurs in the pg8000
299    interface itself, and is typically the result of a interface bug.
300
301    This exception is part of the `DBAPI 2.0 specification
302    <http://www.python.org/dev/peps/pep-0249/>`_.
303    """
304    pass
305
306
307class ProgrammingError(DatabaseError):
308    """Generic exception raised for programming errors.  For example, this
309    exception is raised if more parameter fields are in a query string than
310    there are available parameters.
311
312    This exception is part of the `DBAPI 2.0 specification
313    <http://www.python.org/dev/peps/pep-0249/>`_.
314    """
315    pass
316
317
318class NotSupportedError(DatabaseError):
319    """Generic exception raised in case a method or database API was used which
320    is not supported by the database.
321
322    This exception is part of the `DBAPI 2.0 specification
323    <http://www.python.org/dev/peps/pep-0249/>`_.
324    """
325    pass
326
327
328class ArrayContentNotSupportedError(NotSupportedError):
329    """
330    Raised when attempting to transmit an array where the base type is not
331    supported for binary data transfer by the interface.
332    """
333    pass
334
335
336class ArrayContentNotHomogenousError(ProgrammingError):
337    """
338    Raised when attempting to transmit an array that doesn't contain only a
339    single type of object.
340    """
341    pass
342
343
344class ArrayDimensionsNotConsistentError(ProgrammingError):
345    """
346    Raised when attempting to transmit an array that has inconsistent
347    multi-dimension sizes.
348    """
349    pass
350
351
352class Bytea(binary_type):
353    """Bytea is a str-derived class that is mapped to a PostgreSQL byte array.
354    This class is only used in Python 2, the built-in ``bytes`` type is used in
355    Python 3.
356    """
357    pass
358
359
360def Date(year, month, day):
361    """Constuct an object holding a date value.
362
363    This function is part of the `DBAPI 2.0 specification
364    <http://www.python.org/dev/peps/pep-0249/>`_.
365
366    :rtype: :class:`datetime.date`
367    """
368    return date(year, month, day)
369
370
371def Time(hour, minute, second):
372    """Construct an object holding a time value.
373
374    This function is part of the `DBAPI 2.0 specification
375    <http://www.python.org/dev/peps/pep-0249/>`_.
376
377    :rtype: :class:`datetime.time`
378    """
379    return time(hour, minute, second)
380
381
382def Timestamp(year, month, day, hour, minute, second):
383    """Construct an object holding a timestamp value.
384
385    This function is part of the `DBAPI 2.0 specification
386    <http://www.python.org/dev/peps/pep-0249/>`_.
387
388    :rtype: :class:`datetime.datetime`
389    """
390    return Datetime(year, month, day, hour, minute, second)
391
392
393def DateFromTicks(ticks):
394    """Construct an object holding a date value from the given ticks value
395    (number of seconds since the epoch).
396
397    This function is part of the `DBAPI 2.0 specification
398    <http://www.python.org/dev/peps/pep-0249/>`_.
399
400    :rtype: :class:`datetime.date`
401    """
402    return Date(*localtime(ticks)[:3])
403
404
405def TimeFromTicks(ticks):
406    """Construct an objet holding a time value from the given ticks value
407    (number of seconds since the epoch).
408
409    This function is part of the `DBAPI 2.0 specification
410    <http://www.python.org/dev/peps/pep-0249/>`_.
411
412    :rtype: :class:`datetime.time`
413    """
414    return Time(*localtime(ticks)[3:6])
415
416
417def TimestampFromTicks(ticks):
418    """Construct an object holding a timestamp value from the given ticks value
419    (number of seconds since the epoch).
420
421    This function is part of the `DBAPI 2.0 specification
422    <http://www.python.org/dev/peps/pep-0249/>`_.
423
424    :rtype: :class:`datetime.datetime`
425    """
426    return Timestamp(*localtime(ticks)[:6])
427
428
429def Binary(value):
430    """Construct an object holding binary data.
431
432    This function is part of the `DBAPI 2.0 specification
433    <http://www.python.org/dev/peps/pep-0249/>`_.
434
435    :rtype: :class:`pg8000.types.Bytea` for Python 2, otherwise :class:`bytes`
436    """
437    if PY2:
438        return Bytea(value)
439    else:
440        return value
441
442
443if PY2:
444    BINARY = Bytea
445else:
446    BINARY = bytes
447
448FC_TEXT = 0
449FC_BINARY = 1
450
451BINARY_SPACE = b(" ")
452DDL_COMMANDS = b("ALTER"), b("CREATE")
453
454
455def convert_paramstyle(style, query):
456    # I don't see any way to avoid scanning the query string char by char,
457    # so we might as well take that careful approach and create a
458    # state-based scanner.  We'll use int variables for the state.
459    OUTSIDE = 0    # outside quoted string
460    INSIDE_SQ = 1  # inside single-quote string '...'
461    INSIDE_QI = 2  # inside quoted identifier   "..."
462    INSIDE_ES = 3  # inside escaped single-quote string, E'...'
463    INSIDE_PN = 4  # inside parameter name eg. :name
464    INSIDE_CO = 5  # inside inline comment eg. --
465
466    in_quote_escape = False
467    in_param_escape = False
468    placeholders = []
469    output_query = []
470    param_idx = map(lambda x: "$" + str(x), count(1))
471    state = OUTSIDE
472    prev_c = None
473    for i, c in enumerate(query):
474        if i + 1 < len(query):
475            next_c = query[i + 1]
476        else:
477            next_c = None
478
479        if state == OUTSIDE:
480            if c == "'":
481                output_query.append(c)
482                if prev_c == 'E':
483                    state = INSIDE_ES
484                else:
485                    state = INSIDE_SQ
486            elif c == '"':
487                output_query.append(c)
488                state = INSIDE_QI
489            elif c == '-':
490                output_query.append(c)
491                if prev_c == '-':
492                    state = INSIDE_CO
493            elif style == "qmark" and c == "?":
494                output_query.append(next(param_idx))
495            elif style == "numeric" and c == ":" and next_c != ':' \
496                    and prev_c != ':':
497                # Treat : as beginning of parameter name if and only
498                # if it's the only : around
499                # Needed to properly process type conversions
500                # i.e. sum(x)::float
501                output_query.append("$")
502            elif style == "named" and c == ":" and next_c != ':' \
503                    and prev_c != ':':
504                # Same logic for : as in numeric parameters
505                state = INSIDE_PN
506                placeholders.append('')
507            elif style == "pyformat" and c == '%' and next_c == "(":
508                state = INSIDE_PN
509                placeholders.append('')
510            elif style in ("format", "pyformat") and c == "%":
511                style = "format"
512                if in_param_escape:
513                    in_param_escape = False
514                    output_query.append(c)
515                else:
516                    if next_c == "%":
517                        in_param_escape = True
518                    elif next_c == "s":
519                        state = INSIDE_PN
520                        output_query.append(next(param_idx))
521                    else:
522                        raise InterfaceError(
523                            "Only %s and %% are supported in the query.")
524            else:
525                output_query.append(c)
526
527        elif state == INSIDE_SQ:
528            if c == "'":
529                if in_quote_escape:
530                    in_quote_escape = False
531                else:
532                    if next_c == "'":
533                        in_quote_escape = True
534                    else:
535                        state = OUTSIDE
536            output_query.append(c)
537
538        elif state == INSIDE_QI:
539            if c == '"':
540                state = OUTSIDE
541            output_query.append(c)
542
543        elif state == INSIDE_ES:
544            if c == "'" and prev_c != "\\":
545                # check for escaped single-quote
546                state = OUTSIDE
547            output_query.append(c)
548
549        elif state == INSIDE_PN:
550            if style == 'named':
551                placeholders[-1] += c
552                if next_c is None or (not next_c.isalnum() and next_c != '_'):
553                    state = OUTSIDE
554                    try:
555                        pidx = placeholders.index(placeholders[-1], 0, -1)
556                        output_query.append("$" + str(pidx + 1))
557                        del placeholders[-1]
558                    except ValueError:
559                        output_query.append("$" + str(len(placeholders)))
560            elif style == 'pyformat':
561                if prev_c == ')' and c == "s":
562                    state = OUTSIDE
563                    try:
564                        pidx = placeholders.index(placeholders[-1], 0, -1)
565                        output_query.append("$" + str(pidx + 1))
566                        del placeholders[-1]
567                    except ValueError:
568                        output_query.append("$" + str(len(placeholders)))
569                elif c in "()":
570                    pass
571                else:
572                    placeholders[-1] += c
573            elif style == 'format':
574                state = OUTSIDE
575
576        elif state == INSIDE_CO:
577            output_query.append(c)
578            if c == '\n':
579                state = OUTSIDE
580
581        prev_c = c
582
583    if style in ('numeric', 'qmark', 'format'):
584        def make_args(vals):
585            return vals
586    else:
587        def make_args(vals):
588            return tuple(vals[p] for p in placeholders)
589
590    return ''.join(output_query), make_args
591
592
593EPOCH = Datetime(2000, 1, 1)
594EPOCH_TZ = EPOCH.replace(tzinfo=utc)
595EPOCH_SECONDS = timegm(EPOCH.timetuple())
596INFINITY_MICROSECONDS = 2 ** 63 - 1
597MINUS_INFINITY_MICROSECONDS = -1 * INFINITY_MICROSECONDS - 1
598
599
600# data is 64-bit integer representing microseconds since 2000-01-01
601def timestamp_recv_integer(data, offset, length):
602    micros = q_unpack(data, offset)[0]
603    try:
604        return EPOCH + Timedelta(microseconds=micros)
605    except OverflowError:
606        if micros == INFINITY_MICROSECONDS:
607            return 'infinity'
608        elif micros == MINUS_INFINITY_MICROSECONDS:
609            return '-infinity'
610        else:
611            return micros
612
613
614# data is double-precision float representing seconds since 2000-01-01
615def timestamp_recv_float(data, offset, length):
616    return Datetime.utcfromtimestamp(EPOCH_SECONDS + d_unpack(data, offset)[0])
617
618
619# data is 64-bit integer representing microseconds since 2000-01-01
620def timestamp_send_integer(v):
621    return q_pack(
622        int((timegm(v.timetuple()) - EPOCH_SECONDS) * 1e6) + v.microsecond)
623
624
625# data is double-precision float representing seconds since 2000-01-01
626def timestamp_send_float(v):
627    return d_pack(timegm(v.timetuple()) + v.microsecond / 1e6 - EPOCH_SECONDS)
628
629
630def timestamptz_send_integer(v):
631    # timestamps should be sent as UTC.  If they have zone info,
632    # convert them.
633    return timestamp_send_integer(v.astimezone(utc).replace(tzinfo=None))
634
635
636def timestamptz_send_float(v):
637    # timestamps should be sent as UTC.  If they have zone info,
638    # convert them.
639    return timestamp_send_float(v.astimezone(utc).replace(tzinfo=None))
640
641
642# return a timezone-aware datetime instance if we're reading from a
643# "timestamp with timezone" type.  The timezone returned will always be
644# UTC, but providing that additional information can permit conversion
645# to local.
646def timestamptz_recv_integer(data, offset, length):
647    micros = q_unpack(data, offset)[0]
648    try:
649        return EPOCH_TZ + Timedelta(microseconds=micros)
650    except OverflowError:
651        if micros == INFINITY_MICROSECONDS:
652            return 'infinity'
653        elif micros == MINUS_INFINITY_MICROSECONDS:
654            return '-infinity'
655        else:
656            return micros
657
658
659def timestamptz_recv_float(data, offset, length):
660    return timestamp_recv_float(data, offset, length).replace(tzinfo=utc)
661
662
663def interval_send_integer(v):
664    microseconds = v.microseconds
665    try:
666        microseconds += int(v.seconds * 1e6)
667    except AttributeError:
668        pass
669
670    try:
671        months = v.months
672    except AttributeError:
673        months = 0
674
675    return qii_pack(microseconds, v.days, months)
676
677
678def interval_send_float(v):
679    seconds = v.microseconds / 1000.0 / 1000.0
680    try:
681        seconds += v.seconds
682    except AttributeError:
683        pass
684
685    try:
686        months = v.months
687    except AttributeError:
688        months = 0
689
690    return dii_pack(seconds, v.days, months)
691
692
693def interval_recv_integer(data, offset, length):
694    microseconds, days, months = qii_unpack(data, offset)
695    if months == 0:
696        seconds, micros = divmod(microseconds, 1e6)
697        return Timedelta(days, seconds, micros)
698    else:
699        return Interval(microseconds, days, months)
700
701
702def interval_recv_float(data, offset, length):
703    seconds, days, months = dii_unpack(data, offset)
704    if months == 0:
705        secs, microseconds = divmod(seconds, 1e6)
706        return Timedelta(days, secs, microseconds)
707    else:
708        return Interval(int(seconds * 1000 * 1000), days, months)
709
710
711def int8_recv(data, offset, length):
712    return q_unpack(data, offset)[0]
713
714
715def int2_recv(data, offset, length):
716    return h_unpack(data, offset)[0]
717
718
719def int4_recv(data, offset, length):
720    return i_unpack(data, offset)[0]
721
722
723def float4_recv(data, offset, length):
724    return f_unpack(data, offset)[0]
725
726
727def float8_recv(data, offset, length):
728    return d_unpack(data, offset)[0]
729
730
731def bytea_send(v):
732    return v
733
734
735# bytea
736if PY2:
737    def bytea_recv(data, offset, length):
738        return Bytea(data[offset:offset + length])
739else:
740    def bytea_recv(data, offset, length):
741        return data[offset:offset + length]
742
743
744def uuid_send(v):
745    return v.bytes
746
747
748def uuid_recv(data, offset, length):
749    return UUID(bytes=data[offset:offset+length])
750
751
752TRUE = b("\x01")
753FALSE = b("\x00")
754
755
756def bool_send(v):
757    return TRUE if v else FALSE
758
759
760NULL = i_pack(-1)
761
762NULL_BYTE = b('\x00')
763
764
765def null_send(v):
766    return NULL
767
768
769def int_in(data, offset, length):
770    return int(data[offset: offset + length])
771
772
773class Cursor(object):
774    """A cursor object is returned by the :meth:`~Connection.cursor` method of
775    a connection. It has the following attributes and methods:
776
777    .. attribute:: arraysize
778
779        This read/write attribute specifies the number of rows to fetch at a
780        time with :meth:`fetchmany`.  It defaults to 1.
781
782    .. attribute:: connection
783
784        This read-only attribute contains a reference to the connection object
785        (an instance of :class:`Connection`) on which the cursor was
786        created.
787
788        This attribute is part of a DBAPI 2.0 extension.  Accessing this
789        attribute will generate the following warning: ``DB-API extension
790        cursor.connection used``.
791
792    .. attribute:: rowcount
793
794        This read-only attribute contains the number of rows that the last
795        ``execute()`` or ``executemany()`` method produced (for query
796        statements like ``SELECT``) or affected (for modification statements
797        like ``UPDATE``).
798
799        The value is -1 if:
800
801        - No ``execute()`` or ``executemany()`` method has been performed yet
802          on the cursor.
803        - There was no rowcount associated with the last ``execute()``.
804        - At least one of the statements executed as part of an
805          ``executemany()`` had no row count associated with it.
806        - Using a ``SELECT`` query statement on PostgreSQL server older than
807          version 9.
808        - Using a ``COPY`` query statement on PostgreSQL server version 8.1 or
809          older.
810
811        This attribute is part of the `DBAPI 2.0 specification
812        <http://www.python.org/dev/peps/pep-0249/>`_.
813
814    .. attribute:: description
815
816        This read-only attribute is a sequence of 7-item sequences.  Each value
817        contains information describing one result column.  The 7 items
818        returned for each column are (name, type_code, display_size,
819        internal_size, precision, scale, null_ok).  Only the first two values
820        are provided by the current implementation.
821
822        This attribute is part of the `DBAPI 2.0 specification
823        <http://www.python.org/dev/peps/pep-0249/>`_.
824    """
825
826    def __init__(self, connection):
827        self._c = connection
828        self.arraysize = 1
829        self.ps = None
830        self._row_count = -1
831        self._cached_rows = deque()
832
833    def __enter__(self):
834        return self
835
836    def __exit__(self, exc_type, exc_value, traceback):
837        self.close()
838
839    @property
840    def connection(self):
841        warn("DB-API extension cursor.connection used", stacklevel=3)
842        return self._c
843
844    @property
845    def rowcount(self):
846        return self._row_count
847
848    description = property(lambda self: self._getDescription())
849
850    def _getDescription(self):
851        if self.ps is None:
852            return None
853        row_desc = self.ps['row_desc']
854        if len(row_desc) == 0:
855            return None
856        columns = []
857        for col in row_desc:
858            columns.append(
859                (col["name"], col["type_oid"], None, None, None, None, None))
860        return columns
861
862    ##
863    # Executes a database operation.  Parameters may be provided as a sequence
864    # or mapping and will be bound to variables in the operation.
865    # <p>
866    # Stability: Part of the DBAPI 2.0 specification.
867    def execute(self, operation, args=None, stream=None):
868        """Executes a database operation.  Parameters may be provided as a
869        sequence, or as a mapping, depending upon the value of
870        :data:`pg8000.paramstyle`.
871
872        This method is part of the `DBAPI 2.0 specification
873        <http://www.python.org/dev/peps/pep-0249/>`_.
874
875        :param operation:
876            The SQL statement to execute.
877
878        :param args:
879            If :data:`paramstyle` is ``qmark``, ``numeric``, or ``format``,
880            this argument should be an array of parameters to bind into the
881            statement.  If :data:`paramstyle` is ``named``, the argument should
882            be a dict mapping of parameters.  If the :data:`paramstyle` is
883            ``pyformat``, the argument value may be either an array or a
884            mapping.
885
886        :param stream: This is a pg8000 extension for use with the PostgreSQL
887            `COPY
888            <http://www.postgresql.org/docs/current/static/sql-copy.html>`_
889            command. For a COPY FROM the parameter must be a readable file-like
890            object, and for COPY TO it must be writable.
891
892            .. versionadded:: 1.9.11
893        """
894        try:
895            self.stream = stream
896
897            if not self._c.in_transaction and not self._c.autocommit:
898                self._c.execute(self, "begin transaction", None)
899            self._c.execute(self, operation, args)
900        except AttributeError as e:
901            if self._c is None:
902                raise InterfaceError("Cursor closed")
903            elif self._c._sock is None:
904                raise InterfaceError("connection is closed")
905            else:
906                raise e
907
908    def executemany(self, operation, param_sets):
909        """Prepare a database operation, and then execute it against all
910        parameter sequences or mappings provided.
911
912        This method is part of the `DBAPI 2.0 specification
913        <http://www.python.org/dev/peps/pep-0249/>`_.
914
915        :param operation:
916            The SQL statement to execute
917        :param parameter_sets:
918            A sequence of parameters to execute the statement with. The values
919            in the sequence should be sequences or mappings of parameters, the
920            same as the args argument of the :meth:`execute` method.
921        """
922        rowcounts = []
923        for parameters in param_sets:
924            self.execute(operation, parameters)
925            rowcounts.append(self._row_count)
926
927        self._row_count = -1 if -1 in rowcounts else sum(rowcounts)
928
929    def fetchone(self):
930        """Fetch the next row of a query result set.
931
932        This method is part of the `DBAPI 2.0 specification
933        <http://www.python.org/dev/peps/pep-0249/>`_.
934
935        :returns:
936            A row as a sequence of field values, or ``None`` if no more rows
937            are available.
938        """
939        try:
940            return next(self)
941        except StopIteration:
942            return None
943        except TypeError:
944            raise ProgrammingError("attempting to use unexecuted cursor")
945        except AttributeError:
946            raise ProgrammingError("attempting to use unexecuted cursor")
947
948    def fetchmany(self, num=None):
949        """Fetches the next set of rows of a query result.
950
951        This method is part of the `DBAPI 2.0 specification
952        <http://www.python.org/dev/peps/pep-0249/>`_.
953
954        :param size:
955
956            The number of rows to fetch when called.  If not provided, the
957            :attr:`arraysize` attribute value is used instead.
958
959        :returns:
960
961            A sequence, each entry of which is a sequence of field values
962            making up a row.  If no more rows are available, an empty sequence
963            will be returned.
964        """
965        try:
966            return tuple(
967                islice(self, self.arraysize if num is None else num))
968        except TypeError:
969            raise ProgrammingError("attempting to use unexecuted cursor")
970
971    def fetchall(self):
972        """Fetches all remaining rows of a query result.
973
974        This method is part of the `DBAPI 2.0 specification
975        <http://www.python.org/dev/peps/pep-0249/>`_.
976
977        :returns:
978
979            A sequence, each entry of which is a sequence of field values
980            making up a row.
981        """
982        try:
983            return tuple(self)
984        except TypeError:
985            raise ProgrammingError("attempting to use unexecuted cursor")
986
987    def close(self):
988        """Closes the cursor.
989
990        This method is part of the `DBAPI 2.0 specification
991        <http://www.python.org/dev/peps/pep-0249/>`_.
992        """
993        self._c = None
994
995    def __iter__(self):
996        """A cursor object is iterable to retrieve the rows from a query.
997
998        This is a DBAPI 2.0 extension.
999        """
1000        return self
1001
1002    def setinputsizes(self, sizes):
1003        """This method is part of the `DBAPI 2.0 specification
1004        <http://www.python.org/dev/peps/pep-0249/>`_, however, it is not
1005        implemented by pg8000.
1006        """
1007        pass
1008
1009    def setoutputsize(self, size, column=None):
1010        """This method is part of the `DBAPI 2.0 specification
1011        <http://www.python.org/dev/peps/pep-0249/>`_, however, it is not
1012        implemented by pg8000.
1013        """
1014        pass
1015
1016    def __next__(self):
1017        try:
1018            return self._cached_rows.popleft()
1019        except IndexError:
1020            if self.ps is None:
1021                raise ProgrammingError("A query hasn't been issued.")
1022            elif len(self.ps['row_desc']) == 0:
1023                raise ProgrammingError("no result set")
1024            else:
1025                raise StopIteration()
1026
1027
1028if PY2:
1029    Cursor.next = Cursor.__next__
1030
1031# Message codes
1032NOTICE_RESPONSE = b("N")
1033AUTHENTICATION_REQUEST = b("R")
1034PARAMETER_STATUS = b("S")
1035BACKEND_KEY_DATA = b("K")
1036READY_FOR_QUERY = b("Z")
1037ROW_DESCRIPTION = b("T")
1038ERROR_RESPONSE = b("E")
1039DATA_ROW = b("D")
1040COMMAND_COMPLETE = b("C")
1041PARSE_COMPLETE = b("1")
1042BIND_COMPLETE = b("2")
1043CLOSE_COMPLETE = b("3")
1044PORTAL_SUSPENDED = b("s")
1045NO_DATA = b("n")
1046PARAMETER_DESCRIPTION = b("t")
1047NOTIFICATION_RESPONSE = b("A")
1048COPY_DONE = b("c")
1049COPY_DATA = b("d")
1050COPY_IN_RESPONSE = b("G")
1051COPY_OUT_RESPONSE = b("H")
1052EMPTY_QUERY_RESPONSE = b("I")
1053
1054BIND = b("B")
1055PARSE = b("P")
1056EXECUTE = b("E")
1057FLUSH = b('H')
1058SYNC = b('S')
1059PASSWORD = b('p')
1060DESCRIBE = b('D')
1061TERMINATE = b('X')
1062CLOSE = b('C')
1063
1064
1065def _establish_ssl(_socket, ssl_params):
1066    if not isinstance(ssl_params, dict):
1067        ssl_params = {}
1068
1069    try:
1070        import ssl as sslmodule
1071
1072        keyfile = ssl_params.get('keyfile')
1073        certfile = ssl_params.get('certfile')
1074        ca_certs = ssl_params.get('ca_certs')
1075        if ca_certs is None:
1076            verify_mode = sslmodule.CERT_NONE
1077        else:
1078            verify_mode = sslmodule.CERT_REQUIRED
1079
1080        # Int32(8) - Message length, including self.
1081        # Int32(80877103) - The SSL request code.
1082        _socket.sendall(ii_pack(8, 80877103))
1083        resp = _socket.recv(1)
1084        if resp == b('S'):
1085            return sslmodule.wrap_socket(
1086                _socket, keyfile=keyfile, certfile=certfile,
1087                cert_reqs=verify_mode, ca_certs=ca_certs)
1088        else:
1089            raise InterfaceError("Server refuses SSL")
1090    except ImportError:
1091        raise InterfaceError(
1092            "SSL required but ssl module not available in "
1093            "this python installation")
1094
1095
1096def create_message(code, data=b('')):
1097    return code + i_pack(len(data) + 4) + data
1098
1099
1100FLUSH_MSG = create_message(FLUSH)
1101SYNC_MSG = create_message(SYNC)
1102TERMINATE_MSG = create_message(TERMINATE)
1103COPY_DONE_MSG = create_message(COPY_DONE)
1104EXECUTE_MSG = create_message(EXECUTE, NULL_BYTE + i_pack(0))
1105
1106# DESCRIBE constants
1107STATEMENT = b('S')
1108PORTAL = b('P')
1109
1110# ErrorResponse codes
1111RESPONSE_SEVERITY = "S"  # always present
1112RESPONSE_SEVERITY = "V"  # always present
1113RESPONSE_CODE = "C"  # always present
1114RESPONSE_MSG = "M"  # always present
1115RESPONSE_DETAIL = "D"
1116RESPONSE_HINT = "H"
1117RESPONSE_POSITION = "P"
1118RESPONSE__POSITION = "p"
1119RESPONSE__QUERY = "q"
1120RESPONSE_WHERE = "W"
1121RESPONSE_FILE = "F"
1122RESPONSE_LINE = "L"
1123RESPONSE_ROUTINE = "R"
1124
1125IDLE = b("I")
1126IDLE_IN_TRANSACTION = b("T")
1127IDLE_IN_FAILED_TRANSACTION = b("E")
1128
1129
1130arr_trans = dict(zip(map(ord, u("[] 'u")), list(u('{}')) + [None] * 3))
1131
1132
1133class Connection(object):
1134
1135    # DBAPI Extension: supply exceptions as attributes on the connection
1136    Warning = property(lambda self: self._getError(Warning))
1137    Error = property(lambda self: self._getError(Error))
1138    InterfaceError = property(lambda self: self._getError(InterfaceError))
1139    DatabaseError = property(lambda self: self._getError(DatabaseError))
1140    OperationalError = property(lambda self: self._getError(OperationalError))
1141    IntegrityError = property(lambda self: self._getError(IntegrityError))
1142    InternalError = property(lambda self: self._getError(InternalError))
1143    ProgrammingError = property(lambda self: self._getError(ProgrammingError))
1144    NotSupportedError = property(
1145        lambda self: self._getError(NotSupportedError))
1146
1147    def __enter__(self):
1148        return self
1149
1150    def __exit__(self, exc_type, exc_value, traceback):
1151        self.close()
1152
1153    def _getError(self, error):
1154        warn(
1155            "DB-API extension connection.%s used" %
1156            error.__name__, stacklevel=3)
1157        return error
1158
1159    def __init__(
1160            self, user, host, unix_sock, port, database, password, ssl,
1161            timeout, application_name, max_prepared_statements, tcp_keepalive):
1162        self._client_encoding = "utf8"
1163        self._commands_with_count = (
1164            b("INSERT"), b("DELETE"), b("UPDATE"), b("MOVE"),
1165            b("FETCH"), b("COPY"), b("SELECT"))
1166        self.notifications = deque(maxlen=100)
1167        self.notices = deque(maxlen=100)
1168        self.parameter_statuses = deque(maxlen=100)
1169        self.max_prepared_statements = int(max_prepared_statements)
1170
1171        if user is None:
1172            raise InterfaceError(
1173                "The 'user' connection parameter cannot be None")
1174
1175        if isinstance(user, text_type):
1176            self.user = user.encode('utf8')
1177        else:
1178            self.user = user
1179
1180        if isinstance(password, text_type):
1181            self.password = password.encode('utf8')
1182        else:
1183            self.password = password
1184
1185        self.autocommit = False
1186        self._xid = None
1187
1188        self._caches = {}
1189
1190        try:
1191            if unix_sock is None and host is not None:
1192                self._usock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1193            elif unix_sock is not None:
1194                if not hasattr(socket, "AF_UNIX"):
1195                    raise InterfaceError(
1196                        "attempt to connect to unix socket on unsupported "
1197                        "platform")
1198                self._usock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1199            else:
1200                raise ProgrammingError(
1201                    "one of host or unix_sock must be provided")
1202            if not PY2 and timeout is not None:
1203                self._usock.settimeout(timeout)
1204
1205            if unix_sock is None and host is not None:
1206                self._usock.connect((host, port))
1207            elif unix_sock is not None:
1208                self._usock.connect(unix_sock)
1209
1210            if ssl:
1211                self._usock = _establish_ssl(self._usock, ssl)
1212
1213            self._sock = self._usock.makefile(mode="rwb")
1214            if tcp_keepalive:
1215                self._usock.setsockopt(
1216                    socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1217        except socket.error as e:
1218            self._usock.close()
1219            raise InterfaceError("communication error", e)
1220        self._flush = self._sock.flush
1221        self._read = self._sock.read
1222        self._write = self._sock.write
1223        self._backend_key_data = None
1224
1225        def text_out(v):
1226            return v.encode(self._client_encoding)
1227
1228        def enum_out(v):
1229            return str(v.value).encode(self._client_encoding)
1230
1231        def time_out(v):
1232            return v.isoformat().encode(self._client_encoding)
1233
1234        def date_out(v):
1235            return v.isoformat().encode(self._client_encoding)
1236
1237        def unknown_out(v):
1238            return str(v).encode(self._client_encoding)
1239
1240        trans_tab = dict(zip(map(ord, u('{}')), u('[]')))
1241        glbls = {'Decimal': Decimal}
1242
1243        def array_in(data, idx, length):
1244            arr = []
1245            prev_c = None
1246            for c in data[idx:idx+length].decode(
1247                    self._client_encoding).translate(
1248                    trans_tab).replace(u('NULL'), u('None')):
1249                if c not in ('[', ']', ',', 'N') and prev_c in ('[', ','):
1250                    arr.extend("Decimal('")
1251                elif c in (']', ',') and prev_c not in ('[', ']', ',', 'e'):
1252                    arr.extend("')")
1253
1254                arr.append(c)
1255                prev_c = c
1256            return eval(''.join(arr), glbls)
1257
1258        def array_recv(data, idx, length):
1259            final_idx = idx + length
1260            dim, hasnull, typeoid = iii_unpack(data, idx)
1261            idx += 12
1262
1263            # get type conversion method for typeoid
1264            conversion = self.pg_types[typeoid][1]
1265
1266            # Read dimension info
1267            dim_lengths = []
1268            for i in range(dim):
1269                dim_lengths.append(ii_unpack(data, idx)[0])
1270                idx += 8
1271
1272            # Read all array values
1273            values = []
1274            while idx < final_idx:
1275                element_len, = i_unpack(data, idx)
1276                idx += 4
1277                if element_len == -1:
1278                    values.append(None)
1279                else:
1280                    values.append(conversion(data, idx, element_len))
1281                    idx += element_len
1282
1283            # at this point, {{1,2,3},{4,5,6}}::int[][] looks like
1284            # [1,2,3,4,5,6]. go through the dimensions and fix up the array
1285            # contents to match expected dimensions
1286            for length in reversed(dim_lengths[1:]):
1287                values = list(map(list, zip(*[iter(values)] * length)))
1288            return values
1289
1290        def vector_in(data, idx, length):
1291            return eval('[' + data[idx:idx+length].decode(
1292                self._client_encoding).replace(' ', ',') + ']')
1293
1294        if PY2:
1295            def text_recv(data, offset, length):
1296                return unicode(  # noqa
1297                    data[offset: offset + length], self._client_encoding)
1298
1299            def bool_recv(d, o, l):
1300                return d[o] == "\x01"
1301
1302            def json_in(data, offset, length):
1303                return loads(unicode(  # noqa
1304                    data[offset: offset + length], self._client_encoding))
1305
1306        else:
1307            def text_recv(data, offset, length):
1308                return str(
1309                    data[offset: offset + length], self._client_encoding)
1310
1311            def bool_recv(data, offset, length):
1312                return data[offset] == 1
1313
1314            def json_in(data, offset, length):
1315                return loads(
1316                    str(data[offset: offset + length], self._client_encoding))
1317
1318        def time_in(data, offset, length):
1319            hour = int(data[offset:offset + 2])
1320            minute = int(data[offset + 3:offset + 5])
1321            sec = Decimal(
1322                data[offset + 6:offset + length].decode(self._client_encoding))
1323            return time(
1324                hour, minute, int(sec), int((sec - int(sec)) * 1000000))
1325
1326        def date_in(data, offset, length):
1327            d = data[offset:offset+length].decode(self._client_encoding)
1328            try:
1329                return date(int(d[:4]), int(d[5:7]), int(d[8:10]))
1330            except ValueError:
1331                return d
1332
1333        def numeric_in(data, offset, length):
1334            return Decimal(
1335                data[offset: offset + length].decode(self._client_encoding))
1336
1337        def numeric_out(d):
1338            return str(d).encode(self._client_encoding)
1339
1340        self.pg_types = defaultdict(
1341            lambda: (FC_TEXT, text_recv), {
1342                16: (FC_BINARY, bool_recv),  # boolean
1343                17: (FC_BINARY, bytea_recv),  # bytea
1344                19: (FC_BINARY, text_recv),  # name type
1345                20: (FC_BINARY, int8_recv),  # int8
1346                21: (FC_BINARY, int2_recv),  # int2
1347                22: (FC_TEXT, vector_in),  # int2vector
1348                23: (FC_BINARY, int4_recv),  # int4
1349                25: (FC_BINARY, text_recv),  # TEXT type
1350                26: (FC_TEXT, int_in),  # oid
1351                28: (FC_TEXT, int_in),  # xid
1352                114: (FC_TEXT, json_in),  # json
1353                700: (FC_BINARY, float4_recv),  # float4
1354                701: (FC_BINARY, float8_recv),  # float8
1355                705: (FC_BINARY, text_recv),  # unknown
1356                829: (FC_TEXT, text_recv),  # MACADDR type
1357                1000: (FC_BINARY, array_recv),  # BOOL[]
1358                1003: (FC_BINARY, array_recv),  # NAME[]
1359                1005: (FC_BINARY, array_recv),  # INT2[]
1360                1007: (FC_BINARY, array_recv),  # INT4[]
1361                1009: (FC_BINARY, array_recv),  # TEXT[]
1362                1014: (FC_BINARY, array_recv),  # CHAR[]
1363                1015: (FC_BINARY, array_recv),  # VARCHAR[]
1364                1016: (FC_BINARY, array_recv),  # INT8[]
1365                1021: (FC_BINARY, array_recv),  # FLOAT4[]
1366                1022: (FC_BINARY, array_recv),  # FLOAT8[]
1367                1042: (FC_BINARY, text_recv),  # CHAR type
1368                1043: (FC_BINARY, text_recv),  # VARCHAR type
1369                1082: (FC_TEXT, date_in),  # date
1370                1083: (FC_TEXT, time_in),
1371                1114: (FC_BINARY, timestamp_recv_float),  # timestamp w/ tz
1372                1184: (FC_BINARY, timestamptz_recv_float),
1373                1186: (FC_BINARY, interval_recv_integer),
1374                1231: (FC_TEXT, array_in),  # NUMERIC[]
1375                1263: (FC_BINARY, array_recv),  # cstring[]
1376                1700: (FC_TEXT, numeric_in),  # NUMERIC
1377                2275: (FC_BINARY, text_recv),  # cstring
1378                2950: (FC_BINARY, uuid_recv),  # uuid
1379                3802: (FC_TEXT, json_in),  # jsonb
1380            })
1381
1382        self.py_types = {
1383            type(None): (-1, FC_BINARY, null_send),  # null
1384            bool: (16, FC_BINARY, bool_send),
1385            bytearray: (17, FC_BINARY, bytea_send),  # bytea
1386            20: (20, FC_BINARY, q_pack),  # int8
1387            21: (21, FC_BINARY, h_pack),  # int2
1388            23: (23, FC_BINARY, i_pack),  # int4
1389            PGText: (25, FC_TEXT, text_out),  # text
1390            float: (701, FC_BINARY, d_pack),  # float8
1391            PGEnum: (705, FC_TEXT, enum_out),
1392            date: (1082, FC_TEXT, date_out),  # date
1393            time: (1083, FC_TEXT, time_out),  # time
1394            1114: (1114, FC_BINARY, timestamp_send_integer),  # timestamp
1395            # timestamp w/ tz
1396            PGVarchar: (1043, FC_TEXT, text_out),  # varchar
1397            1184: (1184, FC_BINARY, timestamptz_send_integer),
1398            PGJson: (114, FC_TEXT, text_out),
1399            PGJsonb: (3802, FC_TEXT, text_out),
1400            Timedelta: (1186, FC_BINARY, interval_send_integer),
1401            Interval: (1186, FC_BINARY, interval_send_integer),
1402            Decimal: (1700, FC_TEXT, numeric_out),  # Decimal
1403            PGTsvector: (3614, FC_TEXT, text_out),
1404            UUID: (2950, FC_BINARY, uuid_send)}  # uuid
1405
1406        self.inspect_funcs = {
1407            Datetime: self.inspect_datetime,
1408            list: self.array_inspect,
1409            tuple: self.array_inspect,
1410            int: self.inspect_int}
1411
1412        if PY2:
1413            self.py_types[Bytea] = (17, FC_BINARY, bytea_send)  # bytea
1414            self.py_types[text_type] = (705, FC_TEXT, text_out)  # unknown
1415            self.py_types[str] = (705, FC_TEXT, bytea_send)  # unknown
1416
1417            self.inspect_funcs[long] = self.inspect_int  # noqa
1418        else:
1419            self.py_types[bytes] = (17, FC_BINARY, bytea_send)  # bytea
1420            self.py_types[str] = (705, FC_TEXT, text_out)  # unknown
1421
1422        try:
1423            import enum
1424
1425            self.py_types[enum.Enum] = (705, FC_TEXT, enum_out)
1426        except ImportError:
1427            pass
1428
1429        try:
1430            from ipaddress import (
1431                ip_address, IPv4Address, IPv6Address, ip_network, IPv4Network,
1432                IPv6Network)
1433
1434            def inet_out(v):
1435                return str(v).encode(self._client_encoding)
1436
1437            def inet_in(data, offset, length):
1438                inet_str = data[offset: offset + length].decode(
1439                    self._client_encoding)
1440                if '/' in inet_str:
1441                    return ip_network(inet_str, False)
1442                else:
1443                    return ip_address(inet_str)
1444
1445            self.py_types[IPv4Address] = (869, FC_TEXT, inet_out)  # inet
1446            self.py_types[IPv6Address] = (869, FC_TEXT, inet_out)  # inet
1447            self.py_types[IPv4Network] = (869, FC_TEXT, inet_out)  # inet
1448            self.py_types[IPv6Network] = (869, FC_TEXT, inet_out)  # inet
1449            self.pg_types[869] = (FC_TEXT, inet_in)  # inet
1450        except ImportError:
1451            pass
1452
1453        self.message_types = {
1454            NOTICE_RESPONSE: self.handle_NOTICE_RESPONSE,
1455            AUTHENTICATION_REQUEST: self.handle_AUTHENTICATION_REQUEST,
1456            PARAMETER_STATUS: self.handle_PARAMETER_STATUS,
1457            BACKEND_KEY_DATA: self.handle_BACKEND_KEY_DATA,
1458            READY_FOR_QUERY: self.handle_READY_FOR_QUERY,
1459            ROW_DESCRIPTION: self.handle_ROW_DESCRIPTION,
1460            ERROR_RESPONSE: self.handle_ERROR_RESPONSE,
1461            EMPTY_QUERY_RESPONSE: self.handle_EMPTY_QUERY_RESPONSE,
1462            DATA_ROW: self.handle_DATA_ROW,
1463            COMMAND_COMPLETE: self.handle_COMMAND_COMPLETE,
1464            PARSE_COMPLETE: self.handle_PARSE_COMPLETE,
1465            BIND_COMPLETE: self.handle_BIND_COMPLETE,
1466            CLOSE_COMPLETE: self.handle_CLOSE_COMPLETE,
1467            PORTAL_SUSPENDED: self.handle_PORTAL_SUSPENDED,
1468            NO_DATA: self.handle_NO_DATA,
1469            PARAMETER_DESCRIPTION: self.handle_PARAMETER_DESCRIPTION,
1470            NOTIFICATION_RESPONSE: self.handle_NOTIFICATION_RESPONSE,
1471            COPY_DONE: self.handle_COPY_DONE,
1472            COPY_DATA: self.handle_COPY_DATA,
1473            COPY_IN_RESPONSE: self.handle_COPY_IN_RESPONSE,
1474            COPY_OUT_RESPONSE: self.handle_COPY_OUT_RESPONSE}
1475
1476        # Int32 - Message length, including self.
1477        # Int32(196608) - Protocol version number.  Version 3.0.
1478        # Any number of key/value pairs, terminated by a zero byte:
1479        #   String - A parameter name (user, database, or options)
1480        #   String - Parameter value
1481        protocol = 196608
1482        val = bytearray(
1483            i_pack(protocol) + b("user\x00") + self.user + NULL_BYTE)
1484        if database is not None:
1485            if isinstance(database, text_type):
1486                database = database.encode('utf8')
1487            val.extend(b("database\x00") + database + NULL_BYTE)
1488        if application_name is not None:
1489            if isinstance(application_name, text_type):
1490                application_name = application_name.encode('utf8')
1491            val.extend(
1492                b("application_name\x00") + application_name + NULL_BYTE)
1493        val.append(0)
1494        self._write(i_pack(len(val) + 4))
1495        self._write(val)
1496        self._flush()
1497
1498        self._cursor = self.cursor()
1499        code = self.error = None
1500        while code not in (READY_FOR_QUERY, ERROR_RESPONSE):
1501            code, data_len = ci_unpack(self._read(5))
1502            self.message_types[code](self._read(data_len - 4), None)
1503        if self.error is not None:
1504            raise self.error
1505
1506        self.in_transaction = False
1507
1508    def handle_ERROR_RESPONSE(self, data, ps):
1509        msg = dict(
1510            (
1511                s[:1].decode(self._client_encoding),
1512                s[1:].decode(self._client_encoding)) for s in
1513            data.split(NULL_BYTE) if s != b(''))
1514
1515        response_code = msg[RESPONSE_CODE]
1516        if response_code == '28000':
1517            cls = InterfaceError
1518        elif response_code == '23505':
1519            cls = IntegrityError
1520        else:
1521            cls = ProgrammingError
1522
1523        self.error = cls(msg)
1524
1525    def handle_EMPTY_QUERY_RESPONSE(self, data, ps):
1526        self.error = ProgrammingError("query was empty")
1527
1528    def handle_CLOSE_COMPLETE(self, data, ps):
1529        pass
1530
1531    def handle_PARSE_COMPLETE(self, data, ps):
1532        # Byte1('1') - Identifier.
1533        # Int32(4) - Message length, including self.
1534        pass
1535
1536    def handle_BIND_COMPLETE(self, data, ps):
1537        pass
1538
1539    def handle_PORTAL_SUSPENDED(self, data, cursor):
1540        pass
1541
1542    def handle_PARAMETER_DESCRIPTION(self, data, ps):
1543        # Well, we don't really care -- we're going to send whatever we
1544        # want and let the database deal with it.  But thanks anyways!
1545
1546        # count = h_unpack(data)[0]
1547        # type_oids = unpack_from("!" + "i" * count, data, 2)
1548        pass
1549
1550    def handle_COPY_DONE(self, data, ps):
1551        self._copy_done = True
1552
1553    def handle_COPY_OUT_RESPONSE(self, data, ps):
1554        # Int8(1) - 0 textual, 1 binary
1555        # Int16(2) - Number of columns
1556        # Int16(N) - Format codes for each column (0 text, 1 binary)
1557
1558        is_binary, num_cols = bh_unpack(data)
1559        # column_formats = unpack_from('!' + 'h' * num_cols, data, 3)
1560        if ps.stream is None:
1561            raise InterfaceError(
1562                "An output stream is required for the COPY OUT response.")
1563
1564    def handle_COPY_DATA(self, data, ps):
1565        ps.stream.write(data)
1566
1567    def handle_COPY_IN_RESPONSE(self, data, ps):
1568        # Int16(2) - Number of columns
1569        # Int16(N) - Format codes for each column (0 text, 1 binary)
1570        is_binary, num_cols = bh_unpack(data)
1571        # column_formats = unpack_from('!' + 'h' * num_cols, data, 3)
1572        if ps.stream is None:
1573            raise InterfaceError(
1574                "An input stream is required for the COPY IN response.")
1575
1576        if PY2:
1577            while True:
1578                data = ps.stream.read(8192)
1579                if not data:
1580                    break
1581                self._write(COPY_DATA + i_pack(len(data) + 4))
1582                self._write(data)
1583                self._flush()
1584        else:
1585            bffr = bytearray(8192)
1586            while True:
1587                bytes_read = ps.stream.readinto(bffr)
1588                if bytes_read == 0:
1589                    break
1590                self._write(COPY_DATA + i_pack(bytes_read + 4))
1591                self._write(bffr[:bytes_read])
1592                self._flush()
1593
1594        # Send CopyDone
1595        # Byte1('c') - Identifier.
1596        # Int32(4) - Message length, including self.
1597        self._write(COPY_DONE_MSG)
1598        self._write(SYNC_MSG)
1599        self._flush()
1600
1601    def handle_NOTIFICATION_RESPONSE(self, data, ps):
1602        ##
1603        # A message sent if this connection receives a NOTIFY that it was
1604        # LISTENing for.
1605        # <p>
1606        # Stability: Added in pg8000 v1.03.  When limited to accessing
1607        # properties from a notification event dispatch, stability is
1608        # guaranteed for v1.xx.
1609        backend_pid = i_unpack(data)[0]
1610        idx = 4
1611        null = data.find(NULL_BYTE, idx) - idx
1612        condition = data[idx:idx + null].decode("ascii")
1613        idx += null + 1
1614        null = data.find(NULL_BYTE, idx) - idx
1615        # additional_info = data[idx:idx + null]
1616
1617        self.notifications.append((backend_pid, condition))
1618
1619    def cursor(self):
1620        """Creates a :class:`Cursor` object bound to this
1621        connection.
1622
1623        This function is part of the `DBAPI 2.0 specification
1624        <http://www.python.org/dev/peps/pep-0249/>`_.
1625        """
1626        return Cursor(self)
1627
1628    def commit(self):
1629        """Commits the current database transaction.
1630
1631        This function is part of the `DBAPI 2.0 specification
1632        <http://www.python.org/dev/peps/pep-0249/>`_.
1633        """
1634        self.execute(self._cursor, "commit", None)
1635
1636    def rollback(self):
1637        """Rolls back the current database transaction.
1638
1639        This function is part of the `DBAPI 2.0 specification
1640        <http://www.python.org/dev/peps/pep-0249/>`_.
1641        """
1642        if not self.in_transaction:
1643            return
1644        self.execute(self._cursor, "rollback", None)
1645
1646    def close(self):
1647        """Closes the database connection.
1648
1649        This function is part of the `DBAPI 2.0 specification
1650        <http://www.python.org/dev/peps/pep-0249/>`_.
1651        """
1652        try:
1653            # Byte1('X') - Identifies the message as a terminate message.
1654            # Int32(4) - Message length, including self.
1655            self._write(TERMINATE_MSG)
1656            self._flush()
1657            self._sock.close()
1658        except AttributeError:
1659            raise InterfaceError("connection is closed")
1660        except ValueError:
1661            raise InterfaceError("connection is closed")
1662        except socket.error:
1663            pass
1664        finally:
1665            self._usock.close()
1666            self._sock = None
1667
1668    def handle_AUTHENTICATION_REQUEST(self, data, cursor):
1669        # Int32 -   An authentication code that represents different
1670        #           authentication messages:
1671        #               0 = AuthenticationOk
1672        #               5 = MD5 pwd
1673        #               2 = Kerberos v5 (not supported by pg8000)
1674        #               3 = Cleartext pwd
1675        #               4 = crypt() pwd (not supported by pg8000)
1676        #               6 = SCM credential (not supported by pg8000)
1677        #               7 = GSSAPI (not supported by pg8000)
1678        #               8 = GSSAPI data (not supported by pg8000)
1679        #               9 = SSPI (not supported by pg8000)
1680        # Some authentication messages have additional data following the
1681        # authentication code.  That data is documented in the appropriate
1682        # class.
1683        auth_code = i_unpack(data)[0]
1684        if auth_code == 0:
1685            pass
1686        elif auth_code == 3:
1687            if self.password is None:
1688                raise InterfaceError(
1689                    "server requesting password authentication, but no "
1690                    "password was provided")
1691            self._send_message(PASSWORD, self.password + NULL_BYTE)
1692            self._flush()
1693        elif auth_code == 5:
1694            ##
1695            # A message representing the backend requesting an MD5 hashed
1696            # password response.  The response will be sent as
1697            # md5(md5(pwd + login) + salt).
1698
1699            # Additional message data:
1700            #  Byte4 - Hash salt.
1701            salt = b("").join(cccc_unpack(data, 4))
1702            if self.password is None:
1703                raise InterfaceError(
1704                    "server requesting MD5 password authentication, but no "
1705                    "password was provided")
1706            pwd = b("md5") + md5(
1707                md5(self.password + self.user).hexdigest().encode("ascii") +
1708                salt).hexdigest().encode("ascii")
1709            # Byte1('p') - Identifies the message as a password message.
1710            # Int32 - Message length including self.
1711            # String - The password.  Password may be encrypted.
1712            self._send_message(PASSWORD, pwd + NULL_BYTE)
1713            self._flush()
1714
1715        elif auth_code in (2, 4, 6, 7, 8, 9):
1716            raise InterfaceError(
1717                "Authentication method " + str(auth_code) +
1718                " not supported by pg8000.")
1719        else:
1720            raise InterfaceError(
1721                "Authentication method " + str(auth_code) +
1722                " not recognized by pg8000.")
1723
1724    def handle_READY_FOR_QUERY(self, data, ps):
1725        # Byte1 -   Status indicator.
1726        self.in_transaction = data != IDLE
1727
1728    def handle_BACKEND_KEY_DATA(self, data, ps):
1729        self._backend_key_data = data
1730
1731    def inspect_datetime(self, value):
1732        if value.tzinfo is None:
1733            return self.py_types[1114]  # timestamp
1734        else:
1735            return self.py_types[1184]  # send as timestamptz
1736
1737    def inspect_int(self, value):
1738        if min_int2 < value < max_int2:
1739            return self.py_types[21]
1740        if min_int4 < value < max_int4:
1741            return self.py_types[23]
1742        if min_int8 < value < max_int8:
1743            return self.py_types[20]
1744
1745    def make_params(self, values):
1746        params = []
1747        for value in values:
1748            typ = type(value)
1749            try:
1750                params.append(self.py_types[typ])
1751            except KeyError:
1752                try:
1753                    params.append(self.inspect_funcs[typ](value))
1754                except KeyError as e:
1755                    param = None
1756                    for k, v in iteritems(self.py_types):
1757                        try:
1758                            if isinstance(value, k):
1759                                param = v
1760                                break
1761                        except TypeError:
1762                            pass
1763
1764                    if param is None:
1765                        for k, v in iteritems(self.inspect_funcs):
1766                            try:
1767                                if isinstance(value, k):
1768                                    param = v(value)
1769                                    break
1770                            except TypeError:
1771                                pass
1772                            except KeyError:
1773                                pass
1774
1775                    if param is None:
1776                        raise NotSupportedError(
1777                            "type " + str(e) + " not mapped to pg type")
1778                    else:
1779                        params.append(param)
1780
1781        return tuple(params)
1782
1783    def handle_ROW_DESCRIPTION(self, data, cursor):
1784        count = h_unpack(data)[0]
1785        idx = 2
1786        for i in range(count):
1787            name = data[idx:data.find(NULL_BYTE, idx)]
1788            idx += len(name) + 1
1789            field = dict(
1790                zip((
1791                    "table_oid", "column_attrnum", "type_oid", "type_size",
1792                    "type_modifier", "format"), ihihih_unpack(data, idx)))
1793            field['name'] = name
1794            idx += 18
1795            cursor.ps['row_desc'].append(field)
1796            field['pg8000_fc'], field['func'] = \
1797                self.pg_types[field['type_oid']]
1798
1799    def execute(self, cursor, operation, vals):
1800        if vals is None:
1801            vals = ()
1802
1803        paramstyle = pg8000.paramstyle
1804        pid = getpid()
1805        try:
1806            cache = self._caches[paramstyle][pid]
1807        except KeyError:
1808            try:
1809                param_cache = self._caches[paramstyle]
1810            except KeyError:
1811                param_cache = self._caches[paramstyle] = {}
1812
1813            try:
1814                cache = param_cache[pid]
1815            except KeyError:
1816                cache = param_cache[pid] = {'statement': {}, 'ps': {}}
1817
1818        try:
1819            statement, make_args = cache['statement'][operation]
1820        except KeyError:
1821            statement, make_args = cache['statement'][operation] = \
1822                convert_paramstyle(paramstyle, operation)
1823
1824        args = make_args(vals)
1825        params = self.make_params(args)
1826        key = operation, params
1827
1828        try:
1829            ps = cache['ps'][key]
1830            cursor.ps = ps
1831        except KeyError:
1832            statement_nums = [0]
1833            for style_cache in itervalues(self._caches):
1834                try:
1835                    pid_cache = style_cache[pid]
1836                    for csh in itervalues(pid_cache['ps']):
1837                        statement_nums.append(csh['statement_num'])
1838                except KeyError:
1839                    pass
1840
1841            statement_num = sorted(statement_nums)[-1] + 1
1842            statement_name = '_'.join(
1843                ("pg8000", "statement", str(pid), str(statement_num)))
1844            statement_name_bin = statement_name.encode('ascii') + NULL_BYTE
1845            ps = {
1846                'statement_name_bin': statement_name_bin,
1847                'pid': pid,
1848                'statement_num': statement_num,
1849                'row_desc': [],
1850                'param_funcs': tuple(x[2] for x in params)}
1851            cursor.ps = ps
1852
1853            param_fcs = tuple(x[1] for x in params)
1854
1855            # Byte1('P') - Identifies the message as a Parse command.
1856            # Int32 -   Message length, including self.
1857            # String -  Prepared statement name. An empty string selects the
1858            #           unnamed prepared statement.
1859            # String -  The query string.
1860            # Int16 -   Number of parameter data types specified (can be zero).
1861            # For each parameter:
1862            #   Int32 - The OID of the parameter data type.
1863            val = bytearray(statement_name_bin)
1864            val.extend(statement.encode(self._client_encoding) + NULL_BYTE)
1865            val.extend(h_pack(len(params)))
1866            for oid, fc, send_func in params:
1867                # Parse message doesn't seem to handle the -1 type_oid for NULL
1868                # values that other messages handle.  So we'll provide type_oid
1869                # 705, the PG "unknown" type.
1870                val.extend(i_pack(705 if oid == -1 else oid))
1871
1872            # Byte1('D') - Identifies the message as a describe command.
1873            # Int32 - Message length, including self.
1874            # Byte1 - 'S' for prepared statement, 'P' for portal.
1875            # String - The name of the item to describe.
1876            self._send_message(PARSE, val)
1877            self._send_message(DESCRIBE, STATEMENT + statement_name_bin)
1878            self._write(SYNC_MSG)
1879
1880            try:
1881                self._flush()
1882            except AttributeError as e:
1883                if self._sock is None:
1884                    raise InterfaceError("connection is closed")
1885                else:
1886                    raise e
1887
1888            self.handle_messages(cursor)
1889
1890            # We've got row_desc that allows us to identify what we're
1891            # going to get back from this statement.
1892            output_fc = tuple(
1893                self.pg_types[f['type_oid']][0] for f in ps['row_desc'])
1894
1895            ps['input_funcs'] = tuple(f['func'] for f in ps['row_desc'])
1896            # Byte1('B') - Identifies the Bind command.
1897            # Int32 - Message length, including self.
1898            # String - Name of the destination portal.
1899            # String - Name of the source prepared statement.
1900            # Int16 - Number of parameter format codes.
1901            # For each parameter format code:
1902            #   Int16 - The parameter format code.
1903            # Int16 - Number of parameter values.
1904            # For each parameter value:
1905            #   Int32 - The length of the parameter value, in bytes, not
1906            #           including this length.  -1 indicates a NULL parameter
1907            #           value, in which no value bytes follow.
1908            #   Byte[n] - Value of the parameter.
1909            # Int16 - The number of result-column format codes.
1910            # For each result-column format code:
1911            #   Int16 - The format code.
1912            ps['bind_1'] = NULL_BYTE + statement_name_bin + \
1913                h_pack(len(params)) + \
1914                pack("!" + "h" * len(param_fcs), *param_fcs) + \
1915                h_pack(len(params))
1916
1917            ps['bind_2'] = h_pack(len(output_fc)) + \
1918                pack("!" + "h" * len(output_fc), *output_fc)
1919
1920            if len(cache['ps']) > self.max_prepared_statements:
1921                for p in itervalues(cache['ps']):
1922                    self.close_prepared_statement(p['statement_name_bin'])
1923                cache['ps'].clear()
1924
1925            cache['ps'][key] = ps
1926
1927        cursor._cached_rows.clear()
1928        cursor._row_count = -1
1929
1930        # Byte1('B') - Identifies the Bind command.
1931        # Int32 - Message length, including self.
1932        # String - Name of the destination portal.
1933        # String - Name of the source prepared statement.
1934        # Int16 - Number of parameter format codes.
1935        # For each parameter format code:
1936        #   Int16 - The parameter format code.
1937        # Int16 - Number of parameter values.
1938        # For each parameter value:
1939        #   Int32 - The length of the parameter value, in bytes, not
1940        #           including this length.  -1 indicates a NULL parameter
1941        #           value, in which no value bytes follow.
1942        #   Byte[n] - Value of the parameter.
1943        # Int16 - The number of result-column format codes.
1944        # For each result-column format code:
1945        #   Int16 - The format code.
1946        retval = bytearray(ps['bind_1'])
1947        for value, send_func in zip(args, ps['param_funcs']):
1948            if value is None:
1949                val = NULL
1950            else:
1951                val = send_func(value)
1952                retval.extend(i_pack(len(val)))
1953            retval.extend(val)
1954        retval.extend(ps['bind_2'])
1955
1956        self._send_message(BIND, retval)
1957        self.send_EXECUTE(cursor)
1958        self._write(SYNC_MSG)
1959        self._flush()
1960        self.handle_messages(cursor)
1961
1962    def _send_message(self, code, data):
1963        try:
1964            self._write(code)
1965            self._write(i_pack(len(data) + 4))
1966            self._write(data)
1967            self._write(FLUSH_MSG)
1968        except ValueError as e:
1969            if str(e) == "write to closed file":
1970                raise InterfaceError("connection is closed")
1971            else:
1972                raise e
1973        except AttributeError:
1974            raise InterfaceError("connection is closed")
1975
1976    def send_EXECUTE(self, cursor):
1977        # Byte1('E') - Identifies the message as an execute message.
1978        # Int32 -   Message length, including self.
1979        # String -  The name of the portal to execute.
1980        # Int32 -   Maximum number of rows to return, if portal
1981        #           contains a query # that returns rows.
1982        #           0 = no limit.
1983        self._write(EXECUTE_MSG)
1984        self._write(FLUSH_MSG)
1985
1986    def handle_NO_DATA(self, msg, ps):
1987        pass
1988
1989    def handle_COMMAND_COMPLETE(self, data, cursor):
1990        values = data[:-1].split(BINARY_SPACE)
1991        command = values[0]
1992        if command in self._commands_with_count:
1993            row_count = int(values[-1])
1994            if cursor._row_count == -1:
1995                cursor._row_count = row_count
1996            else:
1997                cursor._row_count += row_count
1998
1999        if command in DDL_COMMANDS:
2000            for scache in itervalues(self._caches):
2001                for pcache in itervalues(scache):
2002                    for ps in itervalues(pcache['ps']):
2003                        self.close_prepared_statement(ps['statement_name_bin'])
2004                    pcache['ps'].clear()
2005
2006    def handle_DATA_ROW(self, data, cursor):
2007        data_idx = 2
2008        row = []
2009        for func in cursor.ps['input_funcs']:
2010            vlen = i_unpack(data, data_idx)[0]
2011            data_idx += 4
2012            if vlen == -1:
2013                row.append(None)
2014            else:
2015                row.append(func(data, data_idx, vlen))
2016                data_idx += vlen
2017        cursor._cached_rows.append(row)
2018
2019    def handle_messages(self, cursor):
2020        code = self.error = None
2021
2022        while code != READY_FOR_QUERY:
2023            code, data_len = ci_unpack(self._read(5))
2024            self.message_types[code](self._read(data_len - 4), cursor)
2025
2026        if self.error is not None:
2027            raise self.error
2028
2029    # Byte1('C') - Identifies the message as a close command.
2030    # Int32 - Message length, including self.
2031    # Byte1 - 'S' for prepared statement, 'P' for portal.
2032    # String - The name of the item to close.
2033    def close_prepared_statement(self, statement_name_bin):
2034        self._send_message(CLOSE, STATEMENT + statement_name_bin)
2035        self._write(SYNC_MSG)
2036        self._flush()
2037        self.handle_messages(self._cursor)
2038
2039    # Byte1('N') - Identifier
2040    # Int32 - Message length
2041    # Any number of these, followed by a zero byte:
2042    #   Byte1 - code identifying the field type (see responseKeys)
2043    #   String - field value
2044    def handle_NOTICE_RESPONSE(self, data, ps):
2045        self.notices.append(
2046            dict((s[0:1], s[1:]) for s in data.split(NULL_BYTE)))
2047
2048    def handle_PARAMETER_STATUS(self, data, ps):
2049        pos = data.find(NULL_BYTE)
2050        key, value = data[:pos], data[pos + 1:-1]
2051        self.parameter_statuses.append((key, value))
2052        if key == b("client_encoding"):
2053            encoding = value.decode("ascii").lower()
2054            self._client_encoding = pg_to_py_encodings.get(encoding, encoding)
2055
2056        elif key == b("integer_datetimes"):
2057            if value == b('on'):
2058
2059                self.py_types[1114] = (1114, FC_BINARY, timestamp_send_integer)
2060                self.pg_types[1114] = (FC_BINARY, timestamp_recv_integer)
2061
2062                self.py_types[1184] = (
2063                    1184, FC_BINARY, timestamptz_send_integer)
2064                self.pg_types[1184] = (FC_BINARY, timestamptz_recv_integer)
2065
2066                self.py_types[Interval] = (
2067                    1186, FC_BINARY, interval_send_integer)
2068                self.py_types[Timedelta] = (
2069                    1186, FC_BINARY, interval_send_integer)
2070                self.pg_types[1186] = (FC_BINARY, interval_recv_integer)
2071            else:
2072                self.py_types[1114] = (1114, FC_BINARY, timestamp_send_float)
2073                self.pg_types[1114] = (FC_BINARY, timestamp_recv_float)
2074                self.py_types[1184] = (1184, FC_BINARY, timestamptz_send_float)
2075                self.pg_types[1184] = (FC_BINARY, timestamptz_recv_float)
2076
2077                self.py_types[Interval] = (
2078                    1186, FC_BINARY, interval_send_float)
2079                self.py_types[Timedelta] = (
2080                    1186, FC_BINARY, interval_send_float)
2081                self.pg_types[1186] = (FC_BINARY, interval_recv_float)
2082
2083        elif key == b("server_version"):
2084            self._server_version = LooseVersion(value.decode('ascii'))
2085            if self._server_version < LooseVersion('8.2.0'):
2086                self._commands_with_count = (
2087                    b("INSERT"), b("DELETE"), b("UPDATE"), b("MOVE"),
2088                    b("FETCH"))
2089            elif self._server_version < LooseVersion('9.0.0'):
2090                self._commands_with_count = (
2091                    b("INSERT"), b("DELETE"), b("UPDATE"), b("MOVE"),
2092                    b("FETCH"), b("COPY"))
2093
2094    def array_inspect(self, value):
2095        # Check if array has any values. If empty, we can just assume it's an
2096        # array of strings
2097        first_element = array_find_first_element(value)
2098        if first_element is None:
2099            oid = 25
2100            # Use binary ARRAY format to avoid having to properly
2101            # escape text in the array literals
2102            fc = FC_BINARY
2103            array_oid = pg_array_types[oid]
2104        else:
2105            # supported array output
2106            typ = type(first_element)
2107
2108            if issubclass(typ, integer_types):
2109                # special int array support -- send as smallest possible array
2110                # type
2111                typ = integer_types
2112                int2_ok, int4_ok, int8_ok = True, True, True
2113                for v in array_flatten(value):
2114                    if v is None:
2115                        continue
2116                    if min_int2 < v < max_int2:
2117                        continue
2118                    int2_ok = False
2119                    if min_int4 < v < max_int4:
2120                        continue
2121                    int4_ok = False
2122                    if min_int8 < v < max_int8:
2123                        continue
2124                    int8_ok = False
2125                if int2_ok:
2126                    array_oid = 1005  # INT2[]
2127                    oid, fc, send_func = (21, FC_BINARY, h_pack)
2128                elif int4_ok:
2129                    array_oid = 1007  # INT4[]
2130                    oid, fc, send_func = (23, FC_BINARY, i_pack)
2131                elif int8_ok:
2132                    array_oid = 1016  # INT8[]
2133                    oid, fc, send_func = (20, FC_BINARY, q_pack)
2134                else:
2135                    raise ArrayContentNotSupportedError(
2136                        "numeric not supported as array contents")
2137            else:
2138                try:
2139                    oid, fc, send_func = self.make_params((first_element,))[0]
2140
2141                    # If unknown or string, assume it's a string array
2142                    if oid in (705, 1043, 25):
2143                        oid = 25
2144                        # Use binary ARRAY format to avoid having to properly
2145                        # escape text in the array literals
2146                        fc = FC_BINARY
2147                    array_oid = pg_array_types[oid]
2148                except KeyError:
2149                    raise ArrayContentNotSupportedError(
2150                        "oid " + str(oid) + " not supported as array contents")
2151                except NotSupportedError:
2152                    raise ArrayContentNotSupportedError(
2153                        "type " + str(typ) +
2154                        " not supported as array contents")
2155        if fc == FC_BINARY:
2156            def send_array(arr):
2157                # check that all array dimensions are consistent
2158                array_check_dimensions(arr)
2159
2160                has_null = array_has_null(arr)
2161                dim_lengths = array_dim_lengths(arr)
2162                data = bytearray(iii_pack(len(dim_lengths), has_null, oid))
2163                for i in dim_lengths:
2164                    data.extend(ii_pack(i, 1))
2165                for v in array_flatten(arr):
2166                    if v is None:
2167                        data += i_pack(-1)
2168                    elif isinstance(v, typ):
2169                        inner_data = send_func(v)
2170                        data += i_pack(len(inner_data))
2171                        data += inner_data
2172                    else:
2173                        raise ArrayContentNotHomogenousError(
2174                            "not all array elements are of type " + str(typ))
2175                return data
2176        else:
2177            def send_array(arr):
2178                array_check_dimensions(arr)
2179                ar = deepcopy(arr)
2180                for a, i, v in walk_array(ar):
2181                    if v is None:
2182                        a[i] = 'NULL'
2183                    elif isinstance(v, typ):
2184                        a[i] = send_func(v).decode('ascii')
2185                    else:
2186                        raise ArrayContentNotHomogenousError(
2187                            "not all array elements are of type " + str(typ))
2188                return u(str(ar)).translate(arr_trans).encode('ascii')
2189
2190        return (array_oid, fc, send_array)
2191
2192    def xid(self, format_id, global_transaction_id, branch_qualifier):
2193        """Create a Transaction IDs (only global_transaction_id is used in pg)
2194        format_id and branch_qualifier are not used in postgres
2195        global_transaction_id may be any string identifier supported by
2196        postgres returns a tuple
2197        (format_id, global_transaction_id, branch_qualifier)"""
2198        return (format_id, global_transaction_id, branch_qualifier)
2199
2200    def tpc_begin(self, xid):
2201        """Begins a TPC transaction with the given transaction ID xid.
2202
2203        This method should be called outside of a transaction (i.e. nothing may
2204        have executed since the last .commit() or .rollback()).
2205
2206        Furthermore, it is an error to call .commit() or .rollback() within the
2207        TPC transaction. A ProgrammingError is raised, if the application calls
2208        .commit() or .rollback() during an active TPC transaction.
2209
2210        This function is part of the `DBAPI 2.0 specification
2211        <http://www.python.org/dev/peps/pep-0249/>`_.
2212        """
2213        self._xid = xid
2214        if self.autocommit:
2215            self.execute(self._cursor, "begin transaction", None)
2216
2217    def tpc_prepare(self):
2218        """Performs the first phase of a transaction started with .tpc_begin().
2219        A ProgrammingError is be raised if this method is called outside of a
2220        TPC transaction.
2221
2222        After calling .tpc_prepare(), no statements can be executed until
2223        .tpc_commit() or .tpc_rollback() have been called.
2224
2225        This function is part of the `DBAPI 2.0 specification
2226        <http://www.python.org/dev/peps/pep-0249/>`_.
2227        """
2228        q = "PREPARE TRANSACTION '%s';" % (self._xid[1],)
2229        self.execute(self._cursor, q, None)
2230
2231    def tpc_commit(self, xid=None):
2232        """When called with no arguments, .tpc_commit() commits a TPC
2233        transaction previously prepared with .tpc_prepare().
2234
2235        If .tpc_commit() is called prior to .tpc_prepare(), a single phase
2236        commit is performed. A transaction manager may choose to do this if
2237        only a single resource is participating in the global transaction.
2238
2239        When called with a transaction ID xid, the database commits the given
2240        transaction. If an invalid transaction ID is provided, a
2241        ProgrammingError will be raised. This form should be called outside of
2242        a transaction, and is intended for use in recovery.
2243
2244        On return, the TPC transaction is ended.
2245
2246        This function is part of the `DBAPI 2.0 specification
2247        <http://www.python.org/dev/peps/pep-0249/>`_.
2248        """
2249        if xid is None:
2250            xid = self._xid
2251
2252        if xid is None:
2253            raise ProgrammingError(
2254                "Cannot tpc_commit() without a TPC transaction!")
2255
2256        try:
2257            previous_autocommit_mode = self.autocommit
2258            self.autocommit = True
2259            if xid in self.tpc_recover():
2260                self.execute(
2261                    self._cursor, "COMMIT PREPARED '%s';" % (xid[1], ),
2262                    None)
2263            else:
2264                # a single-phase commit
2265                self.commit()
2266        finally:
2267            self.autocommit = previous_autocommit_mode
2268        self._xid = None
2269
2270    def tpc_rollback(self, xid=None):
2271        """When called with no arguments, .tpc_rollback() rolls back a TPC
2272        transaction. It may be called before or after .tpc_prepare().
2273
2274        When called with a transaction ID xid, it rolls back the given
2275        transaction. If an invalid transaction ID is provided, a
2276        ProgrammingError is raised. This form should be called outside of a
2277        transaction, and is intended for use in recovery.
2278
2279        On return, the TPC transaction is ended.
2280
2281        This function is part of the `DBAPI 2.0 specification
2282        <http://www.python.org/dev/peps/pep-0249/>`_.
2283        """
2284        if xid is None:
2285            xid = self._xid
2286
2287        if xid is None:
2288            raise ProgrammingError(
2289                "Cannot tpc_rollback() without a TPC prepared transaction!")
2290
2291        try:
2292            previous_autocommit_mode = self.autocommit
2293            self.autocommit = True
2294            if xid in self.tpc_recover():
2295                # a two-phase rollback
2296                self.execute(
2297                    self._cursor, "ROLLBACK PREPARED '%s';" % (xid[1],),
2298                    None)
2299            else:
2300                # a single-phase rollback
2301                self.rollback()
2302        finally:
2303            self.autocommit = previous_autocommit_mode
2304        self._xid = None
2305
2306    def tpc_recover(self):
2307        """Returns a list of pending transaction IDs suitable for use with
2308        .tpc_commit(xid) or .tpc_rollback(xid).
2309
2310        This function is part of the `DBAPI 2.0 specification
2311        <http://www.python.org/dev/peps/pep-0249/>`_.
2312        """
2313        try:
2314            previous_autocommit_mode = self.autocommit
2315            self.autocommit = True
2316            curs = self.cursor()
2317            curs.execute("select gid FROM pg_prepared_xacts")
2318            return [self.xid(0, row[0], '') for row in curs]
2319        finally:
2320            self.autocommit = previous_autocommit_mode
2321
2322
2323# pg element oid -> pg array typeoid
2324pg_array_types = {
2325    16: 1000,
2326    25: 1009,    # TEXT[]
2327    701: 1022,
2328    1043: 1009,
2329    1700: 1231,  # NUMERIC[]
2330}
2331
2332
2333# PostgreSQL encodings:
2334#   http://www.postgresql.org/docs/8.3/interactive/multibyte.html
2335# Python encodings:
2336#   http://www.python.org/doc/2.4/lib/standard-encodings.html
2337#
2338# Commented out encodings don't require a name change between PostgreSQL and
2339# Python.  If the py side is None, then the encoding isn't supported.
2340pg_to_py_encodings = {
2341    # Not supported:
2342    "mule_internal": None,
2343    "euc_tw": None,
2344
2345    # Name fine as-is:
2346    # "euc_jp",
2347    # "euc_jis_2004",
2348    # "euc_kr",
2349    # "gb18030",
2350    # "gbk",
2351    # "johab",
2352    # "sjis",
2353    # "shift_jis_2004",
2354    # "uhc",
2355    # "utf8",
2356
2357    # Different name:
2358    "euc_cn": "gb2312",
2359    "iso_8859_5": "is8859_5",
2360    "iso_8859_6": "is8859_6",
2361    "iso_8859_7": "is8859_7",
2362    "iso_8859_8": "is8859_8",
2363    "koi8": "koi8_r",
2364    "latin1": "iso8859-1",
2365    "latin2": "iso8859_2",
2366    "latin3": "iso8859_3",
2367    "latin4": "iso8859_4",
2368    "latin5": "iso8859_9",
2369    "latin6": "iso8859_10",
2370    "latin7": "iso8859_13",
2371    "latin8": "iso8859_14",
2372    "latin9": "iso8859_15",
2373    "sql_ascii": "ascii",
2374    "win866": "cp886",
2375    "win874": "cp874",
2376    "win1250": "cp1250",
2377    "win1251": "cp1251",
2378    "win1252": "cp1252",
2379    "win1253": "cp1253",
2380    "win1254": "cp1254",
2381    "win1255": "cp1255",
2382    "win1256": "cp1256",
2383    "win1257": "cp1257",
2384    "win1258": "cp1258",
2385    "unicode": "utf-8",  # Needed for Amazon Redshift
2386}
2387
2388
2389def walk_array(arr):
2390    for i, v in enumerate(arr):
2391        if isinstance(v, list):
2392            for a, i2, v2 in walk_array(v):
2393                yield a, i2, v2
2394        else:
2395            yield arr, i, v
2396
2397
2398def array_find_first_element(arr):
2399    for v in array_flatten(arr):
2400        if v is not None:
2401            return v
2402    return None
2403
2404
2405def array_flatten(arr):
2406    for v in arr:
2407        if isinstance(v, list):
2408            for v2 in array_flatten(v):
2409                yield v2
2410        else:
2411            yield v
2412
2413
2414def array_check_dimensions(arr):
2415    if len(arr) > 0:
2416        v0 = arr[0]
2417        if isinstance(v0, list):
2418            req_len = len(v0)
2419            req_inner_lengths = array_check_dimensions(v0)
2420            for v in arr:
2421                inner_lengths = array_check_dimensions(v)
2422                if len(v) != req_len or inner_lengths != req_inner_lengths:
2423                    raise ArrayDimensionsNotConsistentError(
2424                        "array dimensions not consistent")
2425            retval = [req_len]
2426            retval.extend(req_inner_lengths)
2427            return retval
2428        else:
2429            # make sure nothing else at this level is a list
2430            for v in arr:
2431                if isinstance(v, list):
2432                    raise ArrayDimensionsNotConsistentError(
2433                        "array dimensions not consistent")
2434    return []
2435
2436
2437def array_has_null(arr):
2438    for v in array_flatten(arr):
2439        if v is None:
2440            return True
2441    return False
2442
2443
2444def array_dim_lengths(arr):
2445    len_arr = len(arr)
2446    retval = [len_arr]
2447    if len_arr > 0:
2448        v0 = arr[0]
2449        if isinstance(v0, list):
2450            retval.extend(array_dim_lengths(v0))
2451    return retval
2452