1import threading
2import time
3from sqlalchemy import pool, select, event
4import sqlalchemy as tsa
5from sqlalchemy import testing
6from sqlalchemy.testing.util import gc_collect, lazy_gc
7from sqlalchemy.testing import eq_, assert_raises, is_not_, is_
8from sqlalchemy.testing.engines import testing_engine
9from sqlalchemy.testing import fixtures
10import random
11from sqlalchemy.testing.mock import Mock, call, patch, ANY
12import weakref
13import collections
14
15join_timeout = 10
16
17
18def MockDBAPI():  # noqa
19    def cursor():
20        return Mock()
21
22    def connect(*arg, **kw):
23        return Mock(cursor=Mock(side_effect=cursor))
24
25    def shutdown(value):
26        if value:
27            db.connect = Mock(side_effect=Exception("connect failed"))
28        else:
29            db.connect = Mock(side_effect=connect)
30        db.is_shutdown = value
31
32    db = Mock(
33        connect=Mock(side_effect=connect),
34        shutdown=shutdown,
35        is_shutdown=False)
36    return db
37
38
39class PoolTestBase(fixtures.TestBase):
40    def setup(self):
41        pool.clear_managers()
42        self._teardown_conns = []
43
44    def teardown(self):
45        for ref in self._teardown_conns:
46            conn = ref()
47            if conn:
48                conn.close()
49
50    @classmethod
51    def teardown_class(cls):
52        pool.clear_managers()
53
54    def _with_teardown(self, connection):
55        self._teardown_conns.append(weakref.ref(connection))
56        return connection
57
58    def _queuepool_fixture(self, **kw):
59        dbapi, pool = self._queuepool_dbapi_fixture(**kw)
60        return pool
61
62    def _queuepool_dbapi_fixture(self, **kw):
63        dbapi = MockDBAPI()
64        return dbapi, pool.QueuePool(
65            creator=lambda: dbapi.connect('foo.db'),
66            **kw)
67
68
69class PoolTest(PoolTestBase):
70    def test_manager(self):
71        manager = pool.manage(MockDBAPI(), use_threadlocal=True)
72
73        c1 = manager.connect('foo.db')
74        c2 = manager.connect('foo.db')
75        c3 = manager.connect('bar.db')
76        c4 = manager.connect("foo.db", bar="bat")
77        c5 = manager.connect("foo.db", bar="hoho")
78        c6 = manager.connect("foo.db", bar="bat")
79
80        assert c1.cursor() is not None
81        assert c1 is c2
82        assert c1 is not c3
83        assert c4 is c6
84        assert c4 is not c5
85
86    def test_manager_with_key(self):
87
88        dbapi = MockDBAPI()
89        manager = pool.manage(dbapi, use_threadlocal=True)
90
91        c1 = manager.connect('foo.db', sa_pool_key="a")
92        c2 = manager.connect('foo.db', sa_pool_key="b")
93        c3 = manager.connect('bar.db', sa_pool_key="a")
94
95        assert c1.cursor() is not None
96        assert c1 is not c2
97        assert c1 is c3
98
99        eq_(
100            dbapi.connect.mock_calls,
101            [
102                call("foo.db"),
103                call("foo.db"),
104            ]
105        )
106
107    def test_bad_args(self):
108        manager = pool.manage(MockDBAPI())
109        manager.connect(None)
110
111    def test_non_thread_local_manager(self):
112        manager = pool.manage(MockDBAPI(), use_threadlocal=False)
113
114        connection = manager.connect('foo.db')
115        connection2 = manager.connect('foo.db')
116
117        self.assert_(connection.cursor() is not None)
118        self.assert_(connection is not connection2)
119
120    @testing.fails_on('+pyodbc',
121                      "pyodbc cursor doesn't implement tuple __eq__")
122    @testing.fails_on("+pg8000", "returns [1], not (1,)")
123    def test_cursor_iterable(self):
124        conn = testing.db.raw_connection()
125        cursor = conn.cursor()
126        cursor.execute(str(select([1], bind=testing.db)))
127        expected = [(1, )]
128        for row in cursor:
129            eq_(row, expected.pop(0))
130
131    def test_no_connect_on_recreate(self):
132        def creator():
133            raise Exception("no creates allowed")
134
135        for cls in (pool.SingletonThreadPool, pool.StaticPool,
136                    pool.QueuePool, pool.NullPool, pool.AssertionPool):
137            p = cls(creator=creator)
138            p.dispose()
139            p2 = p.recreate()
140            assert p2.__class__ is cls
141
142            mock_dbapi = MockDBAPI()
143            p = cls(creator=mock_dbapi.connect)
144            conn = p.connect()
145            conn.close()
146            mock_dbapi.connect.side_effect = Exception("error!")
147            p.dispose()
148            p.recreate()
149
150    def test_threadlocal_del(self):
151        self._do_testthreadlocal(useclose=False)
152
153    def test_threadlocal_close(self):
154        self._do_testthreadlocal(useclose=True)
155
156    def _do_testthreadlocal(self, useclose=False):
157        dbapi = MockDBAPI()
158        for p in pool.QueuePool(creator=dbapi.connect,
159                                pool_size=3, max_overflow=-1,
160                                use_threadlocal=True), \
161            pool.SingletonThreadPool(
162                creator=dbapi.connect,
163                use_threadlocal=True):
164            c1 = p.connect()
165            c2 = p.connect()
166            self.assert_(c1 is c2)
167            c3 = p.unique_connection()
168            self.assert_(c3 is not c1)
169            if useclose:
170                c2.close()
171            else:
172                c2 = None
173            c2 = p.connect()
174            self.assert_(c1 is c2)
175            self.assert_(c3 is not c1)
176            if useclose:
177                c2.close()
178            else:
179                c2 = None
180                lazy_gc()
181            if useclose:
182                c1 = p.connect()
183                c2 = p.connect()
184                c3 = p.connect()
185                c3.close()
186                c2.close()
187                self.assert_(c1.connection is not None)
188                c1.close()
189            c1 = c2 = c3 = None
190
191            # extra tests with QueuePool to ensure connections get
192            # __del__()ed when dereferenced
193
194            if isinstance(p, pool.QueuePool):
195                lazy_gc()
196                self.assert_(p.checkedout() == 0)
197                c1 = p.connect()
198                c2 = p.connect()
199                if useclose:
200                    c2.close()
201                    c1.close()
202                else:
203                    c2 = None
204                    c1 = None
205                    lazy_gc()
206                self.assert_(p.checkedout() == 0)
207
208    def test_info(self):
209        p = self._queuepool_fixture(pool_size=1, max_overflow=0)
210
211        c = p.connect()
212        self.assert_(not c.info)
213        self.assert_(c.info is c._connection_record.info)
214
215        c.info['foo'] = 'bar'
216        c.close()
217        del c
218
219        c = p.connect()
220        self.assert_('foo' in c.info)
221
222        c.invalidate()
223        c = p.connect()
224        self.assert_('foo' not in c.info)
225
226        c.info['foo2'] = 'bar2'
227        c.detach()
228        self.assert_('foo2' in c.info)
229
230        c2 = p.connect()
231        is_not_(c.connection, c2.connection)
232        assert not c2.info
233        assert 'foo2' in c.info
234
235
236class PoolDialectTest(PoolTestBase):
237    def _dialect(self):
238        canary = []
239
240        class PoolDialect(object):
241            def do_rollback(self, dbapi_connection):
242                canary.append('R')
243                dbapi_connection.rollback()
244
245            def do_commit(self, dbapi_connection):
246                canary.append('C')
247                dbapi_connection.commit()
248
249            def do_close(self, dbapi_connection):
250                canary.append('CL')
251                dbapi_connection.close()
252        return PoolDialect(), canary
253
254    def _do_test(self, pool_cls, assertion):
255        mock_dbapi = MockDBAPI()
256        dialect, canary = self._dialect()
257
258        p = pool_cls(creator=mock_dbapi.connect)
259        p._dialect = dialect
260        conn = p.connect()
261        conn.close()
262        p.dispose()
263        p.recreate()
264        conn = p.connect()
265        conn.close()
266        eq_(canary, assertion)
267
268    def test_queue_pool(self):
269        self._do_test(pool.QueuePool, ['R', 'CL', 'R'])
270
271    def test_assertion_pool(self):
272        self._do_test(pool.AssertionPool, ['R', 'CL', 'R'])
273
274    def test_singleton_pool(self):
275        self._do_test(pool.SingletonThreadPool, ['R', 'CL', 'R'])
276
277    def test_null_pool(self):
278        self._do_test(pool.NullPool, ['R', 'CL', 'R', 'CL'])
279
280    def test_static_pool(self):
281        self._do_test(pool.StaticPool, ['R', 'R'])
282
283
284class PoolEventsTest(PoolTestBase):
285    def _first_connect_event_fixture(self):
286        p = self._queuepool_fixture()
287        canary = []
288
289        def first_connect(*arg, **kw):
290            canary.append('first_connect')
291
292        event.listen(p, 'first_connect', first_connect)
293
294        return p, canary
295
296    def _connect_event_fixture(self):
297        p = self._queuepool_fixture()
298        canary = []
299
300        def connect(*arg, **kw):
301            canary.append('connect')
302
303        event.listen(p, 'connect', connect)
304
305        return p, canary
306
307    def _checkout_event_fixture(self):
308        p = self._queuepool_fixture()
309        canary = []
310
311        def checkout(*arg, **kw):
312            canary.append('checkout')
313        event.listen(p, 'checkout', checkout)
314
315        return p, canary
316
317    def _checkin_event_fixture(self):
318        p = self._queuepool_fixture()
319        canary = []
320
321        def checkin(*arg, **kw):
322            canary.append('checkin')
323        event.listen(p, 'checkin', checkin)
324
325        return p, canary
326
327    def _reset_event_fixture(self):
328        p = self._queuepool_fixture()
329        canary = []
330
331        def reset(*arg, **kw):
332            canary.append('reset')
333        event.listen(p, 'reset', reset)
334
335        return p, canary
336
337    def _invalidate_event_fixture(self):
338        p = self._queuepool_fixture()
339        canary = Mock()
340        event.listen(p, 'invalidate', canary)
341
342        return p, canary
343
344    def _soft_invalidate_event_fixture(self):
345        p = self._queuepool_fixture()
346        canary = Mock()
347        event.listen(p, 'soft_invalidate', canary)
348
349        return p, canary
350
351    def test_first_connect_event(self):
352        p, canary = self._first_connect_event_fixture()
353
354        p.connect()
355        eq_(canary, ['first_connect'])
356
357    def test_first_connect_event_fires_once(self):
358        p, canary = self._first_connect_event_fixture()
359
360        p.connect()
361        p.connect()
362
363        eq_(canary, ['first_connect'])
364
365    def test_first_connect_on_previously_recreated(self):
366        p, canary = self._first_connect_event_fixture()
367
368        p2 = p.recreate()
369        p.connect()
370        p2.connect()
371
372        eq_(canary, ['first_connect', 'first_connect'])
373
374    def test_first_connect_on_subsequently_recreated(self):
375        p, canary = self._first_connect_event_fixture()
376
377        p.connect()
378        p2 = p.recreate()
379        p2.connect()
380
381        eq_(canary, ['first_connect', 'first_connect'])
382
383    def test_connect_event(self):
384        p, canary = self._connect_event_fixture()
385
386        p.connect()
387        eq_(canary, ['connect'])
388
389    def test_connect_event_fires_subsequent(self):
390        p, canary = self._connect_event_fixture()
391
392        c1 = p.connect()  # noqa
393        c2 = p.connect()  # noqa
394
395        eq_(canary, ['connect', 'connect'])
396
397    def test_connect_on_previously_recreated(self):
398        p, canary = self._connect_event_fixture()
399
400        p2 = p.recreate()
401
402        p.connect()
403        p2.connect()
404
405        eq_(canary, ['connect', 'connect'])
406
407    def test_connect_on_subsequently_recreated(self):
408        p, canary = self._connect_event_fixture()
409
410        p.connect()
411        p2 = p.recreate()
412        p2.connect()
413
414        eq_(canary, ['connect', 'connect'])
415
416    def test_checkout_event(self):
417        p, canary = self._checkout_event_fixture()
418
419        p.connect()
420        eq_(canary, ['checkout'])
421
422    def test_checkout_event_fires_subsequent(self):
423        p, canary = self._checkout_event_fixture()
424
425        p.connect()
426        p.connect()
427        eq_(canary, ['checkout', 'checkout'])
428
429    def test_checkout_event_on_subsequently_recreated(self):
430        p, canary = self._checkout_event_fixture()
431
432        p.connect()
433        p2 = p.recreate()
434        p2.connect()
435
436        eq_(canary, ['checkout', 'checkout'])
437
438    def test_checkin_event(self):
439        p, canary = self._checkin_event_fixture()
440
441        c1 = p.connect()
442        eq_(canary, [])
443        c1.close()
444        eq_(canary, ['checkin'])
445
446    def test_reset_event(self):
447        p, canary = self._reset_event_fixture()
448
449        c1 = p.connect()
450        eq_(canary, [])
451        c1.close()
452        eq_(canary, ['reset'])
453
454    def test_soft_invalidate_event_no_exception(self):
455        p, canary = self._soft_invalidate_event_fixture()
456
457        c1 = p.connect()
458        c1.close()
459        assert not canary.called
460        c1 = p.connect()
461        dbapi_con = c1.connection
462        c1.invalidate(soft=True)
463        assert canary.call_args_list[0][0][0] is dbapi_con
464        assert canary.call_args_list[0][0][2] is None
465
466    def test_soft_invalidate_event_exception(self):
467        p, canary = self._soft_invalidate_event_fixture()
468
469        c1 = p.connect()
470        c1.close()
471        assert not canary.called
472        c1 = p.connect()
473        dbapi_con = c1.connection
474        exc = Exception("hi")
475        c1.invalidate(exc, soft=True)
476        assert canary.call_args_list[0][0][0] is dbapi_con
477        assert canary.call_args_list[0][0][2] is exc
478
479    def test_invalidate_event_no_exception(self):
480        p, canary = self._invalidate_event_fixture()
481
482        c1 = p.connect()
483        c1.close()
484        assert not canary.called
485        c1 = p.connect()
486        dbapi_con = c1.connection
487        c1.invalidate()
488        assert canary.call_args_list[0][0][0] is dbapi_con
489        assert canary.call_args_list[0][0][2] is None
490
491    def test_invalidate_event_exception(self):
492        p, canary = self._invalidate_event_fixture()
493
494        c1 = p.connect()
495        c1.close()
496        assert not canary.called
497        c1 = p.connect()
498        dbapi_con = c1.connection
499        exc = Exception("hi")
500        c1.invalidate(exc)
501        assert canary.call_args_list[0][0][0] is dbapi_con
502        assert canary.call_args_list[0][0][2] is exc
503
504    def test_checkin_event_gc(self):
505        p, canary = self._checkin_event_fixture()
506
507        c1 = p.connect()
508        eq_(canary, [])
509        del c1
510        lazy_gc()
511        eq_(canary, ['checkin'])
512
513    def test_checkin_event_on_subsequently_recreated(self):
514        p, canary = self._checkin_event_fixture()
515
516        c1 = p.connect()
517        p2 = p.recreate()
518        c2 = p2.connect()
519
520        eq_(canary, [])
521
522        c1.close()
523        eq_(canary, ['checkin'])
524
525        c2.close()
526        eq_(canary, ['checkin', 'checkin'])
527
528    def test_listen_targets_scope(self):
529        canary = []
530
531        def listen_one(*args):
532            canary.append("listen_one")
533
534        def listen_two(*args):
535            canary.append("listen_two")
536
537        def listen_three(*args):
538            canary.append("listen_three")
539
540        def listen_four(*args):
541            canary.append("listen_four")
542
543        engine = testing_engine(testing.db.url)
544        event.listen(pool.Pool, 'connect', listen_one)
545        event.listen(engine.pool, 'connect', listen_two)
546        event.listen(engine, 'connect', listen_three)
547        event.listen(engine.__class__, 'connect', listen_four)
548
549        engine.execute(select([1])).close()
550        eq_(
551            canary,
552            ["listen_one", "listen_four", "listen_two", "listen_three"]
553        )
554
555    def test_listen_targets_per_subclass(self):
556        """test that listen() called on a subclass remains specific to
557        that subclass."""
558
559        canary = []
560
561        def listen_one(*args):
562            canary.append("listen_one")
563
564        def listen_two(*args):
565            canary.append("listen_two")
566
567        def listen_three(*args):
568            canary.append("listen_three")
569
570        event.listen(pool.Pool, 'connect', listen_one)
571        event.listen(pool.QueuePool, 'connect', listen_two)
572        event.listen(pool.SingletonThreadPool, 'connect', listen_three)
573
574        p1 = pool.QueuePool(creator=MockDBAPI().connect)
575        p2 = pool.SingletonThreadPool(creator=MockDBAPI().connect)
576
577        assert listen_one in p1.dispatch.connect
578        assert listen_two in p1.dispatch.connect
579        assert listen_three not in p1.dispatch.connect
580        assert listen_one in p2.dispatch.connect
581        assert listen_two not in p2.dispatch.connect
582        assert listen_three in p2.dispatch.connect
583
584        p1.connect()
585        eq_(canary, ["listen_one", "listen_two"])
586        p2.connect()
587        eq_(canary, ["listen_one", "listen_two", "listen_one", "listen_three"])
588
589    def teardown(self):
590        # TODO: need to get remove() functionality
591        # going
592        pool.Pool.dispatch._clear()
593
594
595class PoolFirstConnectSyncTest(PoolTestBase):
596    # test [ticket:2964]
597
598    @testing.requires.timing_intensive
599    def test_sync(self):
600        pool = self._queuepool_fixture(pool_size=3, max_overflow=0)
601
602        evt = Mock()
603
604        @event.listens_for(pool, 'first_connect')
605        def slow_first_connect(dbapi_con, rec):
606            time.sleep(1)
607            evt.first_connect()
608
609        @event.listens_for(pool, 'connect')
610        def on_connect(dbapi_con, rec):
611            evt.connect()
612
613        def checkout():
614            for j in range(2):
615                c1 = pool.connect()
616                time.sleep(.02)
617                c1.close()
618                time.sleep(.02)
619
620        threads = []
621        for i in range(5):
622            th = threading.Thread(target=checkout)
623            th.start()
624            threads.append(th)
625        for th in threads:
626            th.join(join_timeout)
627
628        eq_(
629            evt.mock_calls,
630            [
631                call.first_connect(),
632                call.connect(),
633                call.connect(),
634                call.connect()]
635        )
636
637
638class DeprecatedPoolListenerTest(PoolTestBase):
639    @testing.requires.predictable_gc
640    @testing.uses_deprecated(r".*Use event.listen")
641    def test_listeners(self):
642
643        class InstrumentingListener(object):
644            def __init__(self):
645                if hasattr(self, 'connect'):
646                    self.connect = self.inst_connect
647                if hasattr(self, 'first_connect'):
648                    self.first_connect = self.inst_first_connect
649                if hasattr(self, 'checkout'):
650                    self.checkout = self.inst_checkout
651                if hasattr(self, 'checkin'):
652                    self.checkin = self.inst_checkin
653                self.clear()
654
655            def clear(self):
656                self.connected = []
657                self.first_connected = []
658                self.checked_out = []
659                self.checked_in = []
660
661            def assert_total(self, conn, fconn, cout, cin):
662                eq_(len(self.connected), conn)
663                eq_(len(self.first_connected), fconn)
664                eq_(len(self.checked_out), cout)
665                eq_(len(self.checked_in), cin)
666
667            def assert_in(
668                    self, item, in_conn, in_fconn,
669                    in_cout, in_cin):
670                eq_((item in self.connected), in_conn)
671                eq_((item in self.first_connected), in_fconn)
672                eq_((item in self.checked_out), in_cout)
673                eq_((item in self.checked_in), in_cin)
674
675            def inst_connect(self, con, record):
676                print("connect(%s, %s)" % (con, record))
677                assert con is not None
678                assert record is not None
679                self.connected.append(con)
680
681            def inst_first_connect(self, con, record):
682                print("first_connect(%s, %s)" % (con, record))
683                assert con is not None
684                assert record is not None
685                self.first_connected.append(con)
686
687            def inst_checkout(self, con, record, proxy):
688                print("checkout(%s, %s, %s)" % (con, record, proxy))
689                assert con is not None
690                assert record is not None
691                assert proxy is not None
692                self.checked_out.append(con)
693
694            def inst_checkin(self, con, record):
695                print("checkin(%s, %s)" % (con, record))
696                # con can be None if invalidated
697                assert record is not None
698                self.checked_in.append(con)
699
700        class ListenAll(tsa.interfaces.PoolListener, InstrumentingListener):
701            pass
702
703        class ListenConnect(InstrumentingListener):
704            def connect(self, con, record):
705                pass
706
707        class ListenFirstConnect(InstrumentingListener):
708            def first_connect(self, con, record):
709                pass
710
711        class ListenCheckOut(InstrumentingListener):
712            def checkout(self, con, record, proxy, num):
713                pass
714
715        class ListenCheckIn(InstrumentingListener):
716            def checkin(self, con, record):
717                pass
718
719        def assert_listeners(p, total, conn, fconn, cout, cin):
720            for instance in (p, p.recreate()):
721                self.assert_(len(instance.dispatch.connect) == conn)
722                self.assert_(len(instance.dispatch.first_connect) == fconn)
723                self.assert_(len(instance.dispatch.checkout) == cout)
724                self.assert_(len(instance.dispatch.checkin) == cin)
725
726        p = self._queuepool_fixture()
727        assert_listeners(p, 0, 0, 0, 0, 0)
728
729        p.add_listener(ListenAll())
730        assert_listeners(p, 1, 1, 1, 1, 1)
731
732        p.add_listener(ListenConnect())
733        assert_listeners(p, 2, 2, 1, 1, 1)
734
735        p.add_listener(ListenFirstConnect())
736        assert_listeners(p, 3, 2, 2, 1, 1)
737
738        p.add_listener(ListenCheckOut())
739        assert_listeners(p, 4, 2, 2, 2, 1)
740
741        p.add_listener(ListenCheckIn())
742        assert_listeners(p, 5, 2, 2, 2, 2)
743        del p
744
745        snoop = ListenAll()
746        p = self._queuepool_fixture(listeners=[snoop])
747        assert_listeners(p, 1, 1, 1, 1, 1)
748
749        c = p.connect()
750        snoop.assert_total(1, 1, 1, 0)
751        cc = c.connection
752        snoop.assert_in(cc, True, True, True, False)
753        c.close()
754        snoop.assert_in(cc, True, True, True, True)
755        del c, cc
756
757        snoop.clear()
758
759        # this one depends on immediate gc
760        c = p.connect()
761        cc = c.connection
762        snoop.assert_in(cc, False, False, True, False)
763        snoop.assert_total(0, 0, 1, 0)
764        del c, cc
765        lazy_gc()
766        snoop.assert_total(0, 0, 1, 1)
767
768        p.dispose()
769        snoop.clear()
770
771        c = p.connect()
772        c.close()
773        c = p.connect()
774        snoop.assert_total(1, 0, 2, 1)
775        c.close()
776        snoop.assert_total(1, 0, 2, 2)
777
778        # invalidation
779        p.dispose()
780        snoop.clear()
781
782        c = p.connect()
783        snoop.assert_total(1, 0, 1, 0)
784        c.invalidate()
785        snoop.assert_total(1, 0, 1, 1)
786        c.close()
787        snoop.assert_total(1, 0, 1, 1)
788        del c
789        lazy_gc()
790        snoop.assert_total(1, 0, 1, 1)
791        c = p.connect()
792        snoop.assert_total(2, 0, 2, 1)
793        c.close()
794        del c
795        lazy_gc()
796        snoop.assert_total(2, 0, 2, 2)
797
798        # detached
799        p.dispose()
800        snoop.clear()
801
802        c = p.connect()
803        snoop.assert_total(1, 0, 1, 0)
804        c.detach()
805        snoop.assert_total(1, 0, 1, 0)
806        c.close()
807        del c
808        snoop.assert_total(1, 0, 1, 0)
809        c = p.connect()
810        snoop.assert_total(2, 0, 2, 0)
811        c.close()
812        del c
813        snoop.assert_total(2, 0, 2, 1)
814
815        # recreated
816        p = p.recreate()
817        snoop.clear()
818
819        c = p.connect()
820        snoop.assert_total(1, 1, 1, 0)
821        c.close()
822        snoop.assert_total(1, 1, 1, 1)
823        c = p.connect()
824        snoop.assert_total(1, 1, 2, 1)
825        c.close()
826        snoop.assert_total(1, 1, 2, 2)
827
828    @testing.uses_deprecated(r".*Use event.listen")
829    def test_listeners_callables(self):
830        def connect(dbapi_con, con_record):
831            counts[0] += 1
832
833        def checkout(dbapi_con, con_record, con_proxy):
834            counts[1] += 1
835
836        def checkin(dbapi_con, con_record):
837            counts[2] += 1
838
839        i_all = dict(connect=connect, checkout=checkout, checkin=checkin)
840        i_connect = dict(connect=connect)
841        i_checkout = dict(checkout=checkout)
842        i_checkin = dict(checkin=checkin)
843
844        for cls in (pool.QueuePool, pool.StaticPool):
845            counts = [0, 0, 0]
846
847            def assert_listeners(p, total, conn, cout, cin):
848                for instance in (p, p.recreate()):
849                    eq_(len(instance.dispatch.connect), conn)
850                    eq_(len(instance.dispatch.checkout), cout)
851                    eq_(len(instance.dispatch.checkin), cin)
852
853            p = self._queuepool_fixture()
854            assert_listeners(p, 0, 0, 0, 0)
855
856            p.add_listener(i_all)
857            assert_listeners(p, 1, 1, 1, 1)
858
859            p.add_listener(i_connect)
860            assert_listeners(p, 2, 1, 1, 1)
861
862            p.add_listener(i_checkout)
863            assert_listeners(p, 3, 1, 1, 1)
864
865            p.add_listener(i_checkin)
866            assert_listeners(p, 4, 1, 1, 1)
867            del p
868
869            p = self._queuepool_fixture(listeners=[i_all])
870            assert_listeners(p, 1, 1, 1, 1)
871
872            c = p.connect()
873            assert counts == [1, 1, 0]
874            c.close()
875            assert counts == [1, 1, 1]
876
877            c = p.connect()
878            assert counts == [1, 2, 1]
879            p.add_listener(i_checkin)
880            c.close()
881            assert counts == [1, 2, 2]
882
883
884class QueuePoolTest(PoolTestBase):
885
886    def test_queuepool_del(self):
887        self._do_testqueuepool(useclose=False)
888
889    def test_queuepool_close(self):
890        self._do_testqueuepool(useclose=True)
891
892    def _do_testqueuepool(self, useclose=False):
893        p = self._queuepool_fixture(
894            pool_size=3,
895            max_overflow=-1)
896
897        def status(pool):
898            return pool.size(), pool.checkedin(), pool.overflow(), \
899                pool.checkedout()
900
901        c1 = p.connect()
902        self.assert_(status(p) == (3, 0, -2, 1))
903        c2 = p.connect()
904        self.assert_(status(p) == (3, 0, -1, 2))
905        c3 = p.connect()
906        self.assert_(status(p) == (3, 0, 0, 3))
907        c4 = p.connect()
908        self.assert_(status(p) == (3, 0, 1, 4))
909        c5 = p.connect()
910        self.assert_(status(p) == (3, 0, 2, 5))
911        c6 = p.connect()
912        self.assert_(status(p) == (3, 0, 3, 6))
913        if useclose:
914            c4.close()
915            c3.close()
916            c2.close()
917        else:
918            c4 = c3 = c2 = None
919            lazy_gc()
920        self.assert_(status(p) == (3, 3, 3, 3))
921        if useclose:
922            c1.close()
923            c5.close()
924            c6.close()
925        else:
926            c1 = c5 = c6 = None
927            lazy_gc()
928        self.assert_(status(p) == (3, 3, 0, 0))
929        c1 = p.connect()
930        c2 = p.connect()
931        self.assert_(status(p) == (3, 1, 0, 2), status(p))
932        if useclose:
933            c2.close()
934        else:
935            c2 = None
936            lazy_gc()
937        self.assert_(status(p) == (3, 2, 0, 1))
938        c1.close()
939        lazy_gc()
940        assert not pool._refs
941
942    @testing.requires.timing_intensive
943    def test_timeout(self):
944        p = self._queuepool_fixture(
945            pool_size=3,
946            max_overflow=0,
947            timeout=2)
948        c1 = p.connect()  # noqa
949        c2 = p.connect()  # noqa
950        c3 = p.connect()  # noqa
951        now = time.time()
952
953        assert_raises(
954            tsa.exc.TimeoutError,
955            p.connect
956        )
957        assert int(time.time() - now) == 2
958
959    @testing.requires.threading_with_mock
960    @testing.requires.timing_intensive
961    def test_timeout_race(self):
962        # test a race condition where the initial connecting threads all race
963        # to queue.Empty, then block on the mutex.  each thread consumes a
964        # connection as they go in.  when the limit is reached, the remaining
965        # threads go in, and get TimeoutError; even though they never got to
966        # wait for the timeout on queue.get().  the fix involves checking the
967        # timeout again within the mutex, and if so, unlocking and throwing
968        # them back to the start of do_get()
969        dbapi = MockDBAPI()
970        p = pool.QueuePool(
971            creator=lambda: dbapi.connect(delay=.05),
972            pool_size=2,
973            max_overflow=1, use_threadlocal=False, timeout=3)
974        timeouts = []
975
976        def checkout():
977            for x in range(1):
978                now = time.time()
979                try:
980                    c1 = p.connect()
981                except tsa.exc.TimeoutError:
982                    timeouts.append(time.time() - now)
983                    continue
984                time.sleep(4)
985                c1.close()
986
987        threads = []
988        for i in range(10):
989            th = threading.Thread(target=checkout)
990            th.start()
991            threads.append(th)
992        for th in threads:
993            th.join(join_timeout)
994
995        assert len(timeouts) > 0
996        for t in timeouts:
997            assert t >= 3, "Not all timeouts were >= 3 seconds %r" % timeouts
998            # normally, the timeout should under 4 seconds,
999            # but on a loaded down buildbot it can go up.
1000            assert t < 14, "Not all timeouts were < 14 seconds %r" % timeouts
1001
1002    def _test_overflow(self, thread_count, max_overflow):
1003        gc_collect()
1004
1005        dbapi = MockDBAPI()
1006        mutex = threading.Lock()
1007
1008        def creator():
1009            time.sleep(.05)
1010            with mutex:
1011                return dbapi.connect()
1012
1013        p = pool.QueuePool(creator=creator,
1014                           pool_size=3, timeout=2,
1015                           max_overflow=max_overflow)
1016        peaks = []
1017
1018        def whammy():
1019            for i in range(10):
1020                try:
1021                    con = p.connect()
1022                    time.sleep(.005)
1023                    peaks.append(p.overflow())
1024                    con.close()
1025                    del con
1026                except tsa.exc.TimeoutError:
1027                    pass
1028        threads = []
1029        for i in range(thread_count):
1030            th = threading.Thread(target=whammy)
1031            th.start()
1032            threads.append(th)
1033        for th in threads:
1034            th.join(join_timeout)
1035
1036        self.assert_(max(peaks) <= max_overflow)
1037
1038        lazy_gc()
1039        assert not pool._refs
1040
1041    def test_overflow_reset_on_failed_connect(self):
1042        dbapi = Mock()
1043
1044        def failing_dbapi():
1045            time.sleep(2)
1046            raise Exception("connection failed")
1047
1048        creator = dbapi.connect
1049
1050        def create():
1051            return creator()
1052
1053        p = pool.QueuePool(creator=create, pool_size=2, max_overflow=3)
1054        c1 = self._with_teardown(p.connect())  # noqa
1055        c2 = self._with_teardown(p.connect())  # noqa
1056        c3 = self._with_teardown(p.connect())  # noqa
1057        eq_(p._overflow, 1)
1058        creator = failing_dbapi
1059        assert_raises(Exception, p.connect)
1060        eq_(p._overflow, 1)
1061
1062    @testing.requires.threading_with_mock
1063    @testing.requires.timing_intensive
1064    def test_hanging_connect_within_overflow(self):
1065        """test that a single connect() call which is hanging
1066        does not block other connections from proceeding."""
1067
1068        dbapi = Mock()
1069        mutex = threading.Lock()
1070
1071        def hanging_dbapi():
1072            time.sleep(2)
1073            with mutex:
1074                return dbapi.connect()
1075
1076        def fast_dbapi():
1077            with mutex:
1078                return dbapi.connect()
1079
1080        creator = threading.local()
1081
1082        def create():
1083            return creator.mock_connector()
1084
1085        def run_test(name, pool, should_hang):
1086            if should_hang:
1087                creator.mock_connector = hanging_dbapi
1088            else:
1089                creator.mock_connector = fast_dbapi
1090
1091            conn = pool.connect()
1092            conn.operation(name)
1093            time.sleep(1)
1094            conn.close()
1095
1096        p = pool.QueuePool(creator=create, pool_size=2, max_overflow=3)
1097
1098        threads = [
1099            threading.Thread(
1100                target=run_test, args=("success_one", p, False)),
1101            threading.Thread(
1102                target=run_test, args=("success_two", p, False)),
1103            threading.Thread(
1104                target=run_test, args=("overflow_one", p, True)),
1105            threading.Thread(
1106                target=run_test, args=("overflow_two", p, False)),
1107            threading.Thread(
1108                target=run_test, args=("overflow_three", p, False))
1109        ]
1110        for t in threads:
1111            t.start()
1112            time.sleep(.2)
1113
1114        for t in threads:
1115            t.join(timeout=join_timeout)
1116        eq_(
1117            dbapi.connect().operation.mock_calls,
1118            [call("success_one"), call("success_two"),
1119                call("overflow_two"), call("overflow_three"),
1120                call("overflow_one")]
1121        )
1122
1123    @testing.requires.threading_with_mock
1124    @testing.requires.timing_intensive
1125    def test_waiters_handled(self):
1126        """test that threads waiting for connections are
1127        handled when the pool is replaced.
1128
1129        """
1130        mutex = threading.Lock()
1131        dbapi = MockDBAPI()
1132
1133        def creator():
1134            mutex.acquire()
1135            try:
1136                return dbapi.connect()
1137            finally:
1138                mutex.release()
1139
1140        success = []
1141        for timeout in (None, 30):
1142            for max_overflow in (0, -1, 3):
1143                p = pool.QueuePool(creator=creator,
1144                                   pool_size=2, timeout=timeout,
1145                                   max_overflow=max_overflow)
1146
1147                def waiter(p, timeout, max_overflow):
1148                    success_key = (timeout, max_overflow)
1149                    conn = p.connect()
1150                    success.append(success_key)
1151                    time.sleep(.1)
1152                    conn.close()
1153
1154                c1 = p.connect()  # noqa
1155                c2 = p.connect()
1156
1157                threads = []
1158                for i in range(2):
1159                    t = threading.Thread(
1160                        target=waiter,
1161                        args=(p, timeout, max_overflow))
1162                    t.daemon = True
1163                    t.start()
1164                    threads.append(t)
1165
1166                # this sleep makes sure that the
1167                # two waiter threads hit upon wait()
1168                # inside the queue, before we invalidate the other
1169                # two conns
1170                time.sleep(.2)
1171                p._invalidate(c2)
1172
1173                for t in threads:
1174                    t.join(join_timeout)
1175
1176        eq_(len(success), 12, "successes: %s" % success)
1177
1178    def test_connrec_invalidated_within_checkout_no_race(self):
1179        """Test that a concurrent ConnectionRecord.invalidate() which
1180        occurs after the ConnectionFairy has called
1181        _ConnectionRecord.checkout()
1182        but before the ConnectionFairy tests "fairy.connection is None"
1183        will not result in an InvalidRequestError.
1184
1185        This use case assumes that a listener on the checkout() event
1186        will be raising DisconnectionError so that a reconnect attempt
1187        may occur.
1188
1189        """
1190        dbapi = MockDBAPI()
1191
1192        def creator():
1193            return dbapi.connect()
1194
1195        p = pool.QueuePool(creator=creator, pool_size=1, max_overflow=0)
1196
1197        conn = p.connect()
1198        conn.close()
1199
1200        _existing_checkout = pool._ConnectionRecord.checkout
1201
1202        @classmethod
1203        def _decorate_existing_checkout(cls, *arg, **kw):
1204            fairy = _existing_checkout(*arg, **kw)
1205            connrec = fairy._connection_record
1206            connrec.invalidate()
1207            return fairy
1208
1209        with patch(
1210                "sqlalchemy.pool._ConnectionRecord.checkout",
1211                _decorate_existing_checkout):
1212            conn = p.connect()
1213            is_(conn._connection_record.connection, None)
1214        conn.close()
1215
1216    @testing.requires.threading_with_mock
1217    @testing.requires.timing_intensive
1218    def test_notify_waiters(self):
1219        dbapi = MockDBAPI()
1220
1221        canary = []
1222
1223        def creator():
1224            canary.append(1)
1225            return dbapi.connect()
1226        p1 = pool.QueuePool(
1227            creator=creator,
1228            pool_size=1, timeout=None,
1229            max_overflow=0)
1230
1231        def waiter(p):
1232            conn = p.connect()
1233            canary.append(2)
1234            time.sleep(.5)
1235            conn.close()
1236
1237        c1 = p1.connect()
1238
1239        threads = []
1240        for i in range(5):
1241            t = threading.Thread(target=waiter, args=(p1, ))
1242            t.start()
1243            threads.append(t)
1244        time.sleep(.5)
1245        eq_(canary, [1])
1246
1247        # this also calls invalidate()
1248        # on c1
1249        p1._invalidate(c1)
1250
1251        for t in threads:
1252            t.join(join_timeout)
1253
1254        eq_(canary, [1, 1, 2, 2, 2, 2, 2])
1255
1256    def test_dispose_closes_pooled(self):
1257        dbapi = MockDBAPI()
1258
1259        p = pool.QueuePool(creator=dbapi.connect,
1260                           pool_size=2, timeout=None,
1261                           max_overflow=0)
1262        c1 = p.connect()
1263        c2 = p.connect()
1264        c1_con = c1.connection
1265        c2_con = c2.connection
1266
1267        c1.close()
1268
1269        eq_(c1_con.close.call_count, 0)
1270        eq_(c2_con.close.call_count, 0)
1271
1272        p.dispose()
1273
1274        eq_(c1_con.close.call_count, 1)
1275        eq_(c2_con.close.call_count, 0)
1276
1277        # currently, if a ConnectionFairy is closed
1278        # after the pool has been disposed, there's no
1279        # flag that states it should be invalidated
1280        # immediately - it just gets returned to the
1281        # pool normally...
1282        c2.close()
1283        eq_(c1_con.close.call_count, 1)
1284        eq_(c2_con.close.call_count, 0)
1285
1286        # ...and that's the one we'll get back next.
1287        c3 = p.connect()
1288        assert c3.connection is c2_con
1289
1290    @testing.requires.threading_with_mock
1291    @testing.requires.timing_intensive
1292    def test_no_overflow(self):
1293        self._test_overflow(40, 0)
1294
1295    @testing.requires.threading_with_mock
1296    @testing.requires.timing_intensive
1297    def test_max_overflow(self):
1298        self._test_overflow(40, 5)
1299
1300    def test_mixed_close(self):
1301        pool._refs.clear()
1302        p = self._queuepool_fixture(pool_size=3, max_overflow=-1,
1303                                    use_threadlocal=True)
1304        c1 = p.connect()
1305        c2 = p.connect()
1306        assert c1 is c2
1307        c1.close()
1308        c2 = None
1309        assert p.checkedout() == 1
1310        c1 = None
1311        lazy_gc()
1312        assert p.checkedout() == 0
1313        lazy_gc()
1314        assert not pool._refs
1315
1316    def test_overflow_no_gc_tlocal(self):
1317        self._test_overflow_no_gc(True)
1318
1319    def test_overflow_no_gc(self):
1320        self._test_overflow_no_gc(False)
1321
1322    def _test_overflow_no_gc(self, threadlocal):
1323        p = self._queuepool_fixture(
1324            pool_size=2,
1325            max_overflow=2)
1326
1327        # disable weakref collection of the
1328        # underlying connections
1329        strong_refs = set()
1330
1331        def _conn():
1332            c = p.connect()
1333            strong_refs.add(c.connection)
1334            return c
1335
1336        for j in range(5):
1337            # open 4 conns at a time.  each time this
1338            # will yield two pooled connections + two
1339            # overflow connections.
1340            conns = [_conn() for i in range(4)]
1341            for c in conns:
1342                c.close()
1343
1344        # doing that for a total of 5 times yields
1345        # ten overflow connections closed plus the
1346        # two pooled connections unclosed.
1347
1348        eq_(
1349            set([c.close.call_count for c in strong_refs]),
1350            set([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0])
1351        )
1352
1353    @testing.requires.predictable_gc
1354    def test_weakref_kaboom(self):
1355        p = self._queuepool_fixture(
1356            pool_size=3,
1357            max_overflow=-1, use_threadlocal=True)
1358        c1 = p.connect()
1359        c2 = p.connect()
1360        c1.close()
1361        c2 = None
1362        del c1
1363        del c2
1364        gc_collect()
1365        assert p.checkedout() == 0
1366        c3 = p.connect()
1367        assert c3 is not None
1368
1369    def test_trick_the_counter(self):
1370        """this is a "flaw" in the connection pool; since threadlocal
1371        uses a single ConnectionFairy per thread with an open/close
1372        counter, you can fool the counter into giving you a
1373        ConnectionFairy with an ambiguous counter.  i.e. its not true
1374        reference counting."""
1375
1376        p = self._queuepool_fixture(
1377            pool_size=3,
1378            max_overflow=-1, use_threadlocal=True)
1379        c1 = p.connect()
1380        c2 = p.connect()
1381        assert c1 is c2
1382        c1.close()
1383        c2 = p.connect()
1384        c2.close()
1385        self.assert_(p.checkedout() != 0)
1386        c2.close()
1387        self.assert_(p.checkedout() == 0)
1388
1389    def test_recycle(self):
1390        with patch("sqlalchemy.pool.time.time") as mock:
1391            mock.return_value = 10000
1392
1393            p = self._queuepool_fixture(
1394                pool_size=1,
1395                max_overflow=0,
1396                recycle=30)
1397            c1 = p.connect()
1398            c_ref = weakref.ref(c1.connection)
1399            c1.close()
1400            mock.return_value = 10001
1401            c2 = p.connect()
1402
1403            is_(c2.connection, c_ref())
1404            c2.close()
1405
1406            mock.return_value = 10035
1407            c3 = p.connect()
1408            is_not_(c3.connection, c_ref())
1409
1410    @testing.requires.timing_intensive
1411    def test_recycle_on_invalidate(self):
1412        p = self._queuepool_fixture(
1413            pool_size=1,
1414            max_overflow=0)
1415        c1 = p.connect()
1416        c_ref = weakref.ref(c1.connection)
1417        c1.close()
1418        c2 = p.connect()
1419        is_(c2.connection, c_ref())
1420
1421        c2_rec = c2._connection_record
1422        p._invalidate(c2)
1423        assert c2_rec.connection is None
1424        c2.close()
1425        time.sleep(.5)
1426        c3 = p.connect()
1427
1428        is_not_(c3.connection, c_ref())
1429
1430    @testing.requires.timing_intensive
1431    def test_recycle_on_soft_invalidate(self):
1432        p = self._queuepool_fixture(
1433            pool_size=1,
1434            max_overflow=0)
1435        c1 = p.connect()
1436        c_ref = weakref.ref(c1.connection)
1437        c1.close()
1438        c2 = p.connect()
1439        is_(c2.connection, c_ref())
1440
1441        c2_rec = c2._connection_record
1442        c2.invalidate(soft=True)
1443        is_(c2_rec.connection, c2.connection)
1444
1445        c2.close()
1446        time.sleep(.5)
1447        c3 = p.connect()
1448        is_not_(c3.connection, c_ref())
1449        is_(c3._connection_record, c2_rec)
1450        is_(c2_rec.connection, c3.connection)
1451
1452    def _no_wr_finalize(self):
1453        finalize_fairy = pool._finalize_fairy
1454
1455        def assert_no_wr_callback(
1456            connection, connection_record,
1457                pool, ref, echo, fairy=None):
1458            if fairy is None:
1459                raise AssertionError(
1460                    "finalize fairy was called as a weakref callback")
1461            return finalize_fairy(
1462                connection, connection_record, pool, ref, echo, fairy)
1463        return patch.object(
1464            pool, '_finalize_fairy', assert_no_wr_callback)
1465
1466    def _assert_cleanup_on_pooled_reconnect(self, dbapi, p):
1467        # p is QueuePool with size=1, max_overflow=2,
1468        # and one connection in the pool that will need to
1469        # reconnect when next used (either due to recycle or invalidate)
1470
1471        with self._no_wr_finalize():
1472            eq_(p.checkedout(), 0)
1473            eq_(p._overflow, 0)
1474            dbapi.shutdown(True)
1475            assert_raises(
1476                Exception,
1477                p.connect
1478            )
1479            eq_(p._overflow, 0)
1480            eq_(p.checkedout(), 0)  # and not 1
1481
1482            dbapi.shutdown(False)
1483
1484            c1 = self._with_teardown(p.connect())  # noqa
1485            assert p._pool.empty()  # poolsize is one, so we're empty OK
1486            c2 = self._with_teardown(p.connect())  # noqa
1487            eq_(p._overflow, 1)  # and not 2
1488
1489            # this hangs if p._overflow is 2
1490            c3 = self._with_teardown(p.connect())
1491
1492            c3.close()
1493
1494    def test_error_on_pooled_reconnect_cleanup_invalidate(self):
1495        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=2)
1496        c1 = p.connect()
1497        c1.invalidate()
1498        c1.close()
1499        self._assert_cleanup_on_pooled_reconnect(dbapi, p)
1500
1501    @testing.requires.timing_intensive
1502    def test_error_on_pooled_reconnect_cleanup_recycle(self):
1503        dbapi, p = self._queuepool_dbapi_fixture(
1504            pool_size=1,
1505            max_overflow=2, recycle=1)
1506        c1 = p.connect()
1507        c1.close()
1508        time.sleep(1.5)
1509        self._assert_cleanup_on_pooled_reconnect(dbapi, p)
1510
1511    def test_connect_handler_not_called_for_recycled(self):
1512        """test [ticket:3497]"""
1513
1514        dbapi, p = self._queuepool_dbapi_fixture(
1515            pool_size=2, max_overflow=2)
1516
1517        canary = Mock()
1518
1519        c1 = p.connect()
1520        c2 = p.connect()
1521
1522        c1.close()
1523        c2.close()
1524
1525        dbapi.shutdown(True)
1526
1527        bad = p.connect()
1528        p._invalidate(bad)
1529        bad.close()
1530        assert p._invalidate_time
1531
1532        event.listen(p, "connect", canary.connect)
1533        event.listen(p, "checkout", canary.checkout)
1534
1535        assert_raises(
1536            Exception,
1537            p.connect
1538        )
1539
1540        p._pool.queue = collections.deque(
1541            [
1542                c for c in p._pool.queue
1543                if c.connection is not None
1544            ]
1545        )
1546
1547        dbapi.shutdown(False)
1548        c = p.connect()
1549        c.close()
1550
1551        eq_(
1552            canary.mock_calls,
1553            [
1554                call.connect(ANY, ANY),
1555                call.checkout(ANY, ANY, ANY)
1556            ]
1557        )
1558
1559    def test_connect_checkout_handler_always_gets_info(self):
1560        """test [ticket:3497]"""
1561
1562        dbapi, p = self._queuepool_dbapi_fixture(
1563            pool_size=2, max_overflow=2)
1564
1565        c1 = p.connect()
1566        c2 = p.connect()
1567
1568        c1.close()
1569        c2.close()
1570
1571        dbapi.shutdown(True)
1572
1573        bad = p.connect()
1574        p._invalidate(bad)
1575        bad.close()
1576        assert p._invalidate_time
1577
1578        @event.listens_for(p, "connect")
1579        def connect(conn, conn_rec):
1580            conn_rec.info['x'] = True
1581
1582        @event.listens_for(p, "checkout")
1583        def checkout(conn, conn_rec, conn_f):
1584            assert 'x' in conn_rec.info
1585
1586        assert_raises(
1587            Exception,
1588            p.connect
1589        )
1590
1591        p._pool.queue = collections.deque(
1592            [
1593                c for c in p._pool.queue
1594                if c.connection is not None
1595            ]
1596        )
1597
1598        dbapi.shutdown(False)
1599        c = p.connect()
1600        c.close()
1601
1602    def test_error_on_pooled_reconnect_cleanup_wcheckout_event(self):
1603        dbapi, p = self._queuepool_dbapi_fixture(
1604            pool_size=1,
1605            max_overflow=2)
1606
1607        c1 = p.connect()
1608        c1.close()
1609
1610        @event.listens_for(p, "checkout")
1611        def handle_checkout_event(dbapi_con, con_record, con_proxy):
1612            if dbapi.is_shutdown:
1613                raise tsa.exc.DisconnectionError()
1614
1615        self._assert_cleanup_on_pooled_reconnect(dbapi, p)
1616
1617    @testing.requires.timing_intensive
1618    def test_recycle_pool_no_race(self):
1619        def slow_close():
1620            slow_closing_connection._slow_close()
1621            time.sleep(.5)
1622
1623        slow_closing_connection = Mock()
1624        slow_closing_connection.connect.return_value.close = slow_close
1625
1626        class Error(Exception):
1627            pass
1628
1629        dialect = Mock()
1630        dialect.is_disconnect = lambda *arg, **kw: True
1631        dialect.dbapi.Error = Error
1632
1633        pools = []
1634
1635        class TrackQueuePool(pool.QueuePool):
1636            def __init__(self, *arg, **kw):
1637                pools.append(self)
1638                super(TrackQueuePool, self).__init__(*arg, **kw)
1639
1640        def creator():
1641            return slow_closing_connection.connect()
1642        p1 = TrackQueuePool(creator=creator, pool_size=20)
1643
1644        from sqlalchemy import create_engine
1645        eng = create_engine(testing.db.url, pool=p1, _initialize=False)
1646        eng.dialect = dialect
1647
1648        # 15 total connections
1649        conns = [eng.connect() for i in range(15)]
1650
1651        # return 8 back to the pool
1652        for conn in conns[3:10]:
1653            conn.close()
1654
1655        def attempt(conn):
1656            time.sleep(random.random())
1657            try:
1658                conn._handle_dbapi_exception(
1659                    Error(), "statement", {},
1660                    Mock(), Mock())
1661            except tsa.exc.DBAPIError:
1662                pass
1663
1664        # run an error + invalidate operation on the remaining 7 open
1665        # connections
1666        threads = []
1667        for conn in conns:
1668            t = threading.Thread(target=attempt, args=(conn, ))
1669            t.start()
1670            threads.append(t)
1671
1672        for t in threads:
1673            t.join()
1674
1675        # return all 15 connections to the pool
1676        for conn in conns:
1677            conn.close()
1678
1679        # re-open 15 total connections
1680        conns = [eng.connect() for i in range(15)]
1681
1682        # 15 connections have been fully closed due to invalidate
1683        assert slow_closing_connection._slow_close.call_count == 15
1684
1685        # 15 initial connections + 15 reconnections
1686        assert slow_closing_connection.connect.call_count == 30
1687        assert len(pools) <= 2, len(pools)
1688
1689    def test_invalidate(self):
1690        p = self._queuepool_fixture(pool_size=1, max_overflow=0)
1691        c1 = p.connect()
1692        c_id = c1.connection.id
1693        c1.close()
1694        c1 = None
1695        c1 = p.connect()
1696        assert c1.connection.id == c_id
1697        c1.invalidate()
1698        c1 = None
1699        c1 = p.connect()
1700        assert c1.connection.id != c_id
1701
1702    def test_recreate(self):
1703        p = self._queuepool_fixture(reset_on_return=None, pool_size=1,
1704                                    max_overflow=0)
1705        p2 = p.recreate()
1706        assert p2.size() == 1
1707        assert p2._reset_on_return is pool.reset_none
1708        assert p2._use_threadlocal is False
1709        assert p2._max_overflow == 0
1710
1711    def test_reconnect(self):
1712        """tests reconnect operations at the pool level.  SA's
1713        engine/dialect includes another layer of reconnect support for
1714        'database was lost' errors."""
1715
1716        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
1717        c1 = p.connect()
1718        c_id = c1.connection.id
1719        c1.close()
1720        c1 = None
1721        c1 = p.connect()
1722        assert c1.connection.id == c_id
1723        dbapi.raise_error = True
1724        c1.invalidate()
1725        c1 = None
1726        c1 = p.connect()
1727        assert c1.connection.id != c_id
1728
1729    def test_detach(self):
1730        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
1731        c1 = p.connect()
1732        c1.detach()
1733        c2 = p.connect()  # noqa
1734        eq_(dbapi.connect.mock_calls, [call("foo.db"), call("foo.db")])
1735
1736        c1_con = c1.connection
1737        assert c1_con is not None
1738        eq_(c1_con.close.call_count, 0)
1739        c1.close()
1740        eq_(c1_con.close.call_count, 1)
1741
1742    def test_detach_via_invalidate(self):
1743        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
1744
1745        c1 = p.connect()
1746        c1_con = c1.connection
1747        c1.invalidate()
1748        assert c1.connection is None
1749        eq_(c1_con.close.call_count, 1)
1750
1751        c2 = p.connect()
1752        assert c2.connection is not c1_con
1753        c2_con = c2.connection
1754
1755        c2.close()
1756        eq_(c2_con.close.call_count, 0)
1757
1758    def test_threadfairy(self):
1759        p = self._queuepool_fixture(pool_size=3, max_overflow=-1,
1760                                    use_threadlocal=True)
1761        c1 = p.connect()
1762        c1.close()
1763        c2 = p.connect()
1764        assert c2.connection is not None
1765
1766
1767class ResetOnReturnTest(PoolTestBase):
1768    def _fixture(self, **kw):
1769        dbapi = Mock()
1770        return dbapi, pool.QueuePool(
1771            creator=lambda: dbapi.connect('foo.db'),
1772            **kw)
1773
1774    def test_plain_rollback(self):
1775        dbapi, p = self._fixture(reset_on_return='rollback')
1776
1777        c1 = p.connect()
1778        c1.close()
1779        assert dbapi.connect().rollback.called
1780        assert not dbapi.connect().commit.called
1781
1782    def test_plain_commit(self):
1783        dbapi, p = self._fixture(reset_on_return='commit')
1784
1785        c1 = p.connect()
1786        c1.close()
1787        assert not dbapi.connect().rollback.called
1788        assert dbapi.connect().commit.called
1789
1790    def test_plain_none(self):
1791        dbapi, p = self._fixture(reset_on_return=None)
1792
1793        c1 = p.connect()
1794        c1.close()
1795        assert not dbapi.connect().rollback.called
1796        assert not dbapi.connect().commit.called
1797
1798    def test_agent_rollback(self):
1799        dbapi, p = self._fixture(reset_on_return='rollback')
1800
1801        class Agent(object):
1802            def __init__(self, conn):
1803                self.conn = conn
1804
1805            def rollback(self):
1806                self.conn.special_rollback()
1807
1808            def commit(self):
1809                self.conn.special_commit()
1810
1811        c1 = p.connect()
1812        c1._reset_agent = Agent(c1)
1813        c1.close()
1814
1815        assert dbapi.connect().special_rollback.called
1816        assert not dbapi.connect().special_commit.called
1817
1818        assert not dbapi.connect().rollback.called
1819        assert not dbapi.connect().commit.called
1820
1821        c1 = p.connect()
1822        c1.close()
1823        eq_(dbapi.connect().special_rollback.call_count, 1)
1824        eq_(dbapi.connect().special_commit.call_count, 0)
1825
1826        assert dbapi.connect().rollback.called
1827        assert not dbapi.connect().commit.called
1828
1829    def test_agent_commit(self):
1830        dbapi, p = self._fixture(reset_on_return='commit')
1831
1832        class Agent(object):
1833            def __init__(self, conn):
1834                self.conn = conn
1835
1836            def rollback(self):
1837                self.conn.special_rollback()
1838
1839            def commit(self):
1840                self.conn.special_commit()
1841
1842        c1 = p.connect()
1843        c1._reset_agent = Agent(c1)
1844        c1.close()
1845        assert not dbapi.connect().special_rollback.called
1846        assert dbapi.connect().special_commit.called
1847
1848        assert not dbapi.connect().rollback.called
1849        assert not dbapi.connect().commit.called
1850
1851        c1 = p.connect()
1852        c1.close()
1853
1854        eq_(dbapi.connect().special_rollback.call_count, 0)
1855        eq_(dbapi.connect().special_commit.call_count, 1)
1856        assert not dbapi.connect().rollback.called
1857        assert dbapi.connect().commit.called
1858
1859
1860class SingletonThreadPoolTest(PoolTestBase):
1861
1862    @testing.requires.threading_with_mock
1863    def test_cleanup(self):
1864        self._test_cleanup(False)
1865
1866    @testing.requires.threading_with_mock
1867    def test_cleanup_no_gc(self):
1868        self._test_cleanup(True)
1869
1870    def _test_cleanup(self, strong_refs):
1871        """test that the pool's connections are OK after cleanup() has
1872        been called."""
1873
1874        dbapi = MockDBAPI()
1875
1876        lock = threading.Lock()
1877
1878        def creator():
1879            # the mock iterator isn't threadsafe...
1880            with lock:
1881                return dbapi.connect()
1882        p = pool.SingletonThreadPool(creator=creator, pool_size=3)
1883
1884        if strong_refs:
1885            sr = set()
1886
1887            def _conn():
1888                c = p.connect()
1889                sr.add(c.connection)
1890                return c
1891        else:
1892            def _conn():
1893                return p.connect()
1894
1895        def checkout():
1896            for x in range(10):
1897                c = _conn()
1898                assert c
1899                c.cursor()
1900                c.close()
1901                time.sleep(.1)
1902
1903        threads = []
1904        for i in range(10):
1905            th = threading.Thread(target=checkout)
1906            th.start()
1907            threads.append(th)
1908        for th in threads:
1909            th.join(join_timeout)
1910        assert len(p._all_conns) == 3
1911
1912        if strong_refs:
1913            still_opened = len([c for c in sr if not c.close.call_count])
1914            eq_(still_opened, 3)
1915
1916
1917class AssertionPoolTest(PoolTestBase):
1918    def test_connect_error(self):
1919        dbapi = MockDBAPI()
1920        p = pool.AssertionPool(creator=lambda: dbapi.connect('foo.db'))
1921        c1 = p.connect()  # noqa
1922        assert_raises(AssertionError, p.connect)
1923
1924    def test_connect_multiple(self):
1925        dbapi = MockDBAPI()
1926        p = pool.AssertionPool(creator=lambda: dbapi.connect('foo.db'))
1927        c1 = p.connect()
1928        c1.close()
1929        c2 = p.connect()
1930        c2.close()
1931
1932        c3 = p.connect()  # noqa
1933        assert_raises(AssertionError, p.connect)
1934
1935
1936class NullPoolTest(PoolTestBase):
1937    def test_reconnect(self):
1938        dbapi = MockDBAPI()
1939        p = pool.NullPool(creator=lambda: dbapi.connect('foo.db'))
1940        c1 = p.connect()
1941
1942        c1.close()
1943        c1 = None
1944
1945        c1 = p.connect()
1946        c1.invalidate()
1947        c1 = None
1948
1949        c1 = p.connect()
1950        dbapi.connect.assert_has_calls(
1951            [
1952                call('foo.db'),
1953                call('foo.db')],
1954            any_order=True)
1955
1956
1957class StaticPoolTest(PoolTestBase):
1958    def test_recreate(self):
1959        dbapi = MockDBAPI()
1960
1961        def creator():
1962            return dbapi.connect('foo.db')
1963        p = pool.StaticPool(creator)
1964        p2 = p.recreate()
1965        assert p._creator is p2._creator
1966
1967
1968class CreatorCompatibilityTest(PoolTestBase):
1969    def test_creator_callable_outside_noarg(self):
1970        e = testing_engine()
1971
1972        creator = e.pool._creator
1973        try:
1974            conn = creator()
1975        finally:
1976            conn.close()
1977
1978    def test_creator_callable_outside_witharg(self):
1979        e = testing_engine()
1980
1981        creator = e.pool._creator
1982        try:
1983            conn = creator(Mock())
1984        finally:
1985            conn.close()
1986
1987    def test_creator_patching_arg_to_noarg(self):
1988        e = testing_engine()
1989        creator = e.pool._creator
1990        try:
1991            # the creator is the two-arg form
1992            conn = creator(Mock())
1993        finally:
1994            conn.close()
1995
1996        def mock_create():
1997            return creator()
1998
1999        conn = e.connect()
2000        conn.invalidate()
2001        conn.close()
2002
2003        # test that the 'should_wrap_creator' status
2004        # will dynamically switch if the _creator is monkeypatched.
2005
2006        # patch it with a zero-arg form
2007        with patch.object(e.pool, "_creator", mock_create):
2008            conn = e.connect()
2009            conn.invalidate()
2010            conn.close()
2011
2012        conn = e.connect()
2013        conn.close()
2014