1"""Miscellaneous goodies for psycopg2
2
3This module is a generic place used to hold little helper functions
4and classes untill a better place in the distribution is found.
5"""
6# psycopg/extras.py - miscellaneous extra goodies for psycopg
7#
8# Copyright (C) 2003-2010 Federico Di Gregorio  <fog@debian.org>
9#
10# psycopg2 is free software: you can redistribute it and/or modify it
11# under the terms of the GNU Lesser General Public License as published
12# by the Free Software Foundation, either version 3 of the License, or
13# (at your option) any later version.
14#
15# In addition, as a special exception, the copyright holders give
16# permission to link this program with the OpenSSL library (or with
17# modified versions of OpenSSL that use the same license as OpenSSL),
18# and distribute linked combinations including the two.
19#
20# You must obey the GNU Lesser General Public License in all respects for
21# all of the code used other than OpenSSL.
22#
23# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
24# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
25# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
26# License for more details.
27
28from __future__ import unicode_literals
29
30import os as _os
31import sys as _sys
32import time as _time
33import re as _re
34import six
35
36try:
37    import logging as _logging
38except:
39    _logging = None
40
41import psycopg2cffi as psycopg2
42from psycopg2cffi import extensions as _ext
43from psycopg2cffi.extensions import cursor as _cursor
44from psycopg2cffi.extensions import connection as _connection
45from psycopg2cffi.extensions import adapt as _A
46from psycopg2cffi._impl.adapters import ascii_to_bytes, bytes_to_ascii
47
48
49class DictCursorBase(_cursor):
50    """Base class for all dict-like cursors."""
51
52    def __init__(self, *args, **kwargs):
53        if 'row_factory' in kwargs:
54            row_factory = kwargs['row_factory']
55            del kwargs['row_factory']
56        else:
57            raise NotImplementedError(
58                "DictCursorBase can't be instantiated without a row factory.")
59        super(DictCursorBase, self).__init__(*args, **kwargs)
60        self._query_executed = 0
61        self._prefetch = 0
62        self.row_factory = row_factory
63
64    def fetchone(self):
65        if self._prefetch:
66            res = super(DictCursorBase, self).fetchone()
67        if self._query_executed:
68            self._build_index()
69        if not self._prefetch:
70            res = super(DictCursorBase, self).fetchone()
71        return res
72
73    def fetchmany(self, size=None):
74        if self._prefetch:
75            res = super(DictCursorBase, self).fetchmany(size)
76        if self._query_executed:
77            self._build_index()
78        if not self._prefetch:
79            res = super(DictCursorBase, self).fetchmany(size)
80        return res
81
82    def fetchall(self):
83        if self._prefetch:
84            res = super(DictCursorBase, self).fetchall()
85        if self._query_executed:
86            self._build_index()
87        if not self._prefetch:
88            res = super(DictCursorBase, self).fetchall()
89        return res
90
91    def __iter__(self):
92        if self._prefetch:
93            res = super(DictCursorBase, self).__iter__()
94            try:
95                first = six.next(res)
96            except StopIteration:
97                return
98        if self._query_executed:
99            self._build_index()
100        if not self._prefetch:
101            res = super(DictCursorBase, self).__iter__()
102            try:
103                first = six.next(res)
104            except StopIteration:
105                return
106
107        yield first
108        while 1:
109            try:
110                yield six.next(res)
111            except StopIteration:
112                return
113
114
115class DictConnection(_connection):
116    """A connection that uses `DictCursor` automatically."""
117    def cursor(self, *args, **kwargs):
118        kwargs.setdefault('cursor_factory', DictCursor)
119        return super(DictConnection, self).cursor(*args, **kwargs)
120
121class DictCursor(DictCursorBase):
122    """A cursor that keeps a list of column name -> index mappings."""
123
124    def __init__(self, *args, **kwargs):
125        kwargs['row_factory'] = DictRow
126        super(DictCursor, self).__init__(*args, **kwargs)
127        self._prefetch = 1
128
129    def execute(self, query, vars=None):
130        self.index = {}
131        self._query_executed = 1
132        return super(DictCursor, self).execute(query, vars)
133
134    def callproc(self, procname, vars=None):
135        self.index = {}
136        self._query_executed = 1
137        return super(DictCursor, self).callproc(procname, vars)
138
139    def _build_index(self):
140        if self._query_executed == 1 and self.description:
141            for i in range(len(self.description)):
142                self.index[self.description[i][0]] = i
143            self._query_executed = 0
144
145class DictRow(list):
146    """A row object that allow by-colmun-name access to data."""
147
148    __slots__ = ('_index',)
149
150    def __init__(self, cursor):
151        self._index = cursor.index
152        self[:] = [None] * len(cursor.description)
153
154    def __getitem__(self, x):
155        if not isinstance(x, (int, slice)):
156            x = self._index[x]
157        return list.__getitem__(self, x)
158
159    def __setitem__(self, x, v):
160        if not isinstance(x, (int, slice)):
161            x = self._index[x]
162        list.__setitem__(self, x, v)
163
164    def items(self):
165        return list(self.iteritems())
166
167    def keys(self):
168        return self._index.keys()
169
170    def values(self):
171        return tuple(self[:])
172
173    def has_key(self, x):
174        return x in self._index
175
176    def get(self, x, default=None):
177        try:
178            return self[x]
179        except:
180            return default
181
182    def iteritems(self):
183        for n, v in six.iteritems(self._index):
184            yield n, list.__getitem__(self, v)
185
186    def iterkeys(self):
187        return six.iterkeys(self._index)
188
189    def itervalues(self):
190        return list.__iter__(self)
191
192    def copy(self):
193        return dict(self.iteritems())
194
195    def __contains__(self, x):
196        return x in self._index
197
198    def __getstate__(self):
199        return self[:], self._index.copy()
200
201    def __setstate__(self, data):
202        self[:] = data[0]
203        self._index = data[1]
204
205    # drop the crusty Py2 methods
206    if _sys.version_info[0] > 2:
207        items = iteritems; del iteritems
208        keys = iterkeys; del iterkeys
209        values = itervalues; del itervalues
210        del has_key
211
212
213class RealDictConnection(_connection):
214    """A connection that uses `RealDictCursor` automatically."""
215    def cursor(self, *args, **kwargs):
216        kwargs.setdefault('cursor_factory', RealDictCursor)
217        return super(RealDictConnection, self).cursor(*args, **kwargs)
218
219class RealDictCursor(DictCursorBase):
220    """A cursor that uses a real dict as the base type for rows.
221
222    Note that this cursor is extremely specialized and does not allow
223    the normal access (using integer indices) to fetched data. If you need
224    to access database rows both as a dictionary and a list, then use
225    the generic `DictCursor` instead of `!RealDictCursor`.
226    """
227    def __init__(self, *args, **kwargs):
228        kwargs['row_factory'] = RealDictRow
229        super(RealDictCursor, self).__init__(*args, **kwargs)
230        self._prefetch = 0
231
232    def execute(self, query, vars=None):
233        self.column_mapping = []
234        self._query_executed = 1
235        return super(RealDictCursor, self).execute(query, vars)
236
237    def callproc(self, procname, vars=None):
238        self.column_mapping = []
239        self._query_executed = 1
240        return super(RealDictCursor, self).callproc(procname, vars)
241
242    def _build_index(self):
243        if self._query_executed == 1 and self.description:
244            for i in range(len(self.description)):
245                self.column_mapping.append(self.description[i][0])
246            self._query_executed = 0
247
248class RealDictRow(dict):
249    """A `!dict` subclass representing a data record."""
250
251    __slots__ = ('_column_mapping')
252
253    def __init__(self, cursor):
254        dict.__init__(self)
255        # Required for named cursors
256        if cursor.description and not cursor.column_mapping:
257            cursor._build_index()
258
259        self._column_mapping = cursor.column_mapping
260
261    def __setitem__(self, name, value):
262        if type(name) == int:
263            name = self._column_mapping[name]
264        return dict.__setitem__(self, name, value)
265
266    def __getstate__(self):
267        return (self.copy(), self._column_mapping[:])
268
269    def __setstate__(self, data):
270        self.update(data[0])
271        self._column_mapping = data[1]
272
273
274class NamedTupleConnection(_connection):
275    """A connection that uses `NamedTupleCursor` automatically."""
276    def cursor(self, *args, **kwargs):
277        kwargs.setdefault('cursor_factory', NamedTupleCursor)
278        return super(NamedTupleConnection, self).cursor(*args, **kwargs)
279
280class NamedTupleCursor(_cursor):
281    """A cursor that generates results as `~collections.namedtuple`.
282
283    `!fetch*()` methods will return named tuples instead of regular tuples, so
284    their elements can be accessed both as regular numeric items as well as
285    attributes.
286
287        >>> nt_cur = conn.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor)
288        >>> rec = nt_cur.fetchone()
289        >>> rec
290        Record(id=1, num=100, data="abc'def")
291        >>> rec[1]
292        100
293        >>> rec.data
294        "abc'def"
295    """
296    Record = None
297
298    def execute(self, query, vars=None):
299        self.Record = None
300        return super(NamedTupleCursor, self).execute(query, vars)
301
302    def executemany(self, query, vars):
303        self.Record = None
304        return super(NamedTupleCursor, self).executemany(query, vars)
305
306    def callproc(self, procname, vars=None):
307        self.Record = None
308        return super(NamedTupleCursor, self).callproc(procname, vars)
309
310    def fetchone(self):
311        t = super(NamedTupleCursor, self).fetchone()
312        if t is not None:
313            nt = self.Record
314            if nt is None:
315                nt = self.Record = self._make_nt()
316            return nt._make(t)
317
318    def fetchmany(self, size=None):
319        ts = super(NamedTupleCursor, self).fetchmany(size)
320        nt = self.Record
321        if nt is None:
322            nt = self.Record = self._make_nt()
323        return [nt._make(x) for x in ts]
324
325    def fetchall(self):
326        ts = super(NamedTupleCursor, self).fetchall()
327        nt = self.Record
328        if nt is None:
329            nt = self.Record = self._make_nt()
330        return [nt._make(x) for x in ts]
331
332    def __iter__(self):
333        it = super(NamedTupleCursor, self).__iter__()
334        t = six.next(it)
335
336        nt = self.Record
337        if nt is None:
338            nt = self.Record = self._make_nt()
339
340        yield nt._make(t)
341
342        while 1:
343            try:
344                t = six.next(it)
345            except StopIteration:
346                return
347            else:
348                yield nt._make(t)
349
350    try:
351        from collections import namedtuple
352    except ImportError as _exc:
353        def _make_nt(self):
354            raise self._exc
355    else:
356        def _make_nt(self, namedtuple=namedtuple):
357            return namedtuple("Record", [d[0] for d in self.description or ()])
358
359
360class LoggingConnection(_connection):
361    """A connection that logs all queries to a file or logger__ object.
362
363    .. __: http://docs.python.org/library/logging.html
364    """
365
366    def initialize(self, logobj):
367        """Initialize the connection to log to `!logobj`.
368
369        The `!logobj` parameter can be an open file object or a Logger
370        instance from the standard logging module.
371        """
372        self._logobj = logobj
373        if _logging and isinstance(logobj, _logging.Logger):
374            self.log = self._logtologger
375        else:
376            self.log = self._logtofile
377
378    def filter(self, msg, curs):
379        """Filter the query before logging it.
380
381        This is the method to overwrite to filter unwanted queries out of the
382        log or to add some extra data to the output. The default implementation
383        just does nothing.
384        """
385        return msg
386
387    def _logtofile(self, msg, curs):
388        msg = self.filter(msg, curs)
389        if msg: self._logobj.write(msg + _os.linesep)
390
391    def _logtologger(self, msg, curs):
392        msg = self.filter(msg, curs)
393        if msg: self._logobj.debug(msg)
394
395    def _check(self):
396        if not hasattr(self, '_logobj'):
397            raise self.ProgrammingError(
398                "LoggingConnection object has not been initialize()d")
399
400    def cursor(self, *args, **kwargs):
401        self._check()
402        kwargs.setdefault('cursor_factory', LoggingCursor)
403        return super(LoggingConnection, self).cursor(*args, **kwargs)
404
405class LoggingCursor(_cursor):
406    """A cursor that logs queries using its connection logging facilities."""
407
408    def execute(self, query, vars=None):
409        try:
410            return super(LoggingCursor, self).execute(query, vars)
411        finally:
412            self.connection.log(self.query, self)
413
414    def callproc(self, procname, vars=None):
415        try:
416            return super(LoggingCursor, self).callproc(procname, vars)
417        finally:
418            self.connection.log(self.query, self)
419
420
421class MinTimeLoggingConnection(LoggingConnection):
422    """A connection that logs queries based on execution time.
423
424    This is just an example of how to sub-class `LoggingConnection` to
425    provide some extra filtering for the logged queries. Both the
426    `inizialize()` and `filter()` methods are overwritten to make sure
427    that only queries executing for more than ``mintime`` ms are logged.
428
429    Note that this connection uses the specialized cursor
430    `MinTimeLoggingCursor`.
431    """
432    def initialize(self, logobj, mintime=0):
433        LoggingConnection.initialize(self, logobj)
434        self._mintime = mintime
435
436    def filter(self, msg, curs):
437        t = (_time.time() - curs.timestamp) * 1000
438        if t > self._mintime:
439            return msg + _os.linesep + "  (execution time: %d ms)" % t
440
441    def cursor(self, *args, **kwargs):
442        kwargs.setdefault('cursor_factory', MinTimeLoggingCursor)
443        return LoggingConnection.cursor(self, *args, **kwargs)
444
445class MinTimeLoggingCursor(LoggingCursor):
446    """The cursor sub-class companion to `MinTimeLoggingConnection`."""
447
448    def execute(self, query, vars=None):
449        self.timestamp = _time.time()
450        return LoggingCursor.execute(self, query, vars)
451
452    def callproc(self, procname, vars=None):
453        self.timestamp = _time.time()
454        return LoggingCursor.execute(self, procname, vars)
455
456
457# a dbtype and adapter for Python UUID type
458
459class UUID_adapter(object):
460    """Adapt Python's uuid.UUID__ type to PostgreSQL's uuid__.
461
462    .. __: http://docs.python.org/library/uuid.html
463    .. __: http://www.postgresql.org/docs/current/static/datatype-uuid.html
464    """
465
466    def __init__(self, uuid):
467        self._uuid = uuid
468
469    def __conform__(self, proto):
470        if proto is _ext.ISQLQuote:
471            return self
472
473    def getquoted(self):
474        return b''.join([b"'", ascii_to_bytes(str(self._uuid)), b"'::uuid"])
475
476    def __bytes__(self):
477        return self.getquoted()
478
479    def __str__(self):
480        return "'%s'::uuid" % self._uuid
481
482
483def register_uuid(oids=None, conn_or_curs=None):
484    """Create the UUID type and an uuid.UUID adapter.
485
486    :param oids: oid for the PostgreSQL :sql:`uuid` type, or 2-items sequence
487        with oids of the type and the array. If not specified, use PostgreSQL
488        standard oids.
489    :param conn_or_curs: where to register the typecaster. If not specified,
490        register it globally.
491    """
492
493    import uuid
494
495    if not oids:
496        oid1 = 2950
497        oid2 = 2951
498    elif isinstance(oids, (list, tuple)):
499        oid1, oid2 = oids
500    else:
501        oid1 = oids
502        oid2 = 2951
503
504    _ext.UUID = _ext.new_type((oid1, ), "UUID",
505            lambda data, cursor: data and uuid.UUID(data) or None)
506    _ext.UUIDARRAY = _ext.new_array_type((oid2,), "UUID[]", _ext.UUID)
507
508    _ext.register_type(_ext.UUID, conn_or_curs)
509    _ext.register_type(_ext.UUIDARRAY, conn_or_curs)
510    _ext.register_adapter(uuid.UUID, UUID_adapter)
511
512    return _ext.UUID
513
514
515# a type, dbtype and adapter for PostgreSQL inet type
516
517class Inet(object):
518    """Wrap a string to allow for correct SQL-quoting of inet values.
519
520    Note that this adapter does NOT check the passed value to make
521    sure it really is an inet-compatible address but DOES call adapt()
522    on it to make sure it is impossible to execute an SQL-injection
523    by passing an evil value to the initializer.
524    """
525    def __init__(self, addr):
526        self.addr = addr
527
528    def __repr__(self):
529        return "%s(%r)" % (self.__class__.__name__, self.addr)
530
531    def prepare(self, conn):
532        self._conn = conn
533
534    def getquoted(self):
535        obj = _A(self.addr)
536        if hasattr(obj, 'prepare'):
537            obj.prepare(self._conn)
538        return obj.getquoted() + b"::inet"
539
540    def __conform__(self, proto):
541        if proto is _ext.ISQLQuote:
542            return self
543
544    def __str__(self):
545        return str(self.addr)
546
547def register_inet(oid=None, conn_or_curs=None):
548    """Create the INET type and an Inet adapter.
549
550    :param oid: oid for the PostgreSQL :sql:`inet` type, or 2-items sequence
551        with oids of the type and the array. If not specified, use PostgreSQL
552        standard oids.
553    :param conn_or_curs: where to register the typecaster. If not specified,
554        register it globally.
555    """
556    if not oid:
557        oid1 = 869
558        oid2 = 1041
559    elif isinstance(oid, (list, tuple)):
560        oid1, oid2 = oid
561    else:
562        oid1 = oid
563        oid2 = 1041
564
565    _ext.INET = _ext.new_type((oid1, ), "INET",
566            lambda data, cursor: data and Inet(data) or None)
567    _ext.INETARRAY = _ext.new_array_type((oid2, ), "INETARRAY", _ext.INET)
568
569    _ext.register_type(_ext.INET, conn_or_curs)
570    _ext.register_type(_ext.INETARRAY, conn_or_curs)
571
572    return _ext.INET
573
574
575def register_tstz_w_secs(oids=None, conn_or_curs=None):
576    """The function used to register an alternate type caster for
577    :sql:`TIMESTAMP WITH TIME ZONE` to deal with historical time zones with
578    seconds in the UTC offset.
579
580    These are now correctly handled by the default type caster, so currently
581    the function doesn't do anything.
582    """
583    import warnings
584    warnings.warn("deprecated", DeprecationWarning)
585
586
587def wait_select(conn):
588    """Wait until a connection or cursor has data available.
589
590    The function is an example of a wait callback to be registered with
591    `~psycopg2.extensions.set_wait_callback()`. This function uses
592    :py:func:`~select.select()` to wait for data available.
593
594    """
595    import select
596    from psycopg2cffi.extensions import POLL_OK, POLL_READ, POLL_WRITE
597
598    while 1:
599        state = conn.poll()
600        if state == POLL_OK:
601            break
602        elif state == POLL_READ:
603            select.select([conn.fileno()], [], [])
604        elif state == POLL_WRITE:
605            select.select([], [conn.fileno()], [])
606        else:
607            raise conn.OperationalError("bad state from poll: %s" % state)
608
609
610def _solve_conn_curs(conn_or_curs):
611    """Return the connection and a DBAPI cursor from a connection or cursor."""
612    if conn_or_curs is None:
613        raise psycopg2.ProgrammingError("no connection or cursor provided")
614
615    if hasattr(conn_or_curs, 'execute'):
616        conn = conn_or_curs.connection
617        curs = conn.cursor(cursor_factory=_cursor)
618    else:
619        conn = conn_or_curs
620        curs = conn.cursor(cursor_factory=_cursor)
621
622    return conn, curs
623
624
625class HstoreAdapter(object):
626    """Adapt a Python dict to the hstore syntax."""
627    def __init__(self, wrapped):
628        self.wrapped = wrapped
629        self.unicode = False
630
631    def prepare(self, conn):
632        self.conn = conn
633
634        # use an old-style getquoted implementation if required
635        if conn.server_version < 90000:
636            self.getquoted = self._getquoted_8
637
638    def _getquoted_8(self):
639        """Use the operators available in PG pre-9.0."""
640        if not self.wrapped:
641            return b"''::hstore"
642
643        adapt = _ext.adapt
644        rv = []
645        for k, v in six.iteritems(self.wrapped):
646            k = adapt(k)
647            k.prepare(self.conn)
648            k = k.getquoted()
649
650            if v is not None:
651                v = adapt(v)
652                v.prepare(self.conn)
653                v = v.getquoted()
654            else:
655                v = b'NULL'
656
657            rv.append(b"(" + k + b" => " + v + b")")
658
659        return b"(" + b'||'.join(rv) + b")"
660
661    def _getquoted_9(self):
662        """Use the hstore(text[], text[]) function."""
663        if not self.wrapped:
664            return b"''::hstore"
665
666        k = _ext.adapt(list(self.wrapped.keys()))
667        k.prepare(self.conn)
668        v = _ext.adapt(list(self.wrapped.values()))
669        v.prepare(self.conn)
670        return b"hstore(" + k.getquoted() + b", " + v.getquoted() + b")"
671
672    getquoted = _getquoted_9
673
674    _re_hstore = _re.compile(r"""
675        # hstore key:
676        # a string of normal or escaped chars
677        "((?: [^"\\] | \\. )*)"
678        \s*=>\s* # hstore value
679        (?:
680            NULL # the value can be null - not catched
681            # or a quoted string like the key
682            | "((?: [^"\\] | \\. )*)"
683        )
684        (?:\s*,\s*|$) # pairs separated by comma or end of string.
685    """, _re.VERBOSE | _re.UNICODE)
686
687    @classmethod
688    def parse(self, s, cur, _bsdec=_re.compile(r"\\(.)", _re.UNICODE)):
689        """Parse an hstore representation in a Python string.
690
691        The hstore is represented as something like::
692
693            "a"=>"1", "b"=>"2"
694
695        with backslash-escaped strings.
696        """
697        if s is None:
698            return None
699
700        rv = {}
701        start = 0
702        if six.PY3 and isinstance(s, six.binary_type):
703            s = s.decode(_ext.encodings[cur.connection.encoding]) \
704                    if cur else bytes_to_ascii(s)
705        for m in self._re_hstore.finditer(s):
706            if m is None or m.start() != start:
707                raise psycopg2.InterfaceError(
708                    "error parsing hstore pair at char %d" % start)
709            k = _bsdec.sub(r'\1', m.group(1), _re.UNICODE)
710            v = m.group(2)
711            if v is not None:
712                v = _bsdec.sub(r'\1', v, _re.UNICODE)
713
714            rv[k] = v
715            start = m.end()
716
717        if start < len(s):
718            raise psycopg2.InterfaceError(
719                "error parsing hstore: unparsed data after char %d" % start)
720
721        return rv
722
723    @classmethod
724    def _to_unicode(self, s, cur):
725        if s is None:
726            return None
727        else:
728            return s.decode(_ext.encodings[cur.connection.encoding]) \
729                    if cur else bytes_to_ascii(s)
730
731    @classmethod
732    def parse_unicode(self, s, cur):
733        """Parse an hstore returning unicode keys and values."""
734        if s is None:
735            return None
736
737        if not isinstance(s, unicode):
738            s = s.decode(_ext.encodings[cur.connection.encoding])
739
740        return self.parse(s, cur)
741
742    @classmethod
743    def get_oids(self, conn_or_curs):
744        """Return the lists of OID of the hstore and hstore[] types.
745        """
746        conn, curs = _solve_conn_curs(conn_or_curs)
747
748        # Store the transaction status of the connection to revert it after use
749        conn_status = conn.status
750
751        # column typarray not available before PG 8.3
752        typarray = conn.server_version >= 80300 and "typarray" or "NULL"
753
754        rv0, rv1 = [], []
755
756        # get the oid for the hstore
757        curs.execute("""\
758SELECT t.oid, %s
759FROM pg_type t JOIN pg_namespace ns
760    ON typnamespace = ns.oid
761WHERE typname = 'hstore';
762""" % typarray)
763        for oids in curs:
764            rv0.append(oids[0])
765            rv1.append(oids[1])
766
767        # revert the status of the connection as before the command
768        if (conn_status != _ext.STATUS_IN_TRANSACTION
769        and not conn.autocommit):
770            conn.rollback()
771
772        return tuple(rv0), tuple(rv1)
773
774def register_hstore(conn_or_curs, globally=False, unicode=False,
775        oid=None, array_oid=None):
776    """Register adapter and typecaster for `!dict`\-\ |hstore| conversions.
777
778    :param conn_or_curs: a connection or cursor: the typecaster will be
779        registered only on this object unless *globally* is set to `!True`
780    :param globally: register the adapter globally, not only on *conn_or_curs*
781    :param unicode: if `!True`, keys and values returned from the database
782        will be `!unicode` instead of `!str`. The option is not available on
783        Python 3
784    :param oid: the OID of the |hstore| type if known. If not, it will be
785        queried on *conn_or_curs*.
786    :param array_oid: the OID of the |hstore| array type if known. If not, it
787        will be queried on *conn_or_curs*.
788
789    The connection or cursor passed to the function will be used to query the
790    database and look for the OID of the |hstore| type (which may be different
791    across databases). If querying is not desirable (e.g. with
792    :ref:`asynchronous connections <async-support>`) you may specify it in the
793    *oid* parameter, which can be found using a query such as :sql:`SELECT
794    'hstore'::regtype::oid`. Analogously you can obtain a value for *array_oid*
795    using a query such as :sql:`SELECT 'hstore[]'::regtype::oid`.
796
797    Note that, when passing a dictionary from Python to the database, both
798    strings and unicode keys and values are supported. Dictionaries returned
799    from the database have keys/values according to the *unicode* parameter.
800
801    The |hstore| contrib module must be already installed in the database
802    (executing the ``hstore.sql`` script in your ``contrib`` directory).
803    Raise `~psycopg2.ProgrammingError` if the type is not found.
804    """
805    if oid is None:
806        oid = HstoreAdapter.get_oids(conn_or_curs)
807        if oid is None or not oid[0]:
808            raise psycopg2.ProgrammingError(
809                "hstore type not found in the database. "
810                "please install it from your 'contrib/hstore.sql' file")
811        else:
812            array_oid = oid[1]
813            oid = oid[0]
814
815    if isinstance(oid, int):
816        oid = (oid,)
817
818    if array_oid is not None:
819        if isinstance(array_oid, int):
820            array_oid = (array_oid,)
821        else:
822            array_oid = tuple([x for x in array_oid if x])
823
824    # create and register the typecaster
825    if _sys.version_info[0] < 3 and unicode:
826        cast = HstoreAdapter.parse_unicode
827    else:
828        cast = HstoreAdapter.parse
829
830    HSTORE = _ext.new_type(oid, "HSTORE", cast, unicode)
831    _ext.register_type(HSTORE, not globally and conn_or_curs or None)
832    _ext.register_adapter(dict, HstoreAdapter)
833
834    if array_oid:
835        HSTOREARRAY = _ext.new_array_type(array_oid, "HSTOREARRAY", HSTORE)
836        _ext.register_type(HSTOREARRAY, not globally and conn_or_curs or None)
837
838
839class CompositeCaster(object):
840    """Helps conversion of a PostgreSQL composite type into a Python object.
841
842    The class is usually created by the `register_composite()` function.
843    You may want to create and register manually instances of the class if
844    querying the database at registration time is not desirable (such as when
845    using an :ref:`asynchronous connections <async-support>`).
846
847    """
848    def __init__(self, name, oid, attrs, array_oid=None, schema=None):
849        self.name = name
850        self.schema = schema
851        self.oid = oid
852        self.array_oid = array_oid
853
854        self.attnames = [ a[0] for a in attrs ]
855        self.atttypes = [ a[1] for a in attrs ]
856        self._create_type(name, self.attnames)
857        self.typecaster = _ext.new_type((oid,), name, self.parse)
858        if array_oid:
859            self.array_typecaster = _ext.new_array_type(
860                (array_oid,), "%sARRAY" % name, self.typecaster)
861        else:
862            self.array_typecaster = None
863
864    def parse(self, s, curs):
865        if s is None:
866            return None
867
868        tokens = self.tokenize(s)
869        if len(tokens) != len(self.atttypes):
870            raise psycopg2.DataError(
871                "expecting %d components for the type %s, %d found instead" %
872                (len(self.atttypes), self.name, len(tokens)))
873
874        values = [ curs.cast(oid, token)
875            for oid, token in zip(self.atttypes, tokens) ]
876
877        return self.make(values)
878
879    def make(self, values):
880        """Return a new Python object representing the data being casted.
881
882        *values* is the list of attributes, already casted into their Python
883        representation.
884
885        You can subclass this method to :ref:`customize the composite cast
886        <custom-composite>`.
887        """
888
889        return self._ctor(values)
890
891    _re_tokenize = _re.compile(r"""
892  \(? ([,)])                        # an empty token, representing NULL
893| \(? " ((?: [^"] | "")*) " [,)]    # or a quoted string
894| \(? ([^",)]+) [,)]                # or an unquoted string
895    """, _re.VERBOSE)
896
897    _re_undouble = _re.compile(r'(["\\])\1')
898
899    @classmethod
900    def tokenize(self, s):
901        ''' Gets bytestring, returns list of bytestrings
902        '''
903        rv = []
904        for m in self._re_tokenize.finditer(s):
905            if m is None:
906                raise psycopg2.InterfaceError("can't parse type: %r", s)
907            if m.group(1) is not None:
908                rv.append(None)
909            elif m.group(2) is not None:
910                rv.append(self._re_undouble.sub(r"\1", m.group(2)))
911            else:
912                rv.append(m.group(3))
913        return rv
914
915    def _create_type(self, name, attnames):
916        try:
917            from collections import namedtuple
918        except ImportError:
919            self.type = tuple
920            self._ctor = self.type
921        else:
922            self.type = namedtuple(name, attnames)
923            self._ctor = self.type._make
924
925    @classmethod
926    def _from_db(self, name, conn_or_curs):
927        """Return a `CompositeCaster` instance for the type *name*.
928
929        Raise `ProgrammingError` if the type is not found.
930        """
931        conn, curs = _solve_conn_curs(conn_or_curs)
932
933        # Store the transaction status of the connection to revert it after use
934        conn_status = conn.status
935
936        # Use the correct schema
937        if '.' in name:
938            schema, tname = name.split('.', 1)
939        else:
940            tname = name
941            schema = 'public'
942
943        # column typarray not available before PG 8.3
944        typarray = conn.server_version >= 80300 and "typarray" or "NULL"
945
946        # get the type oid and attributes
947        curs.execute("""\
948SELECT t.oid, %s, attname, atttypid
949FROM pg_type t
950JOIN pg_namespace ns ON typnamespace = ns.oid
951JOIN pg_attribute a ON attrelid = typrelid
952WHERE typname = %%s AND nspname = %%s
953    AND attnum > 0 AND NOT attisdropped
954ORDER BY attnum;
955""" % typarray, (tname, schema))
956
957        recs = curs.fetchall()
958
959        # revert the status of the connection as before the command
960        if (conn_status != _ext.STATUS_IN_TRANSACTION
961        and not conn.autocommit):
962            conn.rollback()
963
964        if not recs:
965            raise psycopg2.ProgrammingError(
966                "PostgreSQL type '%s' not found" % name)
967
968        type_oid = recs[0][0]
969        array_oid = recs[0][1]
970        type_attrs = [ (r[2], r[3]) for r in recs ]
971
972        return self(tname, type_oid, type_attrs,
973            array_oid=array_oid, schema=schema)
974
975def register_composite(name, conn_or_curs, globally=False, factory=None):
976    """Register a typecaster to convert a composite type into a tuple.
977
978    :param name: the name of a PostgreSQL composite type, e.g. created using
979        the |CREATE TYPE|_ command
980    :param conn_or_curs: a connection or cursor used to find the type oid and
981        components; the typecaster is registered in a scope limited to this
982        object, unless *globally* is set to `!True`
983    :param globally: if `!False` (default) register the typecaster only on
984        *conn_or_curs*, otherwise register it globally
985    :param factory: if specified it should be a `CompositeCaster` subclass: use
986        it to :ref:`customize how to cast composite types <custom-composite>`
987    :return: the registered `CompositeCaster` or *factory* instance
988        responsible for the conversion
989    """
990    if factory is None:
991        factory = CompositeCaster
992
993    caster = factory._from_db(name, conn_or_curs)
994    _ext.register_type(caster.typecaster, not globally and conn_or_curs or None)
995
996    if caster.array_typecaster is not None:
997        _ext.register_type(caster.array_typecaster, not globally and conn_or_curs or None)
998
999    return caster
1000
1001
1002
1003def _paginate(seq, page_size):
1004    """Consume an iterable and return it in chunks.
1005
1006    Every chunk is at most `page_size`. Never return an empty chunk.
1007    """
1008    page = []
1009    it = iter(seq)
1010    while 1:
1011        try:
1012            for i in range(page_size):
1013                page.append(next(it))
1014            yield page
1015            page = []
1016        except StopIteration:
1017            if page:
1018                yield page
1019            return
1020
1021
1022def execute_batch(cur, sql, argslist, page_size=100):
1023    r"""Execute groups of statements in fewer server roundtrips.
1024
1025    Execute *sql* several times, against all parameters set (sequences or
1026    mappings) found in *argslist*.
1027
1028    The function is semantically similar to
1029
1030    .. parsed-literal::
1031
1032        *cur*\.\ `~cursor.executemany`\ (\ *sql*\ , *argslist*\ )
1033
1034    but has a different implementation: Psycopg will join the statements into
1035    fewer multi-statement commands, each one containing at most *page_size*
1036    statements, resulting in a reduced number of server roundtrips.
1037
1038    After the execution of the function the `cursor.rowcount` property will
1039    **not** contain a total result.
1040
1041    """
1042    for page in _paginate(argslist, page_size=page_size):
1043        sqls = [cur.mogrify(sql, args) for args in page]
1044        cur.execute(b";".join(sqls))
1045
1046
1047def execute_values(cur, sql, argslist, template=None, page_size=100, fetch=False):
1048    '''Execute a statement using :sql:`VALUES` with a sequence of parameters.
1049
1050    :param cur: the cursor to use to execute the query.
1051
1052    :param sql: the query to execute. It must contain a single ``%s``
1053        placeholder, which will be replaced by a `VALUES list`__.
1054        Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``.
1055
1056    :param argslist: sequence of sequences or dictionaries with the arguments
1057        to send to the query. The type and content must be consistent with
1058        *template*.
1059
1060    :param template: the snippet to merge to every item in *argslist* to
1061        compose the query.
1062
1063        - If the *argslist* items are sequences it should contain positional
1064          placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there
1065          are constants value...).
1066
1067        - If the *argslist* items are mappings it should contain named
1068          placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``).
1069
1070        If not specified, assume the arguments are sequence and use a simple
1071        positional template (i.e.  ``(%s, %s, ...)``), with the number of
1072        placeholders sniffed by the first element in *argslist*.
1073
1074    :param page_size: maximum number of *argslist* items to include in every
1075        statement. If there are more items the function will execute more than
1076        one statement.
1077
1078    :param fetch: if `!True` return the query results into a list (like in a
1079        `~cursor.fetchall()`).  Useful for queries with :sql:`RETURNING`
1080        clause.
1081
1082    .. __: https://www.postgresql.org/docs/current/static/queries-values.html
1083
1084    After the execution of the function the `cursor.rowcount` property will
1085    **not** contain a total result.
1086
1087    While :sql:`INSERT` is an obvious candidate for this function it is
1088    possible to use it with other statements, for example::
1089
1090        >>> cur.execute(
1091        ... "create table test (id int primary key, v1 int, v2 int)")
1092
1093        >>> execute_values(cur,
1094        ... "INSERT INTO test (id, v1, v2) VALUES %s",
1095        ... [(1, 2, 3), (4, 5, 6), (7, 8, 9)])
1096
1097        >>> execute_values(cur,
1098        ... """UPDATE test SET v1 = data.v1 FROM (VALUES %s) AS data (id, v1)
1099        ... WHERE test.id = data.id""",
1100        ... [(1, 20), (4, 50)])
1101
1102        >>> cur.execute("select * from test order by id")
1103        >>> cur.fetchall()
1104        [(1, 20, 3), (4, 50, 6), (7, 8, 9)])
1105
1106    '''
1107    # we can't just use sql % vals because vals is bytes: if sql is bytes
1108    # there will be some decoding error because of stupid codec used, and Py3
1109    # doesn't implement % on bytes.
1110    if not isinstance(sql, bytes):
1111        sql = sql.encode(_ext.encodings[cur.connection.encoding])
1112    pre, post = _split_sql(sql)
1113
1114    result = [] if fetch else None
1115    for page in _paginate(argslist, page_size=page_size):
1116        if template is None:
1117            template = b'(' + b','.join([b'%s'] * len(page[0])) + b')'
1118        parts = pre[:]
1119        for args in page:
1120            parts.append(cur.mogrify(template, args))
1121            parts.append(b',')
1122        parts[-1:] = post
1123        cur.execute(b''.join(parts))
1124        if fetch:
1125            result.extend(cur.fetchall())
1126
1127    return result
1128
1129
1130def _split_sql(sql):
1131    """Split *sql* on a single ``%s`` placeholder.
1132
1133    Split on the %s, perform %% replacement and return pre, post lists of
1134    snippets.
1135    """
1136    curr = pre = []
1137    post = []
1138    tokens = _re.split(br'(%.)', sql)
1139    for token in tokens:
1140        if len(token) != 2 or token[:1] != b'%':
1141            curr.append(token)
1142            continue
1143
1144        if token[1:] == b's':
1145            if curr is pre:
1146                curr = post
1147            else:
1148                raise ValueError(
1149                    "the query contains more than one '%s' placeholder")
1150        elif token[1:] == b'%':
1151            curr.append(b'%')
1152        else:
1153            raise ValueError("unsupported format character: '%s'"
1154                % token[1:].decode('ascii', 'replace'))
1155
1156    if curr is pre:
1157        raise ValueError("the query doesn't contain any '%s' placeholder")
1158
1159    return pre, post
1160
1161
1162# expose the json adaptation stuff into the module
1163from psycopg2cffi._json import json, Json, register_json, register_default_json
1164from psycopg2cffi._json import register_default_json, register_default_jsonb
1165
1166
1167
1168# Expose range-related objects
1169from psycopg2cffi._range import Range, NumericRange
1170from psycopg2cffi._range import DateRange, DateTimeRange, DateTimeTZRange
1171from psycopg2cffi._range import register_range, RangeAdapter, RangeCaster
1172