1"""Miscellaneous goodies for psycopg2 2 3This module is a generic place used to hold little helper functions 4and classes untill a better place in the distribution is found. 5""" 6# psycopg/extras.py - miscellaneous extra goodies for psycopg 7# 8# Copyright (C) 2003-2010 Federico Di Gregorio <fog@debian.org> 9# 10# psycopg2 is free software: you can redistribute it and/or modify it 11# under the terms of the GNU Lesser General Public License as published 12# by the Free Software Foundation, either version 3 of the License, or 13# (at your option) any later version. 14# 15# In addition, as a special exception, the copyright holders give 16# permission to link this program with the OpenSSL library (or with 17# modified versions of OpenSSL that use the same license as OpenSSL), 18# and distribute linked combinations including the two. 19# 20# You must obey the GNU Lesser General Public License in all respects for 21# all of the code used other than OpenSSL. 22# 23# psycopg2 is distributed in the hope that it will be useful, but WITHOUT 24# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 25# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public 26# License for more details. 27 28from __future__ import unicode_literals 29 30import os as _os 31import sys as _sys 32import time as _time 33import re as _re 34import six 35 36try: 37 import logging as _logging 38except: 39 _logging = None 40 41import psycopg2cffi as psycopg2 42from psycopg2cffi import extensions as _ext 43from psycopg2cffi.extensions import cursor as _cursor 44from psycopg2cffi.extensions import connection as _connection 45from psycopg2cffi.extensions import adapt as _A 46from psycopg2cffi._impl.adapters import ascii_to_bytes, bytes_to_ascii 47 48 49class DictCursorBase(_cursor): 50 """Base class for all dict-like cursors.""" 51 52 def __init__(self, *args, **kwargs): 53 if 'row_factory' in kwargs: 54 row_factory = kwargs['row_factory'] 55 del kwargs['row_factory'] 56 else: 57 raise NotImplementedError( 58 "DictCursorBase can't be instantiated without a row factory.") 59 super(DictCursorBase, self).__init__(*args, **kwargs) 60 self._query_executed = 0 61 self._prefetch = 0 62 self.row_factory = row_factory 63 64 def fetchone(self): 65 if self._prefetch: 66 res = super(DictCursorBase, self).fetchone() 67 if self._query_executed: 68 self._build_index() 69 if not self._prefetch: 70 res = super(DictCursorBase, self).fetchone() 71 return res 72 73 def fetchmany(self, size=None): 74 if self._prefetch: 75 res = super(DictCursorBase, self).fetchmany(size) 76 if self._query_executed: 77 self._build_index() 78 if not self._prefetch: 79 res = super(DictCursorBase, self).fetchmany(size) 80 return res 81 82 def fetchall(self): 83 if self._prefetch: 84 res = super(DictCursorBase, self).fetchall() 85 if self._query_executed: 86 self._build_index() 87 if not self._prefetch: 88 res = super(DictCursorBase, self).fetchall() 89 return res 90 91 def __iter__(self): 92 if self._prefetch: 93 res = super(DictCursorBase, self).__iter__() 94 try: 95 first = six.next(res) 96 except StopIteration: 97 return 98 if self._query_executed: 99 self._build_index() 100 if not self._prefetch: 101 res = super(DictCursorBase, self).__iter__() 102 try: 103 first = six.next(res) 104 except StopIteration: 105 return 106 107 yield first 108 while 1: 109 try: 110 yield six.next(res) 111 except StopIteration: 112 return 113 114 115class DictConnection(_connection): 116 """A connection that uses `DictCursor` automatically.""" 117 def cursor(self, *args, **kwargs): 118 kwargs.setdefault('cursor_factory', DictCursor) 119 return super(DictConnection, self).cursor(*args, **kwargs) 120 121class DictCursor(DictCursorBase): 122 """A cursor that keeps a list of column name -> index mappings.""" 123 124 def __init__(self, *args, **kwargs): 125 kwargs['row_factory'] = DictRow 126 super(DictCursor, self).__init__(*args, **kwargs) 127 self._prefetch = 1 128 129 def execute(self, query, vars=None): 130 self.index = {} 131 self._query_executed = 1 132 return super(DictCursor, self).execute(query, vars) 133 134 def callproc(self, procname, vars=None): 135 self.index = {} 136 self._query_executed = 1 137 return super(DictCursor, self).callproc(procname, vars) 138 139 def _build_index(self): 140 if self._query_executed == 1 and self.description: 141 for i in range(len(self.description)): 142 self.index[self.description[i][0]] = i 143 self._query_executed = 0 144 145class DictRow(list): 146 """A row object that allow by-colmun-name access to data.""" 147 148 __slots__ = ('_index',) 149 150 def __init__(self, cursor): 151 self._index = cursor.index 152 self[:] = [None] * len(cursor.description) 153 154 def __getitem__(self, x): 155 if not isinstance(x, (int, slice)): 156 x = self._index[x] 157 return list.__getitem__(self, x) 158 159 def __setitem__(self, x, v): 160 if not isinstance(x, (int, slice)): 161 x = self._index[x] 162 list.__setitem__(self, x, v) 163 164 def items(self): 165 return list(self.iteritems()) 166 167 def keys(self): 168 return self._index.keys() 169 170 def values(self): 171 return tuple(self[:]) 172 173 def has_key(self, x): 174 return x in self._index 175 176 def get(self, x, default=None): 177 try: 178 return self[x] 179 except: 180 return default 181 182 def iteritems(self): 183 for n, v in six.iteritems(self._index): 184 yield n, list.__getitem__(self, v) 185 186 def iterkeys(self): 187 return six.iterkeys(self._index) 188 189 def itervalues(self): 190 return list.__iter__(self) 191 192 def copy(self): 193 return dict(self.iteritems()) 194 195 def __contains__(self, x): 196 return x in self._index 197 198 def __getstate__(self): 199 return self[:], self._index.copy() 200 201 def __setstate__(self, data): 202 self[:] = data[0] 203 self._index = data[1] 204 205 # drop the crusty Py2 methods 206 if _sys.version_info[0] > 2: 207 items = iteritems; del iteritems 208 keys = iterkeys; del iterkeys 209 values = itervalues; del itervalues 210 del has_key 211 212 213class RealDictConnection(_connection): 214 """A connection that uses `RealDictCursor` automatically.""" 215 def cursor(self, *args, **kwargs): 216 kwargs.setdefault('cursor_factory', RealDictCursor) 217 return super(RealDictConnection, self).cursor(*args, **kwargs) 218 219class RealDictCursor(DictCursorBase): 220 """A cursor that uses a real dict as the base type for rows. 221 222 Note that this cursor is extremely specialized and does not allow 223 the normal access (using integer indices) to fetched data. If you need 224 to access database rows both as a dictionary and a list, then use 225 the generic `DictCursor` instead of `!RealDictCursor`. 226 """ 227 def __init__(self, *args, **kwargs): 228 kwargs['row_factory'] = RealDictRow 229 super(RealDictCursor, self).__init__(*args, **kwargs) 230 self._prefetch = 0 231 232 def execute(self, query, vars=None): 233 self.column_mapping = [] 234 self._query_executed = 1 235 return super(RealDictCursor, self).execute(query, vars) 236 237 def callproc(self, procname, vars=None): 238 self.column_mapping = [] 239 self._query_executed = 1 240 return super(RealDictCursor, self).callproc(procname, vars) 241 242 def _build_index(self): 243 if self._query_executed == 1 and self.description: 244 for i in range(len(self.description)): 245 self.column_mapping.append(self.description[i][0]) 246 self._query_executed = 0 247 248class RealDictRow(dict): 249 """A `!dict` subclass representing a data record.""" 250 251 __slots__ = ('_column_mapping') 252 253 def __init__(self, cursor): 254 dict.__init__(self) 255 # Required for named cursors 256 if cursor.description and not cursor.column_mapping: 257 cursor._build_index() 258 259 self._column_mapping = cursor.column_mapping 260 261 def __setitem__(self, name, value): 262 if type(name) == int: 263 name = self._column_mapping[name] 264 return dict.__setitem__(self, name, value) 265 266 def __getstate__(self): 267 return (self.copy(), self._column_mapping[:]) 268 269 def __setstate__(self, data): 270 self.update(data[0]) 271 self._column_mapping = data[1] 272 273 274class NamedTupleConnection(_connection): 275 """A connection that uses `NamedTupleCursor` automatically.""" 276 def cursor(self, *args, **kwargs): 277 kwargs.setdefault('cursor_factory', NamedTupleCursor) 278 return super(NamedTupleConnection, self).cursor(*args, **kwargs) 279 280class NamedTupleCursor(_cursor): 281 """A cursor that generates results as `~collections.namedtuple`. 282 283 `!fetch*()` methods will return named tuples instead of regular tuples, so 284 their elements can be accessed both as regular numeric items as well as 285 attributes. 286 287 >>> nt_cur = conn.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor) 288 >>> rec = nt_cur.fetchone() 289 >>> rec 290 Record(id=1, num=100, data="abc'def") 291 >>> rec[1] 292 100 293 >>> rec.data 294 "abc'def" 295 """ 296 Record = None 297 298 def execute(self, query, vars=None): 299 self.Record = None 300 return super(NamedTupleCursor, self).execute(query, vars) 301 302 def executemany(self, query, vars): 303 self.Record = None 304 return super(NamedTupleCursor, self).executemany(query, vars) 305 306 def callproc(self, procname, vars=None): 307 self.Record = None 308 return super(NamedTupleCursor, self).callproc(procname, vars) 309 310 def fetchone(self): 311 t = super(NamedTupleCursor, self).fetchone() 312 if t is not None: 313 nt = self.Record 314 if nt is None: 315 nt = self.Record = self._make_nt() 316 return nt._make(t) 317 318 def fetchmany(self, size=None): 319 ts = super(NamedTupleCursor, self).fetchmany(size) 320 nt = self.Record 321 if nt is None: 322 nt = self.Record = self._make_nt() 323 return [nt._make(x) for x in ts] 324 325 def fetchall(self): 326 ts = super(NamedTupleCursor, self).fetchall() 327 nt = self.Record 328 if nt is None: 329 nt = self.Record = self._make_nt() 330 return [nt._make(x) for x in ts] 331 332 def __iter__(self): 333 it = super(NamedTupleCursor, self).__iter__() 334 t = six.next(it) 335 336 nt = self.Record 337 if nt is None: 338 nt = self.Record = self._make_nt() 339 340 yield nt._make(t) 341 342 while 1: 343 try: 344 t = six.next(it) 345 except StopIteration: 346 return 347 else: 348 yield nt._make(t) 349 350 try: 351 from collections import namedtuple 352 except ImportError as _exc: 353 def _make_nt(self): 354 raise self._exc 355 else: 356 def _make_nt(self, namedtuple=namedtuple): 357 return namedtuple("Record", [d[0] for d in self.description or ()]) 358 359 360class LoggingConnection(_connection): 361 """A connection that logs all queries to a file or logger__ object. 362 363 .. __: http://docs.python.org/library/logging.html 364 """ 365 366 def initialize(self, logobj): 367 """Initialize the connection to log to `!logobj`. 368 369 The `!logobj` parameter can be an open file object or a Logger 370 instance from the standard logging module. 371 """ 372 self._logobj = logobj 373 if _logging and isinstance(logobj, _logging.Logger): 374 self.log = self._logtologger 375 else: 376 self.log = self._logtofile 377 378 def filter(self, msg, curs): 379 """Filter the query before logging it. 380 381 This is the method to overwrite to filter unwanted queries out of the 382 log or to add some extra data to the output. The default implementation 383 just does nothing. 384 """ 385 return msg 386 387 def _logtofile(self, msg, curs): 388 msg = self.filter(msg, curs) 389 if msg: self._logobj.write(msg + _os.linesep) 390 391 def _logtologger(self, msg, curs): 392 msg = self.filter(msg, curs) 393 if msg: self._logobj.debug(msg) 394 395 def _check(self): 396 if not hasattr(self, '_logobj'): 397 raise self.ProgrammingError( 398 "LoggingConnection object has not been initialize()d") 399 400 def cursor(self, *args, **kwargs): 401 self._check() 402 kwargs.setdefault('cursor_factory', LoggingCursor) 403 return super(LoggingConnection, self).cursor(*args, **kwargs) 404 405class LoggingCursor(_cursor): 406 """A cursor that logs queries using its connection logging facilities.""" 407 408 def execute(self, query, vars=None): 409 try: 410 return super(LoggingCursor, self).execute(query, vars) 411 finally: 412 self.connection.log(self.query, self) 413 414 def callproc(self, procname, vars=None): 415 try: 416 return super(LoggingCursor, self).callproc(procname, vars) 417 finally: 418 self.connection.log(self.query, self) 419 420 421class MinTimeLoggingConnection(LoggingConnection): 422 """A connection that logs queries based on execution time. 423 424 This is just an example of how to sub-class `LoggingConnection` to 425 provide some extra filtering for the logged queries. Both the 426 `inizialize()` and `filter()` methods are overwritten to make sure 427 that only queries executing for more than ``mintime`` ms are logged. 428 429 Note that this connection uses the specialized cursor 430 `MinTimeLoggingCursor`. 431 """ 432 def initialize(self, logobj, mintime=0): 433 LoggingConnection.initialize(self, logobj) 434 self._mintime = mintime 435 436 def filter(self, msg, curs): 437 t = (_time.time() - curs.timestamp) * 1000 438 if t > self._mintime: 439 return msg + _os.linesep + " (execution time: %d ms)" % t 440 441 def cursor(self, *args, **kwargs): 442 kwargs.setdefault('cursor_factory', MinTimeLoggingCursor) 443 return LoggingConnection.cursor(self, *args, **kwargs) 444 445class MinTimeLoggingCursor(LoggingCursor): 446 """The cursor sub-class companion to `MinTimeLoggingConnection`.""" 447 448 def execute(self, query, vars=None): 449 self.timestamp = _time.time() 450 return LoggingCursor.execute(self, query, vars) 451 452 def callproc(self, procname, vars=None): 453 self.timestamp = _time.time() 454 return LoggingCursor.execute(self, procname, vars) 455 456 457# a dbtype and adapter for Python UUID type 458 459class UUID_adapter(object): 460 """Adapt Python's uuid.UUID__ type to PostgreSQL's uuid__. 461 462 .. __: http://docs.python.org/library/uuid.html 463 .. __: http://www.postgresql.org/docs/current/static/datatype-uuid.html 464 """ 465 466 def __init__(self, uuid): 467 self._uuid = uuid 468 469 def __conform__(self, proto): 470 if proto is _ext.ISQLQuote: 471 return self 472 473 def getquoted(self): 474 return b''.join([b"'", ascii_to_bytes(str(self._uuid)), b"'::uuid"]) 475 476 def __bytes__(self): 477 return self.getquoted() 478 479 def __str__(self): 480 return "'%s'::uuid" % self._uuid 481 482 483def register_uuid(oids=None, conn_or_curs=None): 484 """Create the UUID type and an uuid.UUID adapter. 485 486 :param oids: oid for the PostgreSQL :sql:`uuid` type, or 2-items sequence 487 with oids of the type and the array. If not specified, use PostgreSQL 488 standard oids. 489 :param conn_or_curs: where to register the typecaster. If not specified, 490 register it globally. 491 """ 492 493 import uuid 494 495 if not oids: 496 oid1 = 2950 497 oid2 = 2951 498 elif isinstance(oids, (list, tuple)): 499 oid1, oid2 = oids 500 else: 501 oid1 = oids 502 oid2 = 2951 503 504 _ext.UUID = _ext.new_type((oid1, ), "UUID", 505 lambda data, cursor: data and uuid.UUID(data) or None) 506 _ext.UUIDARRAY = _ext.new_array_type((oid2,), "UUID[]", _ext.UUID) 507 508 _ext.register_type(_ext.UUID, conn_or_curs) 509 _ext.register_type(_ext.UUIDARRAY, conn_or_curs) 510 _ext.register_adapter(uuid.UUID, UUID_adapter) 511 512 return _ext.UUID 513 514 515# a type, dbtype and adapter for PostgreSQL inet type 516 517class Inet(object): 518 """Wrap a string to allow for correct SQL-quoting of inet values. 519 520 Note that this adapter does NOT check the passed value to make 521 sure it really is an inet-compatible address but DOES call adapt() 522 on it to make sure it is impossible to execute an SQL-injection 523 by passing an evil value to the initializer. 524 """ 525 def __init__(self, addr): 526 self.addr = addr 527 528 def __repr__(self): 529 return "%s(%r)" % (self.__class__.__name__, self.addr) 530 531 def prepare(self, conn): 532 self._conn = conn 533 534 def getquoted(self): 535 obj = _A(self.addr) 536 if hasattr(obj, 'prepare'): 537 obj.prepare(self._conn) 538 return obj.getquoted() + b"::inet" 539 540 def __conform__(self, proto): 541 if proto is _ext.ISQLQuote: 542 return self 543 544 def __str__(self): 545 return str(self.addr) 546 547def register_inet(oid=None, conn_or_curs=None): 548 """Create the INET type and an Inet adapter. 549 550 :param oid: oid for the PostgreSQL :sql:`inet` type, or 2-items sequence 551 with oids of the type and the array. If not specified, use PostgreSQL 552 standard oids. 553 :param conn_or_curs: where to register the typecaster. If not specified, 554 register it globally. 555 """ 556 if not oid: 557 oid1 = 869 558 oid2 = 1041 559 elif isinstance(oid, (list, tuple)): 560 oid1, oid2 = oid 561 else: 562 oid1 = oid 563 oid2 = 1041 564 565 _ext.INET = _ext.new_type((oid1, ), "INET", 566 lambda data, cursor: data and Inet(data) or None) 567 _ext.INETARRAY = _ext.new_array_type((oid2, ), "INETARRAY", _ext.INET) 568 569 _ext.register_type(_ext.INET, conn_or_curs) 570 _ext.register_type(_ext.INETARRAY, conn_or_curs) 571 572 return _ext.INET 573 574 575def register_tstz_w_secs(oids=None, conn_or_curs=None): 576 """The function used to register an alternate type caster for 577 :sql:`TIMESTAMP WITH TIME ZONE` to deal with historical time zones with 578 seconds in the UTC offset. 579 580 These are now correctly handled by the default type caster, so currently 581 the function doesn't do anything. 582 """ 583 import warnings 584 warnings.warn("deprecated", DeprecationWarning) 585 586 587def wait_select(conn): 588 """Wait until a connection or cursor has data available. 589 590 The function is an example of a wait callback to be registered with 591 `~psycopg2.extensions.set_wait_callback()`. This function uses 592 :py:func:`~select.select()` to wait for data available. 593 594 """ 595 import select 596 from psycopg2cffi.extensions import POLL_OK, POLL_READ, POLL_WRITE 597 598 while 1: 599 state = conn.poll() 600 if state == POLL_OK: 601 break 602 elif state == POLL_READ: 603 select.select([conn.fileno()], [], []) 604 elif state == POLL_WRITE: 605 select.select([], [conn.fileno()], []) 606 else: 607 raise conn.OperationalError("bad state from poll: %s" % state) 608 609 610def _solve_conn_curs(conn_or_curs): 611 """Return the connection and a DBAPI cursor from a connection or cursor.""" 612 if conn_or_curs is None: 613 raise psycopg2.ProgrammingError("no connection or cursor provided") 614 615 if hasattr(conn_or_curs, 'execute'): 616 conn = conn_or_curs.connection 617 curs = conn.cursor(cursor_factory=_cursor) 618 else: 619 conn = conn_or_curs 620 curs = conn.cursor(cursor_factory=_cursor) 621 622 return conn, curs 623 624 625class HstoreAdapter(object): 626 """Adapt a Python dict to the hstore syntax.""" 627 def __init__(self, wrapped): 628 self.wrapped = wrapped 629 self.unicode = False 630 631 def prepare(self, conn): 632 self.conn = conn 633 634 # use an old-style getquoted implementation if required 635 if conn.server_version < 90000: 636 self.getquoted = self._getquoted_8 637 638 def _getquoted_8(self): 639 """Use the operators available in PG pre-9.0.""" 640 if not self.wrapped: 641 return b"''::hstore" 642 643 adapt = _ext.adapt 644 rv = [] 645 for k, v in six.iteritems(self.wrapped): 646 k = adapt(k) 647 k.prepare(self.conn) 648 k = k.getquoted() 649 650 if v is not None: 651 v = adapt(v) 652 v.prepare(self.conn) 653 v = v.getquoted() 654 else: 655 v = b'NULL' 656 657 rv.append(b"(" + k + b" => " + v + b")") 658 659 return b"(" + b'||'.join(rv) + b")" 660 661 def _getquoted_9(self): 662 """Use the hstore(text[], text[]) function.""" 663 if not self.wrapped: 664 return b"''::hstore" 665 666 k = _ext.adapt(list(self.wrapped.keys())) 667 k.prepare(self.conn) 668 v = _ext.adapt(list(self.wrapped.values())) 669 v.prepare(self.conn) 670 return b"hstore(" + k.getquoted() + b", " + v.getquoted() + b")" 671 672 getquoted = _getquoted_9 673 674 _re_hstore = _re.compile(r""" 675 # hstore key: 676 # a string of normal or escaped chars 677 "((?: [^"\\] | \\. )*)" 678 \s*=>\s* # hstore value 679 (?: 680 NULL # the value can be null - not catched 681 # or a quoted string like the key 682 | "((?: [^"\\] | \\. )*)" 683 ) 684 (?:\s*,\s*|$) # pairs separated by comma or end of string. 685 """, _re.VERBOSE | _re.UNICODE) 686 687 @classmethod 688 def parse(self, s, cur, _bsdec=_re.compile(r"\\(.)", _re.UNICODE)): 689 """Parse an hstore representation in a Python string. 690 691 The hstore is represented as something like:: 692 693 "a"=>"1", "b"=>"2" 694 695 with backslash-escaped strings. 696 """ 697 if s is None: 698 return None 699 700 rv = {} 701 start = 0 702 if six.PY3 and isinstance(s, six.binary_type): 703 s = s.decode(_ext.encodings[cur.connection.encoding]) \ 704 if cur else bytes_to_ascii(s) 705 for m in self._re_hstore.finditer(s): 706 if m is None or m.start() != start: 707 raise psycopg2.InterfaceError( 708 "error parsing hstore pair at char %d" % start) 709 k = _bsdec.sub(r'\1', m.group(1), _re.UNICODE) 710 v = m.group(2) 711 if v is not None: 712 v = _bsdec.sub(r'\1', v, _re.UNICODE) 713 714 rv[k] = v 715 start = m.end() 716 717 if start < len(s): 718 raise psycopg2.InterfaceError( 719 "error parsing hstore: unparsed data after char %d" % start) 720 721 return rv 722 723 @classmethod 724 def _to_unicode(self, s, cur): 725 if s is None: 726 return None 727 else: 728 return s.decode(_ext.encodings[cur.connection.encoding]) \ 729 if cur else bytes_to_ascii(s) 730 731 @classmethod 732 def parse_unicode(self, s, cur): 733 """Parse an hstore returning unicode keys and values.""" 734 if s is None: 735 return None 736 737 if not isinstance(s, unicode): 738 s = s.decode(_ext.encodings[cur.connection.encoding]) 739 740 return self.parse(s, cur) 741 742 @classmethod 743 def get_oids(self, conn_or_curs): 744 """Return the lists of OID of the hstore and hstore[] types. 745 """ 746 conn, curs = _solve_conn_curs(conn_or_curs) 747 748 # Store the transaction status of the connection to revert it after use 749 conn_status = conn.status 750 751 # column typarray not available before PG 8.3 752 typarray = conn.server_version >= 80300 and "typarray" or "NULL" 753 754 rv0, rv1 = [], [] 755 756 # get the oid for the hstore 757 curs.execute("""\ 758SELECT t.oid, %s 759FROM pg_type t JOIN pg_namespace ns 760 ON typnamespace = ns.oid 761WHERE typname = 'hstore'; 762""" % typarray) 763 for oids in curs: 764 rv0.append(oids[0]) 765 rv1.append(oids[1]) 766 767 # revert the status of the connection as before the command 768 if (conn_status != _ext.STATUS_IN_TRANSACTION 769 and not conn.autocommit): 770 conn.rollback() 771 772 return tuple(rv0), tuple(rv1) 773 774def register_hstore(conn_or_curs, globally=False, unicode=False, 775 oid=None, array_oid=None): 776 """Register adapter and typecaster for `!dict`\-\ |hstore| conversions. 777 778 :param conn_or_curs: a connection or cursor: the typecaster will be 779 registered only on this object unless *globally* is set to `!True` 780 :param globally: register the adapter globally, not only on *conn_or_curs* 781 :param unicode: if `!True`, keys and values returned from the database 782 will be `!unicode` instead of `!str`. The option is not available on 783 Python 3 784 :param oid: the OID of the |hstore| type if known. If not, it will be 785 queried on *conn_or_curs*. 786 :param array_oid: the OID of the |hstore| array type if known. If not, it 787 will be queried on *conn_or_curs*. 788 789 The connection or cursor passed to the function will be used to query the 790 database and look for the OID of the |hstore| type (which may be different 791 across databases). If querying is not desirable (e.g. with 792 :ref:`asynchronous connections <async-support>`) you may specify it in the 793 *oid* parameter, which can be found using a query such as :sql:`SELECT 794 'hstore'::regtype::oid`. Analogously you can obtain a value for *array_oid* 795 using a query such as :sql:`SELECT 'hstore[]'::regtype::oid`. 796 797 Note that, when passing a dictionary from Python to the database, both 798 strings and unicode keys and values are supported. Dictionaries returned 799 from the database have keys/values according to the *unicode* parameter. 800 801 The |hstore| contrib module must be already installed in the database 802 (executing the ``hstore.sql`` script in your ``contrib`` directory). 803 Raise `~psycopg2.ProgrammingError` if the type is not found. 804 """ 805 if oid is None: 806 oid = HstoreAdapter.get_oids(conn_or_curs) 807 if oid is None or not oid[0]: 808 raise psycopg2.ProgrammingError( 809 "hstore type not found in the database. " 810 "please install it from your 'contrib/hstore.sql' file") 811 else: 812 array_oid = oid[1] 813 oid = oid[0] 814 815 if isinstance(oid, int): 816 oid = (oid,) 817 818 if array_oid is not None: 819 if isinstance(array_oid, int): 820 array_oid = (array_oid,) 821 else: 822 array_oid = tuple([x for x in array_oid if x]) 823 824 # create and register the typecaster 825 if _sys.version_info[0] < 3 and unicode: 826 cast = HstoreAdapter.parse_unicode 827 else: 828 cast = HstoreAdapter.parse 829 830 HSTORE = _ext.new_type(oid, "HSTORE", cast, unicode) 831 _ext.register_type(HSTORE, not globally and conn_or_curs or None) 832 _ext.register_adapter(dict, HstoreAdapter) 833 834 if array_oid: 835 HSTOREARRAY = _ext.new_array_type(array_oid, "HSTOREARRAY", HSTORE) 836 _ext.register_type(HSTOREARRAY, not globally and conn_or_curs or None) 837 838 839class CompositeCaster(object): 840 """Helps conversion of a PostgreSQL composite type into a Python object. 841 842 The class is usually created by the `register_composite()` function. 843 You may want to create and register manually instances of the class if 844 querying the database at registration time is not desirable (such as when 845 using an :ref:`asynchronous connections <async-support>`). 846 847 """ 848 def __init__(self, name, oid, attrs, array_oid=None, schema=None): 849 self.name = name 850 self.schema = schema 851 self.oid = oid 852 self.array_oid = array_oid 853 854 self.attnames = [ a[0] for a in attrs ] 855 self.atttypes = [ a[1] for a in attrs ] 856 self._create_type(name, self.attnames) 857 self.typecaster = _ext.new_type((oid,), name, self.parse) 858 if array_oid: 859 self.array_typecaster = _ext.new_array_type( 860 (array_oid,), "%sARRAY" % name, self.typecaster) 861 else: 862 self.array_typecaster = None 863 864 def parse(self, s, curs): 865 if s is None: 866 return None 867 868 tokens = self.tokenize(s) 869 if len(tokens) != len(self.atttypes): 870 raise psycopg2.DataError( 871 "expecting %d components for the type %s, %d found instead" % 872 (len(self.atttypes), self.name, len(tokens))) 873 874 values = [ curs.cast(oid, token) 875 for oid, token in zip(self.atttypes, tokens) ] 876 877 return self.make(values) 878 879 def make(self, values): 880 """Return a new Python object representing the data being casted. 881 882 *values* is the list of attributes, already casted into their Python 883 representation. 884 885 You can subclass this method to :ref:`customize the composite cast 886 <custom-composite>`. 887 """ 888 889 return self._ctor(values) 890 891 _re_tokenize = _re.compile(r""" 892 \(? ([,)]) # an empty token, representing NULL 893| \(? " ((?: [^"] | "")*) " [,)] # or a quoted string 894| \(? ([^",)]+) [,)] # or an unquoted string 895 """, _re.VERBOSE) 896 897 _re_undouble = _re.compile(r'(["\\])\1') 898 899 @classmethod 900 def tokenize(self, s): 901 ''' Gets bytestring, returns list of bytestrings 902 ''' 903 rv = [] 904 for m in self._re_tokenize.finditer(s): 905 if m is None: 906 raise psycopg2.InterfaceError("can't parse type: %r", s) 907 if m.group(1) is not None: 908 rv.append(None) 909 elif m.group(2) is not None: 910 rv.append(self._re_undouble.sub(r"\1", m.group(2))) 911 else: 912 rv.append(m.group(3)) 913 return rv 914 915 def _create_type(self, name, attnames): 916 try: 917 from collections import namedtuple 918 except ImportError: 919 self.type = tuple 920 self._ctor = self.type 921 else: 922 self.type = namedtuple(name, attnames) 923 self._ctor = self.type._make 924 925 @classmethod 926 def _from_db(self, name, conn_or_curs): 927 """Return a `CompositeCaster` instance for the type *name*. 928 929 Raise `ProgrammingError` if the type is not found. 930 """ 931 conn, curs = _solve_conn_curs(conn_or_curs) 932 933 # Store the transaction status of the connection to revert it after use 934 conn_status = conn.status 935 936 # Use the correct schema 937 if '.' in name: 938 schema, tname = name.split('.', 1) 939 else: 940 tname = name 941 schema = 'public' 942 943 # column typarray not available before PG 8.3 944 typarray = conn.server_version >= 80300 and "typarray" or "NULL" 945 946 # get the type oid and attributes 947 curs.execute("""\ 948SELECT t.oid, %s, attname, atttypid 949FROM pg_type t 950JOIN pg_namespace ns ON typnamespace = ns.oid 951JOIN pg_attribute a ON attrelid = typrelid 952WHERE typname = %%s AND nspname = %%s 953 AND attnum > 0 AND NOT attisdropped 954ORDER BY attnum; 955""" % typarray, (tname, schema)) 956 957 recs = curs.fetchall() 958 959 # revert the status of the connection as before the command 960 if (conn_status != _ext.STATUS_IN_TRANSACTION 961 and not conn.autocommit): 962 conn.rollback() 963 964 if not recs: 965 raise psycopg2.ProgrammingError( 966 "PostgreSQL type '%s' not found" % name) 967 968 type_oid = recs[0][0] 969 array_oid = recs[0][1] 970 type_attrs = [ (r[2], r[3]) for r in recs ] 971 972 return self(tname, type_oid, type_attrs, 973 array_oid=array_oid, schema=schema) 974 975def register_composite(name, conn_or_curs, globally=False, factory=None): 976 """Register a typecaster to convert a composite type into a tuple. 977 978 :param name: the name of a PostgreSQL composite type, e.g. created using 979 the |CREATE TYPE|_ command 980 :param conn_or_curs: a connection or cursor used to find the type oid and 981 components; the typecaster is registered in a scope limited to this 982 object, unless *globally* is set to `!True` 983 :param globally: if `!False` (default) register the typecaster only on 984 *conn_or_curs*, otherwise register it globally 985 :param factory: if specified it should be a `CompositeCaster` subclass: use 986 it to :ref:`customize how to cast composite types <custom-composite>` 987 :return: the registered `CompositeCaster` or *factory* instance 988 responsible for the conversion 989 """ 990 if factory is None: 991 factory = CompositeCaster 992 993 caster = factory._from_db(name, conn_or_curs) 994 _ext.register_type(caster.typecaster, not globally and conn_or_curs or None) 995 996 if caster.array_typecaster is not None: 997 _ext.register_type(caster.array_typecaster, not globally and conn_or_curs or None) 998 999 return caster 1000 1001 1002 1003def _paginate(seq, page_size): 1004 """Consume an iterable and return it in chunks. 1005 1006 Every chunk is at most `page_size`. Never return an empty chunk. 1007 """ 1008 page = [] 1009 it = iter(seq) 1010 while 1: 1011 try: 1012 for i in range(page_size): 1013 page.append(next(it)) 1014 yield page 1015 page = [] 1016 except StopIteration: 1017 if page: 1018 yield page 1019 return 1020 1021 1022def execute_batch(cur, sql, argslist, page_size=100): 1023 r"""Execute groups of statements in fewer server roundtrips. 1024 1025 Execute *sql* several times, against all parameters set (sequences or 1026 mappings) found in *argslist*. 1027 1028 The function is semantically similar to 1029 1030 .. parsed-literal:: 1031 1032 *cur*\.\ `~cursor.executemany`\ (\ *sql*\ , *argslist*\ ) 1033 1034 but has a different implementation: Psycopg will join the statements into 1035 fewer multi-statement commands, each one containing at most *page_size* 1036 statements, resulting in a reduced number of server roundtrips. 1037 1038 After the execution of the function the `cursor.rowcount` property will 1039 **not** contain a total result. 1040 1041 """ 1042 for page in _paginate(argslist, page_size=page_size): 1043 sqls = [cur.mogrify(sql, args) for args in page] 1044 cur.execute(b";".join(sqls)) 1045 1046 1047def execute_values(cur, sql, argslist, template=None, page_size=100, fetch=False): 1048 '''Execute a statement using :sql:`VALUES` with a sequence of parameters. 1049 1050 :param cur: the cursor to use to execute the query. 1051 1052 :param sql: the query to execute. It must contain a single ``%s`` 1053 placeholder, which will be replaced by a `VALUES list`__. 1054 Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. 1055 1056 :param argslist: sequence of sequences or dictionaries with the arguments 1057 to send to the query. The type and content must be consistent with 1058 *template*. 1059 1060 :param template: the snippet to merge to every item in *argslist* to 1061 compose the query. 1062 1063 - If the *argslist* items are sequences it should contain positional 1064 placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there 1065 are constants value...). 1066 1067 - If the *argslist* items are mappings it should contain named 1068 placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). 1069 1070 If not specified, assume the arguments are sequence and use a simple 1071 positional template (i.e. ``(%s, %s, ...)``), with the number of 1072 placeholders sniffed by the first element in *argslist*. 1073 1074 :param page_size: maximum number of *argslist* items to include in every 1075 statement. If there are more items the function will execute more than 1076 one statement. 1077 1078 :param fetch: if `!True` return the query results into a list (like in a 1079 `~cursor.fetchall()`). Useful for queries with :sql:`RETURNING` 1080 clause. 1081 1082 .. __: https://www.postgresql.org/docs/current/static/queries-values.html 1083 1084 After the execution of the function the `cursor.rowcount` property will 1085 **not** contain a total result. 1086 1087 While :sql:`INSERT` is an obvious candidate for this function it is 1088 possible to use it with other statements, for example:: 1089 1090 >>> cur.execute( 1091 ... "create table test (id int primary key, v1 int, v2 int)") 1092 1093 >>> execute_values(cur, 1094 ... "INSERT INTO test (id, v1, v2) VALUES %s", 1095 ... [(1, 2, 3), (4, 5, 6), (7, 8, 9)]) 1096 1097 >>> execute_values(cur, 1098 ... """UPDATE test SET v1 = data.v1 FROM (VALUES %s) AS data (id, v1) 1099 ... WHERE test.id = data.id""", 1100 ... [(1, 20), (4, 50)]) 1101 1102 >>> cur.execute("select * from test order by id") 1103 >>> cur.fetchall() 1104 [(1, 20, 3), (4, 50, 6), (7, 8, 9)]) 1105 1106 ''' 1107 # we can't just use sql % vals because vals is bytes: if sql is bytes 1108 # there will be some decoding error because of stupid codec used, and Py3 1109 # doesn't implement % on bytes. 1110 if not isinstance(sql, bytes): 1111 sql = sql.encode(_ext.encodings[cur.connection.encoding]) 1112 pre, post = _split_sql(sql) 1113 1114 result = [] if fetch else None 1115 for page in _paginate(argslist, page_size=page_size): 1116 if template is None: 1117 template = b'(' + b','.join([b'%s'] * len(page[0])) + b')' 1118 parts = pre[:] 1119 for args in page: 1120 parts.append(cur.mogrify(template, args)) 1121 parts.append(b',') 1122 parts[-1:] = post 1123 cur.execute(b''.join(parts)) 1124 if fetch: 1125 result.extend(cur.fetchall()) 1126 1127 return result 1128 1129 1130def _split_sql(sql): 1131 """Split *sql* on a single ``%s`` placeholder. 1132 1133 Split on the %s, perform %% replacement and return pre, post lists of 1134 snippets. 1135 """ 1136 curr = pre = [] 1137 post = [] 1138 tokens = _re.split(br'(%.)', sql) 1139 for token in tokens: 1140 if len(token) != 2 or token[:1] != b'%': 1141 curr.append(token) 1142 continue 1143 1144 if token[1:] == b's': 1145 if curr is pre: 1146 curr = post 1147 else: 1148 raise ValueError( 1149 "the query contains more than one '%s' placeholder") 1150 elif token[1:] == b'%': 1151 curr.append(b'%') 1152 else: 1153 raise ValueError("unsupported format character: '%s'" 1154 % token[1:].decode('ascii', 'replace')) 1155 1156 if curr is pre: 1157 raise ValueError("the query doesn't contain any '%s' placeholder") 1158 1159 return pre, post 1160 1161 1162# expose the json adaptation stuff into the module 1163from psycopg2cffi._json import json, Json, register_json, register_default_json 1164from psycopg2cffi._json import register_default_json, register_default_jsonb 1165 1166 1167 1168# Expose range-related objects 1169from psycopg2cffi._range import Range, NumericRange 1170from psycopg2cffi._range import DateRange, DateTimeRange, DateTimeTZRange 1171from psycopg2cffi._range import register_range, RangeAdapter, RangeCaster 1172