1# -*- test-case-name: twisted.test.test_adbapi -*-
2# Copyright (c) Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5"""
6An asynchronous mapping to U{DB-API 2.0<http://www.python.org/topics/database/DatabaseAPI-2.0.html>}.
7"""
8
9import sys
10
11from twisted.internet import threads
12from twisted.python import reflect, log
13from twisted.python.deprecate import deprecated
14from twisted.python.versions import Version
15
16
17
18class ConnectionLost(Exception):
19    """
20    This exception means that a db connection has been lost.  Client code may
21    try again.
22    """
23
24
25
26class Connection(object):
27    """
28    A wrapper for a DB-API connection instance.
29
30    The wrapper passes almost everything to the wrapped connection and so has
31    the same API. However, the Connection knows about its pool and also
32    handle reconnecting should when the real connection dies.
33    """
34
35    def __init__(self, pool):
36        self._pool = pool
37        self._connection = None
38        self.reconnect()
39
40    def close(self):
41        # The way adbapi works right now means that closing a connection is
42        # a really bad thing  as it leaves a dead connection associated with
43        # a thread in the thread pool.
44        # Really, I think closing a pooled connection should return it to the
45        # pool but that's handled by the runWithConnection method already so,
46        # rather than upsetting anyone by raising an exception, let's ignore
47        # the request
48        pass
49
50    def rollback(self):
51        if not self._pool.reconnect:
52            self._connection.rollback()
53            return
54
55        try:
56            self._connection.rollback()
57            curs = self._connection.cursor()
58            curs.execute(self._pool.good_sql)
59            curs.close()
60            self._connection.commit()
61            return
62        except:
63            log.err(None, "Rollback failed")
64
65        self._pool.disconnect(self._connection)
66
67        if self._pool.noisy:
68            log.msg("Connection lost.")
69
70        raise ConnectionLost()
71
72    def reconnect(self):
73        if self._connection is not None:
74            self._pool.disconnect(self._connection)
75        self._connection = self._pool.connect()
76
77    def __getattr__(self, name):
78        return getattr(self._connection, name)
79
80
81class Transaction:
82    """A lightweight wrapper for a DB-API 'cursor' object.
83
84    Relays attribute access to the DB cursor. That is, you can call
85    execute(), fetchall(), etc., and they will be called on the
86    underlying DB-API cursor object. Attributes will also be
87    retrieved from there.
88    """
89    _cursor = None
90
91    def __init__(self, pool, connection):
92        self._pool = pool
93        self._connection = connection
94        self.reopen()
95
96    def close(self):
97        _cursor = self._cursor
98        self._cursor = None
99        _cursor.close()
100
101    def reopen(self):
102        if self._cursor is not None:
103            self.close()
104
105        try:
106            self._cursor = self._connection.cursor()
107            return
108        except:
109            if not self._pool.reconnect:
110                raise
111            else:
112                log.err(None, "Cursor creation failed")
113
114        if self._pool.noisy:
115            log.msg('Connection lost, reconnecting')
116
117        self.reconnect()
118        self._cursor = self._connection.cursor()
119
120    def reconnect(self):
121        self._connection.reconnect()
122        self._cursor = None
123
124    def __getattr__(self, name):
125        return getattr(self._cursor, name)
126
127
128class ConnectionPool:
129    """
130    Represent a pool of connections to a DB-API 2.0 compliant database.
131
132    @ivar connectionFactory: factory for connections, default to L{Connection}.
133    @type connectionFactory: any callable.
134
135    @ivar transactionFactory: factory for transactions, default to
136        L{Transaction}.
137    @type transactionFactory: any callable
138
139    @ivar shutdownID: C{None} or a handle on the shutdown event trigger
140        which will be used to stop the connection pool workers when the
141        reactor stops.
142
143    @ivar _reactor: The reactor which will be used to schedule startup and
144        shutdown events.
145    @type _reactor: L{IReactorCore} provider
146    """
147
148    CP_ARGS = "min max name noisy openfun reconnect good_sql".split()
149
150    noisy = False # if true, generate informational log messages
151    min = 3 # minimum number of connections in pool
152    max = 5 # maximum number of connections in pool
153    name = None # Name to assign to thread pool for debugging
154    openfun = None # A function to call on new connections
155    reconnect = False # reconnect when connections fail
156    good_sql = 'select 1' # a query which should always succeed
157
158    running = False # true when the pool is operating
159    connectionFactory = Connection
160    transactionFactory = Transaction
161
162    # Initialize this to None so it's available in close() even if start()
163    # never runs.
164    shutdownID = None
165
166    def __init__(self, dbapiName, *connargs, **connkw):
167        """Create a new ConnectionPool.
168
169        Any positional or keyword arguments other than those documented here
170        are passed to the DB-API object when connecting. Use these arguments to
171        pass database names, usernames, passwords, etc.
172
173        @param dbapiName: an import string to use to obtain a DB-API compatible
174                          module (e.g. 'pyPgSQL.PgSQL')
175
176        @param cp_min: the minimum number of connections in pool (default 3)
177
178        @param cp_max: the maximum number of connections in pool (default 5)
179
180        @param cp_noisy: generate informational log messages during operation
181                         (default False)
182
183        @param cp_openfun: a callback invoked after every connect() on the
184                           underlying DB-API object. The callback is passed a
185                           new DB-API connection object.  This callback can
186                           setup per-connection state such as charset,
187                           timezone, etc.
188
189        @param cp_reconnect: detect connections which have failed and reconnect
190                             (default False). Failed connections may result in
191                             ConnectionLost exceptions, which indicate the
192                             query may need to be re-sent.
193
194        @param cp_good_sql: an sql query which should always succeed and change
195                            no state (default 'select 1')
196
197        @param cp_reactor: use this reactor instead of the global reactor
198            (added in Twisted 10.2).
199        @type cp_reactor: L{IReactorCore} provider
200        """
201
202        self.dbapiName = dbapiName
203        self.dbapi = reflect.namedModule(dbapiName)
204
205        if getattr(self.dbapi, 'apilevel', None) != '2.0':
206            log.msg('DB API module not DB API 2.0 compliant.')
207
208        if getattr(self.dbapi, 'threadsafety', 0) < 1:
209            log.msg('DB API module not sufficiently thread-safe.')
210
211        reactor = connkw.pop('cp_reactor', None)
212        if reactor is None:
213            from twisted.internet import reactor
214        self._reactor = reactor
215
216        self.connargs = connargs
217        self.connkw = connkw
218
219        for arg in self.CP_ARGS:
220            cp_arg = 'cp_%s' % arg
221            if cp_arg in connkw:
222                setattr(self, arg, connkw[cp_arg])
223                del connkw[cp_arg]
224
225        self.min = min(self.min, self.max)
226        self.max = max(self.min, self.max)
227
228        self.connections = {}  # all connections, hashed on thread id
229
230        # these are optional so import them here
231        from twisted.python import threadpool
232        import thread
233
234        self.threadID = thread.get_ident
235        self.threadpool = threadpool.ThreadPool(self.min, self.max)
236        self.startID = self._reactor.callWhenRunning(self._start)
237
238
239    def _start(self):
240        self.startID = None
241        return self.start()
242
243
244    def start(self):
245        """
246        Start the connection pool.
247
248        If you are using the reactor normally, this function does *not*
249        need to be called.
250        """
251        if not self.running:
252            self.threadpool.start()
253            self.shutdownID = self._reactor.addSystemEventTrigger(
254                'during', 'shutdown', self.finalClose)
255            self.running = True
256
257
258    def runWithConnection(self, func, *args, **kw):
259        """
260        Execute a function with a database connection and return the result.
261
262        @param func: A callable object of one argument which will be executed
263            in a thread with a connection from the pool.  It will be passed as
264            its first argument a L{Connection} instance (whose interface is
265            mostly identical to that of a connection object for your DB-API
266            module of choice), and its results will be returned as a Deferred.
267            If the method raises an exception the transaction will be rolled
268            back.  Otherwise, the transaction will be committed.  B{Note} that
269            this function is B{not} run in the main thread: it must be
270            threadsafe.
271
272        @param *args: positional arguments to be passed to func
273
274        @param **kw: keyword arguments to be passed to func
275
276        @return: a Deferred which will fire the return value of
277            C{func(Transaction(...), *args, **kw)}, or a Failure.
278        """
279        from twisted.internet import reactor
280        return threads.deferToThreadPool(reactor, self.threadpool,
281                                         self._runWithConnection,
282                                         func, *args, **kw)
283
284
285    def _runWithConnection(self, func, *args, **kw):
286        conn = self.connectionFactory(self)
287        try:
288            result = func(conn, *args, **kw)
289            conn.commit()
290            return result
291        except:
292            excType, excValue, excTraceback = sys.exc_info()
293            try:
294                conn.rollback()
295            except:
296                log.err(None, "Rollback failed")
297            raise excType, excValue, excTraceback
298
299
300    def runInteraction(self, interaction, *args, **kw):
301        """
302        Interact with the database and return the result.
303
304        The 'interaction' is a callable object which will be executed
305        in a thread using a pooled connection. It will be passed an
306        L{Transaction} object as an argument (whose interface is
307        identical to that of the database cursor for your DB-API
308        module of choice), and its results will be returned as a
309        Deferred. If running the method raises an exception, the
310        transaction will be rolled back. If the method returns a
311        value, the transaction will be committed.
312
313        NOTE that the function you pass is *not* run in the main
314        thread: you may have to worry about thread-safety in the
315        function you pass to this if it tries to use non-local
316        objects.
317
318        @param interaction: a callable object whose first argument
319            is an L{adbapi.Transaction}.
320
321        @param *args: additional positional arguments to be passed
322            to interaction
323
324        @param **kw: keyword arguments to be passed to interaction
325
326        @return: a Deferred which will fire the return value of
327            'interaction(Transaction(...), *args, **kw)', or a Failure.
328        """
329        from twisted.internet import reactor
330        return threads.deferToThreadPool(reactor, self.threadpool,
331                                         self._runInteraction,
332                                         interaction, *args, **kw)
333
334
335    def runQuery(self, *args, **kw):
336        """Execute an SQL query and return the result.
337
338        A DB-API cursor will will be invoked with cursor.execute(*args, **kw).
339        The exact nature of the arguments will depend on the specific flavor
340        of DB-API being used, but the first argument in *args be an SQL
341        statement. The result of a subsequent cursor.fetchall() will be
342        fired to the Deferred which is returned. If either the 'execute' or
343        'fetchall' methods raise an exception, the transaction will be rolled
344        back and a Failure returned.
345
346        The  *args and **kw arguments will be passed to the DB-API cursor's
347        'execute' method.
348
349        @return: a Deferred which will fire the return value of a DB-API
350        cursor's 'fetchall' method, or a Failure.
351        """
352        return self.runInteraction(self._runQuery, *args, **kw)
353
354
355    def runOperation(self, *args, **kw):
356        """Execute an SQL query and return None.
357
358        A DB-API cursor will will be invoked with cursor.execute(*args, **kw).
359        The exact nature of the arguments will depend on the specific flavor
360        of DB-API being used, but the first argument in *args will be an SQL
361        statement. This method will not attempt to fetch any results from the
362        query and is thus suitable for INSERT, DELETE, and other SQL statements
363        which do not return values. If the 'execute' method raises an
364        exception, the transaction will be rolled back and a Failure returned.
365
366        The args and kw arguments will be passed to the DB-API cursor's
367        'execute' method.
368
369        return: a Deferred which will fire None or a Failure.
370        """
371        return self.runInteraction(self._runOperation, *args, **kw)
372
373
374    def close(self):
375        """
376        Close all pool connections and shutdown the pool.
377        """
378        if self.shutdownID:
379            self._reactor.removeSystemEventTrigger(self.shutdownID)
380            self.shutdownID = None
381        if self.startID:
382            self._reactor.removeSystemEventTrigger(self.startID)
383            self.startID = None
384        self.finalClose()
385
386    def finalClose(self):
387        """This should only be called by the shutdown trigger."""
388
389        self.shutdownID = None
390        self.threadpool.stop()
391        self.running = False
392        for conn in self.connections.values():
393            self._close(conn)
394        self.connections.clear()
395
396    def connect(self):
397        """Return a database connection when one becomes available.
398
399        This method blocks and should be run in a thread from the internal
400        threadpool. Don't call this method directly from non-threaded code.
401        Using this method outside the external threadpool may exceed the
402        maximum number of connections in the pool.
403
404        @return: a database connection from the pool.
405        """
406
407        tid = self.threadID()
408        conn = self.connections.get(tid)
409        if conn is None:
410            if self.noisy:
411                log.msg('adbapi connecting: %s %s%s' % (self.dbapiName,
412                                                        self.connargs or '',
413                                                        self.connkw or ''))
414            conn = self.dbapi.connect(*self.connargs, **self.connkw)
415            if self.openfun != None:
416                self.openfun(conn)
417            self.connections[tid] = conn
418        return conn
419
420    def disconnect(self, conn):
421        """Disconnect a database connection associated with this pool.
422
423        Note: This function should only be used by the same thread which
424        called connect(). As with connect(), this function is not used
425        in normal non-threaded twisted code.
426        """
427        tid = self.threadID()
428        if conn is not self.connections.get(tid):
429            raise Exception("wrong connection for thread")
430        if conn is not None:
431            self._close(conn)
432            del self.connections[tid]
433
434
435    def _close(self, conn):
436        if self.noisy:
437            log.msg('adbapi closing: %s' % (self.dbapiName,))
438        try:
439            conn.close()
440        except:
441            log.err(None, "Connection close failed")
442
443
444    def _runInteraction(self, interaction, *args, **kw):
445        conn = self.connectionFactory(self)
446        trans = self.transactionFactory(self, conn)
447        try:
448            result = interaction(trans, *args, **kw)
449            trans.close()
450            conn.commit()
451            return result
452        except:
453            excType, excValue, excTraceback = sys.exc_info()
454            try:
455                conn.rollback()
456            except:
457                log.err(None, "Rollback failed")
458            raise excType, excValue, excTraceback
459
460
461    def _runQuery(self, trans, *args, **kw):
462        trans.execute(*args, **kw)
463        return trans.fetchall()
464
465    def _runOperation(self, trans, *args, **kw):
466        trans.execute(*args, **kw)
467
468    def __getstate__(self):
469        return {'dbapiName': self.dbapiName,
470                'min': self.min,
471                'max': self.max,
472                'noisy': self.noisy,
473                'reconnect': self.reconnect,
474                'good_sql': self.good_sql,
475                'connargs': self.connargs,
476                'connkw': self.connkw}
477
478    def __setstate__(self, state):
479        self.__dict__ = state
480        self.__init__(self.dbapiName, *self.connargs, **self.connkw)
481
482
483__all__ = ['Transaction', 'ConnectionPool']
484